WatchFilePattern を使用して RunInference で ML モデルを自動更新する
この例のパイプラインでは、RunInference PTransform
を使用して、TensorFlow モデルを使用して画像に対して推論を実行します。モデルを更新する ModelMetadata
を出力する サイド入力 PCollection
を使用します。
サイド入力を使用すると、Beam パイプラインがまだ実行中でも、モデル(ModelHandler
構成オブジェクトで渡されます)をリアルタイムで更新できます。これは、WatchFilePattern
などの Beam が提供するパターンのいずれかを利用するか、モデル更新のロジックを定義するカスタムサイド入力 PCollection
を構成することで実行できます。
サイド入力の詳細については、Apache Beam プログラミングガイドのサイド入力セクションを参照してください。
この例では、サイド入力として WatchFilePattern
を使用します。WatchFilePattern
は、タイムスタンプに基づいて file_pattern
に一致するファイルの更新を監視するために使用されます。RunInference PTransform
で使用される最新のModelMetadata
を出力し、Beam パイプラインを停止することなく ML モデルを自動的に更新します。
ソースの設定
画像名を読み取るには、ソースとして Pub/Sub トピックを使用します。Pub/Sub トピックは、推論を実行するために画像を読み取り、事前処理するために使用される UTF-8
でエンコードされたモデルパスを出力します。
画像セグメンテーションのモデル
この例では、HDF5 形式で保存された TensorFlow モデルを使用します。
推論用の画像の事前処理
Pub/Sub トピックは、画像パスを出力します。RunInference で使用するために画像を読み取り、事前処理する必要があります。read_image
関数は、推論用の画像を読み取るために使用されます。
import io
from PIL import Image
from apache_beam.io.filesystems import FileSystems
import numpy
import tensorflow as tf
def read_image(image_file_name):
with FileSystems().open(image_file_name, 'r') as file:
data = Image.open(io.BytesIO(file.read())).convert('RGB')
img = data.resize((224, 224))
img = numpy.array(img) / 255.0
img_tensor = tf.cast(tf.convert_to_tensor(img[...]), dtype=tf.float32)
return img_tensor
それでは、パイプラインコードを見ていきましょう。
パイプラインステップ:
- Pub/Sub トピックから画像名を取得します。
read_image
関数を使用して画像を読み取り、事前処理します。- 画像を RunInference
PTransform
に渡します。RunInference は、入力パラメーターとしてmodel_handler
とmodel_metadata_pcoll
を受け取ります。
model_handler
には、TFModelHandlerTensor を使用します。
from apache_beam.ml.inference.tensorflow_inference import TFModelHandlerTensor
# initialize TFModelHandlerTensor with a .h5 model saved in a directory accessible by the pipeline.
tf_model_handler = TFModelHandlerTensor(model_uri='gs://<your-bucket>/<model_path.h5>')
model_metadata_pcoll
は、RunInference PTransform
への サイド入力 PCollection
です。このサイド入力は、beam パイプラインを停止することなく、model_handler
内のモデルを更新するために使用されます。.h5
ファイルに一致する glob パターンを監視するためのサイド入力として WatchFilePattern
を使用します。
model_metadata_pcoll
は、AsSingleton と互換性のある ModelMetadata の PCollection
を予期します。パイプラインはサイド入力として WatchFilePattern
を使用するため、ウィンドウ処理を行い、出力を ModelMetadata
にラップします。
パイプラインがデータの処理を開始し、RunInference PTransform
からいくつかの出力が出力されたら、file_pattern
に一致する .h5
TensorFlow
モデルを Google Cloud Storage バケットにアップロードします。RunInference は、サイド入力として WatchFilePattern
を使用して、TFModelHandlerTensor
の model_uri
を更新します。
注: サイド入力の更新頻度は非決定的であり、更新間隔が長くなる可能性があります。
import apache_beam as beam
from apache_beam.ml.inference.utils import WatchFilePattern
from apache_beam.ml.inference.base import RunInference
with beam.Pipeline() as pipeline:
file_pattern = 'gs://<your-bucket>/*.h5'
pubsub_topic = '<topic_emitting_image_names>'
side_input_pcoll = (
pipeline
| "FilePatternUpdates" >> WatchFilePattern(file_pattern=file_pattern))
images_pcoll = (
pipeline
| "ReadFromPubSub" >> beam.io.ReadFromPubSub(topic=pubsub_topic)
| "DecodeBytes" >> beam.Map(lambda x: x.decode('utf-8'))
| "PreProcessImage" >> beam.Map(read_image)
)
inference_pcoll = (
images_pcoll
| "RunInference" >> RunInference(
model_handler=tf_model_handler,
model_metadata_pcoll=side_input_pcoll))
PredictionResult
オブジェクトの後処理
推論が完了すると、RunInference は、example
、inference
、および model_id
フィールドを含む PredictionResult
オブジェクトを出力します。model_id
は、推論を実行するために使用されるモデルを識別するために使用されます。
from apache_beam.ml.inference.base import PredictionResult
class PostProcessor(beam.DoFn):
"""
Process the PredictionResult to get the predicted label and model id used for inference.
"""
def process(self, element: PredictionResult) -> typing.Iterable[str]:
predicted_class = numpy.argmax(element.inference[0], axis=-1)
labels_path = tf.keras.utils.get_file(
'ImageNetLabels.txt',
'https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt'
)
imagenet_labels = numpy.array(open(labels_path).read().splitlines())
predicted_class_name = imagenet_labels[predicted_class]
return predicted_class_name.title(), element.model_id
post_processor_pcoll = (inference_pcoll | "PostProcessor" >> PostProcessor())
パイプラインの実行
result = pipeline.run().wait_until_finish()
注: ModelMetaData
オブジェクトの model_name
は、RunInference PTransform
によって計算されたメトリクスのプレフィックスとして付加されます。
最後に
パイプラインを停止することなくモデルを自動更新するために、RunInference PTransform
でサイド入力を使用する場合、この例をパターンとして使用できます。PyTorch の同様の例は GitHub で確認できます。
最終更新日: 2024/10/31
お探しのものはすべて見つかりましたか?
すべてが役に立ち、明確でしたか?何か変更したいことはありますか?お知らせください!