├── .gitignore ├── README.md ├── data └── amazon │ └── .gitkeep ├── eval.py ├── main.py ├── models ├── a_llmrec_model.py ├── llm4rec.py └── recsys_model.py ├── pre_train ├── ctrl │ └── model_ctrl.py └── sasrec │ ├── data_preprocess.py │ ├── main.py │ ├── model.py │ └── utils.py ├── requirements.txt ├── train_model.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # models 2 | *.pth 3 | *.pt 4 | *.json.gz 5 | *.bin 6 | *.model 7 | 8 | # data 9 | *.json 10 | *.txt 11 | */data/* 12 | 13 | # etc, cache 14 | .DS_Store 15 | __pycache__/ 16 | *.out 17 | *.pyc 18 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # A-LLMRec : Large Language Models meet Collaborative Filtering: An Efficient All-round LLM-based Recommender System 2 | 3 | The source code for A-LLMRec : Large Language Models meet Collaborative Filtering: An Efficient All-round LLM-based Recommender System paper, accepted at **KDD 2024**. 4 | 5 | ## Overview 6 | In this [paper](https://arxiv.org/abs/2404.11343), we propose an efficient all-round LLM-based recommender system, called A-LLMRec (All-round LLM-based Recommender system). The main idea is to enable an LLM to directly leverage the collaborative knowledge contained in a pre-trained collaborative filtering recommender system (CF-RecSys) so that the emergent ability of the LLM can be jointly exploited. By doing so, A-LLMRec can outperform under the various scenarios including warm/cold, few-shot, cold user, and cross-domain scenarios. 7 | 8 | ## Env Setting 9 | ``` 10 | conda create -n [env name] python=3.10 pip 11 | conda install pytorch==2.1.2 pytorch-cuda=11.8 -c pytorch -c nvidia 12 | conda install numpy=1.26.3 13 | conda install tqdm 14 | conda install pytz 15 | conda install transformers=4.32.1 16 | pip install sentence-transformers==2.2.2 17 | conda install conda-forge::accelerate=0.25.0 18 | conda install conda-forge::bitsandbytes=0.42.0 19 | ``` 20 | 21 | ## Dataset 22 | Download [dataset of 2018 Amazon Review dataset](https://cseweb.ucsd.edu/~jmcauley/datasets/amazon_v2/) for the experiment. Should download metadata and reviews files and place them into data/amazon direcotory. 23 | 24 | ## Pre-train CF-RecSys (SASRec) 25 | ``` 26 | cd pre_train/sasrec 27 | python main.py --device=cuda --dataset Movies_and_TV 28 | ``` 29 | 30 | ## A-LLMRec Train 31 | - train stage1 32 | ``` 33 | cd ../../ 34 | python main.py --pretrain_stage1 --rec_pre_trained_data Movies_and_TV 35 | ``` 36 | 37 | - train stage2 38 | ``` 39 | python main.py --pretrain_stage2 --rec_pre_trained_data Movies_and_TV 40 | ``` 41 | 42 | To run with multi-GPU setting, assign devices using the CUDA_VISIBLE_DEVICES command and add '--multi_gpu' argument. 43 | - ex) CUDA_VISIBLE_DEVICES = 0,1 python main.py ... --multi_gpu 44 | 45 | 46 | 47 | ## Evaluation 48 | Inference stage generates "recommendation_output.txt" file and writes the recommendation result generated from the LLMs into the file. To evaluate the result, run the eval.py file. 49 | 50 | ``` 51 | python main.py --inference --rec_pre_trained_data Movies_and_TV 52 | python eval.py 53 | ``` 54 | -------------------------------------------------------------------------------- /data/amazon/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ghdtjr/A-LLMRec/e216bbc3d0c9eec548f2d9667bdeca8019f29af1/data/amazon/.gitkeep -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def get_answers_predictions(file_path): 4 | answers = [] 5 | llm_predictions = [] 6 | with open(file_path, 'r') as f: 7 | for line in f: 8 | if 'Answer:' == line[:len('Answer:')]: 9 | answer = line.replace('Answer:', '').strip()[1:-1].lower() 10 | answers.append(answer) 11 | if 'LLM:' == line[:len('LLM:')]: 12 | llm_prediction = line.replace('LLM', '').strip().lower() 13 | try: 14 | llm_prediction = llm_prediction.replace("\"item title\" : ", '') 15 | start = llm_prediction.find('"') 16 | end = llm_prediction.rfind('"') 17 | 18 | if (start + end < start) or (start + end < end): 19 | print(1/0) 20 | 21 | llm_prediction = llm_prediction[start+1:end] 22 | except Exception as e: 23 | print() 24 | 25 | llm_predictions.append(llm_prediction) 26 | 27 | return answers, llm_predictions 28 | 29 | def evaluate(answers, llm_predictions, k=1): 30 | NDCG = 0.0 31 | HT = 0.0 32 | predict_num = len(answers) 33 | print(predict_num) 34 | for answer, prediction in zip(answers, llm_predictions): 35 | if k > 1: 36 | rank = prediction.index(answer) 37 | if rank < k: 38 | NDCG += 1 / np.log2(rank + 1) 39 | HT += 1 40 | elif k == 1: 41 | if answer in prediction: 42 | NDCG += 1 43 | HT += 1 44 | 45 | return NDCG / predict_num, HT / predict_num 46 | 47 | if __name__ == "__main__": 48 | inferenced_file_path = './recommendation_output.txt' 49 | answers, llm_predictions = get_answers_predictions(inferenced_file_path) 50 | print(len(answers), len(llm_predictions)) 51 | assert(len(answers) == len(llm_predictions)) 52 | 53 | ndcg, ht = evaluate(answers, llm_predictions, k=1) 54 | print(f"ndcg at 1: {ndcg}") 55 | print(f"hit at 1: {ht}") -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | 5 | from utils import * 6 | from train_model import * 7 | 8 | from pre_train.sasrec.data_preprocess import preprocess 9 | 10 | if __name__ == "__main__": 11 | parser = argparse.ArgumentParser() 12 | 13 | # GPU train options 14 | parser.add_argument("--multi_gpu", action='store_true') 15 | parser.add_argument('--gpu_num', type=int, default=0) 16 | 17 | # model setting 18 | parser.add_argument("--llm", type=str, default='opt', help='flan_t5, opt, vicuna') 19 | parser.add_argument("--recsys", type=str, default='sasrec') 20 | 21 | # dataset setting 22 | parser.add_argument("--rec_pre_trained_data", type=str, default='Movies_and_TV') 23 | 24 | # train phase setting 25 | parser.add_argument("--pretrain_stage1", action='store_true') 26 | parser.add_argument("--pretrain_stage2", action='store_true') 27 | parser.add_argument("--inference", action='store_true') 28 | 29 | # hyperparameters options 30 | parser.add_argument('--batch_size1', default=32, type=int) 31 | parser.add_argument('--batch_size2', default=2, type=int) 32 | parser.add_argument('--batch_size_infer', default=2, type=int) 33 | parser.add_argument('--maxlen', default=50, type=int) 34 | parser.add_argument('--num_epochs', default=10, type=int) 35 | parser.add_argument("--stage1_lr", type=float, default=0.0001) 36 | parser.add_argument("--stage2_lr", type=float, default=0.0001) 37 | 38 | args = parser.parse_args() 39 | 40 | args.device = 'cuda:' + str(args.gpu_num) 41 | 42 | if args.pretrain_stage1: 43 | train_model_phase1(args) 44 | elif args.pretrain_stage2: 45 | train_model_phase2(args) 46 | elif args.inference: 47 | inference(args) -------------------------------------------------------------------------------- /models/a_llmrec_model.py: -------------------------------------------------------------------------------- 1 | import random 2 | import pickle 3 | 4 | import torch 5 | from torch.cuda.amp import autocast as autocast 6 | import torch.nn as nn 7 | import numpy as np 8 | 9 | from models.recsys_model import * 10 | from models.llm4rec import * 11 | from sentence_transformers import SentenceTransformer 12 | 13 | 14 | class two_layer_mlp(nn.Module): 15 | def __init__(self, dims): 16 | super().__init__() 17 | self.fc1 = nn.Linear(dims, 128) 18 | self.fc2 = nn.Linear(128, dims) 19 | self.sigmoid = nn.Sigmoid() 20 | 21 | def forward(self, x): 22 | x = self.fc1(x) 23 | x = self.sigmoid(x) 24 | x1 = self.fc2(x) 25 | return x, x1 26 | 27 | class A_llmrec_model(nn.Module): 28 | def __init__(self, args): 29 | super().__init__() 30 | rec_pre_trained_data = args.rec_pre_trained_data 31 | self.args = args 32 | self.device = args.device 33 | 34 | with open(f'./data/amazon/{args.rec_pre_trained_data}_text_name_dict.json.gz','rb') as ft: 35 | self.text_name_dict = pickle.load(ft) 36 | 37 | self.recsys = RecSys(args.recsys, rec_pre_trained_data, self.device) 38 | self.item_num = self.recsys.item_num 39 | self.rec_sys_dim = self.recsys.hidden_units 40 | self.sbert_dim = 768 41 | 42 | self.mlp = two_layer_mlp(self.rec_sys_dim) 43 | if args.pretrain_stage1: 44 | self.sbert = SentenceTransformer('nq-distilbert-base-v1') 45 | self.mlp2 = two_layer_mlp(self.sbert_dim) 46 | 47 | self.mse = nn.MSELoss() 48 | 49 | self.maxlen = args.maxlen 50 | self.NDCG = 0 51 | self.HIT = 0 52 | self.rec_NDCG = 0 53 | self.rec_HIT = 0 54 | self.lan_NDCG=0 55 | self.lan_HIT=0 56 | self.num_user = 0 57 | self.yes = 0 58 | 59 | self.bce_criterion = torch.nn.BCEWithLogitsLoss() 60 | 61 | if args.pretrain_stage2 or args.inference: 62 | self.llm = llm4rec(device=self.device, llm_model=args.llm) 63 | 64 | self.log_emb_proj = nn.Sequential( 65 | nn.Linear(self.rec_sys_dim, self.llm.llm_model.config.hidden_size), 66 | nn.LayerNorm(self.llm.llm_model.config.hidden_size), 67 | nn.LeakyReLU(), 68 | nn.Linear(self.llm.llm_model.config.hidden_size, self.llm.llm_model.config.hidden_size) 69 | ) 70 | nn.init.xavier_normal_(self.log_emb_proj[0].weight) 71 | nn.init.xavier_normal_(self.log_emb_proj[3].weight) 72 | 73 | self.item_emb_proj = nn.Sequential( 74 | nn.Linear(128, self.llm.llm_model.config.hidden_size), 75 | nn.LayerNorm(self.llm.llm_model.config.hidden_size), 76 | nn.GELU(), 77 | nn.Linear(self.llm.llm_model.config.hidden_size, self.llm.llm_model.config.hidden_size) 78 | ) 79 | nn.init.xavier_normal_(self.item_emb_proj[0].weight) 80 | nn.init.xavier_normal_(self.item_emb_proj[3].weight) 81 | 82 | def save_model(self, args, epoch1=None, epoch2=None): 83 | out_dir = f'./models/saved_models/' 84 | create_dir(out_dir) 85 | out_dir += f'{args.rec_pre_trained_data}_{args.recsys}_{epoch1}_' 86 | if args.pretrain_stage1: 87 | torch.save(self.sbert.state_dict(), out_dir + 'sbert.pt') 88 | torch.save(self.mlp.state_dict(), out_dir + 'mlp.pt') 89 | torch.save(self.mlp2.state_dict(), out_dir + 'mlp2.pt') 90 | 91 | out_dir += f'{args.llm}_{epoch2}_' 92 | if args.pretrain_stage2: 93 | torch.save(self.log_emb_proj.state_dict(), out_dir + 'log_proj.pt') 94 | torch.save(self.item_emb_proj.state_dict(), out_dir + 'item_proj.pt') 95 | 96 | def load_model(self, args, phase1_epoch=None, phase2_epoch=None): 97 | out_dir = f'./models/saved_models/{args.rec_pre_trained_data}_{args.recsys}_{phase1_epoch}_' 98 | 99 | mlp = torch.load(out_dir + 'mlp.pt', map_location = args.device) 100 | self.mlp.load_state_dict(mlp) 101 | del mlp 102 | for name, param in self.mlp.named_parameters(): 103 | param.requires_grad = False 104 | 105 | if args.inference: 106 | out_dir += f'{args.llm}_{phase2_epoch}_' 107 | 108 | log_emb_proj_dict = torch.load(out_dir + 'log_proj.pt', map_location = args.device) 109 | self.log_emb_proj.load_state_dict(log_emb_proj_dict) 110 | del log_emb_proj_dict 111 | 112 | item_emb_proj_dict = torch.load(out_dir + 'item_proj.pt', map_location = args.device) 113 | self.item_emb_proj.load_state_dict(item_emb_proj_dict) 114 | del item_emb_proj_dict 115 | 116 | def find_item_text(self, item, title_flag=True, description_flag=True): 117 | t = 'title' 118 | d = 'description' 119 | t_ = 'No Title' 120 | d_ = 'No Description' 121 | if title_flag and description_flag: 122 | return [f'"{self.text_name_dict[t].get(i,t_)}, {self.text_name_dict[d].get(i,d_)}"' for i in item] 123 | elif title_flag and not description_flag: 124 | return [f'"{self.text_name_dict[t].get(i,t_)}"' for i in item] 125 | elif not title_flag and description_flag: 126 | return [f'"{self.text_name_dict[d].get(i,d_)}"' for i in item] 127 | 128 | def find_item_text_single(self, item, title_flag=True, description_flag=True): 129 | t = 'title' 130 | d = 'description' 131 | t_ = 'No Title' 132 | d_ = 'No Description' 133 | if title_flag and description_flag: 134 | return f'"{self.text_name_dict[t].get(item,t_)}, {self.text_name_dict[d].get(item,d_)}"' 135 | elif title_flag and not description_flag: 136 | return f'"{self.text_name_dict[t].get(item,t_)}"' 137 | elif not title_flag and description_flag: 138 | return f'"{self.text_name_dict[d].get(item,d_)}"' 139 | 140 | def get_item_emb(self, item_ids): 141 | with torch.no_grad(): 142 | item_embs = self.recsys.model.item_emb(torch.LongTensor(item_ids).to(self.device)) 143 | item_embs, _ = self.mlp(item_embs) 144 | 145 | return item_embs 146 | 147 | def forward(self, data, optimizer=None, batch_iter=None, mode='phase1'): 148 | if mode == 'phase1': 149 | self.pre_train_phase1(data, optimizer, batch_iter) 150 | if mode == 'phase2': 151 | self.pre_train_phase2(data, optimizer, batch_iter) 152 | if mode =='generate': 153 | self.generate(data) 154 | 155 | def pre_train_phase1(self,data,optimizer, batch_iter): 156 | epoch, total_epoch, step, total_step = batch_iter 157 | 158 | self.sbert.train() 159 | optimizer.zero_grad() 160 | 161 | u, seq, pos, neg = data 162 | indices = [self.maxlen*(i+1)-1 for i in range(u.shape[0])] 163 | 164 | with torch.no_grad(): 165 | log_emb, pos_emb, neg_emb = self.recsys.model(u, seq, pos, neg, mode='item') 166 | 167 | log_emb_ = log_emb[indices] 168 | pos_emb_ = pos_emb[indices] 169 | neg_emb_ = neg_emb[indices] 170 | pos_ = pos.reshape(pos.size)[indices] 171 | neg_ = neg.reshape(neg.size)[indices] 172 | 173 | start_inx = 0 174 | end_inx = 60 175 | iterss = 0 176 | mean_loss = 0 177 | bpr_loss = 0 178 | gt_loss = 0 179 | rc_loss = 0 180 | text_rc_loss = 0 181 | original_loss = 0 182 | while start_inx < len(log_emb_): 183 | log_emb = log_emb_[start_inx:end_inx] 184 | pos_emb = pos_emb_[start_inx:end_inx] 185 | neg_emb = neg_emb_[start_inx:end_inx] 186 | 187 | pos__ = pos_[start_inx:end_inx] 188 | neg__ = neg_[start_inx:end_inx] 189 | 190 | start_inx = end_inx 191 | end_inx += 60 192 | iterss +=1 193 | 194 | pos_text = self.find_item_text(pos__) 195 | neg_text = self.find_item_text(neg__) 196 | 197 | pos_token = self.sbert.tokenize(pos_text) 198 | pos_text_embedding= self.sbert({'input_ids':pos_token['input_ids'].to(self.device),'attention_mask':pos_token['attention_mask'].to(self.device)})['sentence_embedding'] 199 | neg_token = self.sbert.tokenize(neg_text) 200 | neg_text_embedding= self.sbert({'input_ids':neg_token['input_ids'].to(self.device),'attention_mask':neg_token['attention_mask'].to(self.device)})['sentence_embedding'] 201 | 202 | pos_text_matching, pos_proj = self.mlp(pos_emb) 203 | neg_text_matching, neg_proj = self.mlp(neg_emb) 204 | 205 | pos_text_matching_text, pos_text_proj = self.mlp2(pos_text_embedding) 206 | neg_text_matching_text, neg_text_proj = self.mlp2(neg_text_embedding) 207 | 208 | pos_logits, neg_logits = (log_emb*pos_proj).mean(axis=1), (log_emb*neg_proj).mean(axis=1) 209 | pos_labels, neg_labels = torch.ones(pos_logits.shape, device=pos_logits.device), torch.zeros(neg_logits.shape, device=pos_logits.device) 210 | 211 | loss = self.bce_criterion(pos_logits, pos_labels) 212 | loss += self.bce_criterion(neg_logits, neg_labels) 213 | 214 | matching_loss = self.mse(pos_text_matching,pos_text_matching_text) + self.mse(neg_text_matching,neg_text_matching_text) 215 | reconstruction_loss = self.mse(pos_proj,pos_emb) + self.mse(neg_proj,neg_emb) 216 | text_reconstruction_loss = self.mse(pos_text_proj,pos_text_embedding.data) + self.mse(neg_text_proj,neg_text_embedding.data) 217 | 218 | total_loss = loss + matching_loss + 0.5*reconstruction_loss + 0.2*text_reconstruction_loss 219 | total_loss.backward() 220 | optimizer.step() 221 | 222 | mean_loss += total_loss.item() 223 | bpr_loss += loss.item() 224 | gt_loss += matching_loss.item() 225 | rc_loss += reconstruction_loss.item() 226 | text_rc_loss += text_reconstruction_loss.item() 227 | 228 | print("loss in epoch {}/{} iteration {}/{}: {} / BPR loss: {} / Matching loss: {} / Item reconstruction: {} / Text reconstruction: {}".format(epoch, total_epoch, step, total_step, mean_loss/iterss, bpr_loss/iterss, gt_loss/iterss, rc_loss/iterss, text_rc_loss/iterss)) 229 | 230 | def make_interact_text(self, interact_ids, interact_max_num): 231 | interact_item_titles_ = self.find_item_text(interact_ids, title_flag=True, description_flag=False) 232 | interact_text = [] 233 | if interact_max_num == 'all': 234 | for title in interact_item_titles_: 235 | interact_text.append(title + '[HistoryEmb]') 236 | else: 237 | for title in interact_item_titles_[-interact_max_num:]: 238 | interact_text.append(title + '[HistoryEmb]') 239 | interact_ids = interact_ids[-interact_max_num:] 240 | 241 | interact_text = ','.join(interact_text) 242 | return interact_text, interact_ids 243 | 244 | def make_candidate_text(self, interact_ids, candidate_num, target_item_id, target_item_title): 245 | neg_item_id = [] 246 | while len(neg_item_id)<50: 247 | t = np.random.randint(1, self.item_num+1) 248 | if not (t in interact_ids or t in neg_item_id): 249 | neg_item_id.append(t) 250 | random.shuffle(neg_item_id) 251 | 252 | candidate_ids = [target_item_id] 253 | candidate_text = [target_item_title + '[CandidateEmb]'] 254 | 255 | for neg_candidate in neg_item_id[:candidate_num - 1]: 256 | candidate_text.append(self.find_item_text_single(neg_candidate, title_flag=True, description_flag=False) + '[CandidateEmb]') 257 | candidate_ids.append(neg_candidate) 258 | 259 | random_ = np.random.permutation(len(candidate_text)) 260 | candidate_text = np.array(candidate_text)[random_] 261 | candidate_ids = np.array(candidate_ids)[random_] 262 | 263 | return ','.join(candidate_text), candidate_ids 264 | 265 | def pre_train_phase2(self, data, optimizer, batch_iter): 266 | epoch, total_epoch, step, total_step = batch_iter 267 | 268 | optimizer.zero_grad() 269 | u, seq, pos, neg = data 270 | mean_loss = 0 271 | 272 | text_input = [] 273 | text_output = [] 274 | interact_embs = [] 275 | candidate_embs = [] 276 | self.llm.eval() 277 | 278 | with torch.no_grad(): 279 | log_emb = self.recsys.model(u,seq,pos,neg, mode = 'log_only') 280 | 281 | for i in range(len(u)): 282 | target_item_id = pos[i][-1] 283 | target_item_title = self.find_item_text_single(target_item_id, title_flag=True, description_flag=False) 284 | 285 | interact_text, interact_ids = self.make_interact_text(seq[i][seq[i]>0], 10) 286 | candidate_num = 20 287 | candidate_text, candidate_ids = self.make_candidate_text(seq[i][seq[i]>0], candidate_num, target_item_id, target_item_title) 288 | 289 | input_text = '' 290 | input_text += ' is a user representation.' 291 | 292 | if self.args.rec_pre_trained_data == 'Movies_and_TV': 293 | input_text += 'This user has watched ' 294 | elif self.args.rec_pre_trained_data == 'Video_Games': 295 | input_text += 'This user has played ' 296 | elif self.args.rec_pre_trained_data == 'Luxury_Beauty' or self.args.rec_pre_trained_data == 'Toys_and_Games': 297 | input_text += 'This user has bought ' 298 | 299 | input_text += interact_text 300 | 301 | if self.args.rec_pre_trained_data == 'Movies_and_TV': 302 | input_text +=' in the previous. Recommend one next movie for this user to watch next from the following movie title set, ' 303 | elif self.args.rec_pre_trained_data == 'Video_Games': 304 | input_text +=' in the previous. Recommend one next game for this user to play next from the following game title set, ' 305 | elif self.args.rec_pre_trained_data == 'Luxury_Beauty' or self.args.rec_pre_trained_data == 'Toys_and_Games': 306 | input_text +=' in the previous. Recommend one next item for this user to buy next from the following item title set, ' 307 | 308 | input_text += candidate_text 309 | input_text += '. The recommendation is ' 310 | 311 | text_input.append(input_text) 312 | text_output.append(target_item_title) 313 | 314 | interact_embs.append(self.item_emb_proj(self.get_item_emb(interact_ids))) 315 | candidate_embs.append(self.item_emb_proj(self.get_item_emb(candidate_ids))) 316 | 317 | samples = {'text_input': text_input, 'text_output': text_output, 'interact': interact_embs, 'candidate':candidate_embs} 318 | log_emb = self.log_emb_proj(log_emb) 319 | loss_rm = self.llm(log_emb, samples) 320 | loss_rm.backward() 321 | optimizer.step() 322 | mean_loss += loss_rm.item() 323 | print("A-LLMRec model loss in epoch {}/{} iteration {}/{}: {}".format(epoch, total_epoch, step, total_step, mean_loss)) 324 | 325 | def generate(self, data): 326 | u, seq, pos, neg, rank = data 327 | 328 | answer = [] 329 | text_input = [] 330 | interact_embs = [] 331 | candidate_embs = [] 332 | with torch.no_grad(): 333 | log_emb = self.recsys.model(u,seq,pos,neg, mode = 'log_only') 334 | for i in range(len(u)): 335 | target_item_id = pos[i] 336 | target_item_title = self.find_item_text_single(target_item_id, title_flag=True, description_flag=False) 337 | 338 | interact_text, interact_ids = self.make_interact_text(seq[i][seq[i]>0], 10) 339 | candidate_num = 20 340 | candidate_text, candidate_ids = self.make_candidate_text(seq[i][seq[i]>0], candidate_num, target_item_id, target_item_title) 341 | 342 | input_text = '' 343 | input_text += ' is a user representation.' 344 | if self.args.rec_pre_trained_data == 'Movies_and_TV': 345 | input_text += 'This user has watched ' 346 | elif self.args.rec_pre_trained_data == 'Video_Games': 347 | input_text += 'This user has played ' 348 | elif self.args.rec_pre_trained_data == 'Luxury_Beauty' or self.args.rec_pre_trained_data == 'Toys_and_Games': 349 | input_text += 'This user has bought ' 350 | 351 | input_text += interact_text 352 | 353 | if self.args.rec_pre_trained_data == 'Movies_and_TV': 354 | input_text +=' in the previous. Recommend one next movie for this user to watch next from the following movie title set, ' 355 | elif self.args.rec_pre_trained_data == 'Video_Games': 356 | input_text +=' in the previous. Recommend one next game for this user to play next from the following game title set, ' 357 | elif self.args.rec_pre_trained_data == 'Luxury_Beauty' or self.args.rec_pre_trained_data == 'Toys_and_Games': 358 | input_text +=' in the previous. Recommend one next item for this user to buy next from the following item title set, ' 359 | 360 | input_text += candidate_text 361 | input_text += '. The recommendation is ' 362 | 363 | answer.append(target_item_title) 364 | text_input.append(input_text) 365 | 366 | interact_embs.append(self.item_emb_proj(self.get_item_emb(interact_ids))) 367 | candidate_embs.append(self.item_emb_proj(self.get_item_emb(candidate_ids))) 368 | 369 | log_emb = self.log_emb_proj(log_emb) 370 | atts_llm = torch.ones(log_emb.size()[:-1], dtype=torch.long).to(self.device) 371 | atts_llm = atts_llm.unsqueeze(1) 372 | log_emb = log_emb.unsqueeze(1) 373 | 374 | with torch.no_grad(): 375 | self.llm.llm_tokenizer.padding_side = "left" 376 | llm_tokens = self.llm.llm_tokenizer( 377 | text_input, 378 | padding="longest", 379 | return_tensors="pt" 380 | ).to(self.device) 381 | 382 | with torch.cuda.amp.autocast(): 383 | inputs_embeds = self.llm.llm_model.get_input_embeddings()(llm_tokens.input_ids) 384 | 385 | llm_tokens, inputs_embeds = self.llm.replace_hist_candi_token(llm_tokens, inputs_embeds, interact_embs, candidate_embs) 386 | 387 | attention_mask = llm_tokens.attention_mask 388 | inputs_embeds = torch.cat([log_emb, inputs_embeds], dim=1) 389 | attention_mask = torch.cat([atts_llm, llm_tokens.attention_mask], dim=1) 390 | 391 | outputs = self.llm.llm_model.generate( 392 | inputs_embeds=inputs_embeds, 393 | attention_mask=attention_mask, 394 | do_sample=False, 395 | top_p=0.9, 396 | temperature=1, 397 | num_beams=1, 398 | max_length=512, 399 | min_length=1, 400 | pad_token_id=self.llm.llm_tokenizer.eos_token_id, 401 | repetition_penalty=1.5, 402 | length_penalty=1, 403 | num_return_sequences=1, 404 | ) 405 | 406 | outputs[outputs == 0] = 2 # convert output id 0 to 2 (eos_token_id) 407 | output_text = self.llm.llm_tokenizer.batch_decode(outputs, skip_special_tokens=True) 408 | output_text = [text.strip() for text in output_text] 409 | 410 | for i in range(len(text_input)): 411 | f = open(f'./recommendation_output.txt','a') 412 | f.write(text_input[i]) 413 | f.write('\n\n') 414 | 415 | f.write('Answer: '+ answer[i]) 416 | f.write('\n\n') 417 | 418 | f.write('LLM: '+str(output_text[i])) 419 | f.write('\n\n') 420 | f.close() 421 | 422 | return output_text -------------------------------------------------------------------------------- /models/llm4rec.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from transformers import AutoTokenizer, OPTForCausalLM 5 | 6 | class llm4rec(nn.Module): 7 | def __init__( 8 | self, 9 | device, 10 | llm_model="", 11 | max_output_txt_len=256, 12 | ): 13 | super().__init__() 14 | self.device = device 15 | 16 | if llm_model == 'opt': 17 | self.llm_model = OPTForCausalLM.from_pretrained("facebook/opt-6.7b", torch_dtype=torch.float16, load_in_8bit=True, device_map=self.device) 18 | self.llm_tokenizer = AutoTokenizer.from_pretrained("facebook/opt-6.7b", use_fast=False) 19 | # self.llm_model = OPTForCausalLM.from_pretrained("facebook/opt-6.7b", torch_dtype=torch.float16, device_map=self.device) 20 | else: 21 | raise Exception(f'{llm_model} is not supported') 22 | 23 | self.llm_tokenizer.add_special_tokens({'pad_token': '[PAD]'}) 24 | self.llm_tokenizer.add_special_tokens({'bos_token': ''}) 25 | self.llm_tokenizer.add_special_tokens({'eos_token': ''}) 26 | self.llm_tokenizer.add_special_tokens({'unk_token': ''}) 27 | self.llm_tokenizer.add_special_tokens({'additional_special_tokens': ['[UserRep]','[HistoryEmb]','[CandidateEmb]']}) 28 | 29 | self.llm_model.resize_token_embeddings(len(self.llm_tokenizer)) 30 | 31 | for _, param in self.llm_model.named_parameters(): 32 | param.requires_grad = False 33 | 34 | self.max_output_txt_len = max_output_txt_len 35 | 36 | def concat_text_input_output(self, input_ids, input_atts, output_ids, output_atts): 37 | input_part_targets_len = [] 38 | llm_tokens = {"input_ids": [], "attention_mask": []} 39 | for i in range(input_ids.size(0)): 40 | this_input_ones = input_atts[i].sum() 41 | input_part_targets_len.append(this_input_ones) 42 | llm_tokens['input_ids'].append( 43 | torch.cat([ 44 | input_ids[i][:this_input_ones], 45 | output_ids[i][1:], 46 | input_ids[i][this_input_ones:] 47 | ]) 48 | ) 49 | llm_tokens['attention_mask'].append( 50 | torch.cat([ 51 | input_atts[i][:this_input_ones], 52 | output_atts[i][1:], 53 | input_atts[i][this_input_ones:] 54 | ]) 55 | ) 56 | llm_tokens['input_ids'] = torch.stack(llm_tokens['input_ids']) 57 | llm_tokens['attention_mask'] = torch.stack(llm_tokens['attention_mask']) 58 | return llm_tokens, input_part_targets_len 59 | 60 | def replace_hist_candi_token(self, llm_tokens, inputs_embeds, interact_embs, candidate_embs): 61 | if len(interact_embs) == 0: 62 | return llm_tokens, inputs_embeds 63 | history_token_id = self.llm_tokenizer("[HistoryEmb]", return_tensors="pt", add_special_tokens=False).input_ids.item() 64 | candidate_token_id = self.llm_tokenizer("[CandidateEmb]", return_tensors="pt", add_special_tokens=False).input_ids.item() 65 | 66 | for inx in range(len(llm_tokens["input_ids"])): 67 | idx_tensor=(llm_tokens["input_ids"][inx]==history_token_id).nonzero().view(-1) 68 | for idx, item_emb in zip(idx_tensor, interact_embs[inx]): 69 | inputs_embeds[inx][idx]=item_emb 70 | 71 | idx_tensor=(llm_tokens["input_ids"][inx]==candidate_token_id).nonzero().view(-1) 72 | for idx, item_emb in zip(idx_tensor, candidate_embs[inx]): 73 | inputs_embeds[inx][idx]=item_emb 74 | return llm_tokens, inputs_embeds 75 | 76 | def forward(self, log_emb, samples): 77 | atts_llm = torch.ones(log_emb.size()[:-1], dtype=torch.long).to(self.device) 78 | atts_llm = atts_llm.unsqueeze(1) 79 | 80 | text_output_tokens = self.llm_tokenizer( 81 | [t + self.llm_tokenizer.eos_token for t in samples['text_output']], 82 | return_tensors="pt", 83 | padding="longest", 84 | truncation=False, 85 | ).to(self.device) 86 | 87 | text_input_tokens = self.llm_tokenizer( 88 | samples['text_input'], 89 | return_tensors="pt", 90 | padding="longest", 91 | truncation=False, 92 | ).to(self.device) 93 | 94 | llm_tokens, input_part_targets_len = self.concat_text_input_output( 95 | text_input_tokens.input_ids, 96 | text_input_tokens.attention_mask, 97 | text_output_tokens.input_ids, 98 | text_output_tokens.attention_mask, 99 | ) 100 | 101 | targets = llm_tokens['input_ids'].masked_fill(llm_tokens['input_ids'] == self.llm_tokenizer.pad_token_id, -100) 102 | 103 | for i, l in enumerate(input_part_targets_len): 104 | targets[i][:l] = -100 105 | 106 | empty_targets = (torch.ones(atts_llm.size(), dtype=torch.long).to(self.device).fill_(-100)) 107 | 108 | targets = torch.cat([empty_targets, targets], dim=1) 109 | 110 | inputs_embeds = self.llm_model.get_input_embeddings()(llm_tokens['input_ids']) 111 | llm_tokens, inputs_embeds = self.replace_hist_candi_token(llm_tokens, inputs_embeds, samples['interact'], samples['candidate']) 112 | attention_mask = llm_tokens['attention_mask'] 113 | 114 | log_emb = log_emb.unsqueeze(1) 115 | inputs_embeds = torch.cat([log_emb, inputs_embeds], dim=1) 116 | attention_mask = torch.cat([atts_llm, llm_tokens['attention_mask']], dim=1) 117 | 118 | with torch.cuda.amp.autocast(): 119 | outputs = self.llm_model( 120 | inputs_embeds=inputs_embeds, 121 | attention_mask=attention_mask, 122 | return_dict=True, 123 | labels=targets, 124 | ) 125 | loss = outputs.loss 126 | 127 | return loss -------------------------------------------------------------------------------- /models/recsys_model.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import logging 3 | import os 4 | import glob 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.distributed as dist 9 | import torch.nn.functional as F 10 | 11 | from utils import * 12 | from pre_train.sasrec.model import SASRec 13 | 14 | 15 | def load_checkpoint(recsys, pre_trained): 16 | path = f'pre_train/{recsys}/{pre_trained}/' 17 | 18 | pth_file_path = find_filepath(path, '.pth') 19 | assert len(pth_file_path) == 1, 'There are more than two models in this dir. You need to remove other model files.\n' 20 | kwargs, checkpoint = torch.load(pth_file_path[0], map_location="cpu") 21 | logging.info("load checkpoint from %s" % pth_file_path[0]) 22 | 23 | return kwargs, checkpoint 24 | 25 | class RecSys(nn.Module): 26 | def __init__(self, recsys_model, pre_trained_data, device): 27 | super().__init__() 28 | kwargs, checkpoint = load_checkpoint(recsys_model, pre_trained_data) 29 | kwargs['args'].device = device 30 | model = SASRec(**kwargs) 31 | model.load_state_dict(checkpoint) 32 | 33 | for p in model.parameters(): 34 | p.requires_grad = False 35 | 36 | self.item_num = model.item_num 37 | self.user_num = model.user_num 38 | self.model = model.to(device) 39 | self.hidden_units = kwargs['args'].hidden_units 40 | 41 | def forward(): 42 | print('forward') -------------------------------------------------------------------------------- /pre_train/ctrl/model_ctrl.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from sentence_transformers import SentenceTransformer 4 | import pickle 5 | import random 6 | class PointWiseFeedForward(torch.nn.Module): 7 | def __init__(self, hidden_units, dropout_rate): 8 | 9 | super(PointWiseFeedForward, self).__init__() 10 | 11 | self.conv1 = torch.nn.Conv1d(hidden_units, hidden_units, kernel_size=1) 12 | self.dropout1 = torch.nn.Dropout(p=dropout_rate) 13 | self.relu = torch.nn.ReLU() 14 | self.conv2 = torch.nn.Conv1d(hidden_units, hidden_units, kernel_size=1) 15 | self.dropout2 = torch.nn.Dropout(p=dropout_rate) 16 | 17 | def forward(self, inputs): 18 | outputs = self.dropout2(self.conv2(self.relu(self.dropout1(self.conv1(inputs.transpose(-1, -2)))))) 19 | outputs = outputs.transpose(-1, -2) # as Conv1D requires (N, C, Length) 20 | outputs += inputs 21 | return outputs 22 | 23 | class SASRec_CTRL(torch.nn.Module): 24 | def __init__(self, user_num, item_num, args): 25 | super(SASRec_CTRL, self).__init__() 26 | 27 | self.kwargs = {'user_num': user_num, 'item_num':item_num, 'args':args} 28 | self.user_num = user_num 29 | self.item_num = item_num 30 | self.dev = args.device 31 | self.description = args.use_description 32 | 33 | self.item_emb = torch.nn.Embedding(self.item_num+1, args.hidden_units, padding_idx=0) 34 | self.pos_emb = torch.nn.Embedding(args.maxlen, args.hidden_units) 35 | self.emb_dropout = torch.nn.Dropout(p=args.dropout_rate) 36 | 37 | #Modality encoder 38 | self.sbert = SentenceTransformer('nq-distilbert-base-v1') 39 | 40 | #W_text in equation 3 41 | self.projection = torch.nn.Linear(768,args.hidden_units) 42 | 43 | #W_tab in equation 2 44 | self.projection2 = torch.nn.Linear(args.hidden_units,args.hidden_units) 45 | 46 | # Fine-grained align Wm in equation 7 47 | self.finegrain1_1 = torch.nn.Linear(args.hidden_units,args.hidden_units) 48 | self.finegrain1_2 = torch.nn.Linear(args.hidden_units,args.hidden_units) 49 | self.finegrain1_3 = torch.nn.Linear(args.hidden_units,args.hidden_units) 50 | self.finegrain2_1 = torch.nn.Linear(args.hidden_units,args.hidden_units) 51 | self.finegrain2_2 = torch.nn.Linear(args.hidden_units,args.hidden_units) 52 | self.finegrain2_3 = torch.nn.Linear(args.hidden_units,args.hidden_units) 53 | 54 | self.final_layer = torch.nn.Linear(args.hidden_units,args.hidden_units) 55 | 56 | #Backbone network 57 | self.attention_layernorms = torch.nn.ModuleList() # to be Q for self-attention 58 | self.attention_layers = torch.nn.ModuleList() 59 | self.forward_layernorms = torch.nn.ModuleList() 60 | self.forward_layers = torch.nn.ModuleList() 61 | self.last_layernorm = torch.nn.LayerNorm(args.hidden_units, eps=1e-8) 62 | self.args =args 63 | self.bce_criterion = torch.nn.BCEWithLogitsLoss() # torch.nn.BCELoss() 64 | 65 | #Load textual meta data 66 | with open(f'./data/Movies_and_TV_meta.json.gz','rb') as ft: 67 | self.text_name_dict = pickle.load(ft) 68 | 69 | #Backbone network 70 | for _ in range(args.num_blocks): 71 | new_attn_layernorm = torch.nn.LayerNorm(args.hidden_units, eps=1e-8) 72 | self.attention_layernorms.append(new_attn_layernorm) 73 | 74 | new_attn_layer = torch.nn.MultiheadAttention(args.hidden_units, 75 | args.num_heads, 76 | args.dropout_rate) 77 | self.attention_layers.append(new_attn_layer) 78 | 79 | new_fwd_layernorm = torch.nn.LayerNorm(args.hidden_units, eps=1e-8) 80 | self.forward_layernorms.append(new_fwd_layernorm) 81 | 82 | new_fwd_layer = PointWiseFeedForward(args.hidden_units, args.dropout_rate) 83 | self.forward_layers.append(new_fwd_layer) 84 | 85 | 86 | # Find Text Data 87 | def find_item_text(self, item, title_flag=True, description_flag=True): 88 | t = 'title' 89 | d = 'description' 90 | t_ = 'No Title' 91 | d_ = 'No Description' 92 | if title_flag and description_flag: 93 | return [f'"Title:{self.text_name_dict[t].get(i,t_)}, Description:{self.text_name_dict[d].get(i,d_)}"' for i in item] 94 | elif title_flag and not description_flag: 95 | return [f'"Title:{self.text_name_dict[t].get(i,t_)}"' for i in item] 96 | elif not title_flag and description_flag: 97 | return [f'"Description:{self.text_name_dict[d].get(i,d_)}"' for i in item] 98 | 99 | def log2feats(self, log_seqs): 100 | seqs = self.item_emb(torch.LongTensor(log_seqs).to(self.dev)) 101 | seqs *= self.item_emb.embedding_dim ** 0.5 102 | positions = np.tile(np.array(range(log_seqs.shape[1])), [log_seqs.shape[0], 1]) 103 | seqs += self.pos_emb(torch.LongTensor(positions).to(self.dev)) 104 | seqs = self.emb_dropout(seqs) 105 | 106 | timeline_mask = torch.BoolTensor(log_seqs == 0).to(self.dev) 107 | seqs *= ~timeline_mask.unsqueeze(-1) 108 | 109 | tl = seqs.shape[1] 110 | attention_mask = ~torch.tril(torch.ones((tl, tl), dtype=torch.bool, device=self.dev)) 111 | 112 | for i in range(len(self.attention_layers)): 113 | seqs = torch.transpose(seqs, 0, 1) 114 | Q = self.attention_layernorms[i](seqs) 115 | mha_outputs, _ = self.attention_layers[i](Q, seqs, seqs, 116 | attn_mask=attention_mask) 117 | seqs = Q + mha_outputs 118 | seqs = torch.transpose(seqs, 0, 1) 119 | 120 | seqs = self.forward_layernorms[i](seqs) 121 | seqs = self.forward_layers[i](seqs) 122 | seqs *= ~timeline_mask.unsqueeze(-1) 123 | 124 | log_feats = self.last_layernorm(seqs) # (U, T, C) -> (U, -1, C) 125 | return log_feats 126 | 127 | def forward(self, user_ids, log_seqs, pos_seqs, neg_seqs, mode='default', pretrain=True, opt = None): 128 | 129 | # Get Collaborative model embedding (i.e., get Tab embedding in CTRL) 130 | log_feats = self.log2feats(log_seqs) 131 | 132 | # Cross-modal contrastive learning 133 | if pretrain: 134 | total_loss = 0 135 | iterss = 0 136 | log_feats = log_feats.reshape(-1,log_feats.shape[2]) 137 | log_feats = log_feats[log_seqs.reshape(log_seqs.size)>0] 138 | text_list = [] 139 | 140 | # Get Textual Data 141 | for l in log_seqs: 142 | ll = l[l>0] 143 | for i in range(len(ll)): 144 | to_text = ll[:i+1] 145 | text = "This is a a user, who has recently watched " + '|'.join(self.find_item_text(to_text, description_flag=False)) 146 | text += '. This is a movie, title is ' + ','.join(self.find_item_text(to_text, description_flag=self.description)) 147 | print(text) 148 | text_list.append(text) 149 | 150 | # Embed textual data using Semantic Model 151 | token = self.sbert.tokenize(text_list) 152 | text_embedding= self.sbert({'input_ids':token['input_ids'].to(log_feats.device),'attention_mask':token['attention_mask'].to(log_feats.device)})['sentence_embedding'] 153 | 154 | # Projection - Equation 2 and 3 155 | text_embedding = self.projection(text_embedding) 156 | log_feats = self.projection2(log_feats) 157 | 158 | start_idx = 0 159 | end_idx = 32 160 | loss = 0 161 | 162 | # Cross-modal Contrasive Learning (Batch samples, auto-regressive - SASRec) 163 | while start_idx h^tab_m (i.e., use output of collaborative model) 174 | log_fine1 = self.finegrain1_1(log) 175 | log_fine2 = self.finegrain1_2(log) 176 | log_fine3 = self.finegrain1_3(log) 177 | 178 | # m-th sub-representation (m=3) -> h^text_m (i.e., use output of semantic model) 179 | text_fine1 = self.finegrain2_1(text_) 180 | text_fine2 = self.finegrain2_2(text_) 181 | text_fine3 = self.finegrain2_3(text_) 182 | 183 | sim_mat1 = torch.matmul(log_fine1, text_fine1.T).unsqueeze(0) 184 | sim_mat2 = torch.matmul(log_fine1, text_fine2.T).unsqueeze(0) 185 | sim_mat3 = torch.matmul(log_fine1, text_fine3.T).unsqueeze(0) 186 | 187 | # Maximum similarity 188 | results1 = torch.cat([sim_mat1,sim_mat2,sim_mat3],dim=0).max(axis=0)[0] 189 | 190 | sim_mat4 = torch.matmul(log_fine2, text_fine1.T).unsqueeze(0) 191 | sim_mat5 = torch.matmul(log_fine2, text_fine2.T).unsqueeze(0) 192 | sim_mat6 = torch.matmul(log_fine2, text_fine3.T).unsqueeze(0) 193 | 194 | # Maximum similarity 195 | results2 = torch.cat([sim_mat4,sim_mat5,sim_mat6],dim=0).max(axis=0)[0] 196 | 197 | 198 | 199 | sim_mat7 = torch.matmul(log_fine3, text_fine1.T).unsqueeze(0) 200 | sim_mat8 = torch.matmul(log_fine3, text_fine2.T).unsqueeze(0) 201 | sim_mat9 = torch.matmul(log_fine3, text_fine3.T).unsqueeze(0) 202 | 203 | # Maximum similarity 204 | results3 = torch.cat([sim_mat7,sim_mat8,sim_mat9],dim=0).max(axis=0)[0] 205 | 206 | # Get fine-grained similarity over all sub-representations 207 | results = results1 + results2 + results3 208 | 209 | # Maximum similarity 210 | # Get fine-grained similarity over all sub-representations 211 | test_results1 = torch.cat([sim_mat1,sim_mat4,sim_mat7],dim=0).max(axis=0)[0] 212 | test_results2 = torch.cat([sim_mat2,sim_mat5,sim_mat8],dim=0).max(axis=0)[0] 213 | test_results3 = torch.cat([sim_mat3,sim_mat6,sim_mat9],dim=0).max(axis=0)[0] 214 | test_results = test_results1 + test_results2 + test_results3 215 | 216 | 217 | # tabular2textual Equation 5 218 | pos_labels = torch.ones(results.diag().shape, device=log_feats.device) 219 | neg_labels = torch.zeros(results[~torch.eye(len(text_),dtype=bool)].shape, device=log_feats.device) 220 | 221 | cal += self.bce_criterion(results.diag(), pos_labels) 222 | cal+=self.bce_criterion(results[~torch.eye(len(results),dtype=bool)], neg_labels) 223 | 224 | # textual2tabular Equation 4 225 | pos_labels = torch.ones(test_results.diag().shape, device=log_feats.device) 226 | neg_labels = torch.zeros(test_results[~torch.eye(len(text_),dtype=bool)].shape, device=log_feats.device) 227 | 228 | cal += self.bce_criterion(test_results.diag(), pos_labels) 229 | cal+=self.bce_criterion(test_results[~torch.eye(len(test_results),dtype=bool)], neg_labels) 230 | 231 | # Equation 6 232 | loss += (cal/2) 233 | 234 | opt.zero_grad() 235 | loss.backward() 236 | opt.step() 237 | total_loss += loss.item() 238 | return total_loss 239 | 240 | else: 241 | # Supervised Fine-tuning 242 | log_feats = self.final_layer(log_feats) 243 | if mode == 'log_only': 244 | log_feats = log_feats[:, -1, :] 245 | return log_feats 246 | 247 | pos_embs = self.item_emb(torch.LongTensor(pos_seqs).to(self.dev)) 248 | neg_embs = self.item_emb(torch.LongTensor(neg_seqs).to(self.dev)) 249 | 250 | pos_logits = (log_feats * pos_embs).sum(dim=-1) 251 | neg_logits = (log_feats * neg_embs).sum(dim=-1) 252 | 253 | if self.args.pretrain_stage == True: 254 | return log_feats.reshape(-1,log_feats.shape[2]), pos_embs.reshape(-1,log_feats.shape[2]), neg_embs.reshape(-1,log_feats.shape[2]) 255 | else: 256 | return pos_logits, neg_logits 257 | 258 | def predict(self, user_ids, log_seqs, item_indices): 259 | log_feats = self.log2feats(log_seqs) 260 | log_feats = self.final_layer(log_feats) 261 | 262 | final_feat = log_feats[:, -1, :] 263 | 264 | item_embs = self.item_emb(torch.LongTensor(item_indices).to(self.dev)) 265 | 266 | logits = item_embs.matmul(final_feat.unsqueeze(-1)).squeeze(-1) 267 | 268 | 269 | return logits 270 | -------------------------------------------------------------------------------- /pre_train/sasrec/data_preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | import gzip 4 | import json 5 | import pickle 6 | from tqdm import tqdm 7 | from collections import defaultdict 8 | 9 | def parse(path): 10 | g = gzip.open(path, 'rb') 11 | for l in tqdm(g): 12 | yield json.loads(l) 13 | 14 | def preprocess(fname): 15 | countU = defaultdict(lambda: 0) 16 | countP = defaultdict(lambda: 0) 17 | line = 0 18 | 19 | file_path = f'../../data/amazon/{fname}.json.gz' 20 | 21 | # counting interactions for each user and item 22 | for l in parse(file_path): 23 | line += 1 24 | if ('Beauty' in fname) or ('Toys' in fname): 25 | if l['overall'] < 3: 26 | continue 27 | asin = l['asin'] 28 | rev = l['reviewerID'] 29 | time = l['unixReviewTime'] 30 | 31 | countU[rev] += 1 32 | countP[asin] += 1 33 | 34 | usermap = dict() 35 | usernum = 0 36 | itemmap = dict() 37 | itemnum = 0 38 | User = dict() 39 | review_dict = {} 40 | name_dict = {'title':{}, 'description':{}} 41 | 42 | f = open(f'../../data/amazon/meta_{fname}.json', 'r') 43 | json_data = f.readlines() 44 | f.close() 45 | data_list = [json.loads(line[:-1]) for line in json_data] 46 | meta_dict = {} 47 | for l in data_list: 48 | meta_dict[l['asin']] = l 49 | 50 | for l in parse(file_path): 51 | line += 1 52 | asin = l['asin'] 53 | rev = l['reviewerID'] 54 | time = l['unixReviewTime'] 55 | 56 | threshold = 5 57 | if ('Beauty' in fname) or ('Toys' in fname): 58 | threshold = 4 59 | 60 | if countU[rev] < threshold or countP[asin] < threshold: 61 | continue 62 | 63 | if rev in usermap: 64 | userid = usermap[rev] 65 | else: 66 | usernum += 1 67 | userid = usernum 68 | usermap[rev] = userid 69 | User[userid] = [] 70 | 71 | if asin in itemmap: 72 | itemid = itemmap[asin] 73 | else: 74 | itemnum += 1 75 | itemid = itemnum 76 | itemmap[asin] = itemid 77 | User[userid].append([time, itemid]) 78 | 79 | 80 | if itemmap[asin] in review_dict: 81 | try: 82 | review_dict[itemmap[asin]]['review'][usermap[rev]] = l['reviewText'] 83 | except: 84 | a = 0 85 | try: 86 | review_dict[itemmap[asin]]['summary'][usermap[rev]] = l['summary'] 87 | except: 88 | a = 0 89 | else: 90 | review_dict[itemmap[asin]] = {'review': {}, 'summary':{}} 91 | try: 92 | review_dict[itemmap[asin]]['review'][usermap[rev]] = l['reviewText'] 93 | except: 94 | a = 0 95 | try: 96 | review_dict[itemmap[asin]]['summary'][usermap[rev]] = l['summary'] 97 | except: 98 | a = 0 99 | try: 100 | if len(meta_dict[asin]['description']) ==0: 101 | name_dict['description'][itemmap[asin]] = 'Empty description' 102 | else: 103 | name_dict['description'][itemmap[asin]] = meta_dict[asin]['description'][0] 104 | name_dict['title'][itemmap[asin]] = meta_dict[asin]['title'] 105 | except: 106 | a =0 107 | 108 | with open(f'../../data/amazon/{fname}_text_name_dict.json.gz', 'wb') as tf: 109 | pickle.dump(name_dict, tf) 110 | 111 | for userid in User.keys(): 112 | User[userid].sort(key=lambda x: x[0]) 113 | 114 | print(usernum, itemnum) 115 | 116 | f = open(f'../../data/amazon/{fname}.txt', 'w') 117 | for user in User.keys(): 118 | for i in User[user]: 119 | f.write('%d %d\n' % (user, i[1])) 120 | f.close() 121 | -------------------------------------------------------------------------------- /pre_train/sasrec/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch 4 | import argparse 5 | 6 | from model import SASRec 7 | from data_preprocess import * 8 | from utils import * 9 | 10 | from tqdm import tqdm 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--dataset', required=True) 14 | parser.add_argument('--batch_size', default=128, type=int) 15 | parser.add_argument('--lr', default=0.001, type=float) 16 | parser.add_argument('--maxlen', default=50, type=int) 17 | parser.add_argument('--hidden_units', default=50, type=int) 18 | parser.add_argument('--num_blocks', default=2, type=int) 19 | parser.add_argument('--num_epochs', default=200, type=int) 20 | parser.add_argument('--num_heads', default=1, type=int) 21 | parser.add_argument('--dropout_rate', default=0.5, type=float) 22 | parser.add_argument('--l2_emb', default=0.0, type=float) 23 | parser.add_argument('--device', default='cpu', type=str) 24 | parser.add_argument('--inference_only', default=False, action='store_true') 25 | parser.add_argument('--state_dict_path', default=None, type=str) 26 | 27 | args = parser.parse_args() 28 | 29 | if __name__ == '__main__': 30 | 31 | # global dataset 32 | preprocess(args.dataset) 33 | dataset = data_partition(args.dataset) 34 | 35 | [user_train, user_valid, user_test, usernum, itemnum] = dataset 36 | print('user num:', usernum, 'item num:', itemnum) 37 | num_batch = len(user_train) // args.batch_size 38 | cc = 0.0 39 | for u in user_train: 40 | cc += len(user_train[u]) 41 | print('average sequence length: %.2f' % (cc / len(user_train))) 42 | 43 | # dataloader 44 | sampler = WarpSampler(user_train, usernum, itemnum, batch_size=args.batch_size, maxlen=args.maxlen, n_workers=3) 45 | # model init 46 | model = SASRec(usernum, itemnum, args).to(args.device) 47 | 48 | for name, param in model.named_parameters(): 49 | try: 50 | torch.nn.init.xavier_normal_(param.data) 51 | except: 52 | pass 53 | 54 | model.train() 55 | 56 | epoch_start_idx = 1 57 | if args.state_dict_path is not None: 58 | try: 59 | kwargs, checkpoint = torch.load(args.state_dict_path, map_location=torch.device(args.device)) 60 | kwargs['args'].device = args.device 61 | model = SASRec(**kwargs).to(args.device) 62 | model.load_state_dict(checkpoint) 63 | tail = args.state_dict_path[args.state_dict_path.find('epoch=') + 6:] 64 | epoch_start_idx = int(tail[:tail.find('.')]) + 1 65 | except: 66 | print('failed loading state_dicts, pls check file path: ', end="") 67 | print(args.state_dict_path) 68 | print('pdb enabled for your quick check, pls type exit() if you do not need it') 69 | import pdb; pdb.set_trace() 70 | 71 | if args.inference_only: 72 | model.eval() 73 | t_test = evaluate(model, dataset, args) 74 | print('test (NDCG@10: %.4f, HR@10: %.4f)' % (t_test[0], t_test[1])) 75 | 76 | bce_criterion = torch.nn.BCEWithLogitsLoss() 77 | adam_optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.98)) 78 | 79 | T = 0.0 80 | t0 = time.time() 81 | 82 | for epoch in tqdm(range(epoch_start_idx, args.num_epochs + 1)): 83 | if args.inference_only: break 84 | for step in range(num_batch): 85 | u, seq, pos, neg = sampler.next_batch() 86 | u, seq, pos, neg = np.array(u), np.array(seq), np.array(pos), np.array(neg) 87 | pos_logits, neg_logits = model(u, seq, pos, neg) 88 | pos_labels, neg_labels = torch.ones(pos_logits.shape, device=args.device), torch.zeros(neg_logits.shape, device=args.device) 89 | 90 | adam_optimizer.zero_grad() 91 | indices = np.where(pos != 0) 92 | loss = bce_criterion(pos_logits[indices], pos_labels[indices]) 93 | loss += bce_criterion(neg_logits[indices], neg_labels[indices]) 94 | for param in model.item_emb.parameters(): loss += args.l2_emb * torch.norm(param) 95 | loss.backward() 96 | adam_optimizer.step() 97 | if step % 100 == 0: 98 | print("loss in epoch {} iteration {}: {}".format(epoch, step, loss.item())) # expected 0.4~0.6 after init few epochs 99 | 100 | if epoch % 20 == 0 or epoch == 1: 101 | model.eval() 102 | t1 = time.time() - t0 103 | T += t1 104 | print('Evaluating', end='') 105 | t_test = evaluate(model, dataset, args) 106 | t_valid = evaluate_valid(model, dataset, args) 107 | print('\n') 108 | print('epoch:%d, time: %f(s), valid (NDCG@10: %.4f, HR@10: %.4f), test (NDCG@10: %.4f, HR@10: %.4f)' 109 | % (epoch, T, t_valid[0], t_valid[1], t_test[0], t_test[1])) 110 | 111 | print(str(t_valid) + ' ' + str(t_test) + '\n') 112 | t0 = time.time() 113 | model.train() 114 | 115 | if epoch == args.num_epochs: 116 | folder = args.dataset 117 | fname = 'SASRec.epoch={}.lr={}.layer={}.head={}.hidden={}.maxlen={}.pth' 118 | fname = fname.format(args.num_epochs, args.lr, args.num_blocks, args.num_heads, args.hidden_units, args.maxlen) 119 | if not os.path.exists(os.path.join(folder, fname)): 120 | try: 121 | os.makedirs(os.path.join(folder)) 122 | except: 123 | print() 124 | torch.save([model.kwargs, model.state_dict()], os.path.join(folder, fname)) 125 | 126 | sampler.close() 127 | print("Done") -------------------------------------------------------------------------------- /pre_train/sasrec/model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class PointWiseFeedForward(torch.nn.Module): 6 | def __init__(self, hidden_units, dropout_rate): 7 | 8 | super(PointWiseFeedForward, self).__init__() 9 | 10 | self.conv1 = torch.nn.Conv1d(hidden_units, hidden_units, kernel_size=1) 11 | self.dropout1 = torch.nn.Dropout(p=dropout_rate) 12 | self.relu = torch.nn.ReLU() 13 | self.conv2 = torch.nn.Conv1d(hidden_units, hidden_units, kernel_size=1) 14 | self.dropout2 = torch.nn.Dropout(p=dropout_rate) 15 | 16 | def forward(self, inputs): 17 | outputs = self.dropout2(self.conv2(self.relu(self.dropout1(self.conv1(inputs.transpose(-1, -2)))))) 18 | outputs = outputs.transpose(-1, -2) 19 | outputs += inputs 20 | return outputs 21 | 22 | class SASRec(torch.nn.Module): 23 | def __init__(self, user_num, item_num, args): 24 | super(SASRec, self).__init__() 25 | 26 | self.kwargs = {'user_num': user_num, 'item_num':item_num, 'args':args} 27 | self.user_num = user_num 28 | self.item_num = item_num 29 | self.dev = args.device 30 | 31 | self.item_emb = torch.nn.Embedding(self.item_num+1, args.hidden_units, padding_idx=0) 32 | self.pos_emb = torch.nn.Embedding(args.maxlen, args.hidden_units) 33 | self.emb_dropout = torch.nn.Dropout(p=args.dropout_rate) 34 | 35 | self.attention_layernorms = torch.nn.ModuleList() 36 | self.attention_layers = torch.nn.ModuleList() 37 | self.forward_layernorms = torch.nn.ModuleList() 38 | self.forward_layers = torch.nn.ModuleList() 39 | 40 | self.last_layernorm = torch.nn.LayerNorm(args.hidden_units, eps=1e-8) 41 | 42 | self.args =args 43 | 44 | 45 | for _ in range(args.num_blocks): 46 | new_attn_layernorm = torch.nn.LayerNorm(args.hidden_units, eps=1e-8) 47 | self.attention_layernorms.append(new_attn_layernorm) 48 | 49 | new_attn_layer = torch.nn.MultiheadAttention(args.hidden_units, 50 | args.num_heads, 51 | args.dropout_rate) 52 | self.attention_layers.append(new_attn_layer) 53 | 54 | new_fwd_layernorm = torch.nn.LayerNorm(args.hidden_units, eps=1e-8) 55 | self.forward_layernorms.append(new_fwd_layernorm) 56 | 57 | new_fwd_layer = PointWiseFeedForward(args.hidden_units, args.dropout_rate) 58 | self.forward_layers.append(new_fwd_layer) 59 | 60 | def log2feats(self, log_seqs): 61 | seqs = self.item_emb(torch.LongTensor(log_seqs).to(self.dev)) 62 | seqs *= self.item_emb.embedding_dim ** 0.5 63 | positions = np.tile(np.array(range(log_seqs.shape[1])), [log_seqs.shape[0], 1]) 64 | seqs += self.pos_emb(torch.LongTensor(positions).to(self.dev)) 65 | seqs = self.emb_dropout(seqs) 66 | 67 | timeline_mask = torch.BoolTensor(log_seqs == 0).to(self.dev) 68 | seqs *= ~timeline_mask.unsqueeze(-1) 69 | 70 | tl = seqs.shape[1] 71 | attention_mask = ~torch.tril(torch.ones((tl, tl), dtype=torch.bool, device=self.dev)) 72 | 73 | for i in range(len(self.attention_layers)): 74 | seqs = torch.transpose(seqs, 0, 1) 75 | Q = self.attention_layernorms[i](seqs) 76 | mha_outputs, _ = self.attention_layers[i](Q, seqs, seqs, 77 | attn_mask=attention_mask) 78 | 79 | seqs = Q + mha_outputs 80 | seqs = torch.transpose(seqs, 0, 1) 81 | 82 | seqs = self.forward_layernorms[i](seqs) 83 | seqs = self.forward_layers[i](seqs) 84 | seqs *= ~timeline_mask.unsqueeze(-1) 85 | 86 | log_feats = self.last_layernorm(seqs) 87 | return log_feats 88 | 89 | def forward(self, user_ids, log_seqs, pos_seqs, neg_seqs, mode='default'): 90 | log_feats = self.log2feats(log_seqs) 91 | if mode == 'log_only': 92 | log_feats = log_feats[:, -1, :] 93 | return log_feats 94 | 95 | pos_embs = self.item_emb(torch.LongTensor(pos_seqs).to(self.dev)) 96 | neg_embs = self.item_emb(torch.LongTensor(neg_seqs).to(self.dev)) 97 | 98 | pos_logits = (log_feats * pos_embs).sum(dim=-1) 99 | neg_logits = (log_feats * neg_embs).sum(dim=-1) 100 | 101 | # pos_pred = self.pos_sigmoid(pos_logits) 102 | # neg_pred = self.neg_sigmoid(neg_logits) 103 | if mode == 'item': 104 | return log_feats.reshape(-1, log_feats.shape[2]), pos_embs.reshape(-1, log_feats.shape[2]), neg_embs.reshape(-1, log_feats.shape[2]) 105 | else: 106 | return pos_logits, neg_logits 107 | 108 | def predict(self, user_ids, log_seqs, item_indices): 109 | log_feats = self.log2feats(log_seqs) 110 | 111 | final_feat = log_feats[:, -1, :] 112 | 113 | item_embs = self.item_emb(torch.LongTensor(item_indices).to(self.dev)) 114 | 115 | logits = item_embs.matmul(final_feat.unsqueeze(-1)).squeeze(-1) 116 | 117 | return logits 118 | -------------------------------------------------------------------------------- /pre_train/sasrec/utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import copy 3 | import torch 4 | import random 5 | import numpy as np 6 | from collections import defaultdict 7 | from multiprocessing import Process, Queue 8 | import os 9 | from datetime import datetime 10 | from pytz import timezone 11 | from torch.utils.data import Dataset 12 | 13 | 14 | # sampler for batch generation 15 | def random_neq(l, r, s): 16 | t = np.random.randint(l, r) 17 | while t in s: 18 | t = np.random.randint(l, r) 19 | return t 20 | 21 | 22 | def sample_function(user_train, usernum, itemnum, batch_size, maxlen, result_queue, SEED): 23 | def sample(): 24 | 25 | user = np.random.randint(1, usernum + 1) 26 | while len(user_train[user]) <= 1: user = np.random.randint(1, usernum + 1) 27 | 28 | seq = np.zeros([maxlen], dtype=np.int32) 29 | pos = np.zeros([maxlen], dtype=np.int32) 30 | neg = np.zeros([maxlen], dtype=np.int32) 31 | nxt = user_train[user][-1] 32 | idx = maxlen - 1 33 | 34 | ts = set(user_train[user]) 35 | for i in reversed(user_train[user][:-1]): 36 | seq[idx] = i 37 | pos[idx] = nxt 38 | if nxt != 0: neg[idx] = random_neq(1, itemnum + 1, ts) 39 | nxt = i 40 | idx -= 1 41 | if idx == -1: break 42 | 43 | return (user, seq, pos, neg) 44 | 45 | np.random.seed(SEED) 46 | while True: 47 | one_batch = [] 48 | for i in range(batch_size): 49 | one_batch.append(sample()) 50 | 51 | result_queue.put(zip(*one_batch)) 52 | 53 | 54 | class WarpSampler(object): 55 | def __init__(self, User, usernum, itemnum, batch_size=64, maxlen=10, n_workers=1): 56 | self.result_queue = Queue(maxsize=n_workers * 10) 57 | self.processors = [] 58 | for i in range(n_workers): 59 | self.processors.append( 60 | Process(target=sample_function, args=(User, 61 | usernum, 62 | itemnum, 63 | batch_size, 64 | maxlen, 65 | self.result_queue, 66 | np.random.randint(2e9) 67 | ))) 68 | self.processors[-1].daemon = True 69 | self.processors[-1].start() 70 | 71 | def next_batch(self): 72 | return self.result_queue.get() 73 | 74 | def close(self): 75 | for p in self.processors: 76 | p.terminate() 77 | p.join() 78 | 79 | # DataSet for ddp 80 | class SeqDataset(Dataset): 81 | def __init__(self, user_train, num_user, num_item, max_len): 82 | self.user_train = user_train 83 | self.num_user = num_user 84 | self.num_item = num_item 85 | self.max_len = max_len 86 | print("Initializing with num_user:", num_user) 87 | 88 | 89 | def __len__(self): 90 | return self.num_user 91 | 92 | def __getitem__(self, idx): 93 | user_id = idx + 1 94 | seq = np.zeros([self.max_len], dtype=np.int32) 95 | pos = np.zeros([self.max_len], dtype=np.int32) 96 | neg = np.zeros([self.max_len], dtype=np.int32) 97 | 98 | nxt = self.user_train[user_id][-1] 99 | length_idx = self.max_len - 1 100 | 101 | # user의 seq set 102 | ts = set(self.user_train[user_id]) 103 | for i in reversed(self.user_train[user_id][:-1]): 104 | seq[length_idx] = i 105 | pos[length_idx] = nxt 106 | if nxt != 0: neg[length_idx] = random_neq(1, self.num_item + 1, ts) 107 | nxt = i 108 | length_idx -= 1 109 | if length_idx == -1: break 110 | 111 | return user_id, seq, pos, neg 112 | 113 | class SeqDataset_Inference(Dataset): 114 | def __init__(self, user_train, user_valid, user_test,use_user, num_item, max_len): 115 | self.user_train = user_train 116 | self.user_valid = user_valid 117 | self.user_test = user_test 118 | self.num_user = len(use_user) 119 | self.num_item = num_item 120 | self.max_len = max_len 121 | self.use_user = use_user 122 | print("Initializing with num_user:", self.num_user) 123 | 124 | 125 | def __len__(self): 126 | return self.num_user 127 | 128 | def __getitem__(self, idx): 129 | user_id = self.use_user[idx] 130 | seq = np.zeros([self.max_len], dtype=np.int32) 131 | idx = self.max_len -1 132 | seq[idx] = self.user_valid[user_id][0] 133 | idx -=1 134 | for i in reversed(self.user_train[user_id]): 135 | seq[idx] = i 136 | idx -=1 137 | if idx ==-1: break 138 | rated = set(self.user_train[user_id]) 139 | rated.add(0) 140 | pos = self.user_test[user_id][0] 141 | neg = [] 142 | for _ in range(3): 143 | t = np.random.randint(1,self.num_item+1) 144 | while t in rated: t = np.random.randint(1,self.num_item+1) 145 | neg.append(t) 146 | neg = np.array(neg) 147 | return user_id, seq, pos, neg 148 | # train/val/test data generation 149 | def data_partition(fname, path=None): 150 | usernum = 0 151 | itemnum = 0 152 | User = defaultdict(list) 153 | user_train = {} 154 | user_valid = {} 155 | user_test = {} 156 | # assume user/item index starting from 1 157 | 158 | # f = open('./pre_train/sasrec/data/%s.txt' % fname, 'r') 159 | if path == None: 160 | f = open('../../data/amazon/%s.txt' % fname, 'r') 161 | else: 162 | f = open(path, 'r') 163 | for line in f: 164 | u, i = line.rstrip().split(' ') 165 | u = int(u) 166 | i = int(i) 167 | usernum = max(u, usernum) 168 | itemnum = max(i, itemnum) 169 | User[u].append(i) 170 | 171 | for user in User: 172 | nfeedback = len(User[user]) 173 | if nfeedback < 3: 174 | user_train[user] = User[user] 175 | user_valid[user] = [] 176 | user_test[user] = [] 177 | else: 178 | user_train[user] = User[user][:-2] 179 | user_valid[user] = [] 180 | user_valid[user].append(User[user][-2]) 181 | user_test[user] = [] 182 | user_test[user].append(User[user][-1]) 183 | return [user_train, user_valid, user_test, usernum, itemnum] 184 | 185 | # TODO: merge evaluate functions for test and val set 186 | # evaluate on test set 187 | def evaluate(model, dataset, args): 188 | [train, valid, test, usernum, itemnum] = copy.deepcopy(dataset) 189 | 190 | NDCG = 0.0 191 | HT = 0.0 192 | valid_user = 0.0 193 | 194 | if usernum>10000: 195 | users = random.sample(range(1, usernum + 1), 10000) 196 | else: 197 | users = range(1, usernum + 1) 198 | for u in users: 199 | 200 | if len(train[u]) < 1 or len(test[u]) < 1: continue 201 | 202 | seq = np.zeros([args.maxlen], dtype=np.int32) 203 | idx = args.maxlen - 1 204 | seq[idx] = valid[u][0] 205 | idx -= 1 206 | for i in reversed(train[u]): 207 | seq[idx] = i 208 | idx -= 1 209 | if idx == -1: break 210 | rated = set(train[u]) 211 | rated.add(0) 212 | item_idx = [test[u][0]] 213 | for _ in range(19): 214 | t = np.random.randint(1, itemnum + 1) 215 | while t in rated: t = np.random.randint(1, itemnum + 1) 216 | item_idx.append(t) 217 | 218 | predictions = -model.predict(*[np.array(l) for l in [[u], [seq], item_idx]]) 219 | predictions = predictions[0] # - for 1st argsort DESC 220 | 221 | rank = predictions.argsort().argsort()[0].item() 222 | 223 | valid_user += 1 224 | 225 | if rank < 1: 226 | NDCG += 1 / np.log2(rank + 2) 227 | HT += 1 228 | if valid_user % 100 == 0: 229 | print('.', end="") 230 | sys.stdout.flush() 231 | 232 | return NDCG / valid_user, HT / valid_user 233 | 234 | 235 | # evaluate on val set 236 | def evaluate_valid(model, dataset, args): 237 | [train, valid, test, usernum, itemnum] = copy.deepcopy(dataset) 238 | 239 | NDCG = 0.0 240 | valid_user = 0.0 241 | HT = 0.0 242 | if usernum>10000: 243 | users = random.sample(range(1, usernum + 1), 10000) 244 | else: 245 | users = range(1, usernum + 1) 246 | 247 | for u in users: 248 | if len(train[u]) < 1 or len(valid[u]) < 1: continue 249 | 250 | seq = np.zeros([args.maxlen], dtype=np.int32) 251 | idx = args.maxlen - 1 252 | for i in reversed(train[u]): 253 | seq[idx] = i 254 | idx -= 1 255 | if idx == -1: break 256 | 257 | rated = set(train[u]) 258 | rated.add(0) 259 | item_idx = [valid[u][0]] 260 | 261 | for _ in range(100): 262 | t = np.random.randint(1, itemnum + 1) 263 | while t in rated: t = np.random.randint(1, itemnum + 1) 264 | item_idx.append(t) 265 | 266 | predictions = -model.predict(*[np.array(l) for l in [[u], [seq], item_idx]]) 267 | predictions = predictions[0] 268 | 269 | rank = predictions.argsort().argsort()[0].item() 270 | valid_user += 1 271 | 272 | if rank < 10: 273 | NDCG += 1 / np.log2(rank + 2) 274 | HT += 1 275 | if valid_user % 100 == 0: 276 | print('.', end="") 277 | sys.stdout.flush() 278 | return NDCG / valid_user, HT / valid_user -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # This file may be used to create an environment using: 2 | # $ conda create --name --file 3 | # platform: linux-64 4 | _libgcc_mutex=0.1=main 5 | _openmp_mutex=5.1=1_gnu 6 | torch=2.1.2=pypi_0 7 | tqdm=4.65.0=pypi_0 8 | pytz=2023.3.post1=pypi_0 9 | numpy=1.26.3=pypi_0 10 | accelerate=0.25.0=pyhd8ed1ab_0 11 | bitsandbytes=0.42.0=pypi_0 12 | transformers=4.32.1=pypi_0 13 | sentence-transformers=2.2.2=pyhd8ed1ab_0 -------------------------------------------------------------------------------- /train_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import random 4 | import time 5 | import os 6 | 7 | from tqdm import tqdm 8 | 9 | import torch.multiprocessing as mp 10 | from torch.utils.data import DataLoader 11 | from torch.utils.data.distributed import DistributedSampler 12 | from torch.nn.parallel import DistributedDataParallel as DDP 13 | from torch.distributed import init_process_group, destroy_process_group 14 | 15 | from models.a_llmrec_model import * 16 | from pre_train.sasrec.utils import data_partition, SeqDataset, SeqDataset_Inference 17 | 18 | 19 | def setup_ddp(rank, world_size): 20 | os.environ ["MASTER_ADDR"] = "localhost" 21 | os.environ ["MASTER_PORT"] = "12355" 22 | init_process_group(backend="nccl", rank=rank, world_size=world_size) 23 | torch.cuda.set_device(rank) 24 | 25 | def train_model_phase1(args): 26 | print('A-LLMRec start train phase-1\n') 27 | if args.multi_gpu: 28 | world_size = torch.cuda.device_count() 29 | mp.spawn(train_model_phase1_, args=(world_size, args), nprocs=world_size) 30 | else: 31 | train_model_phase1_(0, 0, args) 32 | 33 | def train_model_phase2(args): 34 | print('A-LLMRec strat train phase-2\n') 35 | if args.multi_gpu: 36 | world_size = torch.cuda.device_count() 37 | mp.spawn(train_model_phase2_, args=(world_size, args), nprocs=world_size) 38 | else: 39 | train_model_phase2_(0, 0, args) 40 | 41 | def inference(args): 42 | print('A-LLMRec start inference\n') 43 | if args.multi_gpu: 44 | world_size = torch.cuda.device_count() 45 | mp.spawn(inference_, args=(world_size, args), nprocs=world_size) 46 | else: 47 | inference_(0,0,args) 48 | 49 | def train_model_phase1_(rank, world_size, args): 50 | if args.multi_gpu: 51 | setup_ddp(rank, world_size) 52 | args.device = 'cuda:' + str(rank) 53 | 54 | model = A_llmrec_model(args).to(args.device) 55 | 56 | # preprocess data 57 | dataset = data_partition(args.rec_pre_trained_data, path=f'./data/amazon/{args.rec_pre_trained_data}.txt') 58 | [user_train, user_valid, user_test, usernum, itemnum] = dataset 59 | print('user num:', usernum, 'item num:', itemnum) 60 | num_batch = len(user_train) // args.batch_size1 61 | cc = 0.0 62 | for u in user_train: 63 | cc += len(user_train[u]) 64 | print('average sequence length: %.2f' % (cc / len(user_train))) 65 | # Init Dataloader, Model, Optimizer 66 | train_data_set = SeqDataset(user_train, usernum, itemnum, args.maxlen) 67 | if args.multi_gpu: 68 | train_data_loader = DataLoader(train_data_set, batch_size = args.batch_size1, sampler=DistributedSampler(train_data_set, shuffle=True), pin_memory=True) 69 | model = DDP(model, device_ids = [args.device], static_graph=True) 70 | else: 71 | train_data_loader = DataLoader(train_data_set, batch_size = args.batch_size1, pin_memory=True) 72 | 73 | adam_optimizer = torch.optim.Adam(model.parameters(), lr=args.stage1_lr, betas=(0.9, 0.98)) 74 | 75 | epoch_start_idx = 1 76 | T = 0.0 77 | model.train() 78 | t0 = time.time() 79 | for epoch in tqdm(range(epoch_start_idx, args.num_epochs + 1)): 80 | if args.multi_gpu: 81 | train_data_loader.sampler.set_epoch(epoch) 82 | for step, data in enumerate(train_data_loader): 83 | u, seq, pos, neg = data 84 | u, seq, pos, neg = u.numpy(), seq.numpy(), pos.numpy(), neg.numpy() 85 | model([u,seq,pos,neg], optimizer=adam_optimizer, batch_iter=[epoch,args.num_epochs + 1,step,num_batch], mode='phase1') 86 | if step % max(10,num_batch//100) ==0: 87 | if rank ==0: 88 | if args.multi_gpu: model.module.save_model(args, epoch1=epoch) 89 | else: model.save_model(args, epoch1=epoch) 90 | if rank == 0: 91 | if args.multi_gpu: model.module.save_model(args, epoch1=epoch) 92 | else: model.save_model(args, epoch1=epoch) 93 | 94 | print('train time :', time.time() - t0) 95 | if args.multi_gpu: 96 | destroy_process_group() 97 | return 98 | 99 | def train_model_phase2_(rank,world_size,args): 100 | if args.multi_gpu: 101 | setup_ddp(rank, world_size) 102 | args.device = 'cuda:'+str(rank) 103 | random.seed(0) 104 | 105 | model = A_llmrec_model(args).to(args.device) 106 | phase1_epoch = 10 107 | model.load_model(args, phase1_epoch=phase1_epoch) 108 | 109 | dataset = data_partition(args.rec_pre_trained_data, path=f'./data/amazon/{args.rec_pre_trained_data}.txt') 110 | [user_train, user_valid, user_test, usernum, itemnum] = dataset 111 | print('user num:', usernum, 'item num:', itemnum) 112 | num_batch = len(user_train) // args.batch_size2 113 | cc = 0.0 114 | for u in user_train: 115 | cc += len(user_train[u]) 116 | print('average sequence length: %.2f' % (cc / len(user_train))) 117 | # Init Dataloader, Model, Optimizer 118 | train_data_set = SeqDataset(user_train, usernum, itemnum, args.maxlen) 119 | if args.multi_gpu: 120 | train_data_loader = DataLoader(train_data_set, batch_size = args.batch_size2, sampler=DistributedSampler(train_data_set, shuffle=True), pin_memory=True) 121 | model = DDP(model, device_ids = [args.device], static_graph=True) 122 | else: 123 | train_data_loader = DataLoader(train_data_set, batch_size = args.batch_size2, pin_memory=True, shuffle=True) 124 | adam_optimizer = torch.optim.Adam(model.parameters(), lr=args.stage2_lr, betas=(0.9, 0.98)) 125 | 126 | epoch_start_idx = 1 127 | T = 0.0 128 | model.train() 129 | t0 = time.time() 130 | for epoch in tqdm(range(epoch_start_idx, args.num_epochs + 1)): 131 | if args.multi_gpu: 132 | train_data_loader.sampler.set_epoch(epoch) 133 | for step, data in enumerate(train_data_loader): 134 | u, seq, pos, neg = data 135 | u, seq, pos, neg = u.numpy(), seq.numpy(), pos.numpy(), neg.numpy() 136 | model([u,seq,pos,neg], optimizer=adam_optimizer, batch_iter=[epoch,args.num_epochs + 1,step,num_batch], mode='phase2') 137 | if step % max(10,num_batch//100) ==0: 138 | if rank ==0: 139 | if args.multi_gpu: model.module.save_model(args, epoch1=phase1_epoch, epoch2=epoch) 140 | else: model.save_model(args, epoch1=phase1_epoch, epoch2=epoch) 141 | if rank == 0: 142 | if args.multi_gpu: model.module.save_model(args, epoch1=phase1_epoch, epoch2=epoch) 143 | else: model.save_model(args, epoch1=phase1_epoch, epoch2=epoch) 144 | 145 | print('phase2 train time :', time.time() - t0) 146 | if args.multi_gpu: 147 | destroy_process_group() 148 | return 149 | 150 | def inference_(rank, world_size, args): 151 | if args.multi_gpu: 152 | setup_ddp(rank, world_size) 153 | args.device = 'cuda:' + str(rank) 154 | 155 | model = A_llmrec_model(args).to(args.device) 156 | phase1_epoch = 10 157 | phase2_epoch = 5 158 | model.load_model(args, phase1_epoch=phase1_epoch, phase2_epoch=phase2_epoch) 159 | 160 | dataset = data_partition(args.rec_pre_trained_data, path=f'./data/amazon/{args.rec_pre_trained_data}.txt') 161 | [user_train, user_valid, user_test, usernum, itemnum] = dataset 162 | print('user num:', usernum, 'item num:', itemnum) 163 | num_batch = len(user_train) // args.batch_size_infer 164 | cc = 0.0 165 | for u in user_train: 166 | cc += len(user_train[u]) 167 | print('average sequence length: %.2f' % (cc / len(user_train))) 168 | model.eval() 169 | 170 | if usernum>10000: 171 | users = random.sample(range(1, usernum + 1), 10000) 172 | else: 173 | users = range(1, usernum + 1) 174 | 175 | user_list = [] 176 | for u in users: 177 | if len(user_train[u]) < 1 or len(user_test[u]) < 1: continue 178 | user_list.append(u) 179 | 180 | inference_data_set = SeqDataset_Inference(user_train, user_valid, user_test, user_list, itemnum, args.maxlen) 181 | 182 | if args.multi_gpu: 183 | inference_data_loader = DataLoader(inference_data_set, batch_size = args.batch_size_infer, sampler=DistributedSampler(inference_data_set, shuffle=True), pin_memory=True) 184 | model = DDP(model, device_ids = [args.device], static_graph=True) 185 | else: 186 | inference_data_loader = DataLoader(inference_data_set, batch_size = args.batch_size_infer, pin_memory=True) 187 | 188 | for _, data in enumerate(inference_data_loader): 189 | u, seq, pos, neg = data 190 | u, seq, pos, neg = u.numpy(), seq.numpy(), pos.numpy(), neg.numpy() 191 | model([u,seq,pos,neg, rank], mode='generate') -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from datetime import datetime 3 | from pytz import timezone 4 | 5 | def create_dir(directory): 6 | if not os.path.exists(directory): 7 | os.makedirs(directory) 8 | 9 | # ex. target_word: .csv / in target_path find 123.csv file 10 | def find_filepath(target_path, target_word): 11 | file_paths = [] 12 | for file in os.listdir(target_path): 13 | if os.path.isfile(os.path.join(target_path, file)): 14 | if target_word in file: 15 | file_paths.append(target_path + file) 16 | 17 | return file_paths 18 | 19 | 20 | 21 | --------------------------------------------------------------------------------