All Products
Search
Document Center

AnalyticDB:Use LightGBM to train GBDT models

Last Updated:Apr 08, 2025

In data mining scenarios such as advertisement click prediction, game user payment or churn prediction, and automatic email classification, you need to train classification models based on historical data to predict subsequent behaviors. You can use AnalyticDB for MySQL Spark and LightGBM to train Gradient Boosting Decision Tree (GBDT) models for data classification and prediction. Compared with XGBoost and CatBoost on a single device, LightGBM on Spark leverages distributed computing to efficiently process TB-scale data. This topic describes how to train GBDT models by using LightGBM for effective data classification and prediction.

Prerequisites

Procedure

Step 1: Prepare and upload Maven dependencies to OSS

  1. Use one of the following methods to obtain Maven dependencies:

    • Download the JAR package of Maven dependencies.

    • Configure Maven dependencies in the pom.xml file in IntelliJ IDEA. Sample code:

      <project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
               xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
          <modelVersion>4.0.0</modelVersion>
      
          <groupId>com.aliyun.adb.spark</groupId>
          <artifactId>LightgbmDemo</artifactId>
          <version>1.0</version>
          <packaging>jar</packaging>
      
          <properties>
              <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
          </properties>
      
          <dependencies>
              <!-- https://mvnrepository.com/artifact/com.microsoft.azure/synapseml -->
              <dependency>
                  <groupId>com.microsoft.azure</groupId>
                  <artifactId>synapseml_2.12</artifactId>
                  <version>1.0.8</version>
              </dependency>
              <dependency>
                  <groupId>org.apache.spark</groupId>
                  <artifactId>spark-launcher_2.12</artifactId>
                  <version>3.5.1</version>
                  <scope>provided</scope>
              </dependency>
              <dependency>
                  <groupId>org.apache.spark</groupId>
                  <artifactId>spark-core_2.12</artifactId>
                  <version>3.5.1</version>
                  <scope>provided</scope>
              </dependency>
              <dependency>
                  <groupId>org.apache.spark</groupId>
                  <artifactId>spark-sql_2.12</artifactId>
                  <version>3.5.1</version>
                  <scope>provided</scope>
              </dependency>
              <dependency>
                  <groupId>org.apache.spark</groupId>
                  <artifactId>spark-mllib_2.12</artifactId>
                  <version>3.5.1</version>
                  <scope>provided</scope>
              </dependency>
          </dependencies>
      
          <build>
              <plugins>
                  <plugin>
                      <groupId>net.alchim31.maven</groupId>
                      <artifactId>scala-maven-plugin</artifactId>
                      <version>4.4.0</version>
                      <executions>
                          <execution>
                              <goals>
                                  <goal>compile</goal>
                                  <goal>testCompile</goal>
                              </goals>
                          </execution>
                      </executions>
                  </plugin>
                  <plugin>
                      <groupId>org.apache.maven.plugins</groupId>
                      <artifactId>maven-shade-plugin</artifactId>
                      <version>3.1.1</version>
                      <configuration>
                          <createDependencyReducedPom>false</createDependencyReducedPom>
                      </configuration>
                      <executions>
                          <execution>
                              <phase>package</phase>
                              <goals>
                                  <goal>shade</goal>
                              </goals>
                          </execution>
                      </executions>
                  </plugin>
              </plugins>
          </build>
      </project>
      
  2. If you configure Maven dependencies in the pom.xml file in IntelliJ IDEA, run the mvn clean package -DskipTests command to package the dependencies. If you download the package of Maven dependencies from a link, skip this step.

  3. Upload the Maven dependencies to OSS.

Step 2: Write and upload an application to OSS

  1. Write an application.

    Scala application

    package com.aliyun.adb
    
    import com.microsoft.azure.synapse.ml.lightgbm.LightGBMClassifier
    import org.apache.spark.sql.{Dataset, Row, SparkSession}
    
    object GBDTdemo {
      def main(args: Array[String]): Unit = {
        // Register a Spark session.
        val spark = SparkSession
          .builder()
          .appName("lightgbm Example")
          .getOrCreate()
    
        // Read training data.
        val dataPath = s"${sys.env.getOrElse("SPARK_HOME", "/opt/spark")}/data/mllib/sample_multiclass_classification_data.txt"
        val data: Dataset[Row] = spark.read.format("libsvm").load(dataPath)
        data.show()
    
        val classifier = new LightGBMClassifier()
        classifier.setLabelCol("label")
        classifier.setFeaturesCol("features")
    
        // Set categorical feature columns. Assume that the 0th and 1st columns are categorical feature columns.
        // The sample data does not contain categorical feature columns, so this line of code needs to be commented out for demonstration.
        // classifier.setCategoricalSlotIndexes(Array(0, 1))
        // classifier.setCategoricalSlotNames(Array.apply("enumCol1", "enumCol2"))
    
        // Split the data into training and validation sets.
        val Array(trainData, validData) = data.randomSplit(Array(0.6, 0.4))
        val model = classifier.fit(trainData)
        model.saveNativeModel("oss://bucket_name/model", overwrite = true)
    
        // Run predictions on the validation data.
        val predictions = model.transform(validData)
        // Show the prediction results.
        predictions.show()
        // Print accuracy.
        val accuracy = predictions.filter("label == prediction").count().toDouble / predictions.count()
        println(s"Accuracy: $accuracy")
        System.exit(0)
      }
    }
    

    Python application

    import os
    
    from pyspark.sql import SparkSession
    
    if __name__ == '__main__':
        # Init spark
        spark = SparkSession.builder.appName("lightgbm_spark_train").getOrCreate()
        # Read data
        f = os.environ.get("SPARK_HOME") + "/data/mllib/sample_multiclass_classification_data.txt"
        df = spark.read.format("libsvm").load(f)
        # Split data
        train, test = df.randomSplit([0.8, 0.2])
        # Train model
        from synapse.ml.lightgbm import LightGBMClassifier
        model = LightGBMClassifier(learningRate=0.3,
                                   numIterations=20,
                                   numLeaves=4).fit(train)
        # Predict
        prediction = model.transform(test)
        prediction.show()
        # Stop spark
        spark.stop()
        
  2. If you write a Scala application, package the application into a JAR file. If you write a Python application, skip this step.

  3. Upload the JAR package or the .py file to OSS.

Step 3: Submit a Spark job

  1. Log on to the AnalyticDB for MySQL console. In the upper-left corner of the console, select a region. In the left-side navigation pane, click Clusters. On the Clusters page, click an edition tab. Find the cluster that you want to manage and click the cluster ID.

  2. In the left-side navigation pane, choose Job Development > Spark JAR Development.

  3. Select a job resource group and the Spark job type. In this example, the Batch type is used.

  4. Enter the following code in the code editor based on the application written in Step 2 and click Run Now.

  5. Scala application

    {
        "name": "LightdbmDemo",
        "file": "oss://testBucketName/original-LightgbmDemo-1.0.jar",
        "jars": "oss://testBucketName/LightgbmDemo-1.0.jar",
        "ClassName": "com.aliyun.adb.GBDTdemo",
        "conf": {
            "spark.driver.resourceSpec": "large",
            "spark.executor.instances": 2,
            "spark.executor.resourceSpec": "medium",
            "spark.adb.version": "3.5"
        }
    }

    Python application

    {
      "name": "CatBoostDemo",
      "file": "oss://testBucketName/GBDT/lightgbm_spark_20241227.py",
      "jars": "oss://testBucketName/GBDT/GDBTDemo-1.0-SNAPSHOT.jar",
      "pyFiles": "oss://testBucketName/GBDT/GDBTDemo-1.0-SNAPSHOT.jar",
      "conf": {
        "spark.driver.resourceSpec": "large",
        "spark.executor.instances": 2,
        "spark.executor.resourceSpec": "medium",
        "spark.adb.version": "3.5"
      }
    }
    Note

    The Python application in this topic needs to call the methods contained in the JAR package. In this case, you must add the jars parameter to the Spark JAR job code to specify the OSS path of the Maven dependencies obtained in Step 1.

    The following table describes the parameters.

    Parameter

    Required

    Description

    name

    No

    The name of the Spark application.

    file

    Yes

    • Scala: the OSS path of the Scala application written in Step 2.

    • Python: the OSS path of the Python application written in Step 2.

    jars

    Yes

    The OSS path of the Maven dependencies prepared in Step 1.

    ClassName

    Yes if specific conditions are met

    The name of the entry class for the Scala application. This parameter is required when you submit a Scala application.

    pyFiles

    Yes if specific conditions are met

    The OSS path of the Maven dependencies prepared in Step 1. This parameter is required when you submit a Python application.

    spark.adb.version

    Yes

    The Spark version, which must be set to 3.5.

    Other conf parameters

    No

    The configuration parameters that are required for the Spark application, which are similar to those of Apache Spark. The parameters must be in the key: value format. Separate multiple parameters with commas (,). For information about the configuration parameters that are different from those of Apache Spark or the configuration parameters that are specific to AnalyticDB for MySQL, see Spark application configuration parameters.

  6. (Optional) On the Applications tab, find the Spark job and click Logs in the Actions column to view the running results of the job.