すべてのプロダクト
Search
ドキュメントセンター

AnalyticDB:ケース: 画像ベースの検索システムを構築する

最終更新日:Sep 24, 2024

このトピックでは、AnalyticDB for PostgreSQLベクトルデータベースを使用してイメージベースの検索システムを構築する方法について説明します。

背景情報

画像ベースの検索は私たちの生活の中でよく使われます。 美しいドレスやお気に入りのスニーカーをテレビで見せたい場合は、画像を撮ってから淘宝網にアップロードして製品をすばやく見つけることができます。 映画の名前を知りたい場合は、検索エンジンの画像検索ボックスに映画のスクリーンショットを貼り付けることができます。 画像ベースの検索では、多数の人間の画像から人の画像をすばやく見つけることができます。 検索エンジンによって提供される画像ベースの検索機能に加えて、AnalyticDB for PostgreSQLベクトルデータベースによって提供されるベクトル検索機能を使用して、SQL構文に基づいて画像ベースの検索システムを構築できます。

イメージベース検索の概要

画像ベースの検索は、逆画像検索とも呼ばれる。 これは、コンテンツベースの画像検索技術である。 画像ベースの検索システムは、多数の画像から、問い合わせ対象の画像にコンテンツが最も近いレコードを返すことができる。 たとえば、商品を含む画像を指定した場合、画像ベースの検索システムは、商品と同じまたは類似した主要なオブジェクトを含む画像を返します。 人間の顔を含む画像を提供すると、画像ベースの検索システムは、人間の顔と同様の顔の特徴を共有する画像を返します。

イメージベースの検索は、2つのコアモジュールに依存します。

  • 特徴抽出モジュール: ソース画像から視覚的特徴を抽出して、高次元特徴ベクトルを取得する。 画像の特徴ベクトルがソース画像の特徴ベクトルに近いほど、画像はソース画像により類似する。

  • ベクトル検索モジュール: 多数の特徴ベクトルからクエリベクトルに最も近いk個のレコードを見つけて返す。

画像特徴抽出

一般的なフィーチャ抽出アルゴリズムでは、Visual Geometry Group (VGG) 、ResNet、Transformerなどの深層学習モデルをバックボーンネットワークとして使用し、さまざまなメソッドを使用してフィーチャを生成します。 フィーチャを生成するには、一般的に3つの方法があります。

  • 方法1: VGGなどの分類モデルの分類レイヤーの前のレイヤーを出力フィーチャとして使用します。 これは最も単純な方法であり、画像ベースの検索シナリオでは高い再現率をもたらさない。

  • 方法2: モデルの中間層内のフィーチャに対して、畳み込みの局所的最大活性化 (RMAC) およびGeMなどの特別なプーリング操作を実行し、次いで、フィーチャの寸法を減少させる。

  • 方法3: 特定の損失関数を使用して、データセット上で事前トレーニングされたモデルをトレーニングし、特徴を抽出する。 例えば、製品特徴抽出モデルは、より正確な方法で異なる製品の視覚的特徴を抽出するために、類似の製品のデータセットに対して訓練される必要がある。

ビジネスシナリオに適した方法を選択して、画像フィーチャを抽出し、フィーチャベクトルを生成できます。

ベクトル検索

ベクトル検索は、最近傍探索 (NNS) とも呼ばれる。 それは、多数の特徴ベクトルの中から問い合わせベクトルに最も近いk個のレコードを見つけて返す。 クエリベクトルとデータベース内のすべてのベクトルとの間の距離を計算し、距離の結果を並べ替えることができます。 しかし、この方法は時間がかかり、大量のデータの要件を満たすことができません。

実際のアプリケーションシナリオでは、近似最近傍 (ANN: approximate nearest neighbor) 検索を一般的に使用して、クエリベクトルにおそらく最も近いデータを高速ではあるが精度が低い方法で返す。

次のいずれかの方法を使用して、ANN検索を実行できます。

  • 地域に敏感なハッシュ (LSH) に基づく

  • 製品の量子化に基づく

  • 画像に基づく

AnalyticDB for PostgreSQLベクトルデータベースを使用して画像で検索する

ステップ1: 特徴ベクトルの抽出

この例では、次のツールが使用されています。

  • プログラミング言語: Python 3.8。

  • ディープラーニングのフレームワーク: Pytorch。

  • データセット: CIFAR100。 データセットは、それぞれが600画像を含む100のカテゴリからなる。

  • ネットワーク: 事前にトレーニングされたSqueezeNet。 SqueezeNetは軽量で、1,000次元の特徴ベクトルを生成します。

説明

Jupyter Notebookを使用して次のコードを実行することを推奨します。

  1. Python環境を作成します。

    # We recommend that you use Anaconda to create a Python environment. 
    conda create -n adbpg_env python=3.8
    conda activate adbpg_env
    
    pip install torchvision
    pip install matplotlib
    pip install psycopg2cffi
  2. CIFAR100データセットをダウンロードして前処理します。

    import torch
    import torchvision
    
    from torchvision.transforms import (
        Compose, 
        Resize, 
        CenterCrop, 
        ToTensor, 
        Normalize
    )
    
    preprocess = Compose([
        Resize(256),
        CenterCrop(224),
        ToTensor(),
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    
    DATA_DIRECTORY = "/Users/XXX/Desktop/vector/CIFAR"
    datasets = {
        "CIFAR100": torchvision.datasets.CIFAR100(DATA_DIRECTORY, transform=preprocess, download=True)
    }
  3. (オプション) ダウンロードしたデータセットを照会します。

    import numpy as np
    import matplotlib.pyplot as plt
    from mpl_toolkits.axes_grid1 import ImageGrid
    
    def show_images_from_full_dataset(dset, num_rows, num_cols, indices):        
        im_arrays = np.take(dset.data, indices, axis=0)
        labels = map(dset.classes.__getitem__, np.take(dset.targets, indices))
    
        fig = plt.figure(figsize=(10, 10))
        grid = ImageGrid(
            fig, 
            111,
            nrows_ncols=(num_rows, num_cols),
            axes_pad=0.3)
        for ax, im_array, label in zip(grid, im_arrays, labels):
            ax.imshow(im_array)
            ax.set_title(label)
            ax.axis("off")
    
    dataset = datasets["CIFAR100"]
    show_images_from_full_dataset(dataset, 4, 8, [i for i in range(0, 32)])

    p684529.png

  4. Squeezenet1_1モデルを使用して、すべての画像の特徴ベクトルをバッチで生成し、特徴ベクトルファイルに保存します。 この例では、/Users/XXX/Desktop/vector/features/CIFAR100/featuresが特徴ベクトルファイルのパスです。

    # Prepare data. 
    BATCH_SIZE = 100
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE)
    
    # Download the model. 
    model = torchvision.models.squeezenet1_1(pretrained=True).eval()
    
    # Extract the feature vectors and write them to the features_file_path file. 
    features_file_path = "/Users/XXX/Desktop/vector/features/CIFAR100/features"
    feature_file = open(features_file_path, 'w')
    img_id = 0
    for batch_number, batch in enumerate(dataloader):
        with torch.no_grad():
            batch_imgs = batch[0]  # 0: images
            batch_labels = batch[1]  # 1: labels
            vector_values = model(batch_imgs).tolist()
    
            for i in range(len(vector_values)):
                img_label = dataset.classes[batch_labels[i].item()]
                # print(img_label)
                feature_file.write(str(img_id) + "|" + img_label + "|")
                
                vector_value = vector_values[i]
                assert len(vector_value) == 1000
    
                for j in range(len(vector_value)):
                    if j == 0:
                        feature_file.write("{")
                        feature_file.write(str(vector_value[j]) + ",")
                    elif j == len(vector_value) - 1:
                        feature_file.write(str(vector_value[j]))
                        feature_file.write("}")
                    else:
                        feature_file.write(str(vector_value[j]) + ",")
                feature_file.write("\n")
                
                img_id = img_id + 1
            print("finished extract feature vector for batch: ", batch_number)
    feature_file.close()

    画像の特徴ベクトルは、次の形式で表示されます。

    [2.67548513424756,2.186723470687866,2.376999616622925,2.3993351459503174,2.833254337310791,
    4.141584873199463,1.0177937746047974,2.0199387073516846,2.436871512298584,1.465838789939880,
    4,10.196249008178711,3.3932418823242188,6.087968826293945,7.661309242248535,7.66005373001098,
    6,5.481011390686035,7.513026237487795,5.552321434020996,4.685927867889404,5.635070323944092,...]

ステップ2: AnalyticDB for PostgreSQLベクトルデータベースにデータをインポートし、データを照会する

  1. テーブルを作成し、テーブルにベクトルインデックスを作成します。 この例では、Pythonのpsycopg2cffiライブラリを使用してベクターデータベースに接続します。

    重要

    データベースのベクトル機能を有効にする場合は、

    チケットを起票します。

    import os
    import psycopg2cffi
    
    # Configure the temporary environment variables. 
    # os.environ["PGHOST"] = "XX.XXX.XX.XXX"
    # os.environ["PGPORT"] = "XXXXX"
    # os.environ["PGDATABASE"] = "adbpg_test"
    # os.environ["PGUSER"] = "adbpg_test"
    # os.environ["PGPASSWORD"] = "adbpg_test"
    
    connection = psycopg2cffi.connect(
        host=os.environ.get("PGHOST", "XX.XXX.XX.XXX"),
        port=os.environ.get("PGPORT", "XXXXX"),
        database=os.environ.get("PGDATABASE", "adbpg_test"),
        user=os.environ.get("PGUSER", "adbpg_test"),
        password=os.environ.get("PGPASSWORD", "adbpg_test")
    )
    
    cursor = connection.cursor()
    
    # Specify an SQL statement to create a table. 
    create_table_sql = """
    CREATE TABLE IF NOT EXISTS public.image_search (
        id INTEGER NOT NULL,
        class TEXT,
        image_vector REAL[],
        PRIMARY KEY(id)
    ) DISTRIBUTED BY(id);
    """
    
    # Specify an SQL statement to change the storage format of the vector column to PLAIN. 
    alter_vector_storage_sql = """
    ALTER TABLE public.image_search ALTER COLUMN image_vector SET STORAGE PLAIN;
    """
    
    # Specify an SQL statement to create a vector index. 
    create_indexes_sql = """
    CREATE INDEX ON public.image_search USING ann (image_vector) WITH (dim = '1000', hnsw_m = '100', pq_enable='0');
    """
    
    # Execute the preceding SQL statements. 
    cursor.execute(create_table_sql)
    cursor.execute(alter_vector_storage_sql)
    cursor.execute(create_indexes_sql)
    connection.commit()
  2. データセットに含まれる画像の特徴ベクトルをテーブルにインポートします。

    import io
    
    # Define a generator function to process the data in the file line by line. 
    def process_file(file_path):
        with open(file_path, 'r') as file:
            for line in file:
                yield line
    
    # Specify a COPY statement to import data. 
    copy_command = """
    COPY public.image_search (id, class, image_vector)
    FROM STDIN WITH (DELIMITER '|');
    """
    
    # Prepare the feature vector file. 
    features_file_path = "/Users/XXX/Desktop/vector/features/CIFAR100/features"
    
    # Execute the COPY statement. 
    modified_lines = io.StringIO(''.join(list(process_file(features_file_path))))
    cursor.copy_expert(copy_command, modified_lines)
    connection.commit()
  3. 特徴ベクトルファイルに含まれる画像の特徴ベクトルに基づいて画像を検索します。 この例では、IDが4999である画像が使用される。

    def query_analyticdb(collection_name, vector_name, query_embedding, top_k=20):
        # Specify an SQL statement to return images whose feature vectors are the closest to the query vector and calculate the similarity to the query vector. 
        query_sql = f"""
        SELECT id, class, l2_distance({vector_name},Array{query_embedding}::real[]) AS similarity
        FROM {collection_name}
        ORDER BY {vector_name} <-> Array{query_embedding}::real[]
        LIMIT {top_k};
        """
    
        # Execute the preceding SQL statement. 
        connection = psycopg2cffi.connect(
            host=os.environ.get("PGHOST", "XX.XXX.XX.XXX"),
            port=os.environ.get("PGPORT", "XXXXX"),
            database=os.environ.get("PGDATABASE", "adbpg_test"),
            user=os.environ.get("PGUSER", "adbpg_test"),
            password=os.environ.get("PGPASSWORD", "adbpg_test")
        )
    
        cursor = connection.cursor()
        cursor.execute(query_sql)
        results = cursor.fetchall()
        
        return results
      
      # Select a piece of data as the query vector. 
    def select_feature(file_path, expect_id):
        with open(file_path, 'r') as file:
            for line in file:
                datas = line.split('|')
                if datas[0] == str(expect_id):
                    vec = '[' + datas[2][1:-2] + ']'
                    return vec
        raise ValueError(f"no id = {expect_id}")
    
    file_path = "/Users/lizhenjing/Desktop/vector/features/CIFAR100/features"
    
    # Select an image whose ID is 4999. 
    query_vector = select_feature(file_path, 4999)
    # Display this image. 
    # show_images_from_full_dataset(dataset, 1, 1, [4999], figsize=(1, 1))
    # print(query_vector)
    
    # Execute the query to display the query results. 
    results = query_analyticdb("image_search", "image_vector", query_vector)

    次の図は、IDが4999の画像を示しています。

    p684560.png

  4. クエリ結果に対応する画像を表示します。

    説明

    AnalyticDB for PostgreSQLベクトルデータベースは、クエリを高速化するためのANN検索機能を提供します。

    # Obtain the image IDs from the query results in the previous step. 
    indices = []
    for item in results:
        indices.append(item[0])
    print(indices)
    
    # Display the images. 
    show_images_from_full_dataset(dataset, 4, 5, indices)

    次の図は、返されたイメージを示しています。p684568.png