All Products
Search
Document Center

AnalyticDB:Use CatBoost to train GBDT models

Last Updated:Apr 08, 2025

In data mining scenarios such as advertisement click prediction, game user payment or churn prediction, and automatic email classification, you need to train classification models based on historical data to predict subsequent behaviors. You can use AnalyticDB for MySQL Spark and CatBoost to train Gradient Boosting Decision Tree (GBDT) models for data classification and prediction.

Prerequisites

Procedure

Step 1: Prepare and upload Maven dependencies to OSS

  1. Use one of the following methods to obtain Maven dependencies:

    • Download the JAR package of Maven dependencies.

    • Configure Maven dependencies in the pom.xml file in IntelliJ IDEA. Sample code:

      <project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
               xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
          <modelVersion>4.0.0</modelVersion>
      
          <groupId>com.aliyun.adb.spark</groupId>
          <artifactId>CatBoostDemo</artifactId>
          <version>1.0</version>
          <packaging>jar</packaging>
      
          <properties>
              <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
          </properties>
      
          <dependencies>
              <dependency>
                  <groupId>ai.catboost</groupId>
                  <artifactId>catboost-spark_3.5_2.12</artifactId>
                  <version>1.2.7</version>
              </dependency>
              <dependency>
                  <groupId>org.apache.spark</groupId>
                  <artifactId>spark-launcher_2.12</artifactId>
                  <version>3.5.1</version>
                  <scope>provided</scope>
              </dependency>
              <dependency>
                  <groupId>org.apache.spark</groupId>
                  <artifactId>spark-core_2.12</artifactId>
                  <version>3.5.1</version>
                  <scope>provided</scope>
              </dependency>
              <dependency>
                  <groupId>org.apache.spark</groupId>
                  <artifactId>spark-sql_2.12</artifactId>
                  <version>3.5.1</version>
                  <scope>provided</scope>
              </dependency>
              <dependency>
                  <groupId>com.aliyun.oss</groupId>
                  <artifactId>aliyun-sdk-oss</artifactId>
                  <version>3.16.2</version>
                  <scope>provided</scope>
              </dependency>
              <dependency>
                  <groupId>org.apache.spark</groupId>
                  <artifactId>spark-mllib_2.12</artifactId>
                  <version>3.5.1</version>
                  <scope>provided</scope>
              </dependency>
          </dependencies>
      
          <build>
              <plugins>
                  <plugin>
                      <groupId>net.alchim31.maven</groupId>
                      <artifactId>scala-maven-plugin</artifactId>
                      <version>4.4.0</version>
                      <executions>
                          <execution>
                              <goals>
                                  <goal>compile</goal>
                                  <goal>testCompile</goal>
                              </goals>
                          </execution>
                      </executions>
                  </plugin>
                  <plugin>
                      <groupId>org.apache.maven.plugins</groupId>
                      <artifactId>maven-shade-plugin</artifactId>
                      <version>3.1.1</version>
                      <configuration>
                          <createDependencyReducedPom>false</createDependencyReducedPom>
                      </configuration>
                      <executions>
                          <execution>
                              <phase>package</phase>
                              <goals>
                                  <goal>shade</goal>
                              </goals>
                          </execution>
                      </executions>
                  </plugin>
              </plugins>
          </build>
      </project>
  2. If you configure Maven dependencies in the pom.xml file in IntelliJ IDEA, run the mvn clean package -DskipTests command to package the dependencies. If you download the package of Maven dependencies from a link, skip this step.

  3. Upload the Maven dependencies to OSS.

Step 2: Write and upload an application to OSS

  1. Write an application.

    Scala application

    package com.aliyun.adb
    
    import ai.catboost.spark.{CatBoostClassificationModel, CatBoostClassifier, Pool}
    import org.apache.spark.ml.linalg.{SQLDataTypes, Vectors}
    import org.apache.spark.sql.types.{StringType, StructField, StructType}
    import org.apache.spark.sql.{Row, SparkSession}
    
    object CatBoostDemo {
      def main(args: Array[String]): Unit = {
        val spark = SparkSession
          .builder()
          .appName("CatBoost Example")
          .getOrCreate()
    
        val srcDataSchema = Seq(
          StructField("features", SQLDataTypes.VectorType),
          StructField("label", StringType)
        )
    
        val trainData = Seq(
          Row(Vectors.dense(0.1, 0.2, 0.11), "1"),
          Row(Vectors.dense(0.97, 0.82, 0.33), "2"),
          Row(Vectors.dense(0.13, 0.22, 0.23), "1"),
          Row(Vectors.dense(0.8, 0.62, 0.0), "0")
        )
    
        val trainDf = spark.createDataFrame(spark.sparkContext.parallelize(trainData), StructType(srcDataSchema))
        val trainPool = new Pool(trainDf)
    
        val evalData = Seq(
          Row(Vectors.dense(0.22, 0.33, 0.9), "2"),
          Row(Vectors.dense(0.11, 0.1, 0.21), "0"),
          Row(Vectors.dense(0.77, 0.0, 0.0), "1")
        )
    
        val evalDf = spark.createDataFrame(spark.sparkContext.parallelize(evalData), StructType(srcDataSchema))
        val evalPool = new Pool(evalDf)
    
        val classifier = new CatBoostClassifier
    
        // train a model
        val model = classifier.fit(trainPool, Array[Pool](evalPool))
    
        // apply the model
        val predictions = model.transform(evalPool.data)
        println("predictions")
        predictions.show()
    
        // save the model as a local file in CatBoost native format
        val savedNativeModelPath = "./multiclass_model.cbm"
        model.saveNativeModel(savedNativeModelPath)
    
        // load the model as a local file in CatBoost native format
        val loadedNativeModel = CatBoostClassificationModel.loadNativeModel(savedNativeModelPath)
    
        val predictionsFromLoadedNativeModel = loadedNativeModel.transform(evalPool.data)
        println("predictionsFromLoadedNativeModel")
        predictionsFromLoadedNativeModel.show()
        System.exit(0)
      }
    }
    

    Python application

    Important

    CatBoost depends on the context environment after SparkSession initialization. In this case, you can execute the import catboost_spark statement only after you create a SparkSession object. If you execute the statement ahead of time, Maven dependencies may fail to be loaded.

    from pyspark.sql import Row,SparkSession
    from pyspark.ml.linalg import Vectors, VectorUDT
    from pyspark.sql.types import *
    
    
    spark = SparkSession.builder.getOrCreate()
    
    import catboost_spark
    
    srcDataSchema = [
        StructField("features", VectorUDT()),
        StructField("label", StringType())
    ]
    
    trainData = [
        Row(Vectors.dense(0.1, 0.2, 0.11), "1"),
        Row(Vectors.dense(0.97, 0.82, 0.33), "2"),
        Row(Vectors.dense(0.13, 0.22, 0.23), "1"),
        Row(Vectors.dense(0.8, 0.62, 0.0), "0")
    ]
    
    trainDf = spark.createDataFrame(spark.sparkContext.parallelize(trainData), StructType(srcDataSchema))
    trainPool = catboost_spark.Pool(trainDf)
    
    evalData = [
        Row(Vectors.dense(0.22, 0.33, 0.9), "2"),
        Row(Vectors.dense(0.11, 0.1, 0.21), "0"),
        Row(Vectors.dense(0.77, 0.0, 0.0), "1")
    ]
    
    evalDf = spark.createDataFrame(spark.sparkContext.parallelize(evalData), StructType(srcDataSchema))
    evalPool = catboost_spark.Pool(evalDf)
    
    classifier = catboost_spark.CatBoostClassifier()
    
    # train a model
    model = classifier.fit(trainPool, evalDatasets=[evalPool])
    
    # apply the model
    predictions = model.transform(evalPool.data)
    predictions.show()
    
    # save the model as a local file in CatBoost native format
    savedNativeModelPath = './multiclass_model.cbm'
    model.saveNativeModel(savedNativeModelPath)
    # load the model as a local file in CatBoost native format
    
    loadedNativeModel = catboost_spark.CatBoostClassificationModel.loadNativeModel(savedNativeModelPath)
    
    predictionsFromLoadedNativeModel = loadedNativeModel.transform(evalPool.data)
    predictionsFromLoadedNativeModel.show()
    
  2. If you write a Scala application, package the application into a JAR file. If you write a Python application, skip this step.

  3. Upload the JAR package or the .py file to OSS.

Step 3: Submit a Spark job

  1. Log on to the AnalyticDB for MySQL console. In the upper-left corner of the console, select a region. In the left-side navigation pane, click Clusters. On the Clusters page, click an edition tab. Find the cluster that you want to manage and click the cluster ID.

  2. In the left-side navigation pane, choose Job Development > Spark JAR Development.

  3. Select a job resource group and the Spark job type. In this example, the Batch type is used.

  4. Enter the following code in the code editor based on the application written in Step 2 and click Run Now.

    Scala application

     {
      "name": "CatBoostDemo",
      "file": "oss://testBucketName/original-LightgbmDemo-1.0.jar",
      "jars": "oss://testBucketName/GBDT/GDBTDemo-1.0-SNAPSHOT.jar",
      "ClassName":"com.aliyun.adb.CatBoostDemo",
        "conf": {
            "spark.driver.resourceSpec": "large",
            "spark.executor.instances": 2,
            "spark.executor.resourceSpec": "medium",
            "spark.executor.memoryOverhead": "4096",
            "spark.task.cpus": 2,
            "spark.adb.version": "3.5"
        }

    Python application

     {
      "name": "CatBoostDemo",
      "file": "oss://testBucketName/GBDT/lightgbm_spark_20241227.py",
      "jars": "oss://testBucketName/GBDT/GDBTDemo-1.0-SNAPSHOT.jar",
      "pyFiles": "oss://testBucketName/GBDT/GDBTDemo-1.0-SNAPSHOT.jar",
        "conf": {
            "spark.driver.resourceSpec": "large",
            "spark.executor.instances": 2,
            "spark.executor.resourceSpec": "medium",
            "spark.executor.memoryOverhead": "4096",
            "spark.task.cpus": 2,
            "spark.adb.version": "3.5"
        }
    Note

    The Python application in this topic needs to call the methods contained in the JAR package. In this case, you must add the jars parameter to the Spark JAR job code to specify the OSS path of the Maven dependencies obtained in Step 1.

    The following table describes the parameters.

    Parameter

    Required

    Description

    name

    No

    The name of the Spark application.

    file

    Yes

    • Scala: the OSS path of the Scala application written in Step 2.

    • Python: the OSS path of the Python application written in Step 2.

    jars

    Yes

    The OSS path of the Maven dependencies prepared in Step 1.

    ClassName

    Yes if specific conditions are met

    The name of the entry class for the Scala application. This parameter is required when you submit a Scala application.

    pyFiles

    Yes if specific conditions are met

    The OSS path of the Maven dependencies prepared in Step 1. This parameter is required when you submit a Python application.

    spark.adb.version

    Yes

    The Spark version, which must be set to 3.5.

    spark.task.cpus

    Yes

    The number of CPU cores corresponding to the Spark executor resource specifications. This parameter ensures that only one CatBoost worker process runs on each Spark executor.

    For example, if the spark.executor.resourceSpec parameter is set to medium, you must set this parameter to 2.

    Other conf parameters

    No

    The configuration parameters that are required for the Spark application, which are similar to those of Apache Spark. The parameters must be in the key: value format. Separate multiple parameters with commas (,). For information about the configuration parameters that are different from those of Apache Spark or the configuration parameters that are specific to AnalyticDB for MySQL, see Spark application configuration parameters.

  5. (Optional) On the Applications tab, find the Spark job and click Logs in the Actions column to view the running results of the job.