オンラインクラスタリングの例

このオンラインクラスタリングの例では、Pub/Subからテキストを読み取り、言語モデルを使用してテキストを埋め込みに変換し、BIRCHを使用してテキストをクラスタリングできるリアルタイムクラスタリングパイプラインの設定方法を示します。

クラスタリング用データセット

この例では、6つの基本的な感情(怒り、恐怖、喜び、愛、悲しみ、驚き)を持つ20,000件の英語のTwitterメッセージを含むemotionというデータセットを使用します。このデータセットには、トレーニング、検証、テストの3つの分割があります。テキストとデータセットのカテゴリ(クラス)が含まれているため、これは教師ありデータセットです。このデータセットにアクセスするには、Hugging Faceデータセットページを使用してください。

以下のテキストは、データセットのトレーニング分割からの例を示しています。

テキスト感情の種類
少し時間を取って投稿します。欲張りで間違っていると感じています。怒り
暖炉について懐かしく思うたびに、それがまだ敷地内にあることを知っています。
推奨量の何倍ものミリグラムを服用していて、ずっと早く眠りに落ちますが、とてもおかしな気分にもなります。恐怖
デンマークへのボート旅行で喜び
SFの世界では、基本的に偽物のように感じています。悲しみ
週に数回、幻覚、動く人々や人物、音、振動に苦しめられるようになりました。恐怖

クラスタリングアルゴリズム

ツイートのクラスタリングには、BIRCHと呼ばれる増分クラスタリングアルゴリズムを使用します。これは、階層を使用してバランスの取れた反復的な削減とクラスタリングを意味し、特に大規模なデータセットに対して階層的クラスタリングを実行するために使用される教師なしデータマイニングアルゴリズムです。BIRCHの利点の1つは、与えられたリソース(メモリと時間制約)に対して最高の品質のクラスタリングを生成しようとして、着信する多次元のメトリックデータポイントをインクリメンタルかつ動的にクラスタリングできることです。

Pub/Subへのインジェスト

この例では、Pub/Subからツイートをクラスタリングしながら読み取ることができるように、データをPub/Subにインジェストすることから始めます。Pub/Subは、アプリケーションやサービス間でイベントデータを交換するためのメッセージングサービスです。ストリーミング分析とデータ統合パイプラインは、Pub/Subを使用してデータのインジェストと配信を行います。

Pub/Subへのデータインジェストの完全なサンプルコードは、GitHubにあります。

インジェストパイプラインのファイル構造を以下の図に示します。

write_data_to_pubsub_pipeline/
├── pipeline/
│   ├── __init__.py
│   ├── options.py
│   └── utils.py
├── __init__.py
├── config.py
├── main.py
└── setup.py

pipeline/utils.pyには、感情データセットのロードとデータ変換に使用される2つのbeam.DoFnに関するコードが含まれています。

pipeline/options.pyには、Dataflowパイプラインを設定するためのパイプラインオプションが含まれています。

config.pyは、GCP PROJECT_IDやNUM_WORKERSなど、複数回使用される変数を定義します。

setup.pyは、パイプラインを実行するためのパッケージと要件を定義します。

main.pyには、パイプラインコードと、パイプラインの実行に使用される追加の関数が含まれています。

パイプラインの実行

まず、必要なパッケージをインストールします。

  1. ローカルマシンで:python main.py
  2. Dataflow向けのGCPで:python main.py --mode cloud

write_data_to_pubsub_pipelineには、4つの異なるトランスフォームが含まれています。

  1. Hugging Faceデータセットを使用して感情データセットを読み込みます(簡素化のため、6つのクラスではなく3つのクラスからサンプルを取得します)。
  2. 各テキストに一意の識別子(UID)を関連付けます。
  3. Pub/Subが期待する形式にテキストを変換します。
  4. フォーマットされたメッセージをPub/Subに書き込みます。

ストリーミングデータのクラスタリング

データをPub/Subにインジェストした後、ストリーミングメッセージをPub/Subから読み取り、言語モデルを使用してテキストを埋め込みに変換し、BIRCHを使用して埋め込みをクラスタリングする2番目のパイプラインを調べます。

前述のすべてのステップの完全なサンプルコードは、GitHubにあります。

clustering_pipelineのファイル構造は次のとおりです。

clustering_pipeline/
├── pipeline/
│   ├── __init__.py
│   ├── options.py
│   └── transformations.py
├── __init__.py
├── config.py
├── main.py
└── setup.py

pipeline/transformations.pyには、パイプラインで使用されるさまざまなbeam.DoFnのコードが含まれています。

pipeline/options.pyには、Dataflowパイプラインを設定するためのパイプラインオプションが含まれています。

config.pyは、Google Cloud PROJECT_IDやNUM_WORKERSなど、複数回使用される変数を定義します。

setup.pyは、パイプラインを実行するためのパッケージと要件を定義します。

main.pyには、パイプラインコードと、パイプラインの実行に使用される追加の関数が含まれています。

パイプラインの実行

必要なパッケージをインストールし、データをPub/Subにプッシュします。

  1. ローカルマシンで:python main.py
  2. Dataflow向けのGCPで:python main.py --mode cloud

パイプラインは、次のステップに分割できます。

  1. Pub/Subからメッセージを読み取ります。
  2. Pub/Subメッセージを、キーがUIDで値がTwitterテキストであるディクショナリのPCollectionに変換します。
  3. トカナイズを使用してテキストをトランスフォーマーで読み取れるトークンID整数にエンコードします。
  4. RunInferenceを使用して、トランスフォーマーベースの言語モデルからベクトル埋め込みを取得します。
  5. クラスタリングのために埋め込みを正規化します。
  6. 状態付き処理を使用してBIRCHクラスタリングを実行します。
  7. クラスタに割り当てられたテキストを出力します。

次のコードは、Pub/Subからのメッセージが読み取られてディクショナリに変換されるパイプラインの最初の2つのステップを示しています。

    docs = (
        pipeline
        | "Read from PubSub"
        >> ReadFromPubSub(subscription=cfg.SUBSCRIPTION_ID, with_attributes=True)
        | "Decode PubSubMessage" >> beam.ParDo(Decode())
    )

次のセクションでは、3つの重要なパイプラインステップを調べます。

  1. テキストのトークン化。
  2. トークン化されたテキストを入力して、トランスフォーマーベースの言語モデルから埋め込みを取得します。
  3. 状態付き処理を使用してクラスタリングを実行します。

言語モデルからの埋め込みの取得

テキストデータをクラスタリングするには、テキストを統計分析に適した数値のベクトルにマッピングする必要があります。この例では、sentence-transformers/stsb-distilbert-base/stsb-distilbert-baseと呼ばれるトランスフォーマーベースの言語モデルを使用します。これは、文と段落を768次元の密ベクトル空間にマッピングし、クラスタリングやセマンティック検索などのタスクに使用できます。

言語モデルは生のテキストではなくトークン化された入力を期待しているため、まずテキストをトークン化します。トークン化は、テキストをモデルに入力して予測を取得できるように変換する前処理タスクです。

    normalized_embedding = (
        docs
        | "Tokenize Text" >> beam.Map(tokenize_sentence)

ここでは、tokenize_sentenceはテキストとIDを含む辞書を受け取り、テキストをトークン化し、(テキスト、ID)とトークン化された出力を返す関数です。

トークン化された出力は、埋め込みを取得するために言語モデルに渡されます。言語モデルから埋め込みを取得するには、Apache BeamのRunInference()を使用します。

    | "Get Embedding" >> RunInference(KeyedModelHandler(model_handler))

より良いクラスタを作成するために、各Twitterテキストの埋め込みを取得した後、埋め込みを正規化します。

    | "Normalize Embedding" >> beam.ParDo(NormalizeEmbedding())

StatefulOnlineClustering

データはストリーミングされるため、BIRCHのような反復クラスタリングアルゴリズムを使用する必要があります。また、アルゴリズムは反復的であるため、Twitterテキストが到着したときに更新できるよう、以前の状態を保存するメカニズムが必要です。**ステートフル処理**により、DoFnは永続的な状態を持つことができ、各要素の処理中に読み書きできます。ステートフル処理の詳細については、Apache Beamによるステートフル処理を参照してください。

この例では、Pub/Subから新しいメッセージが読み取られるたびに、クラスタリングモデルの既存の状態を取得し、更新し、状態に書き戻します。

    clustering = (
        normalized_embedding
        | "Map doc to key" >> beam.Map(lambda x: (1, x))
        | "StatefulClustering using Birch" >> beam.ParDo(StatefulOnlineClustering())
    )

BIRCHは並列化をサポートしていないため、ステートフル処理を実行しているワーカーが1つだけであることを確認する必要があります。そのため、Beam.Mapを使用して、各テキストに同じキー1を関連付けます。

StatefulOnlineClusteringは、テキストの埋め込みを受け取り、クラスタリングモデルを更新するDoFnです。状態を保存するために、ストレージのコンテナとして機能するReadModifyWriteStateSpec状態オブジェクトを使用します。

class StatefulOnlineClustering(beam.DoFn):

    BIRCH_MODEL_SPEC = ReadModifyWriteStateSpec("clustering_model", PickleCoder())
    DATA_ITEMS_SPEC = ReadModifyWriteStateSpec("data_items", PickleCoder())
    EMBEDDINGS_SPEC = ReadModifyWriteStateSpec("embeddings", PickleCoder())
    UPDATE_COUNTER_SPEC = ReadModifyWriteStateSpec("update_counter", PickleCoder())

この例では、4つの異なるReadModifyWriteStateSpecオブジェクトを宣言しています。

これらのReadModifyWriteStateSpecオブジェクトは、process関数に追加の引数として渡されます。新しいニュースアイテムが入力されると、異なるオブジェクトの既存の状態を取得し、更新してから、永続的な共有状態として書き戻します。

def process(
    self,
    element,
    model_state=beam.DoFn.StateParam(BIRCH_MODEL_SPEC),
    collected_docs_state=beam.DoFn.StateParam(DATA_ITEMS_SPEC),
    collected_embeddings_state=beam.DoFn.StateParam(EMBEDDINGS_SPEC),
    update_counter_state=beam.DoFn.StateParam(UPDATE_COUNTER_SPEC),
    *args,
    **kwargs,
):
  """
      Takes the embedding of a document and updates the clustering model

      Args:
        element: The input element to be processed.
        model_state: This is the state of the clustering model. It is a stateful parameter,
        which means that it will be updated after each call to the process function.
        collected_docs_state: This is a stateful dictionary that stores the documents that
        have been processed so far.
        collected_embeddings_state: This is a dictionary of document IDs and their embeddings.
        update_counter_state: This is a counter that keeps track of how many documents have been
      processed.
      """
  # 1. Initialise or load states
  clustering = model_state.read() or Birch(n_clusters=None, threshold=0.7)
  collected_documents = collected_docs_state.read() or {}
  collected_embeddings = collected_embeddings_state.read() or {}
  update_counter = update_counter_state.read() or Counter()

  # 2. Extract document, add to state, and add to clustering model
  _, doc = element
  doc_id = doc["id"]
  embedding_vector = doc["embedding"]
  collected_embeddings[doc_id] = embedding_vector
  collected_documents[doc_id] = {"id": doc_id, "text": doc["text"]}
  update_counter = len(collected_documents)

  clustering.partial_fit(np.atleast_2d(embedding_vector))

  # 3. Predict cluster labels of collected documents
  cluster_labels = clustering.predict(
      np.array(list(collected_embeddings.values())))

  # 4. Write states
  model_state.write(clustering)
  collected_docs_state.write(collected_documents)
  collected_embeddings_state.write(collected_embeddings)
  update_counter_state.write(update_counter)
  yield {
      "labels": cluster_labels,
      "docs": collected_documents,
      "id": list(collected_embeddings.keys()),
      "counter": update_counter,
  }

GetUpdatesは、新しいメッセージが到着するたびに、各Twitterメッセージに割り当てられたクラスタを出力するDoFnです。

updated_clusters = clustering | "Format Update" >> beam.ParDo(GetUpdates())