All Products
Search
Document Center

Platform For AI:Use PAI-Blade and TorchScript custom C++ operators to optimize a RetinaNet model

Last Updated:Feb 27, 2024

To improve the post-processing efficiency of an object detection model, you can use TorchScript custom C++ operators to build the post-processing network that used to be realized in Python. Then, you can export the model and use Machine Learning Platform for AI (PAI)-Blade to optimize the model. This topic describes how to use TorchScript custom C++ operators to build the post-processing network of an object detection model and use PAI-Blade to optimize the model.

Background information

RetinaNet is a detection network of the One-Stage Region-based Convolutional Neural Network (R-CNN) type. The basic structure of RetinaNet consists of a backbone, multiple subnetworks, and Non-Maximum Suppression (NMS). NMS is a post-processing algorithm. RetinaNet is implemented in many training frameworks. Detectron2 is a typical training framework that uses RetinaNet. You can call the scripting_with_instances method of Detectron2 to export a RetinaNet model and use PAI-Blade to optimize the model. For more information, see Use PAI-Blade to optimize a RetinaNet model that is in the Detectron2 framework.

In most cases, the program logic for the post-processing network of an object detection model includes the logic of calculating and filtering boxes and nms. If you use Python to build the post-processing network, the efficiency to implement the logic is low. Instead, you can use TorchScript custom C++ operators to build the post-processing network. Then, you can export the model and use PAI-Blade to optimize the model.

Limits

The environment used for the procedure in this topic must meet the following version requirements:

  • System environment: Python 3.6 or later, GNU Compiler Collection (GCC) 5.4 or later, NVIDIA Tesla T4, CUDA 10.2, and cuDNN 8.0.5.39

  • Framework: PyTorch 1.8.1 or later, and Detectron2 0.4.1 or later

  • Inference optimization tool: PAI-Blade V3.16.0 or later

Procedure

To use PAI-Blade and custom C++ operators to optimize a RetinaNet model, perform the following steps:

  1. Step 1: Create a PyTorch model that contains TorchScript custom C++ operators

    Use TorchScript custom C++ operators to build the post-processing network of the RetinaNet model.

  2. Step 2: Export a TorchScript model

    Call the TracingAdapter or scripting_with_instances method of Detectron2 to export the RetinaNet model.

  3. Step 3: Use PAI-Blade to optimize the model

    Call the blade.optimize method to optimize the model and save the optimized model.

  4. Step 4: Load and run the optimized model

    If the optimized model passes the performance testing and meets your expectations, load the optimized model for inference.

Step 1: Create a PyTorch model that contains TorchScript custom C++ operators

PAI-Blade is seamlessly integrated with TorchScript custom C++ operators. This step describes how to use the operators to build the post-processing network of the RetinaNet model. For more information about TorchScript custom C++ operators, see EXTENDING TORCHSCRIPT WITH CUSTOM C++ OPERATORS. In this topic, the program logic for the post-processing network of the RetinaNet model comes from the open source community of NVIDIA. For more information, see retinanet-examples. The core code is used in this example to show you how to develop and implement custom operators.

  1. Download the sample code and decompress the downloaded package.

    wget -nv https://pai-blade.oss-cn-zhangjiakou.aliyuncs.com/tutorials/retinanet_example/retinanet-examples.tar.gz -O retinanet-examples.tar.gz
    tar xvfz retinanet-examples.tar.gz 1>/dev/null
  2. Compile TorchScript custom C++ operators.

    PyTorch provides three methods to compile custom operators: Building with CMake, Building with JIT Compilation, and Building with Setuptools. For more information, see EXTENDING TORCHSCRIPT WITH CUSTOM C++ OPERATORS. These three compilation methods are suitable for different scenarios. You can select a method based on your needs. In this example, the Building with JIT Compilation method is used to simplify operations. The following sample code provides an example:

    import torch.utils.cpp_extension
    import os
    codebase="retinanet-examples"
    sources=['csrc/extensions.cpp',
             'csrc/cuda/decode.cu',
             'csrc/cuda/nms.cu',]
    sources = [os.path.join(codebase,src) for src in sources]
    torch.utils.cpp_extension.load(
        name="custom",
        sources=sources,
        build_directory=codebase,
        extra_include_paths=['/usr/local/TensorRT/include/', '/usr/local/cuda/include/', '/usr/local/cuda/include/thrust/system/cuda/detail'],
        extra_cflags=['-std=c++14', '-O2', '-Wall'],
        extra_cuda_cflags=[
            '-std=c++14', '--expt-extended-lambda',
            '--use_fast_math', '-Xcompiler', '-Wall,-fno-gnu-unique',
            '-gencode=arch=compute_75,code=sm_75',],
        is_python_module=False,
        with_cuda=True,
        verbose=False,
    )

    After you run the preceding code, the custom.so file that is generated after the compilation is stored in the retinanet-examples directory.

  3. Use the custom C++ operators to build the post-processing network of the RetinaNet model.

    To ensure simplicity, replace RetinaNet.forward with adapter_forward. adapter_forward uses the decode_cuda and nms_cuda custom C++ operators to build the post-processing network. The following sample code provides an example:

    import os
    import torch
    from typing import Tuple, Dict, List, Optional
    codebase="retinanet-examples"
    torch.ops.load_library(os.path.join(codebase, 'custom.so'))
    
    decode_cuda = torch.ops.retinanet.decode
    nms_cuda = torch.ops.retinanet.nms
    
    # The main part of the code written for this method is the same as the code of RetinaNet.forward. However, the program logic of the post-processing network is realized by using the decode_cuda and nms_cuda custom operators. 
    def adapter_forward(self, batched_inputs: Tuple[Dict[str, torch.Tensor]]):
        images = self.preprocess_image(batched_inputs)
        features = self.backbone(images.tensor)
        features = [features[f] for f in self.head_in_features]
        cls_heads, box_heads = self.head(features)
        cls_heads = [cls.sigmoid() for cls in cls_heads]
        box_heads = [b.contiguous() for b in box_heads]
    
        # Build the post-processing network. 
        strides = [images.tensor.shape[-1] // cls_head.shape[-1] for cls_head in cls_heads]
        decoded = [
            decode_cuda(
                cls_head,
                box_head,
                anchor.view(-1),
                stride,
                self.test_score_thresh,
                self.test_topk_candidates,
            )
            for stride, cls_head, box_head, anchor in zip(
                strides, cls_heads, box_heads, self.cell_anchors
            )
        ]
    
        # Implement non-maximum suppression. 
        decoded = [torch.cat(tensors, 1) for tensors in zip(decoded[0], decoded[1], decoded[2])]
        return nms_cuda(decoded[0], decoded[1], decoded[2], self.test_nms_thresh, self.max_detections_per_image)
    
    from detectron2.modeling.meta_arch import retinanet
    
    # Replace RetinaNet.forward with adapter_forward. 
    retinanet.RetinaNet.forward = adapter_forward

Step 2: Export a TorchScript model

Detectron2 is an open source training framework built by Facebook AI Research (FAIR). Detectron2 implements object detection and segmentation algorithms and is flexible, extensible, and configurable. Because of the flexibility of Detectron2, if you export a TorchScript model in regular ways, the export may fail or a wrong export result may be returned. To ensure that a TorchScript model can be deployed, Detectron2 allows you to export the TorchScript model by calling the TracingAdapter or scripting_with_instances method. For more information, see Usage.

PAI-Blade allows you to import TorchScript models of all types. In this example, the scripting_with_instances method is used to describe how to export a TorchScript model. The following sample code provides an example:

import torch
import numpy as np

from torch import Tensor
from torch.testing import assert_allclose

from detectron2 import model_zoo
from detectron2.export import scripting_with_instances
from detectron2.structures import Boxes
from detectron2.data.detection_utils import read_image

# Call the scripting_with_instances method to export the RetinaNet model. 
def load_retinanet(config_path):
    model = model_zoo.get(config_path, trained=True).eval()
    # Set a new cell_anchors attributes to PyTorch model.
    model.cell_anchors = [c.contiguous() for c in model.anchor_generator.cell_anchors]
    fields = {
        "pred_boxes": Boxes,
        "scores": Tensor,
        "pred_classes": Tensor,
    }
    script_model = scripting_with_instances(model, fields)
    return model, script_model

# Download a sample image. 
# wget http://images.cocodataset.org/val2017/000000439715.jpg -q -O input.jpg
img = read_image('./input.jpg')
img = torch.from_numpy(np.ascontiguousarray(img.transpose(2, 0, 1)))

# Run the model and compare the latency before and after you export the model. 
pytorch_model, script_model = load_retinanet("COCO-Detection/retinanet_R_50_FPN_3x.yaml")
with torch.no_grad():
    batched_inputs = [{"image": img.float()}]
    pred1 = pytorch_model(batched_inputs)
    pred2 = script_model(batched_inputs)

assert_allclose(pred1[0], pred2[0])

Step 3: Use PAI-Blade to optimize the model

  1. Call the blade.optimize method of PAI-Blade.

    Call the blade.optimize method to optimize the model. The following sample code provides an example. For more information about the blade.optimize method, see Optimize a PyTorch model.

    import os
    import blade
    import torch
    
    # Load the dynamic-link library of custom C++ operators. 
    codebase="retinanet-examples"
    torch.ops.load_library(os.path.join(codebase, 'custom.so'))
    
    blade_config = blade.Config()
    blade_config.gpu_config.disable_fp16_accuracy_check = True
    
    test_data = [(batched_inputs,)] # The test data used for a PyTorch model is a list of tuples of tensors. 
    
    with blade_config:
        optimized_model, opt_spec, report = blade.optimize(
        script_model,  # The TorchScript model exported in the previous step. 
        'o1',  # The optimization level of PAI-Blade. In this example, the optimization level is o1. 
        device_type='gpu',  # The type of the device on which the model is run. In this example, the device is GPU. 
        test_data=test_data,  # The given set of test data, which facilitates optimization and testing. 
        )
  2. Display the optimization report and save the optimized model.

    The model optimized by using PAI-Blade is still a TorchScript model. After the optimization is complete, you can run the following code to display the optimization report and save the optimized model:

    # Display the optimization report. 
    print("Report: {}".format(report))
    # Save the optimized model. 
    torch.jit.save(script_model, 'script_model.pt')
    torch.jit.save(optimized_model, 'optimized.pt')

    The following sample code provides an example of the optimization report. For more information about the parameters in the report, see Optimization report.

    Report: {
      "software_context": [
        {
          "software": "pytorch",
          "version": "1.8.1+cu102"
        },
        {
          "software": "cuda",
          "version": "10.2.0"
        }
      ],
      "hardware_context": {
        "device_type": "gpu",
        "microarchitecture": "T4"
      },
      "user_config": "",
      "diagnosis": {
        "model": "unnamed.pt",
        "test_data_source": "user provided",
        "shape_variation": "undefined",
        "message": "Unable to deduce model inputs information (data type, shape, value range, etc.)",
        "test_data_info": "0 shape: (3, 480, 640) data type: float32"
      },
      "optimizations": [
        {
          "name": "PtTrtPassFp16",
          "status": "effective",
          "speedup": "3.92",
          "pre_run": "40.72 ms",
          "post_run": "10.39 ms"
        }
      ],
      "overall": {
        "baseline": "40.64 ms",
        "optimized": "10.41 ms",
        "speedup": "3.90"
      },
      "model_info": {
        "input_format": "torch_script"
      },
      "compatibility_list": [
        {
          "device_type": "gpu",
          "microarchitecture": "T4"
        }
      ],
      "model_sdk": {}
    }
  3. Test the performance of the original model and the optimized model.

    The following sample code provides an example on how to test the performance of the models:

    import time
    
    @torch.no_grad()
    def benchmark(model, inp):
        for i in range(100):
            model(inp)
        torch.cuda.synchronize()
        start = time.time()
        for i in range(200):
            model(inp)
        torch.cuda.synchronize()
        elapsed_ms = (time.time() - start) * 1000
        print("Latency: {:.2f}".format(elapsed_ms / 200))
    
    # Test the latency of the original model. 
    benchmark(script_model, batched_inputs)
    # Test the latency of the optimized model. 
    benchmark(optimized_model, batched_inputs)

    The following results of this performance testing are for your reference:

    Latency: 40.65
    Latency: 10.46

    The preceding results show that after both the models are run for 200 times, the average latency of the original model is 40.65 ms and the average latency of the optimized model is 10.46 ms.

Step 4: Load and run the optimized model

  1. Optional: During the trial period, add the following environment variable setting to prevent the program from unexpected quits due to an authentication failure:

    export BLADE_AUTH_USE_COUNTING=1
  2. Get authenticated to use PAI-Blade.

    export BLADE_REGION=<region>
    export BLADE_TOKEN=<token>

    Configure the following parameters based on your business requirements:

    • <region>: the region where you use PAI-Blade. You can join the DingTalk group of PAI-Blade users to obtain the regions where PAI-Blade can be used.

    • <token>: the authentication token that is required to use PAI-Blade. You can join the DingTalk group of PAI-Blade users to obtain the authentication token.

  3. Load and run the optimized model.

    The model optimized by using PAI-Blade is still a TorchScript model. Therefore, you can load the optimized model without changing the environment.

    import blade.runtime.torch
    import detectron2
    import torch
    import numpy as np
    import os
    from detectron2.data.detection_utils import read_image
    from torch.testing import assert_allclose
    
    # Load the dynamic-link library of custom C++ operators. 
    codebase="retinanet-examples"
    torch.ops.load_library(os.path.join(codebase, 'custom.so'))
    
    script_model = torch.jit.load('script_model.pt')
    optimized_model = torch.jit.load('optimized.pt')
    
    img = read_image('./input.jpg')
    img = torch.from_numpy(np.ascontiguousarray(img.transpose(2, 0, 1)))
    
    # Run the model and compare the latency before and after you export the model. 
    with torch.no_grad():
        batched_inputs = [{"image": img.float()}]
        pred1 = script_model(batched_inputs)
        pred2 = optimized_model(batched_inputs)
    
    assert_allclose(pred1[0], pred2[0], rtol=1e-3, atol=1e-2)