├── README.md ├── baselines ├── D3 │ ├── D3_evaluate.py │ ├── LogitProcesser.py │ └── evaluate2.sh ├── DMPO │ ├── dmpo_trainer.py │ └── utils.py ├── Re-weighting │ └── RW_SFT.py ├── SDPO │ ├── softmax_dpo_trainer.py │ └── utils.py └── Semantic_sampling_rosePO │ └── Semantic_sampling_rosePO.py ├── data └── Goodreads │ ├── test.json │ ├── train.json │ └── valid.json ├── environment.yml ├── eval ├── Goodreads │ ├── embeddings.pt │ ├── genre_dict.json │ ├── id2name.json │ ├── name2genre.json │ └── name2id.json ├── MovieLens │ ├── embeddings.pt │ ├── genre_dict.json │ ├── id2name.json │ ├── name2genre.json │ └── name2id.json ├── evaluate.py └── inference.py ├── figs └── method.png ├── shell ├── SFT.sh ├── SPRec.sh └── eval_single_file.sh └── train ├── data_generate.py ├── dpo.py └── sft.py /README.md: -------------------------------------------------------------------------------- 1 | 2 | # SPRec: Self-Play to Debias LLM-based Recommendation 3 | 4 |
5 | introduction 6 |
7 | 8 | This repository provides the official PyTorch implementation and reproduction for the paper titled "SPRec: Self-Play to Debias LLM-based Recommendation" 9 | 10 | ## Installation 11 | 12 | 1. Clone this git repository and change directory to this repository: 13 | 14 | 2. A new [conda environment](https://docs.conda.io/projects/conda/en/latest/user-guide/concepts/environments.html) is suggested. 15 | 16 | ```bash 17 | conda env create -f environment.yml 18 | ``` 19 | 20 | 3. Activate the newly created environment. 21 | 22 | ```bash 23 | conda activate SPRec 24 | ``` 25 | 26 | 27 | ## Quick Start 28 | 29 | Due to GitHub's file size limitations, we have uploaded the minimal sample dataset **Goodreads** in `./data/Goodreads` and `./eval/Goodreads` for reproduction purposes. Additionally, the datasets used in our experiments—**MovieLens**, **CDs and Vinyl**, and **Steam**—have been uploaded to [Datasets](https://zenodo.org/records/14900102?token=eyJhbGciOiJIUzUxMiJ9.eyJpZCI6IjMwYTA1OWM4LWRjZTctNDJmNC1iOWY2LTRjZWQyZjZiNjY5ZCIsImRhdGEiOnt9LCJyYW5kb20iOiI2ZTYyZDZkZTFlNDM5NjA2ZGMwMTA2YWIxMjdjMDJmNCJ9.g2bckZWGA77AEg9EBARxN45rmXYfGD8RuRzy41CZACcDh2XESWxAGD3b91ecu_FEbmYQSzR5qBTH0xvQC_Lw2Q). If you wish to use a different dataset, please ensure that it is processed into a similar format. 30 | 31 | Besides, to ensure that SPRec does not encounter more training data during multiple iterations compared to other baseline methods, it is recommended to sample the training dataset beforehand to limit its size. The sample dataset we provide has already been sampled and contains 5,000 entries. You can further sample it according to your requirements to control the total amount of data SPRec is exposed to during training. 32 | 33 | 34 | ### How to Train Using SPRec Framework 35 | 36 | 1. **SFT Training**: 37 | Before using the SPRec training framework, you need to run SFT to fine-tune your base model for alignment with the recommendation task. Use the following command to perform SFT training: 38 | ```bash 39 | bash ./shell/SFT.sh 0 1 2 3 # Specify your GPUs, e.g., 0 1 2 3 40 | 2. **SPRec Training**: 41 | After completing SFT training, use the following command to perform SPRec training: 42 | ```bash 43 | bash ./shell/SPRec.sh 0 1 2 3 5 # Specify your GPUs, e.g., 0 1 2 3, and the number of iterations, e.g., 5 44 | Once the above commands are executed, the evaluation results for top-1 and top-5 recommendations will be saved as eval_top1.json and eval_top5.json in the corresponding model directory. 45 | 46 | ## **Baseline Implementations Acknowledgement 47 | This repository also includes implementations of baseline methods in our paper for research comparison. We sincerely acknowledge the original authors for their foundational work. 48 | 49 | If you find this repository helpful, we kindly request citing our paper: 50 | ``` 51 | @article{gao2024sprec, 52 | title={SPRec: Self-Play to Debias LLM-based Recommendation}, 53 | author={Gao, Chongming and Chen, Ruijun and Yuan, Shuai and Huang, Kexin and Yu, Yuanqing and He, Xiangnan}, 54 | journal={arXiv preprint arXiv:2412.09243}, 55 | year={2024} 56 | } 57 | ``` 58 | -------------------------------------------------------------------------------- /baselines/D3/D3_evaluate.py: -------------------------------------------------------------------------------- 1 | 2 | import pandas as pd 3 | import fire 4 | import torch 5 | import json 6 | import os 7 | from peft import PeftModel 8 | from transformers import GenerationConfig, AutoTokenizer 9 | from transformers import AutoModelForCausalLM 10 | from dataset import D3Dataset 11 | from transformers import LogitsProcessorList, TemperatureLogitsWarper 12 | from transformers import GenerationConfig, LlamaTokenizer 13 | from transformers import LlamaForCausalLM,AutoTokenizer 14 | from LogitProcesser import CFEnhancedLogitsProcessor 15 | if torch.cuda.is_available(): 16 | device = "cuda" 17 | else: 18 | device = "cpu" 19 | P = 998244353 20 | MOD = int(1e9 + 9) 21 | import numpy as np 22 | 23 | def get_hash(x): 24 | x = [str(_) for _ in x] 25 | return '-'.join(x) 26 | 27 | 28 | 29 | def main( 30 | base_model: str = "", 31 | train_file: str = "", 32 | info_file: str = "", 33 | category: str = "", 34 | logits_file: str=None, 35 | lora_weights:str = "", 36 | test_data_path: str = "data/test.json", 37 | result_json_data: str = "temp.json", 38 | batch_size: int = 1, 39 | K: int = 0, 40 | seed: int = 0, 41 | temperature: float=1.0, 42 | guidance_scale: float=1.0, 43 | length_penalty: float=1.0 44 | ): 45 | category_dict = {"Office_Products": "office products", "Goodreads": "books", "Steam": "games", "CDs_and_Vinyl": "musics", "Toys_and_Games": "toys and games", "Video_Games": "video games", "Musical_Instruments": "music instruments", "Sports_and_Outdoors": "sports and outdoors", "Pet_Supplies": "pet supplies", "Arts_Crafts_and_Sewing": "arts products", "STEAM": "games" ,"MovieLens":"movies"} 46 | category = category_dict[category] 47 | model = LlamaForCausalLM.from_pretrained( 48 | base_model, 49 | load_in_8bit=False, 50 | torch_dtype=torch.float16, 51 | device_map="auto", 52 | ) 53 | model = PeftModel.from_pretrained( 54 | model, 55 | lora_weights, 56 | torch_dtype=torch.float16, 57 | device_map="auto" 58 | ) 59 | with open(info_file, 'r') as f: 60 | info = f.readlines() 61 | print(info) 62 | info = ["\"" + _.split('\t')[0].strip(' ') + "\"\n" for _ in info] 63 | item_name = info 64 | info = [f'''### Response: 65 | {_}''' for _ in info] 66 | 67 | tokenizer = AutoTokenizer.from_pretrained(base_model) 68 | if base_model.lower().find("llama") > -1: 69 | prefixID = [tokenizer(_).input_ids[1:] for _ in info] 70 | else: 71 | prefixID = [tokenizer(_).input_ids for _ in info] 72 | 73 | hash_dict = dict() 74 | sasrec_dict = dict() 75 | for index, ID in enumerate(prefixID): 76 | ID.append(tokenizer.eos_token_id) 77 | for i in range(4, len(ID)): 78 | if i == 4: 79 | hash_number = get_hash(ID[:i]) 80 | else: 81 | hash_number = get_hash(ID[4:i]) 82 | if hash_number not in hash_dict: 83 | hash_dict[hash_number] = set() 84 | sasrec_dict[hash_number] = set() 85 | hash_dict[hash_number].add(ID[i]) 86 | sasrec_dict[hash_number].add(index) 87 | hash_number = get_hash(ID[4:]) 88 | if hash_number not in sasrec_dict: 89 | sasrec_dict[hash_number] = set() 90 | sasrec_dict[hash_number].add(index) 91 | 92 | for key in hash_dict.keys(): 93 | hash_dict[key] = list(hash_dict[key]) 94 | for key in sasrec_dict.keys(): 95 | sasrec_dict[key] = list(sasrec_dict[key]) 96 | 97 | def prefix_allowed_tokens_fn(batch_id, input_ids): 98 | hash_number = get_hash(input_ids) 99 | if hash_number in hash_dict: 100 | return hash_dict[hash_number] 101 | return [] 102 | 103 | tokenizer.pad_token = tokenizer.eos_token 104 | tokenizer.pad_token_id = tokenizer.eos_token_id 105 | tokenizer.padding_side = "left" 106 | val_dataset=D3Dataset(train_file=test_data_path, tokenizer=tokenizer,max_len=2560, category=category, test=True,K=K, seed=seed) 107 | 108 | 109 | if logits_file is not None: 110 | if not logits_file.endswith(".npy"): 111 | logits_file = None 112 | 113 | if logits_file is not None: 114 | logits = np.load(logits_file) 115 | sasrec_logits = torch.tensor(logits).softmax(dim = -1) 116 | sasrec_logits = sasrec_logits[val_dataset.data['Unnamed: 0'].tolist()] 117 | 118 | encodings = [val_dataset.__getitem__(i) for i in range(len(val_dataset))] 119 | test_data = val_dataset.get_all() 120 | 121 | model.config.pad_token_id = model.config.eos_token_id = tokenizer.eos_token_id 122 | model.config.bos_token_id = tokenizer.bos_token_id 123 | 124 | model.eval() 125 | 126 | def evaluate( 127 | encodings, 128 | cf_logits, 129 | temperature=1.0, 130 | num_beams=1, 131 | max_new_tokens=32, 132 | top_p=0.9, 133 | top_k=40, 134 | guidance_scale=0.8, 135 | length_penalty=1.0, 136 | **kwargs, 137 | ): 138 | maxLen = max([len(_["input_ids"]) for _ in encodings]) 139 | 140 | padding_encodings = {"input_ids": []} 141 | 142 | for _ in encodings: 143 | L = len(_["input_ids"]) 144 | padding_encodings["input_ids"].append([tokenizer.pad_token_id] * (maxLen - L) + _["input_ids"]) 145 | 146 | generation_config = GenerationConfig( 147 | num_beams=num_beams, 148 | temperature = temperature, 149 | #length_penalty=length_penalty, 150 | top_p=top_p, 151 | top_k=top_k, 152 | num_return_sequences=num_beams, 153 | pad_token_id = model.config.pad_token_id, 154 | eos_token_id = model.config.eos_token_id, 155 | max_new_tokens = max_new_tokens, 156 | **kwargs 157 | ) 158 | with torch.no_grad(): 159 | ccc = CFEnhancedLogitsProcessor( 160 | guidance_scale=guidance_scale, 161 | cf_logits=cf_logits, 162 | prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, 163 | cf_dict=sasrec_dict, 164 | unconditional_ids=None, 165 | model=model, 166 | tokenizer=tokenizer, 167 | num_beams=num_beams 168 | ) 169 | logits_processor = LogitsProcessorList([TemperatureLogitsWarper(temperature=temperature), ccc]) 170 | # logits 对应上 171 | generation_output = model.generate( 172 | torch.tensor(padding_encodings["input_ids"]).to(device), 173 | generation_config=generation_config, 174 | return_dict_in_generate=True, 175 | output_scores=True, 176 | logits_processor=logits_processor, 177 | ) 178 | s = generation_output.sequences[:, L:] 179 | sequence_scores = [[0 for i in range(len(generation_output.scores))] for _ in range(num_beams)] 180 | #for i in range(num_beams): 181 | #for j in range(L, len(generation_output.sequences[i])): 182 | #if num_beams > 1: 183 | #beam_index = generation_output.beam_indices[i][j - L] 184 | #if beam_index != -1: 185 | #sequence_scores[i][j - L] = generation_output.scores[j - L][beam_index][generation_output.sequences[i][j]].item() 186 | 187 | #scores = generation_output.sequences_scores.tolist() 188 | scores = [1958.0] 189 | output = tokenizer.batch_decode(s, skip_special_tokens=True) 190 | output = [_.split("Response:")[-1] for _ in output] 191 | real_outputs = [output[i * num_beams: (i + 1) * num_beams] for i in range(len(output) // num_beams)] 192 | real_scores = [scores[i * num_beams: (i + 1) * num_beams] for i in range(len(scores) // num_beams)] 193 | return real_outputs, real_scores, sequence_scores 194 | 195 | model = model.to(device) 196 | 197 | from tqdm import tqdm 198 | outputs = [] 199 | new_encodings = [] 200 | BLOCK = (len(encodings) + batch_size - 1) // batch_size 201 | for i in range(BLOCK): 202 | new_encodings.append(encodings[i * batch_size: (i + 1) * batch_size]) 203 | Flg=True 204 | scores = [] 205 | seq_scores = [] 206 | import random 207 | for idx, encodings in enumerate(tqdm(new_encodings)): 208 | if logits_file is not None: 209 | output, score, seq_score = evaluate(encodings, sasrec_logits[idx].to(device), temperature=temperature, guidance_scale=guidance_scale, length_penalty=length_penalty) 210 | else: 211 | output, score, seq_score = evaluate(encodings, cf_logits=None, temperature=temperature, guidance_scale=guidance_scale, length_penalty=length_penalty) 212 | if idx == 0: 213 | print(output) 214 | print(score) 215 | outputs = outputs + output 216 | scores = scores+ score 217 | seq_scores.append(seq_score) 218 | 219 | for i, test in enumerate(test_data): 220 | test["predict"] = outputs[i] 221 | #test["predict_score"] = scores[i] 222 | #test["predict_seq_score"] = seq_scores[i] 223 | 224 | for i in range(len(test_data)): 225 | if 'dedup' in test_data[i]: 226 | test_data[i].pop('dedup') 227 | 228 | with open(result_json_data, 'w') as f: 229 | json.dump(test_data, f, indent=4) 230 | 231 | if __name__ == '__main__': 232 | fire.Fire(main) 233 | 234 | 235 | 236 | 237 | -------------------------------------------------------------------------------- /baselines/D3/LogitProcesser.py: -------------------------------------------------------------------------------- 1 | from transformers.generation import LogitsProcessor 2 | from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union 3 | import math 4 | import numpy as np 5 | import torch 6 | 7 | from transformers.utils import add_start_docstrings 8 | 9 | LOGITS_PROCESSOR_INPUTS_DOCSTRING = r""" 10 | Args: 11 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): 12 | Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) 13 | scores (`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`): 14 | Prediction scores of a language modeling head. These can be logits for each vocabulary when not using beam 15 | search or log softmax for each vocabulary token when using beam search 16 | 17 | Return: 18 | `torch.FloatTensor` of shape `(batch_size, config.vocab_size)`: The processed prediction scores. 19 | 20 | """ 21 | 22 | class PrefixConstrainedLogitsProcessor(LogitsProcessor): 23 | 24 | def __init__(self, prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]], num_beams: int): 25 | self._prefix_allowed_tokens_fn = prefix_allowed_tokens_fn 26 | self._num_beams = num_beams 27 | 28 | @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) 29 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: 30 | mask = torch.full_like(scores, -math.inf) 31 | for batch_id, beam_sent in enumerate(input_ids.view(-1, self._num_beams, input_ids.shape[-1])): 32 | for beam_id, sent in enumerate(beam_sent): 33 | prefix_allowed_tokens = self._prefix_allowed_tokens_fn(batch_id, sent) 34 | if len(prefix_allowed_tokens) == 0: 35 | raise ValueError( 36 | f"`prefix_allowed_tokens_fn` returned an empty list for batch ID {batch_id}." 37 | f"This means that the constraint is unsatisfiable. Please check your implementation" 38 | f"of `prefix_allowed_tokens_fn` " 39 | ) 40 | mask[batch_id * self._num_beams + beam_id, prefix_allowed_tokens] = 0 41 | 42 | scores_processed = scores + mask 43 | return scores_processed 44 | 45 | 46 | def get_hash(x): 47 | x = [str(_) for _ in x] 48 | return '-'.join(x) 49 | 50 | class CFEnhancedLogitsProcessor(LogitsProcessor): 51 | 52 | def __init__( 53 | self, 54 | tokenizer, 55 | model, 56 | cf_logits, 57 | cf_dict, 58 | guidance_scale: float, 59 | prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]], 60 | num_beams: int, 61 | unconditional_ids: Optional[torch.LongTensor] = None, 62 | unconditional_attention_mask: Optional[torch.LongTensor] = None, 63 | use_cache: Optional[bool] = True, 64 | ): 65 | self._prefix_allowed_tokens_fn = prefix_allowed_tokens_fn 66 | self.model = model 67 | self.unconditional_context = { 68 | "input_ids": unconditional_ids, 69 | "attention_mask": unconditional_attention_mask, 70 | "use_cache": use_cache, 71 | "past_key_values": None, 72 | "first_pass": True, 73 | } 74 | self._num_beams = num_beams 75 | self.guidance_scale = guidance_scale 76 | self.tokenizer = tokenizer 77 | self.cf_logits = cf_logits 78 | self.cf_dict = cf_dict 79 | self.count=0 80 | 81 | 82 | @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) 83 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: 84 | scores = torch.nn.functional.log_softmax(scores, dim=-1) 85 | mask = torch.full_like(scores, -1000000) 86 | cf_score = torch.full_like(scores, 1.0) 87 | for batch_id, beam_sent in enumerate(input_ids.view(-1, self._num_beams, input_ids.shape[-1])): 88 | for beam_id, sent in enumerate(beam_sent): 89 | if self.count == 0: 90 | hash_key = sent[-4:] 91 | else: 92 | hash_key=sent[-self.count:] 93 | hash_key = hash_key.tolist() 94 | prefix_allowed_tokens = self._prefix_allowed_tokens_fn(batch_id, hash_key) 95 | 96 | if len(prefix_allowed_tokens) == 0: 97 | continue 98 | mask[batch_id * self._num_beams + beam_id, prefix_allowed_tokens] = 0 99 | 100 | temp = [] 101 | if self.cf_logits is not None: 102 | # print(self.cf_logits) 103 | for allow_token in prefix_allowed_tokens: 104 | if self.count == 0: 105 | cf_key = [allow_token] 106 | else: 107 | cf_key = hash_key + [allow_token] 108 | if get_hash(cf_key) in self.cf_dict: 109 | hash_value = self.cf_dict[get_hash(cf_key)] 110 | else: 111 | continue 112 | 113 | sublogits = self.cf_logits[hash_value] 114 | temp.append(sublogits.sum() + 1e-20) # max or sum 115 | temp = torch.tensor(temp) 116 | temp = temp / temp.sum() 117 | cf_score[batch_id * self._num_beams + beam_id].scatter_(dim = -1, index=torch.tensor(prefix_allowed_tokens).to(cf_score.device), src=temp.to(cf_score.device)) 118 | cf_score = torch.log(cf_score) 119 | cf_score = cf_score + mask 120 | self.count += 1 121 | 122 | if self.guidance_scale == 1: 123 | scores = scores + mask 124 | return scores 125 | 126 | scores = scores + mask 127 | out = self.guidance_scale * (scores - cf_score) + cf_score 128 | 129 | return out 130 | -------------------------------------------------------------------------------- /baselines/D3/evaluate2.sh: -------------------------------------------------------------------------------- 1 | for category in "Goodreads" 2 | do 3 | cudalist="7" 4 | for i in ${cudalist} 5 | do 6 | echo $i 7 | CUDA_VISIBLE_DEVICES=$i python ./evaluate.py \ 8 | --base_model ./output_dir/${category}/ \ 9 | --train_file ${train_file} \ 10 | --info_file ${info_file} \ 11 | --category ${category} \ 12 | --test_data_path ./temp/${category}_base/${i}.csv \ 13 | --result_json_data ./temp/${category}_base/${i}.json \ 14 | --length_penalty 0.0 \ 15 | --logits_file YOUR_LOGITS_FILE_PATH 16 | done 17 | wait 18 | python ./code/merge.py --input_path ./temp/${category}_base --output_path ./output_dir/${category}/final_result.json 19 | python ./code/calc.py --path ./output_dir/${category}/final_result.json --item_path ${info_file} 20 | done 21 | -------------------------------------------------------------------------------- /baselines/DMPO/dmpo_trainer.py: -------------------------------------------------------------------------------- 1 | # DPO Authors: Rafael Rafailov, Archit Sharma, Eric Mitchell, Stefano Ermon, Christopher D. Manning, and Chelsea Finn 2023 2 | # Copyright 2023 The HuggingFace Team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | import warnings 16 | from collections import defaultdict 17 | from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union 18 | import importlib 19 | 20 | 21 | import torch 22 | import torch.nn as nn 23 | import torch.nn.functional as F 24 | from datasets import Dataset 25 | from transformers import DataCollator, PreTrainedModel, PreTrainedTokenizerBase, Trainer, TrainingArguments 26 | from transformers.trainer_callback import TrainerCallback 27 | 28 | from .utils import DPODataCollatorWithPadding, pad_to_length 29 | 30 | 31 | def is_peft_available(): 32 | return importlib.util.find_spec("peft") is not None 33 | 34 | if is_peft_available(): 35 | from peft import get_peft_model, prepare_model_for_kbit_training 36 | 37 | 38 | class DPOTrainer(Trainer): 39 | r""" 40 | Initialize DPOTrainer. 41 | 42 | Args: 43 | model (`transformers.PreTrainedModel`): 44 | The model to train, preferably an `AutoModelForSequenceClassification`. 45 | ref_model (`PreTrainedModelWrapper`): 46 | Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. 47 | beta (`float`, defaults to 0.1): 48 | The beta factor in DPO loss. Higher beta means less divergence from the initial policy. 49 | args (`transformers.TrainingArguments`): 50 | The arguments to use for training. 51 | data_collator (`transformers.DataCollator`): 52 | The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used 53 | which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences. 54 | label_pad_token_id (`int`, defaults to `-100`): 55 | The label pad token id. This argument is required if you want to use the default data collator. 56 | padding_value (`int`, defaults to `0`): 57 | The padding value. This argument is required if you want to use the default data collator. 58 | truncation_mode (`str`, defaults to `keep_end`): 59 | The truncation mode to use, either `keep_end` or `keep_start`. This argument is required if you want to use the default data collator. 60 | train_dataset (`datasets.Dataset`): 61 | The dataset to use for training. 62 | eval_dataset (`datasets.Dataset`): 63 | The dataset to use for evaluation. 64 | tokenizer (`transformers.PreTrainedTokenizerBase`): 65 | The tokenizer to use for training. This argument is required if you want to use the default data collator. 66 | model_init (`Callable[[], transformers.PreTrainedModel]`): 67 | The model initializer to use for training. If None is specified, the default model initializer will be used. 68 | callbacks (`List[transformers.TrainerCallback]`): 69 | The callbacks to use for training. 70 | optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`): 71 | The optimizer and scheduler to use for training. 72 | preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`): 73 | The function to use to preprocess the logits before computing the metrics. 74 | max_length (`int`, defaults to `None`): 75 | The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator. 76 | max_prompt_length (`int`, defaults to `None`): 77 | The maximum length of the prompt. This argument is required if you want to use the default data collator. 78 | peft_config (`Dict`, defaults to `None`): 79 | The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model. 80 | """ 81 | 82 | def __init__( 83 | self, 84 | model: Union[PreTrainedModel, nn.Module] = None, 85 | ref_model: Union[PreTrainedModel, nn.Module] = None, 86 | beta: float = 0.1, 87 | args: TrainingArguments = None, 88 | data_collator: Optional[DataCollator] = None, 89 | label_pad_token_id: int = -100, 90 | padding_value: int = 0, 91 | truncation_mode: str = "keep_end", 92 | train_dataset: Optional[Dataset] = None, 93 | eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, 94 | tokenizer: Optional[PreTrainedTokenizerBase] = None, 95 | model_init: Optional[Callable[[], PreTrainedModel]] = None, 96 | callbacks: Optional[List[TrainerCallback]] = None, 97 | optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = ( 98 | None, 99 | None, 100 | ), 101 | preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, 102 | max_length: Optional[int] = None, 103 | max_prompt_length: Optional[int] = None, 104 | peft_config: Optional[Dict] = None, 105 | ): 106 | if not is_peft_available() and peft_config is not None: 107 | raise ValueError( 108 | "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models" 109 | ) 110 | elif is_peft_available() and peft_config is not None: 111 | if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): 112 | model = prepare_model_for_kbit_training(model) 113 | model = get_peft_model(model, peft_config) 114 | 115 | if data_collator is None: 116 | if tokenizer is None: 117 | raise ValueError( 118 | "max_length or a tokenizer must be specified when using the default DPODataCollatorWithPadding" 119 | ) 120 | if max_length is None: 121 | warnings.warn( 122 | "When using DPODataCollatorWithPadding, you should set `max_length` in the DPOTrainer's init" 123 | " it will be set to `512` by default, but you should do it yourself in the future.", 124 | UserWarning, 125 | ) 126 | max_length = 512 127 | if max_prompt_length is None: 128 | warnings.warn( 129 | "When using DPODataCollatorWithPadding, you should set `max_prompt_length` in the DPOTrainer's init" 130 | " it will be set to `128` by default, but you should do it yourself in the future.", 131 | UserWarning, 132 | ) 133 | max_prompt_length = 128 134 | 135 | data_collator = DPODataCollatorWithPadding( 136 | tokenizer, 137 | max_length=max_length, 138 | max_prompt_length=max_prompt_length, 139 | label_pad_token_id=label_pad_token_id, 140 | padding_value=padding_value, 141 | truncation_mode=truncation_mode, 142 | ) 143 | 144 | if args.remove_unused_columns: 145 | args.remove_unused_columns = False 146 | # warn users 147 | warnings.warn( 148 | "When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments" 149 | " we have set it for you, but you should do it yourself in the future.", 150 | UserWarning, 151 | ) 152 | 153 | self.use_dpo_data_collator = True 154 | else: 155 | self.use_dpo_data_collator = False 156 | 157 | self.label_pad_token_id = label_pad_token_id 158 | self.padding_value = padding_value 159 | 160 | self.beta = beta 161 | self.ref_model = ref_model 162 | 163 | self._stored_metrics = defaultdict(lambda: defaultdict(list)) 164 | 165 | super().__init__( 166 | model, 167 | args, 168 | data_collator, 169 | train_dataset, 170 | eval_dataset, 171 | tokenizer, 172 | model_init, 173 | None, 174 | callbacks, 175 | optimizers, 176 | preprocess_logits_for_metrics, 177 | ) 178 | 179 | # Since we inherit from trainer we always have access to an accelerator 180 | if hasattr(self, "accelerator"): 181 | self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) 182 | else: 183 | raise AttributeError( 184 | "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`." 185 | ) 186 | 187 | def concatenated_inputs(self, batch: Dict[str, Union[List, torch.LongTensor]]) -> Dict[str, torch.LongTensor]: 188 | """Concatenate the chosen and rejected inputs into a single tensor. 189 | 190 | Args: 191 | batch: A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors of shape (batch_size, sequence_length). 192 | 193 | Returns: 194 | A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'. 195 | """ 196 | # 把 chosen 和 rejected response 拼接起来 197 | rejected_max_len = max([batch[key].shape[1] for key in batch if key.startswith("rejected") and key.endswith("_input_ids")]) 198 | max_length = max(batch["chosen_input_ids"].shape[1], rejected_max_len) 199 | concatenated_batch = {} 200 | for k in batch: 201 | if k.startswith("chosen") and isinstance(batch[k], torch.Tensor): 202 | pad_value = self.label_pad_token_id if "labels" in k else self.padding_value 203 | concatenated_key = k.replace("chosen", "concatenated") 204 | concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value) 205 | for k in batch: 206 | if k.startswith("rejected") and isinstance(batch[k], torch.Tensor): 207 | pad_value = self.label_pad_token_id if "labels" in k else self.padding_value 208 | # concatenated_key = k.replace("rejected", "concatenated") 209 | prefix = k.split("_")[0] 210 | concatenated_key = "concatenated" + k[len(prefix):] 211 | concatenated_batch[concatenated_key] = torch.cat( 212 | ( 213 | concatenated_batch[concatenated_key], 214 | pad_to_length(batch[k], max_length, pad_value=pad_value), 215 | ), 216 | dim=0, 217 | ).to(self.accelerator.device) 218 | return concatenated_batch 219 | 220 | def dpo_loss( 221 | self, 222 | policy_chosen_logps: torch.FloatTensor, 223 | policy_rejected_logps: Dict[str, torch.FloatTensor], 224 | reference_chosen_logps: torch.FloatTensor, 225 | reference_rejected_logps: Dict[str, torch.FloatTensor], 226 | reference_free: bool = False, 227 | ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: 228 | """Compute the DPO loss for a batch of policy and reference model log probabilities. 229 | 230 | Args: 231 | policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,) 232 | policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,) 233 | reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,) 234 | reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,) 235 | beta: Temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5. We ignore the reference model as beta -> 0. 236 | reference_free: If True, we ignore the _provided_ reference model and implicitly use a reference model that assigns equal probability to all responses. 237 | 238 | Returns: 239 | A tuple of three tensors: (losses, chosen_rewards, rejected_rewards). 240 | The losses tensor contains the DPO loss for each example in the batch. 241 | The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively. 242 | """ 243 | # pi_logratios = policy_chosen_logps - policy_rejected_logps 244 | # for key in policy_rejected_logps: 245 | # ref_logratios = reference_chosen_logps - reference_rejected_logps 246 | chosen_logratios = policy_chosen_logps - reference_chosen_logps 247 | # print(f"chosen:{chosen_logratios}") 248 | rejected_logratios = {} 249 | for key in policy_rejected_logps: 250 | rejected_logratios[key] = policy_rejected_logps[key] - reference_rejected_logps[key] 251 | # print(f"{key}_logratios:{rejected_logratios[key].shape}") 252 | # if reference_free: 253 | # ref_logratios = 0 254 | 255 | # logits = pi_logratios - ref_logratios 256 | # temp = sum(torch.exp(self.beta * (rejected_logratios[key] - chosen_logratios)) for key in rejected_logratios) 257 | temp = torch.exp(self.beta * sum(rejected_logratios[key] - chosen_logratios for key in rejected_logratios)) 258 | 259 | temp1 = -torch.log(temp) 260 | losses = -F.logsigmoid(temp1) 261 | # losses = -F.logsigmoid(self.beta * logits) 262 | rejected_rewards = {} 263 | chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps).detach() 264 | for key in policy_rejected_logps: 265 | rejected_rewards[key] = self.beta * (policy_rejected_logps[key] - reference_rejected_logps[key]).detach() 266 | 267 | return losses, chosen_rewards, rejected_rewards 268 | 269 | def _get_batch_logps( 270 | self, 271 | logits: torch.FloatTensor, 272 | labels: torch.LongTensor, 273 | average_log_prob: bool = False, 274 | ) -> torch.FloatTensor: 275 | """Compute the log probabilities of the given labels under the given logits. 276 | 277 | Args: 278 | logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size) 279 | labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length) 280 | average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens. 281 | 282 | Returns: 283 | A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits. 284 | """ 285 | if logits.shape[:-1] != labels.shape: 286 | raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.") 287 | 288 | labels = labels[:, 1:].clone() 289 | logits = logits[:, :-1, :] 290 | loss_mask = labels != self.label_pad_token_id 291 | 292 | # dummy token; we'll ignore the losses on these tokens later 293 | labels[labels == self.label_pad_token_id] = 0 294 | 295 | per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) 296 | 297 | if average_log_prob: 298 | return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) 299 | else: 300 | return (per_token_logps * loss_mask).sum(-1) 301 | 302 | def concatenated_forward( 303 | self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]] 304 | ) -> Tuple[torch.FloatTensor, Dict[str, torch.FloatTensor], torch.FloatTensor, Dict[str, torch.FloatTensor]]: 305 | """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together. 306 | 307 | We do this to avoid doing two forward passes, because it's faster for FSDP. 308 | """ 309 | concatenated_batch = self.concatenated_inputs(batch) 310 | # print(concatenated_batch["concatenated_input_ids"].shape) 311 | all_logits = model( 312 | concatenated_batch["concatenated_input_ids"], 313 | attention_mask=concatenated_batch["concatenated_attention_mask"], 314 | ).logits.to(torch.float32) 315 | all_logps = self._get_batch_logps( 316 | all_logits, 317 | concatenated_batch["concatenated_labels"], 318 | average_log_prob=False, 319 | ) 320 | chosen_logps = all_logps[: batch["chosen_input_ids"].shape[0]] 321 | step = batch["chosen_input_ids"].shape[0] 322 | rejected_logps = {} 323 | cnt = 0 324 | for key in batch: 325 | if key.startswith("rejected") and key.endswith("_input_ids"): 326 | cnt += 1 327 | rejected_logps[f"rejected{cnt}"] = all_logps[step*cnt : step*(cnt+1)] 328 | 329 | chosen_logits = all_logits[: batch["chosen_input_ids"].shape[0]] 330 | rejected_logits = {} 331 | cnt = 0 332 | for key in batch: 333 | if key.startswith("rejected") and key.endswith("_input_ids"): 334 | cnt += 1 335 | rejected_logits[f"rejected{cnt}"] = all_logits[step*cnt : step*(cnt+1)] 336 | return (chosen_logps, rejected_logps, chosen_logits, rejected_logits) 337 | 338 | def get_batch_metrics( 339 | self, 340 | model, 341 | batch: Dict[str, Union[List, torch.LongTensor]], 342 | train_eval: Literal["train", "eval"] = "train", 343 | ): 344 | """Compute the DPO loss and other metrics for the given batch of inputs for train or test.""" 345 | metrics = {} 346 | 347 | ( 348 | policy_chosen_logps, 349 | policy_rejected_logps, 350 | policy_chosen_logits, 351 | policy_rejected_logits, 352 | ) = self.concatenated_forward(model, batch) 353 | with torch.no_grad(): 354 | ( 355 | reference_chosen_logps, 356 | reference_rejected_logps, 357 | _, 358 | _, 359 | ) = self.concatenated_forward(self.ref_model, batch) 360 | 361 | losses, chosen_rewards, rejected_rewards = self.dpo_loss( 362 | policy_chosen_logps, 363 | policy_rejected_logps, 364 | reference_chosen_logps, 365 | reference_rejected_logps, 366 | ) 367 | 368 | # reward_accuracies 记录 chosen 比所有 rejected 的收益都大的比例是多少 369 | reward_accuracies = None 370 | for key in rejected_rewards: 371 | if reward_accuracies is None: 372 | reward_accuracies = (chosen_rewards > rejected_rewards[key]).float() 373 | else: 374 | reward_accuracies *= (chosen_rewards > rejected_rewards[key]).float() 375 | 376 | prefix = "eval_" if train_eval == "eval" else "" 377 | metrics[f"{prefix}rewards/chosen"] = chosen_rewards.cpu().numpy().mean() 378 | for key in rejected_rewards: 379 | metrics[f"{prefix}rewards/{key}"] = rejected_rewards[key].cpu().numpy().mean() 380 | metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.cpu().numpy().mean() 381 | for key in rejected_rewards: 382 | metrics[f"{prefix}rewards/margins-{key}"] = (chosen_rewards - rejected_rewards[key]).cpu().numpy().mean() 383 | for key in policy_rejected_logps: 384 | metrics[f"{prefix}logps/rejected-{key}"] = policy_rejected_logps[key].detach().cpu().numpy().mean() 385 | metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().cpu().numpy().mean() 386 | for key in policy_rejected_logits: 387 | metrics[f"{prefix}logits/rejected-{key}"] = policy_rejected_logits[key].detach().cpu().numpy().mean() 388 | metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().cpu().numpy().mean() 389 | 390 | return losses.mean(), metrics 391 | 392 | def compute_loss( 393 | self, 394 | model: Union[PreTrainedModel, nn.Module], 395 | inputs: Dict[str, Union[torch.Tensor, Any]], 396 | return_outputs=False, 397 | ) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]: 398 | # print(inputs.keys()) 399 | # print(inputs) 400 | if not self.use_dpo_data_collator: 401 | warnings.warn( 402 | "compute_loss is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than " 403 | "DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator" 404 | ) 405 | loss, metrics = self.get_batch_metrics(model, inputs, train_eval="train") 406 | 407 | # force log the metrics 408 | if self.accelerator.is_main_process: 409 | self.store_metrics(metrics, train_eval="train") 410 | 411 | if return_outputs: 412 | return (loss, metrics) 413 | return loss 414 | 415 | def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]: 416 | """Generate samples from the model and reference model for the given batch of inputs.""" 417 | 418 | policy_output = model.generate( 419 | batch["prompt_input_ids"], 420 | attention_mask=batch["prompt_attention_mask"], 421 | max_length=self.config.max_length, 422 | do_sample=True, 423 | pad_token_id=self.tokenizer.pad_token_id, 424 | ) 425 | 426 | reference_output = self.ref_model.generate( 427 | batch["prompt_input_ids"], 428 | attention_mask=batch["prompt_attention_mask"], 429 | max_length=self.config.max_length, 430 | do_sample=True, 431 | pad_token_id=self.tokenizer.pad_token_id, 432 | ) 433 | 434 | policy_output = pad_to_length(policy_output, self.config.max_length, self.tokenizer.pad_token_id) 435 | policy_output_decoded = self.tokenizer.batch_decode(policy_output, skip_special_tokens=True) 436 | 437 | reference_output = pad_to_length(reference_output, self.config.max_length, self.tokenizer.pad_token_id) 438 | reference_output_decoded = self.tokenizer.batch_decode(reference_output, skip_special_tokens=True) 439 | 440 | return policy_output_decoded, reference_output_decoded 441 | 442 | def prediction_step( 443 | self, 444 | model: Union[PreTrainedModel, nn.Module], 445 | inputs: Dict[str, Union[torch.Tensor, Any]], 446 | prediction_loss_only: bool, 447 | ignore_keys: Optional[List[str]] = None, 448 | ): 449 | if not self.use_dpo_data_collator: 450 | warnings.warn( 451 | "prediction_step is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than " 452 | "DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator" 453 | ) 454 | if ignore_keys is None: 455 | if hasattr(model, "config"): 456 | ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", []) 457 | else: 458 | ignore_keys = [] 459 | 460 | with torch.no_grad(): 461 | loss, metrics = self.get_batch_metrics(model, inputs, train_eval="eval") 462 | 463 | # force log the metrics 464 | if self.accelerator.is_main_process: 465 | self.store_metrics(metrics, train_eval="eval") 466 | 467 | if prediction_loss_only: 468 | return (loss.detach(), None, None) 469 | 470 | # logits for the chosen and rejected samples from model 471 | logits_dict = { 472 | "logits_test/chosen": metrics["logits_test/chosen"], 473 | # "logits_test/rejected": metrics["logits_test/rejected"], 474 | } 475 | logits = tuple(v for k, v in logits_dict.items() if k not in ignore_keys) 476 | logits = torch.stack(logits).mean(axis=1) 477 | labels = torch.zeros(logits.shape[0]) 478 | 479 | return (loss.detach(), logits, labels) 480 | 481 | def store_metrics(self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None: 482 | for key, value in metrics.items(): 483 | self._stored_metrics[train_eval][key].append(value) 484 | 485 | def log(self, logs: Dict[str, float]) -> None: 486 | """ 487 | Log `logs` on the various objects watching training, including stored metrics. 488 | 489 | Args: 490 | logs (`Dict[str, float]`): 491 | The values to log. 492 | """ 493 | # logs either has 'loss' or 'eval_loss' 494 | train_eval = "train" if "loss" in logs else "eval" 495 | # Add averaged stored metrics to logs 496 | for key, metrics in self._stored_metrics[train_eval].items(): 497 | logs[key] = torch.tensor(metrics).mean().item() 498 | del self._stored_metrics[train_eval] 499 | return super().log(logs) 500 | 501 | -------------------------------------------------------------------------------- /baselines/DMPO/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import warnings 4 | from dataclasses import dataclass 5 | from typing import Any, Dict, List, Optional, Union 6 | 7 | import numpy as np 8 | import torch 9 | from torch.nn.utils.rnn import pad_sequence 10 | from torch.utils.data import IterableDataset 11 | from transformers import DataCollatorForLanguageModeling, PreTrainedTokenizerBase, TrainerCallback 12 | 13 | @dataclass 14 | class DPODataCollatorWithPadding: 15 | r""" 16 | DPO DataCollator class that pads the inputs to the maximum length of the batch. 17 | Args: 18 | tokenizer (`PreTrainedTokenizerBase`): 19 | The tokenizer used for encoding the data. 20 | padding (`Union[bool, str, `PaddingStrategy`]`, `optional`, defaults to `True`): 21 | padding_strategy to pass to the tokenizer. 22 | max_length (`Optional[int]`, `optional`, defaults to `None`): 23 | The maximum length of the sequence to be processed. 24 | max_prompt_length (`Optional[int]`, `optional`, defaults to `None`): 25 | The maximum length of the prompt to be processed. 26 | label_pad_token_id (`int`, defaults to -100): 27 | The label used for masking. 28 | padding_value (`int`, defaults to 0): 29 | The value used for padding. 30 | truncation_mode: (`str`, defaults to "keep_end"): 31 | The truncation mode to use when truncating the prompt + chosen/rejected responses. 32 | """ 33 | tokenizer: PreTrainedTokenizerBase 34 | padding: Union[bool, str] = True 35 | max_length: Optional[int] = None 36 | max_prompt_length: Optional[int] = None 37 | label_pad_token_id: int = -100 38 | padding_value: int = 0 39 | truncation_mode: str = "keep_end" 40 | 41 | def tokenize_batch_element( 42 | self, 43 | prompt: str, 44 | chosen: str, 45 | rejected: Dict[str, str], 46 | ) -> Dict: 47 | """Tokenize a single batch element. 48 | 49 | At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation 50 | in case the prompt + chosen or prompt + rejected responses is/are too long. First 51 | we truncate the prompt; if we're still too long, we truncate the chosen/rejected. 52 | 53 | We also create the labels for the chosen/rejected responses, which are of length equal to 54 | the sum of the length of the prompt and the chosen/rejected response, with 55 | label_pad_token_id for the prompt tokens. 56 | """ 57 | chosen_tokens = self.tokenizer(chosen, add_special_tokens=False) 58 | prompt_tokens = self.tokenizer(prompt, add_special_tokens=False) 59 | rejected_tokens = {} 60 | for key in rejected: 61 | rejected_tokens[key] = self.tokenizer(rejected[key], add_special_tokens=False) 62 | 63 | assert self.tokenizer.eos_token_id not in prompt_tokens["input_ids"], f"Prompt contains EOS token: {prompt}" 64 | assert ( 65 | self.tokenizer.eos_token_id not in chosen_tokens["input_ids"] 66 | ), f"Chosen response contains EOS token: {chosen}" 67 | assert ( 68 | all([self.tokenizer.eos_token_id not in rejected_tokens[key]["input_ids"] for key in rejected_tokens]) 69 | ), f"Rejected response contains EOS token: {rejected}" 70 | 71 | chosen_tokens["input_ids"].append(self.tokenizer.eos_token_id) 72 | chosen_tokens["attention_mask"].append(1) 73 | for key in rejected_tokens: 74 | rejected_tokens[key]["input_ids"].append(self.tokenizer.eos_token_id) 75 | rejected_tokens[key]["attention_mask"].append(1) 76 | max_rejected_len = max([len(rejected_tokens[key]["input_ids"]) for key in rejected_tokens]) 77 | longer_response_length = max(len(chosen_tokens["input_ids"]), max_rejected_len) 78 | 79 | # if combined sequence is too long, truncate the prompt 80 | if len(prompt_tokens["input_ids"]) + longer_response_length > self.max_length: 81 | if self.truncation_mode == "keep_start": 82 | prompt_tokens = {k: v[: self.max_prompt_length] for k, v in prompt_tokens.items()} 83 | elif self.truncation_mode == "keep_end": 84 | prompt_tokens = {k: v[-self.max_prompt_length :] for k, v in prompt_tokens.items()} 85 | else: 86 | raise ValueError(f"Unknown truncation mode: {self.truncation_mode}") 87 | 88 | # if that's still too long, truncate the response 89 | if len(prompt_tokens["input_ids"]) + longer_response_length > self.max_length: 90 | chosen_tokens = {k: v[: self.max_length - self.max_prompt_length] for k, v in chosen_tokens.items()} 91 | rejected_tokens = {k: v[: self.max_length - self.max_prompt_length] for k, v in rejected_tokens.items()} 92 | 93 | # Create labels 94 | chosen_sequence_tokens = {k: prompt_tokens[k] + chosen_tokens[k] for k in chosen_tokens} 95 | rejected_sequence_tokens = {} 96 | # rejected_tokens: Dict[str, Dict] 97 | for key in rejected_tokens: 98 | rejected_sequence_tokens[key] = {k: prompt_tokens[k] + rejected_tokens[key][k] for k in rejected_tokens[key]} 99 | chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:] 100 | chosen_sequence_tokens["labels"][: len(prompt_tokens["input_ids"])] = [self.label_pad_token_id] * len( 101 | prompt_tokens["input_ids"] 102 | ) 103 | for key in rejected_sequence_tokens: 104 | rejected_sequence_tokens[key]["labels"] = rejected_sequence_tokens[key]["input_ids"][:] 105 | rejected_sequence_tokens[key]["labels"][: len(prompt_tokens["input_ids"])] = [self.label_pad_token_id] * len( 106 | prompt_tokens["input_ids"] 107 | ) 108 | 109 | batch = {} 110 | 111 | batch["prompt"] = prompt 112 | batch["chosen"] = prompt + chosen 113 | for key in rejected: 114 | batch[key] = prompt + rejected[key] 115 | batch["chosen_response_only"] = chosen 116 | for key in rejected: 117 | batch[f"{key}_response_only"] = rejected[key] 118 | 119 | for k, toks in { 120 | "chosen": chosen_sequence_tokens, 121 | # "rejected": rejected_sequence_tokens, 122 | "prompt": prompt_tokens, 123 | }.items(): 124 | for type_key, tokens in toks.items(): 125 | if type_key == "token_type_ids": 126 | continue 127 | batch[f"{k}_{type_key}"] = tokens 128 | # rejected_sequence_tokens: Dict[str, Dict] 129 | for k, toks in rejected_sequence_tokens.items(): 130 | for type_key, tokens in toks.items(): 131 | if type_key == "token_type_ids": 132 | continue 133 | batch[f"{k}_{type_key}"] = tokens 134 | 135 | return batch 136 | 137 | def collate(self, batch): 138 | # first, pad everything to the same length 139 | padded_batch = {} 140 | for k in batch[0].keys(): 141 | if k.endswith("_input_ids") or k.endswith("_attention_mask") or k.endswith("_labels"): 142 | # adapted from https://stackoverflow.com/questions/73256206 143 | if "prompt" in k: 144 | to_pad = [torch.LongTensor(ex[k][::-1]) for ex in batch] 145 | else: 146 | to_pad = [torch.LongTensor(ex[k]) for ex in batch] 147 | if k.endswith("_input_ids"): 148 | padding_value = self.tokenizer.pad_token_id 149 | elif k.endswith("_labels"): 150 | padding_value = self.label_pad_token_id 151 | elif k.endswith("_attention_mask"): 152 | padding_value = self.padding_value 153 | else: 154 | raise ValueError(f"Unexpected key in batch '{k}'") 155 | 156 | padded_batch[k] = pad_sequence(to_pad, batch_first=True, padding_value=padding_value) 157 | # for the prompt, flip back so padding is on left side 158 | if "prompt" in k: 159 | padded_batch[k] = padded_batch[k].flip(dims=[1]) 160 | else: 161 | padded_batch[k] = [ex[k] for ex in batch] 162 | 163 | return padded_batch 164 | 165 | def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: 166 | tokenized_batch = [] 167 | 168 | for feature in features: 169 | prompt = feature["prompt"] 170 | chosen = feature["chosen"] 171 | rejected = {} 172 | for key in feature: 173 | if key.startswith("rejected"): 174 | rejected[key] = feature[key] 175 | 176 | batch_element = self.tokenize_batch_element(prompt, chosen, rejected) 177 | tokenized_batch.append(batch_element) 178 | 179 | # return collated batch 180 | return self.collate(tokenized_batch) 181 | 182 | def pad_to_length(tensor: torch.Tensor, length: int, pad_value: Union[int, float], dim: int = -1) -> torch.Tensor: 183 | if tensor.size(dim) >= length: 184 | return tensor 185 | else: 186 | pad_size = list(tensor.shape) 187 | pad_size[dim] = length - tensor.size(dim) 188 | return torch.cat( 189 | [ 190 | tensor, 191 | pad_value * torch.ones(*pad_size, dtype=tensor.dtype, device=tensor.device), 192 | ], 193 | dim=dim, 194 | ) -------------------------------------------------------------------------------- /baselines/Re-weighting/RW_SFT.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import warnings 4 | import re 5 | import wandb 6 | from typing import List, Optional 7 | import datasets 8 | from tqdm import tqdm 9 | from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments,BitsAndBytesConfig 10 | from datasets import load_dataset 11 | from trl import SFTTrainer, DataCollatorForCompletionOnlyLM, SFTConfig 12 | from peft import AutoPeftModelForCausalLM, LoraConfig, get_peft_model, prepare_model_for_kbit_training, TaskType, PeftModel 13 | from transformers import LlamaForCausalLM, LlamaTokenizer 14 | # from utils import find_all_linear_names, print_trainable_parameters 15 | import pandas as pd 16 | from accelerate import Accelerator 17 | import numpy as np 18 | import torch 19 | import bitsandbytes as bnb 20 | import fire 21 | import json 22 | 23 | def read_json(json_file:str) -> dict: 24 | f = open(json_file, 'r') 25 | return json.load(f) 26 | 27 | def gh_tr(category:str,test_data,name2genre:dict,genre_dict:dict): 28 | for data in tqdm(test_data,desc="Processing category data......"): 29 | input = data['input'] 30 | names = re.findall(r'"([^"]+)"', input) 31 | for name in names: 32 | if name in name2genre: 33 | genres = name2genre[name] 34 | else: 35 | continue 36 | for genre in genres: 37 | if genre in genre_dict: 38 | genre_dict[genre] += 1/len(genres) 39 | gh = [genre_dict[x] for x in genre_dict] 40 | gh_normalize = [x/sum(gh) for x in gh] 41 | return gh_normalize 42 | 43 | def gh_ta(category:str,test_data,name2genre:dict,genre_dict:dict): 44 | for data in tqdm(test_data,desc="Processing category data......"): 45 | input = data['output'] 46 | names = re.findall(r'"([^"]+)"', input) 47 | for name in names: 48 | if name in name2genre: 49 | genres = name2genre[name] 50 | else: 51 | # print(f"Not exist in name2genre:{name}") 52 | continue 53 | for genre in genres: 54 | if genre in genre_dict: 55 | genre_dict[genre] += 1/len(genres) 56 | gh = [genre_dict[x] for x in genre_dict] 57 | gh_normalize = [x/sum(gh) for x in gh] 58 | return gh_normalize 59 | 60 | def weight_dict(category:str,test_data,name2genre:dict,genre_dict:dict): 61 | GH_tr = gh_tr(category,test_data,name2genre,genre_dict) 62 | GH_ta = gh_ta(category,test_data,name2genre,genre_dict) 63 | weight_dict = {} 64 | idx = 0 65 | for category in genre_dict: 66 | weight_dict[category] = GH_tr[idx] / GH_ta[idx] 67 | idx += 1 68 | 69 | return weight_dict 70 | 71 | def cal_weight(category:str,test_data,name2genre:dict,genre_dict:dict): 72 | weights = [] 73 | w_dict = weight_dict(category,test_data,name2genre,genre_dict) 74 | print(f"Length of data:{len(test_data)}") 75 | for data in tqdm(test_data,desc="Processing category data......"): 76 | weight = [] 77 | target_item = data['output'].strip("\n").strip("\"") 78 | if target_item in name2genre : 79 | genres = name2genre[target_item] 80 | for genre in genres: 81 | if genre in genre_dict: 82 | weight.append(w_dict[genre]) 83 | if len(weight)>0: 84 | weight = sum(weight) / len(weight) 85 | weights.append(weight) 86 | else: 87 | weights.append(1) 88 | else: 89 | weights.append(1) 90 | print(f"Length of weights:{len(weights)}") 91 | return weights 92 | 93 | class IFTrainer(SFTTrainer): 94 | def compute_loss(self, model, inputs, return_outputs=False): 95 | weights = inputs.pop("weight") 96 | labels = inputs.pop("labels") 97 | outputs = model(**inputs) 98 | 99 | if self.args.past_index >= 0: 100 | self._past = outputs[self.args.past_index] 101 | 102 | logits = outputs.get("logits") 103 | 104 | shift_logits = logits[..., :-1, :].contiguous() 105 | shift_labels = labels[..., 1:].contiguous() 106 | 107 | # Flatten the tokens 108 | loss_fct = torch.nn.CrossEntropyLoss(reduction="none") 109 | shift_logits = shift_logits.view(-1, self.model.config.vocab_size) 110 | shift_labels = shift_labels.view(-1) 111 | # Enable model parallelism 112 | shift_labels = shift_labels.to(shift_logits.device) 113 | 114 | loss = torch.mean(weights * torch.mean(loss_fct(shift_logits, shift_labels).view(weights.shape[0], -1))) 115 | 116 | 117 | return (loss, outputs) if return_outputs else loss 118 | 119 | def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None): 120 | if not self.args.remove_unused_columns: 121 | return dataset 122 | self._set_signature_columns_if_needed() 123 | signature_columns = self._signature_columns 124 | signature_columns.append("weight") 125 | ignored_columns = list(set(dataset.column_names) - set(signature_columns)) 126 | if len(ignored_columns) > 0: 127 | dset_description = "" if description is None else f"in the {description} set" 128 | 129 | columns = [k for k in signature_columns if k in dataset.column_names] 130 | x = dataset.remove_columns(ignored_columns) 131 | return x 132 | 133 | def _prepare_non_packed_dataloader( 134 | self, 135 | tokenizer, 136 | dataset, 137 | dataset_text_field, 138 | max_seq_length, 139 | formatting_func=None, 140 | add_special_tokens=True, 141 | remove_unused_columns=True, 142 | ): 143 | use_formatting_func = formatting_func is not None and dataset_text_field is None 144 | self._dataset_sanity_checked = False 145 | 146 | # Inspired from: https://huggingface.co/learn/nlp-course/chapter7/6?fw=pt 147 | def tokenize(element): 148 | outputs = tokenizer( 149 | element[dataset_text_field] if not use_formatting_func else formatting_func(element), 150 | add_special_tokens=add_special_tokens, 151 | truncation=True, 152 | padding=False, 153 | max_length=max_seq_length, 154 | return_overflowing_tokens=False, 155 | return_length=False, 156 | ) 157 | 158 | if use_formatting_func and not self._dataset_sanity_checked: 159 | if not isinstance(formatting_func(element), list): 160 | raise ValueError( 161 | "The `formatting_func` should return a list of processed strings since it can lead to silent bugs." 162 | ) 163 | else: 164 | self._dataset_sanity_checked = True 165 | 166 | return {"input_ids": outputs["input_ids"], "attention_mask": outputs["attention_mask"], "weight": element['weight']} 167 | 168 | signature_columns = ["input_ids", "labels", "attention_mask","weight"] 169 | 170 | extra_columns = list(set(dataset.column_names) - set(signature_columns)) 171 | 172 | if not remove_unused_columns and len(extra_columns) > 0: 173 | warnings.warn( 174 | "You passed `remove_unused_columns=False` on a non-packed dataset. This might create some issues with the default collator and yield to errors. If you want to " 175 | f"inspect dataset other columns (in this case {extra_columns}), you can subclass `DataCollatorForLanguageModeling` in case you used the default collator and create your own data collator in order to inspect the unused dataset columns." 176 | ) 177 | 178 | tokenized_dataset = dataset.map( 179 | tokenize, 180 | batched=True, 181 | remove_columns=dataset.column_names if remove_unused_columns else None, 182 | num_proc=self.dataset_num_proc, 183 | batch_size=self.dataset_batch_size, 184 | ) 185 | 186 | return tokenized_dataset 187 | 188 | from transformers import DataCollatorWithPadding 189 | import torch 190 | 191 | def train( 192 | # path 193 | output_dir="", 194 | base_model ="", 195 | train_dataset="", 196 | valid_dataset="", 197 | train_sample_size:int = 1024, 198 | resume_from_checkpoint: str = "base_model", # either training checkpoint or final adapter 199 | # wandb config 200 | wandb_project: str = "", 201 | wandb_name: str = "", # the name of the wandb run 202 | # training hyperparameters 203 | gradient_accumulation_steps: int = 1, 204 | batch_size: int = 8, 205 | num_train_epochs: int = 5, 206 | learning_rate: float = 2e-5, 207 | cutoff_len: int = 512, 208 | eval_step = 0.05, 209 | category: str = "CDs_and_Vinyl", 210 | seed = 0 211 | ): 212 | os.environ['WANDB_PROJECT'] = wandb_project 213 | 214 | def formatting_prompts_func(examples): 215 | output_text = [] 216 | for i in range(len(examples["instruction"])): 217 | instruction = examples["instruction"][i] 218 | input_text = examples["input"][i] 219 | response = examples["output"][i] 220 | 221 | if len(input_text) >= 2: 222 | text = f'''Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. 223 | 224 | ### Instruction: 225 | {instruction} 226 | 227 | ### Input: 228 | {input_text} 229 | 230 | ### Response: 231 | {response} 232 | ''' 233 | else: 234 | text = f'''Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. 235 | 236 | ### Instruction: 237 | {instruction} 238 | 239 | ### Response: 240 | {response} 241 | ''' 242 | output_text.append(text) 243 | 244 | return output_text 245 | 246 | def get_train_weight(data, category:str,test_data,name2genre:dict,genre_dict:dict,w_dict): 247 | weight = [] 248 | target_item = data['output'].strip("\n").strip("\"") 249 | if target_item in name2genre : 250 | genres = name2genre[target_item] 251 | for genre in genres: 252 | if genre in genre_dict: 253 | weight.append(w_dict[genre]) 254 | if len(weight)>0: 255 | weight = sum(weight) / len(weight) 256 | return {"weight":weight} 257 | else: 258 | return {"weight":1} 259 | else: 260 | return {"weight":1} 261 | 262 | 263 | name2genre = read_json(f"./eval/{category}/name2genre.json") 264 | genre_dict = read_json(f"./eval/{category}/genre_dict.json") 265 | w_dict = weight_dict(category,test_data=read_json(train_dataset),name2genre=name2genre,genre_dict=genre_dict) 266 | train_weights = cal_weight(category,test_data=read_json(train_dataset),name2genre=name2genre,genre_dict=genre_dict) 267 | 268 | val_sample_size = int(train_sample_size / 8) 269 | dataset = load_dataset('json', data_files=train_dataset) 270 | dataset = {"train": dataset['train'].select(range(train_sample_size+val_sample_size))} 271 | #weights = get_train_weight(dataset['train'],train_weights) 272 | dataset['train'] = dataset['train'].map(lambda x: get_train_weight(x, category,test_data=read_json(train_dataset),name2genre=name2genre,genre_dict=genre_dict,w_dict=w_dict)) 273 | print("Features:{}".format(dataset["train"].features)) 274 | train_val_split = dataset['train'].train_test_split(train_size=train_sample_size, test_size=val_sample_size) 275 | train_data = train_val_split['train'] 276 | print("Features:{}".format(train_data.features)) 277 | val_data = train_val_split['test'] 278 | 279 | 280 | bnb_config = BitsAndBytesConfig( 281 | # load_in_8bit=True, 282 | load_in_4bit=True, 283 | bnb_4bit_quant_type="nf4", 284 | bnb_4bit_compute_dtype=torch.bfloat16, 285 | bnb_4bit_use_double_quant=False, 286 | ) 287 | 288 | device_index = Accelerator().process_index 289 | device_map = {"": device_index} 290 | 291 | model = LlamaForCausalLM.from_pretrained(base_model, device_map=device_map, \ 292 | quantization_config=bnb_config) 293 | model.config.use_cache = False 294 | model = prepare_model_for_kbit_training(model) 295 | 296 | if 'Llama-3' in base_model: 297 | tokenizer = AutoTokenizer.from_pretrained(base_model) 298 | else: 299 | tokenizer = LlamaTokenizer.from_pretrained(base_model) 300 | # tokenizer.pad_token = tokenizer.eos_token 301 | # tokenizer.padding_side = "right" 302 | tokenizer.pad_token_id = (0) 303 | tokenizer.padding_side = "left" # Fix weird overflow issue with fp16 training 304 | 305 | if resume_from_checkpoint!="base_model": 306 | model = PeftModel.from_pretrained(model, resume_from_checkpoint, 307 | is_trainable=True) 308 | else: 309 | peft_config = LoraConfig( 310 | inference_mode=False, 311 | r=64, 312 | lora_alpha=32, 313 | target_modules=['k_proj', 'v_proj', 'q_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj'], 314 | lora_dropout=0.05, 315 | # bias="none", 316 | task_type="CAUSAL_LM", 317 | ) 318 | model = get_peft_model(model, peft_config) 319 | 320 | model.print_trainable_parameters() 321 | 322 | training_args = SFTConfig( 323 | per_device_train_batch_size=batch_size, 324 | gradient_accumulation_steps=gradient_accumulation_steps, 325 | gradient_checkpointing =True, 326 | max_grad_norm= 0.3, 327 | num_train_epochs=num_train_epochs, 328 | learning_rate=learning_rate, 329 | bf16=True, 330 | save_strategy="steps", 331 | save_steps=eval_step, 332 | save_total_limit=100, 333 | load_best_model_at_end=True, 334 | evaluation_strategy="steps", 335 | eval_steps=eval_step, 336 | logging_steps=1, 337 | output_dir=output_dir, 338 | optim="paged_adamw_32bit", 339 | remove_unused_columns= True, 340 | lr_scheduler_type="cosine", 341 | warmup_ratio=0.05, 342 | report_to="wandb", 343 | run_name=wandb_name, 344 | gradient_checkpointing_kwargs={'use_reentrant': True}, 345 | save_only_model=True, 346 | ddp_find_unused_parameters=False, # should set to False becuase there are no unused parameters in the forward process 347 | ) 348 | trainer = IFTrainer( 349 | model, 350 | train_dataset=train_data, 351 | eval_dataset=val_data, 352 | tokenizer=tokenizer, 353 | formatting_func=formatting_prompts_func, 354 | max_seq_length=cutoff_len, 355 | args=training_args 356 | #data_collator=data_callator 357 | ) 358 | 359 | trainer.train() 360 | trainer.save_model(output_dir) 361 | 362 | output_dir = os.path.join(output_dir, "final_checkpoint") 363 | trainer.model.save_pretrained(output_dir) 364 | tokenizer.save_pretrained(output_dir) 365 | 366 | if __name__ == "__main__": 367 | fire.Fire(train) -------------------------------------------------------------------------------- /baselines/SDPO/softmax_dpo_trainer.py: -------------------------------------------------------------------------------- 1 | # DPO Authors: Rafael Rafailov, Archit Sharma, Eric Mitchell, Stefano Ermon, Christopher D. Manning, and Chelsea Finn 2023 2 | # Copyright 2023 The HuggingFace Team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | import warnings 16 | from collections import defaultdict 17 | from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union 18 | import importlib 19 | 20 | 21 | import torch 22 | import torch.nn as nn 23 | import torch.nn.functional as F 24 | from datasets import Dataset 25 | from transformers import DataCollator, PreTrainedModel, PreTrainedTokenizerBase, Trainer, TrainingArguments 26 | from transformers.trainer_callback import TrainerCallback 27 | 28 | from .utils import DPODataCollatorWithPadding, pad_to_length 29 | 30 | 31 | def is_peft_available(): 32 | return importlib.util.find_spec("peft") is not None 33 | 34 | if is_peft_available(): 35 | from peft import get_peft_model, prepare_model_for_kbit_training 36 | 37 | 38 | class DPOTrainer(Trainer): 39 | r""" 40 | Initialize DPOTrainer. 41 | 42 | Args: 43 | model (`transformers.PreTrainedModel`): 44 | The model to train, preferably an `AutoModelForSequenceClassification`. 45 | ref_model (`PreTrainedModelWrapper`): 46 | Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. 47 | beta (`float`, defaults to 0.1): 48 | The beta factor in DPO loss. Higher beta means less divergence from the initial policy. 49 | args (`transformers.TrainingArguments`): 50 | The arguments to use for training. 51 | data_collator (`transformers.DataCollator`): 52 | The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used 53 | which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences. 54 | label_pad_token_id (`int`, defaults to `-100`): 55 | The label pad token id. This argument is required if you want to use the default data collator. 56 | padding_value (`int`, defaults to `0`): 57 | The padding value. This argument is required if you want to use the default data collator. 58 | truncation_mode (`str`, defaults to `keep_end`): 59 | The truncation mode to use, either `keep_end` or `keep_start`. This argument is required if you want to use the default data collator. 60 | train_dataset (`datasets.Dataset`): 61 | The dataset to use for training. 62 | eval_dataset (`datasets.Dataset`): 63 | The dataset to use for evaluation. 64 | tokenizer (`transformers.PreTrainedTokenizerBase`): 65 | The tokenizer to use for training. This argument is required if you want to use the default data collator. 66 | model_init (`Callable[[], transformers.PreTrainedModel]`): 67 | The model initializer to use for training. If None is specified, the default model initializer will be used. 68 | callbacks (`List[transformers.TrainerCallback]`): 69 | The callbacks to use for training. 70 | optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`): 71 | The optimizer and scheduler to use for training. 72 | preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`): 73 | The function to use to preprocess the logits before computing the metrics. 74 | max_length (`int`, defaults to `None`): 75 | The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator. 76 | max_prompt_length (`int`, defaults to `None`): 77 | The maximum length of the prompt. This argument is required if you want to use the default data collator. 78 | peft_config (`Dict`, defaults to `None`): 79 | The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model. 80 | """ 81 | 82 | def __init__( 83 | self, 84 | model: Union[PreTrainedModel, nn.Module] = None, 85 | ref_model: Union[PreTrainedModel, nn.Module] = None, 86 | beta: float = 0.1, 87 | args: TrainingArguments = None, 88 | data_collator: Optional[DataCollator] = None, 89 | label_pad_token_id: int = -100, 90 | padding_value: int = 0, 91 | truncation_mode: str = "keep_end", 92 | train_dataset: Optional[Dataset] = None, 93 | eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, 94 | tokenizer: Optional[PreTrainedTokenizerBase] = None, 95 | model_init: Optional[Callable[[], PreTrainedModel]] = None, 96 | callbacks: Optional[List[TrainerCallback]] = None, 97 | optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = ( 98 | None, 99 | None, 100 | ), 101 | preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, 102 | max_length: Optional[int] = None, 103 | max_prompt_length: Optional[int] = None, 104 | peft_config: Optional[Dict] = None, 105 | ): 106 | if not is_peft_available() and peft_config is not None: 107 | raise ValueError( 108 | "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models" 109 | ) 110 | elif is_peft_available() and peft_config is not None: 111 | if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): 112 | model = prepare_model_for_kbit_training(model) 113 | model = get_peft_model(model, peft_config) 114 | 115 | if data_collator is None: 116 | if tokenizer is None: 117 | raise ValueError( 118 | "max_length or a tokenizer must be specified when using the default DPODataCollatorWithPadding" 119 | ) 120 | if max_length is None: 121 | warnings.warn( 122 | "When using DPODataCollatorWithPadding, you should set `max_length` in the DPOTrainer's init" 123 | " it will be set to `512` by default, but you should do it yourself in the future.", 124 | UserWarning, 125 | ) 126 | max_length = 512 127 | if max_prompt_length is None: 128 | warnings.warn( 129 | "When using DPODataCollatorWithPadding, you should set `max_prompt_length` in the DPOTrainer's init" 130 | " it will be set to `128` by default, but you should do it yourself in the future.", 131 | UserWarning, 132 | ) 133 | max_prompt_length = 128 134 | 135 | data_collator = DPODataCollatorWithPadding( 136 | tokenizer, 137 | max_length=max_length, 138 | max_prompt_length=max_prompt_length, 139 | label_pad_token_id=label_pad_token_id, 140 | padding_value=padding_value, 141 | truncation_mode=truncation_mode, 142 | ) 143 | 144 | if args.remove_unused_columns: 145 | args.remove_unused_columns = False 146 | # warn users 147 | warnings.warn( 148 | "When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments" 149 | " we have set it for you, but you should do it yourself in the future.", 150 | UserWarning, 151 | ) 152 | 153 | self.use_dpo_data_collator = True 154 | else: 155 | self.use_dpo_data_collator = False 156 | 157 | self.label_pad_token_id = label_pad_token_id 158 | self.padding_value = padding_value 159 | 160 | self.beta = beta 161 | self.ref_model = ref_model 162 | 163 | self._stored_metrics = defaultdict(lambda: defaultdict(list)) 164 | 165 | super().__init__( 166 | model, 167 | args, 168 | data_collator, 169 | train_dataset, 170 | eval_dataset, 171 | tokenizer, 172 | model_init, 173 | None, 174 | callbacks, 175 | optimizers, 176 | preprocess_logits_for_metrics, 177 | ) 178 | 179 | # Since we inherit from trainer we always have access to an accelerator 180 | if hasattr(self, "accelerator"): 181 | self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) 182 | else: 183 | raise AttributeError( 184 | "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`." 185 | ) 186 | 187 | def concatenated_inputs(self, batch: Dict[str, Union[List, torch.LongTensor]]) -> Dict[str, torch.LongTensor]: 188 | """Concatenate the chosen and rejected inputs into a single tensor. 189 | 190 | Args: 191 | batch: A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors of shape (batch_size, sequence_length). 192 | 193 | Returns: 194 | A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'. 195 | """ 196 | # 把 chosen 和 rejected response 拼接起来 197 | rejected_max_len = max([batch[key].shape[1] for key in batch if key.startswith("rejected") and key.endswith("_input_ids")]) 198 | max_length = max(batch["chosen_input_ids"].shape[1], rejected_max_len) 199 | concatenated_batch = {} 200 | for k in batch: 201 | if k.startswith("chosen") and isinstance(batch[k], torch.Tensor): 202 | pad_value = self.label_pad_token_id if "labels" in k else self.padding_value 203 | concatenated_key = k.replace("chosen", "concatenated") 204 | concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value) 205 | for k in batch: 206 | if k.startswith("rejected") and isinstance(batch[k], torch.Tensor): 207 | pad_value = self.label_pad_token_id if "labels" in k else self.padding_value 208 | # concatenated_key = k.replace("rejected", "concatenated") 209 | prefix = k.split("_")[0] 210 | concatenated_key = "concatenated" + k[len(prefix):] 211 | concatenated_batch[concatenated_key] = torch.cat( 212 | ( 213 | concatenated_batch[concatenated_key], 214 | pad_to_length(batch[k], max_length, pad_value=pad_value), 215 | ), 216 | dim=0, 217 | ).to(self.accelerator.device) 218 | return concatenated_batch 219 | 220 | def dpo_loss( 221 | self, 222 | policy_chosen_logps: torch.FloatTensor, 223 | policy_rejected_logps: Dict[str, torch.FloatTensor], 224 | reference_chosen_logps: torch.FloatTensor, 225 | reference_rejected_logps: Dict[str, torch.FloatTensor], 226 | reference_free: bool = False, 227 | ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: 228 | """Compute the DPO loss for a batch of policy and reference model log probabilities. 229 | 230 | Args: 231 | policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,) 232 | policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,) 233 | reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,) 234 | reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,) 235 | beta: Temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5. We ignore the reference model as beta -> 0. 236 | reference_free: If True, we ignore the _provided_ reference model and implicitly use a reference model that assigns equal probability to all responses. 237 | 238 | Returns: 239 | A tuple of three tensors: (losses, chosen_rewards, rejected_rewards). 240 | The losses tensor contains the DPO loss for each example in the batch. 241 | The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively. 242 | """ 243 | # pi_logratios = policy_chosen_logps - policy_rejected_logps 244 | # for key in policy_rejected_logps: 245 | # ref_logratios = reference_chosen_logps - reference_rejected_logps 246 | chosen_logratios = policy_chosen_logps - reference_chosen_logps 247 | # print(f"chosen:{chosen_logratios}") 248 | rejected_logratios = {} 249 | for key in policy_rejected_logps: 250 | rejected_logratios[key] = policy_rejected_logps[key] - reference_rejected_logps[key] 251 | # print(f"{key}_logratios:{rejected_logratios[key].shape}") 252 | # if reference_free: 253 | # ref_logratios = 0 254 | 255 | # logits = pi_logratios - ref_logratios 256 | temp = sum(torch.exp(self.beta * (rejected_logratios[key] - chosen_logratios)) for key in rejected_logratios) 257 | temp1 = -torch.log(temp) 258 | losses = -F.logsigmoid(temp1) 259 | # losses = -F.logsigmoid(self.beta * logits) 260 | rejected_rewards = {} 261 | chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps).detach() 262 | for key in policy_rejected_logps: 263 | rejected_rewards[key] = self.beta * (policy_rejected_logps[key] - reference_rejected_logps[key]).detach() 264 | 265 | return losses, chosen_rewards, rejected_rewards 266 | 267 | def _get_batch_logps( 268 | self, 269 | logits: torch.FloatTensor, 270 | labels: torch.LongTensor, 271 | average_log_prob: bool = False, 272 | ) -> torch.FloatTensor: 273 | """Compute the log probabilities of the given labels under the given logits. 274 | 275 | Args: 276 | logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size) 277 | labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length) 278 | average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens. 279 | 280 | Returns: 281 | A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits. 282 | """ 283 | if logits.shape[:-1] != labels.shape: 284 | raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.") 285 | 286 | labels = labels[:, 1:].clone() 287 | logits = logits[:, :-1, :] 288 | loss_mask = labels != self.label_pad_token_id 289 | 290 | # dummy token; we'll ignore the losses on these tokens later 291 | labels[labels == self.label_pad_token_id] = 0 292 | 293 | per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) 294 | 295 | if average_log_prob: 296 | return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) 297 | else: 298 | return (per_token_logps * loss_mask).sum(-1) 299 | 300 | def concatenated_forward( 301 | self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]] 302 | ) -> Tuple[torch.FloatTensor, Dict[str, torch.FloatTensor], torch.FloatTensor, Dict[str, torch.FloatTensor]]: 303 | """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together. 304 | 305 | We do this to avoid doing two forward passes, because it's faster for FSDP. 306 | """ 307 | concatenated_batch = self.concatenated_inputs(batch) 308 | # print(concatenated_batch["concatenated_input_ids"].shape) 309 | all_logits = model( 310 | concatenated_batch["concatenated_input_ids"], 311 | attention_mask=concatenated_batch["concatenated_attention_mask"], 312 | ).logits.to(torch.float32) 313 | all_logps = self._get_batch_logps( 314 | all_logits, 315 | concatenated_batch["concatenated_labels"], 316 | average_log_prob=False, 317 | ) 318 | chosen_logps = all_logps[: batch["chosen_input_ids"].shape[0]] 319 | step = batch["chosen_input_ids"].shape[0] 320 | rejected_logps = {} 321 | cnt = 0 322 | for key in batch: 323 | if key.startswith("rejected") and key.endswith("_input_ids"): 324 | cnt += 1 325 | rejected_logps[f"rejected{cnt}"] = all_logps[step*cnt : step*(cnt+1)] 326 | 327 | chosen_logits = all_logits[: batch["chosen_input_ids"].shape[0]] 328 | rejected_logits = {} 329 | cnt = 0 330 | for key in batch: 331 | if key.startswith("rejected") and key.endswith("_input_ids"): 332 | cnt += 1 333 | rejected_logits[f"rejected{cnt}"] = all_logits[step*cnt : step*(cnt+1)] 334 | return (chosen_logps, rejected_logps, chosen_logits, rejected_logits) 335 | 336 | def get_batch_metrics( 337 | self, 338 | model, 339 | batch: Dict[str, Union[List, torch.LongTensor]], 340 | train_eval: Literal["train", "eval"] = "train", 341 | ): 342 | """Compute the DPO loss and other metrics for the given batch of inputs for train or test.""" 343 | metrics = {} 344 | 345 | ( 346 | policy_chosen_logps, 347 | policy_rejected_logps, 348 | policy_chosen_logits, 349 | policy_rejected_logits, 350 | ) = self.concatenated_forward(model, batch) 351 | with torch.no_grad(): 352 | ( 353 | reference_chosen_logps, 354 | reference_rejected_logps, 355 | _, 356 | _, 357 | ) = self.concatenated_forward(self.ref_model, batch) 358 | 359 | losses, chosen_rewards, rejected_rewards = self.dpo_loss( 360 | policy_chosen_logps, 361 | policy_rejected_logps, 362 | reference_chosen_logps, 363 | reference_rejected_logps, 364 | ) 365 | 366 | # reward_accuracies 记录 chosen 比所有 rejected 的收益都大的比例是多少 367 | reward_accuracies = None 368 | for key in rejected_rewards: 369 | if reward_accuracies is None: 370 | reward_accuracies = (chosen_rewards > rejected_rewards[key]).float() 371 | else: 372 | reward_accuracies *= (chosen_rewards > rejected_rewards[key]).float() 373 | 374 | prefix = "eval_" if train_eval == "eval" else "" 375 | metrics[f"{prefix}rewards/chosen"] = chosen_rewards.cpu().numpy().mean() 376 | for key in rejected_rewards: 377 | metrics[f"{prefix}rewards/{key}"] = rejected_rewards[key].cpu().numpy().mean() 378 | metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.cpu().numpy().mean() 379 | for key in rejected_rewards: 380 | metrics[f"{prefix}rewards/margins-{key}"] = (chosen_rewards - rejected_rewards[key]).cpu().numpy().mean() 381 | for key in policy_rejected_logps: 382 | metrics[f"{prefix}logps/rejected-{key}"] = policy_rejected_logps[key].detach().cpu().numpy().mean() 383 | metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().cpu().numpy().mean() 384 | for key in policy_rejected_logits: 385 | metrics[f"{prefix}logits/rejected-{key}"] = policy_rejected_logits[key].detach().cpu().numpy().mean() 386 | metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().cpu().numpy().mean() 387 | 388 | return losses.mean(), metrics 389 | 390 | def compute_loss( 391 | self, 392 | model: Union[PreTrainedModel, nn.Module], 393 | inputs: Dict[str, Union[torch.Tensor, Any]], 394 | return_outputs=False, 395 | ) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]: 396 | # print(inputs.keys()) 397 | # print(inputs) 398 | if not self.use_dpo_data_collator: 399 | warnings.warn( 400 | "compute_loss is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than " 401 | "DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator" 402 | ) 403 | loss, metrics = self.get_batch_metrics(model, inputs, train_eval="train") 404 | 405 | # force log the metrics 406 | if self.accelerator.is_main_process: 407 | self.store_metrics(metrics, train_eval="train") 408 | 409 | if return_outputs: 410 | return (loss, metrics) 411 | return loss 412 | 413 | def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]: 414 | """Generate samples from the model and reference model for the given batch of inputs.""" 415 | 416 | policy_output = model.generate( 417 | batch["prompt_input_ids"], 418 | attention_mask=batch["prompt_attention_mask"], 419 | max_length=self.config.max_length, 420 | do_sample=True, 421 | pad_token_id=self.tokenizer.pad_token_id, 422 | ) 423 | 424 | reference_output = self.ref_model.generate( 425 | batch["prompt_input_ids"], 426 | attention_mask=batch["prompt_attention_mask"], 427 | max_length=self.config.max_length, 428 | do_sample=True, 429 | pad_token_id=self.tokenizer.pad_token_id, 430 | ) 431 | 432 | policy_output = pad_to_length(policy_output, self.config.max_length, self.tokenizer.pad_token_id) 433 | policy_output_decoded = self.tokenizer.batch_decode(policy_output, skip_special_tokens=True) 434 | 435 | reference_output = pad_to_length(reference_output, self.config.max_length, self.tokenizer.pad_token_id) 436 | reference_output_decoded = self.tokenizer.batch_decode(reference_output, skip_special_tokens=True) 437 | 438 | return policy_output_decoded, reference_output_decoded 439 | 440 | def prediction_step( 441 | self, 442 | model: Union[PreTrainedModel, nn.Module], 443 | inputs: Dict[str, Union[torch.Tensor, Any]], 444 | prediction_loss_only: bool, 445 | ignore_keys: Optional[List[str]] = None, 446 | ): 447 | if not self.use_dpo_data_collator: 448 | warnings.warn( 449 | "prediction_step is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than " 450 | "DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator" 451 | ) 452 | if ignore_keys is None: 453 | if hasattr(model, "config"): 454 | ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", []) 455 | else: 456 | ignore_keys = [] 457 | 458 | with torch.no_grad(): 459 | loss, metrics = self.get_batch_metrics(model, inputs, train_eval="eval") 460 | 461 | # force log the metrics 462 | if self.accelerator.is_main_process: 463 | self.store_metrics(metrics, train_eval="eval") 464 | 465 | if prediction_loss_only: 466 | return (loss.detach(), None, None) 467 | 468 | # logits for the chosen and rejected samples from model 469 | logits_dict = { 470 | "logits_test/chosen": metrics["logits_test/chosen"], 471 | # "logits_test/rejected": metrics["logits_test/rejected"], 472 | } 473 | logits = tuple(v for k, v in logits_dict.items() if k not in ignore_keys) 474 | logits = torch.stack(logits).mean(axis=1) 475 | labels = torch.zeros(logits.shape[0]) 476 | 477 | return (loss.detach(), logits, labels) 478 | 479 | def store_metrics(self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None: 480 | for key, value in metrics.items(): 481 | self._stored_metrics[train_eval][key].append(value) 482 | 483 | def log(self, logs: Dict[str, float]) -> None: 484 | """ 485 | Log `logs` on the various objects watching training, including stored metrics. 486 | 487 | Args: 488 | logs (`Dict[str, float]`): 489 | The values to log. 490 | """ 491 | # logs either has 'loss' or 'eval_loss' 492 | train_eval = "train" if "loss" in logs else "eval" 493 | # Add averaged stored metrics to logs 494 | for key, metrics in self._stored_metrics[train_eval].items(): 495 | logs[key] = torch.tensor(metrics).mean().item() 496 | del self._stored_metrics[train_eval] 497 | return super().log(logs) 498 | 499 | -------------------------------------------------------------------------------- /baselines/SDPO/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import warnings 4 | from dataclasses import dataclass 5 | from typing import Any, Dict, List, Optional, Union 6 | 7 | import numpy as np 8 | import torch 9 | from torch.nn.utils.rnn import pad_sequence 10 | from torch.utils.data import IterableDataset 11 | from transformers import DataCollatorForLanguageModeling, PreTrainedTokenizerBase, TrainerCallback 12 | 13 | @dataclass 14 | class DPODataCollatorWithPadding: 15 | r""" 16 | DPO DataCollator class that pads the inputs to the maximum length of the batch. 17 | Args: 18 | tokenizer (`PreTrainedTokenizerBase`): 19 | The tokenizer used for encoding the data. 20 | padding (`Union[bool, str, `PaddingStrategy`]`, `optional`, defaults to `True`): 21 | padding_strategy to pass to the tokenizer. 22 | max_length (`Optional[int]`, `optional`, defaults to `None`): 23 | The maximum length of the sequence to be processed. 24 | max_prompt_length (`Optional[int]`, `optional`, defaults to `None`): 25 | The maximum length of the prompt to be processed. 26 | label_pad_token_id (`int`, defaults to -100): 27 | The label used for masking. 28 | padding_value (`int`, defaults to 0): 29 | The value used for padding. 30 | truncation_mode: (`str`, defaults to "keep_end"): 31 | The truncation mode to use when truncating the prompt + chosen/rejected responses. 32 | """ 33 | tokenizer: PreTrainedTokenizerBase 34 | padding: Union[bool, str] = True 35 | max_length: Optional[int] = None 36 | max_prompt_length: Optional[int] = None 37 | label_pad_token_id: int = -100 38 | padding_value: int = 0 39 | truncation_mode: str = "keep_end" 40 | 41 | def tokenize_batch_element( 42 | self, 43 | prompt: str, 44 | chosen: str, 45 | rejected: Dict[str, str], 46 | ) -> Dict: 47 | """Tokenize a single batch element. 48 | 49 | At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation 50 | in case the prompt + chosen or prompt + rejected responses is/are too long. First 51 | we truncate the prompt; if we're still too long, we truncate the chosen/rejected. 52 | 53 | We also create the labels for the chosen/rejected responses, which are of length equal to 54 | the sum of the length of the prompt and the chosen/rejected response, with 55 | label_pad_token_id for the prompt tokens. 56 | """ 57 | chosen_tokens = self.tokenizer(chosen, add_special_tokens=False) 58 | prompt_tokens = self.tokenizer(prompt, add_special_tokens=False) 59 | rejected_tokens = {} 60 | for key in rejected: 61 | rejected_tokens[key] = self.tokenizer(rejected[key], add_special_tokens=False) 62 | 63 | assert self.tokenizer.eos_token_id not in prompt_tokens["input_ids"], f"Prompt contains EOS token: {prompt}" 64 | assert ( 65 | self.tokenizer.eos_token_id not in chosen_tokens["input_ids"] 66 | ), f"Chosen response contains EOS token: {chosen}" 67 | assert ( 68 | all([self.tokenizer.eos_token_id not in rejected_tokens[key]["input_ids"] for key in rejected_tokens]) 69 | ), f"Rejected response contains EOS token: {rejected}" 70 | 71 | chosen_tokens["input_ids"].append(self.tokenizer.eos_token_id) 72 | chosen_tokens["attention_mask"].append(1) 73 | for key in rejected_tokens: 74 | rejected_tokens[key]["input_ids"].append(self.tokenizer.eos_token_id) 75 | rejected_tokens[key]["attention_mask"].append(1) 76 | max_rejected_len = max([len(rejected_tokens[key]["input_ids"]) for key in rejected_tokens]) 77 | longer_response_length = max(len(chosen_tokens["input_ids"]), max_rejected_len) 78 | 79 | # if combined sequence is too long, truncate the prompt 80 | if len(prompt_tokens["input_ids"]) + longer_response_length > self.max_length: 81 | if self.truncation_mode == "keep_start": 82 | prompt_tokens = {k: v[: self.max_prompt_length] for k, v in prompt_tokens.items()} 83 | elif self.truncation_mode == "keep_end": 84 | prompt_tokens = {k: v[-self.max_prompt_length :] for k, v in prompt_tokens.items()} 85 | else: 86 | raise ValueError(f"Unknown truncation mode: {self.truncation_mode}") 87 | 88 | # if that's still too long, truncate the response 89 | if len(prompt_tokens["input_ids"]) + longer_response_length > self.max_length: 90 | chosen_tokens = {k: v[: self.max_length - self.max_prompt_length] for k, v in chosen_tokens.items()} 91 | rejected_tokens = {k: v[: self.max_length - self.max_prompt_length] for k, v in rejected_tokens.items()} 92 | 93 | # Create labels 94 | chosen_sequence_tokens = {k: prompt_tokens[k] + chosen_tokens[k] for k in chosen_tokens} 95 | rejected_sequence_tokens = {} 96 | # rejected_tokens: Dict[str, Dict] 97 | for key in rejected_tokens: 98 | rejected_sequence_tokens[key] = {k: prompt_tokens[k] + rejected_tokens[key][k] for k in rejected_tokens[key]} 99 | chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:] 100 | chosen_sequence_tokens["labels"][: len(prompt_tokens["input_ids"])] = [self.label_pad_token_id] * len( 101 | prompt_tokens["input_ids"] 102 | ) 103 | for key in rejected_sequence_tokens: 104 | rejected_sequence_tokens[key]["labels"] = rejected_sequence_tokens[key]["input_ids"][:] 105 | rejected_sequence_tokens[key]["labels"][: len(prompt_tokens["input_ids"])] = [self.label_pad_token_id] * len( 106 | prompt_tokens["input_ids"] 107 | ) 108 | 109 | batch = {} 110 | 111 | batch["prompt"] = prompt 112 | batch["chosen"] = prompt + chosen 113 | for key in rejected: 114 | batch[key] = prompt + rejected[key] 115 | batch["chosen_response_only"] = chosen 116 | for key in rejected: 117 | batch[f"{key}_response_only"] = rejected[key] 118 | 119 | for k, toks in { 120 | "chosen": chosen_sequence_tokens, 121 | # "rejected": rejected_sequence_tokens, 122 | "prompt": prompt_tokens, 123 | }.items(): 124 | for type_key, tokens in toks.items(): 125 | if type_key == "token_type_ids": 126 | continue 127 | batch[f"{k}_{type_key}"] = tokens 128 | # rejected_sequence_tokens: Dict[str, Dict] 129 | for k, toks in rejected_sequence_tokens.items(): 130 | for type_key, tokens in toks.items(): 131 | if type_key == "token_type_ids": 132 | continue 133 | batch[f"{k}_{type_key}"] = tokens 134 | 135 | return batch 136 | 137 | def collate(self, batch): 138 | # first, pad everything to the same length 139 | padded_batch = {} 140 | for k in batch[0].keys(): 141 | if k.endswith("_input_ids") or k.endswith("_attention_mask") or k.endswith("_labels"): 142 | # adapted from https://stackoverflow.com/questions/73256206 143 | if "prompt" in k: 144 | to_pad = [torch.LongTensor(ex[k][::-1]) for ex in batch] 145 | else: 146 | to_pad = [torch.LongTensor(ex[k]) for ex in batch] 147 | if k.endswith("_input_ids"): 148 | padding_value = self.tokenizer.pad_token_id 149 | elif k.endswith("_labels"): 150 | padding_value = self.label_pad_token_id 151 | elif k.endswith("_attention_mask"): 152 | padding_value = self.padding_value 153 | else: 154 | raise ValueError(f"Unexpected key in batch '{k}'") 155 | 156 | padded_batch[k] = pad_sequence(to_pad, batch_first=True, padding_value=padding_value) 157 | # for the prompt, flip back so padding is on left side 158 | if "prompt" in k: 159 | padded_batch[k] = padded_batch[k].flip(dims=[1]) 160 | else: 161 | padded_batch[k] = [ex[k] for ex in batch] 162 | 163 | return padded_batch 164 | 165 | def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: 166 | tokenized_batch = [] 167 | 168 | for feature in features: 169 | prompt = feature["prompt"] 170 | chosen = feature["chosen"] 171 | rejected = {} 172 | for key in feature: 173 | if key.startswith("rejected"): 174 | rejected[key] = feature[key] 175 | 176 | batch_element = self.tokenize_batch_element(prompt, chosen, rejected) 177 | tokenized_batch.append(batch_element) 178 | 179 | # return collated batch 180 | return self.collate(tokenized_batch) 181 | 182 | def pad_to_length(tensor: torch.Tensor, length: int, pad_value: Union[int, float], dim: int = -1) -> torch.Tensor: 183 | if tensor.size(dim) >= length: 184 | return tensor 185 | else: 186 | pad_size = list(tensor.shape) 187 | pad_size[dim] = length - tensor.size(dim) 188 | return torch.cat( 189 | [ 190 | tensor, 191 | pad_value * torch.ones(*pad_size, dtype=tensor.dtype, device=tensor.device), 192 | ], 193 | dim=dim, 194 | ) -------------------------------------------------------------------------------- /baselines/Semantic_sampling_rosePO/Semantic_sampling_rosePO.py: -------------------------------------------------------------------------------- 1 | import json 2 | from tqdm import tqdm 3 | import re 4 | from sentence_transformers import SentenceTransformer 5 | import torch 6 | import torch.nn.functional as F 7 | def process_batch(batch): 8 | results = [] 9 | for data in batch: 10 | input = data['input'] 11 | names = re.findall(r'"([^"]+)"', input) 12 | name_embeddings = torch.tensor([model.encode(name) for name in names], device="cuda") 13 | cosine_similarity = F.cosine_similarity(name_embeddings[:, None, :], embeddings[None, :, :], dim=-1) 14 | similarity = cosine_similarity.mean(dim=0) 15 | min_sim, min_index = similarity.min(dim=-1) 16 | semantic_item = id2name[str(min_index.item())] 17 | data['semantic'] = f"\"{semantic_item}\"\n" 18 | results.append(data) 19 | return results 20 | model = SentenceTransformer('./models/paraphrase-MiniLM-L3-v2') 21 | def read_json(json_file:str) -> dict: 22 | f = open(json_file, 'r') 23 | return json.load(f) 24 | def export_to_json(file_path:str,dic): 25 | f = open(file_path, 'w') 26 | json.dump(dic,f,indent=2) 27 | # semantic item 28 | for category in ["CDs_and_Vinyl"]: 29 | embeddings = torch.load(f"../eval/{category}/embeddings.pt").to('cuda') 30 | id2name = read_json(f"../eval/{category}/id2name.json") 31 | train_data = read_json(f"./{category}/train.json") 32 | batch_size = 64 33 | batched_data = [train_data[i:i+batch_size] for i in range(0, len(train_data), batch_size)] 34 | final_data = [] 35 | for batch in tqdm(batched_data, desc=f"Processing {category} train data......"): 36 | final_data.extend(process_batch(batch)) 37 | export_to_json(f"../data/{category}/train_semantic.json",train_data) -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: SPRec 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - defaults 6 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/ 7 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/ 8 | - https://repo.anaconda.com/pkgs/main 9 | - https://repo.anaconda.com/pkgs/r 10 | dependencies: 11 | - _libgcc_mutex=0.1=main 12 | - _openmp_mutex=5.1=1_gnu 13 | - blas=1.0=mkl 14 | - brotli-python=1.0.9=py38h6a678d5_8 15 | - bzip2=1.0.8=h5eee18b_6 16 | - ca-certificates=2024.9.24=h06a4308_0 17 | - charset-normalizer=3.3.2=pyhd3eb1b0_0 18 | - cuda-cudart=12.1.105=0 19 | - cuda-cupti=12.1.105=0 20 | - cuda-libraries=12.1.0=0 21 | - cuda-nvrtc=12.1.105=0 22 | - cuda-nvtx=12.1.105=0 23 | - cuda-opencl=12.4.127=0 24 | - cuda-runtime=12.1.0=0 25 | - cudatoolkit=11.8.0=h6a678d5_0 26 | - ffmpeg=4.3=hf484d3e_0 27 | - filelock=3.13.1=py38h06a4308_0 28 | - freetype=2.12.1=h4a9f257_0 29 | - gmp=6.2.1=h295c915_3 30 | - gmpy2=2.1.2=py38heeb90bb_0 31 | - gnutls=3.6.15=he1e5248_0 32 | - intel-openmp=2023.1.0=hdb19cb5_46306 33 | - jpeg=9e=h5eee18b_3 34 | - lame=3.100=h7b6447c_0 35 | - lcms2=2.12=h3be6417_0 36 | - ld_impl_linux-64=2.38=h1181459_1 37 | - lerc=3.0=h295c915_0 38 | - libcublas=12.1.0.26=0 39 | - libcufft=11.0.2.4=0 40 | - libcufile=1.9.1.3=0 41 | - libcurand=10.3.5.147=0 42 | - libcusolver=11.4.4.55=0 43 | - libcusparse=12.0.2.55=0 44 | - libdeflate=1.17=h5eee18b_1 45 | - libffi=3.4.4=h6a678d5_0 46 | - libgcc-ng=11.2.0=h1234567_1 47 | - libgomp=11.2.0=h1234567_1 48 | - libiconv=1.16=h5eee18b_3 49 | - libidn2=2.3.4=h5eee18b_0 50 | - libjpeg-turbo=2.0.0=h9bf148f_0 51 | - libnpp=12.0.2.50=0 52 | - libnvjitlink=12.1.105=0 53 | - libnvjpeg=12.1.1.14=0 54 | - libpng=1.6.39=h5eee18b_0 55 | - libstdcxx-ng=11.2.0=h1234567_1 56 | - libtasn1=4.19.0=h5eee18b_0 57 | - libtiff=4.5.1=h6a678d5_0 58 | - libunistring=0.9.10=h27cfd23_0 59 | - libwebp-base=1.3.2=h5eee18b_0 60 | - llvm-openmp=14.0.6=h9e868ea_0 61 | - lz4-c=1.9.4=h6a678d5_1 62 | - markupsafe=2.1.3=py38h5eee18b_0 63 | - mkl=2023.1.0=h213fc3f_46344 64 | - mkl-service=2.4.0=py38h5eee18b_1 65 | - mkl_fft=1.3.8=py38h5eee18b_0 66 | - mkl_random=1.2.4=py38hdb19cb5_0 67 | - mpc=1.1.0=h10f8cd9_1 68 | - mpfr=4.0.2=hb69a4c5_1 69 | - mpmath=1.3.0=py38h06a4308_0 70 | - ncurses=6.4=h6a678d5_0 71 | - nettle=3.7.3=hbbd107a_1 72 | - networkx=3.1=py38h06a4308_0 73 | - openh264=2.1.1=h4ff587b_0 74 | - openjpeg=2.5.2=he7f1fd0_0 75 | - openssl=3.0.12=h7f8727e_0 76 | - pip=23.3=py38h06a4308_0 77 | - pysocks=1.7.1=py38h06a4308_0 78 | - python=3.8.18=h955ad1f_0 79 | - pytorch=2.1.0=py3.8_cuda12.1_cudnn8.9.2_0 80 | - pytorch-cuda=12.1=ha16c6d3_5 81 | - pytorch-mutex=1.0=cuda 82 | - pyyaml=6.0.1=py38h5eee18b_0 83 | - readline=8.2=h5eee18b_0 84 | - setuptools=68.0.0=py38h06a4308_0 85 | - sqlite=3.41.2=h5eee18b_0 86 | - tbb=2021.8.0=hdb19cb5_0 87 | - tk=8.6.12=h1ccaba5_0 88 | - torchtriton=2.1.0=py38 89 | - typing_extensions=4.11.0=py38h06a4308_0 90 | - wheel=0.41.2=py38h06a4308_0 91 | - xz=5.4.2=h5eee18b_0 92 | - yaml=0.2.5=h7b6447c_0 93 | - zlib=1.2.13=h5eee18b_0 94 | - zstd=1.5.5=hc292b87_0 95 | - pip: 96 | - accelerate==1.0.1 97 | - aiofiles==23.2.1 98 | - aiohttp==3.9.0 99 | - aiosignal==1.3.1 100 | - altair==5.1.2 101 | - annotated-types==0.6.0 102 | - anyio==3.7.1 103 | - appdirs==1.4.4 104 | - argon2-cffi==23.1.0 105 | - argon2-cffi-bindings==21.2.0 106 | - arrow==1.3.0 107 | - asttokens==2.4.1 108 | - async-lru==2.0.4 109 | - async-timeout==4.0.3 110 | - attrs==23.1.0 111 | - babel==2.13.1 112 | - backcall==0.2.0 113 | - beautifulsoup4==4.12.2 114 | - bitsandbytes==0.43.1 115 | - bitsandbytes-cuda116==0.26.0.post2 116 | - black==23.11.0 117 | - bleach==6.1.0 118 | - certifi==2023.11.17 119 | - cffi==1.16.0 120 | - click==8.1.7 121 | - cmake==3.27.7 122 | - colorama==0.4.6 123 | - comm==0.2.0 124 | - contourpy==1.1.1 125 | - cycler==0.12.1 126 | - datasets==2.15.0 127 | - debugpy==1.8.0 128 | - decorator==5.1.1 129 | - defusedxml==0.7.1 130 | - dill==0.3.7 131 | - docker-pycreds==0.4.0 132 | - docopt==0.6.2 133 | - docstring-parser==0.16 134 | - eval-type-backport==0.2.0 135 | - exceptiongroup==1.2.0 136 | - executing==2.0.1 137 | - fastapi==0.104.1 138 | - fasteners==0.19 139 | - fastjsonschema==2.19.0 140 | - ffmpy==0.3.1 141 | - fire==0.5.0 142 | - fonttools==4.45.0 143 | - fqdn==1.5.1 144 | - frozenlist==1.4.0 145 | - fschat==0.2.34 146 | - fsspec==2023.10.0 147 | - gitdb==4.0.11 148 | - gitpython==3.1.43 149 | - gradio==3.50.2 150 | - gradio-client==0.6.1 151 | - h11==0.14.0 152 | - httpcore==1.0.2 153 | - httpx==0.25.1 154 | - huggingface-hub==0.25.0 155 | - idna==3.4 156 | - ijson==3.3.0 157 | - imageio==2.35.1 158 | - importlib-metadata==6.8.0 159 | - importlib-resources==6.1.1 160 | - ipykernel==6.27.1 161 | - ipython==8.12.3 162 | - ipywidgets==8.1.1 163 | - isoduration==20.11.0 164 | - jedi==0.19.1 165 | - jinja2==3.1.2 166 | - joblib==1.3.2 167 | - json5==0.9.14 168 | - jsonpointer==2.4 169 | - jsonschema==4.20.0 170 | - jsonschema-specifications==2023.11.1 171 | - jupyter==1.0.0 172 | - jupyter-client==8.6.0 173 | - jupyter-console==6.6.3 174 | - jupyter-core==5.5.0 175 | - jupyter-events==0.9.0 176 | - jupyter-lsp==2.2.1 177 | - jupyter-server==2.11.1 178 | - jupyter-server-terminals==0.4.4 179 | - jupyterlab==4.0.9 180 | - jupyterlab-pygments==0.3.0 181 | - jupyterlab-server==2.25.2 182 | - jupyterlab-widgets==3.0.9 183 | - kiwisolver==1.4.5 184 | - lazy-loader==0.4 185 | - lit==17.0.5 186 | - loguru==0.7.2 187 | - loralib==0.1.2 188 | - markdown-it-py==3.0.0 189 | - markdown2==2.4.12 190 | - matplotlib==3.7.4 191 | - matplotlib-inline==0.1.6 192 | - mdurl==0.1.2 193 | - mistune==3.0.2 194 | - multidict==6.0.4 195 | - multiprocess==0.70.15 196 | - mypy-extensions==1.0.0 197 | - nbclient==0.9.0 198 | - nbconvert==7.11.0 199 | - nbformat==5.9.2 200 | - nest-asyncio==1.5.8 201 | - nh3==0.2.15 202 | - nltk==3.8.1 203 | - notebook==7.0.6 204 | - notebook-shim==0.2.3 205 | - numpy==1.24.4 206 | - nvidia-cublas-cu11==11.10.3.66 207 | - nvidia-cublas-cu12==12.1.3.1 208 | - nvidia-cuda-cupti-cu11==11.7.101 209 | - nvidia-cuda-cupti-cu12==12.1.105 210 | - nvidia-cuda-nvrtc-cu11==11.7.99 211 | - nvidia-cuda-nvrtc-cu12==12.1.105 212 | - nvidia-cuda-runtime-cu11==11.7.99 213 | - nvidia-cuda-runtime-cu12==12.1.105 214 | - nvidia-cudnn-cu11==8.5.0.96 215 | - nvidia-cudnn-cu12==8.9.2.26 216 | - nvidia-cufft-cu11==10.9.0.58 217 | - nvidia-cufft-cu12==11.0.2.54 218 | - nvidia-curand-cu11==10.2.10.91 219 | - nvidia-curand-cu12==10.3.2.106 220 | - nvidia-cusolver-cu11==11.4.0.1 221 | - nvidia-cusolver-cu12==11.4.5.107 222 | - nvidia-cusparse-cu11==11.7.4.91 223 | - nvidia-cusparse-cu12==12.1.0.106 224 | - nvidia-nccl-cu11==2.14.3 225 | - nvidia-nccl-cu12==2.18.1 226 | - nvidia-nvjitlink-cu12==12.3.101 227 | - nvidia-nvtx-cu11==11.7.91 228 | - nvidia-nvtx-cu12==12.1.105 229 | - opencv-python==4.10.0.84 230 | - orjson==3.9.10 231 | - overrides==7.4.0 232 | - packaging==23.2 233 | - pandas==2.0.3 234 | - pandocfilters==1.5.0 235 | - parso==0.8.3 236 | - pathspec==0.11.2 237 | - peft==0.11.0 238 | - pexpect==4.8.0 239 | - pickleshare==0.7.5 240 | - pillow==10.1.0 241 | - pipreqs==0.5.0 242 | - pkgutil-resolve-name==1.3.10 243 | - platformdirs==4.0.0 244 | - prometheus-client==0.19.0 245 | - prompt-toolkit==3.0.41 246 | - protobuf==3.19.0 247 | - psutil==5.9.6 248 | - ptyprocess==0.7.0 249 | - pure-eval==0.2.2 250 | - pyarrow==14.0.1 251 | - pyarrow-hotfix==0.6 252 | - pycparser==2.21 253 | - pydantic==1.10.13 254 | - pydantic-core==2.14.3 255 | - pydub==0.25.1 256 | - pygments==2.17.2 257 | - pyparsing==3.1.1 258 | - python-dateutil==2.8.2 259 | - python-json-logger==2.0.7 260 | - python-multipart==0.0.6 261 | - pytz==2023.3.post1 262 | - pyzmq==25.1.1 263 | - qtconsole==5.5.1 264 | - qtpy==2.4.1 265 | - referencing==0.31.0 266 | - regex==2023.10.3 267 | - requests==2.31.0 268 | - rfc3339-validator==0.1.4 269 | - rfc3986-validator==0.1.1 270 | - rich==13.7.0 271 | - rpds-py==0.13.1 272 | - safetensors==0.4.5 273 | - scikit-image==0.21.0 274 | - scikit-learn==1.3.2 275 | - scipy==1.10.1 276 | - seaborn==0.13.0 277 | - semantic-version==2.10.0 278 | - send2trash==1.8.2 279 | - sentence-transformers==2.2.2 280 | - sentencepiece==0.1.99 281 | - sentry-sdk==1.45.0 282 | - setproctitle==1.3.3 283 | - shellingham==1.5.4 284 | - shortuuid==1.0.11 285 | - shtab==1.7.1 286 | - six==1.16.0 287 | - smmap==5.0.1 288 | - sniffio==1.3.0 289 | - some-package==0.1 290 | - soupsieve==2.5 291 | - stack-data==0.6.3 292 | - starlette==0.27.0 293 | - svgwrite==1.4.3 294 | - sympy==1.12 295 | - termcolor==2.3.0 296 | - terminado==0.18.0 297 | - threadpoolctl==3.2.0 298 | - tifffile==2023.7.10 299 | - tiktoken==0.5.2 300 | - tinycss2==1.2.1 301 | - tokenize-rt==5.2.0 302 | - tokenizers==0.19.1 303 | - tomli==2.0.1 304 | - tomlkit==0.12.0 305 | - toolz==0.12.0 306 | - torch==2.0.1 307 | - torchaudio==2.0.2 308 | - torchvision==0.15.2 309 | - tornado==6.4 310 | - tqdm==4.66.1 311 | - traitlets==5.13.0 312 | - transformers==4.44.2 313 | - triton==2.0.0 314 | - trl==0.9.2 315 | - typer==0.9.0 316 | - types-python-dateutil==2.8.19.14 317 | - typing-extensions==4.8.0 318 | - tyro==0.8.3 319 | - tzdata==2023.3 320 | - uri-template==1.3.0 321 | - urllib3==2.1.0 322 | - uvicorn==0.24.0.post1 323 | - wandb==0.16.6 324 | - wavedrom==2.0.3.post3 325 | - wcwidth==0.2.12 326 | - webcolors==1.13 327 | - webencodings==0.5.1 328 | - websocket-client==1.6.4 329 | - websockets==11.0.3 330 | - widgetsnbextension==4.0.9 331 | - xxhash==3.4.1 332 | - yarg==0.1.9 333 | - yarl==1.9.3 334 | - zipp==3.17.0 335 | prefix: # Specify your prefix 336 | -------------------------------------------------------------------------------- /eval/Goodreads/embeddings.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RegionCh/SPRec/96dff9b5a42c6cc227d07fcee6b3d8813865bac0/eval/Goodreads/embeddings.pt -------------------------------------------------------------------------------- /eval/Goodreads/genre_dict.json: -------------------------------------------------------------------------------- 1 | { 2 | "fiction": 0, 3 | "romance": 0, 4 | "young-adult": 0, 5 | "fantasy, paranormal": 0, 6 | "mystery, thriller, crime": 0, 7 | "history, historical fiction, biography": 0, 8 | "children": 0 9 | } -------------------------------------------------------------------------------- /eval/MovieLens/embeddings.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RegionCh/SPRec/96dff9b5a42c6cc227d07fcee6b3d8813865bac0/eval/MovieLens/embeddings.pt -------------------------------------------------------------------------------- /eval/MovieLens/genre_dict.json: -------------------------------------------------------------------------------- 1 | {"Action": 0, "Adventure": 0, "Sci-Fi": 0, "Thriller": 0, "Drama": 0, "Comedy": 0, "Fantasy": 0, "Crime": 0,"Romance": 0} -------------------------------------------------------------------------------- /eval/evaluate.py: -------------------------------------------------------------------------------- 1 | import os 2 | from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer 3 | import transformers 4 | import torch 5 | import re 6 | import math 7 | import json 8 | from peft import PeftModel 9 | import argparse 10 | import pandas as pd 11 | from collections import Counter 12 | from sentence_transformers import SentenceTransformer 13 | parse = argparse.ArgumentParser() 14 | parse.add_argument("--input_dir",type=str, default="./", help="result file") 15 | parse.add_argument("--model",type=str, default="SPRec", help="result file") 16 | parse.add_argument("--exp_csv",type=str, default=None, help="result file") 17 | parse.add_argument("--output_dir",type=str, default="./", help="eval_result") 18 | parse.add_argument("--topk",type=str, default="./", help="topk") 19 | parse.add_argument("--gamma",type=float,default=0.0,help="gamma") 20 | parse.add_argument("--category",type=str,default="CDs_and_Vinyl",help="gamma") 21 | args = parse.parse_args() 22 | def read_json(json_file:str) -> dict: 23 | f = open(json_file, 'r') 24 | return json.load(f) 25 | category = args.category 26 | id2name = read_json(f"./eval/{category}/id2name.json") 27 | name2id = read_json(f"./eval/{category}/name2id.json") 28 | embeddings = torch.load(f"./eval/{category}/embeddings.pt") 29 | name2genre = read_json(f"./eval/{category}/name2genre.json") 30 | genre_dict = read_json(f"./eval/{category}/genre_dict.json") 31 | def batch(list, batch_size=1): 32 | chunk_size = (len(list) - 1) // batch_size + 1 33 | for i in range(chunk_size): 34 | yield list[batch_size * i: batch_size * (i + 1)] 35 | 36 | def sum_of_first_i_keys(sorted_dic, i): 37 | keys = list(sorted_dic.values())[:i] 38 | return sum(keys) 39 | 40 | def gh(category:str,test_data): 41 | notin_count = 0 42 | in_count = 0 43 | name2genre=read_json(f"./eval/{category}/name2genre.json") 44 | genre_dict = read_json(f"./eval/{category}/genre_dict.json") 45 | for data in tqdm(test_data,desc="Processing category data......"): 46 | input = data['input'] 47 | names = re.findall(r'"([^"]+)"', input) 48 | for name in names: 49 | if name in name2genre: 50 | in_count += 1 51 | genres = name2genre[name] 52 | else: 53 | notin_count += 1 54 | # print(f"Not exist in name2genre:{name}") 55 | continue 56 | select_genres = [] 57 | for genre in genres: 58 | if genre in genre_dict: 59 | select_genres.append(genre) 60 | if(len(select_genres)>0): 61 | for genre in select_genres: 62 | genre_dict[genre] += 1/len(select_genres) 63 | gh = [genre_dict[x] for x in genre_dict] 64 | gh_normalize = [x/sum(gh) for x in gh] 65 | print(f"InCount:{in_count}\nNotinCount:{notin_count}") 66 | return gh_normalize 67 | 68 | 69 | result_json = args.input_dir 70 | f = open(result_json, 'r') 71 | test_data = json.load(f) 72 | total = 0 73 | # Identify your sentence-embedding model 74 | model = SentenceTransformer('/data/chenruijun/code/models/paraphrase-MiniLM-L3-v2') 75 | 76 | from tqdm import tqdm 77 | embeddings = torch.tensor(embeddings).cuda() 78 | text = [] 79 | for i,_ in tqdm(enumerate(test_data)): 80 | if(len(_["predict"])>0): 81 | if(len(_['predict'][0])==0): 82 | text.append("NAN") 83 | print("Empty prediction!") 84 | else: 85 | match = re.search(r'"([^"]*)', _['predict'][0]) 86 | if match: 87 | name = match.group(1) 88 | text.append(name) 89 | else: 90 | text.append(_['predict'][0].split('\n', 1)[0]) 91 | else: 92 | print("Empty:") 93 | 94 | predict_embeddings = [] 95 | for i, batch_input in tqdm(enumerate(batch(text, 8))): 96 | predict_embeddings.append(torch.tensor(model.encode(batch_input))) 97 | predict_embeddings = torch.cat(predict_embeddings, dim=0).cuda() 98 | predict_embeddings.size() 99 | dist = torch.cdist(predict_embeddings, embeddings, p=2) 100 | batch_size = 1 101 | num_batches = (dist.size(0) + batch_size - 1) // batch_size 102 | rank_list = [] 103 | for i in tqdm(range(num_batches), desc="Processing Batches"): 104 | start_idx = i * batch_size 105 | end_idx = min((i + 1) * batch_size, dist.size(0)) 106 | batch_dist = dist[start_idx:end_idx] 107 | 108 | batch_rank = batch_dist.argsort(dim=-1).argsort(dim=-1) 109 | torch.cuda.empty_cache () 110 | rank_list.append(batch_rank) 111 | 112 | rank_list = torch.cat(rank_list, dim=0) 113 | 114 | NDCG = [] 115 | HR = [] 116 | diversity = [] 117 | diversity_dic = {} 118 | MGU_genre = [] 119 | DGU_genre = [] 120 | pop_count = {} 121 | genre_count = {} 122 | notin = 0 123 | notin_count = 0 124 | in_count = 0 125 | topk_list = [int(args.topk)] 126 | diversity_set = set() 127 | for topk in topk_list: 128 | S_ndcg = 0 129 | S_hr = 0 130 | for i in tqdm(range(len(test_data)),desc="Calculating Metrics......"): 131 | rank = rank_list[i] 132 | # Target id 133 | target_name = test_data[i]['output'] 134 | predict_name = test_data[i]['predict'][0] 135 | target_name = target_name.strip().strip('"') 136 | if target_name in name2id: 137 | target_id = name2id[target_name] 138 | total += 1 139 | else: 140 | continue 141 | 142 | rankId = rank[target_id] 143 | 144 | # NDCG & HR 145 | if(rankId0): 162 | for genre in select_genres: 163 | genre_dict[genre] += 1/len(select_genres) 164 | else: 165 | notin += 1 166 | 167 | 168 | # diversity 169 | for i in range(topk): 170 | diversity_set.add(torch.argwhere(rank==i).item()) 171 | if torch.argwhere(rank==i).item() in diversity_dic: 172 | diversity_dic[torch.argwhere(rank==i).item()] += 1 173 | else: 174 | diversity_dic[torch.argwhere(rank==i).item()] = 1 175 | 176 | 177 | NDCG.append(S_ndcg / len(test_data) / (1 / math.log(2))) 178 | HR.append(S_hr / len(test_data)) 179 | diversity.append(len(diversity_set)) 180 | genre = args.category 181 | 182 | gh_genre = gh(category,test_data) 183 | # 184 | print(len(gh_genre)) 185 | gp_genre = [genre_dict[x] for x in genre_dict] 186 | gp_genre = [x/sum(gp_genre) for x in gp_genre] 187 | dis_genre = [gp_genre[i]-gh_genre[i] for i in range(len(gh_genre))] 188 | DGU_genre = max(dis_genre)-min(dis_genre) 189 | dis_abs_genre = [abs(x) for x in dis_genre] 190 | MGU_genre = sum(dis_abs_genre) / len(dis_genre) 191 | i=0 192 | 193 | gp_dict = {} 194 | i=0 195 | for key in genre_dict: 196 | gp_dict[key] = dis_abs_genre[i] 197 | i += 1 198 | print(f"gp_dict:{gp_dict}") 199 | print(f"NDCG:{NDCG}") 200 | print(f"HR:{HR}") 201 | div_ratio = diversity[0] / (total*topk) 202 | print(f"DGU:{DGU_genre}") 203 | print(f"MGU:{MGU_genre}") 204 | print(f"DivRatio:{div_ratio}") 205 | 206 | eval_dic = {} 207 | eval_dic["model"] = args.input_dir 208 | # eval_dic["Dis_genre"] = dis_abs_genre 209 | eval_dic['NDCG'] = NDCG 210 | eval_dic["HR"] = HR 211 | eval_dic["diversity"] = diversity 212 | eval_dic["DivRatio"] = div_ratio 213 | eval_dic['DGU'] = DGU_genre 214 | eval_dic["MGU"] = MGU_genre 215 | 216 | file_path = args.output_dir 217 | if os.path.exists(file_path) and os.path.getsize(file_path) > 0: 218 | with open(file_path, 'r') as file: 219 | try: 220 | data = json.load(file) 221 | except json.JSONDecodeError: 222 | data = [] 223 | else: 224 | data = [] 225 | sorted_dic = dict(sorted(diversity_dic.items(), key=lambda item: item[1],reverse=True)) 226 | count = 0 227 | i=0 228 | eval_dic["ORRatio"] = sum_of_first_i_keys(sorted_dic,3) / (topk*total) 229 | print(f"ORRatio:{sum_of_first_i_keys(sorted_dic,3) / (topk*total)}") 230 | #print(dict(sorted(diversity_dic.items(), key=lambda item: item[1]))) 231 | data.append(eval_dic) 232 | print(count) 233 | with open(args.output_dir, 'w') as file: 234 | json.dump(data, file,separators=(',', ': '),indent=2) 235 | 236 | def update_csv(dataset_name, model_name, metrics_dict, csv_file): 237 | df = pd.read_csv(csv_file) 238 | 239 | required_columns = ["Dataset", "Model"] 240 | if not all(col in df.columns for col in required_columns): 241 | raise ValueError("CSV 文件必须包含 'Dataset' 和 'Model' 列") 242 | 243 | condition = (df["Dataset"] == dataset_name) & (df["Model"] == model_name) 244 | if not condition.any(): 245 | new_row = {col: None for col in df.columns} 246 | new_row["Dataset"] = dataset_name 247 | new_row["Model"] = model_name 248 | 249 | new_row_df = pd.DataFrame([new_row]) 250 | df = pd.concat([df, new_row_df], ignore_index=True) 251 | 252 | condition = (df["Dataset"] == dataset_name) & (df["Model"] == model_name) 253 | 254 | for metric, value in metrics_dict.items(): 255 | if metric not in df.columns: 256 | print(f"注意:指标 '{metric}' 不在 CSV 文件列中,已添加该列并初始化为0。") 257 | df[metric] = 0 258 | df.loc[condition, metric] = value 259 | 260 | df.to_csv(csv_file, index=False) 261 | print(f"CSV 文件已更新:{csv_file}") 262 | 263 | if args.exp_csv != None: 264 | metric_dic = {} 265 | metric_dic[f"MGU@{args.topk}"] = eval_dic["MGU"] 266 | metric_dic[f"DGU@{args.topk}"] = eval_dic["DGU"] 267 | metric_dic[f"DivRatio@{args.topk}"] = eval_dic["DivRatio"] 268 | metric_dic[f"ORRatio@{args.topk}"] = sum_of_first_i_keys(sorted_dic,3) / (topk*total) 269 | if args.topk == '5': 270 | metric_dic[f"NDCG@{args.topk}"] = eval_dic["NDCG"] 271 | update_csv(category,args.model,metric_dic,args.exp_csv) -------------------------------------------------------------------------------- /eval/inference.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import fire 4 | import gradio as gr 5 | import torch 6 | torch.set_num_threads(1) 7 | import transformers 8 | import json 9 | import os 10 | os.environ['OPENBLAS_NUM_THREADS'] = '1' 11 | os.environ['OMP_NUM_THREADS'] = '1' 12 | from peft import PeftModel 13 | from transformers import GenerationConfig, LlamaTokenizer 14 | from transformers import LlamaForCausalLM,AutoTokenizer 15 | 16 | 17 | if torch.cuda.is_available(): 18 | device = "cuda" 19 | else: 20 | device = "cpu" 21 | 22 | try: 23 | if torch.backends.mps.is_available(): 24 | device = "mps" 25 | except: # noqa: E722 26 | pass 27 | 28 | 29 | def main( 30 | load_8bit: bool = False, 31 | base_model: str = "", 32 | lora_weights: str = "tloen/alpaca-lora-7b", 33 | test_data_path: str = "data/test.json", 34 | result_json_data: str = "temp.json", 35 | batch_size: int=32, 36 | num_beams: int=1 37 | ): 38 | assert ( 39 | base_model 40 | ), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'" 41 | 42 | #tokenizer = LlamaTokenizer.from_pretrained(base_model) 43 | tokenizer = AutoTokenizer.from_pretrained(base_model) 44 | tokenizer.pad_token_id = tokenizer.eos_token_id 45 | load_8bit = False 46 | if device == "cuda": 47 | model = LlamaForCausalLM.from_pretrained( 48 | base_model, 49 | load_in_8bit=load_8bit, 50 | torch_dtype=torch.float16, 51 | device_map="auto", 52 | ) 53 | model = PeftModel.from_pretrained( 54 | model, 55 | lora_weights, 56 | torch_dtype=torch.float16, 57 | device_map="auto" 58 | ) 59 | tokenizer.padding_side = "left" 60 | 61 | model.eval() 62 | 63 | def evaluate( 64 | instructions, 65 | inputs=None, 66 | temperature=1.0, 67 | top_p=0.9, 68 | top_k=40, 69 | num_beams=num_beams, 70 | max_new_tokens=32, 71 | **kwargs, 72 | ): 73 | prompt = [generate_prompt(instruction, input) for instruction, input in zip(instructions, inputs)] 74 | inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True).to(device) 75 | generation_config = GenerationConfig( 76 | temperature=temperature, 77 | top_p=top_p, 78 | top_k=top_k, 79 | num_beams=num_beams, 80 | num_return_sequences=num_beams, 81 | **kwargs, 82 | ) 83 | with torch.no_grad(): 84 | generation_output = model.generate( 85 | **inputs, 86 | generation_config=generation_config, 87 | return_dict_in_generate=True, 88 | output_scores=True, 89 | max_new_tokens=max_new_tokens, 90 | pad_token_id = tokenizer.eos_token_id 91 | ) 92 | s = generation_output.sequences 93 | output = tokenizer.batch_decode(s, skip_special_tokens=True) 94 | output = [_.split('Response:\n')[-1] for _ in output] 95 | real_outputs = [output[i * num_beams: (i + 1) * num_beams] for i in range(len(output) // num_beams)] 96 | return real_outputs 97 | 98 | 99 | outputs = [] 100 | tokenizer.pad_token_id = tokenizer.eos_token_id 101 | from tqdm import tqdm 102 | with open(test_data_path, 'r') as f: 103 | test_data = json.load(f) 104 | instructions = [_['instruction'] for _ in test_data] 105 | inputs = [_['input'] for _ in test_data] 106 | def batch(list, batch_size=batch_size): 107 | chunk_size = (len(list) - 1) // batch_size + 1 108 | for i in range(chunk_size): 109 | yield list[batch_size * i: batch_size * (i + 1)] 110 | for i, batch in tqdm(enumerate(zip(batch(instructions), batch(inputs)))): 111 | instructions, inputs = batch 112 | output = evaluate(instructions, inputs) 113 | outputs = outputs + output 114 | 115 | for i, test in tqdm(enumerate(test_data)): 116 | test_data[i]['predict'] = outputs[i] 117 | 118 | 119 | with open(result_json_data, 'w') as f: 120 | json.dump(test_data, f, indent=4) 121 | 122 | def generate_prompt(instruction, input=None): 123 | if input: 124 | return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. 125 | 126 | ### Instruction: 127 | {instruction} 128 | 129 | ### Input: 130 | {input} 131 | 132 | ### Response: 133 | """ 134 | else: 135 | return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request. 136 | 137 | ### Instruction: 138 | {instruction} 139 | 140 | ### Response: 141 | """ 142 | 143 | 144 | if __name__ == "__main__": 145 | fire.Fire(main) -------------------------------------------------------------------------------- /figs/method.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RegionCh/SPRec/96dff9b5a42c6cc227d07fcee6b3d8813865bac0/figs/method.png -------------------------------------------------------------------------------- /shell/SFT.sh: -------------------------------------------------------------------------------- 1 | base_model="" # Specify your base model here 2 | gpu1=$1; gpu2=$2; gpu3=$3; gpu4=$4 3 | sample=4096 4 | for category in "MovieLens" "Goodreads" "CDs_and_Vinyl" "Steam" 5 | do 6 | echo ---------------------- SFT for category $category starting! ---------------------- 7 | train_dataset="./data/${category}/train.json" 8 | valid_dataset="./data/${category}/valid.json" 9 | output_dir="./models/SFT_${sample}/${category}" 10 | mkdir -p $output_dir 11 | CUDA_VISIBLE_DEVICES=$gpu1,$gpu2,$gpu3,$gpu4 python ./train/sft.py \ 12 | --output_dir $output_dir\ 13 | --base_model $base_model \ 14 | --train_dataset $train_dataset \ 15 | --valid_dataset $valid_dataset \ 16 | --train_sample_size $sample \ 17 | --wandb_project SFT_${category}_${sample} \ 18 | --wandb_name SFT_${category}_${sample} \ 19 | --gradient_accumulation_steps 4 \ 20 | --batch_size 4 \ 21 | --num_train_epochs 4 \ 22 | --learning_rate 0.0003 \ 23 | --cutoff_len 512 24 | 25 | bash ./shell/eval_single_file.sh $gpu1 $gpu2 $gpu3 $gpu4 \ 26 | $base_model \ 27 | $output_dir \ 28 | $category \ 29 | $topk 30 | done 31 | 32 | -------------------------------------------------------------------------------- /shell/SPRec.sh: -------------------------------------------------------------------------------- 1 | gpu1=$1; gpu2=$2; gpu3=$3; gpu4=$4; its=$5 2 | train_sample_size=2048;valid_sample_size=256 3 | base_model="/data/chenruijun/code/models/Llama-3.2-1B-Instruct" 4 | batch_size=4 5 | lr=0.00002 6 | # Only change the parameters above if needed 7 | for category in "MovieLens" "Goodreads" "CDs_and_Vinyl" "Video_Games" "Steam" 8 | do 9 | lora_weights="./models/SFT_model_4096/${category}" 10 | output_dir="./models/SPRec/${category}_${train_sample_size}_${lr}" 11 | wandb_project="SPRec_${category}_${lr}_${train_sample_size}" 12 | echo ----------------- Training Parameters ----------------- 13 | echo "GPU: $gpu1,$gpu2,$gpu3,$gpu4" 14 | echo "Iterations: $its" 15 | echo "Train Sample Size: $train_sample_size" 16 | echo "Valid Sample Size: $valid_sample_size" 17 | echo "Base Model: $base_model" 18 | echo "LoRA Weights: $lora_weights" 19 | echo "Category: $category" 20 | echo "Learning Rate: $lr" 21 | 22 | for ((i=0;i<$its;i++)) 23 | do 24 | echo ----------------- Iteration$i starts! ----------------- 25 | it_output_dir="${output_dir}/it${i}/" 26 | dpo_train_data_path="${it_output_dir}/data/dpo_train.jsonl" 27 | dpo_valid_data_path="${it_output_dir}/data/dpo_valid.jsonl" 28 | sft_train_data_path="${it_output_dir}/data/sft_train.jsonl" 29 | sft_valid_data_path="${it_output_dir}/data/sft_valid.jsonl" 30 | mkdir -p $it_output_dir 31 | mkdir -p "${it_output_dir}/data" 32 | touch "${dpo_train_data_path}" 33 | touch "${dpo_valid_data_path}" 34 | touch "${sft_train_data_path}" 35 | touch "${sft_valid_data_path}" 36 | # Data Generation 37 | CUDA_VISIBLE_DEVICES=$gpu1,$gpu2,$gpu3,$gpu4 python ./train/data_generate.py \ 38 | --train_json_file ./data/${category}/train.json \ 39 | --valid_json_file ./data/${category}/valid.json \ 40 | --result_json_dpo_data_train $dpo_train_data_path \ 41 | --result_json_dpo_data_valid $dpo_valid_data_path \ 42 | --result_json_sft_data_train $sft_train_data_path \ 43 | --result_json_sft_data_valid $sft_valid_data_path \ 44 | --base_model $base_model \ 45 | --lora_weights $lora_weights \ 46 | --batch_size 64 \ 47 | --train_sample_size $train_sample_size \ 48 | --valid_sample_size $valid_sample_size \ 49 | # SFT 50 | wandb_name="iteration${i}_SFT" 51 | SFT_path="${it_output_dir}SFT" 52 | mkdir -p $SFT_path 53 | CUDA_VISIBLE_DEVICES=$gpu1,$gpu2,$gpu3,$gpu4 python ./train/sft.py \ 54 | --resume_from_checkpoint $lora_weights \ 55 | --output_dir $SFT_path \ 56 | --base_model $base_model \ 57 | --train_dataset $sft_train_data_path \ 58 | --valid_dataset $sft_valid_data_path \ 59 | --train_sample_size $train_sample_size \ 60 | --wandb_project $wandb_project \ 61 | --wandb_name $wandb_name \ 62 | --gradient_accumulation_steps 4 \ 63 | --batch_size $batch_size \ 64 | --num_train_epochs 1 \ 65 | --learning_rate $lr \ 66 | --cutoff_len 512 \ 67 | # Evaluate SFT model 68 | lora_weights=$SFT_path 69 | bash ./shell/eval_single_file.sh $gpu1 $gpu2 $gpu3 $gpu4 \ 70 | $base_model \ 71 | $lora_weights \ 72 | $category 73 | # DPO 74 | wandb_name="iteration${i}_DPO" 75 | DPO_path="${it_output_dir}DPO/" 76 | mkdir -p $DPO_path 77 | CUDA_VISIBLE_DEVICES=$gpu1,$gpu2,$gpu3,$gpu4 python ./train/dpo.py \ 78 | --train_dataset $dpo_train_data_path \ 79 | --val_dataset $dpo_valid_data_path \ 80 | --output_dir $DPO_path \ 81 | --base_model $base_model \ 82 | --resume_from_checkpoint $lora_weights \ 83 | --wandb_name $wandb_name \ 84 | --wandb_project $wandb_project \ 85 | --batch_size 2 \ 86 | --gradient_accumulation_steps 4 \ 87 | --learning_rate $lr \ 88 | --cutoff_len 512 \ 89 | --num_epochs 1 90 | # Evaluate DPO model 91 | lora_weights=$DPO_path 92 | bash ./shell/eval_single_file.sh $gpu1 $gpu2 $gpu3 $gpu4 \ 93 | $base_model \ 94 | $lora_weights \ 95 | $category 96 | 97 | done 98 | echo SPRec for category ${category} has successfully completed! 99 | done -------------------------------------------------------------------------------- /shell/eval_single_file.sh: -------------------------------------------------------------------------------- 1 | # bash ./shell/eval_single_file.sh 0 2 4 5 2 | base_model=$5 3 | lora_weights=$6 4 | category=$7 5 | # Only change the parameters above if needed 6 | echo -------------------------------------- Evaluation started! -------------------------------------- 7 | 8 | gpu1=$1; gpu2=$2; gpu3=$3; gpu4=$4 9 | test_json="./data/$category/test.json" 10 | result_json="${lora_weights}/test_result.json" 11 | touch $result_json 12 | CUDA_VISIBLE_DEVICES=$gpu1,$gpu2,$gpu3,$gpu4 python ./eval/inference.py \ 13 | --base_model $base_model \ 14 | --lora_weights $lora_weights \ 15 | --test_data_path $test_json \ 16 | --result_json_data $result_json \ 17 | --num_beams 1 18 | echo Result for model "$lora_weights" is created in $result_json! 19 | eval_result_json="${lora_weights}/eval_top1.json" 20 | CUDA_VISIBLE_DEVICES=$1 python ./eval/evaluate.py \ 21 | --input_dir $result_json \ 22 | --output_dir $eval_result_json \ 23 | --topk 1 \ 24 | --gamma 0 \ 25 | --category $category 26 | echo Metrics for model "$lora_weights" is created in $eval_result_json! 27 | eval_result_json="${lora_weights}/eval_top5.json" 28 | CUDA_VISIBLE_DEVICES=$1 python ./eval/evaluate.py \ 29 | --input_dir $result_json \ 30 | --output_dir $eval_result_json \ 31 | --topk 5 \ 32 | --gamma 0 \ 33 | --category $category 34 | echo Metrics for model "$lora_weights" is created in $eval_result_json! 35 | 36 | echo -------------------------------------- Evaluation finished! -------------------------------------- -------------------------------------------------------------------------------- /train/data_generate.py: -------------------------------------------------------------------------------- 1 | import re 2 | import json 3 | import sys 4 | import fire 5 | import gradio as gr 6 | import numpy as np 7 | import torch 8 | torch.set_num_threads(1) 9 | from sentence_transformers import SentenceTransformer 10 | import random 11 | import transformers 12 | from tqdm import tqdm 13 | import json 14 | import os 15 | os.environ['OPENBLAS_NUM_THREADS'] = '1' 16 | os.environ['OMP_NUM_THREADS'] = '1' 17 | from peft import PeftModel 18 | from transformers import GenerationConfig,AutoTokenizer 19 | from transformers import LlamaForCausalLM 20 | if torch.cuda.is_available(): 21 | device = "cuda" 22 | else: 23 | device = "cpu" 24 | 25 | def main( 26 | train_json_file : str = "", 27 | valid_json_file : str = "", 28 | result_json_dpo_data_train: str = "", 29 | result_json_dpo_data_valid: str = "", 30 | result_json_sft_data_train: str = "", 31 | result_json_sft_data_valid: str = "", 32 | base_model: str = "", 33 | lora_weights: str = "", 34 | batch_size:int = 4, 35 | train_sample_size:int = 1024, 36 | valid_sample_size:int = 128, 37 | load_8bit: bool = False, 38 | random_neg: bool = False, 39 | ): 40 | 41 | # generate responses from model 42 | tokenizer = AutoTokenizer.from_pretrained(base_model) 43 | tokenizer.pad_token_id = tokenizer.eos_token_id 44 | load_8bit = False 45 | if device == "cuda": 46 | model = LlamaForCausalLM.from_pretrained( 47 | base_model, 48 | load_in_8bit=load_8bit, 49 | torch_dtype=torch.float16, 50 | device_map="auto", 51 | ) 52 | model = PeftModel.from_pretrained( 53 | model, 54 | lora_weights, 55 | torch_dtype=torch.float16, 56 | device_map="auto" 57 | ) 58 | tokenizer.padding_side = "left" 59 | 60 | model.eval() 61 | 62 | #emb_model = SentenceTransformer('/data/chenruijun/code/models/paraphrase-MiniLM-L3-v2') 63 | 64 | def evaluate( 65 | instructions, 66 | inputs=None, 67 | temperature=0, 68 | top_p=0.9, 69 | top_k=40, 70 | num_beams=1, 71 | max_new_tokens=128, 72 | **kwargs, 73 | ): 74 | prompt = [generate_prompt(instruction, input) for instruction, input in zip(instructions, inputs)] 75 | inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True).to(device) 76 | generation_config = GenerationConfig( 77 | temperature=temperature, 78 | top_p=top_p, 79 | top_k=top_k, 80 | num_beams=num_beams, 81 | num_return_sequences=num_beams, 82 | **kwargs, 83 | ) 84 | with torch.no_grad(): 85 | generation_output = model.generate( 86 | **inputs, 87 | generation_config=generation_config, 88 | return_dict_in_generate=True, 89 | output_scores=True, 90 | max_new_tokens=max_new_tokens, 91 | pad_token_id = tokenizer.eos_token_id 92 | ) 93 | s = generation_output.sequences 94 | output = tokenizer.batch_decode(s, skip_special_tokens=True) 95 | output = [_.split('Response:\n')[-1] for _ in output] 96 | real_outputs = [output[i * num_beams: (i + 1) * num_beams] for i in range(len(output) // num_beams)] 97 | return real_outputs 98 | 99 | outputs = [] 100 | tokenizer.pad_token_id = tokenizer.eos_token_id 101 | 102 | with open(train_json_file, 'r') as f: 103 | train_data = json.load(f) 104 | train_data = random.sample(train_data, train_sample_size) 105 | sft_train_data = train_data 106 | with open(valid_json_file, 'r') as f: 107 | valid_data = json.load(f) 108 | valid_data = random.sample(valid_data, valid_sample_size) 109 | sft_valid_data = valid_data 110 | with open(result_json_sft_data_train, 'w') as f: 111 | for item in sft_train_data: 112 | json.dump(item, f) 113 | f.write('\n') 114 | with open(result_json_sft_data_valid, 'w') as f: 115 | for item in sft_valid_data: 116 | json.dump(item, f) 117 | f.write('\n') 118 | data = train_data + valid_data 119 | instructions = [_['instruction'] for _ in data] 120 | inputs = [_['input'] for _ in data] 121 | def batch(list, batch_size=batch_size): 122 | chunk_size = (len(list) - 1) // batch_size + 1 123 | for i in range(chunk_size): 124 | yield list[batch_size * i: batch_size * (i + 1)] 125 | for i, batch in tqdm(enumerate(zip(batch(instructions), batch(inputs)))): 126 | instructions, inputs = batch 127 | output = evaluate(instructions, inputs) 128 | outputs = outputs + output 129 | 130 | for i, test in tqdm(enumerate(data)): 131 | data[i]['predict'] = outputs[i] 132 | 133 | dpo_data = [] 134 | 135 | for data_point in data: 136 | dpo_case = {} 137 | dpo_case['prompt'] = data_point['instruction'] + data_point['input'] 138 | dpo_case['chosen'] = data_point['output'] 139 | pattern = r'"(.*?)"' 140 | item_names = re.findall(pattern, data_point['predict'][0]) 141 | formatted_item_names = [f'\"{item}\"' for item in item_names] 142 | if len(formatted_item_names) > 0: 143 | dpo_case['rejected'] = formatted_item_names[0]+"\n" 144 | else: 145 | dpo_case['rejected'] = "\n" 146 | dpo_data.append(dpo_case) 147 | 148 | # random.shuffle(dpo_data) 149 | dpo_train_data = dpo_data[:train_sample_size] 150 | dpo_valid_data = dpo_data[train_sample_size:] 151 | 152 | 153 | with open(result_json_dpo_data_train, 'w') as f: 154 | for item in dpo_train_data: 155 | json.dump(item, f) 156 | f.write('\n') 157 | 158 | with open(result_json_dpo_data_valid, 'w') as f: 159 | for item in dpo_valid_data: 160 | json.dump(item, f) 161 | f.write('\n') 162 | 163 | 164 | def generate_prompt(instruction, input=None): 165 | if input: 166 | return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. 167 | 168 | ### Instruction: 169 | {instruction} 170 | 171 | ### Input: 172 | {input} 173 | 174 | ### Response: 175 | """ 176 | else: 177 | return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request. 178 | 179 | ### Instruction: 180 | {instruction} 181 | 182 | ### Response: 183 | """ 184 | 185 | 186 | if __name__ == "__main__": 187 | fire.Fire(main) 188 | -------------------------------------------------------------------------------- /train/dpo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import re 4 | import random 5 | 6 | from peft import get_peft_config, get_peft_model, get_peft_model_state_dict, LoraConfig, TaskType, PeftModel 7 | from transformers import AutoTokenizer, TrainingArguments, AutoModelForCausalLM, BitsAndBytesConfig 8 | from datasets import load_dataset, load_from_disk 9 | from trl import DPOTrainer, DPOConfig 10 | from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model 11 | # from utils import find_all_linear_names, print_trainable_parameters 12 | from transformers import LlamaForCausalLM, LlamaTokenizer 13 | import torch.nn.functional as F 14 | import torch 15 | import bitsandbytes as bnb 16 | from accelerate import Accelerator 17 | import fire 18 | 19 | 20 | def main( 21 | train_dataset = "", 22 | val_dataset = "", 23 | load_8bit: bool = True, 24 | base_model: str = "", 25 | gradient_accumulation_steps: int = 4, 26 | output_dir: str = "", 27 | wandb_project: str = "self_play", 28 | wandb_name: str = "", # the name of the wandb run 29 | batch_size:int = 2, 30 | num_epochs:int = 1, 31 | alpha:float = 1.5, 32 | learning_rate: float = 1e-5, 33 | cutoff_len: int = 512, 34 | eval_step = 0.05, 35 | resume_from_checkpoint:bool = False, 36 | seed = 99 37 | ): 38 | 39 | os.environ['WANDB_PROJECT'] = wandb_project 40 | 41 | train_dataset = load_dataset("json", data_files=train_dataset) 42 | train_data = train_dataset["train"].shuffle(seed=seed) 43 | val_dataset = load_dataset("json", data_files=val_dataset) 44 | val_data = val_dataset["train"].shuffle(seed=seed) 45 | 46 | device_index = Accelerator().process_index 47 | device_map = {"": device_index} 48 | 49 | bnb_config = BitsAndBytesConfig( 50 | # load_in_8bit=True, 51 | load_in_4bit=True, 52 | bnb_4bit_quant_type="nf4", 53 | bnb_4bit_compute_dtype=torch.bfloat16, 54 | bnb_4bit_use_double_quant=False, 55 | ) 56 | 57 | device_index = Accelerator().process_index 58 | device_map = {"": device_index} 59 | 60 | model = AutoModelForCausalLM.from_pretrained( 61 | base_model, 62 | device_map=device_map, 63 | quantization_config=bnb_config 64 | ) 65 | model.config.use_cache = False 66 | model = prepare_model_for_kbit_training(model) 67 | 68 | tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True) 69 | tokenizer.pad_token = tokenizer.eos_token 70 | tokenizer.padding_side = "right" 71 | 72 | if resume_from_checkpoint!="base_model": 73 | model = PeftModel.from_pretrained( 74 | model, 75 | resume_from_checkpoint, 76 | is_trainable=True 77 | ) 78 | else: 79 | peft_config = LoraConfig( 80 | inference_mode=False, 81 | r=16, 82 | lora_alpha=32, 83 | target_modules=['k_proj', 'v_proj', 'q_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj'], 84 | lora_dropout=0.05, 85 | bias="none", 86 | task_type="CAUSAL_LM", 87 | ) 88 | model = get_peft_model(model, peft_config) 89 | 90 | model.print_trainable_parameters() 91 | 92 | model_ref = AutoModelForCausalLM.from_pretrained( 93 | base_model, 94 | device_map=device_map, 95 | quantization_config=bnb_config 96 | ) 97 | 98 | if resume_from_checkpoint: 99 | reference_model = PeftModel.from_pretrained(model_ref, resume_from_checkpoint) 100 | else: 101 | reference_model = model_ref 102 | 103 | 104 | training_args = DPOConfig( 105 | per_device_train_batch_size=batch_size, 106 | per_device_eval_batch_size=batch_size, 107 | gradient_accumulation_steps=gradient_accumulation_steps, 108 | warmup_steps=20, 109 | num_train_epochs=num_epochs, 110 | learning_rate=learning_rate, 111 | bf16=True, 112 | logging_steps=1, 113 | optim="adamw_torch", 114 | evaluation_strategy="steps", 115 | save_strategy="steps", 116 | output_dir=output_dir, 117 | save_total_limit=1, 118 | load_best_model_at_end=True, 119 | ) 120 | 121 | dpo_trainer = DPOTrainer( 122 | model, 123 | reference_model, 124 | args=training_args, 125 | beta=0.1, 126 | train_dataset=train_data, 127 | eval_dataset=val_data, 128 | tokenizer=tokenizer, 129 | max_prompt_length=cutoff_len, 130 | max_length=cutoff_len, 131 | ) 132 | 133 | 134 | dpo_trainer.train() 135 | dpo_trainer.save_model(output_dir) 136 | 137 | 138 | print("DPO training is done") 139 | 140 | if __name__ == "__main__": 141 | fire.Fire(main) 142 | -------------------------------------------------------------------------------- /train/sft.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import re 4 | import wandb 5 | 6 | from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments,BitsAndBytesConfig 7 | from datasets import load_dataset 8 | from trl import SFTTrainer, DataCollatorForCompletionOnlyLM, SFTConfig 9 | from peft import AutoPeftModelForCausalLM, LoraConfig, get_peft_model, prepare_model_for_kbit_training, TaskType, PeftModel 10 | from transformers import LlamaForCausalLM, LlamaTokenizer 11 | # from utils import find_all_linear_names, print_trainable_parameters 12 | import random 13 | from accelerate import Accelerator 14 | 15 | import torch 16 | import bitsandbytes as bnb 17 | import fire 18 | 19 | 20 | def train( 21 | # path 22 | output_dir="", 23 | base_model ="", 24 | train_dataset="", 25 | valid_dataset="", 26 | train_sample_size:int = 1024, 27 | resume_from_checkpoint: str = "base_model", # either training checkpoint or final adapter 28 | # wandb config 29 | wandb_project: str = "", 30 | wandb_name: str = "", # the name of the wandb run 31 | # training hyperparameters 32 | gradient_accumulation_steps: int = 1, 33 | batch_size: int = 8, 34 | num_train_epochs: int = 5, 35 | learning_rate: float = 2e-5, 36 | cutoff_len: int = 512, 37 | eval_step = 0.05, 38 | seed=0 39 | ): 40 | os.environ['WANDB_PROJECT'] = wandb_project 41 | 42 | def formatting_prompts_func(examples): 43 | output_text = [] 44 | for i in range(len(examples["instruction"])): 45 | instruction = examples["instruction"][i] 46 | input_text = examples["input"][i] 47 | response = examples["output"][i] 48 | 49 | if len(input_text) >= 2: 50 | text = f'''Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. 51 | 52 | ### Instruction: 53 | {instruction} 54 | 55 | ### Input: 56 | {input_text} 57 | 58 | ### Response: 59 | {response} 60 | ''' 61 | else: 62 | text = f'''Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. 63 | 64 | ### Instruction: 65 | {instruction} 66 | 67 | ### Response: 68 | {response} 69 | ''' 70 | output_text.append(text) 71 | 72 | return output_text 73 | 74 | train_dataset = load_dataset("json", data_files=train_dataset) 75 | train_data = train_dataset["train"].shuffle(seed=seed).select(range(train_sample_size)) 76 | val_dataset = load_dataset("json", data_files=valid_dataset) 77 | val_data = val_dataset["train"].shuffle(seed=seed).select(range(int(train_sample_size/8))) 78 | 79 | bnb_config = BitsAndBytesConfig( 80 | # load_in_8bit=True, 81 | load_in_4bit=True, 82 | bnb_4bit_quant_type="nf4", 83 | bnb_4bit_compute_dtype=torch.bfloat16, 84 | bnb_4bit_use_double_quant=False, 85 | ) 86 | 87 | device_index = Accelerator().process_index 88 | device_map = {"": device_index} 89 | #device_map = "auto" 90 | model = AutoModelForCausalLM.from_pretrained( 91 | base_model, 92 | device_map=device_map, 93 | quantization_config=bnb_config 94 | ) 95 | model.config.use_cache = False 96 | model = prepare_model_for_kbit_training(model) 97 | 98 | tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True) 99 | tokenizer.pad_token = tokenizer.eos_token 100 | tokenizer.padding_side = "right" 101 | 102 | if resume_from_checkpoint!="base_model": 103 | model = PeftModel.from_pretrained( 104 | model, 105 | resume_from_checkpoint, 106 | is_trainable=True 107 | ) 108 | else: 109 | peft_config = LoraConfig( 110 | inference_mode=False, 111 | r=16, 112 | lora_alpha=32, 113 | target_modules=['k_proj', 'v_proj', 'q_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj'], 114 | lora_dropout=0.05, 115 | bias="none", 116 | task_type="CAUSAL_LM", 117 | ) 118 | model = get_peft_model(model, peft_config) 119 | 120 | model.print_trainable_parameters() 121 | 122 | training_args = SFTConfig( 123 | per_device_train_batch_size=batch_size, 124 | per_device_eval_batch_size=batch_size, 125 | gradient_accumulation_steps=gradient_accumulation_steps, 126 | warmup_steps=20, 127 | num_train_epochs=num_train_epochs, 128 | learning_rate=learning_rate, 129 | bf16=True, 130 | logging_steps=1, 131 | optim="adamw_torch", 132 | evaluation_strategy="steps", 133 | save_strategy="steps", 134 | output_dir=output_dir, 135 | save_total_limit=1, 136 | load_best_model_at_end=True, 137 | report_to=None, 138 | ) 139 | 140 | trainer = SFTTrainer( 141 | model, 142 | train_dataset=train_data, 143 | eval_dataset=val_data, 144 | tokenizer=tokenizer, 145 | formatting_func=formatting_prompts_func, 146 | max_seq_length=cutoff_len, 147 | args=training_args 148 | ) 149 | 150 | trainer.train() 151 | trainer.save_model(output_dir) 152 | 153 | output_dir = os.path.join(output_dir, "final_model") 154 | trainer.model.save_pretrained(output_dir,safe_serialization=False) 155 | tokenizer.save_pretrained(output_dir) 156 | 157 | if __name__ == "__main__": 158 | fire.Fire(train) --------------------------------------------------------------------------------