すべてのプロダクト
Search
ドキュメントセンター

Artificial Intelligence Recommendation:DropoutNet モデルのトレーニングとデプロイ

最終更新日:Nov 09, 2025

コールドスタート DropoutNet アルゴリズムの詳細については、「コールドスタート推薦のための DropoutNet モデルの詳細な分析と改善」をご参照ください。

オフライン トレーニングサンプルの準備

テンプレートを使用して SQL コードを生成し、オフライン トレーニングサンプルを構築できます。

テンプレート構成:

{
  "cold_start_recall": {
    "model_name": "cold_start",
    "model_type": "dropoutnet",
    "label": {
      "name": "is_click",
      "selection": "max(if(event=\"click\", 1, 0))",  // クリックイベントの最大値を選択します。
      "type": "CLASSIFICATION"
    },
    "train_days": 14
  }
}

DropoutNet モデルのトレーニング

Platform for AI (PAI) コマンドを実行してモデルをトレーニングできます。

pai -name easy_rec_ext
-Dcmd='train'
-Dconfig='oss://${bucket}/EasyRec/sv_dropout_net/sv_dropoutnet.config'
-Dtrain_tables='odps://${project}/tables/dwd_samples_for_dropoutnet/dt=${bizdate}'
-Deval_tables='odps://${project}/tables/dwd_sv_cold_start_samples/dt=${bizdate}'
-Dboundary_table='odps://${project}/tables/cold_start_feature_binning'
-Dmodel_dir='oss://${bucket}/EasyRec/sv_dropout_net/${bizdate}'
-Dedit_config_json='{"train_config.fine_tune_checkpoint":"oss://${bucket}/EasyRec/sv_dropout_net/${yesterday}/"}'
-Dbuckets='oss://${bucket}/'
-Darn='acs:ram::XXXXXXXXXXX:role/aliyunodpspaidefaultrole'
-DossHost='oss-cn-beijing-internal.aliyuncs.com'
-Dcluster='{
  \"ps\": {
  \"count\" : 1,
  \"cpu\" : 800
  },
  \"worker\" : {
  \"count\" : 9,
  \"cpu\" : 800
  }
}';

モデルをユーザー埋め込み子モデルとアイテム埋め込み子モデルに分割できます。

pai -name tensorflow1120_cpu_ext
-Dscript='oss://${bucket}/EasyRec/sv_dropout_net/split_model_pai_v2.py'
-Dbuckets='oss://${bucket}/'
-Darn='acs:ram::XXXXXXXXXXXX:role/aliyunodpspaidefaultrole'
-DossHost='oss-cn-beijing-internal.aliyuncs.com'
-DuserDefinedParameters='--model_dir=oss://${bucket}/EasyRec/sv_dropout_net/${bizdate}/export/final --user_model_dir=oss://${bucket}/EasyRec/sv_dropout_net/${bizdate}/export/user --item_model_dir=oss://${bucket}/EasyRec/sv_dropout_net/${bizdate}/export/item';

モデル サービスのデプロイ

デプロイメントスクリプト

bizdate=$1
cat << EOF > eas_config.json
{
    "name": "sv_dropoutnet",
    "metadata": {
        "cpu": 2,
        "instance": 1,
        "memory": 6000
    },
    "processor": "tensorflow_cpu",
    "model_path": "oss://${bucket}/EasyRec/sv_dropout_net/${bizdate}/export/item/"
}
EOF

# サービスを作成します。
 /home/admin/usertools/tools/eascmd \
 -i ${accessId} \
 -k ${accessKey} \
 -e pai-eas.cn-beijing.aliyuncs.com create eas_config.json

# サービスを更新します。
#/home/admin/usertools/tools/eascmd \
# -i ${accessId} \
# -k ${accessKey} \
# -e pai-eas.cn-beijing.aliyuncs.com \
# modify sv_dropoutnet -s eas_config.json

# サービスを表示します。
echo "-------------------View the service-------------------"
/home/admin/usertools/tools/eascmd \
-i ${accessId} \
-k ${accessKey} \
-e pai-eas.cn-beijing.aliyuncs.com desc sv_dropoutnet

リアルタイム機能の計算

1. Flink のアイテムテーブルの binlog に接続する

Flink では、Hologres アイテムテーブルの binlog と新しいアイテムのビューを作成できます。

create TEMPORARY table item_table_binlog (
  hg_binlog_lsn BIGINT,
  hg_binlog_event_type BIGINT,
  hg_binlog_timestamp_us BIGINT,
  itemId bigint,
  ...
  createTime TIMESTAMP,
  ets AS TO_TIMESTAMP(FROM_UNIXTIME(hg_binlog_timestamp_us / 1000000)),
  WATERMARK FOR ets AS ets - INTERVAL '5' MINUTE
) with (
  'connector'='hologres',
  'endpoint' = 'hgpostcn-cn-XXXXXX-cn-beijing-vpc.hologres.aliyuncs.com:80',
  'username' = '${username}',
  'password' = '${passowrod}',
  'dbname' = '${dbname}',
  'tablename' = 'item_table',  // アイテムテーブル
  'binlog' = 'true',
  'binlogMaxRetryTimes' = '10',
  'binlogRetryIntervalMs' = '500',
  'binlogBatchReadSize' = '256',
  'startTime' = '2022-03-03 00:00:00'
);

CREATE TEMPORARY VIEW if NOT EXISTS new_item_view  // 新しいアイテムビュー
AS
SELECT itemId, ..., createTime,  // アイテムID、作成時間など
  PROCTIME() AS proc_time, ets
FROM smart_video_binlog
WHERE hg_binlog_event_type IN (5, 7) --INSERT=5, AFTER_UPDATE=7  // 挿入=5、更新後=7
AND createTime >= CURRENT_TIMESTAMP - INTERVAL '24' HOUR  // 作成時間が過去24時間以内
;

2. アイテム機能を結合する

Hologres アイテム特徴テーブルを作成し、次に Flink で一時テーブルを作成してシンク先のテーブルとして使用できます。

create TEMPORARY table item_cold_start_feature (  // アイテムコールドスタート機能
  itemId bigint,
  ...
  update_time bigint
) with (
  'connector'='hologres',
  'dbname'='${dbname}',
  'tablename'='sv_rec.sv_cold_start_feature',  // コールドスタート機能テーブル
  'username'='${user_name}',
  'password'='${password}',
  'endpoint'='hgpostcn-cn-xxxxxxxxxx-cn-beijing-vpc.hologres.aliyuncs.com:80',
  'mutatetype'='insertorupdate'
);

INSERT INTO item_cold_start_feature  // item_cold_start_feature に挿入
SELECT 
  v.itemId,  // アイテムID
  v.userId AS author,  // 作成者
  s.primaryId AS primary_type,  // プライマリタイプ
  v.title,  // タイトル
  TIMESTAMPDIFF(DAY, v.createTime, LOCALTIMESTAMP) AS pub_days,  // 公開日数
  v.duration,  // 再生時間
  v.sourceType as source_type,  // ソースタイプ
  v.inTimeOrNot as intimeornot,  // タイムリーかどうか
  v.is_prop,  // プロパティかどうか
  COALESCE(s.gradeScore, v.gradeScore) AS grade_score,  // グレードスコア
  v.width,  // 幅
  v.height,  // 高さ
  v.firstPublishSongOrNot AS is_first_publish_song,  // 最初の曲公開かどうか
  COALESCE(v.topic_id, '') as topic_id,  // トピックID
  t.cate_name1,  // カテゴリ名1
  t.cate_name2,  // カテゴリ名2
  t.video_tags,  // ビデオタグ
  au.author_gender,  // 作成者の性別
  au.author_level,  // 作成者のレベル
  au.author_is_member,  // 作成者がメンバーかどうか
  au.author_city,  // 作成者の都市
  au.author_type,  // 作成者のタイプ
  au.author_fans_num,  // 作成者のファン数
  au.author_visitor_num,  // 作成者の訪問者数
  au.author_billboard_num,  // 作成者のビルボード数
  au.author_av_ct,  // 作成者の平均視聴回数
  au.author_sv_ct,  // 作成者のショートビデオ数
  au.author_play_ct,  // 作成者の再生回数
  au.author_play_avg_ct,  // 作成者の平均再生回数
  au.author_like_ct,  // 作成者のいいね数
  au.author_download_ct,  // 作成者のダウンロード数
  au.family_hot_ranking,  // ファミリーホットランキング
  au.author_diamond_ct,  // 作成者のダイヤモンド数
  au.author_flower_ct,  // 作成者の花数
  CAST(STR_TO_MAP(au.author_sv_type_play_ct_1, ',', ':')[CAST(s.primaryId as VARCHAR)] AS bigint) AS author_sv_type_play_ct_1,  // 作成者のショートビデオタイプ再生回数_1
  CAST(STR_TO_MAP(au.author_sv_type_play_ct_7, ',', ':')[CAST(s.primaryId as VARCHAR)] AS bigint) AS author_sv_type_play_ct_7,  // 作成者のショートビデオタイプ再生回数_7
  CAST(STR_TO_MAP(au.author_sv_type_play_ct_15, ',', ':')[CAST(s.primaryId as VARCHAR)] AS bigint) AS author_sv_type_play_ct_15,  // 作成者のショートビデオタイプ再生回数_15
  au.author_play_ct_1,  // 作成者の再生回数_1
  au.author_play_ct_7,  // 作成者の再生回数_7
  au.author_play_ct_15,  // 作成者の再生回数_15
  au.author_like_ct_1,  // 作成者のいいね数_1
  au.author_like_ct_7,  // 作成者のいいね数_7
  au.author_like_ct_15,  // 作成者のいいね数_15
  au.author_comment_ct_1,  // 作成者のコメント数_1
  au.author_comment_ct_7,  // 作成者のコメント数_7
  au.author_comment_ct_15,  // 作成者のコメント数_15
  au.author_share_ct_1,  // 作成者のシェア数_1
  au.author_share_ct_7,  // 作成者のシェア数_7
  au.author_share_ct_15,  // 作成者のシェア数_15
  au.author_tags,  // 作成者のタグ
  TIMESTAMPDIFF(DAY, au.author_last_live_time, LOCALTIMESTAMP) AS author_last_live_days,  // 作成者の最後のライブからの日数
  UNIX_TIMESTAMP() as update_time,  // 更新時間
  t.name_embedding,  // 名前埋め込み
  t.tag_embedding  // タグ埋め込み
FROM new_item_view AS v
LEFT JOIN author_feature FOR SYSTEM_TIME AS OF v.proc_time as au  // 作成者機能を結合
ON v.userId = au.author_id
LEFT JOIN smart_video_sign FOR SYSTEM_TIME AS OF v.proc_time as s  // スマートビデオサインを結合
ON v.smartVideoId = s.svid
LEFT JOIN video_name_tag_embedding FOR SYSTEM_TIME AS OF v.proc_time as t  // ビデオ名タグ埋め込みを結合
ON v.smartVideoId = t.svid
;

3. 新しいアイテム埋め込みの生成

アイテム埋め込み用の Hologres テーブルと、シンク先のテーブルとして機能する Flink 一時テーブルを作成できます。

create TEMPORARY table item_dropoutnet_embedding (  // アイテム DropoutNet 埋め込み
  itemId    bigint,  // アイテムID
  embedding ARRAY<FLOAT>,  // 埋め込み
  update_time bigint  // 更新時間
) with (
  'connector'='hologres',
  'dbname'='${dbname}',
  'tablename'='sv_rec.sv_dropoutnet_embedding',  // DropoutNet 埋め込みテーブル
  'username'='${username}',
  'password'='${password}',
  'endpoint'='hgpostcn-cn-xxxxxxxxxxxx-cn-beijing-vpc.hologres.aliyuncs.com:80',
  'mutatetype'='insertorreplace',
  'field_delimiter'=','
);

DropoutNet モデルの Elastic Algorithm Service (EAS) サービスを呼び出すユーザー定義関数 (UDF) を開発できます。その後、Flink SQL で UDF を呼び出してリアルタイムでアイテム埋め込みを生成し、オンラインで使用するために Hologres に保存できます。

INSERT INTO item_dropoutnet_embedding
SELECT
  f.svid,
  InvokeEasUdf(
    'sv_dropoutnet',
    '${endpoint}',
    '${token}',
    f.primary_type,
    f.title,
    f.pub_days,
    f.duration,
    f.source_type,
    f.intimeornot,
    f.is_prop,
    f.grade_score,
    f.width,
    f.height,
    f.is_first_publish_song,
    f.topic_id,
    COALESCE(t.cate_name1, f.cate_name1),
    COALESCE(t.cate_name2, f.cate_name2),
    COALESCE(t.video_tags, f.video_tags),
    f.author_gender,
    f.author_level,
    f.author_is_member,
    f.author_city,
    f.author_type,
    f.author_fans_num,
    f.author_visitor_num,
    f.author_billboard_num,
    f.author_av_ct,
    f.author_sv_ct,
    f.author_play_ct,
    f.author_play_avg_ct,
    f.author_like_ct,
    f.author_download_ct,
    f.family_hot_ranking,
    f.author_diamond_ct,
    f.author_flower_ct,
    f.author_sv_type_play_ct_1,
    f.author_sv_type_play_ct_7,
    f.author_sv_type_play_ct_15,
    f.author_play_ct_1,
    f.author_play_ct_7,
    f.author_play_ct_15,
    f.author_like_ct_1,
    f.author_like_ct_7,
    f.author_like_ct_15,
    f.author_comment_ct_1,
    f.author_comment_ct_7,
    f.author_comment_ct_15,
    f.author_share_ct_1,
    f.author_share_ct_7,
    f.author_share_ct_15,
    f.author_tags,
    f.author_last_live_days,
    COALESCE(t.name_embedding, f.name_embedding),
    COALESCE(t.tag_embedding, f.tag_embedding)
  ) as embedding,
  UNIX_TIMESTAMP() as update_time
FROM video_name_tag_embedding_hi as t
JOIN sv_cold_start_feature FOR SYSTEM_TIME AS OF t.proc_time as f
ON t.svid = f.svid and t.hg_binlog_event_type IN (5, 7);

次のコードは、EAS サービスを呼び出すための Flink UDF の例を示しています。

package com.alibaba.pairec.udf;

import com.aliyun.openservices.eas.predict.http.HttpConfig;
import com.aliyun.openservices.eas.predict.http.PredictClient;
import com.aliyun.openservices.eas.predict.request.TFDataType;
import com.aliyun.openservices.eas.predict.request.TFRequest;
import com.aliyun.openservices.eas.predict.response.TFResponse;
import org.apache.flink.table.functions.ScalarFunction;
import org.apache.log4j.Logger;
import java.util.*;
import java.util.stream.Collectors;


public class InvokeEasUdf extends ScalarFunction {
    private volatile static PredictClient client;
    private static final Logger logger = Logger.getLogger(InvokeEasUdf.class);

    public static PredictClient getClient(String modelName, String endpoint, String token) {
        if (null == client) {
            synchronized (InvokeEasUdf.class) {
                if (null == client) {
                    client = new PredictClient(new HttpConfig());
                    client.setToken(token);
                    client.setEndpoint(endpoint);
                    client.setModelName(modelName);
                    client.setIsCompressed(false);
                }
            }
        }
        return client;
    }

    public static TFRequest buildPredictRequest(
            Long primary_type,
            String title,
            Long pub_days,
            Double duration,
            Long source_type,
            Long intimeornot,
            Long is_prop,
            Long grade_score,
            Long width,
            Long height,
            Long is_first_publish_song,
            String topic_id,
            String cate_name1,
            String cate_name2,
            String video_tags,
            Long author_gender,
            Long author_level,
            Long author_is_member,
            String author_city,
            String author_type,
            Long author_fans_num,
            Long author_visitor_num,
            Long author_billboard_num,
            Long author_av_ct,
            Long author_sv_ct,
            Long author_play_ct,
            Long author_play_avg_ct,
            Long author_like_ct,
            Long author_download_ct,
            Long family_hot_ranking,
            Long author_diamond_ct,
            Long author_flower_ct,
            Long author_sv_type_play_ct_1,
            Long author_sv_type_play_ct_7,
            Long author_sv_type_play_ct_15,
            Long author_play_ct_1,
            Long author_play_ct_7,
            Long author_play_ct_15,
            Long author_like_ct_1,
            Long author_like_ct_7,
            Long author_like_ct_15,
            Long author_comment_ct_1,
            Long author_comment_ct_7,
            Long author_comment_ct_15,
            Long author_share_ct_1,
            Long author_share_ct_7,
            Long author_share_ct_15,
            String author_tags,
            Long author_last_live_days,
            String name_embedding,
            String tag_embedding
    ) {
        TFRequest request = new TFRequest();
        request.setSignatureName("serving_default");

        request.addFeed("author_av_ct",
                TFDataType.DT_INT64,
                new long[]{1},
                new long[]{author_av_ct == null ? 0 : author_av_ct});
        request.addFeed("author_billboard_num", TFDataType.DT_INT64, new long[]{1}, new long[]{author_billboard_num == null ? 0 : author_billboard_num});
        request.addFeed("author_city", TFDataType.DT_STRING, new long[]{1}, new String[]{author_city == null ? "" : author_city});
        request.addFeed("author_comment_ct_1", TFDataType.DT_INT64, new long[]{1}, new long[]{author_comment_ct_1 == null ? 0 : author_comment_ct_1});
        request.addFeed("author_comment_ct_7", TFDataType.DT_INT64, new long[]{1}, new long[]{author_comment_ct_7 == null ? 0 : author_comment_ct_7});
        request.addFeed("author_comment_ct_15", TFDataType.DT_INT64, new long[]{1}, new long[]{author_comment_ct_15 == null ? 0 : author_comment_ct_15});
        request.addFeed("author_diamond_ct", TFDataType.DT_INT64, new long[]{1}, new long[]{author_diamond_ct == null ? 0 : author_diamond_ct});
        request.addFeed("author_download_ct", TFDataType.DT_INT64, new long[]{1}, new long[]{author_download_ct == null ? 0 : author_download_ct});
        request.addFeed("author_fans_num", TFDataType.DT_INT64, new long[]{1}, new long[]{author_fans_num == null ? 0 : author_fans_num});
        request.addFeed("author_flower_ct", TFDataType.DT_INT64, new long[]{1}, new long[]{author_flower_ct == null ? 0 : author_flower_ct});
        request.addFeed("author_gender", TFDataType.DT_INT64, new long[]{1}, new long[]{author_gender == null ? 0 : author_gender});
        request.addFeed("author_is_member", TFDataType.DT_INT64, new long[]{1}, new long[]{author_is_member == null ? 0 : author_is_member});
        request.addFeed("author_last_live_days", TFDataType.DT_INT64, new long[]{1}, new long[]{author_last_live_days == null ? 0 : author_last_live_days});
        request.addFeed("author_level", TFDataType.DT_INT64, new long[]{1}, new long[]{author_level == null ? 0 : author_level});
        request.addFeed("author_like_ct", TFDataType.DT_INT64, new long[]{1}, new long[]{author_like_ct == null ? 0 : author_like_ct});
        request.addFeed("author_like_ct_1", TFDataType.DT_INT64, new long[]{1}, new long[]{author_like_ct_1 == null ? 0 : author_like_ct_1});
        request.addFeed("author_like_ct_15", TFDataType.DT_INT64, new long[]{1}, new long[]{author_like_ct_15 == null ? 0 : author_like_ct_15});
        request.addFeed("author_like_ct_7", TFDataType.DT_INT64, new long[]{1}, new long[]{author_like_ct_7 == null ? 0 : author_like_ct_7});
        request.addFeed("author_play_avg_ct", TFDataType.DT_INT64, new long[]{1}, new long[]{author_play_avg_ct == null ? 0 : author_play_avg_ct});
        request.addFeed("author_play_ct", TFDataType.DT_INT64, new long[]{1}, new long[]{author_play_ct == null ? 0 : author_play_ct});
        request.addFeed("author_play_ct_1", TFDataType.DT_INT64, new long[]{1}, new long[]{author_play_ct_1 == null ? 0 : author_play_ct_1});
        request.addFeed("author_play_ct_15", TFDataType.DT_INT64, new long[]{1}, new long[]{author_play_ct_15 == null ? 0 : author_play_ct_15});
        request.addFeed("author_play_ct_7", TFDataType.DT_INT64, new long[]{1}, new long[]{author_play_ct_7 == null ? 0 : author_play_ct_7});
        request.addFeed("author_share_ct_1", TFDataType.DT_INT64, new long[]{1}, new long[]{author_share_ct_1 == null ? 0 : author_share_ct_1});
        request.addFeed("author_share_ct_15", TFDataType.DT_INT64, new long[]{1}, new long[]{author_share_ct_15 == null ? 0 : author_share_ct_15});
        request.addFeed("author_share_ct_7", TFDataType.DT_INT64, new long[]{1}, new long[]{author_share_ct_7 == null ? 0 : author_share_ct_7});
        request.addFeed("author_sv_ct", TFDataType.DT_INT64, new long[]{1}, new long[]{author_sv_ct == null ? 0 : author_sv_ct});
        request.addFeed("author_sv_type_play_ct_1", TFDataType.DT_INT64, new long[]{1}, new long[]{author_sv_type_play_ct_1 == null ? 0 : author_sv_type_play_ct_1});
        request.addFeed("author_sv_type_play_ct_15", TFDataType.DT_INT64, new long[]{1}, new long[]{author_sv_type_play_ct_15 == null ? 0 : author_sv_type_play_ct_15});
        request.addFeed("author_sv_type_play_ct_7", TFDataType.DT_INT64, new long[]{1}, new long[]{author_sv_type_play_ct_7 == null ? 0 : author_sv_type_play_ct_7});
        request.addFeed("author_tags", TFDataType.DT_STRING, new long[]{1}, new String[]{author_tags == null ? "" : author_tags});
        request.addFeed("author_type", TFDataType.DT_STRING, new long[]{1}, new String[]{author_type == null ? "" : author_type});
        request.addFeed("author_visitor_num", TFDataType.DT_INT64, new long[]{1}, new long[]{author_visitor_num == null ? 0 : author_visitor_num});
        request.addFeed("cate_name1", TFDataType.DT_STRING, new long[]{1}, new String[]{cate_name1 == null ? "" : cate_name1});
        request.addFeed("cate_name2", TFDataType.DT_STRING, new long[]{1}, new String[]{cate_name2 == null ? "" : cate_name2});
        request.addFeed("duration", TFDataType.DT_DOUBLE, new long[]{1}, new double[]{duration == null ? 0 : duration});
        request.addFeed("family_hot_ranking", TFDataType.DT_INT64, new long[]{1}, new long[]{family_hot_ranking == null ? 0 : family_hot_ranking});
        request.addFeed("grade_score", TFDataType.DT_INT64, new long[]{1}, new long[]{grade_score == null ? 0 : grade_score});
        request.addFeed("height", TFDataType.DT_INT64, new long[]{1}, new long[]{height == null ? 0 : height});
        request.addFeed("intimeornot", TFDataType.DT_INT64, new long[]{1}, new long[]{intimeornot == null ? 0 : intimeornot});
        request.addFeed("is_first_publish_song", TFDataType.DT_INT64, new long[]{1}, new long[]{is_first_publish_song == null ? 0 : is_first_publish_song});
        request.addFeed("is_prop", TFDataType.DT_INT64, new long[]{1}, new long[]{is_prop == null ? 0 : is_prop});
        request.addFeed("primary_type", TFDataType.DT_INT64, new long[]{1}, new long[]{primary_type == null ? 0 : primary_type});
        request.addFeed("pub_days", TFDataType.DT_INT64, new long[]{1}, new long[]{pub_days == null ? 0 : pub_days});
        request.addFeed("source_type", TFDataType.DT_INT64, new long[]{1}, new long[]{source_type == null ? 0 : source_type});
        request.addFeed("title", TFDataType.DT_STRING, new long[]{1}, new String[]{title == null ? "" : title});
        request.addFeed("topic_id", TFDataType.DT_STRING, new long[]{1}, new String[]{topic_id == null ? "" : topic_id});
        request.addFeed("video_tags", TFDataType.DT_STRING, new long[]{1}, new String[]{video_tags == null ? "" : video_tags});
        request.addFeed("width", TFDataType.DT_INT64, new long[]{1}, new long[]{width == null ? 0 : width});
        request.addFeed("name_embedding", TFDataType.DT_STRING, new long[]{1}, new String[]{name_embedding == null ? "" : name_embedding});
        request.addFeed("tag_embedding", TFDataType.DT_STRING, new long[]{1}, new String[]{tag_embedding == null ? "" : tag_embedding});
        request.addFetch("item_emb");
        return request;
    }

    protected void finalize() {
        if (null != client) {
            client.shutdown();
        }
    }

    public List<Float> eval(String modelName, String endpoint, String token,
                       Long primary_type,
                       String title,
                       Long pub_days,
                       Double duration,
                       Long source_type,
                       Long intimeornot,
                       Long is_prop,
                       Long grade_score,
                       Long width,
                       Long height,
                       Long is_first_publish_song,
                       String topic_id,
                       String cate_name1,
                       String cate_name2,
                       String video_tags,
                       Long author_gender,
                       Long author_level,
                       Long author_is_member,
                       String author_city,
                       String author_type,
                       Long author_fans_num,
                       Long author_visitor_num,
                       Long author_billboard_num,
                       Long author_av_ct,
                       Long author_sv_ct,
                       Long author_play_ct,
                       Long author_play_avg_ct,
                       Long author_like_ct,
                       Long author_download_ct,
                       Long family_hot_ranking,
                       Long author_diamond_ct,
                       Long author_flower_ct,
                       Long author_sv_type_play_ct_1,
                       Long author_sv_type_play_ct_7,
                       Long author_sv_type_play_ct_15,
                       Long author_play_ct_1,
                       Long author_play_ct_7,
                       Long author_play_ct_15,
                       Long author_like_ct_1,
                       Long author_like_ct_7,
                       Long author_like_ct_15,
                       Long author_comment_ct_1,
                       Long author_comment_ct_7,
                       Long author_comment_ct_15,
                       Long author_share_ct_1,
                       Long author_share_ct_7,
                       Long author_share_ct_15,
                       String author_tags,
                       Long author_last_live_days,
                       String name_embedding,
                       String tag_embedding
    ) {
        PredictClient predictor = getClient(modelName, endpoint, token);
        TFRequest request = buildPredictRequest(
                primary_type,
                title,
                pub_days,
                duration,
                source_type,
                intimeornot,
                is_prop,
                grade_score,
                width,
                height,
                is_first_publish_song,
                topic_id,
                cate_name1,
                cate_name2,
                video_tags,
                author_gender,
                author_level,
                author_is_member,
                author_city,
                author_type,
                author_fans_num,
                author_visitor_num,
                author_billboard_num,
                author_av_ct,
                author_sv_ct,
                author_play_ct,
                author_play_avg_ct,
                author_like_ct,
                author_download_ct,
                family_hot_ranking,
                author_diamond_ct,
                author_flower_ct,
                author_sv_type_play_ct_1,
                author_sv_type_play_ct_7,
                author_sv_type_play_ct_15,
                author_play_ct_1,
                author_play_ct_7,
                author_play_ct_15,
                author_like_ct_1,
                author_like_ct_7,
                author_like_ct_15,
                author_comment_ct_1,
                author_comment_ct_7,
                author_comment_ct_15,
                author_share_ct_1,
                author_share_ct_7,
                author_share_ct_15,
                author_tags,
                author_last_live_days,
                name_embedding,
                tag_embedding
        );
        TFResponse response;
        try {
            response = predictor.predict(request);
            List<String> result = response.getStringVals("item_emb");
            String embedding = result.get(0);
            String[] emb = embedding.split(",");
            return Arrays.stream(emb).map(Float::valueOf).collect(Collectors.toList());
        } catch (Exception e) {
            logger.error("call eas failed." + e.getMessage());
            return Collections.EMPTY_LIST;
        }
    }

    public static void main(String[] args) {
        InvokeEasUdf udf = new InvokeEasUdf();
        List<Float> emb = udf.eval("sv_dropoutnet",
            "1103287870424018.cn-beijing.pai-eas.aliyuncs.com",
                "NDg4OGIwZGU2MjAzNzljMGZkNjI2ZWUxZWEzZjM4ZGYyNmU2ZWVmZA==",
            90L,
            "#2021\u001DActing\u001DAwards\u001D",
            0L, 72800.0, 4L, 0L, 0L, 5L,
            576L, 1024L, 1L, "97388",
            "Music", "Song", "Beauty\u001DSong\u001DMusic",
            0L, 6L, 1L, "", "", 0L, 3L,
            0L, 0L, 0L, 2L, 6L, 2L,
            0L, 0L, 0L, 0L, 0L, 0L,
            0L, 0L, 1L, 1L, 1L,
            0L, 0L, 0L, 0L, 0L,
            0L, 0L, 0L, "", 0L, "", ""
        );

        System.out.println(emb);
    }
}

次の Maven 依存関係を追加できます。

    <dependencies>
        <dependency>
            <groupId>org.apache.flink</groupId>
            <artifactId>flink-clients_2.12</artifactId>
            <version>${flink.version}</version>
            <scope>provided</scope>
        </dependency>
        <dependency>
            <groupId>org.apache.flink</groupId>
            <artifactId>flink-java</artifactId>
            <version>${flink.version}</version>
            <scope>provided</scope>
        </dependency>
        <dependency>
            <groupId>org.apache.flink</groupId>
            <artifactId>flink-streaming-java_2.12</artifactId>
            <version>${flink.version}</version>
            <scope>provided</scope>
        </dependency>
        <dependency>
            <groupId>org.apache.flink</groupId>
            <artifactId>flink-table-common</artifactId>
            <version>${flink.version}</version>
            <scope>provided</scope>
        </dependency>
        <dependency>
            <groupId>org.apache.flink</groupId>
            <artifactId>flink-table</artifactId>
            <version>${flink.version}</version>
            <type>pom</type>
            <scope>provided</scope>
        </dependency>
        <dependency>
            <groupId>log4j</groupId>
            <artifactId>log4j</artifactId>
            <version>1.2.17</version>
        </dependency>
        <dependency>
            <groupId>com.aliyun.openservices.eas</groupId>
            <artifactId>eas-sdk</artifactId>
            <version>2.0.3</version>
        </dependency>
    </dependencies>

ユーザー埋め込みベクターの準備

ユーザー特徴をオフラインで計算し、前のステップのユーザー子モデルを使用してユーザー埋め込みベクターを生成できます。

pai -name easy_rec_ext
-Dcmd='predict'
-Dconfig='oss://${bucket}/EasyRec/sv_dropout_net/sv_dropoutnet.config'
-Doutput_table='odps://${project}/tables/dropoutnet_user_embedding/dt=${bizdate}'  // 出力テーブル
-Dinput_table='odps://${project}/tables/dropoutnet_user_features/dt=${bizdate}'  // 入力テーブル
-Dsaved_model_dir='oss://${bucket}/EasyRec/sv_dropout_net/${bizdate}/export/final'
-Dreserved_cols="userid"
-Doutput_cols="user_emb string"
-Dmodel_outputs="user_emb"
-Dbuckets='oss://${bucket}/'
-Darn='acs:ram::XXXXXXXXXX:role/aliyunodpspaidefaultrole'
-DossHost='oss-cn-beijing-internal.aliyuncs.com'
-Dcluster='{
    \"worker\" : {
        \"count\" : 8,
        \"cpu\" : 600
    }
}';

最後に、ユーザー埋め込みベクターを Hologres にインポートできます。

リコール結果として上位 N 個のアイテムを取得

推薦サービスでは、Hologres ベクター検索エンジンを使用して、ユーザー埋め込みベクターに最も近い距離にある上位 N 個のアイテムをクエリできます。

func (r *ItemColdStartRecall) GetRetrieveSql(userEmb string) (string, []interface{}) {
    sb := sqlbuilder.PostgreSQL.NewSelectBuilder()
    vecIndex := sb.Args.Add(userEmb)
    dotProduct := fmt.Sprintf("pm_approx_inner_product_distance(%s,%s)", r.VectorEmbeddingField, vecIndex)  // 近似内積距離を計算
    sb.Select(r.VectorKeyField, sb.As(dotProduct, "distance"))  // ベクトルキーフィールドと距離を選択
    sb.From(r.VectorTable)  // ベクトルテーブルから
    sb.OrderBy("distance").Desc()  // 距離で降順にソート
    sb.Limit(r.recallCount)  // 上位 N 個を取得
    return sb.Build()
}