エンティティ別トレーニング
このパイプラインの例は、Beamにおけるエンティティ別トレーニングを示すことを目的としています。エンティティ別トレーニングとは、すべてのエンティティに対して単一のモデルをトレーニングするのではなく、個々のエンティティごとに機械学習モデルをトレーニングするプロセスです。このアプローチでは、各エンティティについて、そのエンティティ固有のデータに基づいて個別のモデルがトレーニングされます。エンティティ別トレーニングは、次のシナリオで役立ちます。
個別のモデルを使用することで、各グループに対してよりパーソナライズされた、きめ細やかな予測が可能になります。各グループには、単一の大きなモデルでは効果的に捉えられない可能性のある、異なる特性、パターン、行動がある可能性があります。
個別のモデルを使用することで、全体のモデルの複雑さを軽減し、効率を高めることもできます。全体のモデルは、すべてのグループにわたるあらゆる可能性のある特性とパターンを考慮しようとするのではなく、個々のグループの特定の特性とパターンにのみ焦点を当てる必要があります。
個別のモデルを使用することで、バイアスと公平性の問題に対処できます。多様なデータセットでトレーニングされた単一のモデルは、特定のグループにはうまく一般化されない可能性があるため、各グループに対する個別のモデルは、バイアスの影響を軽減できます。
このアプローチは、全体の人口の限られたセグメントに固有の問題を検出しやすいため、本番環境で好まれることがよくあります。
より小さなモデルとデータセットを使用する場合、トレーニングと再トレーニングのプロセスはより迅速かつ効率的に完了できます。トレーニングと再トレーニングの両方を並行して行うことができ、結果を待つ時間を短縮できます。さらに、より小さなモデルとデータセットは、リソースをあまり消費しないという利点もあります。そのため、安価なハードウェアで実行できます。
データセット
この例では、成人国勢調査所得データセットを使用しています。このデータセットには、人口統計学的特性、雇用状況、所得レベルなど、個人に関する情報が含まれています。このデータセットには、年齢、教育、職業、週あたりの労働時間などのカテゴリ変数と数値変数の両方が含まれており、個人の所得が50,000米ドルを超えているかどうかの2値ラベルも含まれています。このデータセットの主な目的は、分類タスクに使用することです。モデルは、提供された特徴に基づいて、個人の所得があるしきい値を超えているかどうかを予測します。パイプラインは、入力として`adult.data` CSVファイルが必要です。このファイルはこちらからダウンロードできます。
パイプラインの実行
最初に、必要なパッケージ`apache-beam==2.44.0`、`scikit-learn==1.0.2`、`pandas==1.3.5`をインストールします。GitHubでコードを確認できます。`python per_entity_training.py --input path/to/adult.data`を使用してください。
パイプラインのトレーニング
パイプラインは、次の主要なステップに分割できます。
- 指定された入力パスからデータを読み取ります。
- いくつかの基準に基づいてデータをフィルタリングします。
- 教育レベルに基づいてキーを作成します。
- 生成されたキーに基づいてデータセットをグループ化します。
- データセットを前処理します。
- 教育レベルごとにモデルをトレーニングします。
- トレーニングされたモデルを保存します。
次のコードスニペットには、詳細な手順が含まれています。
with beam.Pipeline(options=pipeline_options) as pipeline:
_ = (
pipeline | "Read Data" >> beam.io.ReadFromText(known_args.input)
| "Split data to make List" >> beam.Map(lambda x: x.split(','))
| "Filter rows" >> beam.Filter(custom_filter)
| "Create Key" >> beam.ParDo(CreateKey())
| "Group by education" >> beam.GroupByKey()
| "Prepare Data" >> beam.ParDo(PrepareDataforTraining())
| "Train Model" >> beam.ParDo(TrainModel())
|
"Save" >> fileio.WriteToFiles(path=known_args.output, sink=ModelSink()))
最終更新日:2024/10/31
お探しのものが見つかりましたか?
すべて役に立ち、分かりやすかったですか?変更したいことはありますか?お知らせください!