├── config ├── loss │ ├── sft.yaml │ └── tdpo.yaml ├── model │ ├── gpt2-xl.yaml │ ├── gpt2-large.yaml │ ├── gptj.yaml │ ├── llama7b.yaml │ ├── pythia28.yaml │ ├── pythia69.yaml │ └── blank_model.yaml └── config.yaml ├── figs ├── TDPO_vs_DPO.png └── IMDb_experiment.png ├── requirements.txt ├── train.py ├── README.md ├── utils.py ├── LICENSE ├── preference_datasets.py └── trainers.py /config/loss/sft.yaml: -------------------------------------------------------------------------------- 1 | name: sft -------------------------------------------------------------------------------- /figs/TDPO_vs_DPO.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vance0124/Token-level-Direct-Preference-Optimization/HEAD/figs/TDPO_vs_DPO.png -------------------------------------------------------------------------------- /figs/IMDb_experiment.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vance0124/Token-level-Direct-Preference-Optimization/HEAD/figs/IMDb_experiment.png -------------------------------------------------------------------------------- /config/model/gpt2-xl.yaml: -------------------------------------------------------------------------------- 1 | name_or_path: gpt2-xl 2 | tokenizer_name_or_path: null 3 | archive: null 4 | block_name: GPT2Block 5 | 6 | policy_dtype: float32 7 | fsdp_policy_mp: null 8 | reference_dtype: float16 -------------------------------------------------------------------------------- /config/model/gpt2-large.yaml: -------------------------------------------------------------------------------- 1 | name_or_path: gpt2-large 2 | tokenizer_name_or_path: null 3 | archive: null 4 | block_name: GPT2Block 5 | 6 | policy_dtype: float32 7 | fsdp_policy_mp: null 8 | reference_dtype: float16 -------------------------------------------------------------------------------- /config/model/gptj.yaml: -------------------------------------------------------------------------------- 1 | name_or_path: EleutherAI/gpt-j-6b 2 | tokenizer_name_or_path: null 3 | archive: null 4 | block_name: GPTJBlock 5 | 6 | policy_dtype: float32 7 | fsdp_policy_mp: null 8 | reference_dtype: float16 -------------------------------------------------------------------------------- /config/model/llama7b.yaml: -------------------------------------------------------------------------------- 1 | name_or_path: huggyllama/llama-7b 2 | tokenizer_name_or_path: null 3 | archive: null 4 | block_name: LlamaDecoderLayer 5 | 6 | policy_dtype: float32 7 | fsdp_policy_mp: null 8 | reference_dtype: float16 -------------------------------------------------------------------------------- /config/model/pythia28.yaml: -------------------------------------------------------------------------------- 1 | name_or_path: EleutherAI/pythia-2.8b 2 | tokenizer_name_or_path: null 3 | archive: null 4 | block_name: GPTNeoXLayer 5 | 6 | policy_dtype: bfloat16 7 | fsdp_policy_mp: null 8 | reference_dtype: float16 -------------------------------------------------------------------------------- /config/model/pythia69.yaml: -------------------------------------------------------------------------------- 1 | name_or_path: EleutherAI/pythia-6.9b 2 | tokenizer_name_or_path: null 3 | archive: null 4 | block_name: GPTNeoXLayer 5 | 6 | policy_dtype: float32 7 | fsdp_policy_mp: null 8 | reference_dtype: float16 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ipykernel==6.23.1 2 | numpy==1.24.3 3 | tokenizers==0.13.3 4 | torch==2.0.1 5 | tqdm==4.65.0 6 | transformers==4.29.2 7 | datasets==2.12.0 8 | beautifulsoup4==4.12.2 9 | wandb==0.15.3 10 | hydra-core==1.3.2 11 | tensor-parallel==1.2.4 -------------------------------------------------------------------------------- /config/loss/tdpo.yaml: -------------------------------------------------------------------------------- 1 | # do TDPO preference-based training 2 | name: tdpo 3 | if_tdpo2: false 4 | 5 | # the temperature parameter for TDPO; lower values mean we care less about 6 | # the reference model 7 | alpha: 0.5 8 | beta: 0.1 9 | 10 | # if true, use a uniform (maximum entropy) reference model 11 | reference_free: false -------------------------------------------------------------------------------- /config/model/blank_model.yaml: -------------------------------------------------------------------------------- 1 | # the name of the model to use; should be something like 2 | # gpt2-xl or gpt-neo-2.7B or huggyllama/llama-7b 3 | name_or_path: ??? 4 | 5 | # the name of the tokenizer to use; if null, will use the tokenizer from the model 6 | tokenizer_name_or_path: null 7 | 8 | # override pre-trained weights (e.g., from SFT); optional 9 | archive: null 10 | 11 | # the name of the module class to wrap with FSDP; should be something like 12 | # e.g. GPT2Block, GPTNeoXLayer, LlamaDecoderLayer, etc. 13 | block_name: null 14 | 15 | # the dtype for the policy parameters/optimizer state 16 | policy_dtype: float32 17 | 18 | # the mixed precision dtype if using FSDP; defaults to the same as the policy 19 | fsdp_policy_mp: null 20 | 21 | # the dtype for the reference model (which is used for inference only) 22 | reference_dtype: float16 23 | -------------------------------------------------------------------------------- /config/config.yaml: -------------------------------------------------------------------------------- 1 | # random seed for batch sampling 2 | seed: 0 3 | 4 | # name for this experiment in the local run directory and on wandb 5 | exp_name: ??? 6 | 7 | # the batch size for training; for FSDP, the batch size per GPU is batch_size / (grad_accumulation_steps * num_gpus) 8 | batch_size: 4 9 | 10 | # the batch size during evaluation and sampling, if enabled 11 | eval_batch_size: 16 12 | 13 | # debug mode (disables wandb, model checkpointing, etc.) 14 | debug: false 15 | 16 | # the port to use for FSDP 17 | fsdp_port: null 18 | 19 | # which dataset(s) to train on; can pass a list like datasets=[hh,shp] 20 | datasets: 21 | - hh 22 | 23 | # wandb configuration 24 | wandb: 25 | enabled: true 26 | entity: null 27 | project: "token-level-direct-preference-optimization" 28 | 29 | # to create the local run directory and cache models/datasets, 30 | # we will try each of these directories in order; if none exist, 31 | # we will create the last one and use it 32 | local_dirs: 33 | - /scr-ssd 34 | - /scr 35 | - .cache 36 | 37 | # whether or not to generate samples during evaluation; disable for FSDP/TensorParallel 38 | # is recommended, because they are slow 39 | sample_during_eval: true 40 | 41 | # how many model samples to generate during evaluation 42 | n_eval_model_samples: 16 43 | 44 | # whether to eval at the very beginning of training 45 | do_first_eval: true 46 | 47 | # an OmegaConf resolver that returns the local run directory, calling a function in utils.py 48 | local_run_dir: ${get_local_run_dir:${exp_name},${local_dirs}} 49 | 50 | # the learning rate 51 | lr: 5e-6 52 | 53 | # number of steps to accumulate over for each batch 54 | # (e.g. if batch_size=4 and gradient_accumulation_steps=2, then we will 55 | # accumulate gradients over 2 microbatches of size 2) 56 | gradient_accumulation_steps: 1 57 | 58 | # the maximum gradient norm to clip to 59 | max_grad_norm: 10.0 60 | 61 | # the maximum allowed length for an input (prompt + response) 62 | max_length: 512 63 | 64 | # the maximum allowed length for a prompt 65 | max_prompt_length: 256 66 | 67 | # the number of epochs to train for; if null, must specify n_examples 68 | n_epochs: 1 69 | 70 | # the number of examples to train for; if null, must specify n_epochs 71 | n_examples: null 72 | 73 | # the number of examples to evaluate on (and sample from, if sample_during_eval is true) 74 | n_eval_examples: 256 75 | 76 | # the trainer class to use (e.g. BasicTrainer, FSDPTrainer, TensorParallelTrainer) 77 | trainer: BasicTrainer 78 | 79 | # The optimizer to use; we use RMSprop because it works about as well as Adam and is more memory-efficient 80 | optimizer: RMSprop 81 | 82 | # number of linear warmup steps for the learning rate 83 | warmup_steps: 150 84 | 85 | # whether or not to use activation/gradient checkpointing 86 | activation_checkpointing: false 87 | 88 | # evaluate and save model every eval_every steps 89 | eval_every: 20_000 90 | 91 | # prevent wandb from logging more than once per minimum_log_interval_secs 92 | minimum_log_interval_secs: 1.0 93 | 94 | defaults: 95 | - _self_ 96 | - model: blank_model_fp32 # basic model configuration 97 | - loss: sft # which loss function, either sft or dpo (specify loss.beta if using dpo) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | torch.backends.cuda.matmul.allow_tf32 = True 4 | import torch.nn as nn 5 | import transformers 6 | from utils import get_local_dir, get_local_run_dir, disable_dropout, init_distributed, get_open_port 7 | import os 8 | import hydra 9 | import torch.multiprocessing as mp 10 | from omegaconf import OmegaConf, DictConfig 11 | import trainers 12 | import wandb 13 | import json 14 | import socket 15 | from typing import Optional, Set 16 | import resource 17 | 18 | OmegaConf.register_new_resolver("get_local_run_dir", lambda exp_name, local_dirs: get_local_run_dir(exp_name, local_dirs)) 19 | 20 | 21 | def worker_main(rank: int, world_size: int, config: DictConfig, policy: nn.Module, reference_model: Optional[nn.Module] = None): 22 | """Main function for each worker process (may be only 1 for BasicTrainer/TensorParallelTrainer).""" 23 | if 'FSDP' in config.trainer: 24 | init_distributed(rank, world_size, port=config.fsdp_port) 25 | 26 | if config.debug: 27 | wandb.init = lambda *args, **kwargs: None 28 | wandb.log = lambda *args, **kwargs: None 29 | 30 | if rank == 0 and config.wandb.enabled: 31 | os.environ['WANDB_CACHE_DIR'] = get_local_dir(config.local_dirs) 32 | wandb.init( 33 | entity=config.wandb.entity, 34 | project=config.wandb.project, 35 | config=OmegaConf.to_container(config), 36 | dir=get_local_dir(config.local_dirs), 37 | name=config.exp_name, 38 | ) 39 | 40 | TrainerClass = getattr(trainers, config.trainer) 41 | print(f'Creating trainer on process {rank} with world size {world_size}') 42 | trainer = TrainerClass(policy, config, config.seed, config.local_run_dir, reference_model=reference_model, rank=rank, world_size=world_size) 43 | 44 | trainer.train() 45 | trainer.save() 46 | 47 | 48 | @hydra.main(version_base=None, config_path="config", config_name="config") 49 | def main(config: DictConfig): 50 | """Main entry point for training. Validates config, creates/initializes model(s), and kicks off worker process(es).""" 51 | 52 | # Resolve hydra references, e.g. so we don't re-compute the run directory 53 | OmegaConf.resolve(config) 54 | 55 | missing_keys: Set[str] = OmegaConf.missing_keys(config) 56 | if missing_keys: 57 | raise ValueError(f"Got missing keys in config:\n{missing_keys}") 58 | 59 | if config.eval_every % config.batch_size != 0: 60 | print('WARNING: eval_every must be divisible by batch_size') 61 | print('Setting eval_every to', config.eval_every - config.eval_every % config.batch_size) 62 | config.eval_every = config.eval_every - config.eval_every % config.batch_size 63 | 64 | if 'FSDP' in config.trainer and config.fsdp_port is None: 65 | free_port = get_open_port() 66 | print('no FSDP port specified; using open port for FSDP:', free_port) 67 | config.fsdp_port = free_port 68 | 69 | print(OmegaConf.to_yaml(config)) 70 | 71 | config_path = os.path.join(config.local_run_dir, 'config.yaml') 72 | with open(config_path, 'w') as f: 73 | OmegaConf.save(config, f) 74 | 75 | print('=' * 80) 76 | print(f'Writing to {socket.gethostname()}:{config.local_run_dir}') 77 | print('=' * 80) 78 | 79 | os.environ['XDG_CACHE_HOME'] = get_local_dir(config.local_dirs) 80 | print('building policy') 81 | model_kwargs = {'device_map': 'balanced'} if config.trainer == 'BasicTrainer' else {} 82 | policy_dtype = getattr(torch, config.model.policy_dtype) 83 | policy = transformers.AutoModelForCausalLM.from_pretrained( 84 | config.model.name_or_path, cache_dir=get_local_dir(config.local_dirs), low_cpu_mem_usage=True, torch_dtype=policy_dtype, **model_kwargs) 85 | disable_dropout(policy) 86 | 87 | if config.loss.name == 'tdpo': 88 | print('building reference model') 89 | reference_model_dtype = getattr(torch, config.model.reference_dtype) 90 | reference_model = transformers.AutoModelForCausalLM.from_pretrained( 91 | config.model.name_or_path, cache_dir=get_local_dir(config.local_dirs), low_cpu_mem_usage=True, torch_dtype=reference_model_dtype, **model_kwargs) 92 | disable_dropout(reference_model) 93 | else: 94 | reference_model = None 95 | 96 | if config.model.archive is not None: 97 | state_dict = torch.load(config.model.archive, map_location='cpu') 98 | step, metrics = state_dict['step_idx'], state_dict['metrics'] 99 | print(f'loading pre-trained weights at step {step} from {config.model.archive} with metrics {json.dumps(metrics, indent=2)}') 100 | policy.load_state_dict(state_dict['state']) 101 | if config.loss.name == 'tdpo': 102 | reference_model.load_state_dict(state_dict['state']) 103 | print('loaded pre-trained weights') 104 | 105 | if 'FSDP' in config.trainer: 106 | world_size = torch.cuda.device_count() 107 | print('starting', world_size, 'processes for FSDP training') 108 | soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE) 109 | resource.setrlimit(resource.RLIMIT_NOFILE, (hard, hard)) 110 | print(f'setting RLIMIT_NOFILE soft limit to {hard} from {soft}') 111 | mp.spawn(worker_main, nprocs=world_size, args=(world_size, config, policy, reference_model), join=True) 112 | else: 113 | print('starting single-process worker') 114 | worker_main(0, 1, config, policy, reference_model) 115 | 116 | 117 | if __name__ == '__main__': 118 | main() -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TDPO: Token-level Direct Preference Optimization 2 | 3 | This repo contains a reference implementation of the TDPO algorithm for training language models from preference data, as described in the paper [_Token-level Direct Preference Optimization_](https://arxiv.org/pdf/2404.11999.pdf) (ICML 2024). Our implementation is based on [DPO](https://github.com/eric-mitchell/direct-preference-optimization), and follows the same usage guidelines. 4 | 5 | 6 | 7 |
8 | Comparison of Loss Functions for $\mathrm{DPO}$, $\mathrm{TDPO}_1$ and $\mathrm{TDPO}_2$ Methods. 9 |
10 | 11 | 12 | 13 | The TDPO pipeline has two stages: 14 | 15 | 1. Run supervised fine-tuning (SFT) on the dataset(s) of interest. Generally, $(x, y_w)$ from the preference dataset is directly used as the supervised fine-tuning target. 16 | 2. Run preference learning on the model from step 1, using preference data (ideally from the same distribution as the SFT examples). The dataset is generally composed of $\mathcal{D} = \{(x, y_w, y_l)_i\}_{i=1}^N$, where $x$ represents the prompt, $y_w$ and $y_l$ denote the preferred and dispreferred completion. 17 | 18 | During training, we generally train for **one episode** in the **SFT** stage, while in the **RL Fine-tuning** stage, we run **multiple episodes** (e.g., three episodes) to enhance the performance of our algorithm. 19 | 20 | 21 | 22 | The files in this repo are: 23 | 24 | - `train.py`: the main entry point for training (either SFT or TDPO preference-based training) 25 | - `trainers.py`: the trainer classes (e.g., implementing the loop of learning as well as multi-GPU logic) 26 | - `utils.py`: some convenience functions used by multiple other files 27 | - `preference_datasets.py`: dataset processing logic for both SFT and TDPO preference-based training; **this is where you'll need to make some additions to train on your own data** 28 | 29 | 30 | 31 | The code here supports any causal HuggingFace model- look at our examples in `config/model` to add your own. Adding your own datasets is also easy. See [the README section](https://github.com/huggingface/peft) on adding datasets. 32 | 33 | 34 | 35 | ## Example 36 | 37 | Let's work through a complete example training pythia 2.8B model on the Anthropic-HH dataset. 38 | 39 | ### Step 1: Set up environment 40 | 41 | python3 -m venv env 42 | source env/bin/activate 43 | pip install -r requirements.txt 44 | 45 | 46 | 47 | ### Step 2: Run SFT 48 | 49 | python -u train.py model=pythia28 datasets=[hh] loss=sft exp_name=anthropic_tdpo_pythia28 gradient_accumulation_steps=2 batch_size=64 eval_batch_size=32 trainer=FSDPTrainer sample_during_eval=false model.fsdp_policy_mp=bfloat16 50 | 51 | 52 | 53 | ### Step 3: Run TDPO 54 | 55 | For running **TDPO2**, we recommend the following command: 56 | 57 | python -u train.py model=pythia28 datasets=[hh] loss=tdpo loss.alpha=0.5 loss.beta=0.1 exp_name=anthropic_tdpo_pythia28 gradient_accumulation_steps=2 batch_size=64 eval_batch_size=32 trainer=FSDPTrainer sample_during_eval=false model.fsdp_policy_mp=bfloat16 model.archive=/path/to/archive/from/sft/LATEST/policy.pt 58 | 59 | 60 | 61 | To run **TDPO1**, we only need to pass the additional parameter `loss.if_tdpo2=false`: 62 | 63 | ~~~ 64 | python -u train.py model=pythia28 datasets=[hh] loss=tdpo loss.beta=0.1 loss.if_tdpo2=false exp_name=anthropic_tdpo_pythia28 gradient_accumulation_steps=2 batch_size=64 eval_batch_size=32 trainer=FSDPTrainer sample_during_eval=false model.fsdp_policy_mp=bfloat16 model.archive=/path/to/archive/from/sft/LATEST/policy.pt 65 | ~~~ 66 | 67 | When the learning rate/**lr** is low, we recommend the **TDPO1** algorithm; conversely, for higher learning rates, the **TDPO2** algorithm is preferable. 68 | 69 | 70 | 71 | We have included the training curve from wandb [here](https://wandb.ai/492277267/tdpo_demos). Additionally, we have also provided the comparison results with **DPO** on the IDMb experiment, as shown below. 72 | 73 | ![The experiment on IMDb dataset. (a) represents the frontier of expected reward and KL divergence with respect to the reference model. We implemented DPO, $\mathrm{TDPO}_1$, and different versions of $\mathrm{TDPO}_2$ with respect to the parameter $\alpha$. Both $\mathrm{TDPO}_1$ and $\mathrm{TDPO}_2$ outperform DPO in terms of the frontier, with $\mathrm{TDPO}_2$ showing further improvement over $\mathrm{TDPO}_1$. This demonstrates the effectiveness of our analysis and modifications. (b) and (c) present the progression of sequential KL divergence on the preferred and dispreferred responses subset over training steps respectively. (d) illustrates the difference between the sequential KL divergence on the dispreferred responses subset and that on the preferred responses subset throughout the training process, namely $margin=|D_{\mathrm{SeqKL}}({x}, {y}_w;\pi_{\mathrm{ref}}\|\pi_{\theta}) - D_{\mathrm{SeqKL}}({x}, {y}_l;\pi_{\mathrm{ref}}\|\pi_{\theta})|$. $\mathrm{TDPO}_2$ exhibit superior regulation over KL divergence compared to the $\mathrm{TDPO}_1$ and DPO algorithm.](figs/IMDb_experiment.png) 74 | 75 | For more experimental details and information, please refer to our paper. 76 | 77 | 78 | 79 | ## Acknowledgements 80 | 81 | Many thanks to the contributors of [DPO](https://github.com/eric-mitchell/direct-preference-optimization) for their valuable contributions to the RLHF community. For more detailed information, please refer to the [DPO](https://github.com/eric-mitchell/direct-preference-optimization). 82 | 83 | 84 | 85 | ## Citing TDPO 86 | 87 | If TDPO or this repository is useful in your research, you can use the following BibTeX entry to cite our paper: 88 | 89 | ~~~ 90 | @article{zeng2024token, 91 | title={Token-level Direct Preference Optimization}, 92 | author={Zeng, Yongcheng and Liu, Guoqing and Ma, Weiyu and Yang, Ning and Zhang, Haifeng and Wang, Jun}, 93 | journal={arXiv preprint arXiv:2404.11999}, 94 | year={2024} 95 | } 96 | ~~~ 97 | 98 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import getpass 3 | from datetime import datetime 4 | import torch 5 | import random 6 | import numpy as np 7 | import torch.distributed as dist 8 | import inspect 9 | import importlib.util 10 | import socket 11 | import os 12 | from typing import Dict, Union, Type, List 13 | 14 | 15 | def get_open_port(): 16 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: 17 | s.bind(('', 0)) # bind to all interfaces and use an OS provided port 18 | return s.getsockname()[1] # return only the port number 19 | 20 | 21 | def get_remote_file(remote_path, local_path=None): 22 | hostname, path = remote_path.split(':') 23 | local_hostname = socket.gethostname() 24 | if hostname == local_hostname or hostname == local_hostname[:local_hostname.find('.')]: 25 | return path 26 | 27 | if local_path is None: 28 | local_path = path 29 | # local_path = local_path.replace('/scr-ssd', '/scr') 30 | if os.path.exists(local_path): 31 | return local_path 32 | local_dir = os.path.dirname(local_path) 33 | os.makedirs(local_dir, exist_ok=True) 34 | 35 | print(f'Copying {hostname}:{path} to {local_path}') 36 | os.system(f'scp {remote_path} {local_path}') 37 | return local_path 38 | 39 | 40 | def rank0_print(*args, **kwargs): 41 | """Print, but only on rank 0.""" 42 | if not dist.is_initialized() or dist.get_rank() == 0: 43 | print(*args, **kwargs) 44 | 45 | 46 | def get_local_dir(prefixes_to_resolve: List[str]) -> str: 47 | """Return the path to the cache directory for this user.""" 48 | for prefix in prefixes_to_resolve: 49 | if os.path.exists(prefix): 50 | return f"{prefix}/{getpass.getuser()}" 51 | os.makedirs(prefix) 52 | return f"{prefix}/{getpass.getuser()}" 53 | 54 | 55 | def get_local_run_dir(exp_name: str, local_dirs: List[str]) -> str: 56 | """Create a local directory to store outputs for this run, and return its path.""" 57 | now = datetime.now() 58 | timestamp = now.strftime("%Y-%m-%d_%H-%M-%S_%f") 59 | run_dir = f"{get_local_dir(local_dirs)}/{exp_name}_{timestamp}" 60 | os.makedirs(run_dir, exist_ok=True) 61 | return run_dir 62 | 63 | 64 | def slice_and_move_batch_for_device(batch: Dict, rank: int, world_size: int, device: str) -> Dict: 65 | """Slice a batch into chunks, and move each chunk to the specified device.""" 66 | chunk_size = len(list(batch.values())[0]) // world_size 67 | start = chunk_size * rank 68 | end = chunk_size * (rank + 1) 69 | sliced = {k: v[start:end] for k, v in batch.items()} 70 | on_device = {k: (v.to(device) if isinstance(v, torch.Tensor) else v) for k, v in sliced.items()} 71 | return on_device 72 | 73 | 74 | def pad_to_length(tensor: torch.Tensor, length: int, pad_value: Union[int, float], dim: int = -1) -> torch.Tensor: 75 | if tensor.size(dim) >= length: 76 | return tensor 77 | else: 78 | pad_size = list(tensor.shape) 79 | pad_size[dim] = length - tensor.size(dim) 80 | return torch.cat([tensor, pad_value * torch.ones(*pad_size, dtype=tensor.dtype, device=tensor.device)], dim=dim) 81 | 82 | 83 | def all_gather_if_needed(values: torch.Tensor, rank: int, world_size: int) -> torch.Tensor: 84 | """Gather and stack/cat values from all processes, if there are multiple processes.""" 85 | if world_size == 1: 86 | return values 87 | 88 | all_values = [torch.empty_like(values).to(rank) for _ in range(world_size)] 89 | dist.all_gather(all_values, values) 90 | cat_function = torch.cat if values.dim() > 0 else torch.stack 91 | return cat_function(all_values, dim=0) 92 | 93 | 94 | def formatted_dict(d: Dict) -> Dict: 95 | """Format a dictionary for printing.""" 96 | return {k: (f"{v:.5g}" if type(v) == float else v) for k, v in d.items()} 97 | 98 | 99 | def disable_dropout(model: torch.nn.Module): 100 | """Disable dropout in a model.""" 101 | for module in model.modules(): 102 | if isinstance(module, torch.nn.Dropout): 103 | module.p = 0 104 | 105 | 106 | def print_gpu_memory(rank: int = None, message: str = ''): 107 | """Print the amount of GPU memory currently allocated for each GPU.""" 108 | if torch.cuda.is_available(): 109 | device_count = torch.cuda.device_count() 110 | for i in range(device_count): 111 | device = torch.device(f'cuda:{i}') 112 | allocated_bytes = torch.cuda.memory_allocated(device) 113 | if allocated_bytes == 0: 114 | continue 115 | print('*' * 40) 116 | print(f'[{message} rank {rank} ] GPU {i}: {allocated_bytes / 1024 ** 2:.2f} MB') 117 | print('*' * 40) 118 | 119 | 120 | def get_block_class_from_model(model: torch.nn.Module, block_class_name: str) -> torch.nn.Module: 121 | """Get the class of a block from a model, using the block's class name.""" 122 | for module in model.modules(): 123 | if module.__class__.__name__ == block_class_name: 124 | return module.__class__ 125 | raise ValueError(f"Could not find block class {block_class_name} in model {model}") 126 | 127 | 128 | def get_block_class_from_model_class_and_block_name(model_class: Type, block_class_name: str) -> Type: 129 | filepath = inspect.getfile(model_class) 130 | assert filepath.endswith('.py'), f"Expected a .py file, got {filepath}" 131 | assert os.path.exists(filepath), f"File {filepath} does not exist" 132 | assert "transformers" in filepath, f"Expected a transformers model, got {filepath}" 133 | 134 | module_name = filepath[filepath.find('transformers'):].replace('/', '.')[:-3] 135 | print(f"Searching in file {filepath}, module {module_name} for class {block_class_name}") 136 | 137 | # Load the module dynamically 138 | spec = importlib.util.spec_from_file_location(module_name, filepath) 139 | module = importlib.util.module_from_spec(spec) 140 | spec.loader.exec_module(module) 141 | 142 | # Get the class dynamically 143 | class_ = getattr(module, block_class_name) 144 | print(f"Found class {class_} in module {module_name}") 145 | return class_ 146 | 147 | 148 | def init_distributed(rank: int, world_size: int, master_addr: str = 'localhost', port: int = 12355, backend: str = 'nccl'): 149 | print(rank, 'initializing distributed') 150 | os.environ["MASTER_ADDR"] = master_addr 151 | os.environ["MASTER_PORT"] = str(port) 152 | dist.init_process_group(backend, rank=rank, world_size=world_size) 153 | torch.cuda.set_device(rank) 154 | 155 | 156 | class TemporarilySeededRandom: 157 | def __init__(self, seed): 158 | """Temporarily set the random seed, and then restore it when exiting the context.""" 159 | self.seed = int(seed) 160 | self.stored_state = None 161 | self.stored_np_state = None 162 | 163 | def __enter__(self): 164 | # Store the current random state 165 | self.stored_state = random.getstate() 166 | self.stored_np_state = np.random.get_state() 167 | 168 | # Set the random seed 169 | random.seed(self.seed) 170 | np.random.seed(self.seed) 171 | 172 | def __exit__(self, exc_type, exc_value, traceback): 173 | # Restore the random state 174 | random.setstate(self.stored_state) 175 | np.random.set_state(self.stored_np_state) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /preference_datasets.py: -------------------------------------------------------------------------------- 1 | import datasets 2 | import torch 3 | from torch.utils.data import DataLoader, Dataset 4 | from utils import get_local_dir, TemporarilySeededRandom 5 | from torch.nn.utils.rnn import pad_sequence 6 | from collections import defaultdict 7 | import tqdm 8 | import random 9 | from bs4 import BeautifulSoup, NavigableString 10 | import numpy as np 11 | from typing import Dict, List, Optional, Iterator, Callable, Union, Tuple 12 | import os 13 | 14 | 15 | def extract_anthropic_prompt(prompt_and_response): 16 | """Extract the anthropic prompt from a prompt and response pair.""" 17 | search_term = '\n\nAssistant:' 18 | search_term_idx = prompt_and_response.rfind(search_term) 19 | assert search_term_idx != -1, f"Prompt and response does not contain '{search_term}'" 20 | return prompt_and_response[:search_term_idx + len(search_term)] 21 | 22 | 23 | def strip_html_tags(html_string): 24 | """Strip HTML tags from a string, except for tags (which contain real code in the StackExchange answers).""" 25 | # Create a BeautifulSoup object 26 | soup = BeautifulSoup(html_string, 'html.parser') 27 | 28 | # Initialize an empty list to store the text 29 | text = [] 30 | for element in soup.children: 31 | if isinstance(element, NavigableString): 32 | continue 33 | if element.name == 'p': 34 | text.append(''.join(child.string for child in element.children if isinstance(child, NavigableString))) 35 | elif element.name == 'pre': 36 | for code in element.find_all('code'): 37 | text.append("" + code.get_text() + "") 38 | elif element.name == 'code': 39 | text.append("" + element.get_text() + "") 40 | 41 | # Join the text together with newlines in between 42 | text = "\n\n".join(text) 43 | 44 | return text 45 | 46 | 47 | def get_se(split, silent=False, cache_dir: str = None) -> Dict[ 48 | str, Dict[str, Union[List[Tuple[int, int]], List[str], str]]]: 49 | """Load the StackExchange dataset from Huggingface, and return a dict of prompts and responses. See get_hh for the format. 50 | 51 | We strip the HTML tags from the responses (except for tags), and we add necessary newlines. 52 | """ 53 | print(f'Loading SE dataset ({split} split) from Huggingface...') 54 | dataset = datasets.load_dataset('HuggingFaceH4/stack-exchange-preferences', cache_dir=cache_dir)['train'] 55 | print('done') 56 | 57 | # shuffle the dataset and select 1% for test 58 | dataset = dataset.shuffle(seed=42) 59 | dataset = dataset.select(range(int(len(dataset) * 0.01))) if split == 'test' else dataset.select( 60 | range(int(len(dataset) * 0.01), len(dataset))) 61 | 62 | def strip_html(x): 63 | x['question'] = strip_html_tags(x['question']) 64 | for a in x['answers']: 65 | a['text'] = strip_html_tags(a['text']) 66 | return x 67 | 68 | dataset = dataset.map(strip_html, num_proc=64) 69 | 70 | data = defaultdict(dict) 71 | for row in tqdm.tqdm(dataset, desc='Processing SE', disable=silent): 72 | prompt = '\n\nHuman: ' + row['question'] + '\n\nAssistant:' 73 | responses = [' ' + a['text'] for a in row['answers']] 74 | scores = [a['pm_score'] for a in row['answers']] 75 | 76 | pairs = [] 77 | for i in range(len(responses)): 78 | for j in range(i + 1, len(responses)): 79 | pairs.append((i, j) if scores[i] > scores[j] else (j, i)) 80 | 81 | data[prompt]['responses'] = responses 82 | data[prompt]['pairs'] = pairs 83 | data[prompt]['sft_target'] = max(responses, key=lambda x: scores[responses.index(x)]) 84 | 85 | return data 86 | 87 | 88 | def get_shp(split: str, silent: bool = False, cache_dir: str = None) -> Dict[ 89 | str, Dict[str, Union[List[Tuple[int, int]], List[str], str]]]: 90 | """Load the Stanford Human Preferences dataset from Huggingface and convert it to the necessary format. See hh for the format. 91 | 92 | We filter preference pairs to only keep pairs where the score ratio is at least 2. 93 | For this dataset, the sft_target is the response with the highest score. 94 | """ 95 | print(f'Loading SHP dataset ({split} split) from Huggingface...') 96 | dataset = datasets.load_dataset('stanfordnlp/SHP', split=split, cache_dir=cache_dir) 97 | print('done') 98 | 99 | data = defaultdict(lambda: defaultdict(list)) 100 | for row in tqdm.tqdm(dataset, desc='Processing SHP', disable=silent): 101 | prompt = '\n\nHuman: ' + row['history'] + '\n\nAssistant:' 102 | responses = [' ' + row['human_ref_A'], ' ' + row['human_ref_B']] 103 | scores = [row['score_A'], row['score_B']] 104 | if prompt in data: 105 | n_responses = len(data[prompt]['responses']) 106 | else: 107 | n_responses = 0 108 | score_ratio = max(scores[0] / scores[1], scores[1] / scores[0]) 109 | if score_ratio < 2: 110 | continue 111 | 112 | # according to https://huggingface.co/datasets/stanfordnlp/SHP 113 | data[prompt]['pairs'].append( 114 | (n_responses, n_responses + 1) if row['labels'] == 1 else (n_responses + 1, n_responses)) 115 | data[prompt]['responses'].extend(responses) 116 | data[prompt]['scores'].extend(scores) 117 | 118 | for prompt in data: 119 | data[prompt]['sft_target'] = max(data[prompt]['responses'], 120 | key=lambda x: data[prompt]['scores'][data[prompt]['responses'].index(x)]) 121 | del data[prompt]['scores'] 122 | 123 | return data 124 | 125 | 126 | def get_hh(split: str, silent: bool = False, cache_dir: str = None) -> Dict[ 127 | str, Dict[str, Union[List[Tuple[int, int]], List[str], str]]]: 128 | """Load the Anthropic Helpful-Harmless dataset from Huggingface and convert it to the necessary format. 129 | 130 | The dataset is converted to a dictionary with the following structure: 131 | { 132 | 'prompt1': { 133 | 'responses': List[str], 134 | 'pairs': List[Tuple[int, int]], 135 | 'sft_target': str 136 | }, 137 | 'prompt2': { 138 | ... 139 | }, 140 | } 141 | 142 | Prompts should be structured as follows: 143 | \n\nHuman: \n\nAssistant: 144 | Multiple turns are allowed, but the prompt should always start with \n\nHuman: and end with \n\nAssistant:. 145 | 146 | For this dataset, the sft_target is just the chosen response. 147 | """ 148 | print(f'Loading HH dataset ({split} split) from Huggingface...') 149 | dataset = datasets.load_dataset('Anthropic/hh-rlhf', split=split, cache_dir=cache_dir) 150 | print('done') 151 | 152 | def split_prompt_and_responses(ex): 153 | prompt = extract_anthropic_prompt(ex['chosen']) 154 | chosen_response = ex['chosen'][len(prompt):] 155 | rejected_response = ex['rejected'][len(prompt):] 156 | return prompt, chosen_response, rejected_response 157 | 158 | data = defaultdict(lambda: defaultdict(list)) 159 | for row in tqdm.tqdm(dataset, desc='Processing HH', disable=silent): 160 | prompt, chosen, rejected = split_prompt_and_responses(row) 161 | responses = [chosen, rejected] 162 | n_responses = len(data[prompt]['responses']) 163 | data[prompt]['pairs'].append((n_responses, n_responses + 1)) 164 | data[prompt]['responses'].extend(responses) 165 | data[prompt]['sft_target'] = chosen 166 | 167 | return data 168 | 169 | 170 | def get_dataset(name: str, split: str, silent: bool = False, cache_dir: str = None): 171 | """Load the given dataset by name. Supported by default are 'shp', 'hh', and 'se'.""" 172 | if name == 'shp': 173 | data = get_shp(split, silent=silent, cache_dir=cache_dir) 174 | elif name == 'hh': 175 | data = get_hh(split, silent=silent, cache_dir=cache_dir) 176 | elif name == 'se': 177 | data = get_se(split, silent=silent, cache_dir=cache_dir) 178 | else: 179 | raise ValueError(f"Unknown dataset '{name}'") 180 | 181 | assert set(list(data.values())[0].keys()) == {'responses', 'pairs', 'sft_target'}, \ 182 | f"Unexpected keys in dataset: {list(list(data.values())[0].keys())}" 183 | 184 | return data 185 | 186 | 187 | def get_collate_fn(tokenizer) -> Callable[[List[Dict]], Dict[str, Union[List, torch.Tensor]]]: 188 | """Returns a collate function for the given tokenizer. 189 | 190 | The collate function takes a list of examples (dicts, where values are lists of 191 | ints [tokens] or strings [the original texts]) and returns a batch of examples, 192 | PyTorch tensors padded to the maximum length. Strings are passed through.""" 193 | 194 | def collate_fn(batch): 195 | # first, pad everything to the same length 196 | padded_batch = {} 197 | for k in batch[0].keys(): 198 | if k.endswith('_input_ids') or k.endswith('_attention_mask') or k.endswith('_labels'): 199 | if 'prompt' in k: # adapted from https://stackoverflow.com/questions/73256206 200 | to_pad = [torch.LongTensor(ex[k][::-1]) for ex in batch] 201 | else: 202 | to_pad = [torch.LongTensor(ex[k]) for ex in batch] 203 | if k.endswith('_input_ids'): 204 | padding_value = tokenizer.pad_token_id 205 | elif k.endswith('_labels'): 206 | padding_value = -100 207 | elif k.endswith('_attention_mask'): 208 | padding_value = 0 209 | else: 210 | raise ValueError(f"Unexpected key in batch '{k}'") 211 | 212 | padded_batch[k] = pad_sequence(to_pad, batch_first=True, padding_value=padding_value) 213 | if 'prompt' in k: # for the prompt, flip back so padding is on left side 214 | padded_batch[k] = padded_batch[k].flip(dims=[1]) 215 | else: 216 | padded_batch[k] = [ex[k] for ex in batch] 217 | 218 | return padded_batch 219 | 220 | return collate_fn 221 | 222 | 223 | def tokenize_batch_element(prompt: str, chosen: str, rejected: str, truncation_mode: str, tokenizer, max_length: int, 224 | max_prompt_length: int) -> Dict: 225 | """Tokenize a single batch element. 226 | 227 | At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation 228 | in case the prompt + chosen or prompt + rejected responses is/are too long. First 229 | we truncate the prompt; if we're still too long, we truncate the chosen/rejected. 230 | 231 | We also create the labels for the chosen/rejected responses, which are of length equal to 232 | the sum of the length of the prompt and the chosen/rejected response, with -100 for the 233 | prompt tokens. 234 | """ 235 | chosen_tokens = tokenizer(chosen, add_special_tokens=False) 236 | rejected_tokens = tokenizer(rejected, add_special_tokens=False) 237 | prompt_tokens = tokenizer(prompt, add_special_tokens=False) 238 | 239 | assert tokenizer.eos_token_id not in prompt_tokens['input_ids'], f"Prompt contains EOS token: {prompt}" 240 | assert tokenizer.eos_token_id not in chosen_tokens['input_ids'], f"Chosen response contains EOS token: {chosen}" 241 | assert tokenizer.eos_token_id not in rejected_tokens['input_ids'], f"Rejected response contains EOS token: {rejected}" 242 | 243 | chosen_tokens['input_ids'].append(tokenizer.eos_token_id) 244 | chosen_tokens['attention_mask'].append(1) 245 | 246 | rejected_tokens['input_ids'].append(tokenizer.eos_token_id) 247 | rejected_tokens['attention_mask'].append(1) 248 | 249 | longer_response_length = max(len(chosen_tokens['input_ids']), len(rejected_tokens['input_ids'])) 250 | 251 | # if combined sequence is too long, truncate the prompt 252 | if len(prompt_tokens['input_ids']) + longer_response_length > max_length: 253 | if truncation_mode == 'keep_start': 254 | prompt_tokens = {k: v[:max_prompt_length] for k, v in prompt_tokens.items()} 255 | elif truncation_mode == 'keep_end': 256 | prompt_tokens = {k: v[-max_prompt_length:] for k, v in prompt_tokens.items()} 257 | else: 258 | raise ValueError(f'Unknown truncation mode: {truncation_mode}') 259 | 260 | # if that's still too long, truncate the response 261 | if len(prompt_tokens['input_ids']) + longer_response_length > max_length: 262 | chosen_tokens = {k: v[:max_length - max_prompt_length] for k, v in chosen_tokens.items()} 263 | rejected_tokens = {k: v[:max_length - max_prompt_length] for k, v in rejected_tokens.items()} 264 | 265 | # Create labels 266 | chosen_sequence_tokens = {k: prompt_tokens[k] + chosen_tokens[k] for k in chosen_tokens} 267 | rejected_sequence_tokens = {k: prompt_tokens[k] + rejected_tokens[k] for k in rejected_tokens} 268 | chosen_sequence_tokens['labels'] = chosen_sequence_tokens['input_ids'][:] 269 | chosen_sequence_tokens['labels'][:len(prompt_tokens['input_ids'])] = [-100] * len(prompt_tokens['input_ids']) 270 | rejected_sequence_tokens['labels'] = rejected_sequence_tokens['input_ids'][:] 271 | rejected_sequence_tokens['labels'][:len(prompt_tokens['input_ids'])] = [-100] * len(prompt_tokens['input_ids']) 272 | 273 | batch = {} 274 | 275 | batch['prompt'] = prompt 276 | batch['chosen'] = prompt + chosen 277 | batch['rejected'] = prompt + rejected 278 | batch['chosen_response_only'] = chosen 279 | batch['rejected_response_only'] = rejected 280 | 281 | for k, toks in {'chosen': chosen_sequence_tokens, 'rejected': rejected_sequence_tokens, 282 | 'prompt': prompt_tokens}.items(): 283 | for type_key, tokens in toks.items(): 284 | if type_key == 'token_type_ids': 285 | continue 286 | batch[f'{k}_{type_key}'] = tokens 287 | 288 | return batch 289 | 290 | 291 | def get_batch_iterator(names: List[str], 292 | tokenizer, 293 | split: str = 'train', 294 | batch_size: int = 1, 295 | shuffle: bool = True, 296 | max_length: int = 512, 297 | max_prompt_length: int = 128, 298 | sft_mode: bool = False, 299 | n_epochs: Optional[int] = None, 300 | n_examples: Optional[int] = None, 301 | seed: int = 0, 302 | silent: bool = False, 303 | cache_dir: Optional[str] = None) -> Iterator[Dict]: 304 | """Get an iterator over batches of data. Stops after n_epochs or n_examples, whichever comes first. 305 | 306 | Args: 307 | names: Names of datasets to use. 308 | tokenizer: Tokenizer to use. 309 | split: Which split to use. 310 | batch_size: Batch size. 311 | shuffle: Whether to shuffle the data after each epoch. 312 | max_length: Maximum length of the combined prompt + response. 313 | max_prompt_length: Maximum length of the prompt. 314 | sft_mode: Whether to use SFT mode (i.e., return sft_target instead of chosen/rejected). In sft mode, we just return chosen_input_ids, but they contain the sft_target. 315 | n_epochs: Number of epochs to run for. This or n_examples must be specified. 316 | n_examples: Number of examples to run for. This or n_epochs must be specified. 317 | seed: Random seed. 318 | silent: Whether to silence the progress bar(s). 319 | cache_dir: Directory to cache the datasets in. 320 | """ 321 | assert n_epochs is not None or n_examples is not None, "Must specify either n_epochs or n_examples" 322 | if silent: 323 | datasets.logging.disable_progress_bar() 324 | datasets.logging.set_verbosity_error() 325 | 326 | with TemporarilySeededRandom(seed): 327 | permutation_seeds = iter(np.random.randint(0, 2 ** 31, size=1000000)) 328 | flat_data = [] 329 | for name in names: 330 | truncation_mode = 'keep_end' if name == 'hh' else 'keep_start' 331 | for prompt, data in get_dataset(name, split, silent=silent, cache_dir=cache_dir).items(): 332 | flat_data.append((prompt, data['responses'], data['pairs'], data['sft_target'], truncation_mode)) 333 | 334 | collate_fn = get_collate_fn(tokenizer) 335 | 336 | epoch_idx = 0 337 | example_idx = 0 338 | done = False 339 | while True: 340 | if n_epochs is not None and epoch_idx >= n_epochs: 341 | if not silent: 342 | print(f'Finished generating {n_epochs} epochs on {split} split') 343 | break 344 | if shuffle: 345 | with TemporarilySeededRandom(next(permutation_seeds)): 346 | random.shuffle(flat_data) 347 | 348 | batch = [] 349 | for prompt, responses, pairs, sft_target, truncation_mode in flat_data: 350 | if done: 351 | break 352 | if sft_mode: 353 | batch_element = tokenize_batch_element(prompt, sft_target, sft_target, truncation_mode, tokenizer, max_length, max_prompt_length) 354 | batch_element = {k: v for k, v in batch_element.items() if 'rejected' not in k} 355 | batch.append(batch_element) 356 | example_idx += 1 357 | if len(batch) == batch_size: 358 | yield collate_fn(batch) 359 | if n_examples is not None and example_idx >= n_examples: 360 | if not silent: 361 | print(f'Finished generating {n_examples} examples on {split} split') 362 | done = True 363 | 364 | batch = [] 365 | else: 366 | for p in pairs: 367 | if done: 368 | break 369 | batch_element = tokenize_batch_element(prompt, responses[p[0]], responses[p[1]], truncation_mode, tokenizer, max_length, max_prompt_length) 370 | batch.append(batch_element) 371 | example_idx += 1 372 | if len(batch) == batch_size: 373 | yield collate_fn(batch) 374 | if n_examples is not None and example_idx >= n_examples: 375 | if not silent: 376 | print(f'FINISHED {n_examples} EXAMPLES on {split} split') 377 | done = True 378 | batch = [] 379 | if done: 380 | break 381 | 382 | epoch_idx += 1 383 | 384 | 385 | def strings_match_up_to_spaces(str_a: str, str_b: str) -> bool: 386 | """Returns True if str_a and str_b match up to spaces, False otherwise.""" 387 | for idx in range(min(len(str_a), len(str_b)) - 2): 388 | if str_a[idx] != str_b[idx]: 389 | if str_a[idx] != ' ' and str_b[idx] != ' ': 390 | return False 391 | else: 392 | if str_a[idx] == ' ': 393 | str_a = str_a[:idx] + str_a[idx + 1:] 394 | else: 395 | str_b = str_b[:idx] + str_b[idx + 1:] 396 | 397 | return True -------------------------------------------------------------------------------- /trainers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | torch.backends.cuda.matmul.allow_tf32 = True 4 | import torch.nn.functional as F 5 | import torch.nn as nn 6 | import transformers 7 | from omegaconf import DictConfig 8 | 9 | import torch.distributed as dist 10 | from torch.distributed.fsdp import ( 11 | FullyShardedDataParallel as FSDP, 12 | MixedPrecision, 13 | StateDictType, 14 | BackwardPrefetch, 15 | ShardingStrategy, 16 | CPUOffload, 17 | ) 18 | from torch.distributed.fsdp.api import FullStateDictConfig, FullOptimStateDictConfig 19 | from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy 20 | import tensor_parallel as tp 21 | import contextlib 22 | 23 | from preference_datasets import get_batch_iterator 24 | from utils import ( 25 | slice_and_move_batch_for_device, 26 | formatted_dict, 27 | all_gather_if_needed, 28 | pad_to_length, 29 | get_block_class_from_model, 30 | rank0_print, 31 | get_local_dir, 32 | ) 33 | import numpy as np 34 | import wandb 35 | import tqdm 36 | 37 | import random 38 | import os 39 | from collections import defaultdict 40 | import time 41 | import json 42 | import functools 43 | from typing import Optional, Dict, List, Union, Tuple 44 | 45 | 46 | def tdpo_loss(chosen_logps_margin: torch.FloatTensor, 47 | rejected_logps_margin: torch.FloatTensor, 48 | chosen_position_kl: torch.FloatTensor, 49 | rejected_position_kl: torch.FloatTensor, 50 | beta: float, alpha: float = 0.5, if_tdpo2: bool = True) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: 51 | """Compute the TDPO loss for a batch of policy and reference model log probabilities. 52 | 53 | Args: 54 | chosen_logps_margin: The difference of log probabilities between the policy model and the reference model for the chosen responses. Shape: (batch_size,) 55 | rejected_logps_margin: The difference of log probabilities between the policy model and the reference model for the rejected responses. Shape: (batch_size,) 56 | chosen_position_kl: The difference of sequential kl divergence between the policy model and the reference model for the chosen responses. Shape: (batch_size,) 57 | rejected_position_kl: The difference of sequential kl divergence between the policy model and the reference model for the rejected responses. Shape: (batch_size,) 58 | beta: Temperature parameter for the TDPO loss, typically something in the range of 0.1 to 0.5. We ignore the reference model as beta -> 0. 59 | alpha: Temperature parameter for the TDPO loss, used to adjust the impact of sequential kl divergence. 60 | if_tdpo2: Determine whether to use method TDPO2, default is True; if False, then use method TDPO1. 61 | 62 | Returns: 63 | A tuple of two tensors: (losses, rewards). 64 | The losses tensor contains the TDPO loss for each example in the batch. 65 | The rewards tensors contain the rewards for response pair. 66 | """ 67 | 68 | chosen_values = chosen_logps_margin + chosen_position_kl 69 | rejected_values = rejected_logps_margin + rejected_position_kl 70 | 71 | chosen_rejected_logps_margin = chosen_logps_margin - rejected_logps_margin 72 | 73 | 74 | if not if_tdpo2: 75 | logits = chosen_rejected_logps_margin - (rejected_position_kl - chosen_position_kl) # tdpo1 76 | else: 77 | logits = chosen_rejected_logps_margin - alpha * (rejected_position_kl - chosen_position_kl.detach()) # tdpo2 78 | losses = -F.logsigmoid(beta * logits) 79 | 80 | chosen_rewards = beta * chosen_values.detach() 81 | rejected_rewards = beta * rejected_values.detach() 82 | 83 | return losses, chosen_rewards, rejected_rewards 84 | 85 | 86 | def _get_batch_logps(logits: torch.FloatTensor, labels: torch.LongTensor, 87 | average_log_prob: bool = False) -> torch.FloatTensor: 88 | """Compute the log probabilities of the given labels under the given logits. 89 | 90 | Args: 91 | logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size) 92 | labels: Labels for which to compute the log probabilities. Label tokens with a value of -100 are ignored. Shape: (batch_size, sequence_length) 93 | average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens. 94 | 95 | Returns: 96 | A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits. 97 | """ 98 | assert logits.shape[:-1] == labels.shape 99 | 100 | labels = labels[:, 1:].clone() 101 | logits = logits[:, :-1, :] 102 | loss_mask = (labels != -100) 103 | 104 | # dummy token; we'll ignore the losses on these tokens later 105 | labels[labels == -100] = 0 106 | 107 | per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) 108 | 109 | if average_log_prob: 110 | return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) 111 | else: 112 | return (per_token_logps * loss_mask).sum(-1) 113 | 114 | 115 | def _tdpo_get_batch_logps(logits: torch.FloatTensor, reference_logits: torch.FloatTensor, labels: torch.LongTensor, 116 | average_log_prob: bool = False): 117 | """Compute the kl divergence/log probabilities of the given labels under the given logits. 118 | 119 | Args: 120 | logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size) 121 | reference_logits: Logits of the reference model (unnormalized). Shape: (batch_size, sequence_length, vocab_size) 122 | labels: Labels for which to compute the log probabilities. Label tokens with a value of -100 are ignored. Shape: (batch_size, sequence_length) 123 | average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens. 124 | 125 | Returns: 126 | Several tensors of shape (batch_size,) containing the average/sum kl divergence/log probabilities of the given labels under the given logits. 127 | """ 128 | assert logits.shape[:-1] == labels.shape 129 | assert reference_logits.shape[:-1] == labels.shape 130 | 131 | labels = labels[:, 1:].clone() 132 | logits = logits[:, :-1, :] 133 | reference_logits = reference_logits[:, :-1, :] 134 | 135 | loss_mask = (labels != -100) 136 | 137 | # dummy token; we'll ignore the losses on these tokens later 138 | labels[labels == -100] = 0 139 | 140 | vocab_logps = logits.log_softmax(-1) 141 | 142 | reference_vocab_ps = reference_logits.softmax(-1) 143 | reference_vocab_logps = reference_vocab_ps.log() 144 | 145 | per_position_kl = (reference_vocab_ps * (reference_vocab_logps - vocab_logps)).sum(-1) 146 | per_token_logps = torch.gather(vocab_logps, dim=2, index=labels.unsqueeze(2)).squeeze(2) 147 | per_reference_token_logps = torch.gather(reference_vocab_logps, dim=2, index=labels.unsqueeze(2)).squeeze(2) 148 | 149 | logps_margin = per_token_logps - per_reference_token_logps 150 | 151 | if average_log_prob: 152 | return (logps_margin * loss_mask).sum(-1) / loss_mask.sum(-1), \ 153 | (per_position_kl * loss_mask).sum(-1) / loss_mask.sum(-1), \ 154 | (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) 155 | else: 156 | return (logps_margin * loss_mask).sum(-1), \ 157 | (per_position_kl * loss_mask).sum(-1), \ 158 | (per_token_logps * loss_mask).sum(-1) 159 | 160 | 161 | def concatenated_inputs(batch: Dict[str, Union[List, torch.LongTensor]]) -> Dict[str, torch.LongTensor]: 162 | """Concatenate the chosen and rejected inputs into a single tensor. 163 | 164 | Args: 165 | batch: A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors of shape (batch_size, sequence_length). 166 | 167 | Returns: 168 | A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'. 169 | """ 170 | max_length = max(batch['chosen_input_ids'].shape[1], batch['rejected_input_ids'].shape[1]) 171 | concatenated_batch = {} 172 | for k in batch: 173 | if k.startswith('chosen') and isinstance(batch[k], torch.Tensor): 174 | pad_value = -100 if 'labels' in k else 0 175 | concatenated_key = k.replace('chosen', 'concatenated') 176 | concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value) 177 | for k in batch: 178 | if k.startswith('rejected') and isinstance(batch[k], torch.Tensor): 179 | pad_value = -100 if 'labels' in k else 0 180 | concatenated_key = k.replace('rejected', 'concatenated') 181 | concatenated_batch[concatenated_key] = torch.cat(( 182 | concatenated_batch[concatenated_key], 183 | pad_to_length(batch[k], max_length, pad_value=pad_value), 184 | ), dim=0) 185 | return concatenated_batch 186 | 187 | 188 | class BasicTrainer(object): 189 | def __init__(self, policy: nn.Module, config: DictConfig, seed: int, run_dir: str, 190 | reference_model: Optional[nn.Module] = None, rank: int = 0, world_size: int = 1): 191 | """A trainer for a language model, supporting either SFT or TDPO training. 192 | 193 | If multiple GPUs are present, naively splits the model across them, effectively 194 | offering N times available memory, but without any parallel computation. 195 | """ 196 | self.seed = seed 197 | self.rank = rank 198 | self.world_size = world_size 199 | self.config = config 200 | self.run_dir = run_dir 201 | 202 | tokenizer_name_or_path = config.model.tokenizer_name_or_path or config.model.name_or_path 203 | rank0_print(f'Loading tokenizer {tokenizer_name_or_path}') 204 | self.tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_name_or_path, 205 | cache_dir=get_local_dir(config.local_dirs)) 206 | if self.tokenizer.pad_token_id is None: 207 | self.tokenizer.pad_token_id = self.tokenizer.eos_token_id 208 | 209 | data_iterator_kwargs = dict( 210 | names=config.datasets, 211 | tokenizer=self.tokenizer, 212 | shuffle=True, 213 | max_length=config.max_length, 214 | max_prompt_length=config.max_prompt_length, 215 | sft_mode=config.loss.name == 'sft', 216 | ) 217 | 218 | self.policy = policy 219 | self.reference_model = reference_model 220 | 221 | self.train_iterator = get_batch_iterator(**data_iterator_kwargs, split='train', n_epochs=config.n_epochs, 222 | n_examples=config.n_examples, batch_size=config.batch_size, 223 | silent=rank != 0, cache_dir=get_local_dir(config.local_dirs)) 224 | rank0_print(f'Loaded train data iterator') 225 | self.eval_iterator = get_batch_iterator(**data_iterator_kwargs, split='test', n_examples=config.n_eval_examples, 226 | batch_size=config.eval_batch_size, silent=rank != 0, 227 | cache_dir=get_local_dir(config.local_dirs)) 228 | self.eval_batches = list(self.eval_iterator) 229 | rank0_print(f'Loaded {len(self.eval_batches)} eval batches of size {config.eval_batch_size}') 230 | 231 | def get_batch_samples(self, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]: 232 | """Generate samples from the policy (and reference model, if doing TDPO training) for the given batch of inputs.""" 233 | 234 | # FSDP generation according to https://github.com/pytorch/pytorch/issues/100069 235 | ctx = lambda: (FSDP.summon_full_params(self.policy, writeback=False, 236 | recurse=False) if 'FSDP' in self.config.trainer else contextlib.nullcontext()) 237 | with ctx(): 238 | policy_output = self.policy.generate( 239 | batch['prompt_input_ids'], attention_mask=batch['prompt_attention_mask'], 240 | max_length=self.config.max_length, do_sample=True, pad_token_id=self.tokenizer.pad_token_id) 241 | 242 | if self.config.loss.name == 'tdpo': 243 | ctx = lambda: (FSDP.summon_full_params(self.reference_model, writeback=False, 244 | recurse=False) if 'FSDP' in self.config.trainer else contextlib.nullcontext()) 245 | with ctx(): 246 | reference_output = self.reference_model.generate( 247 | batch['prompt_input_ids'], attention_mask=batch['prompt_attention_mask'], 248 | max_length=self.config.max_length, do_sample=True, pad_token_id=self.tokenizer.pad_token_id) 249 | 250 | policy_output = pad_to_length(policy_output, self.config.max_length, self.tokenizer.pad_token_id) 251 | policy_output = all_gather_if_needed(policy_output, self.rank, self.world_size) 252 | policy_output_decoded = self.tokenizer.batch_decode(policy_output, skip_special_tokens=True) 253 | 254 | if self.config.loss.name == 'tdpo': 255 | reference_output = pad_to_length(reference_output, self.config.max_length, self.tokenizer.pad_token_id) 256 | reference_output = all_gather_if_needed(reference_output, self.rank, self.world_size) 257 | reference_output_decoded = self.tokenizer.batch_decode(reference_output, skip_special_tokens=True) 258 | else: 259 | reference_output_decoded = [] 260 | 261 | return policy_output_decoded, reference_output_decoded 262 | 263 | def tdpo_concatenated_forward(self, model: nn.Module, reference_model: nn.Module, 264 | batch: Dict[str, Union[List, torch.LongTensor]]): 265 | """Run the policy model and the reference model on the given batch of inputs, concatenating the chosen and rejected inputs together. 266 | 267 | We do this to avoid doing two forward passes, because it's faster for FSDP. 268 | """ 269 | concatenated_batch = concatenated_inputs(batch) 270 | all_logits = model(concatenated_batch['concatenated_input_ids'], 271 | attention_mask=concatenated_batch['concatenated_attention_mask']).logits.to(torch.float32) 272 | with torch.no_grad(): 273 | reference_all_logits = reference_model(concatenated_batch['concatenated_input_ids'], 274 | attention_mask=concatenated_batch[ 275 | 'concatenated_attention_mask']).logits.to(torch.float32) 276 | all_logps_margin, all_position_kl, all_logps = _tdpo_get_batch_logps(all_logits, reference_all_logits, concatenated_batch['concatenated_labels'], average_log_prob=False) 277 | 278 | chosen_logps_margin = all_logps_margin[:batch['chosen_input_ids'].shape[0]] 279 | rejected_logps_margin = all_logps_margin[batch['chosen_input_ids'].shape[0]:] 280 | chosen_position_kl = all_position_kl[:batch['chosen_input_ids'].shape[0]] 281 | rejected_position_kl = all_position_kl[batch['chosen_input_ids'].shape[0]:] 282 | 283 | chosen_logps = all_logps[:batch['chosen_input_ids'].shape[0]].detach() 284 | rejected_logps = all_logps[batch['chosen_input_ids'].shape[0]:].detach() 285 | 286 | return chosen_logps_margin, rejected_logps_margin, chosen_position_kl, rejected_position_kl, \ 287 | chosen_logps, rejected_logps 288 | 289 | def get_batch_metrics(self, batch: Dict[str, Union[List, torch.LongTensor]], loss_config: DictConfig, train=True): 290 | """Compute the SFT or TDPO loss and other metrics for the given batch of inputs.""" 291 | 292 | metrics = {} 293 | train_test = 'train' if train else 'eval' 294 | 295 | if loss_config.name == 'tdpo': 296 | chosen_logps_margin, rejected_logps_margin, chosen_position_kl, rejected_position_kl, policy_chosen_logps, policy_rejected_logps\ 297 | = self.tdpo_concatenated_forward(self.policy, self.reference_model, batch) 298 | losses, chosen_rewards, rejected_rewards = tdpo_loss(chosen_logps_margin, rejected_logps_margin, 299 | chosen_position_kl, rejected_position_kl, 300 | beta=loss_config.beta, alpha=loss_config.alpha, if_tdpo2=loss_config.if_tdpo2) 301 | 302 | reward_accuracies = (chosen_rewards > rejected_rewards).float() 303 | 304 | chosen_rewards = all_gather_if_needed(chosen_rewards, self.rank, self.world_size) 305 | rejected_rewards = all_gather_if_needed(rejected_rewards, self.rank, self.world_size) 306 | reward_accuracies = all_gather_if_needed(reward_accuracies, self.rank, self.world_size) 307 | 308 | metrics[f'rewards_{train_test}/chosen'] = chosen_rewards.cpu().numpy().tolist() 309 | metrics[f'rewards_{train_test}/rejected'] = rejected_rewards.cpu().numpy().tolist() 310 | metrics[f'rewards_{train_test}/accuracies'] = reward_accuracies.cpu().numpy().tolist() 311 | metrics[f'rewards_{train_test}/margins'] = (chosen_rewards - rejected_rewards).cpu().numpy().tolist() 312 | 313 | all_device_chosen_position_kl = all_gather_if_needed(chosen_position_kl.detach(), self.rank, self.world_size) 314 | all_device_rejected_position_kl = all_gather_if_needed(rejected_position_kl.detach(), self.rank, self.world_size) 315 | 316 | metrics[f'kl_{train_test}/chosen'] = all_device_chosen_position_kl.cpu().numpy().tolist() 317 | metrics[f'kl_{train_test}/rejected'] = all_device_rejected_position_kl.cpu().numpy().tolist() 318 | metrics[f'kl_{train_test}/margin'] = (all_device_chosen_position_kl - all_device_rejected_position_kl).cpu().numpy().tolist() 319 | 320 | policy_rejected_logps = all_gather_if_needed(policy_rejected_logps.detach(), self.rank, self.world_size) 321 | metrics[f'logps_{train_test}/rejected'] = policy_rejected_logps.cpu().numpy().tolist() 322 | 323 | elif loss_config.name == 'sft': 324 | policy_chosen_logits = self.policy(batch['chosen_input_ids'], 325 | attention_mask=batch['chosen_attention_mask']).logits.to(torch.float32) 326 | policy_chosen_logps = _get_batch_logps(policy_chosen_logits, batch['chosen_labels'], average_log_prob=False) 327 | 328 | losses = -policy_chosen_logps 329 | 330 | policy_chosen_logps = all_gather_if_needed(policy_chosen_logps.detach(), self.rank, self.world_size) 331 | metrics[f'logps_{train_test}/chosen'] = policy_chosen_logps.cpu().numpy().tolist() 332 | 333 | all_devices_losses = all_gather_if_needed(losses.detach(), self.rank, self.world_size) 334 | metrics[f'loss/{train_test}'] = all_devices_losses.cpu().numpy().tolist() 335 | 336 | return losses.mean(), metrics 337 | 338 | def train(self): 339 | """Begin either SFT or TDPO training, with periodic evaluation.""" 340 | 341 | rank0_print(f'Using {self.config.optimizer} optimizer') 342 | self.optimizer = getattr(torch.optim, self.config.optimizer)(self.policy.parameters(), lr=self.config.lr) 343 | self.scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lambda step: min(1.0, 344 | (step + 1) / ( 345 | self.config.warmup_steps + 1))) 346 | 347 | torch.manual_seed(self.seed) 348 | np.random.seed(self.seed) 349 | random.seed(self.seed) 350 | 351 | if self.config.loss.name == 'tdpo': 352 | self.reference_model.eval() 353 | 354 | self.example_counter = 0 355 | self.batch_counter = 0 356 | last_log = None 357 | 358 | for batch in self.train_iterator: 359 | #### BEGIN EVALUATION #### 360 | if self.example_counter % self.config.eval_every == 0 and ( 361 | self.example_counter > 0 or self.config.do_first_eval): 362 | rank0_print(f'Running evaluation after {self.example_counter} train examples') 363 | self.policy.eval() 364 | 365 | all_eval_metrics = defaultdict(list) 366 | if self.config.sample_during_eval: 367 | all_policy_samples, all_reference_samples = [], [] 368 | policy_text_table = wandb.Table(columns=["step", "prompt", "sample"]) 369 | if self.config.loss.name in 'tdpo': 370 | reference_text_table = wandb.Table(columns=["step", "prompt", "sample"]) 371 | 372 | for eval_batch in ( 373 | tqdm.tqdm(self.eval_batches, desc='Computing eval metrics') if self.rank == 0 else self.eval_batches): 374 | local_eval_batch = slice_and_move_batch_for_device(eval_batch, self.rank, self.world_size, 375 | self.rank) 376 | with torch.no_grad(): 377 | _, eval_metrics = self.get_batch_metrics(local_eval_batch, self.config.loss, train=False) 378 | 379 | for k, v in eval_metrics.items(): 380 | all_eval_metrics[k].extend(v) 381 | 382 | if self.config.sample_during_eval: 383 | if self.config.n_eval_model_samples < self.config.eval_batch_size: 384 | rank0_print( 385 | f'Warning: n_eval_model_samples ({self.config.n_eval_model_samples}) < eval_batch_size ({self.config.eval_batch_size}). Sampling from the first complete eval batch of prompts.') 386 | sample_batches = self.eval_batches[:1] 387 | else: 388 | n_sample_batches = self.config.n_eval_model_samples // self.config.eval_batch_size 389 | sample_batches = self.eval_batches[:n_sample_batches] 390 | for eval_batch in ( 391 | tqdm.tqdm(sample_batches, desc='Generating samples...') if self.rank == 0 else sample_batches): 392 | local_eval_batch = slice_and_move_batch_for_device(eval_batch, self.rank, self.world_size, 393 | self.rank) 394 | policy_samples, reference_samples = self.get_batch_samples(local_eval_batch) 395 | 396 | all_policy_samples.extend(policy_samples) 397 | all_reference_samples.extend(reference_samples) 398 | 399 | for prompt, sample in zip(eval_batch['prompt'], policy_samples): 400 | policy_text_table.add_data(self.example_counter, prompt, sample) 401 | if self.config.loss.name == 'tdpo': 402 | for prompt, sample in zip(eval_batch['prompt'], reference_samples): 403 | reference_text_table.add_data(self.example_counter, prompt, sample) 404 | 405 | mean_eval_metrics = {k: sum(v) / len(v) for k, v in all_eval_metrics.items()} 406 | rank0_print(f'eval after {self.example_counter}: {formatted_dict(mean_eval_metrics)}') 407 | if self.config.sample_during_eval: 408 | rank0_print(json.dumps(all_policy_samples[:10], indent=2)) 409 | if self.config.loss.name == 'tdpo': 410 | rank0_print(json.dumps(all_reference_samples[:10], indent=2)) 411 | 412 | if self.config.wandb.enabled and self.rank == 0: 413 | wandb.log(mean_eval_metrics, step=self.example_counter) 414 | 415 | if self.config.sample_during_eval: 416 | wandb.log({"policy_samples": policy_text_table}, step=self.example_counter) 417 | if self.config.loss.name == 'tdpo': 418 | wandb.log({"reference_samples": reference_text_table}, step=self.example_counter) 419 | 420 | if self.example_counter > 0: 421 | if self.config.debug: 422 | rank0_print('skipping save in debug mode') 423 | else: 424 | output_dir = os.path.join(self.run_dir, f'step-{self.example_counter}') 425 | rank0_print(f'creating checkpoint to write to {output_dir}...') 426 | self.save(output_dir, mean_eval_metrics) 427 | #### END EVALUATION #### 428 | 429 | #### BEGIN TRAINING #### 430 | self.policy.train() 431 | 432 | start_time = time.time() 433 | batch_metrics = defaultdict(list) 434 | for microbatch_idx in range(self.config.gradient_accumulation_steps): 435 | global_microbatch = slice_and_move_batch_for_device(batch, microbatch_idx, 436 | self.config.gradient_accumulation_steps, self.rank) 437 | local_microbatch = slice_and_move_batch_for_device(global_microbatch, self.rank, self.world_size, 438 | self.rank) 439 | loss, metrics = self.get_batch_metrics(local_microbatch, self.config.loss, train=True) 440 | (loss / self.config.gradient_accumulation_steps).backward() 441 | 442 | for k, v in metrics.items(): 443 | batch_metrics[k].extend(v) 444 | 445 | grad_norm = self.clip_gradient() 446 | self.optimizer.step() 447 | self.scheduler.step() 448 | self.optimizer.zero_grad() 449 | 450 | step_time = time.time() - start_time 451 | examples_per_second = self.config.batch_size / step_time 452 | batch_metrics['examples_per_second'].append(examples_per_second) 453 | batch_metrics['grad_norm'].append(grad_norm) 454 | 455 | self.batch_counter += 1 456 | self.example_counter += self.config.batch_size 457 | 458 | if last_log is None or time.time() - last_log > self.config.minimum_log_interval_secs: 459 | mean_train_metrics = {k: sum(v) / len(v) for k, v in batch_metrics.items()} 460 | mean_train_metrics['counters/examples'] = self.example_counter 461 | mean_train_metrics['counters/updates'] = self.batch_counter 462 | rank0_print(f'train stats after {self.example_counter} examples: {formatted_dict(mean_train_metrics)}') 463 | 464 | if self.config.wandb.enabled and self.rank == 0: 465 | wandb.log(mean_train_metrics, step=self.example_counter) 466 | 467 | last_log = time.time() 468 | else: 469 | rank0_print(f'skipping logging after {self.example_counter} examples to avoid logging too frequently') 470 | #### END TRAINING #### 471 | 472 | def clip_gradient(self): 473 | """Clip the gradient norm of the parameters of a non-FSDP policy.""" 474 | return torch.nn.utils.clip_grad_norm_(self.policy.parameters(), self.config.max_grad_norm).item() 475 | 476 | def write_state_dict(self, step: int, state: Dict[str, torch.Tensor], metrics: Dict, filename: str, 477 | dir_name: Optional[str] = None): 478 | """Write a checkpoint to disk.""" 479 | if dir_name is None: 480 | dir_name = os.path.join(self.run_dir, f'LATEST') 481 | 482 | os.makedirs(dir_name, exist_ok=True) 483 | output_path = os.path.join(dir_name, filename) 484 | rank0_print(f'writing checkpoint to {output_path}...') 485 | torch.save({ 486 | 'step_idx': step, 487 | 'state': state, 488 | 'metrics': metrics if metrics is not None else {}, 489 | }, output_path) 490 | 491 | def save(self, output_dir: Optional[str] = None, metrics: Optional[Dict] = None): 492 | """Save policy, optimizer, and scheduler state to disk.""" 493 | 494 | policy_state_dict = self.policy.state_dict() 495 | self.write_state_dict(self.example_counter, policy_state_dict, metrics, 'policy.pt', output_dir) 496 | del policy_state_dict 497 | 498 | optimizer_state_dict = self.optimizer.state_dict() 499 | self.write_state_dict(self.example_counter, optimizer_state_dict, metrics, 'optimizer.pt', output_dir) 500 | del optimizer_state_dict 501 | 502 | scheduler_state_dict = self.scheduler.state_dict() 503 | self.write_state_dict(self.example_counter, scheduler_state_dict, metrics, 'scheduler.pt', output_dir) 504 | 505 | 506 | class FSDPTrainer(BasicTrainer): 507 | def __init__(self, policy: nn.Module, config: DictConfig, seed: int, run_dir: str, 508 | reference_model: Optional[nn.Module] = None, rank: int = 0, world_size: int = 1): 509 | """A trainer subclass that uses PyTorch FSDP to shard the model across multiple GPUs. 510 | 511 | This trainer will shard both the policy and reference model across all available GPUs. 512 | Models are sharded at the block level, where the block class name is provided in the config. 513 | """ 514 | 515 | super().__init__(policy, config, seed, run_dir, reference_model, rank, world_size) 516 | assert config.model.block_name is not None, 'must specify model.block_name (e.g., GPT2Block or GPTNeoXLayer) for FSDP' 517 | 518 | wrap_class = get_block_class_from_model(policy, config.model.block_name) 519 | model_auto_wrap_policy = functools.partial(transformer_auto_wrap_policy, transformer_layer_cls={wrap_class}, ) 520 | 521 | shared_fsdp_kwargs = dict( 522 | auto_wrap_policy=model_auto_wrap_policy, 523 | sharding_strategy=ShardingStrategy.FULL_SHARD, 524 | cpu_offload=CPUOffload(offload_params=False), 525 | backward_prefetch=BackwardPrefetch.BACKWARD_PRE, 526 | device_id=rank, 527 | ignored_modules=None, 528 | limit_all_gathers=False, 529 | use_orig_params=False, 530 | sync_module_states=False 531 | ) 532 | 533 | rank0_print('Sharding policy...') 534 | mp_dtype = getattr(torch, config.model.fsdp_policy_mp) if config.model.fsdp_policy_mp is not None else None 535 | policy_mp_policy = MixedPrecision(param_dtype=mp_dtype, reduce_dtype=mp_dtype, buffer_dtype=mp_dtype) 536 | self.policy = FSDP(policy, **shared_fsdp_kwargs, mixed_precision=policy_mp_policy) 537 | 538 | if config.activation_checkpointing: 539 | rank0_print('Attempting to enable activation checkpointing...') 540 | try: 541 | # use activation checkpointing, according to: 542 | # https://pytorch.org/blog/scaling-multimodal-foundation-models-in-torchmultimodal-with-pytorch-distributed/ 543 | # 544 | # first, verify we have FSDP activation support ready by importing: 545 | from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( 546 | checkpoint_wrapper, 547 | apply_activation_checkpointing, 548 | CheckpointImpl, 549 | ) 550 | non_reentrant_wrapper = functools.partial( 551 | checkpoint_wrapper, 552 | offload_to_cpu=False, 553 | checkpoint_impl=CheckpointImpl.NO_REENTRANT, 554 | ) 555 | except Exception as e: 556 | rank0_print('FSDP activation checkpointing not available:', e) 557 | else: 558 | check_fn = lambda submodule: isinstance(submodule, wrap_class) 559 | rank0_print('Applying activation checkpointing wrapper to policy...') 560 | apply_activation_checkpointing(self.policy, checkpoint_wrapper_fn=non_reentrant_wrapper, 561 | check_fn=check_fn) 562 | rank0_print('FSDP activation checkpointing enabled!') 563 | 564 | if config.loss.name == 'tdpo': 565 | rank0_print('Sharding reference model...') 566 | self.reference_model = FSDP(reference_model, **shared_fsdp_kwargs) 567 | 568 | print('Loaded model on rank', rank) 569 | dist.barrier() 570 | 571 | def clip_gradient(self): 572 | """Clip the gradient norm of the parameters of an FSDP policy, gathering the gradients across all GPUs.""" 573 | return self.policy.clip_grad_norm_(self.config.max_grad_norm).item() 574 | 575 | def save(self, output_dir=None, metrics=None): 576 | """Save policy, optimizer, and scheduler state to disk, gathering from all processes and saving only on the rank 0 process.""" 577 | save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) 578 | with FSDP.state_dict_type(self.policy, StateDictType.FULL_STATE_DICT, state_dict_config=save_policy): 579 | policy_state_dict = self.policy.state_dict() 580 | 581 | if self.rank == 0: 582 | self.write_state_dict(self.example_counter, policy_state_dict, metrics, 'policy.pt', output_dir) 583 | del policy_state_dict 584 | dist.barrier() 585 | 586 | save_policy = FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True) 587 | with FSDP.state_dict_type(self.policy, StateDictType.FULL_STATE_DICT, optim_state_dict_config=save_policy): 588 | optimizer_state_dict = FSDP.optim_state_dict(self.policy, self.optimizer) 589 | 590 | if self.rank == 0: 591 | self.write_state_dict(self.example_counter, optimizer_state_dict, metrics, 'optimizer.pt', output_dir) 592 | del optimizer_state_dict 593 | dist.barrier() 594 | 595 | if self.rank == 0: 596 | scheduler_state_dict = self.scheduler.state_dict() 597 | self.write_state_dict(self.example_counter, scheduler_state_dict, metrics, 'scheduler.pt', output_dir) 598 | dist.barrier() 599 | 600 | 601 | class TensorParallelTrainer(BasicTrainer): 602 | def __init__(self, policy, config, seed, run_dir, reference_model=None, rank=0, world_size=1): 603 | """A trainer subclass that uses TensorParallel to shard the model across multiple GPUs. 604 | 605 | Based on https://github.com/BlackSamorez/tensor_parallel. Note sampling is extremely slow, 606 | see https://github.com/BlackSamorez/tensor_parallel/issues/66. 607 | """ 608 | super().__init__(policy, config, seed, run_dir, reference_model, rank, world_size) 609 | 610 | rank0_print('Sharding policy...') 611 | self.policy = tp.tensor_parallel(policy, sharded=True) 612 | if config.loss.name == 'tdpo': 613 | rank0_print('Sharding reference model...') 614 | self.reference_model = tp.tensor_parallel(reference_model, sharded=False) 615 | 616 | def save(self, output_dir=None, metrics=None): 617 | """Save (unsharded) policy state to disk.""" 618 | with tp.save_tensor_parallel(self.policy): 619 | policy_state_dict = self.policy.state_dict() 620 | 621 | self.write_state_dict(self.example_counter, policy_state_dict, metrics, 'policy.pt', output_dir) 622 | del policy_state_dict 623 | --------------------------------------------------------------------------------