LightGBM (Light Gradient Boosting Machine) は、決定木をベースに構築された分散型の勾配ブースティングフレームワークです。PolarDB for MySQL では、外部の ML 環境を必要とせず、CREATE MODEL、EVALUATE、PREDICT 文を使用して SQL から直接 LightGBM モデルを実行できます。
LightGBM は、トレーニングデータのメモリフットプリントを削減し、ノード間の通信コストを低減させ、マルチノード弾性並列クエリ (ePQ) の効率を向上させ、データコンピューティングにおいて線形加速を実現します。
利用シーン
LightGBM は、勾配ブースティング決定木 (GBDT)、ランダムフォレスト、ロジスティック回帰の 3 つのモデルファミリーをカバーしており、以下の用途に適しています。
二項分類:ユーザーがクリックするか購入するかなど、二値の結果を予測します。
多クラス分類:製品カテゴリなど、アイテムを複数のカテゴリのいずれかに割り当てます。
ランキングとソート:予測された関連性に基づいて結果を順序付けします。
例:レコメンデーションシステムにおけるクリック予測
パーソナライズされた製品レコメンデーションシステムでは、過去の行動 (クリック、非クリック、購入) に基づいて、ユーザーがアイテムをクリックまたは購入するかどうかを予測する必要があります。入力特徴量には通常、以下が含まれます。
カテゴリカル特徴量:性別 (
male、female) や製品カテゴリ (clothing、toys、electronics) などの文字列値。数値特徴量:ユーザーアクティビティスコアや製品価格などの整数値または浮動小数点値。
LightGBM モデルの設定
パラメーターを設定する前に、以下の決定事項を順番に検討してください。
学習目的 (`loss`) の選択:2 クラス問題には
binary、多クラス問題にはmulticlass、連続値の出力には回帰バリアントを選択します。弱学習器 (`boosting_type`) の選択:まずデフォルトの
gbdtから始めます。過学習を防ぐ必要がある場合はdartに、トレーニングを高速化したい場合はgossに、線形モデルを使用したい場合はgblinearに切り替えます。木の複雑さ (`num_leaves`, `max_depth`) の設定:過学習を避けるため、
num_leavesは2^max_depth未満に設定します。正則化の調整 (`learning_rate`, `subsample`, `min_samples_leaf`):
learning_rateを低くするとモデルはより安定しますが、より多くの反復 (n_estimators) が必要になります。subsampleを 1 未満に設定すると、モデル作成にサンプルの指定された割合のみが使用されます。不明な場合は AutoML (`automl`) を有効化:
automl=Trueを設定すると、PolarDB が最適なパラメーターの組み合わせを自動的に検索します。
パラメーター
CREATE MODEL 文の model_parameter に対応するパラメーターは以下の通りです。
boosting_type のような文字列値のパラメーターの場合、値を単一引用符で囲みます。例:boosting_type='gbdt'。
| パラメーター | タイプ | デフォルト | 説明 |
|---|---|---|---|
boosting_type | 文字列 | gbdt | 弱学習器の種類。gbdt:勾配ブースティング決定木 (推奨のデフォルト)。rf:ランダムフォレスト。dart:ドロップアウトを使用して過学習を低減します。goss:勾配ベースの片側サンプリング。高速ですが、学習不足を引き起こす可能性があります。gblinear:線形モデル。 |
n_estimators | 整数 | 100 | ブースティングの反復回数。 |
loss | 文字列 | binary | 学習目的。binary:二項分類。multiclass:多クラス分類。regression:L2 正則化回帰。regression_l1:L1 正則化回帰。 |
num_leaves | 整数 | 128 | 木あたりの最大葉数。モデルの複雑さを制御します。 |
max_depth | 整数 | 7 | 木の最大深度。-1 に設定すると、深度の制限がなくなります。木が深いほど多くのパターンを捉えられますが、過学習のリスクがあります。 |
learning_rate | 浮動 | 0.06 | 各反復におけるステップサイズ。 |
max_leaf_nodes | 整数または空白 | 空白 | リーフノードの最大数。空白は制限なしを意味します。 |
min_samples_leaf | 整数 | 20 | リーフノードで必要とされる最小サンプル数。サンプル数がこの値を下回る場合、リーフノード (およびその兄弟ノード) は剪定されます。 |
subsample | 浮動 | 1 | 各反復で使用されるトレーニングサンプルの割合。有効値の範囲は 0~1 です。1 未満の値の場合、モデル作成に指定された割合のサンプルのみが使用されます。 |
max_features | 浮動 | 1 | ノードを分割する際に考慮される特徴量の割合。有効値の範囲は 0~1 です。 |
random_state | 整数 | 1 | 乱数シード。この値を変更すると、木の構築とデータ分割に影響が及び、結果が異なる場合があります。 |
model_type | 文字列 | pkl | トレーニング済みモデルのストレージフォーマット。pkl:PKL ファイル。pmml:PMML (Predictive Model Markup Language) ファイル。完全な木構造が含まれており、確認に役立ちます。 |
n_jobs | 整数 | 4 | トレーニングに使用されるスレッド数。スレッド数が多いほど、トレーニング時間が短縮されます。 |
is_unbalance | ブール値 | False | クラスの不均衡に対処するために、少数派クラスの重み付けを増やすかどうか。True に設定すると、あるクラスのサンプル数が他のクラスより著しく少ない場合に対応します。 |
categorical_feature | 文字列配列 | — | カテゴリカル特徴量の列名をカンマ区切りの文字列で指定します。ほとんどの場合、LightGBM はカテゴリカル特徴量を自動的に検出します。自動検出が不十分な場合にこのパラメーターを上書きします。例:categorical_feature='AirportTo,DayOfWeek'。 |
automl | ブール値 | False | 自動パラメーターチューニングを有効にするかどうか。True に設定すると、loss で指定されたメトリックの改善が停止した時点で早期停止が適用されます。 |
automl_train_tag | 文字列 | — | automl_column 内でトレーニング行を識別するラベル値。 |
automl_test_tag | 文字列 | — | automl_column 内でテスト行を識別するラベル値。トレーニングセットは、テストセットの 4~9 倍のサイズである必要があります。 |
automl_column | 文字列 | — | AutoML のために行をトレーニングセットとテストセットに分割するために使用される列名。設定した場合、パラメーターに automl_ プレフィックスを付けて検索空間を定義します。例:automl_learning_rate='0.05,0.04,0.03,0.01' は 4 つの値で検索します。automl_train_tag と automl_test_tag が必要です。 |
例
以下の例では、db4ai.airlines データセットを使用します。すべての SQL 文では、クエリを PolarDB AI エンジンにルーティングするために /*polar4ai*/ ヒントワードを使用します。
LightGBM モデルの作成
/*polar4ai*/CREATE MODEL airline_gbm WITH
(model_class = 'lightgbm',
x_cols = 'Airline,Flight,AirportFrom,AirportTo,DayOfWeek,Time,Length',
y_cols='Delay',model_parameter=(boosting_type='gbdt'))
AS (SELECT * FROM db4ai.airlines);モデルの評価
/*polar4ai*/SELECT Airline FROM EVALUATE(MODEL airline_gbm,
SELECT * FROM db4ai.airlines LIMIT 20) WITH
(x_cols = 'Airline,Flight,AirportFrom,AirportTo,DayOfWeek,Time,Length',y_cols='Delay',metrics='acc');予測の実行
/*polar4ai*/SELECT Airline FROM PREDICT(MODEL airline_gbm,
SELECT * FROM db4ai.airlines limit 20) WITH
(x_cols = 'Airline,Flight,AirportFrom,AirportTo,DayOfWeek,Time,Length');