すべてのプロダクト
Search
ドキュメントセンター

AnalyticDB:CatBoost を使用して GBDT モデルをトレーニングする

最終更新日:Apr 08, 2025

広告のクリック予測、ゲームユーザーの支払いまたは解約予測、自動メール分類などのデータマイニングシナリオでは、既存データに基づいて分類モデルをトレーニングし、後続の動作を予測する必要があります。AnalyticDB for MySQL Spark と CatBoost を使用して、データの分類と予測のための勾配ブースティング決定木(GBDT)モデルをトレーニングできます。

前提条件

手順

ステップ 1:Maven 依存関係を準備して OSS にアップロードする

  1. 次のいずれかの方法を使用して、Maven 依存関係を取得します。

    • Maven 依存関係の JAR パッケージをダウンロードする。

    • pom.xml ファイルの Maven 依存関係を IntelliJ IDEA で構成します。サンプルコード:

      <project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
               xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
          <modelVersion>4.0.0</modelVersion>
      
          <groupId>com.aliyun.adb.spark</groupId>
          <artifactId>CatBoostDemo</artifactId>
          <version>1.0</version>
          <packaging>jar</packaging>
      
          <properties>
              <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
          </properties>
      
          <dependencies>
              <dependency>
                  <groupId>ai.catboost</groupId>
                  <artifactId>catboost-spark_3.5_2.12</artifactId>
                  <version>1.2.7</version>
              </dependency>
              <dependency>
                  <groupId>org.apache.spark</groupId>
                  <artifactId>spark-launcher_2.12</artifactId>
                  <version>3.5.1</version>
                  <scope>provided</scope>
              </dependency>
              <dependency>
                  <groupId>org.apache.spark</groupId>
                  <artifactId>spark-core_2.12</artifactId>
                  <version>3.5.1</version>
                  <scope>provided</scope>
              </dependency>
              <dependency>
                  <groupId>org.apache.spark</groupId>
                  <artifactId>spark-sql_2.12</artifactId>
                  <version>3.5.1</version>
                  <scope>provided</scope>
              </dependency>
              <dependency>
                  <groupId>com.aliyun.oss</groupId>
                  <artifactId>aliyun-sdk-oss</artifactId>
                  <version>3.16.2</version>
                  <scope>provided</scope>
              </dependency>
              <dependency>
                  <groupId>org.apache.spark</groupId>
                  <artifactId>spark-mllib_2.12</artifactId>
                  <version>3.5.1</version>
                  <scope>provided</scope>
              </dependency>
          </dependencies>
      
          <build>
              <plugins>
                  <plugin>
                      <groupId>net.alchim31.maven</groupId>
                      <artifactId>scala-maven-plugin</artifactId>
                      <version>4.4.0</version>
                      <executions>
                          <execution>
                              <goals>
                                  <goal>compile</goal>
                                  <goal>testCompile</goal>
                              </goals>
                          </execution>
                      </executions>
                  </plugin>
                  <plugin>
                      <groupId>org.apache.maven.plugins</groupId>
                      <artifactId>maven-shade-plugin</artifactId>
                      <version>3.1.1</version>
                      <configuration>
                          <createDependencyReducedPom>false</createDependencyReducedPom>
                      </configuration>
                      <executions>
                          <execution>
                              <phase>package</phase>
                              <goals>
                                  <goal>shade</goal>
                              </goals>
                          </execution>
                      </executions>
                  </plugin>
              </plugins>
          </build>
      </project>
  2. IntelliJ IDEA の pom.xml ファイルで Maven 依存関係を構成した場合は、pom.xml ファイル(mvn clean package -DskipTests コマンドを実行して依存関係をパッケージ化します。リンクから Maven 依存関係のパッケージをダウンロードした場合は、この手順をスキップします。依存関係。Maven 依存関係のパッケージをリンクからダウンロードした場合は、このステップをスキップします。

  3. Maven 依存関係を OSS にアップロードする。

ステップ 2:プログラムを作成して OSS にアップロードするOSS へ

  1. プログラムを作成する。

    Scala

    package com.aliyun.adb
    
    import ai.catboost.spark.{CatBoostClassificationModel, CatBoostClassifier, Pool}
    import org.apache.spark.ml.linalg.{SQLDataTypes, Vectors}
    import org.apache.spark.sql.types.{StringType, StructField, StructType}
    import org.apache.spark.sql.{Row, SparkSession}
    
    object CatBoostDemo {
      def main(args: Array[String]): Unit = {
        val spark = SparkSession
          .builder()
          .appName("CatBoost Example") // CatBoost の例
          .getOrCreate()
    
        val srcDataSchema = Seq(
          StructField("features", SQLDataTypes.VectorType),
          StructField("label", StringType)
        )
    
        val trainData = Seq(
          Row(Vectors.dense(0.1, 0.2, 0.11), "1"),
          Row(Vectors.dense(0.97, 0.82, 0.33), "2"),
          Row(Vectors.dense(0.13, 0.22, 0.23), "1"),
          Row(Vectors.dense(0.8, 0.62, 0.0), "0")
        )
    
        val trainDf = spark.createDataFrame(spark.sparkContext.parallelize(trainData), StructType(srcDataSchema))
        val trainPool = new Pool(trainDf)
    
        val evalData = Seq(
          Row(Vectors.dense(0.22, 0.33, 0.9), "2"),
          Row(Vectors.dense(0.11, 0.1, 0.21), "0"),
          Row(Vectors.dense(0.77, 0.0, 0.0), "1")
        )
    
        val evalDf = spark.createDataFrame(spark.sparkContext.parallelize(evalData), StructType(srcDataSchema))
        val evalPool = new Pool(evalDf)
    
        val classifier = new CatBoostClassifier
    
        // モデルをトレーニングする
        val model = classifier.fit(trainPool, Array[Pool](evalPool))
    
        // モデルを適用する
        val predictions = model.transform(evalPool.data)
        println("predictions") // 予測
        predictions.show()
    
        // モデルを CatBoost ネイティブ形式のローカルファイルとして保存する
        val savedNativeModelPath = "./multiclass_model.cbm"
        model.saveNativeModel(savedNativeModelPath)
    
        // モデルを CatBoost ネイティブ形式のローカルファイルとしてロードする
        val loadedNativeModel = CatBoostClassificationModel.loadNativeModel(savedNativeModelPath)
    
        val predictionsFromLoadedNativeModel = loadedNativeModel.transform(evalPool.data)
        println("predictionsFromLoadedNativeModel") // ロードされたネイティブモデルからの予測
        predictionsFromLoadedNativeModel.show()
        System.exit(0)
      }
    }
    

    Python

    重要

    CatBoost は、SparkSession の初期化後のコンテキスト環境に依存します。この場合、SparkSession オブジェクトを作成した後にのみ、import catboost_spark 文を実行できます。事前に文を実行すると、Maven 依存関係のロードに失敗する可能性があります。

    from pyspark.sql import Row,SparkSession
    from pyspark.ml.linalg import Vectors, VectorUDT
    from pyspark.sql.types import *
    
    
    spark = SparkSession.builder.getOrCreate()
    
    import catboost_spark
    
    srcDataSchema = [
        StructField("features", VectorUDT()),
        StructField("label", StringType())
    ]
    
    trainData = [
        Row(Vectors.dense(0.1, 0.2, 0.11), "1"),
        Row(Vectors.dense(0.97, 0.82, 0.33), "2"),
        Row(Vectors.dense(0.13, 0.22, 0.23), "1"),
        Row(Vectors.dense(0.8, 0.62, 0.0), "0")
    ]
    
    trainDf = spark.createDataFrame(spark.sparkContext.parallelize(trainData), StructType(srcDataSchema))
    trainPool = catboost_spark.Pool(trainDf)
    
    evalData = [
        Row(Vectors.dense(0.22, 0.33, 0.9), "2"),
        Row(Vectors.dense(0.11, 0.1, 0.21), "0"),
        Row(Vectors.dense(0.77, 0.0, 0.0), "1")
    ]
    
    evalDf = spark.createDataFrame(spark.sparkContext.parallelize(evalData), StructType(srcDataSchema))
    evalPool = catboost_spark.Pool(evalDf)
    
    classifier = catboost_spark.CatBoostClassifier()
    
    # モデルをトレーニングする
    model = classifier.fit(trainPool, evalDatasets=[evalPool])
    
    # モデルを適用する
    predictions = model.transform(evalPool.data)
    predictions.show()
    
    # モデルを CatBoost ネイティブ形式のローカルファイルとして保存する
    savedNativeModelPath = './multiclass_model.cbm'
    model.saveNativeModel(savedNativeModelPath)
    # モデルを CatBoost ネイティブ形式のローカルファイルとしてロードする
    
    loadedNativeModel = catboost_spark.CatBoostClassificationModel.loadNativeModel(savedNativeModelPath)
    
    predictionsFromLoadedNativeModel = loadedNativeModel.transform(evalPool.data)
    predictionsFromLoadedNativeModel.show()
    
  2. Scala プログラムを作成した場合は、プログラムを JAR ファイルにパッケージ化します。Python プログラムを作成した場合は、この手順をスキップします。

  3. JAR パッケージまたは .py ファイルを OSS にアップロードする。

ステップ 3:Spark ジョブを送信する

  1. AnalyticDB for MySQL コンソールにログインします。コンソールの左上隅で、リージョンを選択します。左側のナビゲーションウィンドウで、クラスターリストをクリックします。クラスターリスト ページで、エディションタブをクリックします。管理するクラスタを見つけて、クラスタ ID をクリックします。

  2. 左側のナビゲーションウィンドウで、[ジョブ開発] > [Spark JAR 開発] を選択します。

  3. ジョブリソースグループと Spark ジョブタイプを選択します。この例では、[バッチ] タイプが使用されています。

  4. ステップ 2 で作成したプログラムに基づいて、コードエディタに次のコードを入力し、[今すぐ実行] をクリックします。

    Scala プログラム

     {
      "name": "CatBoostDemo",
      "file": "oss://testBucketName/original-LightgbmDemo-1.0.jar",
      "jars": "oss://testBucketName/GBDT/GDBTDemo-1.0-SNAPSHOT.jar",
      "ClassName":"com.aliyun.adb.CatBoostDemo",
        "conf": {
            "spark.driver.resourceSpec": "large",
            "spark.executor.instances": 2,
            "spark.executor.resourceSpec": "medium",
            "spark.executor.memoryOverhead": "4096",
            "spark.task.cpus": 2,
            "spark.adb.version": "3.5"
        }

    Python プログラム

     {
      "name": "CatBoostDemo",
      "file": "oss://testBucketName/GBDT/lightgbm_spark_20241227.py",
      "jars": "oss://testBucketName/GBDT/GDBTDemo-1.0-SNAPSHOT.jar",
      "pyFiles": "oss://testBucketName/GBDT/GDBTDemo-1.0-SNAPSHOT.jar",
        "conf": {
            "spark.driver.resourceSpec": "large",
            "spark.executor.instances": 2,
            "spark.executor.resourceSpec": "medium",
            "spark.executor.memoryOverhead": "4096",
            "spark.task.cpus": 2,
            "spark.adb.version": "3.5"
        }
    説明

    このトピックの Python プログラムは、JAR パッケージに含まれるメソッドを呼び出す必要があります。この場合、Spark JAR ジョブコードに jars パラメータを追加して、ステップ 1 で取得した Maven 依存関係の OSS パスを指定する必要があります。

    次の表にパラメータを示します。

    パラメータ

    必須

    説明

    name

    いいえ

    Spark ジョブの名前。

    file

    はい

    • Scala:ステップ 2 で作成した Scala プログラムの OSS パス。

    • Python:ステップ 2 で作成した Python プログラムの OSS パス。

    jars

    はい

    ステップ 1 で準備した Maven 依存関係の OSS パス。

    ClassName

    特定の条件を満たす場合、はい

    Scala プログラムのエントリクラスの名前。このパラメータは、Scala プログラムを送信する場合に必須です。

    pyFiles

    特定の条件を満たす場合、はい

    ステップ 1 で準備した Maven 依存関係の OSS パス。このパラメータは、Python プログラムを送信する場合に必須です。

    spark.adb.version

    はい

    Spark のバージョン。 3.5 に設定する必要があります。

    spark.task.cpus

    はい

    Spark エグゼキューターリソース仕様に対応する CPU コア数。このパラメータは、各 Spark エグゼキューターで 1 つの CatBoost ワーカープロセスのみが実行されるようにします。

    たとえば、spark.executor.resourceSpec パラメータが medium に設定されている場合、このパラメータを 2 に設定する必要があります。

    その他の conf パラメータ

    いいえ

    Spark アプリケーションに必要な構成パラメータ。Apache Spark のパラメータと似ています。パラメータは key: value 形式である必要があります。複数のパラメータはカンマ(,)で区切ります。Apache Spark のパラメータとは異なる構成パラメータ、または AnalyticDB for MySQL 固有の構成パラメータについては、Spark アプリケーションの構成パラメータをご参照ください。

  5. (オプション) [アプリケーション] タブで、Spark ジョブを見つけ、[ログ][アクション] 列でクリックして、ジョブの実行結果を表示します。