├── README.md
├── requirements.txt
├── run.sh
└── src
├── hf_train.py
└── train.py
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | [![Contributors][contributors-shield]][contributors-url]
6 | [![Forks][forks-shield]][forks-url]
7 | [![Stargazers][stars-shield]][stars-url]
8 | [![Issues][issues-shield]][issues-url]
9 | [![LinkedIn][linkedin-shield]][linkedin-url]
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
Direct Preference Optimization from scratch in PyTorch
18 |
19 |
20 |
21 |
22 | Report Bug
23 | ·
24 | Request Feature
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 | ## About The Project
33 |
34 | This project is an implementation of Direct Preference Optimization, an alternative to RLHF for aligning Large Language Models (LLMs) to human. The algorithm is described in the research paper [Direct Preference Optimization: Your Language Model is Secretly a Reward Model
35 | ](https://arxiv.org/abs/2305.18290).
36 |
37 | Direct Preference Optimization (DPO) is a promising and efficient technique for fine-tuning Large Language Models (LLMs) aligned with human preferences. Compared to traditional Reinforcement Learning From Human Feedback (RLHF), DPO eliminates the need for a separate reward model and simplifies the training process, leading to better stability and computational efficiency.
38 |
39 | The key insight in Direct Preference Optimization is replacing the complex reward modeling process in RLHF with a simple loss function that directly optimizes for human preferences in closed form. It does this by simply increasing the log probability of the tokens in the human prefered responses, and decreasing the log probability of the tokens in the human disprefered responses, given a preferences dataset, which basically makes the model have an implicit reward function that is directly optimized for human preferences. Through this clever math trick, the process now becomes much simpler and more efficient than RLHF, as it does not require a separate reward model, and it is also more stable, as it does not use other methods like PPO for fine-tuning.
40 |
41 | The DPO loss function is defined as follows:
42 |
43 | $$
44 | L_\text{DPO}(\pi_{\theta}; \pi_\text{ref}) = -E_{(x, y_w, y_l)\sim D}\left[\log \sigma \left(
45 | \beta \log \frac{\pi_{\theta}(y_w\mid x)}{\pi_\text{ref}(y_w\mid x)} \thinspace
46 | {- \beta \log \frac{\pi_{\theta}(y_l\mid x)}{\pi_\text{ref}(y_l\mid x)}}\right)\right]
47 | $$
48 |
49 | where:
50 |
51 | - $\pi_{\theta}$ is the language model we want to fine-tune
52 | - $\pi_\text{ref}$ is a reference model, usually a frozen version of the original pre-trained language model
53 | - $D$ is the dataset of preferences
54 | - $x$ is a sample prompt from the dataset $D$
55 | - $y_w$ is the human prefered response to the prompt $x$
56 | - $y_l$ is the human disprefered response to the prompt $x$
57 | - $\beta$ is a hyperparameter that controls the amount of divergence from the reference model $\pi_\text{ref}$
58 |
59 | The DPO loss function can be broken down into two main terms, the first term represents the log probability of the human-preferred response $y_w$. This term aims to maximize the probability of $y_w$ as generated by the model $\pi_{\theta}$, relative to the reference model $\pi_{\text{ref}}$. The division by $\pi_{\text{ref}}$ serves as a regularizing factor, ensuring that the fine-tuning does not cause the model to deviate excessively from its original training. Maximizing this term effectively increases the likelihood of $\pi_{\theta}$ generating responses similar to $y_w$ in response to inputs like $x$, reinforcing the human preference patterns. Conversely, the second term focuses on minimizing the log probability of the human-dispreferred response $y_l$. This is achieved by reducing the model's tendency to generate $y_l$ type responses, as indicated by the negative sign.
60 |
61 | The hyperparameter $\beta$, typically set between 0.1 and 0.5, affects the amount of divergence from the reference model $\pi_\text{ref}$, allowing for controlled adjustments in the model's outputs while preventing significant deviations from the behavior of the reference model. The entire computation is then simply averaged across the dataset $D$ or a batch of samples from it, giving us the final DPO loss that we can optimize for using gradient descent to fine-tune the language model.
62 |
63 |
64 | For a detailed explanation, you can check my blog post [Unveiling the Hidden Reward System in Language Models: A Dive into DPO](https://allam.vercel.app/post/dpo/)
65 |
66 |
67 | [contributors-shield]: https://img.shields.io/github/contributors/ahmed-alllam/Direct-Preference-Optimization.svg?style=for-the-badge
68 | [contributors-url]: https://github.com/ahmed-alllam/Direct-Preference-Optimization/graphs/contributors
69 | [forks-shield]: https://img.shields.io/github/forks/ahmed-alllam/Direct-Preference-Optimization.svg?style=for-the-badge
70 | [forks-url]: https://github.com/ahmed-alllam/Direct-Preference-Optimization/network/members
71 | [stars-shield]: https://img.shields.io/github/stars/ahmed-alllam/Direct-Preference-Optimization.svg?style=for-the-badge
72 | [stars-url]: https://github.com/ahmed-alllam/Direct-Preference-Optimization/stargazers
73 | [issues-shield]: https://img.shields.io/github/issues/ahmed-alllam/Direct-Preference-Optimization.svg?style=for-the-badge
74 | [issues-url]: https://github.com/ahmed-alllam/Direct-Preference-Optimization/issues
75 | [license-shield]: https://img.shields.io/github/license/ahmed-alllam/Direct-Preference-Optimization.svg?style=for-the-badge
76 | [license-url]: https://github.com/ahmed-alllam/Direct-Preference-Optimization/blob/master/LICENSE.txt
77 | [linkedin-shield]: https://img.shields.io/badge/-LinkedIn-black.svg?style=for-the-badge&logo=linkedin&colorB=555
78 | [linkedin-url]: https://linkedin.com/in/ahmed-e-allam
79 | [product-screenshot]: images/screenshot.png
80 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | argparse==1.4.0
2 | numpy==1.26.3
3 | torch==2.1.2
4 | datasets==2.16.1
5 | transformers==4.37.0
6 | wandb==0.16.2
7 | tqdm==4.66.1
8 | trl==0.7.10
9 |
--------------------------------------------------------------------------------
/run.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env sh
2 |
3 | python src/train.py \
4 | --epochs 10 \
5 | --batch_size 64 \
6 | --max_length 512 \
7 | --lr 1e-6 \
8 | --beta 0.1 \
9 | --seed 2003 \
10 | --model_name "microsoft/phi-2" \
11 | --dataset_name "jondurbin/truthy-dpo-v0.1" \
12 | --wandb_project "truthy-dpo"
13 |
--------------------------------------------------------------------------------
/src/hf_train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import random
3 | import numpy as np
4 |
5 | import torch
6 |
7 | from datasets import load_dataset
8 | from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
9 |
10 | from trl import DPOTrainer
11 |
12 | import wandb
13 |
14 | def seed_everything(seed=2003):
15 | torch.manual_seed(seed)
16 | torch.cuda.manual_seed_all(seed)
17 | np.random.seed(seed)
18 | random.seed(seed)
19 | torch.backends.cudnn.deterministic = True
20 |
21 | def preprocess_data(item):
22 | return {
23 | 'prompt': 'Instruct: ' + item['prompt'] + '\n',
24 | 'chosen': 'Output: ' + item['chosen'],
25 | 'rejected': 'Output: ' + item['rejected']
26 | }
27 |
28 | def train(model, ref_model, dataset, tokenizer, beta, training_args):
29 | model.train()
30 | ref_model.eval()
31 |
32 | dpo_trainer = DPOTrainer(
33 | model,
34 | ref_model,
35 | beta=beta,
36 | train_dataset=dataset,
37 | tokenizer=tokenizer,
38 | args=training_args,
39 | max_length=1024,
40 | max_prompt_length=512
41 | )
42 |
43 | dpo_trainer.train()
44 |
45 | def main():
46 | parser = argparse.ArgumentParser()
47 |
48 | parser.add_argument("--epochs", type=int, default=1)
49 | parser.add_argument("--beta", type=float, default=0.1)
50 | parser.add_argument("--batch_size", type=int, default=4)
51 | parser.add_argument("--lr", type=float, default=1e-6)
52 | parser.add_argument("--seed", type=int, default=2003)
53 | parser.add_argument("--model_name", type=str, default="microsoft/phi-2")
54 | parser.add_argument("--dataset_name", type=str, default="jondurbin/truthy-dpo-v0.1")
55 | parser.add_argument("--wandb_project", type=str, default="truthy-dpo")
56 |
57 | args = parser.parse_args()
58 |
59 | seed_everything(args.seed)
60 |
61 | wandb.login()
62 | wandb.init(project=args.wandb_project, config=args)
63 |
64 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
65 |
66 | tokenizer = AutoTokenizer.from_pretrained(args.model_name)
67 | tokenizer.pad_token = tokenizer.eos_token
68 | model = AutoModelForCausalLM.from_pretrained(args.model_name).to(device)
69 | ref_model = AutoModelForCausalLM.from_pretrained(args.model_name).to(device)
70 |
71 | dataset = load_dataset(args.dataset_name, split="train")
72 | dataset = dataset.map(preprocess_data)
73 |
74 | training_args = TrainingArguments(
75 | learning_rate=args.lr,
76 | num_train_epochs=args.epochs,
77 | per_device_train_batch_size=args.batch_size,
78 | report_to="wandb",
79 | output_dir='./results',
80 | logging_steps=10,
81 | remove_unused_columns=False,
82 | )
83 |
84 | train(model, ref_model, dataset, tokenizer, args.beta, training_args)
85 |
86 | model.save_pretrained("model-HF-DPO.pt")
87 |
88 | if __name__ == "__main__":
89 | main()
90 |
--------------------------------------------------------------------------------
/src/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import random
3 | import numpy as np
4 | from functools import partial
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | from torch.optim import AdamW
10 |
11 | from torch.utils.data import DataLoader
12 | from datasets import load_dataset
13 | from transformers import AutoTokenizer, AutoModelForCausalLM
14 |
15 | import wandb
16 | from tqdm import tqdm
17 |
18 | def seed_everything(seed=2003):
19 | torch.manual_seed(seed)
20 | torch.cuda.manual_seed_all(seed)
21 | np.random.seed(seed)
22 | random.seed(seed)
23 | torch.backends.cudnn.deterministic = True
24 |
25 | def calculate_DPO_loss(model_preferred_logprob, model_dispreferred_logprob,
26 | ref_preferred_logprob, ref_dispreferred_logprob,
27 | beta=0.5):
28 |
29 | preferred_relative_logprob = model_preferred_logprob - ref_preferred_logprob
30 | dispreferred_relative_logprob = model_dispreferred_logprob - ref_dispreferred_logprob
31 |
32 | reward_accuracies = (preferred_relative_logprob > dispreferred_relative_logprob).float().mean()
33 | reward_margins = (preferred_relative_logprob - dispreferred_relative_logprob).mean()
34 |
35 | loss = -F.logsigmoid(beta * (preferred_relative_logprob - dispreferred_relative_logprob)).mean()
36 |
37 | return loss, preferred_relative_logprob.mean(), dispreferred_relative_logprob.mean(), reward_accuracies, reward_margins
38 |
39 | def get_log_prob(logits, labels, prompt_lengths):
40 | log_probs = F.log_softmax(logits, dim=-1)
41 | token_log_probs = torch.gather(log_probs, -1, labels.unsqueeze(-1)).squeeze(-1)
42 |
43 | batch_size, seq_len = labels.shape
44 | response_mask = torch.arange(seq_len, device=labels.device).unsqueeze(0) >= prompt_lengths.unsqueeze(1)
45 | response_mask = response_mask.float()
46 |
47 | response_log_probs = (token_log_probs * response_mask).sum(dim=-1)
48 | response_lengths = response_mask.sum(dim=-1).clamp(min=1)
49 | return response_log_probs / response_lengths
50 |
51 | def collate_fn(batch, tokenizer, max_length, device):
52 | prompt_encodings = tokenizer(
53 | ['Instruct: ' + item['prompt'] + '\n' for item in batch],
54 | padding='max_length',
55 | truncation=True,
56 | max_length=max_length,
57 | return_tensors='pt'
58 | )
59 |
60 | chosen_encodings = tokenizer(
61 | ['Output: ' + item['chosen'] for item in batch],
62 | padding='max_length',
63 | truncation=True,
64 | max_length=max_length,
65 | return_tensors='pt'
66 | )
67 |
68 | rejected_encodings = tokenizer(
69 | ['Output: ' + item['rejected'] for item in batch],
70 | padding='max_length',
71 | truncation=True,
72 | max_length=max_length,
73 | return_tensors='pt'
74 | )
75 |
76 | prompt_preferred_ids = torch.cat([
77 | prompt_encodings.input_ids,
78 | chosen_encodings.input_ids
79 | ], dim=-1).to(device)
80 |
81 | prompt_dispreferred_ids = torch.cat([
82 | prompt_encodings.input_ids,
83 | rejected_encodings.input_ids
84 | ], dim=-1).to(device)
85 |
86 | prompt_preferred_mask = torch.cat([
87 | prompt_encodings.attention_mask,
88 | chosen_encodings.attention_mask
89 | ], dim=-1).to(device)
90 |
91 | prompt_dispreferred_mask = torch.cat([
92 | prompt_encodings.attention_mask,
93 | rejected_encodings.attention_mask
94 | ], dim=-1).to(device)
95 |
96 | prompt_lengths = prompt_encodings.attention_mask.sum(dim=-1)
97 |
98 | return {
99 | 'prompt_preferred_ids': prompt_preferred_ids,
100 | 'prompt_dispreferred_ids': prompt_dispreferred_ids,
101 | 'prompt_preferred_mask': prompt_preferred_mask,
102 | 'prompt_dispreferred_mask': prompt_dispreferred_mask,
103 | 'prompt_lengths': prompt_lengths
104 | }
105 |
106 | def train(model, ref_model, tokenizer, optimizer, train_dataloader, epochs=1, beta=0.1):
107 | model.train()
108 | ref_model.eval()
109 |
110 | for epoch in range(epochs):
111 | for batch in tqdm(train_dataloader):
112 | optimizer.zero_grad()
113 |
114 | model_preferred_logits = model(
115 | input_ids=batch['prompt_preferred_ids'],
116 | attention_mask=batch['prompt_preferred_mask']
117 | ).logits
118 |
119 | model_preferred_logprob = get_log_prob(
120 | model_preferred_logits,
121 | batch['prompt_preferred_ids'],
122 | batch['prompt_lengths']
123 | )
124 |
125 | model_dispreferred_logits = model(
126 | input_ids=batch['prompt_dispreferred_ids'],
127 | attention_mask=batch['prompt_dispreferred_mask']
128 | ).logits
129 |
130 | model_dispreferred_logprob = get_log_prob(
131 | model_dispreferred_logits,
132 | batch['prompt_dispreferred_ids'],
133 | batch['prompt_lengths']
134 | )
135 |
136 | with torch.no_grad():
137 | ref_preferred_logits = ref_model(
138 | input_ids=batch['prompt_preferred_ids'],
139 | attention_mask=batch['prompt_preferred_mask']
140 | ).logits
141 |
142 | ref_preferred_logprob = get_log_prob(
143 | ref_preferred_logits,
144 | batch['prompt_preferred_ids'],
145 | batch['prompt_lengths']
146 | )
147 |
148 | ref_dispreferred_logits = ref_model(
149 | input_ids=batch['prompt_dispreferred_ids'],
150 | attention_mask=batch['prompt_dispreferred_mask']
151 | ).logits
152 |
153 | ref_dispreferred_logprob = get_log_prob(
154 | ref_dispreferred_logits,
155 | batch['prompt_dispreferred_ids'],
156 | batch['prompt_lengths']
157 | )
158 |
159 | loss, preferred_relative_logprob, dispreferred_relative_logprob, reward_accuracies, reward_margins = calculate_DPO_loss(
160 | model_preferred_logprob,
161 | model_dispreferred_logprob,
162 | ref_preferred_logprob,
163 | ref_dispreferred_logprob,
164 | beta=beta
165 | )
166 |
167 | loss.backward()
168 | optimizer.step()
169 |
170 | wandb.log({
171 | 'loss': loss.item(),
172 | 'preferred_relative_logprob': preferred_relative_logprob.item(),
173 | 'dispreferred_relative_logprob': dispreferred_relative_logprob.item(),
174 | 'reward_accuracy': reward_accuracies.item(),
175 | 'reward_margin': reward_margins.item()
176 | })
177 |
178 | def main():
179 | parser = argparse.ArgumentParser()
180 |
181 | parser.add_argument("--epochs", type=int, default=1)
182 | parser.add_argument("--beta", type=float, default=0.1)
183 | parser.add_argument("--batch_size", type=int, default=4)
184 | parser.add_argument("--max_length", type=int, default=512)
185 | parser.add_argument("--lr", type=float, default=1e-6)
186 | parser.add_argument("--seed", type=int, default=2003)
187 | parser.add_argument("--model_name", type=str, default="microsoft/phi-2")
188 | parser.add_argument("--dataset_name", type=str, default="jondurbin/truthy-dpo-v0.1")
189 | parser.add_argument("--wandb_project", type=str, default="truthy-dpo")
190 |
191 | args = parser.parse_args()
192 |
193 | seed_everything(args.seed)
194 |
195 | wandb.login()
196 | wandb.init(project=args.wandb_project, config=args)
197 |
198 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
199 |
200 | tokenizer = AutoTokenizer.from_pretrained(args.model_name)
201 | tokenizer.pad_token = tokenizer.eos_token
202 | model = AutoModelForCausalLM.from_pretrained(args.model_name).to(device)
203 | ref_model = AutoModelForCausalLM.from_pretrained(args.model_name).to(device)
204 |
205 | ref_model.requires_grad_(False)
206 |
207 | optimizer = AdamW(model.parameters(), lr=args.lr)
208 |
209 | dataset = load_dataset(args.dataset_name, split="train")
210 | collate = partial(collate_fn, tokenizer=tokenizer, max_length=args.max_length, device=device)
211 | train_dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, collate_fn=collate)
212 |
213 | train(model, ref_model, tokenizer, optimizer, train_dataloader, epochs=args.epochs, beta=args.beta)
214 |
215 | model.save_pretrained("model-DPO")
216 |
217 | if __name__ == "__main__":
218 | main()
219 |
--------------------------------------------------------------------------------