[DSW Gallery] Metric Learning Example Based on Residual Network

EasyCV metric learning
  Metric learning, also known as similarity learning, has a wide range of applications, and can construct a suitable distance metric on a data set to model and answer practical questions. Distance metric learning (also known as metric learning, metric learning) is a typical task in machine learning, which is usually combined with many well-known metric-based methods (such as KNN, K-means, etc.) to achieve classification or clustering , which usually works very well.
  This article will introduce how to quickly train and reason the metric model based on EasyCV in pai-dsw.
Operating environment requirements
PAI-Pytorch 1.7/1.8 image, GPU model P100 or V100, memory 32G
Install dependencies
Note: There is no need to install related dependencies in PAI-DSW docker, you can skip the steps 1 and 2, and perform steps 1 and 2 to install the environment in the local notebook environment
1. Obtain the torch and cuda versions, and modify the mmcv installation command according to the version number, and install the corresponding version of mmcv and nvidia-dali
import torch
import os
os.environ['CUDA']='cu' + torch.version.cuda.replace('.', '')
os.environ['Torch']='torch'+torch.version.__version__.replace('+PAI', '')
!echo $CUDA
!echo $Torch
cu101
torch1.8.1+cu101
# install some python deps
!pip install --upgrade tqdm
!pip install mmcv-full==1.4.4 -f https://download.openmmlab.com/mmcv/dist/cu101/torch1.8.0/index.html
!pip install http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/third_party/nvidia_dali_cuda100-0.25.0-1535750-py3-none-manylinux2014_x86_64.whl
2. Install the EasyCV algorithm package Note: The pai-easycv library is pre-installed in the PAI-DSW docker, and this step can be skipped. If an error is reported during the training and testing process, try to update the easycv version with the following command
#pip install pai-easycv
!echo y | pip uninstall pai-easycv easycv
! pip install http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/pkgs/whl/2022_6/pai_easycv-0.3.0-py3-none-any.whl
3. Simple verification
from easycv.apis import *
CUB200 Metric Learning
The following example introduces how to use the CUB200 data and use the ResNet50 model to quickly perform the training evaluation and model prediction process of the image classification model
data preparation
Download the cub200 data and decompress it to the data/cub200 directory. The directory structure is as follows
data/cub200
├── images
├── images.txt
├── image_class_labels.txt
├── train_test_split.txt
!mkdir -p data/ && wget https://s3.amazonaws.com/fast-ai-imageclas/CUB_200_2011.tgz && tar -xzf CUB_200_2011.tgz -C data/ && mv data/CUB_200_2011 data/cub200
training model
Download the training configuration file, which imports weights from the ImageNet-based pre-training model by default. If you need a self-supervised pre-training model, you can download it from the link above and replace the configuration in config.
!rm -rf cub_resnet50_jpg.py
!wget https://raw.githubusercontent.com/alibaba/EasyCV/master/configs/metric_learning/cub_resnet50_jpg.py
Use a single-card GPU for training and verification set evaluation. In order to run quickly, you can set the total_epoch parameter in cub_resnet50_jpg.py to 10.
!python -m torch.distributed.launch --nproc_per_node=1 --master_port=29500 /home/pai/lib/python3.6/site-packages/easycv/tools/train.py cub_resnet50_jpg.py --work_dir work_dirs/metric_learning /cub/r50 --launcher pytorch --fp16
export model
After the model training is complete, use the export command to export the model for inference. The exported model contains the pre-processing information and post-processing information required for inference.
# View the pt file generated by training
!ls work_dirs/metric_learning/cub/r50*
RetrivalTopKEvaluator_R@K=1_best.pth is the pth with the highest acc generated during the training process, and the model is exported
!python -m easycv.tools.export ./cub_resnet50_jpg.py work_dirs/metric_learning/cub/r50/RetrivalTopKEvaluator_R@K=1_best.pth work_dirs/metric_learning/cub/r50/best_export.pth
predict
Download test image
! wget http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/fine_grain_cls/cub_raw/images/001.Black_footed_Albatross/Black_Footed_Albatross_0001_796111.jpg
Import model weights and predict classification results for test images
import cv2
from easycv.predictors.feature_extractor import TorchFeatureExtractor
output_ckpt = 'work_dirs/metric_learning/cub/r50/best_export.pth'
tcls = TorchFeatureExtractor(output_ckpt)
img = cv2.imread('Black_Footed_Albatross_0001_796111.jpg')
# input image should be RGB order
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
output = tcls. predict([img])

Related Articles

Explore More Special Offers

  1. Short Message Service(SMS) & Mail Service

    50,000 email package starts as low as USD 1.99, 120 short messages start at only USD 1.00

phone Contact Us