All Products
Search
Document Center

Platform For AI:Quantization

Last Updated:Sep 06, 2024

Machine Learning Platform for AI (PAI)-Blade supports the INT8 quantization of TensorFlow and PyTorch models on a GPU device or a client device. This topic describes how to use PAI-Blade to quantize a model on a GPU device.

Background information

Quantization is one of the most commonly used methods to compact models. Fixed-point integers that require less bit width are used instead of 32-bit floating-point numbers. This reduces memory access overheads and raises the inbound throughput of running commands. Quantization requires support from underlying computing hardware.

Quantize a TensorFlow model

For more information about how to optimize a TensorFlow model, see Optimize a TensorFlow model. In addition, you can perform quantization for a model by specifying optimization_level='o2' when you optimize the model by using PAI-Blade. If the GPU device that you use supports INT8 quantization and quantized models can be accelerated, PAI-Blade performs quantization in the default mode.

  • If you do not provide a calibration dataset, PAI-Blade performs INT8 quantization in online mode.

  • To achieve greater acceleration, we recommend that you provide a calibration dataset that is used to calculate the quantization parameters in offline mode. After PAI-Blade obtains the calibration dataset, PAI-Blade automatically performs INT8 quantization in offline mode.

The calibration dataset used to quantize a TensorFlow model is a list of feed_dict arguments. The following sample code provides an example:

# Prepare a calibration dataset. 
import numpy as np
calib_data = list()
for i in range(10):
    # All values in the feed_dict arguments must be of the np.ndarray type. 
    feed_dict = {'input:0': np.ones((32, 224, 224, 3), dtype=np.float32)}
    calib_data.append(feed_dict)

You can perform the following steps to quantize a TensorFlow model:

  1. Download the sample model, test data, and calibration dataset.

    wget https://pai-blade.oss-cn-zhangjiakou.aliyuncs.com/test_public_model/bbs/tf_resnet50_v1.5/frozen.pb
    wget https://pai-blade.oss-cn-zhangjiakou.aliyuncs.com/test_public_model/bbs/tf_resnet50_v1.5/test_bc32.npy
    wget https://pai-blade.oss-cn-zhangjiakou.aliyuncs.com/test_public_model/bbs/tf_resnet50_v1.5/calib_data_test_bc32.npy
  2. Load the model to be quantized and the corresponding data.

    import numpy as np
    import tensorflow as tf
    # Load the model. 
    graph_def = tf.GraphDef()
    with open('frozen.pb', 'rb') as f:
        graph_def.ParseFromString(f.read())
    # Load the test data. 
    test_data = np.load('test_bc32.npy', allow_pickle=True, encoding='bytes',).item()
    # Load the calibration dataset. 
    calib_data = np.load('calib_data_test_bc32.npy', allow_pickle=True, encoding='bytes',)
  3. Perform INT8 quantization in offline mode.

    import blade
    optimized_model, opt_spec, report = blade.optimize(
        model=graph_def,
        optimization_level='o2',
        device_type='gpu',
        test_data=test_data,
        calib_data=calib_data
    )
  4. Verify the precision of the quantized model.

    After the quantization is complete, you can use a complete test dataset to check whether the precision of the quantized model is considerably reduced. If the precision of the quantized model meets your requirement, you do not need to run the following code. Otherwise, you can run the following code to modify the quantization configurations to reduce the loss of precision. For more information about quantization configurations, see blade.Config.

    You can enable only the weight_adjustment feature if you quantize a TensorFlow model on a GPU device. If you set the weight_adjustment key to true, PAI-Blade automatically adjusts the model parameters to reduce the loss of precision. The following sample code provides an example on how to enable the feature:

    quant_config = {
        'weight_adjustment': 'true'  # Default value: false. 
    }
    optimized_model, opt_spec, report = blade.optimize(
        model=graph_def,
        optimization_level='o2',
        device_type='gpu',
        test_data=test_data,
        calib_data=calib_data,
        config=blade.Config(quant_config=quant_config)
    )

Quantize a PyTorch model

Similar to the quantization of a TensorFlow model, you need to only specify optimization_level='o2' to enable quantization when you optimize a PyTorch model by using PAI-Blade. However, a PyTorch model can be quantized only in offline mode. Therefore, you must provide a calibration dataset that can be used to calculate the quantization parameters in offline mode when you enable quantization for a PyTorch model.

The calibration dataset used to quantize a PyTorch model is a list that contains multiple groups of input data. The following sample code provides an example:

# Prepare a calibration dataset. 
import numpy as np
calib_data = list()
for i in range(10):
    image = torch.ones(32, 3, 224, 224)
    calib_data.append(image)

You can perform the following steps to quantize a PyTorch model:

  1. Download the sample model, test data, and calibration dataset.

    wget https://pai-blade.oss-cn-zhangjiakou.aliyuncs.com/test_public_model/bbs/pt_resnet50_v1.5/traced_model.pt
    wget https://pai-blade.oss-cn-zhangjiakou.aliyuncs.com/test_public_model/bbs/pt_resnet50_v1.5/test_bc32.pth
    wget https://pai-blade.oss-cn-zhangjiakou.aliyuncs.com/test_public_model/bbs/pt_resnet50_v1.5/calib_data_test_bc32.pth
  2. Load the model to be quantized and the corresponding data.

    import torch
    # Load the model. 
    pt_model = torch.jit.load('traced_model.pt')
    # Load the test data. 
    test_data = torch.load('test_bc32.pth')
    # Load the calibration dataset. 
    calib_data = torch.load('calib_data_test_bc32.pth')
  3. Perform INT8 quantization in offline mode.

    import blade
    optimized_model, opt_spec, report = blade.optimize(
        model=pt_model,
        optimization_level='o2',
        device_type='gpu',
        test_data=test_data,
        calib_data=calib_data
    )