├── README.md ├── data_utils.py ├── deepspeed.conf ├── dpo_utils.py ├── main.py ├── requirements.txt └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # RefDPO: Understanding Reference Policies in Direct Preference Optimization 2 | 3 | The repository contains the training scripts, datasets, and model checkpoints for the paper ["Understanding Reference Policies in Direct Preference Optimization"](https://arxiv.org/abs/2407.13709). 4 | 5 | ## Quick Links 6 | 7 | - [Installation](#installation) 8 | - [Running the code](#running-the-code) 9 | - [Datasets](#datasets) 10 | - [Experimental Results](#experimental-results) 11 | - [RQ1: What Is the Optimal KL Constraint Strength for DPO?](#rq1-what-is-the-optimal-kl-constraint-strength-for-dpo) 12 | - [RQ2: Is a Reference Policy Necessary for Effective Preference Learning?](#rq2-is-a-reference-policy-necessary-for-effective-preference-learning) 13 | - [RQ3: Does DPO Benefit from Stronger Reference Policies?](#rq3-does-dpo-benefit-from-stronger-reference-policies) 14 | - Resource collection on Huggingface: [RefDPO](https://huggingface.co/collections/yale-nlp/refdpo-669987117dd799b55ac5b552) 15 | 16 | ## Installation 17 | 18 | Our code base is based on Huggingface's Transformers library, deepspeed, and PyTorch. 19 | To install the required dependencies, run the following command: 20 | 21 | ```bash 22 | pip install -r requirements.txt 23 | ``` 24 | 25 | Our code base is adapted from the the [open-instruct](https://github.com/allenai/open-instruct) repository. 26 | 27 | ## Running the code 28 | 29 | To run the code, you can use the following command (it assumes there are 8 GPUs available): 30 | 31 | ```bash 32 | accelerate launch \ 33 | --mixed_precision bf16 \ 34 | --num_machines 1 \ 35 | --num_processes 8 \ 36 | --use_deepspeed \ 37 | --deepspeed_config_file deepspeed.conf \ 38 | main.py \ 39 | --cuda \ 40 | --dataset 'yale-nlp/RefDPO' \ 41 | --data_split 'mistral' \ 42 | --epoch 3 \ 43 | --beta 0.1 \ 44 | --dpo_weight 1.0 \ 45 | --model_type 'HuggingFaceH4/mistral-7b-sft-beta' \ 46 | --insert_eos \ 47 | -l 48 | ``` 49 | 50 | Each argument is explained in the `main.py` file. 51 | 52 | To run the training *without* the reference model, you can use the `--ref_free` flag: 53 | 54 | ```bash 55 | accelerate launch \ 56 | --mixed_precision bf16 \ 57 | --num_machines 1 \ 58 | --num_processes 8 \ 59 | --use_deepspeed \ 60 | --deepspeed_config_file deepspeed.conf \ 61 | main.py \ 62 | --cuda \ 63 | --dataset 'yale-nlp/RefDPO' \ 64 | --data_split 'mistral' \ 65 | --epoch 3 \ 66 | --beta 10.0 \ 67 | --dpo_weight 1.0 \ 68 | --ref_free \ 69 | --model_type 'HuggingFaceH4/mistral-7b-sft-beta' \ 70 | --insert_eos \ 71 | -l 72 | ``` 73 | 74 | ### Code structure 75 | 76 | - `main.py`: The main file to run the code. 77 | - `data_utils.py`: The data processing utilities. 78 | - `utils.py`: Utility functions. 79 | - `deepspeed.conf`: The deepspeed configuration file. 80 | - `dpo_utils.py`: The DPO utilities. 81 | 82 | ### Resource requirements 83 | 84 | Our training requires 8 GPUs with 48GB of memory each. We use the `deepspeed` library to distribute the training across multiple GPUs. 85 | 86 | 87 | ## Datasets 88 | 89 | We have made the datasets used in the paper available on Huggingface's dataset hub: [yale-nlp/RefDPO](https://huggingface.co/datasets/yale-nlp/RefDPO). 90 | It contains 5 different datasets. 91 | Each dataset is built upon the [UltraFeedback](https://huggingface.co/datasets/openbmb/UltraFeedback) dataset, specifically its binarized version [ultrafeedback_binarized_cleaned](https://huggingface.co/datasets/allenai/ultrafeedback_binarized_cleaned) converted from [ultrafeedback_binarized](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized). 92 | The datasets contain **pre-computed log-probabilities** of the reference policy/model for the output pairs in the UltraFeedback dataset. 93 | 94 | | Dataset | Reference Model | Description | 95 | |---------|-----------------|-------------| 96 | | `mistral` | [HuggingFaceH4/mistral-7b-sft-beta](https://huggingface.co/HuggingFaceH4/mistral-7b-sft-beta) | The log-probabilities are computed using the Mistral-7B-SFT model. | 97 | | `tulu2` | [allenai/tulu-2-7b](https://huggingface.co/allenai/tulu-2-7b) | The log-probabilities are computed using the Tulu-2-7B model. | 98 | | `mistral_prior` | [HuggingFaceH4/mistral-7b-sft-beta](https://huggingface.co/HuggingFaceH4/mistral-7b-sft-beta) | The **prior** (unconditional) log-probabilities are computed using the Mistral-7B-SFT model. | 99 | | `mistralv2` | [mistralai/Mistral-7B-Instruct-v0.2](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2) | The log-probabilities are computed using the Mistral-7B-Instruct-v0.2 model. | 100 | | `llama3` | [meta-llama/Meta-Llama-3-70B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-70B-Instruct) | The log-probabilities are computed using the Meta-Llama-3-70B-Instruct model. | 101 | 102 | 103 | ## Experimental Results 104 | 105 | Below are the model checkpoints for the models trained in the paper. 106 | 107 | 108 | ### RQ1: What Is the Optimal KL Constraint Strength for DPO? 109 | 110 | Below are the model checkpoints fine-tuned with DPO from [mistral-7b-sft](https://huggingface.co/HuggingFaceH4/mistral-7b-sft-beta) and [tulu-2-7b](https://huggingface.co/allenai/tulu-2-7b). The models are fine-tuned with different KL constraint strengths (**$\beta$**). 111 | 112 | The checkpoints are available on Huggingface's model hub. They are evaluated using the length-controlled AlpacaEval2 score [reference](https://arxiv.org/abs/2404.04475). 113 | 114 | 115 | #### Checkpoints fine-tuned from [mistral-7b-sft](https://huggingface.co/HuggingFaceH4/mistral-7b-sft-beta) 116 | 117 | 118 | | $\beta$ | HF Checkpoint | AlpacaEval2 LC-Score | 119 | |-------|---------------|-----------------| 120 | | 0.1 | [yale-nlp/mistral-7b-dpo-beta-0.1](https://huggingface.co/yale-nlp/mistral-7b-dpo-beta-0.1) | 14.03 | 121 | | 0.05 | [yale-nlp/mistral-7b-dpo-beta-0.05](https://huggingface.co/yale-nlp/mistral-7b-dpo-beta-0.05) | 13.29 | 122 | | 0.02 | [yale-nlp/mistral-7b-dpo-beta-0.02](https://huggingface.co/yale-nlp/mistral-7b-dpo-beta-0.02) | 16.06 | 123 | | 0.01 | [yale-nlp/mistral-7b-dpo-beta-0.01](https://huggingface.co/yale-nlp/mistral-7b-dpo-beta-0.01) | **16.25** | 124 | | 0.005 | [yale-nlp/mistral-7b-dpo-beta-0.005](https://huggingface.co/yale-nlp/mistral-7b-dpo-beta-0.005) | 12.36 | 125 | 126 | 127 | #### Checkpoints fine-tuned from [tulu-2-7b](https://huggingface.co/allenai/tulu-2-7b) 128 | 129 | | $\beta$ | HF Checkpoint | AlpacaEval2 LC-Score | 130 | |-------|---------------|-----------------| 131 | | 0.1 | [yale-nlp/tulu2-7b-dpo-beta-0.1](https://huggingface.co/yale-nlp/tulu2-7b-dpo-beta-0.1) | 9.38 | 132 | | 0.05 | [yale-nlp/tulu2-7b-dpo-beta-0.05](https://huggingface.co/yale-nlp/tulu2-7b-dpo-beta-0.05) | 9.96 | 133 | | 0.02 | [yale-nlp/tulu2-7b-dpo-beta-0.02](https://huggingface.co/yale-nlp/tulu2-7b-dpo-beta-0.02) | **10.46** | 134 | | 0.01 | [yale-nlp/tulu2-7b-dpo-beta-0.01](https://huggingface.co/yale-nlp/tulu2-7b-dpo-beta-0.01) | 7.86 | 135 | | 0.005 | [yale-nlp/tulu2-7b-dpo-beta-0.005](https://huggingface.co/yale-nlp/tulu2-7b-dpo-beta-0.005) | [degenerate] | 136 | 137 | 138 | ### RQ2: Is a Reference Policy Necessary for Effective Preference Learning? 139 | 140 | Below are the optimal checkpoints fine-tuned from [mistral-7b-sft](https://huggingface.co/HuggingFaceH4/mistral-7b-sft-beta) with three different reward parameterizations. 141 | 142 | | Reward Parameterization | HF Checkpoint | AlpacaEval2 LC-Score | $\beta$ | 143 | |-------|---------------|-----------------| ---- | 144 | | $\beta \cdot \frac{p_\theta(y\|x)}{p_{\mathrm{ref}}(y\|x)}$ (DPO) | [yale-nlp/mistral-7b-dpo-beta-0.01](https://huggingface.co/yale-nlp/mistral-7b-dpo-beta-0.01) | 16.25 | 0.01 | 145 | | $\beta \cdot p_\theta(y\|x)$ (Posterior Probability) | [yale-nlp/mistral-probability](https://huggingface.co/yale-nlp/mistral-probability) | 12.84 | 100.0 | 146 | | $\beta \cdot p_\theta(x\|y)$ (Likelihood Function) | [yale-nlp/mistral-likelihood](https://huggingface.co/yale-nlp/mistral-likelihood) | 13.63 | 0.01 | 147 | 148 | ### RQ3: Does DPO Benefit from Stronger Reference Policies? 149 | 150 | Here we use two stronger reference models [Mistral-7B-Instruct-v0.2](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2) and [Meta-Llama-3-70B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-70B-Instruct) for DPO training. 151 | 152 | 153 | #### Checkpoints fine-tuned from [mistral-7b-sft](https://huggingface.co/HuggingFaceH4/mistral-7b-sft-beta) using [Mistral-7B-Instruct-v0.2](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2) as the reference policy. 154 | 155 | 156 | | $\beta$ | HF Checkpoint | AlpacaEval2 LC-Score | 157 | |-------|---------------|-----------------| 158 | | 10.0 | [yale-nlp/mistral-7b-dpo-mistralv2-7b-beta-10.0](https://huggingface.co/yale-nlp/mistral-7b-dpo-mistralv2-7b-beta-10.0) | 18.74 | 159 | | 1.00 | [yale-nlp/mistral-7b-dpo-mistralv2-7b-beta-1.0](https://huggingface.co/yale-nlp/mistral-7b-dpo-mistralv2-7b-beta-1.0) | **20.25** | 160 | | 0.10 | [yale-nlp/mistral-7b-dpo-mistralv2-7b-beta-0.1](https://huggingface.co/yale-nlp/mistral-7b-dpo-mistralv2-7b-beta-0.1) | 19.58 | 161 | | 0.01 | [yale-nlp/mistral-7b-dpo-mistralv2-7b-beta-0.01](https://huggingface.co/yale-nlp/mistral-7b-dpo-mistralv2-7b-beta-0.01) | 17.18 | 162 | | 0.005 | [yale-nlp/mistral-7b-dpo-mistralv2-7b-beta-0.005](https://huggingface.co/yale-nlp/mistral-7b-dpo-mistralv2-7b-beta-0.005) | 15.34 | 163 | 164 | 165 | #### Checkpoints fine-tuned from [mistral-7b-sft](https://huggingface.co/HuggingFaceH4/mistral-7b-sft-beta) using [Meta-Llama-3-70B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-70B-Instruct) as the reference policy. 166 | 167 | 168 | | $\beta$ | HF Checkpoint | AlpacaEval2 LC-Score | 169 | |-------|---------------|-----------------| 170 | | 10.0 | [yale-nlp/mistral-7b-dpo-llama3-70b-beta-10.0](https://huggingface.co/yale-nlp/mistral-7b-dpo-llama3-70b-beta-10.0) | 13.29 | 171 | | 1.00 | [yale-nlp/mistral-7b-dpo-llama3-70b-beta-1.0](https://huggingface.co/yale-nlp/mistral-7b-dpo-llama3-70b-beta-1.0) | 9.59 | 172 | | 0.10 | [yale-nlp/mistral-7b-dpo-llama3-70b-beta-0.1](https://huggingface.co/yale-nlp/mistral-7b-dpo-llama3-70b-beta-0.1) | 10.99 | 173 | | 0.01 | [yale-nlp/mistral-7b-dpo-llama3-70b-beta-0.01](https://huggingface.co/yale-nlp/mistral-7b-dpo-llama3-70b-beta-0.01) | **15.37** | 174 | | 0.005 | [yale-nlp/mistral-7b-dpo-llama3-70b-beta-0.005](https://huggingface.co/yale-nlp/mistral-7b-dpo-llama3-70b-beta-0.005) | 11.70 | 175 | 176 | 177 | #### Checkpoints fine-tuned from [tulu-2-7b](https://huggingface.co/allenai/tulu-2-7b) using [Mistral-7B-Instruct-v0.2](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2) as the reference policy. 178 | 179 | 180 | | $\beta$ | HF Checkpoint | AlpacaEval2 LC-Score | 181 | |-------|---------------|-----------------| 182 | | 10.0 | [yale-nlp/tulu2-7b-dpo-mistralv2-7b-beta-10.0](https://huggingface.co/yale-nlp/tulu2-7b-dpo-mistralv2-7b-beta-10.0) | 7.61 | 183 | | 1.00 | [yale-nlp/tulu2-7b-dpo-mistralv2-7b-beta-1.0](https://huggingface.co/yale-nlp/tulu2-7b-dpo-mistralv2-7b-beta-1.0) | **7.85** | 184 | | 0.10 | [yale-nlp/tulu2-7b-dpo-mistralv2-7b-beta-0.1](https://huggingface.co/yale-nlp/tulu2-7b-dpo-mistralv2-7b-beta-0.1) | [degenerate] | 185 | | 0.01 | [yale-nlp/tulu2-7b-dpo-mistralv2-7b-beta-0.01](https://huggingface.co/yale-nlp/tulu2-7b-dpo-mistralv2-7b-beta-0.01) | [degenerate] | 186 | | 0.005 | [yale-nlp/tulu2-7b-dpo-mistralv2-7b-beta-0.005](https://huggingface.co/yale-nlp/tulu2-7b-dpo-mistralv2-7b-beta-0.005) | [degenerate] | 187 | 188 | 189 | #### Checkpoints fine-tuned from [tulu-2-7b](https://huggingface.co/allenai/tulu-2-7b) using [Meta-Llama-3-70B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-70B-Instruct) as the reference policy. 190 | 191 | 192 | | $\beta$ | HF Checkpoint | AlpacaEval2 LC-Score | 193 | |-------|---------------|-----------------| 194 | | 10.0 | [yale-nlp/tulu2-7b-dpo-llama3-70b-beta-10.0](https://huggingface.co/yale-nlp/tulu2-7b-dpo-llama3-70b-beta-10.0) | 9.79 | 195 | | 1.00 | [yale-nlp/tulu2-7b-dpo-llama3-70b-beta-1.0](https://huggingface.co/yale-nlp/tulu2-7b-dpo-llama3-70b-beta-1.0) | **11.17** | 196 | | 0.10 | [yale-nlp/tulu2-7b-dpo-llama3-70b-beta-0.1](https://huggingface.co/yale-nlp/tulu2-7b-dpo-llama3-70b-beta-0.1) | 10.31 | 197 | | 0.01 | [yale-nlp/tulu2-7b-dpo-llama3-70b-beta-0.01](https://huggingface.co/yale-nlp/tulu2-7b-dpo-llama3-70b-beta-0.01) | 9.16 | 198 | | 0.005 | [yale-nlp/tulu2-7b-dpo-llama3-70b-beta-0.005](https://huggingface.co/yale-nlp/tulu2-7b-dpo-llama3-70b-beta-0.005) | 3.29 | 199 | 200 | 201 | -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import torch 3 | import copy 4 | 5 | 6 | def to_cuda(batch, gpuid): 7 | for n in batch: 8 | if n != "data": 9 | batch[n] = batch[n].to(gpuid) 10 | 11 | 12 | class PreferenceBaseDataset(Dataset): 13 | def __init__(self, data, tokenizer, max_len=2048, is_test=False, insert_eos=False): 14 | """data format: article, abstract, [(candidiate_i, score_i)]""" 15 | self.data = data 16 | self.tokenizer = copy.deepcopy(tokenizer) 17 | self.max_len = max_len 18 | self.is_test = is_test 19 | self.num = len(self.data) 20 | self.insert_eos = insert_eos 21 | 22 | def __len__(self): 23 | return self.num 24 | 25 | def encode_with_messages_format(self, example): 26 | """ 27 | from https://github.com/allenai/open-instruct/blob/main/open_instruct/dpo_tune.py#L252 28 | Here we assume each example has a rejected and chosen field, both of which are a list of messages. 29 | Each message is a dict with 'role' and 'content' fields. 30 | We concatenate all messages with the roles as delimiters and tokenize them together. 31 | We assume only the last message is different, and the prompt is contained in the list of messages. 32 | """ 33 | chosen_messages = example["chosen"] 34 | rejected_messages = example["rejected"] 35 | if len(chosen_messages) == 0: 36 | raise ValueError("chosen messages field is empty.") 37 | if len(rejected_messages) == 0: 38 | raise ValueError("rejected messages field is empty.") 39 | eos_insert = self.tokenizer.eos_token if self.insert_eos else "" 40 | def _concat_messages(messages): 41 | message_text = "" 42 | for message in messages: 43 | if message["role"] == "system": 44 | message_text += "<|system|>\n" + message["content"].strip() + eos_insert + "\n" 45 | elif message["role"] == "user": 46 | message_text += "<|user|>\n" + message["content"].strip() + eos_insert + "\n" 47 | elif message["role"] == "assistant": 48 | message_text += ( 49 | "<|assistant|>\n" 50 | + message["content"].strip() 51 | + self.tokenizer.eos_token 52 | + "\n" 53 | ) 54 | else: 55 | raise ValueError("Invalid role: {}".format(message["role"])) 56 | return message_text 57 | 58 | def encode_messages(messages): 59 | example_text = _concat_messages(messages).strip() 60 | tokenized_example = self.tokenizer( 61 | example_text, 62 | return_tensors="pt", 63 | max_length=self.max_len, 64 | truncation=True, 65 | ) 66 | input_ids = tokenized_example.input_ids 67 | masks = torch.ones_like(input_ids) 68 | 69 | # mask the non-assistant part for avoiding loss 70 | for message_idx, message in enumerate(messages): 71 | if message["role"] != "assistant": 72 | if message_idx == 0: 73 | message_start_idx = 0 74 | else: 75 | message_start_idx = self.tokenizer( 76 | _concat_messages(messages[:message_idx]), 77 | return_tensors="pt", 78 | max_length=self.max_len, 79 | truncation=True, 80 | ).input_ids.shape[1] 81 | if ( 82 | message_idx < len(messages) - 1 83 | and messages[message_idx + 1]["role"] == "assistant" 84 | ): 85 | # here we also ignore the role of the assistant 86 | messages_so_far = ( 87 | _concat_messages(messages[: message_idx + 1]) 88 | + "<|assistant|>\n" 89 | ) 90 | else: 91 | messages_so_far = _concat_messages(messages[: message_idx + 1]) 92 | message_end_idx = self.tokenizer( 93 | messages_so_far, 94 | return_tensors="pt", 95 | max_length=self.max_len, 96 | truncation=True, 97 | ).input_ids.shape[1] 98 | masks[:, message_start_idx:message_end_idx] = 0 99 | 100 | if message_end_idx >= self.max_len: 101 | break 102 | 103 | return { 104 | "input_ids": input_ids.flatten(), 105 | "masks": masks.flatten(), 106 | } 107 | 108 | chosen_encoded = encode_messages(chosen_messages) 109 | rejected_encoded = encode_messages(rejected_messages) 110 | 111 | return { 112 | "chosen_input_ids": chosen_encoded["input_ids"], 113 | "chosen_masks": chosen_encoded["masks"], 114 | "rejected_input_ids": rejected_encoded["input_ids"], 115 | "rejected_masks": rejected_encoded["masks"], 116 | } 117 | 118 | def __getitem__(self, idx): 119 | data = self.data[idx] 120 | encoded = self.encode_with_messages_format(data) 121 | if self.is_test: 122 | encoded["data"] = data 123 | return encoded 124 | 125 | 126 | def collate_preference_base(batch, pad_token_id, is_test=False): 127 | def pad(X, padding, max_len=-1, pad_side="left"): 128 | assert pad_side in ["left", "right"] 129 | if max_len < 0: 130 | max_len = max(x.size(0) for x in X) 131 | result = torch.ones(len(X), max_len, dtype=X[0].dtype) * padding 132 | attention_mask = torch.zeros(len(X), max_len, dtype=X[0].dtype) 133 | for i, x in enumerate(X): 134 | if pad_side == "left": 135 | result[i, -x.size(0) :] = x 136 | attention_mask[i, -x.size(0) :] = 1 137 | else: 138 | result[i, : x.size(0)] = x 139 | attention_mask[i, : x.size(0)] = 1 140 | return result, attention_mask 141 | 142 | # pad chosen 143 | chosen_input_ids, chosen_attention_mask = pad( 144 | [x["chosen_input_ids"] for x in batch], pad_token_id, pad_side="left" 145 | ) 146 | chosen_masks, _ = pad([x["chosen_masks"] for x in batch], 0, pad_side="left") 147 | 148 | # pad rejected 149 | rejected_input_ids, rejected_attention_mask = pad( 150 | [x["rejected_input_ids"] for x in batch], pad_token_id, pad_side="left" 151 | ) 152 | rejected_masks, _ = pad([x["rejected_masks"] for x in batch], 0, pad_side="left") 153 | 154 | # concatenate 155 | input_ids = torch.unbind(chosen_input_ids) + torch.unbind(rejected_input_ids) 156 | attention_mask = torch.unbind(chosen_attention_mask) + torch.unbind(rejected_attention_mask) 157 | masks = torch.unbind(chosen_masks) + torch.unbind(rejected_masks) 158 | 159 | # right pad now 160 | input_ids, _attention_mask = pad(input_ids, pad_token_id, pad_side="right") 161 | attention_mask, _ = pad(attention_mask, 0, pad_side="right") 162 | attention_mask = attention_mask * _attention_mask 163 | masks, _ = pad(masks, 0, pad_side="right") 164 | 165 | result = { 166 | "input_ids": input_ids, 167 | "masks": masks, 168 | "attention_mask": attention_mask, 169 | } 170 | if is_test: 171 | result["data"] = [x["data"] for x in batch] 172 | result["chosen_input_ids"] = [x["chosen_input_ids"] for x in batch] 173 | result["rejected_input_ids"] = [x["rejected_input_ids"] for x in batch] 174 | return result 175 | 176 | 177 | class PreferenceDataset(PreferenceBaseDataset): 178 | def __getitem__(self, idx): 179 | data = self.data[idx] 180 | encoded = self.encode_with_messages_format(data) 181 | encoded["chosen_logprob"] = data["chosen_logprob"] 182 | encoded["rejected_logprob"] = data["rejected_logprob"] 183 | if self.is_test: 184 | encoded["data"] = data 185 | return encoded 186 | 187 | 188 | def collate_preference(batch, pad_token_id, is_test=False): 189 | results = collate_preference_base(batch, pad_token_id, is_test=is_test) 190 | chosen_logprob = torch.tensor([x["chosen_logprob"] for x in batch]) 191 | rejected_logprob = torch.tensor([x["rejected_logprob"] for x in batch]) 192 | results["chosen_logprob"] = chosen_logprob 193 | results["rejected_logprob"] = rejected_logprob 194 | return results -------------------------------------------------------------------------------- /deepspeed.conf: -------------------------------------------------------------------------------- 1 | { 2 | "bf16": { 3 | "enabled": true 4 | }, 5 | "zero_optimization": { 6 | "stage": 3, 7 | "overlap_comm": true, 8 | "contiguous_gradients": true, 9 | "sub_group_size": 1e9, 10 | "reduce_bucket_size": "auto", 11 | "stage3_prefetch_bucket_size": "auto", 12 | "stage3_param_persistence_threshold": "auto", 13 | "stage3_max_live_parameters": 1e9, 14 | "stage3_max_reuse_distance": 1e9, 15 | "stage3_gather_16bit_weights_on_model_save": true 16 | }, 17 | "gradient_accumulation_steps": "auto", 18 | "gradient_clipping": 1.0, 19 | "steps_per_print": 1e5, 20 | "train_batch_size": "auto", 21 | "train_micro_batch_size_per_gpu": "auto", 22 | "wall_clock_breakdown": false 23 | } -------------------------------------------------------------------------------- /dpo_utils.py: -------------------------------------------------------------------------------- 1 | ''' 2 | DPO utils 3 | Adapted from https://github.com/eric-mitchell/direct-preference-optimization/blob/main/trainers.py 4 | ''' 5 | import torch 6 | import torch.nn.functional as F 7 | from typing import Tuple 8 | 9 | 10 | def dpo_loss(policy_chosen_logps: torch.FloatTensor, 11 | policy_rejected_logps: torch.FloatTensor, 12 | reference_chosen_logps: torch.FloatTensor, 13 | reference_rejected_logps: torch.FloatTensor, 14 | beta: float, 15 | reference_free: bool = False) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: 16 | """Compute the DPO loss for a batch of policy and reference model log probabilities. 17 | 18 | Args: 19 | policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,) 20 | policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,) 21 | reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,) 22 | reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,) 23 | beta: Temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5. We ignore the reference model as beta -> 0. 24 | reference_free: If True, we ignore the _provided_ reference model and implicitly use a reference model that assigns equal probability to all responses. 25 | 26 | Returns: 27 | A tuple of three tensors: (losses, chosen_rewards, rejected_rewards). 28 | The losses tensor contains the DPO loss for each example in the batch. 29 | The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively. 30 | """ 31 | pi_logratios = policy_chosen_logps - policy_rejected_logps 32 | ref_logratios = reference_chosen_logps - reference_rejected_logps 33 | 34 | if reference_free: 35 | ref_logratios = 0 36 | 37 | logits = pi_logratios - ref_logratios 38 | 39 | losses = -F.logsigmoid(beta * logits) 40 | chosen_rewards = beta * (policy_chosen_logps - reference_chosen_logps).detach() 41 | rejected_rewards = beta * (policy_rejected_logps - reference_rejected_logps).detach() 42 | 43 | return losses, chosen_rewards, rejected_rewards -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | import argparse 4 | import numpy as np 5 | import os 6 | import random 7 | from transformers import AutoTokenizer, AutoModelForCausalLM, get_linear_schedule_with_warmup 8 | from utils import Recorder 9 | from data_utils import collate_preference, PreferenceDataset 10 | from torch.utils.data import DataLoader 11 | from functools import partial 12 | from datetime import datetime 13 | from datasets import load_dataset 14 | from dpo_utils import dpo_loss as dpo_loss_fn 15 | from tqdm import tqdm 16 | from accelerate import Accelerator 17 | from accelerate.utils import set_seed 18 | import math 19 | import deepspeed 20 | from deepspeed.accelerator import get_accelerator 21 | 22 | 23 | def base_setting(args): 24 | args.batch_size = getattr(args, 'batch_size', 4) # batch size on one gpu, one step 25 | args.report_freq = getattr(args, "report_freq", 10) # report frequency 26 | args.accumulate_step = getattr(args, "accumulate_step", 1) # accumulate gradients steps 27 | args.warmup_ratio = getattr(args, "warmup_ratio", 0.1) # warmup steps 28 | args.grad_norm = getattr(args, "grad_norm", 0) # gradient norm 29 | args.seed = getattr(args, "seed", 18890426) # random seed 30 | args.pretrained = getattr(args, "pretrained", None) # pretrained model path 31 | args.max_lr = getattr(args, "max_lr", 5e-7) # max learning rate (* 1e-2) 32 | args.max_len = getattr(args, "max_len", 2048) # max length of input 33 | args.device = getattr(args, "device", "auto") # device 34 | args.allow_tf32 = getattr(args, "allow_tf32", True) # allow tf32 35 | args.mixed_precision = getattr(args, "mixed_precision", True) # mixed precision 36 | args.gradient_checkpointing = getattr(args, "gradient_checkpointing", True) # gradient checkpointing 37 | args.use_flash_attention = getattr(args, "use_flash_attention", True) # use flash attention 38 | args.empty_cache = getattr(args, "empty_cache", False) # flush cache 39 | 40 | 41 | def test(dataloader, model, args, is_master): 42 | model.eval() 43 | batch_cnt = 0 44 | all_loss = 0 45 | all_pos_logits, all_neg_logits = 0, 0 46 | with torch.no_grad(): 47 | # scoring 48 | for batch in tqdm(dataloader, total=len(dataloader), disable=not is_master, desc="evaluating"): 49 | input_ids = batch["input_ids"] 50 | attention_mask = batch["attention_mask"] 51 | output = model( 52 | input_ids=input_ids, 53 | attention_mask=attention_mask, 54 | output_hidden_states=False, 55 | use_cache=False, 56 | ) 57 | output = output[0] 58 | output = output[:, :-1] # truncate last logit 59 | labels = input_ids[:, 1:] # shift labels 60 | output = output.to(torch.float32) 61 | logits = torch.log_softmax(output, dim=-1) 62 | logits = logits.gather(2, labels.unsqueeze(2)).squeeze(2) 63 | masks = batch["masks"][:, 1:] # actual mask 64 | logits = logits * masks 65 | batch_size = logits.size(0) // 2 66 | logits = logits.sum(dim=1) 67 | pos_logits, neg_logits = logits[:batch_size], logits[batch_size:] 68 | pos_ref_logits = batch["chosen_logprob"] 69 | neg_ref_logits = batch["rejected_logprob"] 70 | dpo_loss, _, _ = dpo_loss_fn(pos_logits, neg_logits, pos_ref_logits, neg_ref_logits, args.beta, args.ref_free) 71 | dpo_loss = dpo_loss.mean() 72 | loss = args.dpo_weight * dpo_loss 73 | all_loss += loss.detach().float() 74 | all_pos_logits += pos_logits.mean().detach().float() 75 | all_neg_logits += neg_logits.mean().detach().float() 76 | batch_cnt += 1 77 | loss = all_loss / batch_cnt 78 | pos_logits = all_pos_logits / batch_cnt 79 | neg_logits = all_neg_logits / batch_cnt 80 | results = {"loss": loss, "pos_logits": pos_logits, "neg_logits": neg_logits} 81 | model.train() 82 | return results 83 | 84 | 85 | def run(args): 86 | base_setting(args) 87 | # task initialization 88 | torch.manual_seed(args.seed) 89 | torch.cuda.manual_seed_all(args.seed) 90 | np.random.seed(args.seed) 91 | random.seed(args.seed) 92 | set_seed(args.seed) 93 | if args.allow_tf32: 94 | torch.backends.cuda.matmul.allow_tf32 = True 95 | # build tokenizer 96 | tokenizer = AutoTokenizer.from_pretrained(args.model_type, use_fast=False) 97 | if args.reset_pad_token: # reset pad token 98 | tokenizer.pad_token = None 99 | # add pad token 100 | tokenizer.add_special_tokens({"pad_token": ""}) 101 | # build dataloader 102 | collate_fn = partial(collate_preference, pad_token_id=tokenizer.pad_token_id, is_test=False) 103 | train_data = load_dataset(args.dataset, args.data_split)["train"] 104 | val_data = load_dataset(args.dataset, args.data_split)["val"] 105 | 106 | train_set = PreferenceDataset(train_data, tokenizer=tokenizer, max_len=args.max_len, is_test=False, insert_eos=args.insert_eos) 107 | val_set = PreferenceDataset(val_data, tokenizer=tokenizer, max_len=args.max_len, is_test=False, insert_eos=args.insert_eos) 108 | dataloader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=4, collate_fn=collate_fn) 109 | val_dataloader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, num_workers=4, collate_fn=collate_fn) 110 | # build model 111 | model_path = args.pretrained if args.pretrained is not None else args.model_type 112 | if len(args.model_pt) > 0: 113 | model_path = args.model_pt 114 | 115 | if args.mixed_precision: 116 | accelerator = Accelerator(gradient_accumulation_steps=args.accumulate_step, mixed_precision="bf16") 117 | else: 118 | accelerator = Accelerator(gradient_accumulation_steps=args.accumulate_step) 119 | 120 | accelerator.wait_for_everyone() 121 | 122 | is_master = accelerator.is_main_process 123 | now = datetime.now() 124 | date = now.strftime("%y-%m-%d") 125 | if is_master: 126 | id = len(os.listdir("./cache")) 127 | while os.path.exists(os.path.join("./cache", f"{date}-{id}")): 128 | id += 1 129 | recorder = Recorder(id, args.log) 130 | else: 131 | id = 0 132 | 133 | id = torch.tensor(id).to(accelerator.device).float() 134 | id = accelerator.gather(id).sum().item() 135 | fpath = os.path.join("./cache", f"{date}-{int(id)}") 136 | 137 | 138 | if args.use_flash_attention: 139 | model = AutoModelForCausalLM.from_pretrained(model_path, attn_implementation="flash_attention_2") 140 | else: 141 | model = AutoModelForCausalLM.from_pretrained(model_path) 142 | 143 | embeddings = model.get_input_embeddings() 144 | with deepspeed.zero.GatheredParameters(embeddings.weight, modifier_rank=None): 145 | model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8) 146 | 147 | model.config.pad_token_id = tokenizer.pad_token_id 148 | model.model.embed_tokens.padding_idx = tokenizer.pad_token_id 149 | model.config.vocab_size = len(tokenizer) 150 | model.train() 151 | 152 | if args.gradient_checkpointing: 153 | model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) 154 | 155 | optimizer = optim.AdamW(model.parameters(), lr=args.max_lr) 156 | actual_batch_size = args.batch_size * args.accumulate_step * accelerator.num_processes 157 | total_steps = math.ceil(len(dataloader) * args.epoch / actual_batch_size * accelerator.num_processes * args.batch_size) 158 | warmup_steps = int(args.warmup_ratio * total_steps) 159 | if is_master: 160 | recorder.print(f"total steps: {total_steps}") 161 | scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps) 162 | if is_master: 163 | recorder.write_config(args, [model], __file__) 164 | minimum_loss = 1e5 165 | all_step_cnt = 0 166 | model, optimizer, dataloader, val_dataloader, scheduler = accelerator.prepare( 167 | model, optimizer, dataloader, val_dataloader, scheduler 168 | ) 169 | 170 | def save_with_accelerate(model, model_name): 171 | accelerator.wait_for_everyone() 172 | state_dict = accelerator.get_state_dict(model) 173 | accelerator.wait_for_everyone() 174 | unwrapped_model = accelerator.unwrap_model(model) 175 | accelerator.wait_for_everyone() 176 | if args.log: 177 | if is_master: 178 | unwrapped_model.save_pretrained(os.path.join(fpath, model_name), state_dict=state_dict, safe_serialization=True) 179 | accelerator.wait_for_everyone() 180 | 181 | 182 | for epoch in range(args.epoch): 183 | optimizer.zero_grad() 184 | step_cnt = 0 185 | epoch_step = 0 186 | avg_loss = 0 187 | avg_pos_logits, avg_neg_logits = 0, 0 188 | for (i, batch) in tqdm(enumerate(dataloader), total=len(dataloader), disable=not is_master): 189 | with accelerator.accumulate(model): 190 | step_cnt += 1 191 | # forward pass 192 | input_ids = batch["input_ids"] 193 | attention_mask = batch["attention_mask"] 194 | output = model( 195 | input_ids=input_ids, 196 | attention_mask=attention_mask, 197 | output_hidden_states=False, 198 | use_cache=False, 199 | ) 200 | output = output[0] 201 | output = output[:, :-1] # truncate last logit 202 | labels = input_ids[:, 1:] # shift labels 203 | output = output.to(torch.float32) 204 | logits = torch.log_softmax(output, dim=-1) 205 | logits = logits.gather(2, labels.unsqueeze(2)).squeeze(2) 206 | masks = batch["masks"][:, 1:] # actual mask 207 | masks = masks.float() 208 | logits = logits * masks 209 | batch_size = logits.size(0) // 2 210 | logits = logits.sum(dim=1) 211 | pos_logits, neg_logits = logits[:batch_size], logits[batch_size:] 212 | pos_ref_logits = batch["chosen_logprob"] 213 | neg_ref_logits = batch["rejected_logprob"] 214 | dpo_loss, _, _ = dpo_loss_fn(pos_logits, neg_logits, pos_ref_logits, neg_ref_logits, args.beta, args.ref_free) 215 | dpo_loss = dpo_loss.mean() 216 | loss = args.dpo_weight * dpo_loss 217 | avg_loss += loss.detach().float() / args.accumulate_step 218 | avg_pos_logits += pos_logits.mean().detach().float() / args.accumulate_step 219 | avg_neg_logits += neg_logits.mean().detach().float() / args.accumulate_step 220 | accelerator.backward(loss) 221 | # updating 222 | optimizer.step() 223 | optimizer.zero_grad() 224 | scheduler.step() 225 | if args.empty_cache: 226 | get_accelerator().empty_cache() 227 | lr = optimizer.param_groups[0]['lr'] 228 | if accelerator.sync_gradients: 229 | if step_cnt == args.accumulate_step: 230 | step_cnt = 0 231 | epoch_step += 1 232 | all_step_cnt += 1 233 | if all_step_cnt % args.report_freq == 0 and all_step_cnt > 0 and step_cnt == 0: 234 | # report stats 235 | avg_loss = accelerator.gather(avg_loss).mean().item() 236 | avg_pos_logits = accelerator.gather(avg_pos_logits).mean().item() 237 | avg_neg_logits = accelerator.gather(avg_neg_logits).mean().item() 238 | if is_master: 239 | print("id: %d"%id) 240 | recorder.print("epoch: %d, batch: %d, avg loss: %.6f"%(epoch+1, epoch_step, avg_loss / args.report_freq)) 241 | recorder.print(f"learning rate: {lr:.10f}") 242 | recorder.plot( 243 | "loss", 244 | { 245 | "loss": avg_loss / args.report_freq, 246 | }, 247 | all_step_cnt 248 | ) 249 | recorder.plot( 250 | "logits", 251 | { 252 | "pos_logits": avg_pos_logits / args.report_freq, 253 | "neg_logits": avg_neg_logits / args.report_freq, 254 | }, 255 | all_step_cnt 256 | ) 257 | recorder.plot("lr", {"lr": lr}, all_step_cnt) 258 | recorder.print() 259 | avg_loss = 0 260 | avg_pos_logits, avg_neg_logits = 0, 0 261 | 262 | 263 | if (all_step_cnt % args.eval_interval == 0 and all_step_cnt > 0 and step_cnt == 0) or (i == len(dataloader) - 1): 264 | result = test(val_dataloader, model, args, is_master) 265 | overall_loss = result["loss"] 266 | overall_loss = accelerator.gather(overall_loss).mean().item() 267 | eval_pos_logits = accelerator.gather(result["pos_logits"]).mean().item() 268 | eval_neg_logits = accelerator.gather(result["neg_logits"]).mean().item() 269 | if overall_loss < minimum_loss: 270 | minimum_loss = overall_loss 271 | save_with_accelerate(model, "model") 272 | if is_master: 273 | recorder.print("best overall loss - epoch: %d"%(epoch)) 274 | if is_master: 275 | recorder.print("loss: %.6f"%(overall_loss)) 276 | recorder.plot( 277 | "loss", 278 | { 279 | "val_loss": result["loss"], 280 | }, 281 | all_step_cnt 282 | ) 283 | recorder.print(f"pos logits: {eval_pos_logits:.6f}, neg logits: {eval_neg_logits:.6f}") 284 | recorder.plot( 285 | "logits", 286 | { 287 | "val_pos_logits": eval_pos_logits, 288 | "val_neg_logits": eval_neg_logits 289 | }, 290 | all_step_cnt 291 | ) 292 | save_with_accelerate(model, f"model_cur") 293 | 294 | 295 | def main(): 296 | parser = argparse.ArgumentParser(description='Parameters') 297 | parser.add_argument("-l", "--log", action="store_true", help="logging") 298 | parser.add_argument("--model_pt", default="", type=str, help="model path, if given, load model from this path") 299 | parser.add_argument("--epoch", type=int, default=3, help="number of epochs") 300 | parser.add_argument("--dataset", type=str, help="dataset") 301 | parser.add_argument("--data_split", type=str, help="data_split") 302 | parser.add_argument("--beta", type=float, help="beta") 303 | parser.add_argument("--ref_free", action="store_true", help="reference free") 304 | parser.add_argument("--model_type", type=str, help="model type") 305 | parser.add_argument("--insert_eos", action="store_true", help="insert eos") 306 | parser.add_argument("--dpo_weight", type=float, default=1, help="dpo weight") 307 | parser.add_argument("--eval_interval", type=int, default=500, help="evaluation interval") 308 | args = parser.parse_args() 309 | run(args) 310 | 311 | 312 | if __name__ == "__main__": 313 | main() 314 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | aiohttp==3.9.3 2 | aiosignal==1.3.1 3 | alpaca_eval==0.6 4 | annotated-types==0.6.0 5 | anyio==4.3.0 6 | async-timeout==4.0.3 7 | attrs==23.2.0 8 | certifi==2024.2.2 9 | charset-normalizer==3.3.2 10 | datasets==2.18.0 11 | dill==0.3.8 12 | distro==1.9.0 13 | exceptiongroup==1.2.0 14 | filelock==3.13.3 15 | fire==0.6.0 16 | frozenlist==1.4.1 17 | fsspec==2024.2.0 18 | h11==0.14.0 19 | httpcore==1.0.4 20 | httpx==0.27.0 21 | huggingface-hub==0.22.1 22 | idna==3.6 23 | joblib==1.3.2 24 | multidict==6.0.5 25 | multiprocess==0.70.16 26 | numpy==1.26.4 27 | openai==1.14.3 28 | packaging==24.0 29 | pandas==2.2.1 30 | patsy==0.5.6 31 | pyarrow==15.0.2 32 | pyarrow-hotfix==0.6 33 | pydantic==2.6.4 34 | pydantic_core==2.16.3 35 | python-dateutil==2.9.0.post0 36 | python-dotenv==1.0.1 37 | pytz==2024.1 38 | PyYAML==6.0.1 39 | regex==2023.12.25 40 | requests==2.31.0 41 | scikit-learn==1.4.1.post1 42 | scipy==1.12.0 43 | six==1.16.0 44 | sniffio==1.3.1 45 | termcolor==2.4.0 46 | threadpoolctl==3.4.0 47 | tiktoken==0.6.0 48 | tqdm==4.66.2 49 | typing_extensions==4.10.0 50 | tzdata==2024.1 51 | urllib3==2.2.1 52 | xxhash==3.4.1 53 | yarl==1.9.4 54 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.utils.tensorboard import SummaryWriter 4 | from datetime import datetime 5 | 6 | 7 | class Recorder(): 8 | def __init__(self, id, log=True, base_dir="cache"): 9 | self.log = log 10 | now = datetime.now() 11 | date = now.strftime("%y-%m-%d") 12 | self.dir = os.path.join(base_dir, f"{date}-{id}") 13 | if self.log: 14 | os.mkdir(self.dir) 15 | self.f = open(os.path.join(self.dir, "log.txt"), "w") 16 | self.writer = SummaryWriter(os.path.join(self.dir, "log"), flush_secs=60) 17 | 18 | def write_config(self, args, models, name): 19 | if self.log: 20 | with open(os.path.join(self.dir, "config.txt"), "w") as f: 21 | print(name, file=f) 22 | print(args, file=f) 23 | print(file=f) 24 | for (i, x) in enumerate(models): 25 | print(x, file=f) 26 | print(file=f) 27 | print(args) 28 | print() 29 | for (i, x) in enumerate(models): 30 | print(x) 31 | print() 32 | 33 | def print(self, x=None): 34 | if x is not None: 35 | print(x, flush=True) 36 | else: 37 | print(flush=True) 38 | if self.log: 39 | if x is not None: 40 | print(x, file=self.f, flush=True) 41 | else: 42 | print(file=self.f, flush=True) 43 | 44 | def plot(self, tag, values, step): 45 | if self.log: 46 | self.writer.add_scalars(tag, values, step) 47 | 48 | 49 | def __del__(self): 50 | if self.log: 51 | self.f.close() 52 | self.writer.close() 53 | 54 | def save(self, model, name): 55 | if self.log: 56 | torch.save(model.state_dict(), os.path.join(self.dir, name)) 57 | 58 | def save_pretrained(self, model, name, **kwargs): 59 | if self.log: 60 | model.save_pretrained(os.path.join(self.dir, name), **kwargs) --------------------------------------------------------------------------------