├── README.md
├── baselines
├── D3
│ ├── D3_evaluate.py
│ ├── LogitProcesser.py
│ └── evaluate2.sh
├── DMPO
│ ├── dmpo_trainer.py
│ └── utils.py
├── Re-weighting
│ └── RW_SFT.py
├── SDPO
│ ├── softmax_dpo_trainer.py
│ └── utils.py
└── Semantic_sampling_rosePO
│ └── Semantic_sampling_rosePO.py
├── data
└── Goodreads
│ ├── test.json
│ ├── train.json
│ └── valid.json
├── environment.yml
├── eval
├── Goodreads
│ ├── embeddings.pt
│ ├── genre_dict.json
│ ├── id2name.json
│ ├── name2genre.json
│ └── name2id.json
├── MovieLens
│ ├── embeddings.pt
│ ├── genre_dict.json
│ ├── id2name.json
│ ├── name2genre.json
│ └── name2id.json
├── evaluate.py
└── inference.py
├── figs
└── method.png
├── shell
├── SFT.sh
├── SPRec.sh
└── eval_single_file.sh
└── train
├── data_generate.py
├── dpo.py
└── sft.py
/README.md:
--------------------------------------------------------------------------------
1 |
2 | # SPRec: Self-Play to Debias LLM-based Recommendation
3 |
4 |
5 |

6 |
7 |
8 | This repository provides the official PyTorch implementation and reproduction for the paper titled "SPRec: Self-Play to Debias LLM-based Recommendation"
9 |
10 | ## Installation
11 |
12 | 1. Clone this git repository and change directory to this repository:
13 |
14 | 2. A new [conda environment](https://docs.conda.io/projects/conda/en/latest/user-guide/concepts/environments.html) is suggested.
15 |
16 | ```bash
17 | conda env create -f environment.yml
18 | ```
19 |
20 | 3. Activate the newly created environment.
21 |
22 | ```bash
23 | conda activate SPRec
24 | ```
25 |
26 |
27 | ## Quick Start
28 |
29 | Due to GitHub's file size limitations, we have uploaded the minimal sample dataset **Goodreads** in `./data/Goodreads` and `./eval/Goodreads` for reproduction purposes. Additionally, the datasets used in our experiments—**MovieLens**, **CDs and Vinyl**, and **Steam**—have been uploaded to [Datasets](https://zenodo.org/records/14900102?token=eyJhbGciOiJIUzUxMiJ9.eyJpZCI6IjMwYTA1OWM4LWRjZTctNDJmNC1iOWY2LTRjZWQyZjZiNjY5ZCIsImRhdGEiOnt9LCJyYW5kb20iOiI2ZTYyZDZkZTFlNDM5NjA2ZGMwMTA2YWIxMjdjMDJmNCJ9.g2bckZWGA77AEg9EBARxN45rmXYfGD8RuRzy41CZACcDh2XESWxAGD3b91ecu_FEbmYQSzR5qBTH0xvQC_Lw2Q). If you wish to use a different dataset, please ensure that it is processed into a similar format.
30 |
31 | Besides, to ensure that SPRec does not encounter more training data during multiple iterations compared to other baseline methods, it is recommended to sample the training dataset beforehand to limit its size. The sample dataset we provide has already been sampled and contains 5,000 entries. You can further sample it according to your requirements to control the total amount of data SPRec is exposed to during training.
32 |
33 |
34 | ### How to Train Using SPRec Framework
35 |
36 | 1. **SFT Training**:
37 | Before using the SPRec training framework, you need to run SFT to fine-tune your base model for alignment with the recommendation task. Use the following command to perform SFT training:
38 | ```bash
39 | bash ./shell/SFT.sh 0 1 2 3 # Specify your GPUs, e.g., 0 1 2 3
40 | 2. **SPRec Training**:
41 | After completing SFT training, use the following command to perform SPRec training:
42 | ```bash
43 | bash ./shell/SPRec.sh 0 1 2 3 5 # Specify your GPUs, e.g., 0 1 2 3, and the number of iterations, e.g., 5
44 | Once the above commands are executed, the evaluation results for top-1 and top-5 recommendations will be saved as eval_top1.json and eval_top5.json in the corresponding model directory.
45 |
46 | ## **Baseline Implementations Acknowledgement
47 | This repository also includes implementations of baseline methods in our paper for research comparison. We sincerely acknowledge the original authors for their foundational work.
48 |
49 | If you find this repository helpful, we kindly request citing our paper:
50 | ```
51 | @article{gao2024sprec,
52 | title={SPRec: Self-Play to Debias LLM-based Recommendation},
53 | author={Gao, Chongming and Chen, Ruijun and Yuan, Shuai and Huang, Kexin and Yu, Yuanqing and He, Xiangnan},
54 | journal={arXiv preprint arXiv:2412.09243},
55 | year={2024}
56 | }
57 | ```
58 |
--------------------------------------------------------------------------------
/baselines/D3/D3_evaluate.py:
--------------------------------------------------------------------------------
1 |
2 | import pandas as pd
3 | import fire
4 | import torch
5 | import json
6 | import os
7 | from peft import PeftModel
8 | from transformers import GenerationConfig, AutoTokenizer
9 | from transformers import AutoModelForCausalLM
10 | from dataset import D3Dataset
11 | from transformers import LogitsProcessorList, TemperatureLogitsWarper
12 | from transformers import GenerationConfig, LlamaTokenizer
13 | from transformers import LlamaForCausalLM,AutoTokenizer
14 | from LogitProcesser import CFEnhancedLogitsProcessor
15 | if torch.cuda.is_available():
16 | device = "cuda"
17 | else:
18 | device = "cpu"
19 | P = 998244353
20 | MOD = int(1e9 + 9)
21 | import numpy as np
22 |
23 | def get_hash(x):
24 | x = [str(_) for _ in x]
25 | return '-'.join(x)
26 |
27 |
28 |
29 | def main(
30 | base_model: str = "",
31 | train_file: str = "",
32 | info_file: str = "",
33 | category: str = "",
34 | logits_file: str=None,
35 | lora_weights:str = "",
36 | test_data_path: str = "data/test.json",
37 | result_json_data: str = "temp.json",
38 | batch_size: int = 1,
39 | K: int = 0,
40 | seed: int = 0,
41 | temperature: float=1.0,
42 | guidance_scale: float=1.0,
43 | length_penalty: float=1.0
44 | ):
45 | category_dict = {"Office_Products": "office products", "Goodreads": "books", "Steam": "games", "CDs_and_Vinyl": "musics", "Toys_and_Games": "toys and games", "Video_Games": "video games", "Musical_Instruments": "music instruments", "Sports_and_Outdoors": "sports and outdoors", "Pet_Supplies": "pet supplies", "Arts_Crafts_and_Sewing": "arts products", "STEAM": "games" ,"MovieLens":"movies"}
46 | category = category_dict[category]
47 | model = LlamaForCausalLM.from_pretrained(
48 | base_model,
49 | load_in_8bit=False,
50 | torch_dtype=torch.float16,
51 | device_map="auto",
52 | )
53 | model = PeftModel.from_pretrained(
54 | model,
55 | lora_weights,
56 | torch_dtype=torch.float16,
57 | device_map="auto"
58 | )
59 | with open(info_file, 'r') as f:
60 | info = f.readlines()
61 | print(info)
62 | info = ["\"" + _.split('\t')[0].strip(' ') + "\"\n" for _ in info]
63 | item_name = info
64 | info = [f'''### Response:
65 | {_}''' for _ in info]
66 |
67 | tokenizer = AutoTokenizer.from_pretrained(base_model)
68 | if base_model.lower().find("llama") > -1:
69 | prefixID = [tokenizer(_).input_ids[1:] for _ in info]
70 | else:
71 | prefixID = [tokenizer(_).input_ids for _ in info]
72 |
73 | hash_dict = dict()
74 | sasrec_dict = dict()
75 | for index, ID in enumerate(prefixID):
76 | ID.append(tokenizer.eos_token_id)
77 | for i in range(4, len(ID)):
78 | if i == 4:
79 | hash_number = get_hash(ID[:i])
80 | else:
81 | hash_number = get_hash(ID[4:i])
82 | if hash_number not in hash_dict:
83 | hash_dict[hash_number] = set()
84 | sasrec_dict[hash_number] = set()
85 | hash_dict[hash_number].add(ID[i])
86 | sasrec_dict[hash_number].add(index)
87 | hash_number = get_hash(ID[4:])
88 | if hash_number not in sasrec_dict:
89 | sasrec_dict[hash_number] = set()
90 | sasrec_dict[hash_number].add(index)
91 |
92 | for key in hash_dict.keys():
93 | hash_dict[key] = list(hash_dict[key])
94 | for key in sasrec_dict.keys():
95 | sasrec_dict[key] = list(sasrec_dict[key])
96 |
97 | def prefix_allowed_tokens_fn(batch_id, input_ids):
98 | hash_number = get_hash(input_ids)
99 | if hash_number in hash_dict:
100 | return hash_dict[hash_number]
101 | return []
102 |
103 | tokenizer.pad_token = tokenizer.eos_token
104 | tokenizer.pad_token_id = tokenizer.eos_token_id
105 | tokenizer.padding_side = "left"
106 | val_dataset=D3Dataset(train_file=test_data_path, tokenizer=tokenizer,max_len=2560, category=category, test=True,K=K, seed=seed)
107 |
108 |
109 | if logits_file is not None:
110 | if not logits_file.endswith(".npy"):
111 | logits_file = None
112 |
113 | if logits_file is not None:
114 | logits = np.load(logits_file)
115 | sasrec_logits = torch.tensor(logits).softmax(dim = -1)
116 | sasrec_logits = sasrec_logits[val_dataset.data['Unnamed: 0'].tolist()]
117 |
118 | encodings = [val_dataset.__getitem__(i) for i in range(len(val_dataset))]
119 | test_data = val_dataset.get_all()
120 |
121 | model.config.pad_token_id = model.config.eos_token_id = tokenizer.eos_token_id
122 | model.config.bos_token_id = tokenizer.bos_token_id
123 |
124 | model.eval()
125 |
126 | def evaluate(
127 | encodings,
128 | cf_logits,
129 | temperature=1.0,
130 | num_beams=1,
131 | max_new_tokens=32,
132 | top_p=0.9,
133 | top_k=40,
134 | guidance_scale=0.8,
135 | length_penalty=1.0,
136 | **kwargs,
137 | ):
138 | maxLen = max([len(_["input_ids"]) for _ in encodings])
139 |
140 | padding_encodings = {"input_ids": []}
141 |
142 | for _ in encodings:
143 | L = len(_["input_ids"])
144 | padding_encodings["input_ids"].append([tokenizer.pad_token_id] * (maxLen - L) + _["input_ids"])
145 |
146 | generation_config = GenerationConfig(
147 | num_beams=num_beams,
148 | temperature = temperature,
149 | #length_penalty=length_penalty,
150 | top_p=top_p,
151 | top_k=top_k,
152 | num_return_sequences=num_beams,
153 | pad_token_id = model.config.pad_token_id,
154 | eos_token_id = model.config.eos_token_id,
155 | max_new_tokens = max_new_tokens,
156 | **kwargs
157 | )
158 | with torch.no_grad():
159 | ccc = CFEnhancedLogitsProcessor(
160 | guidance_scale=guidance_scale,
161 | cf_logits=cf_logits,
162 | prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
163 | cf_dict=sasrec_dict,
164 | unconditional_ids=None,
165 | model=model,
166 | tokenizer=tokenizer,
167 | num_beams=num_beams
168 | )
169 | logits_processor = LogitsProcessorList([TemperatureLogitsWarper(temperature=temperature), ccc])
170 | # logits 对应上
171 | generation_output = model.generate(
172 | torch.tensor(padding_encodings["input_ids"]).to(device),
173 | generation_config=generation_config,
174 | return_dict_in_generate=True,
175 | output_scores=True,
176 | logits_processor=logits_processor,
177 | )
178 | s = generation_output.sequences[:, L:]
179 | sequence_scores = [[0 for i in range(len(generation_output.scores))] for _ in range(num_beams)]
180 | #for i in range(num_beams):
181 | #for j in range(L, len(generation_output.sequences[i])):
182 | #if num_beams > 1:
183 | #beam_index = generation_output.beam_indices[i][j - L]
184 | #if beam_index != -1:
185 | #sequence_scores[i][j - L] = generation_output.scores[j - L][beam_index][generation_output.sequences[i][j]].item()
186 |
187 | #scores = generation_output.sequences_scores.tolist()
188 | scores = [1958.0]
189 | output = tokenizer.batch_decode(s, skip_special_tokens=True)
190 | output = [_.split("Response:")[-1] for _ in output]
191 | real_outputs = [output[i * num_beams: (i + 1) * num_beams] for i in range(len(output) // num_beams)]
192 | real_scores = [scores[i * num_beams: (i + 1) * num_beams] for i in range(len(scores) // num_beams)]
193 | return real_outputs, real_scores, sequence_scores
194 |
195 | model = model.to(device)
196 |
197 | from tqdm import tqdm
198 | outputs = []
199 | new_encodings = []
200 | BLOCK = (len(encodings) + batch_size - 1) // batch_size
201 | for i in range(BLOCK):
202 | new_encodings.append(encodings[i * batch_size: (i + 1) * batch_size])
203 | Flg=True
204 | scores = []
205 | seq_scores = []
206 | import random
207 | for idx, encodings in enumerate(tqdm(new_encodings)):
208 | if logits_file is not None:
209 | output, score, seq_score = evaluate(encodings, sasrec_logits[idx].to(device), temperature=temperature, guidance_scale=guidance_scale, length_penalty=length_penalty)
210 | else:
211 | output, score, seq_score = evaluate(encodings, cf_logits=None, temperature=temperature, guidance_scale=guidance_scale, length_penalty=length_penalty)
212 | if idx == 0:
213 | print(output)
214 | print(score)
215 | outputs = outputs + output
216 | scores = scores+ score
217 | seq_scores.append(seq_score)
218 |
219 | for i, test in enumerate(test_data):
220 | test["predict"] = outputs[i]
221 | #test["predict_score"] = scores[i]
222 | #test["predict_seq_score"] = seq_scores[i]
223 |
224 | for i in range(len(test_data)):
225 | if 'dedup' in test_data[i]:
226 | test_data[i].pop('dedup')
227 |
228 | with open(result_json_data, 'w') as f:
229 | json.dump(test_data, f, indent=4)
230 |
231 | if __name__ == '__main__':
232 | fire.Fire(main)
233 |
234 |
235 |
236 |
237 |
--------------------------------------------------------------------------------
/baselines/D3/LogitProcesser.py:
--------------------------------------------------------------------------------
1 | from transformers.generation import LogitsProcessor
2 | from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
3 | import math
4 | import numpy as np
5 | import torch
6 |
7 | from transformers.utils import add_start_docstrings
8 |
9 | LOGITS_PROCESSOR_INPUTS_DOCSTRING = r"""
10 | Args:
11 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
12 | Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
13 | scores (`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`):
14 | Prediction scores of a language modeling head. These can be logits for each vocabulary when not using beam
15 | search or log softmax for each vocabulary token when using beam search
16 |
17 | Return:
18 | `torch.FloatTensor` of shape `(batch_size, config.vocab_size)`: The processed prediction scores.
19 |
20 | """
21 |
22 | class PrefixConstrainedLogitsProcessor(LogitsProcessor):
23 |
24 | def __init__(self, prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]], num_beams: int):
25 | self._prefix_allowed_tokens_fn = prefix_allowed_tokens_fn
26 | self._num_beams = num_beams
27 |
28 | @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
29 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
30 | mask = torch.full_like(scores, -math.inf)
31 | for batch_id, beam_sent in enumerate(input_ids.view(-1, self._num_beams, input_ids.shape[-1])):
32 | for beam_id, sent in enumerate(beam_sent):
33 | prefix_allowed_tokens = self._prefix_allowed_tokens_fn(batch_id, sent)
34 | if len(prefix_allowed_tokens) == 0:
35 | raise ValueError(
36 | f"`prefix_allowed_tokens_fn` returned an empty list for batch ID {batch_id}."
37 | f"This means that the constraint is unsatisfiable. Please check your implementation"
38 | f"of `prefix_allowed_tokens_fn` "
39 | )
40 | mask[batch_id * self._num_beams + beam_id, prefix_allowed_tokens] = 0
41 |
42 | scores_processed = scores + mask
43 | return scores_processed
44 |
45 |
46 | def get_hash(x):
47 | x = [str(_) for _ in x]
48 | return '-'.join(x)
49 |
50 | class CFEnhancedLogitsProcessor(LogitsProcessor):
51 |
52 | def __init__(
53 | self,
54 | tokenizer,
55 | model,
56 | cf_logits,
57 | cf_dict,
58 | guidance_scale: float,
59 | prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]],
60 | num_beams: int,
61 | unconditional_ids: Optional[torch.LongTensor] = None,
62 | unconditional_attention_mask: Optional[torch.LongTensor] = None,
63 | use_cache: Optional[bool] = True,
64 | ):
65 | self._prefix_allowed_tokens_fn = prefix_allowed_tokens_fn
66 | self.model = model
67 | self.unconditional_context = {
68 | "input_ids": unconditional_ids,
69 | "attention_mask": unconditional_attention_mask,
70 | "use_cache": use_cache,
71 | "past_key_values": None,
72 | "first_pass": True,
73 | }
74 | self._num_beams = num_beams
75 | self.guidance_scale = guidance_scale
76 | self.tokenizer = tokenizer
77 | self.cf_logits = cf_logits
78 | self.cf_dict = cf_dict
79 | self.count=0
80 |
81 |
82 | @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
83 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
84 | scores = torch.nn.functional.log_softmax(scores, dim=-1)
85 | mask = torch.full_like(scores, -1000000)
86 | cf_score = torch.full_like(scores, 1.0)
87 | for batch_id, beam_sent in enumerate(input_ids.view(-1, self._num_beams, input_ids.shape[-1])):
88 | for beam_id, sent in enumerate(beam_sent):
89 | if self.count == 0:
90 | hash_key = sent[-4:]
91 | else:
92 | hash_key=sent[-self.count:]
93 | hash_key = hash_key.tolist()
94 | prefix_allowed_tokens = self._prefix_allowed_tokens_fn(batch_id, hash_key)
95 |
96 | if len(prefix_allowed_tokens) == 0:
97 | continue
98 | mask[batch_id * self._num_beams + beam_id, prefix_allowed_tokens] = 0
99 |
100 | temp = []
101 | if self.cf_logits is not None:
102 | # print(self.cf_logits)
103 | for allow_token in prefix_allowed_tokens:
104 | if self.count == 0:
105 | cf_key = [allow_token]
106 | else:
107 | cf_key = hash_key + [allow_token]
108 | if get_hash(cf_key) in self.cf_dict:
109 | hash_value = self.cf_dict[get_hash(cf_key)]
110 | else:
111 | continue
112 |
113 | sublogits = self.cf_logits[hash_value]
114 | temp.append(sublogits.sum() + 1e-20) # max or sum
115 | temp = torch.tensor(temp)
116 | temp = temp / temp.sum()
117 | cf_score[batch_id * self._num_beams + beam_id].scatter_(dim = -1, index=torch.tensor(prefix_allowed_tokens).to(cf_score.device), src=temp.to(cf_score.device))
118 | cf_score = torch.log(cf_score)
119 | cf_score = cf_score + mask
120 | self.count += 1
121 |
122 | if self.guidance_scale == 1:
123 | scores = scores + mask
124 | return scores
125 |
126 | scores = scores + mask
127 | out = self.guidance_scale * (scores - cf_score) + cf_score
128 |
129 | return out
130 |
--------------------------------------------------------------------------------
/baselines/D3/evaluate2.sh:
--------------------------------------------------------------------------------
1 | for category in "Goodreads"
2 | do
3 | cudalist="7"
4 | for i in ${cudalist}
5 | do
6 | echo $i
7 | CUDA_VISIBLE_DEVICES=$i python ./evaluate.py \
8 | --base_model ./output_dir/${category}/ \
9 | --train_file ${train_file} \
10 | --info_file ${info_file} \
11 | --category ${category} \
12 | --test_data_path ./temp/${category}_base/${i}.csv \
13 | --result_json_data ./temp/${category}_base/${i}.json \
14 | --length_penalty 0.0 \
15 | --logits_file YOUR_LOGITS_FILE_PATH
16 | done
17 | wait
18 | python ./code/merge.py --input_path ./temp/${category}_base --output_path ./output_dir/${category}/final_result.json
19 | python ./code/calc.py --path ./output_dir/${category}/final_result.json --item_path ${info_file}
20 | done
21 |
--------------------------------------------------------------------------------
/baselines/DMPO/dmpo_trainer.py:
--------------------------------------------------------------------------------
1 | # DPO Authors: Rafael Rafailov, Archit Sharma, Eric Mitchell, Stefano Ermon, Christopher D. Manning, and Chelsea Finn 2023
2 | # Copyright 2023 The HuggingFace Team. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | import warnings
16 | from collections import defaultdict
17 | from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
18 | import importlib
19 |
20 |
21 | import torch
22 | import torch.nn as nn
23 | import torch.nn.functional as F
24 | from datasets import Dataset
25 | from transformers import DataCollator, PreTrainedModel, PreTrainedTokenizerBase, Trainer, TrainingArguments
26 | from transformers.trainer_callback import TrainerCallback
27 |
28 | from .utils import DPODataCollatorWithPadding, pad_to_length
29 |
30 |
31 | def is_peft_available():
32 | return importlib.util.find_spec("peft") is not None
33 |
34 | if is_peft_available():
35 | from peft import get_peft_model, prepare_model_for_kbit_training
36 |
37 |
38 | class DPOTrainer(Trainer):
39 | r"""
40 | Initialize DPOTrainer.
41 |
42 | Args:
43 | model (`transformers.PreTrainedModel`):
44 | The model to train, preferably an `AutoModelForSequenceClassification`.
45 | ref_model (`PreTrainedModelWrapper`):
46 | Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss.
47 | beta (`float`, defaults to 0.1):
48 | The beta factor in DPO loss. Higher beta means less divergence from the initial policy.
49 | args (`transformers.TrainingArguments`):
50 | The arguments to use for training.
51 | data_collator (`transformers.DataCollator`):
52 | The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
53 | which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
54 | label_pad_token_id (`int`, defaults to `-100`):
55 | The label pad token id. This argument is required if you want to use the default data collator.
56 | padding_value (`int`, defaults to `0`):
57 | The padding value. This argument is required if you want to use the default data collator.
58 | truncation_mode (`str`, defaults to `keep_end`):
59 | The truncation mode to use, either `keep_end` or `keep_start`. This argument is required if you want to use the default data collator.
60 | train_dataset (`datasets.Dataset`):
61 | The dataset to use for training.
62 | eval_dataset (`datasets.Dataset`):
63 | The dataset to use for evaluation.
64 | tokenizer (`transformers.PreTrainedTokenizerBase`):
65 | The tokenizer to use for training. This argument is required if you want to use the default data collator.
66 | model_init (`Callable[[], transformers.PreTrainedModel]`):
67 | The model initializer to use for training. If None is specified, the default model initializer will be used.
68 | callbacks (`List[transformers.TrainerCallback]`):
69 | The callbacks to use for training.
70 | optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
71 | The optimizer and scheduler to use for training.
72 | preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
73 | The function to use to preprocess the logits before computing the metrics.
74 | max_length (`int`, defaults to `None`):
75 | The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator.
76 | max_prompt_length (`int`, defaults to `None`):
77 | The maximum length of the prompt. This argument is required if you want to use the default data collator.
78 | peft_config (`Dict`, defaults to `None`):
79 | The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
80 | """
81 |
82 | def __init__(
83 | self,
84 | model: Union[PreTrainedModel, nn.Module] = None,
85 | ref_model: Union[PreTrainedModel, nn.Module] = None,
86 | beta: float = 0.1,
87 | args: TrainingArguments = None,
88 | data_collator: Optional[DataCollator] = None,
89 | label_pad_token_id: int = -100,
90 | padding_value: int = 0,
91 | truncation_mode: str = "keep_end",
92 | train_dataset: Optional[Dataset] = None,
93 | eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
94 | tokenizer: Optional[PreTrainedTokenizerBase] = None,
95 | model_init: Optional[Callable[[], PreTrainedModel]] = None,
96 | callbacks: Optional[List[TrainerCallback]] = None,
97 | optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (
98 | None,
99 | None,
100 | ),
101 | preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
102 | max_length: Optional[int] = None,
103 | max_prompt_length: Optional[int] = None,
104 | peft_config: Optional[Dict] = None,
105 | ):
106 | if not is_peft_available() and peft_config is not None:
107 | raise ValueError(
108 | "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
109 | )
110 | elif is_peft_available() and peft_config is not None:
111 | if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
112 | model = prepare_model_for_kbit_training(model)
113 | model = get_peft_model(model, peft_config)
114 |
115 | if data_collator is None:
116 | if tokenizer is None:
117 | raise ValueError(
118 | "max_length or a tokenizer must be specified when using the default DPODataCollatorWithPadding"
119 | )
120 | if max_length is None:
121 | warnings.warn(
122 | "When using DPODataCollatorWithPadding, you should set `max_length` in the DPOTrainer's init"
123 | " it will be set to `512` by default, but you should do it yourself in the future.",
124 | UserWarning,
125 | )
126 | max_length = 512
127 | if max_prompt_length is None:
128 | warnings.warn(
129 | "When using DPODataCollatorWithPadding, you should set `max_prompt_length` in the DPOTrainer's init"
130 | " it will be set to `128` by default, but you should do it yourself in the future.",
131 | UserWarning,
132 | )
133 | max_prompt_length = 128
134 |
135 | data_collator = DPODataCollatorWithPadding(
136 | tokenizer,
137 | max_length=max_length,
138 | max_prompt_length=max_prompt_length,
139 | label_pad_token_id=label_pad_token_id,
140 | padding_value=padding_value,
141 | truncation_mode=truncation_mode,
142 | )
143 |
144 | if args.remove_unused_columns:
145 | args.remove_unused_columns = False
146 | # warn users
147 | warnings.warn(
148 | "When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments"
149 | " we have set it for you, but you should do it yourself in the future.",
150 | UserWarning,
151 | )
152 |
153 | self.use_dpo_data_collator = True
154 | else:
155 | self.use_dpo_data_collator = False
156 |
157 | self.label_pad_token_id = label_pad_token_id
158 | self.padding_value = padding_value
159 |
160 | self.beta = beta
161 | self.ref_model = ref_model
162 |
163 | self._stored_metrics = defaultdict(lambda: defaultdict(list))
164 |
165 | super().__init__(
166 | model,
167 | args,
168 | data_collator,
169 | train_dataset,
170 | eval_dataset,
171 | tokenizer,
172 | model_init,
173 | None,
174 | callbacks,
175 | optimizers,
176 | preprocess_logits_for_metrics,
177 | )
178 |
179 | # Since we inherit from trainer we always have access to an accelerator
180 | if hasattr(self, "accelerator"):
181 | self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
182 | else:
183 | raise AttributeError(
184 | "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
185 | )
186 |
187 | def concatenated_inputs(self, batch: Dict[str, Union[List, torch.LongTensor]]) -> Dict[str, torch.LongTensor]:
188 | """Concatenate the chosen and rejected inputs into a single tensor.
189 |
190 | Args:
191 | batch: A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors of shape (batch_size, sequence_length).
192 |
193 | Returns:
194 | A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'.
195 | """
196 | # 把 chosen 和 rejected response 拼接起来
197 | rejected_max_len = max([batch[key].shape[1] for key in batch if key.startswith("rejected") and key.endswith("_input_ids")])
198 | max_length = max(batch["chosen_input_ids"].shape[1], rejected_max_len)
199 | concatenated_batch = {}
200 | for k in batch:
201 | if k.startswith("chosen") and isinstance(batch[k], torch.Tensor):
202 | pad_value = self.label_pad_token_id if "labels" in k else self.padding_value
203 | concatenated_key = k.replace("chosen", "concatenated")
204 | concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value)
205 | for k in batch:
206 | if k.startswith("rejected") and isinstance(batch[k], torch.Tensor):
207 | pad_value = self.label_pad_token_id if "labels" in k else self.padding_value
208 | # concatenated_key = k.replace("rejected", "concatenated")
209 | prefix = k.split("_")[0]
210 | concatenated_key = "concatenated" + k[len(prefix):]
211 | concatenated_batch[concatenated_key] = torch.cat(
212 | (
213 | concatenated_batch[concatenated_key],
214 | pad_to_length(batch[k], max_length, pad_value=pad_value),
215 | ),
216 | dim=0,
217 | ).to(self.accelerator.device)
218 | return concatenated_batch
219 |
220 | def dpo_loss(
221 | self,
222 | policy_chosen_logps: torch.FloatTensor,
223 | policy_rejected_logps: Dict[str, torch.FloatTensor],
224 | reference_chosen_logps: torch.FloatTensor,
225 | reference_rejected_logps: Dict[str, torch.FloatTensor],
226 | reference_free: bool = False,
227 | ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
228 | """Compute the DPO loss for a batch of policy and reference model log probabilities.
229 |
230 | Args:
231 | policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
232 | policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
233 | reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,)
234 | reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,)
235 | beta: Temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5. We ignore the reference model as beta -> 0.
236 | reference_free: If True, we ignore the _provided_ reference model and implicitly use a reference model that assigns equal probability to all responses.
237 |
238 | Returns:
239 | A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
240 | The losses tensor contains the DPO loss for each example in the batch.
241 | The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
242 | """
243 | # pi_logratios = policy_chosen_logps - policy_rejected_logps
244 | # for key in policy_rejected_logps:
245 | # ref_logratios = reference_chosen_logps - reference_rejected_logps
246 | chosen_logratios = policy_chosen_logps - reference_chosen_logps
247 | # print(f"chosen:{chosen_logratios}")
248 | rejected_logratios = {}
249 | for key in policy_rejected_logps:
250 | rejected_logratios[key] = policy_rejected_logps[key] - reference_rejected_logps[key]
251 | # print(f"{key}_logratios:{rejected_logratios[key].shape}")
252 | # if reference_free:
253 | # ref_logratios = 0
254 |
255 | # logits = pi_logratios - ref_logratios
256 | # temp = sum(torch.exp(self.beta * (rejected_logratios[key] - chosen_logratios)) for key in rejected_logratios)
257 | temp = torch.exp(self.beta * sum(rejected_logratios[key] - chosen_logratios for key in rejected_logratios))
258 |
259 | temp1 = -torch.log(temp)
260 | losses = -F.logsigmoid(temp1)
261 | # losses = -F.logsigmoid(self.beta * logits)
262 | rejected_rewards = {}
263 | chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps).detach()
264 | for key in policy_rejected_logps:
265 | rejected_rewards[key] = self.beta * (policy_rejected_logps[key] - reference_rejected_logps[key]).detach()
266 |
267 | return losses, chosen_rewards, rejected_rewards
268 |
269 | def _get_batch_logps(
270 | self,
271 | logits: torch.FloatTensor,
272 | labels: torch.LongTensor,
273 | average_log_prob: bool = False,
274 | ) -> torch.FloatTensor:
275 | """Compute the log probabilities of the given labels under the given logits.
276 |
277 | Args:
278 | logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
279 | labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)
280 | average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
281 |
282 | Returns:
283 | A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
284 | """
285 | if logits.shape[:-1] != labels.shape:
286 | raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
287 |
288 | labels = labels[:, 1:].clone()
289 | logits = logits[:, :-1, :]
290 | loss_mask = labels != self.label_pad_token_id
291 |
292 | # dummy token; we'll ignore the losses on these tokens later
293 | labels[labels == self.label_pad_token_id] = 0
294 |
295 | per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
296 |
297 | if average_log_prob:
298 | return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
299 | else:
300 | return (per_token_logps * loss_mask).sum(-1)
301 |
302 | def concatenated_forward(
303 | self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
304 | ) -> Tuple[torch.FloatTensor, Dict[str, torch.FloatTensor], torch.FloatTensor, Dict[str, torch.FloatTensor]]:
305 | """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
306 |
307 | We do this to avoid doing two forward passes, because it's faster for FSDP.
308 | """
309 | concatenated_batch = self.concatenated_inputs(batch)
310 | # print(concatenated_batch["concatenated_input_ids"].shape)
311 | all_logits = model(
312 | concatenated_batch["concatenated_input_ids"],
313 | attention_mask=concatenated_batch["concatenated_attention_mask"],
314 | ).logits.to(torch.float32)
315 | all_logps = self._get_batch_logps(
316 | all_logits,
317 | concatenated_batch["concatenated_labels"],
318 | average_log_prob=False,
319 | )
320 | chosen_logps = all_logps[: batch["chosen_input_ids"].shape[0]]
321 | step = batch["chosen_input_ids"].shape[0]
322 | rejected_logps = {}
323 | cnt = 0
324 | for key in batch:
325 | if key.startswith("rejected") and key.endswith("_input_ids"):
326 | cnt += 1
327 | rejected_logps[f"rejected{cnt}"] = all_logps[step*cnt : step*(cnt+1)]
328 |
329 | chosen_logits = all_logits[: batch["chosen_input_ids"].shape[0]]
330 | rejected_logits = {}
331 | cnt = 0
332 | for key in batch:
333 | if key.startswith("rejected") and key.endswith("_input_ids"):
334 | cnt += 1
335 | rejected_logits[f"rejected{cnt}"] = all_logits[step*cnt : step*(cnt+1)]
336 | return (chosen_logps, rejected_logps, chosen_logits, rejected_logits)
337 |
338 | def get_batch_metrics(
339 | self,
340 | model,
341 | batch: Dict[str, Union[List, torch.LongTensor]],
342 | train_eval: Literal["train", "eval"] = "train",
343 | ):
344 | """Compute the DPO loss and other metrics for the given batch of inputs for train or test."""
345 | metrics = {}
346 |
347 | (
348 | policy_chosen_logps,
349 | policy_rejected_logps,
350 | policy_chosen_logits,
351 | policy_rejected_logits,
352 | ) = self.concatenated_forward(model, batch)
353 | with torch.no_grad():
354 | (
355 | reference_chosen_logps,
356 | reference_rejected_logps,
357 | _,
358 | _,
359 | ) = self.concatenated_forward(self.ref_model, batch)
360 |
361 | losses, chosen_rewards, rejected_rewards = self.dpo_loss(
362 | policy_chosen_logps,
363 | policy_rejected_logps,
364 | reference_chosen_logps,
365 | reference_rejected_logps,
366 | )
367 |
368 | # reward_accuracies 记录 chosen 比所有 rejected 的收益都大的比例是多少
369 | reward_accuracies = None
370 | for key in rejected_rewards:
371 | if reward_accuracies is None:
372 | reward_accuracies = (chosen_rewards > rejected_rewards[key]).float()
373 | else:
374 | reward_accuracies *= (chosen_rewards > rejected_rewards[key]).float()
375 |
376 | prefix = "eval_" if train_eval == "eval" else ""
377 | metrics[f"{prefix}rewards/chosen"] = chosen_rewards.cpu().numpy().mean()
378 | for key in rejected_rewards:
379 | metrics[f"{prefix}rewards/{key}"] = rejected_rewards[key].cpu().numpy().mean()
380 | metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.cpu().numpy().mean()
381 | for key in rejected_rewards:
382 | metrics[f"{prefix}rewards/margins-{key}"] = (chosen_rewards - rejected_rewards[key]).cpu().numpy().mean()
383 | for key in policy_rejected_logps:
384 | metrics[f"{prefix}logps/rejected-{key}"] = policy_rejected_logps[key].detach().cpu().numpy().mean()
385 | metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().cpu().numpy().mean()
386 | for key in policy_rejected_logits:
387 | metrics[f"{prefix}logits/rejected-{key}"] = policy_rejected_logits[key].detach().cpu().numpy().mean()
388 | metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().cpu().numpy().mean()
389 |
390 | return losses.mean(), metrics
391 |
392 | def compute_loss(
393 | self,
394 | model: Union[PreTrainedModel, nn.Module],
395 | inputs: Dict[str, Union[torch.Tensor, Any]],
396 | return_outputs=False,
397 | ) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]:
398 | # print(inputs.keys())
399 | # print(inputs)
400 | if not self.use_dpo_data_collator:
401 | warnings.warn(
402 | "compute_loss is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than "
403 | "DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
404 | )
405 | loss, metrics = self.get_batch_metrics(model, inputs, train_eval="train")
406 |
407 | # force log the metrics
408 | if self.accelerator.is_main_process:
409 | self.store_metrics(metrics, train_eval="train")
410 |
411 | if return_outputs:
412 | return (loss, metrics)
413 | return loss
414 |
415 | def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]:
416 | """Generate samples from the model and reference model for the given batch of inputs."""
417 |
418 | policy_output = model.generate(
419 | batch["prompt_input_ids"],
420 | attention_mask=batch["prompt_attention_mask"],
421 | max_length=self.config.max_length,
422 | do_sample=True,
423 | pad_token_id=self.tokenizer.pad_token_id,
424 | )
425 |
426 | reference_output = self.ref_model.generate(
427 | batch["prompt_input_ids"],
428 | attention_mask=batch["prompt_attention_mask"],
429 | max_length=self.config.max_length,
430 | do_sample=True,
431 | pad_token_id=self.tokenizer.pad_token_id,
432 | )
433 |
434 | policy_output = pad_to_length(policy_output, self.config.max_length, self.tokenizer.pad_token_id)
435 | policy_output_decoded = self.tokenizer.batch_decode(policy_output, skip_special_tokens=True)
436 |
437 | reference_output = pad_to_length(reference_output, self.config.max_length, self.tokenizer.pad_token_id)
438 | reference_output_decoded = self.tokenizer.batch_decode(reference_output, skip_special_tokens=True)
439 |
440 | return policy_output_decoded, reference_output_decoded
441 |
442 | def prediction_step(
443 | self,
444 | model: Union[PreTrainedModel, nn.Module],
445 | inputs: Dict[str, Union[torch.Tensor, Any]],
446 | prediction_loss_only: bool,
447 | ignore_keys: Optional[List[str]] = None,
448 | ):
449 | if not self.use_dpo_data_collator:
450 | warnings.warn(
451 | "prediction_step is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than "
452 | "DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
453 | )
454 | if ignore_keys is None:
455 | if hasattr(model, "config"):
456 | ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
457 | else:
458 | ignore_keys = []
459 |
460 | with torch.no_grad():
461 | loss, metrics = self.get_batch_metrics(model, inputs, train_eval="eval")
462 |
463 | # force log the metrics
464 | if self.accelerator.is_main_process:
465 | self.store_metrics(metrics, train_eval="eval")
466 |
467 | if prediction_loss_only:
468 | return (loss.detach(), None, None)
469 |
470 | # logits for the chosen and rejected samples from model
471 | logits_dict = {
472 | "logits_test/chosen": metrics["logits_test/chosen"],
473 | # "logits_test/rejected": metrics["logits_test/rejected"],
474 | }
475 | logits = tuple(v for k, v in logits_dict.items() if k not in ignore_keys)
476 | logits = torch.stack(logits).mean(axis=1)
477 | labels = torch.zeros(logits.shape[0])
478 |
479 | return (loss.detach(), logits, labels)
480 |
481 | def store_metrics(self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
482 | for key, value in metrics.items():
483 | self._stored_metrics[train_eval][key].append(value)
484 |
485 | def log(self, logs: Dict[str, float]) -> None:
486 | """
487 | Log `logs` on the various objects watching training, including stored metrics.
488 |
489 | Args:
490 | logs (`Dict[str, float]`):
491 | The values to log.
492 | """
493 | # logs either has 'loss' or 'eval_loss'
494 | train_eval = "train" if "loss" in logs else "eval"
495 | # Add averaged stored metrics to logs
496 | for key, metrics in self._stored_metrics[train_eval].items():
497 | logs[key] = torch.tensor(metrics).mean().item()
498 | del self._stored_metrics[train_eval]
499 | return super().log(logs)
500 |
501 |
--------------------------------------------------------------------------------
/baselines/DMPO/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import warnings
4 | from dataclasses import dataclass
5 | from typing import Any, Dict, List, Optional, Union
6 |
7 | import numpy as np
8 | import torch
9 | from torch.nn.utils.rnn import pad_sequence
10 | from torch.utils.data import IterableDataset
11 | from transformers import DataCollatorForLanguageModeling, PreTrainedTokenizerBase, TrainerCallback
12 |
13 | @dataclass
14 | class DPODataCollatorWithPadding:
15 | r"""
16 | DPO DataCollator class that pads the inputs to the maximum length of the batch.
17 | Args:
18 | tokenizer (`PreTrainedTokenizerBase`):
19 | The tokenizer used for encoding the data.
20 | padding (`Union[bool, str, `PaddingStrategy`]`, `optional`, defaults to `True`):
21 | padding_strategy to pass to the tokenizer.
22 | max_length (`Optional[int]`, `optional`, defaults to `None`):
23 | The maximum length of the sequence to be processed.
24 | max_prompt_length (`Optional[int]`, `optional`, defaults to `None`):
25 | The maximum length of the prompt to be processed.
26 | label_pad_token_id (`int`, defaults to -100):
27 | The label used for masking.
28 | padding_value (`int`, defaults to 0):
29 | The value used for padding.
30 | truncation_mode: (`str`, defaults to "keep_end"):
31 | The truncation mode to use when truncating the prompt + chosen/rejected responses.
32 | """
33 | tokenizer: PreTrainedTokenizerBase
34 | padding: Union[bool, str] = True
35 | max_length: Optional[int] = None
36 | max_prompt_length: Optional[int] = None
37 | label_pad_token_id: int = -100
38 | padding_value: int = 0
39 | truncation_mode: str = "keep_end"
40 |
41 | def tokenize_batch_element(
42 | self,
43 | prompt: str,
44 | chosen: str,
45 | rejected: Dict[str, str],
46 | ) -> Dict:
47 | """Tokenize a single batch element.
48 |
49 | At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation
50 | in case the prompt + chosen or prompt + rejected responses is/are too long. First
51 | we truncate the prompt; if we're still too long, we truncate the chosen/rejected.
52 |
53 | We also create the labels for the chosen/rejected responses, which are of length equal to
54 | the sum of the length of the prompt and the chosen/rejected response, with
55 | label_pad_token_id for the prompt tokens.
56 | """
57 | chosen_tokens = self.tokenizer(chosen, add_special_tokens=False)
58 | prompt_tokens = self.tokenizer(prompt, add_special_tokens=False)
59 | rejected_tokens = {}
60 | for key in rejected:
61 | rejected_tokens[key] = self.tokenizer(rejected[key], add_special_tokens=False)
62 |
63 | assert self.tokenizer.eos_token_id not in prompt_tokens["input_ids"], f"Prompt contains EOS token: {prompt}"
64 | assert (
65 | self.tokenizer.eos_token_id not in chosen_tokens["input_ids"]
66 | ), f"Chosen response contains EOS token: {chosen}"
67 | assert (
68 | all([self.tokenizer.eos_token_id not in rejected_tokens[key]["input_ids"] for key in rejected_tokens])
69 | ), f"Rejected response contains EOS token: {rejected}"
70 |
71 | chosen_tokens["input_ids"].append(self.tokenizer.eos_token_id)
72 | chosen_tokens["attention_mask"].append(1)
73 | for key in rejected_tokens:
74 | rejected_tokens[key]["input_ids"].append(self.tokenizer.eos_token_id)
75 | rejected_tokens[key]["attention_mask"].append(1)
76 | max_rejected_len = max([len(rejected_tokens[key]["input_ids"]) for key in rejected_tokens])
77 | longer_response_length = max(len(chosen_tokens["input_ids"]), max_rejected_len)
78 |
79 | # if combined sequence is too long, truncate the prompt
80 | if len(prompt_tokens["input_ids"]) + longer_response_length > self.max_length:
81 | if self.truncation_mode == "keep_start":
82 | prompt_tokens = {k: v[: self.max_prompt_length] for k, v in prompt_tokens.items()}
83 | elif self.truncation_mode == "keep_end":
84 | prompt_tokens = {k: v[-self.max_prompt_length :] for k, v in prompt_tokens.items()}
85 | else:
86 | raise ValueError(f"Unknown truncation mode: {self.truncation_mode}")
87 |
88 | # if that's still too long, truncate the response
89 | if len(prompt_tokens["input_ids"]) + longer_response_length > self.max_length:
90 | chosen_tokens = {k: v[: self.max_length - self.max_prompt_length] for k, v in chosen_tokens.items()}
91 | rejected_tokens = {k: v[: self.max_length - self.max_prompt_length] for k, v in rejected_tokens.items()}
92 |
93 | # Create labels
94 | chosen_sequence_tokens = {k: prompt_tokens[k] + chosen_tokens[k] for k in chosen_tokens}
95 | rejected_sequence_tokens = {}
96 | # rejected_tokens: Dict[str, Dict]
97 | for key in rejected_tokens:
98 | rejected_sequence_tokens[key] = {k: prompt_tokens[k] + rejected_tokens[key][k] for k in rejected_tokens[key]}
99 | chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:]
100 | chosen_sequence_tokens["labels"][: len(prompt_tokens["input_ids"])] = [self.label_pad_token_id] * len(
101 | prompt_tokens["input_ids"]
102 | )
103 | for key in rejected_sequence_tokens:
104 | rejected_sequence_tokens[key]["labels"] = rejected_sequence_tokens[key]["input_ids"][:]
105 | rejected_sequence_tokens[key]["labels"][: len(prompt_tokens["input_ids"])] = [self.label_pad_token_id] * len(
106 | prompt_tokens["input_ids"]
107 | )
108 |
109 | batch = {}
110 |
111 | batch["prompt"] = prompt
112 | batch["chosen"] = prompt + chosen
113 | for key in rejected:
114 | batch[key] = prompt + rejected[key]
115 | batch["chosen_response_only"] = chosen
116 | for key in rejected:
117 | batch[f"{key}_response_only"] = rejected[key]
118 |
119 | for k, toks in {
120 | "chosen": chosen_sequence_tokens,
121 | # "rejected": rejected_sequence_tokens,
122 | "prompt": prompt_tokens,
123 | }.items():
124 | for type_key, tokens in toks.items():
125 | if type_key == "token_type_ids":
126 | continue
127 | batch[f"{k}_{type_key}"] = tokens
128 | # rejected_sequence_tokens: Dict[str, Dict]
129 | for k, toks in rejected_sequence_tokens.items():
130 | for type_key, tokens in toks.items():
131 | if type_key == "token_type_ids":
132 | continue
133 | batch[f"{k}_{type_key}"] = tokens
134 |
135 | return batch
136 |
137 | def collate(self, batch):
138 | # first, pad everything to the same length
139 | padded_batch = {}
140 | for k in batch[0].keys():
141 | if k.endswith("_input_ids") or k.endswith("_attention_mask") or k.endswith("_labels"):
142 | # adapted from https://stackoverflow.com/questions/73256206
143 | if "prompt" in k:
144 | to_pad = [torch.LongTensor(ex[k][::-1]) for ex in batch]
145 | else:
146 | to_pad = [torch.LongTensor(ex[k]) for ex in batch]
147 | if k.endswith("_input_ids"):
148 | padding_value = self.tokenizer.pad_token_id
149 | elif k.endswith("_labels"):
150 | padding_value = self.label_pad_token_id
151 | elif k.endswith("_attention_mask"):
152 | padding_value = self.padding_value
153 | else:
154 | raise ValueError(f"Unexpected key in batch '{k}'")
155 |
156 | padded_batch[k] = pad_sequence(to_pad, batch_first=True, padding_value=padding_value)
157 | # for the prompt, flip back so padding is on left side
158 | if "prompt" in k:
159 | padded_batch[k] = padded_batch[k].flip(dims=[1])
160 | else:
161 | padded_batch[k] = [ex[k] for ex in batch]
162 |
163 | return padded_batch
164 |
165 | def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
166 | tokenized_batch = []
167 |
168 | for feature in features:
169 | prompt = feature["prompt"]
170 | chosen = feature["chosen"]
171 | rejected = {}
172 | for key in feature:
173 | if key.startswith("rejected"):
174 | rejected[key] = feature[key]
175 |
176 | batch_element = self.tokenize_batch_element(prompt, chosen, rejected)
177 | tokenized_batch.append(batch_element)
178 |
179 | # return collated batch
180 | return self.collate(tokenized_batch)
181 |
182 | def pad_to_length(tensor: torch.Tensor, length: int, pad_value: Union[int, float], dim: int = -1) -> torch.Tensor:
183 | if tensor.size(dim) >= length:
184 | return tensor
185 | else:
186 | pad_size = list(tensor.shape)
187 | pad_size[dim] = length - tensor.size(dim)
188 | return torch.cat(
189 | [
190 | tensor,
191 | pad_value * torch.ones(*pad_size, dtype=tensor.dtype, device=tensor.device),
192 | ],
193 | dim=dim,
194 | )
--------------------------------------------------------------------------------
/baselines/Re-weighting/RW_SFT.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import warnings
4 | import re
5 | import wandb
6 | from typing import List, Optional
7 | import datasets
8 | from tqdm import tqdm
9 | from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments,BitsAndBytesConfig
10 | from datasets import load_dataset
11 | from trl import SFTTrainer, DataCollatorForCompletionOnlyLM, SFTConfig
12 | from peft import AutoPeftModelForCausalLM, LoraConfig, get_peft_model, prepare_model_for_kbit_training, TaskType, PeftModel
13 | from transformers import LlamaForCausalLM, LlamaTokenizer
14 | # from utils import find_all_linear_names, print_trainable_parameters
15 | import pandas as pd
16 | from accelerate import Accelerator
17 | import numpy as np
18 | import torch
19 | import bitsandbytes as bnb
20 | import fire
21 | import json
22 |
23 | def read_json(json_file:str) -> dict:
24 | f = open(json_file, 'r')
25 | return json.load(f)
26 |
27 | def gh_tr(category:str,test_data,name2genre:dict,genre_dict:dict):
28 | for data in tqdm(test_data,desc="Processing category data......"):
29 | input = data['input']
30 | names = re.findall(r'"([^"]+)"', input)
31 | for name in names:
32 | if name in name2genre:
33 | genres = name2genre[name]
34 | else:
35 | continue
36 | for genre in genres:
37 | if genre in genre_dict:
38 | genre_dict[genre] += 1/len(genres)
39 | gh = [genre_dict[x] for x in genre_dict]
40 | gh_normalize = [x/sum(gh) for x in gh]
41 | return gh_normalize
42 |
43 | def gh_ta(category:str,test_data,name2genre:dict,genre_dict:dict):
44 | for data in tqdm(test_data,desc="Processing category data......"):
45 | input = data['output']
46 | names = re.findall(r'"([^"]+)"', input)
47 | for name in names:
48 | if name in name2genre:
49 | genres = name2genre[name]
50 | else:
51 | # print(f"Not exist in name2genre:{name}")
52 | continue
53 | for genre in genres:
54 | if genre in genre_dict:
55 | genre_dict[genre] += 1/len(genres)
56 | gh = [genre_dict[x] for x in genre_dict]
57 | gh_normalize = [x/sum(gh) for x in gh]
58 | return gh_normalize
59 |
60 | def weight_dict(category:str,test_data,name2genre:dict,genre_dict:dict):
61 | GH_tr = gh_tr(category,test_data,name2genre,genre_dict)
62 | GH_ta = gh_ta(category,test_data,name2genre,genre_dict)
63 | weight_dict = {}
64 | idx = 0
65 | for category in genre_dict:
66 | weight_dict[category] = GH_tr[idx] / GH_ta[idx]
67 | idx += 1
68 |
69 | return weight_dict
70 |
71 | def cal_weight(category:str,test_data,name2genre:dict,genre_dict:dict):
72 | weights = []
73 | w_dict = weight_dict(category,test_data,name2genre,genre_dict)
74 | print(f"Length of data:{len(test_data)}")
75 | for data in tqdm(test_data,desc="Processing category data......"):
76 | weight = []
77 | target_item = data['output'].strip("\n").strip("\"")
78 | if target_item in name2genre :
79 | genres = name2genre[target_item]
80 | for genre in genres:
81 | if genre in genre_dict:
82 | weight.append(w_dict[genre])
83 | if len(weight)>0:
84 | weight = sum(weight) / len(weight)
85 | weights.append(weight)
86 | else:
87 | weights.append(1)
88 | else:
89 | weights.append(1)
90 | print(f"Length of weights:{len(weights)}")
91 | return weights
92 |
93 | class IFTrainer(SFTTrainer):
94 | def compute_loss(self, model, inputs, return_outputs=False):
95 | weights = inputs.pop("weight")
96 | labels = inputs.pop("labels")
97 | outputs = model(**inputs)
98 |
99 | if self.args.past_index >= 0:
100 | self._past = outputs[self.args.past_index]
101 |
102 | logits = outputs.get("logits")
103 |
104 | shift_logits = logits[..., :-1, :].contiguous()
105 | shift_labels = labels[..., 1:].contiguous()
106 |
107 | # Flatten the tokens
108 | loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
109 | shift_logits = shift_logits.view(-1, self.model.config.vocab_size)
110 | shift_labels = shift_labels.view(-1)
111 | # Enable model parallelism
112 | shift_labels = shift_labels.to(shift_logits.device)
113 |
114 | loss = torch.mean(weights * torch.mean(loss_fct(shift_logits, shift_labels).view(weights.shape[0], -1)))
115 |
116 |
117 | return (loss, outputs) if return_outputs else loss
118 |
119 | def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None):
120 | if not self.args.remove_unused_columns:
121 | return dataset
122 | self._set_signature_columns_if_needed()
123 | signature_columns = self._signature_columns
124 | signature_columns.append("weight")
125 | ignored_columns = list(set(dataset.column_names) - set(signature_columns))
126 | if len(ignored_columns) > 0:
127 | dset_description = "" if description is None else f"in the {description} set"
128 |
129 | columns = [k for k in signature_columns if k in dataset.column_names]
130 | x = dataset.remove_columns(ignored_columns)
131 | return x
132 |
133 | def _prepare_non_packed_dataloader(
134 | self,
135 | tokenizer,
136 | dataset,
137 | dataset_text_field,
138 | max_seq_length,
139 | formatting_func=None,
140 | add_special_tokens=True,
141 | remove_unused_columns=True,
142 | ):
143 | use_formatting_func = formatting_func is not None and dataset_text_field is None
144 | self._dataset_sanity_checked = False
145 |
146 | # Inspired from: https://huggingface.co/learn/nlp-course/chapter7/6?fw=pt
147 | def tokenize(element):
148 | outputs = tokenizer(
149 | element[dataset_text_field] if not use_formatting_func else formatting_func(element),
150 | add_special_tokens=add_special_tokens,
151 | truncation=True,
152 | padding=False,
153 | max_length=max_seq_length,
154 | return_overflowing_tokens=False,
155 | return_length=False,
156 | )
157 |
158 | if use_formatting_func and not self._dataset_sanity_checked:
159 | if not isinstance(formatting_func(element), list):
160 | raise ValueError(
161 | "The `formatting_func` should return a list of processed strings since it can lead to silent bugs."
162 | )
163 | else:
164 | self._dataset_sanity_checked = True
165 |
166 | return {"input_ids": outputs["input_ids"], "attention_mask": outputs["attention_mask"], "weight": element['weight']}
167 |
168 | signature_columns = ["input_ids", "labels", "attention_mask","weight"]
169 |
170 | extra_columns = list(set(dataset.column_names) - set(signature_columns))
171 |
172 | if not remove_unused_columns and len(extra_columns) > 0:
173 | warnings.warn(
174 | "You passed `remove_unused_columns=False` on a non-packed dataset. This might create some issues with the default collator and yield to errors. If you want to "
175 | f"inspect dataset other columns (in this case {extra_columns}), you can subclass `DataCollatorForLanguageModeling` in case you used the default collator and create your own data collator in order to inspect the unused dataset columns."
176 | )
177 |
178 | tokenized_dataset = dataset.map(
179 | tokenize,
180 | batched=True,
181 | remove_columns=dataset.column_names if remove_unused_columns else None,
182 | num_proc=self.dataset_num_proc,
183 | batch_size=self.dataset_batch_size,
184 | )
185 |
186 | return tokenized_dataset
187 |
188 | from transformers import DataCollatorWithPadding
189 | import torch
190 |
191 | def train(
192 | # path
193 | output_dir="",
194 | base_model ="",
195 | train_dataset="",
196 | valid_dataset="",
197 | train_sample_size:int = 1024,
198 | resume_from_checkpoint: str = "base_model", # either training checkpoint or final adapter
199 | # wandb config
200 | wandb_project: str = "",
201 | wandb_name: str = "", # the name of the wandb run
202 | # training hyperparameters
203 | gradient_accumulation_steps: int = 1,
204 | batch_size: int = 8,
205 | num_train_epochs: int = 5,
206 | learning_rate: float = 2e-5,
207 | cutoff_len: int = 512,
208 | eval_step = 0.05,
209 | category: str = "CDs_and_Vinyl",
210 | seed = 0
211 | ):
212 | os.environ['WANDB_PROJECT'] = wandb_project
213 |
214 | def formatting_prompts_func(examples):
215 | output_text = []
216 | for i in range(len(examples["instruction"])):
217 | instruction = examples["instruction"][i]
218 | input_text = examples["input"][i]
219 | response = examples["output"][i]
220 |
221 | if len(input_text) >= 2:
222 | text = f'''Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
223 |
224 | ### Instruction:
225 | {instruction}
226 |
227 | ### Input:
228 | {input_text}
229 |
230 | ### Response:
231 | {response}
232 | '''
233 | else:
234 | text = f'''Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
235 |
236 | ### Instruction:
237 | {instruction}
238 |
239 | ### Response:
240 | {response}
241 | '''
242 | output_text.append(text)
243 |
244 | return output_text
245 |
246 | def get_train_weight(data, category:str,test_data,name2genre:dict,genre_dict:dict,w_dict):
247 | weight = []
248 | target_item = data['output'].strip("\n").strip("\"")
249 | if target_item in name2genre :
250 | genres = name2genre[target_item]
251 | for genre in genres:
252 | if genre in genre_dict:
253 | weight.append(w_dict[genre])
254 | if len(weight)>0:
255 | weight = sum(weight) / len(weight)
256 | return {"weight":weight}
257 | else:
258 | return {"weight":1}
259 | else:
260 | return {"weight":1}
261 |
262 |
263 | name2genre = read_json(f"./eval/{category}/name2genre.json")
264 | genre_dict = read_json(f"./eval/{category}/genre_dict.json")
265 | w_dict = weight_dict(category,test_data=read_json(train_dataset),name2genre=name2genre,genre_dict=genre_dict)
266 | train_weights = cal_weight(category,test_data=read_json(train_dataset),name2genre=name2genre,genre_dict=genre_dict)
267 |
268 | val_sample_size = int(train_sample_size / 8)
269 | dataset = load_dataset('json', data_files=train_dataset)
270 | dataset = {"train": dataset['train'].select(range(train_sample_size+val_sample_size))}
271 | #weights = get_train_weight(dataset['train'],train_weights)
272 | dataset['train'] = dataset['train'].map(lambda x: get_train_weight(x, category,test_data=read_json(train_dataset),name2genre=name2genre,genre_dict=genre_dict,w_dict=w_dict))
273 | print("Features:{}".format(dataset["train"].features))
274 | train_val_split = dataset['train'].train_test_split(train_size=train_sample_size, test_size=val_sample_size)
275 | train_data = train_val_split['train']
276 | print("Features:{}".format(train_data.features))
277 | val_data = train_val_split['test']
278 |
279 |
280 | bnb_config = BitsAndBytesConfig(
281 | # load_in_8bit=True,
282 | load_in_4bit=True,
283 | bnb_4bit_quant_type="nf4",
284 | bnb_4bit_compute_dtype=torch.bfloat16,
285 | bnb_4bit_use_double_quant=False,
286 | )
287 |
288 | device_index = Accelerator().process_index
289 | device_map = {"": device_index}
290 |
291 | model = LlamaForCausalLM.from_pretrained(base_model, device_map=device_map, \
292 | quantization_config=bnb_config)
293 | model.config.use_cache = False
294 | model = prepare_model_for_kbit_training(model)
295 |
296 | if 'Llama-3' in base_model:
297 | tokenizer = AutoTokenizer.from_pretrained(base_model)
298 | else:
299 | tokenizer = LlamaTokenizer.from_pretrained(base_model)
300 | # tokenizer.pad_token = tokenizer.eos_token
301 | # tokenizer.padding_side = "right"
302 | tokenizer.pad_token_id = (0)
303 | tokenizer.padding_side = "left" # Fix weird overflow issue with fp16 training
304 |
305 | if resume_from_checkpoint!="base_model":
306 | model = PeftModel.from_pretrained(model, resume_from_checkpoint,
307 | is_trainable=True)
308 | else:
309 | peft_config = LoraConfig(
310 | inference_mode=False,
311 | r=64,
312 | lora_alpha=32,
313 | target_modules=['k_proj', 'v_proj', 'q_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj'],
314 | lora_dropout=0.05,
315 | # bias="none",
316 | task_type="CAUSAL_LM",
317 | )
318 | model = get_peft_model(model, peft_config)
319 |
320 | model.print_trainable_parameters()
321 |
322 | training_args = SFTConfig(
323 | per_device_train_batch_size=batch_size,
324 | gradient_accumulation_steps=gradient_accumulation_steps,
325 | gradient_checkpointing =True,
326 | max_grad_norm= 0.3,
327 | num_train_epochs=num_train_epochs,
328 | learning_rate=learning_rate,
329 | bf16=True,
330 | save_strategy="steps",
331 | save_steps=eval_step,
332 | save_total_limit=100,
333 | load_best_model_at_end=True,
334 | evaluation_strategy="steps",
335 | eval_steps=eval_step,
336 | logging_steps=1,
337 | output_dir=output_dir,
338 | optim="paged_adamw_32bit",
339 | remove_unused_columns= True,
340 | lr_scheduler_type="cosine",
341 | warmup_ratio=0.05,
342 | report_to="wandb",
343 | run_name=wandb_name,
344 | gradient_checkpointing_kwargs={'use_reentrant': True},
345 | save_only_model=True,
346 | ddp_find_unused_parameters=False, # should set to False becuase there are no unused parameters in the forward process
347 | )
348 | trainer = IFTrainer(
349 | model,
350 | train_dataset=train_data,
351 | eval_dataset=val_data,
352 | tokenizer=tokenizer,
353 | formatting_func=formatting_prompts_func,
354 | max_seq_length=cutoff_len,
355 | args=training_args
356 | #data_collator=data_callator
357 | )
358 |
359 | trainer.train()
360 | trainer.save_model(output_dir)
361 |
362 | output_dir = os.path.join(output_dir, "final_checkpoint")
363 | trainer.model.save_pretrained(output_dir)
364 | tokenizer.save_pretrained(output_dir)
365 |
366 | if __name__ == "__main__":
367 | fire.Fire(train)
--------------------------------------------------------------------------------
/baselines/SDPO/softmax_dpo_trainer.py:
--------------------------------------------------------------------------------
1 | # DPO Authors: Rafael Rafailov, Archit Sharma, Eric Mitchell, Stefano Ermon, Christopher D. Manning, and Chelsea Finn 2023
2 | # Copyright 2023 The HuggingFace Team. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | import warnings
16 | from collections import defaultdict
17 | from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
18 | import importlib
19 |
20 |
21 | import torch
22 | import torch.nn as nn
23 | import torch.nn.functional as F
24 | from datasets import Dataset
25 | from transformers import DataCollator, PreTrainedModel, PreTrainedTokenizerBase, Trainer, TrainingArguments
26 | from transformers.trainer_callback import TrainerCallback
27 |
28 | from .utils import DPODataCollatorWithPadding, pad_to_length
29 |
30 |
31 | def is_peft_available():
32 | return importlib.util.find_spec("peft") is not None
33 |
34 | if is_peft_available():
35 | from peft import get_peft_model, prepare_model_for_kbit_training
36 |
37 |
38 | class DPOTrainer(Trainer):
39 | r"""
40 | Initialize DPOTrainer.
41 |
42 | Args:
43 | model (`transformers.PreTrainedModel`):
44 | The model to train, preferably an `AutoModelForSequenceClassification`.
45 | ref_model (`PreTrainedModelWrapper`):
46 | Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss.
47 | beta (`float`, defaults to 0.1):
48 | The beta factor in DPO loss. Higher beta means less divergence from the initial policy.
49 | args (`transformers.TrainingArguments`):
50 | The arguments to use for training.
51 | data_collator (`transformers.DataCollator`):
52 | The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
53 | which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
54 | label_pad_token_id (`int`, defaults to `-100`):
55 | The label pad token id. This argument is required if you want to use the default data collator.
56 | padding_value (`int`, defaults to `0`):
57 | The padding value. This argument is required if you want to use the default data collator.
58 | truncation_mode (`str`, defaults to `keep_end`):
59 | The truncation mode to use, either `keep_end` or `keep_start`. This argument is required if you want to use the default data collator.
60 | train_dataset (`datasets.Dataset`):
61 | The dataset to use for training.
62 | eval_dataset (`datasets.Dataset`):
63 | The dataset to use for evaluation.
64 | tokenizer (`transformers.PreTrainedTokenizerBase`):
65 | The tokenizer to use for training. This argument is required if you want to use the default data collator.
66 | model_init (`Callable[[], transformers.PreTrainedModel]`):
67 | The model initializer to use for training. If None is specified, the default model initializer will be used.
68 | callbacks (`List[transformers.TrainerCallback]`):
69 | The callbacks to use for training.
70 | optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
71 | The optimizer and scheduler to use for training.
72 | preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
73 | The function to use to preprocess the logits before computing the metrics.
74 | max_length (`int`, defaults to `None`):
75 | The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator.
76 | max_prompt_length (`int`, defaults to `None`):
77 | The maximum length of the prompt. This argument is required if you want to use the default data collator.
78 | peft_config (`Dict`, defaults to `None`):
79 | The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
80 | """
81 |
82 | def __init__(
83 | self,
84 | model: Union[PreTrainedModel, nn.Module] = None,
85 | ref_model: Union[PreTrainedModel, nn.Module] = None,
86 | beta: float = 0.1,
87 | args: TrainingArguments = None,
88 | data_collator: Optional[DataCollator] = None,
89 | label_pad_token_id: int = -100,
90 | padding_value: int = 0,
91 | truncation_mode: str = "keep_end",
92 | train_dataset: Optional[Dataset] = None,
93 | eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
94 | tokenizer: Optional[PreTrainedTokenizerBase] = None,
95 | model_init: Optional[Callable[[], PreTrainedModel]] = None,
96 | callbacks: Optional[List[TrainerCallback]] = None,
97 | optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (
98 | None,
99 | None,
100 | ),
101 | preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
102 | max_length: Optional[int] = None,
103 | max_prompt_length: Optional[int] = None,
104 | peft_config: Optional[Dict] = None,
105 | ):
106 | if not is_peft_available() and peft_config is not None:
107 | raise ValueError(
108 | "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
109 | )
110 | elif is_peft_available() and peft_config is not None:
111 | if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
112 | model = prepare_model_for_kbit_training(model)
113 | model = get_peft_model(model, peft_config)
114 |
115 | if data_collator is None:
116 | if tokenizer is None:
117 | raise ValueError(
118 | "max_length or a tokenizer must be specified when using the default DPODataCollatorWithPadding"
119 | )
120 | if max_length is None:
121 | warnings.warn(
122 | "When using DPODataCollatorWithPadding, you should set `max_length` in the DPOTrainer's init"
123 | " it will be set to `512` by default, but you should do it yourself in the future.",
124 | UserWarning,
125 | )
126 | max_length = 512
127 | if max_prompt_length is None:
128 | warnings.warn(
129 | "When using DPODataCollatorWithPadding, you should set `max_prompt_length` in the DPOTrainer's init"
130 | " it will be set to `128` by default, but you should do it yourself in the future.",
131 | UserWarning,
132 | )
133 | max_prompt_length = 128
134 |
135 | data_collator = DPODataCollatorWithPadding(
136 | tokenizer,
137 | max_length=max_length,
138 | max_prompt_length=max_prompt_length,
139 | label_pad_token_id=label_pad_token_id,
140 | padding_value=padding_value,
141 | truncation_mode=truncation_mode,
142 | )
143 |
144 | if args.remove_unused_columns:
145 | args.remove_unused_columns = False
146 | # warn users
147 | warnings.warn(
148 | "When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments"
149 | " we have set it for you, but you should do it yourself in the future.",
150 | UserWarning,
151 | )
152 |
153 | self.use_dpo_data_collator = True
154 | else:
155 | self.use_dpo_data_collator = False
156 |
157 | self.label_pad_token_id = label_pad_token_id
158 | self.padding_value = padding_value
159 |
160 | self.beta = beta
161 | self.ref_model = ref_model
162 |
163 | self._stored_metrics = defaultdict(lambda: defaultdict(list))
164 |
165 | super().__init__(
166 | model,
167 | args,
168 | data_collator,
169 | train_dataset,
170 | eval_dataset,
171 | tokenizer,
172 | model_init,
173 | None,
174 | callbacks,
175 | optimizers,
176 | preprocess_logits_for_metrics,
177 | )
178 |
179 | # Since we inherit from trainer we always have access to an accelerator
180 | if hasattr(self, "accelerator"):
181 | self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
182 | else:
183 | raise AttributeError(
184 | "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
185 | )
186 |
187 | def concatenated_inputs(self, batch: Dict[str, Union[List, torch.LongTensor]]) -> Dict[str, torch.LongTensor]:
188 | """Concatenate the chosen and rejected inputs into a single tensor.
189 |
190 | Args:
191 | batch: A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors of shape (batch_size, sequence_length).
192 |
193 | Returns:
194 | A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'.
195 | """
196 | # 把 chosen 和 rejected response 拼接起来
197 | rejected_max_len = max([batch[key].shape[1] for key in batch if key.startswith("rejected") and key.endswith("_input_ids")])
198 | max_length = max(batch["chosen_input_ids"].shape[1], rejected_max_len)
199 | concatenated_batch = {}
200 | for k in batch:
201 | if k.startswith("chosen") and isinstance(batch[k], torch.Tensor):
202 | pad_value = self.label_pad_token_id if "labels" in k else self.padding_value
203 | concatenated_key = k.replace("chosen", "concatenated")
204 | concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value)
205 | for k in batch:
206 | if k.startswith("rejected") and isinstance(batch[k], torch.Tensor):
207 | pad_value = self.label_pad_token_id if "labels" in k else self.padding_value
208 | # concatenated_key = k.replace("rejected", "concatenated")
209 | prefix = k.split("_")[0]
210 | concatenated_key = "concatenated" + k[len(prefix):]
211 | concatenated_batch[concatenated_key] = torch.cat(
212 | (
213 | concatenated_batch[concatenated_key],
214 | pad_to_length(batch[k], max_length, pad_value=pad_value),
215 | ),
216 | dim=0,
217 | ).to(self.accelerator.device)
218 | return concatenated_batch
219 |
220 | def dpo_loss(
221 | self,
222 | policy_chosen_logps: torch.FloatTensor,
223 | policy_rejected_logps: Dict[str, torch.FloatTensor],
224 | reference_chosen_logps: torch.FloatTensor,
225 | reference_rejected_logps: Dict[str, torch.FloatTensor],
226 | reference_free: bool = False,
227 | ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
228 | """Compute the DPO loss for a batch of policy and reference model log probabilities.
229 |
230 | Args:
231 | policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
232 | policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
233 | reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,)
234 | reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,)
235 | beta: Temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5. We ignore the reference model as beta -> 0.
236 | reference_free: If True, we ignore the _provided_ reference model and implicitly use a reference model that assigns equal probability to all responses.
237 |
238 | Returns:
239 | A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
240 | The losses tensor contains the DPO loss for each example in the batch.
241 | The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
242 | """
243 | # pi_logratios = policy_chosen_logps - policy_rejected_logps
244 | # for key in policy_rejected_logps:
245 | # ref_logratios = reference_chosen_logps - reference_rejected_logps
246 | chosen_logratios = policy_chosen_logps - reference_chosen_logps
247 | # print(f"chosen:{chosen_logratios}")
248 | rejected_logratios = {}
249 | for key in policy_rejected_logps:
250 | rejected_logratios[key] = policy_rejected_logps[key] - reference_rejected_logps[key]
251 | # print(f"{key}_logratios:{rejected_logratios[key].shape}")
252 | # if reference_free:
253 | # ref_logratios = 0
254 |
255 | # logits = pi_logratios - ref_logratios
256 | temp = sum(torch.exp(self.beta * (rejected_logratios[key] - chosen_logratios)) for key in rejected_logratios)
257 | temp1 = -torch.log(temp)
258 | losses = -F.logsigmoid(temp1)
259 | # losses = -F.logsigmoid(self.beta * logits)
260 | rejected_rewards = {}
261 | chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps).detach()
262 | for key in policy_rejected_logps:
263 | rejected_rewards[key] = self.beta * (policy_rejected_logps[key] - reference_rejected_logps[key]).detach()
264 |
265 | return losses, chosen_rewards, rejected_rewards
266 |
267 | def _get_batch_logps(
268 | self,
269 | logits: torch.FloatTensor,
270 | labels: torch.LongTensor,
271 | average_log_prob: bool = False,
272 | ) -> torch.FloatTensor:
273 | """Compute the log probabilities of the given labels under the given logits.
274 |
275 | Args:
276 | logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
277 | labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)
278 | average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
279 |
280 | Returns:
281 | A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
282 | """
283 | if logits.shape[:-1] != labels.shape:
284 | raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
285 |
286 | labels = labels[:, 1:].clone()
287 | logits = logits[:, :-1, :]
288 | loss_mask = labels != self.label_pad_token_id
289 |
290 | # dummy token; we'll ignore the losses on these tokens later
291 | labels[labels == self.label_pad_token_id] = 0
292 |
293 | per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
294 |
295 | if average_log_prob:
296 | return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
297 | else:
298 | return (per_token_logps * loss_mask).sum(-1)
299 |
300 | def concatenated_forward(
301 | self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
302 | ) -> Tuple[torch.FloatTensor, Dict[str, torch.FloatTensor], torch.FloatTensor, Dict[str, torch.FloatTensor]]:
303 | """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
304 |
305 | We do this to avoid doing two forward passes, because it's faster for FSDP.
306 | """
307 | concatenated_batch = self.concatenated_inputs(batch)
308 | # print(concatenated_batch["concatenated_input_ids"].shape)
309 | all_logits = model(
310 | concatenated_batch["concatenated_input_ids"],
311 | attention_mask=concatenated_batch["concatenated_attention_mask"],
312 | ).logits.to(torch.float32)
313 | all_logps = self._get_batch_logps(
314 | all_logits,
315 | concatenated_batch["concatenated_labels"],
316 | average_log_prob=False,
317 | )
318 | chosen_logps = all_logps[: batch["chosen_input_ids"].shape[0]]
319 | step = batch["chosen_input_ids"].shape[0]
320 | rejected_logps = {}
321 | cnt = 0
322 | for key in batch:
323 | if key.startswith("rejected") and key.endswith("_input_ids"):
324 | cnt += 1
325 | rejected_logps[f"rejected{cnt}"] = all_logps[step*cnt : step*(cnt+1)]
326 |
327 | chosen_logits = all_logits[: batch["chosen_input_ids"].shape[0]]
328 | rejected_logits = {}
329 | cnt = 0
330 | for key in batch:
331 | if key.startswith("rejected") and key.endswith("_input_ids"):
332 | cnt += 1
333 | rejected_logits[f"rejected{cnt}"] = all_logits[step*cnt : step*(cnt+1)]
334 | return (chosen_logps, rejected_logps, chosen_logits, rejected_logits)
335 |
336 | def get_batch_metrics(
337 | self,
338 | model,
339 | batch: Dict[str, Union[List, torch.LongTensor]],
340 | train_eval: Literal["train", "eval"] = "train",
341 | ):
342 | """Compute the DPO loss and other metrics for the given batch of inputs for train or test."""
343 | metrics = {}
344 |
345 | (
346 | policy_chosen_logps,
347 | policy_rejected_logps,
348 | policy_chosen_logits,
349 | policy_rejected_logits,
350 | ) = self.concatenated_forward(model, batch)
351 | with torch.no_grad():
352 | (
353 | reference_chosen_logps,
354 | reference_rejected_logps,
355 | _,
356 | _,
357 | ) = self.concatenated_forward(self.ref_model, batch)
358 |
359 | losses, chosen_rewards, rejected_rewards = self.dpo_loss(
360 | policy_chosen_logps,
361 | policy_rejected_logps,
362 | reference_chosen_logps,
363 | reference_rejected_logps,
364 | )
365 |
366 | # reward_accuracies 记录 chosen 比所有 rejected 的收益都大的比例是多少
367 | reward_accuracies = None
368 | for key in rejected_rewards:
369 | if reward_accuracies is None:
370 | reward_accuracies = (chosen_rewards > rejected_rewards[key]).float()
371 | else:
372 | reward_accuracies *= (chosen_rewards > rejected_rewards[key]).float()
373 |
374 | prefix = "eval_" if train_eval == "eval" else ""
375 | metrics[f"{prefix}rewards/chosen"] = chosen_rewards.cpu().numpy().mean()
376 | for key in rejected_rewards:
377 | metrics[f"{prefix}rewards/{key}"] = rejected_rewards[key].cpu().numpy().mean()
378 | metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.cpu().numpy().mean()
379 | for key in rejected_rewards:
380 | metrics[f"{prefix}rewards/margins-{key}"] = (chosen_rewards - rejected_rewards[key]).cpu().numpy().mean()
381 | for key in policy_rejected_logps:
382 | metrics[f"{prefix}logps/rejected-{key}"] = policy_rejected_logps[key].detach().cpu().numpy().mean()
383 | metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().cpu().numpy().mean()
384 | for key in policy_rejected_logits:
385 | metrics[f"{prefix}logits/rejected-{key}"] = policy_rejected_logits[key].detach().cpu().numpy().mean()
386 | metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().cpu().numpy().mean()
387 |
388 | return losses.mean(), metrics
389 |
390 | def compute_loss(
391 | self,
392 | model: Union[PreTrainedModel, nn.Module],
393 | inputs: Dict[str, Union[torch.Tensor, Any]],
394 | return_outputs=False,
395 | ) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]:
396 | # print(inputs.keys())
397 | # print(inputs)
398 | if not self.use_dpo_data_collator:
399 | warnings.warn(
400 | "compute_loss is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than "
401 | "DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
402 | )
403 | loss, metrics = self.get_batch_metrics(model, inputs, train_eval="train")
404 |
405 | # force log the metrics
406 | if self.accelerator.is_main_process:
407 | self.store_metrics(metrics, train_eval="train")
408 |
409 | if return_outputs:
410 | return (loss, metrics)
411 | return loss
412 |
413 | def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]:
414 | """Generate samples from the model and reference model for the given batch of inputs."""
415 |
416 | policy_output = model.generate(
417 | batch["prompt_input_ids"],
418 | attention_mask=batch["prompt_attention_mask"],
419 | max_length=self.config.max_length,
420 | do_sample=True,
421 | pad_token_id=self.tokenizer.pad_token_id,
422 | )
423 |
424 | reference_output = self.ref_model.generate(
425 | batch["prompt_input_ids"],
426 | attention_mask=batch["prompt_attention_mask"],
427 | max_length=self.config.max_length,
428 | do_sample=True,
429 | pad_token_id=self.tokenizer.pad_token_id,
430 | )
431 |
432 | policy_output = pad_to_length(policy_output, self.config.max_length, self.tokenizer.pad_token_id)
433 | policy_output_decoded = self.tokenizer.batch_decode(policy_output, skip_special_tokens=True)
434 |
435 | reference_output = pad_to_length(reference_output, self.config.max_length, self.tokenizer.pad_token_id)
436 | reference_output_decoded = self.tokenizer.batch_decode(reference_output, skip_special_tokens=True)
437 |
438 | return policy_output_decoded, reference_output_decoded
439 |
440 | def prediction_step(
441 | self,
442 | model: Union[PreTrainedModel, nn.Module],
443 | inputs: Dict[str, Union[torch.Tensor, Any]],
444 | prediction_loss_only: bool,
445 | ignore_keys: Optional[List[str]] = None,
446 | ):
447 | if not self.use_dpo_data_collator:
448 | warnings.warn(
449 | "prediction_step is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than "
450 | "DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
451 | )
452 | if ignore_keys is None:
453 | if hasattr(model, "config"):
454 | ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
455 | else:
456 | ignore_keys = []
457 |
458 | with torch.no_grad():
459 | loss, metrics = self.get_batch_metrics(model, inputs, train_eval="eval")
460 |
461 | # force log the metrics
462 | if self.accelerator.is_main_process:
463 | self.store_metrics(metrics, train_eval="eval")
464 |
465 | if prediction_loss_only:
466 | return (loss.detach(), None, None)
467 |
468 | # logits for the chosen and rejected samples from model
469 | logits_dict = {
470 | "logits_test/chosen": metrics["logits_test/chosen"],
471 | # "logits_test/rejected": metrics["logits_test/rejected"],
472 | }
473 | logits = tuple(v for k, v in logits_dict.items() if k not in ignore_keys)
474 | logits = torch.stack(logits).mean(axis=1)
475 | labels = torch.zeros(logits.shape[0])
476 |
477 | return (loss.detach(), logits, labels)
478 |
479 | def store_metrics(self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
480 | for key, value in metrics.items():
481 | self._stored_metrics[train_eval][key].append(value)
482 |
483 | def log(self, logs: Dict[str, float]) -> None:
484 | """
485 | Log `logs` on the various objects watching training, including stored metrics.
486 |
487 | Args:
488 | logs (`Dict[str, float]`):
489 | The values to log.
490 | """
491 | # logs either has 'loss' or 'eval_loss'
492 | train_eval = "train" if "loss" in logs else "eval"
493 | # Add averaged stored metrics to logs
494 | for key, metrics in self._stored_metrics[train_eval].items():
495 | logs[key] = torch.tensor(metrics).mean().item()
496 | del self._stored_metrics[train_eval]
497 | return super().log(logs)
498 |
499 |
--------------------------------------------------------------------------------
/baselines/SDPO/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import warnings
4 | from dataclasses import dataclass
5 | from typing import Any, Dict, List, Optional, Union
6 |
7 | import numpy as np
8 | import torch
9 | from torch.nn.utils.rnn import pad_sequence
10 | from torch.utils.data import IterableDataset
11 | from transformers import DataCollatorForLanguageModeling, PreTrainedTokenizerBase, TrainerCallback
12 |
13 | @dataclass
14 | class DPODataCollatorWithPadding:
15 | r"""
16 | DPO DataCollator class that pads the inputs to the maximum length of the batch.
17 | Args:
18 | tokenizer (`PreTrainedTokenizerBase`):
19 | The tokenizer used for encoding the data.
20 | padding (`Union[bool, str, `PaddingStrategy`]`, `optional`, defaults to `True`):
21 | padding_strategy to pass to the tokenizer.
22 | max_length (`Optional[int]`, `optional`, defaults to `None`):
23 | The maximum length of the sequence to be processed.
24 | max_prompt_length (`Optional[int]`, `optional`, defaults to `None`):
25 | The maximum length of the prompt to be processed.
26 | label_pad_token_id (`int`, defaults to -100):
27 | The label used for masking.
28 | padding_value (`int`, defaults to 0):
29 | The value used for padding.
30 | truncation_mode: (`str`, defaults to "keep_end"):
31 | The truncation mode to use when truncating the prompt + chosen/rejected responses.
32 | """
33 | tokenizer: PreTrainedTokenizerBase
34 | padding: Union[bool, str] = True
35 | max_length: Optional[int] = None
36 | max_prompt_length: Optional[int] = None
37 | label_pad_token_id: int = -100
38 | padding_value: int = 0
39 | truncation_mode: str = "keep_end"
40 |
41 | def tokenize_batch_element(
42 | self,
43 | prompt: str,
44 | chosen: str,
45 | rejected: Dict[str, str],
46 | ) -> Dict:
47 | """Tokenize a single batch element.
48 |
49 | At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation
50 | in case the prompt + chosen or prompt + rejected responses is/are too long. First
51 | we truncate the prompt; if we're still too long, we truncate the chosen/rejected.
52 |
53 | We also create the labels for the chosen/rejected responses, which are of length equal to
54 | the sum of the length of the prompt and the chosen/rejected response, with
55 | label_pad_token_id for the prompt tokens.
56 | """
57 | chosen_tokens = self.tokenizer(chosen, add_special_tokens=False)
58 | prompt_tokens = self.tokenizer(prompt, add_special_tokens=False)
59 | rejected_tokens = {}
60 | for key in rejected:
61 | rejected_tokens[key] = self.tokenizer(rejected[key], add_special_tokens=False)
62 |
63 | assert self.tokenizer.eos_token_id not in prompt_tokens["input_ids"], f"Prompt contains EOS token: {prompt}"
64 | assert (
65 | self.tokenizer.eos_token_id not in chosen_tokens["input_ids"]
66 | ), f"Chosen response contains EOS token: {chosen}"
67 | assert (
68 | all([self.tokenizer.eos_token_id not in rejected_tokens[key]["input_ids"] for key in rejected_tokens])
69 | ), f"Rejected response contains EOS token: {rejected}"
70 |
71 | chosen_tokens["input_ids"].append(self.tokenizer.eos_token_id)
72 | chosen_tokens["attention_mask"].append(1)
73 | for key in rejected_tokens:
74 | rejected_tokens[key]["input_ids"].append(self.tokenizer.eos_token_id)
75 | rejected_tokens[key]["attention_mask"].append(1)
76 | max_rejected_len = max([len(rejected_tokens[key]["input_ids"]) for key in rejected_tokens])
77 | longer_response_length = max(len(chosen_tokens["input_ids"]), max_rejected_len)
78 |
79 | # if combined sequence is too long, truncate the prompt
80 | if len(prompt_tokens["input_ids"]) + longer_response_length > self.max_length:
81 | if self.truncation_mode == "keep_start":
82 | prompt_tokens = {k: v[: self.max_prompt_length] for k, v in prompt_tokens.items()}
83 | elif self.truncation_mode == "keep_end":
84 | prompt_tokens = {k: v[-self.max_prompt_length :] for k, v in prompt_tokens.items()}
85 | else:
86 | raise ValueError(f"Unknown truncation mode: {self.truncation_mode}")
87 |
88 | # if that's still too long, truncate the response
89 | if len(prompt_tokens["input_ids"]) + longer_response_length > self.max_length:
90 | chosen_tokens = {k: v[: self.max_length - self.max_prompt_length] for k, v in chosen_tokens.items()}
91 | rejected_tokens = {k: v[: self.max_length - self.max_prompt_length] for k, v in rejected_tokens.items()}
92 |
93 | # Create labels
94 | chosen_sequence_tokens = {k: prompt_tokens[k] + chosen_tokens[k] for k in chosen_tokens}
95 | rejected_sequence_tokens = {}
96 | # rejected_tokens: Dict[str, Dict]
97 | for key in rejected_tokens:
98 | rejected_sequence_tokens[key] = {k: prompt_tokens[k] + rejected_tokens[key][k] for k in rejected_tokens[key]}
99 | chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:]
100 | chosen_sequence_tokens["labels"][: len(prompt_tokens["input_ids"])] = [self.label_pad_token_id] * len(
101 | prompt_tokens["input_ids"]
102 | )
103 | for key in rejected_sequence_tokens:
104 | rejected_sequence_tokens[key]["labels"] = rejected_sequence_tokens[key]["input_ids"][:]
105 | rejected_sequence_tokens[key]["labels"][: len(prompt_tokens["input_ids"])] = [self.label_pad_token_id] * len(
106 | prompt_tokens["input_ids"]
107 | )
108 |
109 | batch = {}
110 |
111 | batch["prompt"] = prompt
112 | batch["chosen"] = prompt + chosen
113 | for key in rejected:
114 | batch[key] = prompt + rejected[key]
115 | batch["chosen_response_only"] = chosen
116 | for key in rejected:
117 | batch[f"{key}_response_only"] = rejected[key]
118 |
119 | for k, toks in {
120 | "chosen": chosen_sequence_tokens,
121 | # "rejected": rejected_sequence_tokens,
122 | "prompt": prompt_tokens,
123 | }.items():
124 | for type_key, tokens in toks.items():
125 | if type_key == "token_type_ids":
126 | continue
127 | batch[f"{k}_{type_key}"] = tokens
128 | # rejected_sequence_tokens: Dict[str, Dict]
129 | for k, toks in rejected_sequence_tokens.items():
130 | for type_key, tokens in toks.items():
131 | if type_key == "token_type_ids":
132 | continue
133 | batch[f"{k}_{type_key}"] = tokens
134 |
135 | return batch
136 |
137 | def collate(self, batch):
138 | # first, pad everything to the same length
139 | padded_batch = {}
140 | for k in batch[0].keys():
141 | if k.endswith("_input_ids") or k.endswith("_attention_mask") or k.endswith("_labels"):
142 | # adapted from https://stackoverflow.com/questions/73256206
143 | if "prompt" in k:
144 | to_pad = [torch.LongTensor(ex[k][::-1]) for ex in batch]
145 | else:
146 | to_pad = [torch.LongTensor(ex[k]) for ex in batch]
147 | if k.endswith("_input_ids"):
148 | padding_value = self.tokenizer.pad_token_id
149 | elif k.endswith("_labels"):
150 | padding_value = self.label_pad_token_id
151 | elif k.endswith("_attention_mask"):
152 | padding_value = self.padding_value
153 | else:
154 | raise ValueError(f"Unexpected key in batch '{k}'")
155 |
156 | padded_batch[k] = pad_sequence(to_pad, batch_first=True, padding_value=padding_value)
157 | # for the prompt, flip back so padding is on left side
158 | if "prompt" in k:
159 | padded_batch[k] = padded_batch[k].flip(dims=[1])
160 | else:
161 | padded_batch[k] = [ex[k] for ex in batch]
162 |
163 | return padded_batch
164 |
165 | def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
166 | tokenized_batch = []
167 |
168 | for feature in features:
169 | prompt = feature["prompt"]
170 | chosen = feature["chosen"]
171 | rejected = {}
172 | for key in feature:
173 | if key.startswith("rejected"):
174 | rejected[key] = feature[key]
175 |
176 | batch_element = self.tokenize_batch_element(prompt, chosen, rejected)
177 | tokenized_batch.append(batch_element)
178 |
179 | # return collated batch
180 | return self.collate(tokenized_batch)
181 |
182 | def pad_to_length(tensor: torch.Tensor, length: int, pad_value: Union[int, float], dim: int = -1) -> torch.Tensor:
183 | if tensor.size(dim) >= length:
184 | return tensor
185 | else:
186 | pad_size = list(tensor.shape)
187 | pad_size[dim] = length - tensor.size(dim)
188 | return torch.cat(
189 | [
190 | tensor,
191 | pad_value * torch.ones(*pad_size, dtype=tensor.dtype, device=tensor.device),
192 | ],
193 | dim=dim,
194 | )
--------------------------------------------------------------------------------
/baselines/Semantic_sampling_rosePO/Semantic_sampling_rosePO.py:
--------------------------------------------------------------------------------
1 | import json
2 | from tqdm import tqdm
3 | import re
4 | from sentence_transformers import SentenceTransformer
5 | import torch
6 | import torch.nn.functional as F
7 | def process_batch(batch):
8 | results = []
9 | for data in batch:
10 | input = data['input']
11 | names = re.findall(r'"([^"]+)"', input)
12 | name_embeddings = torch.tensor([model.encode(name) for name in names], device="cuda")
13 | cosine_similarity = F.cosine_similarity(name_embeddings[:, None, :], embeddings[None, :, :], dim=-1)
14 | similarity = cosine_similarity.mean(dim=0)
15 | min_sim, min_index = similarity.min(dim=-1)
16 | semantic_item = id2name[str(min_index.item())]
17 | data['semantic'] = f"\"{semantic_item}\"\n"
18 | results.append(data)
19 | return results
20 | model = SentenceTransformer('./models/paraphrase-MiniLM-L3-v2')
21 | def read_json(json_file:str) -> dict:
22 | f = open(json_file, 'r')
23 | return json.load(f)
24 | def export_to_json(file_path:str,dic):
25 | f = open(file_path, 'w')
26 | json.dump(dic,f,indent=2)
27 | # semantic item
28 | for category in ["CDs_and_Vinyl"]:
29 | embeddings = torch.load(f"../eval/{category}/embeddings.pt").to('cuda')
30 | id2name = read_json(f"../eval/{category}/id2name.json")
31 | train_data = read_json(f"./{category}/train.json")
32 | batch_size = 64
33 | batched_data = [train_data[i:i+batch_size] for i in range(0, len(train_data), batch_size)]
34 | final_data = []
35 | for batch in tqdm(batched_data, desc=f"Processing {category} train data......"):
36 | final_data.extend(process_batch(batch))
37 | export_to_json(f"../data/{category}/train_semantic.json",train_data)
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: SPRec
2 | channels:
3 | - pytorch
4 | - nvidia
5 | - defaults
6 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/
7 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/
8 | - https://repo.anaconda.com/pkgs/main
9 | - https://repo.anaconda.com/pkgs/r
10 | dependencies:
11 | - _libgcc_mutex=0.1=main
12 | - _openmp_mutex=5.1=1_gnu
13 | - blas=1.0=mkl
14 | - brotli-python=1.0.9=py38h6a678d5_8
15 | - bzip2=1.0.8=h5eee18b_6
16 | - ca-certificates=2024.9.24=h06a4308_0
17 | - charset-normalizer=3.3.2=pyhd3eb1b0_0
18 | - cuda-cudart=12.1.105=0
19 | - cuda-cupti=12.1.105=0
20 | - cuda-libraries=12.1.0=0
21 | - cuda-nvrtc=12.1.105=0
22 | - cuda-nvtx=12.1.105=0
23 | - cuda-opencl=12.4.127=0
24 | - cuda-runtime=12.1.0=0
25 | - cudatoolkit=11.8.0=h6a678d5_0
26 | - ffmpeg=4.3=hf484d3e_0
27 | - filelock=3.13.1=py38h06a4308_0
28 | - freetype=2.12.1=h4a9f257_0
29 | - gmp=6.2.1=h295c915_3
30 | - gmpy2=2.1.2=py38heeb90bb_0
31 | - gnutls=3.6.15=he1e5248_0
32 | - intel-openmp=2023.1.0=hdb19cb5_46306
33 | - jpeg=9e=h5eee18b_3
34 | - lame=3.100=h7b6447c_0
35 | - lcms2=2.12=h3be6417_0
36 | - ld_impl_linux-64=2.38=h1181459_1
37 | - lerc=3.0=h295c915_0
38 | - libcublas=12.1.0.26=0
39 | - libcufft=11.0.2.4=0
40 | - libcufile=1.9.1.3=0
41 | - libcurand=10.3.5.147=0
42 | - libcusolver=11.4.4.55=0
43 | - libcusparse=12.0.2.55=0
44 | - libdeflate=1.17=h5eee18b_1
45 | - libffi=3.4.4=h6a678d5_0
46 | - libgcc-ng=11.2.0=h1234567_1
47 | - libgomp=11.2.0=h1234567_1
48 | - libiconv=1.16=h5eee18b_3
49 | - libidn2=2.3.4=h5eee18b_0
50 | - libjpeg-turbo=2.0.0=h9bf148f_0
51 | - libnpp=12.0.2.50=0
52 | - libnvjitlink=12.1.105=0
53 | - libnvjpeg=12.1.1.14=0
54 | - libpng=1.6.39=h5eee18b_0
55 | - libstdcxx-ng=11.2.0=h1234567_1
56 | - libtasn1=4.19.0=h5eee18b_0
57 | - libtiff=4.5.1=h6a678d5_0
58 | - libunistring=0.9.10=h27cfd23_0
59 | - libwebp-base=1.3.2=h5eee18b_0
60 | - llvm-openmp=14.0.6=h9e868ea_0
61 | - lz4-c=1.9.4=h6a678d5_1
62 | - markupsafe=2.1.3=py38h5eee18b_0
63 | - mkl=2023.1.0=h213fc3f_46344
64 | - mkl-service=2.4.0=py38h5eee18b_1
65 | - mkl_fft=1.3.8=py38h5eee18b_0
66 | - mkl_random=1.2.4=py38hdb19cb5_0
67 | - mpc=1.1.0=h10f8cd9_1
68 | - mpfr=4.0.2=hb69a4c5_1
69 | - mpmath=1.3.0=py38h06a4308_0
70 | - ncurses=6.4=h6a678d5_0
71 | - nettle=3.7.3=hbbd107a_1
72 | - networkx=3.1=py38h06a4308_0
73 | - openh264=2.1.1=h4ff587b_0
74 | - openjpeg=2.5.2=he7f1fd0_0
75 | - openssl=3.0.12=h7f8727e_0
76 | - pip=23.3=py38h06a4308_0
77 | - pysocks=1.7.1=py38h06a4308_0
78 | - python=3.8.18=h955ad1f_0
79 | - pytorch=2.1.0=py3.8_cuda12.1_cudnn8.9.2_0
80 | - pytorch-cuda=12.1=ha16c6d3_5
81 | - pytorch-mutex=1.0=cuda
82 | - pyyaml=6.0.1=py38h5eee18b_0
83 | - readline=8.2=h5eee18b_0
84 | - setuptools=68.0.0=py38h06a4308_0
85 | - sqlite=3.41.2=h5eee18b_0
86 | - tbb=2021.8.0=hdb19cb5_0
87 | - tk=8.6.12=h1ccaba5_0
88 | - torchtriton=2.1.0=py38
89 | - typing_extensions=4.11.0=py38h06a4308_0
90 | - wheel=0.41.2=py38h06a4308_0
91 | - xz=5.4.2=h5eee18b_0
92 | - yaml=0.2.5=h7b6447c_0
93 | - zlib=1.2.13=h5eee18b_0
94 | - zstd=1.5.5=hc292b87_0
95 | - pip:
96 | - accelerate==1.0.1
97 | - aiofiles==23.2.1
98 | - aiohttp==3.9.0
99 | - aiosignal==1.3.1
100 | - altair==5.1.2
101 | - annotated-types==0.6.0
102 | - anyio==3.7.1
103 | - appdirs==1.4.4
104 | - argon2-cffi==23.1.0
105 | - argon2-cffi-bindings==21.2.0
106 | - arrow==1.3.0
107 | - asttokens==2.4.1
108 | - async-lru==2.0.4
109 | - async-timeout==4.0.3
110 | - attrs==23.1.0
111 | - babel==2.13.1
112 | - backcall==0.2.0
113 | - beautifulsoup4==4.12.2
114 | - bitsandbytes==0.43.1
115 | - bitsandbytes-cuda116==0.26.0.post2
116 | - black==23.11.0
117 | - bleach==6.1.0
118 | - certifi==2023.11.17
119 | - cffi==1.16.0
120 | - click==8.1.7
121 | - cmake==3.27.7
122 | - colorama==0.4.6
123 | - comm==0.2.0
124 | - contourpy==1.1.1
125 | - cycler==0.12.1
126 | - datasets==2.15.0
127 | - debugpy==1.8.0
128 | - decorator==5.1.1
129 | - defusedxml==0.7.1
130 | - dill==0.3.7
131 | - docker-pycreds==0.4.0
132 | - docopt==0.6.2
133 | - docstring-parser==0.16
134 | - eval-type-backport==0.2.0
135 | - exceptiongroup==1.2.0
136 | - executing==2.0.1
137 | - fastapi==0.104.1
138 | - fasteners==0.19
139 | - fastjsonschema==2.19.0
140 | - ffmpy==0.3.1
141 | - fire==0.5.0
142 | - fonttools==4.45.0
143 | - fqdn==1.5.1
144 | - frozenlist==1.4.0
145 | - fschat==0.2.34
146 | - fsspec==2023.10.0
147 | - gitdb==4.0.11
148 | - gitpython==3.1.43
149 | - gradio==3.50.2
150 | - gradio-client==0.6.1
151 | - h11==0.14.0
152 | - httpcore==1.0.2
153 | - httpx==0.25.1
154 | - huggingface-hub==0.25.0
155 | - idna==3.4
156 | - ijson==3.3.0
157 | - imageio==2.35.1
158 | - importlib-metadata==6.8.0
159 | - importlib-resources==6.1.1
160 | - ipykernel==6.27.1
161 | - ipython==8.12.3
162 | - ipywidgets==8.1.1
163 | - isoduration==20.11.0
164 | - jedi==0.19.1
165 | - jinja2==3.1.2
166 | - joblib==1.3.2
167 | - json5==0.9.14
168 | - jsonpointer==2.4
169 | - jsonschema==4.20.0
170 | - jsonschema-specifications==2023.11.1
171 | - jupyter==1.0.0
172 | - jupyter-client==8.6.0
173 | - jupyter-console==6.6.3
174 | - jupyter-core==5.5.0
175 | - jupyter-events==0.9.0
176 | - jupyter-lsp==2.2.1
177 | - jupyter-server==2.11.1
178 | - jupyter-server-terminals==0.4.4
179 | - jupyterlab==4.0.9
180 | - jupyterlab-pygments==0.3.0
181 | - jupyterlab-server==2.25.2
182 | - jupyterlab-widgets==3.0.9
183 | - kiwisolver==1.4.5
184 | - lazy-loader==0.4
185 | - lit==17.0.5
186 | - loguru==0.7.2
187 | - loralib==0.1.2
188 | - markdown-it-py==3.0.0
189 | - markdown2==2.4.12
190 | - matplotlib==3.7.4
191 | - matplotlib-inline==0.1.6
192 | - mdurl==0.1.2
193 | - mistune==3.0.2
194 | - multidict==6.0.4
195 | - multiprocess==0.70.15
196 | - mypy-extensions==1.0.0
197 | - nbclient==0.9.0
198 | - nbconvert==7.11.0
199 | - nbformat==5.9.2
200 | - nest-asyncio==1.5.8
201 | - nh3==0.2.15
202 | - nltk==3.8.1
203 | - notebook==7.0.6
204 | - notebook-shim==0.2.3
205 | - numpy==1.24.4
206 | - nvidia-cublas-cu11==11.10.3.66
207 | - nvidia-cublas-cu12==12.1.3.1
208 | - nvidia-cuda-cupti-cu11==11.7.101
209 | - nvidia-cuda-cupti-cu12==12.1.105
210 | - nvidia-cuda-nvrtc-cu11==11.7.99
211 | - nvidia-cuda-nvrtc-cu12==12.1.105
212 | - nvidia-cuda-runtime-cu11==11.7.99
213 | - nvidia-cuda-runtime-cu12==12.1.105
214 | - nvidia-cudnn-cu11==8.5.0.96
215 | - nvidia-cudnn-cu12==8.9.2.26
216 | - nvidia-cufft-cu11==10.9.0.58
217 | - nvidia-cufft-cu12==11.0.2.54
218 | - nvidia-curand-cu11==10.2.10.91
219 | - nvidia-curand-cu12==10.3.2.106
220 | - nvidia-cusolver-cu11==11.4.0.1
221 | - nvidia-cusolver-cu12==11.4.5.107
222 | - nvidia-cusparse-cu11==11.7.4.91
223 | - nvidia-cusparse-cu12==12.1.0.106
224 | - nvidia-nccl-cu11==2.14.3
225 | - nvidia-nccl-cu12==2.18.1
226 | - nvidia-nvjitlink-cu12==12.3.101
227 | - nvidia-nvtx-cu11==11.7.91
228 | - nvidia-nvtx-cu12==12.1.105
229 | - opencv-python==4.10.0.84
230 | - orjson==3.9.10
231 | - overrides==7.4.0
232 | - packaging==23.2
233 | - pandas==2.0.3
234 | - pandocfilters==1.5.0
235 | - parso==0.8.3
236 | - pathspec==0.11.2
237 | - peft==0.11.0
238 | - pexpect==4.8.0
239 | - pickleshare==0.7.5
240 | - pillow==10.1.0
241 | - pipreqs==0.5.0
242 | - pkgutil-resolve-name==1.3.10
243 | - platformdirs==4.0.0
244 | - prometheus-client==0.19.0
245 | - prompt-toolkit==3.0.41
246 | - protobuf==3.19.0
247 | - psutil==5.9.6
248 | - ptyprocess==0.7.0
249 | - pure-eval==0.2.2
250 | - pyarrow==14.0.1
251 | - pyarrow-hotfix==0.6
252 | - pycparser==2.21
253 | - pydantic==1.10.13
254 | - pydantic-core==2.14.3
255 | - pydub==0.25.1
256 | - pygments==2.17.2
257 | - pyparsing==3.1.1
258 | - python-dateutil==2.8.2
259 | - python-json-logger==2.0.7
260 | - python-multipart==0.0.6
261 | - pytz==2023.3.post1
262 | - pyzmq==25.1.1
263 | - qtconsole==5.5.1
264 | - qtpy==2.4.1
265 | - referencing==0.31.0
266 | - regex==2023.10.3
267 | - requests==2.31.0
268 | - rfc3339-validator==0.1.4
269 | - rfc3986-validator==0.1.1
270 | - rich==13.7.0
271 | - rpds-py==0.13.1
272 | - safetensors==0.4.5
273 | - scikit-image==0.21.0
274 | - scikit-learn==1.3.2
275 | - scipy==1.10.1
276 | - seaborn==0.13.0
277 | - semantic-version==2.10.0
278 | - send2trash==1.8.2
279 | - sentence-transformers==2.2.2
280 | - sentencepiece==0.1.99
281 | - sentry-sdk==1.45.0
282 | - setproctitle==1.3.3
283 | - shellingham==1.5.4
284 | - shortuuid==1.0.11
285 | - shtab==1.7.1
286 | - six==1.16.0
287 | - smmap==5.0.1
288 | - sniffio==1.3.0
289 | - some-package==0.1
290 | - soupsieve==2.5
291 | - stack-data==0.6.3
292 | - starlette==0.27.0
293 | - svgwrite==1.4.3
294 | - sympy==1.12
295 | - termcolor==2.3.0
296 | - terminado==0.18.0
297 | - threadpoolctl==3.2.0
298 | - tifffile==2023.7.10
299 | - tiktoken==0.5.2
300 | - tinycss2==1.2.1
301 | - tokenize-rt==5.2.0
302 | - tokenizers==0.19.1
303 | - tomli==2.0.1
304 | - tomlkit==0.12.0
305 | - toolz==0.12.0
306 | - torch==2.0.1
307 | - torchaudio==2.0.2
308 | - torchvision==0.15.2
309 | - tornado==6.4
310 | - tqdm==4.66.1
311 | - traitlets==5.13.0
312 | - transformers==4.44.2
313 | - triton==2.0.0
314 | - trl==0.9.2
315 | - typer==0.9.0
316 | - types-python-dateutil==2.8.19.14
317 | - typing-extensions==4.8.0
318 | - tyro==0.8.3
319 | - tzdata==2023.3
320 | - uri-template==1.3.0
321 | - urllib3==2.1.0
322 | - uvicorn==0.24.0.post1
323 | - wandb==0.16.6
324 | - wavedrom==2.0.3.post3
325 | - wcwidth==0.2.12
326 | - webcolors==1.13
327 | - webencodings==0.5.1
328 | - websocket-client==1.6.4
329 | - websockets==11.0.3
330 | - widgetsnbextension==4.0.9
331 | - xxhash==3.4.1
332 | - yarg==0.1.9
333 | - yarl==1.9.3
334 | - zipp==3.17.0
335 | prefix: # Specify your prefix
336 |
--------------------------------------------------------------------------------
/eval/Goodreads/embeddings.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RegionCh/SPRec/96dff9b5a42c6cc227d07fcee6b3d8813865bac0/eval/Goodreads/embeddings.pt
--------------------------------------------------------------------------------
/eval/Goodreads/genre_dict.json:
--------------------------------------------------------------------------------
1 | {
2 | "fiction": 0,
3 | "romance": 0,
4 | "young-adult": 0,
5 | "fantasy, paranormal": 0,
6 | "mystery, thriller, crime": 0,
7 | "history, historical fiction, biography": 0,
8 | "children": 0
9 | }
--------------------------------------------------------------------------------
/eval/MovieLens/embeddings.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RegionCh/SPRec/96dff9b5a42c6cc227d07fcee6b3d8813865bac0/eval/MovieLens/embeddings.pt
--------------------------------------------------------------------------------
/eval/MovieLens/genre_dict.json:
--------------------------------------------------------------------------------
1 | {"Action": 0, "Adventure": 0, "Sci-Fi": 0, "Thriller": 0, "Drama": 0, "Comedy": 0, "Fantasy": 0, "Crime": 0,"Romance": 0}
--------------------------------------------------------------------------------
/eval/evaluate.py:
--------------------------------------------------------------------------------
1 | import os
2 | from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer
3 | import transformers
4 | import torch
5 | import re
6 | import math
7 | import json
8 | from peft import PeftModel
9 | import argparse
10 | import pandas as pd
11 | from collections import Counter
12 | from sentence_transformers import SentenceTransformer
13 | parse = argparse.ArgumentParser()
14 | parse.add_argument("--input_dir",type=str, default="./", help="result file")
15 | parse.add_argument("--model",type=str, default="SPRec", help="result file")
16 | parse.add_argument("--exp_csv",type=str, default=None, help="result file")
17 | parse.add_argument("--output_dir",type=str, default="./", help="eval_result")
18 | parse.add_argument("--topk",type=str, default="./", help="topk")
19 | parse.add_argument("--gamma",type=float,default=0.0,help="gamma")
20 | parse.add_argument("--category",type=str,default="CDs_and_Vinyl",help="gamma")
21 | args = parse.parse_args()
22 | def read_json(json_file:str) -> dict:
23 | f = open(json_file, 'r')
24 | return json.load(f)
25 | category = args.category
26 | id2name = read_json(f"./eval/{category}/id2name.json")
27 | name2id = read_json(f"./eval/{category}/name2id.json")
28 | embeddings = torch.load(f"./eval/{category}/embeddings.pt")
29 | name2genre = read_json(f"./eval/{category}/name2genre.json")
30 | genre_dict = read_json(f"./eval/{category}/genre_dict.json")
31 | def batch(list, batch_size=1):
32 | chunk_size = (len(list) - 1) // batch_size + 1
33 | for i in range(chunk_size):
34 | yield list[batch_size * i: batch_size * (i + 1)]
35 |
36 | def sum_of_first_i_keys(sorted_dic, i):
37 | keys = list(sorted_dic.values())[:i]
38 | return sum(keys)
39 |
40 | def gh(category:str,test_data):
41 | notin_count = 0
42 | in_count = 0
43 | name2genre=read_json(f"./eval/{category}/name2genre.json")
44 | genre_dict = read_json(f"./eval/{category}/genre_dict.json")
45 | for data in tqdm(test_data,desc="Processing category data......"):
46 | input = data['input']
47 | names = re.findall(r'"([^"]+)"', input)
48 | for name in names:
49 | if name in name2genre:
50 | in_count += 1
51 | genres = name2genre[name]
52 | else:
53 | notin_count += 1
54 | # print(f"Not exist in name2genre:{name}")
55 | continue
56 | select_genres = []
57 | for genre in genres:
58 | if genre in genre_dict:
59 | select_genres.append(genre)
60 | if(len(select_genres)>0):
61 | for genre in select_genres:
62 | genre_dict[genre] += 1/len(select_genres)
63 | gh = [genre_dict[x] for x in genre_dict]
64 | gh_normalize = [x/sum(gh) for x in gh]
65 | print(f"InCount:{in_count}\nNotinCount:{notin_count}")
66 | return gh_normalize
67 |
68 |
69 | result_json = args.input_dir
70 | f = open(result_json, 'r')
71 | test_data = json.load(f)
72 | total = 0
73 | # Identify your sentence-embedding model
74 | model = SentenceTransformer('/data/chenruijun/code/models/paraphrase-MiniLM-L3-v2')
75 |
76 | from tqdm import tqdm
77 | embeddings = torch.tensor(embeddings).cuda()
78 | text = []
79 | for i,_ in tqdm(enumerate(test_data)):
80 | if(len(_["predict"])>0):
81 | if(len(_['predict'][0])==0):
82 | text.append("NAN")
83 | print("Empty prediction!")
84 | else:
85 | match = re.search(r'"([^"]*)', _['predict'][0])
86 | if match:
87 | name = match.group(1)
88 | text.append(name)
89 | else:
90 | text.append(_['predict'][0].split('\n', 1)[0])
91 | else:
92 | print("Empty:")
93 |
94 | predict_embeddings = []
95 | for i, batch_input in tqdm(enumerate(batch(text, 8))):
96 | predict_embeddings.append(torch.tensor(model.encode(batch_input)))
97 | predict_embeddings = torch.cat(predict_embeddings, dim=0).cuda()
98 | predict_embeddings.size()
99 | dist = torch.cdist(predict_embeddings, embeddings, p=2)
100 | batch_size = 1
101 | num_batches = (dist.size(0) + batch_size - 1) // batch_size
102 | rank_list = []
103 | for i in tqdm(range(num_batches), desc="Processing Batches"):
104 | start_idx = i * batch_size
105 | end_idx = min((i + 1) * batch_size, dist.size(0))
106 | batch_dist = dist[start_idx:end_idx]
107 |
108 | batch_rank = batch_dist.argsort(dim=-1).argsort(dim=-1)
109 | torch.cuda.empty_cache ()
110 | rank_list.append(batch_rank)
111 |
112 | rank_list = torch.cat(rank_list, dim=0)
113 |
114 | NDCG = []
115 | HR = []
116 | diversity = []
117 | diversity_dic = {}
118 | MGU_genre = []
119 | DGU_genre = []
120 | pop_count = {}
121 | genre_count = {}
122 | notin = 0
123 | notin_count = 0
124 | in_count = 0
125 | topk_list = [int(args.topk)]
126 | diversity_set = set()
127 | for topk in topk_list:
128 | S_ndcg = 0
129 | S_hr = 0
130 | for i in tqdm(range(len(test_data)),desc="Calculating Metrics......"):
131 | rank = rank_list[i]
132 | # Target id
133 | target_name = test_data[i]['output']
134 | predict_name = test_data[i]['predict'][0]
135 | target_name = target_name.strip().strip('"')
136 | if target_name in name2id:
137 | target_id = name2id[target_name]
138 | total += 1
139 | else:
140 | continue
141 |
142 | rankId = rank[target_id]
143 |
144 | # NDCG & HR
145 | if(rankId0):
162 | for genre in select_genres:
163 | genre_dict[genre] += 1/len(select_genres)
164 | else:
165 | notin += 1
166 |
167 |
168 | # diversity
169 | for i in range(topk):
170 | diversity_set.add(torch.argwhere(rank==i).item())
171 | if torch.argwhere(rank==i).item() in diversity_dic:
172 | diversity_dic[torch.argwhere(rank==i).item()] += 1
173 | else:
174 | diversity_dic[torch.argwhere(rank==i).item()] = 1
175 |
176 |
177 | NDCG.append(S_ndcg / len(test_data) / (1 / math.log(2)))
178 | HR.append(S_hr / len(test_data))
179 | diversity.append(len(diversity_set))
180 | genre = args.category
181 |
182 | gh_genre = gh(category,test_data)
183 | #
184 | print(len(gh_genre))
185 | gp_genre = [genre_dict[x] for x in genre_dict]
186 | gp_genre = [x/sum(gp_genre) for x in gp_genre]
187 | dis_genre = [gp_genre[i]-gh_genre[i] for i in range(len(gh_genre))]
188 | DGU_genre = max(dis_genre)-min(dis_genre)
189 | dis_abs_genre = [abs(x) for x in dis_genre]
190 | MGU_genre = sum(dis_abs_genre) / len(dis_genre)
191 | i=0
192 |
193 | gp_dict = {}
194 | i=0
195 | for key in genre_dict:
196 | gp_dict[key] = dis_abs_genre[i]
197 | i += 1
198 | print(f"gp_dict:{gp_dict}")
199 | print(f"NDCG:{NDCG}")
200 | print(f"HR:{HR}")
201 | div_ratio = diversity[0] / (total*topk)
202 | print(f"DGU:{DGU_genre}")
203 | print(f"MGU:{MGU_genre}")
204 | print(f"DivRatio:{div_ratio}")
205 |
206 | eval_dic = {}
207 | eval_dic["model"] = args.input_dir
208 | # eval_dic["Dis_genre"] = dis_abs_genre
209 | eval_dic['NDCG'] = NDCG
210 | eval_dic["HR"] = HR
211 | eval_dic["diversity"] = diversity
212 | eval_dic["DivRatio"] = div_ratio
213 | eval_dic['DGU'] = DGU_genre
214 | eval_dic["MGU"] = MGU_genre
215 |
216 | file_path = args.output_dir
217 | if os.path.exists(file_path) and os.path.getsize(file_path) > 0:
218 | with open(file_path, 'r') as file:
219 | try:
220 | data = json.load(file)
221 | except json.JSONDecodeError:
222 | data = []
223 | else:
224 | data = []
225 | sorted_dic = dict(sorted(diversity_dic.items(), key=lambda item: item[1],reverse=True))
226 | count = 0
227 | i=0
228 | eval_dic["ORRatio"] = sum_of_first_i_keys(sorted_dic,3) / (topk*total)
229 | print(f"ORRatio:{sum_of_first_i_keys(sorted_dic,3) / (topk*total)}")
230 | #print(dict(sorted(diversity_dic.items(), key=lambda item: item[1])))
231 | data.append(eval_dic)
232 | print(count)
233 | with open(args.output_dir, 'w') as file:
234 | json.dump(data, file,separators=(',', ': '),indent=2)
235 |
236 | def update_csv(dataset_name, model_name, metrics_dict, csv_file):
237 | df = pd.read_csv(csv_file)
238 |
239 | required_columns = ["Dataset", "Model"]
240 | if not all(col in df.columns for col in required_columns):
241 | raise ValueError("CSV 文件必须包含 'Dataset' 和 'Model' 列")
242 |
243 | condition = (df["Dataset"] == dataset_name) & (df["Model"] == model_name)
244 | if not condition.any():
245 | new_row = {col: None for col in df.columns}
246 | new_row["Dataset"] = dataset_name
247 | new_row["Model"] = model_name
248 |
249 | new_row_df = pd.DataFrame([new_row])
250 | df = pd.concat([df, new_row_df], ignore_index=True)
251 |
252 | condition = (df["Dataset"] == dataset_name) & (df["Model"] == model_name)
253 |
254 | for metric, value in metrics_dict.items():
255 | if metric not in df.columns:
256 | print(f"注意:指标 '{metric}' 不在 CSV 文件列中,已添加该列并初始化为0。")
257 | df[metric] = 0
258 | df.loc[condition, metric] = value
259 |
260 | df.to_csv(csv_file, index=False)
261 | print(f"CSV 文件已更新:{csv_file}")
262 |
263 | if args.exp_csv != None:
264 | metric_dic = {}
265 | metric_dic[f"MGU@{args.topk}"] = eval_dic["MGU"]
266 | metric_dic[f"DGU@{args.topk}"] = eval_dic["DGU"]
267 | metric_dic[f"DivRatio@{args.topk}"] = eval_dic["DivRatio"]
268 | metric_dic[f"ORRatio@{args.topk}"] = sum_of_first_i_keys(sorted_dic,3) / (topk*total)
269 | if args.topk == '5':
270 | metric_dic[f"NDCG@{args.topk}"] = eval_dic["NDCG"]
271 | update_csv(category,args.model,metric_dic,args.exp_csv)
--------------------------------------------------------------------------------
/eval/inference.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | import fire
4 | import gradio as gr
5 | import torch
6 | torch.set_num_threads(1)
7 | import transformers
8 | import json
9 | import os
10 | os.environ['OPENBLAS_NUM_THREADS'] = '1'
11 | os.environ['OMP_NUM_THREADS'] = '1'
12 | from peft import PeftModel
13 | from transformers import GenerationConfig, LlamaTokenizer
14 | from transformers import LlamaForCausalLM,AutoTokenizer
15 |
16 |
17 | if torch.cuda.is_available():
18 | device = "cuda"
19 | else:
20 | device = "cpu"
21 |
22 | try:
23 | if torch.backends.mps.is_available():
24 | device = "mps"
25 | except: # noqa: E722
26 | pass
27 |
28 |
29 | def main(
30 | load_8bit: bool = False,
31 | base_model: str = "",
32 | lora_weights: str = "tloen/alpaca-lora-7b",
33 | test_data_path: str = "data/test.json",
34 | result_json_data: str = "temp.json",
35 | batch_size: int=32,
36 | num_beams: int=1
37 | ):
38 | assert (
39 | base_model
40 | ), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'"
41 |
42 | #tokenizer = LlamaTokenizer.from_pretrained(base_model)
43 | tokenizer = AutoTokenizer.from_pretrained(base_model)
44 | tokenizer.pad_token_id = tokenizer.eos_token_id
45 | load_8bit = False
46 | if device == "cuda":
47 | model = LlamaForCausalLM.from_pretrained(
48 | base_model,
49 | load_in_8bit=load_8bit,
50 | torch_dtype=torch.float16,
51 | device_map="auto",
52 | )
53 | model = PeftModel.from_pretrained(
54 | model,
55 | lora_weights,
56 | torch_dtype=torch.float16,
57 | device_map="auto"
58 | )
59 | tokenizer.padding_side = "left"
60 |
61 | model.eval()
62 |
63 | def evaluate(
64 | instructions,
65 | inputs=None,
66 | temperature=1.0,
67 | top_p=0.9,
68 | top_k=40,
69 | num_beams=num_beams,
70 | max_new_tokens=32,
71 | **kwargs,
72 | ):
73 | prompt = [generate_prompt(instruction, input) for instruction, input in zip(instructions, inputs)]
74 | inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True).to(device)
75 | generation_config = GenerationConfig(
76 | temperature=temperature,
77 | top_p=top_p,
78 | top_k=top_k,
79 | num_beams=num_beams,
80 | num_return_sequences=num_beams,
81 | **kwargs,
82 | )
83 | with torch.no_grad():
84 | generation_output = model.generate(
85 | **inputs,
86 | generation_config=generation_config,
87 | return_dict_in_generate=True,
88 | output_scores=True,
89 | max_new_tokens=max_new_tokens,
90 | pad_token_id = tokenizer.eos_token_id
91 | )
92 | s = generation_output.sequences
93 | output = tokenizer.batch_decode(s, skip_special_tokens=True)
94 | output = [_.split('Response:\n')[-1] for _ in output]
95 | real_outputs = [output[i * num_beams: (i + 1) * num_beams] for i in range(len(output) // num_beams)]
96 | return real_outputs
97 |
98 |
99 | outputs = []
100 | tokenizer.pad_token_id = tokenizer.eos_token_id
101 | from tqdm import tqdm
102 | with open(test_data_path, 'r') as f:
103 | test_data = json.load(f)
104 | instructions = [_['instruction'] for _ in test_data]
105 | inputs = [_['input'] for _ in test_data]
106 | def batch(list, batch_size=batch_size):
107 | chunk_size = (len(list) - 1) // batch_size + 1
108 | for i in range(chunk_size):
109 | yield list[batch_size * i: batch_size * (i + 1)]
110 | for i, batch in tqdm(enumerate(zip(batch(instructions), batch(inputs)))):
111 | instructions, inputs = batch
112 | output = evaluate(instructions, inputs)
113 | outputs = outputs + output
114 |
115 | for i, test in tqdm(enumerate(test_data)):
116 | test_data[i]['predict'] = outputs[i]
117 |
118 |
119 | with open(result_json_data, 'w') as f:
120 | json.dump(test_data, f, indent=4)
121 |
122 | def generate_prompt(instruction, input=None):
123 | if input:
124 | return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
125 |
126 | ### Instruction:
127 | {instruction}
128 |
129 | ### Input:
130 | {input}
131 |
132 | ### Response:
133 | """
134 | else:
135 | return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
136 |
137 | ### Instruction:
138 | {instruction}
139 |
140 | ### Response:
141 | """
142 |
143 |
144 | if __name__ == "__main__":
145 | fire.Fire(main)
--------------------------------------------------------------------------------
/figs/method.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RegionCh/SPRec/96dff9b5a42c6cc227d07fcee6b3d8813865bac0/figs/method.png
--------------------------------------------------------------------------------
/shell/SFT.sh:
--------------------------------------------------------------------------------
1 | base_model="" # Specify your base model here
2 | gpu1=$1; gpu2=$2; gpu3=$3; gpu4=$4
3 | sample=4096
4 | for category in "MovieLens" "Goodreads" "CDs_and_Vinyl" "Steam"
5 | do
6 | echo ---------------------- SFT for category $category starting! ----------------------
7 | train_dataset="./data/${category}/train.json"
8 | valid_dataset="./data/${category}/valid.json"
9 | output_dir="./models/SFT_${sample}/${category}"
10 | mkdir -p $output_dir
11 | CUDA_VISIBLE_DEVICES=$gpu1,$gpu2,$gpu3,$gpu4 python ./train/sft.py \
12 | --output_dir $output_dir\
13 | --base_model $base_model \
14 | --train_dataset $train_dataset \
15 | --valid_dataset $valid_dataset \
16 | --train_sample_size $sample \
17 | --wandb_project SFT_${category}_${sample} \
18 | --wandb_name SFT_${category}_${sample} \
19 | --gradient_accumulation_steps 4 \
20 | --batch_size 4 \
21 | --num_train_epochs 4 \
22 | --learning_rate 0.0003 \
23 | --cutoff_len 512
24 |
25 | bash ./shell/eval_single_file.sh $gpu1 $gpu2 $gpu3 $gpu4 \
26 | $base_model \
27 | $output_dir \
28 | $category \
29 | $topk
30 | done
31 |
32 |
--------------------------------------------------------------------------------
/shell/SPRec.sh:
--------------------------------------------------------------------------------
1 | gpu1=$1; gpu2=$2; gpu3=$3; gpu4=$4; its=$5
2 | train_sample_size=2048;valid_sample_size=256
3 | base_model="/data/chenruijun/code/models/Llama-3.2-1B-Instruct"
4 | batch_size=4
5 | lr=0.00002
6 | # Only change the parameters above if needed
7 | for category in "MovieLens" "Goodreads" "CDs_and_Vinyl" "Video_Games" "Steam"
8 | do
9 | lora_weights="./models/SFT_model_4096/${category}"
10 | output_dir="./models/SPRec/${category}_${train_sample_size}_${lr}"
11 | wandb_project="SPRec_${category}_${lr}_${train_sample_size}"
12 | echo ----------------- Training Parameters -----------------
13 | echo "GPU: $gpu1,$gpu2,$gpu3,$gpu4"
14 | echo "Iterations: $its"
15 | echo "Train Sample Size: $train_sample_size"
16 | echo "Valid Sample Size: $valid_sample_size"
17 | echo "Base Model: $base_model"
18 | echo "LoRA Weights: $lora_weights"
19 | echo "Category: $category"
20 | echo "Learning Rate: $lr"
21 |
22 | for ((i=0;i<$its;i++))
23 | do
24 | echo ----------------- Iteration$i starts! -----------------
25 | it_output_dir="${output_dir}/it${i}/"
26 | dpo_train_data_path="${it_output_dir}/data/dpo_train.jsonl"
27 | dpo_valid_data_path="${it_output_dir}/data/dpo_valid.jsonl"
28 | sft_train_data_path="${it_output_dir}/data/sft_train.jsonl"
29 | sft_valid_data_path="${it_output_dir}/data/sft_valid.jsonl"
30 | mkdir -p $it_output_dir
31 | mkdir -p "${it_output_dir}/data"
32 | touch "${dpo_train_data_path}"
33 | touch "${dpo_valid_data_path}"
34 | touch "${sft_train_data_path}"
35 | touch "${sft_valid_data_path}"
36 | # Data Generation
37 | CUDA_VISIBLE_DEVICES=$gpu1,$gpu2,$gpu3,$gpu4 python ./train/data_generate.py \
38 | --train_json_file ./data/${category}/train.json \
39 | --valid_json_file ./data/${category}/valid.json \
40 | --result_json_dpo_data_train $dpo_train_data_path \
41 | --result_json_dpo_data_valid $dpo_valid_data_path \
42 | --result_json_sft_data_train $sft_train_data_path \
43 | --result_json_sft_data_valid $sft_valid_data_path \
44 | --base_model $base_model \
45 | --lora_weights $lora_weights \
46 | --batch_size 64 \
47 | --train_sample_size $train_sample_size \
48 | --valid_sample_size $valid_sample_size \
49 | # SFT
50 | wandb_name="iteration${i}_SFT"
51 | SFT_path="${it_output_dir}SFT"
52 | mkdir -p $SFT_path
53 | CUDA_VISIBLE_DEVICES=$gpu1,$gpu2,$gpu3,$gpu4 python ./train/sft.py \
54 | --resume_from_checkpoint $lora_weights \
55 | --output_dir $SFT_path \
56 | --base_model $base_model \
57 | --train_dataset $sft_train_data_path \
58 | --valid_dataset $sft_valid_data_path \
59 | --train_sample_size $train_sample_size \
60 | --wandb_project $wandb_project \
61 | --wandb_name $wandb_name \
62 | --gradient_accumulation_steps 4 \
63 | --batch_size $batch_size \
64 | --num_train_epochs 1 \
65 | --learning_rate $lr \
66 | --cutoff_len 512 \
67 | # Evaluate SFT model
68 | lora_weights=$SFT_path
69 | bash ./shell/eval_single_file.sh $gpu1 $gpu2 $gpu3 $gpu4 \
70 | $base_model \
71 | $lora_weights \
72 | $category
73 | # DPO
74 | wandb_name="iteration${i}_DPO"
75 | DPO_path="${it_output_dir}DPO/"
76 | mkdir -p $DPO_path
77 | CUDA_VISIBLE_DEVICES=$gpu1,$gpu2,$gpu3,$gpu4 python ./train/dpo.py \
78 | --train_dataset $dpo_train_data_path \
79 | --val_dataset $dpo_valid_data_path \
80 | --output_dir $DPO_path \
81 | --base_model $base_model \
82 | --resume_from_checkpoint $lora_weights \
83 | --wandb_name $wandb_name \
84 | --wandb_project $wandb_project \
85 | --batch_size 2 \
86 | --gradient_accumulation_steps 4 \
87 | --learning_rate $lr \
88 | --cutoff_len 512 \
89 | --num_epochs 1
90 | # Evaluate DPO model
91 | lora_weights=$DPO_path
92 | bash ./shell/eval_single_file.sh $gpu1 $gpu2 $gpu3 $gpu4 \
93 | $base_model \
94 | $lora_weights \
95 | $category
96 |
97 | done
98 | echo SPRec for category ${category} has successfully completed!
99 | done
--------------------------------------------------------------------------------
/shell/eval_single_file.sh:
--------------------------------------------------------------------------------
1 | # bash ./shell/eval_single_file.sh 0 2 4 5
2 | base_model=$5
3 | lora_weights=$6
4 | category=$7
5 | # Only change the parameters above if needed
6 | echo -------------------------------------- Evaluation started! --------------------------------------
7 |
8 | gpu1=$1; gpu2=$2; gpu3=$3; gpu4=$4
9 | test_json="./data/$category/test.json"
10 | result_json="${lora_weights}/test_result.json"
11 | touch $result_json
12 | CUDA_VISIBLE_DEVICES=$gpu1,$gpu2,$gpu3,$gpu4 python ./eval/inference.py \
13 | --base_model $base_model \
14 | --lora_weights $lora_weights \
15 | --test_data_path $test_json \
16 | --result_json_data $result_json \
17 | --num_beams 1
18 | echo Result for model "$lora_weights" is created in $result_json!
19 | eval_result_json="${lora_weights}/eval_top1.json"
20 | CUDA_VISIBLE_DEVICES=$1 python ./eval/evaluate.py \
21 | --input_dir $result_json \
22 | --output_dir $eval_result_json \
23 | --topk 1 \
24 | --gamma 0 \
25 | --category $category
26 | echo Metrics for model "$lora_weights" is created in $eval_result_json!
27 | eval_result_json="${lora_weights}/eval_top5.json"
28 | CUDA_VISIBLE_DEVICES=$1 python ./eval/evaluate.py \
29 | --input_dir $result_json \
30 | --output_dir $eval_result_json \
31 | --topk 5 \
32 | --gamma 0 \
33 | --category $category
34 | echo Metrics for model "$lora_weights" is created in $eval_result_json!
35 |
36 | echo -------------------------------------- Evaluation finished! --------------------------------------
--------------------------------------------------------------------------------
/train/data_generate.py:
--------------------------------------------------------------------------------
1 | import re
2 | import json
3 | import sys
4 | import fire
5 | import gradio as gr
6 | import numpy as np
7 | import torch
8 | torch.set_num_threads(1)
9 | from sentence_transformers import SentenceTransformer
10 | import random
11 | import transformers
12 | from tqdm import tqdm
13 | import json
14 | import os
15 | os.environ['OPENBLAS_NUM_THREADS'] = '1'
16 | os.environ['OMP_NUM_THREADS'] = '1'
17 | from peft import PeftModel
18 | from transformers import GenerationConfig,AutoTokenizer
19 | from transformers import LlamaForCausalLM
20 | if torch.cuda.is_available():
21 | device = "cuda"
22 | else:
23 | device = "cpu"
24 |
25 | def main(
26 | train_json_file : str = "",
27 | valid_json_file : str = "",
28 | result_json_dpo_data_train: str = "",
29 | result_json_dpo_data_valid: str = "",
30 | result_json_sft_data_train: str = "",
31 | result_json_sft_data_valid: str = "",
32 | base_model: str = "",
33 | lora_weights: str = "",
34 | batch_size:int = 4,
35 | train_sample_size:int = 1024,
36 | valid_sample_size:int = 128,
37 | load_8bit: bool = False,
38 | random_neg: bool = False,
39 | ):
40 |
41 | # generate responses from model
42 | tokenizer = AutoTokenizer.from_pretrained(base_model)
43 | tokenizer.pad_token_id = tokenizer.eos_token_id
44 | load_8bit = False
45 | if device == "cuda":
46 | model = LlamaForCausalLM.from_pretrained(
47 | base_model,
48 | load_in_8bit=load_8bit,
49 | torch_dtype=torch.float16,
50 | device_map="auto",
51 | )
52 | model = PeftModel.from_pretrained(
53 | model,
54 | lora_weights,
55 | torch_dtype=torch.float16,
56 | device_map="auto"
57 | )
58 | tokenizer.padding_side = "left"
59 |
60 | model.eval()
61 |
62 | #emb_model = SentenceTransformer('/data/chenruijun/code/models/paraphrase-MiniLM-L3-v2')
63 |
64 | def evaluate(
65 | instructions,
66 | inputs=None,
67 | temperature=0,
68 | top_p=0.9,
69 | top_k=40,
70 | num_beams=1,
71 | max_new_tokens=128,
72 | **kwargs,
73 | ):
74 | prompt = [generate_prompt(instruction, input) for instruction, input in zip(instructions, inputs)]
75 | inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True).to(device)
76 | generation_config = GenerationConfig(
77 | temperature=temperature,
78 | top_p=top_p,
79 | top_k=top_k,
80 | num_beams=num_beams,
81 | num_return_sequences=num_beams,
82 | **kwargs,
83 | )
84 | with torch.no_grad():
85 | generation_output = model.generate(
86 | **inputs,
87 | generation_config=generation_config,
88 | return_dict_in_generate=True,
89 | output_scores=True,
90 | max_new_tokens=max_new_tokens,
91 | pad_token_id = tokenizer.eos_token_id
92 | )
93 | s = generation_output.sequences
94 | output = tokenizer.batch_decode(s, skip_special_tokens=True)
95 | output = [_.split('Response:\n')[-1] for _ in output]
96 | real_outputs = [output[i * num_beams: (i + 1) * num_beams] for i in range(len(output) // num_beams)]
97 | return real_outputs
98 |
99 | outputs = []
100 | tokenizer.pad_token_id = tokenizer.eos_token_id
101 |
102 | with open(train_json_file, 'r') as f:
103 | train_data = json.load(f)
104 | train_data = random.sample(train_data, train_sample_size)
105 | sft_train_data = train_data
106 | with open(valid_json_file, 'r') as f:
107 | valid_data = json.load(f)
108 | valid_data = random.sample(valid_data, valid_sample_size)
109 | sft_valid_data = valid_data
110 | with open(result_json_sft_data_train, 'w') as f:
111 | for item in sft_train_data:
112 | json.dump(item, f)
113 | f.write('\n')
114 | with open(result_json_sft_data_valid, 'w') as f:
115 | for item in sft_valid_data:
116 | json.dump(item, f)
117 | f.write('\n')
118 | data = train_data + valid_data
119 | instructions = [_['instruction'] for _ in data]
120 | inputs = [_['input'] for _ in data]
121 | def batch(list, batch_size=batch_size):
122 | chunk_size = (len(list) - 1) // batch_size + 1
123 | for i in range(chunk_size):
124 | yield list[batch_size * i: batch_size * (i + 1)]
125 | for i, batch in tqdm(enumerate(zip(batch(instructions), batch(inputs)))):
126 | instructions, inputs = batch
127 | output = evaluate(instructions, inputs)
128 | outputs = outputs + output
129 |
130 | for i, test in tqdm(enumerate(data)):
131 | data[i]['predict'] = outputs[i]
132 |
133 | dpo_data = []
134 |
135 | for data_point in data:
136 | dpo_case = {}
137 | dpo_case['prompt'] = data_point['instruction'] + data_point['input']
138 | dpo_case['chosen'] = data_point['output']
139 | pattern = r'"(.*?)"'
140 | item_names = re.findall(pattern, data_point['predict'][0])
141 | formatted_item_names = [f'\"{item}\"' for item in item_names]
142 | if len(formatted_item_names) > 0:
143 | dpo_case['rejected'] = formatted_item_names[0]+"\n"
144 | else:
145 | dpo_case['rejected'] = "\n"
146 | dpo_data.append(dpo_case)
147 |
148 | # random.shuffle(dpo_data)
149 | dpo_train_data = dpo_data[:train_sample_size]
150 | dpo_valid_data = dpo_data[train_sample_size:]
151 |
152 |
153 | with open(result_json_dpo_data_train, 'w') as f:
154 | for item in dpo_train_data:
155 | json.dump(item, f)
156 | f.write('\n')
157 |
158 | with open(result_json_dpo_data_valid, 'w') as f:
159 | for item in dpo_valid_data:
160 | json.dump(item, f)
161 | f.write('\n')
162 |
163 |
164 | def generate_prompt(instruction, input=None):
165 | if input:
166 | return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
167 |
168 | ### Instruction:
169 | {instruction}
170 |
171 | ### Input:
172 | {input}
173 |
174 | ### Response:
175 | """
176 | else:
177 | return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
178 |
179 | ### Instruction:
180 | {instruction}
181 |
182 | ### Response:
183 | """
184 |
185 |
186 | if __name__ == "__main__":
187 | fire.Fire(main)
188 |
--------------------------------------------------------------------------------
/train/dpo.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import re
4 | import random
5 |
6 | from peft import get_peft_config, get_peft_model, get_peft_model_state_dict, LoraConfig, TaskType, PeftModel
7 | from transformers import AutoTokenizer, TrainingArguments, AutoModelForCausalLM, BitsAndBytesConfig
8 | from datasets import load_dataset, load_from_disk
9 | from trl import DPOTrainer, DPOConfig
10 | from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model
11 | # from utils import find_all_linear_names, print_trainable_parameters
12 | from transformers import LlamaForCausalLM, LlamaTokenizer
13 | import torch.nn.functional as F
14 | import torch
15 | import bitsandbytes as bnb
16 | from accelerate import Accelerator
17 | import fire
18 |
19 |
20 | def main(
21 | train_dataset = "",
22 | val_dataset = "",
23 | load_8bit: bool = True,
24 | base_model: str = "",
25 | gradient_accumulation_steps: int = 4,
26 | output_dir: str = "",
27 | wandb_project: str = "self_play",
28 | wandb_name: str = "", # the name of the wandb run
29 | batch_size:int = 2,
30 | num_epochs:int = 1,
31 | alpha:float = 1.5,
32 | learning_rate: float = 1e-5,
33 | cutoff_len: int = 512,
34 | eval_step = 0.05,
35 | resume_from_checkpoint:bool = False,
36 | seed = 99
37 | ):
38 |
39 | os.environ['WANDB_PROJECT'] = wandb_project
40 |
41 | train_dataset = load_dataset("json", data_files=train_dataset)
42 | train_data = train_dataset["train"].shuffle(seed=seed)
43 | val_dataset = load_dataset("json", data_files=val_dataset)
44 | val_data = val_dataset["train"].shuffle(seed=seed)
45 |
46 | device_index = Accelerator().process_index
47 | device_map = {"": device_index}
48 |
49 | bnb_config = BitsAndBytesConfig(
50 | # load_in_8bit=True,
51 | load_in_4bit=True,
52 | bnb_4bit_quant_type="nf4",
53 | bnb_4bit_compute_dtype=torch.bfloat16,
54 | bnb_4bit_use_double_quant=False,
55 | )
56 |
57 | device_index = Accelerator().process_index
58 | device_map = {"": device_index}
59 |
60 | model = AutoModelForCausalLM.from_pretrained(
61 | base_model,
62 | device_map=device_map,
63 | quantization_config=bnb_config
64 | )
65 | model.config.use_cache = False
66 | model = prepare_model_for_kbit_training(model)
67 |
68 | tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
69 | tokenizer.pad_token = tokenizer.eos_token
70 | tokenizer.padding_side = "right"
71 |
72 | if resume_from_checkpoint!="base_model":
73 | model = PeftModel.from_pretrained(
74 | model,
75 | resume_from_checkpoint,
76 | is_trainable=True
77 | )
78 | else:
79 | peft_config = LoraConfig(
80 | inference_mode=False,
81 | r=16,
82 | lora_alpha=32,
83 | target_modules=['k_proj', 'v_proj', 'q_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj'],
84 | lora_dropout=0.05,
85 | bias="none",
86 | task_type="CAUSAL_LM",
87 | )
88 | model = get_peft_model(model, peft_config)
89 |
90 | model.print_trainable_parameters()
91 |
92 | model_ref = AutoModelForCausalLM.from_pretrained(
93 | base_model,
94 | device_map=device_map,
95 | quantization_config=bnb_config
96 | )
97 |
98 | if resume_from_checkpoint:
99 | reference_model = PeftModel.from_pretrained(model_ref, resume_from_checkpoint)
100 | else:
101 | reference_model = model_ref
102 |
103 |
104 | training_args = DPOConfig(
105 | per_device_train_batch_size=batch_size,
106 | per_device_eval_batch_size=batch_size,
107 | gradient_accumulation_steps=gradient_accumulation_steps,
108 | warmup_steps=20,
109 | num_train_epochs=num_epochs,
110 | learning_rate=learning_rate,
111 | bf16=True,
112 | logging_steps=1,
113 | optim="adamw_torch",
114 | evaluation_strategy="steps",
115 | save_strategy="steps",
116 | output_dir=output_dir,
117 | save_total_limit=1,
118 | load_best_model_at_end=True,
119 | )
120 |
121 | dpo_trainer = DPOTrainer(
122 | model,
123 | reference_model,
124 | args=training_args,
125 | beta=0.1,
126 | train_dataset=train_data,
127 | eval_dataset=val_data,
128 | tokenizer=tokenizer,
129 | max_prompt_length=cutoff_len,
130 | max_length=cutoff_len,
131 | )
132 |
133 |
134 | dpo_trainer.train()
135 | dpo_trainer.save_model(output_dir)
136 |
137 |
138 | print("DPO training is done")
139 |
140 | if __name__ == "__main__":
141 | fire.Fire(main)
142 |
--------------------------------------------------------------------------------
/train/sft.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import re
4 | import wandb
5 |
6 | from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments,BitsAndBytesConfig
7 | from datasets import load_dataset
8 | from trl import SFTTrainer, DataCollatorForCompletionOnlyLM, SFTConfig
9 | from peft import AutoPeftModelForCausalLM, LoraConfig, get_peft_model, prepare_model_for_kbit_training, TaskType, PeftModel
10 | from transformers import LlamaForCausalLM, LlamaTokenizer
11 | # from utils import find_all_linear_names, print_trainable_parameters
12 | import random
13 | from accelerate import Accelerator
14 |
15 | import torch
16 | import bitsandbytes as bnb
17 | import fire
18 |
19 |
20 | def train(
21 | # path
22 | output_dir="",
23 | base_model ="",
24 | train_dataset="",
25 | valid_dataset="",
26 | train_sample_size:int = 1024,
27 | resume_from_checkpoint: str = "base_model", # either training checkpoint or final adapter
28 | # wandb config
29 | wandb_project: str = "",
30 | wandb_name: str = "", # the name of the wandb run
31 | # training hyperparameters
32 | gradient_accumulation_steps: int = 1,
33 | batch_size: int = 8,
34 | num_train_epochs: int = 5,
35 | learning_rate: float = 2e-5,
36 | cutoff_len: int = 512,
37 | eval_step = 0.05,
38 | seed=0
39 | ):
40 | os.environ['WANDB_PROJECT'] = wandb_project
41 |
42 | def formatting_prompts_func(examples):
43 | output_text = []
44 | for i in range(len(examples["instruction"])):
45 | instruction = examples["instruction"][i]
46 | input_text = examples["input"][i]
47 | response = examples["output"][i]
48 |
49 | if len(input_text) >= 2:
50 | text = f'''Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
51 |
52 | ### Instruction:
53 | {instruction}
54 |
55 | ### Input:
56 | {input_text}
57 |
58 | ### Response:
59 | {response}
60 | '''
61 | else:
62 | text = f'''Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
63 |
64 | ### Instruction:
65 | {instruction}
66 |
67 | ### Response:
68 | {response}
69 | '''
70 | output_text.append(text)
71 |
72 | return output_text
73 |
74 | train_dataset = load_dataset("json", data_files=train_dataset)
75 | train_data = train_dataset["train"].shuffle(seed=seed).select(range(train_sample_size))
76 | val_dataset = load_dataset("json", data_files=valid_dataset)
77 | val_data = val_dataset["train"].shuffle(seed=seed).select(range(int(train_sample_size/8)))
78 |
79 | bnb_config = BitsAndBytesConfig(
80 | # load_in_8bit=True,
81 | load_in_4bit=True,
82 | bnb_4bit_quant_type="nf4",
83 | bnb_4bit_compute_dtype=torch.bfloat16,
84 | bnb_4bit_use_double_quant=False,
85 | )
86 |
87 | device_index = Accelerator().process_index
88 | device_map = {"": device_index}
89 | #device_map = "auto"
90 | model = AutoModelForCausalLM.from_pretrained(
91 | base_model,
92 | device_map=device_map,
93 | quantization_config=bnb_config
94 | )
95 | model.config.use_cache = False
96 | model = prepare_model_for_kbit_training(model)
97 |
98 | tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
99 | tokenizer.pad_token = tokenizer.eos_token
100 | tokenizer.padding_side = "right"
101 |
102 | if resume_from_checkpoint!="base_model":
103 | model = PeftModel.from_pretrained(
104 | model,
105 | resume_from_checkpoint,
106 | is_trainable=True
107 | )
108 | else:
109 | peft_config = LoraConfig(
110 | inference_mode=False,
111 | r=16,
112 | lora_alpha=32,
113 | target_modules=['k_proj', 'v_proj', 'q_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj'],
114 | lora_dropout=0.05,
115 | bias="none",
116 | task_type="CAUSAL_LM",
117 | )
118 | model = get_peft_model(model, peft_config)
119 |
120 | model.print_trainable_parameters()
121 |
122 | training_args = SFTConfig(
123 | per_device_train_batch_size=batch_size,
124 | per_device_eval_batch_size=batch_size,
125 | gradient_accumulation_steps=gradient_accumulation_steps,
126 | warmup_steps=20,
127 | num_train_epochs=num_train_epochs,
128 | learning_rate=learning_rate,
129 | bf16=True,
130 | logging_steps=1,
131 | optim="adamw_torch",
132 | evaluation_strategy="steps",
133 | save_strategy="steps",
134 | output_dir=output_dir,
135 | save_total_limit=1,
136 | load_best_model_at_end=True,
137 | report_to=None,
138 | )
139 |
140 | trainer = SFTTrainer(
141 | model,
142 | train_dataset=train_data,
143 | eval_dataset=val_data,
144 | tokenizer=tokenizer,
145 | formatting_func=formatting_prompts_func,
146 | max_seq_length=cutoff_len,
147 | args=training_args
148 | )
149 |
150 | trainer.train()
151 | trainer.save_model(output_dir)
152 |
153 | output_dir = os.path.join(output_dir, "final_model")
154 | trainer.model.save_pretrained(output_dir,safe_serialization=False)
155 | tokenizer.save_pretrained(output_dir)
156 |
157 | if __name__ == "__main__":
158 | fire.Fire(train)
--------------------------------------------------------------------------------