WatchFilePattern を使用して RunInference で ML モデルを自動更新する

この例のパイプラインでは、RunInference PTransform を使用して、TensorFlow モデルを使用して画像に対して推論を実行します。モデルを更新する ModelMetadata を出力する サイド入力 PCollection を使用します。

サイド入力を使用すると、Beam パイプラインがまだ実行中でも、モデル(ModelHandler 構成オブジェクトで渡されます)をリアルタイムで更新できます。これは、WatchFilePattern などの Beam が提供するパターンのいずれかを利用するか、モデル更新のロジックを定義するカスタムサイド入力 PCollection を構成することで実行できます。

サイド入力の詳細については、Apache Beam プログラミングガイドのサイド入力セクションを参照してください。

この例では、サイド入力として WatchFilePattern を使用します。WatchFilePattern は、タイムスタンプに基づいて file_pattern に一致するファイルの更新を監視するために使用されます。RunInference PTransform で使用される最新のModelMetadata を出力し、Beam パイプラインを停止することなく ML モデルを自動的に更新します。

ソースの設定

画像名を読み取るには、ソースとして Pub/Sub トピックを使用します。Pub/Sub トピックは、推論を実行するために画像を読み取り、事前処理するために使用される UTF-8 でエンコードされたモデルパスを出力します。

画像セグメンテーションのモデル

この例では、HDF5 形式で保存された TensorFlow モデルを使用します。

推論用の画像の事前処理

Pub/Sub トピックは、画像パスを出力します。RunInference で使用するために画像を読み取り、事前処理する必要があります。read_image 関数は、推論用の画像を読み取るために使用されます。

import io
from PIL import Image
from apache_beam.io.filesystems import FileSystems
import numpy
import tensorflow as tf

def read_image(image_file_name):
  with FileSystems().open(image_file_name, 'r') as file:
    data = Image.open(io.BytesIO(file.read())).convert('RGB')
  img = data.resize((224, 224))
  img = numpy.array(img) / 255.0
  img_tensor = tf.cast(tf.convert_to_tensor(img[...]), dtype=tf.float32)
  return img_tensor

それでは、パイプラインコードを見ていきましょう。

パイプラインステップ:

  1. Pub/Sub トピックから画像名を取得します。
  2. read_image 関数を使用して画像を読み取り、事前処理します。
  3. 画像を RunInference PTransform に渡します。RunInference は、入力パラメーターとして model_handlermodel_metadata_pcoll を受け取ります。

model_handler には、TFModelHandlerTensor を使用します。

from apache_beam.ml.inference.tensorflow_inference import TFModelHandlerTensor
# initialize TFModelHandlerTensor with a .h5 model saved in a directory accessible by the pipeline.
tf_model_handler = TFModelHandlerTensor(model_uri='gs://<your-bucket>/<model_path.h5>')

model_metadata_pcoll は、RunInference PTransform への サイド入力 PCollection です。このサイド入力は、beam パイプラインを停止することなく、model_handler 内のモデルを更新するために使用されます。.h5 ファイルに一致する glob パターンを監視するためのサイド入力として WatchFilePattern を使用します。

model_metadata_pcoll は、AsSingleton と互換性のある ModelMetadata の PCollection を予期します。パイプラインはサイド入力として WatchFilePattern を使用するため、ウィンドウ処理を行い、出力を ModelMetadata にラップします。

パイプラインがデータの処理を開始し、RunInference PTransform からいくつかの出力が出力されたら、file_pattern に一致する .h5 TensorFlow モデルを Google Cloud Storage バケットにアップロードします。RunInference は、サイド入力として WatchFilePattern を使用して、TFModelHandlerTensormodel_uri を更新します。

: サイド入力の更新頻度は非決定的であり、更新間隔が長くなる可能性があります。

import apache_beam as beam
from apache_beam.ml.inference.utils import WatchFilePattern
from apache_beam.ml.inference.base import RunInference
with beam.Pipeline() as pipeline:

  file_pattern = 'gs://<your-bucket>/*.h5'
  pubsub_topic = '<topic_emitting_image_names>'

  side_input_pcoll = (
    pipeline
    | "FilePatternUpdates" >> WatchFilePattern(file_pattern=file_pattern))

  images_pcoll = (
    pipeline
    | "ReadFromPubSub" >> beam.io.ReadFromPubSub(topic=pubsub_topic)
    | "DecodeBytes" >> beam.Map(lambda x: x.decode('utf-8'))
    | "PreProcessImage" >> beam.Map(read_image)
  )

  inference_pcoll = (
    images_pcoll
    | "RunInference" >> RunInference(
    model_handler=tf_model_handler,
    model_metadata_pcoll=side_input_pcoll))

PredictionResult オブジェクトの後処理

推論が完了すると、RunInference は、exampleinference、および model_id フィールドを含む PredictionResult オブジェクトを出力します。model_id は、推論を実行するために使用されるモデルを識別するために使用されます。

from apache_beam.ml.inference.base import PredictionResult

class PostProcessor(beam.DoFn):
  """
  Process the PredictionResult to get the predicted label and model id used for inference.
  """
  def process(self, element: PredictionResult) -> typing.Iterable[str]:
    predicted_class = numpy.argmax(element.inference[0], axis=-1)
    labels_path = tf.keras.utils.get_file(
        'ImageNetLabels.txt',
        'https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt'
    )
    imagenet_labels = numpy.array(open(labels_path).read().splitlines())
    predicted_class_name = imagenet_labels[predicted_class]
    return predicted_class_name.title(), element.model_id

post_processor_pcoll = (inference_pcoll | "PostProcessor" >> PostProcessor())

パイプラインの実行

result = pipeline.run().wait_until_finish()

: ModelMetaData オブジェクトの model_name は、RunInference PTransform によって計算されたメトリクスのプレフィックスとして付加されます。

最後に

パイプラインを停止することなくモデルを自動更新するために、RunInference PTransform でサイド入力を使用する場合、この例をパターンとして使用できます。PyTorch の同様の例は GitHub で確認できます。