├── __init__.py ├── module ├── __init__.py ├── plugin.py ├── common.py ├── encoder.py ├── initializer.py └── attention.py ├── backend ├── __init__.py ├── dataset_hub │ ├── pin_datasets.py │ ├── base_datasets.py │ ├── ecrec_datasets.py │ ├── cigar_datasets.py │ ├── mcrec_dataset.py │ └── ggcn_datasets.py ├── arguments.py ├── model_hub │ ├── mcrec_model.py │ ├── ggcn_model.py │ ├── pinrec_model.py │ └── ecrec_model.py ├── utils.py └── task_backbone.py ├── onnx_test ├── __init__.py ├── mcrec_onnx_model_test.py ├── pinrec_onnx_model_test.py ├── ecrec_onnx_test.py ├── ggcn_onnx_model_test.py └── cigar_onnx_model_test.py ├── train_hub ├── __init__.py ├── model_execution_template.py ├── mcrec.py ├── cigar.py ├── ecrec.py ├── ggcn.py └── pinrec.py ├── GGCN.png ├── MCRec.png ├── cigar.png ├── PINRec.png ├── scripts ├── ecrec │ ├── ecrec_export.sh │ ├── ecrec_inference.sh │ └── ecrec_training.sh ├── ggcn │ ├── ggcn_training.sh │ ├── ggcn_inference.sh │ └── ggcn_export.sh ├── mcrec │ ├── mcrec_training.sh │ ├── mcrec_inference.sh │ └── mcrec_export.sh ├── cigar │ ├── cigar_training.sh │ ├── cigar_inference.sh │ └── cigar_export.sh └── pinrec │ ├── pinrec_training.sh │ ├── pinrec_inference.sh │ ├── pinrec_export.sh │ └── movielens_conf.json └── LICENSE /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /module/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /backend/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /onnx_test/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /train_hub/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /GGCN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luoxi-model/luoxi_models/HEAD/GGCN.png -------------------------------------------------------------------------------- /MCRec.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luoxi-model/luoxi_models/HEAD/MCRec.png -------------------------------------------------------------------------------- /cigar.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luoxi-model/luoxi_models/HEAD/cigar.png -------------------------------------------------------------------------------- /PINRec.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luoxi-model/luoxi_models/HEAD/PINRec.png -------------------------------------------------------------------------------- /scripts/ecrec/ecrec_export.sh: -------------------------------------------------------------------------------- 1 | echo "begin" 2 | if [ -d "tmp/" ];then 3 | rm -rf tmp 4 | fi 5 | 6 | save=/mnt2/songxiao/ecrec/test 7 | #infer=/mnt2/songxiao/ecrec/test/infer_result.txt 8 | #infer_table=/mnt2/songxiao/memory/book/book_10w_test_new.txt 9 | #data=/mnt2/songxiao/memory/book/book_10w_train_new.txt,/mnt2/songxiao/memory/book/book_10w_test_new.txt 10 | 11 | 12 | 13 | python train_hub/ecrec.py \ 14 | --task-type=onnx_export\ 15 | --onnx_export_path=${save}\ 16 | --load=${save} 17 | -------------------------------------------------------------------------------- /scripts/ecrec/ecrec_inference.sh: -------------------------------------------------------------------------------- 1 | echo "begin" 2 | if [ -d "tmp/" ];then 3 | rm -rf tmp 4 | fi 5 | 6 | save=/mnt2/songxiao/ecrec/test 7 | infer=/mnt2/songxiao/ecrec/test/infer_result.txt 8 | infer_table=/mnt2/songxiao/memory/book/book_10w_test_new.txt 9 | data=/mnt2/songxiao/memory/book/book_10w_train_new.txt,/mnt2/songxiao/memory/book/book_10w_test_new.txt 10 | 11 | 12 | python train_hub/ecrec.py \ 13 | --eval-iters=135 \ 14 | --save=${save} \ 15 | --task-type=inference \ 16 | --outputs=${infer} \ 17 | --infer_table=${infer_table} \ 18 | --batch-size=128 \ 19 | --load=${save} 20 | -------------------------------------------------------------------------------- /scripts/ggcn/ggcn_training.sh: -------------------------------------------------------------------------------- 1 | echo "begin" 2 | time=$(date "+%Y%m%d%H") 3 | echo ${time} 4 | model_name=GGCN 5 | save=/${model_name}_${time} 6 | data=/mnt2/pytorch/GGCN/ppi,/mnt2/pytorch/GGCN/ppi 7 | 8 | python -m torch.distributed.launch --master_port 2849 train_hub/GGCN.py \ 9 | --save-interval=200 \ 10 | --optimizer=adam \ 11 | --lr=0.001 \ 12 | --weight-decay=0.0 \ 13 | --column_len=29 \ 14 | --eval-interval=100 \ 15 | --save=${save} \ 16 | --tables=${data} \ 17 | --log-interval=50 \ 18 | --batch-size=1 \ 19 | --model=${model_name} \ 20 | >${model_name}_${time}.log 2>&1 & 21 | 22 | echo ${data} 23 | -------------------------------------------------------------------------------- /scripts/mcrec/mcrec_training.sh: -------------------------------------------------------------------------------- 1 | echo "begin" 2 | time=$(date "+%Y%m%d%H") 3 | echo ${time} 4 | 5 | model_prefix=o # c, o, mc 6 | model_name=${model_prefix}rec 7 | save=./${model_name}_${time} 8 | data=./dataset/${model_prefix}rec_training.txt,./dataset/${model_prefix}rec_test.txt 9 | 10 | echo ${data} 11 | 12 | nohup python -m torch.distributed.launch --master_port 28489 train_hub/mcrec.py \ 13 | --save-interval=1000 \ 14 | --optimizer=adam \ 15 | --lr=0.001 \ 16 | --weight-decay=0.0 \ 17 | --eval-interval=100 \ 18 | --save=${save} \ 19 | --tables=${data} \ 20 | --log-interval=100 \ 21 | --batch-size=512 \ 22 | --model_type=${model_name} \ 23 | >${model_name}_${time}.log 2>&1 & 24 | -------------------------------------------------------------------------------- /scripts/ecrec/ecrec_training.sh: -------------------------------------------------------------------------------- 1 | echo "begin" 2 | if [ -d "tmp/" ];then 3 | rm -rf tmp 4 | fi 5 | 6 | save=/mnt2/songxiao/ecrec/test 7 | #infer=/mnt2/songxiao/ecrec/test/infer_result.txt 8 | #infer_table=/mnt2/songxiao/memory/book/book_10w_test_new.txt 9 | data=/mnt2/songxiao/memory/book/book_10w_train_new.txt,/mnt2/songxiao/memory/book/book_10w_test_new.txt 10 | 11 | 12 | python train_hub/ecrec.py \ 13 | --save-interval=200 \ 14 | --lr=0.0001 \ 15 | --eval-interval=200 \ 16 | --eval-iters=135 \ 17 | --save=${save} \ 18 | --tables=${data} \ 19 | --task-type=train \ 20 | --num-epochs=2 \ 21 | --weight-decay=0.01 \ 22 | --batch-size=128 \ 23 | --optimizer='adam' \ 24 | --find-unused-parameters 25 | # --load=${save} 26 | -------------------------------------------------------------------------------- /scripts/ggcn/ggcn_inference.sh: -------------------------------------------------------------------------------- 1 | echo "begin" 2 | time=$(date "+%Y%m%d%H") 3 | echo ${time} 4 | model_name=GGCN 5 | load=/GGCN_2022031401/ 6 | data=/mnt2/pytorch/GGCN/ppi,/mnt2/pytorch/GGCN/ppi 7 | rm -rf tmp 8 | outputs=/GGCN_2022031401/result.txt 9 | 10 | echo ${data} 11 | 12 | 13 | python -m torch.distributed.launch --master_port 2849 train_hub/GGCN.py \ 14 | --save-interval=200 \ 15 | --optimizer=adam \ 16 | --lr=0.001 \ 17 | --weight-decay=0.0 \ 18 | --column_len=29 \ 19 | --task-type=inference \ 20 | --eval-interval=100 \ 21 | --load=${load} \ 22 | --outputs=${outputs} \ 23 | --tables=${data} \ 24 | --log-interval=50 \ 25 | --batch-size=1 \ 26 | --model=${model_name} \ 27 | >${model_name}_${time}.log 2>&1 & 28 | 29 | -------------------------------------------------------------------------------- /scripts/cigar/cigar_training.sh: -------------------------------------------------------------------------------- 1 | echo "begin" 2 | time=$(date "+%Y%m%d%H") 3 | echo ${time} 4 | if [ -d "tmp/" ];then 5 | rm -rf tmp 6 | fi 7 | 8 | model_name=CIGAR 9 | save=/home/fay.cyf/fay.cyf/cigar/${model_name}_${time} 10 | data=/home/fay.cyf/fay.cyf/cigar/cigar_alimama_train.txt,/home/fay.cyf/fay.cyf/cigar/cigar_alimama_test_10000.txt 11 | 12 | 13 | echo ${data} 14 | 15 | nohup python -m torch.distributed.launch --master_port 20848 train_hub/cigar.py \ 16 | --save-interval=1000 \ 17 | --optimizer=adam \ 18 | --lr=0.001 \ 19 | --weight-decay=0.0 \ 20 | --column_len=29 \ 21 | --eval-interval=100 \ 22 | --save=${save} \ 23 | --tables=${data} \ 24 | --log-interval=100 \ 25 | --batch-size=512 \ 26 | --model=${model_name} \ 27 | >${model_name}_${time}.log 2>&1 & -------------------------------------------------------------------------------- /scripts/mcrec/mcrec_inference.sh: -------------------------------------------------------------------------------- 1 | echo "begin" 2 | time="2022032305" # you need to specify your log file suffix 3 | echo ${time} 4 | 5 | rm -rf tmp 6 | 7 | model_prefix=o # c, o, mc 8 | model_name=${model_prefix}rec 9 | load=./${model_name}_${time} 10 | data=./dataset/${model_prefix}rec_training.txt,./dataset/${model_prefix}rec_test.txt 11 | outputs=./${model_name}_${time}/inference_result.txt 12 | 13 | echo ${data} 14 | 15 | nohup python -m torch.distributed.launch --master_port 28489 train_hub/mcrec.py \ 16 | --save-interval=1000 \ 17 | --optimizer=adam \ 18 | --lr=0.001 \ 19 | --weight-decay=0.0 \ 20 | --eval-interval=100 \ 21 | --task-type=inference \ 22 | --load=${load} \ 23 | --outputs=${outputs} \ 24 | --tables=${data} \ 25 | --log-interval=100 \ 26 | --batch-size=512 \ 27 | --model_type=${model_name} \ 28 | >${model_name}_${time}_inference.log 2>&1 & 29 | -------------------------------------------------------------------------------- /scripts/mcrec/mcrec_export.sh: -------------------------------------------------------------------------------- 1 | echo "begin" 2 | time="2022032305" # you need to specify your log file suffix 3 | echo ${time} 4 | 5 | rm -rf tmp 6 | 7 | model_prefix=o # o or mc 8 | model_name=${model_prefix}rec 9 | load=./${model_name}_${time} 10 | data=./dataset/${model_prefix}rec_training.txt,./dataset/${model_prefix}rec_test.txt 11 | onnx_output=./${model_name}_${time}/output.onnx 12 | 13 | echo ${data} 14 | 15 | nohup python -m torch.distributed.launch --master_port 28489 train_hub/mcrec.py \ 16 | --save-interval=1000 \ 17 | --optimizer=adam \ 18 | --lr=0.001 \ 19 | --weight-decay=0.0 \ 20 | --eval-interval=100 \ 21 | --task-type=onnx_export \ 22 | --load=${load} \ 23 | --onnx_export_path=${onnx_output} \ 24 | --tables=${data} \ 25 | --log-interval=100 \ 26 | --batch-size=512 \ 27 | --model_type=${model_name} \ 28 | >${model_name}_${time}_export.log 2>&1 & 29 | -------------------------------------------------------------------------------- /scripts/pinrec/pinrec_training.sh: -------------------------------------------------------------------------------- 1 | echo "begin" 2 | 3 | cd ../../ 4 | 5 | if [ -d "tmp/" ];then 6 | rm -rf tmp 7 | fi 8 | 9 | save=`pwd`/data/output/pin/result 10 | 11 | dataset="" 12 | train_file="${dataset}/train.txt" 13 | test_file="${dataset}/test.txt" 14 | data="${train_file},${test_file}" 15 | 16 | ARCH_CONF_FILE=`pwd`/scripts/pinrec/movielens_conf.json 17 | 18 | python -m torch.distributed.launch --master_port 35213 train_hub/pinrec.py \ 19 | --batch-size=128 \ 20 | --num-epochs=12 \ 21 | --clip-grad=0.0 \ 22 | --train-iters=40000 \ 23 | --log-interval=100 \ 24 | --save-interval=9708 \ 25 | --lr=0.001 \ 26 | --eval-interval=40001 \ 27 | --save=${save} \ 28 | --tables=${data} \ 29 | --model=pinrec \ 30 | --group_num=5 \ 31 | --arch_config=${ARCH_CONF_FILE} \ 32 | --stage_switch_epoch=2 \ 33 | --optimizer=adam \ 34 | --backward-step-contains-in-forward-step -------------------------------------------------------------------------------- /scripts/cigar/cigar_inference.sh: -------------------------------------------------------------------------------- 1 | echo "begin" 2 | time=$(date "+%Y%m%d%H") 3 | echo ${time} 4 | if [ -d "tmp/" ];then 5 | rm -rf tmp 6 | fi 7 | 8 | model_name=CIGAR 9 | load=/home/fay.cyf/fay.cyf/cigar/CIGAR_2022032315 10 | data=/home/fay.cyf/fay.cyf/cigar/cigar_alimama_train.txt,/home/fay.cyf/fay.cyf/cigar/cigar_alimama_test_10000.txt 11 | outputs=/home/fay.cyf/fay.cyf/cigar/CIGAR_2022032315/result.txt 12 | 13 | echo ${data} 14 | 15 | nohup python -m torch.distributed.launch --master_port 28486 train_hub/cigar.py \ 16 | --save-interval=10000 \ 17 | --lr=0.001 \ 18 | --optimizer=adam \ 19 | --weight-decay=0.0 \ 20 | --column_len=29 \ 21 | --eval-interval=100 \ 22 | --task-type=inference \ 23 | --load=${load} \ 24 | --outputs=${outputs} \ 25 | --tables=${data} \ 26 | --log-interval=100 \ 27 | --batch-size=512 \ 28 | --model=${model_name} \ 29 | >${model_name}_${time}_infer.log 2>&1 & -------------------------------------------------------------------------------- /scripts/cigar/cigar_export.sh: -------------------------------------------------------------------------------- 1 | echo "begin" 2 | time=$(date "+%Y%m%d%H") 3 | echo ${time} 4 | 5 | model_name=CIGAR 6 | save=/home/fay.cyf/fay.cyf/cigar 7 | data=/home/fay.cyf/fay.cyf/cigar/cigar_alimama_test_10000.txt,/home/fay.cyf/fay.cyf/cigar/cigar_alimama_test_10000.txt 8 | #onnx_export_path=/mnt2/yyang/cigar_onnx/save/ 9 | load=/home/fay.cyf/fay.cyf/cigar 10 | load_model_name=CIGAR_2022032315 11 | load_model_path=${load}/${load_model_name}/1000 12 | 13 | echo ${data} 14 | 15 | python -m torch.distributed.launch --master_port 29493 train_hub/cigar.py \ 16 | --save-interval=10 \ 17 | --optimizer=adam \ 18 | --lr=0.001 \ 19 | --weight-decay=0.0 \ 20 | --column_len=29 \ 21 | --eval-interval=10 \ 22 | --save=${save}/${load_model_name} \ 23 | --tables=${data} \ 24 | --log-interval=10 \ 25 | --batch-size=512 \ 26 | --num-epochs=2 \ 27 | --model=${model_name} \ 28 | --task-type=onnx_export\ 29 | --load=${load_model_path} -------------------------------------------------------------------------------- /scripts/ggcn/ggcn_export.sh: -------------------------------------------------------------------------------- 1 | echo "begin" 2 | time=$(date "+%Y%m%d%H") 3 | echo ${time} 4 | 5 | model_name=GGCN 6 | 7 | 8 | echo ${data} 9 | 10 | model_name=GGCN 11 | save=/${model_name}_${time} 12 | data=/mnt2/pytorch/GGCN/ppi,/mnt2/pytorch/GGCN/ppi 13 | 14 | onnx_export_path=/mnt2/pytorch/save/ 15 | onnx_model_name=onnx_${model_name}_${time}.onnx 16 | load_model_path=/GGCN_2022031401/3200/ 17 | 18 | 19 | 20 | python -m torch.distributed.launch --master_port 2819 train_hub/GGCN.py \ 21 | --save-interval=200 \ 22 | --optimizer=adam \ 23 | --lr=0.001 \ 24 | --weight-decay=0.0 \ 25 | --column_len=29 \ 26 | --eval-interval=100 \ 27 | --save=${save} \ 28 | --onnx_export_path=${onnx_export_path} \ 29 | --onnx_model_name=${onnx_model_name} \ 30 | --tables=${data} \ 31 | --log-interval=50 \ 32 | --batch-size=1 \ 33 | --model=${model_name}\ 34 | --task-type=onnx_export\ 35 | --load_model_path=${load_model_path} 36 | -------------------------------------------------------------------------------- /onnx_test/mcrec_onnx_model_test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2022 The Luoxi Team. 3 | # All rights reserved. 4 | # This source code is licensed under the Apache 2.0 license 5 | # found in the LICENSE file in the root directory. 6 | 7 | import torch 8 | 9 | def mock_orec_data(): 10 | num_samples = 2 11 | mock_data = {} 12 | 13 | hist_seq = [[int(e) for e in "44,172,602,602,163,258,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0".strip().split(",")]] 14 | mock_data['hist_seq'] = torch.LongTensor(hist_seq).repeat((num_samples, 1)) 15 | 16 | cand = [[int("672")]] 17 | mock_data['cand'] = torch.LongTensor(cand).repeat((num_samples, 1)).squeeze(1) 18 | 19 | prior_score = [[float("0.1192")]] 20 | mock_data['prior_score'] = torch.Tensor(prior_score).repeat((num_samples, 1)).squeeze(1) 21 | 22 | label = [[int("0")]] 23 | mock_data['label'] = torch.LongTensor(label).repeat((num_samples, 1)).squeeze(1) 24 | 25 | return mock_data 26 | 27 | -------------------------------------------------------------------------------- /scripts/pinrec/pinrec_inference.sh: -------------------------------------------------------------------------------- 1 | echo "begin" 2 | 3 | cd ../../ 4 | 5 | if [ -d "tmp/" ];then 6 | rm -rf tmp 7 | fi 8 | 9 | dataset="" 10 | train_file="${dataset}/train.txt" 11 | test_file="${dataset}/test.txt" 12 | data="${train_file},${test_file}" 13 | 14 | load=`pwd`/data/output/pinrec/result 15 | save=`pwd`/data/output/pinrec/result 16 | output=`pwd`/data/output/pinrec/result/rerank.txt 17 | 18 | ARCH_CONF_FILE=`pwd`/scripts/pinrec/movielens_conf.json 19 | 20 | python -m torch.distributed.launch --master_port 12350 train_hub/pinrec.py \ 21 | --task-type=inference \ 22 | --batch-size=512 \ 23 | --num-epochs=1 \ 24 | --train-iters=40000 \ 25 | --log-interval=100 \ 26 | --save-interval=10000 \ 27 | --lr=0.001 \ 28 | --eval-interval=40001 \ 29 | --load=${load} \ 30 | --save=${save} \ 31 | --tables=${data} \ 32 | --model=pinrec \ 33 | --group_num=5 \ 34 | --arch_config=${ARCH_CONF_FILE} \ 35 | --stage_switch_epoch=2 \ 36 | --optimizer=adam \ 37 | --outputs=${output} -------------------------------------------------------------------------------- /scripts/pinrec/pinrec_export.sh: -------------------------------------------------------------------------------- 1 | echo "begin" 2 | 3 | cd ../../ 4 | 5 | if [ -d "tmp/" ];then 6 | rm -rf tmp 7 | fi 8 | 9 | dataset="" 10 | train_file="${dataset}/train.txt" 11 | test_file="${dataset}/test.txt" 12 | data="${train_file},${test_file}" 13 | 14 | load=`pwd`/data/output/pinrec/result 15 | save=`pwd`/data/output/pinrec/result 16 | output=`pwd`/data/output/pinrec/result/rerank.txt 17 | 18 | onnx_export_path=`pwd`/onnx_test/ 19 | 20 | ARCH_CONF_FILE=`pwd`/scripts/pinrec/movielens_conf.json 21 | 22 | python -m torch.distributed.launch --master_port 12350 train_hub/pinrec.py \ 23 | --task-type=inference \ 24 | --batch-size=512 \ 25 | --num-epochs=1 \ 26 | --train-iters=40000 \ 27 | --log-interval=100 \ 28 | --save-interval=10000 \ 29 | --lr=0.001 \ 30 | --eval-interval=40001 \ 31 | --load=${load} \ 32 | --save=${save} \ 33 | --tables=${data} \ 34 | --model=pinrec \ 35 | --group_num=5 \ 36 | --arch_config=${ARCH_CONF_FILE} \ 37 | --stage_switch_epoch=2 \ 38 | --optimizer=adam \ 39 | --outputs=${output} \ 40 | --task-type=onnx_export \ 41 | --onnx_export_path=${onnx_export_path} -------------------------------------------------------------------------------- /scripts/pinrec/movielens_conf.json: -------------------------------------------------------------------------------- 1 | { 2 | "id_dimension": 32, 3 | "id_vocab": 4545, 4 | "classifier": [128, 64], 5 | "embedding_layer_name_list": ["_id_encoder._embedding_matrix.weight"], 6 | "model_name_list": [ 7 | "_id_encoder._embedding_matrix.weight", 8 | "_target_trans.net.0.weight", 9 | "_target_trans.net.0.bias", 10 | "_seq_trans.net.0.weight", 11 | "_seq_trans.net.0.bias", 12 | "_target_attention._target_key_transform.net.weight", 13 | "_target_attention._item_key_transform.net.weight", 14 | "_target_attention._value_transform.net.weight", 15 | "_classifier.net.0.weight", 16 | "_classifier.net.0.bias", 17 | "_classifier.net.2.weight", 18 | "_classifier.net.2.bias", 19 | "_classifier.net.4.weight", 20 | "_classifier.net.4.bias" 21 | ], 22 | "feature_extractor_name_list": [ 23 | "_target_trans.net.0.weight", 24 | "_target_trans.net.0.bias", 25 | "_seq_trans.net.0.weight", 26 | "_seq_trans.net.0.bias", 27 | "_target_attention._target_key_transform.net.weight", 28 | "_target_attention._item_key_transform.net.weight", 29 | "_target_attention._value_transform.net.weight" 30 | ], 31 | "classifier_name_list": [ 32 | "_classifier.net.0.weight", 33 | "_classifier.net.0.bias", 34 | "_classifier.net.2.weight", 35 | "_classifier.net.2.bias", 36 | "_classifier.net.4.weight", 37 | "_classifier.net.4.bias" 38 | ] 39 | } 40 | -------------------------------------------------------------------------------- /module/plugin.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2022 The Luoxi Team. 3 | # All rights reserved. 4 | # This source code is licensed under the Apache 2.0 license 5 | # found in the LICENSE file in the root directory. 6 | 7 | import torch 8 | from . import initializer 9 | 10 | 11 | class Plugin(torch.nn.Module): 12 | def __init__(self, dimension, activation_func): 13 | super(Plugin, self).__init__() 14 | self._activation_func = activation_func() 15 | self.downsampling_layer = torch.nn.Linear(dimension, dimension//2, True) 16 | self.upsampling_layer = torch.nn.Linear(dimension//2, dimension, True) 17 | 18 | initializer.default_lite_plugin_init(self.downsampling_layer) 19 | initializer.default_lite_plugin_init(self.upsampling_layer) 20 | 21 | module_lst = [ 22 | self.downsampling_layer, 23 | activation_func(), 24 | self.upsampling_layer 25 | ] 26 | self.net = torch.nn.Sequential(*module_lst) 27 | 28 | def forward(self, x): 29 | residual = self.net(x) 30 | return self._activation_func(x + residual) 31 | 32 | 33 | if __name__ == '__main__': 34 | plugin_model = Plugin(32, torch.nn.Tanh) 35 | print("=" * 50) 36 | print("plugin_model.downsampling_layer", plugin_model.downsampling_layer.weight, 37 | plugin_model.downsampling_layer.bias, sep="\n") 38 | print("-" * 50) 39 | print("plugin_model.upsampling_layer", plugin_model.upsampling_layer.weight, 40 | plugin_model.upsampling_layer.bias, sep="\n") 41 | -------------------------------------------------------------------------------- /module/common.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2022 The Luoxi Team. 3 | # All rights reserved. 4 | # This source code is licensed under the Apache 2.0 license 5 | # found in the LICENSE file in the root directory. 6 | 7 | import torch 8 | from . import initializer 9 | 10 | class StackedDense(torch.nn.Module): 11 | def __init__(self, in_dimension, units, activation_fns): 12 | super(StackedDense, self).__init__() 13 | 14 | modules = [] 15 | units = [in_dimension] + list(units) 16 | for i in range(1, len(units)): 17 | linear = torch.nn.Linear(units[i-1], units[i], bias=True) 18 | initializer.default_weight_init(linear.weight) 19 | initializer.default_bias_init(linear.bias) 20 | modules.append(linear) 21 | 22 | if activation_fns[i-1] is not None: 23 | modules.append(activation_fns[i-1]()) 24 | 25 | self.net = torch.nn.Sequential(*modules) 26 | 27 | def __setitem__(self, k, v): 28 | self.k = v 29 | 30 | def forward(self, x): 31 | return self.net(x) 32 | 33 | class Linear(torch.nn.Module): 34 | def __init__(self, in_dimension, out_dimension, bias): 35 | super(Linear, self).__init__() 36 | self.net = torch.nn.Linear(in_dimension, out_dimension, bias) 37 | initializer.default_weight_init(self.net.weight) 38 | if bias: 39 | initializer.default_weight_init(self.net.bias) 40 | 41 | def __setitem__(self, k, v): 42 | self.k = v 43 | 44 | def forward(self, x): 45 | return self.net(x) 46 | 47 | -------------------------------------------------------------------------------- /module/encoder.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2022 The Luoxi Team. 3 | # All rights reserved. 4 | # This source code is licensed under the Apache 2.0 license 5 | # found in the LICENSE file in the root directory. 6 | 7 | import torch 8 | from . import initializer 9 | 10 | class BaseEncoder(torch.nn.Module): 11 | def __init__(self): 12 | super(BaseEncoder, self).__init__() 13 | 14 | def in_dimension(self): 15 | raise NotImplementedError 16 | 17 | def out_dimension(self): 18 | raise NotImplementedError 19 | 20 | class DenseEncoder(BaseEncoder): 21 | def __init__(self, in_dimension, out_dimension, activation=torch.nn.Tanh): 22 | super(DenseEncoder, self).__init__() 23 | self._in_dimension = in_dimension 24 | self._out_dimension = out_dimension 25 | self._fully_connect = torch.nn.Sequential( 26 | torch.nn.Linear(self._in_dimension, self._out_dimension), 27 | activation() 28 | ) 29 | 30 | def __setitem__(self, k, v): 31 | self.k = v 32 | 33 | def forward(self, x): 34 | return self._fully_connect(x) 35 | 36 | def out_dimension(self): 37 | return self._out_dimension 38 | 39 | def in_dimension(self): 40 | return self._in_dimension 41 | 42 | class IDEncoder(BaseEncoder): 43 | def __init__(self, vocab_size, out_dimension): 44 | super(IDEncoder, self).__init__() 45 | self._embedding_matrix = torch.nn.Embedding( 46 | num_embeddings=vocab_size, 47 | embedding_dim=out_dimension 48 | ) 49 | initializer.default_weight_init(self._embedding_matrix.weight) 50 | 51 | def __setitem__(self, k, v): 52 | self.k = v 53 | 54 | def forward(self, x): 55 | return self._embedding_matrix(x) 56 | 57 | def out_dimension(self): 58 | return self._embedding_matrix.embedding_dim -------------------------------------------------------------------------------- /module/initializer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2022 The Luoxi Team. 3 | # All rights reserved. 4 | # This source code is licensed under the Apache 2.0 license 5 | # found in the LICENSE file in the root directory. 6 | 7 | import torch 8 | 9 | # trunk model init 10 | def default_weight_init(tensor): 11 | # torch.nn.init.xavier_uniform(tensor) 12 | torch.nn.init.xavier_uniform_(tensor) 13 | 14 | def default_bias_init(tensor): 15 | torch.nn.init.constant_(tensor, 0) 16 | 17 | # lite plugin model init 18 | def default_lite_plugin_init(layer): 19 | # torch.nn.init.xavier_uniform(layer.weight, gain=0.001) 20 | torch.nn.init.xavier_uniform_(layer.weight, gain=0.001) 21 | # torch.nn.init.constant_(layer.weight, 0) 22 | torch.nn.init.constant_(layer.bias, 0) 23 | 24 | # naive plugin model init 25 | def default_naive_plugin_init(layer): 26 | torch.nn.init.constant_(layer.weight, 0) 27 | torch.nn.init.constant_(layer.bias, 0) 28 | 29 | if __name__ == '__main__': 30 | # model.apply(weight_init_normal) 31 | dimension = 10 32 | plugin_layer = torch.nn.Linear(dimension, dimension // 2, True) 33 | print("-" * 50) 34 | print("original") 35 | print("plugin_layer.weight", plugin_layer.weight) 36 | print("plugin_layer.bias", plugin_layer.bias) 37 | default_weight_init(plugin_layer.weight) 38 | default_bias_init(plugin_layer.bias) 39 | print("-" * 50) 40 | print("trunk_init") 41 | print("plugin_layer.weight", plugin_layer.weight) 42 | print("plugin_layer.bias", plugin_layer.bias) 43 | default_lite_plugin_init(plugin_layer) 44 | print("-" * 50) 45 | print("lite_plugin_init") 46 | print("plugin_layer.weight", plugin_layer.weight) 47 | print("plugin_layer.bias", plugin_layer.bias) 48 | default_naive_plugin_init(plugin_layer) 49 | print("-" * 50) 50 | print("naive_plugin_init") 51 | print("plugin_layer.weight", plugin_layer.weight) 52 | print("plugin_layer.bias", plugin_layer.bias) -------------------------------------------------------------------------------- /backend/dataset_hub/pin_datasets.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2022 The Luoxi Team. 3 | # All rights reserved. 4 | # This source code is licensed under the Apache 2.0 license 5 | # found in the LICENSE file in the root directory. 6 | 7 | import torch 8 | from backend.dataset_hub.base_datasets import BaseDataset 9 | 10 | 11 | class PINDatasetLocal(BaseDataset): 12 | def __init__(self, 13 | args, 14 | table_name, 15 | shuffle_buffer_size=8194, 16 | is_test=False): 17 | 18 | super().__init__(args, table_name, shuffle_buffer_size, is_test) 19 | self.column_length = len(args.consts.keys()) 20 | 21 | def get_total_row_count(self): 22 | cnt = 0 23 | for _ in self.reader: 24 | if len(_) > 0: 25 | cnt += 1 26 | self.reader.seek(0) 27 | 28 | return cnt 29 | 30 | def _new_reader(self): 31 | if self.reader is not None: 32 | self.reader.close() 33 | print('self.table_name', self.table_name) 34 | reader = open(self.table_name, "r") 35 | 36 | return reader 37 | 38 | def _read_record(self): 39 | try: 40 | column_l = self.reader.readline().strip().split(";") 41 | assert len(column_l) == self.column_length, "len(column_l) must be {}, now is {}}".format(self.column_length, len(column_l)) 42 | except: 43 | self.reader.seek(0) 44 | column_l = self.reader.readline().strip().split(";") 45 | return column_l 46 | 47 | def __del__(self): 48 | if self.reader is not None: 49 | self.reader.close() 50 | 51 | def _parse_item(self, column_l): 52 | user_id = int(column_l[0]) 53 | target_id = int(column_l[1]) 54 | clk_sequence = torch.LongTensor(list(map(int, column_l[2].split(",")))) 55 | label = int(column_l[3]) 56 | group_id = int(column_l[4]) 57 | ret = { 58 | "user_id": user_id, 59 | "target_id": target_id, 60 | "clk_sequence": clk_sequence, 61 | "label": label, 62 | "group_id": group_id 63 | } 64 | return ret -------------------------------------------------------------------------------- /onnx_test/pinrec_onnx_model_test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2022 The Luoxi Team. 3 | # All rights reserved. 4 | # This source code is licensed under the Apache 2.0 license 5 | # found in the LICENSE file in the root directory. 6 | 7 | import numpy as np 8 | import torch 9 | import os 10 | import onnxruntime 11 | import sys 12 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 13 | from backend.utils import to_numpy 14 | from backend.model_hub import pinrec_model 15 | 16 | consts = pinrec_model.consts 17 | 18 | 19 | def mock_data(): 20 | batch_size = 2 21 | seq_lens = 50 22 | mock_data_dict = { 23 | consts["FIELD_USER_ID"]: torch.LongTensor(torch.randint(low=0, high=6040, size=[batch_size])), 24 | consts["FIELD_TARGET_ID"]: torch.LongTensor(torch.randint(low=0, high=3706, size=[batch_size])), 25 | consts["FIELD_CLK_SEQUENCE"]: torch.LongTensor(torch.randint(low=0, high=3706, size=[batch_size, seq_lens])), 26 | consts["FIELD_LABEL"]: torch.LongTensor(torch.randint(low=0, high=1, size=[batch_size])), 27 | consts["FIELD_GROUP_ID"]: torch.LongTensor(torch.randint(low=0, high=4, size=[batch_size])) 28 | } 29 | return mock_data_dict 30 | 31 | 32 | def onnx_test(mock_data_func, onnx_export_path=None, onnx_model_name=None): 33 | ''' 34 | func: get the result of onnx model 35 | mock_data_func: the func is used in onnx export 36 | onnx_export_path: onnx model path 37 | onnx_model_name: onnx model name 38 | ''' 39 | data = mock_data_func() 40 | data = dict((k, to_numpy(v)) for k, v in data.items()) 41 | input_data = [] 42 | for i, (k, v) in enumerate(data.items()): 43 | input_data.append(np.array(v)) 44 | onnx_model_name = onnx_model_name if onnx_model_name else "model_00.onnx" 45 | model_path = os.path.join(onnx_export_path, onnx_model_name) 46 | ort_session = onnxruntime.InferenceSession(model_path) 47 | ort_inputs = { 48 | 'input.1': data[consts["FIELD_TARGET_ID"]], 49 | 'input.5': data[consts["FIELD_CLK_SEQUENCE"]] 50 | } 51 | ort_outs = ort_session.run(None, ort_inputs) 52 | print(ort_outs) 53 | 54 | 55 | if __name__ == '__main__': 56 | onnx_test(mock_data, '/mnt4/lzq/mobilem6/onnx_test', 'onnx_model.onnx') 57 | -------------------------------------------------------------------------------- /module/attention.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2022 The Luoxi Team. 3 | # All rights reserved. 4 | # This source code is licensed under the Apache 2.0 license 5 | # found in the LICENSE file in the root directory. 6 | 7 | import torch 8 | from torch import nn 9 | from . import common 10 | 11 | class TargetAttention(nn.Module): 12 | def __init__(self, key_dimension, value_dimension): 13 | super(TargetAttention, self).__init__() 14 | self._target_key_transform = common.Linear(key_dimension, key_dimension, bias=False) 15 | self._item_key_transform = common.Linear(key_dimension, key_dimension, bias=False) 16 | self._value_transform = common.Linear(value_dimension, value_dimension, bias=False) 17 | self._scaler = key_dimension ** 0.5 18 | 19 | def __setitem__(self, k, v): 20 | self.k = v 21 | 22 | def forward(self, target_key, item_keys, item_values, mask): 23 | """ 24 | :param target_key: B * D 25 | :param item_keys: B * L * D 26 | :param item_values: B * L * D 27 | :param mask: B * L 28 | :return: 29 | """ 30 | assert item_keys.shape[1] == item_values.shape[1] 31 | assert target_key.shape[-1] == item_keys.shape[-1] 32 | assert target_key.shape[0] == item_keys.shape[0] == item_values.shape[0] 33 | 34 | target_key = self._target_key_transform(target_key)[:, None, :] 35 | item_keys = self._item_key_transform(item_keys) 36 | item_values = self._value_transform(item_values) 37 | 38 | atten_weights = torch.sum(target_key * item_keys, dim=-1, keepdim=True) / self._scaler 39 | if mask is not None: 40 | atten_weights += -1e8 * (1 - mask[:, :, None]) 41 | 42 | atten_weights = torch.softmax(atten_weights, dim=1) 43 | return torch.sum(atten_weights * item_values, dim=1) 44 | 45 | if __name__ == '__main__': 46 | target_embed = torch.randn(16, 8) 47 | item_keys = torch.randn(16, 10, 8) 48 | item_values = torch.randn(16, 10, 23) 49 | 50 | m = TargetAttention( 51 | key_dimension=8, 52 | value_dimension=23, 53 | value_out_dimension=32 54 | ) 55 | mask = torch.cat([torch.ones([16, 4]), torch.zeros([16, 6])], dim=1) 56 | data = m(target_embed, item_keys, item_values, mask) 57 | print(data.shape) -------------------------------------------------------------------------------- /backend/dataset_hub/base_datasets.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2022 The Luoxi Team. 3 | # All rights reserved. 4 | # This source code is licensed under the Apache 2.0 license 5 | # found in the LICENSE file in the root directory. 6 | 7 | from torch.utils.data import Dataset 8 | import random 9 | 10 | class BaseDataset(Dataset): 11 | def __init__(self, 12 | args, 13 | table_name, 14 | shuffle_buffer_size, 15 | is_test): 16 | self.args = args 17 | self.table_name = table_name 18 | 19 | self.reader = None 20 | self.reader = self._new_reader() 21 | self.fetch_iter = 0 22 | self.num_to_fetch = self.get_total_row_count() # 单卡 23 | self.shuffle_buffer = [] 24 | self.shuffle_size = shuffle_buffer_size 25 | self.is_test = is_test 26 | self._total_row_count = -1 27 | 28 | def get_total_row_count(self): 29 | raise NotImplementedError 30 | 31 | def _new_reader(self): 32 | raise NotImplementedError 33 | 34 | def _need_reload(self): 35 | return self.fetch_iter == self.num_to_fetch 36 | 37 | def _read_record(self): 38 | raise NotImplementedError 39 | 40 | def _read_item(self): 41 | if self._need_reload(): 42 | self.fetch_iter = 0 43 | self.reader = self._new_reader() 44 | self.fetch_iter += 1 45 | column_l = self._read_record() 46 | column_l = [ 47 | item.decode(encoding="utf8", errors="ignore") if type(item) == bytes else item 48 | for item in column_l 49 | ] 50 | return column_l 51 | 52 | def _parse_item(self, column_l): 53 | raise NotImplementedError 54 | 55 | def __getitem__(self, idx): 56 | if self.is_test: 57 | return self._parse_item(self._read_item()) 58 | while ((not self._need_reload()) or (len(self.shuffle_buffer) == 0)) and ( 59 | len(self.shuffle_buffer) < self.shuffle_size): 60 | self.shuffle_buffer.append(self._read_item()) 61 | num_samples = len(self.shuffle_buffer) 62 | i = random.randint(0, num_samples - 1) 63 | if i != num_samples - 1: 64 | self.shuffle_buffer[i], self.shuffle_buffer[-1] = self.shuffle_buffer[-1], self.shuffle_buffer[i] 65 | ret_item = self.shuffle_buffer.pop(-1) 66 | ret_item = self._parse_item(ret_item) 67 | return ret_item 68 | 69 | def __len__(self): 70 | if self.is_test: 71 | return self.num_to_fetch 72 | return 2 ** 30 # fake -------------------------------------------------------------------------------- /backend/dataset_hub/ecrec_datasets.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2022 The Luoxi Team. 3 | # All rights reserved. 4 | # This source code is licensed under the Apache 2.0 license 5 | # found in the LICENSE file in the root directory. 6 | 7 | import torch 8 | from backend.dataset_hub.base_datasets import BaseDataset 9 | 10 | 11 | class ECRecDatasetLocal(BaseDataset): 12 | def __init__(self, 13 | args, 14 | table_name, 15 | shuffle_buffer_size=8194, 16 | is_test=False): 17 | 18 | super().__init__(args, table_name, shuffle_buffer_size, is_test) 19 | self.column_length = args.column_length 20 | self.maxlen = args.sequence_length 21 | 22 | def get_total_row_count(self): 23 | cnt = 0 24 | for _ in self.reader: 25 | if len(_) > 0: 26 | cnt += 1 27 | self.reader.seek(0) 28 | 29 | return cnt 30 | 31 | def _new_reader(self): 32 | if self.reader is not None: 33 | self.reader.close() 34 | print('self.table_name', self.table_name) 35 | reader = open(self.table_name, "r") 36 | 37 | return reader 38 | 39 | 40 | def _read_record(self): 41 | try: 42 | column_l = self.reader.readline().strip().split("\t") 43 | assert len(column_l) == self.column_length, "len(column_l) must be {}, now is {}}".format(self.column_length, len(column_l)) 44 | except: 45 | self.reader.seek(0) 46 | column_l = self.reader.readline().strip().split("\t") 47 | return column_l 48 | 49 | def __del__(self): 50 | if self.reader is not None: 51 | self.reader.close() 52 | 53 | def _parse_item(self, one_sample): 54 | '''parse items for each sample''' 55 | uid = int(one_sample[0]) 56 | item_id = int(one_sample[1]) 57 | cate_id = int(one_sample[2]) 58 | label = int(one_sample[3]) 59 | 60 | if one_sample[4] == '': 61 | hist_item = [0] * self.maxlen 62 | hist_cate = [0] * self.maxlen 63 | else: 64 | hist_item = one_sample[4].split(',') 65 | hist_item = hist_item[:self.maxlen] if len(hist_item)>=self.maxlen else hist_item + [0]*(self.maxlen-len(hist_item)) 66 | hist_item = list(map(int, hist_item)) 67 | hist_cate = one_sample[5].split(',') 68 | hist_cate = hist_cate[:self.maxlen] if len(hist_cate)>=self.maxlen else hist_cate + [0]*(self.maxlen-len(hist_cate)) 69 | hist_cate = list(map(int, hist_cate)) 70 | 71 | if one_sample[6] == '': 72 | edge_item = [0] * self.maxlen 73 | edge_cate = [0] * self.maxlen 74 | else: 75 | edge_item = one_sample[6].split(',') 76 | edge_item = edge_item[:self.maxlen] if len(edge_item) >= self.maxlen else edge_item + [0] * ( 77 | self.maxlen - len(edge_item)) 78 | edge_item = list(map(int, edge_item)) 79 | edge_cate = one_sample[7].split(',') 80 | edge_cate = edge_cate[:self.maxlen] if len(edge_cate) >= self.maxlen else edge_cate + [0] * ( 81 | self.maxlen - len(edge_cate)) 82 | edge_cate = list(map(int, edge_cate)) 83 | 84 | seq_len = int(one_sample[8]) 85 | 86 | output = {'user_id': uid, 87 | 'item_id': item_id, 88 | 'cate_id': cate_id, 89 | 'label': label, 90 | 'item_seq': torch.LongTensor(hist_item), 91 | 'cate_seq': torch.LongTensor(hist_cate), 92 | 'edge_item_seq': torch.LongTensor(edge_item), 93 | 'edge_cate_seq': torch.LongTensor(edge_cate), 94 | 'seq_len': seq_len 95 | } 96 | 97 | return output -------------------------------------------------------------------------------- /onnx_test/ecrec_onnx_test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2022 The Luoxi Team. 3 | # All rights reserved. 4 | # This source code is licensed under the Apache 2.0 license 5 | # found in the LICENSE file in the root directory. 6 | 7 | import torch 8 | import os 9 | import onnxruntime 10 | import sys 11 | sys.path.append('/mnt2/songxiao/device-cloud-modelhub/') 12 | from backend.utils import to_numpy 13 | 14 | def mock_data(num_samples=2,data_names=None,data_sizes=None,data_nums=None): 15 | ''' 16 | num_samples: batch_size 17 | data_names: 18 | data_types: int or float 19 | data_sizes: the size of one sample 20 | data_nums: dict, the max num of k 21 | ''' 22 | mock_data={} 23 | data_names = data_names.split(",") 24 | data_sizes = data_sizes.split(",") 25 | assert len(data_names)==len(data_sizes) 26 | for name,size in zip(data_names,data_sizes): 27 | if name=="label": 28 | v = torch.randint(low=0, high=2, size=[num_samples], dtype=torch.int64) 29 | else: 30 | high_num = data_nums.get(name, 2) if data_nums else 2 31 | if int(size) == 1: 32 | v = torch.randint(low=1, high=high_num, size=[num_samples], dtype=torch.int64) 33 | else: 34 | v = torch.randint(low=1, high=high_num, size=[num_samples, int(size)], dtype=torch.int64) 35 | mock_data[name]=v 36 | return mock_data 37 | 38 | def mock_edge_data(): 39 | data_names = "user_id,item_id,cate_id,edge_item_seq,edge_cate_seq,seq_len,label" 40 | data_sizes = "1,1,1,100,100,1,1" 41 | data_nums = {"user_id": 1000, 42 | "item_id": 1000, 43 | "cate_id": 100, 44 | "edge_item_seq": 1000, 45 | "edge_cate_seq": 100, 46 | "seq_len": 100, 47 | "label": 2 48 | } 49 | data = mock_data(2, data_names, data_sizes, data_nums) 50 | return data 51 | 52 | def mock_cloud_data(): 53 | data_names = "user_id,item_id,cate_id,edge_item_seq,edge_cate_seq,seq_len,label,item_seq,cate_seq" 54 | data_sizes = "1,1,1,100,100,1,1,100,100" 55 | data_nums = {"user_id": 1000, 56 | "item_id": 1000, 57 | "cate_id": 100, 58 | "edge_item_seq": 1000, 59 | "edge_cate_seq": 100, 60 | "seq_len": 100, 61 | "label": 2, 62 | "item_seq": 1000, 63 | "cate_seq": 100 64 | } 65 | data = mock_data(2, data_names, data_sizes, data_nums) 66 | return data 67 | 68 | def onnx_test(mock_data_func,onnx_export_path=None, onnx_model_name=None): 69 | ''' 70 | func: get the result of onnx model 71 | mock_data_func: the func is used in onnx export 72 | onnx_export_path: onnx model path 73 | onnx_model_name: onnx model name 74 | ''' 75 | data = mock_data_func() 76 | data = dict((k, to_numpy(v)) for k, v in data.items()) 77 | input_data = [] 78 | for i,(k, v) in enumerate(data.items()): 79 | input_data.append(v) 80 | onnx_model_name = onnx_model_name if onnx_model_name else "model_00.onnx" 81 | model_path = os.path.join(onnx_export_path, onnx_model_name) 82 | # model_path = "./output/model_00.onnx" 83 | ort_session = onnxruntime.InferenceSession(model_path) 84 | ort_inputs={} 85 | for s_input in ort_session.get_inputs(): 86 | if str(s_input.name).split(".")[0] == 'input': 87 | k = str(s_input.name).split(".")[-1] 88 | in_da = input_data[int(k) - 1] 89 | ort_inputs[s_input.name] = in_da 90 | ort_inputs['5'] = data['seq_len'] 91 | ort_inputs['6'] = data['label'] 92 | 93 | ort_outs = ort_session.run(None, ort_inputs) 94 | print(ort_outs) 95 | 96 | if __name__ == '__main__': 97 | onnx_test(mock_edge_data, '/mnt2/songxiao/ecrec/test', 'onnx_model.onnx') -------------------------------------------------------------------------------- /backend/dataset_hub/cigar_datasets.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2022 The Luoxi Team. 3 | # All rights reserved. 4 | # This source code is licensed under the Apache 2.0 license 5 | # found in the LICENSE file in the root directory. 6 | 7 | import torch 8 | from backend.dataset_hub.base_datasets import * 9 | 10 | class CIGARDataset(BaseDataset): 11 | def __init__(self, 12 | args, 13 | table_name, 14 | shuffle_buffer_size, 15 | is_test=False, 16 | max_len=-1, 17 | max_neighbor=-1): 18 | super(CIGARDataset, self).__init__(args, table_name, shuffle_buffer_size, is_test) 19 | self.maxlen = max_len 20 | self.max_neighbor = max_neighbor 21 | 22 | def _parse_item(self, one_sample): 23 | output = {} 24 | #### User feature #### 25 | user_fea = [int(one_sample[int(x)])+1 for x in self.args.user_fea_col_id.split(',')] 26 | user_fea_name = self.args.user_fea_name.split(',') 27 | output.update(dict(zip(user_fea_name, user_fea))) 28 | #### Item feature #### 29 | item_fea = [int(one_sample[int(x)]) + 1 for x in self.args.item_fea_col_id.split(',')] 30 | item_fea_name = self.args.item_fea_name.split(',') 31 | output.update(dict(zip(item_fea_name, item_fea))) 32 | #### Seq feature #### 33 | seq_col_id = [int(x) for x in self.args.seq_col_id.split(',')] 34 | seq_fea_name = [x+'_seq' for x in self.args.item_fea_name.split(',')] 35 | # padding and truncation 36 | for i in range(len(seq_col_id)): 37 | if one_sample[seq_col_id[i]] == '': 38 | seq_fea = [0] * self.maxlen 39 | else: 40 | seq_fea = one_sample[seq_col_id[i]].split(',') 41 | seq_fea = seq_fea[-self.maxlen:] \ 42 | if len(seq_fea) >= self.maxlen \ 43 | else seq_fea + [0] * (self.maxlen - len(seq_fea)) 44 | seq_fea = list(map(int, seq_fea)) 45 | assert len(seq_fea) == self.maxlen 46 | output[seq_fea_name[i]] = torch.LongTensor(seq_fea) 47 | # uid, label 48 | uid, graph_id, label_id = [int(x) for x in self.args.uid_graph_label_col_id.split(',')] 49 | output['userid'] = int(one_sample[uid]) 50 | output['label'] = float(one_sample[label_id]) 51 | # Graph feature 52 | # padding and truncation 53 | if one_sample[graph_id] == '': 54 | neighbor_ids = [0] * self.max_neighbor 55 | else: 56 | neighbor_ids = one_sample[graph_id].split(',') 57 | neighbor_ids = neighbor_ids[:self.max_neighbor] \ 58 | if len(neighbor_ids) >= self.max_neighbor \ 59 | else neighbor_ids + [0] * (self.max_neighbor - len(neighbor_ids)) 60 | neighbor_ids = list(map(int, neighbor_ids)) 61 | assert len(neighbor_ids) == self.max_neighbor 62 | output['neighbor_ids'] = torch.LongTensor(neighbor_ids) 63 | return output 64 | 65 | class CIGARDatasetLocal(CIGARDataset): 66 | def __init__(self, 67 | args, 68 | table_name, 69 | shuffle_buffer_size=8194, 70 | is_test=False, 71 | max_len=100, 72 | max_neighbor=10): 73 | 74 | super(CIGARDatasetLocal,self).__init__(args, table_name, shuffle_buffer_size, is_test, max_len, max_neighbor) 75 | 76 | def get_total_row_count(self): 77 | cnt = 0 78 | for _ in self.reader: 79 | if len(_) > 0: 80 | cnt += 1 81 | self.reader.seek(0) 82 | return cnt 83 | 84 | def _new_reader(self): 85 | if self.reader is not None: 86 | self.reader.close() 87 | print('self.table_name', self.table_name) 88 | reader = open(self.table_name, "r") 89 | 90 | return reader 91 | 92 | def _read_record(self): 93 | try: 94 | column_l = self.reader.readline().strip().split("\t") 95 | # print(column_l) 96 | assert len(column_l) == self.args.column_len, "len(column_l) must be %d, now is %d" % (self.args.column_len, len(column_l)) 97 | except: 98 | self.reader.seek(0) 99 | column_l = self.reader.readline().strip().split("\t") 100 | return column_l 101 | 102 | def __del__(self): 103 | if self.reader is not None: 104 | self.reader.close() 105 | -------------------------------------------------------------------------------- /backend/arguments.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2022 The Luoxi Team. 3 | # All rights reserved. 4 | # This source code is licensed under the Apache 2.0 license 5 | # found in the LICENSE file in the root directory. 6 | 7 | import argparse 8 | import os 9 | import torch 10 | 11 | 12 | def add_training_args(parser): 13 | group = parser.add_argument_group('train', 'training') 14 | 15 | #optimizer args 16 | group.add_argument('--batch-size', type=int, default=128, 17 | help='data batch size') 18 | group.add_argument('--clip-grad', type=float, default=1.0, 19 | help='gradient clipping') 20 | group.add_argument('--num-epochs', type=int, default=2, 21 | help='num-epochs') 22 | group.add_argument('--train-iters', type=int, default=1000000, 23 | help='total number of iterations to train') 24 | group.add_argument('--log-interval', type=int, default=100, 25 | help='report metrics interval') 26 | group.add_argument('--exit-interval', type=int, default=None, 27 | help='Exit the program after this many new iterations.') 28 | group.add_argument('--seed', type=int, default=1234, 29 | help='random seed') 30 | group.add_argument('--optimizer', default='adamw', 31 | help='optimizer, One of [adamw, adam]') 32 | group.add_argument("--final-saved-iteration", type=int, default=0, help="if gpu") 33 | 34 | #ddp params 35 | group.add_argument('--find-unused-parameters', action='store_true', help='find_unused_parameters setting in DDP ') 36 | 37 | #backward args 38 | group.add_argument('--backward-step-contains-in-forward-step', action='store_true', help='backward operations contains in forward step') 39 | 40 | # Learning rate. 41 | group.add_argument('--lr', type=float, default=1.0e-4, 42 | help='learning rate') 43 | group.add_argument('--weight-decay', type=float, default=0.0, 44 | help='weight decay coefficient for L2 regularization') 45 | 46 | # model checkpointing args 47 | group.add_argument('--save', type=str, default=None, 48 | help='Output directory to save checkpoints to.') 49 | group.add_argument('--save-interval', type=int, default=5000, 50 | help='number of iterations between saves') 51 | group.add_argument('--load', type=str, default=None, 52 | help='Path to a directory containing a model checkpoint.') 53 | group.add_argument('--task-type', type=str, default='train', 54 | help='task type: train, inference') 55 | 56 | # distributed args 57 | group.add_argument('--distributed-backend', default='nccl', 58 | help='which backend to use for distributed ' 59 | 'train_hub. One of [gloo, nccl]') 60 | group.add_argument('--local_rank', type=int, default=0, 61 | help='local rank passed from distributed launcher') 62 | group.add_argument('--worker-cnt', type=int, default=1, 63 | help='number of workers') 64 | group.add_argument('--gpus-per-node', type=int, default=1, 65 | help='number of gpus per node') 66 | group.add_argument('--entry', type=str, default='pretrain_gpt2.py') 67 | 68 | 69 | return parser 70 | 71 | 72 | def add_data_args(parser): 73 | group = parser.add_argument_group('data', 'data configurations') 74 | 75 | # add arguments of input and output 76 | group.add_argument('--tables', type=str, default='', help='input table(data) name (train, valid, test)') 77 | group.add_argument('--outputs', type=str, default='') 78 | 79 | #onnx model 80 | group.add_argument("--onnx_export_path", type=str, default='onnx model export path', help="onnx_export_path") 81 | group.add_argument("--onnx_model_name", type=str, default='onnx_model.onnx', help="ONNX model name") 82 | 83 | return parser 84 | 85 | def add_validation_args(parser): 86 | group = parser.add_argument_group(title='validation') 87 | 88 | group.add_argument('--eval-iters', type=int, default=100, 89 | help='Number of iterations to run for evaluation' 90 | 'validation/test for.') 91 | group.add_argument('--eval-interval', type=int, default=1000, 92 | help='Interval between running evaluation on ' 93 | 'validation set.') 94 | 95 | return parser 96 | 97 | def get_args(add_personalized_args_fn=None): 98 | 99 | parser = argparse.ArgumentParser(description='..') 100 | parser = add_training_args(parser) 101 | parser = add_data_args(parser) 102 | parser = add_validation_args(parser) 103 | 104 | if add_personalized_args_fn is not None: 105 | parser = add_personalized_args_fn(parser) 106 | 107 | args = parser.parse_args() 108 | 109 | args.cuda = torch.cuda.is_available() 110 | 111 | args.rank = int(os.getenv('RANK', '0')) 112 | args.world_size = int(os.getenv('WORLD_SIZE', '1')) 113 | 114 | return args 115 | -------------------------------------------------------------------------------- /train_hub/model_execution_template.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import sys 3 | import os 4 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 5 | from backend.task_backbone import task_dispatcher 6 | from torch.utils.data import Dataset 7 | 8 | def model_provider(args): 9 | ''' 10 | Build the model func. 11 | :param args: user defined arguments dictionary 12 | :return: 13 | model: user defined model that implements the torch.nn.Module interface 14 | ''' 15 | model = None 16 | 17 | return model 18 | 19 | 20 | def get_batch(data_iterator, args): 21 | ''' 22 | Batch data processing method for training job. 23 | :param data_iterator: data iterator that implements the torch.utils.data.DataLoader interface 24 | :param args: user defined arguments dictionary 25 | :return: 26 | dictionary (python dict()): a dictionary that contains all data used in the model forward step 27 | ''' 28 | ret_data = [] 29 | 30 | return ret_data 31 | 32 | def get_inference_batch(data_iterator, args): 33 | ''' 34 | Batch data processing method for inference job. 35 | :param data_iterator: data iterator that implements the torch.utils.data.DataLoader interface 36 | :param args: user defined arguments dictionary 37 | :return: 38 | dictionary (python dict()): a dictionary that contains all data used in the model forward step 39 | ''' 40 | ret_data = [] 41 | 42 | return ret_data 43 | 44 | def forward_func(data_iterator, model, args): 45 | ''' 46 | Model forward step. 47 | :param data_iterator: data iterator that implements the torch.utils.data.DataLoader interface 48 | :param model: a model that implements the torch.nn.Module interface and defined in the model_provider func 49 | :param args: user defined arguments dictionary 50 | :return: 51 | if task_type = 'train': 52 | then return loss: a one-dimensional loss vector that contains every sample's loss 53 | stats: other results which need print on terminal 54 | else(aka task_type = 'inference'): 55 | then return infer_res: results list that output to files 56 | ''' 57 | if args.task_type == 'train': 58 | ret_data = get_batch(data_iterator, args) 59 | 60 | 61 | loss, *stats = model(ret_data) 62 | 63 | return loss, stats 64 | else: 65 | ret_data = get_inference_batch(data_iterator, args) 66 | 67 | infer_res = model(ret_data) 68 | 69 | infer_res = list(infer_res) 70 | return infer_res 71 | 72 | 73 | def train_eval_datasets_provider(args): 74 | ''' 75 | Build train, valid, and test datasets for training job. 76 | :param args: user defined arguments dictionary 77 | :return: 78 | train_dataset, valid_dataset : dataset that implements the torch.utils.data.Dataset interface 79 | ''' 80 | 81 | # Build the dataset. 82 | input_tables = args.tables.split(",") 83 | 84 | #this just a example, you can build a personal Dataset class that inherits the torch.utils.data.Dataset interface 85 | dataset = Dataset(input_tables[0]) 86 | 87 | if len(input_tables) > 1: 88 | eval_dataset = Dataset(input_tables[1]) 89 | #run on local 90 | # eval_dataset = SeqToSeqDatasetLocal(args, input_tables[1], text_preprocessor, is_test=True) 91 | else: 92 | eval_dataset = None 93 | 94 | return dataset, eval_dataset 95 | 96 | def personalized_args_provider(parser): 97 | ''' 98 | User-defined parameters function 99 | :param parser: parser,the object of argparse.ArgumentParser 100 | :return: a python method where user defines parameters in it 101 | ''' 102 | def add_model_config_args(parser): 103 | """Model arguments""" 104 | group = parser.add_argument_group('model', 'model configuration') 105 | 106 | #add some exclusive parameters that your model use 107 | group.add_argument('--mock', type=float, default=0.1, help='just a example') 108 | 109 | return parser 110 | 111 | return add_model_config_args(parser) 112 | 113 | def inference_dataset_provider(args): 114 | ''' 115 | Build train, valid, and test datasets for inference job. 116 | :param args: user defined arguments dictionary 117 | :return: 118 | train_dataset, valid_dataset : dataset that implements the torch.utils.data.Dataset interface 119 | ''' 120 | input_tables = args.tables.split(",") 121 | # this just a example, you can build a personal Dataset class that inherits the torch.utils.data.Dataset interface 122 | dataset = Dataset(input_tables) 123 | 124 | return dataset 125 | 126 | def training_post_processing_func(model, args): 127 | ''' 128 | :param model: the trained model 129 | :param args: user defined arguments dictionary 130 | :return: None 131 | ''' 132 | pass 133 | 134 | if __name__ == "__main__": 135 | task_dispatcher(train_eval_dataset_provider=train_eval_datasets_provider, 136 | inference_dataset_provider=inference_dataset_provider, 137 | model_provider=model_provider, 138 | forward_func=forward_func, 139 | personalized_args_provider=personalized_args_provider, 140 | training_post_processing_func=training_post_processing_func) -------------------------------------------------------------------------------- /backend/dataset_hub/mcrec_dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2022 The Luoxi Team. 3 | # All rights reserved. 4 | # This source code is licensed under the Apache 2.0 license 5 | # found in the LICENSE file in the root directory. 6 | 7 | import torch 8 | from backend.dataset_hub.base_datasets import * 9 | 10 | class CRecDataset(BaseDataset): 11 | def __init__(self, args, file_name, shuffle_buffer_size=8096, is_test=False): 12 | """ 13 | function: read the dataset for the cloud-based recommendation model. 14 | for more details, please refer to the parent class 15 | """ 16 | super().__init__(args, file_name, shuffle_buffer_size, is_test) 17 | 18 | def get_total_row_count(self): 19 | cnt = 0 20 | for _ in self.reader: 21 | if len(_) > 0: 22 | cnt += 1 23 | self.reader.seek(0) 24 | 25 | return cnt 26 | 27 | def _new_reader(self): 28 | if self.reader is not None: 29 | self.reader.close() 30 | print('self.file_name', self.table_name) 31 | reader = open(self.table_name, "r") 32 | 33 | return reader 34 | 35 | def _read_record(self): 36 | try: 37 | column_l = self.reader.readline().strip().split("\t") 38 | assert len(column_l) == 3, "len(column_l) must be 3, now is %d" % len(column_l) 39 | except: 40 | self.reader.seek(0) 41 | column_l = self.reader.readline().strip().split("\t") 42 | 43 | return column_l 44 | 45 | def _parse_item(self, sample): 46 | # user feature 47 | hist_seq = sample[0].strip().split(',') 48 | hist_seq = list(map(int, hist_seq)) 49 | 50 | # Item feature 51 | cand = int(sample[1]) 52 | 53 | # Label 54 | label = float(sample[2]) 55 | 56 | output = { 57 | "hist_seq" : torch.LongTensor(hist_seq), 58 | "cand" : cand, 59 | "label": label} 60 | return output 61 | 62 | def __del__(self): 63 | if self.reader is not None: 64 | self.reader.close() 65 | 66 | class ORecDataset(BaseDataset): 67 | def __init__(self, args, file_name, shuffle_buffer_size=8096, is_test=False): 68 | """ 69 | function: read the dataset for the on-device recommendation model. 70 | for more details, please refer to the parent class 71 | """ 72 | super().__init__(args, file_name, shuffle_buffer_size, is_test) 73 | 74 | def get_total_row_count(self): 75 | cnt = 0 76 | for _ in self.reader: 77 | if len(_) > 0: 78 | cnt += 1 79 | self.reader.seek(0) 80 | 81 | return cnt 82 | 83 | def _new_reader(self): 84 | if self.reader is not None: 85 | self.reader.close() 86 | print('self.table_name', self.table_name) 87 | reader = open(self.table_name, "r") 88 | 89 | return reader 90 | 91 | def _read_record(self): 92 | try: 93 | column_l = self.reader.readline().strip().split("\t") 94 | assert len(column_l) == 4, "len(column_l) must be 4, now is %d" % len(column_l) 95 | except: 96 | self.reader.seek(0) 97 | column_l = self.reader.readline().strip().split("\t") 98 | 99 | return column_l 100 | 101 | def _parse_item(self, sample): 102 | # user feature 103 | hist_seq = sample[0].strip().split(',') 104 | hist_seq = list(map(int, hist_seq)) 105 | 106 | # Item feature 107 | cand = int(sample[1]) 108 | 109 | # prior score 110 | prior_score = float(sample[2]) 111 | 112 | # Label 113 | label = float(sample[3]) 114 | 115 | output = { 116 | "hist_seq" : torch.LongTensor(hist_seq), 117 | "cand" : cand, 118 | "prior_score": prior_score, 119 | "label": label} 120 | return output 121 | 122 | def __del__(self): 123 | if self.reader is not None: 124 | self.reader.close() 125 | 126 | class MCRecDataset(BaseDataset): 127 | def __init__(self, args, file_name, shuffle_buffer_size=8096, is_test=False): 128 | """ 129 | function: read the counterfactual dataset for the meta controller. 130 | for more details, please refer to the parent class 131 | """ 132 | super().__init__(args, file_name, shuffle_buffer_size, is_test) 133 | 134 | def get_total_row_count(self): 135 | cnt = 0 136 | for _ in self.reader: 137 | if len(_) > 0: 138 | cnt += 1 139 | self.reader.seek(0) 140 | 141 | return cnt 142 | 143 | def _new_reader(self): 144 | if self.reader is not None: 145 | self.reader.close() 146 | print('self.table_name', self.table_name) 147 | reader = open(self.table_name, "r") 148 | 149 | return reader 150 | 151 | def _read_record(self): 152 | try: 153 | column_l = self.reader.readline().strip().split("\t") 154 | assert len(column_l) == 2, "len(column_l) must be 2, now is %d" % len(column_l) 155 | except: 156 | self.reader.seek(0) 157 | column_l = self.reader.readline().strip().split("\t") 158 | 159 | return column_l 160 | 161 | def _parse_item(self, sample): 162 | # user feature 163 | hist_seq = sample[0].strip().split(',') 164 | hist_seq = list(map(int, hist_seq)) 165 | 166 | # Label 167 | label = float(sample[1]) 168 | 169 | output = { 170 | "hist_seq" : torch.LongTensor(hist_seq), 171 | "label": label} 172 | return output 173 | 174 | def __del__(self): 175 | if self.reader is not None: 176 | self.reader.close() 177 | -------------------------------------------------------------------------------- /backend/model_hub/mcrec_model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2022 The Luoxi Team. 3 | # All rights reserved. 4 | # This source code is licensed under the Apache 2.0 license 5 | # found in the LICENSE file in the root directory. 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | class Embedding(nn.Module): 11 | def __init__(self, num, dim): 12 | """ 13 | build embedding table 14 | input schema: 15 | num: number of the embedding vectors 16 | dim: dimension of the embedding vector 17 | output schema: 18 | return a embedding table for lookup 19 | """ 20 | super(Embedding, self).__init__() 21 | self.emb = nn.Embedding(num, dim, padding_idx=0) 22 | 23 | def forward(self, idx): 24 | output = self.emb(idx) 25 | return output 26 | 27 | class Attention(nn.Module): 28 | def __init__(self, dim_i, dim_o): 29 | """ 30 | build the target-aware attention 31 | input schema: 32 | dim_i: the dimension of the input feature vector 33 | dim_o: the dimension of the output feature vector 34 | output schema: 35 | return a aggregated vector from the context k, v of q 36 | """ 37 | super(Attention, self).__init__() 38 | self.Q = nn.Linear(dim_i, dim_o) 39 | self.K = nn.Linear(dim_i, dim_o) 40 | self.V = nn.Linear(dim_i, dim_o) 41 | 42 | def forward(self, hist_seq_emb, hist_seq_mask, cand_emb): 43 | q, k, v = self.Q(cand_emb), self.K(hist_seq_emb), self.V(hist_seq_emb) 44 | 45 | # q: B x d; k: B x L x d; v: B x L x d 46 | # hist_seq_mask: B x L 47 | logits = torch.sum(q.unsqueeze(1) * k, dim=2) 48 | logits = logits * hist_seq_mask + logits * (1-hist_seq_mask) * (-2**32.0) 49 | scores = torch.softmax(logits, dim=1) 50 | 51 | output = torch.sum(scores.unsqueeze(2) * v, dim=1) 52 | 53 | return output 54 | 55 | class CRec(nn.Module): 56 | def __init__(self, args): 57 | """ 58 | build the cloud-based recommendation model 59 | input schema: 60 | args: parameters for the model initialization 61 | hist_seq: the user historical sequence 62 | cand: the target item for scoring 63 | label: click or not 64 | output: 65 | return the loss in the "train" mode or prediction in the other mode 66 | """ 67 | super(CRec, self).__init__() 68 | self.task_type = args.task_type 69 | self.dim = args.dim 70 | self.num = args.num 71 | self.emb = Embedding(self.num, self.dim) 72 | self.att = Attention(self.dim, self.dim) 73 | self.projection = nn.Linear(self.dim, self.dim) 74 | self.classifier = nn.Linear(self.dim, 2) 75 | self.loss = nn.CrossEntropyLoss() 76 | 77 | def forward(self, hist_seq, cand, label): 78 | hist_seq_emb = self.emb(hist_seq) 79 | hist_seq_mask = torch.where(hist_seq == 0, torch.zeros_like(hist_seq), torch.ones_like(hist_seq)) 80 | cand_emb = self.emb(cand) 81 | 82 | agg_emb = self.att(hist_seq_emb, hist_seq_mask, cand_emb) 83 | logits = self.classifier(self.projection(agg_emb)) 84 | 85 | if self.task_type!='train': 86 | pred = torch.softmax(logits, dim=1)[:,1] 87 | return pred 88 | else: 89 | loss = torch.mean(self.loss(logits, label)) 90 | return loss 91 | 92 | class ORec(nn.Module): 93 | def __init__(self, args): 94 | """ 95 | build the on-device recommendation model 96 | input schema: 97 | args: parameters for the model initialization 98 | hist_seq: the user historical sequence 99 | cand: the target item for scoring 100 | prior_score: the prior prediction from the agnostic cloud-based model 101 | label: click or not 102 | output: 103 | return the loss in the "train" mode or prediction in the other mode 104 | """ 105 | super(ORec, self).__init__() 106 | self.task_type = args.task_type 107 | self.dim = args.dim 108 | self.num = args.num 109 | self.emb = Embedding(self.num, self.dim) 110 | self.att = Attention(self.dim, self.dim) 111 | self.projection = nn.Linear(self.dim, self.dim) 112 | self.classifier = nn.Linear(self.dim, 2) 113 | self.loss = nn.CrossEntropyLoss() 114 | 115 | def forward(self, hist_seq, cand, prior_score, label): 116 | hist_seq_emb = self.emb(hist_seq) 117 | hist_seq_mask = torch.where(hist_seq == 0, torch.zeros_like(hist_seq), torch.ones_like(hist_seq)) 118 | cand_emb = self.emb(cand) 119 | 120 | agg_emb = self.att(hist_seq_emb, hist_seq_mask, cand_emb) 121 | logits_res = self.classifier(self.projection(agg_emb)) 122 | 123 | # thresholding the interval so that it does not overflow 124 | score = prior_score.unsqueeze(1) 125 | score_0 = 1.0 - score 126 | score_1 = score 127 | score_0 = score_0 * (1.0 - 1e-3) + 1e-4 128 | score_1 = score_1 * (1.0 - 1e-3) + 1e-4 129 | logits_main = torch.cat([-torch.log(1.0/score_0 - 1.0), -torch.log(1.0/score_1 - 1.0)], dim=1) 130 | 131 | logits = logits_res + logits_main 132 | 133 | if self.task_type != "train": 134 | pred = torch.softmax(logits, dim=1)[:,1] 135 | return pred 136 | else: 137 | loss = torch.mean(self.loss(logits, label)) 138 | return loss 139 | 140 | class Controller(nn.Module): 141 | def __init__(self, args): 142 | """ 143 | build the controller model 144 | input schema: 145 | args: parameters for the model initialization 146 | hist_seq: the user historical sequence 147 | label: which mechanism (e.g., the cloud-based session recommendation, the cloud-based refresh or the on-device recommendation) to invoke for recommendation 148 | output: 149 | return the loss in the "train" mode or prediction in the other mode 150 | """ 151 | super(Controller, self).__init__() 152 | self.task_type = args.task_type 153 | self.dim = args.dim 154 | self.num = args.num 155 | self.emb = Embedding(self.num, self.dim) 156 | self.att = Attention(self.dim, self.dim) 157 | self.projection = nn.Linear(self.dim, self.dim) 158 | self.classifier = nn.Linear(self.dim, 3) 159 | self.loss = nn.CrossEntropyLoss() 160 | 161 | def forward(self, hist_seq, label): 162 | hist_seq_emb = self.emb(hist_seq) 163 | hist_seq_mask = torch.where(hist_seq == 0, torch.zeros_like(hist_seq), torch.ones_like(hist_seq)) 164 | # we built a pseudo averaging cand embedding as the query in target-aware attention 165 | cand_emb = torch.sum(hist_seq_emb * hist_seq_mask.unsqueeze(2), dim=1)/torch.sum(hist_seq_mask, dim=1, keepdim=True) 166 | 167 | agg_emb = self.att(hist_seq_emb, hist_seq_mask, cand_emb) 168 | logits = self.classifier(self.projection(agg_emb)) 169 | 170 | if self.task_type != "train": 171 | pred = torch.softmax(logits, dim=1) 172 | return pred 173 | else: 174 | loss = torch.mean(self.loss(logits, label)) 175 | return loss 176 | -------------------------------------------------------------------------------- /backend/model_hub/ggcn_model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2022 The Luoxi Team. 3 | # All rights reserved. 4 | # This source code is licensed under the Apache 2.0 license 5 | # found in the LICENSE file in the root directory. 6 | 7 | import torch.nn as nn 8 | import torch 9 | import math 10 | import numpy as np 11 | import torch.nn.functional as F 12 | from torch.nn.parameter import Parameter 13 | from sklearn.metrics import f1_score 14 | 15 | 16 | class GraphConvolution(nn.Module): 17 | def __init__(self, in_features, out_features, residual=False, variant=False): 18 | """ 19 | :param in_features: input feature dimension 20 | :param out_features: output feature dimension 21 | :return: 22 | """ 23 | super(GraphConvolution, self).__init__() 24 | self.variant = variant 25 | if self.variant: 26 | self.in_features = 2*in_features 27 | else: 28 | self.in_features = in_features 29 | self.out_features = out_features 30 | self.residual = residual 31 | self.weight = Parameter(torch.FloatTensor(self.in_features,self.out_features)) 32 | self.reset_parameters() 33 | 34 | def reset_parameters(self): 35 | stdv = 1. / math.sqrt(self.out_features) 36 | self.weight.data.uniform_(-stdv, stdv) 37 | 38 | def forward(self, input, adj , h0 , lamda, alpha, l): 39 | theta = math.log(lamda/l+1) 40 | hi = torch.spmm(adj, input) 41 | if self.variant: 42 | support = torch.cat([hi,h0],1) 43 | r = (1-alpha)*hi+alpha*h0 44 | else: 45 | support = (1-alpha)*hi+alpha*h0 46 | r = support 47 | output = theta*torch.mm(support, self.weight)+(1-theta)*r 48 | if self.residual: 49 | output = output+input 50 | return output 51 | 52 | class GGCN(nn.Module): 53 | def __init__(self, nfeat, nlayers,nhidden, nclass, dropout, lamda, alpha,variant,args,onnx_step): 54 | super(GGCN, self).__init__() 55 | """ 56 | :param in_features: input feature dimension 57 | :param out_features: output feature dimension 58 | :return: 59 | """ 60 | self.convs = nn.ModuleList() 61 | for _ in range(nlayers): 62 | self.convs.append(GraphConvolution(nhidden, nhidden,variant=variant,residual=True)) 63 | self.fcs = nn.ModuleList() 64 | self.fcs.append(nn.Linear(nfeat, nhidden)) 65 | self.fcs.append(nn.Linear(nhidden, nclass)) 66 | self.act_fn = nn.ReLU() 67 | self.sig = nn.Sigmoid() 68 | self.dropout = dropout 69 | self.alpha = alpha 70 | self.lamda = lamda 71 | self.onnx_step = onnx_step 72 | input_tables = args.tables.split(",") 73 | [self.train_adj_list, self.val_adj_list, self.test_adj_list, self.train_feat, self.val_feat, self.test_feat, 74 | self.train_labels, self.val_labels, self.test_labels, self.train_nodes, self.val_nodes, 75 | self.test_nodes] = torch.load('%s/traingraphL.pth' % input_tables[0]) 76 | 77 | def forward(self, x, adj): 78 | if self.onnx_step: 79 | self.onnx_step = False 80 | return self.onnxModel() 81 | _layers = [] 82 | 83 | x = F.dropout(x, self.dropout, training=self.training) 84 | layer_inner = self.act_fn(self.fcs[0](x)) 85 | _layers.append(layer_inner) 86 | for i,con in enumerate(self.convs): 87 | layer_inner = F.dropout(layer_inner, self.dropout, training=self.training) 88 | layer_inner = self.act_fn(con(layer_inner,adj,_layers[0],self.lamda,self.alpha,i+1)) 89 | layer_inner = F.dropout(layer_inner, self.dropout, training=self.training) 90 | layer_inner = self.sig(self.fcs[-1](layer_inner)) 91 | return layer_inner 92 | 93 | def evalModel(self): 94 | """ 95 | evaluate model with F1_score 96 | input schema: 97 | args: validation nodes, labels, adj, feature 98 | output schema: 99 | loss: BCELoss of model inference 100 | score: F1_score of model inference 101 | """ 102 | loss = 0 103 | score = 0 104 | device = torch.cuda.current_device() 105 | for i in range(2): 106 | nodes = self.val_nodes[i] 107 | labels = self.val_labels[i] 108 | labels = labels.to(device) 109 | feat = self.val_feat[i] 110 | adj = self.val_adj_list[i] 111 | feat = feat.to(device) 112 | adj = adj.to(device) 113 | output = self.forward(feat,adj) 114 | lossfn = torch.nn.BCELoss() 115 | loss += lossfn(output[:nodes], labels[:nodes]) 116 | predict = np.where(output[:nodes].data.cpu().numpy() > 0.5, 1, 0) 117 | score += f1_score(labels[:nodes].data.cpu().numpy(), predict, average='micro') 118 | return loss/2, score/2 119 | 120 | def onnxModel(self): 121 | """ 122 | evaluate onnx model with F1_score 123 | input schema: 124 | args: onnx nodes, labels, adj, feature 125 | output schema: 126 | loss: BCELoss of model inference 127 | score: F1_score of model inference 128 | """ 129 | loss = 0 130 | score = 0 131 | device = torch.cuda.current_device() 132 | for i in range(2): 133 | nodes = self.val_nodes[i] 134 | labels = self.val_labels[i] 135 | labels = labels.to(device) 136 | feat = self.val_feat[i] 137 | adj = self.val_adj_list[i] 138 | feat = feat.to(device) 139 | adj = adj.to(device) 140 | output = self.forward(feat,adj) 141 | lossfn = torch.nn.BCELoss() 142 | loss += lossfn(output[:nodes], labels[:nodes]) 143 | predict = np.where(output[:nodes].data.cpu().numpy() > 0.5, 1, 0) 144 | score += f1_score(labels[:nodes].data.cpu().numpy(), predict, average='micro') 145 | return loss/2, score/2 146 | 147 | def inferModel(self): 148 | """ 149 | evaluate model with test data in F1_score 150 | input schema: 151 | args: test nodes, labels, adj, feature 152 | output schema: 153 | loss: BCELoss of model inference 154 | score: F1_score of model inference 155 | """ 156 | loss = 0 157 | score = 0 158 | device = torch.cuda.current_device() 159 | for i in range(2): 160 | nodes = self.test_nodes[i] 161 | labels = self.test_labels[i] 162 | labels = labels.to(device) 163 | feat = self.test_feat[i] 164 | adj = self.test_adj_list[i] 165 | feat = feat.to(device) 166 | adj = adj.to(device) 167 | output = self.forward(feat,adj) 168 | lossfn = torch.nn.BCELoss() 169 | loss += lossfn(output[:nodes], labels[:nodes]) 170 | predict = np.where(output[:nodes].data.cpu().numpy() > 0.5, 1, 0) 171 | score += f1_score(labels[:nodes].data.cpu().numpy(), predict, average='micro') 172 | return loss/2, score/2 -------------------------------------------------------------------------------- /backend/model_hub/pinrec_model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2022 The Luoxi Team. 3 | # All rights reserved. 4 | # This source code is licensed under the Apache 2.0 license 5 | # found in the LICENSE file in the root directory. 6 | 7 | from collections import defaultdict 8 | import torch.nn 9 | import torch.nn as nn 10 | from module import attention, encoder, common, plugin 11 | import logging 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | __all__ = ["ModelMeta", "get_model_meta", "model"] 16 | 17 | consts = { 18 | "FIELD_USER_ID": "user_id", 19 | "FIELD_TARGET_ID": "target_id", 20 | "FIELD_CLK_SEQUENCE": "clk_sequence", 21 | "FIELD_LABEL": "label", 22 | "FIELD_GROUP_ID": "group_id" 23 | } 24 | 25 | 26 | class ModelMeta(object): 27 | def __init__(self, config_parser=None, data_loader=None, model_builder=None): 28 | """ 29 | Build parent class 30 | """ 31 | self._config_parser = config_parser 32 | self._data_loader = data_loader 33 | self._model_builder = model_builder 34 | 35 | @property 36 | def arch_config_parser(self): 37 | return self._config_parser 38 | 39 | def set_arch_config_parser(self, parser): 40 | self._check(self._config_parser, "Config parser has been set") 41 | self._config_parser = parser 42 | 43 | @property 44 | def data_loader_builder(self): 45 | return self._data_loader 46 | 47 | def set_data_loader_builder(self, loader): 48 | self._check(self._data_loader, "Data loader builder has been set") 49 | self._data_loader = loader 50 | 51 | @property 52 | def model_builder(self): 53 | return self._model_builder 54 | 55 | def set_model_builder(self, model_builder): 56 | self._check(self._model_builder, "Model builder has been set") 57 | self._model_builder = model_builder 58 | 59 | def _check(self, value, message): 60 | if value is not None: 61 | raise ValueError(message) 62 | 63 | def __setitem__(self, k, v): 64 | self.k = v 65 | 66 | 67 | # Each model consists of two parts: model and config 68 | class MetaType(object): 69 | """ 70 | Build model type 71 | Each model consists of two parts: ConfigParser and ModelBuilder 72 | """ 73 | ConfigParser = ModelMeta.set_arch_config_parser 74 | ModelBuilder = ModelMeta.set_model_builder 75 | 76 | 77 | class _ModelMetaRegister(object): 78 | def __init__(self): 79 | """ 80 | Register different models,to facilitate further expansion in the future 81 | input schema: 82 | name: model name 83 | setter: which part of the model is defined 84 | output schema: 85 | return a model 86 | """ 87 | self._register_map = defaultdict(ModelMeta) 88 | 89 | def get(self, name): 90 | return self._register_map.get(name) 91 | 92 | def __call__(self, name, setter): 93 | model_meta = self._register_map[name] 94 | 95 | def _executor(func): 96 | setter(model_meta, func) 97 | return func 98 | 99 | return _executor 100 | 101 | 102 | model = _ModelMetaRegister() 103 | get_model_meta = model.get 104 | 105 | 106 | @model("pinrec", MetaType.ModelBuilder) 107 | class DeepInterestNetwork(nn.Module): 108 | def __init__(self, model_conf, group_num): 109 | """ 110 | Main algorithm,based on DIN 111 | input schema: 112 | model_conf: configuration file 113 | group_num: number of user groups 114 | output schema: 115 | return a score 116 | """ 117 | super(DeepInterestNetwork, self).__init__() 118 | 119 | assert isinstance(model_conf, ModelConfig) 120 | self._plugin_index = 0 121 | self._group_num = group_num 122 | self._id_encoder = encoder.IDEncoder( 123 | model_conf.id_vocab, 124 | model_conf.id_dimension, 125 | ) 126 | self._target_emb_plugin = nn.ModuleList( 127 | [plugin.Plugin(model_conf.id_dimension, torch.nn.Tanh) for i in range(self._group_num)] 128 | ) 129 | self._target_trans = common.StackedDense( 130 | model_conf.id_dimension, [model_conf.id_dimension], [torch.nn.Tanh] 131 | ) 132 | self._seq_trans = common.StackedDense( 133 | model_conf.id_dimension, [model_conf.id_dimension], [torch.nn.Tanh] 134 | ) 135 | self._target_attention = attention.TargetAttention( 136 | key_dimension=model_conf.id_dimension, 137 | value_dimension=model_conf.id_dimension, 138 | ) 139 | self._atten_aggregated_embed_plugin = nn.ModuleList( 140 | [plugin.Plugin(model_conf.id_dimension, torch.nn.Tanh) for i in range(self._group_num)] 141 | ) 142 | self._classifier = common.StackedDense( 143 | model_conf.id_dimension * 2, 144 | model_conf.classifier + [1], 145 | ([torch.nn.Tanh] * len(model_conf.classifier)) + [None] 146 | ) 147 | 148 | def __setitem__(self, k, v): 149 | self.k = v 150 | 151 | def set_plugin_index(self, plugin_index): 152 | self._plugin_index = plugin_index 153 | 154 | def forward(self, features, plugin=True): 155 | # Encode target item 156 | # B * D 157 | target_embed = self._id_encoder(features[consts["FIELD_TARGET_ID"]]) 158 | if plugin: 159 | target_embed = self._target_emb_plugin[self._plugin_index](target_embed) 160 | target_embed = self._target_trans(target_embed) 161 | 162 | # Encode user historical behaviors 163 | with torch.no_grad(): 164 | mask = torch.not_equal(features[consts["FIELD_CLK_SEQUENCE"]], 0).to(dtype=torch.float32) 165 | # B * L * D 166 | hist_embed = self._id_encoder(features[consts["FIELD_CLK_SEQUENCE"]]) 167 | if plugin: 168 | hist_embed = self._target_emb_plugin[self._plugin_index](hist_embed) 169 | hist_embed = self._seq_trans(hist_embed) 170 | 171 | # Target attention 172 | atten_aggregated_embed = self._target_attention( 173 | target_key=target_embed, 174 | item_keys=hist_embed, 175 | item_values=hist_embed, 176 | mask=mask 177 | ) 178 | if plugin: 179 | atten_aggregated_embed = self._atten_aggregated_embed_plugin[self._plugin_index](atten_aggregated_embed) 180 | classifier_input = torch.cat([target_embed, atten_aggregated_embed], dim=1) 181 | return self._classifier(classifier_input) 182 | 183 | 184 | class ModelConfig(object): 185 | def __init__(self): 186 | """ 187 | Main algorithm,based on DIN 188 | input schema: 189 | json_obj: json object read from configuration file 190 | output schema: 191 | return an instance of the ModelConfig class 192 | """ 193 | self.id_dimension = 8 194 | self.id_vocab = 500 195 | self.classifier = [64, 32] 196 | self.add_plugin = False 197 | 198 | @staticmethod 199 | @model("pinrec", MetaType.ConfigParser) 200 | def parse(json_obj): 201 | conf = ModelConfig() 202 | conf.id_dimension = json_obj.get("id_dimension") 203 | conf.id_vocab = json_obj.get("id_vocab") 204 | conf.classifier = json_obj.get("classifier") 205 | conf.add_plugin = json_obj.get("add_plugin") 206 | 207 | return conf 208 | 209 | -------------------------------------------------------------------------------- /backend/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2022 The Luoxi Team. 3 | # All rights reserved. 4 | # This source code is licensed under the Apache 2.0 license 5 | # found in the LICENSE file in the root directory. 6 | 7 | import torch 8 | import os 9 | from torch.nn.parallel.distributed import DistributedDataParallel 10 | import random 11 | import numpy as np 12 | import time 13 | import json 14 | 15 | class Timer: 16 | def __init__(self): 17 | self.interval = 0.0 18 | self.is_start = False 19 | self.start_time = None 20 | 21 | def start(self): 22 | assert not self.is_start, 'timer has been started' 23 | torch.cuda.synchronize() 24 | self.start_time = time.time() 25 | self.is_start = True 26 | 27 | def stop(self): 28 | assert self.is_start, 'timer is not started' 29 | torch.cuda.synchronize() 30 | self.interval += (time.time() - self.start_time) 31 | 32 | return self.interval 33 | 34 | def reset(self): 35 | self.interval = 0.0 36 | self.is_start = False 37 | 38 | def print_args(args): 39 | print('arguments:', flush=True) 40 | for arg in vars(args): 41 | dots = '-' * (20 - len(arg)) 42 | print(' {} {} {}'.format(arg, dots, getattr(args, arg)), flush=True) 43 | 44 | def print_rank_0(message): 45 | if torch.distributed.is_initialized(): 46 | if torch.distributed.get_rank() == 0: 47 | print(message, flush=True) 48 | else: 49 | print(message, flush=True) 50 | 51 | def initialize_distribution_env(args): 52 | assert torch.cuda.is_available(), 'requires CUDA.' 53 | 54 | device = args.rank % torch.cuda.device_count() 55 | if args.local_rank is not None: 56 | device = args.local_rank 57 | print("device id: {}".format(device)) 58 | 59 | torch.cuda.set_device(device) 60 | 61 | init_method = 'tcp://' 62 | master_ip = os.getenv('MASTER_ADDR', 'localhost') 63 | master_port = os.getenv('MASTER_PORT', '6000') 64 | init_method += master_ip + ':' + master_port 65 | torch.distributed.init_process_group( 66 | backend=args.distributed_backend, 67 | world_size=args.world_size, rank=args.rank, 68 | init_method=init_method) 69 | print('args.world_size =', args.world_size, ', args.rank =', args.rank, ', args.local_rank =', args.local_rank) 70 | assert args.rank == torch.distributed.get_rank() 71 | 72 | def set_random_seed(seed=None): 73 | if seed is None: 74 | seed = int(time.time()) 75 | print('set_random_seed', seed) 76 | seed = int(seed + torch.distributed.get_rank()) 77 | random.seed(seed) 78 | np.random.seed(seed) 79 | torch.manual_seed(seed) 80 | torch.cuda.manual_seed(seed) 81 | 82 | def get_checkpoint_tracker_filename(checkpoints_path): 83 | return os.path.join(checkpoints_path, 'latest_iteration.txt') 84 | 85 | def get_saved_iteration(args): 86 | 87 | tracker_filename = get_checkpoint_tracker_filename(args.load) 88 | if not os.path.isfile(tracker_filename): 89 | print_rank_0('could not find {} and will start from random'.format(tracker_filename)) 90 | return 0 91 | iteration = 0 92 | with open(tracker_filename, 'r') as f: 93 | metastring = f.read().strip() 94 | try: 95 | iteration = int(metastring) 96 | except ValueError: 97 | print_rank_0('the first row of {} must be an integer'.format(tracker_filename)) 98 | exit() 99 | 100 | return iteration 101 | 102 | def get_checkpoint_name(checkpoints_path, iteration): 103 | d = '{:d}'.format(iteration) 104 | 105 | return os.path.join(checkpoints_path, d, 'rank_{:02d}_model_states.pt'.format(0)) 106 | 107 | def load_model_state_only(model, args, remove_prefix=None, remap_prefix=None, force_remap=False, load_checkpoint_name=None): 108 | if load_checkpoint_name is None: 109 | iteration = get_saved_iteration(args) 110 | checkpoint_name = get_checkpoint_name(args.load, iteration) 111 | else: 112 | iteration = 0 113 | checkpoint_name = load_checkpoint_name 114 | # Load the checkpoint. 115 | sd = torch.load(checkpoint_name, map_location='cpu') 116 | 117 | if isinstance(model, DistributedDataParallel): 118 | model = model.module 119 | model_state = sd['module'] if 'module' in sd else sd 120 | 121 | if remove_prefix: 122 | for load_prefix in remove_prefix: 123 | keys = list(model_state.keys()) 124 | for k in keys: 125 | if k.startswith(load_prefix): 126 | print('Skip loading %s in the checkpoint.' % k) 127 | del model_state[k] 128 | 129 | if remap_prefix: 130 | for var_prefix, load_prefix in remap_prefix.items(): 131 | keys = list(model_state.keys()) 132 | for k in keys: 133 | if k.startswith(load_prefix): 134 | new_k = k.replace(load_prefix, var_prefix) 135 | if new_k in model_state: 136 | print('WARN: param %s already in the checkpoint.' % new_k) 137 | if (new_k not in model_state) or force_remap: 138 | print('Load param %s from %s in the checkpoint.' % (new_k, k)) 139 | model_state[new_k] = model_state[k] 140 | 141 | try: 142 | model.load_state_dict(model_state, strict=True) 143 | except RuntimeError as e: 144 | print(e) 145 | print('> strict load failed, try non-strict load instead') 146 | keys = model.load_state_dict(model_state, strict=False) 147 | print('> non-strict load done') 148 | print(keys) 149 | return iteration 150 | 151 | def ensure_directory_exists(filename): 152 | dirname = os.path.dirname(filename) 153 | if not os.path.exists(dirname): 154 | os.makedirs(dirname) 155 | 156 | def save_checkpoint(iteration, model, optimizer, args): 157 | if isinstance(model, DistributedDataParallel): 158 | model = model.module 159 | 160 | if torch.distributed.get_rank() == 0: 161 | checkpoint_name = get_checkpoint_name(args.save, iteration) 162 | print('rank {} is saving checkpoint at iteration {:7d} to {}'. \ 163 | format(torch.distributed.get_rank(), iteration, checkpoint_name)) 164 | 165 | sd = {} 166 | sd['iteration'] = iteration 167 | sd['module'] = model.state_dict() 168 | 169 | # Optimizer stuff. 170 | if optimizer is not None: 171 | sd['optimizer'] = optimizer.state_dict() 172 | 173 | ensure_directory_exists(checkpoint_name) 174 | torch.save(sd, checkpoint_name) 175 | print('Successfully saved {}'.format(checkpoint_name)) 176 | 177 | 178 | torch.distributed.barrier() 179 | 180 | if torch.distributed.get_rank() == 0: 181 | tracker_filename = get_checkpoint_tracker_filename(args.save) 182 | with open(tracker_filename, 'w') as f: 183 | f.write(str(iteration)) 184 | 185 | torch.distributed.barrier() 186 | 187 | def make_local_writer(args): 188 | file = open(args.outputs, 'w') 189 | 190 | return file 191 | 192 | def to_numpy(tensor,data_type=None): 193 | tensor = tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() 194 | if data_type: 195 | if data_type=="int": 196 | data_type=np.int64 197 | else: 198 | data_type = np.float32 199 | tensor=tensor.astype(data_type) 200 | return tensor 201 | 202 | def parse_arch_config_from_args(model_meta, args): 203 | """ 204 | Read or parse arch config 205 | :param model_meta: 206 | :param args: 207 | :return: 208 | """ 209 | if args.arch_config is not None: 210 | with open(args.arch_config) as jsonfile: 211 | raw_arch_config = json.load(jsonfile) 212 | elif args.arch_config_path is not None: 213 | with open(args.arch_config_path, "rt") as reader: 214 | raw_arch_config = json.load(reader) 215 | else: 216 | raise KeyError("Model configuration not found") 217 | 218 | return model_meta.arch_config_parser(raw_arch_config), raw_arch_config -------------------------------------------------------------------------------- /train_hub/mcrec.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import sys 3 | import os 4 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 5 | from backend.dataset_hub.mcrec_dataset import CRecDataset, ORecDataset, MCRecDataset 6 | from backend.model_hub.mcrec_model import CRec, ORec, Controller 7 | import torch 8 | from backend.task_backbone import task_dispatcher 9 | from onnx_test.mcrec_onnx_model_test import mock_orec_data 10 | 11 | def model_provider(args): 12 | ''' 13 | Build the model func. 14 | :param args: user defined arguments dictionary 15 | :return: 16 | model: user defined model that implements the torch.nn.Module interface 17 | ''' 18 | 19 | if args.model_type == 'crec': 20 | model = CRec(args) 21 | elif args.model_type == 'orec': 22 | args.dim = int(args.dim / 4) # make the on-device model smaller 23 | model = ORec(args) 24 | elif args.model_type == 'mcrec': 25 | model = Controller(args) 26 | else: 27 | model = None 28 | 29 | return model 30 | 31 | def get_batch(data_iterator, args): 32 | 33 | data = next(data_iterator) 34 | 35 | cuda_device = torch.cuda.current_device() 36 | data = dict((k, v.to(cuda_device)) for k, v in data.items()) 37 | 38 | if args.model_type == 'crec': 39 | hist_seq = data['hist_seq'].long() 40 | cand = data['cand'].long() 41 | label = data['label'].long() 42 | 43 | return (hist_seq, cand, label) 44 | 45 | elif args.model_type == 'orec': 46 | hist_seq = data['hist_seq'].long() 47 | cand = data['cand'].long() 48 | prior_score = data['prior_score'].long() 49 | label = data['label'].long() 50 | 51 | return (hist_seq, cand, prior_score, label) 52 | 53 | elif args.model_type == 'mcrec': 54 | hist_seq = data['hist_seq'].long() 55 | label = data['label'].long() 56 | 57 | return (hist_seq, label) 58 | 59 | else: 60 | return None 61 | 62 | def get_inference_batch(data_iterator, args): 63 | 64 | data = next(data_iterator) 65 | 66 | cuda_device = torch.cuda.current_device() 67 | data = dict((k, v.to(cuda_device)) for k, v in data.items()) 68 | 69 | if args.model_type == 'crec': 70 | hist_seq = data['hist_seq'].long() 71 | cand = data['cand'].long() 72 | label = data['label'].long() 73 | 74 | return (hist_seq, cand, label) 75 | 76 | elif args.model_type == 'orec': 77 | hist_seq = data['hist_seq'].long() 78 | cand = data['cand'].long() 79 | prior_score = data['prior_score'].long() 80 | label = data['label'].long() 81 | 82 | return (hist_seq, cand, prior_score, label) 83 | 84 | elif args.model_type == 'mcrec': 85 | hist_seq = data['hist_seq'].long() 86 | label = data['label'].long() 87 | 88 | return (hist_seq, label) 89 | 90 | else: 91 | return None 92 | 93 | def forward_func(data_iterator, model, args): 94 | 95 | if args.task_type == 'train': 96 | if args.model_type == 'crec': 97 | hist_seq, cand, label = get_batch(data_iterator, args) 98 | 99 | loss = model(hist_seq, cand, label) 100 | elif args.model_type == 'orec': 101 | hist_seq, cand, prior_score, label = get_batch(data_iterator, args) 102 | 103 | loss = model(hist_seq, cand, prior_score, label) 104 | elif args.model_type == 'mcrec': 105 | hist_seq, label = get_batch(data_iterator, args) 106 | 107 | loss = model(hist_seq, label) 108 | else: 109 | loss = None 110 | 111 | return loss, [] 112 | else: 113 | if args.model_type == 'crec': 114 | hist_seq, cand, label = get_inference_batch(data_iterator, args) 115 | 116 | pred = model(hist_seq, cand, label) 117 | elif args.model_type == 'orec': 118 | hist_seq, cand, prior_score, label = get_inference_batch(data_iterator, args) 119 | 120 | pred = model(hist_seq, cand, prior_score, label) 121 | elif args.model_type == 'mcrec': 122 | hist_seq, label = get_inference_batch(data_iterator, args) 123 | 124 | pred = model(hist_seq, label) 125 | else: 126 | pred, label = None, None 127 | 128 | return pred, label 129 | 130 | def train_eval_datasets_provider(args): 131 | ''' 132 | Build train, valid, and test datasets for training job. 133 | :param args: user defined arguments dictionary 134 | :return: 135 | train_dataset, valid_dataset : dataset that implements the torch.utils.data.Dataset interface 136 | ''' 137 | 138 | # Build the dataset. 139 | input_files = args.tables.split(",") 140 | 141 | if args.model_type == 'crec': 142 | dataset, eval_dataset = CRecDataset, CRecDataset 143 | elif args.model_type == 'orec': 144 | dataset, eval_dataset = ORecDataset, ORecDataset 145 | elif args.model_type == 'mcrec': 146 | dataset, eval_dataset = MCRecDataset, MCRecDataset 147 | else: 148 | dataset, eval_dataset = None, None 149 | 150 | dataset = dataset(args, input_files[0]) 151 | 152 | if len(input_files) > 1: 153 | eval_dataset = eval_dataset(args, input_files[1], is_test=True) 154 | 155 | else: 156 | eval_dataset = None 157 | 158 | return dataset, eval_dataset 159 | 160 | def inference_dataset_provider(args): 161 | ''' 162 | Build train, valid, and test datasets for inference job. 163 | :param args: user defined arguments dictionary 164 | :return: 165 | train_dataset, valid_dataset : dataset that implements the torch.utils.data.Dataset interface 166 | ''' 167 | input_files = args.tables.split(",") 168 | if args.model_type == 'crec': 169 | dataset, eval_dataset = CRecDataset, CRecDataset 170 | elif args.model_type == 'orec': 171 | dataset, eval_dataset = ORecDataset, ORecDataset 172 | elif args.model_type == 'mcrec': 173 | dataset, eval_dataset = MCRecDataset, MCRecDataset 174 | else: 175 | dataset, eval_dataset = None, None 176 | 177 | dataset = eval_dataset(args, input_files[1], is_test=True) 178 | 179 | return dataset 180 | 181 | def onnx_model_export(args): 182 | ''' 183 | :param model: the trained model 184 | :param args: user defined arguments dictionary 185 | :return: None 186 | ''' 187 | print("=====start onnx export======") 188 | import torch.onnx as onnx 189 | from backend.utils import load_model_state_only 190 | def mock_data_provider(args): 191 | if args.model_type == 'orec': 192 | data = mock_orec_data() 193 | else: 194 | data = None 195 | return data 196 | 197 | data = mock_data_provider(args) 198 | cuda_device = torch.cuda.current_device() 199 | data = dict((k, v.to(cuda_device)) for k, v in data.items()) 200 | model = model_provider(args) 201 | print(' >export onnx model number of parameters on rank{}'.format(sum([p.nelement() for p in model.parameters()])), flush=True) 202 | model.cuda(torch.cuda.current_device()) 203 | 204 | load_model_state_only(model, args, remove_prefix=None, remap_prefix=None) 205 | model.eval() 206 | model_path = args.onnx_export_path 207 | 208 | onnx.export(model, (data["hist_seq"], data["cand"], data["prior_score"], data["label"]), model_path, export_params=True, verbose=False, opset_version=12) 209 | print("success save to:", model_path) 210 | 211 | def personalized_args_provider(parser): 212 | ''' 213 | User-defined parameters function 214 | :param parser: parser,the object of argparse.ArgumentParser 215 | :return: a python method where user defines parameters in it 216 | ''' 217 | def add_model_config_args(parser): 218 | """Model arguments""" 219 | group = parser.add_argument_group('model', 'model configuration') 220 | 221 | group.add_argument('--dim', type=int, default=64, help='embedding dim') 222 | group.add_argument('--num', type=int, default=50000, 223 | help='vocab size for items') 224 | group.add_argument('--cpu-optimizer', action='store_true', 225 | help='Run optimizer on CPU') 226 | group.add_argument('--model_type', type=str, help='model type') 227 | 228 | return parser 229 | 230 | return add_model_config_args(parser) 231 | 232 | if __name__ == "__main__": 233 | task_dispatcher(train_eval_dataset_provider=train_eval_datasets_provider, 234 | inference_dataset_provider=inference_dataset_provider, 235 | model_provider=model_provider, 236 | forward_func=forward_func, 237 | personalized_args_provider=personalized_args_provider, 238 | onnx_model_export_func=onnx_model_export) 239 | -------------------------------------------------------------------------------- /train_hub/cigar.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import sys 3 | import os 4 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 5 | from backend.dataset_hub.cigar_datasets import CIGARDatasetLocal 6 | from backend.model_hub.cigar_model import * 7 | from backend.task_backbone import task_dispatcher 8 | import warnings 9 | from onnx_test.cigar_onnx_model_test import mock_gnn_data, mock_no_gnn_data 10 | 11 | warnings.filterwarnings('ignore') 12 | 13 | def model_provider(args): 14 | """ 15 | Build the model func. 16 | input schema: 17 | args: user defined arguments dictionary 18 | output schema: 19 | model: user defined model that implements the torch.nn.Module interface 20 | """ 21 | if args.model == 'CIGAR': 22 | model = CIGAR(args=args, device = torch.cuda.current_device()) 23 | elif args.model == 'CIGAR_WO_CDGNN': 24 | model = CIGAR_WO_CDGNN(args=args, device = torch.cuda.current_device()) 25 | elif args.model == 'CIGAR_WO_PN': 26 | model = CIGAR_WO_PN(args=args, device = torch.cuda.current_device()) 27 | elif args.model == 'PNN': 28 | model = PNN(args=args, device = torch.cuda.current_device()) 29 | 30 | return model 31 | 32 | def get_batch(data_iterator, args): 33 | """ 34 | Generate a batch func. 35 | input schema: 36 | data_iterator: data iterator that implements the torch.utils.data.DataLoader interface 37 | args: user defined arguments dictionary 38 | output schema: 39 | dictionary (python dict()): a dictionary that contains all data used in the model forward step 40 | """ 41 | data = next(data_iterator) 42 | 43 | cuda_device = torch.cuda.current_device() 44 | data = dict((k, v.to(cuda_device)) for k, v in data.items()) 45 | 46 | ##################################################################### 47 | return data 48 | ##################################################################### 49 | 50 | def get_inference_batch(data_iterator, args): 51 | data = next(data_iterator) 52 | 53 | cuda_device = torch.cuda.current_device() 54 | data = dict((k, v.to(cuda_device)) for k, v in data.items()) 55 | 56 | return data 57 | 58 | def forward_func(data_iterator, model, args): 59 | """ 60 | Forward step. 61 | input schema: 62 | data_iterator: data iterator that implements the torch.utils.data.DataLoader interface 63 | model: a model that implements the torch.nn.Module interface and defined in the model_provider func 64 | args: user defined arguments dictionary 65 | output schema: 66 | loss: a one-dimensional loss vector that contains every sample's loss 67 | """ 68 | if args.task_type == 'train': 69 | data = get_batch(data_iterator, args) 70 | loss, *stats = model(data, data["label"]) 71 | return loss, stats 72 | 73 | # elif args.task_type == 'eval': 74 | # data = get_batch(data_iterator, args) 75 | # loss, label, score = model(data, args) 76 | # return loss, label, score 77 | 78 | else: 79 | data = get_batch(data_iterator, args) 80 | infer_res = model(data, data["label"]) 81 | return infer_res 82 | # (unique_id, input_ids, position_ids, token_type_ids, attention_mask, eos_indices, 83 | # gen_label_ids, gen_label_masks, cls_label_ids, cls_label_masks) = get_inference_batch(data_iterator, args) 84 | 85 | # infer_res = model( 86 | # input_ids, position_ids, token_type_ids, attention_mask, eos_indices, 87 | # gen_label_ids, gen_label_masks, cls_label_ids, cls_label_masks, unique_id=unique_id 88 | # ) 89 | # infer_res = list(infer_res) 90 | 91 | def train_eval_datasets_provider(args): 92 | """ 93 | Build train, valid, and test datasets. 94 | input schema: 95 | tokenizer: input sentence samples tokenizer 96 | args: user defined arguments dictionary 97 | output schema: 98 | train_dataset, valid_dataset, test_dataset: dataset that implements the torch.utils.data.Dataset interface 99 | """ 100 | # Build the dataset. 101 | input_tables = args.tables.split(",") 102 | 103 | #run on local 104 | dataset = CIGARDatasetLocal(args, input_tables[0]) 105 | print('train_eval_datasets_provider') 106 | 107 | if len(input_tables) > 1: 108 | #run on local 109 | eval_dataset = CIGARDatasetLocal(args, input_tables[1], is_test=True) 110 | else: 111 | eval_dataset = None 112 | 113 | return dataset, eval_dataset 114 | 115 | def personalized_args_provider(parser): 116 | def add_model_config_args(parser): 117 | """Model arguments""" 118 | 119 | group = parser.add_argument_group('model', 'model configuration') 120 | 121 | group.add_argument("--model", type=str, default='cigar', help="model") 122 | group.add_argument("--kv_dimension", type=int, default=8, help="dimension of each feature field") 123 | group.add_argument("--mem_dimension", type=int, default=40, help="dimension of memory") 124 | group.add_argument("--gnn_layers", type=str, default='40', help="dimension of GNN layer") 125 | group.add_argument("--dim_hidden", type=str, default='128,64,1', help="dimension of prediction layer") 126 | group.add_argument("--prototype_num", type=int, default=5, help="prototype_num") 127 | group.add_argument("--seq_length", type=int, default=100, help="length of sequence") 128 | group.add_argument("--column_len", type=int, default=29, help="length of column") 129 | group.add_argument("--user_fea_name", type=str, 130 | default='cms_segid,cms_group_id,final_gender_code,age_level,pvalue_level,shopping_level,occupation,new_user_class_level', 131 | help="user_fea_name") 132 | group.add_argument("--user_fea_col_id", type=str, default='1,2,3,4,5,6,7,8', help="user_fea_col_id") 133 | group.add_argument("--item_fea_name", type=str, default='adgroup_id,cate_id,campaign_id,customer,brand', help="item_fea_name") 134 | group.add_argument("--item_fea_col_id", type=str, default='20,21,22,23,24', help="item_fea_col_id") 135 | group.add_argument("--seq_col_id", type=str, default='13,14,15,16,17', help="seq_col_id") 136 | group.add_argument("--table_size", type=str, default='1150000,100,15,5,10,5,5,5,10,850000,13000,425000,260000,461500', help="embedding table size of uid, user fea and item fea") 137 | group.add_argument("--uid_graph_label_col_id", type=str, default='0,9,28', help="column ID of user, neighbors and label") 138 | 139 | group.add_argument("--onnx_step", type=bool, default=False, help="if onnx_step") 140 | 141 | return parser 142 | 143 | return add_model_config_args(parser) 144 | 145 | def inference_dataset_provider(args): 146 | 147 | input_tables = args.tables.split(",") 148 | 149 | dataset = CIGARDatasetLocal(args, input_tables[1], is_test=True) 150 | 151 | return dataset 152 | 153 | def onnx_model_export(args): 154 | ''' 155 | :param model: the trained model 156 | :param args: user defined arguments dictionary 157 | :return: None 158 | ''' 159 | def mock_data_provider(args): 160 | """ 161 | Build the model func. 162 | input schema: 163 | args: user defined arguments dictionary 164 | output schema: 165 | model: user defined model that implements the torch.nn.Module interface 166 | """ 167 | if args.model == 'CIGAR': 168 | data = mock_gnn_data() 169 | elif args.model == 'CIGAR_WO_CDGNN': 170 | data = mock_no_gnn_data() 171 | elif args.model == 'CIGAR_WO_PN': 172 | data = mock_gnn_data() 173 | elif args.model == 'PNN': 174 | data = mock_no_gnn_data() 175 | 176 | return data 177 | 178 | print("=====start onnx export======") 179 | import torch.onnx as onnx 180 | from backend.utils import load_model_state_only 181 | 182 | data = mock_data_provider(args) 183 | cuda_device = torch.cuda.current_device() 184 | data = dict((k, v.to(cuda_device)) for k, v in data.items()) 185 | y = data["label"] 186 | args.onnx_step = True 187 | args.task_type = "inference" 188 | model = model_provider(args) 189 | model.cuda(torch.cuda.current_device()) 190 | print(' >export onnx model number of parameters on rank{}'.format(sum([p.nelement() for p in model.parameters()])), flush=True) 191 | load_model_state_only(model, args, remove_prefix=None, remap_prefix=None, 192 | load_checkpoint_name=os.path.join(args.load,"rank_00_model_states.pt")) 193 | model.eval() 194 | model_path = os.path.join(args.save,"onnx_model_00.onnx") 195 | 196 | onnx.export(model, (data, y), model_path, export_params=True, verbose=False, opset_version=12) 197 | print("success save to:", model_path) 198 | 199 | if __name__ == "__main__": 200 | #running what task depend on args.task_type's value(train or inference or onnx_export) 201 | task_dispatcher(train_eval_dataset_provider=train_eval_datasets_provider, 202 | inference_dataset_provider=inference_dataset_provider, 203 | model_provider=model_provider, 204 | forward_func=forward_func, 205 | personalized_args_provider=personalized_args_provider, 206 | onnx_model_export_func=onnx_model_export) -------------------------------------------------------------------------------- /train_hub/ecrec.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import sys 3 | import os 4 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 5 | from backend.dataset_hub.ecrec_datasets import ECRecDatasetLocal 6 | from backend.model_hub.ecrec_model import ECRec 7 | import torch 8 | from backend.task_backbone import task_dispatcher 9 | from onnx_test.ecrec_onnx_test import mock_cloud_data 10 | 11 | def model_provider(args): 12 | """ 13 | Build the model func. 14 | input schema: 15 | args: user defined arguments dictionary 16 | output schema: 17 | model: user defined model that implements the torch.nn.Module interface 18 | """ 19 | 20 | model = ECRec(args=args, device = torch.cuda.current_device()) 21 | 22 | return model 23 | 24 | def get_batch(data_iterator, args): 25 | """ 26 | Generate a batch func. 27 | input schema: 28 | data_iterator: data iterator that implements the torch.utils.data.DataLoader interface 29 | args: user defined arguments dictionary 30 | output schema: 31 | dictionary (python dict()): a dictionary that contains all data used in the model forward step 32 | """ 33 | data = next(data_iterator) 34 | 35 | cuda_device = torch.cuda.current_device() 36 | data = dict((k, v.to(cuda_device)) for k, v in data.items()) 37 | 38 | return data 39 | 40 | def get_inference_batch(data_iterator, args): 41 | data = next(data_iterator) 42 | 43 | cuda_device = torch.cuda.current_device() 44 | data = dict((k, v.to(cuda_device)) for k, v in data.items()) 45 | 46 | return data 47 | 48 | def forward_func(data_iterator, model, args): 49 | """ 50 | Forward step. 51 | input schema: 52 | data_iterator: data iterator that implements the torch.utils.data.DataLoader interface 53 | model: a model that implements the torch.nn.Module interface and defined in the model_provider func 54 | args: user defined arguments dictionary 55 | output schema: 56 | loss: a one-dimensional loss vector that contains every sample's loss 57 | """ 58 | if args.task_type == 'train': 59 | 60 | data = get_batch(data_iterator, args) 61 | loss, score_avg = model(data) 62 | 63 | return loss, [loss] 64 | else: 65 | data = get_batch(data_iterator, args) 66 | loss, score_avg = model(data) 67 | return (score_avg, data['label']) 68 | 69 | def train_eval_datasets_provider(args): 70 | """ 71 | Build train, valid, and test datasets. 72 | input schema: 73 | tokenizer: input sentence samples tokenizer 74 | args: user defined arguments dictionary 75 | output schema: 76 | train_dataset, valid_dataset, test_dataset: dataset that implements the torch.utils.data.Dataset interface 77 | """ 78 | 79 | input_tables = args.tables.split(",") 80 | print('eval table: ', input_tables[1]) 81 | shuffle_buffer_size = 1 82 | 83 | dataset = ECRecDatasetLocal(args, input_tables[0], shuffle_buffer_size) 84 | 85 | if len(input_tables) > 1: 86 | eval_dataset = ECRecDatasetLocal(args, input_tables[1], shuffle_buffer_size, is_test=True) 87 | else: 88 | eval_dataset = None 89 | 90 | return dataset, eval_dataset 91 | 92 | def personalized_args_provider(parser): 93 | def add_model_config_args(parser): 94 | """Model arguments""" 95 | 96 | group = parser.add_argument_group('model', 'model configuration') 97 | 98 | # group.add_argument("--model", type=str, default='model', help="model") 99 | 100 | group.add_argument("--num_item", type=int, default=300437, help="number of item vocabulary. ml: 42876") 101 | 102 | group.add_argument("--num_cat", type=int, default=1921, help="number of category vocabulary. ml: 1505") 103 | 104 | group.add_argument("--num_user", type=int, default=47227, help="number of user vocabulary. ml: 9716") 105 | 106 | group.add_argument("--num_head", type=int, default=4, help="number of heads") 107 | 108 | group.add_argument("--d_model", type=int, default=16, help="model dimension") 109 | 110 | group.add_argument("--d_memory", type=int, default=16, help="memory dimension") 111 | 112 | group.add_argument("--length", type=int, default=100, help="length of sequence") 113 | group.add_argument("--seq_length", type=int, default=100, help="length of sequence") 114 | # group.add_argument("--dropout", type=float, default=0.1, help="dropout ratio") 115 | group.add_argument("--drop_rate", type=float, default=0.3, help="drop rate") 116 | # group.add_argument("--temp", type=float, default=1.0, help="temperature") 117 | group.add_argument("--K", type=int, default=4, help="K") 118 | group.add_argument("--l", type = float, default = 10.0, help="lambda value") 119 | group.add_argument("--gate", type=float, default=0.5, help="lambda value") 120 | 121 | group.add_argument("--column_length", type=int, default=9, help="length of column") 122 | group.add_argument("--sequence_length", type=int, default=100, help="length of sequence") 123 | 124 | group.add_argument("--with_cuda", type=bool, default=True, help="training with CUDA: true, or false") 125 | group.add_argument("--with_bn", type=bool, default=False, help="whether use batch norm") 126 | group.add_argument("--log_freq", type=int, default=10, help="printing loss every n iter: setting n") 127 | group.add_argument("--patience", type=int, default=3, help="patience") 128 | group.add_argument("--update_after_train", type=bool, default=True, help="update memory immediately after training") 129 | 130 | group.add_argument("--adam_beta1", type=float, default=0.9, help="adam first beta value") 131 | group.add_argument("--adam_beta2", type=float, default=0.999, help="adam first beta value") 132 | 133 | group.add_argument("--test_dataset", type=str, default=None, help="test set for evaluate train set") 134 | 135 | group.add_argument("--infer_table", type=str, default='', help="inference data") 136 | 137 | 138 | return parser 139 | 140 | return add_model_config_args(parser) 141 | 142 | def inference_dataset_provider(args): 143 | 144 | input_table = args.infer_table 145 | 146 | shuffle_buffer_size = 1 147 | dataset = ECRecDatasetLocal(args, input_table, shuffle_buffer_size, is_test=True) 148 | 149 | return dataset 150 | 151 | def training_post_processing_func(model, args): 152 | ''' 153 | :param model: the trained model 154 | :param args: user defined arguments dictionary 155 | :return: None 156 | ''' 157 | from backend.utils import save_checkpoint, print_rank_0 158 | from backend.task_backbone import make_inference_data_loader 159 | infer_dataset, _ = train_eval_datasets_provider(args) 160 | def update_dataset_provider(args): 161 | return infer_dataset 162 | infer_data = make_inference_data_loader(args, update_dataset_provider) 163 | infer_data_iterator = iter(infer_data) 164 | model.train() 165 | model.task_type = 'inference' 166 | iter_index = 0 167 | num_workers = torch.distributed.get_world_size() 168 | infer_iters = args.train_iters // args.num_epochs + num_workers 169 | 170 | with torch.no_grad(): 171 | while iter_index < infer_iters: 172 | forward_func(infer_data_iterator, model, args) 173 | iter_index += 1 174 | print_rank_0('memory is updated!') 175 | model.task_type = 'train' 176 | save_checkpoint(args.iteration + 1, model, None, args) 177 | 178 | def onnx_model_export(args): 179 | ''' 180 | :param args: user defined arguments dictionary 181 | :return: None 182 | ''' 183 | print("*****running task : export ONNX model*****") 184 | import torch.onnx as onnx 185 | from backend.utils import load_model_state_only 186 | 187 | def mock_data_provider(): 188 | ''' 189 | :param model: the trained model 190 | :param args: user defined arguments dictionary 191 | :return: the mock data for testing the ONNX model 192 | ''' 193 | data = mock_cloud_data() 194 | return data 195 | 196 | data = mock_data_provider() 197 | if args.cuda: 198 | cuda_device = torch.cuda.current_device() 199 | data = dict((k, v.to(cuda_device)) for k, v in data.items()) 200 | args.onnx_step = True 201 | args.task_type = 'inference' 202 | model = model_provider(args) 203 | print(' >export onnx model number of parameters on rank{}'.format(sum([p.nelement() for p in model.parameters()])), flush=True) 204 | if args.cuda: 205 | model.cuda(torch.cuda.current_device()) 206 | 207 | load_model_state_only(model, args, remove_prefix=None, remap_prefix=None) 208 | model.eval() 209 | model_path = os.path.join(args.onnx_export_path, args.onnx_model_name) 210 | 211 | onnx.export(model, data, model_path, export_params=True, verbose=False, opset_version=12) 212 | print('success save to:', model_path) 213 | 214 | if __name__ == "__main__": 215 | task_dispatcher(train_eval_dataset_provider=train_eval_datasets_provider, 216 | inference_dataset_provider=inference_dataset_provider, 217 | model_provider=model_provider, 218 | forward_func=forward_func, 219 | personalized_args_provider=personalized_args_provider, 220 | training_post_processing_func=training_post_processing_func, 221 | onnx_model_export_func=onnx_model_export) -------------------------------------------------------------------------------- /backend/dataset_hub/ggcn_datasets.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2022 The Luoxi Team. 3 | # All rights reserved. 4 | # This source code is licensed under the Apache 2.0 license 5 | # found in the LICENSE file in the root directory. 6 | 7 | import torch 8 | from backend.dataset_hub.base_datasets import * 9 | import numpy as np 10 | import scipy.sparse as sp 11 | import sys 12 | sys.setrecursionlimit(99999) 13 | 14 | class GGCNDataset(BaseDataset): 15 | def __init__(self, 16 | args, 17 | table_name, 18 | shuffle_buffer_size, 19 | is_test=False, 20 | max_len=-1, 21 | max_neighbor=-1): 22 | super(GGCNDataset, self).__init__(args, table_name, shuffle_buffer_size, is_test) 23 | self.maxlen = max_len 24 | self.max_neighbor = max_neighbor 25 | 26 | def run_dfs(adj, msk, u, ind, nb_nodes): 27 | """ 28 | data preprocess, run dfs_split function for each edge in adj 29 | input schema: 30 | adj, mask, index, nodes 31 | output schema: 32 | ret: nodes for each edge in adj 33 | """ 34 | if msk[u] == -1: 35 | msk[u] = ind 36 | for v in adj[u, :].nonzero()[1]: 37 | run_dfs(adj, msk, v, ind, nb_nodes) 38 | 39 | def dfs_split(adj): 40 | """ 41 | data preprocess, split data with dfs method 42 | input schema: 43 | adj 44 | output schema: 45 | ret: split nodes 46 | """ 47 | nb_nodes = adj.shape[0] 48 | ret = np.full(nb_nodes, -1, dtype=np.int32) 49 | 50 | graph_id = 0 51 | 52 | for i in range(nb_nodes): 53 | if ret[i] == -1: 54 | run_dfs(adj, ret, i, graph_id, nb_nodes) 55 | graph_id += 1 56 | 57 | return ret 58 | 59 | class GGCNDatasetLocal(GGCNDataset): 60 | def __init__(self, 61 | args, 62 | table_name, 63 | shuffle_buffer_size=8194, 64 | is_test=False, 65 | max_len=100, 66 | max_neighbor=10,traintype='train'): 67 | 68 | super(GGCNDatasetLocal,self).__init__(args, table_name, shuffle_buffer_size, is_test, max_len, max_neighbor) 69 | self.table_name = table_name 70 | self.train_adj_list = [] 71 | self.val_adj_list = [] 72 | self.test_adj_list = [] 73 | self.train_feat = 0 74 | self.val_feat = 0 75 | self.test_feat = 0 76 | self.train_labels = 0 77 | self.val_labels = 0 78 | self.test_labels = 0 79 | self.train_nodes = [] 80 | self.val_nodes = [] 81 | self.test_nodes = [] 82 | self.traintype = traintype 83 | if self.traintype == 'train': 84 | [self.train_adj_list, self.val_adj_list, self.test_adj_list, self.train_feat, self.val_feat, self.test_feat, 85 | self.train_labels, self.val_labels, self.test_labels, self.train_nodes, self.val_nodes, self.test_nodes] = torch.load('%s/traingraphL.pth'%table_name) 86 | 87 | self.loadData() 88 | else: 89 | [self.train_adj_list, self.val_adj_list, self.test_adj_list, self.train_feat, self.val_feat, self.test_feat, 90 | self.train_labels, self.val_labels, self.test_labels, self.train_nodes, self.val_nodes, 91 | self.test_nodes] = torch.load('%s/traingraphL.pth' % table_name) 92 | l = [] 93 | # avoid pin error in sparse cuda 94 | for i in self.val_adj_list: 95 | e = i.to_dense() 96 | l.append(e) 97 | self.val_adj_list = l 98 | print('load val data') 99 | 100 | def get_total_row_count(self): 101 | self.caoncat = 160 102 | self.reader.seek(0) 103 | return self.caoncat * 20 104 | 105 | def _new_reader(self): 106 | if self.reader is not None: 107 | self.reader.close() 108 | print('self.table_name', self.table_name) 109 | reader = open('%s/ppi-id_map.json'%self.table_name, "r") 110 | return reader 111 | 112 | def _read_record(self): 113 | try: 114 | column_l = self.reader.readline().strip().split("\t") 115 | # print(column_l) 116 | assert len(column_l) == self.args.column_len, "len(column_l) must be %d, now is %d" % (self.args.column_len, len(column_l)) 117 | except: 118 | self.reader.seek(0) 119 | column_l = self.reader.readline().strip().split("\t") 120 | return column_l 121 | 122 | def find_split(self,adj, mapping, ds_label): 123 | """ 124 | split data into train/val/test following previous works, 125 | where get relation between id sub-graph and tran,val or test set 126 | input schema: 127 | adj, mapping, labels 128 | output schema: 129 | dict_splits: splitting dictionary 130 | """ 131 | nb_nodes = adj.shape[0] 132 | dict_splits={} 133 | for i in range(nb_nodes): 134 | for j in adj[i, :].nonzero()[1]: 135 | if mapping[i]==0 or mapping[j]==0: 136 | dict_splits[0]=None 137 | elif mapping[i] == mapping[j]: 138 | if ds_label[i]['val'] == ds_label[j]['val'] and ds_label[i]['test'] == ds_label[j]['test']: 139 | 140 | if mapping[i] not in dict_splits.keys(): 141 | if ds_label[i]['val']: 142 | dict_splits[mapping[i]] = 'val' 143 | 144 | elif ds_label[i]['test']: 145 | dict_splits[mapping[i]]='test' 146 | 147 | else: 148 | dict_splits[mapping[i]] = 'train' 149 | 150 | else: 151 | if ds_label[i]['test']: 152 | ind_label='test' 153 | elif ds_label[i]['val']: 154 | ind_label='val' 155 | else: 156 | ind_label='train' 157 | if dict_splits[mapping[i]]!= ind_label: 158 | print ('inconsistent labels within a graph exiting!!!') 159 | return None 160 | else: 161 | print ('label of both nodes different, exiting!!') 162 | return None 163 | return dict_splits 164 | 165 | def sparse_mx_to_torch_sparse_tensor(self,sparse_mx): 166 | """ 167 | Convert a scipy sparse matrix to a torch sparse tensor. 168 | input schema: 169 | sparse matrix: numpy 170 | output schema: 171 | sparse_matrix_torch: torch.sparse.FloatTensor(indices, values, shape) 172 | """ 173 | sparse_mx = sparse_mx.tocoo().astype(np.float32) 174 | indices = torch.from_numpy( 175 | np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64)) 176 | values = torch.from_numpy(sparse_mx.data) 177 | shape = torch.Size(sparse_mx.shape) 178 | return torch.sparse.FloatTensor(indices, values, shape) 179 | 180 | def sys_normalized_adjacency(self,adj): 181 | """ 182 | convert adj to normalized adj 183 | input schema: 184 | adj 185 | output schema: 186 | normalized adj 187 | """ 188 | adj = sp.coo_matrix(adj) 189 | adj = adj + sp.eye(adj.shape[0]) 190 | row_sum = np.array(adj.sum(1)) 191 | row_sum = (row_sum == 0) * 1 + row_sum 192 | d_inv_sqrt = np.power(row_sum, -0.5).flatten() 193 | d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0. 194 | d_mat_inv_sqrt = sp.diags(d_inv_sqrt) 195 | return d_mat_inv_sqrt.dot(adj).dot(d_mat_inv_sqrt).tocoo() 196 | 197 | def loadData(self): 198 | """ 199 | add data, data augment with adj, feature, labels and nodes index 200 | input schema: 201 | adj, labels, nodes, features 202 | output schema: 203 | augmented adj, labels, nodes, features 204 | """ 205 | i = self.train_adj_list 206 | 207 | newi = [] 208 | for t in range(self.caoncat): 209 | newi = newi + i 210 | 211 | self.train_adj_list = newi 212 | i = self.train_feat 213 | newi = i.repeat([self.caoncat, 1, 1]) 214 | 215 | self.train_feat =newi 216 | i = self.train_labels 217 | newi = i.repeat([self.caoncat, 1, 1]) 218 | 219 | self.train_labels=(newi) 220 | i = self.train_nodes 221 | newi = np.tile(i, (self.caoncat)) 222 | self.train_nodes = newi 223 | 224 | return 225 | 226 | def __getitem__(self, idx): 227 | """ 228 | item for different type, 229 | contain adj, feature, label and nodes index 230 | """ 231 | if self.traintype == 'train': 232 | adj = self.train_adj_list[idx] 233 | feat = self.train_feat[idx] 234 | labels = self.train_labels[idx] 235 | nodes = self.train_nodes[[idx]] 236 | elif self.traintype == 'eval': 237 | adj = self.val_adj_list[idx] 238 | feat = self.val_feat[idx] 239 | labels = self.val_labels[idx] 240 | nodes = self.val_nodes[[idx]] 241 | else: 242 | adj = self.test_adj_list[idx] 243 | feat = self.test_feat[idx] 244 | labels = self.test_labels[idx] 245 | nodes = self.test_nodes[[idx]] 246 | 247 | return feat, adj, labels, nodes, self.traintype 248 | 249 | def __len__(self): 250 | """ 251 | get len for different type 252 | """ 253 | if self.traintype == 'train': 254 | return len(self.train_adj_list) 255 | elif self.traintype == 'eval': 256 | return 2 257 | else: 258 | return 2 259 | 260 | def __del__(self): 261 | if self.reader is not None: 262 | self.reader.close() 263 | -------------------------------------------------------------------------------- /train_hub/ggcn.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import sys 3 | import os 4 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 5 | from backend.dataset_hub.ggcn_datasets import GGCNDatasetLocal 6 | from backend.model_hub.ggcn_model import * 7 | import torch 8 | from backend.task_backbone import task_dispatcher 9 | import numpy as np 10 | from sklearn.metrics import f1_score 11 | from onnx_test.ggcn_onnx_model_test import mock_gnn_data 12 | 13 | 14 | def model_provider(args): 15 | """ 16 | Build the model func. 17 | input schema: 18 | args: user defined arguments dictionary 19 | output schema: 20 | model: user defined model that implements the torch.nn.Module interface 21 | """ 22 | model = GGCN(nfeat=args.nfeat, 23 | nlayers=args.layer, 24 | nhidden=args.hidden, 25 | nclass=args.nclass, 26 | dropout=args.dropout, 27 | lamda = args.lamda, 28 | alpha=args.alpha, 29 | variant=args.variant,args=args,onnx_step=args.onnx_step).to(torch.cuda.current_device()) 30 | 31 | return model 32 | 33 | def get_batch(data_iterator, args): 34 | """ 35 | Generate a batch func. 36 | input schema: 37 | data_iterator: data iterator that implements the torch.utils.data.DataLoader interface 38 | args: user defined arguments dictionary 39 | output schema: 40 | dictionary (python dict()): a dictionary that contains all data used in the model forward step 41 | """ 42 | 43 | 44 | device = torch.cuda.current_device() 45 | 46 | feat, adj, labels, nodes,traintype = next(data_iterator) 47 | 48 | nodes = nodes.to(device) 49 | labels = labels[0].to(device) 50 | feat = feat[0].to(device)#.to_dense() 51 | adj = adj[0].to(device) 52 | return feat, adj, labels, nodes, traintype[0] 53 | 54 | def get_inference_batch(data_iterator, args): 55 | """ 56 | Generate a batch func. 57 | input schema: 58 | data_iterator: data iterator that implements the torch.utils.data.DataLoader interface 59 | args: user defined arguments dictionary 60 | output schema: 61 | dictionary (python dict()): a dictionary that contains all data used in the model forward step 62 | """ 63 | data = next(data_iterator) 64 | 65 | cuda_device = torch.cuda.current_device() 66 | data = dict((k, v.to(cuda_device)) for k, v in data.items()) 67 | 68 | return data 69 | 70 | def forward_func(data_iterator, model, args): 71 | """ 72 | Forward step. 73 | input schema: 74 | data_iterator: data iterator that implements the torch.utils.data.DataLoader interface 75 | model: a model that implements the torch.nn.Module interface and defined in the model_provider func 76 | args: user defined arguments dictionary 77 | output schema: 78 | loss: a one-dimensional loss vector that contains every sample's loss 79 | """ 80 | if args.task_type == "inference": 81 | try: 82 | feat, adj, labels, nodes, traintype = get_batch(data_iterator, args) 83 | except: 84 | return [torch.tensor([0,0])] 85 | else: 86 | feat, adj, labels, nodes,traintype = get_batch(data_iterator, args) 87 | 88 | if traintype == 'train': 89 | 90 | output = model(feat, adj) 91 | lossfn = torch.nn.BCELoss() 92 | loss = lossfn(output[:nodes], labels[:nodes]) 93 | predict = np.where(output[:nodes].data.cpu().numpy() > 0.5, 1, 0) 94 | score = f1_score(labels[:nodes].data.cpu().numpy(), predict, average='micro') 95 | score = torch.tensor(score).to(torch.cuda.current_device()) 96 | 97 | return loss, [score] 98 | elif traintype == 'eval': 99 | lossfn = torch.nn.BCELoss() 100 | loss,score = model.module.evalModel() 101 | return loss,score 102 | elif traintype == 'test': 103 | loss,score = model.module.inferModel() 104 | infer_res = [torch.tensor([loss,loss])] 105 | return infer_res 106 | else: 107 | return 0,0 108 | 109 | def train_eval_datasets_provider(args): 110 | """ 111 | Build train, valid, and test datasets. 112 | input schema: 113 | tokenizer: input sentence samples tokenizer 114 | args: user defined arguments dictionary 115 | output schema: 116 | train_dataset, valid_dataset, test_dataset: dataset that implements the torch.utils.data.Dataset interface 117 | """ 118 | 119 | # Build the dataset. 120 | input_tables = args.tables.split(",") 121 | 122 | #run on local 123 | dataset = GGCNDatasetLocal(args, input_tables[0],traintype='train') 124 | print('train_eval_datasets_provider') 125 | 126 | if len(input_tables) > 1: 127 | #run on local 128 | eval_dataset = GGCNDatasetLocal(args, input_tables[1],traintype='eval') 129 | else: 130 | eval_dataset = None 131 | 132 | return dataset, eval_dataset 133 | 134 | def personalized_args_provider(parser): 135 | def add_model_config_args(parser): 136 | """Model arguments""" 137 | 138 | group = parser.add_argument_group('model', 'model configuration') 139 | 140 | group.add_argument("--model", type=str, default='cigar', help="model") 141 | group.add_argument("--kv_dimension", type=int, default=8, help="dimension of each feature field") 142 | group.add_argument("--mem_dimension", type=int, default=40, help="dimension of memory") 143 | group.add_argument("--gnn_layers", type=str, default='40', help="dimension of GNN layer") 144 | group.add_argument("--dim_hidden", type=str, default='128,64,1', help="dimension of prediction layer") 145 | group.add_argument("--prototype_num", type=int, default=5, help="prototype_num") 146 | group.add_argument("--seq_length", type=int, default=100, help="length of sequence") 147 | group.add_argument("--column_len", type=int, default=29, help="length of column") 148 | group.add_argument("--user_fea_name", type=str, 149 | default='cms_segid,cms_group_id,final_gender_code,age_level,pvalue_level,shopping_level,occupation,new_user_class_level', 150 | help="user_fea_name") 151 | group.add_argument("--user_fea_col_id", type=str, default='1,2,3,4,5,6,7,8', help="user_fea_col_id") 152 | group.add_argument("--item_fea_name", type=str, default='adgroup_id,cate_id,campaign_id,customer,brand', help="item_fea_name") 153 | group.add_argument("--item_fea_col_id", type=str, default='20,21,22,23,24', help="item_fea_col_id") 154 | group.add_argument("--seq_col_id", type=str, default='13,14,15,16,17', help="seq_col_id") 155 | group.add_argument("--table_size", type=str, default='1150000,100,15,5,10,5,5,5,10,850000,13000,425000,260000,461500', help="embedding table size of uid, user fea and item fea") 156 | group.add_argument("--uid_graph_label_col_id", type=str, default='0,9,28', help="column ID of user, neighbors and label") 157 | group.add_argument("--epochs", type=int, default=8000, help='Number of epochs to train.') 158 | group.add_argument("--wd", type=float, default=0, help='Weight decay (L2 loss on parameters).') 159 | group.add_argument("--layer", type=int, default=9, help='Number of hidden layers.') 160 | group.add_argument("--hidden", type=int, default=2048, help='Number of hidden layers.') 161 | group.add_argument("--nfeat", type=int, default=50, help='Number of feature .') 162 | group.add_argument("--nclass", type=int, default=121, help='Number of classes.') 163 | group.add_argument("--dropout", type=float, default=0.2, help='Dropout rate (1 - keep probability).') 164 | group.add_argument("--patience", type=int, default=2000, help='Patience') 165 | group.add_argument("--data", default='ppi', help='dateset') 166 | group.add_argument("--dev", type=int, default=0, help='device id') 167 | group.add_argument("--alpha", type=float, default=0.5, help='alpha_l') 168 | group.add_argument("--lamda", type=float, default=1, help='lamda.') 169 | group.add_argument("--variant", action='store_true', default=False, help='GCN* model.') 170 | group.add_argument("--test", action='store_true', default=False, help='evaluation on test set.') 171 | 172 | group.add_argument("--onnx_step", type=bool, default=False, help="if onnx_step") 173 | 174 | group.add_argument("--load_model_path", type=str, default='',help="load_model_path") 175 | group.add_argument("--is_gpu", type=bool,default=True,help="if gpu") 176 | group.add_argument("--final_saved_iteration", type=int, default=0, help="if gpu") 177 | return parser 178 | 179 | return add_model_config_args(parser) 180 | 181 | def inference_dataset_provider(args): 182 | """ 183 | Build train, valid, and test datasets for inference. 184 | input schema: 185 | tokenizer: input sentence samples tokenizer 186 | args: user defined arguments dictionary 187 | output schema: 188 | train_dataset, valid_dataset, test_dataset: dataset that implements the torch.utils.data.Dataset interface 189 | """ 190 | input_tables = args.tables.split(",") 191 | 192 | dataset = GGCNDatasetLocal(args, input_tables[1],traintype='test') 193 | 194 | return dataset 195 | 196 | def onnx_model_export(args): 197 | ''' 198 | :param model: the trained model 199 | :param args: user defined arguments dictionary 200 | :return: None 201 | ''' 202 | print("=====start onnx export======") 203 | import torch.onnx as onnx 204 | from backend.utils import load_model_state_only 205 | load_checkpoint_name =os.path.join(args.load_model_path,"rank_00_model_states.pt") 206 | 207 | data = mock_gnn_data(args) 208 | cuda_device = torch.cuda.current_device() 209 | data = dict((k, v.to(cuda_device)) for k, v in data.items()) 210 | y = data["label"] 211 | args.onnx_step = True 212 | args.task_type = "inference" 213 | model = model_provider(args) 214 | print(' >export onnx model number of parameters on rank{}'.format(sum([p.nelement() for p in model.parameters()])), flush=True) 215 | if args.is_gpu: 216 | model.cuda(torch.cuda.current_device()) 217 | 218 | load_model_state_only(model, args, remove_prefix=None, remap_prefix=None, load_checkpoint_name=load_checkpoint_name) 219 | model.eval() 220 | model_path = os.path.join(args.onnx_export_path, args.onnx_model_name) 221 | 222 | onnx.export(model, (data, y), model_path, export_params=True, verbose=False, opset_version=12) 223 | print("success save to:", model_path) 224 | 225 | if __name__ == "__main__": 226 | task_dispatcher(train_eval_dataset_provider=train_eval_datasets_provider, 227 | inference_dataset_provider=inference_dataset_provider, 228 | model_provider=model_provider, 229 | forward_func=forward_func, 230 | personalized_args_provider=personalized_args_provider, 231 | onnx_model_export_func=onnx_model_export) -------------------------------------------------------------------------------- /train_hub/pinrec.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import sys 3 | import os 4 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 5 | from backend.dataset_hub.pin_datasets import PINDatasetLocal 6 | from backend.model_hub import pinrec_model 7 | import torch 8 | from torch import nn 9 | from torch.nn.parallel.distributed import DistributedDataParallel 10 | from backend.task_backbone import task_dispatcher 11 | from backend.utils import parse_arch_config_from_args 12 | from onnx_test.pinrec_onnx_model_test import mock_data 13 | 14 | trunk_layer_set = set() 15 | 16 | def model_provider(args): 17 | """ 18 | Build the model func. 19 | input schema: 20 | args: user defined arguments dictionary 21 | output schema: 22 | model: user defined model that implements the torch.nn.Module interface 23 | """ 24 | model_plugin = pinrec_model.get_model_meta(args.model) 25 | model_plugin_conf, raw_model_plugin_conf = parse_arch_config_from_args(model_plugin, args) # type: dict 26 | model = model_plugin.model_builder(model_conf=model_plugin_conf, group_num=args.group_num) 27 | for name, parms in model.named_parameters(): 28 | if "plugin" not in name: 29 | trunk_layer_set.add(name) 30 | return model 31 | 32 | def get_batch(data_iterator, args): 33 | """ 34 | Generate a batch func. 35 | input schema: 36 | data_iterator: data iterator that implements the torch.utils.data.DataLoader interface 37 | args: user defined arguments dictionary 38 | output schema: 39 | dictionary (python dict()): a dictionary that contains all data used in the model forward step 40 | """ 41 | data = next(data_iterator) 42 | 43 | user_id = data[args.consts["FIELD_USER_ID"]].long() 44 | target_id = data[args.consts["FIELD_TARGET_ID"]].long() 45 | clk_seq = data[args.consts["FIELD_CLK_SEQUENCE"]].long() 46 | label = data[args.consts["FIELD_LABEL"]].long() 47 | group_id = data[args.consts["FIELD_GROUP_ID"]].long() 48 | 49 | data = { 50 | args.consts["FIELD_USER_ID"]: user_id, 51 | args.consts["FIELD_TARGET_ID"]: target_id, 52 | args.consts["FIELD_CLK_SEQUENCE"]: clk_seq, 53 | args.consts["FIELD_LABEL"]: label, 54 | args.consts["FIELD_GROUP_ID"]: group_id 55 | } 56 | return data 57 | 58 | def get_inference_batch(data_iterator, args): 59 | data = next(data_iterator) 60 | 61 | user_id = data[args.consts["FIELD_USER_ID"]].long() 62 | target_id = data[args.consts["FIELD_TARGET_ID"]].long() 63 | clk_seq = data[args.consts["FIELD_CLK_SEQUENCE"]].long() 64 | label = data[args.consts["FIELD_LABEL"]].long() 65 | group_id = data[args.consts["FIELD_GROUP_ID"]].long() 66 | 67 | data = { 68 | args.consts["FIELD_USER_ID"]: user_id, 69 | args.consts["FIELD_TARGET_ID"]: target_id, 70 | args.consts["FIELD_CLK_SEQUENCE"]: clk_seq, 71 | args.consts["FIELD_LABEL"]: label, 72 | args.consts["FIELD_GROUP_ID"]: group_id 73 | } 74 | 75 | return data 76 | 77 | def forward_func(data_iterator, model, args): 78 | """ 79 | Forward step. 80 | input schema: 81 | data_iterator: data iterator that implements the torch.utils.data.DataLoader interface 82 | model: a model that implements the torch.nn.Module interface and defined in the model_provider func 83 | args: user defined arguments dictionary 84 | output schema: 85 | loss: a one-dimensional loss vector that contains every sample's loss 86 | """ 87 | device = torch.cuda.current_device() 88 | criterion = nn.BCEWithLogitsLoss() 89 | if isinstance(model, DistributedDataParallel): 90 | model = model.module 91 | 92 | if args.task_type == 'train': 93 | # Calculate iters so that switch stage1 to stage2 94 | args.stage_switch_iters = int(args.stage_switch_epoch / args.num_epochs * args.train_iters) 95 | # stage1->trunk model stage2->add plugin model 96 | plugin = True if (args.iteration > args.stage_switch_iters) else False 97 | # Reduce learning rate at stage2 98 | if args.iteration - args.stage_switch_iters == 1: 99 | for group in args.optimizer_runtime.param_groups: 100 | group['lr'] = group['lr'] / 10 101 | batch_data = get_batch(data_iterator, args) 102 | loss = 0.0 103 | stats = [] 104 | device = torch.cuda.current_device() 105 | gradient_dict = {} 106 | for name, parms in model.named_parameters(): 107 | gradient_dict[name] = torch.zeros((args.group_num, parms.view(-1).size()[0])).to(device) 108 | 109 | for group_index in range(args.group_num): 110 | group_index_tensor = torch.LongTensor([group_index]).repeat(batch_data[args.consts["FIELD_GROUP_ID"]].size()) 111 | 112 | if len(batch_data[args.consts["FIELD_LABEL"]][torch.where(batch_data[args.consts["FIELD_GROUP_ID"]] == group_index_tensor)]) == 0: 113 | continue 114 | # set plugin module index according to group index 115 | model.set_plugin_index(group_index) 116 | 117 | logits = model({ 118 | key: value[torch.where(batch_data[args.consts["FIELD_GROUP_ID"]] == group_index_tensor)].to(device) 119 | for key, value in batch_data.items() 120 | if key not in {args.consts["FIELD_USER_ID"], args.consts["FIELD_LABEL"], args.consts["FIELD_GROUP_ID"]} 121 | }, plugin=plugin) 122 | 123 | loss_item = criterion(logits, batch_data[args.consts["FIELD_LABEL"]][torch.where(batch_data[args.consts["FIELD_GROUP_ID"]] == group_index_tensor)].float().view(-1, 1).to(device)) 124 | data_lens = batch_data[args.consts["FIELD_LABEL"]][torch.where(batch_data[args.consts["FIELD_GROUP_ID"]] == group_index_tensor)].size()[0] 125 | loss += loss_item * data_lens 126 | 127 | loss_item.backward() 128 | 129 | for name, parms in model.named_parameters(): 130 | # Record the gradient of the trunk model calculated from each group 131 | if name in trunk_layer_set: 132 | gradient_dict[name][group_index] = parms.grad.view(-1) 133 | 134 | stats.append(loss_item) 135 | # Aggregate the gradient of the trunk model, 136 | # you can also customize other aggregation methods 137 | for name, parms in model.named_parameters(): 138 | if name in trunk_layer_set: 139 | parms.grad = torch.mean(gradient_dict[name], 0).reshape(parms.grad.size()) 140 | 141 | loss = loss / args.batch_size # calculate total loss 142 | return loss, stats 143 | else: 144 | infer_res_list = [] 145 | batch_data = get_inference_batch(data_iterator, args) 146 | 147 | for group_index in range(args.group_num): 148 | group_index_tensor = torch.LongTensor([group_index]).repeat(batch_data[args.consts["FIELD_GROUP_ID"]].size()) 149 | if len(batch_data[args.consts["FIELD_LABEL"]][torch.where(batch_data[args.consts["FIELD_GROUP_ID"]] == group_index_tensor)]) == 0: 150 | continue 151 | # set plugin module index according to group index 152 | model.set_plugin_index(group_index) 153 | 154 | infer_res = model({ 155 | key: value[torch.where(batch_data[args.consts["FIELD_GROUP_ID"]] == group_index_tensor)].to(device) 156 | for key, value in batch_data.items() 157 | if key not in {args.consts["FIELD_USER_ID"], args.consts["FIELD_LABEL"], args.consts["FIELD_GROUP_ID"]} 158 | }, plugin=True) # Plugin is true during inference 159 | infer_res = torch.sigmoid(infer_res) 160 | 161 | infer_res_list.extend(infer_res) 162 | 163 | return infer_res_list 164 | 165 | def train_eval_datasets_provider(args): 166 | """ 167 | Build train, valid, and test datasets. 168 | input schema: 169 | tokenizer: input sentence samples tokenizer 170 | args: user defined arguments dictionary 171 | output schema: 172 | train_dataset, valid_dataset, test_dataset: dataset that implements the torch.utils.data.Dataset interface 173 | """ 174 | 175 | # Build the dataset. 176 | input_tables = args.tables.split(",") 177 | args.consts = pinrec_model.consts 178 | # run on local 179 | dataset = PINDatasetLocal(args, input_tables[0], is_test=False) 180 | eval_dataset = None 181 | return dataset, eval_dataset 182 | 183 | def personalized_args_provider(parser): 184 | def add_model_config_args(parser): 185 | """Model arguments""" 186 | 187 | group = parser.add_argument_group('model', 'model configuration') 188 | 189 | parser.add_argument("--model", type=str, help="Model type") 190 | parser.add_argument("--group_num", type=int, default=5, help="Number of user groups") 191 | parser.add_argument("--arch_config_path", type=str, default=None, help="Path of model configs") 192 | parser.add_argument("--arch_config", type=str, default=None, help="base64-encoded model configs") 193 | parser.add_argument('--stage_switch_epoch', 194 | type=int, default=2, 195 | help='Number of training epochs (stage1)') 196 | return parser 197 | 198 | return add_model_config_args(parser) 199 | 200 | def inference_dataset_provider(args): 201 | input_tables = args.tables.split(",") 202 | args.consts = pinrec_model.consts 203 | dataset = PINDatasetLocal(args, input_tables[1], is_test=True) 204 | return dataset 205 | 206 | def onnx_model_export(args): 207 | ''' 208 | :param model: the trained model 209 | :param args: user defined arguments dictionary 210 | :return: None 211 | ''' 212 | print("=====start onnx export======") 213 | import torch.onnx as onnx 214 | from backend.utils import load_model_state_only 215 | 216 | with open(os.path.join(args.load, 'latest_iteration.txt')) as f: 217 | for line in f: 218 | folder = line.strip() 219 | break 220 | load_checkpoint_name = os.path.join(args.load, folder, "rank_00_model_states.pt") 221 | data = mock_data() 222 | cuda_device = torch.cuda.current_device() 223 | 224 | data = dict((k, v.to(cuda_device)) for k, v in data.items()) 225 | plugin = torch.Tensor([1]) 226 | args.onnx_step = True 227 | args.task_type = "inference" 228 | model = model_provider(args) 229 | print(' >export onnx model number of parameters on rank{}'.format(sum([p.nelement() for p in model.parameters()])), flush=True) 230 | model.cuda(torch.cuda.current_device()) 231 | 232 | load_model_state_only(model, args, remove_prefix=None, remap_prefix=None, load_checkpoint_name=load_checkpoint_name) 233 | model.eval() 234 | model_path = os.path.join(args.onnx_export_path, args.onnx_model_name) 235 | 236 | onnx.export(model, (data, plugin), model_path, export_params=True, verbose=False, opset_version=12) 237 | print("success save to:", model_path) 238 | 239 | if __name__ == "__main__": 240 | task_dispatcher(train_eval_dataset_provider=train_eval_datasets_provider, 241 | inference_dataset_provider=inference_dataset_provider, 242 | model_provider=model_provider, 243 | forward_func=forward_func, 244 | personalized_args_provider=personalized_args_provider, 245 | onnx_model_export_func=onnx_model_export) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 1999-2022 Alibaba Group Holding Ltd. 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /backend/model_hub/ecrec_model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2022 The Luoxi Team. 3 | # All rights reserved. 4 | # This source code is licensed under the Apache 2.0 license 5 | # found in the LICENSE file in the root directory. 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import math 11 | 12 | 13 | class ECRec(nn.Module): 14 | def __init__(self, args, device): 15 | super(ECRec, self).__init__() 16 | self.linear1 = nn.Linear(args.d_model*2 + args.d_memory * 2, 256) 17 | nn.init.xavier_uniform_(self.linear1.weight, gain=1.0) 18 | nn.init.zeros_(self.linear1.bias) 19 | self.linear2 = nn.Linear(256, 128) 20 | nn.init.xavier_uniform_(self.linear2.weight, gain=1.0) 21 | nn.init.zeros_(self.linear2.bias) 22 | self.linear3 = nn.Linear(128, 2) 23 | nn.init.xavier_uniform_(self.linear3.weight, gain=1.0) 24 | nn.init.zeros_(self.linear3.bias) 25 | self.emb = Embedding(args.num_user, args.num_item, args.num_cat, d_model=args.d_model, d_mem=args.d_memory) 26 | if args.task_type != 'inference': 27 | self.agg = Aggregation(d_model=args.d_model * 2, num_head=args.num_head) 28 | self.length = args.length 29 | self.K = args.K 30 | self.drop_rate = args.drop_rate 31 | # self.temp = args.temp 32 | self.device = device 33 | self.prelu1 = nn.PReLU() 34 | self.prelu2 = nn.PReLU() 35 | self.bn = nn.BatchNorm1d(args.d_model*2 + args.d_memory * 2) 36 | self.l = args.l 37 | self.task_type = args.task_type 38 | self.gate_linear1 = nn.Linear(2, 16) 39 | self.gate_linear2 = nn.Linear(16, 1) 40 | self.prelu_gate = nn.PReLU() 41 | self.gate = args.gate 42 | 43 | def forward(self, input, mask_mem = False): 44 | bs = input['user_id'].size(0) 45 | if mask_mem: 46 | x = torch.cat([torch.zeros([bs,1], dtype=int).to(self.device), input['item_seq']], -1) # (B, 1+T) 47 | else: 48 | x = torch.cat([torch.ones([bs,1], dtype=int).to(self.device), input['item_seq']], -1) # (B, 1+T) 49 | mask_raw = (x > 0).unsqueeze(1).unsqueeze(1) # (B, 1, 1, 1+T) 50 | 51 | 52 | res_tmp = self.emb(input) 53 | mem_raw, mem_cat_raw, seq, item, seq_cat, cat, user, dev_seq, dev_seq_cat = \ 54 | res_tmp['mem_item'], res_tmp['mem_cate'], res_tmp['seq_item'], res_tmp['item'], res_tmp['seq_cate'], res_tmp[ 55 | 'cate'], res_tmp['user'], res_tmp['edge_seq_item'], res_tmp['edge_seq_cate'] 56 | ###################################################### 57 | # mem = torch.cat([mem_raw,mem_cat_raw], dim=1) # (B, T, 2D) 58 | mem_all = torch.cat([mem_raw, mem_cat_raw], -1) # (B, 2D) 59 | seq_all = torch.cat([seq, seq_cat], -1) # (B, T, 2D) 60 | item_all = torch.cat([item, cat], -1) # (B, 2D) 61 | 62 | dev_all = torch.cat([dev_seq, dev_seq_cat], -1) # (B, T, 2D) 63 | 64 | ###################################################### 65 | 66 | mask_dev = ((input['edge_item_seq'] > 0) * 1.0) # (B, T) 67 | dev_actual_len = torch.sum(mask_dev, -1, keepdim=True) # (B, 1) 68 | dev_len = torch.where(dev_actual_len > 0, dev_actual_len, torch.ones_like(dev_actual_len)) 69 | dev_all_sum = torch.sum(dev_all * mask_dev.unsqueeze(-1), 1) 70 | dev_all_mean = dev_all_sum / dev_len 71 | 72 | ######################################################## 73 | 74 | if 'seq_len' in input.keys(): # generated by a neural network 75 | seq_len = input['seq_len'] 76 | gate = torch.cat([seq_len.unsqueeze(-1).float(), dev_actual_len], dim=-1) 77 | gate = self.gate_linear2(self.prelu_gate(self.gate_linear1(gate))) 78 | gate = torch.sigmoid(gate) 79 | else: 80 | gate = self.gate # use specific hyperparameter 81 | 82 | scores = [] 83 | ''' 84 | train: task_type=='train', self.training==True 85 | eval: task_type=='train', self.training==False 86 | inference: task_type=='inference', self.training==False 87 | update: task_type=='inference', self.training==True 88 | ''' 89 | if self.task_type=='train' and self.training: 90 | mask_meta = torch.ones_like(mask_raw) * (1 - self.drop_rate) 91 | for k in range(self.K): 92 | # random mask 93 | mask = torch.bernoulli(mask_meta) * mask_raw 94 | mem_all, _ = self.agg(mem_all, seq_all, mask, mode='b') 95 | all_feas = torch.cat([gate * mem_all + (1 - gate) * dev_all_mean, item_all], -1) 96 | all_feas = self.bn(all_feas) 97 | x = self.prelu1(self.linear1(all_feas)) 98 | x = self.prelu2(self.linear2(x)) 99 | x = self.linear3(x) 100 | score = F.softmax(x,dim=-1) + 0.00000001 101 | scores.append(score) 102 | 103 | elif not self.training: 104 | if self.task_type=='train': 105 | mem_all, _ = self.agg(mem_all, seq_all, mask_raw, mode='b') 106 | all_feas = torch.cat([gate*mem_all + (1.0-gate)*dev_all_mean, item_all], -1) 107 | all_feas = self.bn(all_feas) 108 | x = self.prelu1(self.linear1(all_feas)) 109 | x = self.prelu2(self.linear2(x)) 110 | x = self.linear3(x) 111 | score = F.softmax(x,dim=-1) + 0.00000001 112 | scores.append(score) 113 | elif self.task_type=='inference' and self.training: 114 | mem_all, _ = self.agg(mem_all, seq_all, mask_raw, mode='b') 115 | d_model = mem_all.size(1) // 2 116 | mem, mem_cat = mem_all[:, :d_model], mem_all[:, d_model:] 117 | # update memory 118 | self.emb.Mem_item.weight.data[torch.LongTensor(input['user_id'].to('cpu'))] = mem 119 | self.emb.Mem_cat.weight.data[torch.LongTensor(input['user_id'].to('cpu'))] = mem_cat 120 | return None, None 121 | 122 | score_avg = sum(scores) / len(scores) 123 | 124 | label = torch.stack([input['label'], 1-input['label']], dim=1) 125 | 126 | 127 | loss_sup = 0.0 # supervised loss 128 | loss_unsup = 0.0 # self-supervised loss 129 | if self.training: 130 | for k in range(self.K): 131 | loss_unsup += F.mse_loss(scores[k], score_avg) 132 | loss_sup += (-torch.mean(torch.log(scores[k]) * label)) 133 | else: 134 | loss_sup += (-torch.mean(torch.log(scores[0]) * label)) 135 | 136 | loss = loss_unsup * self.l + loss_sup 137 | 138 | return loss, score_avg 139 | 140 | class Aggregation(nn.Module): 141 | def __init__(self, d_model, num_head): 142 | super(Aggregation, self).__init__() 143 | self.Q = nn.Linear(d_model, d_model, bias=False) 144 | nn.init.xavier_uniform_(self.Q.weight, gain=1.0) 145 | self.K = nn.Linear(d_model, d_model, bias=False) 146 | nn.init.xavier_uniform_(self.K.weight, gain=1.0) 147 | self.V = nn.Linear(d_model, d_model, bias=False) 148 | nn.init.xavier_uniform_(self.V.weight, gain=1.0) 149 | self.d_h = d_model//num_head 150 | self.num_head = num_head 151 | self.d_model = d_model 152 | # self.ln = nn.LayerNorm(d_model) 153 | 154 | def forward(self, memory, items, mask, mode='agg'): 155 | if mode in ('only', 'aggregation','agg', 'a'): 156 | return self._only_aggregation_multihead(memory, items, mask) 157 | else: 158 | return self._multihead_attention(memory, items, mask) 159 | 160 | def _only_aggregation_multihead(self, memory, items, mask=None, residual=False): 161 | bs = memory.size(0) 162 | seq = torch.cat([memory.unsqueeze(1), items], 1) # (B, 1+T, D) 163 | mem_q = self.Q(memory).view(bs, -1, self.num_head, self.d_h).transpose(1,2) # (B, H, 1, d_h) 164 | seq_k = self.K(seq).view(bs, -1, self.num_head, self.d_h).transpose(1,2) # (B, H, 1+T, d_h) 165 | seq_v = seq.view(bs, -1, self.num_head, self.d_h).transpose(1, 2) # (B, H, 1+T, d_h) 166 | 167 | mem, attn = self._only_aggregation_single(mem_q, seq_k, seq_v, mask) # (B, H, 1, d_h), (B, H, 1, 1+T) 168 | mem = mem.transpose(1,2).contiguous().view(bs, self.d_model) # (B, D) 169 | if residual: 170 | mem = torch.add(mem, memory) 171 | return mem, attn 172 | 173 | def _only_aggregation_single(self, mem_q, seq_k, seq_v, mask=None, dropout=None): 174 | scores = torch.matmul(mem_q, seq_k.transpose(-2,-1)) / math.sqrt(self.d_h) # (B, H, 1, 1+T) 175 | if mask is not None: 176 | scores = scores.masked_fill(mask==0, -2**32+1) 177 | p_attn = F.softmax(scores, dim=-1) 178 | if dropout is not None: 179 | p_attn = dropout(p_attn) 180 | return torch.matmul(p_attn, seq_v), p_attn # (B, H, 1, d_h), (B, H, 1, 1+T) 181 | 182 | def _multihead_attention(self, memory, items, mask=None, residual=True): 183 | bs = memory.size(0) 184 | seq = torch.cat([memory.unsqueeze(1), items], 1) # (B, 1+T, D) 185 | q = self.Q(seq).view(bs, -1, self.num_head, self.d_h).transpose(1,2) # (B, H, 1+T, d_h) 186 | k = self.K(seq).view(bs, -1, self.num_head, self.d_h).transpose(1,2) # (B, H, 1+T, d_h) 187 | v = seq.view(bs, -1, self.num_head, self.d_h).transpose(1, 2) # (B, H, 1+T, d_h) 188 | 189 | seq_new, attn = self._single_attention(q, k, v, mask) # (B, H, 1+T, d_h), (B, H, 1+T, 1+T) 190 | 191 | seq_new = seq_new.transpose(1,2).contiguous().view(bs, -1, self.d_model) # (B, 1+T, D) 192 | 193 | mem = seq_new[:,0,:] 194 | 195 | if residual: 196 | mem = torch.add(mem, memory) 197 | return mem, attn 198 | 199 | def _single_attention(self, q, k, v, mask=None, dropout=None): 200 | scores = torch.matmul(q, k.transpose(-2,-1)) / math.sqrt(self.d_h) # (B, H, 1+T, 1+T) 201 | scores = scores / self.d_h ** 0.5 202 | if mask is not None: 203 | scores = scores.masked_fill(mask == 0, -2**32+1) 204 | 205 | p_attn = F.softmax(scores, dim=-1) 206 | if dropout is not None: 207 | p_attn = dropout(p_attn) 208 | return torch.matmul(p_attn, v), p_attn # (B, H, 1+T, d_h), (B, H, 1+T, 1+T) 209 | 210 | class Embedding(nn.Module): 211 | def __init__(self, num_user, num_item, num_cat, d_model=32, d_mem=128, mem_init = '2'): 212 | super().__init__() 213 | if mem_init in ('0', 'zero'): 214 | self.Mem_item = nn.Embedding.from_pretrained(torch.zeros(num_user, d_mem), freeze=True) 215 | self.Mem_cat = nn.Embedding.from_pretrained(torch.zeros(num_user, d_mem), freeze=True) 216 | elif mem_init in ('1', 'one'): 217 | self.Mem_item = nn.Embedding.from_pretrained(torch.ones(num_user, d_mem) / d_mem, freeze=True) 218 | self.Mem_cat = nn.Embedding.from_pretrained(torch.ones(num_user, d_mem) / d_mem, freeze=True) 219 | else: 220 | self.Mem_item = nn.Embedding(num_user, d_mem) 221 | nn.init.xavier_uniform_(self.Mem_item.weight, gain=1.0) 222 | self.Mem_cat = nn.Embedding(num_user, d_mem) 223 | nn.init.xavier_uniform_(self.Mem_cat.weight, gain=1.0) 224 | 225 | self.Item = nn.Embedding(num_item, d_model, padding_idx=0) 226 | nn.init.xavier_uniform_(self.Item.weight, gain=1.0) 227 | self.Cat = nn.Embedding(num_cat, d_model, padding_idx=0) 228 | nn.init.xavier_uniform_(self.Cat.weight, gain=1.0) 229 | self.User = nn.Embedding(num_user, d_model, padding_idx=0) 230 | nn.init.xavier_uniform_(self.User.weight, gain=1.0) 231 | 232 | def forward(self, input): 233 | output = { 234 | 'mem_item': self.Mem_item(input['user_id']), 235 | 'mem_cate': self.Mem_cat(input['user_id']), 236 | 'item': self.Item(input['item_id']), 237 | 'cate': self.Cat(input['cate_id']), 238 | 'user': self.User(input['user_id']), 239 | 'edge_seq_item': self.Item(input['edge_item_seq']), 240 | 'edge_seq_cate': self.Cat(input['edge_cate_seq']), 241 | 'seq_item': self.Item(input['item_seq']), 242 | 'seq_cate': self.Cat(input['cate_seq']) 243 | 244 | } 245 | return output -------------------------------------------------------------------------------- /onnx_test/ggcn_onnx_model_test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2022 The Luoxi Team. 3 | # All rights reserved. 4 | # This source code is licensed under the Apache 2.0 license 5 | # found in the LICENSE file in the root directory. 6 | 7 | import torch 8 | from backend.dataset_hub.ggcn_datasets import GGCNDatasetLocal 9 | 10 | def real_cigar_data(): 11 | data_type = {'userid': "int", 12 | 'cms_segid': "int", 13 | 'cms_group_id': "int", 14 | 'final_gender_code': "int", 15 | 'age_level': "int", 16 | 'pvalue_level': "int", 17 | 'shopping_level': "int", 18 | 'occupation': "int", 19 | 'new_user_class_level': "int", 20 | 'adgroup_id': "int", 21 | 'seq_length': "int", 22 | 'item_embedding': "float", 23 | 'mean_embedding': "float", 24 | 'gnn_output': "float", 25 | 'group_score': "float", 26 | 'label': "float", 27 | } 28 | 29 | data = {'userid': [257942, 352928, 296768, 325475, 544880], 30 | 'cms_segid': [8, 79, 1, 1, 1], 31 | 'cms_group_id': [3, 11, 5, 5, 5], 32 | 'final_gender_code': [2, 1, 2, 2, 2], 33 | 'age_level': [3, 5, 5, 5, 5], 34 | 'pvalue_level': [3, 2, 1, 1, 3], 35 | 'shopping_level': [3, 3, 1, 2, 3], 36 | 'occupation': [1, 1, 1, 1, 1], 37 | 'new_user_class_level': [3, 5, 1, 1, 1], 38 | 'adgroup_id': [777487, 9781, 664671, 743740, 844239], 39 | 'seq_length': [1., 0., 28., 2., 3.], 40 | 'item_embedding': [ 41 | [-0.1642, 0.0191, -0.0332, 0.0666, 0.0226, 0.0220, 0.1183, -0.1284, 0.1393, 0.2170, -0.1647, 42 | -0.0316, -0.2213, -0.1487, 0.0263, 0.2481, 0.0220, 0.0649, 0.2545, 0.2207, 0.1615, -0.0936, 43 | -0.0373, 0.2057, -0.0980, -0.0574, 0.3132, 0.0471, -0.1437, 0.1355, -0.0809, -0.0405, -0.0300, 44 | -0.0126, 0.0446, -0.0285, 0.0164, 0.0605, -0.0186, -0.0894], 45 | [-0.6926, 0.0021, -0.2531, 0.1085, 0.0345, 0.2508, 0.2773, -0.3272, 0.0986, 0.2033, 0.0233, 0.0342, 46 | -0.1619, -0.0700, -0.0821, 0.2925, -0.2639, 0.2183, -0.1321, -0.3341, 0.3004, 0.4415, -0.2169, 47 | 0.4756, -0.1051, 0.0158, 0.1106, 0.0990, -0.0461, -0.0206, -0.1337, 0.0419, 0.0617, 0.1032, 0.0272, 48 | -0.0244, 0.1607, 0.0892, -0.0447, -0.0411], 49 | [-0.0928, 0.0016, -0.1122, 0.1218, -0.0621, -0.0516, 0.0155, -0.0864, 0.1379, 0.1051, -0.0403, 50 | -0.0604, -0.1258, -0.2051, 0.0125, 0.0997, 0.0265, 0.0449, 0.0467, -0.0373, -0.0020, 0.0011, 51 | 0.1486, -0.0105, -0.2552, 0.0563, -0.0886, -0.1216, 0.0752, 0.0040, -0.0218, -0.0609, -0.0300, 52 | -0.0126, 0.0446, -0.0285, 0.0164, 0.0605, -0.0186, -0.0894], 53 | [-0.1795, 0.0693, 0.0743, -0.0031, -0.0856, 0.0287, 0.1785, -0.0536, 0.1304, 0.1885, -0.0549, 54 | -0.0532, -0.1828, -0.0833, 0.0737, 0.3236, -0.1278, 0.0281, 0.0973, -0.1154, 0.1054, 0.0590, 55 | 0.1405, 0.0663, -0.1268, 0.0774, -0.0258, -0.0322, -0.0032, -0.1190, -0.1028, 0.0612, -0.0199, 56 | 0.0594, -0.0132, -0.0351, 0.1327, 0.1379, -0.2662, -0.0991], 57 | [-0.0861, 0.0548, 0.0904, 0.0602, -0.0635, 0.0743, 0.1105, 0.0074, 0.0930, 0.1092, 0.0534, 0.0575, 58 | -0.2001, -0.1564, -0.0505, 0.0641, -0.1482, 0.1948, -0.2620, -0.3430, 0.2245, 0.2375, 0.1858, 59 | 0.2586, -0.1011, -0.0609, 0.0232, 0.0289, 0.1379, -0.0487, -0.0158, 0.0918, 0.1813, 0.0669, 60 | -0.0605, -0.1415, -0.0501, 0.0040, 0.0030, -0.0952]], 61 | 'mean_embedding': [ 62 | [4.4465e-02, 1.0445e-01, -2.7625e-01, -2.0569e-01, -6.8474e-02, -5.2394e-02, 1.0456e-01, 1.7022e-01, 63 | 1.3791e-01, 1.0512e-01, -4.0317e-02, -6.0382e-02, -1.2581e-01, -2.0511e-01, 1.2495e-02, 9.9727e-02, 64 | -1.2358e-01, 2.7788e-02, -8.0873e-02, -1.1753e-01, 1.0135e-01, 5.6316e-03, -2.0245e-01, 65 | -1.0968e-01, 7.0365e-02, -3.8328e-02, -1.5932e-01, -1.1318e-01, -1.4678e-01, -8.6553e-02, 66 | 6.5388e-02, -1.0718e-01, -1.8357e-02, -1.0338e-01, 8.6135e-02, 3.4204e-02, 7.0397e-02, -2.0644e-02, 67 | 1.0845e-01, 1.4099e-02], 68 | [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 69 | 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 70 | 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 71 | 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 72 | 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], 73 | [4.0462e-02, -8.6966e-02, -1.2895e-01, 6.5544e-02, -3.4426e-02, 1.1975e-01, -1.2570e-01, 2.4530e-02, 74 | 1.4350e-01, 1.1492e-01, -4.2107e-02, -4.2388e-02, -1.4527e-01, -1.9226e-01, 4.5437e-03, 1.1377e-01, 75 | -1.1355e-02, -1.8809e-02, 4.4197e-02, 8.3866e-02, -6.0360e-02, -3.1770e-02, -2.8437e-02, 76 | 1.8450e-01, -1.8495e-01, -8.3106e-02, -3.4705e-04, 7.0822e-02, -9.0474e-02, 8.2293e-02, 77 | -6.0735e-02, 1.1667e-01, 2.2143e-02, -1.3329e-02, -2.7356e-02, -7.3730e-02, 8.0943e-02, 9.9228e-02, 78 | -7.0119e-02, -1.0540e-01], 79 | [2.7340e-02, 1.1924e-01, -5.9607e-02, 1.5895e-01, 2.1646e-02, -7.5258e-02, -7.0717e-02, 7.2195e-03, 80 | 1.5034e-01, 1.6499e-01, -4.7621e-02, -1.4737e-02, -1.7842e-01, -1.0894e-01, 1.5183e-02, 2.2416e-01, 81 | -9.1066e-02, 1.4574e-01, 6.9024e-02, -5.4564e-02, 1.3869e-01, 7.7455e-02, -1.3641e-01, 1.2936e-01, 82 | -1.0755e-01, 4.2002e-02, 6.5072e-02, 1.4974e-02, -7.0939e-03, -8.7588e-02, -5.9885e-02, 4.7056e-02, 83 | 1.7133e-02, -2.1070e-02, -2.4853e-02, -4.2562e-02, 6.3640e-02, 1.0982e-01, -6.8471e-02, 84 | -8.8208e-02], 85 | [-1.2780e-01, 5.0805e-02, -1.1612e-01, 7.9578e-02, 1.0335e-01, 1.4314e-01, 3.3847e-03, -1.7688e-01, 86 | 1.3929e-01, 1.6951e-01, -6.5954e-02, -4.5142e-02, -1.4295e-01, -1.6744e-01, -2.4190e-02, 87 | 1.8433e-01, -3.6416e-02, 2.4685e-01, -3.6875e-02, -5.3090e-02, 8.7769e-02, 1.5820e-01, -1.7054e-01, 88 | 3.4149e-01, -3.6339e-01, -5.0575e-02, 4.6724e-02, 1.6422e-01, -1.8536e-01, -3.9412e-02, 89 | -2.6019e-01, 1.9385e-01, -2.7335e-02, -1.8533e-02, -1.7009e-02, -3.2307e-02, 2.7097e-03, 90 | 3.9065e-02, -8.3710e-03, -5.8574e-02]], 91 | 'gnn_output': [ 92 | [-0.0719, -0.1300, -0.0700, 0.1248, 0.0943, -0.1491, -0.1248, -0.0546, 0.0144, 0.1653, -0.1943, 93 | -0.0897, 0.1093, -0.2755, 0.2900, 0.0476, 0.3434, -0.1982, -0.2286, 0.1124, -0.3105, -0.3145, 94 | 0.0516, -0.0458, 0.0021, 0.0970, 0.1089, -0.0205, 0.0430, 0.0854, 0.2716, -0.0453, 0.2313, -0.0869, 95 | 0.3106, 0.0992, -0.0270, 0.0260, 0.0730, -0.2721], 96 | [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 97 | 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 98 | 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 99 | 0.0000, 0.0000, 0.0000, 0.0000], 100 | [0.0357, -0.2169, -0.2358, -0.1980, -0.0763, -0.0323, 0.0960, -0.3070, 0.0956, 0.0345, -0.0965, 101 | -0.0754, -0.0028, -0.2407, 0.1383, 0.1988, 0.1502, 0.0690, -0.2667, -0.0154, -0.2668, -0.2654, 102 | -0.1630, 0.1116, 0.0551, -0.0283, -0.1580, -0.0829, 0.0186, 0.1518, 0.3620, 0.1359, 0.3065, 103 | -0.0371, 0.3105, -0.0218, -0.0081, -0.1666, 0.2048, -0.0792], 104 | [0.0136, 0.0489, -0.2396, 0.0582, 0.0072, 0.0309, 0.0462, -0.1681, -0.0556, 0.2680, 0.0539, -0.0192, 105 | -0.0064, -0.3069, 0.2251, -0.0020, 0.3111, 0.1067, -0.1319, 0.1522, -0.2446, -0.3086, -0.0063, 106 | -0.0076, 0.0004, -0.0174, -0.0301, -0.0181, 0.0073, 0.0129, 0.2869, 0.1295, 0.1268, 0.0299, 0.3440, 107 | -0.0242, -0.1371, 0.0263, -0.0792, -0.2131], 108 | [0.0159, 0.0400, -0.2300, 0.1695, 0.1073, 0.0497, 0.0163, -0.1813, -0.0745, 0.3303, 0.0953, -0.0304, 109 | -0.0085, -0.2792, 0.2309, 0.0410, 0.3900, 0.1206, -0.1116, 0.1467, -0.2310, -0.2889, -0.1278, 110 | -0.0606, -0.0874, 0.0693, 0.0017, 0.0307, 0.0867, -0.0641, 0.3025, 0.0540, 0.2110, 0.0215, 0.3501, 111 | 0.1108, -0.1022, 0.0526, -0.0510, -0.1610]], 112 | 'group_score': [0.0406, 0.0049, 0.0396, 0.0399, 0.0755], 113 | 'label': [0., 0., 0., 0., 1.]} 114 | re_data = {} 115 | for k, v in data.items(): 116 | if data_type.get(k) == "int": 117 | re_data[k] = torch.LongTensor(v) 118 | else: 119 | re_data[k] = torch.FloatTensor(v) 120 | return data_type, re_data 121 | 122 | def mock_data(num_samples=1,data_names=None,data_types=None,data_sizes=None,data_nums=None): 123 | ''' 124 | num_samples: batch_size 125 | data_names: 126 | data_types: int or float 127 | data_sizes: the size of one sample 128 | data_nums: dict, the max num of k 129 | ''' 130 | mock_data={} 131 | data_names = data_names.split(",") 132 | data_types = data_types.split(",") 133 | data_sizes = data_sizes.split(",") 134 | assert len(data_names)==len(data_types)==len(data_sizes) 135 | for name,type,size in zip(data_names,data_types,data_sizes): 136 | if type=="int": 137 | high_num = data_nums.get(name,2) if data_nums else 2 138 | v=torch.randint(low=1,high=high_num,size=[num_samples],dtype=torch.int64) 139 | else: 140 | if name=="label": 141 | v=torch.ones([num_samples],dtype=torch.float32,requires_grad=True) 142 | else: 143 | v = torch.randn([num_samples,int(size)],dtype=torch.float32, requires_grad=True) 144 | mock_data[name]=v 145 | return mock_data 146 | 147 | def mock_gnn_dataTable(args): 148 | data_names = "userid,cms_segid,cms_group_id,final_gender_code,age_level,pvalue_level,shopping_level,occupation,new_user_class_level,adgroup_id,seq_length,item_embedding,mean_embedding,gnn_output,group_score,label" 149 | data_types = "int,int,int,int,int,int,int,int,int,int,int,float,float,float,float,float" 150 | data_sizes = "1,1,1,1,1,1,1,1,1,1,1,40,40,40,1,1" 151 | data_nums = {"user_id": 1150000, 152 | "cms_segid": 100, 153 | "cms_group_id": 15, 154 | "final_gender_code": 5, 155 | "age_level": 10, 156 | "pvalue_level": 5, 157 | "shopping_level": 5, 158 | "occupation": 5, 159 | "new_user_class_level": 10, 160 | "adgroup_id": 850000} 161 | data = mock_data(2, data_names, data_types, data_sizes, data_nums) 162 | return data 163 | 164 | def mock_no_gnn_data(args): 165 | data_names = "userid,cms_segid,cms_group_id,final_gender_code,age_level,pvalue_level,shopping_level,occupation,new_user_class_level,adgroup_id,seq_length,item_embedding,mean_embedding,group_score,label" 166 | data_types = "int,int,int,int,int,int,int,int,int,int,int,float,float,float,float" 167 | data_sizes = "1,1,1,1,1,1,1,1,1,1,1,40,40,1,1" 168 | data_nums = {"user_id": 1150000, 169 | "cms_segid": 100, 170 | "cms_group_id": 15, 171 | "final_gender_code": 5, 172 | "age_level": 10, 173 | "pvalue_level": 5, 174 | "shopping_level": 5, 175 | "occupation": 5, 176 | "new_user_class_level": 10, 177 | "adgroup_id": 850000} 178 | data = mock_data(2, data_names, data_types, data_sizes, data_nums) 179 | return data 180 | 181 | def mock_gnn_data(args): 182 | input_tables = args.tables.split(",") 183 | dataset = GGCNDatasetLocal(args, input_tables[1], traintype='eval') 184 | data = next(iter(dataset)) 185 | 186 | 187 | return data 188 | 189 | def onnx_test(mock_data_func, label_col_name = 'label', onnx_export_path=None, onnx_model_name=None): 190 | ''' 191 | func: get the result of onnx model 192 | ********** 193 | Note that the order of the keys in the returned dict from mock_data_func must be the same as the order of the parameters in the forward method of your model 194 | ********** 195 | mock_data_func: the func is used in onnx export 196 | onnx_export_path: onnx model path 197 | onnx_model_name: onnx model name 198 | ''' 199 | 200 | def to_numpy(tensor, data_type=None): 201 | tensor = tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() 202 | if data_type: 203 | if data_type == "int": 204 | data_type = np.int64 205 | else: 206 | data_type = np.float32 207 | tensor = tensor.astype(data_type) 208 | return tensor 209 | 210 | data = mock_data_func() 211 | assert isinstance(data, OrderedDict), "returned data from mock_data_func must be an OrderDict!" 212 | 213 | data = dict((k, to_numpy(v)) for k, v in data.items()) 214 | 215 | y = data[label_col_name].astype(np.float32) 216 | input_data = [] 217 | for i,(k, v) in enumerate(data.items()): 218 | input_data.append(v) 219 | input_data.append(y) 220 | onnx_model_name = onnx_model_name if onnx_model_name else "model_00.onnx" 221 | model_path = os.path.join(onnx_export_path, onnx_model_name) 222 | ort_session = onnxruntime.InferenceSession(model_path) 223 | ort_inputs={} 224 | for s_input in ort_session.get_inputs(): 225 | k = str(s_input.name).split(".")[-1] 226 | in_da = input_data[int(k)] 227 | ort_inputs[s_input.name] = in_da 228 | ort_outs = ort_session.run(None, ort_inputs) 229 | print(ort_outs) 230 | if __name__ == '__main__': 231 | onnx_test(mock_gnn_data) 232 | 233 | -------------------------------------------------------------------------------- /onnx_test/cigar_onnx_model_test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2022 The Luoxi Team. 3 | # All rights reserved. 4 | # This source code is licensed under the Apache 2.0 license 5 | # found in the LICENSE file in the root directory. 6 | 7 | import numpy as np 8 | import torch 9 | import onnxruntime 10 | from collections import OrderedDict 11 | import os 12 | 13 | def real_data(data_names=None,data_types=None): 14 | data_types = data_types.split(",") 15 | data_names = data_names.split(",") 16 | data_type = OrderedDict() 17 | for name, type in zip(data_names,data_types): 18 | data_type[name]=type 19 | # data_type.update({'userid': "int", 20 | # 'cms_segid': "int", 21 | # 'cms_group_id': "int", 22 | # 'final_gender_code': "int", 23 | # 'age_level': "int", 24 | # 'pvalue_level': "int", 25 | # 'shopping_level': "int", 26 | # 'occupation': "int", 27 | # 'new_user_class_level': "int", 28 | # 'adgroup_id': "int", 29 | # 'seq_length': "int", 30 | # 'item_embedding': "float", 31 | # 'mean_embedding': "float", 32 | # 'gnn_output': "float", 33 | # 'group_score': "float", 34 | # 'label': "float", 35 | # }) 36 | 37 | data = OrderedDict() 38 | data.update({'userid': [257942, 352928, 296768, 325475, 544880], 39 | 'cms_segid': [8, 79, 1, 1, 1], 40 | 'cms_group_id': [3, 11, 5, 5, 5], 41 | 'final_gender_code': [2, 1, 2, 2, 2], 42 | 'age_level': [3, 5, 5, 5, 5], 43 | 'pvalue_level': [3, 2, 1, 1, 3], 44 | 'shopping_level': [3, 3, 1, 2, 3], 45 | 'occupation': [1, 1, 1, 1, 1], 46 | 'new_user_class_level': [3, 5, 1, 1, 1], 47 | 'adgroup_id': [777487, 9781, 664671, 743740, 844239], 48 | 'seq_length': [1., 0., 28., 2., 3.], 49 | 'item_embedding': [ 50 | [-0.1642, 0.0191, -0.0332, 0.0666, 0.0226, 0.0220, 0.1183, -0.1284, 0.1393, 0.2170, -0.1647, 51 | -0.0316, -0.2213, -0.1487, 0.0263, 0.2481, 0.0220, 0.0649, 0.2545, 0.2207, 0.1615, -0.0936, 52 | -0.0373, 0.2057, -0.0980, -0.0574, 0.3132, 0.0471, -0.1437, 0.1355, -0.0809, -0.0405, -0.0300, 53 | -0.0126, 0.0446, -0.0285, 0.0164, 0.0605, -0.0186, -0.0894], 54 | [-0.6926, 0.0021, -0.2531, 0.1085, 0.0345, 0.2508, 0.2773, -0.3272, 0.0986, 0.2033, 0.0233, 0.0342, 55 | -0.1619, -0.0700, -0.0821, 0.2925, -0.2639, 0.2183, -0.1321, -0.3341, 0.3004, 0.4415, -0.2169, 56 | 0.4756, -0.1051, 0.0158, 0.1106, 0.0990, -0.0461, -0.0206, -0.1337, 0.0419, 0.0617, 0.1032, 0.0272, 57 | -0.0244, 0.1607, 0.0892, -0.0447, -0.0411], 58 | [-0.0928, 0.0016, -0.1122, 0.1218, -0.0621, -0.0516, 0.0155, -0.0864, 0.1379, 0.1051, -0.0403, 59 | -0.0604, -0.1258, -0.2051, 0.0125, 0.0997, 0.0265, 0.0449, 0.0467, -0.0373, -0.0020, 0.0011, 60 | 0.1486, -0.0105, -0.2552, 0.0563, -0.0886, -0.1216, 0.0752, 0.0040, -0.0218, -0.0609, -0.0300, 61 | -0.0126, 0.0446, -0.0285, 0.0164, 0.0605, -0.0186, -0.0894], 62 | [-0.1795, 0.0693, 0.0743, -0.0031, -0.0856, 0.0287, 0.1785, -0.0536, 0.1304, 0.1885, -0.0549, 63 | -0.0532, -0.1828, -0.0833, 0.0737, 0.3236, -0.1278, 0.0281, 0.0973, -0.1154, 0.1054, 0.0590, 64 | 0.1405, 0.0663, -0.1268, 0.0774, -0.0258, -0.0322, -0.0032, -0.1190, -0.1028, 0.0612, -0.0199, 65 | 0.0594, -0.0132, -0.0351, 0.1327, 0.1379, -0.2662, -0.0991], 66 | [-0.0861, 0.0548, 0.0904, 0.0602, -0.0635, 0.0743, 0.1105, 0.0074, 0.0930, 0.1092, 0.0534, 0.0575, 67 | -0.2001, -0.1564, -0.0505, 0.0641, -0.1482, 0.1948, -0.2620, -0.3430, 0.2245, 0.2375, 0.1858, 68 | 0.2586, -0.1011, -0.0609, 0.0232, 0.0289, 0.1379, -0.0487, -0.0158, 0.0918, 0.1813, 0.0669, 69 | -0.0605, -0.1415, -0.0501, 0.0040, 0.0030, -0.0952]], 70 | 'mean_embedding': [ 71 | [4.4465e-02, 1.0445e-01, -2.7625e-01, -2.0569e-01, -6.8474e-02, -5.2394e-02, 1.0456e-01, 1.7022e-01, 72 | 1.3791e-01, 1.0512e-01, -4.0317e-02, -6.0382e-02, -1.2581e-01, -2.0511e-01, 1.2495e-02, 9.9727e-02, 73 | -1.2358e-01, 2.7788e-02, -8.0873e-02, -1.1753e-01, 1.0135e-01, 5.6316e-03, -2.0245e-01, 74 | -1.0968e-01, 7.0365e-02, -3.8328e-02, -1.5932e-01, -1.1318e-01, -1.4678e-01, -8.6553e-02, 75 | 6.5388e-02, -1.0718e-01, -1.8357e-02, -1.0338e-01, 8.6135e-02, 3.4204e-02, 7.0397e-02, -2.0644e-02, 76 | 1.0845e-01, 1.4099e-02], 77 | [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 78 | 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 79 | 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 80 | 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 81 | 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], 82 | [4.0462e-02, -8.6966e-02, -1.2895e-01, 6.5544e-02, -3.4426e-02, 1.1975e-01, -1.2570e-01, 2.4530e-02, 83 | 1.4350e-01, 1.1492e-01, -4.2107e-02, -4.2388e-02, -1.4527e-01, -1.9226e-01, 4.5437e-03, 1.1377e-01, 84 | -1.1355e-02, -1.8809e-02, 4.4197e-02, 8.3866e-02, -6.0360e-02, -3.1770e-02, -2.8437e-02, 85 | 1.8450e-01, -1.8495e-01, -8.3106e-02, -3.4705e-04, 7.0822e-02, -9.0474e-02, 8.2293e-02, 86 | -6.0735e-02, 1.1667e-01, 2.2143e-02, -1.3329e-02, -2.7356e-02, -7.3730e-02, 8.0943e-02, 9.9228e-02, 87 | -7.0119e-02, -1.0540e-01], 88 | [2.7340e-02, 1.1924e-01, -5.9607e-02, 1.5895e-01, 2.1646e-02, -7.5258e-02, -7.0717e-02, 7.2195e-03, 89 | 1.5034e-01, 1.6499e-01, -4.7621e-02, -1.4737e-02, -1.7842e-01, -1.0894e-01, 1.5183e-02, 2.2416e-01, 90 | -9.1066e-02, 1.4574e-01, 6.9024e-02, -5.4564e-02, 1.3869e-01, 7.7455e-02, -1.3641e-01, 1.2936e-01, 91 | -1.0755e-01, 4.2002e-02, 6.5072e-02, 1.4974e-02, -7.0939e-03, -8.7588e-02, -5.9885e-02, 4.7056e-02, 92 | 1.7133e-02, -2.1070e-02, -2.4853e-02, -4.2562e-02, 6.3640e-02, 1.0982e-01, -6.8471e-02, 93 | -8.8208e-02], 94 | [-1.2780e-01, 5.0805e-02, -1.1612e-01, 7.9578e-02, 1.0335e-01, 1.4314e-01, 3.3847e-03, -1.7688e-01, 95 | 1.3929e-01, 1.6951e-01, -6.5954e-02, -4.5142e-02, -1.4295e-01, -1.6744e-01, -2.4190e-02, 96 | 1.8433e-01, -3.6416e-02, 2.4685e-01, -3.6875e-02, -5.3090e-02, 8.7769e-02, 1.5820e-01, -1.7054e-01, 97 | 3.4149e-01, -3.6339e-01, -5.0575e-02, 4.6724e-02, 1.6422e-01, -1.8536e-01, -3.9412e-02, 98 | -2.6019e-01, 1.9385e-01, -2.7335e-02, -1.8533e-02, -1.7009e-02, -3.2307e-02, 2.7097e-03, 99 | 3.9065e-02, -8.3710e-03, -5.8574e-02]], 100 | 'gnn_output': [ 101 | [-0.0719, -0.1300, -0.0700, 0.1248, 0.0943, -0.1491, -0.1248, -0.0546, 0.0144, 0.1653, -0.1943, 102 | -0.0897, 0.1093, -0.2755, 0.2900, 0.0476, 0.3434, -0.1982, -0.2286, 0.1124, -0.3105, -0.3145, 103 | 0.0516, -0.0458, 0.0021, 0.0970, 0.1089, -0.0205, 0.0430, 0.0854, 0.2716, -0.0453, 0.2313, -0.0869, 104 | 0.3106, 0.0992, -0.0270, 0.0260, 0.0730, -0.2721], 105 | [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 106 | 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 107 | 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 108 | 0.0000, 0.0000, 0.0000, 0.0000], 109 | [0.0357, -0.2169, -0.2358, -0.1980, -0.0763, -0.0323, 0.0960, -0.3070, 0.0956, 0.0345, -0.0965, 110 | -0.0754, -0.0028, -0.2407, 0.1383, 0.1988, 0.1502, 0.0690, -0.2667, -0.0154, -0.2668, -0.2654, 111 | -0.1630, 0.1116, 0.0551, -0.0283, -0.1580, -0.0829, 0.0186, 0.1518, 0.3620, 0.1359, 0.3065, 112 | -0.0371, 0.3105, -0.0218, -0.0081, -0.1666, 0.2048, -0.0792], 113 | [0.0136, 0.0489, -0.2396, 0.0582, 0.0072, 0.0309, 0.0462, -0.1681, -0.0556, 0.2680, 0.0539, -0.0192, 114 | -0.0064, -0.3069, 0.2251, -0.0020, 0.3111, 0.1067, -0.1319, 0.1522, -0.2446, -0.3086, -0.0063, 115 | -0.0076, 0.0004, -0.0174, -0.0301, -0.0181, 0.0073, 0.0129, 0.2869, 0.1295, 0.1268, 0.0299, 0.3440, 116 | -0.0242, -0.1371, 0.0263, -0.0792, -0.2131], 117 | [0.0159, 0.0400, -0.2300, 0.1695, 0.1073, 0.0497, 0.0163, -0.1813, -0.0745, 0.3303, 0.0953, -0.0304, 118 | -0.0085, -0.2792, 0.2309, 0.0410, 0.3900, 0.1206, -0.1116, 0.1467, -0.2310, -0.2889, -0.1278, 119 | -0.0606, -0.0874, 0.0693, 0.0017, 0.0307, 0.0867, -0.0641, 0.3025, 0.0540, 0.2110, 0.0215, 0.3501, 120 | 0.1108, -0.1022, 0.0526, -0.0510, -0.1610]], 121 | 'group_score': [0.0406, 0.0049, 0.0396, 0.0399, 0.0755], 122 | 'label': [0., 0., 0., 0., 1.]}) 123 | 124 | re_data = OrderedDict() 125 | for k, v in data.items(): 126 | if data_type.get(k) == "int": 127 | re_data[k] = torch.LongTensor(v) 128 | elif data_type.get(k) == "float": 129 | re_data[k] = torch.FloatTensor(v) 130 | else: 131 | continue 132 | return re_data 133 | 134 | def mock_data(num_samples=1,data_names=None,data_types=None,data_sizes=None,data_nums=None): 135 | ''' 136 | num_samples: batch_size 137 | data_names: 138 | data_types: int or float 139 | data_sizes: the size of one sample 140 | data_nums: dict, the max num of k 141 | ''' 142 | mock_data=OrderedDict() 143 | data_names = data_names.split(",") 144 | data_types = data_types.split(",") 145 | data_sizes = data_sizes.split(",") 146 | assert len(data_names)==len(data_types)==len(data_sizes) 147 | for name,type,size in zip(data_names,data_types,data_sizes): 148 | if type=="int": 149 | high_num = data_nums.get(name,2) if data_nums else 2 150 | v=torch.randint(low=1,high=high_num,size=[num_samples],dtype=torch.int64) 151 | else: 152 | if name=="label" or "score" in name: 153 | v=torch.ones([num_samples],dtype=torch.float32,requires_grad=True) 154 | else: 155 | v = torch.randn([num_samples,int(size)],dtype=torch.float32, requires_grad=True) 156 | mock_data[name]=v 157 | return mock_data 158 | 159 | def mock_gnn_data(): 160 | data_names = "userid,cms_segid,cms_group_id,final_gender_code,age_level,pvalue_level,shopping_level,occupation,new_user_class_level,adgroup_id,seq_length,item_embedding,mean_embedding,gnn_output,group_score,label" 161 | data_types = "int,int,int,int,int,int,int,int,int,int,int,float,float,float,float,float" 162 | data_sizes = "1,1,1,1,1,1,1,1,1,1,1,40,40,40,1,1" 163 | data_nums = {"user_id": 1150000, 164 | "cms_segid": 100, 165 | "cms_group_id": 15, 166 | "final_gender_code": 5, 167 | "age_level": 10, 168 | "pvalue_level": 5, 169 | "shopping_level": 5, 170 | "occupation": 5, 171 | "new_user_class_level": 10, 172 | "adgroup_id": 850000} 173 | data = mock_data(5, data_names, data_types, data_sizes, data_nums) 174 | return data 175 | 176 | def mock_no_gnn_data(): 177 | data_names = "userid,cms_segid,cms_group_id,final_gender_code,age_level,pvalue_level,shopping_level,occupation,new_user_class_level,adgroup_id,seq_length,item_embedding,mean_embedding,group_score,label" 178 | data_types = "int,int,int,int,int,int,int,int,int,int,int,float,float,float,float" 179 | data_sizes = "1,1,1,1,1,1,1,1,1,1,1,40,40,1,1" 180 | data_nums = {"user_id": 1150000, 181 | "cms_segid": 100, 182 | "cms_group_id": 15, 183 | "final_gender_code": 5, 184 | "age_level": 10, 185 | "pvalue_level": 5, 186 | "shopping_level": 5, 187 | "occupation": 5, 188 | "new_user_class_level": 10, 189 | "adgroup_id": 850000} 190 | data = mock_data(5, data_names, data_types, data_sizes, data_nums) 191 | return data 192 | 193 | def real_gnn_data(): 194 | data_names = "userid,cms_segid,cms_group_id,final_gender_code,age_level,pvalue_level,shopping_level,occupation,new_user_class_level,adgroup_id,seq_length,item_embedding,mean_embedding,gnn_output,group_score,label" 195 | data_types = "int,int,int,int,int,int,int,int,int,int,int,float,float,float,float,float" 196 | data = real_data(data_names, data_types) 197 | return data 198 | 199 | def real_no_gnn_data(): 200 | data_names = "userid,cms_segid,cms_group_id,final_gender_code,age_level,pvalue_level,shopping_level,occupation,new_user_class_level,adgroup_id,seq_length,item_embedding,mean_embedding,group_score,label" 201 | data_types = "int,int,int,int,int,int,int,int,int,int,int,float,float,float,float" 202 | data = real_data(data_names, data_types) 203 | return data 204 | 205 | def onnx_test(mock_data_func, label_col_name = 'label', onnx_export_path=None, onnx_model_name=None): 206 | ''' 207 | func: get the result of onnx model 208 | ********** 209 | Note that the order of the keys in the returned dict from mock_data_func must be the same as the order of the parameters in the forward method of your model 210 | ********** 211 | mock_data_func: the func is used in onnx export 212 | onnx_export_path: onnx model path 213 | onnx_model_name: onnx model name 214 | ''' 215 | 216 | def to_numpy(tensor, data_type=None): 217 | tensor = tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() 218 | if data_type: 219 | if data_type == "int": 220 | data_type = np.int64 221 | else: 222 | data_type = np.float32 223 | tensor = tensor.astype(data_type) 224 | return tensor 225 | 226 | data = mock_data_func() 227 | assert isinstance(data, OrderedDict), "returned data from mock_data_func must be an OrderDict!" 228 | 229 | data = dict((k, to_numpy(v)) for k, v in data.items()) 230 | 231 | y = data[label_col_name].astype(np.float32) 232 | input_data = [] 233 | for i,(k, v) in enumerate(data.items()): 234 | input_data.append(v) 235 | input_data.append(y) 236 | onnx_model_name = onnx_model_name if onnx_model_name else "model_00.onnx" 237 | model_path = os.path.join(onnx_export_path, onnx_model_name) 238 | # model_path = "./output/model_00.onnx" 239 | ort_session = onnxruntime.InferenceSession(model_path) 240 | ort_inputs={} 241 | for s_input in ort_session.get_inputs(): 242 | k = str(s_input.name).split(".")[-1] 243 | in_da = input_data[int(k)] 244 | ort_inputs[s_input.name] = in_da 245 | ort_outs = ort_session.run(None, ort_inputs) 246 | print(ort_outs) 247 | 248 | if __name__ == '__main__': 249 | onnx_test(real_no_gnn_data,onnx_export_path="/mnt2/yyang/cigar_onnx/save",onnx_model_name="onnx_PNN_2022031511.onnx") -------------------------------------------------------------------------------- /backend/task_backbone.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2022 The Luoxi Team. 3 | # All rights reserved. 4 | # This source code is licensed under the Apache 2.0 license 5 | # found in the LICENSE file in the root directory. 6 | 7 | import torch 8 | from .arguments import get_args 9 | import os 10 | import time 11 | from backend.utils import print_args, load_model_state_only, initialize_distribution_env, set_random_seed, Timer 12 | from backend.utils import save_checkpoint, print_rank_0, make_local_writer 13 | from torch.nn.parallel.distributed import DistributedDataParallel 14 | from torch import nn 15 | from datetime import datetime 16 | from torch.utils.data.distributed import DistributedSampler 17 | 18 | def make_inference_data_loader(args, inference_dataset_provider): 19 | print('make inference data loaders') 20 | 21 | num_workers = torch.distributed.get_world_size() 22 | 23 | test_data_set = inference_dataset_provider(args) 24 | 25 | total_samples = test_data_set.get_total_row_count() 26 | samples_per_iter = args.batch_size * num_workers 27 | args.infer_iters = (total_samples + samples_per_iter - 1) // samples_per_iter 28 | print('[inference]> samples=%d workers=%d batch=%d samples_per_iter=%d infer_iters=%d' % ( 29 | total_samples, num_workers, args.batch_size, samples_per_iter, args.infer_iters)) 30 | 31 | batch_sampler = DistributedSampler(dataset=test_data_set) 32 | infer_data_loader = torch.utils.data.DataLoader(test_data_set, 33 | batch_size=args.batch_size, 34 | sampler=batch_sampler, 35 | drop_last=False) 36 | 37 | return infer_data_loader 38 | 39 | def make_data_loader(args, train_eval_data_provider): 40 | print('make data loaders') 41 | 42 | num_workers = torch.distributed.get_world_size() 43 | 44 | train_data_set, eval_data_set = train_eval_data_provider(args) 45 | 46 | if args.num_epochs > 0: 47 | 48 | total_samples = train_data_set.get_total_row_count() 49 | samples_per_iter = args.batch_size * num_workers 50 | args.train_iters = (args.num_epochs * total_samples + samples_per_iter - 1) // samples_per_iter 51 | 52 | if eval_data_set is not None: 53 | eval_samples = eval_data_set.get_total_row_count() 54 | args.eval_iters = (eval_samples + args.batch_size - 1) // args.batch_size 55 | else: 56 | args.eval_iters = 0 57 | 58 | print('to', args.train_iters, 'due to args.num_epochs=', args.num_epochs) 59 | print('***num_epochs=%d samples=%d workers=%d batch=%d samples_per_iter=%d train_iters=%d eval_iters=%d' % ( 60 | args.num_epochs, total_samples, num_workers, args.batch_size, samples_per_iter, args.train_iters, args.eval_iters)) 61 | 62 | batch_sampler = DistributedSampler(dataset=train_data_set) 63 | train_data_loader = torch.utils.data.DataLoader(train_data_set, 64 | batch_size=args.batch_size, 65 | sampler=batch_sampler, 66 | drop_last=True) 67 | 68 | if eval_data_set is not None: 69 | eval_data_loader = torch.utils.data.DataLoader(eval_data_set, 70 | num_workers=0, 71 | pin_memory=True, 72 | batch_size=args.batch_size) 73 | else: 74 | eval_data_loader = None 75 | 76 | return train_data_loader, eval_data_loader 77 | 78 | def get_model(args, 79 | model_provider, 80 | is_gpu=True): 81 | 82 | model = model_provider(args) 83 | 84 | print_rank_0('number of parameters on rank : {}'.format( 85 | sum([p.nelement() for p in model.parameters()]))) 86 | 87 | if is_gpu: 88 | model.cuda(torch.cuda.current_device()) 89 | 90 | from torch.nn.parallel.distributed import DistributedDataParallel 91 | i = torch.cuda.current_device() 92 | model = DistributedDataParallel(model, device_ids=[i], output_device=i, find_unused_parameters=args.find_unused_parameters) 93 | else: 94 | raise NotImplementedError 95 | 96 | return model 97 | 98 | def get_params_for_weight_decay(module): 99 | weight_decay_params = {'params': []} 100 | no_weight_decay_params = {'params': [], 'weight_decay': 0.0} 101 | for module_ in module.modules(): 102 | if isinstance(module_, (torch.nn.LayerNorm)): 103 | no_weight_decay_params['params'].extend( 104 | [p for p in list(module_._parameters.values()) 105 | if p is not None]) 106 | else: 107 | weight_decay_params['params'].extend( 108 | [p for n, p in list(module_._parameters.items()) 109 | if p is not None and n != 'bias']) 110 | no_weight_decay_params['params'].extend( 111 | [p for n, p in list(module_._parameters.items()) 112 | if p is not None and n == 'bias']) 113 | 114 | return weight_decay_params, no_weight_decay_params 115 | 116 | def get_optimizer_param_groups(model): 117 | while isinstance(model, (DistributedDataParallel)): 118 | model = model.module 119 | param_groups = get_params_for_weight_decay(model) 120 | 121 | return param_groups 122 | 123 | def get_optimizer(param_groups, args): 124 | assert args.optimizer in ('adam', 'adamw'), 'optimizer must be adam or adamw!' 125 | 126 | if args.optimizer == 'adamw': 127 | optimizer = torch.optim.AdamW(param_groups, lr=args.lr, weight_decay=args.weight_decay) 128 | else: 129 | optimizer = torch.optim.Adam(param_groups, lr=args.lr, weight_decay=args.weight_decay) 130 | 131 | args.optimizer_runtime = optimizer 132 | 133 | return optimizer 134 | 135 | def setup_model_and_optimizer(args, model): 136 | param_groups = get_optimizer_param_groups(model) 137 | optimizer = get_optimizer(param_groups, args) 138 | 139 | return optimizer 140 | 141 | 142 | def load_model_setup_optimizer(args, model, need_optimizer = True): 143 | optimizer = None 144 | if args.load is not None: 145 | local_device_id = torch.distributed.get_rank() % args.gpus_per_node 146 | if local_device_id == 0: 147 | os.makedirs('tmp/') 148 | else: 149 | while not os.path.exists('tmp/done_{}.txt'.format(local_device_id - 1)): 150 | time.sleep(1) 151 | 152 | load_model_state_only(model, args, remove_prefix=None, remap_prefix=None) 153 | 154 | if need_optimizer: 155 | optimizer = setup_model_and_optimizer(args, model=model) 156 | 157 | fout = open('tmp/done_{}.txt'.format(local_device_id), 'w') 158 | fout.close() 159 | if local_device_id < (args.gpus_per_node - 1): 160 | while not os.path.exists('tmp/done_{}.txt'.format(args.gpus_per_node - 1)): 161 | time.sleep(1) 162 | elif need_optimizer: 163 | optimizer = setup_model_and_optimizer(args, model=model) 164 | 165 | torch.distributed.barrier() 166 | 167 | return model, optimizer 168 | 169 | def train_step(forward_step, data_iterator, model, optimizer, args): 170 | stats_reduced = loss_reduced = None 171 | for i in range(1): 172 | if i == 0: 173 | optimizer.zero_grad() 174 | 175 | loss, stats = forward_step(data_iterator, model, args) 176 | stats = torch.stack([loss.detach()] + [s.detach() for s in stats]) 177 | 178 | if not args.backward_step_contains_in_forward_step: 179 | loss.backward() 180 | 181 | stats_reduced = stats.clone() 182 | 183 | torch.distributed.all_reduce(stats_reduced.data) 184 | stats_reduced.data = stats_reduced.data / args.world_size 185 | loss_reduced = stats_reduced[0] 186 | 187 | if args.clip_grad > 0: 188 | nn.utils.clip_grad.clip_grad_norm_(model.parameters(), args.clip_grad) 189 | 190 | optimizer.step() 191 | 192 | return loss_reduced, stats_reduced 193 | 194 | def evaluate(forward_step, model, data_iterator, args): 195 | model.eval() 196 | 197 | with torch.no_grad(): 198 | 199 | sum_loss = 0 200 | sum_samples = 0 201 | iteration = 0 202 | 203 | while iteration < args.eval_iters: 204 | try: 205 | loss, stats = forward_step(data_iterator, model, args) 206 | except StopIteration: 207 | break 208 | 209 | loss = loss.item() 210 | 211 | sum_loss += loss 212 | 213 | sum_samples += 1 214 | 215 | iteration += 1 216 | 217 | torch.distributed.barrier() 218 | 219 | result = torch.LongTensor([sum_loss, sum_samples]).to(torch.cuda.current_device()) 220 | torch.distributed.all_reduce(result) 221 | sum_loss = result[0].item() 222 | sum_samples = result[1].item() 223 | 224 | if args.local_rank == 0: 225 | print('[evaluation]global eval loss = %.4f (%d / %d)' % (sum_loss / sum_samples, args.eval_iters, sum_samples)) 226 | 227 | def train(forward_step_func, model, optimizer, 228 | train_data_iterator, valid_data_set, args): 229 | model.train() 230 | 231 | # Tracking loss. 232 | sum_iter = sum_loss = 0 233 | sum_stats = None 234 | 235 | timer = Timer() 236 | timer.reset() 237 | timer.start() 238 | while args.iteration < args.train_iters: 239 | loss, stats = train_step(forward_step_func, train_data_iterator, model, optimizer, args) 240 | 241 | loss = loss.item() 242 | stats = stats.data.detach().tolist() 243 | 244 | args.iteration = args.iteration + 1 245 | 246 | sum_iter += 1 247 | sum_loss += loss 248 | if sum_stats is None: 249 | sum_stats = [0.0] * len(stats) 250 | sum_stats = [a + b for a, b in zip(sum_stats, stats)] 251 | 252 | if args.iteration % args.log_interval == 0: 253 | use_time = timer.stop() 254 | per_iteration_use_time = use_time * 1000.0 / sum_iter 255 | 256 | if args.local_rank == 0: 257 | report_iteration_metrics(per_iteration_use_time, sum_loss / sum_iter, args.iteration, args.train_iters) 258 | print('stats: [', ', '.join(['%.6f' % (x / sum_iter) for x in sum_stats]), ']') 259 | 260 | sum_iter = sum_loss = 0 261 | sum_stats = None 262 | 263 | if args.iteration % args.save_interval == 0: 264 | args.final_saved_iteration=args.iteration 265 | save_checkpoint(args.iteration, model, optimizer, args) 266 | 267 | if valid_data_set is not None and args.iteration % args.eval_interval == 0: 268 | evaluate(forward_step_func, model, iter(valid_data_set), args) 269 | model.train() 270 | 271 | if args.iteration % args.log_interval == 0: 272 | timer.reset() 273 | timer.start() 274 | 275 | def parse(torch_tensor): 276 | arr = torch_tensor.cpu().numpy() 277 | 278 | if len(arr.shape) == 0: 279 | arr = [arr] 280 | 281 | return ','.join(list([str(item) for item in arr])) 282 | 283 | def inference(forward_step_func, model, writer, 284 | infer_data_iterator, args): 285 | model.eval() 286 | 287 | iter_index = 0 288 | samples = 0 289 | 290 | with torch.no_grad(): 291 | while iter_index < args.infer_iters: 292 | stats = forward_step_func(infer_data_iterator, model, args) 293 | samples += stats[0].shape[0] 294 | 295 | iter_index += 1 296 | indices = [idx for idx in range(len(stats))] 297 | for j in range(stats[0].shape[0]): 298 | ret = [parse(item[j]) for item in stats] 299 | 300 | writer.write('\t'.join(ret) + '\n') 301 | 302 | if iter_index % args.log_interval == 0: 303 | print_rank_0("%d samples inference finished!" %(samples)) 304 | samples = 0 305 | 306 | writer.close() 307 | 308 | def report_iteration_metrics(per_iteration_use_time, loss, step, total_step): 309 | log_string = '\n' 310 | log_string += str(datetime.now()) 311 | log_string += ' iteration %d(all:%d) ||' % (step, total_step) 312 | log_string += ' loss %.6f ||' % loss 313 | log_string += ' time per iteration (ms) %.1f |' % per_iteration_use_time 314 | 315 | print(log_string) 316 | 317 | def task_dispatcher(train_eval_dataset_provider = None, 318 | inference_dataset_provider = None, 319 | model_provider = None, 320 | forward_func = None, 321 | personalized_args_provider = None, 322 | training_post_processing_func = None, 323 | onnx_model_export_func = None): 324 | ''' 325 | task dispatcher, now support training task and inference task 326 | :param train_eval_dataset_provider: a training used function that input args and return train&valid datasets 327 | :param inference_dataset_provider: a inference used function that input args and return inference dataset 328 | :param model_provider: a function that input args and return user-defined model 329 | :param forward_func: a function that input data iterator and args, and return 330 | 1) a list containing loss and other metrics that need to be printed if args.task_type = train 331 | 2) a list containing prediction data that need to be printed to the local files if args.task_type = inference 332 | :param personalized_args_provider: a function that provide for users to define model specific parameters 333 | :param training_post_processing_func: a function that provide for users to do sth after training task finished 334 | :param onnx_model_export_func: onnx model export func 335 | :return: 336 | ''' 337 | 338 | args = get_args(personalized_args_provider) 339 | 340 | # Pytorch distributed. 341 | initialize_distribution_env(args) 342 | 343 | # Random seeds for reproducibility. 344 | set_random_seed(seed=args.seed) 345 | 346 | assert args.task_type == 'train' or args.task_type == 'inference' or args.task_type == 'onnx_export', 'task type must be train or inference or onnx_export' 347 | if args.task_type == 'train': 348 | print_rank_0('*****running task : train*****') 349 | assert train_eval_dataset_provider is not None and model_provider is not None and forward_func is not None, \ 350 | '[train_eval_dataset_provider, model_provider, forward_func] cannot be None when training task' 351 | training_backbone(train_eval_dataset_provider, model_provider, forward_func, training_post_processing_func, args) 352 | elif args.task_type == 'inference': 353 | print_rank_0('*****running task : inference*****') 354 | assert inference_dataset_provider is not None and model_provider is not None and forward_func is not None, \ 355 | '[inference_dataset_provider, model_provider, forward_func] cannot be None when inference task' 356 | inference_backbone(inference_dataset_provider, model_provider, forward_func, args) 357 | elif args.task_type == 'onnx_export': 358 | onnx_export_backbone(onnx_model_export_func, args) 359 | pass 360 | 361 | def onnx_export_backbone(onnx_model_export_func 362 | ,args): 363 | ''' 364 | onnx export backbone 365 | :param onnx_model_export_func: onnx model export func 366 | :param args: a arguments dictionary 367 | :return: 368 | ''' 369 | onnx_model_export_func(args) 370 | 371 | def inference_backbone(inference_dataset_provider, 372 | model_provider, 373 | forward_func, 374 | args): 375 | ''' 376 | inference task backbone 377 | :param inference_dataset_provider: a inference used function that input args and return inference dataset 378 | :param model_provider: a function that input args and return user-defined model 379 | :param forward_func: a function that input data iterator and args, and return a list containing prediction data that need to be printed to the local files 380 | :param args: a arguments dictionary 381 | :return: 382 | ''' 383 | test_data = make_inference_data_loader(args, inference_dataset_provider) 384 | print('Inference data preparation done.') 385 | 386 | model = get_model(args, model_provider) 387 | 388 | model, _ = load_model_setup_optimizer(args, model) 389 | 390 | writer = make_local_writer(args) 391 | 392 | inference(forward_func, model, writer, iter(test_data), args) 393 | 394 | def training_backbone(train_eval_dataset_provider, 395 | model_provider, 396 | forward_func, 397 | training_post_processing_func, 398 | args): 399 | ''' 400 | training task backbone 401 | :param train_eval_dataset_provider: a training used function that input args and return train&valid dataset 402 | :param model_provider: a function that input args and return user-defined model 403 | :param forward_func: a function that input data iterator and args, and return a dict containing the loss and other metrics that need to be printed 404 | :param args: a arguments dictionary 405 | :return: 406 | ''' 407 | model = get_model(args, model_provider) 408 | 409 | train_data, eval_data = make_data_loader(args, train_eval_dataset_provider) 410 | print('Data preparation done.') 411 | 412 | args.iteration = 0 413 | train_data_iterator = iter(train_data) 414 | 415 | if torch.distributed.get_rank() == 0: 416 | print_args(args) 417 | 418 | model, optimizer = load_model_setup_optimizer(args, model) 419 | 420 | train(forward_func, model, optimizer, train_data_iterator, eval_data, args) 421 | 422 | if training_post_processing_func is not None: 423 | training_post_processing_func(model, args) --------------------------------------------------------------------------------