All Products
Search
Document Center

AnalyticDB:Customer churn prediction

Last Updated:Nov 15, 2024

This topic describes how to use the pgml extension to build the batch and real-time prediction task flows in AnalyticDB for PostgreSQL V7.0 by using customer churn prediction tasks in e-commerce scenarios.

Prerequisites

  • An AnalyticDB for PostgreSQL V7.0 instance of V7.1.1.0 or later is created.

  • The instance is in elastic storage mode.

  • The pgml extension is installed on the instance.

    Note

    The pgml extension does not support GUI-based installation. To install the pgml extension, submit a ticket. To uninstall the pgml extension, submit a ticket.

Background information

The pgml extension is designed to bring models closer to data. The in-database AI/ML feature loads models to the backend processes of PostgreSQL and uses user-defined functions (UDFs) to perform training, fine-tuning, and inference on the models. After training, fine-tuning, and inference, the models are stored in heap tables. You do not need to design high-availability or high-reliability solutions. This makes O&M simple and user-friendly. The pgml extension utilizes the integrated storage and computing resources to reduce data transmission loss and complete model training and service deployment in an efficient manner. The following figure shows the AI/ML training and inference process.

image

Data presentation and analysis

Sample dataset

In this example, a dataset that contains historical behavior data of customers with customer churn labels is used. For more information, see Ecommerce Customer Churn Analysis and Prediction. Customer churn analysis and prediction can help enterprises develop more effective policies to improve the customer retention rate. The following table describes the fields included in the dataset.

Field

Description

CustomerID

The unique customer ID.

Churn

The customer churn label.

Tenure

The usage duration of the customer.

PreferredLoginDevice

The preferred logon device of the customer.

CityTier

The category of the city where the customer resides.

WarehouseToHome

The distance from the warehouse to the home of the customer.

PreferredPaymentMode

The preferred payment method of the customer.

Gender

The gender of the customer.

HourSpendOnApp

The number of hours spent by the customer on mobile applications or websites.

NumberOfDeviceRegistered

The total number of devices that are registered to the customer.

PreferedOrderCat

The preferred order category of the customer in the last month.

SatisfactionScore

The satisfactory score of the customer on services.

MaritalStatus

The marital status of the customer.

NumberOfAddress

The total number of addresses that are added by the customer.

Complain

Specifies whether the customer raised complaints in the last month.

OrderAmountHikeFromlastYear

The increase rate of the customer order amount compared with the last year.

CouponUsed

The total number of coupons that were used by the customer in the last month.

OrderCount

The total number of customer orders in the last month.

DaySinceLastOrder

The number of days since the most recent order of the customer.

CashbackAmount

The cashback amount of the customer in the last month.

Data import

  1. Create a table.

    CREATE TABLE raw_data_table (
        CustomerID INTEGER,
        Churn INTEGER,
        Tenure FLOAT,
        PreferredLoginDevice TEXT,
        CityTier INTEGER,
        WarehouseToHome FLOAT,
        PreferredPaymentMode TEXT,
        Gender TEXT,
        HourSpendOnApp FLOAT,
        NumberOfDeviceRegistered INTEGER,
        PreferedOrderCat TEXT,
        SatisfactionScore INTEGER,
        MaritalStatus TEXT,
        NumberOfAddress INTEGER,
        Complain INTEGER,
        OrderAmountHikeFromlastYear FLOAT,
        CouponUsed FLOAT,
        OrderCount FLOAT,
        DaySinceLastOrder FLOAT,
        CashbackAmount FLOAT
    );
  2. Download the dataset and use the COPY statement to import the dataset in the CSV format to the table. Replace the /path/to/dataset parameter with the actual path of the dataset.

    COPY raw_data_table FROM '/path/to/dataset.csv' DELIMITER ',' CSV HEADER;
    Note

    We recommend that you use the psql tool to import data. If you use other SDKs to import data, you can use the COPY or INSERT statement. For more information, see the corresponding documentation.

Data analysis

Check the distribution of null values in the dataset.

DO $$
DECLARE
    r RECORD;
    SQL TEXT := '';
BEGIN
    FOR r IN 
        SELECT column_name 
        FROM information_schema.columns 
        WHERE table_name = 'raw_data_table'
    LOOP
        SQL := SQL || 
            'SELECT ''' || r.column_name || ''' AS column_name, COUNT(*) FILTER (WHERE ' || r.column_name || ' IS NULL) AS null_count FROM raw_data_table UNION ALL ';
    END LOOP;

    SQL := LEFT(SQL, length(SQL) - 11); 
    
    FOR r IN EXECUTE SQL LOOP
        RAISE NOTICE 'Column: %, Null Count: %', r.column_name, r.null_count;
    END LOOP;
END $$;

Sample result:

NOTICE:  Column: customerid, Null Count: 0
NOTICE:  Column: churn, Null Count: 0
NOTICE:  Column: tenure, Null Count: 264
NOTICE:  Column: preferredlogindevice, Null Count: 0
NOTICE:  Column: citytier, Null Count: 0
NOTICE:  Column: warehousetohome, Null Count: 251
NOTICE:  Column: preferredpaymentmode, Null Count: 0
NOTICE:  Column: gender, Null Count: 0
NOTICE:  Column: hourspendonapp, Null Count: 255
NOTICE:  Column: numberofdeviceregistered, Null Count: 0
NOTICE:  Column: preferedordercat, Null Count: 0
NOTICE:  Column: satisfactionscore, Null Count: 0
NOTICE:  Column: maritalstatus, Null Count: 0
NOTICE:  Column: numberofaddress, Null Count: 0
NOTICE:  Column: complain, Null Count: 0
NOTICE:  Column: orderamounthikefromlastyear, Null Count: 265
NOTICE:  Column: couponused, Null Count: 256
NOTICE:  Column: ordercount, Null Count: 258
NOTICE:  Column: daysincelastorder, Null Count: 307
NOTICE:  Column: cashbackamount, Null Count: 0

For fields that contain null values, you must check the semantics and data distribution of the fields to determine preprocessing policies and perform feature engineering for subsequent training. Perform the following steps:

  1. Create an analysis function that is used to check whether the fields contain null values.

    CREATE OR REPLACE FUNCTION print_column_statistics(table_name TEXT, column_name TEXT)
    RETURNS VOID AS $$
    DECLARE
        SQL TEXT;
        distinct_count INTEGER;
        min_value NUMERIC;
        max_value NUMERIC;
        avg_value NUMERIC;
        median_value NUMERIC;
        r RECORD;
    BEGIN
        SQL := 'SELECT 
                    COUNT(DISTINCT ' || column_name || ') AS distinct_count,
                    MIN(' || column_name || ') AS min_value,
                    MAX(' || column_name || ') AS max_value,
                    AVG(' || column_name || ') AS avg_value,
                    PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY ' || column_name || ') AS median_value
                FROM ' || table_name;
    
        EXECUTE SQL INTO r;
    
        distinct_count := r.distinct_count;
        min_value := r.min_value;
        max_value := r.max_value;
        avg_value := r.avg_value;
        median_value := r.median_value;
    
        RAISE NOTICE 'Distinct Count: %', distinct_count;
    
        IF distinct_count < 40 THEN
            SQL := 'SELECT ' || column_name || ' AS col, COUNT(*) AS count FROM ' || table_name || 
                   ' GROUP BY ' || column_name || ' ORDER BY count DESC';
    
            FOR r IN EXECUTE SQL LOOP
                RAISE NOTICE '%: %', r.col, r.count;
            END LOOP;
        END IF;
    
        RAISE NOTICE 'Min Value: %, Max Value: %, Avg Value: %, Median Value: %', 
                     min_value, max_value, avg_value, median_value;
    END;
    $$ LANGUAGE plpgsql;
  2. Use the analysis function to check whether the fields contain null values.

    SELECT print_column_statistics('raw_data_table', 'tenure');

    Sample result:

    NOTICE:  Distinct Count: 36
    NOTICE:  1: 690
    NOTICE:  0: 508
    NOTICE:  <NULL>: 264
    NOTICE:  8: 263
    NOTICE:  9: 247
    NOTICE:  7: 221
    NOTICE:  10: 213
    NOTICE:  5: 204
    ...
    NOTICE:  Min Value: 0, Max Value: 61, Avg Value: 10.1898993663809, Median Value: 9

Model training

Preprocess data

The preceding analysis shows that the Tenure, WareHousetohome, HourSpendOnApp, OrderAmountHikeFromLastYear, CouponUsed, OrderCount, and DaySinceLastOrder fields contain null values. To prevent impacts on model performance, you must perform data processing on the null values. The following list describes the data analysis and null processing methods on the fields:

  • In the Tenure field, normal values are positively skewed. In this case, fill null values with the median value.

  • In the WareHousetohome field, extreme values exist. For example, specific customers live far away from the warehouse. In this case, fill null values with the median value to centralize data distribution.

  • In the HourSpendOnApp field, data is relatively symmetrical. In this case, fill null values with the average value.

  • In the OrderAmountHikeFromLastYear field, the data distribution is stable. In this case, fill null values with the average value.

  • In the CouponUsed field, null values are assumed to be coupons that are not used. In this case, fill null values with 0.

  • In the OrderCount field, null values are assumed to be orders that are not placed. In this case, fill null values with 0.

  • In the DaySinceLastOrder field, null values indicate an extended period of time since the most recent order. In this case, fill null values with the maximum value.

Use machine learning to obtain the following preprocessing parameters. For more information, see Use machine learning.

{
  "tenure": {"impute": "median"},
  "warehousetohome": {"impute": "median"},
  "hourspendonapp": {"impute": "mean"},
  "orderamounthikefromlastyear": {"impute": "mean"},
  "couponused": {"impute": "zero"},
  "ordercount": {"impute": "zero"},
  "daysincelastorder": {"impute": "max"}
}
Note

The CityTier and Complain fields are of the INTEGER type. However, the fields use labels to represent meanings. You can convert the fields into the TEXT type and use the one-hot encoding method.

Create a training view

You can create a view based on the preceding data processing results to prevent physical modification to the raw data table and subsequently add more features.

CREATE OR REPLACE VIEW train_data_view AS
SELECT
  Churn::TEXT,
  Tenure,
  PreferredLoginDevice,
  CityTier::TEXT,
  WarehouseToHome,
  PreferredPaymentMode,
  Gender,
  HourSpendOnApp,
  NumberOfDeviceRegistered,
  PreferedOrderCat,
  SatisfactionScore,
  MaritalStatus,
  NumberOfAddress,
  Complain::TEXT,
  OrderAmountHikeFromlastYear,
  CouponUsed,
  OrderCount,
  DaySinceLastOrder,
  CashbackAmount
FROM
  raw_data_table;

Perform feature engineering

Feature engineering is a key step in machine learning and data mining. This step can process and transform raw data, provide additional information, achieve model convergence, and improve overall performance.

Feature

Description

AvgCashbkPerOrder

The average cashback amount per order. Calculation formula: CashbackAmount/OrderCount.

AvgHourSpendPerOrder

The average browse time per order. Calculation formula: HourSpendOnApp/OrderCount.

CouponUsedPerOrder

The average number of coupons used per order. Calculation formula: CouponUsed/OrderCount.

LogCashbackAmount

The logarithmic transformation of the cashback amount. Calculation formula: log(1 + LogCashbackAmount).

Re-create a view based on the preceding features.

CREATE OR REPLACE VIEW train_data_view AS
SELECT
  Churn::TEXT,
  Tenure,
  PreferredLoginDevice,
  CityTier::TEXT,
  WarehouseToHome,
  PreferredPaymentMode,
  Gender,
  HourSpendOnApp,
  NumberOfDeviceRegistered,
  PreferedOrderCat,
  SatisfactionScore,
  MaritalStatus,
  NumberOfAddress,
  Complain::TEXT,
  OrderAmountHikeFromlastYear,
  CouponUsed,
  OrderCount,
  DaySinceLastOrder,
  CashbackAmount,
  CashbackAmount/OrderCount AS AvgCashbkPerOrder,
  HourSpendOnApp/OrderCount AS AvgHourSpendPerOrder,
  CouponUsed/OrderCount AS CouponUsedPerOrder,
  log(1+CashbackAmount) AS LogCashbackAmount
FROM
  raw_data_table;

Train and select a model

  1. Use the pgml.train() function to fit different models to data and verify the results. In this example, the XGBoost and bagging models are used.

    • Fit the XGBoost model to data.

      SELECT * FROM pgml.train(
          project_name => 'Customer Churn Prediction Project', -- The project name.
          task => 'classification', -- The task type.
          relation_name => 'train_data_view', -- The data source.
          y_column_name => 'churn', -- The name of the prediction category column.
          preprocess => '{
                  "tenure": {"impute": "median"},
                  "warehousetohome": {"impute": "median"},
                  "hourspendonapp": {"impute": "mean"},
                  "orderamounthikefromlastyear": {"impute": "mean"},
                  "couponused": {"impute": "zero"},
                  "ordercount": {"impute": "zero"},
                  "daysincelastorder": {"impute": "max"},
                  "avgcashbkperorder": {"impute": "zero"},
                  "avghourspendperorder": {"impute": "zero"},
                  "couponusedperorder": {"impute": "zero"},
                  "logcashbackamount": {"impute": "min"}
              }', -- The preprocessing method.
          algorithm => 'xgboost', -- The model type.
          runtime => 'python', -- The runtime environment. Set this parameter to python.
          test_size => 0.2 -- The ratio of the test set.
      );

      The following fitting metrics are returned:

      -- The following fitting metrics are returned:
      -- {f1": 0.9543147, "precision": 0.96907216, "recall": 0.94, "accuracy": 0.9840142, ...}
    • Fit the bagging model to data.

      -- bagging regression
      SELECT * FROM pgml.train(
          project_name => 'Customer Churn Prediction Project', -- The project name.
          task => 'classification', -- The task type.
          relation_name => 'train_data_view', -- The data source.
          y_column_name => 'churn', -- The name of the prediction category column.
          preprocess => '{
                  "tenure": {"impute": "median"},
                  "warehousetohome": {"impute": "median"},
                  "hourspendonapp": {"impute": "mean"},
                  "orderamounthikefromlastyear": {"impute": "mean"},
                  "couponused": {"impute": "zero"},
                  "ordercount": {"impute": "zero"},
                  "daysincelastorder": {"impute": "max"},
                  "avgcashbkperorder": {"impute": "zero"},
                  "avghourspendperorder": {"impute": "zero"},
                  "couponusedperorder": {"impute": "zero"},
                  "logcashbackamount": {"impute": "min"}
              }', -- The preprocessing method.
          algorithm => 'bagging', -- The model type.
          runtime => 'python', -- The runtime environment. Set this parameter to python.
          test_size => 0.2 -- The ratio of the test set.
      );

      The following fitting metrics are returned:

      -- The following fitting metrics are returned:
      -- {"f1": 0.9270833, "precision": 0.96216214, "recall": 0.89447236}

    You can replace the value of the algorithm parameter to verify the fitting capabilities of different models on the dataset. For information about the supported models, see the pgml.algorithm enumeration type table in the Use machine learning topic. Based on the value of the F1 metric, the performance of the XGBoost model is better than the performance of other models. In this example, the XGBoost model is selected to perform subsequent operations.

  2. Use the grid parameter search method to find the optimal model hyperparameters and use 5-fold cross-validation to increase the reliability of the results. The following table describes the hyperparameters used for the search method.

    Hyperparameter

    Description

    n_estimators

    The number of trees that you want to construct. An increase in the number of trees improves model performance but increases computing costs. The value range is a variable. For example, you can set the values to 100, 200, 300, 400, and 500 to find the optimal balance point.

    eta

    The learning rate. This parameter specifies the extent to which each tree contributes to the final prediction. A smaller learning rate allows the training process to be more stable but may require a higher value of the n_estimators parameter. The value range is a variable. For example, you can set the values to 0.05, 0.1, and 0.2 to find the optimal balance point between training stability and efficiency.

    max_depth

    The maximum depth of each tree. A higher depth captures more feature interactions but may lead to overfitting. The value range is a variable. For example, you can set the values to 16 and 32.

    SELECT * FROM pgml.train(
        project_name => 'Customer Churn Prediction Project', -- The project name.
        task => 'classification', -- The task type.
        relation_name => 'train_data_view', -- The data source.
        y_column_name => 'churn', -- The name of the prediction category column.
        preprocess => '{
                "tenure": {"impute": "median"},
                "warehousetohome": {"impute": "median"},
                "hourspendonapp": {"impute": "mean"},
                "orderamounthikefromlastyear": {"impute": "mean"},
                "couponused": {"impute": "zero"},
                "ordercount": {"impute": "zero"},
                "daysincelastorder": {"impute": "max"},
                "avgcashbkperorder": {"impute": "zero"},
                "avghourspendperorder": {"impute": "zero"},
                "couponusedperorder": {"impute": "zero"},
                "logcashbackamount": {"impute": "min"}
            }', -- The preprocessing method.
        algorithm => 'xgboost', -- The model type.
        search_args => '{ "cv": 5 }', -- Enables 5-fold cross-validation.
        SEARCH => 'grid', -- The grid search method.
        search_params => '{
            "max_depth": [4, 6, 8, 16], 
            "n_estimators": [100, 200, 300, 400, 500, 1000, 2000],
            "eta": [0.05, 0.1, 0.2]
        }',
        hyperparams => '{
            "nthread": 16,
            "alpha": 0,
            "lambda": 1
        }',
        runtime => 'python', -- The runtime environment. Set this parameter to python.
        test_size => 0.2 -- The ratio of the test set.
    );

    Sample result:

    -- Search result:
    -- ... (Details omitted)
    INFO:  Best Hyperparams: {
      "alpha": 0,
      "lambda": 1,
      "nthread": 16,
      "eta": 0.1,
      "max_depth": 6,
      "n_estimators": 1000
    }
    INFO:  Best f1 Metrics: Number(0.9874088168144226)

    The search result shows that the hyperparameter values of {"eta": 0.2, "max_depth": 16, "n_estimators": 400} provide the optimal model fitting capability. The search_args => '{ "cv": 5 } configuration specifies that k-fold cross-validation is enabled. Therefore, the model uses 80% of the dataset for training.

  3. Use the optimal hyperparameters to train the model on full data and verify the results.

    SELECT * FROM pgml.train(
        project_name => 'Customer Churn Prediction Project', -- The project name.
        task => 'classification', -- The task type.
        relation_name => 'train_data_view', -- The data source.
        y_column_name => 'churn', -- The name of the prediction category column.
        preprocess => '{
                "tenure": {"impute": "median"},
                "warehousetohome": {"impute": "median"},
                "hourspendonapp": {"impute": "mean"},
                "orderamounthikefromlastyear": {"impute": "mean"},
                "couponused": {"impute": "zero"},
                "ordercount": {"impute": "zero"},
                "daysincelastorder": {"impute": "max"},
                "avgcashbkperorder": {"impute": "zero"},
                "avghourspendperorder": {"impute": "zero"},
                "couponusedperorder": {"impute": "zero"},
                "logcashbackamount": {"impute": "min"}
            }', -- The preprocessing method.
        algorithm => 'xgboost', -- The model type.
        hyperparams => '{
            "max_depth": 6,
            "n_estimators": 1000,
            "eta": 0.1,
            "nthread": 16,
            "alpha": 0,
            "lambda": 1
        }',
        runtime => 'python', -- The runtime environment. Set this parameter to python.
        test_size => 0.2 -- The ratio of the test set.
    );

    Sample result:

    -- Search result:
    INFO:  Training Model { id: 170, task: classification, algorithm: xgboost, runtime: python }
    INFO:  Hyperparameter searches: 1, cross validation folds: 1
    INFO:  Hyperparams: {
      "eta": 0.1,
      "alpha": 0,
      "lambda": 1,
      "nthread": 16,
      "max_depth": 6,
      "n_estimators": 1000
    }
    INFO:  Metrics: {"roc_auc": 0.9751001, "log_loss": 0.19821791, "f1": 0.99258476, "precision": 0.9936373, "recall": 0.9915344, "accuracy": 0.9875666, "mcc": 0.95414394, "fit_time": 0.9980099, "score_time": 0.0085158}
    INFO:  Comparing to deployed model f1: Some(0.9874088168144226)
    INFO:  Deploying model id: 170
                  project              |      task      | algorithm | deployed
    -----------------------------------+----------------+-----------+----------
     Customer Churn Prediction Project | classification | xgboost   | t

    The model performance F1 on the test set can reach 0.99258476.

Model deployment

Select the model to be deployed

By default, the pgml extension automatically deploys the model that has the highest F1 value during training in the project (for classification tasks). You can use the pgml.deployments table to check the current deployment.

SELECT d.id, d.project_id, d.model_id, p.name, p.task FROM pgml.deployments d 
JOIN pgml.projects p on d.project_id = p.id;

Sample result:

 id | project_id | model_id |               name                |      task
----+------------+----------+-----------------------------------+----------------
 61 |          2 |      170 | Customer Churn Prediction Project | classification

For information about how to deploy other models, see the "Deployment" section of the Use machine learning topic.

Use the model

Real-time inference

Real-time inference is suitable for scenarios that require real-time interactive responses. For example, a data analyst that performs case analysis wants the prediction result to be immediately returned based on the historical behavior of users.

-- Perform real-time inference on a single piece of text.
SELECT pgml.predict('Customer Churn Prediction Project', 
( 4, 'Mobile Phone'::TEXT, 3, 6, 
'Debit Card'::TEXT, 'Female'::TEXT, 3, 3, 
'Laptop & Accessory'::TEXT, 2, 
'Single'::TEXT, 9 ,
'1'::TEXT, 11, 1, 1, 5, 159.93, 
159.93, 3, 1, 2.206637011283536
));

Sample result:

-- Prediction result:
 predict
---------
       0
(1 row)

Batch inference

Batch inference is suitable for scenarios in which a large amount of data needs to be processed and the throughput precedes over the response time. Batch inference can improve the utilization of computing resources.

-- Create a view.
CREATE OR REPLACE VIEW predict_data_view AS
SELECT
  CustomerID,
  Churn::TEXT,
  Tenure,
  PreferredLoginDevice,
  CityTier::TEXT,
  WarehouseToHome,
  PreferredPaymentMode,
  Gender,
  HourSpendOnApp,
  NumberOfDeviceRegistered,
  PreferedOrderCat,
  SatisfactionScore,
  MaritalStatus,
  NumberOfAddress,
  Complain::TEXT,
  OrderAmountHikeFromlastYear,
  CouponUsed,
  OrderCount,
  DaySinceLastOrder,
  CashbackAmount,
  CashbackAmount/OrderCount AS AvgCashbkPerOrder,
  HourSpendOnApp/OrderCount AS AvgHourSpendPerOrder,
  CouponUsed/OrderCount AS CouponUsedPerOrder,
  log(1+CashbackAmount) AS LogCashbackAmount
FROM
  raw_data_table;

-- ====================================
-- Perform batch inference on multiple pieces of text at a time.
 
SELECT CustomerID, pgml.predict('Customer Churn Prediction Project', (
  "tenure",
  "preferredlogindevice",
  "citytier",
  "warehousetohome",
  "preferredpaymentmode",
  "gender",
  "hourspendonapp",
  "numberofdeviceregistered",
  "preferedordercat",
  "satisfactionscore",
  "maritalstatus",
  "numberofaddress",
  "complain",
  "orderamounthikefromlastyear",
  "couponused",
  "ordercount",
  "daysincelastorder",
  "cashbackamount",
  "avgcashbkperorder",
  "avghourspendperorder",
  "couponusedperorder",
  "logcashbackamount"
)) FROM predict_data_view limit 20;

Sample result:

-- Prediction result:
 customerid | predict
------------+---------
      50005 |       0
      50009 |       0
      50012 |       0
      50013 |       0
      50019 |       0
      50020 |       0
      50022 |       0
      50023 |       0
      50026 |       0
      50031 |       1
      50039 |       1
      50040 |       0
      50043 |       1
      50045 |       1
      50047 |       0
      50048 |       1
      50050 |       1
      50051 |       1
      50052 |       1
      50053 |       0
(20 rows)