├── config ├── loss │ ├── sft.yaml │ └── tdpo.yaml ├── model │ ├── gpt2-xl.yaml │ ├── gpt2-large.yaml │ ├── gptj.yaml │ ├── llama7b.yaml │ ├── pythia28.yaml │ ├── pythia69.yaml │ └── blank_model.yaml └── config.yaml ├── figs ├── TDPO_vs_DPO.png └── IMDb_experiment.png ├── requirements.txt ├── train.py ├── README.md ├── utils.py ├── LICENSE ├── preference_datasets.py └── trainers.py /config/loss/sft.yaml: -------------------------------------------------------------------------------- 1 | name: sft -------------------------------------------------------------------------------- /figs/TDPO_vs_DPO.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vance0124/Token-level-Direct-Preference-Optimization/HEAD/figs/TDPO_vs_DPO.png -------------------------------------------------------------------------------- /figs/IMDb_experiment.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vance0124/Token-level-Direct-Preference-Optimization/HEAD/figs/IMDb_experiment.png -------------------------------------------------------------------------------- /config/model/gpt2-xl.yaml: -------------------------------------------------------------------------------- 1 | name_or_path: gpt2-xl 2 | tokenizer_name_or_path: null 3 | archive: null 4 | block_name: GPT2Block 5 | 6 | policy_dtype: float32 7 | fsdp_policy_mp: null 8 | reference_dtype: float16 -------------------------------------------------------------------------------- /config/model/gpt2-large.yaml: -------------------------------------------------------------------------------- 1 | name_or_path: gpt2-large 2 | tokenizer_name_or_path: null 3 | archive: null 4 | block_name: GPT2Block 5 | 6 | policy_dtype: float32 7 | fsdp_policy_mp: null 8 | reference_dtype: float16 -------------------------------------------------------------------------------- /config/model/gptj.yaml: -------------------------------------------------------------------------------- 1 | name_or_path: EleutherAI/gpt-j-6b 2 | tokenizer_name_or_path: null 3 | archive: null 4 | block_name: GPTJBlock 5 | 6 | policy_dtype: float32 7 | fsdp_policy_mp: null 8 | reference_dtype: float16 -------------------------------------------------------------------------------- /config/model/llama7b.yaml: -------------------------------------------------------------------------------- 1 | name_or_path: huggyllama/llama-7b 2 | tokenizer_name_or_path: null 3 | archive: null 4 | block_name: LlamaDecoderLayer 5 | 6 | policy_dtype: float32 7 | fsdp_policy_mp: null 8 | reference_dtype: float16 -------------------------------------------------------------------------------- /config/model/pythia28.yaml: -------------------------------------------------------------------------------- 1 | name_or_path: EleutherAI/pythia-2.8b 2 | tokenizer_name_or_path: null 3 | archive: null 4 | block_name: GPTNeoXLayer 5 | 6 | policy_dtype: bfloat16 7 | fsdp_policy_mp: null 8 | reference_dtype: float16 -------------------------------------------------------------------------------- /config/model/pythia69.yaml: -------------------------------------------------------------------------------- 1 | name_or_path: EleutherAI/pythia-6.9b 2 | tokenizer_name_or_path: null 3 | archive: null 4 | block_name: GPTNeoXLayer 5 | 6 | policy_dtype: float32 7 | fsdp_policy_mp: null 8 | reference_dtype: float16 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ipykernel==6.23.1 2 | numpy==1.24.3 3 | tokenizers==0.13.3 4 | torch==2.0.1 5 | tqdm==4.65.0 6 | transformers==4.29.2 7 | datasets==2.12.0 8 | beautifulsoup4==4.12.2 9 | wandb==0.15.3 10 | hydra-core==1.3.2 11 | tensor-parallel==1.2.4 -------------------------------------------------------------------------------- /config/loss/tdpo.yaml: -------------------------------------------------------------------------------- 1 | # do TDPO preference-based training 2 | name: tdpo 3 | if_tdpo2: false 4 | 5 | # the temperature parameter for TDPO; lower values mean we care less about 6 | # the reference model 7 | alpha: 0.5 8 | beta: 0.1 9 | 10 | # if true, use a uniform (maximum entropy) reference model 11 | reference_free: false -------------------------------------------------------------------------------- /config/model/blank_model.yaml: -------------------------------------------------------------------------------- 1 | # the name of the model to use; should be something like 2 | # gpt2-xl or gpt-neo-2.7B or huggyllama/llama-7b 3 | name_or_path: ??? 4 | 5 | # the name of the tokenizer to use; if null, will use the tokenizer from the model 6 | tokenizer_name_or_path: null 7 | 8 | # override pre-trained weights (e.g., from SFT); optional 9 | archive: null 10 | 11 | # the name of the module class to wrap with FSDP; should be something like 12 | # e.g. GPT2Block, GPTNeoXLayer, LlamaDecoderLayer, etc. 13 | block_name: null 14 | 15 | # the dtype for the policy parameters/optimizer state 16 | policy_dtype: float32 17 | 18 | # the mixed precision dtype if using FSDP; defaults to the same as the policy 19 | fsdp_policy_mp: null 20 | 21 | # the dtype for the reference model (which is used for inference only) 22 | reference_dtype: float16 23 | -------------------------------------------------------------------------------- /config/config.yaml: -------------------------------------------------------------------------------- 1 | # random seed for batch sampling 2 | seed: 0 3 | 4 | # name for this experiment in the local run directory and on wandb 5 | exp_name: ??? 6 | 7 | # the batch size for training; for FSDP, the batch size per GPU is batch_size / (grad_accumulation_steps * num_gpus) 8 | batch_size: 4 9 | 10 | # the batch size during evaluation and sampling, if enabled 11 | eval_batch_size: 16 12 | 13 | # debug mode (disables wandb, model checkpointing, etc.) 14 | debug: false 15 | 16 | # the port to use for FSDP 17 | fsdp_port: null 18 | 19 | # which dataset(s) to train on; can pass a list like datasets=[hh,shp] 20 | datasets: 21 | - hh 22 | 23 | # wandb configuration 24 | wandb: 25 | enabled: true 26 | entity: null 27 | project: "token-level-direct-preference-optimization" 28 | 29 | # to create the local run directory and cache models/datasets, 30 | # we will try each of these directories in order; if none exist, 31 | # we will create the last one and use it 32 | local_dirs: 33 | - /scr-ssd 34 | - /scr 35 | - .cache 36 | 37 | # whether or not to generate samples during evaluation; disable for FSDP/TensorParallel 38 | # is recommended, because they are slow 39 | sample_during_eval: true 40 | 41 | # how many model samples to generate during evaluation 42 | n_eval_model_samples: 16 43 | 44 | # whether to eval at the very beginning of training 45 | do_first_eval: true 46 | 47 | # an OmegaConf resolver that returns the local run directory, calling a function in utils.py 48 | local_run_dir: ${get_local_run_dir:${exp_name},${local_dirs}} 49 | 50 | # the learning rate 51 | lr: 5e-6 52 | 53 | # number of steps to accumulate over for each batch 54 | # (e.g. if batch_size=4 and gradient_accumulation_steps=2, then we will 55 | # accumulate gradients over 2 microbatches of size 2) 56 | gradient_accumulation_steps: 1 57 | 58 | # the maximum gradient norm to clip to 59 | max_grad_norm: 10.0 60 | 61 | # the maximum allowed length for an input (prompt + response) 62 | max_length: 512 63 | 64 | # the maximum allowed length for a prompt 65 | max_prompt_length: 256 66 | 67 | # the number of epochs to train for; if null, must specify n_examples 68 | n_epochs: 1 69 | 70 | # the number of examples to train for; if null, must specify n_epochs 71 | n_examples: null 72 | 73 | # the number of examples to evaluate on (and sample from, if sample_during_eval is true) 74 | n_eval_examples: 256 75 | 76 | # the trainer class to use (e.g. BasicTrainer, FSDPTrainer, TensorParallelTrainer) 77 | trainer: BasicTrainer 78 | 79 | # The optimizer to use; we use RMSprop because it works about as well as Adam and is more memory-efficient 80 | optimizer: RMSprop 81 | 82 | # number of linear warmup steps for the learning rate 83 | warmup_steps: 150 84 | 85 | # whether or not to use activation/gradient checkpointing 86 | activation_checkpointing: false 87 | 88 | # evaluate and save model every eval_every steps 89 | eval_every: 20_000 90 | 91 | # prevent wandb from logging more than once per minimum_log_interval_secs 92 | minimum_log_interval_secs: 1.0 93 | 94 | defaults: 95 | - _self_ 96 | - model: blank_model_fp32 # basic model configuration 97 | - loss: sft # which loss function, either sft or dpo (specify loss.beta if using dpo) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | torch.backends.cuda.matmul.allow_tf32 = True 4 | import torch.nn as nn 5 | import transformers 6 | from utils import get_local_dir, get_local_run_dir, disable_dropout, init_distributed, get_open_port 7 | import os 8 | import hydra 9 | import torch.multiprocessing as mp 10 | from omegaconf import OmegaConf, DictConfig 11 | import trainers 12 | import wandb 13 | import json 14 | import socket 15 | from typing import Optional, Set 16 | import resource 17 | 18 | OmegaConf.register_new_resolver("get_local_run_dir", lambda exp_name, local_dirs: get_local_run_dir(exp_name, local_dirs)) 19 | 20 | 21 | def worker_main(rank: int, world_size: int, config: DictConfig, policy: nn.Module, reference_model: Optional[nn.Module] = None): 22 | """Main function for each worker process (may be only 1 for BasicTrainer/TensorParallelTrainer).""" 23 | if 'FSDP' in config.trainer: 24 | init_distributed(rank, world_size, port=config.fsdp_port) 25 | 26 | if config.debug: 27 | wandb.init = lambda *args, **kwargs: None 28 | wandb.log = lambda *args, **kwargs: None 29 | 30 | if rank == 0 and config.wandb.enabled: 31 | os.environ['WANDB_CACHE_DIR'] = get_local_dir(config.local_dirs) 32 | wandb.init( 33 | entity=config.wandb.entity, 34 | project=config.wandb.project, 35 | config=OmegaConf.to_container(config), 36 | dir=get_local_dir(config.local_dirs), 37 | name=config.exp_name, 38 | ) 39 | 40 | TrainerClass = getattr(trainers, config.trainer) 41 | print(f'Creating trainer on process {rank} with world size {world_size}') 42 | trainer = TrainerClass(policy, config, config.seed, config.local_run_dir, reference_model=reference_model, rank=rank, world_size=world_size) 43 | 44 | trainer.train() 45 | trainer.save() 46 | 47 | 48 | @hydra.main(version_base=None, config_path="config", config_name="config") 49 | def main(config: DictConfig): 50 | """Main entry point for training. Validates config, creates/initializes model(s), and kicks off worker process(es).""" 51 | 52 | # Resolve hydra references, e.g. so we don't re-compute the run directory 53 | OmegaConf.resolve(config) 54 | 55 | missing_keys: Set[str] = OmegaConf.missing_keys(config) 56 | if missing_keys: 57 | raise ValueError(f"Got missing keys in config:\n{missing_keys}") 58 | 59 | if config.eval_every % config.batch_size != 0: 60 | print('WARNING: eval_every must be divisible by batch_size') 61 | print('Setting eval_every to', config.eval_every - config.eval_every % config.batch_size) 62 | config.eval_every = config.eval_every - config.eval_every % config.batch_size 63 | 64 | if 'FSDP' in config.trainer and config.fsdp_port is None: 65 | free_port = get_open_port() 66 | print('no FSDP port specified; using open port for FSDP:', free_port) 67 | config.fsdp_port = free_port 68 | 69 | print(OmegaConf.to_yaml(config)) 70 | 71 | config_path = os.path.join(config.local_run_dir, 'config.yaml') 72 | with open(config_path, 'w') as f: 73 | OmegaConf.save(config, f) 74 | 75 | print('=' * 80) 76 | print(f'Writing to {socket.gethostname()}:{config.local_run_dir}') 77 | print('=' * 80) 78 | 79 | os.environ['XDG_CACHE_HOME'] = get_local_dir(config.local_dirs) 80 | print('building policy') 81 | model_kwargs = {'device_map': 'balanced'} if config.trainer == 'BasicTrainer' else {} 82 | policy_dtype = getattr(torch, config.model.policy_dtype) 83 | policy = transformers.AutoModelForCausalLM.from_pretrained( 84 | config.model.name_or_path, cache_dir=get_local_dir(config.local_dirs), low_cpu_mem_usage=True, torch_dtype=policy_dtype, **model_kwargs) 85 | disable_dropout(policy) 86 | 87 | if config.loss.name == 'tdpo': 88 | print('building reference model') 89 | reference_model_dtype = getattr(torch, config.model.reference_dtype) 90 | reference_model = transformers.AutoModelForCausalLM.from_pretrained( 91 | config.model.name_or_path, cache_dir=get_local_dir(config.local_dirs), low_cpu_mem_usage=True, torch_dtype=reference_model_dtype, **model_kwargs) 92 | disable_dropout(reference_model) 93 | else: 94 | reference_model = None 95 | 96 | if config.model.archive is not None: 97 | state_dict = torch.load(config.model.archive, map_location='cpu') 98 | step, metrics = state_dict['step_idx'], state_dict['metrics'] 99 | print(f'loading pre-trained weights at step {step} from {config.model.archive} with metrics {json.dumps(metrics, indent=2)}') 100 | policy.load_state_dict(state_dict['state']) 101 | if config.loss.name == 'tdpo': 102 | reference_model.load_state_dict(state_dict['state']) 103 | print('loaded pre-trained weights') 104 | 105 | if 'FSDP' in config.trainer: 106 | world_size = torch.cuda.device_count() 107 | print('starting', world_size, 'processes for FSDP training') 108 | soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE) 109 | resource.setrlimit(resource.RLIMIT_NOFILE, (hard, hard)) 110 | print(f'setting RLIMIT_NOFILE soft limit to {hard} from {soft}') 111 | mp.spawn(worker_main, nprocs=world_size, args=(world_size, config, policy, reference_model), join=True) 112 | else: 113 | print('starting single-process worker') 114 | worker_main(0, 1, config, policy, reference_model) 115 | 116 | 117 | if __name__ == '__main__': 118 | main() -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TDPO: Token-level Direct Preference Optimization 2 | 3 | This repo contains a reference implementation of the TDPO algorithm for training language models from preference data, as described in the paper [_Token-level Direct Preference Optimization_](https://arxiv.org/pdf/2404.11999.pdf) (ICML 2024). Our implementation is based on [DPO](https://github.com/eric-mitchell/direct-preference-optimization), and follows the same usage guidelines. 4 | 5 | 6 | 7 |
9 | tags (which contain real code in the StackExchange answers)."""
25 | # Create a BeautifulSoup object
26 | soup = BeautifulSoup(html_string, 'html.parser')
27 |
28 | # Initialize an empty list to store the text
29 | text = []
30 | for element in soup.children:
31 | if isinstance(element, NavigableString):
32 | continue
33 | if element.name == 'p':
34 | text.append(''.join(child.string for child in element.children if isinstance(child, NavigableString)))
35 | elif element.name == 'pre':
36 | for code in element.find_all('code'):
37 | text.append("" + code.get_text() + "")
38 | elif element.name == 'code':
39 | text.append("" + element.get_text() + "")
40 |
41 | # Join the text together with newlines in between
42 | text = "\n\n".join(text)
43 |
44 | return text
45 |
46 |
47 | def get_se(split, silent=False, cache_dir: str = None) -> Dict[
48 | str, Dict[str, Union[List[Tuple[int, int]], List[str], str]]]:
49 | """Load the StackExchange dataset from Huggingface, and return a dict of prompts and responses. See get_hh for the format.
50 |
51 | We strip the HTML tags from the responses (except for tags), and we add necessary newlines.
52 | """
53 | print(f'Loading SE dataset ({split} split) from Huggingface...')
54 | dataset = datasets.load_dataset('HuggingFaceH4/stack-exchange-preferences', cache_dir=cache_dir)['train']
55 | print('done')
56 |
57 | # shuffle the dataset and select 1% for test
58 | dataset = dataset.shuffle(seed=42)
59 | dataset = dataset.select(range(int(len(dataset) * 0.01))) if split == 'test' else dataset.select(
60 | range(int(len(dataset) * 0.01), len(dataset)))
61 |
62 | def strip_html(x):
63 | x['question'] = strip_html_tags(x['question'])
64 | for a in x['answers']:
65 | a['text'] = strip_html_tags(a['text'])
66 | return x
67 |
68 | dataset = dataset.map(strip_html, num_proc=64)
69 |
70 | data = defaultdict(dict)
71 | for row in tqdm.tqdm(dataset, desc='Processing SE', disable=silent):
72 | prompt = '\n\nHuman: ' + row['question'] + '\n\nAssistant:'
73 | responses = [' ' + a['text'] for a in row['answers']]
74 | scores = [a['pm_score'] for a in row['answers']]
75 |
76 | pairs = []
77 | for i in range(len(responses)):
78 | for j in range(i + 1, len(responses)):
79 | pairs.append((i, j) if scores[i] > scores[j] else (j, i))
80 |
81 | data[prompt]['responses'] = responses
82 | data[prompt]['pairs'] = pairs
83 | data[prompt]['sft_target'] = max(responses, key=lambda x: scores[responses.index(x)])
84 |
85 | return data
86 |
87 |
88 | def get_shp(split: str, silent: bool = False, cache_dir: str = None) -> Dict[
89 | str, Dict[str, Union[List[Tuple[int, int]], List[str], str]]]:
90 | """Load the Stanford Human Preferences dataset from Huggingface and convert it to the necessary format. See hh for the format.
91 |
92 | We filter preference pairs to only keep pairs where the score ratio is at least 2.
93 | For this dataset, the sft_target is the response with the highest score.
94 | """
95 | print(f'Loading SHP dataset ({split} split) from Huggingface...')
96 | dataset = datasets.load_dataset('stanfordnlp/SHP', split=split, cache_dir=cache_dir)
97 | print('done')
98 |
99 | data = defaultdict(lambda: defaultdict(list))
100 | for row in tqdm.tqdm(dataset, desc='Processing SHP', disable=silent):
101 | prompt = '\n\nHuman: ' + row['history'] + '\n\nAssistant:'
102 | responses = [' ' + row['human_ref_A'], ' ' + row['human_ref_B']]
103 | scores = [row['score_A'], row['score_B']]
104 | if prompt in data:
105 | n_responses = len(data[prompt]['responses'])
106 | else:
107 | n_responses = 0
108 | score_ratio = max(scores[0] / scores[1], scores[1] / scores[0])
109 | if score_ratio < 2:
110 | continue
111 |
112 | # according to https://huggingface.co/datasets/stanfordnlp/SHP
113 | data[prompt]['pairs'].append(
114 | (n_responses, n_responses + 1) if row['labels'] == 1 else (n_responses + 1, n_responses))
115 | data[prompt]['responses'].extend(responses)
116 | data[prompt]['scores'].extend(scores)
117 |
118 | for prompt in data:
119 | data[prompt]['sft_target'] = max(data[prompt]['responses'],
120 | key=lambda x: data[prompt]['scores'][data[prompt]['responses'].index(x)])
121 | del data[prompt]['scores']
122 |
123 | return data
124 |
125 |
126 | def get_hh(split: str, silent: bool = False, cache_dir: str = None) -> Dict[
127 | str, Dict[str, Union[List[Tuple[int, int]], List[str], str]]]:
128 | """Load the Anthropic Helpful-Harmless dataset from Huggingface and convert it to the necessary format.
129 |
130 | The dataset is converted to a dictionary with the following structure:
131 | {
132 | 'prompt1': {
133 | 'responses': List[str],
134 | 'pairs': List[Tuple[int, int]],
135 | 'sft_target': str
136 | },
137 | 'prompt2': {
138 | ...
139 | },
140 | }
141 |
142 | Prompts should be structured as follows:
143 | \n\nHuman: \n\nAssistant:
144 | Multiple turns are allowed, but the prompt should always start with \n\nHuman: and end with \n\nAssistant:.
145 |
146 | For this dataset, the sft_target is just the chosen response.
147 | """
148 | print(f'Loading HH dataset ({split} split) from Huggingface...')
149 | dataset = datasets.load_dataset('Anthropic/hh-rlhf', split=split, cache_dir=cache_dir)
150 | print('done')
151 |
152 | def split_prompt_and_responses(ex):
153 | prompt = extract_anthropic_prompt(ex['chosen'])
154 | chosen_response = ex['chosen'][len(prompt):]
155 | rejected_response = ex['rejected'][len(prompt):]
156 | return prompt, chosen_response, rejected_response
157 |
158 | data = defaultdict(lambda: defaultdict(list))
159 | for row in tqdm.tqdm(dataset, desc='Processing HH', disable=silent):
160 | prompt, chosen, rejected = split_prompt_and_responses(row)
161 | responses = [chosen, rejected]
162 | n_responses = len(data[prompt]['responses'])
163 | data[prompt]['pairs'].append((n_responses, n_responses + 1))
164 | data[prompt]['responses'].extend(responses)
165 | data[prompt]['sft_target'] = chosen
166 |
167 | return data
168 |
169 |
170 | def get_dataset(name: str, split: str, silent: bool = False, cache_dir: str = None):
171 | """Load the given dataset by name. Supported by default are 'shp', 'hh', and 'se'."""
172 | if name == 'shp':
173 | data = get_shp(split, silent=silent, cache_dir=cache_dir)
174 | elif name == 'hh':
175 | data = get_hh(split, silent=silent, cache_dir=cache_dir)
176 | elif name == 'se':
177 | data = get_se(split, silent=silent, cache_dir=cache_dir)
178 | else:
179 | raise ValueError(f"Unknown dataset '{name}'")
180 |
181 | assert set(list(data.values())[0].keys()) == {'responses', 'pairs', 'sft_target'}, \
182 | f"Unexpected keys in dataset: {list(list(data.values())[0].keys())}"
183 |
184 | return data
185 |
186 |
187 | def get_collate_fn(tokenizer) -> Callable[[List[Dict]], Dict[str, Union[List, torch.Tensor]]]:
188 | """Returns a collate function for the given tokenizer.
189 |
190 | The collate function takes a list of examples (dicts, where values are lists of
191 | ints [tokens] or strings [the original texts]) and returns a batch of examples,
192 | PyTorch tensors padded to the maximum length. Strings are passed through."""
193 |
194 | def collate_fn(batch):
195 | # first, pad everything to the same length
196 | padded_batch = {}
197 | for k in batch[0].keys():
198 | if k.endswith('_input_ids') or k.endswith('_attention_mask') or k.endswith('_labels'):
199 | if 'prompt' in k: # adapted from https://stackoverflow.com/questions/73256206
200 | to_pad = [torch.LongTensor(ex[k][::-1]) for ex in batch]
201 | else:
202 | to_pad = [torch.LongTensor(ex[k]) for ex in batch]
203 | if k.endswith('_input_ids'):
204 | padding_value = tokenizer.pad_token_id
205 | elif k.endswith('_labels'):
206 | padding_value = -100
207 | elif k.endswith('_attention_mask'):
208 | padding_value = 0
209 | else:
210 | raise ValueError(f"Unexpected key in batch '{k}'")
211 |
212 | padded_batch[k] = pad_sequence(to_pad, batch_first=True, padding_value=padding_value)
213 | if 'prompt' in k: # for the prompt, flip back so padding is on left side
214 | padded_batch[k] = padded_batch[k].flip(dims=[1])
215 | else:
216 | padded_batch[k] = [ex[k] for ex in batch]
217 |
218 | return padded_batch
219 |
220 | return collate_fn
221 |
222 |
223 | def tokenize_batch_element(prompt: str, chosen: str, rejected: str, truncation_mode: str, tokenizer, max_length: int,
224 | max_prompt_length: int) -> Dict:
225 | """Tokenize a single batch element.
226 |
227 | At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation
228 | in case the prompt + chosen or prompt + rejected responses is/are too long. First
229 | we truncate the prompt; if we're still too long, we truncate the chosen/rejected.
230 |
231 | We also create the labels for the chosen/rejected responses, which are of length equal to
232 | the sum of the length of the prompt and the chosen/rejected response, with -100 for the
233 | prompt tokens.
234 | """
235 | chosen_tokens = tokenizer(chosen, add_special_tokens=False)
236 | rejected_tokens = tokenizer(rejected, add_special_tokens=False)
237 | prompt_tokens = tokenizer(prompt, add_special_tokens=False)
238 |
239 | assert tokenizer.eos_token_id not in prompt_tokens['input_ids'], f"Prompt contains EOS token: {prompt}"
240 | assert tokenizer.eos_token_id not in chosen_tokens['input_ids'], f"Chosen response contains EOS token: {chosen}"
241 | assert tokenizer.eos_token_id not in rejected_tokens['input_ids'], f"Rejected response contains EOS token: {rejected}"
242 |
243 | chosen_tokens['input_ids'].append(tokenizer.eos_token_id)
244 | chosen_tokens['attention_mask'].append(1)
245 |
246 | rejected_tokens['input_ids'].append(tokenizer.eos_token_id)
247 | rejected_tokens['attention_mask'].append(1)
248 |
249 | longer_response_length = max(len(chosen_tokens['input_ids']), len(rejected_tokens['input_ids']))
250 |
251 | # if combined sequence is too long, truncate the prompt
252 | if len(prompt_tokens['input_ids']) + longer_response_length > max_length:
253 | if truncation_mode == 'keep_start':
254 | prompt_tokens = {k: v[:max_prompt_length] for k, v in prompt_tokens.items()}
255 | elif truncation_mode == 'keep_end':
256 | prompt_tokens = {k: v[-max_prompt_length:] for k, v in prompt_tokens.items()}
257 | else:
258 | raise ValueError(f'Unknown truncation mode: {truncation_mode}')
259 |
260 | # if that's still too long, truncate the response
261 | if len(prompt_tokens['input_ids']) + longer_response_length > max_length:
262 | chosen_tokens = {k: v[:max_length - max_prompt_length] for k, v in chosen_tokens.items()}
263 | rejected_tokens = {k: v[:max_length - max_prompt_length] for k, v in rejected_tokens.items()}
264 |
265 | # Create labels
266 | chosen_sequence_tokens = {k: prompt_tokens[k] + chosen_tokens[k] for k in chosen_tokens}
267 | rejected_sequence_tokens = {k: prompt_tokens[k] + rejected_tokens[k] for k in rejected_tokens}
268 | chosen_sequence_tokens['labels'] = chosen_sequence_tokens['input_ids'][:]
269 | chosen_sequence_tokens['labels'][:len(prompt_tokens['input_ids'])] = [-100] * len(prompt_tokens['input_ids'])
270 | rejected_sequence_tokens['labels'] = rejected_sequence_tokens['input_ids'][:]
271 | rejected_sequence_tokens['labels'][:len(prompt_tokens['input_ids'])] = [-100] * len(prompt_tokens['input_ids'])
272 |
273 | batch = {}
274 |
275 | batch['prompt'] = prompt
276 | batch['chosen'] = prompt + chosen
277 | batch['rejected'] = prompt + rejected
278 | batch['chosen_response_only'] = chosen
279 | batch['rejected_response_only'] = rejected
280 |
281 | for k, toks in {'chosen': chosen_sequence_tokens, 'rejected': rejected_sequence_tokens,
282 | 'prompt': prompt_tokens}.items():
283 | for type_key, tokens in toks.items():
284 | if type_key == 'token_type_ids':
285 | continue
286 | batch[f'{k}_{type_key}'] = tokens
287 |
288 | return batch
289 |
290 |
291 | def get_batch_iterator(names: List[str],
292 | tokenizer,
293 | split: str = 'train',
294 | batch_size: int = 1,
295 | shuffle: bool = True,
296 | max_length: int = 512,
297 | max_prompt_length: int = 128,
298 | sft_mode: bool = False,
299 | n_epochs: Optional[int] = None,
300 | n_examples: Optional[int] = None,
301 | seed: int = 0,
302 | silent: bool = False,
303 | cache_dir: Optional[str] = None) -> Iterator[Dict]:
304 | """Get an iterator over batches of data. Stops after n_epochs or n_examples, whichever comes first.
305 |
306 | Args:
307 | names: Names of datasets to use.
308 | tokenizer: Tokenizer to use.
309 | split: Which split to use.
310 | batch_size: Batch size.
311 | shuffle: Whether to shuffle the data after each epoch.
312 | max_length: Maximum length of the combined prompt + response.
313 | max_prompt_length: Maximum length of the prompt.
314 | sft_mode: Whether to use SFT mode (i.e., return sft_target instead of chosen/rejected). In sft mode, we just return chosen_input_ids, but they contain the sft_target.
315 | n_epochs: Number of epochs to run for. This or n_examples must be specified.
316 | n_examples: Number of examples to run for. This or n_epochs must be specified.
317 | seed: Random seed.
318 | silent: Whether to silence the progress bar(s).
319 | cache_dir: Directory to cache the datasets in.
320 | """
321 | assert n_epochs is not None or n_examples is not None, "Must specify either n_epochs or n_examples"
322 | if silent:
323 | datasets.logging.disable_progress_bar()
324 | datasets.logging.set_verbosity_error()
325 |
326 | with TemporarilySeededRandom(seed):
327 | permutation_seeds = iter(np.random.randint(0, 2 ** 31, size=1000000))
328 | flat_data = []
329 | for name in names:
330 | truncation_mode = 'keep_end' if name == 'hh' else 'keep_start'
331 | for prompt, data in get_dataset(name, split, silent=silent, cache_dir=cache_dir).items():
332 | flat_data.append((prompt, data['responses'], data['pairs'], data['sft_target'], truncation_mode))
333 |
334 | collate_fn = get_collate_fn(tokenizer)
335 |
336 | epoch_idx = 0
337 | example_idx = 0
338 | done = False
339 | while True:
340 | if n_epochs is not None and epoch_idx >= n_epochs:
341 | if not silent:
342 | print(f'Finished generating {n_epochs} epochs on {split} split')
343 | break
344 | if shuffle:
345 | with TemporarilySeededRandom(next(permutation_seeds)):
346 | random.shuffle(flat_data)
347 |
348 | batch = []
349 | for prompt, responses, pairs, sft_target, truncation_mode in flat_data:
350 | if done:
351 | break
352 | if sft_mode:
353 | batch_element = tokenize_batch_element(prompt, sft_target, sft_target, truncation_mode, tokenizer, max_length, max_prompt_length)
354 | batch_element = {k: v for k, v in batch_element.items() if 'rejected' not in k}
355 | batch.append(batch_element)
356 | example_idx += 1
357 | if len(batch) == batch_size:
358 | yield collate_fn(batch)
359 | if n_examples is not None and example_idx >= n_examples:
360 | if not silent:
361 | print(f'Finished generating {n_examples} examples on {split} split')
362 | done = True
363 |
364 | batch = []
365 | else:
366 | for p in pairs:
367 | if done:
368 | break
369 | batch_element = tokenize_batch_element(prompt, responses[p[0]], responses[p[1]], truncation_mode, tokenizer, max_length, max_prompt_length)
370 | batch.append(batch_element)
371 | example_idx += 1
372 | if len(batch) == batch_size:
373 | yield collate_fn(batch)
374 | if n_examples is not None and example_idx >= n_examples:
375 | if not silent:
376 | print(f'FINISHED {n_examples} EXAMPLES on {split} split')
377 | done = True
378 | batch = []
379 | if done:
380 | break
381 |
382 | epoch_idx += 1
383 |
384 |
385 | def strings_match_up_to_spaces(str_a: str, str_b: str) -> bool:
386 | """Returns True if str_a and str_b match up to spaces, False otherwise."""
387 | for idx in range(min(len(str_a), len(str_b)) - 2):
388 | if str_a[idx] != str_b[idx]:
389 | if str_a[idx] != ' ' and str_b[idx] != ' ':
390 | return False
391 | else:
392 | if str_a[idx] == ' ':
393 | str_a = str_a[:idx] + str_a[idx + 1:]
394 | else:
395 | str_b = str_b[:idx] + str_b[idx + 1:]
396 |
397 | return True
--------------------------------------------------------------------------------
/trainers.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | torch.backends.cuda.matmul.allow_tf32 = True
4 | import torch.nn.functional as F
5 | import torch.nn as nn
6 | import transformers
7 | from omegaconf import DictConfig
8 |
9 | import torch.distributed as dist
10 | from torch.distributed.fsdp import (
11 | FullyShardedDataParallel as FSDP,
12 | MixedPrecision,
13 | StateDictType,
14 | BackwardPrefetch,
15 | ShardingStrategy,
16 | CPUOffload,
17 | )
18 | from torch.distributed.fsdp.api import FullStateDictConfig, FullOptimStateDictConfig
19 | from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
20 | import tensor_parallel as tp
21 | import contextlib
22 |
23 | from preference_datasets import get_batch_iterator
24 | from utils import (
25 | slice_and_move_batch_for_device,
26 | formatted_dict,
27 | all_gather_if_needed,
28 | pad_to_length,
29 | get_block_class_from_model,
30 | rank0_print,
31 | get_local_dir,
32 | )
33 | import numpy as np
34 | import wandb
35 | import tqdm
36 |
37 | import random
38 | import os
39 | from collections import defaultdict
40 | import time
41 | import json
42 | import functools
43 | from typing import Optional, Dict, List, Union, Tuple
44 |
45 |
46 | def tdpo_loss(chosen_logps_margin: torch.FloatTensor,
47 | rejected_logps_margin: torch.FloatTensor,
48 | chosen_position_kl: torch.FloatTensor,
49 | rejected_position_kl: torch.FloatTensor,
50 | beta: float, alpha: float = 0.5, if_tdpo2: bool = True) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
51 | """Compute the TDPO loss for a batch of policy and reference model log probabilities.
52 |
53 | Args:
54 | chosen_logps_margin: The difference of log probabilities between the policy model and the reference model for the chosen responses. Shape: (batch_size,)
55 | rejected_logps_margin: The difference of log probabilities between the policy model and the reference model for the rejected responses. Shape: (batch_size,)
56 | chosen_position_kl: The difference of sequential kl divergence between the policy model and the reference model for the chosen responses. Shape: (batch_size,)
57 | rejected_position_kl: The difference of sequential kl divergence between the policy model and the reference model for the rejected responses. Shape: (batch_size,)
58 | beta: Temperature parameter for the TDPO loss, typically something in the range of 0.1 to 0.5. We ignore the reference model as beta -> 0.
59 | alpha: Temperature parameter for the TDPO loss, used to adjust the impact of sequential kl divergence.
60 | if_tdpo2: Determine whether to use method TDPO2, default is True; if False, then use method TDPO1.
61 |
62 | Returns:
63 | A tuple of two tensors: (losses, rewards).
64 | The losses tensor contains the TDPO loss for each example in the batch.
65 | The rewards tensors contain the rewards for response pair.
66 | """
67 |
68 | chosen_values = chosen_logps_margin + chosen_position_kl
69 | rejected_values = rejected_logps_margin + rejected_position_kl
70 |
71 | chosen_rejected_logps_margin = chosen_logps_margin - rejected_logps_margin
72 |
73 |
74 | if not if_tdpo2:
75 | logits = chosen_rejected_logps_margin - (rejected_position_kl - chosen_position_kl) # tdpo1
76 | else:
77 | logits = chosen_rejected_logps_margin - alpha * (rejected_position_kl - chosen_position_kl.detach()) # tdpo2
78 | losses = -F.logsigmoid(beta * logits)
79 |
80 | chosen_rewards = beta * chosen_values.detach()
81 | rejected_rewards = beta * rejected_values.detach()
82 |
83 | return losses, chosen_rewards, rejected_rewards
84 |
85 |
86 | def _get_batch_logps(logits: torch.FloatTensor, labels: torch.LongTensor,
87 | average_log_prob: bool = False) -> torch.FloatTensor:
88 | """Compute the log probabilities of the given labels under the given logits.
89 |
90 | Args:
91 | logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
92 | labels: Labels for which to compute the log probabilities. Label tokens with a value of -100 are ignored. Shape: (batch_size, sequence_length)
93 | average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
94 |
95 | Returns:
96 | A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
97 | """
98 | assert logits.shape[:-1] == labels.shape
99 |
100 | labels = labels[:, 1:].clone()
101 | logits = logits[:, :-1, :]
102 | loss_mask = (labels != -100)
103 |
104 | # dummy token; we'll ignore the losses on these tokens later
105 | labels[labels == -100] = 0
106 |
107 | per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
108 |
109 | if average_log_prob:
110 | return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
111 | else:
112 | return (per_token_logps * loss_mask).sum(-1)
113 |
114 |
115 | def _tdpo_get_batch_logps(logits: torch.FloatTensor, reference_logits: torch.FloatTensor, labels: torch.LongTensor,
116 | average_log_prob: bool = False):
117 | """Compute the kl divergence/log probabilities of the given labels under the given logits.
118 |
119 | Args:
120 | logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
121 | reference_logits: Logits of the reference model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
122 | labels: Labels for which to compute the log probabilities. Label tokens with a value of -100 are ignored. Shape: (batch_size, sequence_length)
123 | average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
124 |
125 | Returns:
126 | Several tensors of shape (batch_size,) containing the average/sum kl divergence/log probabilities of the given labels under the given logits.
127 | """
128 | assert logits.shape[:-1] == labels.shape
129 | assert reference_logits.shape[:-1] == labels.shape
130 |
131 | labels = labels[:, 1:].clone()
132 | logits = logits[:, :-1, :]
133 | reference_logits = reference_logits[:, :-1, :]
134 |
135 | loss_mask = (labels != -100)
136 |
137 | # dummy token; we'll ignore the losses on these tokens later
138 | labels[labels == -100] = 0
139 |
140 | vocab_logps = logits.log_softmax(-1)
141 |
142 | reference_vocab_ps = reference_logits.softmax(-1)
143 | reference_vocab_logps = reference_vocab_ps.log()
144 |
145 | per_position_kl = (reference_vocab_ps * (reference_vocab_logps - vocab_logps)).sum(-1)
146 | per_token_logps = torch.gather(vocab_logps, dim=2, index=labels.unsqueeze(2)).squeeze(2)
147 | per_reference_token_logps = torch.gather(reference_vocab_logps, dim=2, index=labels.unsqueeze(2)).squeeze(2)
148 |
149 | logps_margin = per_token_logps - per_reference_token_logps
150 |
151 | if average_log_prob:
152 | return (logps_margin * loss_mask).sum(-1) / loss_mask.sum(-1), \
153 | (per_position_kl * loss_mask).sum(-1) / loss_mask.sum(-1), \
154 | (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
155 | else:
156 | return (logps_margin * loss_mask).sum(-1), \
157 | (per_position_kl * loss_mask).sum(-1), \
158 | (per_token_logps * loss_mask).sum(-1)
159 |
160 |
161 | def concatenated_inputs(batch: Dict[str, Union[List, torch.LongTensor]]) -> Dict[str, torch.LongTensor]:
162 | """Concatenate the chosen and rejected inputs into a single tensor.
163 |
164 | Args:
165 | batch: A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors of shape (batch_size, sequence_length).
166 |
167 | Returns:
168 | A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'.
169 | """
170 | max_length = max(batch['chosen_input_ids'].shape[1], batch['rejected_input_ids'].shape[1])
171 | concatenated_batch = {}
172 | for k in batch:
173 | if k.startswith('chosen') and isinstance(batch[k], torch.Tensor):
174 | pad_value = -100 if 'labels' in k else 0
175 | concatenated_key = k.replace('chosen', 'concatenated')
176 | concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value)
177 | for k in batch:
178 | if k.startswith('rejected') and isinstance(batch[k], torch.Tensor):
179 | pad_value = -100 if 'labels' in k else 0
180 | concatenated_key = k.replace('rejected', 'concatenated')
181 | concatenated_batch[concatenated_key] = torch.cat((
182 | concatenated_batch[concatenated_key],
183 | pad_to_length(batch[k], max_length, pad_value=pad_value),
184 | ), dim=0)
185 | return concatenated_batch
186 |
187 |
188 | class BasicTrainer(object):
189 | def __init__(self, policy: nn.Module, config: DictConfig, seed: int, run_dir: str,
190 | reference_model: Optional[nn.Module] = None, rank: int = 0, world_size: int = 1):
191 | """A trainer for a language model, supporting either SFT or TDPO training.
192 |
193 | If multiple GPUs are present, naively splits the model across them, effectively
194 | offering N times available memory, but without any parallel computation.
195 | """
196 | self.seed = seed
197 | self.rank = rank
198 | self.world_size = world_size
199 | self.config = config
200 | self.run_dir = run_dir
201 |
202 | tokenizer_name_or_path = config.model.tokenizer_name_or_path or config.model.name_or_path
203 | rank0_print(f'Loading tokenizer {tokenizer_name_or_path}')
204 | self.tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_name_or_path,
205 | cache_dir=get_local_dir(config.local_dirs))
206 | if self.tokenizer.pad_token_id is None:
207 | self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
208 |
209 | data_iterator_kwargs = dict(
210 | names=config.datasets,
211 | tokenizer=self.tokenizer,
212 | shuffle=True,
213 | max_length=config.max_length,
214 | max_prompt_length=config.max_prompt_length,
215 | sft_mode=config.loss.name == 'sft',
216 | )
217 |
218 | self.policy = policy
219 | self.reference_model = reference_model
220 |
221 | self.train_iterator = get_batch_iterator(**data_iterator_kwargs, split='train', n_epochs=config.n_epochs,
222 | n_examples=config.n_examples, batch_size=config.batch_size,
223 | silent=rank != 0, cache_dir=get_local_dir(config.local_dirs))
224 | rank0_print(f'Loaded train data iterator')
225 | self.eval_iterator = get_batch_iterator(**data_iterator_kwargs, split='test', n_examples=config.n_eval_examples,
226 | batch_size=config.eval_batch_size, silent=rank != 0,
227 | cache_dir=get_local_dir(config.local_dirs))
228 | self.eval_batches = list(self.eval_iterator)
229 | rank0_print(f'Loaded {len(self.eval_batches)} eval batches of size {config.eval_batch_size}')
230 |
231 | def get_batch_samples(self, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]:
232 | """Generate samples from the policy (and reference model, if doing TDPO training) for the given batch of inputs."""
233 |
234 | # FSDP generation according to https://github.com/pytorch/pytorch/issues/100069
235 | ctx = lambda: (FSDP.summon_full_params(self.policy, writeback=False,
236 | recurse=False) if 'FSDP' in self.config.trainer else contextlib.nullcontext())
237 | with ctx():
238 | policy_output = self.policy.generate(
239 | batch['prompt_input_ids'], attention_mask=batch['prompt_attention_mask'],
240 | max_length=self.config.max_length, do_sample=True, pad_token_id=self.tokenizer.pad_token_id)
241 |
242 | if self.config.loss.name == 'tdpo':
243 | ctx = lambda: (FSDP.summon_full_params(self.reference_model, writeback=False,
244 | recurse=False) if 'FSDP' in self.config.trainer else contextlib.nullcontext())
245 | with ctx():
246 | reference_output = self.reference_model.generate(
247 | batch['prompt_input_ids'], attention_mask=batch['prompt_attention_mask'],
248 | max_length=self.config.max_length, do_sample=True, pad_token_id=self.tokenizer.pad_token_id)
249 |
250 | policy_output = pad_to_length(policy_output, self.config.max_length, self.tokenizer.pad_token_id)
251 | policy_output = all_gather_if_needed(policy_output, self.rank, self.world_size)
252 | policy_output_decoded = self.tokenizer.batch_decode(policy_output, skip_special_tokens=True)
253 |
254 | if self.config.loss.name == 'tdpo':
255 | reference_output = pad_to_length(reference_output, self.config.max_length, self.tokenizer.pad_token_id)
256 | reference_output = all_gather_if_needed(reference_output, self.rank, self.world_size)
257 | reference_output_decoded = self.tokenizer.batch_decode(reference_output, skip_special_tokens=True)
258 | else:
259 | reference_output_decoded = []
260 |
261 | return policy_output_decoded, reference_output_decoded
262 |
263 | def tdpo_concatenated_forward(self, model: nn.Module, reference_model: nn.Module,
264 | batch: Dict[str, Union[List, torch.LongTensor]]):
265 | """Run the policy model and the reference model on the given batch of inputs, concatenating the chosen and rejected inputs together.
266 |
267 | We do this to avoid doing two forward passes, because it's faster for FSDP.
268 | """
269 | concatenated_batch = concatenated_inputs(batch)
270 | all_logits = model(concatenated_batch['concatenated_input_ids'],
271 | attention_mask=concatenated_batch['concatenated_attention_mask']).logits.to(torch.float32)
272 | with torch.no_grad():
273 | reference_all_logits = reference_model(concatenated_batch['concatenated_input_ids'],
274 | attention_mask=concatenated_batch[
275 | 'concatenated_attention_mask']).logits.to(torch.float32)
276 | all_logps_margin, all_position_kl, all_logps = _tdpo_get_batch_logps(all_logits, reference_all_logits, concatenated_batch['concatenated_labels'], average_log_prob=False)
277 |
278 | chosen_logps_margin = all_logps_margin[:batch['chosen_input_ids'].shape[0]]
279 | rejected_logps_margin = all_logps_margin[batch['chosen_input_ids'].shape[0]:]
280 | chosen_position_kl = all_position_kl[:batch['chosen_input_ids'].shape[0]]
281 | rejected_position_kl = all_position_kl[batch['chosen_input_ids'].shape[0]:]
282 |
283 | chosen_logps = all_logps[:batch['chosen_input_ids'].shape[0]].detach()
284 | rejected_logps = all_logps[batch['chosen_input_ids'].shape[0]:].detach()
285 |
286 | return chosen_logps_margin, rejected_logps_margin, chosen_position_kl, rejected_position_kl, \
287 | chosen_logps, rejected_logps
288 |
289 | def get_batch_metrics(self, batch: Dict[str, Union[List, torch.LongTensor]], loss_config: DictConfig, train=True):
290 | """Compute the SFT or TDPO loss and other metrics for the given batch of inputs."""
291 |
292 | metrics = {}
293 | train_test = 'train' if train else 'eval'
294 |
295 | if loss_config.name == 'tdpo':
296 | chosen_logps_margin, rejected_logps_margin, chosen_position_kl, rejected_position_kl, policy_chosen_logps, policy_rejected_logps\
297 | = self.tdpo_concatenated_forward(self.policy, self.reference_model, batch)
298 | losses, chosen_rewards, rejected_rewards = tdpo_loss(chosen_logps_margin, rejected_logps_margin,
299 | chosen_position_kl, rejected_position_kl,
300 | beta=loss_config.beta, alpha=loss_config.alpha, if_tdpo2=loss_config.if_tdpo2)
301 |
302 | reward_accuracies = (chosen_rewards > rejected_rewards).float()
303 |
304 | chosen_rewards = all_gather_if_needed(chosen_rewards, self.rank, self.world_size)
305 | rejected_rewards = all_gather_if_needed(rejected_rewards, self.rank, self.world_size)
306 | reward_accuracies = all_gather_if_needed(reward_accuracies, self.rank, self.world_size)
307 |
308 | metrics[f'rewards_{train_test}/chosen'] = chosen_rewards.cpu().numpy().tolist()
309 | metrics[f'rewards_{train_test}/rejected'] = rejected_rewards.cpu().numpy().tolist()
310 | metrics[f'rewards_{train_test}/accuracies'] = reward_accuracies.cpu().numpy().tolist()
311 | metrics[f'rewards_{train_test}/margins'] = (chosen_rewards - rejected_rewards).cpu().numpy().tolist()
312 |
313 | all_device_chosen_position_kl = all_gather_if_needed(chosen_position_kl.detach(), self.rank, self.world_size)
314 | all_device_rejected_position_kl = all_gather_if_needed(rejected_position_kl.detach(), self.rank, self.world_size)
315 |
316 | metrics[f'kl_{train_test}/chosen'] = all_device_chosen_position_kl.cpu().numpy().tolist()
317 | metrics[f'kl_{train_test}/rejected'] = all_device_rejected_position_kl.cpu().numpy().tolist()
318 | metrics[f'kl_{train_test}/margin'] = (all_device_chosen_position_kl - all_device_rejected_position_kl).cpu().numpy().tolist()
319 |
320 | policy_rejected_logps = all_gather_if_needed(policy_rejected_logps.detach(), self.rank, self.world_size)
321 | metrics[f'logps_{train_test}/rejected'] = policy_rejected_logps.cpu().numpy().tolist()
322 |
323 | elif loss_config.name == 'sft':
324 | policy_chosen_logits = self.policy(batch['chosen_input_ids'],
325 | attention_mask=batch['chosen_attention_mask']).logits.to(torch.float32)
326 | policy_chosen_logps = _get_batch_logps(policy_chosen_logits, batch['chosen_labels'], average_log_prob=False)
327 |
328 | losses = -policy_chosen_logps
329 |
330 | policy_chosen_logps = all_gather_if_needed(policy_chosen_logps.detach(), self.rank, self.world_size)
331 | metrics[f'logps_{train_test}/chosen'] = policy_chosen_logps.cpu().numpy().tolist()
332 |
333 | all_devices_losses = all_gather_if_needed(losses.detach(), self.rank, self.world_size)
334 | metrics[f'loss/{train_test}'] = all_devices_losses.cpu().numpy().tolist()
335 |
336 | return losses.mean(), metrics
337 |
338 | def train(self):
339 | """Begin either SFT or TDPO training, with periodic evaluation."""
340 |
341 | rank0_print(f'Using {self.config.optimizer} optimizer')
342 | self.optimizer = getattr(torch.optim, self.config.optimizer)(self.policy.parameters(), lr=self.config.lr)
343 | self.scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lambda step: min(1.0,
344 | (step + 1) / (
345 | self.config.warmup_steps + 1)))
346 |
347 | torch.manual_seed(self.seed)
348 | np.random.seed(self.seed)
349 | random.seed(self.seed)
350 |
351 | if self.config.loss.name == 'tdpo':
352 | self.reference_model.eval()
353 |
354 | self.example_counter = 0
355 | self.batch_counter = 0
356 | last_log = None
357 |
358 | for batch in self.train_iterator:
359 | #### BEGIN EVALUATION ####
360 | if self.example_counter % self.config.eval_every == 0 and (
361 | self.example_counter > 0 or self.config.do_first_eval):
362 | rank0_print(f'Running evaluation after {self.example_counter} train examples')
363 | self.policy.eval()
364 |
365 | all_eval_metrics = defaultdict(list)
366 | if self.config.sample_during_eval:
367 | all_policy_samples, all_reference_samples = [], []
368 | policy_text_table = wandb.Table(columns=["step", "prompt", "sample"])
369 | if self.config.loss.name in 'tdpo':
370 | reference_text_table = wandb.Table(columns=["step", "prompt", "sample"])
371 |
372 | for eval_batch in (
373 | tqdm.tqdm(self.eval_batches, desc='Computing eval metrics') if self.rank == 0 else self.eval_batches):
374 | local_eval_batch = slice_and_move_batch_for_device(eval_batch, self.rank, self.world_size,
375 | self.rank)
376 | with torch.no_grad():
377 | _, eval_metrics = self.get_batch_metrics(local_eval_batch, self.config.loss, train=False)
378 |
379 | for k, v in eval_metrics.items():
380 | all_eval_metrics[k].extend(v)
381 |
382 | if self.config.sample_during_eval:
383 | if self.config.n_eval_model_samples < self.config.eval_batch_size:
384 | rank0_print(
385 | f'Warning: n_eval_model_samples ({self.config.n_eval_model_samples}) < eval_batch_size ({self.config.eval_batch_size}). Sampling from the first complete eval batch of prompts.')
386 | sample_batches = self.eval_batches[:1]
387 | else:
388 | n_sample_batches = self.config.n_eval_model_samples // self.config.eval_batch_size
389 | sample_batches = self.eval_batches[:n_sample_batches]
390 | for eval_batch in (
391 | tqdm.tqdm(sample_batches, desc='Generating samples...') if self.rank == 0 else sample_batches):
392 | local_eval_batch = slice_and_move_batch_for_device(eval_batch, self.rank, self.world_size,
393 | self.rank)
394 | policy_samples, reference_samples = self.get_batch_samples(local_eval_batch)
395 |
396 | all_policy_samples.extend(policy_samples)
397 | all_reference_samples.extend(reference_samples)
398 |
399 | for prompt, sample in zip(eval_batch['prompt'], policy_samples):
400 | policy_text_table.add_data(self.example_counter, prompt, sample)
401 | if self.config.loss.name == 'tdpo':
402 | for prompt, sample in zip(eval_batch['prompt'], reference_samples):
403 | reference_text_table.add_data(self.example_counter, prompt, sample)
404 |
405 | mean_eval_metrics = {k: sum(v) / len(v) for k, v in all_eval_metrics.items()}
406 | rank0_print(f'eval after {self.example_counter}: {formatted_dict(mean_eval_metrics)}')
407 | if self.config.sample_during_eval:
408 | rank0_print(json.dumps(all_policy_samples[:10], indent=2))
409 | if self.config.loss.name == 'tdpo':
410 | rank0_print(json.dumps(all_reference_samples[:10], indent=2))
411 |
412 | if self.config.wandb.enabled and self.rank == 0:
413 | wandb.log(mean_eval_metrics, step=self.example_counter)
414 |
415 | if self.config.sample_during_eval:
416 | wandb.log({"policy_samples": policy_text_table}, step=self.example_counter)
417 | if self.config.loss.name == 'tdpo':
418 | wandb.log({"reference_samples": reference_text_table}, step=self.example_counter)
419 |
420 | if self.example_counter > 0:
421 | if self.config.debug:
422 | rank0_print('skipping save in debug mode')
423 | else:
424 | output_dir = os.path.join(self.run_dir, f'step-{self.example_counter}')
425 | rank0_print(f'creating checkpoint to write to {output_dir}...')
426 | self.save(output_dir, mean_eval_metrics)
427 | #### END EVALUATION ####
428 |
429 | #### BEGIN TRAINING ####
430 | self.policy.train()
431 |
432 | start_time = time.time()
433 | batch_metrics = defaultdict(list)
434 | for microbatch_idx in range(self.config.gradient_accumulation_steps):
435 | global_microbatch = slice_and_move_batch_for_device(batch, microbatch_idx,
436 | self.config.gradient_accumulation_steps, self.rank)
437 | local_microbatch = slice_and_move_batch_for_device(global_microbatch, self.rank, self.world_size,
438 | self.rank)
439 | loss, metrics = self.get_batch_metrics(local_microbatch, self.config.loss, train=True)
440 | (loss / self.config.gradient_accumulation_steps).backward()
441 |
442 | for k, v in metrics.items():
443 | batch_metrics[k].extend(v)
444 |
445 | grad_norm = self.clip_gradient()
446 | self.optimizer.step()
447 | self.scheduler.step()
448 | self.optimizer.zero_grad()
449 |
450 | step_time = time.time() - start_time
451 | examples_per_second = self.config.batch_size / step_time
452 | batch_metrics['examples_per_second'].append(examples_per_second)
453 | batch_metrics['grad_norm'].append(grad_norm)
454 |
455 | self.batch_counter += 1
456 | self.example_counter += self.config.batch_size
457 |
458 | if last_log is None or time.time() - last_log > self.config.minimum_log_interval_secs:
459 | mean_train_metrics = {k: sum(v) / len(v) for k, v in batch_metrics.items()}
460 | mean_train_metrics['counters/examples'] = self.example_counter
461 | mean_train_metrics['counters/updates'] = self.batch_counter
462 | rank0_print(f'train stats after {self.example_counter} examples: {formatted_dict(mean_train_metrics)}')
463 |
464 | if self.config.wandb.enabled and self.rank == 0:
465 | wandb.log(mean_train_metrics, step=self.example_counter)
466 |
467 | last_log = time.time()
468 | else:
469 | rank0_print(f'skipping logging after {self.example_counter} examples to avoid logging too frequently')
470 | #### END TRAINING ####
471 |
472 | def clip_gradient(self):
473 | """Clip the gradient norm of the parameters of a non-FSDP policy."""
474 | return torch.nn.utils.clip_grad_norm_(self.policy.parameters(), self.config.max_grad_norm).item()
475 |
476 | def write_state_dict(self, step: int, state: Dict[str, torch.Tensor], metrics: Dict, filename: str,
477 | dir_name: Optional[str] = None):
478 | """Write a checkpoint to disk."""
479 | if dir_name is None:
480 | dir_name = os.path.join(self.run_dir, f'LATEST')
481 |
482 | os.makedirs(dir_name, exist_ok=True)
483 | output_path = os.path.join(dir_name, filename)
484 | rank0_print(f'writing checkpoint to {output_path}...')
485 | torch.save({
486 | 'step_idx': step,
487 | 'state': state,
488 | 'metrics': metrics if metrics is not None else {},
489 | }, output_path)
490 |
491 | def save(self, output_dir: Optional[str] = None, metrics: Optional[Dict] = None):
492 | """Save policy, optimizer, and scheduler state to disk."""
493 |
494 | policy_state_dict = self.policy.state_dict()
495 | self.write_state_dict(self.example_counter, policy_state_dict, metrics, 'policy.pt', output_dir)
496 | del policy_state_dict
497 |
498 | optimizer_state_dict = self.optimizer.state_dict()
499 | self.write_state_dict(self.example_counter, optimizer_state_dict, metrics, 'optimizer.pt', output_dir)
500 | del optimizer_state_dict
501 |
502 | scheduler_state_dict = self.scheduler.state_dict()
503 | self.write_state_dict(self.example_counter, scheduler_state_dict, metrics, 'scheduler.pt', output_dir)
504 |
505 |
506 | class FSDPTrainer(BasicTrainer):
507 | def __init__(self, policy: nn.Module, config: DictConfig, seed: int, run_dir: str,
508 | reference_model: Optional[nn.Module] = None, rank: int = 0, world_size: int = 1):
509 | """A trainer subclass that uses PyTorch FSDP to shard the model across multiple GPUs.
510 |
511 | This trainer will shard both the policy and reference model across all available GPUs.
512 | Models are sharded at the block level, where the block class name is provided in the config.
513 | """
514 |
515 | super().__init__(policy, config, seed, run_dir, reference_model, rank, world_size)
516 | assert config.model.block_name is not None, 'must specify model.block_name (e.g., GPT2Block or GPTNeoXLayer) for FSDP'
517 |
518 | wrap_class = get_block_class_from_model(policy, config.model.block_name)
519 | model_auto_wrap_policy = functools.partial(transformer_auto_wrap_policy, transformer_layer_cls={wrap_class}, )
520 |
521 | shared_fsdp_kwargs = dict(
522 | auto_wrap_policy=model_auto_wrap_policy,
523 | sharding_strategy=ShardingStrategy.FULL_SHARD,
524 | cpu_offload=CPUOffload(offload_params=False),
525 | backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
526 | device_id=rank,
527 | ignored_modules=None,
528 | limit_all_gathers=False,
529 | use_orig_params=False,
530 | sync_module_states=False
531 | )
532 |
533 | rank0_print('Sharding policy...')
534 | mp_dtype = getattr(torch, config.model.fsdp_policy_mp) if config.model.fsdp_policy_mp is not None else None
535 | policy_mp_policy = MixedPrecision(param_dtype=mp_dtype, reduce_dtype=mp_dtype, buffer_dtype=mp_dtype)
536 | self.policy = FSDP(policy, **shared_fsdp_kwargs, mixed_precision=policy_mp_policy)
537 |
538 | if config.activation_checkpointing:
539 | rank0_print('Attempting to enable activation checkpointing...')
540 | try:
541 | # use activation checkpointing, according to:
542 | # https://pytorch.org/blog/scaling-multimodal-foundation-models-in-torchmultimodal-with-pytorch-distributed/
543 | #
544 | # first, verify we have FSDP activation support ready by importing:
545 | from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
546 | checkpoint_wrapper,
547 | apply_activation_checkpointing,
548 | CheckpointImpl,
549 | )
550 | non_reentrant_wrapper = functools.partial(
551 | checkpoint_wrapper,
552 | offload_to_cpu=False,
553 | checkpoint_impl=CheckpointImpl.NO_REENTRANT,
554 | )
555 | except Exception as e:
556 | rank0_print('FSDP activation checkpointing not available:', e)
557 | else:
558 | check_fn = lambda submodule: isinstance(submodule, wrap_class)
559 | rank0_print('Applying activation checkpointing wrapper to policy...')
560 | apply_activation_checkpointing(self.policy, checkpoint_wrapper_fn=non_reentrant_wrapper,
561 | check_fn=check_fn)
562 | rank0_print('FSDP activation checkpointing enabled!')
563 |
564 | if config.loss.name == 'tdpo':
565 | rank0_print('Sharding reference model...')
566 | self.reference_model = FSDP(reference_model, **shared_fsdp_kwargs)
567 |
568 | print('Loaded model on rank', rank)
569 | dist.barrier()
570 |
571 | def clip_gradient(self):
572 | """Clip the gradient norm of the parameters of an FSDP policy, gathering the gradients across all GPUs."""
573 | return self.policy.clip_grad_norm_(self.config.max_grad_norm).item()
574 |
575 | def save(self, output_dir=None, metrics=None):
576 | """Save policy, optimizer, and scheduler state to disk, gathering from all processes and saving only on the rank 0 process."""
577 | save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
578 | with FSDP.state_dict_type(self.policy, StateDictType.FULL_STATE_DICT, state_dict_config=save_policy):
579 | policy_state_dict = self.policy.state_dict()
580 |
581 | if self.rank == 0:
582 | self.write_state_dict(self.example_counter, policy_state_dict, metrics, 'policy.pt', output_dir)
583 | del policy_state_dict
584 | dist.barrier()
585 |
586 | save_policy = FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True)
587 | with FSDP.state_dict_type(self.policy, StateDictType.FULL_STATE_DICT, optim_state_dict_config=save_policy):
588 | optimizer_state_dict = FSDP.optim_state_dict(self.policy, self.optimizer)
589 |
590 | if self.rank == 0:
591 | self.write_state_dict(self.example_counter, optimizer_state_dict, metrics, 'optimizer.pt', output_dir)
592 | del optimizer_state_dict
593 | dist.barrier()
594 |
595 | if self.rank == 0:
596 | scheduler_state_dict = self.scheduler.state_dict()
597 | self.write_state_dict(self.example_counter, scheduler_state_dict, metrics, 'scheduler.pt', output_dir)
598 | dist.barrier()
599 |
600 |
601 | class TensorParallelTrainer(BasicTrainer):
602 | def __init__(self, policy, config, seed, run_dir, reference_model=None, rank=0, world_size=1):
603 | """A trainer subclass that uses TensorParallel to shard the model across multiple GPUs.
604 |
605 | Based on https://github.com/BlackSamorez/tensor_parallel. Note sampling is extremely slow,
606 | see https://github.com/BlackSamorez/tensor_parallel/issues/66.
607 | """
608 | super().__init__(policy, config, seed, run_dir, reference_model, rank, world_size)
609 |
610 | rank0_print('Sharding policy...')
611 | self.policy = tp.tensor_parallel(policy, sharded=True)
612 | if config.loss.name == 'tdpo':
613 | rank0_print('Sharding reference model...')
614 | self.reference_model = tp.tensor_parallel(reference_model, sharded=False)
615 |
616 | def save(self, output_dir=None, metrics=None):
617 | """Save (unsharded) policy state to disk."""
618 | with tp.save_tensor_parallel(self.policy):
619 | policy_state_dict = self.policy.state_dict()
620 |
621 | self.write_state_dict(self.example_counter, policy_state_dict, metrics, 'policy.pt', output_dir)
622 | del policy_state_dict
623 |
--------------------------------------------------------------------------------