This topic describes how to use the pgml extension to build the batch and real-time prediction task flows in AnalyticDB for PostgreSQL V7.0 by using customer churn prediction tasks in e-commerce scenarios.
Prerequisites
An AnalyticDB for PostgreSQL V7.0 instance of V7.1.1.0 or later is created.
The instance is in elastic storage mode.
The pgml extension is installed on the instance.
NoteThe pgml extension does not support GUI-based installation. To install the pgml extension, submit a ticket. To uninstall the pgml extension, submit a ticket.
Background information
The pgml extension is designed to bring models closer to data. The in-database AI/ML feature loads models to the backend processes of PostgreSQL and uses user-defined functions (UDFs) to perform training, fine-tuning, and inference on the models. After training, fine-tuning, and inference, the models are stored in heap tables. You do not need to design high-availability or high-reliability solutions. This makes O&M simple and user-friendly. The pgml extension utilizes the integrated storage and computing resources to reduce data transmission loss and complete model training and service deployment in an efficient manner. The following figure shows the AI/ML training and inference process.
Data presentation and analysis
Sample dataset
In this example, a dataset that contains historical behavior data of customers with customer churn labels is used. For more information, see Ecommerce Customer Churn Analysis and Prediction. Customer churn analysis and prediction can help enterprises develop more effective policies to improve the customer retention rate. The following table describes the fields included in the dataset.
Field | Description |
CustomerID | The unique customer ID. |
Churn | The customer churn label. |
Tenure | The usage duration of the customer. |
PreferredLoginDevice | The preferred logon device of the customer. |
CityTier | The category of the city where the customer resides. |
WarehouseToHome | The distance from the warehouse to the home of the customer. |
PreferredPaymentMode | The preferred payment method of the customer. |
Gender | The gender of the customer. |
HourSpendOnApp | The number of hours spent by the customer on mobile applications or websites. |
NumberOfDeviceRegistered | The total number of devices that are registered to the customer. |
PreferedOrderCat | The preferred order category of the customer in the last month. |
SatisfactionScore | The satisfactory score of the customer on services. |
MaritalStatus | The marital status of the customer. |
NumberOfAddress | The total number of addresses that are added by the customer. |
Complain | Specifies whether the customer raised complaints in the last month. |
OrderAmountHikeFromlastYear | The increase rate of the customer order amount compared with the last year. |
CouponUsed | The total number of coupons that were used by the customer in the last month. |
OrderCount | The total number of customer orders in the last month. |
DaySinceLastOrder | The number of days since the most recent order of the customer. |
CashbackAmount | The cashback amount of the customer in the last month. |
Data import
Create a table.
CREATE TABLE raw_data_table ( CustomerID INTEGER, Churn INTEGER, Tenure FLOAT, PreferredLoginDevice TEXT, CityTier INTEGER, WarehouseToHome FLOAT, PreferredPaymentMode TEXT, Gender TEXT, HourSpendOnApp FLOAT, NumberOfDeviceRegistered INTEGER, PreferedOrderCat TEXT, SatisfactionScore INTEGER, MaritalStatus TEXT, NumberOfAddress INTEGER, Complain INTEGER, OrderAmountHikeFromlastYear FLOAT, CouponUsed FLOAT, OrderCount FLOAT, DaySinceLastOrder FLOAT, CashbackAmount FLOAT );
Download the dataset and use the
COPY
statement to import the dataset in the CSV format to the table. Replace the /path/to/dataset parameter with the actual path of the dataset.COPY raw_data_table FROM '/path/to/dataset.csv' DELIMITER ',' CSV HEADER;
NoteWe recommend that you use the psql tool to import data. If you use other SDKs to import data, you can use the
COPY
orINSERT
statement. For more information, see the corresponding documentation.
Data analysis
Check the distribution of null values in the dataset.
DO $$
DECLARE
r RECORD;
SQL TEXT := '';
BEGIN
FOR r IN
SELECT column_name
FROM information_schema.columns
WHERE table_name = 'raw_data_table'
LOOP
SQL := SQL ||
'SELECT ''' || r.column_name || ''' AS column_name, COUNT(*) FILTER (WHERE ' || r.column_name || ' IS NULL) AS null_count FROM raw_data_table UNION ALL ';
END LOOP;
SQL := LEFT(SQL, length(SQL) - 11);
FOR r IN EXECUTE SQL LOOP
RAISE NOTICE 'Column: %, Null Count: %', r.column_name, r.null_count;
END LOOP;
END $$;
Sample result:
NOTICE: Column: customerid, Null Count: 0
NOTICE: Column: churn, Null Count: 0
NOTICE: Column: tenure, Null Count: 264
NOTICE: Column: preferredlogindevice, Null Count: 0
NOTICE: Column: citytier, Null Count: 0
NOTICE: Column: warehousetohome, Null Count: 251
NOTICE: Column: preferredpaymentmode, Null Count: 0
NOTICE: Column: gender, Null Count: 0
NOTICE: Column: hourspendonapp, Null Count: 255
NOTICE: Column: numberofdeviceregistered, Null Count: 0
NOTICE: Column: preferedordercat, Null Count: 0
NOTICE: Column: satisfactionscore, Null Count: 0
NOTICE: Column: maritalstatus, Null Count: 0
NOTICE: Column: numberofaddress, Null Count: 0
NOTICE: Column: complain, Null Count: 0
NOTICE: Column: orderamounthikefromlastyear, Null Count: 265
NOTICE: Column: couponused, Null Count: 256
NOTICE: Column: ordercount, Null Count: 258
NOTICE: Column: daysincelastorder, Null Count: 307
NOTICE: Column: cashbackamount, Null Count: 0
For fields that contain null values, you must check the semantics and data distribution of the fields to determine preprocessing policies and perform feature engineering for subsequent training. Perform the following steps:
Create an analysis function that is used to check whether the fields contain null values.
CREATE OR REPLACE FUNCTION print_column_statistics(table_name TEXT, column_name TEXT) RETURNS VOID AS $$ DECLARE SQL TEXT; distinct_count INTEGER; min_value NUMERIC; max_value NUMERIC; avg_value NUMERIC; median_value NUMERIC; r RECORD; BEGIN SQL := 'SELECT COUNT(DISTINCT ' || column_name || ') AS distinct_count, MIN(' || column_name || ') AS min_value, MAX(' || column_name || ') AS max_value, AVG(' || column_name || ') AS avg_value, PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY ' || column_name || ') AS median_value FROM ' || table_name; EXECUTE SQL INTO r; distinct_count := r.distinct_count; min_value := r.min_value; max_value := r.max_value; avg_value := r.avg_value; median_value := r.median_value; RAISE NOTICE 'Distinct Count: %', distinct_count; IF distinct_count < 40 THEN SQL := 'SELECT ' || column_name || ' AS col, COUNT(*) AS count FROM ' || table_name || ' GROUP BY ' || column_name || ' ORDER BY count DESC'; FOR r IN EXECUTE SQL LOOP RAISE NOTICE '%: %', r.col, r.count; END LOOP; END IF; RAISE NOTICE 'Min Value: %, Max Value: %, Avg Value: %, Median Value: %', min_value, max_value, avg_value, median_value; END; $$ LANGUAGE plpgsql;
Use the analysis function to check whether the fields contain null values.
SELECT print_column_statistics('raw_data_table', 'tenure');
Sample result:
NOTICE: Distinct Count: 36 NOTICE: 1: 690 NOTICE: 0: 508 NOTICE: <NULL>: 264 NOTICE: 8: 263 NOTICE: 9: 247 NOTICE: 7: 221 NOTICE: 10: 213 NOTICE: 5: 204 ... NOTICE: Min Value: 0, Max Value: 61, Avg Value: 10.1898993663809, Median Value: 9
Model training
Preprocess data
The preceding analysis shows that the Tenure
, WareHousetohome
, HourSpendOnApp
, OrderAmountHikeFromLastYear
, CouponUsed
, OrderCount
, and DaySinceLastOrder
fields contain null values. To prevent impacts on model performance, you must perform data processing on the null values. The following list describes the data analysis and null processing methods on the fields:
In the
Tenure
field, normal values are positively skewed. In this case, fill null values with the median value.In the
WareHousetohome
field, extreme values exist. For example, specific customers live far away from the warehouse. In this case, fill null values with the median value to centralize data distribution.In the
HourSpendOnApp
field, data is relatively symmetrical. In this case, fill null values with the average value.In the
OrderAmountHikeFromLastYear
field, the data distribution is stable. In this case, fill null values with the average value.In the
CouponUsed
field, null values are assumed to be coupons that are not used. In this case, fill null values with 0.In the
OrderCount
field, null values are assumed to be orders that are not placed. In this case, fill null values with 0.In the
DaySinceLastOrder
field, null values indicate an extended period of time since the most recent order. In this case, fill null values with the maximum value.
Use machine learning to obtain the following preprocessing parameters. For more information, see Use machine learning.
{
"tenure": {"impute": "median"},
"warehousetohome": {"impute": "median"},
"hourspendonapp": {"impute": "mean"},
"orderamounthikefromlastyear": {"impute": "mean"},
"couponused": {"impute": "zero"},
"ordercount": {"impute": "zero"},
"daysincelastorder": {"impute": "max"}
}
The CityTier
and Complain
fields are of the INTEGER type. However, the fields use labels to represent meanings. You can convert the fields into the TEXT type and use the one-hot encoding method.
Create a training view
You can create a view based on the preceding data processing results to prevent physical modification to the raw data table and subsequently add more features.
CREATE OR REPLACE VIEW train_data_view AS
SELECT
Churn::TEXT,
Tenure,
PreferredLoginDevice,
CityTier::TEXT,
WarehouseToHome,
PreferredPaymentMode,
Gender,
HourSpendOnApp,
NumberOfDeviceRegistered,
PreferedOrderCat,
SatisfactionScore,
MaritalStatus,
NumberOfAddress,
Complain::TEXT,
OrderAmountHikeFromlastYear,
CouponUsed,
OrderCount,
DaySinceLastOrder,
CashbackAmount
FROM
raw_data_table;
Perform feature engineering
Feature engineering is a key step in machine learning and data mining. This step can process and transform raw data, provide additional information, achieve model convergence, and improve overall performance.
Feature | Description |
AvgCashbkPerOrder | The average cashback amount per order. Calculation formula: CashbackAmount/OrderCount. |
AvgHourSpendPerOrder | The average browse time per order. Calculation formula: HourSpendOnApp/OrderCount. |
CouponUsedPerOrder | The average number of coupons used per order. Calculation formula: CouponUsed/OrderCount. |
LogCashbackAmount | The logarithmic transformation of the cashback amount. Calculation formula: log(1 + LogCashbackAmount). |
Re-create a view based on the preceding features.
CREATE OR REPLACE VIEW train_data_view AS
SELECT
Churn::TEXT,
Tenure,
PreferredLoginDevice,
CityTier::TEXT,
WarehouseToHome,
PreferredPaymentMode,
Gender,
HourSpendOnApp,
NumberOfDeviceRegistered,
PreferedOrderCat,
SatisfactionScore,
MaritalStatus,
NumberOfAddress,
Complain::TEXT,
OrderAmountHikeFromlastYear,
CouponUsed,
OrderCount,
DaySinceLastOrder,
CashbackAmount,
CashbackAmount/OrderCount AS AvgCashbkPerOrder,
HourSpendOnApp/OrderCount AS AvgHourSpendPerOrder,
CouponUsed/OrderCount AS CouponUsedPerOrder,
log(1+CashbackAmount) AS LogCashbackAmount
FROM
raw_data_table;
Train and select a model
Use the
pgml.train()
function to fit different models to data and verify the results. In this example, the XGBoost and bagging models are used.Fit the XGBoost model to data.
SELECT * FROM pgml.train( project_name => 'Customer Churn Prediction Project', -- The project name. task => 'classification', -- The task type. relation_name => 'train_data_view', -- The data source. y_column_name => 'churn', -- The name of the prediction category column. preprocess => '{ "tenure": {"impute": "median"}, "warehousetohome": {"impute": "median"}, "hourspendonapp": {"impute": "mean"}, "orderamounthikefromlastyear": {"impute": "mean"}, "couponused": {"impute": "zero"}, "ordercount": {"impute": "zero"}, "daysincelastorder": {"impute": "max"}, "avgcashbkperorder": {"impute": "zero"}, "avghourspendperorder": {"impute": "zero"}, "couponusedperorder": {"impute": "zero"}, "logcashbackamount": {"impute": "min"} }', -- The preprocessing method. algorithm => 'xgboost', -- The model type. runtime => 'python', -- The runtime environment. Set this parameter to python. test_size => 0.2 -- The ratio of the test set. );
The following fitting metrics are returned:
-- The following fitting metrics are returned: -- {f1": 0.9543147, "precision": 0.96907216, "recall": 0.94, "accuracy": 0.9840142, ...}
Fit the bagging model to data.
-- bagging regression SELECT * FROM pgml.train( project_name => 'Customer Churn Prediction Project', -- The project name. task => 'classification', -- The task type. relation_name => 'train_data_view', -- The data source. y_column_name => 'churn', -- The name of the prediction category column. preprocess => '{ "tenure": {"impute": "median"}, "warehousetohome": {"impute": "median"}, "hourspendonapp": {"impute": "mean"}, "orderamounthikefromlastyear": {"impute": "mean"}, "couponused": {"impute": "zero"}, "ordercount": {"impute": "zero"}, "daysincelastorder": {"impute": "max"}, "avgcashbkperorder": {"impute": "zero"}, "avghourspendperorder": {"impute": "zero"}, "couponusedperorder": {"impute": "zero"}, "logcashbackamount": {"impute": "min"} }', -- The preprocessing method. algorithm => 'bagging', -- The model type. runtime => 'python', -- The runtime environment. Set this parameter to python. test_size => 0.2 -- The ratio of the test set. );
The following fitting metrics are returned:
-- The following fitting metrics are returned: -- {"f1": 0.9270833, "precision": 0.96216214, "recall": 0.89447236}
You can replace the value of the
algorithm
parameter to verify the fitting capabilities of different models on the dataset. For information about the supported models, see the pgml.algorithm enumeration type table in the Use machine learning topic. Based on the value of the F1 metric, the performance of the XGBoost model is better than the performance of other models. In this example, the XGBoost model is selected to perform subsequent operations.Use the grid parameter search method to find the optimal model hyperparameters and use 5-fold cross-validation to increase the reliability of the results. The following table describes the hyperparameters used for the search method.
Hyperparameter
Description
n_estimators
The number of trees that you want to construct. An increase in the number of trees improves model performance but increases computing costs. The value range is a variable. For example, you can set the values to 100, 200, 300, 400, and 500 to find the optimal balance point.
eta
The learning rate. This parameter specifies the extent to which each tree contributes to the final prediction. A smaller learning rate allows the training process to be more stable but may require a higher value of the
n_estimators
parameter. The value range is a variable. For example, you can set the values to 0.05, 0.1, and 0.2 to find the optimal balance point between training stability and efficiency.max_depth
The maximum depth of each tree. A higher depth captures more feature interactions but may lead to overfitting. The value range is a variable. For example, you can set the values to 16 and 32.
SELECT * FROM pgml.train( project_name => 'Customer Churn Prediction Project', -- The project name. task => 'classification', -- The task type. relation_name => 'train_data_view', -- The data source. y_column_name => 'churn', -- The name of the prediction category column. preprocess => '{ "tenure": {"impute": "median"}, "warehousetohome": {"impute": "median"}, "hourspendonapp": {"impute": "mean"}, "orderamounthikefromlastyear": {"impute": "mean"}, "couponused": {"impute": "zero"}, "ordercount": {"impute": "zero"}, "daysincelastorder": {"impute": "max"}, "avgcashbkperorder": {"impute": "zero"}, "avghourspendperorder": {"impute": "zero"}, "couponusedperorder": {"impute": "zero"}, "logcashbackamount": {"impute": "min"} }', -- The preprocessing method. algorithm => 'xgboost', -- The model type. search_args => '{ "cv": 5 }', -- Enables 5-fold cross-validation. SEARCH => 'grid', -- The grid search method. search_params => '{ "max_depth": [4, 6, 8, 16], "n_estimators": [100, 200, 300, 400, 500, 1000, 2000], "eta": [0.05, 0.1, 0.2] }', hyperparams => '{ "nthread": 16, "alpha": 0, "lambda": 1 }', runtime => 'python', -- The runtime environment. Set this parameter to python. test_size => 0.2 -- The ratio of the test set. );
Sample result:
-- Search result: -- ... (Details omitted) INFO: Best Hyperparams: { "alpha": 0, "lambda": 1, "nthread": 16, "eta": 0.1, "max_depth": 6, "n_estimators": 1000 } INFO: Best f1 Metrics: Number(0.9874088168144226)
The search result shows that the hyperparameter values of
{"eta": 0.2, "max_depth": 16, "n_estimators": 400}
provide the optimal model fitting capability. Thesearch_args => '{ "cv": 5 }
configuration specifies that k-fold cross-validation is enabled. Therefore, the model uses 80% of the dataset for training.Use the optimal hyperparameters to train the model on full data and verify the results.
SELECT * FROM pgml.train( project_name => 'Customer Churn Prediction Project', -- The project name. task => 'classification', -- The task type. relation_name => 'train_data_view', -- The data source. y_column_name => 'churn', -- The name of the prediction category column. preprocess => '{ "tenure": {"impute": "median"}, "warehousetohome": {"impute": "median"}, "hourspendonapp": {"impute": "mean"}, "orderamounthikefromlastyear": {"impute": "mean"}, "couponused": {"impute": "zero"}, "ordercount": {"impute": "zero"}, "daysincelastorder": {"impute": "max"}, "avgcashbkperorder": {"impute": "zero"}, "avghourspendperorder": {"impute": "zero"}, "couponusedperorder": {"impute": "zero"}, "logcashbackamount": {"impute": "min"} }', -- The preprocessing method. algorithm => 'xgboost', -- The model type. hyperparams => '{ "max_depth": 6, "n_estimators": 1000, "eta": 0.1, "nthread": 16, "alpha": 0, "lambda": 1 }', runtime => 'python', -- The runtime environment. Set this parameter to python. test_size => 0.2 -- The ratio of the test set. );
Sample result:
-- Search result: INFO: Training Model { id: 170, task: classification, algorithm: xgboost, runtime: python } INFO: Hyperparameter searches: 1, cross validation folds: 1 INFO: Hyperparams: { "eta": 0.1, "alpha": 0, "lambda": 1, "nthread": 16, "max_depth": 6, "n_estimators": 1000 } INFO: Metrics: {"roc_auc": 0.9751001, "log_loss": 0.19821791, "f1": 0.99258476, "precision": 0.9936373, "recall": 0.9915344, "accuracy": 0.9875666, "mcc": 0.95414394, "fit_time": 0.9980099, "score_time": 0.0085158} INFO: Comparing to deployed model f1: Some(0.9874088168144226) INFO: Deploying model id: 170 project | task | algorithm | deployed -----------------------------------+----------------+-----------+---------- Customer Churn Prediction Project | classification | xgboost | t
The model performance F1 on the test set can reach 0.99258476.
Model deployment
Select the model to be deployed
By default, the pgml extension automatically deploys the model that has the highest F1 value during training in the project (for classification tasks). You can use the pgml.deployments
table to check the current deployment.
SELECT d.id, d.project_id, d.model_id, p.name, p.task FROM pgml.deployments d
JOIN pgml.projects p on d.project_id = p.id;
Sample result:
id | project_id | model_id | name | task
----+------------+----------+-----------------------------------+----------------
61 | 2 | 170 | Customer Churn Prediction Project | classification
For information about how to deploy other models, see the "Deployment" section of the Use machine learning topic.
Use the model
Real-time inference
Real-time inference is suitable for scenarios that require real-time interactive responses. For example, a data analyst that performs case analysis wants the prediction result to be immediately returned based on the historical behavior of users.
-- Perform real-time inference on a single piece of text.
SELECT pgml.predict('Customer Churn Prediction Project',
( 4, 'Mobile Phone'::TEXT, 3, 6,
'Debit Card'::TEXT, 'Female'::TEXT, 3, 3,
'Laptop & Accessory'::TEXT, 2,
'Single'::TEXT, 9 ,
'1'::TEXT, 11, 1, 1, 5, 159.93,
159.93, 3, 1, 2.206637011283536
));
Sample result:
-- Prediction result:
predict
---------
0
(1 row)
Batch inference
Batch inference is suitable for scenarios in which a large amount of data needs to be processed and the throughput precedes over the response time. Batch inference can improve the utilization of computing resources.
-- Create a view.
CREATE OR REPLACE VIEW predict_data_view AS
SELECT
CustomerID,
Churn::TEXT,
Tenure,
PreferredLoginDevice,
CityTier::TEXT,
WarehouseToHome,
PreferredPaymentMode,
Gender,
HourSpendOnApp,
NumberOfDeviceRegistered,
PreferedOrderCat,
SatisfactionScore,
MaritalStatus,
NumberOfAddress,
Complain::TEXT,
OrderAmountHikeFromlastYear,
CouponUsed,
OrderCount,
DaySinceLastOrder,
CashbackAmount,
CashbackAmount/OrderCount AS AvgCashbkPerOrder,
HourSpendOnApp/OrderCount AS AvgHourSpendPerOrder,
CouponUsed/OrderCount AS CouponUsedPerOrder,
log(1+CashbackAmount) AS LogCashbackAmount
FROM
raw_data_table;
-- ====================================
-- Perform batch inference on multiple pieces of text at a time.
SELECT CustomerID, pgml.predict('Customer Churn Prediction Project', (
"tenure",
"preferredlogindevice",
"citytier",
"warehousetohome",
"preferredpaymentmode",
"gender",
"hourspendonapp",
"numberofdeviceregistered",
"preferedordercat",
"satisfactionscore",
"maritalstatus",
"numberofaddress",
"complain",
"orderamounthikefromlastyear",
"couponused",
"ordercount",
"daysincelastorder",
"cashbackamount",
"avgcashbkperorder",
"avghourspendperorder",
"couponusedperorder",
"logcashbackamount"
)) FROM predict_data_view limit 20;
Sample result:
-- Prediction result:
customerid | predict
------------+---------
50005 | 0
50009 | 0
50012 | 0
50013 | 0
50019 | 0
50020 | 0
50022 | 0
50023 | 0
50026 | 0
50031 | 1
50039 | 1
50040 | 0
50043 | 1
50045 | 1
50047 | 0
50048 | 1
50050 | 1
50051 | 1
50052 | 1
50053 | 0
(20 rows)