├── utils ├── __init__.py ├── optimizer.py ├── logger.py ├── utils.py └── parser.py ├── unifiedssr_env.yml ├── data.py ├── predict.py ├── README.md ├── metrics.py ├── train.py ├── pretrain.py ├── modules.py ├── model.py ├── LICENSE └── batch.py /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /unifiedssr_env.yml: -------------------------------------------------------------------------------- 1 | name: unifiedssr 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - python=3.8.17 7 | - torch=1.13.1 8 | - numpy=1.24.4 9 | - pandas=2.0.3 10 | - scikit-learn=1.3.0 11 | - matplotlib=3.7.1 12 | - nltk=3.8.1 13 | - joblib=1.3.0 14 | -------------------------------------------------------------------------------- /utils/optimizer.py: -------------------------------------------------------------------------------- 1 | class NoamOpt: 2 | "Optim wrapper that implements rate." 3 | 4 | def __init__(self, model_size, factor, warmup, optimizer): 5 | self.optimizer = optimizer 6 | self._step = 0 7 | self.warmup = warmup 8 | self.factor = factor 9 | self.model_size = model_size 10 | self._rate = 0 11 | 12 | def step(self): 13 | "Update parameters and rate" 14 | self._step += 1 15 | rate = self.rate() 16 | for p in self.optimizer.param_groups: 17 | p['lr'] = rate 18 | self._rate = rate 19 | self.optimizer.step() 20 | return rate 21 | 22 | def rate(self, step=None): 23 | "Implement `lr` above" 24 | if step is None: 25 | step = self._step 26 | return self.factor * \ 27 | (self.model_size ** (-0.5) * 28 | min(step ** (-0.5), step * self.warmup ** (-1.5))) 29 | 30 | 31 | if __name__ == "__main__": 32 | # for test only 33 | pass -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import joblib 4 | from torch.utils.data import Dataset 5 | 6 | 7 | class UnifiedDataset(Dataset): 8 | def __init__(self, phase, tasks, data_root='', logging=False): 9 | assert os.path.exists(data_root) 10 | assert phase in ['pretrain', 'finetune'], 'Phase must be pretrain or finetune.' 11 | assert set(tasks).issubset({'recommendation', 'search'}), 'Task must be specified as in recommendation and search.' 12 | self.data_root = data_root 13 | self.phase = phase 14 | self.tasks = tasks 15 | self.tasks_num_sample = [] 16 | self.user_seq = [] 17 | 18 | for task in tasks: 19 | data = joblib.load(os.path.join(data_root, f'{phase}_{task}.pkl')) 20 | self.tasks_num_sample.append(len(data)) 21 | self.user_seq.extend(data) 22 | 23 | if logging: 24 | self.print_info() 25 | 26 | def __len__(self): 27 | return len(self.user_seq) 28 | 29 | def __getitem__(self, index): 30 | return self.user_seq[index] 31 | 32 | def print_info(self): 33 | logging.info(f'current data path: {self.data_root}') 34 | logging.info(f'current phase: {self.phase}') 35 | logging.info(f'current task: {self.tasks}') 36 | for num_sample, task in zip(self.tasks_num_sample, self.tasks): 37 | logging.info(f"the number of samples for task {task}: {num_sample}") 38 | 39 | 40 | if __name__ == '__main__': 41 | # for test only 42 | pass 43 | 44 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | 4 | 5 | def create_log_id(dir_path, name): 6 | log_count = 0 7 | file_path = os.path.join(dir_path, '{}_{:d}.log'.format(name, log_count)) 8 | while os.path.exists(file_path): 9 | log_count += 1 10 | file_path = os.path.join(dir_path, '{}_{:d}.log'.format(name, log_count)) 11 | return log_count 12 | 13 | 14 | def logging_config(folder=None, 15 | name=None, 16 | level=logging.DEBUG, 17 | console_level=logging.DEBUG, 18 | no_console=True): 19 | 20 | if not os.path.exists(folder): 21 | os.makedirs(folder) 22 | for handler in logging.root.handlers: 23 | logging.root.removeHandler(handler) 24 | logging.root.handlers = [] 25 | logpath = os.path.join(folder, name + ".log") 26 | print("All logs will be saved to %s" %logpath) 27 | 28 | logging.root.setLevel(level) 29 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 30 | logfile = logging.FileHandler(logpath) 31 | logfile.setLevel(level) 32 | logfile.setFormatter(formatter) 33 | logging.root.addHandler(logfile) 34 | 35 | if not no_console: 36 | logconsole = logging.StreamHandler() 37 | logconsole.setLevel(console_level) 38 | logconsole.setFormatter(formatter) 39 | logging.root.addHandler(logconsole) 40 | return folder 41 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import OrderedDict 3 | import torch 4 | 5 | 6 | def save_model(model, model_dir, current_epoch, last_best_epoch=None): 7 | if not os.path.exists(model_dir): 8 | os.makedirs(model_dir) 9 | 10 | if last_best_epoch is not None and current_epoch != last_best_epoch: 11 | model_state_file = os.path.join(model_dir, 'model_best.pth'.format(current_epoch)) 12 | else: 13 | model_state_file = os.path.join(model_dir, 'model_{}.pth'.format(current_epoch)) 14 | 15 | torch.save({'model_state_dict': model.state_dict(), 'epoch': current_epoch}, model_state_file) 16 | 17 | 18 | def load_model(model, model_path): 19 | checkpoint = torch.load(model_path, map_location=torch.device('cpu')) 20 | 21 | try: 22 | model.load_state_dict(checkpoint['model_state_dict'], False) 23 | # model.load_state_dict(checkpoint['model_state_dict']) 24 | except RuntimeError: 25 | state_dict = OrderedDict() 26 | for k, v in checkpoint['model_state_dict'].items(): 27 | k_ = k[7:] 28 | # remove 'module.' of DistributedDataParallel instance 29 | state_dict[k_] = v 30 | model.load_state_dict(state_dict) 31 | 32 | model.eval() 33 | return model 34 | 35 | 36 | def degrade_saved_model(model_path): 37 | checkpoint = torch.load(model_path, map_location=torch.device('cpu')) 38 | save_path = os.path.join(os.path.dirname(model_path), 'degrade_version') 39 | 40 | if not os.path.exists(save_path): 41 | os.makedirs(save_path) 42 | 43 | torch.save(checkpoint, os.path.join(save_path, os.path.basename(model_path)), _use_new_zipfile_serialization=False) 44 | 45 | 46 | def early_stopping(cur_scores, best_scores, stopping_count, patient=100, logging=None): 47 | update_flag = False 48 | for cur_score, best_score in zip(cur_scores, best_scores): 49 | if cur_score > best_score: 50 | update_flag = True 51 | 52 | if update_flag == True: 53 | stopping_count = 0 54 | best_scores = cur_scores 55 | else: stopping_count += 1 56 | 57 | if stopping_count >= patient: 58 | if logging: 59 | logging.info("Early stopping is trigger at step: {}".format(stopping_count)) 60 | should_stop = True 61 | else: 62 | if logging: 63 | logging.info("Current stopping count: {}".format(stopping_count)) 64 | should_stop = False 65 | return best_scores, stopping_count, should_stop 66 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | from utils.parser import parse_args 2 | from utils.logger import create_log_id, logging_config 3 | from utils.utils import load_model 4 | from metrics import evaluate_product 5 | from data import UnifiedDataset 6 | from batch import collate_test 7 | from model import get_model 8 | import numpy as np 9 | import pandas as pd 10 | import os, logging, random 11 | import torch 12 | from torch.utils.data import DataLoader 13 | 14 | def predict_product(args): 15 | # log 16 | log_name = f'log_test_{args.tasks[0]}' 17 | log_save_id = create_log_id(args.save_dir, name=log_name) 18 | logging_config(folder=args.save_dir, name='{}_{:d}'.format(log_name, log_save_id), no_console=False) 19 | logging.info(args) 20 | 21 | # GPU / CPU 22 | args.use_cuda = args.use_cuda & torch.cuda.is_available() 23 | device = torch.device("cuda:{}".format(args.cuda_idx) if args.use_cuda else "cpu") 24 | 25 | # load data 26 | data = UnifiedDataset(args.phase, args.tasks, args.data_root, logging) 27 | 28 | data_loader = DataLoader(data, 29 | shuffle=False, 30 | batch_size=args.test_batch_size, 31 | collate_fn=lambda x: collate_test(x, args)) 32 | 33 | # load model 34 | model = get_model(args) 35 | model = load_model(model, args.trained_model_path).to(device) 36 | 37 | # evaluate 38 | hits, ndcgs = evaluate_product(model, data_loader, len(data), args, device) 39 | for k_idx, topk in enumerate(args.k_list): 40 | logging.info( 41 | 'Evaluation (K={}): HR {:.4f} NDCG {:.4f}'.format(topk, hits[k_idx], ndcgs[k_idx])) 42 | 43 | # initialize metrics 44 | result_save_file = os.path.join(args.save_dir, 'test_results.csv') 45 | init_metrics = pd.DataFrame(['HR@{}'.format(k) for k in args.k_list] + 46 | ['NDCG@{}'.format(k) for k in args.k_list]).transpose() 47 | init_metrics.to_csv(result_save_file, mode='a', header=False, sep='\t', index=False) 48 | metrics = pd.DataFrame(hits.tolist() + ndcgs.tolist()).transpose() 49 | metrics.to_csv(result_save_file, mode='a', header=False, sep='\t', index=False) 50 | return hits, ndcgs 51 | 52 | 53 | if __name__ == "__main__": 54 | args = parse_args() 55 | 56 | # Seed 57 | random.seed(args.seed) 58 | np.random.seed(args.seed) 59 | torch.manual_seed(args.seed) 60 | 61 | # Evaluation 62 | args.phase = 'finetune' 63 | args.tasks = ['recommendation'] # 'recommendation', 'search' 64 | 65 | pretrain_dir = 'models/Amazon_Clothing/pretrain_recommendation_search/finetune_recommendation' 66 | # pretrain_dir = 'models/Amazon_Clothing/pretrain_recommendation_search/finetune_search' 67 | args.trained_model_path = os.path.join(pretrain_dir, f'model.pth') 68 | args.save_dir = pretrain_dir 69 | predict_product(args) 70 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # (WWW '24) UnifiedSSR: A Unified Framework of Sequential Search and Recommendation 2 | 3 | This is the Pytorch implementation of UnifiedSSR for joint learning of user behavior history in both search and recommendation scenarios. 4 | 5 | ## Environments 6 | 7 | - python=3.8.17 8 | - torch=1.13.1 9 | - numpy=1.24.4 10 | - pandas=2.0.3 11 | - scikit-learn=1.3.0 12 | - matplotlib=3.7.1 13 | - nltk=3.8.1 14 | - joblib=1.3.0 15 | 16 | You can create the environment via `conda env create -f unifiedssr_env.yaml`. 17 | 18 | ## Run the Codes 19 | 20 | 1. Pretrain: Customize parameters in `utils/parser.py`, and then run `pretrain.py` to pretrain the model. The pretrained model will be saved in `models/`. 21 | 2. Finetune: Modify `args.trained_model_path` in `train.py` to specify the path to the pretrained model, and then run `train.py` to finetune the model. 22 | 3. Evaluate: Modify `args.tasks` in `predict.py` to specify the path to the trained model, and then run `predict.py` to evaluate the model. Note that evaluation can only be conducted on one task at a time. 23 | 24 | We provide a pretrained model and task-specific finetuned models for the Amazon-CL dataset as follows: 25 | 1. Pretrained model: `models/Amazon_Clothing/pretrain_recommendation_search/model.pth` 26 | 2. Finetuned model for search: `models/Amazon_Clothing/pretrain_recommendation_search/finetune_search/model.pth` 27 | 3. Finetuned model for recommendation: `models/Amazon_Clothing/pretrain_recommendation_search/finetune_recommendation/model.pth` 28 | 29 | ## Datasets 30 | 31 | * Original Dataset 32 | * Amazon dataset can be found in [here](https://nijianmo.github.io/amazon/index.html). 33 | * JDsearch dataset can be found in [here](https://github.com/rucliujn/JDsearch). 34 | * Preprocess 35 | * Use the provided preprocessed Amazon-CL dataset in `datasets/Amazon_Clothing`. 36 | * Feel free to contact the author for more details of the data preprocessing. 37 | 38 | 39 | Notes: The dataset and model files are large. Please download them from [Google Drive](https://drive.google.com/drive/folders/1GShl2vju5_uXHRgcd1UZinJhmgmDzSw_?usp=share_link) and place them in the project folder. 40 | 41 | ## Citation 42 | 43 | If you find our codes helpful, please kindly cite the following papers: 44 | 45 | ``` 46 | @article{unifiedssr, 47 | author = {Jiayi Xie and 48 | Shang Liu and 49 | Gao Cong and 50 | Zhenzhong Chen}, 51 | title = {UnifiedSSR: {A} Unified Framework of Sequential Search and Recommendation}, 52 | journal = {CoRR}, 53 | volume = {abs/2310.13921}, 54 | year = {2023}, 55 | url = {https://doi.org/10.48550/arXiv.2310.13921}, 56 | doi = {10.48550/ARXIV.2310.13921}, 57 | eprinttype = {arXiv}, 58 | eprint = {2310.13921} 59 | } 60 | ``` -------------------------------------------------------------------------------- /utils/parser.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pandas as pd 3 | import os 4 | import sys 5 | sys.path.append(os.getcwd()) 6 | 7 | 8 | def parse_args(): 9 | parser = argparse.ArgumentParser() 10 | 11 | # basic 12 | parser.add_argument('--seed', type=int, default=888) 13 | parser.add_argument('--use_cuda', type=bool, default=True) 14 | parser.add_argument('--cuda_idx', type=int, default=0) 15 | 16 | # pretrain and train 17 | parser.add_argument('--num_epoch', type=int, default=100) 18 | parser.add_argument('--train_batch_size', type=int, default=128) 19 | parser.add_argument('--train_num_neg', type=int, default=4) 20 | parser.add_argument('--start_epoch_idx', type=int, default=1) 21 | parser.add_argument('--learning_rate', type=float, default=None) 22 | parser.add_argument('--opt_factor', type=float, default=1) 23 | parser.add_argument('--opt_warmup', type=int, default=4000) 24 | parser.add_argument('--print_every', type=int, default=8, 25 | help='Iteration interval of printing loss.') 26 | parser.add_argument('--save_every', type=int, default=4, 27 | help='Iteration interval of saving model.') 28 | parser.add_argument('--evaluate_every', type=int, default=4, 29 | help='Epoch interval of evaluation.') 30 | 31 | # validation and test 32 | parser.add_argument('--test_batch_size', type=int, default=50) 33 | parser.add_argument('--test_num_neg', type=int, default=99) 34 | parser.add_argument('--test_neg', type=bool, default=True) 35 | parser.add_argument('--k_list', type=list, default=[5, 10]) 36 | 37 | # model 38 | parser.add_argument('--corr_factor', type=float, default=0.1) 39 | parser.add_argument('--num_head', type=int, default=4, 40 | choices=[1, 2, 4, 8]) 41 | parser.add_argument('--enc_num_layer', type=int, default=2, 42 | choices=[1, 2, 3]) 43 | parser.add_argument('--sub_seq_num', type=int, default=2, 44 | choices=[1, 2, 3, 4]) 45 | parser.add_argument('--emb_size', type=int, default=32, 46 | choices=[16, 32, 48, 64, 80]) 47 | parser.add_argument('--hid_size', type=int, default=None) 48 | parser.add_argument('--dropout', type=float, default=0.5) 49 | parser.add_argument('--trained_model_path', type=str, default='') 50 | 51 | # data 52 | parser.add_argument('--data_name', type=str, default='Amazon_Clothing', choices=['JDsearch', 'Amazon_Clothing', 'Amazon_Electronics']) 53 | parser.add_argument('--data_root', type=str, default='./datasets') 54 | parser.add_argument('--padding_value', type=int, default=0) 55 | parser.add_argument('--query_max_len', type=int, default=50) 56 | 57 | args = parser.parse_args() 58 | 59 | args.data_root = os.path.join(args.data_root, args.data_name) 60 | data_meta_path = os.path.join(args.data_root, 'meta.csv') 61 | user_num, product_num, term_num = pd.read_csv(data_meta_path, sep='\t').values.squeeze() 62 | args.user_vocab = user_num + 1 63 | args.product_vocab = product_num + 1 64 | args.term_vocab = term_num + 1 65 | args.bos_id = args.term_vocab + 1 # Begin-of-Sentence 66 | args.eos_id = args.term_vocab # End-of-Sentence 67 | 68 | return args 69 | 70 | 71 | if __name__ == "__main__": 72 | # for test only 73 | pass -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def evaluate_product(trained_model, data_loader, total_sample_num, args, device='cuda'): 6 | trained_model.eval() 7 | k_list = args.k_list or [5, 10] 8 | hit_at_k = np.array([0] * len(k_list)) 9 | ndcg_at_k = np.array([0] * len(k_list)) 10 | 11 | with torch.no_grad(): 12 | for idx, (cur_task, batch_data) in enumerate(data_loader): 13 | for key in batch_data.keys(): 14 | if isinstance(batch_data[key], list): 15 | batch_data[key] = [[i.to(device) for i in data] for data in batch_data[key]] 16 | else: 17 | batch_data[key] = batch_data[key].to(device) 18 | 19 | out = trained_model.forward(cur_task, batch_data) 20 | 21 | if cur_task == 'recommendation': 22 | if args.corr_factor > 0: 23 | sub_seq_wins = trained_model.get_sub_seq_wins(out) 24 | out, _, _ = trained_model.intra_corr_loss(out, sub_seq_wins, batch_data['pids_mask']) 25 | 26 | if args.test_neg: # rank candidates 27 | out = trained_model.next_product_predict(out, batch_data['pid_last_idx'], batch_data['pid_pred']) 28 | tgt = torch.zeros((out.shape[0], 1)).long() 29 | else: # rank all products 30 | out = trained_model.next_product_predict(out, batch_data['pid_last_idx']) 31 | tgt = batch_data['pid_pred'][:,0].unsqueeze(-1).to('cpu') 32 | elif cur_task == 'search': 33 | p_out, q_out = out 34 | p_out = p_out[:, 1:, :] 35 | q_out = q_out[:, 1:, :] 36 | if args.corr_factor > 0: 37 | mask = batch_data['pids_mask'][:, :, 1:] 38 | p_sub_seq_wins = trained_model.get_sub_seq_wins(p_out) 39 | q_sub_seq_wins = trained_model.get_sub_seq_wins(q_out) 40 | p_out, q_out, _ = trained_model.inter_corr_loss(p_out, p_sub_seq_wins, q_out, q_sub_seq_wins, mask) 41 | 42 | if args.test_neg: # rank candidates 43 | out = trained_model.next_product_search(p_out, q_out, batch_data['pid_last_idx'], batch_data['pid_pred']) 44 | tgt = torch.zeros((out.shape[0], 1)).long() 45 | else: # rank all products 46 | out = trained_model.next_product_search(p_out, q_out, batch_data['pid_last_idx']) 47 | tgt = batch_data['pid_pred'][:,0].unsqueeze(-1).to('cpu') 48 | 49 | _, out_rank = torch.sort(out, descending=True) 50 | if out_rank.device.type == 'cuda': 51 | out_rank = out_rank.cpu() 52 | 53 | for idx_k, k in enumerate(k_list): 54 | hit_at_k[idx_k] += hit_at_k_per_batch(out_rank, tgt, k) 55 | ndcg_at_k[idx_k] += ndcg_at_k_per_batch(out_rank, tgt, k) 56 | 57 | hit_at_k = hit_at_k / total_sample_num 58 | ndcg_at_k = ndcg_at_k / total_sample_num 59 | return hit_at_k, ndcg_at_k 60 | 61 | 62 | def hit_at_k_per_batch(pred, tgt, k): 63 | hits_num = 0 64 | for i in range(len(tgt)): 65 | tgt_set = set(tgt[i].numpy()) 66 | pred_set = set(pred[i][:k].numpy()) 67 | hits_num += len(tgt_set & pred_set) 68 | return hits_num 69 | 70 | 71 | def recall_at_k_per_batch(pred, tgt, k): 72 | sum_recall = 0. 73 | num_sample = 0 74 | for i in range(len(tgt)): 75 | tgt_set = set(tgt[i].numpy()) 76 | pred_set = set(pred[i][:k].numpy()) 77 | if len(tgt_set) != 0: 78 | sum_recall += len(tgt_set & pred_set) / float(len(tgt_set)) 79 | num_sample += 1 80 | return num_sample, sum_recall 81 | 82 | 83 | def ndcg_at_k_per_batch(pred, tgt, k): 84 | ndcg_score = 0. 85 | for i in range(len(tgt)): 86 | sample_pred = pred[i, :k].numpy() 87 | sample_tgt = tgt[i].numpy() 88 | ndcg_score += ndcg_at_k_per_sample(sample_pred, sample_tgt) 89 | return ndcg_score 90 | 91 | 92 | def ndcg_at_k_per_sample(pred, tgt, method=1): 93 | r = np.zeros_like(pred, dtype=np.float32) 94 | ideal_r = np.zeros_like(pred, dtype=np.float32) 95 | for i, v in enumerate(pred): 96 | if v in tgt and v not in pred[:i]: 97 | r[i] = 1. 98 | ideal_r[:len(tgt)] = 1. 99 | 100 | idcg = dcg_at_k_per_sample(ideal_r, method) 101 | if not idcg: 102 | return 0. 103 | return dcg_at_k_per_sample(r, method) / idcg 104 | 105 | 106 | def dcg_at_k_per_sample(r, method=1): 107 | if r.size: 108 | if method == 0: 109 | return r[0] + np.sum(r[1:] / np.log2(np.arange(2, r.size + 1))) 110 | elif method == 1: # 01相关性,仅相关时(r=1)有相加 111 | return np.sum(r / np.log2(np.arange(2, r.size + 2))) 112 | else: 113 | raise ValueError('method must be 0 or 1.') 114 | return 0. 115 | 116 | 117 | if __name__ == "__main__": 118 | # for test only 119 | pass -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from utils.parser import parse_args 2 | from utils.logger import create_log_id, logging_config 3 | from utils.optimizer import NoamOpt 4 | from utils.utils import save_model, load_model, early_stopping 5 | from metrics import evaluate_product 6 | from data import UnifiedDataset 7 | from batch import BatchSampler, collate_train, collate_val 8 | from model import get_model 9 | from pretrain import run_epoch 10 | import numpy as np 11 | import pandas as pd 12 | import os, time, logging, random 13 | import torch 14 | import torch.nn as nn 15 | from torch.utils.data import DataLoader 16 | 17 | 18 | def train(args): 19 | # log 20 | log_name = 'log_train' 21 | log_save_id = create_log_id(args.save_dir, name=log_name) 22 | logging_config(folder=args.save_dir, name='{}_{:d}'.format(log_name, log_save_id), no_console=False) 23 | logging.info(args) 24 | 25 | # GPU / CPU 26 | args.use_cuda = args.use_cuda & torch.cuda.is_available() 27 | device = torch.device("cuda:{}".format(args.cuda_idx) if args.use_cuda else "cpu") 28 | 29 | # load data 30 | data = UnifiedDataset(args.phase, args.tasks, args.data_root, logging) 31 | 32 | batch_sampler = BatchSampler(data, args.train_batch_size) 33 | data_loader = DataLoader(data, 34 | batch_sampler=batch_sampler, 35 | collate_fn=lambda x: collate_train(x, args)) 36 | batch_num = len(data_loader) 37 | 38 | # construct model 39 | model = get_model(args) 40 | model.to(device) 41 | logging.info(model) 42 | 43 | if os.path.isfile(args.trained_model_path): 44 | logging.info("Loading pre-trained model: {}".format(args.trained_model_path)) 45 | model = load_model(model, args.trained_model_path) 46 | else: 47 | logging.info('Parameters initializing ...') 48 | for p in model.parameters(): 49 | if p.dim() > 1: 50 | nn.init.xavier_uniform_((p)) 51 | model.seq_partition.reset_offset() 52 | 53 | # define optimizer 54 | optimizer = NoamOpt(args.emb_size, args.opt_factor, args.opt_warmup, 55 | torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) 56 | logging.info(optimizer) 57 | 58 | # train 59 | init_metrics = pd.DataFrame(['Epoch_idx'] + ['HR@{}'.format(k) for k in args.k_list] + 60 | ['NDCG@{}'.format(k) for k in args.k_list]).transpose() 61 | init_metrics.to_csv(os.path.join(args.save_dir, 'train_results.csv'), mode='a', header=False, 62 | sep='\t', index=False) 63 | cur_best_scores = [0., 0.] 64 | stopping_count = 0 65 | should_stop = False 66 | best_epoch_idx = -1 67 | best_result = -np.inf 68 | best_results = [] 69 | 70 | assert len(args.tasks) == 1, 'Phase specify a single downstream task to train & test.' 71 | start_epoch_idx = args.start_epoch_idx or 1 72 | for epoch_idx in range(start_epoch_idx, args.num_epoch + start_epoch_idx): 73 | # train and save model 74 | run_epoch(args, model, data_loader, optimizer, epoch_idx, batch_num, device) 75 | 76 | # evaluate 77 | if (epoch_idx % args.evaluate_every) == 0: 78 | time3 = time.time() 79 | val_data_loader = DataLoader(data, 80 | shuffle=False, 81 | batch_size=args.test_batch_size, 82 | collate_fn=lambda x: collate_val(x, args)) 83 | 84 | hits, ndcgs = evaluate_product(model, val_data_loader, len(data), args, device) 85 | for k_idx, topk in enumerate(args.k_list): 86 | logging.info('Evaluation (K={}): Epoch {:04d} | Total Time {:.1f}s | HR {:.4f} NDCG {:.4f}'.format( 87 | topk, epoch_idx, time.time() - time3, hits[k_idx], ndcgs[k_idx])) 88 | 89 | cur_best_scores, stopping_count, should_stop = early_stopping([hits[-1], ndcgs[-1]], cur_best_scores, stopping_count, 3, logging) 90 | 91 | # save the best result 92 | if ndcgs[0] > best_result: 93 | best_result = ndcgs[0] 94 | best_results = hits.tolist() + ndcgs.tolist() 95 | save_model(model, args.save_dir, epoch_idx, best_epoch_idx) 96 | best_epoch_idx = epoch_idx 97 | 98 | metrics = pd.DataFrame([epoch_idx] + hits.tolist() + ndcgs.tolist()).transpose() 99 | metrics.to_csv(os.path.join(args.save_dir, 'train_results.csv'), mode='a', header=False, sep='\t', 100 | index=False) 101 | 102 | if should_stop == True: 103 | break 104 | 105 | best_metrics = pd.DataFrame([best_epoch_idx] + best_results).transpose() 106 | best_metrics.to_csv(os.path.join(args.save_dir, 'train_results.csv'), mode='a', header=False, sep='\t', index=False) 107 | 108 | 109 | if __name__ == '__main__': 110 | args = parse_args() 111 | 112 | # Seed 113 | random.seed(args.seed) 114 | np.random.seed(args.seed) 115 | torch.manual_seed(args.seed) 116 | 117 | # Finetune and Evaluation 118 | args.phase = 'finetune' 119 | args.tasks = ['recommendation'] # 'recommendation', 'search' 120 | 121 | pretrain_dir = 'models/Amazon_Clothing/pretrain_recommendation_search' 122 | args.trained_model_path = os.path.join(pretrain_dir, f'model.pth') 123 | args.save_dir = os.path.join(pretrain_dir, f'{"_".join([args.phase] + args.tasks)}/{time.strftime("%Y%m%d_%H%M%S")}') 124 | train(args) 125 | -------------------------------------------------------------------------------- /pretrain.py: -------------------------------------------------------------------------------- 1 | from utils.parser import parse_args 2 | from utils.logger import create_log_id, logging_config 3 | from utils.optimizer import NoamOpt 4 | from utils.utils import save_model, load_model 5 | from data import UnifiedDataset 6 | from batch import BatchSampler, collate_pretrain 7 | from model import get_model 8 | import os, time, logging, random 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | from torch.utils.data import DataLoader 13 | 14 | 15 | def run_epoch(args, model, data_loader, optimizer, epoch_idx, batch_num, device): 16 | time1 = time.time() 17 | model.train() 18 | total_out_loss = 0. 19 | total_corr_loss = 0. 20 | total_batch_loss = 0. 21 | for idx, (cur_task, batch_data) in enumerate(data_loader): 22 | batch_idx = idx + 1 23 | time2 = time.time() 24 | for key in batch_data.keys(): 25 | if isinstance(batch_data[key], list): 26 | batch_data[key] = [[i.to(device) for i in data] for data in batch_data[key]] 27 | else: 28 | batch_data[key] = batch_data[key].to(device) 29 | out = model.forward(cur_task, batch_data) 30 | 31 | if cur_task == 'recommendation': 32 | if args.corr_factor > 0: 33 | sub_seq_wins = model.get_sub_seq_wins(out) 34 | out, _, corr_loss = model.intra_corr_loss(out, sub_seq_wins, batch_data['pids_mask']) 35 | else: 36 | corr_loss = torch.tensor(0.) 37 | out_loss = model.loss(out.view(-1, out.size(-1)), 38 | batch_data['pids_mask'].view(-1), 39 | batch_data['pids_tgt'].view(-1), 40 | batch_data['pids_neg'].view(-1, args.train_num_neg)) 41 | else: # cur_task == 'search': 42 | p_out, q_out = out 43 | p_out = p_out[:, 1:, :] 44 | q_out = q_out[:, 1:, :] 45 | mask = batch_data['pids_mask'][:, :, 1:] 46 | 47 | if args.corr_factor > 0: 48 | p_sub_seq_wins = model.get_sub_seq_wins(p_out) 49 | q_sub_seq_wins = model.get_sub_seq_wins(q_out) 50 | p_out, q_out, corr_loss = model.inter_corr_loss(p_out, p_sub_seq_wins, q_out, q_sub_seq_wins, mask) 51 | else: 52 | corr_loss = torch.tensor(0.) 53 | out_loss = model.loss(p_out.reshape(-1, p_out.size(-1)), 54 | q_out.reshape(-1, q_out.size(-1)), 55 | mask.reshape(-1), 56 | batch_data['pids_tgt'].view(-1), 57 | batch_data['pids_neg'].view(-1, args.train_num_neg)) 58 | 59 | batch_loss = out_loss + args.corr_factor * corr_loss 60 | batch_loss.backward() 61 | cur_lr = optimizer.step() 62 | optimizer.optimizer.zero_grad() 63 | total_out_loss += out_loss.item() 64 | total_corr_loss += corr_loss.item() 65 | total_batch_loss += batch_loss.item() 66 | if (batch_idx % args.print_every) == 0: 67 | logging.info( 68 | 'Training: Epoch {:04d} Iter {:04d} / {:04d} | Current Task {} | Time {:.1f}s | L_Rate {:.5f}'.format( 69 | epoch_idx, batch_idx, batch_num, cur_task, time.time() - time2, cur_lr)) 70 | logging.info( 71 | 'Training: Iter Loss {:.4f} | Out Loss {:.4f} | Corr Loss {:.4f}'.format(batch_loss.item(), 72 | out_loss.item(), 73 | corr_loss.item())) 74 | logging.info( 75 | 'Training: Iter Mean Loss {:.4f} | Out Mean Loss {:.4f} | Corr Mean Loss {:.4f}'.format( 76 | total_batch_loss / batch_idx, total_out_loss / batch_idx, total_corr_loss / batch_idx)) 77 | 78 | logging.info( 79 | 'Training: Epoch {:04d} Total Iter {:04d} | Total Time {:.1f}s'.format(epoch_idx, batch_num, 80 | time.time() - time1)) 81 | logging.info( 82 | 'Training: Iter Mean Loss {:.4f} | Out Mean Loss {:.4f} | Corr Mean Loss {:.4f}'.format( 83 | total_batch_loss / batch_num, total_out_loss / batch_num, total_corr_loss / batch_num)) 84 | 85 | # save model 86 | if (epoch_idx % args.save_every) == 0: 87 | save_model(model, args.save_dir, epoch_idx) 88 | 89 | 90 | def pretrain(args): 91 | # log 92 | log_name = 'log_pretrain' 93 | log_save_id = create_log_id(args.save_dir, name=log_name) 94 | logging_config(folder=args.save_dir, name='{}_{:d}'.format(log_name, log_save_id), no_console=False) 95 | logging.info(args) 96 | 97 | # GPU / CPU 98 | args.use_cuda = args.use_cuda & torch.cuda.is_available() 99 | device = torch.device("cuda:{}".format(args.cuda_idx) if args.use_cuda else "cpu") 100 | 101 | # load data 102 | data = UnifiedDataset(args.phase, args.tasks, args.data_root, logging) 103 | 104 | batch_sampler = BatchSampler(data, args.train_batch_size) 105 | data_loader = DataLoader(data, 106 | batch_sampler=batch_sampler, 107 | collate_fn=lambda x: collate_pretrain(x, args)) 108 | batch_num = len(data_loader) 109 | 110 | # construct model 111 | model = get_model(args) 112 | model.to(device) 113 | logging.info(model) 114 | 115 | if os.path.isfile(args.trained_model_path): 116 | logging.info("Loading pre-trained model: {}".format(args.trained_model_path)) 117 | model = load_model(model, args.trained_model_path) 118 | else: 119 | logging.info('Parameters initializing ...') 120 | for p in model.parameters(): 121 | if p.dim() > 1: 122 | nn.init.xavier_uniform_((p)) 123 | model.seq_partition.reset_offset() 124 | 125 | # define optimizer 126 | optimizer = NoamOpt(args.emb_size, args.opt_factor, args.opt_warmup, 127 | torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) 128 | logging.info(optimizer) 129 | 130 | start_epoch_idx = args.start_epoch_idx or 1 131 | for epoch_idx in range(start_epoch_idx, args.num_epoch + start_epoch_idx): 132 | # train and save model 133 | run_epoch(args, model, data_loader, optimizer, epoch_idx, batch_num, device) 134 | 135 | 136 | if __name__ == '__main__': 137 | args = parse_args() 138 | 139 | # Seed 140 | random.seed(args.seed) 141 | np.random.seed(args.seed) 142 | torch.manual_seed(args.seed) 143 | 144 | # Pretrain 145 | args.phase = 'pretrain' 146 | args.tasks = ['recommendation', 'search'] 147 | args.save_dir = f'models/{args.data_name}/{"_".join([args.phase] + args.tasks)}/{time.strftime("%Y%m%d_%H%M%S")}/' 148 | pretrain(args) -------------------------------------------------------------------------------- /modules.py: -------------------------------------------------------------------------------- 1 | import math, copy 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | def clones(module, num_layer): 9 | return nn.ModuleList([copy.deepcopy(module) for _ in range(num_layer)]) 10 | 11 | 12 | def attention(query, key, value, mask=None, dropout=None): 13 | d_k = query.size(-1) 14 | scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) 15 | if mask != None: 16 | scores = scores.masked_fill(mask == 0, -1e9) 17 | p_attn = F.softmax(scores, dim=-1) 18 | if dropout != None: 19 | p_attn = dropout(p_attn) 20 | return torch.matmul(p_attn, value), p_attn 21 | 22 | 23 | class PositionwiseFeedForward(nn.Module): 24 | def __init__(self, emb_size, hid_size, dropout=0.1): 25 | super(PositionwiseFeedForward, self).__init__() 26 | self.w_1 = nn.Linear(emb_size, hid_size) 27 | self.w_2 = nn.Linear(hid_size, emb_size) 28 | self.dropout = nn.Dropout(dropout) 29 | 30 | def forward(self, x): 31 | return self.w_2(self.dropout(F.relu(self.w_1(x)))) 32 | 33 | 34 | class MultiHeadAttention(nn.Module): 35 | def __init__(self, num_head, emb_size, dropout=0.1): 36 | super(MultiHeadAttention, self).__init__() 37 | assert emb_size % num_head == 0 38 | self.d_k = emb_size // num_head 39 | self.num_head = num_head 40 | self.linears = clones(nn.Linear(emb_size, emb_size), 4) 41 | self.attn = None 42 | self.dropout = nn.Dropout(p=dropout) 43 | 44 | def forward(self, query, key, value, mask=None): 45 | if mask != None: 46 | # Same mask applied to all h heads. 47 | mask = mask.unsqueeze(1) 48 | nbatches = query.size(0) 49 | 50 | query, key, value = [l(x).view(nbatches, -1, self.num_head, self.d_k).transpose(1, 2) for l, x in 51 | zip(self.linears, (query, key, value))] 52 | x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout) 53 | x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.num_head * self.d_k) 54 | return self.linears[-1](x) 55 | 56 | 57 | class SublayerConnection(nn.Module): 58 | def __init__(self, size, dropout): 59 | super(SublayerConnection, self).__init__() 60 | self.norm = LayerNorm(size) 61 | self.dropout = nn.Dropout(dropout) 62 | 63 | def forward(self, x, sublayer): 64 | return x + self.dropout(sublayer(self.norm(x))) 65 | 66 | 67 | class LayerNorm(nn.Module): 68 | def __init__(self, size, eps=1e-6): 69 | super(LayerNorm, self).__init__() 70 | self.a_2 = nn.Parameter(torch.ones(size)) 71 | self.b_2 = nn.Parameter(torch.zeros(size)) 72 | self.eps = eps 73 | 74 | def forward(self, x): 75 | mean = x.mean(-1, keepdim=True) 76 | std = x.std(-1, keepdim=True) 77 | return self.a_2 * (x - mean) / (std + self.eps) + self.b_2 78 | 79 | 80 | class SequencePartition(nn.Module): 81 | def __init__(self, sub_seq_num, emb_size, is_uniform=False): 82 | super(SequencePartition, self).__init__() 83 | self.sub_seq_num = sub_seq_num 84 | self.offset_activation = nn.ReLU() 85 | self.proj = nn.AdaptiveAvgPool1d(sub_seq_num) 86 | self.offset_predictor = nn.Linear(emb_size, 2, bias=False) 87 | self.sub_seq_coder = SubsequenceCoder(sub_seq_num, is_uniform, width_bias=torch.tensor(5./3.).sqrt().log()) 88 | 89 | def forward(self, x): 90 | src = x 91 | x = x.permute(0, 2, 1) 92 | x = self.proj(x) 93 | x = x.permute(0, 2, 1) 94 | pred_offset = self.offset_predictor(self.offset_activation(x)) 95 | sub_seq_wins = self.sub_seq_coder(pred_offset) 96 | sub_seq_wins = sub_seq_wins * src.size(1) 97 | return sub_seq_wins 98 | 99 | def reset_offset(self): 100 | nn.init.constant_(self.offset_predictor.weight, 0) 101 | if hasattr(self.offset_predictor, "bias") and self.offset_predictor.bias is not None: 102 | nn.init.constant_(self.offset_predictor.bias, 0) 103 | 104 | 105 | class SubsequenceCoder(nn.Module): 106 | def __init__(self, sub_seq_num, is_uniform, weights=(1., 1.), width_bias=None): 107 | super(SubsequenceCoder, self).__init__() 108 | self.sub_seq_num = sub_seq_num 109 | self.is_uniform = is_uniform 110 | self._generate_anchor() 111 | self.weights = weights # 2d: center coordinate and length 112 | self.width_bias = None 113 | if width_bias is not None: 114 | self.width_bias = nn.Parameter(width_bias) 115 | 116 | def _generate_anchor(self): 117 | anchors = [] 118 | sub_seq_stride = 1. / self.sub_seq_num 119 | for i in range(self.sub_seq_num): 120 | anchors.append((0.5 + i) * sub_seq_stride) 121 | anchors = torch.as_tensor(anchors) 122 | self.register_buffer("anchor", anchors) 123 | 124 | def forward(self, pred_offset): 125 | if self.is_uniform: 126 | ref_x = self.anchor.unsqueeze(0) 127 | windows = torch.zeros_like(pred_offset) 128 | width = 1 / self.sub_seq_num 129 | windows[:, :, 0] = ref_x - width / 2 130 | windows[:, :, 1] = ref_x + width / 2 131 | else: 132 | if self.width_bias is not None: 133 | pred_offset[:, :, -1] = pred_offset[:, :, -1] + self.width_bias 134 | windows = self.decode(pred_offset) 135 | 136 | windows = windows.clamp(min=0., max=1.) 137 | return windows 138 | 139 | def decode(self, rel_codes): 140 | windows = self.anchor 141 | point = 1. / self.sub_seq_num 142 | w_x, w_width = self.weights 143 | 144 | dx = torch.tanh(rel_codes[:, :, 0] / w_x) * point 145 | dw = F.relu(torch.tanh(rel_codes[:, :, -1] / w_width)) * point 146 | 147 | pred_windows = torch.zeros_like(rel_codes) 148 | ref_x = windows.unsqueeze(0) 149 | pred_windows[:, :, 0] = ref_x + dx - dw 150 | pred_windows[:, :, -1] = ref_x + dx + dw 151 | pred_windows = pred_windows.clamp(min=0., max=1.) 152 | return pred_windows 153 | 154 | 155 | class SiameseEncoder(nn.Module): 156 | def __init__(self, layer, num_layer): 157 | super(SiameseEncoder, self).__init__() 158 | self.layers = clones(layer, num_layer) 159 | self.norm = LayerNorm(layer.emb_size) 160 | 161 | def forward(self, src1, src2, mask): 162 | for layer in self.layers: 163 | x = layer(src1, src2, mask) 164 | return self.norm(x) 165 | 166 | 167 | class SiameseEncoderLayer(nn.Module): 168 | def __init__(self, emb_size, hid_size, num_head, dropout): 169 | super(SiameseEncoderLayer, self).__init__() 170 | self.emb_size = emb_size 171 | self.self_attn = MultiHeadAttention(num_head, emb_size) 172 | self.cross_attn = MultiHeadAttention(num_head, emb_size) 173 | self.feed_forward = PositionwiseFeedForward(emb_size, hid_size, dropout) 174 | self.sublayer = clones(SublayerConnection(emb_size, dropout), 3) 175 | 176 | def forward(self, src1, src2, src_mask): 177 | src1 = self.sublayer[0](src1, lambda src1: self.self_attn(src1, src1, src1, src_mask)) 178 | src1 = self.sublayer[1](src1, lambda src1: self.cross_attn(src1, src2, src2, src_mask)) 179 | return self.sublayer[2](src1, self.feed_forward) 180 | 181 | 182 | class PositionalEncoding(nn.Module): 183 | def __init__(self, emb_size, dropout, max_len=5000): 184 | super(PositionalEncoding, self).__init__() 185 | self.dropout = nn.Dropout(p=dropout) 186 | 187 | # Compute the positional encodings once in log space. 188 | pe = torch.zeros(max_len, emb_size) 189 | position = torch.arange(0., max_len).unsqueeze(1) 190 | div_term = torch.exp(torch.arange(0., emb_size, 2) * 191 | -(math.log(10000.0) / emb_size)) 192 | pe[:, 0::2] = torch.sin(position * div_term) 193 | pe[:, 1::2] = torch.cos(position * div_term) 194 | pe = pe.unsqueeze(0) 195 | self.register_buffer('pe', pe) 196 | 197 | def forward(self, x): 198 | x = x + self.pe[:, :x.size(-2)] 199 | return self.dropout(x) 200 | 201 | 202 | class Embeddings(nn.Module): 203 | def __init__(self, vocab, emb_size): 204 | super(Embeddings, self).__init__() 205 | self.lut = nn.Embedding(vocab, emb_size) 206 | self.emb_size = emb_size 207 | 208 | def forward(self, x): 209 | return self.lut(x) * math.sqrt(self.emb_size) -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from modules import * 2 | 3 | 4 | def get_model(args): 5 | model = UnifiedSSR(u_vocab=args.user_vocab, 6 | p_vocab=args.product_vocab, 7 | t_vocab=args.term_vocab, 8 | emb_size=args.emb_size, 9 | hid_size=args.hid_size, 10 | sub_seq_num=args.sub_seq_num, 11 | enc_num_layer=args.enc_num_layer, 12 | num_head=args.num_head, 13 | tasks=args.tasks, 14 | dropout=args.dropout) 15 | return model 16 | 17 | 18 | class UnifiedSSR(nn.Module): 19 | def __init__(self, u_vocab, p_vocab, t_vocab, emb_size, hid_size, sub_seq_num, enc_num_layer, 20 | num_head, tasks, padding_value=0, dropout=0.1): 21 | super(UnifiedSSR, self).__init__() 22 | hid_size = hid_size or emb_size * 2 23 | self.tasks = tasks 24 | self.p_vocab = p_vocab 25 | self.t_vocab = t_vocab 26 | self.sub_seq_num = sub_seq_num 27 | self.emb_size = emb_size 28 | self.padding_value = padding_value 29 | self.u_embed = Embeddings(u_vocab, emb_size) 30 | self.p_embed = Embeddings(p_vocab, emb_size) 31 | self.position = PositionalEncoding(emb_size, dropout) 32 | self.q_t_embed = Embeddings(t_vocab + 2, emb_size) # additional bos, eos 33 | self.encoder = SiameseEncoder(SiameseEncoderLayer(emb_size, hid_size, num_head, dropout), enc_num_layer) 34 | self.seq_partition = SequencePartition(sub_seq_num, emb_size) 35 | self.next_product_search_w = nn.Parameter(torch.tensor(0.5)) 36 | self.loss = None 37 | 38 | def forward(self, task, inputs): 39 | """ 40 | Shape: 41 | (task == 'recommendation' -> next product prediction) 42 | :return p_enc: [BS, Seq Max Len, Emb Size] 43 | 44 | (task == 'search' -> next product retrieval) 45 | :return p_enc: [BS, Seq Max Len, Emb Size] 46 | :return q_enc: [BS, Seq Max Len, Emb Size] 47 | """ 48 | if task == 'recommendation': 49 | self.loss = self.next_product_predict_loss 50 | p_rep = self.position(self.p_embed(inputs['pids_in']) + self.u_embed(inputs['uid']).unsqueeze(1)) 51 | p_enc = self.encoder(p_rep, p_rep, inputs['pids_mask']) 52 | return p_enc 53 | else: # task == 'search' 54 | self.loss = self.next_product_search_loss 55 | p_rep = self.position(self.p_embed(inputs['pids_in']) + self.u_embed(inputs['uid']).unsqueeze(1)) 56 | q_rep = [[torch.mean(self.q_t_embed(qry), dim=0) for qry in qrys] for qrys in inputs['qrys_in']] 57 | q_rep = torch.stack([torch.stack(q_rep_t) for q_rep_t in q_rep]) 58 | q_rep = self.position(q_rep + self.u_embed(inputs['uid']).unsqueeze(1)) 59 | p_enc = self.encoder(p_rep, q_rep, inputs['pids_mask']) 60 | q_enc = self.encoder(q_rep, p_rep, inputs['qrys_in_mask']) 61 | return p_enc, q_enc 62 | 63 | def next_product_predict_loss(self, seq_emb, mask, p_pos, p_negs): 64 | p_pos_emb = self.p_embed(p_pos) 65 | # p_pos_emb [BS*MaxLen, EmbSize] 66 | p_pos_logits = torch.sum(p_pos_emb * seq_emb, -1) 67 | # p_pos_logits [BS*MaxLen] 68 | 69 | p_negs_emb = self.p_embed(p_negs) 70 | # p_negs_emb [BS*MaxLen, NumNeg, EmbSize] 71 | p_negs_logits = torch.sum(p_negs_emb * seq_emb.unsqueeze(1).repeat(1, p_negs_emb.size(1), 1), -1) 72 | # p_negs_logits [BS*MaxLen, NumNeg] 73 | 74 | loss = - torch.sum( 75 | torch.log(p_pos_logits.sigmoid() + 1e-24) * mask + 76 | torch.log(1 - p_negs_logits.sigmoid() + 1e-24).sum(-1) * mask 77 | ) / mask.sum() 78 | 79 | return loss 80 | 81 | def next_product_predict(self, seq_emb, last_idx, p_pred=None): 82 | last_idx = last_idx.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, seq_emb.size(-1)) 83 | seq_last_out = seq_emb.gather(1, last_idx).squeeze() 84 | if p_pred is not None: 85 | p_emb = self.p_embed(p_pred) 86 | # p_emb [BS, NumNeg+1, EmbSize] 87 | seq_last_out = seq_last_out.unsqueeze(1).repeat(1, p_emb.size(1), 1) 88 | # seq_last_out [BS, NumNeg+1 or PVocab, EmbSize] 89 | pred_logits = torch.sum(p_emb * seq_last_out, -1) 90 | # pred_logits [BS, NumNeg+1 or PVocab] 91 | return pred_logits 92 | else: 93 | p_emb = self.p_embed.lut.weight * math.sqrt(self.emb_size) 94 | # p_emb = self.p_embed.lut.weight 95 | # p_emb [PVocab, EmbSize] 96 | p_emb = p_emb.unsqueeze(0).repeat(seq_emb.size(0), 1, 1) 97 | # p_emb [BS, PVocab, EmbSize] 98 | p_emb_chunks = [p_emb[:, i:i + 5000] for i in range(0, p_emb.size(1), 5000)] 99 | pred_logits = [] 100 | for p_emb_chunk in p_emb_chunks: 101 | if p_emb.device.type == 'cuda': 102 | pred_logits.append( 103 | torch.sum(p_emb_chunk * seq_last_out.unsqueeze(1).repeat(1, p_emb_chunk.size(1), 1), -1).cpu()) 104 | else: 105 | pred_logits.append( 106 | torch.sum(p_emb_chunk * seq_last_out.unsqueeze(1).repeat(1, p_emb_chunk.size(1), 1), -1)) 107 | return torch.cat(pred_logits, dim=1) 108 | 109 | def next_product_search_loss(self, p_seq_emb, q_seq_emb, mask, p_pos, p_negs): 110 | p_pos_emb = self.p_embed(p_pos) 111 | # p_pos_emb [BS*MaxLen, EmbSize] 112 | p_negs_emb = self.p_embed(p_negs) 113 | # p_negs_emb [BS*MaxLen, NumNeg, EmbSize] 114 | 115 | p_pos_sc = torch.sum(p_pos_emb * p_seq_emb, -1) 116 | # p_pos_sc [BS*MaxLen] 117 | p_negs_sc = torch.sum(p_negs_emb * p_seq_emb.unsqueeze(1).repeat(1, p_negs_emb.size(1), 1), -1) 118 | # p_negs_sc [BS*MaxLen, NumNeg] 119 | p_loss = - torch.sum( 120 | torch.log(p_pos_sc.sigmoid() + 1e-24) * mask + 121 | torch.log(1 - p_negs_sc.sigmoid() + 1e-24).sum(-1) * mask 122 | ) / mask.sum() 123 | 124 | q_pos_sc = torch.sum(p_pos_emb * q_seq_emb, -1) 125 | # q_pos_sc [BS*MaxLen] 126 | q_negs_sc = torch.sum(p_negs_emb * q_seq_emb.unsqueeze(1).repeat(1, p_negs_emb.size(1), 1), -1) 127 | # q_negs_sc [BS*MaxLen,NumNeg] 128 | q_loss = - torch.sum( 129 | torch.log(q_pos_sc.sigmoid() + 1e-24) * mask + 130 | torch.log(1 - q_negs_sc.sigmoid() + 1e-24).sum(-1) * mask 131 | ) / mask.sum() 132 | 133 | self.next_product_search_w.data = self.next_product_search_w.clamp(min=0.1, max=0.9) 134 | return self.next_product_search_w * p_loss + (1 - self.next_product_search_w) * q_loss 135 | 136 | def next_product_search(self, p_seq_emb, q_seq_emb, last_idx, p_pred=None): 137 | last_idx = last_idx.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, p_seq_emb.size(-1)) 138 | p_seq_last_out = p_seq_emb.gather(1, last_idx).squeeze() # [BS, EmbSize] 139 | q_seq_last_out = q_seq_emb.gather(1, last_idx).squeeze() # [BS, EmbSize] 140 | 141 | if p_pred is not None: 142 | p_emb = self.p_embed(p_pred) 143 | # p_emb [BS, NumNeg+1, EmbSize] 144 | 145 | p_seq_last_out = p_seq_last_out.unsqueeze(1).repeat(1, p_emb.size(1), 1) 146 | # p_seq_last_out [BS, NumNeg+1, EmbSize] 147 | q_seq_last_out = q_seq_last_out.unsqueeze(1).repeat(1, p_emb.size(1), 1) 148 | # q_seq_last_out [BS, NumNeg+1, EmbSize] 149 | p_pred_logits = torch.sum(p_emb * p_seq_last_out, -1) # [BS, NumNeg+1] 150 | q_pred_logits = torch.sum(p_emb * q_seq_last_out, -1) # [BS, NumNeg+1] 151 | return self.next_product_search_w * p_pred_logits + (1 - self.next_product_search_w) * q_pred_logits 152 | else: 153 | p_emb = self.p_embed.lut.weight * math.sqrt(self.emb_size) # [PVocab, EmbSize] 154 | # p_emb = self.p_embed.lut.weight # [PVocab, EmbSize] 155 | p_emb = p_emb.unsqueeze(0).repeat(p_seq_emb.size(0), 1, 1) 156 | # p_emb [BS, PVocab, EmbSize] 157 | p_emb_chunks = [p_emb[:, i:i + 5000] for i in range(0, p_emb.size(1), 5000)] 158 | pred_logits = [] 159 | for p_emb_chunk in p_emb_chunks: 160 | p_pred_logits_ = torch.sum( 161 | p_emb_chunk * p_seq_last_out.unsqueeze(1).repeat(1, p_emb_chunk.size(1), 1), -1) 162 | q_pred_logits_ = torch.sum( 163 | p_emb_chunk * q_seq_last_out.unsqueeze(1).repeat(1, p_emb_chunk.size(1), 1), -1) 164 | pred_logits_ = self.next_product_search_w * p_pred_logits_ + (1 - self.next_product_search_w) * q_pred_logits_ 165 | if p_emb.device.type == 'cuda': 166 | pred_logits.append(pred_logits_.cpu()) 167 | else: 168 | pred_logits.append(pred_logits_) 169 | return torch.cat(pred_logits, dim=1) 170 | 171 | def get_sub_seq_wins(self, emb): 172 | sub_seq_wins = self.seq_partition(emb) # [BS, Sub Seq Num, 2] 173 | return sub_seq_wins 174 | 175 | def intra_corr_loss(self, emb, sub_seq_wins, mask): 176 | len_idx = torch.arange(emb.size(1), device=emb.device).unsqueeze(0) # [1, Seq Max Len] 177 | sub_mask = (sub_seq_wins[:, :, 0:1] <= len_idx) & (len_idx <= sub_seq_wins[:, :, 1:2]) 178 | sub_mask = sub_mask & mask 179 | # sub_mask [BS, Sub Seq Num, Seq Max Len] 180 | sub_mask = sub_mask.unsqueeze(-1).expand(-1, -1, -1, emb.size(-1)) 181 | # sub_mask [BS, Sub Seq Num, Seq Max Len, Emb Size] 182 | sub_seq_rep = emb.unsqueeze(1) * sub_mask.float() 183 | emb = emb + sub_seq_rep.sum(dim=1) / (sub_mask.sum(dim=1) + 1e-10) 184 | 185 | sub_seq_rep = sub_seq_rep.sum(dim=-2) / (sub_mask.sum(dim=-2) + 1e-10) 186 | intra_corr = F.cosine_similarity(sub_seq_rep.unsqueeze(2), sub_seq_rep.unsqueeze(1), dim=-1) 187 | intra_corr = torch.abs(intra_corr) 188 | corr_mask = torch.triu(torch.ones((1, sub_seq_rep.size(1), sub_seq_rep.size(1)), device=sub_seq_rep.device), 189 | diagonal=1).bool() 190 | corr_mask = corr_mask & ~torch.triu( 191 | torch.ones((1, sub_seq_rep.size(1), sub_seq_rep.size(1)), device=sub_seq_rep.device), diagonal=2).bool() 192 | intra_corr = intra_corr * corr_mask.float() 193 | intra_corr_loss = intra_corr.sum() / (intra_corr.nonzero().size(0) + 1e-10) 194 | return emb, sub_seq_rep, intra_corr_loss 195 | 196 | def inter_corr_loss(self, p_emb, p_sub_seq_wins, q_emb, q_sub_seq_wins, mask): 197 | p_emb, p_sub_seq_rep, p_intra_corr_loss = self.intra_corr_loss(p_emb, p_sub_seq_wins, mask) 198 | q_emb, q_sub_seq_rep, q_intra_corr_loss = self.intra_corr_loss(q_emb, q_sub_seq_wins, mask) 199 | inter_corr = F.cosine_similarity(p_sub_seq_rep, q_sub_seq_rep, dim=-1) 200 | inter_corr_loss = inter_corr.sum() / (inter_corr.nonzero().size(0) + 1e-10) 201 | inter_corr_loss = p_intra_corr_loss + q_intra_corr_loss - inter_corr_loss 202 | return p_emb, q_emb, inter_corr_loss 203 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /batch.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | from torch.utils.data import Sampler 6 | from torch.nn.utils.rnn import pad_sequence 7 | 8 | 9 | def subsequent_mask(size): 10 | """ 11 | The shape of the attention map (size, size) = (max_len, max_len) 12 | Return the lower triangular mask 13 | """ 14 | return torch.tril(torch.ones((1, size, size)), diagonal=0).bool() 15 | 16 | 17 | def neg_sampling(batch_data, num_neg, vocab): 18 | batch_neg = [] 19 | for data in batch_data: 20 | if isinstance(data, list): 21 | batch_neg.append(neg_sampling(data, num_neg, vocab)) 22 | else: 23 | sampled_negs = [] 24 | for _ in range(num_neg): 25 | neg = np.random.randint(1, vocab) 26 | while neg == data or neg in sampled_negs: 27 | neg = np.random.randint(1, vocab) 28 | sampled_negs.append(neg) 29 | batch_neg.append(sampled_negs) 30 | return batch_neg 31 | 32 | 33 | def collate_pretrain(batch_data, args): 34 | num_neg = args.train_num_neg 35 | p_vocab = args.product_vocab 36 | padding_value = args.padding_value 37 | uids = [int(data['uid']) for data in batch_data] 38 | uids = torch.tensor(uids).long() 39 | if batch_data[0]['flag'] == 'recommendation': 40 | pids_in = [torch.tensor(data['pid_list'][:-1]).long() for data in batch_data] 41 | pids_in = pad_sequence(pids_in, batch_first=True, padding_value=padding_value) 42 | pids_out = [data['pid_list'][1:] for data in batch_data] 43 | pids_out_neg = neg_sampling(pids_out, num_neg, p_vocab) 44 | pids_out = pad_sequence([torch.tensor(pid_out).long() for pid_out in pids_out], batch_first=True, padding_value=padding_value) 45 | pids_out_neg = pad_sequence([torch.tensor(pid_neg).long() for pid_neg in pids_out_neg], batch_first=True, 46 | padding_value=padding_value) 47 | pids_mask = (pids_in != padding_value).unsqueeze(-2) 48 | return 'recommendation', {'uid': uids, 49 | 'pids_in': pids_in, 50 | 'pids_mask': pids_mask, 51 | 'pids_tgt': pids_out, 52 | 'pids_neg': pids_out_neg} 53 | 54 | else: # 'search' 55 | pids_in = [torch.tensor(data['pid_list'][:-1]).long() for data in batch_data] 56 | pids_in = pad_sequence(pids_in, batch_first=True, padding_value=padding_value) 57 | pids_in = F.pad(pids_in, (1, 0), value=padding_value) 58 | pids_out = [data['pid_list'][1:] for data in batch_data] 59 | pids_out_neg = neg_sampling(pids_out, num_neg, p_vocab) 60 | pids_out = pad_sequence([torch.tensor(pid_out).long() for pid_out in pids_out], batch_first=True, padding_value=padding_value) 61 | pids_out_neg = pad_sequence([torch.tensor(pid_neg).long() for pid_neg in pids_out_neg], batch_first=True, padding_value=padding_value) 62 | pids_mask = (pids_in != padding_value).unsqueeze(-2) 63 | 64 | qrys_in = [data['qry_list'] for data in batch_data] # [BS, Seq Len, Qry Len] 65 | qrys_in_seq_max_len = max(len(qrys) for qrys in qrys_in) 66 | qrys_in_mask = torch.stack([torch.cat([torch.ones(1)] * len(qrys) + [torch.zeros(1)] * (qrys_in_seq_max_len - len(qrys))) for qrys in qrys_in]) 67 | qrys_in_mask = (qrys_in_mask != padding_value).unsqueeze(-2) 68 | # qrys_in_mask [BS, 1, Seq Max Len] 69 | qrys_in = [[torch.tensor(qry).long() for qry in qrys] + [torch.zeros(1).long()] * (qrys_in_seq_max_len - len(qrys)) for qrys in qrys_in] # padding 70 | # qrys_in [BS, Seq Max Len, Qry Len] 71 | 72 | return 'search', {'uid': uids, 73 | 'pids_in': pids_in, 74 | 'pids_mask': pids_mask, 75 | 'pids_tgt': pids_out, 76 | 'pids_neg': pids_out_neg, 77 | 'qrys_in': qrys_in, 78 | 'qrys_in_mask': qrys_in_mask} 79 | 80 | 81 | def collate_train(batch_data, args): 82 | num_neg = args.train_num_neg 83 | p_vocab = args.product_vocab 84 | padding_value = args.padding_value 85 | uids = [int(data['uid']) for data in batch_data] 86 | uids = torch.tensor(uids).long() 87 | if batch_data[0]['flag'] == 'recommendation': 88 | pids_in = [torch.tensor(data['pid_list'][:-3]).long() for data in batch_data] 89 | pids_in = pad_sequence(pids_in, batch_first=True, padding_value=padding_value) 90 | pids_out = [data['pid_list'][1:-2] for data in batch_data] 91 | pids_out_neg = neg_sampling(pids_out, num_neg, p_vocab) 92 | pids_out = pad_sequence([torch.tensor(pid_out).long() for pid_out in pids_out], batch_first=True, padding_value=padding_value) 93 | pids_out_neg = pad_sequence([torch.tensor(pid_neg).long() for pid_neg in pids_out_neg], batch_first=True, padding_value=padding_value) 94 | pids_mask = (pids_in != padding_value).unsqueeze(-2) 95 | return 'recommendation', {'uid': uids, 96 | 'pids_in': pids_in, 97 | 'pids_mask': pids_mask, 98 | 'pids_tgt': pids_out, 99 | 'pids_neg': pids_out_neg} 100 | 101 | else: # 'search' 102 | pids_in = [torch.tensor(data['pid_list'][:-3]).long() for data in batch_data] 103 | pids_in = pad_sequence(pids_in, batch_first=True, padding_value=padding_value) 104 | pids_in = F.pad(pids_in, (1, 0), value=padding_value) 105 | pids_out = [data['pid_list'][1:-2] for data in batch_data] 106 | pids_out_neg = neg_sampling(pids_out, num_neg, p_vocab) 107 | pids_out = pad_sequence([torch.tensor(pid_out).long() for pid_out in pids_out], batch_first=True, 108 | padding_value=padding_value) 109 | pids_out_neg = pad_sequence([torch.tensor(pid_neg).long() for pid_neg in pids_out_neg], batch_first=True, 110 | padding_value=padding_value) 111 | pids_mask = (pids_in != padding_value).unsqueeze(-2) 112 | 113 | qrys_in = [data['qry_list'][:-2] for data in batch_data] 114 | qrys_in_seq_max_len = max(len(qrys) for qrys in qrys_in) 115 | qrys_in_mask = torch.stack([torch.cat([torch.ones(1)] * len(qrys) + [torch.zeros(1)] * (qrys_in_seq_max_len - len(qrys))) for qrys in qrys_in]) 116 | qrys_in_mask = (qrys_in_mask != padding_value).unsqueeze(-2) 117 | qrys_in = [[torch.tensor(qry).long() for qry in qrys] + [torch.zeros(1).long()] * (qrys_in_seq_max_len - len(qrys)) for qrys in qrys_in] 118 | 119 | return 'search', {'uid': uids, 120 | 'pids_in': pids_in, 121 | 'pids_mask': pids_mask, 122 | 'pids_tgt': pids_out, 123 | 'pids_neg': pids_out_neg, 124 | 'qrys_in': qrys_in, 125 | 'qrys_in_mask': qrys_in_mask} 126 | 127 | 128 | def collate_val(batch_data, args): 129 | num_neg = args.test_num_neg 130 | p_vocab = args.product_vocab 131 | padding_value = args.padding_value 132 | uids = [int(data['uid']) for data in batch_data] 133 | uids = torch.tensor(uids).long() 134 | if batch_data[0]['flag'] == 'recommendation': 135 | pids_in = [torch.tensor(data['pid_list'][:-2]).long() for data in batch_data] 136 | pids_in = pad_sequence(pids_in, batch_first=True, padding_value=padding_value) 137 | pids_mask = (pids_in != padding_value).unsqueeze(-2) 138 | pid_last_idx = torch.tensor([data['length'] - 3 for data in batch_data]) 139 | 140 | pid_tgt = [data['pid_list'][-2] for data in batch_data] 141 | pid_tgt_neg = neg_sampling(pid_tgt, num_neg, p_vocab) 142 | pid_tgt = torch.tensor(pid_tgt).long() 143 | pid_tgt_neg = torch.tensor(pid_tgt_neg).long() 144 | pid_pred = torch.cat((pid_tgt.unsqueeze(-1), pid_tgt_neg), -1) 145 | 146 | return 'recommendation', {'uid': uids, 147 | 'pids_in': pids_in, 148 | 'pids_mask': pids_mask, 149 | 'pid_pred': pid_pred, 150 | 'pid_last_idx': pid_last_idx} 151 | 152 | else: # 'search' 153 | pids_in = [torch.tensor(data['pid_list'][:-2]).long() for data in batch_data] 154 | pids_in = pad_sequence(pids_in, batch_first=True, padding_value=padding_value) 155 | pids_in = F.pad(pids_in, (1, 0), value=padding_value) 156 | pid_tgt = [data['pid_list'][-2] for data in batch_data] 157 | pid_tgt_neg = neg_sampling(pid_tgt, num_neg, p_vocab) 158 | pid_tgt = torch.tensor(pid_tgt).long() 159 | pid_tgt_neg = torch.tensor(pid_tgt_neg).long() 160 | pid_pred = torch.cat((pid_tgt.unsqueeze(-1), pid_tgt_neg), -1) 161 | pids_mask = (pids_in != padding_value).unsqueeze(-2) 162 | pid_last_idx = torch.tensor([data['length'] - 3 for data in batch_data]) 163 | 164 | qrys_in = [data['qry_list'][:-1] for data in batch_data] 165 | qrys_in_seq_max_len = max(len(qrys) for qrys in qrys_in) 166 | qrys_in_mask = torch.stack([torch.cat([torch.ones(1)] * len(qrys) + [torch.zeros(1)] * (qrys_in_seq_max_len - len(qrys))) for qrys in qrys_in]) 167 | qrys_in_mask = (qrys_in_mask != padding_value).unsqueeze(-2) 168 | qrys_in = [[torch.tensor(qry).long() for qry in qrys] + [torch.zeros(1).long()] * (qrys_in_seq_max_len - len(qrys)) for qrys in qrys_in] 169 | 170 | return 'search', {'uid': uids, 171 | 'pids_in': pids_in, 172 | 'pids_mask': pids_mask, 173 | 'pid_pred': pid_pred, 174 | 'pid_last_idx': pid_last_idx, 175 | 'qrys_in': qrys_in, 176 | 'qrys_in_mask': qrys_in_mask} 177 | 178 | 179 | def collate_test(batch_data, args): 180 | num_neg = args.test_num_neg 181 | p_vocab = args.product_vocab 182 | padding_value = args.padding_value 183 | uids = [int(data['uid']) for data in batch_data] 184 | uids = torch.tensor(uids).long() 185 | if batch_data[0]['flag'] == 'recommendation': 186 | pids_in = [torch.tensor(data['pid_list'][:-1]).long() for data in batch_data] 187 | pids_in = pad_sequence(pids_in, batch_first=True, padding_value=padding_value) 188 | pids_mask = (pids_in != padding_value).unsqueeze(-2) 189 | pid_last_idx = torch.tensor([data['length'] - 2 for data in batch_data]) 190 | 191 | pid_tgt = [data['pid_list'][-1] for data in batch_data] 192 | pid_tgt_neg = neg_sampling(pid_tgt, num_neg, p_vocab) 193 | pid_tgt = torch.tensor(pid_tgt).long() 194 | pid_tgt_neg = torch.tensor(pid_tgt_neg).long() 195 | pid_pred = torch.cat((pid_tgt.unsqueeze(-1), pid_tgt_neg), -1) 196 | 197 | return 'recommendation', {'uid': uids, 198 | 'pids_in': pids_in, 199 | 'pids_mask': pids_mask, 200 | 'pid_pred': pid_pred, 201 | 'pid_last_idx': pid_last_idx} 202 | 203 | else: # 'search' 204 | pids_in = [torch.tensor(data['pid_list'][:-1]).long() for data in batch_data] 205 | pids_in = pad_sequence(pids_in, batch_first=True, padding_value=padding_value) 206 | pids_in = F.pad(pids_in, (1, 0), value=padding_value) 207 | pid_tgt = [data['pid_list'][-1] for data in batch_data] 208 | pid_tgt_neg = neg_sampling(pid_tgt, num_neg, p_vocab) 209 | pid_tgt = torch.tensor(pid_tgt).long() 210 | pid_tgt_neg = torch.tensor(pid_tgt_neg).long() 211 | pid_pred = torch.cat((pid_tgt.unsqueeze(-1), pid_tgt_neg), -1) 212 | pids_mask = (pids_in != padding_value).unsqueeze(-2) 213 | pid_last_idx = torch.tensor([data['length'] - 2 for data in batch_data]) 214 | 215 | qrys_in = [data['qry_list'] for data in batch_data] 216 | qrys_in_seq_max_len = max(len(qrys) for qrys in qrys_in) 217 | qrys_in_mask = torch.stack( 218 | [torch.cat([torch.ones(1)] * len(qrys) + [torch.zeros(1)] * (qrys_in_seq_max_len - len(qrys))) for qrys in 219 | qrys_in]) 220 | qrys_in_mask = (qrys_in_mask != padding_value).unsqueeze(-2) 221 | qrys_in = [ 222 | [torch.tensor(qry).long() for qry in qrys] + [torch.zeros(1).long()] * (qrys_in_seq_max_len - len(qrys)) for 223 | qrys in qrys_in] 224 | 225 | return 'search', {'uid': uids, 226 | 'pids_in': pids_in, 227 | 'pids_mask': pids_mask, 228 | 'pid_pred': pid_pred, 229 | 'pid_last_idx': pid_last_idx, 230 | 'qrys_in': qrys_in, 231 | 'qrys_in_mask': qrys_in_mask} 232 | 233 | 234 | class BatchSampler(Sampler): 235 | def __init__(self, data_source, batch_size): 236 | super(BatchSampler, self).__init__(data_source) 237 | self.recommendation_data_len = [] 238 | self.search_data_len = [] 239 | self.batch_size = batch_size 240 | self.pool_size = batch_size * 100 241 | for idx, data in enumerate(data_source): 242 | if data['flag'] == 'recommendation': 243 | self.recommendation_data_len.append((idx, data['length'])) 244 | else: # data['flag'] == 'search' 245 | self.search_data_len.append((idx, data['length'])) 246 | 247 | def __iter__(self): 248 | batch_idxes_list = [] 249 | if self.recommendation_data_len: 250 | recommendation_ori_idx, recommendation_len = zip(*self.recommendation_data_len) 251 | recommendation_idx = zip(recommendation_len, np.random.permutation(len(self.recommendation_data_len)), recommendation_ori_idx) 252 | recommendation_idx = sorted(recommendation_idx, key=lambda e: (e[1] // self.pool_size, e[0]), reverse=True) 253 | for i in range(0, len(recommendation_idx), self.batch_size): 254 | batch_idxes_list.append([recommendation_idx_[2] for recommendation_idx_ in recommendation_idx[i:i+self.batch_size]]) 255 | if self.search_data_len: 256 | search_ori_idx, search_len = zip(*self.search_data_len) 257 | search_idx = zip(search_len, np.random.permutation(len(self.search_data_len)), search_ori_idx) 258 | search_idx = sorted(search_idx, key=lambda e: (e[1] // self.pool_size, e[0]), reverse=True) 259 | for i in range(0, len(search_idx), self.batch_size): 260 | batch_idxes_list.append([search_idx_[2] for search_idx_ in search_idx[i:i+self.batch_size]]) 261 | random.shuffle(batch_idxes_list) 262 | for batch_idxes in batch_idxes_list: 263 | yield batch_idxes 264 | 265 | def __len__(self): 266 | recommendation_batches = (len(self.recommendation_data_len) + self.batch_size - 1) // self.batch_size 267 | search_batches = (len(self.search_data_len) + self.batch_size - 1) // self.batch_size 268 | return recommendation_batches + search_batches 269 | 270 | 271 | if __name__ == "__main__": 272 | # for test only 273 | pass --------------------------------------------------------------------------------