├── README.md ├── requirements.txt ├── run.sh └── src ├── hf_train.py └── train.py /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | [![Contributors][contributors-shield]][contributors-url] 6 | [![Forks][forks-shield]][forks-url] 7 | [![Stargazers][stars-shield]][stars-url] 8 | [![Issues][issues-shield]][issues-url] 9 | [![LinkedIn][linkedin-shield]][linkedin-url] 10 | 11 | 12 | 13 | 14 |
15 |
16 | 17 |

Direct Preference Optimization from scratch in PyTorch

18 | 19 |

20 | 21 | 22 | Report Bug 23 | · 24 | Request Feature 25 |

26 |
27 | 28 | 29 | 30 | 31 | 32 | ## About The Project 33 | 34 | This project is an implementation of Direct Preference Optimization, an alternative to RLHF for aligning Large Language Models (LLMs) to human. The algorithm is described in the research paper [Direct Preference Optimization: Your Language Model is Secretly a Reward Model 35 | ](https://arxiv.org/abs/2305.18290). 36 | 37 | Direct Preference Optimization (DPO) is a promising and efficient technique for fine-tuning Large Language Models (LLMs) aligned with human preferences. Compared to traditional Reinforcement Learning From Human Feedback (RLHF), DPO eliminates the need for a separate reward model and simplifies the training process, leading to better stability and computational efficiency. 38 | 39 | The key insight in Direct Preference Optimization is replacing the complex reward modeling process in RLHF with a simple loss function that directly optimizes for human preferences in closed form. It does this by simply increasing the log probability of the tokens in the human prefered responses, and decreasing the log probability of the tokens in the human disprefered responses, given a preferences dataset, which basically makes the model have an implicit reward function that is directly optimized for human preferences. Through this clever math trick, the process now becomes much simpler and more efficient than RLHF, as it does not require a separate reward model, and it is also more stable, as it does not use other methods like PPO for fine-tuning. 40 | 41 | The DPO loss function is defined as follows: 42 | 43 | $$ 44 | L_\text{DPO}(\pi_{\theta}; \pi_\text{ref}) = -E_{(x, y_w, y_l)\sim D}\left[\log \sigma \left( 45 | \beta \log \frac{\pi_{\theta}(y_w\mid x)}{\pi_\text{ref}(y_w\mid x)} \thinspace 46 | {- \beta \log \frac{\pi_{\theta}(y_l\mid x)}{\pi_\text{ref}(y_l\mid x)}}\right)\right] 47 | $$ 48 | 49 | where: 50 | 51 | - $\pi_{\theta}$ is the language model we want to fine-tune 52 | - $\pi_\text{ref}$ is a reference model, usually a frozen version of the original pre-trained language model 53 | - $D$ is the dataset of preferences 54 | - $x$ is a sample prompt from the dataset $D$ 55 | - $y_w$ is the human prefered response to the prompt $x$ 56 | - $y_l$ is the human disprefered response to the prompt $x$ 57 | - $\beta$ is a hyperparameter that controls the amount of divergence from the reference model $\pi_\text{ref}$ 58 | 59 | The DPO loss function can be broken down into two main terms, the first term represents the log probability of the human-preferred response $y_w$. This term aims to maximize the probability of $y_w$ as generated by the model $\pi_{\theta}$, relative to the reference model $\pi_{\text{ref}}$. The division by $\pi_{\text{ref}}$ serves as a regularizing factor, ensuring that the fine-tuning does not cause the model to deviate excessively from its original training. Maximizing this term effectively increases the likelihood of $\pi_{\theta}$ generating responses similar to $y_w$ in response to inputs like $x$, reinforcing the human preference patterns. Conversely, the second term focuses on minimizing the log probability of the human-dispreferred response $y_l$. This is achieved by reducing the model's tendency to generate $y_l$ type responses, as indicated by the negative sign. 60 | 61 | The hyperparameter $\beta$, typically set between 0.1 and 0.5, affects the amount of divergence from the reference model $\pi_\text{ref}$, allowing for controlled adjustments in the model's outputs while preventing significant deviations from the behavior of the reference model. The entire computation is then simply averaged across the dataset $D$ or a batch of samples from it, giving us the final DPO loss that we can optimize for using gradient descent to fine-tune the language model. 62 | 63 | 64 | For a detailed explanation, you can check my blog post [Unveiling the Hidden Reward System in Language Models: A Dive into DPO](https://allam.vercel.app/post/dpo/) 65 | 66 | 67 | [contributors-shield]: https://img.shields.io/github/contributors/ahmed-alllam/Direct-Preference-Optimization.svg?style=for-the-badge 68 | [contributors-url]: https://github.com/ahmed-alllam/Direct-Preference-Optimization/graphs/contributors 69 | [forks-shield]: https://img.shields.io/github/forks/ahmed-alllam/Direct-Preference-Optimization.svg?style=for-the-badge 70 | [forks-url]: https://github.com/ahmed-alllam/Direct-Preference-Optimization/network/members 71 | [stars-shield]: https://img.shields.io/github/stars/ahmed-alllam/Direct-Preference-Optimization.svg?style=for-the-badge 72 | [stars-url]: https://github.com/ahmed-alllam/Direct-Preference-Optimization/stargazers 73 | [issues-shield]: https://img.shields.io/github/issues/ahmed-alllam/Direct-Preference-Optimization.svg?style=for-the-badge 74 | [issues-url]: https://github.com/ahmed-alllam/Direct-Preference-Optimization/issues 75 | [license-shield]: https://img.shields.io/github/license/ahmed-alllam/Direct-Preference-Optimization.svg?style=for-the-badge 76 | [license-url]: https://github.com/ahmed-alllam/Direct-Preference-Optimization/blob/master/LICENSE.txt 77 | [linkedin-shield]: https://img.shields.io/badge/-LinkedIn-black.svg?style=for-the-badge&logo=linkedin&colorB=555 78 | [linkedin-url]: https://linkedin.com/in/ahmed-e-allam 79 | [product-screenshot]: images/screenshot.png 80 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | argparse==1.4.0 2 | numpy==1.26.3 3 | torch==2.1.2 4 | datasets==2.16.1 5 | transformers==4.37.0 6 | wandb==0.16.2 7 | tqdm==4.66.1 8 | trl==0.7.10 9 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | 3 | python src/train.py \ 4 | --epochs 10 \ 5 | --batch_size 64 \ 6 | --max_length 512 \ 7 | --lr 1e-6 \ 8 | --beta 0.1 \ 9 | --seed 2003 \ 10 | --model_name "microsoft/phi-2" \ 11 | --dataset_name "jondurbin/truthy-dpo-v0.1" \ 12 | --wandb_project "truthy-dpo" 13 | -------------------------------------------------------------------------------- /src/hf_train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | import numpy as np 4 | 5 | import torch 6 | 7 | from datasets import load_dataset 8 | from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments 9 | 10 | from trl import DPOTrainer 11 | 12 | import wandb 13 | 14 | def seed_everything(seed=2003): 15 | torch.manual_seed(seed) 16 | torch.cuda.manual_seed_all(seed) 17 | np.random.seed(seed) 18 | random.seed(seed) 19 | torch.backends.cudnn.deterministic = True 20 | 21 | def preprocess_data(item): 22 | return { 23 | 'prompt': 'Instruct: ' + item['prompt'] + '\n', 24 | 'chosen': 'Output: ' + item['chosen'], 25 | 'rejected': 'Output: ' + item['rejected'] 26 | } 27 | 28 | def train(model, ref_model, dataset, tokenizer, beta, training_args): 29 | model.train() 30 | ref_model.eval() 31 | 32 | dpo_trainer = DPOTrainer( 33 | model, 34 | ref_model, 35 | beta=beta, 36 | train_dataset=dataset, 37 | tokenizer=tokenizer, 38 | args=training_args, 39 | max_length=1024, 40 | max_prompt_length=512 41 | ) 42 | 43 | dpo_trainer.train() 44 | 45 | def main(): 46 | parser = argparse.ArgumentParser() 47 | 48 | parser.add_argument("--epochs", type=int, default=1) 49 | parser.add_argument("--beta", type=float, default=0.1) 50 | parser.add_argument("--batch_size", type=int, default=4) 51 | parser.add_argument("--lr", type=float, default=1e-6) 52 | parser.add_argument("--seed", type=int, default=2003) 53 | parser.add_argument("--model_name", type=str, default="microsoft/phi-2") 54 | parser.add_argument("--dataset_name", type=str, default="jondurbin/truthy-dpo-v0.1") 55 | parser.add_argument("--wandb_project", type=str, default="truthy-dpo") 56 | 57 | args = parser.parse_args() 58 | 59 | seed_everything(args.seed) 60 | 61 | wandb.login() 62 | wandb.init(project=args.wandb_project, config=args) 63 | 64 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 65 | 66 | tokenizer = AutoTokenizer.from_pretrained(args.model_name) 67 | tokenizer.pad_token = tokenizer.eos_token 68 | model = AutoModelForCausalLM.from_pretrained(args.model_name).to(device) 69 | ref_model = AutoModelForCausalLM.from_pretrained(args.model_name).to(device) 70 | 71 | dataset = load_dataset(args.dataset_name, split="train") 72 | dataset = dataset.map(preprocess_data) 73 | 74 | training_args = TrainingArguments( 75 | learning_rate=args.lr, 76 | num_train_epochs=args.epochs, 77 | per_device_train_batch_size=args.batch_size, 78 | report_to="wandb", 79 | output_dir='./results', 80 | logging_steps=10, 81 | remove_unused_columns=False, 82 | ) 83 | 84 | train(model, ref_model, dataset, tokenizer, args.beta, training_args) 85 | 86 | model.save_pretrained("model-HF-DPO.pt") 87 | 88 | if __name__ == "__main__": 89 | main() 90 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | import numpy as np 4 | from functools import partial 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch.optim import AdamW 10 | 11 | from torch.utils.data import DataLoader 12 | from datasets import load_dataset 13 | from transformers import AutoTokenizer, AutoModelForCausalLM 14 | 15 | import wandb 16 | from tqdm import tqdm 17 | 18 | def seed_everything(seed=2003): 19 | torch.manual_seed(seed) 20 | torch.cuda.manual_seed_all(seed) 21 | np.random.seed(seed) 22 | random.seed(seed) 23 | torch.backends.cudnn.deterministic = True 24 | 25 | def calculate_DPO_loss(model_preferred_logprob, model_dispreferred_logprob, 26 | ref_preferred_logprob, ref_dispreferred_logprob, 27 | beta=0.5): 28 | 29 | preferred_relative_logprob = model_preferred_logprob - ref_preferred_logprob 30 | dispreferred_relative_logprob = model_dispreferred_logprob - ref_dispreferred_logprob 31 | 32 | reward_accuracies = (preferred_relative_logprob > dispreferred_relative_logprob).float().mean() 33 | reward_margins = (preferred_relative_logprob - dispreferred_relative_logprob).mean() 34 | 35 | loss = -F.logsigmoid(beta * (preferred_relative_logprob - dispreferred_relative_logprob)).mean() 36 | 37 | return loss, preferred_relative_logprob.mean(), dispreferred_relative_logprob.mean(), reward_accuracies, reward_margins 38 | 39 | def get_log_prob(logits, labels, prompt_lengths): 40 | log_probs = F.log_softmax(logits, dim=-1) 41 | token_log_probs = torch.gather(log_probs, -1, labels.unsqueeze(-1)).squeeze(-1) 42 | 43 | batch_size, seq_len = labels.shape 44 | response_mask = torch.arange(seq_len, device=labels.device).unsqueeze(0) >= prompt_lengths.unsqueeze(1) 45 | response_mask = response_mask.float() 46 | 47 | response_log_probs = (token_log_probs * response_mask).sum(dim=-1) 48 | response_lengths = response_mask.sum(dim=-1).clamp(min=1) 49 | return response_log_probs / response_lengths 50 | 51 | def collate_fn(batch, tokenizer, max_length, device): 52 | prompt_encodings = tokenizer( 53 | ['Instruct: ' + item['prompt'] + '\n' for item in batch], 54 | padding='max_length', 55 | truncation=True, 56 | max_length=max_length, 57 | return_tensors='pt' 58 | ) 59 | 60 | chosen_encodings = tokenizer( 61 | ['Output: ' + item['chosen'] for item in batch], 62 | padding='max_length', 63 | truncation=True, 64 | max_length=max_length, 65 | return_tensors='pt' 66 | ) 67 | 68 | rejected_encodings = tokenizer( 69 | ['Output: ' + item['rejected'] for item in batch], 70 | padding='max_length', 71 | truncation=True, 72 | max_length=max_length, 73 | return_tensors='pt' 74 | ) 75 | 76 | prompt_preferred_ids = torch.cat([ 77 | prompt_encodings.input_ids, 78 | chosen_encodings.input_ids 79 | ], dim=-1).to(device) 80 | 81 | prompt_dispreferred_ids = torch.cat([ 82 | prompt_encodings.input_ids, 83 | rejected_encodings.input_ids 84 | ], dim=-1).to(device) 85 | 86 | prompt_preferred_mask = torch.cat([ 87 | prompt_encodings.attention_mask, 88 | chosen_encodings.attention_mask 89 | ], dim=-1).to(device) 90 | 91 | prompt_dispreferred_mask = torch.cat([ 92 | prompt_encodings.attention_mask, 93 | rejected_encodings.attention_mask 94 | ], dim=-1).to(device) 95 | 96 | prompt_lengths = prompt_encodings.attention_mask.sum(dim=-1) 97 | 98 | return { 99 | 'prompt_preferred_ids': prompt_preferred_ids, 100 | 'prompt_dispreferred_ids': prompt_dispreferred_ids, 101 | 'prompt_preferred_mask': prompt_preferred_mask, 102 | 'prompt_dispreferred_mask': prompt_dispreferred_mask, 103 | 'prompt_lengths': prompt_lengths 104 | } 105 | 106 | def train(model, ref_model, tokenizer, optimizer, train_dataloader, epochs=1, beta=0.1): 107 | model.train() 108 | ref_model.eval() 109 | 110 | for epoch in range(epochs): 111 | for batch in tqdm(train_dataloader): 112 | optimizer.zero_grad() 113 | 114 | model_preferred_logits = model( 115 | input_ids=batch['prompt_preferred_ids'], 116 | attention_mask=batch['prompt_preferred_mask'] 117 | ).logits 118 | 119 | model_preferred_logprob = get_log_prob( 120 | model_preferred_logits, 121 | batch['prompt_preferred_ids'], 122 | batch['prompt_lengths'] 123 | ) 124 | 125 | model_dispreferred_logits = model( 126 | input_ids=batch['prompt_dispreferred_ids'], 127 | attention_mask=batch['prompt_dispreferred_mask'] 128 | ).logits 129 | 130 | model_dispreferred_logprob = get_log_prob( 131 | model_dispreferred_logits, 132 | batch['prompt_dispreferred_ids'], 133 | batch['prompt_lengths'] 134 | ) 135 | 136 | with torch.no_grad(): 137 | ref_preferred_logits = ref_model( 138 | input_ids=batch['prompt_preferred_ids'], 139 | attention_mask=batch['prompt_preferred_mask'] 140 | ).logits 141 | 142 | ref_preferred_logprob = get_log_prob( 143 | ref_preferred_logits, 144 | batch['prompt_preferred_ids'], 145 | batch['prompt_lengths'] 146 | ) 147 | 148 | ref_dispreferred_logits = ref_model( 149 | input_ids=batch['prompt_dispreferred_ids'], 150 | attention_mask=batch['prompt_dispreferred_mask'] 151 | ).logits 152 | 153 | ref_dispreferred_logprob = get_log_prob( 154 | ref_dispreferred_logits, 155 | batch['prompt_dispreferred_ids'], 156 | batch['prompt_lengths'] 157 | ) 158 | 159 | loss, preferred_relative_logprob, dispreferred_relative_logprob, reward_accuracies, reward_margins = calculate_DPO_loss( 160 | model_preferred_logprob, 161 | model_dispreferred_logprob, 162 | ref_preferred_logprob, 163 | ref_dispreferred_logprob, 164 | beta=beta 165 | ) 166 | 167 | loss.backward() 168 | optimizer.step() 169 | 170 | wandb.log({ 171 | 'loss': loss.item(), 172 | 'preferred_relative_logprob': preferred_relative_logprob.item(), 173 | 'dispreferred_relative_logprob': dispreferred_relative_logprob.item(), 174 | 'reward_accuracy': reward_accuracies.item(), 175 | 'reward_margin': reward_margins.item() 176 | }) 177 | 178 | def main(): 179 | parser = argparse.ArgumentParser() 180 | 181 | parser.add_argument("--epochs", type=int, default=1) 182 | parser.add_argument("--beta", type=float, default=0.1) 183 | parser.add_argument("--batch_size", type=int, default=4) 184 | parser.add_argument("--max_length", type=int, default=512) 185 | parser.add_argument("--lr", type=float, default=1e-6) 186 | parser.add_argument("--seed", type=int, default=2003) 187 | parser.add_argument("--model_name", type=str, default="microsoft/phi-2") 188 | parser.add_argument("--dataset_name", type=str, default="jondurbin/truthy-dpo-v0.1") 189 | parser.add_argument("--wandb_project", type=str, default="truthy-dpo") 190 | 191 | args = parser.parse_args() 192 | 193 | seed_everything(args.seed) 194 | 195 | wandb.login() 196 | wandb.init(project=args.wandb_project, config=args) 197 | 198 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 199 | 200 | tokenizer = AutoTokenizer.from_pretrained(args.model_name) 201 | tokenizer.pad_token = tokenizer.eos_token 202 | model = AutoModelForCausalLM.from_pretrained(args.model_name).to(device) 203 | ref_model = AutoModelForCausalLM.from_pretrained(args.model_name).to(device) 204 | 205 | ref_model.requires_grad_(False) 206 | 207 | optimizer = AdamW(model.parameters(), lr=args.lr) 208 | 209 | dataset = load_dataset(args.dataset_name, split="train") 210 | collate = partial(collate_fn, tokenizer=tokenizer, max_length=args.max_length, device=device) 211 | train_dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, collate_fn=collate) 212 | 213 | train(model, ref_model, tokenizer, optimizer, train_dataloader, epochs=args.epochs, beta=args.beta) 214 | 215 | model.save_pretrained("model-DPO") 216 | 217 | if __name__ == "__main__": 218 | main() 219 | --------------------------------------------------------------------------------