線上分群範例

線上分群範例示範如何設定一個即時分群管道,它可以從 Pub/Sub 讀取文字、使用語言模型將文字轉換為嵌入,並使用 BIRCH 對文字進行分群。

用於分群的資料集

此範例使用一個名為 emotion 的資料集,其中包含 20,000 則帶有 6 種基本情緒的英文 Twitter 訊息:憤怒、恐懼、喜悅、愛、悲傷和驚訝。該資料集有三個分割:訓練、驗證和測試。由於它包含資料集的文字和類別 (類別),因此它是一個監督式資料集。要存取此資料集,請使用 Hugging Face 資料集頁面

以下文字顯示資料集訓練分割的範例

文字情緒類型
im grabbing a minute to post i feel greedy wrong憤怒
i am ever feeling nostalgic about the fireplace i will know that it is still on the property
ive been taking or milligrams or times recommended amount and ive fallen asleep a lot faster but i also feel like so funny恐懼
on a boat trip to denmark喜悅
i feel you know basically like a fake in the realm of science fiction悲傷
i began having them several times a week feeling tortured by the hallucinations moving people and figures sounds and vibrations恐懼

分群演算法

對於推文的分群,我們使用一種稱為 BIRCH 的增量分群演算法。它代表使用階層的平衡迭代減少和分群,是一種用於在特別大型的資料集上執行階層分群的非監督式資料採礦演算法。BIRCH 的優點是它能夠增量且動態地對傳入的多維度度量資料點進行分群,以嘗試為給定的資源 (記憶體和時間限制) 生成最佳品質的分群。

擷取到 Pub/Sub

範例首先將資料擷取到 Pub/Sub,以便我們可以在分群時從 Pub/Sub 讀取推文。Pub/Sub 是一種訊息服務,用於在應用程式和服務之間交換事件資料。串流分析和資料整合管道使用 Pub/Sub 來擷取和分發資料。

您可以在 GitHub 中找到將資料擷取到 Pub/Sub 的完整範例程式碼

擷取管道的檔案結構如下圖所示

write_data_to_pubsub_pipeline/
├── pipeline/
│   ├── __init__.py
│   ├── options.py
│   └── utils.py
├── __init__.py
├── config.py
├── main.py
└── setup.py

pipeline/utils.py 包含用於載入情緒資料集和兩個用於資料轉換的 beam.DoFn 的程式碼。

pipeline/options.py 包含用於設定 Dataflow 管道的管道選項。

config.py 定義一些多次使用的變數,例如 GCP PROJECT_ID 和 NUM_WORKERS。

setup.py 定義管道執行所需的套件和需求。

main.py 包含管道程式碼和一些用於執行管道的額外函式。

執行管道

首先,安裝所需的套件。

  1. 在本機上:python main.py
  2. 在 GCP 上用於 Dataflow:python main.py --mode cloud

write_data_to_pubsub_pipeline 包含四個不同的轉換

  1. 使用 Hugging Face 資料集載入情緒資料集 (為了簡化,我們從三個類別而不是六個類別中取樣)。
  2. 將每段文字與唯一識別碼 (UID) 建立關聯。
  3. 將文字轉換為 Pub/Sub 預期的格式。
  4. 將格式化的訊息寫入 Pub/Sub。

串流資料的分群

在將資料擷取到 Pub/Sub 之後,檢查第二個管道,我們從 Pub/Sub 讀取串流訊息,使用語言模型將文字轉換為嵌入,並使用 BIRCH 對嵌入進行分群。

您可以在 GitHub 中找到前面提到之所有步驟的完整範例程式碼。

clustering_pipeline 的檔案結構為

clustering_pipeline/
├── pipeline/
│   ├── __init__.py
│   ├── options.py
│   └── transformations.py
├── __init__.py
├── config.py
├── main.py
└── setup.py

pipeline/transformations.py 包含用於管道中不同 beam.DoFn 的程式碼。

pipeline/options.py 包含用於設定 Dataflow 管道的管道選項。

config.py 定義多次使用的變數,例如 Google Cloud PROJECT_ID 和 NUM_WORKERS。

setup.py 定義管道執行所需的套件和需求。

main.py 包含管道程式碼和一些用於執行管道的額外函式。

執行管道

安裝所需的套件並將資料推送至 Pub/Sub。

  1. 在本機上:python main.py
  2. 在 GCP 上用於 Dataflow:python main.py --mode cloud

該管道可以分解為以下步驟

  1. 從 Pub/Sub 讀取訊息。
  2. 將 Pub/Sub 訊息轉換為字典的 PCollection,其中鍵是 UID,值是 Twitter 文字。
  3. 使用分詞器將文字編碼為變壓器可讀取的符記 ID 整數。
  4. 使用 RunInference 從基於變壓器的語言模型取得向量嵌入。
  5. 正規化嵌入以進行分群。
  6. 使用具狀態處理執行 BIRCH 分群。
  7. 印出分配至各群集的文字。

以下程式碼顯示管道的前兩個步驟,其中會讀取來自 Pub/Sub 的訊息並轉換為字典。

    docs = (
        pipeline
        | "Read from PubSub"
        >> ReadFromPubSub(subscription=cfg.SUBSCRIPTION_ID, with_attributes=True)
        | "Decode PubSubMessage" >> beam.ParDo(Decode())
    )

接下來的章節將檢視三個重要的管道步驟

  1. 將文字符號化 (Tokenize)。
  2. 將符號化的文字輸入,以從基於 Transformer 的語言模型取得嵌入 (Embedding)。
  3. 使用具狀態處理執行分群。

從語言模型取得嵌入

為了對文字資料進行分群,您需要將文字映射到適合統計分析的數值向量。此範例使用名為 sentence-transformers/stsb-distilbert-base/stsb-distilbert-base 的基於 Transformer 的語言模型。它將句子和段落映射到 768 維的密集向量空間,您可以將其用於分群或語意搜尋等任務。

由於語言模型預期輸入的是符號化文字而不是原始文字,因此請先從文字符號化開始。符號化是一種預處理任務,它會轉換文字,以便將其輸入到模型中以取得預測。

    normalized_embedding = (
        docs
        | "Tokenize Text" >> beam.Map(tokenize_sentence)

此處,tokenize_sentence 是一個函數,它接受包含文字和 ID 的字典,對文字進行符號化,並返回一個元組 (文字,id) 和符號化輸出。

然後將符號化輸出傳遞給語言模型以取得嵌入。為了從語言模型取得嵌入,我們使用 Apache Beam 的 RunInference()

    | "Get Embedding" >> RunInference(KeyedModelHandler(model_handler))

為了產生更好的群集,在取得每段 Twitter 文字的嵌入之後,會對嵌入進行正規化。

    | "Normalize Embedding" >> beam.ParDo(NormalizeEmbedding())

StatefulOnlineClustering

由於資料是串流的,因此您需要使用迭代分群演算法,例如 BIRCH。由於該演算法是迭代的,因此您需要一種機制來儲存先前的狀態,以便在 Twitter 文字到達時可以更新。具狀態處理使 DoFn 能夠具有持久狀態,可以在處理每個元素時讀取和寫入。有關具狀態處理的更多資訊,請參閱使用 Apache Beam 的具狀態處理

在此範例中,每次從 Pub/Sub 讀取新訊息時,您都會擷取分群模型的現有狀態、更新它,然後將其寫回狀態。

    clustering = (
        normalized_embedding
        | "Map doc to key" >> beam.Map(lambda x: (1, x))
        | "StatefulClustering using Birch" >> beam.ParDo(StatefulOnlineClustering())
    )

由於 BIRCH 不支援平行化,因此您需要確保只有一個 worker 執行所有具狀態處理。為此,請使用 Beam.Map 將每個文字關聯到相同的鍵 1

StatefulOnlineClustering 是一個 DoFn,它接收文字的嵌入並更新分群模型。為了儲存狀態,它使用 ReadModifyWriteStateSpec 狀態物件,該物件充當儲存的容器。

class StatefulOnlineClustering(beam.DoFn):

    BIRCH_MODEL_SPEC = ReadModifyWriteStateSpec("clustering_model", PickleCoder())
    DATA_ITEMS_SPEC = ReadModifyWriteStateSpec("data_items", PickleCoder())
    EMBEDDINGS_SPEC = ReadModifyWriteStateSpec("embeddings", PickleCoder())
    UPDATE_COUNTER_SPEC = ReadModifyWriteStateSpec("update_counter", PickleCoder())

此範例宣告四個不同的 ReadModifyWriteStateSpec 物件

這些 ReadModifyWriteStateSpec 物件會作為額外引數傳遞給 process 函數。當新項目傳入時,我們會擷取不同物件的現有狀態、更新它們,然後將它們寫回作為持久共享狀態。

def process(
    self,
    element,
    model_state=beam.DoFn.StateParam(BIRCH_MODEL_SPEC),
    collected_docs_state=beam.DoFn.StateParam(DATA_ITEMS_SPEC),
    collected_embeddings_state=beam.DoFn.StateParam(EMBEDDINGS_SPEC),
    update_counter_state=beam.DoFn.StateParam(UPDATE_COUNTER_SPEC),
    *args,
    **kwargs,
):
  """
      Takes the embedding of a document and updates the clustering model

      Args:
        element: The input element to be processed.
        model_state: This is the state of the clustering model. It is a stateful parameter,
        which means that it will be updated after each call to the process function.
        collected_docs_state: This is a stateful dictionary that stores the documents that
        have been processed so far.
        collected_embeddings_state: This is a dictionary of document IDs and their embeddings.
        update_counter_state: This is a counter that keeps track of how many documents have been
      processed.
      """
  # 1. Initialise or load states
  clustering = model_state.read() or Birch(n_clusters=None, threshold=0.7)
  collected_documents = collected_docs_state.read() or {}
  collected_embeddings = collected_embeddings_state.read() or {}
  update_counter = update_counter_state.read() or Counter()

  # 2. Extract document, add to state, and add to clustering model
  _, doc = element
  doc_id = doc["id"]
  embedding_vector = doc["embedding"]
  collected_embeddings[doc_id] = embedding_vector
  collected_documents[doc_id] = {"id": doc_id, "text": doc["text"]}
  update_counter = len(collected_documents)

  clustering.partial_fit(np.atleast_2d(embedding_vector))

  # 3. Predict cluster labels of collected documents
  cluster_labels = clustering.predict(
      np.array(list(collected_embeddings.values())))

  # 4. Write states
  model_state.write(clustering)
  collected_docs_state.write(collected_documents)
  collected_embeddings_state.write(collected_embeddings)
  update_counter_state.write(update_counter)
  yield {
      "labels": cluster_labels,
      "docs": collected_documents,
      "id": list(collected_embeddings.keys()),
      "counter": update_counter,
  }

GetUpdates 是一個 DoFn,它會在每次有新訊息到達時,印出分配給每個 Twitter 訊息的群集。

updated_clusters = clustering | "Format Update" >> beam.ParDo(GetUpdates())