從 Java SDK 使用 RunInference

此範例中的管道是以 Java 撰寫,並從 Google Cloud Storage 讀取輸入資料。藉由PythonExternalTransform的協助,會呼叫一個複合 Python 轉換來執行預處理、後處理和推論。最後,資料會寫回 Java 管道中的 Google Cloud Storage。

您可以在Beam 儲存庫中找到此範例中使用的程式碼。

NLP 模型和資料集

使用 bert-base-uncased 自然語言處理 (NLP) 模型進行推論。此模型是開源的,可在HuggingFace上取得。此 BERT 模型用於根據句子的上下文預測句子的最後一個單字。

我們也使用IMDB 電影評論資料集,這是一個在 Kaggle 上提供的開源資料集。

以下是預處理後資料的範例

文字最後一個字
其中一位評論員提到,在觀看 1 集 Oz 後,您會被 [MASK]。迷住
一個很棒的小 [MASK]製作
所以我不是 Boll 作品的忠實粉絲,但話說回來,沒有多少人是 [MASK]。
這是一部關於三個成為 [MASK] 的囚犯的奇幻電影。出名
有些電影根本不應該被 [MASK]。重拍
凱倫·卡彭特的故事稍微展現了歌手凱倫·卡彭特複雜的 [MASK]。生活

多語言推論管道

使用多語言管道時,您可以存取更大的轉換池。如需更多資訊,請參閱 Apache Beam 程式設計指南中的多語言管道

自訂 Python 轉換

除了執行推論之外,我們還需要對資料執行預處理和後處理。後處理資料可以解釋輸出。為了執行這三項任務,會撰寫一個單一複合自訂 PTransform,每個任務都有一個單位 DoFn 或 PTransform,如下列程式碼片段所示

def expand(self, pcoll):
    return (
    pcoll
    | 'Preprocess' >> beam.ParDo(self.Preprocess(self._tokenizer))
    | 'Inference' >> RunInference(KeyedModelHandler(self._model_handler))
    | 'Postprocess' >> beam.ParDo(self.Postprocess(
        self._tokenizer))
    )

首先,進行資料的預處理。在此範例中,會針對 BERT 模型清除和標記化原始文字資料。所有這些步驟都會在 Preprocess DoFn 中執行。Preprocess DoFn 會採用單一元素做為輸入,並傳回包含原始文字和標記化文字的清單。

然後,預處理後的資料會用於進行推論。這是在 Apache Beam SDK 中已提供的RunInference PTransform 中完成的。RunInference PTransform 需要一個參數,即模型處理常式。在此範例中,會使用 KeyedModelHandler,因為 Preprocess DoFn 也會輸出原始句子。您可以根據您的需求變更預處理的方式。此模型處理常式是在複合 PTransform 的下列初始化函式中定義

def __init__(self, model, model_path):
    self._model = model
    logging.info(f"Downloading {self._model} model from GCS.")
    self._model_config = BertConfig.from_pretrained(self._model)
    self._tokenizer = BertTokenizer.from_pretrained(self._model)
    self._model_handler = self.PytorchModelHandlerKeyedTensorWrapper(
        state_dict_path=(model_path),
        model_class=BertForMaskedLM,
        model_params={'config': self._model_config},
        device='cuda:0')

使用 PytorchModelHandlerKeyedTensorWrapper,它是 PytorchModelHandlerKeyedTensor 模型處理常式的包裝函式。PytorchModelHandlerKeyedTensor 模型處理常式會對 PyTorch 模型進行推論。因為從 BertTokenizer 產生的標記化字串可能具有不同的長度,而且 stack() 需要張量具有相同的大小,所以 PytorchModelHandlerKeyedTensorWrapper 會將批次大小限制為 1。將 max_batch_size 限制為 1 表示 run_inference() 呼叫的每個批次都包含一個範例。以下程式碼顯示包裝函式的定義

class PytorchModelHandlerKeyedTensorWrapper(PytorchModelHandlerKeyedTensor):

    def batch_elements_kwargs(self):
      return {'max_batch_size': 1}

另一種方法是讓所有張量具有相同的長度。此範例說明如何執行此操作。

ModelConfigModelTokenizer 會在初始化函式中載入。ModelConfig 用於定義模型架構,而 ModelTokenizer 用於標記化輸入資料。以下兩個參數會用於這些任務

這兩個參數都會在 Java PipelineOptions 中指定。

最後,我們在 Postprocess DoFn 中後處理模型預測。Postprocess DoFn 會傳回原始文字、句子的最後一個單字和預測的單字。

將 Python 程式碼編譯為套件

自訂 Python 程式碼需要寫在本機套件中,並編譯為 tarball。然後,Java 管道可以使用此套件。以下範例說明如何將 Python 套件編譯為 tarball

 pip install --upgrade build && python -m build --sdist

為了執行此操作,需要 setup.py。tarball 的路徑將用作 Java 管道的管道選項中的引數。

執行 Java 管道

Java 管道定義於 MultiLangRunInference 類別中。在此管道中,資料會從 Google Cloud Storage 讀取,接著應用跨語言 Python 轉換,最後將輸出寫回 Google Cloud Storage。

PythonExternalTransform 用於將跨語言 Python 轉換注入 Java 管道。PythonExternalTransform 接受一個字串參數,該參數為 Python 轉換的完整限定名稱。

withKwarg 方法用於指定 Python 轉換所需的參數。在此範例中,指定了 modelmodel_path 參數。這些參數用於複合 Python PTransform 的初始化函數中,如第一節所示。最後,withExtraPackages 方法用於指定 Python 轉換所需的額外 Python 依賴項。在此範例中,使用了 local_packages 列表,其中包含 Python 需求以及編譯後的 tarball 路徑。

若要執行管道,請使用以下命令

mvn compile exec:java -Dexec.mainClass=org.apache.beam.examples.MultiLangRunInference \
    -Dexec.args="--runner=DataflowRunner \
                 --project=$GCP_PROJECT\
                 --region=$GCP_REGION \
                 --gcpTempLocation=gs://$GCP_BUCKET/temp/ \
                 --inputFile=gs://$GCP_BUCKET/input/imdb_reviews.csv \
                 --outputFile=gs://$GCP_BUCKET/output/ouput.txt \
                 --modelPath=gs://$GCP_BUCKET/input/bert-model/bert-base-uncased.pth \
                 --modelName=$MODEL_NAME \
                 --localPackage=$LOCAL_PACKAGE" \
    -Pdataflow-runner

指定了標準的 Google Cloud 和 Runner 參數。inputFileoutputFile 參數用於指定輸入和輸出檔案。modelPathmodelName 自訂參數會傳遞至 PythonExternalTransform。最後,localPackage 參數用於指定已編譯的 Python 套件的路徑,其中包含自訂 Python 轉換。

最後的說明

請以此範例為基礎,建立其他自訂的多語言推論管道。您也可以使用其他 SDK。例如,Go 也具有可以進行跨語言轉換的包裝器。如需更多資訊,請參閱 Apache Beam 程式設計指南中的在 Go 管道中使用跨語言轉換

此範例中使用的完整程式碼可在 GitHub 上找到。