使用 WatchFilePattern 在 RunInference 中自動更新 ML 模型

此範例中的管道使用 RunInference PTransform 來使用 TensorFlow 模型對影像執行推論。它使用一個 側輸入 PCollection,該側輸入發出 ModelMetadata 來更新模型。

使用側輸入,您可以即時更新模型(在 ModelHandler 組態物件中傳遞),即使 Beam 管道仍在執行。這可以透過利用 Beam 提供的其中一種模式(例如 WatchFilePattern)來完成,也可以透過設定自訂的側輸入 PCollection 來完成,該側輸入定義模型更新的邏輯。

如需更多關於側輸入的資訊,請參閱 Apache Beam 程式設計指南中的側輸入章節。

此範例使用 WatchFilePattern 作為側輸入。WatchFilePattern 用於根據時間戳記監看符合 file_pattern 的檔案更新。它會發出最新的 ModelMetadata,此 ModelMetadata 用於 RunInference PTransform 中,以自動更新 ML 模型,而無需停止 Beam 管道。

設定來源

若要讀取影像名稱,請使用 Pub/Sub 主題作為來源。Pub/Sub 主題發出一個 UTF-8 編碼的模型路徑,用於讀取和預處理影像以執行推論。

用於影像分割的模型

為了本範例的目的,請使用以 HDF5 格式儲存的 TensorFlow 模型。

預先處理影像以進行推論

Pub/Sub 主題會發出一個影像路徑。我們需要讀取和預處理影像,才能將其用於 RunInference。read_image 函式用於讀取影像以進行推論。

import io
from PIL import Image
from apache_beam.io.filesystems import FileSystems
import numpy
import tensorflow as tf

def read_image(image_file_name):
  with FileSystems().open(image_file_name, 'r') as file:
    data = Image.open(io.BytesIO(file.read())).convert('RGB')
  img = data.resize((224, 224))
  img = numpy.array(img) / 255.0
  img_tensor = tf.cast(tf.convert_to_tensor(img[...]), dtype=tf.float32)
  return img_tensor

現在,讓我們進入管道程式碼。

管道步驟:

  1. 從 Pub/Sub 主題取得影像名稱。
  2. 使用 read_image 函式讀取和預處理影像。
  3. 將影像傳遞至 RunInference PTransform。RunInference 會將 model_handlermodel_metadata_pcoll 作為輸入參數。

對於 model_handler,我們使用 TFModelHandlerTensor

from apache_beam.ml.inference.tensorflow_inference import TFModelHandlerTensor
# initialize TFModelHandlerTensor with a .h5 model saved in a directory accessible by the pipeline.
tf_model_handler = TFModelHandlerTensor(model_uri='gs://<your-bucket>/<model_path.h5>')

model_metadata_pcoll 是 RunInference PTransform 的一個側輸入 PCollection。此側輸入用於更新 model_handler 中的模型,而無需停止 beam 管道。我們將使用 WatchFilePattern 作為側輸入來監看符合 .h5 檔案的 glob 模式。

model_metadata_pcoll 預期一個與 AsSingleton 相容的 ModelMetadata 的 PCollection。由於管道使用 WatchFilePattern 作為側輸入,它將處理視窗化,並將輸出包裝到 ModelMetadata 中。

管道開始處理資料之後,當您看到從 RunInference PTransform 發出一些輸出時,請將符合 file_pattern.h5 TensorFlow 模型上傳到 Google Cloud Storage 儲存貯體。RunInference 將使用 WatchFilePattern 作為側輸入來更新 TFModelHandlerTensormodel_uri

注意:側輸入更新頻率是不確定的,並且更新之間可能會間隔較長的時間。

import apache_beam as beam
from apache_beam.ml.inference.utils import WatchFilePattern
from apache_beam.ml.inference.base import RunInference
with beam.Pipeline() as pipeline:

  file_pattern = 'gs://<your-bucket>/*.h5'
  pubsub_topic = '<topic_emitting_image_names>'

  side_input_pcoll = (
    pipeline
    | "FilePatternUpdates" >> WatchFilePattern(file_pattern=file_pattern))

  images_pcoll = (
    pipeline
    | "ReadFromPubSub" >> beam.io.ReadFromPubSub(topic=pubsub_topic)
    | "DecodeBytes" >> beam.Map(lambda x: x.decode('utf-8'))
    | "PreProcessImage" >> beam.Map(read_image)
  )

  inference_pcoll = (
    images_pcoll
    | "RunInference" >> RunInference(
    model_handler=tf_model_handler,
    model_metadata_pcoll=side_input_pcoll))

後處理 PredictionResult 物件

推論完成時,RunInference 會輸出一個包含 exampleinferencemodel_id 欄位的 PredictionResult 物件。model_id 用於識別執行推論時所使用的模型。

from apache_beam.ml.inference.base import PredictionResult

class PostProcessor(beam.DoFn):
  """
  Process the PredictionResult to get the predicted label and model id used for inference.
  """
  def process(self, element: PredictionResult) -> typing.Iterable[str]:
    predicted_class = numpy.argmax(element.inference[0], axis=-1)
    labels_path = tf.keras.utils.get_file(
        'ImageNetLabels.txt',
        'https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt'
    )
    imagenet_labels = numpy.array(open(labels_path).read().splitlines())
    predicted_class_name = imagenet_labels[predicted_class]
    return predicted_class_name.title(), element.model_id

post_processor_pcoll = (inference_pcoll | "PostProcessor" >> PostProcessor())

執行管道

result = pipeline.run().wait_until_finish()

注意ModelMetaData 物件的 model_name 將會附加為 RunInference PTransform 計算的指標的前綴。

最後的說明

當您將側輸入與 RunInference PTransform 一起使用時,可以使用此範例作為模式,以自動更新模型而無需停止管道。您可以在 GitHub 上看到類似的 PyTorch 範例。