オンラインクラスタリングの例
このオンラインクラスタリングの例では、Pub/Subからテキストを読み取り、言語モデルを使用してテキストを埋め込みに変換し、BIRCHを使用してテキストをクラスタリングできるリアルタイムクラスタリングパイプラインの設定方法を示します。
クラスタリング用データセット
この例では、6つの基本的な感情(怒り、恐怖、喜び、愛、悲しみ、驚き)を持つ20,000件の英語のTwitterメッセージを含むemotionというデータセットを使用します。このデータセットには、トレーニング、検証、テストの3つの分割があります。テキストとデータセットのカテゴリ(クラス)が含まれているため、これは教師ありデータセットです。このデータセットにアクセスするには、Hugging Faceデータセットページを使用してください。
以下のテキストは、データセットのトレーニング分割からの例を示しています。
テキスト | 感情の種類 |
---|---|
少し時間を取って投稿します。欲張りで間違っていると感じています。 | 怒り |
暖炉について懐かしく思うたびに、それがまだ敷地内にあることを知っています。 | 愛 |
推奨量の何倍ものミリグラムを服用していて、ずっと早く眠りに落ちますが、とてもおかしな気分にもなります。 | 恐怖 |
デンマークへのボート旅行で | 喜び |
SFの世界では、基本的に偽物のように感じています。 | 悲しみ |
週に数回、幻覚、動く人々や人物、音、振動に苦しめられるようになりました。 | 恐怖 |
クラスタリングアルゴリズム
ツイートのクラスタリングには、BIRCHと呼ばれる増分クラスタリングアルゴリズムを使用します。これは、階層を使用してバランスの取れた反復的な削減とクラスタリングを意味し、特に大規模なデータセットに対して階層的クラスタリングを実行するために使用される教師なしデータマイニングアルゴリズムです。BIRCHの利点の1つは、与えられたリソース(メモリと時間制約)に対して最高の品質のクラスタリングを生成しようとして、着信する多次元のメトリックデータポイントをインクリメンタルかつ動的にクラスタリングできることです。
Pub/Subへのインジェスト
この例では、Pub/Subからツイートをクラスタリングしながら読み取ることができるように、データをPub/Subにインジェストすることから始めます。Pub/Subは、アプリケーションやサービス間でイベントデータを交換するためのメッセージングサービスです。ストリーミング分析とデータ統合パイプラインは、Pub/Subを使用してデータのインジェストと配信を行います。
Pub/Subへのデータインジェストの完全なサンプルコードは、GitHubにあります。
インジェストパイプラインのファイル構造を以下の図に示します。
write_data_to_pubsub_pipeline/
├── pipeline/
│ ├── __init__.py
│ ├── options.py
│ └── utils.py
├── __init__.py
├── config.py
├── main.py
└── setup.py
pipeline/utils.py
には、感情データセットのロードとデータ変換に使用される2つのbeam.DoFn
に関するコードが含まれています。
pipeline/options.py
には、Dataflowパイプラインを設定するためのパイプラインオプションが含まれています。
config.py
は、GCP PROJECT_IDやNUM_WORKERSなど、複数回使用される変数を定義します。
setup.py
は、パイプラインを実行するためのパッケージと要件を定義します。
main.py
には、パイプラインコードと、パイプラインの実行に使用される追加の関数が含まれています。
パイプラインの実行
まず、必要なパッケージをインストールします。
- ローカルマシンで:
python main.py
- Dataflow向けのGCPで:
python main.py --mode cloud
write_data_to_pubsub_pipeline
には、4つの異なるトランスフォームが含まれています。
- Hugging Faceデータセットを使用して感情データセットを読み込みます(簡素化のため、6つのクラスではなく3つのクラスからサンプルを取得します)。
- 各テキストに一意の識別子(UID)を関連付けます。
- Pub/Subが期待する形式にテキストを変換します。
- フォーマットされたメッセージをPub/Subに書き込みます。
ストリーミングデータのクラスタリング
データをPub/Subにインジェストした後、ストリーミングメッセージをPub/Subから読み取り、言語モデルを使用してテキストを埋め込みに変換し、BIRCHを使用して埋め込みをクラスタリングする2番目のパイプラインを調べます。
前述のすべてのステップの完全なサンプルコードは、GitHubにあります。
clustering_pipelineのファイル構造は次のとおりです。
clustering_pipeline/
├── pipeline/
│ ├── __init__.py
│ ├── options.py
│ └── transformations.py
├── __init__.py
├── config.py
├── main.py
└── setup.py
pipeline/transformations.py
には、パイプラインで使用されるさまざまなbeam.DoFn
のコードが含まれています。
pipeline/options.py
には、Dataflowパイプラインを設定するためのパイプラインオプションが含まれています。
config.py
は、Google Cloud PROJECT_IDやNUM_WORKERSなど、複数回使用される変数を定義します。
setup.py
は、パイプラインを実行するためのパッケージと要件を定義します。
main.py
には、パイプラインコードと、パイプラインの実行に使用される追加の関数が含まれています。
パイプラインの実行
必要なパッケージをインストールし、データをPub/Subにプッシュします。
- ローカルマシンで:
python main.py
- Dataflow向けのGCPで:
python main.py --mode cloud
パイプラインは、次のステップに分割できます。
- Pub/Subからメッセージを読み取ります。
- Pub/Subメッセージを、キーがUIDで値がTwitterテキストであるディクショナリの
PCollection
に変換します。 - トカナイズを使用してテキストをトランスフォーマーで読み取れるトークンID整数にエンコードします。
- RunInferenceを使用して、トランスフォーマーベースの言語モデルからベクトル埋め込みを取得します。
- クラスタリングのために埋め込みを正規化します。
- 状態付き処理を使用してBIRCHクラスタリングを実行します。
- クラスタに割り当てられたテキストを出力します。
次のコードは、Pub/Subからのメッセージが読み取られてディクショナリに変換されるパイプラインの最初の2つのステップを示しています。
次のセクションでは、3つの重要なパイプラインステップを調べます。
- テキストのトークン化。
- トークン化されたテキストを入力して、トランスフォーマーベースの言語モデルから埋め込みを取得します。
- 状態付き処理を使用してクラスタリングを実行します。
言語モデルからの埋め込みの取得
テキストデータをクラスタリングするには、テキストを統計分析に適した数値のベクトルにマッピングする必要があります。この例では、sentence-transformers/stsb-distilbert-base/stsb-distilbert-baseと呼ばれるトランスフォーマーベースの言語モデルを使用します。これは、文と段落を768次元の密ベクトル空間にマッピングし、クラスタリングやセマンティック検索などのタスクに使用できます。
言語モデルは生のテキストではなくトークン化された入力を期待しているため、まずテキストをトークン化します。トークン化は、テキストをモデルに入力して予測を取得できるように変換する前処理タスクです。
ここでは、tokenize_sentence
はテキストとIDを含む辞書を受け取り、テキストをトークン化し、(テキスト、ID)とトークン化された出力を返す関数です。
トークン化された出力は、埋め込みを取得するために言語モデルに渡されます。言語モデルから埋め込みを取得するには、Apache BeamのRunInference()
を使用します。
より良いクラスタを作成するために、各Twitterテキストの埋め込みを取得した後、埋め込みを正規化します。
StatefulOnlineClustering
データはストリーミングされるため、BIRCHのような反復クラスタリングアルゴリズムを使用する必要があります。また、アルゴリズムは反復的であるため、Twitterテキストが到着したときに更新できるよう、以前の状態を保存するメカニズムが必要です。**ステートフル処理**により、DoFn
は永続的な状態を持つことができ、各要素の処理中に読み書きできます。ステートフル処理の詳細については、Apache Beamによるステートフル処理を参照してください。
この例では、Pub/Subから新しいメッセージが読み取られるたびに、クラスタリングモデルの既存の状態を取得し、更新し、状態に書き戻します。
BIRCHは並列化をサポートしていないため、ステートフル処理を実行しているワーカーが1つだけであることを確認する必要があります。そのため、Beam.Map
を使用して、各テキストに同じキー1
を関連付けます。
StatefulOnlineClustering
は、テキストの埋め込みを受け取り、クラスタリングモデルを更新するDoFn
です。状態を保存するために、ストレージのコンテナとして機能するReadModifyWriteStateSpec
状態オブジェクトを使用します。
class StatefulOnlineClustering(beam.DoFn):
BIRCH_MODEL_SPEC = ReadModifyWriteStateSpec("clustering_model", PickleCoder())
DATA_ITEMS_SPEC = ReadModifyWriteStateSpec("data_items", PickleCoder())
EMBEDDINGS_SPEC = ReadModifyWriteStateSpec("embeddings", PickleCoder())
UPDATE_COUNTER_SPEC = ReadModifyWriteStateSpec("update_counter", PickleCoder())
この例では、4つの異なるReadModifyWriteStateSpecオブジェクト
を宣言しています。
BIRCH_MODEL_SPEC
は、クラスタリングモデルの状態を保持します。DATA_ITEMS_SPEC
は、これまでに見られたTwitterテキストを保持します。EMBEDDINGS_SPEC
は、正規化された埋め込みを保持します。UPDATE_COUNTER_SPEC
は、処理されたテキストの数を保持します。
これらのReadModifyWriteStateSpec
オブジェクトは、process
関数に追加の引数として渡されます。新しいニュースアイテムが入力されると、異なるオブジェクトの既存の状態を取得し、更新してから、永続的な共有状態として書き戻します。
def process(
self,
element,
model_state=beam.DoFn.StateParam(BIRCH_MODEL_SPEC),
collected_docs_state=beam.DoFn.StateParam(DATA_ITEMS_SPEC),
collected_embeddings_state=beam.DoFn.StateParam(EMBEDDINGS_SPEC),
update_counter_state=beam.DoFn.StateParam(UPDATE_COUNTER_SPEC),
*args,
**kwargs,
):
"""
Takes the embedding of a document and updates the clustering model
Args:
element: The input element to be processed.
model_state: This is the state of the clustering model. It is a stateful parameter,
which means that it will be updated after each call to the process function.
collected_docs_state: This is a stateful dictionary that stores the documents that
have been processed so far.
collected_embeddings_state: This is a dictionary of document IDs and their embeddings.
update_counter_state: This is a counter that keeps track of how many documents have been
processed.
"""
# 1. Initialise or load states
clustering = model_state.read() or Birch(n_clusters=None, threshold=0.7)
collected_documents = collected_docs_state.read() or {}
collected_embeddings = collected_embeddings_state.read() or {}
update_counter = update_counter_state.read() or Counter()
# 2. Extract document, add to state, and add to clustering model
_, doc = element
doc_id = doc["id"]
embedding_vector = doc["embedding"]
collected_embeddings[doc_id] = embedding_vector
collected_documents[doc_id] = {"id": doc_id, "text": doc["text"]}
update_counter = len(collected_documents)
clustering.partial_fit(np.atleast_2d(embedding_vector))
# 3. Predict cluster labels of collected documents
cluster_labels = clustering.predict(
np.array(list(collected_embeddings.values())))
# 4. Write states
model_state.write(clustering)
collected_docs_state.write(collected_documents)
collected_embeddings_state.write(collected_embeddings)
update_counter_state.write(update_counter)
yield {
"labels": cluster_labels,
"docs": collected_documents,
"id": list(collected_embeddings.keys()),
"counter": update_counter,
}
GetUpdates
は、新しいメッセージが到着するたびに、各Twitterメッセージに割り当てられたクラスタを出力するDoFn
です。
最終更新日: 2024年10月31日
必要な情報は見つかりましたか?
すべて役に立ち、分かりやすかったですか?変更したい点があれば教えてください!