使用 WatchFilePattern 在 RunInference 中自動更新 ML 模型
此範例中的管道使用 RunInference PTransform
來使用 TensorFlow 模型對影像執行推論。它使用一個 側輸入 PCollection
,該側輸入發出 ModelMetadata
來更新模型。
使用側輸入,您可以即時更新模型(在 ModelHandler
組態物件中傳遞),即使 Beam 管道仍在執行。這可以透過利用 Beam 提供的其中一種模式(例如 WatchFilePattern
)來完成,也可以透過設定自訂的側輸入 PCollection
來完成,該側輸入定義模型更新的邏輯。
如需更多關於側輸入的資訊,請參閱 Apache Beam 程式設計指南中的側輸入章節。
此範例使用 WatchFilePattern
作為側輸入。WatchFilePattern
用於根據時間戳記監看符合 file_pattern
的檔案更新。它會發出最新的 ModelMetadata
,此 ModelMetadata
用於 RunInference PTransform
中,以自動更新 ML 模型,而無需停止 Beam 管道。
設定來源
若要讀取影像名稱,請使用 Pub/Sub 主題作為來源。Pub/Sub 主題發出一個 UTF-8
編碼的模型路徑,用於讀取和預處理影像以執行推論。
用於影像分割的模型
為了本範例的目的,請使用以 HDF5 格式儲存的 TensorFlow 模型。
預先處理影像以進行推論
Pub/Sub 主題會發出一個影像路徑。我們需要讀取和預處理影像,才能將其用於 RunInference。read_image
函式用於讀取影像以進行推論。
import io
from PIL import Image
from apache_beam.io.filesystems import FileSystems
import numpy
import tensorflow as tf
def read_image(image_file_name):
with FileSystems().open(image_file_name, 'r') as file:
data = Image.open(io.BytesIO(file.read())).convert('RGB')
img = data.resize((224, 224))
img = numpy.array(img) / 255.0
img_tensor = tf.cast(tf.convert_to_tensor(img[...]), dtype=tf.float32)
return img_tensor
現在,讓我們進入管道程式碼。
管道步驟:
- 從 Pub/Sub 主題取得影像名稱。
- 使用
read_image
函式讀取和預處理影像。 - 將影像傳遞至 RunInference
PTransform
。RunInference 會將model_handler
和model_metadata_pcoll
作為輸入參數。
對於 model_handler
,我們使用 TFModelHandlerTensor。
from apache_beam.ml.inference.tensorflow_inference import TFModelHandlerTensor
# initialize TFModelHandlerTensor with a .h5 model saved in a directory accessible by the pipeline.
tf_model_handler = TFModelHandlerTensor(model_uri='gs://<your-bucket>/<model_path.h5>')
model_metadata_pcoll
是 RunInference PTransform
的一個側輸入 PCollection
。此側輸入用於更新 model_handler
中的模型,而無需停止 beam 管道。我們將使用 WatchFilePattern
作為側輸入來監看符合 .h5
檔案的 glob 模式。
model_metadata_pcoll
預期一個與 AsSingleton 相容的 ModelMetadata 的 PCollection
。由於管道使用 WatchFilePattern
作為側輸入,它將處理視窗化,並將輸出包裝到 ModelMetadata
中。
管道開始處理資料之後,當您看到從 RunInference PTransform
發出一些輸出時,請將符合 file_pattern
的 .h5
TensorFlow
模型上傳到 Google Cloud Storage 儲存貯體。RunInference 將使用 WatchFilePattern
作為側輸入來更新 TFModelHandlerTensor
的 model_uri
。
注意:側輸入更新頻率是不確定的,並且更新之間可能會間隔較長的時間。
import apache_beam as beam
from apache_beam.ml.inference.utils import WatchFilePattern
from apache_beam.ml.inference.base import RunInference
with beam.Pipeline() as pipeline:
file_pattern = 'gs://<your-bucket>/*.h5'
pubsub_topic = '<topic_emitting_image_names>'
side_input_pcoll = (
pipeline
| "FilePatternUpdates" >> WatchFilePattern(file_pattern=file_pattern))
images_pcoll = (
pipeline
| "ReadFromPubSub" >> beam.io.ReadFromPubSub(topic=pubsub_topic)
| "DecodeBytes" >> beam.Map(lambda x: x.decode('utf-8'))
| "PreProcessImage" >> beam.Map(read_image)
)
inference_pcoll = (
images_pcoll
| "RunInference" >> RunInference(
model_handler=tf_model_handler,
model_metadata_pcoll=side_input_pcoll))
後處理 PredictionResult
物件
推論完成時,RunInference 會輸出一個包含 example
、inference
和 model_id
欄位的 PredictionResult
物件。model_id
用於識別執行推論時所使用的模型。
from apache_beam.ml.inference.base import PredictionResult
class PostProcessor(beam.DoFn):
"""
Process the PredictionResult to get the predicted label and model id used for inference.
"""
def process(self, element: PredictionResult) -> typing.Iterable[str]:
predicted_class = numpy.argmax(element.inference[0], axis=-1)
labels_path = tf.keras.utils.get_file(
'ImageNetLabels.txt',
'https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt'
)
imagenet_labels = numpy.array(open(labels_path).read().splitlines())
predicted_class_name = imagenet_labels[predicted_class]
return predicted_class_name.title(), element.model_id
post_processor_pcoll = (inference_pcoll | "PostProcessor" >> PostProcessor())
執行管道
result = pipeline.run().wait_until_finish()
注意:ModelMetaData
物件的 model_name
將會附加為 RunInference PTransform
計算的指標的前綴。
最後的說明
當您將側輸入與 RunInference PTransform
一起使用時,可以使用此範例作為模式,以自動更新模型而無需停止管道。您可以在 GitHub 上看到類似的 PyTorch 範例。
上次更新於 2024/10/31
您是否找到了您正在尋找的所有內容?
它是否都實用且清楚?您有想要更改的任何內容嗎?請告訴我們!