├── seqrec ├── models │ ├── __init__.py │ ├── GRU4Rec │ │ ├── config.yaml │ │ └── _model.py │ ├── SASRec │ │ ├── config.yaml │ │ └── _model.py │ └── Embedding2.py ├── default.yaml ├── base.py ├── evaluator.py ├── recdata.py ├── trainer.py ├── runner.py ├── utils.py └── modules.py ├── .gitignore ├── run_LLM2Rec_IEM.sh ├── utils ├── memory.py ├── datasets.py └── llm2vec_encoder.py ├── llm2rec ├── train_mntp_config.json ├── train_simcse_config.json ├── dataset_utils.py ├── recdata │ ├── dataset.py │ ├── ItemTitleData.py │ ├── RecItemData.py │ └── SeqRecData.py ├── run_csft.py ├── run_unsupervised_SimCSE.py └── dataset.py ├── evaluate_with_seqrec.py ├── run_LLM2Rec_CSFT.sh ├── script_eval_baselines.sh ├── README.md ├── script_extract_and_evaluate.sh ├── repeated_evaluate_with_seqrec.py ├── extract_llm_embedding.py ├── baselines ├── model.py └── EasyRecModel.py └── Baseline_inference.py /seqrec/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .SASRec._model import SASRec 2 | from .GRU4Rec._model import GRU4Rec 3 | -------------------------------------------------------------------------------- /seqrec/models/GRU4Rec/config.yaml: -------------------------------------------------------------------------------- 1 | lr: 1e-2 2 | loss_type: ce 3 | 4 | hidden_size: 128 5 | layer_num: 2 6 | 7 | dropout: 0.3 8 | sample_func: random 9 | 10 | adapter_dims: [-1] # must ending with -1, which refers to hidden_size -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # datasets/ 2 | .vscode/ 3 | data 4 | data/ 5 | data_EasyRec_setting/ 6 | 7 | checkpoints/ 8 | log/ 9 | runs/ 10 | 11 | __pycache__/ 12 | *.ttc 13 | wandb/ 14 | output/ 15 | 16 | seqrec/ckpt/ 17 | item_info/ 18 | item_info 19 | cache/ 20 | Results/ 21 | 22 | Rebuttal/ 23 | 24 | *.log 25 | *.pdf -------------------------------------------------------------------------------- /seqrec/models/SASRec/config.yaml: -------------------------------------------------------------------------------- 1 | # lr: 1e-2 2 | # weight_decay: 3e-4 3 | loss_type: ce 4 | 5 | hidden_size: 128 6 | layer_num: 2 7 | num_heads: 2 8 | dropout: 0.3 9 | 10 | # sample_func: batch 11 | sample_func: random 12 | 13 | adapter_dims: [-1] # must ending with -1, which refers to hidden_size -------------------------------------------------------------------------------- /run_LLM2Rec_IEM.sh: -------------------------------------------------------------------------------- 1 | # Second stage of training LLM2Rec -- Item Embedding Modeling. 2 | 3 | model_path="/home/yingzhi/huggingface_data/hub/Qwen2-0.5B" # Replace with your own model path 4 | 5 | # Stage 2 - Train MNTP 6 | echo "Starting Stage 2 - Train MNTP..." 7 | CUDA_VISIBLE_DEVICES=0 torchrun --nproc_per_node=1 --master_port=29501 ./llm2rec/run_mntp.py ./llm2rec/train_mntp_config.json 8 | 9 | # Stage 3 - Train SimCSE 10 | echo "Starting Stage 3 - Train SimCSE..." 11 | cp ${model_path}/*token* ./output/iem_stage1/Qwen2-0.5B-AmazonMix6-CSFT/checkpoint-100/ 12 | CUDA_VISIBLE_DEVICES=0 torchrun --nproc_per_node=1 --master_port=29502 ./llm2rec/run_unsupervised_SimCSE.py ./llm2rec/train_simcse_config.json 13 | -------------------------------------------------------------------------------- /seqrec/models/Embedding2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from dataclasses import dataclass 4 | 5 | @dataclass 6 | class Weight: 7 | data: None 8 | 9 | class Embedding2(nn.Module): 10 | def __init__(self, adapter, embedding): 11 | super().__init__() 12 | self.embedding = embedding 13 | self.adapter = adapter 14 | 15 | def forward(self, indices): 16 | return self.adapter(self.embedding(indices)) 17 | 18 | @property 19 | def weight(self): 20 | return Weight(self.adapter(self.embedding.weight.data)) 21 | # return Weight(10) 22 | 23 | if __name__ == "__main__": 24 | print(Embedding2(None, None).weight.data) -------------------------------------------------------------------------------- /seqrec/default.yaml: -------------------------------------------------------------------------------- 1 | num_proc: 1 2 | cache_dir: seqrec/cache/ # Usually for raw and processed data 3 | log_dir: seqrec/logs/ 4 | tensorboard_log_dir: seqrec/tensorboard/ 5 | ckpt_dir: seqrec/ckpt/ 6 | rand_seed: 2024 7 | reproducibility: True 8 | 9 | max_seq_length: 10 10 | whiten: False 11 | 12 | train_batch_size: 256 13 | eval_batch_size: 32 14 | lr: 1.0e-3 15 | weight_decay: 1.0e-4 16 | warmup_steps: 10000 17 | steps: ~ 18 | # epochs: 150 19 | epochs: 1000 20 | 21 | max_grad_norm: 1.0 # None for no clipping, else a float value 22 | eval_interval: 5 # Evaluate every n epochs 23 | patience: 20 # Early stopping. Stop training after n epochs without improvement. Set to None to disable 24 | 25 | topk: [5,10,20] 26 | metrics: [ndcg,recall] 27 | val_metric: recall@20 28 | run_id: Eval # Change this to your customized run id 29 | 30 | 31 | save: True -------------------------------------------------------------------------------- /utils/memory.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | # follow MoCo memory bank implementation 4 | class MemoryBank(): 5 | def __init__(self, size = 128, dim = 128): 6 | self.size = size 7 | self.feature = torch.randn(size, dim, dtype=torch.bfloat16) 8 | self.queue_ptr = torch.zeros(1, dtype=torch.long) 9 | self.K = size 10 | 11 | @torch.no_grad() 12 | def _dequeue_and_enqueue(self, keys): 13 | 14 | batch_size = keys.shape[0] 15 | 16 | ptr = int(self.queue_ptr) 17 | assert self.K % batch_size == 0 # for simplicity 18 | 19 | # replace the keys at ptr (dequeue and enqueue) 20 | self.feature[ptr : ptr + batch_size, :] = keys 21 | ptr = (ptr + batch_size) % self.K # move pointer 22 | 23 | self.queue_ptr[0] = ptr 24 | 25 | def update(self, keys): 26 | self._dequeue_and_enqueue(keys) -------------------------------------------------------------------------------- /llm2rec/train_mntp_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name_or_path": "output/Qwen2-0.5B-CSFT-AmazonMix-6/checkpoint-10000", 3 | "dataset_name": "data/AmazonMix-6/5-core/info/item_titles.txt", 4 | "per_device_train_batch_size": 32, 5 | "per_device_eval_batch_size": 32, 6 | "gradient_accumulation_steps": 1, 7 | "do_train": true, 8 | "do_eval": true, 9 | "line_by_line": true, 10 | "max_seq_length": 512, 11 | "mask_token_type": "blank", 12 | "data_collator_type": "default", 13 | "mlm_probability": 0.2, 14 | "overwrite_output_dir": true, 15 | "output_dir": "output/iem_stage1/Qwen2-0.5B-AmazonMix6-CSFT", 16 | "evaluation_strategy": "steps", 17 | "eval_steps": 100, 18 | "save_steps": 100, 19 | "stop_after_n_steps": 1000, 20 | "lora_r": null, 21 | "gradient_checkpointing": true, 22 | "torch_dtype": "bfloat16", 23 | "attn_implementation": "flash_attention_2", 24 | "trust_remote_code": true 25 | } 26 | -------------------------------------------------------------------------------- /seqrec/base.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | class AbstractModel(nn.Module): 5 | def __init__( 6 | self, 7 | config: dict, 8 | ): 9 | super(AbstractModel, self).__init__() 10 | self.config = config 11 | # self.sub_embeddings = nn.Embedding(self.config['hidden_size'], self.config['hidden_size']) 12 | # nn.init.normal_(self.sub_embeddings.weight, 0, 1) 13 | 14 | @property 15 | def n_parameters(self): 16 | total_params = sum(p.numel() for p in self.parameters() if p.requires_grad) 17 | return f'Total number of trainable parameters: {total_params}' 18 | 19 | def calculate_loss(self, batch): 20 | raise NotImplementedError('calculate_loss method must be implemented.') 21 | 22 | def predict(self, batch, n_return_sequences=1): 23 | raise NotImplementedError('predict method must be implemented.') 24 | 25 | def get_embeddings(self, items): 26 | raise NotImplementedError('get item_embeddings must be implemented.') 27 | -------------------------------------------------------------------------------- /llm2rec/train_simcse_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name_or_path": "output/iem_stage1/Qwen2-0.5B-AmazonMix6-CSFT/checkpoint-1000", 3 | "peft_model_name_or_path": null, 4 | "bidirectional": true, 5 | "pooling_mode": "mean", 6 | "dataset_name": "ItemTitles", 7 | "dataset_file_path": "data", 8 | "remove_unused_columns": false, 9 | "learning_rate": 2e-4, 10 | "num_train_epochs": 5, 11 | "warmup_steps": 300, 12 | "per_device_train_batch_size": 256, 13 | "per_device_eval_batch_size": 256, 14 | "gradient_accumulation_steps": 1, 15 | "do_train": true, 16 | "disable_tqdm": false, 17 | "max_seq_length": 512, 18 | "overwrite_output_dir": true, 19 | "output_dir": "output/iem_stage2/Qwen2-0.5B-AmazonMix6-CSFT", 20 | "logging_steps": 1, 21 | "max_grad_norm": 1.0, 22 | "simcse_dropout": 0.2, 23 | 24 | "save_only_model": false, 25 | "save_steps": 100, 26 | "stop_after_n_steps": 1000, 27 | "loss_scale": 10.0, 28 | 29 | "lora_r": null, 30 | "gradient_checkpointing": true, 31 | "torch_dtype": "bfloat16", 32 | "attn_implementation": "flash_attention_2", 33 | "seed": 42 34 | } -------------------------------------------------------------------------------- /llm2rec/dataset_utils.py: -------------------------------------------------------------------------------- 1 | from recdata.RecItemData import RecItemData 2 | from recdata.SeqRecData import SeqRecData 3 | from recdata.ItemTitleData import ItemTitleData 4 | 5 | 6 | def load_dataset(dataset_name, split="validation", file_path=None, **kwargs): 7 | """ 8 | Loads a dataset by name. 9 | 10 | Args: 11 | dataset_name (str): Name of the dataset to load. 12 | split (str): Split of the dataset to load. 13 | file_path (str): Path to the dataset file. 14 | """ 15 | dataset_mapping = { 16 | "ItemRec": RecItemData, 17 | "SeqRec": SeqRecData, 18 | "ItemTitles": ItemTitleData, 19 | } 20 | 21 | if dataset_name.split("_")[0] not in dataset_mapping: 22 | raise NotImplementedError(f"Dataset name {dataset_name} not supported.") 23 | 24 | if split not in ["train", "validation", "test"]: 25 | raise NotImplementedError(f"Split {split} not supported.") 26 | 27 | if "_SeqAug" in dataset_name: 28 | dataset_name = dataset_name.replace("_SeqAug", "") 29 | return dataset_mapping[dataset_name](split=split, file_path=file_path, data_augmentation=True, **kwargs) 30 | else: 31 | return dataset_mapping[dataset_name](split=split, file_path=file_path, **kwargs) 32 | -------------------------------------------------------------------------------- /evaluate_with_seqrec.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from seqrec.runner import Runner 4 | from seqrec.utils import parse_command_line_args 5 | # import os 6 | # os.environ['CUDA_VISIBLE_DEVICES'] = '3' 7 | 8 | 9 | def parse_args(): 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--model', type=str, default='SASRec', help='Model name, options: SASRec, GRU4Rec') 12 | parser.add_argument('--dataset', type=str, default='Games_5core', help='Source domain') 13 | parser.add_argument('--exp_type', type=str, default='srec') 14 | parser.add_argument('--embedding', type=str, default='./item_info/Games_5core/LLM2Vec_Qwen2-0.5B-Backbone_title_item_embs.npy', help='Whether to use source domain data') 15 | parser.add_argument('--seq_embedding', type=str, default='', help='whether pre-trained sequence embeddings are used') 16 | 17 | return parser.parse_known_args() 18 | 19 | 20 | if __name__ == '__main__': 21 | args, unparsed_args = parse_args() 22 | command_line_configs = parse_command_line_args(unparsed_args) 23 | args_dict = vars(args) 24 | 25 | merged_dict = {**args_dict, **command_line_configs} 26 | 27 | 28 | runner = Runner( 29 | model_name=args.model, 30 | config_dict= merged_dict 31 | ) 32 | runner.run() 33 | 34 | 35 | # CUDA_VISIBLE_DEVICES=3 accelerate launch --main_process_port=12324 main.py --model=PDSRec --sd=T --td=T 36 | 37 | -------------------------------------------------------------------------------- /llm2rec/recdata/dataset.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Union, List 3 | 4 | import torch 5 | 6 | 7 | @dataclass 8 | class DataSample: 9 | id_: int 10 | query: str 11 | positive: str 12 | negative: str = None 13 | task_name: str = None 14 | aug_query: str = None 15 | 16 | 17 | class TrainSample: 18 | """ 19 | Structure for one input example with texts, the label and a unique id 20 | """ 21 | 22 | def __init__( 23 | self, guid: str = "", texts: List[str] = None, label: Union[int, float] = 0 24 | ): 25 | """ 26 | Creates one TrainSample with the given texts, guid and label 27 | 28 | 29 | :param guid 30 | id for the example 31 | :param texts 32 | the texts for the example. 33 | :param label 34 | the label for the example 35 | """ 36 | self.guid = guid 37 | self.texts = texts 38 | self.label = label 39 | 40 | def __str__(self): 41 | return " label: {}, texts: {}".format( 42 | str(self.label), "; ".join(self.texts) 43 | ) 44 | 45 | 46 | class Dataset(torch.utils.data.Dataset): 47 | def load_data(self, file_path: str = None): 48 | raise NotImplementedError() 49 | 50 | def __getitem__(self, index): 51 | raise NotImplementedError() 52 | 53 | def __len__(self): 54 | raise NotImplementedError() -------------------------------------------------------------------------------- /run_LLM2Rec_CSFT.sh: -------------------------------------------------------------------------------- 1 | # First stage of LLM2Rec training -- Collaborative Supervised Fine-Tuning (CSFT). 2 | 3 | model_path="/home/yingzhi/huggingface_data/hub/Qwen2-0.5B" # Replace with your own model path 4 | 5 | for category in "AmazonMix-6" 6 | do 7 | train_file=$(ls -f ./data/${category}/5-core/train/${category}*.csv) 8 | eval_file=$(ls -f ./data/${category}/5-core/valid/${category}*.csv) 9 | echo ${train_file} ${info_file} 10 | 11 | CUDA_VISIBLE_DEVICES=0,1 torchrun --master_port=25649 --nproc_per_node 2 \ 12 | ./llm2rec/run_csft.py \ 13 | --base_model ${model_path} \ 14 | --train_file ${train_file} \ 15 | --eval_file ${eval_file} \ 16 | --output_dir ./output/Qwen2-0.5B-CSFT-${category} \ 17 | --wandb_run_name Qwen2-0.5B-CSFT-${category} \ 18 | --category ${category} \ 19 | --train_from_scratch False \ 20 | --use_lora False 21 | 22 | cp ${model_path}/*token* ./output/Qwen2-0.5B-CSFT-${category}/ 23 | # Also copy tokenizer to the last checkpoint 24 | latest_ckpt=$(ls -d ./output/Qwen2-0.5B-CSFT-${category}/checkpoint-* | sort -V | tail -n 1) 25 | cp ${model_path}/*token* ${latest_ckpt}/ 26 | done 27 | 28 | 29 | 30 | # echo "Starting Stage 2 - Train MNTP..." 31 | # cp /home/yingzhi/huggingface_data/hub/gemma-2b/*token* ./output/gemma-2b-FULL-Amazon6-wo-prompt-SFT-AmazonMix-6/checkpoint-10000/ 32 | # CUDA_VISIBLE_DEVICES=4,5 torchrun --nproc_per_node=2 --master_port=29501 ./llm2rec/run_mntp.py ./llm2rec/Rebuttal_train_mntp_config.json 33 | 34 | # # Stage 3 - Train SimCSE 35 | # echo "Starting Stage 3 - Train SimCSE..." 36 | # cp /home/yingzhi/huggingface_data/hub/gemma-2b/*token* ./output/mntp/gemma-2b-FULL-SFT-AmazonMix-6/checkpoint-1000/ 37 | # CUDA_VISIBLE_DEVICES=4,5 torchrun --nproc_per_node=2 --master_port=29502 run_unsupervised_SimCSE.py ./llm2rec/Rebuttal_train_simcse_config.json 38 | -------------------------------------------------------------------------------- /script_eval_baselines.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Define datasets and their corresponding CUDA_VISIBLE_DEVICES 4 | datasets=( 5 | "Games_5core 0" 6 | "Arts_5core 0" 7 | "Movies_5core 0" 8 | "Goodreads 1" 9 | "Sports_5core 1" 10 | "Baby_5core 1" 11 | ) 12 | 13 | # Define the fixed parameters 14 | run_id="Eval_Embeddings" 15 | model="SASRec" 16 | dr=0.3 17 | port_base=12324 18 | lr=1.0e-3 19 | wd=1.0e-4 20 | loss_type="ce" 21 | 22 | # Run experiments for all datasets simultaneously 23 | for dataset_entry in "${datasets[@]}"; do 24 | ( 25 | # Split dataset_entry into dataset name and CUDA device 26 | IFS=' ' read -r dataset cuda_device <<< "$dataset_entry" 27 | 28 | # Define embeddings for the current dataset 29 | embeddings=( 30 | # "./item_info/${dataset}/BGE_title_item_embs.npy" 31 | # "./item_info/${dataset}/Blair_title_item_embs.npy" 32 | # "./item_info/${dataset}/EasyRec_title_item_embs.npy" 33 | # "./item_info/${dataset}/BERT_title_item_embs.npy" 34 | "./item_info/${dataset}/RoBERTa_large_sentence_title_item_embs.npy" 35 | # "./item_info/${dataset}/GTE_7B_title_item_embs.npy" 36 | ) 37 | 38 | for embs in "${embeddings[@]}"; do 39 | echo "Running evaluation with dataset=$dataset, CUDA_VISIBLE_DEVICES=$cuda_device, and embeddings=$embs" 40 | CUDA_VISIBLE_DEVICES=$cuda_device accelerate launch --main_process_port=$((port_base + cuda_device)) repeated_evaluate_with_seqrec.py \ 41 | --model=$model \ 42 | --dataset=$dataset \ 43 | --lr=$lr \ 44 | --weight_decay=$wd \ 45 | --embedding=$embs \ 46 | --dropout=$dr \ 47 | --loss_type=$loss_type \ 48 | --run_id=$run_id 49 | done 50 | ) & 51 | done 52 | 53 | # Wait for all background processes to complete 54 | wait 55 | -------------------------------------------------------------------------------- /seqrec/evaluator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Evaluator: 5 | def __init__(self, config): 6 | self.config = config 7 | self.metric2func = { 8 | 'recall': self.recall_at_k, 9 | 'ndcg': self.ndcg_at_k 10 | } 11 | 12 | self.eos_token = self.config['eos_token'] 13 | self.maxk = max(config['topk']) 14 | 15 | def calculate_pos_index(self, preds, labels): 16 | preds = preds.detach().cpu() 17 | labels = labels.detach().cpu() 18 | assert preds.shape[1] == self.maxk, f"preds.shape[1] = {preds.shape[1]} != {self.maxk}" 19 | 20 | pos_index = torch.zeros((preds.shape[0], self.maxk), dtype=torch.bool) 21 | for i in range(preds.shape[0]): 22 | cur_label = labels[i].tolist() 23 | if self.eos_token in [cur_label]: 24 | eos_pos = cur_label.index(self.eos_token) 25 | cur_label = cur_label[:eos_pos] 26 | for j in range(self.maxk): 27 | cur_pred = preds[i, j].tolist() 28 | if cur_pred == cur_label: 29 | pos_index[i, j] = True 30 | break 31 | return pos_index 32 | 33 | def recall_at_k(self, pos_index, k): 34 | return pos_index[:, :k].sum(dim=1).cpu().float() 35 | 36 | def ndcg_at_k(self, pos_index, k): 37 | ranks = torch.arange(1, pos_index.shape[-1] + 1).to(pos_index.device) 38 | dcg = 1.0 / torch.log2(ranks + 1) 39 | dcg = torch.where(pos_index, dcg, torch.tensor(0.0, dtype=dcg.dtype, device=dcg.device)) 40 | 41 | return dcg[:, :k].sum(dim=1).cpu().float() 42 | 43 | def calculate_metrics(self, preds, labels): 44 | results = {} 45 | pos_index = self.calculate_pos_index(preds, labels) 46 | for metric in self.config['metrics']: 47 | for k in self.config['topk']: 48 | results[f"{metric}@{k}"] = self.metric2func[metric](pos_index, k) 49 | return results 50 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LLM2Rec: Large Language Models Are Powerful Embedding Models for Sequential Recommendation 2 | 3 | ## Introduction 4 | This is the code implementation for our paper on KDD'25 "LLM2Rec: Large Language Models Are Powerful Embedding Models for Sequential Recommendation". 5 | 6 | ## Environments 7 | To execute the code correctly, the following python packages are required: 8 | 9 | - `torch >= 2.6.0` 10 | - `transformers >= 4.44.2` 11 | - `llm2vec == 0.2.3` 12 | - `flash-attn >= 2.7.4` 13 | 14 | ## Datasets 15 | The zipped datasets used in this paper can be downloaded from this [link](https://drive.google.com/file/d/1GIXWaaaNuUkUtuFy5JTN0OwAQiLGb2z4/view?usp=sharing). Please unzip the dataset files under directory `./data` . 16 | 17 | ## Training 18 | 19 | LLM2Rec follows a two-stage training pipeline: 20 | 21 | 1. **Collaborative Supervised Fine-Tuning (CSFT)** 22 | Fine-tunes a pre-trained LLM to capture collaborative filtering (CF) signals using user interaction sequences as training data. 23 | 24 | 2. **Item-level Embedding Modeling (IEM)** 25 | Converts the CF-aware LLM into an embedding generator. 26 | 27 | ### Run training 28 | 29 | We provide example shell scripts for training: 30 | 31 | ```bash 32 | # Stage 1: Collaborative Supervised Fine-Tuning 33 | bash run_LLM2Rec_CSFT.sh 34 | 35 | # Stage 2: Item-level Embedding Modeling 36 | bash run_LLM2Rec_IEM.sh 37 | ``` 38 | 39 | Please change the necessary configs of your own device (e.g. path of the saved pre-trained LLMs) before executing. 40 | 41 | ## Evaluation 42 | 43 | We integrate the evaluation process, including embedding extraction and training downstream sequential recommenders, into one script, which can be easily executed by 44 | ```bash 45 | bash script_extract_and_evaluate.sh 46 | ``` 47 | 48 | You can change the paths of the saved checkpoints to evaluate in the config part of the script_extract_and_evaluate.sh script. 49 | 50 | 51 | ## Citation 52 | If you find our repo useful, please consider citing: 53 | ```bibtex 54 | @inproceedings{he2025llm2rec, 55 | title={LLM2Rec: Large Language Models Are Powerful Embedding Models for Sequential Recommendation}, 56 | author={He, Yingzhi and Liu, Xiaohao and Zhang, An and Ma, Yunshan and Chua, Tat-Seng}, 57 | booktitle={Proceedings of the 31st ACM SIGKDD Conference on Knowledge Discovery and Data Mining V. 2}, 58 | pages={896--907}, 59 | year={2025} 60 | } 61 | ``` 62 | 63 | ## Acknowledgements 64 | 65 | The code implementation is based on previous repos, including [llm2vec](https://github.com/McGill-NLP/llm2vec), [recbole](https://github.com/RUCAIBox/RecBole), and [DecodingMatters](https://github.com/SAI990323/DecodingMatters). 66 | -------------------------------------------------------------------------------- /script_extract_and_evaluate.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Define model settings as tuples: (save_info, model_path, bidirectional) 4 | models_list=( 5 | # "Qwen2-0.5B-Backbone /home/yingzhi/huggingface_data/hub/Qwen2-0.5B" 6 | "Qwen2-0.5B-LLM2Rec-IEM_step500 ./output/iem_stage2/Qwen2-0.5B-AmazonMix6-CSFT/checkpoint-500" 7 | "Qwen2-0.5B-LLM2Rec-IEM_step1000 ./output/iem_stage2/Qwen2-0.5B-AmazonMix6-CSFT/checkpoint-1000" 8 | ) 9 | 10 | # Define datasets as tuples: (dataset_name, cuda_device) 11 | datasets=( 12 | "Games_5core 0" 13 | "Arts_5core 1" 14 | "Movies_5core 0" 15 | "Sports_5core 1" 16 | "Baby_5core 0" 17 | "Goodreads 1" 18 | ) 19 | 20 | # Loop over each dataset setting (parallelized) 21 | for dataset_setting in "${datasets[@]}" 22 | do 23 | ( 24 | # Extract dataset and CUDA device 25 | dataset=$(echo $dataset_setting | awk '{print $1}') 26 | cuda_device=$(echo $dataset_setting | awk '{print $2}') 27 | 28 | extraction_method="title" 29 | 30 | # Ensure the item_info directory exists 31 | mkdir -p "./item_info/${dataset}" 32 | 33 | # Loop over models sequentially for each dataset 34 | for model_setting in "${models_list[@]}" 35 | do 36 | # Split model_setting into save_info, model_path, and bidirectional 37 | save_info=$(echo $model_setting | awk '{print $1}') 38 | model_path=$(echo $model_setting | awk '{print $2}') 39 | bidirectional=1 40 | 41 | # Extract embeddings 42 | CUDA_VISIBLE_DEVICES=$cuda_device python extract_llm_embedding.py --dataset=$dataset \ 43 | --model_path=$model_path \ 44 | --item_prompt_type=$extraction_method \ 45 | --bidirectional=$bidirectional \ 46 | --save_info=$save_info 47 | 48 | # Define hyperparameters for evaluation 49 | # Default hyperparameters setting for SASRec 50 | lr=1.0e-3 51 | wd=1.0e-4 52 | loss_type="ce" 53 | model="SASRec" 54 | dr=0.3 55 | 56 | # Default hyperparameters setting for GRU4Rec 57 | # lr=1.0e-4 58 | # wd=1.0e-4 59 | # loss_type="ce" 60 | # model="GRU4Rec" 61 | # dr=0.3 62 | 63 | run_id="CR" 64 | embs="./item_info/${dataset}/${save_info}_${extraction_method}_item_embs.npy" 65 | 66 | # Random port to avoid conflicts 67 | port=$((12000 + RANDOM % 1000)) 68 | 69 | # Evaluate the model 70 | CUDA_VISIBLE_DEVICES=$cuda_device accelerate launch --main_process_port=$port repeated_evaluate_with_seqrec.py \ 71 | --model=$model \ 72 | --dataset=$dataset \ 73 | --lr=$lr \ 74 | --weight_decay=$wd \ 75 | --embedding=$embs \ 76 | --dropout=$dr \ 77 | --loss_type=$loss_type \ 78 | --run_id=$run_id 79 | 80 | echo "✅ Finished processing dataset: $dataset with model: $save_info on CUDA device: $cuda_device" 81 | 82 | done 83 | ) & # Run each dataset in parallel across available GPUs 84 | done 85 | 86 | # Wait for all parallel dataset processes to finish 87 | wait 88 | 89 | echo "🚀 All dataset experiments completed!" 90 | -------------------------------------------------------------------------------- /seqrec/recdata.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from collections import defaultdict 3 | from torch.utils.data import Dataset, DataLoader, Sampler 4 | 5 | 6 | class SequenceDataset(Dataset): 7 | def __init__(self, config, sequences, seq_type=None): 8 | self.sequences = sequences 9 | self.config = config 10 | self.seq_type = seq_type 11 | 12 | def __len__(self): 13 | return len(self.sequences) 14 | 15 | def __getitem__(self, idx): 16 | seq = self.sequences[idx] 17 | item_seq = seq[:-1] 18 | labels = seq[-1] 19 | seq_length = len(item_seq) 20 | padding_length = self.config['max_seq_length'] - len(item_seq) 21 | if padding_length > 0: 22 | item_seq = item_seq + [0] * padding_length # 在后面填充0 23 | return { 24 | 'item_seqs': torch.tensor(item_seq, dtype=torch.long), 25 | 'labels': torch.tensor(labels, dtype=torch.long), 26 | 'seq_lengths': seq_length, 27 | 28 | # The variables below are used for sequential embedding generation. Ignore if not needed. 29 | 'seq_ids': idx, 30 | 'seq_type': self.seq_type 31 | } 32 | 33 | 34 | class NormalRecData: 35 | def __init__(self, config: dict): 36 | self.config = config 37 | 38 | def load_data(self): 39 | from pathlib import Path 40 | 41 | source_dict = { 42 | "Goodreads": 'Goodreads/clean', 43 | "Games_5core": "Video_Games/5-core/downstream", 44 | "Movies_5core": "Movies_and_TV/5-core/downstream", 45 | "Arts_5core": "Arts_Crafts_and_Sewing/5-core/downstream", 46 | "Sports_5core": "Sports_and_Outdoors/5-core/downstream", 47 | "Baby_5core": "Baby_Products/5-core/downstream", 48 | } 49 | self.config['source_dict'] = source_dict 50 | 51 | def read_data_from_file(domain, mode=''): 52 | base_path = Path('data/') 53 | file_path = base_path / source_dict[domain] / '{}data.txt'.format(mode) 54 | with file_path.open('r') as file: 55 | item_seqs = [list(map(int, line.split()))[-self.config['max_seq_length']-1:] for line in file] 56 | 57 | if mode == '': 58 | flat_list = [item for sublist in item_seqs for item in sublist] 59 | import numpy as np 60 | item_num = np.max(flat_list) 61 | return item_seqs, item_num 62 | else: 63 | return item_seqs 64 | 65 | train_data = [] 66 | valid_data = [] 67 | test_data = [] 68 | 69 | tmp_item_seqs, total_item_num = read_data_from_file(self.config['dataset']) 70 | tmp_train_item_seqs, tmp_valid_item_seqs, tmp_test_item_seqs = ( 71 | read_data_from_file(self.config['dataset'], mode='train_'), 72 | read_data_from_file(self.config['dataset'], mode='val_'), 73 | read_data_from_file(self.config['dataset'], mode='test_') 74 | ) 75 | train_data.extend(tmp_train_item_seqs) 76 | valid_data.extend(tmp_valid_item_seqs) 77 | test_data.extend(tmp_test_item_seqs) 78 | select_pool = [1, total_item_num + 1] 79 | 80 | return ( 81 | SequenceDataset(self.config, train_data, seq_type='train'), 82 | SequenceDataset(self.config, valid_data, seq_type='val'), 83 | SequenceDataset(self.config, test_data, seq_type='test'), 84 | select_pool, 85 | total_item_num 86 | ) 87 | -------------------------------------------------------------------------------- /repeated_evaluate_with_seqrec.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import numpy as np 4 | from seqrec.runner import Runner 5 | from seqrec.utils import parse_command_line_args 6 | import os 7 | import json 8 | # os.environ['CUDA_VISIBLE_DEVICES'] = '3' 9 | 10 | 11 | def parse_args(): 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--model', type=str, default='SASRec_v2', help='Model name, options: SASRec, GRU4Rec') 14 | parser.add_argument('--dataset', type=str, default='Games_5core', help='Source domain') 15 | parser.add_argument('--exp_type', type=str, default='srec') 16 | 17 | parser.add_argument('--embedding', type=str, default='', help='Whether to use source domain data') 18 | parser.add_argument('--seq_embedding', type=str, default='', help='whether pre-trained sequence embeddings are used') 19 | 20 | # parser.add_argument('--embedding', type=str, default='./item_info/Baby_AF2021/LLM2Vec_Mix6_SeqAug_Mistral-7B_step600_title_item_embs.npy', help='Whether to use source domain data') 21 | # parser.add_argument('--seq_embedding', type=str, default='./item_info/Baby_AF2021/LLM2Vec_Mix6_SeqAug_Mistral-7B_step600_{}_seq_embs.npy', help='whether pre-trained sequence embeddings are used') 22 | 23 | return parser.parse_known_args() 24 | 25 | 26 | 27 | def calculate_mean_and_std(results): 28 | metrics = {} 29 | for result in results: 30 | for key, value in result.items(): 31 | if key not in metrics: 32 | metrics[key] = [] 33 | metrics[key].append(value) 34 | 35 | stats = {} 36 | for metric, values in metrics.items(): 37 | mean = np.mean(values) 38 | std = np.std(values) 39 | stats[metric] = (float(mean), float(std)) 40 | 41 | return stats 42 | 43 | 44 | 45 | if __name__ == '__main__': 46 | args, unparsed_args = parse_args() 47 | command_line_configs = parse_command_line_args(unparsed_args) 48 | args_dict = vars(args) # 将 args 转换为字典 49 | 50 | # 合并字典,假设 command_line_configs 是一个字典 51 | merged_dict = {**args_dict, **command_line_configs} 52 | 53 | 54 | exp_seeds = [2024, 2025, 2026] 55 | test_results = [] 56 | for seed in exp_seeds: 57 | merged_dict['rand_seed'] = seed 58 | 59 | runner = Runner( 60 | model_name=args.model, 61 | config_dict= merged_dict 62 | ) 63 | test_result, exp_config = runner.run() 64 | 65 | test_results.append(test_result) 66 | 67 | # calcuate average and std of test results 68 | stats = calculate_mean_and_std(test_results) 69 | 70 | result_save_dir = f"./Results/{exp_config['dataset']}/{exp_config['model']}/lr_{exp_config['lr']}_dr_{exp_config['dropout']}_time_{datetime.datetime.now().strftime('%b-%d-%Y_%H-%M-%S')}_emb_{exp_config['embedding']}" 71 | if not os.path.exists(result_save_dir): 72 | os.makedirs(result_save_dir) 73 | # write the stats to local txt file 74 | with open(f'{result_save_dir}/results.txt', 'a') as f: 75 | f.write(f"Final Results for {exp_config['model']} on {exp_config['dataset']}:\n") 76 | for key, value in stats.items(): 77 | f.write(f'{key}: {value}\n') 78 | 79 | # write the results of each experiment 80 | f.write("\n\n") 81 | f.write("Results of each experiment:\n") 82 | for i, result in enumerate(test_results): 83 | f.write(f"Experiment {i+1}:\n") 84 | for key, value in result.items(): 85 | f.write(f'{key}: {value}\n') 86 | f.write("\n") 87 | 88 | # save config as pretty json file 89 | with open(f'{result_save_dir}/config.json', 'w') as f: 90 | json.dump(merged_dict, f, indent=4) 91 | 92 | print("Finished:") 93 | print(stats) 94 | 95 | # CUDA_VISIBLE_DEVICES=3 accelerate launch --main_process_port=12324 main.py --model=PDSRec --sd=T --td=T 96 | 97 | -------------------------------------------------------------------------------- /llm2rec/recdata/ItemTitleData.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | import os 4 | 5 | from .dataset import DataSample, TrainSample, Dataset 6 | from accelerate.logging import get_logger 7 | 8 | logger = get_logger(__name__, log_level="INFO") 9 | 10 | 11 | AMAZON_DATASET_NAME_MAPPING = { 12 | "Mix6": "AmazonMix-6/5-core/info/item_titles.txt", 13 | } 14 | 15 | class ItemTitleData(Dataset): 16 | def __init__( 17 | self, 18 | dataset_name: str = "Rec", 19 | split: str = "validation", 20 | file_path: str = "data", 21 | effective_batch_size: int = 32, 22 | shuffle_individual_datasets: bool = True, 23 | separator: str = "!@#$%^&*()", 24 | ): 25 | self.dataset_name = dataset_name 26 | self.split = split 27 | self.effective_batch_size = effective_batch_size 28 | self.shuffle_individual_datasets = shuffle_individual_datasets 29 | self.separator = separator 30 | 31 | self.data = [] 32 | self.load_data(file_path) 33 | 34 | def __len__(self): 35 | return len(self.data) 36 | 37 | def load_data(self, file_path: str = None): 38 | logger.info(f"Loading Rec data from {file_path}...") 39 | # file path is actually a directory 40 | 41 | data_map = {} 42 | all_samples = [] 43 | id_ = 0 44 | for dataset in AMAZON_DATASET_NAME_MAPPING: 45 | logger.info(f"Loading dataset {dataset}...") 46 | if dataset not in data_map: 47 | data_map[dataset] = [] 48 | 49 | dataset_raw_naming = AMAZON_DATASET_NAME_MAPPING[dataset] 50 | dataset_samples = [] 51 | with open(os.path.join(file_path, dataset_raw_naming), "r") as f: 52 | for line in f: 53 | dataset_samples.append(line.strip()) 54 | 55 | for i, sample in enumerate(dataset_samples): 56 | query = self.separator + sample 57 | pos = self.separator + sample 58 | data_map[dataset].append(id_) 59 | 60 | all_samples.append( 61 | DataSample( 62 | id_=id_, 63 | query=query, 64 | positive=pos, 65 | task_name=dataset, 66 | ) 67 | ) 68 | id_ += 1 69 | 70 | # combine split1 and split2 71 | new_data_map = {} 72 | for dataset in data_map: 73 | new_dataset = dataset.replace("_split1", "").replace("_split2", "") 74 | if new_dataset not in new_data_map: 75 | new_data_map[new_dataset] = [] 76 | new_data_map[new_dataset] += data_map[dataset] 77 | data_map = new_data_map 78 | 79 | if self.shuffle_individual_datasets: 80 | for task, samples in data_map.items(): 81 | random.shuffle(samples) 82 | 83 | datasets = list(data_map.keys()) 84 | 85 | logger.info( 86 | f"Batching REC data properly for effective batch size of {self.effective_batch_size}..." 87 | ) 88 | all_batches = [] 89 | for dataset in datasets: 90 | dataset_samples = data_map[dataset] 91 | for i in range(0, len(dataset_samples), self.effective_batch_size): 92 | batch = dataset_samples[i : i + self.effective_batch_size] 93 | if len(batch) == self.effective_batch_size: 94 | all_batches.append(batch) 95 | else: 96 | logger.info(f"Skip 1 batch for dataset {dataset}.") 97 | random.shuffle(all_batches) 98 | 99 | final_idx_order = [] 100 | for batch in all_batches: 101 | for idx in batch: 102 | final_idx_order.append(idx) 103 | 104 | self.data = [all_samples[idx] for idx in final_idx_order] 105 | logger.info(f"Loaded {len(self.data)} samples.") 106 | 107 | def __getitem__(self, index): 108 | sample = self.data[index] 109 | if self.split == "train": 110 | return TrainSample( 111 | texts=[sample.query, sample.positive], label=1.0 112 | ) 113 | elif self.split == "validation": 114 | assert False, "RecData does not have a validation split." 115 | -------------------------------------------------------------------------------- /utils/datasets.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | import numpy as np 4 | import scipy.sparse as sp 5 | import json 6 | import torch 7 | from torch.utils.data import Dataset, DataLoader 8 | 9 | 10 | def print_statistics(X, string): 11 | print('>'*10 + string + '>'*10 ) 12 | print('Average interactions', X.sum(1).mean(0).item()) 13 | nonzero_row_indice, nonzero_col_indice = X.nonzero() 14 | unique_nonzero_row_indice = np.unique(nonzero_row_indice) 15 | unique_nonzero_col_indice = np.unique(nonzero_col_indice) 16 | print('Non-zero rows', len(unique_nonzero_row_indice)/X.shape[0]) 17 | print('Non-zero columns', len(unique_nonzero_col_indice)/X.shape[1]) 18 | print('Matrix density', len(nonzero_row_indice)/(X.shape[0]*X.shape[1])) 19 | 20 | 21 | class TrainDataset(Dataset): 22 | def __init__(self, ui_pairs, ui_graph, num_items): 23 | self.ui_pairs = ui_pairs 24 | self.ui_graph = ui_graph 25 | self.num_items = num_items 26 | 27 | 28 | def __getitem__(self, index): 29 | uid, pos_id = self.ui_pairs[index] 30 | neg_id = np.random.randint(self.num_items) 31 | while self.ui_graph[uid, neg_id] == 1: 32 | neg_id = np.random.randint(self.num_items) 33 | return uid, pos_id, int(neg_id) 34 | 35 | 36 | def __len__(self): 37 | return len(self.ui_pairs) 38 | 39 | 40 | class TestDataset(Dataset): 41 | def __init__(self, ui_pairs, ui_graph, ui_graph_train, num_users, num_items): 42 | self.ui_pairs = ui_pairs 43 | self.ui_graph = ui_graph 44 | self.train_mask_ui = ui_graph_train 45 | 46 | self.num_users = num_users 47 | self.num_items = num_items 48 | 49 | 50 | def __getitem__(self, index): 51 | ui_grd = torch.from_numpy(self.ui_graph[index].toarray()).squeeze() 52 | ui_mask = torch.from_numpy(self.train_mask_ui[index].toarray()).squeeze() 53 | return index, ui_grd, ui_mask 54 | 55 | 56 | def __len__(self): 57 | return self.ui_graph.shape[0] 58 | 59 | 60 | class Datasets(): 61 | def __init__(self, conf): 62 | self.path = conf['data_path'] 63 | self.name = conf['dataset'] 64 | batch_size_train = conf['batch_size'] 65 | batch_size_test = conf['test_batch_size'] 66 | self.num_users, self.num_items = self.get_dataset_size() 67 | 68 | ui_pairs_train, ui_graph_train = self.get_graph("train.txt") 69 | ui_pairs_val, ui_graph_val = self.get_graph("valid.txt") 70 | ui_pairs_test, ui_graph_test = self.get_graph("test.txt") 71 | 72 | self.ui_graph_train = ui_graph_train 73 | 74 | self.train_data = TrainDataset(ui_pairs_train, ui_graph_train, self.num_items) 75 | self.val_data = TestDataset(ui_pairs_val, ui_graph_val, ui_graph_train, self.num_users, self.num_items) 76 | self.test_data = TestDataset(ui_pairs_test, ui_graph_test, ui_graph_train, self.num_users, self.num_items) 77 | 78 | self.train_loader = DataLoader(self.train_data, batch_size=batch_size_train, shuffle=True, num_workers=20, drop_last=True) 79 | self.val_loader = DataLoader(self.val_data, batch_size=batch_size_test, shuffle=False, num_workers=10) 80 | self.test_loader = DataLoader(self.test_data, batch_size=batch_size_test, shuffle=False, num_workers=10) 81 | 82 | 83 | def get_dataset_size(self): 84 | data_path = self.path 85 | with open(data_path + "count.json", 'r') as file: 86 | count_info = json.load(file) 87 | 88 | n_users = count_info['#U'] 89 | n_items = count_info['#I'] 90 | return int(n_users), int(n_items) 91 | 92 | 93 | def get_graph(self, filename): 94 | data_path = self.path 95 | 96 | ui_pairs = [] 97 | with open(data_path + filename, 'r') as file: 98 | for line in file: 99 | parts = line.strip().split(" ") 100 | user_id = parts[0] 101 | item_ids = parts[1:] 102 | for item_id in item_ids: 103 | ui_pairs.append([int(user_id), int(item_id)]) 104 | ui_pairs = np.array(ui_pairs, dtype=np.int32) 105 | 106 | indice = np.array(ui_pairs, dtype=np.int32) 107 | values = np.ones(len(ui_pairs), dtype=np.float32) 108 | ui_graph = sp.coo_matrix((values, (indice[:, 0], indice[:, 1])), shape=(self.num_users, self.num_items)).tocsr() 109 | return ui_pairs, ui_graph -------------------------------------------------------------------------------- /llm2rec/recdata/RecItemData.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | import os 4 | 5 | from .dataset import DataSample, TrainSample, Dataset 6 | from accelerate.logging import get_logger 7 | 8 | logger = get_logger(__name__, log_level="INFO") 9 | 10 | 11 | AMAZON_DATASET_NAME_MAPPING = { 12 | "Arts": "Arts_Crafts_and_Sewing", 13 | "Electronics": "Electronics", 14 | "Home": "Home_and_Kitchen", 15 | "Movies": "Movies_and_TV", 16 | "Tools": "Tools_and_Home_Improvement", 17 | "Games": "Video_Games", 18 | 19 | # "Sports": "Sports_and_Outdoors", 20 | } 21 | NUM_TRAINING_SAMPLES = 100000 22 | 23 | class RecItemData(Dataset): 24 | def __init__( 25 | self, 26 | dataset_name: str = "Rec", 27 | split: str = "validation", 28 | file_path: str = "dataset/llm2vec", 29 | effective_batch_size: int = 32, 30 | shuffle_individual_datasets: bool = True, 31 | separator: str = "!@#$%^&*()", 32 | ): 33 | self.dataset_name = dataset_name 34 | self.split = split 35 | self.effective_batch_size = effective_batch_size 36 | self.shuffle_individual_datasets = shuffle_individual_datasets 37 | self.separator = separator 38 | 39 | self.data = [] 40 | self.load_data(file_path) 41 | 42 | def __len__(self): 43 | return len(self.data) 44 | 45 | def load_data(self, file_path: str = None): 46 | logger.info(f"Loading Rec data from {file_path}...") 47 | # file path is actually a directory 48 | 49 | data_map = {} 50 | all_samples = [] 51 | id_ = 0 52 | for dataset in AMAZON_DATASET_NAME_MAPPING: 53 | logger.info(f"Loading dataset {dataset}...") 54 | if dataset not in data_map: 55 | data_map[dataset] = [] 56 | 57 | dataset_raw_naming = AMAZON_DATASET_NAME_MAPPING[dataset] 58 | with open(os.path.join(file_path, f"{dataset_raw_naming}/training_item_pairs_gap24.jsonl"), "r") as f: 59 | dataset_samples = json.loads(f.read().strip()) 60 | 61 | if len(dataset_samples) > NUM_TRAINING_SAMPLES: 62 | dataset_samples = random.sample(dataset_samples, NUM_TRAINING_SAMPLES) 63 | 64 | for i, sample in enumerate(dataset_samples): 65 | query = self.separator + sample[0] 66 | pos = self.separator + sample[1] 67 | data_map[dataset].append(id_) 68 | 69 | all_samples.append( 70 | DataSample( 71 | id_=id_, 72 | query=query, 73 | positive=pos, 74 | task_name=dataset, 75 | ) 76 | ) 77 | id_ += 1 78 | 79 | # combine split1 and split2 80 | new_data_map = {} 81 | for dataset in data_map: 82 | new_dataset = dataset.replace("_split1", "").replace("_split2", "") 83 | if new_dataset not in new_data_map: 84 | new_data_map[new_dataset] = [] 85 | new_data_map[new_dataset] += data_map[dataset] 86 | data_map = new_data_map 87 | 88 | if self.shuffle_individual_datasets: 89 | for task, samples in data_map.items(): 90 | random.shuffle(samples) 91 | 92 | datasets = list(data_map.keys()) 93 | 94 | logger.info( 95 | f"Batching REC data properly for effective batch size of {self.effective_batch_size}..." 96 | ) 97 | all_batches = [] 98 | for dataset in datasets: 99 | dataset_samples = data_map[dataset] 100 | for i in range(0, len(dataset_samples), self.effective_batch_size): 101 | batch = dataset_samples[i : i + self.effective_batch_size] 102 | if len(batch) == self.effective_batch_size: 103 | all_batches.append(batch) 104 | else: 105 | logger.info(f"Skip 1 batch for dataset {dataset}.") 106 | random.shuffle(all_batches) 107 | 108 | final_idx_order = [] 109 | for batch in all_batches: 110 | for idx in batch: 111 | final_idx_order.append(idx) 112 | 113 | self.data = [all_samples[idx] for idx in final_idx_order] 114 | logger.info(f"Loaded {len(self.data)} samples.") 115 | 116 | def __getitem__(self, index): 117 | sample = self.data[index] 118 | if self.split == "train": 119 | return TrainSample( 120 | texts=[sample.query, sample.positive], label=1.0 121 | ) 122 | elif self.split == "validation": 123 | assert False, "RecData does not have a validation split." 124 | -------------------------------------------------------------------------------- /extract_llm_embedding.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import torch 3 | import numpy as np 4 | import os 5 | import os.path as op 6 | import json 7 | import argparse 8 | from utils.llm2vec_encoder import LLM2Vec 9 | token = os.environ.get("HUGGINGFACE_HUB_TOKEN") 10 | print(token) 11 | 12 | # fix random seed 13 | np.random.seed(0) 14 | torch.manual_seed(0) 15 | # os.environ["CUDA_VISIBLE_DEVICES"] = "7" 16 | 17 | 18 | from huggingface_hub import login 19 | # login(token=token) 20 | 21 | 22 | dataset_name_mappings = { 23 | # 5-core filtered datasets 24 | "Games_5core": "Video_Games/5-core/downstream", 25 | "Movies_5core": "Movies_and_TV/5-core/downstream", 26 | "Arts_5core": "Arts_Crafts_and_Sewing/5-core/downstream", 27 | "Sports_5core": "Sports_and_Outdoors/5-core/downstream", 28 | "Baby_5core": "Baby_Products/5-core/downstream", 29 | "Goodreads": "Goodreads/clean", 30 | } 31 | 32 | 33 | class llm2vec_encoder(): 34 | def __init__(self, model_path, peft_model_name_or_path, bidirectional=False): 35 | self.model_path = model_path 36 | 37 | if bidirectional: 38 | self.model = LLM2Vec.from_pretrained( 39 | model_path, 40 | peft_model_name_or_path=peft_model_name_or_path, 41 | device_map="cuda" if torch.cuda.is_available() else "cpu", 42 | torch_dtype=torch.bfloat16, 43 | use_auth_token=token, 44 | ) 45 | else: 46 | self.model = LLM2Vec.from_pretrained( 47 | model_path, 48 | peft_model_name_or_path=peft_model_name_or_path, 49 | device_map="cuda" if torch.cuda.is_available() else "cpu", 50 | torch_dtype=torch.bfloat16, 51 | use_auth_token=token, 52 | 53 | enable_bidirectional=False, 54 | pooling_mode="eos_token", 55 | ) 56 | 57 | def encode(self, sentences, batch_size): 58 | return np.asarray(self.model.encode(sentences, batch_size=batch_size)) 59 | 60 | def encode_with_prompt(self, sentences, batch_size, prompts): 61 | return np.asarray(self.model.encode(sentences, batch_size=batch_size)) 62 | 63 | 64 | def extract_item_embedding_with_prompts(dataset_name, model_path, peft_path, batch_size, prompt_type, bidirectional=False, save_info=None): 65 | # Load data here 66 | raw_dataset_name = dataset_name_mappings[dataset_name] 67 | with open(f"./data/{raw_dataset_name}/item_titles.json", 'r', encoding='utf-8') as file: 68 | item_metadata = json.load(file) 69 | 70 | if dataset_name in dataset_name_mappings: 71 | item_ids = [int(int_id) for int_id in item_metadata.keys()] 72 | max_item_id = max(item_ids) 73 | assert 0 not in item_ids, "Item IDs should not contain 0" 74 | 75 | # Add a null item as placeholder for item 0 76 | item_titles = ["Null"] 77 | for i in range(1, max_item_id + 1): 78 | item_titles.append(item_metadata[str(i)]) 79 | 80 | else: 81 | raise ValueError("Invalid dataset name") 82 | 83 | # item_titles = item_titles[:100] 84 | # print(token) 85 | 86 | model = llm2vec_encoder(model_path, peft_model_name_or_path=peft_path, bidirectional=bool(bidirectional)) 87 | 88 | item_infos = np.array(item_titles) 89 | if prompt_type == "direct": 90 | prompts = generate_direct_item_prompt_pog(item_infos) 91 | elif prompt_type == "title": 92 | prompts = generate_title_item_prompt_pog(item_infos) 93 | 94 | item_llama_embeds = model.encode(prompts, batch_size) 95 | 96 | save_path = f"./item_info/{dataset_name}/" 97 | if not os.path.isdir(save_path): 98 | os.makedirs(save_path) 99 | 100 | if save_info is not None: 101 | model_name = f"{save_info}" 102 | else: 103 | model_name = model_path.replace("/", "_") 104 | np.save(op.join(save_path, f"{model_name}_{prompt_type}_item_embs.npy"), item_llama_embeds) 105 | 106 | 107 | def generate_direct_item_prompt_pog(item_info): 108 | instruct = "To recommend this item to users, this item can be described as: " 109 | instructs = np.repeat(instruct, len(item_info)) 110 | prompts = item_info 111 | 112 | outputs = np.concatenate((instructs[:, np.newaxis], prompts[:, np.newaxis]), axis=1) 113 | return outputs 114 | 115 | 116 | def generate_title_item_prompt_pog(item_info): 117 | prompts = item_info 118 | return prompts 119 | 120 | 121 | if __name__ == "__main__": 122 | 123 | parser = argparse.ArgumentParser(description="Extract item embeddings with prompts.") 124 | parser.add_argument('--dataset', type=str, default="Arts_5core", help="Name of the dataset") 125 | parser.add_argument('--batch_size', type=int, default=16, help="Batch size for processing") 126 | parser.add_argument('--model_path', type=str, default='./', help="Path to the model") 127 | parser.add_argument('--peft_path', type=str, default=None, help="Path to the PEFT model") 128 | parser.add_argument('--item_prompt_type', type=str, default="title", help="Type of item prompt") 129 | parser.add_argument('--bidirectional', type=int, default=1, help="Bidirectional model") 130 | parser.add_argument('--save_info', type=str, default="Test-only", help="Save information identifier") 131 | args = parser.parse_args() 132 | 133 | args = parser.parse_args() 134 | extract_item_embedding_with_prompts( 135 | args.dataset, 136 | args.model_path, 137 | args.peft_path, 138 | args.batch_size, 139 | args.item_prompt_type, 140 | args.bidirectional, 141 | args.save_info 142 | ) 143 | -------------------------------------------------------------------------------- /seqrec/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from tqdm import tqdm 5 | from torch.optim import AdamW 6 | from .base import AbstractModel 7 | from transformers.optimization import get_scheduler 8 | from collections import defaultdict, OrderedDict 9 | from .utils import get_file_name, get_total_steps 10 | from .evaluator import Evaluator 11 | 12 | 13 | class BaseTrainer(object): 14 | def __init__(self, config: dict, model: AbstractModel): 15 | self.config = config 16 | self.model = model 17 | self.accelerator = config['accelerator'] 18 | self.evaluator = Evaluator(config) 19 | self.saved_model_ckpt = os.path.join( 20 | self.config['ckpt_dir'], 21 | get_file_name(self.config, suffix='.pth') 22 | ) 23 | os.makedirs(os.path.dirname(self.saved_model_ckpt), exist_ok=True) 24 | self.best_metric = 0 25 | self.best_epoch = 0 26 | self.count = 0 27 | 28 | self.checkpoints_deque = [] 29 | 30 | def train(self, train_dataloader, val_dataloader): 31 | optimizer = AdamW( 32 | self.model.parameters(), 33 | lr=self.config['lr'], 34 | weight_decay=self.config['weight_decay'], 35 | ) 36 | total_n_steps = get_total_steps(self.config, train_dataloader) 37 | # scheduler = get_scheduler( 38 | # name="cosine", 39 | # optimizer=optimizer, 40 | # num_warmup_steps=self.config['warmup_steps'], 41 | # num_training_steps=total_n_steps, 42 | # ) 43 | self.model, optimizer, train_dataloader, val_dataloader = self.accelerator.prepare( 44 | self.model, optimizer, train_dataloader, val_dataloader) 45 | self.config.pop('accelerator') 46 | self.accelerator.init_trackers( 47 | project_name="PreferDiff-Re", 48 | config=self.config 49 | ) 50 | n_epochs = np.ceil(total_n_steps / (len(train_dataloader) * self.accelerator.num_processes)).astype(int) 51 | best_epoch = 0 52 | best_val_score = -1 53 | for epoch in range(n_epochs): 54 | # Training 55 | self.model.train() 56 | total_loss = 0.0 57 | train_progress_bar = tqdm( 58 | train_dataloader, 59 | total=len(train_dataloader), 60 | desc=f"Training - [Epoch {epoch + 1}]", 61 | ) 62 | for batch in train_progress_bar: 63 | optimizer.zero_grad() 64 | outputs = self.model(batch) 65 | loss = outputs['loss'] 66 | self.accelerator.backward(loss) 67 | optimizer.step() 68 | # scheduler.step() 69 | total_loss = total_loss + loss.item() 70 | 71 | self.accelerator.log({"Loss/train_loss": total_loss / len(train_dataloader)}, step=epoch + 1) 72 | 73 | # Evaluation 74 | if (epoch + 1) % self.config['eval_interval'] == 0: 75 | all_results = self.evaluate(val_dataloader, split='val') 76 | if self.accelerator.is_main_process: 77 | for key in all_results: 78 | self.accelerator.log({f"Val_Metric/{key}": all_results[key]}, step=epoch + 1) 79 | print(all_results) 80 | 81 | val_score = all_results[self.config['val_metric']] 82 | if val_score > best_val_score: 83 | best_val_score = val_score 84 | best_epoch = epoch + 1 85 | if self.accelerator.is_main_process: 86 | if self.config['use_ddp']: # unwrap model for saving 87 | unwrapped_model = self.accelerator.unwrap_model(self.model) 88 | torch.save(unwrapped_model.state_dict(), self.saved_model_ckpt) 89 | else: 90 | torch.save(self.model.state_dict(), self.saved_model_ckpt) 91 | print(f'[Epoch {epoch + 1}] Saved model checkpoint to {self.saved_model_ckpt}') 92 | else: 93 | print('Patience for {} Times'.format(epoch + 1 - best_epoch)) 94 | 95 | if self.config['patience'] is not None and epoch + 1 - best_epoch >= self.config['patience']: 96 | print(f'Early stopping at epoch {epoch + 1}') 97 | break 98 | print(f'Best epoch: {best_epoch}, Best val score: {best_val_score}') 99 | 100 | def evaluate(self, dataloader, split='test'): 101 | 102 | self.model.eval() 103 | 104 | all_results = defaultdict(list) 105 | val_progress_bar = tqdm( 106 | dataloader, 107 | total=len(dataloader), 108 | desc=f"Eval - {split}", 109 | ) 110 | for batch in val_progress_bar: 111 | with torch.no_grad(): 112 | batch = {k: v.to(self.accelerator.device) if k != "seq_type" else v for k, v in batch.items()} 113 | if self.config['use_ddp']: # ddp, gather data from all devices for evaluation 114 | preds = self.model.module.predict(batch, n_return_sequences=self.evaluator.maxk) 115 | all_preds, all_labels = self.accelerator.gather_for_metrics((preds, batch['labels'])) 116 | results = self.evaluator.calculate_metrics(all_preds, all_labels) 117 | else: 118 | preds = self.model.predict(batch, n_return_sequences=self.evaluator.maxk) 119 | results = self.evaluator.calculate_metrics(preds, batch['labels']) 120 | 121 | for key, value in results.items(): 122 | all_results[key].append(value) 123 | 124 | 125 | output_results = OrderedDict() 126 | for metric in self.config['metrics']: 127 | for k in self.config['topk']: 128 | key = f"{metric}@{k}" 129 | output_results[key] = torch.cat(all_results[key]).mean().item() 130 | return output_results 131 | 132 | def end(self): 133 | """ 134 | Ends the training process and releases any used resources 135 | """ 136 | self.accelerator.end_training() 137 | -------------------------------------------------------------------------------- /baselines/model.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("baselines/EasyRec") 3 | import torch 4 | from baselines.EasyRecModel import Easyrec_encoder 5 | import torch.nn.functional as F 6 | from transformers import AutoConfig, AutoModel, AutoTokenizer, RobertaModel, RobertaTokenizer 7 | 8 | 9 | 10 | class EasyRec(torch.nn.Module): 11 | def __init__(self, device): 12 | super().__init__() 13 | self.device = device 14 | self.config = AutoConfig.from_pretrained("hkuds/easyrec-roberta-large") 15 | self.model = Easyrec_encoder.from_pretrained("hkuds/easyrec-roberta-large", config=self.config,).to(self.device) 16 | self.tokenizer = AutoTokenizer.from_pretrained("hkuds/easyrec-roberta-large", use_fast=False,) 17 | 18 | def forward(self, x): 19 | # x is a batch of text sequences 20 | 21 | inputs = self.tokenizer(x.tolist(), padding=True, truncation=True, max_length=512, return_tensors="pt").to(self.device) 22 | with torch.inference_mode(): 23 | embeddings = self.model.encode(input_ids=inputs.input_ids, attention_mask=inputs.attention_mask) 24 | embeddings = F.normalize(embeddings.pooler_output.detach().float(), dim=-1) 25 | 26 | return embeddings 27 | 28 | class Blair(torch.nn.Module): 29 | def __init__(self, device): 30 | super().__init__() 31 | self.device = device 32 | self.tokenizer = AutoTokenizer.from_pretrained("hyp1231/blair-roberta-base") 33 | self.model = AutoModel.from_pretrained("hyp1231/blair-roberta-base").to(self.device) 34 | 35 | def forward(self, x): 36 | # x is a batch of text sequences 37 | 38 | inputs = self.tokenizer(x.tolist(), padding=True, truncation=True, max_length=512, return_tensors="pt").to(self.device) 39 | with torch.no_grad(): 40 | embeddings = self.model(**inputs, return_dict=True).last_hidden_state[:, 0] 41 | embeddings = embeddings / embeddings.norm(dim=1, keepdim=True) 42 | 43 | return embeddings 44 | 45 | 46 | class BGE(torch.nn.Module): 47 | def __init__(self, device): 48 | super().__init__() 49 | self.device = device 50 | self.tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-large-en-v1.5') # or BAAI/bge-m3, BAAI/llm-embedder 51 | self.model = AutoModel.from_pretrained('BAAI/bge-large-en-v1.5').to(self.device) 52 | self.model.eval() 53 | 54 | def forward(self, x): 55 | # x is a batch of text sequences 56 | inputs = self.tokenizer(x.tolist(), padding=True, truncation=True, return_tensors='pt').to(self.device) 57 | 58 | with torch.no_grad(): 59 | embeddings = self.model(**inputs)[0][:, 0] 60 | embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) 61 | 62 | return embeddings 63 | 64 | 65 | class BERT(torch.nn.Module): 66 | def __init__(self, device): 67 | super().__init__() 68 | self.device = device 69 | self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') # Use BERT tokenizer 70 | self.model = AutoModel.from_pretrained('bert-base-uncased').to(self.device) # Load BERT model 71 | self.model.eval() # Set model to evaluation mode 72 | 73 | def forward(self, x): 74 | # x is a batch of text sequences (list of strings) 75 | inputs = self.tokenizer( 76 | x.tolist(), 77 | padding=True, 78 | truncation=True, 79 | return_tensors='pt' # Return PyTorch tensors 80 | ).to(self.device) 81 | 82 | with torch.no_grad(): 83 | # Pass the tokenized inputs through the BERT model 84 | outputs = self.model(**inputs) 85 | cls_embeddings = outputs.last_hidden_state[:, 0, :] 86 | normalized_embeddings = torch.nn.functional.normalize(cls_embeddings, p=2, dim=1) 87 | 88 | return normalized_embeddings 89 | 90 | 91 | class RoBERTa_large_sentence(torch.nn.Module): 92 | def __init__(self, device): 93 | super().__init__() 94 | self.device = device 95 | self.tokenizer = RobertaTokenizer.from_pretrained('roberta-large') # Use RoBERTa tokenizer 96 | self.model = RobertaModel.from_pretrained('roberta-large').to(self.device) # Load RoBERTa model 97 | self.model.eval() # Set model to evaluation mode 98 | 99 | 100 | def mean_pooling(self, model_output, attention_mask): 101 | token_embeddings = model_output[0] #First element of model_output contains all token embeddings 102 | input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() 103 | return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) 104 | 105 | 106 | def forward(self, x): 107 | # x is a batch of text sequences (list of strings) 108 | inputs = self.tokenizer( 109 | x.tolist(), 110 | padding=True, 111 | truncation=True, 112 | return_tensors='pt' # Return PyTorch tensors 113 | ).to(self.device) 114 | 115 | with torch.no_grad(): 116 | # Pass the tokenized inputs through the RoBERTa model 117 | outputs = self.model(**inputs) 118 | sentence_embeddings = self.mean_pooling(outputs, inputs['attention_mask']) 119 | normalized_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1) 120 | 121 | return normalized_embeddings 122 | 123 | 124 | 125 | class GTE_7B(torch.nn.Module): 126 | def __init__(self, device): 127 | super().__init__() 128 | self.device = device 129 | self.tokenizer = AutoTokenizer.from_pretrained('Alibaba-NLP/gte-Qwen2-7B-instruct') # Use GTE tokenizer 130 | self.model = AutoModel.from_pretrained('Alibaba-NLP/gte-Qwen2-7B-instruct').to(self.device) # Load GTE model 131 | self.model.eval() # Set model to evaluation mode 132 | 133 | 134 | def last_token_pool(self, last_hidden_states, attention_mask): 135 | sequence_lengths = attention_mask.sum(dim=1) - 1 136 | return last_hidden_states[torch.arange(last_hidden_states.shape[0], device=last_hidden_states.device), sequence_lengths] 137 | 138 | 139 | def forward(self, x): 140 | # x is a batch of text sequences (list of strings) 141 | inputs = self.tokenizer( 142 | x.tolist(), 143 | padding=True, 144 | truncation=True, 145 | return_tensors='pt' # Return PyTorch tensors 146 | ).to(self.device) 147 | 148 | with torch.no_grad(): 149 | # Pass the tokenized inputs through the GTE model 150 | outputs = self.model(**inputs) 151 | pooled_embeddings = self.last_token_pool(outputs.last_hidden_state, inputs['attention_mask']) 152 | normalized_embeddings = F.normalize(pooled_embeddings, p=2, dim=1) 153 | return normalized_embeddings 154 | 155 | 156 | 157 | 158 | -------------------------------------------------------------------------------- /seqrec/models/SASRec/_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from seqrec.base import AbstractModel 5 | from seqrec.modules import TransformerEncoder_v2, get_attention_mask, gather_indexes 6 | from ..Embedding2 import Embedding2 7 | 8 | 9 | 10 | class SASRec(AbstractModel): 11 | def __init__(self, config: dict, pretrained_item_embeddings=None): 12 | super(SASRec, self).__init__(config=config) 13 | self.config = config 14 | self.load_item_embeddings(pretrained_item_embeddings) 15 | 16 | # Initialize embeddings 17 | self.positional_embeddings = nn.Embedding( 18 | num_embeddings=config['max_seq_length'], 19 | embedding_dim=config['hidden_size'] 20 | ) 21 | 22 | self.emb_dropout = nn.Dropout(config['dropout']) 23 | 24 | # Initialize Transformer layers 25 | self.transformer_encoder = TransformerEncoder_v2(config) 26 | 27 | # Initialize loss function 28 | if config['loss_type'] == 'bce': 29 | self.loss_func = nn.BCEWithLogitsLoss() 30 | elif config['loss_type'] == "ce": 31 | self.loss_func = nn.CrossEntropyLoss() 32 | 33 | def load_item_embeddings(self, pretrained_embs): 34 | if pretrained_embs is None: 35 | self.item_embeddings = nn.Embedding( 36 | num_embeddings=self.config['item_num'] + 1, 37 | embedding_dim=self.config['hidden_size'], 38 | padding_idx=0 39 | ) 40 | nn.init.normal_(self.item_embeddings.weight, 0, 1) 41 | 42 | # use pretrained textual embedding with linear mapping as item embedding 43 | else: 44 | more_token = 0 45 | assert pretrained_embs.shape[0] == self.config['item_num'] + 1 46 | self.pretrained_item_embeddings = nn.Embedding.from_pretrained( 47 | torch.cat([ 48 | pretrained_embs, 49 | torch.randn(more_token, pretrained_embs.shape[-1]).to(pretrained_embs.device) 50 | ]), 51 | padding_idx=0 52 | ) 53 | # fix pretrained item embedding 54 | self.pretrained_item_embeddings.weight.requires_grad = False 55 | self.pretrained_item_embeddings.weight[-more_token:].requires_grad = True 56 | 57 | assert self.config['adapter_dims'][-1] == -1 58 | mlp_dims = [self.pretrained_item_embeddings.embedding_dim] + self.config['adapter_dims'] 59 | mlp_dims[-1] = self.config['hidden_size'] 60 | 61 | # create mlp with linears and activations 62 | self.item_embeddings_adapter = nn.Sequential() 63 | self.item_embeddings_adapter.add_module('linear_0', nn.Linear(mlp_dims[0], mlp_dims[1])) 64 | for i in range(1, len(mlp_dims) - 1): 65 | self.item_embeddings_adapter.add_module(f'activation_{i}', nn.ReLU()) 66 | self.item_embeddings_adapter.add_module(f'linear_{i}', nn.Linear(mlp_dims[i], mlp_dims[i + 1])) 67 | 68 | # initialize the adapter 69 | for name, param in self.item_embeddings_adapter.named_parameters(): 70 | if 'weight' in name: 71 | nn.init.xavier_normal_(param) 72 | elif 'bias' in name: 73 | nn.init.constant_(param, 0) 74 | 75 | self.item_embeddings = Embedding2(self.item_embeddings_adapter, self.pretrained_item_embeddings) 76 | 77 | 78 | # Note: to replace item_embedding, we need to modify both get_embeddings and get_all_embeddings functions. 79 | def get_embeddings(self, items): 80 | return self.item_embeddings(items) 81 | 82 | def get_all_embeddings(self, device=None): 83 | return self.item_embeddings.weight.data 84 | 85 | 86 | def get_representation(self, batch): 87 | inputs_emb = self.get_embeddings(batch['item_seqs']) 88 | inputs_emb += self.positional_embeddings( 89 | torch.arange(self.config['max_seq_length']).to(inputs_emb.device) 90 | ) 91 | seq = self.emb_dropout(inputs_emb) 92 | mask = torch.ne(batch['item_seqs'], 0).float().to(inputs_emb.device) 93 | 94 | mask = get_attention_mask(mask, bidirectional=False) 95 | 96 | seq = self.transformer_encoder(seq, attention_mask=mask) 97 | 98 | output = seq[-1] 99 | output = gather_indexes(output, batch['seq_lengths'] - 1) 100 | return output 101 | 102 | def forward(self, batch): 103 | state_hidden = self.get_representation(batch) 104 | test_item_emb = self.get_all_embeddings(state_hidden.device) 105 | if self.config['loss_type'] == 'bce': 106 | labels_neg = self._generate_negative_samples(batch) 107 | # labels_neg = labels_neg.view(-1, 1) 108 | logits = torch.matmul(state_hidden, test_item_emb.transpose(0, 1)) 109 | pos_scores = torch.gather(logits, 1, batch['labels'].view(-1, 1)) 110 | neg_scores = torch.gather(logits, 1, labels_neg) 111 | pos_labels = torch.ones((batch['labels'].shape[0], 1), device=state_hidden.device) 112 | neg_labels = torch.zeros((batch['labels'].shape[0], labels_neg.shape[1]), device=state_hidden.device) 113 | 114 | scores = torch.cat((pos_scores, neg_scores), dim=1).view(-1, 1) # Shape: (batch_size * (1 + num_neg), 1) 115 | labels = torch.cat((pos_labels, neg_labels), dim=1).view(-1, 1) # Shape: (batch_size * (1 + num_neg), 1) 116 | 117 | loss = self.loss_func(scores, labels) 118 | 119 | elif self.config['loss_type'] == 'ce': 120 | logits = torch.matmul(state_hidden, test_item_emb.transpose(0, 1)) 121 | loss = self.loss_func(logits, batch['labels'].view(-1)) 122 | 123 | return {'loss': loss} 124 | 125 | def predict(self, batch, n_return_sequences=1): 126 | state_hidden = self.get_representation(batch).view(-1, self.config['hidden_size']) 127 | test_item_emb = self.get_all_embeddings(state_hidden.device) 128 | scores = torch.matmul(state_hidden, test_item_emb.transpose(0, 1))[:, 129 | self.config['select_pool'][0]: self.config['select_pool'][1]] 130 | preds = scores.topk(n_return_sequences, dim=-1).indices + self.config['select_pool'][0] 131 | return preds 132 | 133 | def _generate_negative_samples(self, batch): 134 | # if self.config['sample_func'] == 'batch': 135 | # return in_batch_negative_sampling(batch['labels']) 136 | 137 | target_neg = [] 138 | for index in range(len(batch['labels'])): 139 | neg=np.random.randint(self.config['select_pool'][0], self.config['select_pool'][1]) 140 | while neg==batch['labels'][index]: 141 | neg = np.random.randint(self.config['select_pool'][0], self.config['select_pool'][1]) 142 | target_neg.append(neg) 143 | 144 | return torch.LongTensor(target_neg).to(batch['labels'].device).reshape(-1, 1) 145 | -------------------------------------------------------------------------------- /Baseline_inference.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import torch 3 | import time 4 | import numpy as np 5 | import os 6 | import os.path as op 7 | import json 8 | from tqdm import tqdm 9 | from baselines.model import EasyRec, Blair, BGE, BERT, GTE_7B, RoBERTa_large_sentence 10 | # Load model directly 11 | token = os.environ.get("HUGGINGFACE_HUB_TOKEN") 12 | print(token) 13 | 14 | np.random.seed(0) 15 | torch.manual_seed(0) 16 | 17 | from huggingface_hub import login 18 | login(token=token) 19 | 20 | 21 | dataset_name_mappings = { 22 | # 5-core filtered datasets 23 | "Games_5core": "Video_Games/5-core/downstream", 24 | "Movies_5core": "Movies_and_TV/5-core/downstream", 25 | "Arts_5core": "Arts_Crafts_and_Sewing/5-core/downstream", 26 | 27 | "Sports_5core": "Sports_and_Outdoors/5-core/downstream", 28 | "Baby_5core": "Baby_Products/5-core/downstream", 29 | "Goodreads": 'Goodreads/clean', 30 | } 31 | 32 | 33 | def extract_item_embeddings(model, dataset_name, batch_size, prompt_type, save_info=None): 34 | # Load data here 35 | raw_dataset_name = dataset_name_mappings[dataset_name] 36 | with open(f"./data/{raw_dataset_name}/item_titles.json", 'r', encoding='utf-8') as file: 37 | item_metadata = json.load(file) 38 | 39 | if dataset_name in dataset_name_mappings: 40 | item_ids = [int(int_id) for int_id in item_metadata.keys()] 41 | max_item_id = max(item_ids) 42 | assert 0 not in item_ids, "Item IDs should not contain 0" 43 | 44 | # Add a null item as placeholder for item 0 45 | item_titles = ["Null"] 46 | for i in range(1, max_item_id + 1): 47 | item_titles.append(item_metadata[str(i)]) 48 | 49 | else: 50 | raise ValueError("Invalid dataset name") 51 | 52 | item_infos = np.array(item_titles) 53 | if prompt_type == "direct": 54 | prompts = generate_direct_item_prompt_pog(item_infos) 55 | elif prompt_type == "title": 56 | prompts = generate_title_item_prompt_pog(item_infos) 57 | # elif prompt_type == "summarize": 58 | # prompts = generate_summarize_item_prompt(item_infos) 59 | 60 | item_llama_embeds = [] 61 | 62 | if type(model).__name__ != "LLM2VecOri": 63 | for i in tqdm(range(0, len(prompts), batch_size), desc="Processing batches"): 64 | x = prompts[i:i+batch_size] 65 | embeds = model(x) 66 | item_llama_embeds.append(embeds) 67 | item_llama_embeds = torch.cat(item_llama_embeds, dim=0).cpu().numpy() 68 | else: 69 | item_llama_embeds = model(prompts, batch_size) 70 | 71 | 72 | save_path = f"./item_info/{dataset_name}/" 73 | os.makedirs(save_path, exist_ok=True) 74 | 75 | if save_info is not None: 76 | model_name = f"{save_info}" 77 | else: 78 | model_name = type(model).__name__ 79 | np.save(op.join(save_path, f"{model_name}_{prompt_type}_item_embs.npy"), item_llama_embeds) 80 | 81 | 82 | def extract_sequence_embeddings(model, dataset_name, batch_size, prompt_type, save_info=None, max_seq_length=10): 83 | # Load data here 84 | raw_dataset_name = dataset_name_mappings[dataset_name] 85 | with open(f"./data/{raw_dataset_name}/item_titles.json", 'r', encoding='utf-8') as file: 86 | item_metadata = json.load(file) 87 | 88 | if dataset_name in dataset_name_mappings: 89 | item_ids = [int(int_id) for int_id in item_metadata.keys()] 90 | max_item_id = max(item_ids) 91 | assert 0 not in item_ids, "Item IDs should not contain 0" 92 | 93 | else: 94 | raise ValueError("Invalid dataset name") 95 | 96 | 97 | extract_sequence_types = ['train', 'val', 'test'] 98 | for seq_type in extract_sequence_types: 99 | # Load sequence data 100 | file_path = f"./data/{raw_dataset_name}/{seq_type}_data.txt" 101 | with open(file_path, 'r', encoding='utf-8') as file: 102 | item_seqs = [list(map(int, line.split()))[-max_seq_length-1: -1] for line in file] 103 | for seq in item_seqs: 104 | assert len(seq) != 0 and len(seq) <= max_seq_length 105 | 106 | sequence_titles = [] 107 | for item_seq in item_seqs: 108 | sequnce_titles = "#item {" + "}, #item {".join([item_metadata[str(item_id)] for item_id in item_seq]) + "}" 109 | title = f"Given the user with the following historical interactions: Predict the next item that this user would like to interact with {sequnce_titles}" 110 | sequence_titles.append(title) 111 | 112 | 113 | item_infos = np.array(sequence_titles) 114 | prompts = item_infos 115 | 116 | seq_embeds = [] 117 | 118 | if type(model).__name__ != "LLM2VecOri": 119 | for i in tqdm(range(0, len(prompts), batch_size), desc="Processing batches"): 120 | x = prompts[i:i+batch_size] 121 | embeds = model(x) 122 | seq_embeds.append(embeds) 123 | item_llama_embeds = torch.cat(item_llama_embeds, dim=0).cpu().numpy() 124 | else: 125 | seq_embeds = model(prompts, batch_size) 126 | 127 | # for i in range(0, len(prompts), batch_size): 128 | # x = prompts[i:i+batch_size] 129 | # embeds = model(x) 130 | # seq_embeds.append(embeds) 131 | 132 | save_path = f"./item_info/{dataset_name}/" 133 | if not os.path.isdir(save_path): 134 | os.makedirs(save_path) 135 | 136 | if save_info is not None: 137 | model_name = f"{save_info}" 138 | else: 139 | model_name = type(model).__name__ 140 | np.save(op.join(save_path, f"{model_name}_{seq_type}_seq_embs.npy"), seq_embeds) 141 | 142 | 143 | def generate_direct_item_prompt_pog(item_info): 144 | instruct = "To recommend this fashion item to users, this item can be described as: " 145 | instructs = np.repeat(instruct, len(item_info)) 146 | prompts = item_info 147 | 148 | outputs = np.concatenate((instructs[:, np.newaxis], prompts[:, np.newaxis]), axis=1) 149 | return outputs 150 | 151 | 152 | def generate_title_item_prompt_pog(item_info): 153 | # instruct = "" 154 | # instructs = np.repeat(instruct, len(item_info)) 155 | prompts = item_info 156 | # outputs = np.concatenate((instructs[:, np.newaxis], prompts[:, np.newaxis]), axis=1) 157 | return prompts 158 | 159 | def main( 160 | model_name="blair", # blair, llm2vec 161 | mode="item", # or sequence 162 | dataset_name = "Yelp", 163 | ): 164 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 165 | 166 | if model_name == "easyrec": 167 | model = EasyRec(device) 168 | elif model_name == "blair": 169 | model = Blair(device) 170 | elif model_name == "bge": 171 | model = BGE(device) 172 | elif model_name == "bert": 173 | model = BERT(device) 174 | elif model_name == "gte_7b": 175 | model = GTE_7B(device) 176 | elif model_name == "roberta_large_sentence": 177 | model = RoBERTa_large_sentence(device) 178 | 179 | # model.to("cuda") 180 | 181 | if mode == "item": 182 | extract_item_embeddings(model, dataset_name, 64, "title") 183 | elif mode == "sequence": 184 | extract_sequence_embeddings(model, dataset_name, 8, "title") 185 | else: 186 | raise ValueError("Invalid mode") 187 | 188 | if __name__ == "__main__": 189 | 190 | datasets = ["Goodreads"] 191 | baselines = ["easyrec"] 192 | 193 | for dataset in datasets: 194 | for baseline in baselines: 195 | print(f"Processing {baseline} on {dataset}") 196 | main(model_name=baseline, mode="item", dataset_name=dataset) 197 | 198 | -------------------------------------------------------------------------------- /seqrec/models/GRU4Rec/_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from seqrec.base import AbstractModel 5 | from seqrec.modules import in_batch_negative_sampling, extract_axis_1 6 | from ..Embedding2 import Embedding2 7 | 8 | 9 | class GRU4Rec(AbstractModel): 10 | def __init__(self, config: dict, pretrained_item_embeddings=None): 11 | super(GRU4Rec, self).__init__(config=config) 12 | self.config = config 13 | self.load_item_embeddings(pretrained_item_embeddings) 14 | 15 | self.gru_layers = nn.GRU( 16 | input_size=config['hidden_size'], 17 | hidden_size=config['hidden_size'], 18 | num_layers=config['layer_num'], 19 | bias=False, 20 | batch_first=True, 21 | ) 22 | self.emb_dropout = nn.Dropout(config['dropout']) 23 | 24 | # Initialize loss function 25 | if config['loss_type'] == 'bce': 26 | self.loss_func = nn.BCEWithLogitsLoss() 27 | elif config['loss_type'] == "ce": 28 | self.loss_func = nn.CrossEntropyLoss() 29 | 30 | def load_item_embeddings(self, pretrained_embs): 31 | if pretrained_embs is None: 32 | self.item_embeddings = nn.Embedding( 33 | num_embeddings=self.config['item_num'] + 1, 34 | embedding_dim=self.config['hidden_size'], 35 | padding_idx=0 36 | ) 37 | nn.init.normal_(self.item_embeddings.weight, 0, 1) 38 | else: 39 | more_token = 0 40 | self.pretrained_item_embeddings = nn.Embedding.from_pretrained( 41 | torch.cat([ 42 | pretrained_embs[:self.config['item_num']+1], 43 | torch.randn(more_token, pretrained_embs.shape[-1]).to(pretrained_embs.device) 44 | ]), 45 | padding_idx=0 46 | ) 47 | # fix pretrained item embedding 48 | self.pretrained_item_embeddings.weight.requires_grad = False 49 | self.pretrained_item_embeddings.weight[-more_token:].requires_grad = True 50 | 51 | assert self.config['adapter_dims'][-1] == -1 52 | mlp_dims = [self.pretrained_item_embeddings.embedding_dim] + self.config['adapter_dims'] 53 | mlp_dims[-1] = self.config['hidden_size'] 54 | 55 | # create mlp with linears and activations 56 | self.item_embeddings_adapter = nn.Sequential() 57 | self.item_embeddings_adapter.add_module('linear_0', nn.Linear(mlp_dims[0], mlp_dims[1])) 58 | for i in range(1, len(mlp_dims) - 1): 59 | self.item_embeddings_adapter.add_module(f'activation_{i}', nn.ReLU()) 60 | self.item_embeddings_adapter.add_module(f'linear_{i}', nn.Linear(mlp_dims[i], mlp_dims[i + 1])) 61 | 62 | # initialize the adapter 63 | for name, param in self.item_embeddings_adapter.named_parameters(): 64 | if 'weight' in name: 65 | nn.init.xavier_normal_(param) 66 | elif 'bias' in name: 67 | nn.init.constant_(param, 0) 68 | 69 | self.item_embedding_pretrained = True 70 | 71 | self.item_embeddings = Embedding2(self.item_embeddings_adapter, self.pretrained_item_embeddings) 72 | 73 | if self.config.get('aug', None) == 'sub': 74 | self.category_embedding = nn.Embedding(self.config['sub_head'], 75 | self.config['hidden_size'] // self.config['sub_head']) 76 | nn.init.normal_(self.category_embedding.weight, 0, 1) 77 | 78 | 79 | def get_embeddings(self, items): 80 | if self.config.get('aug', None) == 'sub': 81 | return self.item_embeddings(items) * self.get_sub_embeddings(items) 82 | else: 83 | return self.item_embeddings(items) 84 | 85 | def get_all_embeddings(self, device=None): 86 | if self.config.get('aug', None) == 'sub': 87 | return self.item_embeddings.weight * self.get_sub_embeddings(self.item_embeddings.weight) 88 | else: 89 | return self.item_embeddings.weight.data 90 | 91 | def get_current_embeddings(self, device=None): 92 | if self.config.get('aug', None) == 'sub': 93 | item_embeddings = self.item_embeddings.weight * self.get_sub_embeddings(self.item_embeddings.weight) 94 | return item_embeddings[self.config['select_pool'][0]:self.config['select_pool'][1]] 95 | else: 96 | return self.item_embeddings.weight.data[self.config['select_pool'][0]:self.config['select_pool'][1]] 97 | 98 | 99 | def get_representation(self, batch): 100 | inputs_emb = self.get_embeddings(batch['item_seqs']) 101 | seq = self.emb_dropout(inputs_emb) 102 | mask = torch.ne(batch['item_seqs'], 0).float().unsqueeze(-1).to(inputs_emb.device) 103 | seq = seq * mask 104 | seq, _ = self.gru_layers(seq) 105 | state_hidden = extract_axis_1(seq, batch['seq_lengths'] - 1).squeeze() 106 | return state_hidden 107 | 108 | def forward(self, batch): 109 | state_hidden = self.get_representation(batch) 110 | test_item_emb = self.get_all_embeddings(state_hidden.device) 111 | if self.config['loss_type'] == 'bce': 112 | labels_neg = self._generate_negative_samples(batch) 113 | labels_neg = labels_neg.view(-1, 1) 114 | logits = torch.matmul(state_hidden, test_item_emb.transpose(0, 1)) 115 | pos_scores = torch.gather(logits, 1, batch['labels'].view(-1, 1)) 116 | neg_scores = torch.gather(logits, 1, labels_neg) 117 | pos_labels = torch.ones((batch['labels'].view(-1).shape[0], 1)) 118 | neg_labels = torch.zeros((batch['labels'].view(-1).shape[0], 1)) 119 | 120 | scores = torch.cat((pos_scores, neg_scores), 0) 121 | labels = torch.cat((pos_labels, neg_labels), 0) 122 | labels = labels.to(state_hidden.device) 123 | loss = self.loss_func(scores, labels) 124 | 125 | elif self.config['loss_type'] == 'ce': 126 | logits = torch.matmul(state_hidden, test_item_emb.transpose(0, 1)) 127 | loss = self.loss_func(logits, batch['labels'].view(-1)) 128 | return {'loss': loss} 129 | 130 | def predict(self, batch, n_return_sequences=1): 131 | state_hidden = self.get_representation(batch).view(-1, self.config['hidden_size']) 132 | test_item_emb = self.get_all_embeddings(state_hidden.device) 133 | scores = torch.matmul(state_hidden, test_item_emb.transpose(0, 1))[:, 134 | self.config['select_pool'][0]: self.config['select_pool'][1]] 135 | preds = scores.topk(n_return_sequences, dim=-1).indices + self.config['select_pool'][0] 136 | return preds 137 | 138 | def _generate_negative_samples(self, batch): 139 | if self.config['sample_func'] == 'batch': 140 | return in_batch_negative_sampling(batch['labels']) 141 | 142 | labels_neg = [] 143 | for index in range(len(batch['labels'])): 144 | import numpy as np 145 | neg_samples = np.random.choice(range(self.config['select_pool'][0], self.config['select_pool'][1]), size=1, 146 | replace=False) 147 | neg_samples = neg_samples[neg_samples != batch['labels'][index]] 148 | labels_neg.append(neg_samples.tolist()) 149 | return torch.LongTensor(labels_neg).to(batch['labels'].device).reshape(-1, 1) 150 | -------------------------------------------------------------------------------- /seqrec/runner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from typing import Union 4 | 5 | from accelerate import Accelerator 6 | from torch.utils.data import DataLoader, Sampler 7 | import wandb 8 | 9 | from .recdata import NormalRecData 10 | from .base import AbstractModel 11 | 12 | from .utils import get_config, init_device, init_seed, get_model, get_file_name, diagonalize_and_scale 13 | from .trainer import BaseTrainer 14 | 15 | 16 | class Runner: 17 | def __init__( 18 | self, 19 | model_name: Union[str, AbstractModel], 20 | config_dict: dict = None, 21 | config_file: str = None, 22 | ): 23 | self.config = get_config( 24 | model_name=model_name, 25 | config_file=config_file, 26 | config_dict=config_dict 27 | ) 28 | print(self.config) 29 | 30 | # Automatically set devices and ddp 31 | self.config['device'], self.config['use_ddp'] = init_device() 32 | 33 | wandb.init( 34 | project="LLM2Rec_Eval", # Replace with your project name 35 | name=get_file_name(self.config), # Set the desired run name 36 | ) 37 | self.accelerator = Accelerator(log_with='wandb') 38 | 39 | self.config['accelerator'] = self.accelerator 40 | 41 | init_seed(self.config['rand_seed'], self.config['reproducibility']) 42 | _ = NormalRecData(self.config).load_data() 43 | 44 | self.recdata = { 45 | 'train': _[0], 46 | 'valid': _[1], 47 | 'test': _[2] 48 | } 49 | self.config['select_pool'] = _[3] 50 | self.config['item_num'] = _[4] 51 | self.config['eos_token'] = _[4] + 1 52 | 53 | if self.config['embedding']: 54 | pretrained_item_embeddings = torch.tensor(np.load(self.config['embedding']), dtype=torch.float32).to(self.config['device']) 55 | # judge if "seq_embedding" in config.keys() 56 | if "seq_embedding" in self.config.keys() and self.config['seq_embedding']: 57 | base_seq_embedding_path = self.config['seq_embedding'] 58 | train_seq_embedding_path = base_seq_embedding_path.format("train") 59 | valid_seq_embedding_path = base_seq_embedding_path.format("val") 60 | test_seq_embedding_path = base_seq_embedding_path.format("test") 61 | train_seq_embedding = torch.tensor(np.load(train_seq_embedding_path), dtype=torch.float32).to(self.config['device']) 62 | valid_seq_embedding = torch.tensor(np.load(valid_seq_embedding_path), dtype=torch.float32).to(self.config['device']) 63 | test_seq_embedding = torch.tensor(np.load(test_seq_embedding_path), dtype=torch.float32).to(self.config['device']) 64 | pretrained_item_embeddings = [pretrained_item_embeddings, train_seq_embedding, valid_seq_embedding, test_seq_embedding] 65 | 66 | else: 67 | pretrained_item_embeddings = None 68 | 69 | with self.accelerator.main_process_first(): 70 | self.model = get_model(model_name)(self.config, pretrained_item_embeddings) 71 | 72 | print(self.model) 73 | # print(self.model.n_parameters) 74 | self.trainer = BaseTrainer(self.config, self.model) 75 | 76 | def run(self): 77 | train_dataloader = DataLoader( 78 | self.recdata['train'], 79 | batch_size=self.config['train_batch_size'], 80 | shuffle=True, 81 | ) 82 | val_dataloader = DataLoader( 83 | self.recdata['valid'], 84 | batch_size=self.config['eval_batch_size'], 85 | shuffle=False, 86 | ) 87 | test_dataloader = DataLoader( 88 | self.recdata['test'], 89 | batch_size=self.config['eval_batch_size'], 90 | shuffle=False, 91 | ) 92 | 93 | # skip training for ItemKNN model 94 | if self.config['model'] != 'ItemKNN': 95 | self.trainer.train(train_dataloader, val_dataloader) 96 | 97 | self.accelerator.wait_for_everyone() 98 | self.model = self.accelerator.unwrap_model(self.model) 99 | 100 | if self.config.get('steps', None) != 0: 101 | self.model.load_state_dict(torch.load(self.trainer.saved_model_ckpt)) 102 | else: 103 | """ 104 | SASRec: ckpt/PDSRec-main.py_--model=SASRec_--sd=B_--td=B_--loss_type=ce_--lr=1e-2_--exp_type=lr-Sep-14-2024_09-29-11-ac20ba.pth 105 | DreamRec: ckpt/PDSRec-main.py_--model=DreamRec_--sd=B_--td=B_--hidden_size=3072_--exp_type=dim-Sep-13-2024_15-26-48-9c76db.pth 106 | Ours: ckpt/PDSRec-main.py_--model=PDSRec_--sd=B_--td=B_--loss_type=cosine_--ab=iids_--hidden_size=3072_--exp_type=ab-Sep-14-2024_15-53-12-06c9df.pth 107 | """ 108 | ckpt_dict = { 109 | 'SASRec': 'ckpt/PDSRec-main.py_--model=SASRec_--sd=B_--td=B_--loss_type=ce_--lr=1e-2_--exp_type=lr-Sep-14-2024_09-29-11-ac20ba.pth', 110 | 'DreamRec': 'ckpt/PDSRec-main.py_--model=DreamRec_--sd=B_--td=B_--hidden_size=3072_--exp_type=dim-Sep-13-2024_15-26-48-9c76db.pth', 111 | 'PDSRec': 'ckpt/PDSRec-main.py_--model=PDSRec_--sd=B_--td=B_--loss_type=cosine_--ab=iids_--hidden_size=3072_--exp_type=ab-Sep-14-2024_15-53-12-06c9df.pth' 112 | } 113 | self.model.load_state_dict(torch.load(ckpt_dict[self.config['model']])) 114 | embeddings = self.model.get_current_embeddings() 115 | embeddings_np = embeddings.cpu().numpy() 116 | import numpy as np 117 | np.save('{}_{}_embeddings.npy'.format(self.config['model'], self.config['dataset']), embeddings_np) 118 | 119 | self.model, test_dataloader = self.accelerator.prepare( 120 | self.model, test_dataloader 121 | ) 122 | if self.accelerator.is_main_process: 123 | print(f'Loaded best model checkpoint from {self.trainer.saved_model_ckpt}') 124 | 125 | if self.config.get('step', None) != 0: 126 | test_results = self.trainer.evaluate(test_dataloader) 127 | print(test_results) 128 | if self.accelerator.is_main_process: 129 | for key in test_results: 130 | self.accelerator.log({f'Test_Metric/{key}': test_results[key]}) 131 | 132 | if self.config['exp_type'] == 'check': 133 | np.save('{}_{}_vis_embeddings.npy'.format(self.config['model'], self.config['dataset']), 134 | np.array(self.model.samples)) 135 | np.save('{}_{}_pred_embeddings.npy'.format(self.config['model'], self.config['dataset']), 136 | np.array(self.model.predict_embeddings.detach().cpu().numpy())) 137 | np.save('{}_{}_target_embeddings.npy'.format(self.config['model'], self.config['dataset']), 138 | np.array(self.model.target_embedding.detach().cpu().numpy())) 139 | 140 | if self.accelerator.is_main_process: 141 | if self.config['save'] is False: 142 | import os 143 | if os.path.exists(self.trainer.saved_model_ckpt): 144 | os.remove(self.trainer.saved_model_ckpt) 145 | print(f"{self.trainer.saved_model_ckpt} has been deleted.") 146 | else: 147 | print(f"{self.trainer.saved_model_ckpt} not found.") 148 | 149 | self.trainer.end() 150 | return test_results, self.config 151 | 152 | -------------------------------------------------------------------------------- /seqrec/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import yaml 4 | import importlib 5 | import datetime 6 | from accelerate.utils import set_seed 7 | from typing import Union, Optional 8 | import torch 9 | 10 | from .base import AbstractModel 11 | 12 | 13 | def init_seed(seed: int, reproducibility: bool): 14 | """ 15 | Initialize random seeds for reproducibility across random functions in numpy, torch, cuda, and cudnn. 16 | 17 | Args: 18 | seed (int): Random seed value. 19 | reproducibility (bool): Whether to enforce reproducibility. 20 | """ 21 | import random 22 | import numpy as np 23 | import torch 24 | 25 | random.seed(seed) 26 | np.random.seed(seed) 27 | torch.manual_seed(seed) 28 | torch.cuda.manual_seed(seed) 29 | torch.cuda.manual_seed_all(seed) 30 | set_seed(seed) 31 | torch.backends.cudnn.benchmark = not reproducibility 32 | torch.backends.cudnn.deterministic = reproducibility 33 | 34 | 35 | def get_local_time() -> str: 36 | """ 37 | Get the current local time in a specific format. 38 | 39 | Returns: 40 | str: Current time formatted as "Month-Day-Year_Hour-Minute-Second". 41 | """ 42 | return datetime.datetime.now().strftime("%b-%d-%Y_%H-%M-%S") 43 | 44 | 45 | def get_command_line_args_str() -> str: 46 | """ 47 | Get the command line arguments as a single string, with '/' replaced by '|'. 48 | 49 | Returns: 50 | str: The command line arguments. 51 | """ 52 | return '_'.join(sys.argv).replace('/', '|') 53 | 54 | 55 | def get_model(model_name: Union[str, AbstractModel]) -> AbstractModel: 56 | """ 57 | Retrieve the model class based on the provided model name. 58 | 59 | Args: 60 | model_name (Union[str, AbstractModel]): The name or instance of the model. 61 | 62 | Returns: 63 | AbstractModel: The model class corresponding to the provided model name. 64 | 65 | Raises: 66 | ValueError: If the model name is not found. 67 | """ 68 | if isinstance(model_name, AbstractModel): 69 | return model_name 70 | 71 | try: 72 | model_class = getattr(importlib.import_module('seqrec.models'), model_name) 73 | except AttributeError: 74 | raise ValueError(f'Model "{model_name}" not found.') 75 | 76 | return model_class 77 | 78 | def get_mapper(model_name: str): 79 | """ 80 | Retrieves the mapper for a given model name. 81 | 82 | Args: 83 | model_name (str): The model name. 84 | 85 | Returns: 86 | AbstractMapper: The tokenizer for the given model name. 87 | 88 | Raises: 89 | ValueError: If the tokenizer is not found. 90 | """ 91 | try: 92 | mapper_class = getattr( 93 | importlib.import_module(f'seqrec.models.{model_name}._mapper'), 94 | f'{model_name}Mapper' 95 | ) 96 | except: 97 | raise ValueError(f'Mapper for model "{model_name}" not found.') 98 | return mapper_class 99 | 100 | 101 | def parse_command_line_args(unparsed: list[str]) -> dict: 102 | """ 103 | Parse command line arguments into a dictionary. 104 | 105 | Args: 106 | unparsed (list[str]): List of unparsed command line arguments. 107 | 108 | Returns: 109 | dict: Parsed arguments as key-value pairs. 110 | 111 | Raises: 112 | ValueError: If the argument format is invalid. 113 | """ 114 | args = {} 115 | for arg in unparsed: 116 | if '=' not in arg: 117 | raise ValueError(f"Invalid command line argument: {arg}. Expected format is '--key=value'.") 118 | key, value = arg.split('=') 119 | key = key.lstrip('--') 120 | try: 121 | value = eval(value) 122 | except (NameError, SyntaxError): 123 | pass 124 | args[key] = value 125 | 126 | return args 127 | 128 | 129 | def init_device() -> tuple: 130 | """ 131 | Set the visible devices for training, supporting multiple GPUs. 132 | 133 | Returns: 134 | tuple: A tuple containing the torch device and whether DDP (Distributed Data Parallel) is enabled. 135 | """ 136 | import torch 137 | 138 | use_ddp = bool(os.environ.get("WORLD_SIZE")) 139 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 140 | 141 | return device, use_ddp 142 | 143 | 144 | def get_config( 145 | model_name: Union[str, AbstractModel], 146 | config_file: Union[str, list[str], None], 147 | config_dict: Optional[dict] 148 | ) -> dict: 149 | """ 150 | Get the configuration for a model and dataset. 151 | 152 | Args: 153 | model_name (Union[str, AbstractModel]): The name or instance of the model. 154 | dataset_name (Union[str, AbstractDataset]): The name or instance of the dataset. 155 | config_file (Union[str, list[str], None]): Additional configuration file(s). 156 | config_dict (Optional[dict]): Dictionary of additional configuration options. 157 | 158 | Returns: 159 | dict: The final configuration dictionary. 160 | 161 | Raises: 162 | FileNotFoundError: If any of the specified configuration files are missing. 163 | """ 164 | final_config = {} 165 | current_path = os.path.dirname(os.path.realpath(__file__)) 166 | config_file_list = [os.path.join(current_path, 'default.yaml')] 167 | 168 | 169 | if isinstance(model_name, str): 170 | config_file_list.append(os.path.join(current_path, f'models/{model_name}/config.yaml')) 171 | final_config['model'] = model_name 172 | else: 173 | final_config['model'] = model_name.__class__.__name__ 174 | 175 | if config_file: 176 | if isinstance(config_file, str): 177 | config_file = [config_file] 178 | config_file_list.extend(config_file) 179 | 180 | for file in config_file_list: 181 | with open(file, 'r') as f: 182 | cur_config = yaml.safe_load(f) 183 | if cur_config: 184 | final_config.update(cur_config) 185 | 186 | if config_dict: 187 | final_config.update(config_dict) 188 | 189 | final_config['run_local_time'] = get_local_time() 190 | return convert_config_dict(final_config) 191 | 192 | def get_total_steps(config, train_dataloader): 193 | """ 194 | Calculate the total number of steps for training based on the given configuration and dataloader. 195 | 196 | Args: 197 | config (dict): The configuration dictionary containing the training parameters. 198 | train_dataloader (DataLoader): The dataloader for the training dataset. 199 | 200 | Returns: 201 | int: The total number of steps for training. 202 | 203 | """ 204 | if config['steps'] is not None: 205 | return config['steps'] 206 | else: 207 | return len(train_dataloader) * config['epochs'] 208 | 209 | def convert_config_dict(config: dict) -> dict: 210 | """ 211 | Convert configuration values in a dictionary to their appropriate types. 212 | 213 | Args: 214 | config (dict): The dictionary containing the configuration values. 215 | 216 | Returns: 217 | dict: The dictionary with converted values. 218 | """ 219 | for key, value in config.items(): 220 | if isinstance(value, str): 221 | try: 222 | new_value = eval(value) 223 | if new_value is not None and not isinstance(new_value, (str, int, float, bool, list, dict, tuple)): 224 | new_value = value 225 | except (NameError, SyntaxError, TypeError): 226 | new_value = value.lower() == 'true' if value.lower() in ['true', 'false'] else value 227 | config[key] = new_value 228 | 229 | return config 230 | 231 | def get_file_name(config: dict, suffix: str = ''): 232 | import hashlib 233 | config_str = "".join([str(value) for key, value in config.items() if key != 'accelerator']) 234 | md5 = hashlib.md5(config_str.encode(encoding="utf-8")).hexdigest()[:6] 235 | command_line_args = get_command_line_args_str() 236 | logfilename = "{}-{}-{}-{}{}".format( 237 | config["run_id"], command_line_args[:50], config['run_local_time'], md5, suffix 238 | ) 239 | return logfilename 240 | 241 | 242 | def diagonalize_and_scale(e, epsilon=1e-7): 243 | var_e = torch.cov(e.T) 244 | mean_e = torch.mean(e, axis=0) 245 | eigvals, eigvecs = torch.linalg.eigh(var_e) 246 | eigvals = eigvals + epsilon 247 | D = torch.diag(1.0 / torch.sqrt(eigvals)) 248 | O = eigvecs 249 | transformed_e = (e - mean_e) @ O @ D 250 | 251 | return transformed_e -------------------------------------------------------------------------------- /llm2rec/recdata/SeqRecData.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | import os 4 | import pandas as pd 5 | 6 | from .dataset import DataSample, TrainSample, Dataset 7 | from accelerate.logging import get_logger 8 | 9 | logger = get_logger(__name__, log_level="INFO") 10 | 11 | AMAZON_TRAIN_DATA_PATH_MAPPING = { 12 | "Arts": "Arts_Crafts_and_Sewing/5-core/train/Arts_Crafts_and_Sewing_5_2014-9-2023-10.csv", 13 | "Electronics": "Electronics/5-core/train/Electronics_5_2016-9-2023-10.csv", 14 | "Home": "Home_and_Kitchen/5-core/train/Home_and_Kitchen_5_2016-9-2023-10.csv", 15 | "Movies": "Movies_and_TV/5-core/train/Movies_and_TV_5_2019-9-2023-10.csv", 16 | "Tools": "Tools_and_Home_Improvement/5-core/train/Tools_and_Home_Improvement_5_2016-9-2023-10.csv", 17 | "Games": "Video_Games/5-core/train/Video_Games_5_1996-9-2023-10.csv", 18 | } 19 | 20 | AMAZON_ITEM_INFO_MAPPING = { 21 | "Arts": "Arts_Crafts_and_Sewing/5-core/downstream/item_titles.json", 22 | "Electronics": "Electronics/5-core/downstream/item_titles.json", 23 | "Home": "Home_and_Kitchen/5-core/downstream/item_titles.json", 24 | "Movies": "Movies_and_TV/5-core/downstream/item_titles.json", 25 | "Tools": "Tools_and_Home_Improvement/5-core/downstream/item_titles.json", 26 | "Games": "Video_Games/5-core/downstream/item_titles.json", 27 | } 28 | 29 | NUM_TRAINING_SAMPLES = 100000 30 | 31 | 32 | class SeqRecData(Dataset): 33 | def __init__( 34 | self, 35 | dataset_name: str = "Rec", 36 | split: str = "validation", 37 | file_path: str = "./data", 38 | effective_batch_size: int = 32, 39 | shuffle_individual_datasets: bool = True, 40 | separator: str = "!@#$%^&*()", 41 | data_augmentation: bool = False, 42 | augmentation_rate: float = 0.2, 43 | ): 44 | self.dataset_name = dataset_name 45 | self.split = split 46 | self.effective_batch_size = effective_batch_size 47 | self.shuffle_individual_datasets = shuffle_individual_datasets 48 | self.separator = separator 49 | self.data_augmentation = data_augmentation 50 | self.augmentation_rate = augmentation_rate 51 | 52 | # list storing all item titles for random negative sampling 53 | self.negative_item_pool = [] 54 | 55 | self.data = [] 56 | self.load_data(file_path) 57 | 58 | # remove NoneType samples 59 | self.negative_item_pool = [item for item in self.negative_item_pool if item is not None] 60 | 61 | 62 | def __len__(self): 63 | return len(self.data) 64 | 65 | def load_data(self, file_path: str = None): 66 | logger.info(f"Loading SeqRec data from {file_path}...") 67 | # file path is actually a directory 68 | 69 | data_map = {} 70 | all_samples = [] 71 | id_ = 0 72 | for dataset in AMAZON_TRAIN_DATA_PATH_MAPPING: 73 | logger.info(f"Loading dataset {dataset}...") 74 | if dataset not in data_map: 75 | data_map[dataset] = [] 76 | 77 | dataset_samples = self.process_data(file_path, dataset, self.data_augmentation, self.augmentation_rate) 78 | 79 | for i, sample in enumerate(dataset_samples): 80 | query = self.separator + sample['query'] 81 | pos = self.separator + sample["positive"] 82 | neg = self.separator + sample["negative"] 83 | 84 | data_map[dataset].append(id_) 85 | self.negative_item_pool.append(neg) 86 | 87 | if not self.data_augmentation: 88 | all_samples.append( 89 | DataSample( 90 | id_=id_, 91 | query=query, 92 | positive=pos, 93 | negative=neg, 94 | task_name=dataset, 95 | ) 96 | ) 97 | else: 98 | aug_query = self.separator + sample["aug_query"] 99 | all_samples.append( 100 | DataSample( 101 | id_=id_, 102 | query=query, 103 | positive=pos, 104 | negative=neg, 105 | task_name=dataset, 106 | aug_query=aug_query, 107 | ) 108 | ) 109 | id_ += 1 110 | 111 | # combine split1 and split2 112 | new_data_map = {} 113 | for dataset in data_map: 114 | new_dataset = dataset.replace("_split1", "").replace("_split2", "") 115 | if new_dataset not in new_data_map: 116 | new_data_map[new_dataset] = [] 117 | new_data_map[new_dataset] += data_map[dataset] 118 | data_map = new_data_map 119 | 120 | if self.shuffle_individual_datasets: 121 | for task, samples in data_map.items(): 122 | random.shuffle(samples) 123 | 124 | datasets = list(data_map.keys()) 125 | 126 | logger.info( 127 | f"Batching REC data properly for effective batch size of {self.effective_batch_size}..." 128 | ) 129 | all_batches = [] 130 | for dataset in datasets: 131 | dataset_samples = data_map[dataset] 132 | for i in range(0, len(dataset_samples), self.effective_batch_size): 133 | batch = dataset_samples[i : i + self.effective_batch_size] 134 | if len(batch) == self.effective_batch_size: 135 | all_batches.append(batch) 136 | else: 137 | logger.info(f"Skip 1 batch for dataset {dataset}.") 138 | random.shuffle(all_batches) 139 | 140 | final_idx_order = [] 141 | for batch in all_batches: 142 | for idx in batch: 143 | final_idx_order.append(idx) 144 | 145 | self.data = [all_samples[idx] for idx in final_idx_order] 146 | logger.info(f"Loaded {len(self.data)} samples.") 147 | 148 | def __getitem__(self, index): 149 | sample = self.data[index] 150 | if self.split == "train": 151 | if not self.data_augmentation: 152 | return TrainSample( 153 | texts=[sample.query, sample.positive, sample.negative], label=1.0 154 | ) 155 | else: 156 | return TrainSample( 157 | texts=[sample.query, sample.positive, sample.negative, sample.aug_query], label=1.0 158 | ) 159 | 160 | elif self.split == "validation": 161 | assert False, "SeqRecData does not have a validation split." 162 | 163 | 164 | def process_data(self, file_path, dataset_name, data_augmentation=False, augmentation_rate=0.2): 165 | item_info_path = AMAZON_ITEM_INFO_MAPPING[dataset_name] 166 | item_info_path = os.path.join(file_path, item_info_path) 167 | with open(item_info_path, "r") as f: 168 | item_info = json.load(f) 169 | # change key of item_info from string to int 170 | assert "0" not in item_info 171 | 172 | # The item_info is a dictionary with keys starting from 1. We need to change the keys to start from 0 173 | item_info = {int(k) - 1: v for k, v in item_info.items()} 174 | candidiate_item_ids = list(item_info.keys()) 175 | 176 | train_data_path = AMAZON_TRAIN_DATA_PATH_MAPPING[dataset_name] 177 | train_data_path = os.path.join(file_path, train_data_path) 178 | dataset_samples = pd.read_csv(train_data_path) 179 | 180 | # random sample a fixed number of samples 181 | dataset_samples = dataset_samples.sample(n=NUM_TRAINING_SAMPLES, random_state=42) 182 | 183 | # interate through all samples 184 | samples = [] 185 | for i, row in dataset_samples.iterrows(): 186 | his_ids = eval(row["history_item_id"]) 187 | pos_id = row["item_id"] 188 | neg_id = random.choice(candidiate_item_ids) 189 | while neg_id == pos_id or neg_id in his_ids: 190 | neg_id = random.choice(candidiate_item_ids) 191 | 192 | his_titles = eval(row["history_item_title"]) 193 | pos_title = row["item_title"] 194 | neg_title = item_info[neg_id] 195 | 196 | if data_augmentation: 197 | if len(his_ids) <= 2: 198 | aug_his_ids = his_ids 199 | else: 200 | num_items_to_drop = int(len(his_ids) * augmentation_rate) 201 | if num_items_to_drop == 0: 202 | num_items_to_drop = 1 203 | 204 | remaining_ids = random.sample(his_ids, len(his_ids) - num_items_to_drop) 205 | aug_his_ids = [item_id for item_id in his_ids if item_id in remaining_ids] 206 | 207 | aug_his_titles = [item_info[item_id] for item_id in aug_his_ids] 208 | 209 | 210 | samples.append({ 211 | "query": ", ".join(his_titles), 212 | "positive": pos_title, 213 | "negative": neg_title, 214 | "aug_query": ", ".join(aug_his_titles) if data_augmentation else None, 215 | }) 216 | 217 | return samples 218 | 219 | 220 | def generate_negative_samples(self, num_samples): 221 | return random.sample(self.negative_item_pool, num_samples) 222 | -------------------------------------------------------------------------------- /llm2rec/run_csft.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["NCCL_P2P_DISABLE"] = "1" # 禁用 NVLink 3 | os.environ["NCCL_IB_DISABLE"] = "1" # 禁用 InfiniBand,如果适用 4 | os.environ["NCCL_NET_GDR_LEVEL"] = "0" # 禁用 GDR(GPU 直连) 5 | 6 | import sys 7 | from typing import List 8 | import numpy as np 9 | import fire 10 | 11 | import torch 12 | import transformers 13 | from transformers import EarlyStoppingCallback, AutoConfig 14 | from peft import LoraConfig, get_peft_model 15 | from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union 16 | from dataclasses import dataclass 17 | import torch.nn as nn 18 | import math 19 | import warnings 20 | from functools import partial 21 | from torch.optim.lr_scheduler import LambdaLR 22 | """ 23 | Unused imports:` 24 | import torch.nn as nn 25 | import bitsandbytes as bnb 26 | """ 27 | from transformers import AutoModelForCausalLM, AutoTokenizer 28 | from dataset import PurePromptDataset 29 | def _get_cosine_schedule_with_warmup_lr_lambda( 30 | current_step, *, num_warmup_steps, num_training_steps, num_cycles 31 | ): 32 | if current_step < num_warmup_steps: 33 | return max(0.1, float(current_step) / float(max(1, num_warmup_steps))) 34 | progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) 35 | return max(0.1, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) 36 | 37 | def get_cosine_schedule_with_warmup( 38 | optimizer, num_warmup_steps, num_training_steps, num_cycles: float = 0.5, last_epoch: int = -1 39 | ): 40 | 41 | lr_lambda = partial( 42 | _get_cosine_schedule_with_warmup_lr_lambda, 43 | num_warmup_steps=num_warmup_steps, 44 | num_training_steps=num_training_steps, 45 | num_cycles=num_cycles, 46 | ) 47 | return LambdaLR(optimizer, lr_lambda, last_epoch) 48 | 49 | 50 | 51 | def train( 52 | # model/data params 53 | base_model: str = "/home/yzhe/workspace/huggingface_data/hub/Qwen2-0.5B", # the only required argument 54 | train_file: str="./data/AmazonMix6/5-core/train/AmazonMix-6.csv", 55 | eval_file: str="./data/AmazonMix6/5-core/valid/AmazonMix-6.csv", 56 | output_dir: str = "./output/Test-SFT", 57 | sample: int = -1, 58 | seed: int = 0, 59 | 60 | # training hyperparams 61 | batch_size: int = 128, 62 | micro_batch_size: int = 4, 63 | num_epochs: int = 10, 64 | learning_rate: float = 3e-4, 65 | cutoff_len: int = 1024, 66 | # llm hyperparams 67 | train_on_inputs: bool = True, # if False, masks out inputs in loss./output/Mix2-SFT-${category} 68 | group_by_length: bool = False, # faster, but produces an odd training loss curve 69 | # wandb params 70 | wandb_project: str = "", 71 | wandb_run_name: str = "", 72 | wandb_watch: str = "", # options: false | gradients | all 73 | wandb_log_model: str = "", # options: false | true 74 | resume_from_checkpoint: str = None, # either training checkpoint or final adapter 75 | use_lora: bool = False, 76 | 77 | local_rank: int = 0, 78 | deepspeed: str ="./deepspeed.json", 79 | category: str="AmazonMix-6", 80 | K: int = 0, 81 | version: str = "base", 82 | train_from_scratch: bool = False, 83 | ): 84 | os.environ['WANDB_PROJECT'] = wandb_project 85 | # print(train_file) 86 | category_dict = {"AmazonMix-6": "items", "Office_Products": "office products", "Books": "books", "Goodreads": "books", "Steam": "games", "CDs_and_Vinyl": "musics", "Toys_and_Games": "toys and games", "Video_Games": "video games", "Musical_Instruments": "music instruments", "Sports_and_Outdoors": "sports and outdoors", "Pet_Supplies": "pet supplies", "Arts_Crafts_and_Sewing": "arts products", "Movies": "movie", "Industrial_and_Scientific": "industrial and scientific", "Automotive": "automotive products", "Grocery_and_Gourmet_Food": "grocery and gourmet food", "Software": "software", "Pet_Supplies": "pet supply products"} 87 | print(category) 88 | category = category_dict[category] 89 | assert ( 90 | base_model 91 | ), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'" 92 | gradient_accumulation_steps = batch_size // micro_batch_size 93 | 94 | device_map = "auto" 95 | world_size = int(os.environ.get("WORLD_SIZE", 1)) 96 | ddp = world_size != 1 97 | if ddp: 98 | device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} 99 | gradient_accumulation_steps = gradient_accumulation_steps // world_size 100 | 101 | # uses.environ["WANDB_LOG_MODEL"] = wandb_log_model 102 | # os.environ["WANDB_DISABLED"] = "true" 103 | if not train_from_scratch: 104 | model = AutoModelForCausalLM.from_pretrained( 105 | base_model, 106 | # load_in_8bit=True, 107 | torch_dtype=torch.bfloat16, 108 | # device_map=device_map, 109 | trust_remote_code=True, 110 | ) 111 | else: 112 | config = AutoConfig.from_pretrained(base_model) 113 | model = AutoModelForCausalLM.from_config(config) 114 | print("Training from scratch!") 115 | 116 | if use_lora: 117 | # if base model is a Qwen model, use the Qwen model's config 118 | if "Qwen" in base_model: 119 | print("Using Qwen model") 120 | lora_config = LoraConfig( 121 | r=8, 122 | lora_alpha=16, 123 | target_modules=[ 124 | "q_proj", 125 | "v_proj", 126 | "k_proj", 127 | "o_proj", 128 | "gate_proj", 129 | "up_proj", 130 | "down_proj", 131 | ], # Lora settings for Qwen model 132 | lora_dropout=0.05, 133 | bias="none", 134 | task_type="CAUSAL_LM" 135 | ) 136 | elif "Llama" in base_model: 137 | print("Using Llama model") 138 | lora_config = LoraConfig( 139 | r=8, 140 | lora_alpha=16, 141 | target_modules=[ 142 | "q_proj", 143 | "v_proj", 144 | "k_proj", 145 | "o_proj", 146 | "gate_proj", 147 | "up_proj", 148 | "down_proj", 149 | ], # Lora settings for Llama model 150 | lora_dropout=0.05, 151 | bias="none", 152 | task_type="CAUSAL_LM" 153 | ) 154 | 155 | # Wrap the model with LoRA 156 | model = get_peft_model(model, lora_config) 157 | model.print_trainable_parameters() 158 | 159 | 160 | tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True) 161 | tokenizer.pad_token = tokenizer.eos_token 162 | tokenizer.pad_token_id = tokenizer.eos_token_id 163 | tokenizer.padding_side = "left" 164 | 165 | train_data = PurePromptDataset(train_file=train_file, tokenizer=tokenizer, max_len=cutoff_len, sample=sample, seed=seed, category=category, K = K) 166 | # val_data = PurePromptDataset(train_file=eval_file, tokenizer=tokenizer, max_len=cutoff_len, sample=sample, category=category, K = K) 167 | val_data = PurePromptDataset(train_file=eval_file, tokenizer=tokenizer, max_len=cutoff_len, sample=2000, category=category, K = K) 168 | 169 | print("LOAD DATA FINISHED") 170 | 171 | if resume_from_checkpoint: 172 | # Check the available weights and load them 173 | checkpoint_name = os.path.join( 174 | resume_from_checkpoint, "pytorch_model.bin" 175 | ) # Full checkpoint 176 | 177 | if not ddp and torch.cuda.device_count() > 1: 178 | model.is_parallelizable = True 179 | model.model_parallel = True 180 | from datasets import Dataset as HFDataset 181 | hf_train_dataset = HFDataset.from_dict({k: [v[k] for v in train_data] for k in train_data[0].keys()}) 182 | hf_train_dataset = hf_train_dataset.shuffle(seed=seed) 183 | 184 | hf_val_dataset = HFDataset.from_dict({k: [v[k] for v in val_data] for k in val_data[0].keys()}) 185 | trainer = transformers.Trainer( 186 | # deepspeed=deepspeed, 187 | model=model, 188 | train_dataset=hf_train_dataset, 189 | eval_dataset=hf_val_dataset, 190 | args=transformers.TrainingArguments( 191 | # deepspeed=deepspeed, 192 | run_name=wandb_run_name, 193 | per_device_train_batch_size=micro_batch_size, 194 | per_device_eval_batch_size=micro_batch_size, 195 | gradient_accumulation_steps=gradient_accumulation_steps, 196 | warmup_steps=200, 197 | num_train_epochs=num_epochs, 198 | learning_rate=learning_rate, 199 | bf16=True, 200 | logging_steps=1, 201 | optim="adamw_torch", 202 | # evaluation_strategy="epoch", 203 | # save_strategy="epoch", 204 | max_steps=10000, 205 | evaluation_strategy="steps", # Changed from "epoch" to "steps" 206 | eval_steps=2000, # Evaluate every 1000 steps 207 | save_strategy="steps", # Changed from "epoch" to "steps" 208 | save_steps=2000, # Save checkpoint every 1000 steps 209 | 210 | output_dir=output_dir, 211 | save_total_limit=5, 212 | load_best_model_at_end=True, 213 | ddp_find_unused_parameters=False if ddp else None, 214 | group_by_length=group_by_length, 215 | report_to="wandb", 216 | ), 217 | data_collator=transformers.DataCollatorForSeq2Seq( 218 | tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True 219 | ), 220 | callbacks = [EarlyStoppingCallback(early_stopping_patience=5)], 221 | # optimizers=(optimizer, lr_scheduler) 222 | ) 223 | model.config.use_cache = False 224 | trainer.evaluate() 225 | 226 | trainer.train(resume_from_checkpoint=resume_from_checkpoint) 227 | 228 | if use_lora: 229 | model.save_pretrained(output_dir, safe_serialization=True) 230 | else: 231 | model.save_pretrained(output_dir) 232 | 233 | 234 | if __name__ == "__main__": 235 | fire.Fire(train) -------------------------------------------------------------------------------- /llm2rec/run_unsupervised_SimCSE.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from dataclasses import dataclass, field 3 | import os 4 | 5 | import sys 6 | from typing import Any, Dict, List, Optional, Tuple, Union 7 | 8 | import torch 9 | from torch import nn 10 | 11 | from accelerate import Accelerator, DistributedDataParallelKwargs 12 | from accelerate.logging import get_logger 13 | 14 | import transformers 15 | from transformers import ( 16 | MODEL_FOR_MASKED_LM_MAPPING, 17 | HfArgumentParser, 18 | TrainingArguments, 19 | Trainer, 20 | TrainerCallback, 21 | set_seed, 22 | ) 23 | from transformers.trainer_utils import seed_worker 24 | 25 | from peft import LoraConfig, get_peft_model 26 | 27 | from llm2vec import LLM2Vec 28 | from dataset_utils import load_dataset 29 | from llm2vec.loss.utils import load_loss 30 | 31 | from tqdm import tqdm 32 | 33 | transformers.logging.set_verbosity_error() 34 | 35 | logging.basicConfig( 36 | format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", 37 | datefmt="%Y-%m-%d %H:%M:%S", 38 | level=logging.INFO, 39 | ) 40 | logger = get_logger(__name__, log_level="INFO") 41 | MODEL_CONFIG_CLASSES = list(MODEL_FOR_MASKED_LM_MAPPING.keys()) 42 | MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) 43 | 44 | 45 | def initialize_peft( 46 | model, 47 | lora_r: int = 8, 48 | lora_alpha: int = 16, 49 | lora_dropout: float = 0.05, 50 | lora_modules: Optional[List[str]] = None, 51 | ): 52 | if lora_modules is None and model.config.__class__.__name__ in [ 53 | "LlamaConfig", 54 | "MistralConfig", 55 | "GemmaConfig", 56 | "Qwen2Config", 57 | ]: 58 | lora_modules = [ 59 | "q_proj", 60 | "v_proj", 61 | "k_proj", 62 | "o_proj", 63 | "gate_proj", 64 | "up_proj", 65 | "down_proj", 66 | ] 67 | elif lora_modules is None: 68 | raise ValueError("lora_modules must be specified for this model.") 69 | 70 | config = LoraConfig( 71 | r=lora_r, 72 | lora_alpha=lora_alpha, 73 | target_modules=lora_modules, 74 | lora_dropout=lora_dropout, 75 | bias="none", 76 | task_type=None, 77 | ) 78 | 79 | model = get_peft_model(model, config) 80 | print(f"Model's Lora trainable parameters:") 81 | model.print_trainable_parameters() 82 | return model 83 | 84 | 85 | @dataclass 86 | class ModelArguments: 87 | """ 88 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. 89 | """ 90 | 91 | model_name_or_path: Optional[str] = field( 92 | default=None, 93 | metadata={ 94 | "help": ( 95 | "The base model checkpoint for weights initialization. Don't set if you want to train a model from scratch." 96 | ) 97 | }, 98 | ) 99 | peft_model_name_or_path: Optional[str] = field( 100 | default=None, 101 | metadata={"help": ("The PEFT model checkpoint to add on top of base model.")}, 102 | ) 103 | bidirectional: Optional[bool] = field( 104 | default=False, 105 | metadata={ 106 | "help": ( 107 | "Whether to enable bidirectional attention in the model. If set to False, the model will use unidirectional attention." 108 | ) 109 | }, 110 | ) 111 | max_seq_length: Optional[int] = field( 112 | default=None, 113 | metadata={ 114 | "help": ( 115 | "The maximum total input sequence length after tokenization. Sequences longer " 116 | "than this will be truncated." 117 | ) 118 | }, 119 | ) 120 | torch_dtype: Optional[str] = field( 121 | default=None, 122 | metadata={ 123 | "help": ( 124 | "Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the " 125 | "dtype will be automatically derived from the model's weights." 126 | ), 127 | "choices": ["auto", "bfloat16", "float16", "float32"], 128 | }, 129 | ) 130 | attn_implementation: Optional[str] = field( 131 | default="sdpa", 132 | metadata={ 133 | "help": ("The attention implementation to use in the model."), 134 | "choices": ["eager", "sdpa", "flash_attention_2"], 135 | }, 136 | ) 137 | pooling_mode: Optional[str] = field( 138 | default="mean", 139 | metadata={ 140 | "help": ("The pooling mode to use in the model."), 141 | "choices": ["mean", "weighted_mean", "eos_token"], 142 | }, 143 | ) 144 | 145 | 146 | @dataclass 147 | class DataTrainingArguments: 148 | """ 149 | Arguments pertaining to what data we are going to input our model for training and eval. 150 | """ 151 | 152 | dataset_name: Optional[str] = field( 153 | default=None, 154 | metadata={"help": "The name of the dataset to use. Options: E5"}, 155 | ) 156 | dataset_file_path: Optional[str] = field( 157 | default=None, metadata={"help": "The input training data file or folder."} 158 | ) 159 | # TODO: implement this 160 | max_train_samples: Optional[int] = field( 161 | default=None, 162 | metadata={ 163 | "help": ( 164 | "For debugging purposes or quicker training, truncate the number of training examples to this " 165 | "value if set." 166 | ) 167 | }, 168 | ) 169 | 170 | 171 | @dataclass 172 | class CustomArguments: 173 | """ 174 | Custom arguments for the script 175 | """ 176 | 177 | simcse_dropout: float = field( 178 | default=0.1, metadata={"help": "The SimCSE dropout rate for the model"} 179 | ) 180 | 181 | lora_dropout: float = field( 182 | default=0.05, metadata={"help": "The dropout rate for lora"} 183 | ) 184 | 185 | lora_r: int = field(default=8, metadata={"help": "The r value for lora"}) 186 | 187 | stop_after_n_steps: int = field( 188 | default=10000, metadata={"help": "Stop training after n steps"} 189 | ) 190 | 191 | experiment_id: Optional[str] = field( 192 | default=None, metadata={"help": "The experiment id"} 193 | ) 194 | 195 | loss_class: Optional[str] = field( 196 | default="HardNegativeNLLLoss", 197 | metadata={ 198 | "help": "The loss class to use for training. Options: HardNegativeNLLLoss" 199 | }, 200 | ) 201 | 202 | loss_scale: float = field( 203 | default=50.0, metadata={"help": "The loss scale for the loss function"} 204 | ) 205 | 206 | 207 | @dataclass 208 | class DefaultCollator: 209 | model: LLM2Vec 210 | 211 | def __init__(self, model: LLM2Vec) -> None: 212 | self.model = model 213 | 214 | def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]: 215 | batch = features 216 | num_texts = len(batch[0].texts) 217 | texts = [[] for _ in range(num_texts)] 218 | labels = [] 219 | 220 | for example in batch: 221 | for idx, text in enumerate(example.texts): 222 | # TODO: Add prepare_for_tokenization here similar to supervised training and see if it impacts performance 223 | texts[idx].append(text) 224 | labels.append(example.label) 225 | labels = torch.tensor(labels) 226 | 227 | sentence_features = [] 228 | for idx in range(num_texts): 229 | tokenized = self.model.tokenize(texts[idx]) 230 | sentence_features.append(tokenized) 231 | 232 | return sentence_features, labels 233 | 234 | 235 | class StopTrainingCallback(TrainerCallback): 236 | def __init__(self, stop_after_n_steps: int): 237 | self.stop_after_n_steps = stop_after_n_steps 238 | 239 | def on_step_end(self, args, state, control, **kwargs): 240 | if state.global_step >= self.stop_after_n_steps: 241 | control.should_training_stop = True 242 | 243 | 244 | class SimCSETrainer(Trainer): 245 | def __init__( 246 | self, 247 | *args, 248 | loss_function=None, 249 | **kwargs, 250 | ) -> None: 251 | super().__init__(*args, **kwargs) 252 | self.loss_function = loss_function 253 | 254 | def compute_loss( 255 | self, 256 | model: nn.Module, 257 | inputs: Dict[str, Union[torch.Tensor, Any]], 258 | return_outputs: bool = False, 259 | ) -> Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: 260 | features, labels = inputs 261 | q_reps = self.model(features[0]) 262 | d_reps = self.model(features[1]) 263 | 264 | d_reps_neg = None 265 | if len(features) > 2: 266 | d_reps_neg = self.model(features[2]) 267 | 268 | loss = self.loss_function(q_reps, d_reps, d_reps_neg) 269 | 270 | if return_outputs: 271 | output = torch.cat( 272 | [model(row)["sentence_embedding"][:, None] for row in features], dim=1 273 | ) 274 | return loss, output 275 | 276 | return loss 277 | 278 | def _save(self, output_dir: Optional[str] = None, state_dict=None): 279 | # If we are executing this function, we are the process zero, so we don't check for that. 280 | output_dir = output_dir if output_dir is not None else self.args.output_dir 281 | os.makedirs(output_dir, exist_ok=True) 282 | logger.info(f"Saving model checkpoint to {output_dir}") 283 | 284 | self.model.save(output_dir) 285 | 286 | # Good practice: save your training arguments together with the trained model 287 | torch.save(self.args, os.path.join(output_dir, "training_args.bin")) 288 | 289 | 290 | def main(): 291 | parser = HfArgumentParser( 292 | (ModelArguments, DataTrainingArguments, TrainingArguments, CustomArguments) 293 | ) 294 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 295 | # If we pass only one argument to the script and it's the path to a json file, 296 | # let's parse it to get our arguments. 297 | model_args, data_args, training_args, custom_args = parser.parse_json_file( 298 | json_file=os.path.abspath(sys.argv[1]) 299 | ) 300 | else: 301 | ( 302 | model_args, 303 | data_args, 304 | training_args, 305 | custom_args, 306 | ) = parser.parse_args_into_dataclasses() 307 | if training_args.ddp_find_unused_parameters: 308 | kwargs = [ 309 | DistributedDataParallelKwargs( 310 | dim=0, 311 | broadcast_buffers=True, 312 | bucket_cap_mb=25, 313 | find_unused_parameters=True, 314 | check_reduction=False, 315 | gradient_as_bucket_view=False, 316 | ) 317 | ] 318 | else: 319 | kwargs = [] 320 | accelerator = Accelerator(kwargs_handlers=kwargs) 321 | 322 | set_seed(training_args.seed) 323 | 324 | if training_args.gradient_checkpointing: 325 | training_args.gradient_checkpointing_kwargs = {"use_reentrant": False} 326 | 327 | train_dataset = load_dataset( 328 | data_args.dataset_name, 329 | split="train", 330 | file_path=data_args.dataset_file_path, 331 | ) 332 | 333 | train_examples = [ 334 | train_dataset[i] 335 | for i in tqdm( 336 | range(len(train_dataset)), 337 | desc="Loading train examples...", 338 | disable=not accelerator.is_main_process, 339 | ) 340 | ] 341 | 342 | torch_dtype = ( 343 | model_args.torch_dtype 344 | if model_args.torch_dtype in ["auto", None] 345 | else getattr(torch, model_args.torch_dtype) 346 | ) 347 | model = LLM2Vec.from_pretrained( 348 | base_model_name_or_path=model_args.model_name_or_path, 349 | enable_bidirectional=model_args.bidirectional, 350 | peft_model_name_or_path=model_args.peft_model_name_or_path, 351 | merge_peft=True, 352 | pooling_mode=model_args.pooling_mode, 353 | max_length=model_args.max_seq_length, 354 | torch_dtype=torch_dtype, 355 | attn_implementation=model_args.attn_implementation, 356 | attention_dropout=custom_args.simcse_dropout, 357 | ) 358 | 359 | # model organization is LLM2VecModel.model -> HF Model, we have to apply PEFT to the inner model 360 | if custom_args.lora_r is not None: 361 | model.model = initialize_peft( 362 | model.model, 363 | lora_r=custom_args.lora_r, 364 | lora_alpha=2 * custom_args.lora_r, 365 | lora_dropout=custom_args.lora_dropout, 366 | ) 367 | else: 368 | for param in model.model.parameters(): 369 | param.requires_grad = True 370 | 371 | tokenizer = model.tokenizer 372 | 373 | train_loss = load_loss(custom_args.loss_class, scale=custom_args.loss_scale) 374 | 375 | data_collator = DefaultCollator(model) 376 | 377 | print(training_args) 378 | 379 | trainer = SimCSETrainer( 380 | model=model, 381 | args=training_args, 382 | train_dataset=train_examples, 383 | data_collator=data_collator, 384 | tokenizer=tokenizer, 385 | loss_function=train_loss, 386 | ) 387 | 388 | if custom_args.stop_after_n_steps is not None: 389 | trainer.add_callback(StopTrainingCallback(custom_args.stop_after_n_steps)) 390 | 391 | trainer.train() 392 | 393 | 394 | if __name__ == "__main__": 395 | main() -------------------------------------------------------------------------------- /baselines/EasyRecModel.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import numpy as np 4 | import torch.nn as nn 5 | from tqdm import tqdm 6 | import scipy.sparse as sp 7 | import torch.nn.functional as F 8 | import torch.distributed as dist 9 | 10 | import transformers 11 | from transformers import RobertaTokenizer 12 | from transformers.models.roberta.modeling_roberta import RobertaPreTrainedModel, RobertaModel, RobertaLMHead 13 | from transformers.models.bert.modeling_bert import BertPreTrainedModel, BertModel, BertLMPredictionHead 14 | from transformers.activations import gelu 15 | from transformers.file_utils import ( 16 | add_code_sample_docstrings, 17 | add_start_docstrings, 18 | add_start_docstrings_to_model_forward, 19 | replace_return_docstrings, 20 | ) 21 | from transformers.modeling_outputs import SequenceClassifierOutput, BaseModelOutputWithPoolingAndCrossAttentions 22 | 23 | 24 | init = nn.init.xavier_uniform_ 25 | uniformInit = nn.init.uniform 26 | 27 | 28 | """ 29 | EasyRec 30 | """ 31 | def dot_product_scores(q_vectors, ctx_vectors): 32 | r = torch.matmul(q_vectors, torch.transpose(ctx_vectors, 0, 1)) 33 | return r 34 | 35 | 36 | import torch as t 37 | import torch.nn.functional as F 38 | 39 | def cal_bpr_loss(anc_embeds, pos_embeds, neg_embeds): 40 | pos_preds = (anc_embeds * pos_embeds).sum(-1) 41 | neg_preds = (anc_embeds * neg_embeds).sum(-1) 42 | return t.sum(F.softplus(neg_preds - pos_preds)) 43 | 44 | 45 | def reg_pick_embeds(embeds_list): 46 | reg_loss = 0 47 | for embeds in embeds_list: 48 | reg_loss += embeds.square().sum() 49 | return reg_loss 50 | 51 | 52 | def cal_infonce_loss(embeds1, embeds2, all_embeds2, temp=1.0): 53 | normed_embeds1 = embeds1 / t.sqrt(1e-8 + embeds1.square().sum(-1, keepdim=True)) 54 | normed_embeds2 = embeds2 / t.sqrt(1e-8 + embeds2.square().sum(-1, keepdim=True)) 55 | normed_all_embeds2 = all_embeds2 / t.sqrt(1e-8 + all_embeds2.square().sum(-1, keepdim=True)) 56 | nume_term = -(normed_embeds1 * normed_embeds2 / temp).sum(-1) 57 | deno_term = t.log(t.sum(t.exp(normed_embeds1 @ normed_all_embeds2.T / temp), dim=-1)) 58 | cl_loss = (nume_term + deno_term).sum() 59 | return cl_loss 60 | 61 | 62 | def cal_infonce_loss_spec_nodes(embeds1, embeds2, nodes, temp): 63 | embeds1 = F.normalize(embeds1 + 1e-8, p=2) 64 | embeds2 = F.normalize(embeds2 + 1e-8, p=2) 65 | pckEmbeds1 = embeds1[nodes] 66 | pckEmbeds2 = embeds2[nodes] 67 | nume = t.exp(t.sum(pckEmbeds1 * pckEmbeds2, dim=-1) / temp) 68 | deno = t.exp(pckEmbeds1 @ embeds2.T / temp).sum(-1) + 1e-8 69 | return -t.log(nume / deno).mean() 70 | 71 | 72 | def cal_sce_loss(x, y, alpha): 73 | x = F.normalize(x, p=2, dim=-1) 74 | y = F.normalize(y, p=2, dim=-1) 75 | loss = (1 - (x * y).sum(dim=-1)).pow_(alpha) 76 | loss = loss.mean() 77 | return loss 78 | 79 | 80 | def cal_rank_loss(stu_anc_emb, stu_pos_emb, stu_neg_emb, tea_anc_emb, tea_pos_emb, tea_neg_emb): 81 | stu_pos_score = (stu_anc_emb * stu_pos_emb).sum(dim=-1) 82 | stu_neg_score = (stu_anc_emb * stu_neg_emb).sum(dim=-1) 83 | stu_r_score = F.sigmoid(stu_pos_score - stu_neg_score) 84 | 85 | tea_pos_score = (tea_anc_emb * tea_pos_emb).sum(dim=-1) 86 | tea_neg_score = (tea_anc_emb * tea_neg_emb).sum(dim=-1) 87 | tea_r_score = F.sigmoid(tea_pos_score - tea_neg_score) 88 | 89 | rank_loss = -(tea_r_score * t.log(stu_r_score + 1e-8) + (1 - tea_r_score) * t.log(1 - stu_r_score + 1e-8)).mean() 90 | 91 | return rank_loss 92 | 93 | 94 | def reg_params(model): 95 | reg_loss = 0 96 | for W in model.parameters(): 97 | reg_loss += W.norm(2).square() 98 | return reg_loss 99 | 100 | 101 | class MLPLayer(nn.Module): 102 | """ 103 | Head for getting sentence representations over RoBERTa/BERT's CLS representation. 104 | """ 105 | def __init__(self, config): 106 | super().__init__() 107 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 108 | self.activation = nn.Tanh() 109 | 110 | def forward(self, features, **kwargs): 111 | x = self.dense(features) 112 | x = self.activation(x) 113 | return x 114 | 115 | 116 | class Pooler(nn.Module): 117 | """ 118 | Parameter-free poolers to get the sentence embedding 119 | 'cls': [CLS] representation with BERT/RoBERTa's MLP pooler. 120 | 'cls_before_pooler': [CLS] representation without the original MLP pooler. 121 | 'avg': average of the last layers' hidden states at each token. 122 | 'avg_top2': average of the last two layers. 123 | 'avg_first_last': average of the first and the last layers. 124 | """ 125 | def __init__(self, pooler_type): 126 | super().__init__() 127 | self.pooler_type = pooler_type 128 | assert self.pooler_type in ["cls", "cls_before_pooler", "avg", "avg_top2", "avg_first_last"], "unrecognized pooling type %s" % self.pooler_type 129 | 130 | def forward(self, attention_mask, outputs): 131 | last_hidden = outputs.last_hidden_state 132 | pooler_output = outputs.pooler_output 133 | hidden_states = outputs.hidden_states 134 | 135 | if self.pooler_type in ['cls_before_pooler', 'cls']: 136 | return last_hidden[:, 0] 137 | elif self.pooler_type == "avg": 138 | return ((last_hidden * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum(-1).unsqueeze(-1)) 139 | elif self.pooler_type == "avg_first_last": 140 | first_hidden = hidden_states[1] 141 | last_hidden = hidden_states[-1] 142 | pooled_result = ((first_hidden + last_hidden) / 2.0 * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum(-1).unsqueeze(-1) 143 | return pooled_result 144 | elif self.pooler_type == "avg_top2": 145 | second_last_hidden = hidden_states[-2] 146 | last_hidden = hidden_states[-1] 147 | pooled_result = ((last_hidden + second_last_hidden) / 2.0 * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum(-1).unsqueeze(-1) 148 | return pooled_result 149 | else: 150 | raise NotImplementedError 151 | 152 | 153 | class Similarity(nn.Module): 154 | """ 155 | Dot product or cosine similarity 156 | """ 157 | def __init__(self, temp): 158 | super().__init__() 159 | self.temp = temp 160 | self.cos = nn.CosineSimilarity(dim=-1) 161 | 162 | def forward(self, x, y): 163 | return self.cos(x, y) / self.temp 164 | 165 | 166 | class Easyrec_encoder(RobertaPreTrainedModel): 167 | _keys_to_ignore_on_load_missing = [r"position_ids"] 168 | 169 | def __init__(self, config, *model_args, **model_kargs): 170 | super().__init__(config) 171 | try: 172 | self.model_args = model_kargs["model_args"] 173 | self.roberta = RobertaModel(config, add_pooling_layer=False) 174 | if self.model_args.pooler_type == "cls": 175 | self.mlp = MLPLayer(config) 176 | if self.model_args.do_mlm: 177 | self.lm_head = RobertaLMHead(config) 178 | """ 179 | Contrastive learning class init function. 180 | """ 181 | self.pooler_type = self.model_args.pooler_type 182 | self.pooler = Pooler(self.pooler_type) 183 | self.sim = Similarity(temp=self.model_args.temp) 184 | self.init_weights() 185 | except: 186 | self.roberta = RobertaModel(config, add_pooling_layer=False) 187 | self.mlp = MLPLayer(config) 188 | self.lm_head = RobertaLMHead(config) 189 | self.pooler_type = 'cls' 190 | self.pooler = Pooler(self.pooler_type) 191 | self.init_weights() 192 | 193 | def forward(self, 194 | user_input_ids=None, 195 | user_attention_mask=None, 196 | pos_item_input_ids=None, 197 | pos_item_attention_mask=None, 198 | neg_item_input_ids=None, 199 | neg_item_attention_mask=None, 200 | token_type_ids=None, 201 | position_ids=None, 202 | head_mask=None, 203 | inputs_embeds=None, 204 | labels=None, 205 | output_attentions=None, 206 | output_hidden_states=None, 207 | return_dict=None, 208 | mlm_input_ids=None, 209 | mlm_attention_mask=None, 210 | mlm_labels=None, 211 | ): 212 | """ 213 | Contrastive learning forward function. 214 | """ 215 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 216 | batch_size = user_input_ids.size(0) 217 | 218 | # Get user embeddings 219 | user_outputs = self.roberta( 220 | input_ids=user_input_ids, 221 | attention_mask=user_attention_mask, 222 | token_type_ids=None, 223 | position_ids=None, 224 | head_mask=None, 225 | inputs_embeds=None, 226 | output_attentions=output_attentions, 227 | output_hidden_states=output_hidden_states, 228 | return_dict=return_dict, 229 | ) 230 | 231 | # Get positive item embeddings 232 | pos_item_outputs = self.roberta( 233 | input_ids=pos_item_input_ids, 234 | attention_mask=pos_item_attention_mask, 235 | token_type_ids=None, 236 | position_ids=None, 237 | head_mask=None, 238 | inputs_embeds=None, 239 | output_attentions=output_attentions, 240 | output_hidden_states=output_hidden_states, 241 | return_dict=return_dict, 242 | ) 243 | 244 | # Get negative item embeddings 245 | neg_item_outputs = self.roberta( 246 | input_ids=neg_item_input_ids, 247 | attention_mask=neg_item_attention_mask, 248 | token_type_ids=None, 249 | position_ids=None, 250 | head_mask=None, 251 | inputs_embeds=None, 252 | output_attentions=output_attentions, 253 | output_hidden_states=output_hidden_states, 254 | return_dict=return_dict, 255 | ) 256 | 257 | # MLM auxiliary objective 258 | if mlm_input_ids is not None: 259 | mlm_outputs = self.roberta( 260 | input_ids=mlm_input_ids, 261 | attention_mask=mlm_attention_mask, 262 | token_type_ids=None, 263 | position_ids=None, 264 | head_mask=None, 265 | inputs_embeds=None, 266 | output_attentions=output_attentions, 267 | output_hidden_states=output_hidden_states, 268 | return_dict=return_dict, 269 | ) 270 | 271 | # Pooling 272 | user_pooler_output = self.pooler(user_attention_mask, user_outputs) 273 | pos_item_pooler_output = self.pooler(pos_item_attention_mask, pos_item_outputs) 274 | neg_item_pooler_output = self.pooler(neg_item_attention_mask, neg_item_outputs) 275 | 276 | # If using "cls", we add an extra MLP layer 277 | # (same as BERT's original implementation) over the representation. 278 | if self.pooler_type == "cls": 279 | user_pooler_output = self.mlp(user_pooler_output) 280 | pos_item_pooler_output = self.mlp(pos_item_pooler_output) 281 | neg_item_pooler_output = self.mlp(neg_item_pooler_output) 282 | 283 | # Gather all item embeddings if using distributed training 284 | if dist.is_initialized() and self.training: 285 | # Dummy vectors for allgather 286 | user_list = [torch.zeros_like(user_pooler_output) for _ in range(dist.get_world_size())] 287 | pos_item_list = [torch.zeros_like(pos_item_pooler_output) for _ in range(dist.get_world_size())] 288 | neg_item_list = [torch.zeros_like(neg_item_pooler_output) for _ in range(dist.get_world_size())] 289 | # Allgather 290 | dist.all_gather(tensor_list=user_list, tensor=user_pooler_output.contiguous()) 291 | dist.all_gather(tensor_list=pos_item_list, tensor=pos_item_pooler_output.contiguous()) 292 | dist.all_gather(tensor_list=neg_item_list, tensor=neg_item_pooler_output.contiguous()) 293 | 294 | # Since allgather results do not have gradients, we replace the 295 | # current process's corresponding embeddings with original tensors 296 | user_list[dist.get_rank()] = user_pooler_output 297 | pos_item_list[dist.get_rank()] = pos_item_pooler_output 298 | neg_item_list[dist.get_rank()] = neg_item_pooler_output 299 | 300 | # Get full batch embeddings 301 | user_pooler_output = torch.cat(user_list, dim=0) 302 | pos_item_pooler_output = torch.cat(pos_item_list, dim=0) 303 | neg_item_pooler_output = torch.cat(neg_item_list, dim=0) 304 | 305 | cos_sim = self.sim(user_pooler_output.unsqueeze(1), pos_item_pooler_output.unsqueeze(0)) 306 | neg_sim = self.sim(user_pooler_output.unsqueeze(1), neg_item_pooler_output.unsqueeze(0)) 307 | cos_sim = torch.cat([cos_sim, neg_sim], 1) 308 | 309 | labels = torch.arange(cos_sim.size(0)).long().to(self.device) 310 | loss_fct = nn.CrossEntropyLoss() 311 | 312 | loss = loss_fct(cos_sim, labels) 313 | 314 | # Calculate loss for MLM 315 | if mlm_outputs is not None and mlm_labels is not None and self.model_args.do_mlm: 316 | mlm_labels = mlm_labels.view(-1, mlm_labels.size(-1)) 317 | prediction_scores = self.lm_head(mlm_outputs.last_hidden_state) 318 | masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), mlm_labels.view(-1)) 319 | loss = loss + self.model_args.mlm_weight * masked_lm_loss 320 | 321 | if not return_dict: 322 | raise NotImplementedError 323 | 324 | return SequenceClassifierOutput( 325 | loss=loss, 326 | logits=cos_sim, 327 | ) 328 | 329 | def encode(self, 330 | input_ids=None, 331 | attention_mask=None, 332 | token_type_ids=None, 333 | position_ids=None, 334 | head_mask=None, 335 | inputs_embeds=None, 336 | labels=None, 337 | output_attentions=None, 338 | output_hidden_states=None, 339 | return_dict=None, 340 | ): 341 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 342 | outputs = self.roberta( 343 | input_ids=input_ids, 344 | attention_mask=attention_mask, 345 | token_type_ids=None, 346 | position_ids=None, 347 | head_mask=None, 348 | inputs_embeds=None, 349 | output_attentions=output_attentions, 350 | output_hidden_states=output_hidden_states, 351 | return_dict=return_dict, 352 | ) 353 | pooler_output = self.pooler(attention_mask, outputs) 354 | if self.pooler_type == "cls": 355 | pooler_output = self.mlp(pooler_output) 356 | if not return_dict: 357 | return (outputs[0], pooler_output) + outputs[2:] 358 | 359 | return BaseModelOutputWithPoolingAndCrossAttentions( 360 | pooler_output=pooler_output, 361 | last_hidden_state=outputs.last_hidden_state, 362 | hidden_states=outputs.hidden_states, 363 | ) 364 | 365 | def inference(self, 366 | user_profile_list, 367 | item_profile_list, 368 | dataset_name, 369 | tokenizer, 370 | infer_batch_size=128 371 | ): 372 | n_user = len(user_profile_list) 373 | profiles = user_profile_list + item_profile_list 374 | n_batch = math.ceil(len(profiles) / infer_batch_size) 375 | text_embeds = [] 376 | for i in tqdm(range(n_batch), desc=f'Encoding Text {dataset_name}'): 377 | batch_profiles = profiles[i * infer_batch_size: (i + 1) * infer_batch_size] 378 | inputs = tokenizer(batch_profiles, padding=True, truncation=True, max_length=512, return_tensors="pt") 379 | for k in inputs: 380 | inputs[k] = inputs[k].to(self.device) 381 | with torch.inference_mode(): 382 | embeds = self.encode( 383 | input_ids=inputs.input_ids, 384 | attention_mask=inputs.attention_mask 385 | ) 386 | text_embeds.append(embeds.pooler_output.detach().cpu()) 387 | text_embeds = torch.concat(text_embeds, dim=0).cuda() 388 | user_embeds = F.normalize(text_embeds[: n_user], dim=-1) 389 | item_embeds = F.normalize(text_embeds[n_user: ], dim=-1) 390 | return user_embeds, item_embeds -------------------------------------------------------------------------------- /llm2rec/dataset.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import torch 3 | from torch.utils.data import Dataset 4 | import numpy as np 5 | from typing import List, Tuple 6 | import json 7 | import random 8 | from tqdm import tqdm 9 | import os 10 | import copy 11 | import pickle 12 | 13 | class Tokenizer: 14 | def __init__(self, tokenizer): 15 | self.tokenizer = tokenizer 16 | self.bos_id: int = self.tokenizer.bos_token_id 17 | self.eos_id: int = self.tokenizer.eos_token_id 18 | 19 | 20 | def encode(self, s: str, bos: bool, eos: bool) -> List[int]: 21 | assert type(s) is str 22 | t = self.tokenizer.encode(s) 23 | while t[0] == self.bos_id: 24 | t = t[1:] 25 | while t[-1] == self.eos_id: 26 | t = t[:-1] 27 | 28 | if bos and self.bos_id is not None: 29 | t = [self.bos_id] + t 30 | if eos and self.eos_id is not None: 31 | t = t + [self.eos_id] 32 | return t 33 | 34 | def decode(self, t: List[int]) -> str: 35 | return self.tokenizer.decode(t) 36 | 37 | 38 | class PurePromptDataset(Dataset): 39 | def __init__(self, train_file, tokenizer, max_len=2048, sample=-1, test = False, seed=0, category="", K=4, dedup=False): 40 | self.data = pd.read_csv(train_file) 41 | random.seed(seed) 42 | 43 | if not test: 44 | if sample > 0: 45 | self.data = self.data.sample(sample, random_state=seed) 46 | self.tokenizer = Tokenizer(tokenizer) 47 | self.test = test 48 | self.max_len = max_len 49 | self.category = category 50 | self.K = K 51 | self.dedup = dedup 52 | self.instructs = [ 53 | f"", 54 | ] 55 | self.get_inputs() 56 | def __len__(self): 57 | return len(self.data) 58 | 59 | 60 | def generate_example_prompt(self, data_point): 61 | return f"""{data_point["input"]}""" 62 | 63 | def generate_prompt(self, data_point): 64 | return data_point["input"] 65 | 66 | 67 | def get_history(self, row): 68 | row['history_item_title'] = eval(row['history_item_title']) 69 | L = len(row['history_item_title']) 70 | history = "" 71 | for i in range(L): 72 | if i == 0: 73 | history += row['history_item_title'][i] 74 | else: 75 | history += ", " + row['history_item_title'][i] 76 | target_item = str(row['item_title']) 77 | target_item_id = row["item_id"] 78 | last_history_item_id = eval(row["history_item_id"])[-1] 79 | return {"input": f"{history}", 80 | "output": target_item + '\n', 81 | "dedup": target_item_id == last_history_item_id} 82 | 83 | def pre(self, idx): 84 | history = self.get_history(self.data.iloc[idx]) 85 | target_item = history['output'] 86 | history['output'] = '' 87 | 88 | prompt = self.generate_prompt(history) 89 | tokens = self.tokenizer.encode(prompt, bos=False, eos=False) 90 | history["input"] = "" 91 | 92 | attention_mask = [1] * len(tokens) 93 | 94 | 95 | if self.test: 96 | return { 97 | "input_ids": tokens, 98 | "attention_mask": attention_mask, 99 | "text": prompt, 100 | # "select_index": select_index, 101 | } 102 | 103 | golden_tokens = self.tokenizer.encode(target_item, bos=False, eos=True) 104 | input_prompt_len = len(tokens) 105 | tokens = tokens + golden_tokens 106 | attention_mask = [1] * len(tokens) 107 | labels = [-100] * input_prompt_len + tokens[input_prompt_len:] 108 | 109 | if len(tokens) >= self.max_len: 110 | print(len(tokens)) 111 | 112 | 113 | return { 114 | "input_ids": tokens[-self.max_len:], 115 | "attention_mask": attention_mask[-self.max_len:], 116 | "labels": labels[-self.max_len:], 117 | } 118 | 119 | 120 | 121 | 122 | def get_inputs(self): 123 | inputs = [] 124 | for i in tqdm(range(len(self.data))): 125 | inputs.append(self.pre(i)) 126 | # print(inputs[-1]) 127 | 128 | self.inputs = inputs 129 | 130 | 131 | def get_all(self): 132 | temp = [] 133 | for i in range(len(self.data)): 134 | temp.append(self.get_history(self.data.iloc[i])) 135 | return temp 136 | 137 | def get_inputs_list(self): 138 | return self.inputs 139 | 140 | def __getitem__(self, idx): 141 | return self.inputs[idx] 142 | 143 | 144 | 145 | class DPODataset(Dataset): 146 | def __init__(self, train_file, info_file, tokenizer, neg_num=3, max_len=2048, sample=-1, test = False, seed=0, category="", dedup=False, negative_sample="cf", dpo=True, hard_negative_file=None): 147 | self.data = pd.read_csv(train_file) 148 | random.seed(seed) 149 | if not test: 150 | if sample > 0: 151 | self.data = self.data.sample(sample, random_state=seed) 152 | self.tokenizer = Tokenizer(tokenizer) 153 | self.test = test 154 | with open(info_file, 'r') as f: 155 | info = f.readlines() 156 | info = ["\"" + _.split('\t')[0].strip(' ') + "\"\n" for _ in info] 157 | self.item_name = info 158 | 159 | with open(hard_negative_file, 'rb') as f: 160 | hard_negative_dict = pickle.load(f) 161 | self.hard_negative_dict = hard_negative_dict 162 | 163 | self.neg_num = neg_num 164 | self.max_len = max_len 165 | self.category = category 166 | self.neg_num = neg_num 167 | self.negative_sample = negative_sample 168 | self.hard_negative_file = hard_negative_file 169 | self.dpo = dpo 170 | # self.K = K 171 | self.dedup = dedup 172 | self.instructs = [ 173 | f"Given a list of {category} the user recetenly enjoy, please write a new {category} that the user may bought", 174 | f"Considering the {category} that has recently captured the user's interest, kindly create a compilation of other {category} that the user might have played prior to this.", 175 | f"Based on the user's current gaming preference, please draft a list of potential {category} they may have experienced beforehand.", 176 | f"Reflecting on the {category} the user has taken pleasure in recently, we request that you formulate a list of {category} that may have preceded the user's current enjoyment.", 177 | f"In light of the recent gaming enjoyment expressed by the user, please assemble a list of {category} that could potentially include past titles the user has engaged with.", 178 | f"Taking into account the {category} that has lately provided enjoyment to the user, please put together an inventory of {category} the user might have explored previously.", 179 | f"Given the user's newfound enjoyment of a particular {category}, would you kindly generate a roster of other {category} that might resonate with their past gaming experiences?", 180 | f"In response to the user's recent fondness for a specific {category}, we seek your assistance in listing possible {category} the user may have delighted in earlier.", 181 | f"With respect to the {category} currently enjoyed by the user, please compile a suggestive list of {category} they may have played in the past.", 182 | f"Bearing in mind the {category} that the user has recently been enthralled by, please construct a catalog of other {category} that the user potentially partook in beforehand.", 183 | f"In relation to the user's recent entertainment with a given {category}, it would be appreciated if you could curate a list of {category} that might form part of the user's previous gaming history." 184 | ] 185 | self.get_inputs() 186 | def __len__(self): 187 | return len(self.data) 188 | 189 | 190 | def generate_example_prompt(self, data_point): 191 | return f"""### Example {data_point["idx"]}: 192 | {data_point["input"]} 193 | 194 | ### Response: 195 | {data_point["output"]} 196 | """ 197 | 198 | def generate_prompt(self, data_point): 199 | return f"""### User Input: 200 | {data_point["input"]} 201 | 202 | ### Response: 203 | {data_point["output"]}""" 204 | 205 | def get_history(self, row): 206 | row['history_item_title'] = eval(row['history_item_title']) 207 | L = len(row['history_item_title']) 208 | history = "" 209 | for i in range(L): 210 | if i == 0: 211 | history += "\"" + row['history_item_title'][i] + "\"" 212 | else: 213 | history += ", \"" + row['history_item_title'][i] + "\"" 214 | target_item = str(row['item_title']) 215 | target_item = "\"" + target_item + "\"" 216 | target_item_id = row["item_id"] 217 | last_history_item_id = eval(row["history_item_id"])[-1] 218 | return {"input": f"The user has palyed the following {self.category}s before: {history}", 219 | "output": target_item + '\n', 220 | "dedup": target_item_id == last_history_item_id} 221 | 222 | def pre(self, idx): 223 | instruction = f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. 224 | 225 | ### Instruction: 226 | {self.instructs[0]} 227 | """ 228 | # tokens = self.tokenizer.encode(instruction, bos=True, eos=False) 229 | 230 | history = self.get_history(self.data.iloc[idx]) # 从交互数据中拆分出历史交互数据和目标物品 231 | target_item = history['output'] 232 | history['output'] = '' 233 | # negative_prompt_ids = copy.deepcopy(tokens) 234 | # negative_items = [item for item in self.item_name if item != target_item] 235 | # neg_sam = random.sample(negative_items, self.neg_num) 236 | if not self.test: 237 | negative_samples = [] 238 | if self.negative_sample == "cf": 239 | if idx in self.hard_negative_dict: 240 | hard_negative_data = self.hard_negative_dict[idx] 241 | item = hard_negative_data['negative_item'] 242 | weight = hard_negative_data['predict_score'] 243 | non_zero_idx = np.nonzero(weight)[0] 244 | 245 | cf_weight = hard_negative_data['cf_score'] 246 | cf_weight = np.array(cf_weight) 247 | # get the ranking index of weight, select the items with the least cf_score as negative samples among non_zero_idx 248 | ranking_index = np.argsort(cf_weight[non_zero_idx]) 249 | cf_num = len(non_zero_idx) 250 | # negative_samples = [item[non_zero_idx[ranking_index[i]]] for i in range(cf_num)] 251 | if len(non_zero_idx) < self.neg_num: 252 | # additional_idx is the index of the items with the least cf_score among zero_idx 253 | zero_idx = np.setdiff1d(np.arange(len(weight)), non_zero_idx) 254 | ranking_zero_idx = np.argsort(cf_weight[zero_idx]) 255 | additional_idx = zero_idx[ranking_zero_idx[:self.neg_num - cf_num]] 256 | 257 | # additional_idx = np.random.choice(np.setdiff1d(np.arange(len(weight)), non_zero_idx), self.neg_num - len(non_zero_idx), replace=False) 258 | negative_samples = [item[non_zero_idx[ranking_index[i]]] for i in range(cf_num)] + [item[additional_idx[i]] for i in range(self.neg_num - cf_num)] 259 | # print(negative_samples) 260 | else: 261 | negative_samples = [item[non_zero_idx[ranking_index[i]]] for i in range(self.neg_num)] 262 | 263 | 264 | # print(negative_samples) 265 | else: 266 | return [] 267 | 268 | 269 | 270 | 271 | elif self.negative_sample == "hard": 272 | if idx in self.hard_negative_dict: 273 | hard_negative_data = self.hard_negative_dict[idx] 274 | weight = hard_negative_data['predict_score'] 275 | # find the index of non-zero weight 276 | non_zero_idx = np.nonzero(weight)[0] 277 | 278 | if self.neg_num >= 1: 279 | 280 | if len(non_zero_idx) < self.neg_num: 281 | # randomly sample idx in weight instead of non_zero_idx and add to non_zero_idx 282 | additional_idx = np.random.choice(np.setdiff1d(np.arange(len(weight)), non_zero_idx), self.neg_num - len(non_zero_idx), replace=False) 283 | non_zero_idx = np.concatenate([non_zero_idx, additional_idx]) 284 | 285 | # randomly sample neg_num non-zero weight 286 | num = np.random.choice(non_zero_idx, self.neg_num, replace=False) 287 | negative_samples = [str(hard_negative_data['negative_item'][num[i]]) for i in range(self.neg_num)] 288 | else: 289 | negative_samples = [str(hard_negative_data['negative_item'][non_zero_idx[i]]) for i in range(len(non_zero_idx))] 290 | else: 291 | return [] 292 | 293 | elif self.negative_sample == "random": 294 | if idx in self.hard_negative_dict: 295 | negative_items = [item for item in self.item_name if item != target_item] 296 | negative_samples = random.sample(negative_items, self.neg_num) 297 | else: 298 | return [] 299 | else: 300 | print("negative_sample is not valid") 301 | return [] 302 | 303 | else: 304 | if idx not in self.hard_negative_dict: 305 | return [] 306 | 307 | if not self.test and negative_samples == []: 308 | print("negative_samples is empty") 309 | dic_list = [] 310 | prompt = self.generate_prompt(history) 311 | if not self.dpo: 312 | dic["prompt"] = instruction + prompt 313 | dic["chosen"] = target_item 314 | for i in range(self.neg_num): 315 | dic[f"rejected{i}"] = negative_samples[i] 316 | # dic[f"weight{i}"] = 1 / self.neg_num 317 | dic[f"weight{i}"] = 1 318 | dic_list.append(dic) 319 | else: 320 | if self.test: 321 | # for reject in self.hard_negative_dict[idx]['negative_item']: 322 | # dic = { 323 | # "prompt": instruction + prompt, 324 | # "chosen": target_item, 325 | # "rejected1": reject, 326 | # "weight1": 1 327 | # } 328 | # dic_list.append(dic) 329 | dic = { 330 | "prompt": instruction + prompt, 331 | "chosen": target_item, 332 | } 333 | for i, reject in enumerate(self.hard_negative_dict[idx]['negative_item']): 334 | dic[f"reject{i+1}"] = reject 335 | dic[f"weight{i+1}"] = 1 336 | # break 337 | dic_list.append(dic) 338 | else: 339 | for reject in negative_samples: 340 | dic = { 341 | "prompt": instruction + prompt, 342 | "chosen": target_item, 343 | "rejected1": reject, 344 | "weight1": 1 345 | } 346 | dic_list.append(dic) 347 | return dic_list 348 | # tokens = tokens + self.tokenizer.encode(prompt, bos=False, eos=False) 349 | # history["input"] = "" 350 | 351 | # attention_mask = [1] * len(tokens) 352 | 353 | 354 | # if self.test: 355 | # return { 356 | # "input_ids": tokens, 357 | # "attention_mask": attention_mask, 358 | 359 | # # "select_index": select_index, 360 | # } 361 | 362 | # golden_tokens = self.tokenizer.encode(target_item, bos=False, eos=True) 363 | # input_prompt_len = len(tokens) 364 | # tokens = tokens + golden_tokens 365 | # attention_mask = [1] * len(tokens) 366 | # labels = [-100] * input_prompt_len + tokens[input_prompt_len:] 367 | 368 | # if len(tokens) >= self.max_len: 369 | # print(len(tokens)) 370 | 371 | 372 | # return { 373 | # "input_ids": tokens[-self.max_len:], 374 | # "attention_mask": attention_mask[-self.max_len:], 375 | # "labels": labels[-self.max_len:], 376 | 377 | # } 378 | 379 | 380 | 381 | 382 | def get_inputs(self): 383 | inputs = [] 384 | for i in tqdm(range(len(self.data))): 385 | inputs.append(self.pre(i)) 386 | # print(inputs[-1]) 387 | 388 | self.inputs = inputs 389 | 390 | 391 | def get_all(self): 392 | temp = [] 393 | for i in range(len(self.data)): 394 | temp.append(self.get_history(self.data.iloc[i])) 395 | return temp 396 | 397 | def get_inputs_list(self): 398 | return self.inputs 399 | 400 | def __getitem__(self, idx): 401 | return self.pre(idx) -------------------------------------------------------------------------------- /utils/llm2vec_encoder.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | from functools import partial 5 | from typing import Dict, List, Optional, Union 6 | 7 | import numpy as np 8 | import torch 9 | import torch.multiprocessing as mp 10 | from peft import PeftModel 11 | from torch import Tensor, device, nn 12 | from tqdm.autonotebook import tqdm, trange 13 | from transformers import ( 14 | AutoModel, 15 | AutoConfig, 16 | PretrainedConfig, 17 | AutoTokenizer, 18 | LlamaConfig, 19 | MistralConfig, 20 | GemmaConfig, 21 | Qwen2Config, 22 | ) 23 | 24 | from llm2vec.models import ( 25 | MistralBiModel, 26 | LlamaBiModel, 27 | GemmaBiModel, 28 | Qwen2BiModel, 29 | ) 30 | 31 | logger = logging.getLogger(__name__) 32 | 33 | 34 | def batch_to_device(batch, target_device: device): 35 | """ 36 | send a pytorch batch to a device (CPU/GPU) 37 | """ 38 | for key in batch: 39 | if isinstance(batch[key], Tensor): 40 | batch[key] = batch[key].to(target_device) 41 | return batch 42 | 43 | 44 | class LLM2Vec(nn.Module): 45 | def __init__( 46 | self, 47 | model: AutoModel, 48 | tokenizer: AutoTokenizer, 49 | pooling_mode: str = "mean", 50 | max_length: int = 512, 51 | doc_max_length: int = 400, 52 | skip_instruction: bool = True, 53 | ): 54 | super().__init__() 55 | self.model = model 56 | self.tokenizer = tokenizer 57 | self.pooling_mode = pooling_mode 58 | self.skip_instruction = skip_instruction 59 | self.max_length = max_length 60 | self.doc_max_length = doc_max_length 61 | self.config = model.config 62 | 63 | @classmethod 64 | def _get_model_class(cls, config_class_name, enable_bidirectional): 65 | if not enable_bidirectional: 66 | return AutoModel 67 | if config_class_name == "MistralConfig": 68 | return MistralBiModel 69 | elif config_class_name == "LlamaConfig": 70 | return LlamaBiModel 71 | elif config_class_name == "GemmaConfig": 72 | return GemmaBiModel 73 | elif config_class_name == "Qwen2Config": 74 | return Qwen2BiModel 75 | else: 76 | raise ValueError( 77 | f"{config_class_name} is not supported yet with bidirectional models." 78 | ) 79 | 80 | @classmethod 81 | def from_pretrained( 82 | cls, 83 | base_model_name_or_path, 84 | peft_model_name_or_path=None, 85 | merge_peft=False, 86 | enable_bidirectional=True, 87 | **kwargs, 88 | ): 89 | # pop out encoder args 90 | keys = ["pooling_mode", "max_length", "doc_max_length", "skip_instruction"] 91 | encoder_args = { 92 | key: kwargs.pop(key, None) for key in keys if kwargs.get(key) is not None 93 | } 94 | 95 | tokenizer = AutoTokenizer.from_pretrained(base_model_name_or_path) 96 | tokenizer.pad_token = tokenizer.eos_token 97 | tokenizer.padding_side = "left" 98 | 99 | config = AutoConfig.from_pretrained(base_model_name_or_path) 100 | config_class_name = config.__class__.__name__ 101 | 102 | model_class = cls._get_model_class( 103 | config_class_name, enable_bidirectional=enable_bidirectional 104 | ) 105 | model = model_class.from_pretrained(base_model_name_or_path, **kwargs) 106 | 107 | if os.path.isdir(base_model_name_or_path) and os.path.exists( 108 | f"{base_model_name_or_path}/config.json" 109 | ): 110 | with open(f"{base_model_name_or_path}/config.json", "r") as fIn: 111 | config_dict = json.load(fIn) 112 | config = PretrainedConfig.from_dict(config_dict) 113 | model.config._name_or_path = config._name_or_path 114 | 115 | # For special case where config.json and adapter weights are in the same directory 116 | if hasattr(model, "peft_config"): 117 | model = PeftModel.from_pretrained( 118 | model, 119 | base_model_name_or_path, 120 | ) 121 | model = model.merge_and_unload() 122 | 123 | if peft_model_name_or_path is not None: 124 | model = PeftModel.from_pretrained( 125 | model, 126 | peft_model_name_or_path, 127 | ) 128 | if merge_peft: 129 | model = model.merge_and_unload() 130 | 131 | config = {} 132 | config_addr = ( 133 | peft_model_name_or_path 134 | if peft_model_name_or_path is not None 135 | else base_model_name_or_path 136 | ) 137 | if os.path.exists(f"{config_addr}/llm2vec_config.json"): 138 | with open(f"{config_addr}/llm2vec_config.json", "r") as fIn: 139 | llm2vec_config = json.load(fIn) 140 | config.update(llm2vec_config) 141 | 142 | for key, value in encoder_args.items(): 143 | config[key] = value 144 | 145 | return cls(model=model, tokenizer=tokenizer, **config) 146 | 147 | def prepare_for_tokenization(self, text): 148 | if self.model.config._name_or_path == "meta-llama/Meta-Llama-3-8B-Instruct": 149 | text = ( 150 | "<|start_header_id|>user<|end_header_id|>\n\n" 151 | + text.strip() 152 | + "<|eot_id|>" 153 | ) 154 | return text 155 | if self.model.config._name_or_path in [ 156 | "mistralai/Mistral-7B-Instruct-v0.2", 157 | "meta-llama/Llama-2-7b-chat-hf", 158 | ]: 159 | text = "[INST] " + text.strip() + " [/INST]" 160 | if self.model.config._name_or_path in [ 161 | "google/gemma-2-9b-it", 162 | ]: 163 | text = "user\n" + text.strip() + "" 164 | if self.model.config._name_or_path in [ 165 | "Qwen/Qwen2-1.5B-Instruct", 166 | "Qwen/Qwen2-7B-Instruct", 167 | "Qwen/Qwen2-0.5B-Instruct", 168 | ]: 169 | text = "<|im_start|>user\n" + text.strip() + "<|im_end|>" 170 | if self.pooling_mode == "eos_token": 171 | if self.model.config._name_or_path == "meta-llama/Meta-Llama-3-8B": 172 | text = text.strip() + "<|end_of_text|>" 173 | elif isinstance(self.model.config, LlamaConfig) or isinstance( 174 | self.model.config, MistralConfig 175 | ): 176 | text = text.strip() + " " 177 | elif isinstance(self.model.config, GemmaConfig): 178 | text = text.strip() + "" 179 | elif isinstance(self.model.config, Qwen2Config): 180 | text = text.strip() + "<|endoftext|>" 181 | return text 182 | 183 | def tokenize(self, texts): 184 | texts_2 = [] 185 | original_texts = [] 186 | for text in texts: 187 | t = text.split("!@#$%^&*()") 188 | texts_2.append(t[1] if len(t) > 1 else "") 189 | original_texts.append("".join(t)) 190 | 191 | original = self.tokenizer( 192 | original_texts, 193 | return_tensors="pt", 194 | padding=True, 195 | truncation=True, 196 | max_length=self.max_length, 197 | ) 198 | embed_mask = None 199 | for t_i, t in enumerate(texts_2): 200 | ids = self.tokenizer( 201 | [t], 202 | return_tensors="pt", 203 | padding=True, 204 | truncation=True, 205 | max_length=self.max_length, 206 | add_special_tokens=False, 207 | ) 208 | if embed_mask is None: 209 | e_m = torch.zeros_like(original["attention_mask"][t_i]) 210 | if len(ids["input_ids"][0]) > 0: 211 | e_m[-len(ids["input_ids"][0]) :] = torch.ones( 212 | len(ids["input_ids"][0]) 213 | ) 214 | embed_mask = e_m.unsqueeze(0) 215 | else: 216 | e_m = torch.zeros_like(original["attention_mask"][t_i]) 217 | if len(ids["input_ids"][0]) > 0: 218 | e_m[-len(ids["input_ids"][0]) :] = torch.ones( 219 | len(ids["input_ids"][0]) 220 | ) 221 | embed_mask = torch.cat((embed_mask, e_m.unsqueeze(0)), dim=0) 222 | 223 | original["embed_mask"] = embed_mask 224 | return original 225 | 226 | def _skip_instruction(self, sentence_feature): 227 | assert ( 228 | sentence_feature["attention_mask"].shape 229 | == sentence_feature["embed_mask"].shape 230 | ) 231 | sentence_feature["attention_mask"] = sentence_feature["embed_mask"] 232 | 233 | def forward(self, sentence_feature: Dict[str, Tensor]): 234 | embed_mask = None 235 | if "embed_mask" in sentence_feature: 236 | embed_mask = sentence_feature.pop("embed_mask") 237 | reps = self.model(**sentence_feature) 238 | sentence_feature["embed_mask"] = embed_mask 239 | 240 | return self.get_pooling(sentence_feature, reps.last_hidden_state) 241 | 242 | def get_pooling(self, features, last_hidden_states): # All models padded from left 243 | assert ( 244 | self.tokenizer.padding_side == "left" 245 | ), "Pooling modes are implemented for padding from left." 246 | if self.skip_instruction: 247 | self._skip_instruction(features) 248 | seq_lengths = features["attention_mask"].sum(dim=-1) 249 | if self.pooling_mode == "mean": 250 | return torch.stack( 251 | [ 252 | last_hidden_states[i, -length:, :].mean(dim=0) 253 | for i, length in enumerate(seq_lengths) 254 | ], 255 | dim=0, 256 | ) 257 | elif self.pooling_mode == "weighted_mean": 258 | bs, l, _ = last_hidden_states.shape 259 | complete_weights = torch.zeros(bs, l, device=last_hidden_states.device) 260 | for i, seq_l in enumerate(seq_lengths): 261 | if seq_l > 0: 262 | complete_weights[i, -seq_l:] = torch.arange(seq_l) + 1 263 | complete_weights[i] /= torch.clamp( 264 | complete_weights[i].sum(), min=1e-9 265 | ) 266 | return torch.sum(last_hidden_states * complete_weights.unsqueeze(-1), dim=1) 267 | elif self.pooling_mode == "eos_token" or self.pooling_mode == "last_token": 268 | return last_hidden_states[:, -1] 269 | elif self.pooling_mode == "bos_token": 270 | return last_hidden_states[ 271 | features["input_ids"] == self.tokenizer.bos_token_id 272 | ] 273 | else: 274 | raise ValueError(f"{self.pooling_mode} is not implemented yet.") 275 | 276 | def _convert_to_str(self, instruction, text): 277 | tokenized_q = self.tokenizer( 278 | text, 279 | return_tensors="pt", 280 | padding=True, 281 | truncation=True, 282 | max_length=self.max_length, 283 | add_special_tokens=False, 284 | ) 285 | tokenized_q_length = len(tokenized_q["input_ids"][0]) 286 | 287 | while tokenized_q_length > self.doc_max_length: 288 | reduction_ratio = self.doc_max_length / tokenized_q_length 289 | reduced_length = int(len(text.split()) * reduction_ratio) 290 | text = " ".join(text.split()[:reduced_length]) 291 | tokenized_q = self.tokenizer( 292 | text, 293 | return_tensors="pt", 294 | padding=True, 295 | truncation=True, 296 | max_length=self.max_length, 297 | add_special_tokens=False, 298 | ) 299 | tokenized_q_length = len(tokenized_q["input_ids"][0]) 300 | 301 | return ( 302 | f"{instruction.strip()} !@#$%^&*(){text}" 303 | if instruction 304 | else f"!@#$%^&*(){text}" 305 | ) 306 | 307 | def encode( 308 | self, 309 | sentences: Union[str, List[str]], 310 | batch_size: int = 32, 311 | show_progress_bar: bool = True, 312 | convert_to_numpy: bool = False, 313 | convert_to_tensor: bool = False, 314 | device: Optional[str] = None, 315 | ): 316 | """ 317 | Encode a list of sentences to their respective embeddings. The sentences can be a list of strings or a string. 318 | Args: 319 | sentences: sentence or sentences to encode. 320 | batch_size: batch size for turning sentence tokens into embeddings. 321 | show_progress_bar: whether to show progress bars during encoding steps. 322 | convert_to_numpy: If true, return numpy arrays instead of torch tensors. 323 | convert_to_tensor: If true, return torch tensors (default). 324 | device: torch backend device identifier (e.g., 'cuda', 'cpu','mps' etc.). If not specified, 325 | the default is to use cuda when available, otherwise cpu. Note that only the choice of 'cuda' supports 326 | multiprocessing as currently implemented. 327 | 328 | Returns: embeddings of the sentences. Embeddings are detached and always on the CPU (see _encode implementation). 329 | 330 | """ 331 | if isinstance(sentences[0], str) and isinstance(sentences[-1], int): 332 | sentences = [sentences] 333 | # required for MEDI version of MTEB 334 | if isinstance(sentences[0], str): 335 | sentences = [[""] + [sentence] for sentence in sentences] 336 | 337 | if device is None: 338 | device = "cuda" if torch.cuda.is_available() else "cpu" 339 | 340 | concatenated_input_texts = [] 341 | for sentence in sentences: 342 | assert isinstance(sentence[0], str) 343 | assert isinstance(sentence[1], str) 344 | concatenated_input_texts.append( 345 | self._convert_to_str(sentence[0], sentence[1]) 346 | ) 347 | sentences = concatenated_input_texts 348 | 349 | self.eval() 350 | 351 | if convert_to_tensor: 352 | convert_to_numpy = False 353 | 354 | length_sorted_idx = np.argsort([-self._text_length(sen) for sen in sentences]) 355 | sentences_sorted = [sentences[idx] for idx in length_sorted_idx] 356 | all_embeddings = [] 357 | 358 | if torch.cuda.device_count() <= 1: 359 | # This branch also support mps devices 360 | self.to(device) 361 | for start_index in trange( 362 | 0, 363 | len(sentences), 364 | batch_size, 365 | desc="Batches", 366 | disable=not show_progress_bar, 367 | ): 368 | sentences_batch = sentences_sorted[ 369 | start_index : start_index + batch_size 370 | ] 371 | embeddings = self._encode( 372 | sentences_batch, device=device, convert_to_numpy=convert_to_numpy 373 | ) 374 | all_embeddings.append(embeddings) 375 | else: 376 | 377 | num_proc = torch.cuda.device_count() 378 | cuda_compatible_multiprocess = mp.get_context("spawn") 379 | with cuda_compatible_multiprocess.Pool(num_proc) as p: 380 | sentences_batches = [ 381 | sentences_sorted[start_index : start_index + batch_size] 382 | for start_index in range(0, len(sentences), batch_size) 383 | ] 384 | 385 | progress_bar = tqdm( 386 | total=len(sentences_batches), 387 | desc="Batches", 388 | disable=not show_progress_bar, 389 | ) 390 | results = [] 391 | 392 | def update(*args): 393 | progress_bar.update() 394 | 395 | for batch in sentences_batches: 396 | results.append( 397 | p.apply_async( 398 | self._encode, 399 | args=(batch, None, convert_to_numpy, True), 400 | callback=update, 401 | ) 402 | ) 403 | 404 | all_embeddings = [result.get() for result in results] 405 | progress_bar.close() 406 | 407 | all_embeddings = torch.cat(all_embeddings, dim=0) 408 | all_embeddings = all_embeddings[np.argsort(length_sorted_idx)] 409 | all_embeddings = all_embeddings.to(torch.float32) 410 | if convert_to_numpy: 411 | all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings]) 412 | return all_embeddings 413 | 414 | def save(self, output_path, merge_before_save=False, save_config=True): 415 | if merge_before_save and isinstance(self.model, PeftModel): 416 | self.model = self.model.merge_and_unload() 417 | # Fixes the issue of saving - https://huggingface.co/McGill-NLP/LLM2Vec-Mistral-7B-Instruct-v2-mntp-unsup-simcse/discussions/1 418 | if hasattr(self.model, "_hf_peft_config_loaded"): 419 | self.model._hf_peft_config_loaded = False 420 | 421 | self.model.save_pretrained(output_path) 422 | self.tokenizer.save_pretrained(output_path) 423 | 424 | llm2vec_config = { 425 | "pooling_mode": self.pooling_mode, 426 | "max_length": self.max_length, 427 | "doc_max_length": self.doc_max_length, 428 | "skip_instruction": self.skip_instruction, 429 | } 430 | 431 | if save_config: 432 | os.makedirs(output_path, exist_ok=True) 433 | with open(f"{output_path}/llm2vec_config.json", "w") as fOut: 434 | json.dump(llm2vec_config, fOut, indent=4) 435 | 436 | def _encode( 437 | self, 438 | sentences_batch, 439 | device: Optional[str] = None, 440 | convert_to_numpy: bool = False, 441 | multiprocessing=False, 442 | ): 443 | if multiprocessing: 444 | # multiprocessing only supports CUDA devices at this time, so we ignore the value of device 445 | # and use cuda:rank for the device 446 | rank = mp.current_process()._identity[0] 447 | if device is None and torch.cuda.is_available(): 448 | device = f"cuda:{rank % torch.cuda.device_count()}" 449 | 450 | self.to(device) 451 | features = self.tokenize( 452 | [self.prepare_for_tokenization(sentence) for sentence in sentences_batch] 453 | ) 454 | features = batch_to_device(features, device) 455 | 456 | with torch.no_grad(): 457 | embeddings = self.forward(features) 458 | embeddings = embeddings.detach() 459 | embeddings = embeddings.cpu() 460 | 461 | return embeddings 462 | 463 | def _text_length(self, text: Union[List[int], List[List[int]]]): 464 | """ 465 | Help function to get the length for the input text. Text can be either a string (which means a single text) 466 | a list of ints (which means a single tokenized text), or a tuple of list of ints 467 | (representing several text inputs to the model). 468 | """ 469 | if ( 470 | isinstance(text, str) 471 | or (isinstance(text, list) and isinstance(text[0], int)) 472 | or len(text) == 0 473 | ): # Single text, list of ints, or empty 474 | return len(text) 475 | if isinstance(text, dict): # {key: value} case 476 | return len(next(iter(text.values()))) 477 | elif not hasattr(text, "__len__"): # Object has no len() method 478 | return 1 479 | else: 480 | return sum([len(t) for t in text]) 481 | 482 | def resize_token_embeddings( 483 | self, 484 | new_num_tokens: Optional[int] = None, 485 | pad_to_multiple_of: Optional[int] = None, 486 | ) -> nn.Embedding: 487 | return self.model.resize_token_embeddings( 488 | new_num_tokens=new_num_tokens, pad_to_multiple_of=pad_to_multiple_of 489 | ) 490 | 491 | def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): 492 | self.model.gradient_checkpointing_enable( 493 | gradient_checkpointing_kwargs=gradient_checkpointing_kwargs 494 | ) 495 | -------------------------------------------------------------------------------- /seqrec/modules.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import numpy as np 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import copy 7 | 8 | class PositionwiseFeedForward(nn.Module): 9 | def __init__(self, d_in, d_hid, dropout=0.1): 10 | super().__init__() 11 | self.w_1 = nn.Conv1d(d_in, d_hid, 1) 12 | self.w_2 = nn.Conv1d(d_hid, d_in, 1) 13 | self.layer_norm = nn.LayerNorm(d_in) 14 | self.dropout = nn.Dropout(dropout) 15 | 16 | def forward(self, x): 17 | residual = x 18 | output = x.transpose(1, 2) 19 | output = self.w_2(F.relu(self.w_1(output))) 20 | output = output.transpose(1, 2) 21 | output = self.dropout(output) 22 | output = self.layer_norm(output + residual) 23 | return output 24 | 25 | 26 | class SinusoidalPositionEmbeddings(nn.Module): 27 | def __init__(self, dim): 28 | super().__init__() 29 | self.dim = dim 30 | 31 | def forward(self, time): 32 | device = time.device 33 | half_dim = self.dim // 2 34 | embeddings = math.log(10000) / (half_dim - 1) 35 | embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings) 36 | embeddings = time[:, None] * embeddings[None, :] 37 | embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) 38 | return embeddings 39 | 40 | 41 | class MultiHeadAttention(nn.Module): 42 | def __init__(self, hidden_size, num_units, num_heads, dropout_rate, bidirectional=False): 43 | super().__init__() 44 | self.hidden_size = hidden_size 45 | self.num_heads = num_heads 46 | assert hidden_size % num_heads == 0 47 | 48 | self.linear_q = nn.Linear(hidden_size, num_units) 49 | self.linear_k = nn.Linear(hidden_size, num_units) 50 | self.linear_v = nn.Linear(hidden_size, num_units) 51 | self.dropout = nn.Dropout(dropout_rate) 52 | self.softmax = nn.Softmax(dim=-1) 53 | self.bidirectional = bidirectional 54 | 55 | def forward(self, queries, keys): 56 | """ 57 | :param queries: A 3d tensor with shape of [N, T_q, C_q] 58 | :param keys: A 3d tensor with shape of [N, T_k, C_k] 59 | 60 | :return: A 3d tensor with shape of (N, T_q, C) 61 | 62 | """ 63 | Q = self.linear_q(queries) # (N, T_q, C) 64 | K = self.linear_k(keys) # (N, T_k, C) 65 | V = self.linear_v(keys) # (N, T_k, C) 66 | 67 | # Split and Concat 68 | split_size = self.hidden_size // self.num_heads 69 | Q_ = torch.cat(torch.split(Q, split_size, dim=2), dim=0) # (h*N, T_q, C/h) 70 | K_ = torch.cat(torch.split(K, split_size, dim=2), dim=0) # (h*N, T_k, C/h) 71 | V_ = torch.cat(torch.split(V, split_size, dim=2), dim=0) # (h*N, T_k, C/h) 72 | 73 | # Multiplication 74 | matmul_output = torch.bmm(Q_, K_.transpose(1, 2)) / self.hidden_size ** 0.5 # (h*N, T_q, T_k) 75 | 76 | # Key Masking 77 | key_mask = torch.sign(torch.abs(keys.sum(dim=-1))).repeat(self.num_heads, 1) # (h*N, T_k) 78 | key_mask_reshaped = key_mask.unsqueeze(1).repeat(1, queries.shape[1], 1) # (h*N, T_q, T_k) 79 | key_paddings = torch.ones_like(matmul_output) * (-2 ** 32 + 1) 80 | matmul_output_m1 = torch.where(torch.eq(key_mask_reshaped, 0), key_paddings, matmul_output) # (h*N, T_q, T_k) 81 | 82 | if not self.bidirectional: 83 | # Causality - Future Blinding 84 | diag_vals = torch.ones_like(matmul_output[0, :, :]) # (T_q, T_k) 85 | tril = torch.tril(diag_vals) # (T_q, T_k) 86 | causality_mask = tril.unsqueeze(0).repeat(matmul_output.shape[0], 1, 1) # (h*N, T_q, T_k) 87 | causality_paddings = torch.ones_like(causality_mask) * (-2 ** 32 + 1) 88 | matmul_output_m2 = torch.where(torch.eq(causality_mask, 0), causality_paddings, 89 | matmul_output_m1) # (h*N, T_q, T_k) 90 | 91 | # Activation 92 | matmul_output_sm = self.softmax(matmul_output_m2) # (h*N, T_q, T_k) 93 | else: 94 | matmul_output_sm = self.softmax(matmul_output_m1) # (h*N, T_q, T_k) 95 | # Query Masking 96 | query_mask = torch.sign(torch.abs(queries.sum(dim=-1))).repeat(self.num_heads, 1) # (h*N, T_q) 97 | query_mask = query_mask.unsqueeze(-1).repeat(1, 1, keys.shape[1]) # (h*N, T_q, T_k) 98 | matmul_output_qm = matmul_output_sm * query_mask 99 | 100 | # Dropout 101 | matmul_output_dropout = self.dropout(matmul_output_qm) 102 | 103 | # Weighted Sum 104 | output_ws = torch.bmm(matmul_output_dropout, V_) # ( h*N, T_q, C/h) 105 | 106 | # Restore Shape 107 | output = torch.cat(torch.split(output_ws, output_ws.shape[0] // self.num_heads, dim=0), dim=2) # (N, T_q, C) 108 | 109 | # Residual Connection 110 | output_res = output + queries 111 | 112 | return output_res 113 | 114 | class GRUEncoder(nn.Module): 115 | def __init__(self, config): 116 | super(GRUEncoder, self).__init__() 117 | self.gru = nn.GRU( 118 | input_size=config['hidden_size'], 119 | hidden_size=config['hidden_size'], 120 | num_layers=config['layer_num'], 121 | bias=False, 122 | batch_first=True) 123 | 124 | self._reset_parameters() 125 | 126 | def _reset_parameters(self): 127 | def init_weights(m): 128 | if isinstance(m, nn.Linear): 129 | nn.init.xavier_uniform_(m.weight) 130 | if m.bias is not None: 131 | nn.init.zeros_(m.bias) 132 | elif isinstance(m, nn.LayerNorm): 133 | nn.init.ones_(m.weight) 134 | nn.init.zeros_(m.bias) 135 | 136 | self.apply(init_weights) 137 | 138 | def forward(self, seq, mask): 139 | return self.gru(seq * mask)[0] 140 | 141 | class TransformerEncoder(nn.Module): 142 | def __init__(self, config): 143 | super(TransformerEncoder, self).__init__() 144 | self.ln_1 = nn.LayerNorm(config['hidden_size']) 145 | self.ln_2 = nn.LayerNorm(config['hidden_size']) 146 | self.ln_3 = nn.LayerNorm(config['hidden_size']) 147 | self.mh_attn = MultiHeadAttention(config['hidden_size'], config['hidden_size'], config['num_heads'], config['dropout'], config.get('bidirectional', False)) 148 | self.feed_forward = PositionwiseFeedForward(config['hidden_size'], config['hidden_size'], config['dropout']) 149 | 150 | 151 | self._reset_parameters() 152 | 153 | def _reset_parameters(self): 154 | def init_weights(m): 155 | if isinstance(m, nn.Linear): 156 | nn.init.xavier_uniform_(m.weight) 157 | if m.bias is not None: 158 | nn.init.zeros_(m.bias) 159 | elif isinstance(m, nn.LayerNorm): 160 | nn.init.ones_(m.weight) 161 | nn.init.zeros_(m.bias) 162 | 163 | self.apply(init_weights) 164 | 165 | def forward(self, seq, mask): 166 | seq = seq * mask 167 | seq_normalized = self.ln_1(seq) 168 | mh_attn_out = self.mh_attn(seq_normalized, seq) 169 | seq = seq + mh_attn_out # Residual connection 170 | ff_out = self.feed_forward(self.ln_2(seq)) 171 | ff_out *= mask 172 | seq = seq + ff_out 173 | ff_out = self.ln_3(seq) 174 | return ff_out 175 | 176 | 177 | 178 | class FeedForward(nn.Module): 179 | """ 180 | Point-wise feed-forward layer is implemented by two dense layers. 181 | 182 | Args: 183 | input_tensor (torch.Tensor): the input of the point-wise feed-forward layer 184 | 185 | Returns: 186 | hidden_states (torch.Tensor): the output of the point-wise feed-forward layer 187 | 188 | """ 189 | 190 | def __init__( 191 | self, hidden_size, inner_size, hidden_dropout_prob, hidden_act, layer_norm_eps 192 | ): 193 | super(FeedForward, self).__init__() 194 | self.dense_1 = nn.Linear(hidden_size, inner_size) 195 | self.intermediate_act_fn = self.get_hidden_act(hidden_act) 196 | 197 | self.dense_2 = nn.Linear(inner_size, hidden_size) 198 | self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) 199 | self.dropout = nn.Dropout(hidden_dropout_prob) 200 | 201 | def get_hidden_act(self, act): 202 | ACT2FN = { 203 | "gelu": self.gelu, 204 | "relu": F.relu, 205 | "swish": self.swish, 206 | "tanh": torch.tanh, 207 | "sigmoid": torch.sigmoid, 208 | } 209 | return ACT2FN[act] 210 | 211 | def gelu(self, x): 212 | """Implementation of the gelu activation function. 213 | 214 | For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):: 215 | 216 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 217 | 218 | Also see https://arxiv.org/abs/1606.08415 219 | """ 220 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 221 | 222 | def swish(self, x): 223 | return x * torch.sigmoid(x) 224 | 225 | def forward(self, input_tensor): 226 | hidden_states = self.dense_1(input_tensor) 227 | hidden_states = self.intermediate_act_fn(hidden_states) 228 | 229 | hidden_states = self.dense_2(hidden_states) 230 | hidden_states = self.dropout(hidden_states) 231 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 232 | 233 | return hidden_states 234 | 235 | 236 | class MultiHeadAttention_v2(nn.Module): 237 | """ 238 | Multi-head Self-attention layers, a attention score dropout layer is introduced. 239 | 240 | Args: 241 | input_tensor (torch.Tensor): the input of the multi-head self-attention layer 242 | attention_mask (torch.Tensor): the attention mask for input tensor 243 | 244 | Returns: 245 | hidden_states (torch.Tensor): the output of the multi-head self-attention layer 246 | 247 | """ 248 | 249 | def __init__( 250 | self, 251 | n_heads, 252 | hidden_size, 253 | hidden_dropout_prob, 254 | attn_dropout_prob, 255 | layer_norm_eps, 256 | ): 257 | super(MultiHeadAttention_v2, self).__init__() 258 | if hidden_size % n_heads != 0: 259 | raise ValueError( 260 | "The hidden size (%d) is not a multiple of the number of attention " 261 | "heads (%d)" % (hidden_size, n_heads) 262 | ) 263 | 264 | self.num_attention_heads = n_heads 265 | self.attention_head_size = int(hidden_size / n_heads) 266 | self.all_head_size = self.num_attention_heads * self.attention_head_size 267 | self.sqrt_attention_head_size = math.sqrt(self.attention_head_size) 268 | 269 | self.query = nn.Linear(hidden_size, self.all_head_size) 270 | self.key = nn.Linear(hidden_size, self.all_head_size) 271 | self.value = nn.Linear(hidden_size, self.all_head_size) 272 | 273 | self.softmax = nn.Softmax(dim=-1) 274 | self.attn_dropout = nn.Dropout(attn_dropout_prob) 275 | 276 | self.dense = nn.Linear(hidden_size, hidden_size) 277 | self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) 278 | self.out_dropout = nn.Dropout(hidden_dropout_prob) 279 | 280 | def transpose_for_scores(self, x): 281 | new_x_shape = x.size()[:-1] + ( 282 | self.num_attention_heads, 283 | self.attention_head_size, 284 | ) 285 | x = x.view(*new_x_shape) 286 | return x 287 | 288 | def forward(self, input_tensor, attention_mask): 289 | mixed_query_layer = self.query(input_tensor) 290 | mixed_key_layer = self.key(input_tensor) 291 | mixed_value_layer = self.value(input_tensor) 292 | 293 | query_layer = self.transpose_for_scores(mixed_query_layer).permute(0, 2, 1, 3) 294 | key_layer = self.transpose_for_scores(mixed_key_layer).permute(0, 2, 3, 1) 295 | value_layer = self.transpose_for_scores(mixed_value_layer).permute(0, 2, 1, 3) 296 | 297 | # Take the dot product between "query" and "key" to get the raw attention scores. 298 | attention_scores = torch.matmul(query_layer, key_layer) 299 | 300 | attention_scores = attention_scores / self.sqrt_attention_head_size 301 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function) 302 | # [batch_size heads seq_len seq_len] scores 303 | # [batch_size 1 1 seq_len] 304 | attention_scores = attention_scores + attention_mask 305 | 306 | # Normalize the attention scores to probabilities. 307 | attention_probs = self.softmax(attention_scores) 308 | # This is actually dropping out entire tokens to attend to, which might 309 | # seem a bit unusual, but is taken from the original Transformer paper. 310 | 311 | attention_probs = self.attn_dropout(attention_probs) 312 | context_layer = torch.matmul(attention_probs, value_layer) 313 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 314 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 315 | context_layer = context_layer.view(*new_context_layer_shape) 316 | hidden_states = self.dense(context_layer) 317 | hidden_states = self.out_dropout(hidden_states) 318 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 319 | 320 | return hidden_states 321 | 322 | 323 | 324 | class TransformerLayer_v2(nn.Module): 325 | """ 326 | One transformer layer consists of a multi-head self-attention layer and a point-wise feed-forward layer. 327 | 328 | Args: 329 | hidden_states (torch.Tensor): the input of the multi-head self-attention sublayer 330 | attention_mask (torch.Tensor): the attention mask for the multi-head self-attention sublayer 331 | 332 | Returns: 333 | feedforward_output (torch.Tensor): The output of the point-wise feed-forward sublayer, 334 | is the output of the transformer layer. 335 | 336 | """ 337 | 338 | def __init__( 339 | self, 340 | n_heads, 341 | hidden_size, 342 | intermediate_size, 343 | hidden_dropout_prob, 344 | attn_dropout_prob, 345 | hidden_act, 346 | layer_norm_eps, 347 | ): 348 | super(TransformerLayer_v2, self).__init__() 349 | self.multi_head_attention = MultiHeadAttention_v2( 350 | n_heads, hidden_size, hidden_dropout_prob, attn_dropout_prob, layer_norm_eps 351 | ) 352 | self.feed_forward = FeedForward( 353 | hidden_size, 354 | intermediate_size, 355 | hidden_dropout_prob, 356 | hidden_act, 357 | layer_norm_eps, 358 | ) 359 | 360 | def forward(self, hidden_states, attention_mask): 361 | attention_output = self.multi_head_attention(hidden_states, attention_mask) 362 | feedforward_output = self.feed_forward(attention_output) 363 | return feedforward_output 364 | 365 | 366 | class TransformerEncoder_v2(nn.Module): 367 | r"""One TransformerEncoder consists of several TransformerLayers. 368 | 369 | Args: 370 | n_layers(num): num of transformer layers in transformer encoder. Default: 2 371 | n_heads(num): num of attention heads for multi-head attention layer. Default: 2 372 | hidden_size(num): the input and output hidden size. Default: 64 373 | inner_size(num): the dimensionality in feed-forward layer. Default: 256 374 | hidden_dropout_prob(float): probability of an element to be zeroed. Default: 0.5 375 | attn_dropout_prob(float): probability of an attention score to be zeroed. Default: 0.5 376 | hidden_act(str): activation function in feed-forward layer. Default: 'gelu' 377 | candidates: 'gelu', 'relu', 'swish', 'tanh', 'sigmoid' 378 | layer_norm_eps(float): a value added to the denominator for numerical stability. Default: 1e-12 379 | 380 | """ 381 | 382 | def __init__( 383 | self, 384 | config, 385 | # n_layers=2, 386 | # n_heads=2, 387 | # hidden_size=64, 388 | # inner_size=256, 389 | # hidden_dropout_prob=0.5, 390 | # attn_dropout_prob=0.5, 391 | # hidden_act="gelu", 392 | # layer_norm_eps=1e-12, 393 | ): 394 | super(TransformerEncoder_v2, self).__init__() 395 | layer = TransformerLayer_v2( 396 | config['num_heads'], 397 | config['hidden_size'], 398 | 256, 399 | config['dropout'], 400 | config['dropout'], 401 | "gelu", 402 | 1e-12, 403 | ) 404 | self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config['layer_num'])]) 405 | 406 | def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True): 407 | """ 408 | Args: 409 | hidden_states (torch.Tensor): the input of the TransformerEncoder 410 | attention_mask (torch.Tensor): the attention mask for the input hidden_states 411 | output_all_encoded_layers (Bool): whether output all transformer layers' output 412 | 413 | Returns: 414 | all_encoder_layers (list): if output_all_encoded_layers is True, return a list consists of all transformer 415 | layers' output, otherwise return a list only consists of the output of last transformer layer. 416 | 417 | """ 418 | all_encoder_layers = [] 419 | for layer_module in self.layer: 420 | hidden_states = layer_module(hidden_states, attention_mask) 421 | if output_all_encoded_layers: 422 | all_encoder_layers.append(hidden_states) 423 | if not output_all_encoded_layers: 424 | all_encoder_layers.append(hidden_states) 425 | return all_encoder_layers 426 | 427 | 428 | def get_attention_mask(item_seq, bidirectional=False): 429 | """Generate left-to-right uni-directional or bidirectional attention mask for multi-head attention.""" 430 | attention_mask = item_seq != 0 431 | extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) # torch.bool 432 | if not bidirectional: 433 | extended_attention_mask = torch.tril( 434 | extended_attention_mask.expand((-1, -1, item_seq.size(-1), -1)) 435 | ) 436 | extended_attention_mask = torch.where(extended_attention_mask, 0.0, -10000.0) 437 | return extended_attention_mask 438 | 439 | 440 | def gather_indexes(output, gather_index): 441 | """Gathers the vectors at the specific positions over a minibatch""" 442 | gather_index = gather_index.view(-1, 1, 1).expand(-1, -1, output.shape[-1]) 443 | output_tensor = output.gather(dim=1, index=gather_index) 444 | return output_tensor.squeeze(1) 445 | 446 | 447 | 448 | def in_batch_negative_sampling(batch_items): 449 | B = batch_items.size(0) 450 | expanded_items = batch_items.unsqueeze(0).repeat(B, 1) 451 | mask = ~torch.eye(B, dtype=torch.bool, device=batch_items.device) 452 | neg_items = expanded_items[mask].view(B, B - 1) 453 | return neg_items 454 | 455 | 456 | import torch 457 | 458 | 459 | def in_batch_negative_sampling_sample(batch_items, num_neg=16): 460 | B = batch_items.size(0) 461 | expanded_items = batch_items.unsqueeze(0).repeat(B, 1) 462 | mask = ~torch.eye(B, dtype=torch.bool, device=batch_items.device) 463 | neg_items = expanded_items.masked_select(mask).view(B, B - 1) 464 | weights = torch.ones_like(neg_items, dtype=torch.float) 465 | sample_indices = torch.multinomial(weights, num_neg, replacement=False) 466 | neg_sampled = neg_items.gather(1, sample_indices) 467 | return neg_sampled 468 | 469 | 470 | def extract_axis_1(data, indices): 471 | """ 472 | Extracts elements from axis 1 based on the provided indices. 473 | """ 474 | return torch.stack([data[i, indices[i], :] for i in range(data.shape[0])], dim=0).unsqueeze(1) 475 | 476 | def diagonalize_and_scale(e, epsilon=1e-7): 477 | var_e = torch.cov(e.T) 478 | mean_e = torch.mean(e, axis=0) 479 | eigvals, eigvecs = torch.linalg.eigh(var_e) 480 | eigvals = eigvals + epsilon 481 | D = torch.diag(1.0 / torch.sqrt(eigvals)) 482 | O = eigvecs 483 | transformed_e = (e - mean_e) @ O @ D 484 | 485 | return transformed_e 486 | 487 | # Diffusion 488 | import torch 489 | 490 | def extract(a, t, x_shape): 491 | batch_size = t.shape[0] 492 | out = a.gather(-1, t.cpu()) 493 | return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device) 494 | 495 | 496 | def linear_beta_schedule(timesteps, beta_start, beta_end): 497 | beta_start = beta_start 498 | beta_end = beta_end 499 | return torch.linspace(beta_start, beta_end, timesteps) 500 | 501 | 502 | def cosine_beta_schedule(timesteps, s=0.008): 503 | steps = timesteps + 1 504 | x = torch.linspace(0, timesteps, steps) 505 | alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2 506 | alphas_cumprod = alphas_cumprod / alphas_cumprod[0] 507 | betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) 508 | return torch.clip(betas, 0.0001, 0.9999) 509 | 510 | 511 | def exp_beta_schedule(timesteps, beta_min=0.1, beta_max=10): 512 | x = torch.linspace(1, 2 * timesteps + 1, timesteps) 513 | betas = 1 - torch.exp(- beta_min / timesteps - x * 0.5 * (beta_max - beta_min) / (timesteps * timesteps)) 514 | return betas 515 | 516 | 517 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 518 | """ 519 | Create a beta schedule that discretizes the given alpha_t_bar function, 520 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 521 | :param num_diffusion_timesteps: the number of betas to produce. 522 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 523 | produces the cumulative product of (1-beta) up to that 524 | part of the diffusion process. 525 | :param max_beta: the maximum beta to use; use values lower than 1 to 526 | prevent singularities. 527 | """ 528 | betas = [] 529 | for i in range(num_diffusion_timesteps): 530 | t1 = i / num_diffusion_timesteps 531 | t2 = (i + 1) / num_diffusion_timesteps 532 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 533 | return np.array(betas) 534 | 535 | 536 | class SinusoidalPositionEmbeddings(nn.Module): 537 | def __init__(self, dim): 538 | super().__init__() 539 | self.dim = dim 540 | 541 | def forward(self, time): 542 | device = time.device 543 | half_dim = self.dim // 2 544 | embeddings = math.log(10000) / (half_dim - 1) 545 | embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings) 546 | embeddings = time[:, None] * embeddings[None, :] 547 | embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) 548 | return embeddings 549 | 550 | 551 | 552 | class PWLayer(nn.Module): 553 | """Single Parametric Whitening Layer 554 | """ 555 | def __init__(self, input_size, output_size, dropout=0.0): 556 | super(PWLayer, self).__init__() 557 | 558 | self.dropout = nn.Dropout(p=dropout) 559 | self.bias = nn.Parameter(torch.zeros(input_size), requires_grad=True) 560 | self.lin = nn.Linear(input_size, output_size, bias=False) 561 | 562 | self.apply(self._init_weights) 563 | 564 | def _init_weights(self, module): 565 | if isinstance(module, nn.Linear): 566 | module.weight.data.normal_(mean=0.0, std=0.02) 567 | 568 | def forward(self, x): 569 | return self.lin(self.dropout(x) - self.bias) 570 | 571 | 572 | # UniSRec 573 | class MoEAdaptorLayer(nn.Module): 574 | """MoE-enhanced Adaptor 575 | """ 576 | def __init__(self, n_exps, layers, dropout=0.0, noise=True): 577 | super(MoEAdaptorLayer, self).__init__() 578 | 579 | self.n_exps = n_exps 580 | self.noisy_gating = noise 581 | 582 | self.experts = nn.ModuleList([PWLayer(layers[0], layers[1], dropout) for i in range(n_exps)]) 583 | self.w_gate = nn.Parameter(torch.zeros(layers[0], n_exps), requires_grad=True) 584 | self.w_noise = nn.Parameter(torch.zeros(layers[0], n_exps), requires_grad=True) 585 | 586 | def noisy_top_k_gating(self, x, train, noise_epsilon=1e-2): 587 | clean_logits = x @ self.w_gate 588 | if self.noisy_gating and train: 589 | raw_noise_stddev = x @ self.w_noise 590 | noise_stddev = ((F.softplus(raw_noise_stddev) + noise_epsilon)) 591 | noisy_logits = clean_logits + (torch.randn_like(clean_logits).to(x.device) * noise_stddev) 592 | logits = noisy_logits 593 | else: 594 | logits = clean_logits 595 | 596 | gates = F.softmax(logits, dim=-1) 597 | return gates 598 | 599 | def forward(self, x): 600 | gates = self.noisy_top_k_gating(x, self.training) # (B, n_E) 601 | expert_outputs = [self.experts[i](x).unsqueeze(-2) for i in range(self.n_exps)] # [(B, 1, D)] 602 | expert_outputs = torch.cat(expert_outputs, dim=-2) 603 | multiple_outputs = gates.unsqueeze(-1) * expert_outputs 604 | return multiple_outputs.sum(dim=-2) 605 | 606 | 607 | class PLMEmb: 608 | def __init__(self, config, plm_embeddings): 609 | self.item_drop_ratio = config['item_drop_ratio'] 610 | self.plm_embeddings = plm_embeddings 611 | 612 | def __call__(self, interaction): 613 | '''Sequence augmentation and PLM embedding fetching 614 | ''' 615 | item_seq_len = interaction['item_length'] 616 | item_seq = interaction['item_id_list'] 617 | 618 | item_emb_seq = self.plm_embeddings(item_seq) 619 | pos_item_id = interaction['item_id'] 620 | pos_item_emb = self.plm_embeddings(pos_item_id) 621 | 622 | mask_p = torch.full_like(item_seq, 1 - self.item_drop_ratio, dtype=torch.float) 623 | mask = torch.bernoulli(mask_p).to(torch.bool) 624 | 625 | # Augmentation 626 | seq_mask = item_seq.eq(0).to(torch.bool) 627 | mask = torch.logical_or(mask, seq_mask) 628 | mask[:, 0] = True 629 | drop_index = torch.cumsum(mask, dim=1) - 1 630 | 631 | item_seq_aug = torch.zeros_like(item_seq).scatter(dim=-1, index=drop_index, src=item_seq) 632 | item_seq_len_aug = torch.gather(drop_index, 1, (item_seq_len - 1).unsqueeze(1)).squeeze() + 1 633 | item_emb_seq_aug = self.plm_embeddings(item_seq_aug) 634 | 635 | interaction.update({ 636 | 'item_emb_list': item_emb_seq, 637 | 'pos_item_emb': pos_item_emb, 638 | 'item_id_list_aug': item_seq_aug, 639 | 'item_length_aug': item_seq_len_aug, 640 | 'item_emb_list_aug': item_emb_seq_aug, 641 | }) 642 | 643 | return interaction 644 | 645 | 646 | class PointwiseAggregatedAttention(nn.Module): 647 | def __init__(self, d_model, num_heads): 648 | super().__init__() 649 | self.num_heads = num_heads 650 | self.head_dim = d_model // num_heads 651 | 652 | # TODO: add relative attention bias based on time 653 | self.rab_p = RelativeAttentionBias(num_heads, relative_attention_num_buckets=32, 654 | relative_attention_max_distance=128) 655 | 656 | def split_heads(self, x, batch_size): 657 | x = x.view(batch_size, -1, self.num_heads, self.head_dim) 658 | return x.permute(0, 2, 1, 3) 659 | 660 | def forward(self, v, k, q, mask=None): 661 | batch_size = q.shape[0] 662 | q = self.split_heads(q, batch_size) 663 | k = self.split_heads(k, batch_size) 664 | v = self.split_heads(v, batch_size) 665 | 666 | attention_scores = torch.matmul(q, k.transpose(-2, -1)) 667 | # attention_scores=torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)) 668 | rab = self.rab_p(q.shape[2], k.shape[2], device=q.device) 669 | 670 | att_w_bias = attention_scores + rab 671 | 672 | av = (F.silu(att_w_bias) @ v) 673 | return av.transpose(1, 2).flatten(2) 674 | 675 | 676 | class RelativeAttentionBias(nn.Module): 677 | def __init__(self, num_heads, relative_attention_num_buckets, relative_attention_max_distance=128): 678 | super().__init__() 679 | self.relative_attention_num_buckets = relative_attention_num_buckets 680 | self.relative_attention_max_distance = relative_attention_max_distance 681 | self.relative_attention_bias = nn.Embedding(relative_attention_num_buckets, num_heads) 682 | 683 | def forward(self, query_length, key_length, device=None): 684 | """Compute binned relative position bias""" 685 | if device is None: 686 | device = self.relative_attention_bias.weight.device 687 | context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] 688 | memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] 689 | relative_position = memory_position - context_position # shape (query_length, key_length) 690 | relative_position_bucket = self._relative_position_bucket( 691 | relative_position, # shape (query_length, key_length) 692 | bidirectional=False, 693 | num_buckets=self.relative_attention_num_buckets, 694 | max_distance=self.relative_attention_max_distance, 695 | ) 696 | values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads) 697 | values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) 698 | return values 699 | 700 | # https://github.com/huggingface/transformers/blob/6cdbd73e01a9719bfaec07d91fd108e8d932bbbb/src/transformers/models/t5/modeling_t5.py#L384 701 | @staticmethod 702 | def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): 703 | """ 704 | Adapted from Mesh Tensorflow: 705 | https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 706 | 707 | Translate relative position to a bucket number for relative attention. The relative position is defined as 708 | memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to 709 | position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for 710 | small absolute relative_position and larger buckets for larger absolute relative_positions. All relative 711 | positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. 712 | This should allow for more graceful generalization to longer sequences than the model has been trained on 713 | 714 | Args: 715 | relative_position: an int32 Tensor 716 | bidirectional: a boolean - whether the attention is bidirectional 717 | num_buckets: an integer 718 | max_distance: an integer 719 | 720 | Returns: 721 | a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) 722 | """ 723 | relative_buckets = 0 724 | if bidirectional: 725 | num_buckets //= 2 726 | relative_buckets += (relative_position > 0).to(torch.long) * num_buckets 727 | relative_position = torch.abs(relative_position) 728 | else: 729 | relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) 730 | # now relative_position is in the range [0, inf) 731 | 732 | # half of the buckets are for exact increments in positions 733 | max_exact = num_buckets // 2 734 | is_small = relative_position < max_exact 735 | 736 | # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance 737 | relative_position_if_large = max_exact + ( 738 | torch.log(relative_position.float() / max_exact) 739 | / math.log(max_distance / max_exact) 740 | * (num_buckets - max_exact) 741 | ).to(torch.long) 742 | relative_position_if_large = torch.min( 743 | relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1) 744 | ) 745 | 746 | relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) 747 | return relative_buckets 748 | 749 | 750 | class HSTUBlock(nn.Module): 751 | def __init__(self, d_model, num_heads, dropout=0.1): 752 | super().__init__() 753 | self.f1 = nn.Linear(d_model, d_model * 4) # Transform and split 754 | self.pointwise_attn = PointwiseAggregatedAttention(d_model, num_heads) 755 | self.f2 = nn.Linear(d_model, d_model) 756 | self.norm = nn.LayerNorm(d_model) 757 | 758 | def split(self, x): 759 | u, v, q, k = x.chunk(4, dim=-1) 760 | return u, v, q, k 761 | 762 | def forward(self, x, mask=None): 763 | # Pointwise Projection 764 | if mask is not None: 765 | x = x * mask 766 | x_proj = F.silu(self.f1(x)) 767 | u, v, q, k = self.split(x_proj) 768 | 769 | # Spatial Aggregation 770 | av = self.pointwise_attn(v, k, q) 771 | 772 | # Pointwise Transformation 773 | y = self.f2(self.norm(av * u)) 774 | 775 | return y --------------------------------------------------------------------------------