├── config ├── loss │ ├── sft.yaml │ ├── ift.yaml │ ├── orpo.yaml │ ├── dpo.yaml │ └── kto.yaml ├── model │ ├── qwen2-0_5b.yaml │ ├── llama7b.yaml │ ├── zephyr7b-sft.yaml │ ├── mistral7b.yaml │ └── blank_model.yaml └── config.yaml ├── requirements.txt ├── commands ├── run_mistral_ift.sh └── run_zephyr_ift.sh ├── README.md └── src ├── train.py ├── utils.py ├── dataloader.py └── trainers.py /config/loss/sft.yaml: -------------------------------------------------------------------------------- 1 | name: sft -------------------------------------------------------------------------------- /config/model/qwen2-0_5b.yaml: -------------------------------------------------------------------------------- 1 | name_or_path: /root/pubmodels/Qwen2-0.5B 2 | tokenizer_name_or_path: null 3 | archive: null 4 | block_name: Qwen2DecoderLayer 5 | 6 | policy_dtype: bfloat16 7 | fsdp_policy_mp: bfloat16 8 | reference_dtype: bfloat16 9 | 10 | lora: false -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ipykernel==6.23.1 2 | numpy==1.24.3 3 | tokenizers==0.13.3 4 | torch==2.0.1 5 | tqdm==4.65.0 6 | transformers==4.29.2 7 | datasets==2.12.0 8 | beautifulsoup4==4.12.2 9 | wandb==0.15.3 10 | hydra-core==1.3.2 11 | tensor-parallel==1.2.4 12 | torch-discounted-cumsum==1.1.0 -------------------------------------------------------------------------------- /config/model/llama7b.yaml: -------------------------------------------------------------------------------- 1 | name_or_path: /root/pubmodels/transformers/llama-2/llama-2-7b-chat-hf 2 | tokenizer_name_or_path: null 3 | archive: null 4 | block_name: LlamaDecoderLayer 5 | 6 | policy_dtype: bfloat16 7 | fsdp_policy_mp: bfloat16 8 | reference_dtype: bfloat16 9 | 10 | lora: true -------------------------------------------------------------------------------- /config/model/zephyr7b-sft.yaml: -------------------------------------------------------------------------------- 1 | name_or_path: /root/pubmodels/transformers/chat-models/mistral-7b-sft-beta 2 | tokenizer_name_or_path: null 3 | archive: null 4 | block_name: MistralDecoderLayer 5 | 6 | policy_dtype: bfloat16 7 | fsdp_policy_mp: bfloat16 8 | reference_dtype: bfloat16 9 | 10 | lora: false -------------------------------------------------------------------------------- /config/loss/ift.yaml: -------------------------------------------------------------------------------- 1 | # do DPO preference-based training 2 | name: ift 3 | 4 | # Temporal Residual Connection 5 | min_lambda: 0.2 6 | max_lambda: 0.2 7 | lambda_disturb: null 8 | disturb_std: 0.5 9 | lambda_schedule: null 10 | 11 | # Relation Propagation 12 | gamma: 0.95 13 | propagation_type: loss 14 | propagation_side: left 15 | propagation_norm: L1 -------------------------------------------------------------------------------- /config/model/mistral7b.yaml: -------------------------------------------------------------------------------- 1 | name_or_path: /root/pubmodels/transformers/mistralai/Mistral-7B-v0.1 2 | tokenizer_name_or_path: /root/pubmodels/transformers/chat-models/mistral-7b-sft-beta 3 | archive: null 4 | block_name: MistralDecoderLayer 5 | 6 | policy_dtype: bfloat16 7 | fsdp_policy_mp: bfloat16 8 | reference_dtype: bfloat16 9 | 10 | lora: false -------------------------------------------------------------------------------- /config/loss/orpo.yaml: -------------------------------------------------------------------------------- 1 | # do DPO preference-based training 2 | name: orpo 3 | 4 | # the temperature parameter for DPO; lower values mean we care less about 5 | # the reference model 6 | beta: 0.25 7 | # alpha: 1.0 8 | # the noise parameter for conservative DPO; should be in range (0, 0.5); interpreted as 9 | # the fraction of preference pairs that are flipped 10 | # eps=0 is the original DPO loss in the DPO paper 11 | label_smoothing: 0 12 | 13 | # if true, use a uniform (maximum entropy) reference model 14 | reference_free: false -------------------------------------------------------------------------------- /config/loss/dpo.yaml: -------------------------------------------------------------------------------- 1 | # do DPO preference-based training 2 | name: dpo 3 | 4 | # the temperature parameter for DPO; lower values mean we care less about 5 | # the reference model 6 | beta: ??? 7 | gamma: 1.0 8 | # the noise parameter for conservative DPO; should be in range (0, 0.5); interpreted as 9 | # the fraction of preference pairs that are flipped 10 | # eps=0 is the original DPO loss in the DPO paper 11 | label_smoothing: 0 12 | 13 | # if true, use a uniform (maximum entropy) reference model 14 | reference_free: false 15 | rejected_free: false -------------------------------------------------------------------------------- /config/loss/kto.yaml: -------------------------------------------------------------------------------- 1 | # do DPO preference-based training 2 | name: kto 3 | 4 | # the temperature parameter for DPO; lower values mean we care less about 5 | # the reference model 6 | beta: 0.1 7 | # alpha: 1.0 8 | # the noise parameter for conservative DPO; should be in range (0, 0.5); interpreted as 9 | # the fraction of preference pairs that are flipped 10 | # eps=0 is the original DPO loss in the DPO paper 11 | label_smoothing: 0 12 | 13 | # if true, use a uniform (maximum entropy) reference model 14 | reference_free: false 15 | 16 | desirable_weight: 1.0 17 | undesirable_weight: 1.0 -------------------------------------------------------------------------------- /config/model/blank_model.yaml: -------------------------------------------------------------------------------- 1 | # the name of the model to use; should be something like 2 | # gpt2-xl or gpt-neo-2.7B or huggyllama/llama-7b 3 | name_or_path: ??? 4 | 5 | # the name of the tokenizer to use; if null, will use the tokenizer from the model 6 | tokenizer_name_or_path: null 7 | 8 | # override pre-trained weights (e.g., from SFT); optional 9 | archive: null 10 | 11 | # the name of the module class to wrap with FSDP; should be something like 12 | # e.g. GPT2Block, GPTNeoXLayer, LlamaDecoderLayer, etc. 13 | block_name: null 14 | 15 | # the dtype for the policy parameters/optimizer state 16 | policy_dtype: float32 17 | 18 | # the mixed precision dtype if using FSDP; defaults to the same as the policy 19 | fsdp_policy_mp: null 20 | 21 | # the dtype for the reference model (which is used for inference only) 22 | reference_dtype: float16 23 | 24 | # whether to use LoRA (Layer-wise Recalibration); defaults to false 25 | lora: false -------------------------------------------------------------------------------- /commands/run_mistral_ift.sh: -------------------------------------------------------------------------------- 1 | # Equivalent to SFT, but will be slower than setting loss=sft and loss.name=sft 2 | python -u \ 3 | src/train.py \ 4 | model=mistral7b \ 5 | datasets=[ultrachat,ultrafeedback] \ 6 | exp_name=ultra_ift_mistral7b_fsdp \ 7 | loss=ift \ 8 | loss.name=ift \ 9 | loss.gamma=0.00 \ 10 | loss.min_lambda=0.0 \ 11 | loss.max_lambda=0.0 \ 12 | n_epochs=3 \ 13 | n_examples=61136 \ 14 | batch_size=512 \ 15 | gradient_accumulation_steps=64 \ 16 | eval_batch_size=32 \ 17 | lr=5e-7 \ 18 | warmup_ratio=0.15 \ 19 | max_prompt_length=1024 \ 20 | max_length=1024 \ 21 | trainer=FSDPTrainer \ 22 | optimizer=RMSprop \ 23 | lr_scheduler=cosine 24 | 25 | # IFT 26 | python -u \ 27 | src/train.py \ 28 | model=mistral7b \ 29 | datasets=[ultrachat,ultrafeedback] \ 30 | exp_name=ultra_ift_mistral7b_fsdp \ 31 | loss=ift \ 32 | loss.name=ift \ 33 | loss.gamma=0.95 \ 34 | loss.min_lambda=0.2 \ 35 | loss.max_lambda=0.2 \ 36 | loss.propagation_type=loss \ 37 | loss.propagation_norm=L1 \ 38 | loss.propagation_side=left \ 39 | n_epochs=3 \ 40 | n_examples=61136 \ 41 | batch_size=512 \ 42 | gradient_accumulation_steps=64 \ 43 | eval_batch_size=32 \ 44 | lr=5e-7 \ 45 | warmup_ratio=0.15 \ 46 | max_prompt_length=1024 \ 47 | max_length=1024 \ 48 | trainer=FSDPTrainer \ 49 | optimizer=RMSprop \ 50 | lr_scheduler=cosine 51 | 52 | # IFT with trained checkpoint 53 | python -u \ 54 | src/train.py \ 55 | model=mistral7b \ 56 | datasets=[ultrachat,ultrafeedback] \ 57 | exp_name=ultra_ift_mistral7b_fsdp \ 58 | loss=ift \ 59 | loss.name=ift \ 60 | loss.gamma=0.95 \ 61 | loss.min_lambda=0.2 \ 62 | loss.max_lambda=0.2 \ 63 | loss.propagation_type=loss \ 64 | loss.propagation_norm=L1 \ 65 | loss.propagation_side=left \ 66 | n_epochs=3 \ 67 | n_examples=61136 \ 68 | batch_size=512 \ 69 | gradient_accumulation_steps=64 \ 70 | eval_batch_size=32 \ 71 | lr=5e-7 \ 72 | warmup_ratio=0.15 \ 73 | max_prompt_length=1024 \ 74 | max_length=1024 \ 75 | trainer=FSDPTrainer \ 76 | optimizer=RMSprop \ 77 | lr_scheduler=cosine \ 78 | checkpoint_path=path/to/checkpoint 79 | -------------------------------------------------------------------------------- /commands/run_zephyr_ift.sh: -------------------------------------------------------------------------------- 1 | # Equivalent to SFT, but will be slower than setting loss=sft and loss.name=sft 2 | python -u \ 3 | src/train.py \ 4 | model=zephyr7b-sft \ 5 | datasets=[ultrafeedback] \ 6 | exp_name=ultrafeedback_ift_zephyr7b-sft_fsdp_sample \ 7 | loss=ift \ 8 | loss.name=ift \ 9 | loss.gamma=0.00 \ 10 | loss.min_lambda=0.0 \ 11 | loss.max_lambda=0.0 \ 12 | n_epochs=3 \ 13 | n_examples=61136 \ 14 | batch_size=512 \ 15 | gradient_accumulation_steps=64 \ 16 | eval_batch_size=32 \ 17 | lr=5e-7 \ 18 | warmup_ratio=0.15 \ 19 | max_prompt_length=1024 \ 20 | max_length=1024 \ 21 | trainer=FSDPTrainer \ 22 | optimizer=RMSprop \ 23 | lr_scheduler=cosine 24 | 25 | # IFT 26 | python -u \ 27 | src/train.py \ 28 | model=zephyr7b-sft \ 29 | datasets=[ultrafeedback] \ 30 | exp_name=ultrafeedback_ift_zephyr7b-sft_fsdp_sample \ 31 | loss=ift \ 32 | loss.name=ift \ 33 | loss.gamma=0.95 \ 34 | loss.min_lambda=0.2 \ 35 | loss.max_lambda=0.2 \ 36 | loss.propagation_type=loss \ 37 | loss.propagation_norm=L1 \ 38 | loss.propagation_side=left \ 39 | n_epochs=3 \ 40 | n_examples=61136 \ 41 | batch_size=512 \ 42 | gradient_accumulation_steps=64 \ 43 | eval_batch_size=32 \ 44 | lr=5e-7 \ 45 | warmup_ratio=0.15 \ 46 | max_prompt_length=1024 \ 47 | max_length=1024 \ 48 | trainer=FSDPTrainer \ 49 | optimizer=RMSprop \ 50 | lr_scheduler=cosine 51 | 52 | # IFT with trained checkpoint 53 | python -u \ 54 | src/train.py \ 55 | model=zephyr7b-sft \ 56 | datasets=[ultrafeedback] \ 57 | exp_name=ultrafeedback_ift_zephyr7b-sft_fsdp_sample \ 58 | loss=ift \ 59 | loss.name=ift \ 60 | loss.gamma=0.95 \ 61 | loss.min_lambda=0.2 \ 62 | loss.max_lambda=0.2 \ 63 | loss.propagation_type=loss \ 64 | loss.propagation_norm=L1 \ 65 | loss.propagation_side=left \ 66 | n_epochs=3 \ 67 | n_examples=61136 \ 68 | batch_size=512 \ 69 | gradient_accumulation_steps=64 \ 70 | eval_batch_size=32 \ 71 | lr=5e-7 \ 72 | warmup_ratio=0.15 \ 73 | max_prompt_length=1024 \ 74 | max_length=1024 \ 75 | trainer=FSDPTrainer \ 76 | optimizer=RMSprop \ 77 | lr_scheduler=cosine \ 78 | checkpoint_path=path/to/checkpoint 79 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # IFT: Intuitve Fine-Tuning 2 | 3 | ## Overview 4 | 5 | This repository contains the code for the paper "Intuitive Fine-Tuning: Towards Simplifying Alignment into a Single Process". 6 | 7 | The code is based on the [eric-mitchell/direct-preference-optimization](https://github.com/eric-mitchell/direct-preference-optimization) repository. 8 | 9 | ## Setup 10 | 11 | pip install -r requirements.txt 12 | 13 | ## Running IFT 14 | 15 | bash commands/run_mistral_ift.sh 16 | 17 | ## Hyperparameters 18 | 19 | * `Temporal Residual Connection`: 20 | * `lambda_schedule`: The schedule mode of `lambda`. The default value is set to `null`, which means the static mode. `linear` mode is also provided for the dynamic mode. 21 | * `min_lambda` & `max_lambda`: The minimum value of `lambda`. The default value of both is set to 0.2, which means the static mode. If the `lambda_schedule` is set to `linear`, the `min_lambda` and `max_lambda` will be used to control the start and end value of `lambda` during training. 22 | * `lambda_disturb`: The disturbance distribution of `lambda`. The default value is set to `null`, which means no disturbance. `normal` mode is also provided for the disturbance distribution. 23 | * `disturb_std`: The standard deviation of the `lambda_disturb`. This hyperparameter is only worked when the `lambda_disturb` is not `null`. 24 | 25 | * `Relation Propagation`: 26 | * `gamma`: The decay factor of the Relation Propagation. The default value is set to 0.95. 27 | * `propagation_type`: The variable attribute to Relation Propagation. The default value is set to `loss`. `mask` and `logps` are also provided for the variable attribute. 28 | * `propagation_side`: The side of the Relation Propagation. The default value is set to `left`. `right` is also provided for the side of the Relation Propagation. 29 | * `propagation_norm`: The normalization mode of the Relation Propagation. The default value is set to `L1`. `L2`, `softmax` and `log` are also provided for the normalization mode. 30 | 31 | # Citing IFT 32 | 33 | If you find IFT useful in your research, please consider citing the following paper: 34 | 35 | @article{ 36 | hua2024intuitive, 37 | title={Intuitive Fine-Tuning: Towards Simplifying Alignment into a Single Process}, 38 | author={Hua, Ermo and Qi, Biqing and Zhang, Kaiyan and Yu, Yue and Ding, Ning and Lv, Xingtai and Tian, Kai and Zhou, Bowen}, 39 | journal={arXiv preprint arXiv:2405.11870}, 40 | year={2024} 41 | } -------------------------------------------------------------------------------- /config/config.yaml: -------------------------------------------------------------------------------- 1 | # random seed for batch sampling 2 | seed: 0 3 | 4 | # name for this experiment in the local run directory and on wandb 5 | exp_name: default 6 | 7 | # the batch size for training; for FSDP, the batch size per GPU is batch_size / (grad_accumulation_steps * num_gpus) 8 | batch_size: 4 9 | 10 | # the batch size during evaluation and sampling, if enabled 11 | eval_batch_size: 16 12 | 13 | # debug mode (disables wandb, model checkpointing, etc.) 14 | debug: false 15 | 16 | # the port to use for FSDP 17 | fsdp_port: null 18 | 19 | # the port to use for DeepSpeed 20 | deepspeed_port: null 21 | 22 | # which dataset(s) to train on; can pass a list like datasets=[hh,shp] 23 | datasets: 24 | - hh 25 | 26 | # wandb configuration 27 | wandb: 28 | enabled: true 29 | entity: null 30 | project: "intuitive-fine-tuning" 31 | 32 | # to create the local run directory and cache models/datasets, 33 | # we will try each of these directories in order; if none exist, 34 | # we will create the last one and use it 35 | local_dirs: 36 | - /scr-ssd 37 | - /scr 38 | - .cache 39 | 40 | # whether or not to generate samples during evaluation; disable for FSDP/TensorParallel 41 | # is recommended, because they are slow 42 | sample_during_eval: false 43 | 44 | # how many model samples to generate during evaluation 45 | n_eval_model_samples: 16 46 | 47 | # whether to eval at the very beginning of training 48 | do_first_eval: true 49 | 50 | # an OmegaConf resolver that returns the local run directory, calling a function in utils.py 51 | local_run_dir: ${get_local_run_dir:${exp_name},${local_dirs}} 52 | 53 | checkpoint_path: null 54 | # the learning rate 55 | lr: 5e-7 56 | 57 | # number of steps to accumulate over for each batch 58 | # (e.g. if batch_size=4 and gradient_accumulation_steps=2, then we will 59 | # accumulate gradients over 2 microbatches of size 2) 60 | gradient_accumulation_steps: 1 61 | 62 | # the maximum gradient norm to clip to 63 | max_grad_norm: 10.0 64 | 65 | # the maximum allowed length for an input (prompt + response) 66 | max_length: 512 67 | 68 | # the maximum allowed length for a prompt 69 | max_prompt_length: 256 70 | 71 | # the number of epochs to train for; if null, must specify n_examples 72 | n_epochs: 1 73 | 74 | # the number of examples to train for; if null, must specify n_epochs 75 | n_examples: null 76 | 77 | # the number of examples to evaluate on (and sample from, if sample_during_eval is true) 78 | n_eval_examples: 256 79 | 80 | # the trainer class to use (e.g. BasicTrainer, FSDPTrainer, TensorParallelTrainer) 81 | trainer: BasicTrainer 82 | 83 | # The optimizer to use; we use RMSprop because it works about as well as Adam and is more memory-efficient 84 | optimizer: RMSprop 85 | 86 | # The scheduler to use; we use a linear in dpo, cosine in sft, both of which don't do weight decay 87 | lr_scheduler: linear 88 | 89 | # number of linear warmup steps for the learning rate 90 | warmup_steps: 150 91 | 92 | # the ratio of warmup steps to total steps 93 | warmup_ratio: 0.1 94 | 95 | # whether or not to use activation/gradient checkpointing 96 | activation_checkpointing: false 97 | 98 | # evaluate and save model every eval_every steps 99 | eval_every: 20_000 100 | 101 | # prevent wandb from logging more than once per minimum_log_interval_secs 102 | minimum_log_interval_secs: 1.0 103 | 104 | flash_attn: false 105 | 106 | defaults: 107 | - _self_ 108 | - model: blank_model_fp32 # basic model configuration 109 | - loss: sft # which loss function, either sft or dpo (specify loss.beta if using dpo) 110 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | torch.backends.cuda.matmul.allow_tf32 = True 3 | import torch.nn as nn 4 | import transformers 5 | import peft 6 | from utils import get_local_dir, get_local_run_dir, disable_dropout, init_distributed, get_open_port 7 | import os 8 | import hydra 9 | import torch.multiprocessing as mp 10 | from omegaconf import OmegaConf, DictConfig 11 | import trainers 12 | import wandb 13 | import json 14 | import socket 15 | from typing import Optional, Set 16 | import resource 17 | 18 | # import deepspeed 19 | import argparse 20 | 21 | OmegaConf.register_new_resolver("get_local_run_dir", lambda exp_name, local_dirs: get_local_run_dir(exp_name, local_dirs)) 22 | 23 | 24 | def worker_main(rank: int, world_size: int, config: DictConfig, policy: nn.Module, policy_weak: nn.Module, reference_model: Optional[nn.Module] = None): 25 | """Main function for each worker process (may be only 1 for BasicTrainer/TensorParallelTrainer).""" 26 | if 'FSDP' in config.trainer: 27 | init_distributed(rank, world_size, port=config.fsdp_port) 28 | # elif 'DeepSpeed' in config.trainer: 29 | # deepspeed.init_distributed(dist_backend="nccl", rank=rank, world_size=world_size)#, distributed_port=config.deepspeed_port) 30 | 31 | if config.debug: 32 | wandb.init = lambda *args, **kwargs: None 33 | wandb.log = lambda *args, **kwargs: None 34 | 35 | if rank == 0 and config.wandb.enabled: 36 | os.environ['WANDB_CACHE_DIR'] = get_local_dir(config.local_dirs) 37 | wandb.init( 38 | entity=config.wandb.entity, 39 | project=config.wandb.project, 40 | config=OmegaConf.to_container(config), 41 | dir=get_local_dir(config.local_dirs), 42 | name=config.exp_name, 43 | ) 44 | 45 | TrainerClass = getattr(trainers, config.trainer) 46 | print(f'Creating trainer on process {rank} with world size {world_size}') 47 | trainer = TrainerClass(policy, config, config.seed, config.local_run_dir, policy_weak=policy_weak, reference_model=reference_model, rank=rank, world_size=world_size) 48 | 49 | trainer.train() 50 | trainer.save() 51 | 52 | 53 | @hydra.main(version_base=None, config_path="../config", config_name="config") 54 | def main(config: DictConfig): 55 | """Main entry point for training. Validates config, creates/initializes model(s), and kicks off worker process(es).""" 56 | # Resolve hydra references, e.g. so we don't re-compute the run directory 57 | OmegaConf.resolve(config) 58 | 59 | missing_keys: Set[str] = OmegaConf.missing_keys(config) 60 | if missing_keys: 61 | raise ValueError(f"Got missing keys in config:\n{missing_keys}") 62 | 63 | if config.eval_every % config.batch_size != 0: 64 | print('WARNING: eval_every must be divisible by batch_size') 65 | print('Setting eval_every to', config.eval_every - config.eval_every % config.batch_size) 66 | config.eval_every = config.eval_every - config.eval_every % config.batch_size 67 | 68 | if 'FSDP' in config.trainer and config.fsdp_port is None: 69 | free_port = get_open_port() 70 | print('no FSDP port specified; using open port for FSDP:', free_port) 71 | config.fsdp_port = free_port 72 | elif 'DeepSpeed' in config.trainer and config.deepspeed_port is None: 73 | free_port = get_open_port() 74 | print('no DeepSpeed port specified; using open port for DeepSpeed:', free_port) 75 | config.deepspeed_port = free_port 76 | 77 | print(OmegaConf.to_yaml(config)) 78 | 79 | config_path = os.path.join(config.local_run_dir, 'config.yaml') 80 | with open(config_path, 'w') as f: 81 | OmegaConf.save(config, f) 82 | 83 | print('=' * 80) 84 | print(f'Writing to {socket.gethostname()}:{config.local_run_dir}') 85 | print('=' * 80) 86 | 87 | os.environ['XDG_CACHE_HOME'] = get_local_dir(config.local_dirs) 88 | print('building policy') 89 | model_kwargs = {'device_map': 'balanced'} if config.trainer == 'BasicTrainer' else {} 90 | policy_dtype = getattr(torch, config.model.policy_dtype) 91 | policy = transformers.AutoModelForCausalLM.from_pretrained( 92 | config.model.name_or_path, 93 | cache_dir=get_local_dir(config.local_dirs), 94 | low_cpu_mem_usage=True, 95 | torch_dtype=policy_dtype, 96 | trust_remote_code=True, 97 | attn_implementation="flash_attention_2" if config.flash_attn else None, 98 | **model_kwargs) 99 | disable_dropout(policy) 100 | 101 | if config.loss.name in {'dpo', 'ipo', 'kto', 'orpo'} and config.loss.reference_free is False: 102 | print('building reference model') 103 | reference_model_dtype = getattr(torch, config.model.reference_dtype) 104 | reference_model = transformers.AutoModelForCausalLM.from_pretrained( 105 | config.model.name_or_path, cache_dir=get_local_dir(config.local_dirs), low_cpu_mem_usage=True, torch_dtype=reference_model_dtype, trust_remote_code=True, **model_kwargs) 106 | disable_dropout(reference_model) 107 | 108 | policy_weak = None 109 | else: 110 | policy_weak = None 111 | reference_model = None 112 | 113 | if config.model.archive is not None: 114 | state_dict = torch.load(config.model.archive, map_location='cpu') 115 | step, metrics = state_dict['step_idx'], state_dict['metrics'] 116 | print(f'loading pre-trained weights at step {step} from {config.model.archive} with metrics {json.dumps(metrics, indent=2)}') 117 | policy.load_state_dict(state_dict['state']) 118 | if config.loss.name in {'dpo', 'ipo', 'mypo'}: 119 | reference_model.load_state_dict(state_dict['state']) 120 | print('loaded pre-trained weights') 121 | 122 | if config.model.lora is True: 123 | loftq_config = peft.LoftQConfig(loftq_bits=4, 124 | loftq_iter=1) 125 | lora_config = peft.LoraConfig(init_lora_weights="loftq", 126 | loftq_config=loftq_config) 127 | policy = peft.get_peft_model(policy, lora_config) 128 | 129 | if 'FSDP' in config.trainer: 130 | world_size = torch.cuda.device_count() 131 | print('starting', world_size, 'processes for FSDP training') 132 | soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE) 133 | resource.setrlimit(resource.RLIMIT_NOFILE, (hard, hard)) 134 | print(f'setting RLIMIT_NOFILE soft limit to {hard} from {soft}') 135 | mp.spawn(worker_main, nprocs=world_size, args=(world_size, config, policy, policy_weak, reference_model), join=True) 136 | elif 'DeepSpeed' in config.trainer: 137 | world_size = torch.cuda.device_count() 138 | print('starting', world_size, 'processes for DeepSpeed training') 139 | soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE) 140 | resource.setrlimit(resource.RLIMIT_NOFILE, (hard, hard)) 141 | print(f'setting RLIMIT_NOFILE soft limit to {hard} from {soft}') 142 | worker_main(config.local_rank, world_size, config, policy, reference_model) 143 | else: 144 | print('starting single-process worker') 145 | worker_main(0, 1, config, policy, policy_weak, reference_model) 146 | 147 | if __name__ == '__main__': 148 | main() 149 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import getpass 3 | from datetime import datetime 4 | import torch 5 | import random 6 | import numpy as np 7 | import torch.distributed as dist 8 | # import deepspeed 9 | 10 | import inspect 11 | import importlib.util 12 | import socket 13 | import os 14 | from typing import Dict, Union, Type, List 15 | 16 | # print only on rank 0 17 | def rank0_print(*args, **kwargs): 18 | """Print, but only on rank 0.""" 19 | if not dist.is_initialized() or dist.get_rank() == 0: 20 | print(*args, **kwargs) 21 | 22 | def get_open_port(): 23 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: 24 | s.bind(('', 0)) # bind to all interfaces and use an OS provided port 25 | return s.getsockname()[1] # return only the port number 26 | 27 | 28 | def get_remote_file(remote_path, local_path=None): 29 | hostname, path = remote_path.split(':') 30 | local_hostname = socket.gethostname() 31 | if hostname == local_hostname or hostname == local_hostname[:local_hostname.find('.')]: 32 | return path 33 | 34 | if local_path is None: 35 | local_path = path 36 | # local_path = local_path.replace('/scr-ssd', '/scr') 37 | if os.path.exists(local_path): 38 | return local_path 39 | local_dir = os.path.dirname(local_path) 40 | os.makedirs(local_dir, exist_ok=True) 41 | 42 | print(f'Copying {hostname}:{path} to {local_path}') 43 | os.system(f'scp {remote_path} {local_path}') 44 | return local_path 45 | 46 | 47 | def rank0_print(*args, **kwargs): 48 | """Print, but only on rank 0.""" 49 | if not dist.is_initialized() or dist.get_rank() == 0: 50 | print(*args, **kwargs) 51 | 52 | 53 | def get_local_dir(prefixes_to_resolve: List[str]) -> str: 54 | """Return the path to the cache directory for this user.""" 55 | for prefix in prefixes_to_resolve: 56 | if os.path.exists(prefix): 57 | return f"{prefix}/{getpass.getuser()}" 58 | os.makedirs(prefix) 59 | return f"{prefix}/{getpass.getuser()}" 60 | 61 | 62 | def get_local_run_dir(exp_name: str, local_dirs: List[str]) -> str: 63 | """Create a local directory to store outputs for this run, and return its path.""" 64 | now = datetime.now() 65 | timestamp = now.strftime("%Y-%m-%d_%H-%M-%S_%f") 66 | run_dir = f"{get_local_dir(local_dirs)}/{exp_name}_{timestamp}" 67 | os.makedirs(run_dir, exist_ok=True) 68 | return run_dir 69 | 70 | 71 | def slice_and_move_batch_for_device(batch: Dict, rank: int, world_size: int, device: str) -> Dict: 72 | """Slice a batch into chunks, and move each chunk to the specified device.""" 73 | chunk_size = len(list(batch.values())[0]) // world_size 74 | start = chunk_size * rank 75 | end = chunk_size * (rank + 1) 76 | sliced = {k: v[start:end] for k, v in batch.items()} 77 | on_device = {k: (v.to(device) if isinstance(v, torch.Tensor) else v) for k, v in sliced.items()} 78 | return on_device 79 | 80 | 81 | def pad_to_length(tensor: torch.Tensor, length: int, pad_value: Union[int, float], dim: int = -1) -> torch.Tensor: 82 | if tensor.size(dim) >= length: 83 | return tensor 84 | else: 85 | pad_size = list(tensor.shape) 86 | pad_size[dim] = length - tensor.size(dim) 87 | return torch.cat([tensor, pad_value * torch.ones(*pad_size, dtype=tensor.dtype, device=tensor.device)], dim=dim) 88 | 89 | 90 | def all_gather_if_needed(values: torch.Tensor, rank: int, world_size: int) -> torch.Tensor: 91 | """Gather and stack/cat values from all processes, if there are multiple processes.""" 92 | if world_size == 1: 93 | return values 94 | 95 | all_values = [torch.empty_like(values).to(rank) for _ in range(world_size)] 96 | dist.all_gather(all_values, values) 97 | cat_function = torch.cat if values.dim() > 0 else torch.stack 98 | return cat_function(all_values, dim=0) 99 | 100 | 101 | def formatted_dict(d: Dict) -> Dict: 102 | """Format a dictionary for printing.""" 103 | return {k: (f"{v:.5g}" if type(v) == float else v) for k, v in d.items()} 104 | 105 | 106 | def disable_dropout(model: torch.nn.Module): 107 | """Disable dropout in a model.""" 108 | for module in model.modules(): 109 | if isinstance(module, torch.nn.Dropout): 110 | module.p = 0 111 | 112 | def delete_dict(d: Dict): 113 | """Delete all items inside the dict.""" 114 | for k in list(d.keys()): 115 | del d[k] 116 | 117 | def print_gpu_memory(rank: int = None, message: str = ''): 118 | """Print the amount of GPU memory currently allocated for each GPU.""" 119 | if torch.cuda.is_available(): 120 | device_count = torch.cuda.device_count() 121 | for i in range(device_count): 122 | device = torch.device(f'cuda:{i}') 123 | allocated_bytes = torch.cuda.memory_allocated(device) 124 | if allocated_bytes == 0: 125 | continue 126 | print('*' * 40) 127 | print(f'[{message} rank {rank} ] GPU {i}: {allocated_bytes / 1024**2:.2f} MB') 128 | print('*' * 40) 129 | 130 | 131 | def get_block_class_from_model(model: torch.nn.Module, block_class_name: str) -> torch.nn.Module: 132 | """Get the class of a block from a model, using the block's class name.""" 133 | for module in model.modules(): 134 | if module.__class__.__name__ == block_class_name: 135 | return module.__class__ 136 | raise ValueError(f"Could not find block class {block_class_name} in model {model}") 137 | 138 | 139 | def get_block_class_from_model_class_and_block_name(model_class: Type, block_class_name: str) -> Type: 140 | filepath = inspect.getfile(model_class) 141 | assert filepath.endswith('.py'), f"Expected a .py file, got {filepath}" 142 | assert os.path.exists(filepath), f"File {filepath} does not exist" 143 | assert "transformers" in filepath, f"Expected a transformers model, got {filepath}" 144 | 145 | module_name = filepath[filepath.find('transformers'):].replace('/', '.')[:-3] 146 | print(f"Searching in file {filepath}, module {module_name} for class {block_class_name}") 147 | 148 | # Load the module dynamically 149 | spec = importlib.util.spec_from_file_location(module_name, filepath) 150 | module = importlib.util.module_from_spec(spec) 151 | spec.loader.exec_module(module) 152 | 153 | # Get the class dynamically 154 | class_ = getattr(module, block_class_name) 155 | print(f"Found class {class_} in module {module_name}") 156 | return class_ 157 | 158 | 159 | def init_distributed(rank: int, world_size: int, master_addr: str = 'localhost', port: int = 12355, backend: str = 'nccl'): 160 | print(rank, 'initializing distributed') 161 | os.environ["MASTER_ADDR"] = master_addr 162 | os.environ["MASTER_PORT"] = str(port) 163 | dist.init_process_group(backend, rank=rank, world_size=world_size) 164 | torch.cuda.set_device(rank) 165 | 166 | 167 | class TemporarilySeededRandom: 168 | def __init__(self, seed): 169 | """Temporarily set the random seed, and then restore it when exiting the context.""" 170 | self.seed = seed 171 | self.stored_state = None 172 | self.stored_np_state = None 173 | 174 | def __enter__(self): 175 | # Store the current random state 176 | self.stored_state = random.getstate() 177 | self.stored_np_state = np.random.get_state() 178 | 179 | # Set the random seed 180 | random.seed(self.seed) 181 | np.random.seed(self.seed) 182 | 183 | def __exit__(self, exc_type, exc_value, traceback): 184 | # Restore the random state 185 | random.setstate(self.stored_state) 186 | np.random.set_state(self.stored_np_state) 187 | -------------------------------------------------------------------------------- /src/dataloader.py: -------------------------------------------------------------------------------- 1 | import datasets 2 | import torch 3 | from torch.nn.utils.rnn import pad_sequence 4 | from collections import defaultdict 5 | import numpy as np 6 | 7 | import tqdm 8 | import re 9 | import random 10 | from typing import Dict, List, Optional, Tuple 11 | from dataclasses import dataclass, field 12 | from utils import rank0_print, get_local_dir, TemporarilySeededRandom 13 | import time 14 | import math 15 | 16 | @dataclass 17 | class Example: 18 | """ 19 | Class for an example in a preference or SFT dataset. If you want each prompt to be uniquely associated with an Example instance, save it in a dict. 20 | """ 21 | prompt: str = '' # prompt text 22 | chosen: str = '' # the chosen text 23 | rejected: str = '' # the rejected text # if truncation needed, keep the beginning (keep_start) or end (keep_end) (only override default for SHP) # the unformatted prompt (needed to recover instruction for AlpacaEval) 24 | 25 | class Dataset: 26 | """ 27 | A collection of Example instances, indexed by prompt. 28 | """ 29 | def __init__(self, 30 | name, 31 | truncation_mode: str = 'keep_start'): 32 | self.name = name 33 | self.data = defaultdict(Example) 34 | self.truncation_mode = truncation_mode 35 | 36 | def __setitem__(self, key, value): 37 | if not isinstance(key, str): 38 | raise KeyError("key must be a string") 39 | 40 | if not isinstance(value, Example): 41 | raise ValueError("value must be a Example") 42 | 43 | self.data[key] = value 44 | 45 | def __getitem__(self, key): 46 | return self.data[key] 47 | 48 | def __len__(self): 49 | return len(self.data) 50 | 51 | def __iter__(self): 52 | return iter(self.data) 53 | 54 | def get_orca(split, preprocess, silent, cache_dir, n_examples=None) -> Dataset: 55 | dataset = datasets.load_dataset('/root/pubdatasets/orca_dpo_pairs', split=split, cache_dir=cache_dir) 56 | 57 | data = Dataset(name='orca', truncation_mode='keep_start') 58 | count = 0 59 | for row in tqdm.tqdm(dataset, desc='Processing Orca', disable=silent): 60 | title = row['prompt'] 61 | 62 | prompt = row['chosen'][:-1] 63 | if prompt[0]['role'] != 'system': 64 | prompt.insert(0, {"role": "system", "content": ""}) 65 | 66 | chosen = row['chosen'][-1:] 67 | rejected = row['rejected'][-1:] 68 | 69 | prompt, chosen, rejected = preprocess(prompt, chosen, rejected, truncation_mode='keep_start') 70 | 71 | data[title].prompt = prompt 72 | data[title].chosen = chosen 73 | data[title].rejected = rejected 74 | 75 | count += 1 76 | if n_examples is not None and count >= n_examples: 77 | break 78 | 79 | return data 80 | 81 | def get_ultrafeedback(split, preprocess, silent, cache_dir, n_examples=None) -> Dataset: 82 | split += "_prefs" 83 | 84 | dataset = datasets.load_dataset('/root/pubdatasets/UltraFeedback/others/HuggingFaceH4/ultrafeedback_binarized', split=split, cache_dir=cache_dir) 85 | 86 | data = Dataset(name='ultrafeedback', truncation_mode='keep_start') 87 | count = 0 88 | for row in tqdm.tqdm(dataset, desc='Processing UltraFeedback', disable=silent): 89 | title = row['prompt'] 90 | 91 | prompt = row['chosen'][:-1] 92 | if prompt[0]['role'] != 'system': 93 | prompt.insert(0, {"role": "system", "content": ""}) 94 | chosen = row['chosen'][-1:] 95 | rejected = row['rejected'][-1:] 96 | 97 | # prompt, chosen, rejected = preprocess(prompt, chosen, rejected, truncation_mode='keep_start') 98 | data[title].type = "pairwise" 99 | data[title].prompt = prompt 100 | data[title].chosen = chosen 101 | data[title].rejected = rejected 102 | 103 | count += 1 104 | if n_examples is not None and count >= n_examples: 105 | break 106 | 107 | return data 108 | 109 | def get_ultrachat(split, preprocess, silent, cache_dir, n_examples=None) -> Dataset: 110 | split += "_sft" 111 | 112 | dataset = datasets.load_dataset('/root/pubdatasets/UltraChat/others/HuggingFaceH4/ultrachat_200k', split=split, cache_dir=cache_dir) 113 | 114 | data = Dataset(name='ultrachat', truncation_mode='keep_start') 115 | count = 0 116 | for row in tqdm.tqdm(dataset, desc='Processing UltraChat', disable=silent): 117 | title = row['prompt'] 118 | 119 | prompt = row['messages'][:-1] 120 | if prompt[0]['role'] != 'system': 121 | prompt.insert(0, {"content": "", "role": "system"}) 122 | messages = row['messages'][-1:] 123 | 124 | 125 | # prompt, messages = preprocess(prompt, messages, truncation_mode='keep_start') 126 | data[title].type = "single" 127 | data[title].prompt = prompt 128 | data[title].chosen = messages 129 | data[title].rejected = messages 130 | 131 | count += 1 132 | if n_examples is not None and count >= n_examples: 133 | break 134 | 135 | return data 136 | 137 | def get_orcamath(split, preprocess, silent, cache_dir, n_examples=None) -> Dataset: 138 | # dataset = datasets.load_dataset('microsoft/orca-math-word-problems-200k', split=split, cache_dir=cache_dir) 139 | dataset = datasets.load_dataset('parquet', 140 | data_dir='/root/pubdatasets/orca-math/microsoft/orca-math-word-problems-200k/data', 141 | data_files={'train': 'train-00000-of-00001.parquet'}, 142 | cache_dir=cache_dir) 143 | dataset = dataset['train'] 144 | 145 | if split == 'test': 146 | dataset = [{'question': q, 'answer': a} for q, a in zip(dataset[:1000]['question'], dataset[:1000]['answer'])] 147 | 148 | data = Dataset(name='orcamath', truncation_mode='keep_start') 149 | count = 0 150 | for row in tqdm.tqdm(dataset, desc='Processing OrcaMath', disable=silent): 151 | title = row['question'] 152 | 153 | prompt = [{"role": "system", "content": ""}, 154 | {"role": "user", "content": row['question']}] 155 | 156 | messages = [{"role": "assistant", "content": row['answer']}] 157 | 158 | prompt, chosen, rejected = preprocess(prompt, chosen, rejected, truncation_mode='keep_start') 159 | 160 | data[title].prompt = prompt 161 | data[title].chosen = messages 162 | data[title].rejected = messages 163 | 164 | count += 1 165 | if n_examples is not None and count >= n_examples: 166 | break 167 | 168 | return data 169 | 170 | class DataLoader: 171 | """ 172 | The base data loader class, similar to the one from the DPO repo. 173 | Subclass this and overwrite the __iter__ method as needed, since the batcch elements will be different depending 174 | on whether you're doing SFT, aligning with a pairwise loss like DPO, or alignment with a unary loss like KTO. 175 | """ 176 | def __init__(self, 177 | names: List[str], # e.g., ['shp', 'oasst']; should have get_{name} method in this file 178 | tokenizer, # Huggingface tokenizer object 179 | split: str = 'train', 180 | loss_name: str = 'dpo', 181 | batch_size: int = 1, 182 | max_length: int = 512, # max length of prompt + response 183 | max_prompt_length: int = 128, # max length of prompt alone 184 | n_epochs: Optional[int] = None, 185 | n_examples: Optional[int] = None, 186 | seed: int = 0, 187 | shuffle: bool = True, 188 | silent: bool = False, 189 | cache_dir: Optional[str] = None, 190 | **kwargs): 191 | 192 | torch.manual_seed(seed) 193 | random.seed(seed) 194 | 195 | self.tokenizer = tokenizer 196 | if self.tokenizer.__class__.__name__ == 'Qwen2TokenizerFast': 197 | self.tokenizer.chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" 198 | 199 | self.build_assistant_starter(tokenizer=tokenizer) 200 | 201 | self.split = split 202 | self.batch_size = batch_size 203 | self.max_length = max_length 204 | self.max_prompt_length = max_prompt_length 205 | self.seed = seed 206 | self.shuffle = shuffle 207 | self.silent = silent 208 | self.cache_dir = cache_dir 209 | self.kwargs = kwargs 210 | 211 | assert n_epochs is not None or n_examples is not None, "Must specify either n_epochs or n_examples" 212 | self.n_epochs = n_epochs 213 | self.epoch_idx = 0 214 | self.n_examples = n_examples 215 | 216 | self.names = names 217 | self.full_data = self.flatten_data() #*debug 218 | 219 | self.truncation_mode = "keep_start"# dataset.truncation_mode 220 | 221 | def build_assistant_starter(self, tokenizer): 222 | message = [{"content": "#"*10, "role": "assistant"}] 223 | starter_text = tokenizer.apply_chat_template(message, tokenize=False) 224 | position = starter_text.find("#"*10) 225 | 226 | text = starter_text[:position] 227 | ids = tokenizer(starter_text[:position])['input_ids'] 228 | length = len(ids) 229 | 230 | self.assistant_starter = { 231 | "text": text, 232 | "ids": ids, 233 | "length": length} 234 | 235 | def collate_fn(self, batch: List[Dict]) -> Dict: 236 | """ 237 | Collate function for the dataloader. 238 | """ 239 | padded_batch = {} 240 | for k in batch[0].keys(): 241 | if k.endswith('_input_ids') or k.endswith('_attention_mask') or k.endswith('_labels'): 242 | if 'prompt' in k: # adapted from https://stackoverflow.com/questions/73256206 243 | to_pad = [torch.LongTensor(ex[k][::-1]) for ex in batch] 244 | else: 245 | to_pad = [torch.LongTensor(ex[k]) for ex in batch] 246 | if k.endswith('_input_ids'): 247 | padding_value = self.tokenizer.pad_token_id 248 | elif k.endswith('_labels'): 249 | padding_value = -100 250 | elif k.endswith('_attention_mask'): 251 | padding_value = 0 252 | else: 253 | raise ValueError(f"Unexpected key in batch '{k}'") 254 | 255 | padded_batch[k] = pad_sequence(to_pad, batch_first=True, padding_value=padding_value) 256 | if 'prompt' in k: # for the prompt, flip back so padding is on left side 257 | padded_batch[k] = padded_batch[k].flip(dims=[1]) 258 | else: 259 | padded_batch[k] = [ex[k] for ex in batch] 260 | 261 | return padded_batch 262 | 263 | def flatten_data(self): 264 | """ 265 | Flatten the data into a list of examples. 266 | """ 267 | flat_data = [] 268 | with TemporarilySeededRandom(self.seed): 269 | for name in self.names: 270 | dataset = globals()[f"get_{name}"](self.split, self.preprocess, self.silent, self.cache_dir, self.n_examples) 271 | for prompt, example in dataset.data.items(): 272 | flat_data.append(example) 273 | return flat_data 274 | 275 | 276 | def __iter__(self): 277 | """ 278 | """ 279 | with TemporarilySeededRandom(self.seed): 280 | permutation_seeds = iter(np.random.randint(0, 2**32, size=1000000)) 281 | # flat_data = self.flatten_data() #*debug 282 | epoch_idx = 0 283 | example_idx = 0 284 | done = False 285 | 286 | while True: 287 | if done: break 288 | 289 | if self.shuffle: 290 | with TemporarilySeededRandom(next(permutation_seeds)): 291 | random.shuffle(self.full_data) # otherwise, will be frontloaded with prompts in same domain 292 | # random.shuffle(flat_data) #*debug 293 | 294 | batch = [] 295 | for example in self.full_data: 296 | # for example in flat_data: #*debug 297 | batch_element = self.get_element(example) 298 | if batch_element is not None: 299 | batch.append(batch_element) 300 | example_idx += 1 301 | if len(batch) == self.batch_size: 302 | # example_idx += len(batch) 303 | yield self.collate_fn(batch) 304 | batch = [] 305 | 306 | if self.split != "train" and self.n_examples is not None: 307 | if example_idx >= self.n_examples * len(self.names): 308 | rank0_print(f'Finished generating {self.n_examples * len(self.names)} examples on {self.split} split') 309 | done = True 310 | break 311 | 312 | epoch_idx += 1 313 | if self.n_epochs is not None and epoch_idx >= self.n_epochs: 314 | done = True 315 | break 316 | 317 | def __len__(self): 318 | """ 319 | The length of the dataloader. 320 | """ 321 | if self.n_examples is not None: 322 | if self.split == 'train': 323 | return math.ceil(self.n_examples * len(self.names) / self.batch_size) * self.n_epochs 324 | else: 325 | return math.ceil(self.n_examples * len(self.names) / self.batch_size) 326 | else: 327 | return math.ceil(len(self.full_data) / self.batch_size) * self.n_epochs 328 | # return math.ceil(self.total_length / self.batch_size) * self.n_epochs #*debug 329 | 330 | def get_element(self, example: Example) -> Dict: 331 | """ 332 | Get a single batch element. 333 | """ 334 | raise NotImplementedError 335 | 336 | def preprocess(self, *args, **kwargs): 337 | """ 338 | Preprocess the prompt. 339 | """ 340 | raise NotImplementedError 341 | 342 | def tokenize_batch_element(self, 343 | prompt: str, 344 | generation: str, 345 | prefix: str='target') -> Dict: 346 | """ 347 | Tokenize a single batch element and truncate if prompt + generation is too long. Batch element is turned into Pytorch 348 | tensors in self.collate. Create the labels for the generation, which are of length equal to the sum of the length of 349 | the prompt and the generation, with -100 for the prompt tokens. 350 | 351 | Args: 352 | - prompt: the input/instruction text 353 | - generation: output text 354 | - truncation_mode: one of 'keep_start'/'keep_end' (truncate end/beginning of combined text respectively) 355 | - prefix: the prefix corresponding to the generation (e.g., 'chosen', 'rejected', 'target') 356 | 357 | Returns: 358 | A dict of the tokenized prompt, tokenized generation, and the concatenation of the two on all relevant elements 359 | (e.g., tokens, attention mask, etc.). The generation elements will have keys starting with '{prefix}_' and the 360 | concatenated elements will have keys starting with '{prefix}_combined_'. 361 | """ 362 | 363 | messages = prompt + generation 364 | conversation = self.tokenizer.apply_chat_template(messages, tokenize=False) 365 | all_input_ids = self.tokenizer( 366 | conversation, 367 | return_tensors='pt', 368 | padding='max_length', 369 | max_length=self.max_length, 370 | truncation=True 371 | ) 372 | 373 | input_ids = all_input_ids.input_ids[0] 374 | attention_mask = all_input_ids.attention_mask[0] 375 | target = input_ids.clone() 376 | 377 | cur_len = 0 378 | for item in messages: 379 | tokens = self.tokenizer.apply_chat_template([item]) 380 | 381 | if item['role'] != 'assistant': 382 | next_len = min(cur_len+len(tokens), len(target)) 383 | else: 384 | next_len = min(cur_len+self.assistant_starter["length"], len(target)) 385 | # tokens.append(-100) # TODO: add the \n token. This is for Mistral specifically 386 | 387 | target[cur_len:next_len] = torch.ones(next_len-cur_len) * -100 388 | 389 | cur_len += len(tokens) 390 | if cur_len >= len(target): 391 | break 392 | 393 | if cur_len < len(target): 394 | target[cur_len:] = torch.ones(len(target)-cur_len) * -100 395 | 396 | # if True: 397 | # rank0_print("#"*10+" input_ids "+"#"*10) 398 | # rank0_print(f"{self.tokenizer.decode(input_ids)}\n") 399 | # # rank0_print([f"{self.tokenizer.decode(input_ids)}\n"]) 400 | # rank0_print("#"*10+" labels "+"#"*10) 401 | # rank0_print(f"{self.tokenizer.decode(torch.where(target==-100, self.tokenizer.pad_token_id, target))}\n") 402 | # rank0_print("#"*50) 403 | # exit() 404 | 405 | # if all of the tokens are masked, return None 406 | # it is possible that the first user prompt is too long 407 | if torch.all(target == -100): 408 | return None 409 | 410 | return { 411 | f"{prefix}_input_ids": input_ids, 412 | f"{prefix}_labels": target, 413 | f"{prefix}_attention_mask": attention_mask, 414 | } 415 | 416 | def get_batch_element(self, example: Example) -> Dict: 417 | """ 418 | Get a single batch element. 419 | """ 420 | raise NotImplementedError 421 | 422 | class SFTDataLoader(DataLoader): 423 | """ 424 | A data loader for SFT. 425 | """ 426 | def get_element(self, example: Example) -> Dict: 427 | """ 428 | Get a single batch element. 429 | """ 430 | batch_element = self.tokenize_batch_element( 431 | example.prompt, 432 | example.chosen, 433 | prefix='target' 434 | ) 435 | 436 | return batch_element 437 | 438 | def preprocess(self, prompt, generation, truncation_mode='keep_start'): 439 | """ 440 | Preprocess the prompt. 441 | """ 442 | return prompt, generation 443 | 444 | 445 | def tokenize_batch_element(self, 446 | prompt: str, 447 | generation: str, 448 | prefix: str='target') -> Dict: 449 | 450 | return super().tokenize_batch_element(prompt, generation, prefix=prefix) 451 | 452 | class IFTDataLoader(SFTDataLoader): 453 | def __init__(self, *args, **kwargs): 454 | super().__init__(*args, **kwargs) 455 | 456 | def tokenize_batch_element(self, 457 | prompt: str, 458 | generation: str, 459 | prefix: str="target") -> Dict: 460 | 461 | return super().tokenize_batch_element(prompt, generation) 462 | 463 | class ORPODataLoader(DataLoader): 464 | def get_element(self, example: Example) -> Dict: 465 | chosen_element = self.tokenize_batch_element( 466 | example.prompt, 467 | example.chosen, 468 | prefix='chosen' 469 | ) 470 | rejected_element = self.tokenize_batch_element( 471 | example.prompt, 472 | example.rejected, 473 | prefix='rejected' 474 | ) 475 | # 拼接两个字典 476 | if chosen_element is None or rejected_element is None: 477 | batch_element = None 478 | else: 479 | batch_element = {**chosen_element, **rejected_element} 480 | 481 | return batch_element 482 | 483 | def preprocess(self, prompt, chosen, rejected, truncation_mode='keep_start'): 484 | return prompt, chosen, rejected 485 | 486 | class KTODataLoader(DataLoader): 487 | def __len__(self): 488 | return super().__len__() 489 | def get_element(self, example: Example) -> Dict: 490 | batch_element = self.tokenize_batch_element( 491 | example.prompt, 492 | example.chosen, 493 | prefix='chosen' 494 | ) 495 | 496 | return batch_element 497 | 498 | def preprocess(self, prompt, chosen, rejected, truncation_mode='keep_start'): 499 | return prompt, chosen, rejected 500 | 501 | def flatten_data(self): 502 | flat_data = [] 503 | with TemporarilySeededRandom(self.seed): 504 | for name in self.names: 505 | dataset = globals()[f"get_{name}"](self.split, self.preprocess, self.silent, self.cache_dir, self.n_examples) 506 | for prompt, example in dataset.data.items(): 507 | if example.type == "single": 508 | flat_data.append((example, 'chosen')) 509 | else: 510 | flat_data.append((example, 'chosen')) 511 | flat_data.append((Example(prompt=example.prompt, chosen=example.rejected, rejected=example.chosen), 'rejected')) 512 | 513 | return flat_data 514 | 515 | def __iter__(self): 516 | with TemporarilySeededRandom(self.seed): 517 | permutation_seeds = iter(np.random.randint(0, 2**32, size=1000000)) 518 | 519 | epoch_idx = 0 520 | example_idx = 0 521 | done = False 522 | 523 | while True: 524 | if done: break 525 | 526 | if self.shuffle: 527 | with TemporarilySeededRandom(next(permutation_seeds)): 528 | random.shuffle(self.full_data) # otherwise, will be frontloaded with prompts in same domain 529 | # random.shuffle(flat_data) #*debug 530 | 531 | batch = [] 532 | example_queue = [] 533 | for example, status in self.full_data: 534 | batch_element = self.get_element(example) 535 | 536 | if batch_element is not None: 537 | batch_element['status'] = status 538 | example_queue.append(example) 539 | batch.append(batch_element) 540 | example_idx += 1 541 | 542 | if len(batch) >= self.batch_size: 543 | indices = list(range(1, len(batch))) + [0] 544 | for i in range(self.batch_size): 545 | batch[i].update(self.tokenize_batch_element( 546 | example_queue[i].prompt, 547 | example_queue[indices[i]].chosen, 548 | prefix='rejected' 549 | )) 550 | example_queue = [] 551 | 552 | yield self.collate_fn(batch[:self.batch_size]) 553 | batch = [] 554 | 555 | if self.split != "train" and self.n_examples is not None: 556 | if example_idx >= self.n_examples * len(self.names): 557 | rank0_print(f'Finished generating {self.n_examples * len(self.names)} examples on {self.split} split') 558 | done = True 559 | break 560 | 561 | epoch_idx += 1 562 | if self.n_epochs is not None and epoch_idx >= self.n_epochs: 563 | done = True 564 | break 565 | 566 | class DPODataLoader(DataLoader): 567 | """ 568 | A data loader for DPO. 569 | """ 570 | def get_element(self, example: Example) -> Dict: 571 | """ 572 | Get a single batch element. 573 | """ 574 | batch_element = self.tokenize_batch_element( 575 | example.prompt, 576 | example.chosen, 577 | example.rejected, 578 | ) 579 | 580 | return batch_element 581 | 582 | def tokenize_batch_element(self, 583 | prompt: str, 584 | chosen: str, 585 | rejected: str,) -> Dict: 586 | """ 587 | """ 588 | 589 | prompt_messages = self.tokenizer.apply_chat_template(prompt, tokenize=False) 590 | chosen_messages = self.tokenizer.apply_chat_template(chosen, tokenize=False) 591 | rejected_messages = self.tokenizer.apply_chat_template(rejected, tokenize=False) 592 | 593 | prompt_tokens = self.tokenizer(prompt_messages, add_special_tokens=False) # added in 02.29 594 | chosen_tokens = self.tokenizer(chosen_messages, add_special_tokens=False) 595 | rejected_tokens = self.tokenizer(rejected_messages, add_special_tokens=False) 596 | 597 | longer_response_length = max(len(chosen_tokens['input_ids']), len(rejected_tokens['input_ids'])) 598 | 599 | # if combined sequence is too long, first truncate prompt 600 | if (len(prompt_tokens['input_ids']) + longer_response_length > self.max_length) and (len(prompt_tokens['input_ids']) > self.max_prompt_length): 601 | if self.truncation_mode == 'keep_start': 602 | prompt_tokens = {k: v[:self.max_prompt_length] for k, v in prompt_tokens.items()} 603 | elif self.truncation_mode == 'keep_end': 604 | prompt_tokens = {k: v[-self.max_prompt_length:] for k, v in prompt_tokens.items()} 605 | else: 606 | raise ValueError(f'Unknown truncation mode: {self.truncation_mode}') 607 | 608 | if (len(prompt_tokens['input_ids']) + longer_response_length > self.max_length): 609 | response_length = self.max_length - self.max_prompt_length # TODO: check 610 | chosen_tokens = {k: v[:response_length] for k, v in chosen_tokens.items()} 611 | rejected_tokens = {k: v[:response_length] for k, v in rejected_tokens.items()} 612 | 613 | batch_element = {} 614 | 615 | batch_element.update({f'chosen_{k}': prompt_tokens[k] + chosen_tokens[k] for k in chosen_tokens}) 616 | batch_element[f'chosen_labels'] = batch_element[f'chosen_input_ids'][:] 617 | batch_element[f'chosen_labels'][:len(prompt_tokens['input_ids'])] = [-100] * len(prompt_tokens['input_ids']) 618 | 619 | batch_element.update({f'rejected_{k}': prompt_tokens[k] + rejected_tokens[k] for k in rejected_tokens}) 620 | batch_element[f'rejected_labels'] = batch_element[f'rejected_input_ids'][:] 621 | batch_element[f'rejected_labels'][:len(prompt_tokens['input_ids'])] = [-100] * len(prompt_tokens['input_ids']) 622 | 623 | return batch_element 624 | 625 | def preprocess(self, prompt, chosen, rejected, truncation_mode='keep_start'): 626 | """ 627 | Preprocess the prompt. 628 | """ 629 | 630 | return prompt, chosen, rejected 631 | 632 | class AnalyseDataLoader(DataLoader): 633 | def __init__(self, *args, **kwargs): 634 | super().__init__(*args, **kwargs) 635 | 636 | def get_element(self, example: Example) -> Dict: 637 | """ 638 | Get a single batch element. 639 | """ 640 | batch_element = self.tokenize_batch_element( 641 | example.prompt, 642 | example.chosen, 643 | example.rejected, 644 | ) 645 | 646 | return batch_element 647 | 648 | def preprocess(self, prompt, chosen, rejected, truncation_mode='keep_start'): 649 | """ 650 | Preprocess the prompt. 651 | """ 652 | return prompt, chosen, rejected 653 | 654 | def tokenize_batch_element(self, prompt: str, chosen: str, rejected: str) -> Dict: 655 | 656 | prompt_messages = self.tokenizer.apply_chat_template(prompt, tokenize=False) 657 | chosen_messages = self.tokenizer.apply_chat_template(chosen, tokenize=False) 658 | rejected_messages = self.tokenizer.apply_chat_template(rejected, tokenize=False) 659 | 660 | prompt_tokens = self.tokenizer( 661 | prompt_messages, 662 | max_length=self.max_prompt_length, 663 | truncation=True, 664 | add_special_tokens=False 665 | ) 666 | max_response_length = self.max_length - len(prompt_tokens['input_ids']) 667 | chosen_tokens = self.tokenizer( 668 | chosen_messages, 669 | max_length=max_response_length, 670 | truncation=True, 671 | add_special_tokens=False 672 | ) 673 | rejected_tokens = self.tokenizer( 674 | rejected_messages, 675 | max_length=max_response_length, 676 | truncation=True, 677 | add_special_tokens=False 678 | ) 679 | 680 | return dict( 681 | prompt_input_ids=prompt_tokens['input_ids'], 682 | chosen_input_ids=chosen_tokens['input_ids'], 683 | rejected_input_ids=rejected_tokens['input_ids'], 684 | ) 685 | 686 | 687 | if __name__ == "__main__": 688 | from transformers import AutoTokenizer 689 | from utils import slice_and_move_batch_for_device 690 | tokenizer = AutoTokenizer.from_pretrained( 691 | "/root/pubmodels/transformers/chat-models/mistral-7b-sft-beta", 692 | cache_dir=".cache/", 693 | truncation_side='right', 694 | padding_side='right') 695 | 696 | tokenizer.chat_template = "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}" 697 | if tokenizer.pad_token is None: 698 | tokenizer.pad_token = tokenizer.eos_token 699 | 700 | loader = SFTDataLoader( 701 | names=['ultrachat','ultrafeedback'], 702 | tokenizer=tokenizer, 703 | split='train', 704 | batch_size=16, 705 | max_length=2048, 706 | max_prompt_length=2048, 707 | n_epochs=3, 708 | seed=0, 709 | shuffle=False, 710 | silent=False, 711 | cache_dir=".cache/" 712 | ) 713 | 714 | gradient_accumulation_steps = 2 715 | world_size = 8 716 | 717 | count = 0 718 | start_time = time.time() 719 | check_time = start_time 720 | for batch in tqdm.tqdm(loader): 721 | if count % 100 == 0: 722 | tmp_time = time.time() 723 | print(f"Time: {tmp_time - start_time}") 724 | check_time = tmp_time 725 | count += 1 726 | 727 | end_time = time.time() 728 | print(f"Time: {end_time - start_time}") -------------------------------------------------------------------------------- /src/trainers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | import transformers 5 | from omegaconf import DictConfig 6 | 7 | import torch.distributed as dist 8 | from torch.distributed.fsdp import ( 9 | FullyShardedDataParallel as FSDP, 10 | MixedPrecision, 11 | StateDictType, 12 | BackwardPrefetch, 13 | ShardingStrategy, 14 | CPUOffload, 15 | ) 16 | from torch.distributed.fsdp.api import FullStateDictConfig, FullOptimStateDictConfig 17 | from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy 18 | 19 | from torch_discounted_cumsum import discounted_cumsum_right, discounted_cumsum_left 20 | 21 | import tensor_parallel as tp 22 | import contextlib 23 | # import deepspeed 24 | 25 | # from preference_datasets import DataLoader, get_batch_iterator 26 | from dataloader import DataLoader, SFTDataLoader, DPODataLoader, IFTDataLoader, ORPODataLoader 27 | 28 | from utils import ( 29 | slice_and_move_batch_for_device, 30 | formatted_dict, 31 | all_gather_if_needed, 32 | pad_to_length, 33 | get_block_class_from_model, 34 | rank0_print, 35 | get_local_dir, 36 | delete_dict, 37 | disable_dropout, 38 | ) 39 | import numpy as np 40 | import wandb 41 | import tqdm 42 | import matplotlib.pyplot as plt 43 | import math 44 | 45 | import gc 46 | import random 47 | import os 48 | import argparse 49 | from collections import defaultdict 50 | import time 51 | import json 52 | import functools 53 | from typing import Optional, Dict, List, Union, Tuple 54 | 55 | torch.backends.cuda.matmul.allow_tf32 = True 56 | 57 | def linear_with_warmup( 58 | current_step: int, *, num_warmup_steps: int, num_training_steps: int 59 | ): 60 | """ 61 | Copied from transformers.optimization._get_linear_schedule_with_warmup_lr_lambda 62 | """ 63 | if current_step < num_warmup_steps: 64 | return float(current_step) / float(max(1, num_warmup_steps)) 65 | return max( 66 | 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)) 67 | ) 68 | 69 | def cosine_with_warmup( 70 | current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5 71 | ): 72 | """ 73 | Copied from transformers.optimization._get_cosine_schedule_with_warmup_lr_lambda 74 | """ 75 | if current_step < num_warmup_steps: 76 | return float(current_step) / float(max(1, num_warmup_steps)) 77 | progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) 78 | 79 | return max( 80 | 0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)) 81 | ) 82 | 83 | def linear_warmup( 84 | current_step: int, *, num_warmup_steps: int, num_training_steps: int 85 | ): 86 | if current_step < num_warmup_steps: 87 | return float(current_step) / float(max(1, num_warmup_steps)) 88 | else: 89 | return 1.0 90 | 91 | def preference_loss(policy_chosen_logps: torch.FloatTensor, 92 | policy_rejected_logps: torch.FloatTensor, 93 | reference_chosen_logps: torch.FloatTensor, 94 | reference_rejected_logps: torch.FloatTensor, 95 | beta: float, 96 | gamma: float = 1.0, 97 | label_smoothing: float = 0.0, 98 | loss_name: str = 'dpo', 99 | reference_free: bool = False, 100 | rejected_free: bool = False) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: 101 | """Compute the DPO loss for a batch of policy and reference model log probabilities. 102 | 103 | Args: 104 | policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,) 105 | policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,) 106 | reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,) 107 | reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,) 108 | beta: Temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5. We ignore the reference model as beta -> 0. 109 | label_smoothing: conservativeness for DPO loss, which assumes that preferences are noisy (flipped with probability label_smoothing) 110 | ipo: If True, use the IPO loss instead of the DPO loss. 111 | reference_free: If True, we ignore the _provided_ reference model and implicitly use a reference model that assigns equal probability to all responses. 112 | 113 | Returns: 114 | A tuple of three tensors: (losses, chosen_rewards, rejected_rewards). 115 | The losses tensor contains the DPO loss for each example in the batch. 116 | The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively. 117 | """ 118 | if rejected_free: 119 | policy_rejected_logps = torch.zeros_like(policy_rejected_logps).to(policy_chosen_logps.device) 120 | reference_rejected_logps = torch.zeros_like(reference_rejected_logps).to(policy_chosen_logps.device) 121 | pi_logratios = policy_chosen_logps - policy_rejected_logps 122 | ref_logratios = reference_chosen_logps - reference_rejected_logps 123 | 124 | if reference_free: 125 | ref_logratios = 0 126 | 127 | logits = pi_logratios - ref_logratios # also known as h_{\pi_\theta}^{y_w,y_l} 128 | 129 | if loss_name == 'dpo': 130 | # Eq. 3 https://ericmitchell.ai/cdpo.pdf; label_smoothing=0 gives original DPO (Eq. 7 of https://arxiv.org/pdf/2305.18290.pdf) 131 | losses = -F.logsigmoid(beta * logits) * (1 - label_smoothing) - F.logsigmoid(-beta * logits) * label_smoothing 132 | elif loss_name == 'ipo': 133 | # Eq. 17 of https://arxiv.org/pdf/2310.12036v2.pdf 134 | losses = (logits - 1/(2 * beta)) ** 2 135 | 136 | chosen_rewards = beta * (policy_chosen_logps - reference_chosen_logps).detach() 137 | rejected_rewards = beta * (policy_rejected_logps - reference_rejected_logps).detach() 138 | 139 | return losses, chosen_rewards, rejected_rewards 140 | 141 | 142 | def _get_batch_logps(logits: torch.FloatTensor, 143 | labels: torch.LongTensor, 144 | average_log_prob: bool = False, 145 | per_token_prob: bool = False) -> torch.FloatTensor: 146 | """Compute the log probabilities of the given labels under the given logits. 147 | 148 | Args: 149 | logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size) 150 | labels: Labels for which to compute the log probabilities. Label tokens with a value of -100 are ignored. Shape: (batch_size, sequence_length) 151 | average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens. 152 | 153 | Returns: 154 | A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits. 155 | """ 156 | assert logits.shape[:-1] == labels.shape 157 | 158 | labels = labels[:, 1:].clone() 159 | logits = logits[:, :-1, :] 160 | loss_mask = (labels != -100) 161 | 162 | # dummy token; we'll ignore the losses on these tokens later 163 | labels[labels == -100] = 0 164 | per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) 165 | 166 | if per_token_prob: 167 | return per_token_logps * loss_mask 168 | elif average_log_prob: 169 | return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) 170 | else: 171 | return (per_token_logps * loss_mask).sum(-1) 172 | 173 | 174 | def concatenated_inputs(batch: Dict[str, Union[List, torch.LongTensor]]) -> Dict[str, torch.LongTensor]: 175 | """Concatenate the chosen and rejected inputs into a single tensor. 176 | 177 | Args: 178 | batch: A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors of shape (batch_size, sequence_length). 179 | 180 | Returns: 181 | A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'. 182 | """ 183 | max_length = max(batch['chosen_input_ids'].shape[1], batch['rejected_input_ids'].shape[1]) 184 | concatenated_batch = {} 185 | for k in batch: 186 | if k.startswith('chosen') and isinstance(batch[k], torch.Tensor): 187 | pad_value = -100 if 'labels' in k else 0 188 | concatenated_key = k.replace('chosen', 'concatenated') 189 | concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value) 190 | for k in batch: 191 | if k.startswith('rejected') and isinstance(batch[k], torch.Tensor): 192 | pad_value = -100 if 'labels' in k else 0 193 | concatenated_key = k.replace('rejected', 'concatenated') 194 | concatenated_batch[concatenated_key] = torch.cat(( 195 | concatenated_batch[concatenated_key], 196 | pad_to_length(batch[k], max_length, pad_value=pad_value), 197 | ), dim=0) 198 | return concatenated_batch 199 | 200 | class BasicTrainer(object): 201 | def __init__(self, 202 | policy: nn.Module, 203 | config: DictConfig, 204 | seed: int, 205 | run_dir: str, 206 | policy_weak: Optional[nn.Module] = None, 207 | reference_model: Optional[nn.Module] = None, 208 | truncation_side="right", 209 | padding_side="right", 210 | rank: int = 0, 211 | world_size: int = 1 212 | ): 213 | """A trainer for a language model, supporting either SFT or DPO training. 214 | 215 | If multiple GPUs are present, naively splits the model across them, effectively 216 | offering N times available memory, but without any parallel computation. 217 | """ 218 | self.seed = seed 219 | self.rank = rank 220 | self.world_size = world_size 221 | self.config = config 222 | self.run_dir = run_dir 223 | self.debug = config.debug 224 | assert self.config.batch_size % self.config.gradient_accumulation_steps == 0, 'batch_size must be divisible by gradient_accumulation_steps' 225 | 226 | tokenizer_name_or_path = config.model.tokenizer_name_or_path or config.model.name_or_path 227 | rank0_print(f'Loading tokenizer {tokenizer_name_or_path}') 228 | self.tokenizer = transformers.AutoTokenizer.from_pretrained( 229 | tokenizer_name_or_path, 230 | truncation_side=truncation_side, 231 | padding_side=padding_side, 232 | cache_dir=get_local_dir(config.local_dirs)) 233 | 234 | if self.tokenizer.chat_template is None: 235 | self.tokenizer.chat_template = "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}" 236 | else: 237 | rank0_print(f'chat_template: {self.tokenizer.chat_template}') 238 | 239 | if self.tokenizer.pad_token_id is None: 240 | self.tokenizer.pad_token_id = self.tokenizer.eos_token_id 241 | 242 | data_iterator_kwargs = dict( 243 | names=config.datasets, 244 | tokenizer=self.tokenizer, 245 | shuffle=True, 246 | max_length=config.max_length, 247 | max_prompt_length=config.max_prompt_length, 248 | loss_name=config.loss.name, 249 | seed=seed, 250 | silent=rank != 0, 251 | cache_dir=get_local_dir(config.local_dirs), 252 | ) 253 | 254 | self.policy = policy 255 | self.policy_weak = policy_weak 256 | self.reference_model = reference_model 257 | if self.config.loss.name in {'ift'}: 258 | self.embed_tokens = nn.Embedding(num_embeddings=self.policy.model.embed_tokens.num_embeddings, 259 | embedding_dim=self.policy.model.embed_tokens.embedding_dim, 260 | padding_idx=self.policy.model.embed_tokens.padding_idx, 261 | device=torch.device(self.rank), 262 | dtype=self.policy.model.embed_tokens.weight.dtype) 263 | self.embed_tokens.load_state_dict(self.policy.model.embed_tokens.state_dict()) 264 | self.embed_tokens.requires_grad_(False) 265 | 266 | self.train_loader = globals()[f'{config.loss.name.upper()}DataLoader']( 267 | split='train', 268 | batch_size=config.batch_size, 269 | n_epochs=config.n_epochs, 270 | n_examples=config.n_examples, 271 | **data_iterator_kwargs 272 | ) 273 | rank0_print(f'Loaded train data iterator') 274 | 275 | self.eval_loader = globals()[f'{config.loss.name.upper()}DataLoader']( 276 | split='test', 277 | batch_size=config.eval_batch_size, 278 | n_examples=config.n_eval_examples, 279 | **data_iterator_kwargs 280 | ) 281 | self.eval_batches = list(self.eval_loader) 282 | rank0_print(f'Loaded {len(self.eval_batches)} eval batches of size {config.eval_batch_size}') 283 | 284 | self.train_iterations = len(self.train_loader) 285 | self.warmup_steps = math.ceil(self.train_iterations * config.warmup_ratio) if config.warmup_ratio is not None else config.warmup_steps 286 | rank0_print(f'Using {self.warmup_steps} warmup steps') 287 | 288 | self.gamma = config.loss.gamma 289 | self.min_lambda = config.loss.min_lambda 290 | self.max_lambda = config.loss.max_lambda 291 | self.lambda_schedule = config.loss.lambda_schedule 292 | self.lambda_disturb = config.loss.lambda_disturb 293 | self.disturb_std = config.loss.disturb_std 294 | 295 | if self.lambda_disturb == "normal": 296 | self.noise = [] 297 | for _ in range(self.train_iterations): 298 | self.noise.append(torch.randn(1).item()) 299 | 300 | def get_batch_samples(self, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]: 301 | """Generate samples from the policy (and reference model, if doing DPO training) for the given batch of inputs.""" 302 | 303 | # FSDP generation according to https://github.com/pytorch/pytorch/issues/100069 304 | ctx = lambda: (FSDP.summon_full_params(self.policy, writeback=False, recurse=False) if 'FSDP' in self.config.trainer else contextlib.nullcontext()) 305 | with ctx(): 306 | policy_output = self.policy.generate( 307 | batch['prompt_input_ids'], attention_mask=batch['prompt_attention_mask'], max_length=self.config.max_length, do_sample=True, pad_token_id=self.tokenizer.pad_token_id) 308 | 309 | if self.config.loss.name in {'dpo', 'ipo', 'mypo'}: 310 | ctx = lambda: (FSDP.summon_full_params(self.reference_model, writeback=False, recurse=False) if 'FSDP' in self.config.trainer else contextlib.nullcontext()) 311 | with ctx(): 312 | reference_output = self.reference_model.generate( 313 | batch['prompt_input_ids'], attention_mask=batch['prompt_attention_mask'], max_length=self.config.max_length, do_sample=True, pad_token_id=self.tokenizer.pad_token_id) 314 | 315 | policy_output = pad_to_length(policy_output, self.config.max_length, self.tokenizer.pad_token_id) 316 | policy_output = all_gather_if_needed(policy_output, self.rank, self.world_size) 317 | policy_output_decoded = self.tokenizer.batch_decode(policy_output, skip_special_tokens=True) 318 | 319 | if self.config.loss.name in {'dpo', 'ipo', 'mypo'}: 320 | reference_output = pad_to_length(reference_output, self.config.max_length, self.tokenizer.pad_token_id) 321 | reference_output = all_gather_if_needed(reference_output, self.rank, self.world_size) 322 | reference_output_decoded = self.tokenizer.batch_decode(reference_output, skip_special_tokens=True) 323 | else: 324 | reference_output_decoded = [] 325 | 326 | return policy_output_decoded, reference_output_decoded 327 | 328 | def concatenated_forward(self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]) -> Tuple[torch.FloatTensor, torch.FloatTensor]: 329 | """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together. 330 | 331 | We do this to avoid doing two forward passes, because it's faster for FSDP. 332 | """ 333 | concatenated_batch = concatenated_inputs(batch) 334 | all_logits = model(concatenated_batch['concatenated_input_ids'], attention_mask=concatenated_batch['concatenated_attention_mask']).logits.to(torch.float32) 335 | all_logps = _get_batch_logps(all_logits, concatenated_batch['concatenated_labels'], average_log_prob=False) 336 | chosen_logps = all_logps[:batch['chosen_input_ids'].shape[0]] 337 | rejected_logps = all_logps[batch['chosen_input_ids'].shape[0]:] 338 | return chosen_logps, rejected_logps 339 | 340 | 341 | def get_sample_function(self, logits: torch.FloatTensor, sample: str = 'greedy', k: int = 2, temperature: float = 0.7) -> torch.LongTensor: 342 | """Get the labels for the given logits using the given sampling strategy.""" 343 | with torch.no_grad(): 344 | if sample == 'greedy': 345 | token_sample = torch.argmax(logits, dim=-1) 346 | elif sample == 'topk': 347 | token_sample = torch.topk(logits, k=k, dim=-1).indices[..., -1] 348 | elif sample == 'nucleus': 349 | batch_size, seq_len, vocab_size = logits.shape 350 | logits = logits.view(-1, vocab_size) 351 | token_sample = torch.multinomial(F.softmax(logits / temperature, dim=-1), num_samples=1).squeeze(-1) 352 | token_sample = token_sample.view(batch_size, seq_len) 353 | else: 354 | raise ValueError(f'unknown sample {sample}') 355 | 356 | return token_sample 357 | 358 | def get_cumsum_weight( 359 | self, 360 | logps, 361 | loss_mask, 362 | gamma=1, 363 | propagation_type='loss', 364 | propagation_norm='L1', 365 | propagation_side='right') -> torch.FloatTensor: 366 | 367 | if gamma == 0: 368 | return torch.ones_like(logps) 369 | 370 | if propagation_type == 'mask': 371 | cumsum_item = loss_mask 372 | elif propagation_type == 'loss': 373 | cumsum_item = -logps / loss_mask.sum(-1).unsqueeze(-1) 374 | elif propagation_type == 'logps': 375 | cumsum_item = -logps 376 | else: 377 | raise ValueError(f'unknown propagation_type {propagation_type}') 378 | 379 | cumsum_item[loss_mask == 0] = 0 380 | 381 | if propagation_side == 'right': 382 | if gamma == 1: 383 | cumsum_weight = torch.cumsum(cumsum_item, dim=-1) 384 | else: 385 | cumsum_weight = discounted_cumsum_left(cumsum_item, gamma=gamma) 386 | 387 | cumsum_weight[loss_mask == 0] = 1e6 388 | cumsum_weight += cumsum_weight[cumsum_weight.nonzero(as_tuple=True)].min() 389 | 390 | if propagation_norm == 'L1': # sharp level 1 391 | cumsum_weight = 1 / (cumsum_weight) 392 | elif propagation_norm == 'L2': # sharp level 2 393 | cumsum_weight = 1 / (cumsum_weight) ** 2 394 | elif propagation_norm == 'softmax': # sharp level 3 395 | cumsum_weight = torch.softmax(1/cumsum_weight, dim=-1) 396 | elif propagation_norm == 'log': # sharp level 0 397 | cumsum_weight = 1 / torch.log(cumsum_weight + 1) 398 | else: 399 | raise ValueError(f'unknown propagation_norm {propagation_norm}') 400 | 401 | elif propagation_side == 'left': 402 | if gamma == 1: 403 | cumsum_weight = torch.flip(torch.cumsum(torch.flip(cumsum_item, [-1]), dim=-1), [-1]) 404 | else: 405 | cumsum_weight = discounted_cumsum_right(cumsum_item, gamma=gamma) 406 | cumsum_weight[loss_mask == 0] = 0 407 | 408 | if propagation_norm == 'L1': # sharp level 2 409 | cumsum_weight = cumsum_weight 410 | elif propagation_norm == 'L2': # sharp level 3 411 | cumsum_weight = cumsum_weight ** 2 412 | elif propagation_norm == 'softmax': # sharp level 1 413 | cumsum_weight = torch.softmax(cumsum_weight, dim=-1) 414 | elif propagation_norm == 'log': # sharp level 0 415 | cumsum_weight = torch.log(cumsum_weight + 1) 416 | else: 417 | raise ValueError(f'unknown propagation_norm {propagation_norm}') 418 | 419 | else: 420 | raise ValueError(f'unknown propagation_side {propagation_side}') 421 | 422 | cumsum_weight[loss_mask == 0] = 0 423 | cumsum_weight *= (loss_mask.sum(-1, keepdim=True) / cumsum_weight.sum(-1, keepdim=True)) 424 | 425 | return cumsum_weight 426 | 427 | def update_lambda( 428 | self, 429 | step_idx: int 430 | ) -> float: 431 | """Get the gamma value for the given step index.""" 432 | if self.lambda_schedule: 433 | schedule = linear_warmup( 434 | step_idx, 435 | num_warmup_steps=self.config.warmup_steps, 436 | num_training_steps=self.train_iterations) 437 | self._lambda = self.min_lambda + (self.max_lambda - self.min_lambda) * schedule 438 | else: 439 | self._lambda = self.max_lambda 440 | 441 | if self.lambda_disturb: 442 | self._lambda += self.noise[step_idx] * self.disturb_std * self._lambda 443 | 444 | self._lambda = torch.clamp( 445 | input=torch.tensor(self._lambda), 446 | min=0.0, 447 | max=1.0).item() 448 | 449 | return self._lambda 450 | 451 | def debug_inputs(self, inputs, device=0): 452 | input_ids = inputs['target_input_ids'] 453 | attention_mask = inputs['target_attention_mask'] 454 | labels = inputs['target_labels'] 455 | loss_mask = labels[:, 1:] != -100 456 | 457 | if device == "all": 458 | for input_id, label, mask in zip(input_ids, labels, attention_mask): 459 | print("#"*10+" input_ids "+"#"*10) 460 | print(f"{self.tokenizer.decode(input_id)}\n") 461 | print("#"*10+" labels "+"#"*10) 462 | print(f"{self.tokenizer.decode(torch.where(label==-100, 0, label))}\n") 463 | print("#"*10+" attention_mask "+"#"*10) 464 | print(f"{self.tokenizer.decode(torch.where(mask==0, 0, input_id))}\n") 465 | 466 | print(len(input_id)) 467 | print(len(label)) 468 | print(len(attention_mask)) 469 | 470 | with open(f"{self.args.output_dir}/{torch.distributed.get_rank()}.json", "w", encoding="utf-8") as file: 471 | data_dict = { 472 | "input_ids": self.tokenizer.decode(input_id), 473 | "labels": self.tokenizer.decode(torch.where(label==-100, 0, label)), 474 | "attention_mask": self.tokenizer.decode(torch.where(mask==0, 0, input_id)) 475 | } 476 | json.dump(data_dict, file, indent=4, ensure_ascii=False) 477 | 478 | 479 | elif torch.distributed.get_rank() == device: 480 | for input_id, label, mask in zip(input_ids, labels, attention_mask): 481 | print("#"*10+" input_ids "+"#"*10) 482 | print(f"{self.tokenizer.decode(input_id)}\n") 483 | print("#"*10+" labels "+"#"*10) 484 | print(f"{self.tokenizer.decode(torch.where(label==-100, 0, label))}\n") 485 | print("#"*10+" attention_mask "+"#"*10) 486 | print(f"{self.tokenizer.decode(torch.where(mask==0, 0, input_id))}\n") 487 | 488 | print(len(input_id)) 489 | print(len(label)) 490 | print(len(attention_mask)) 491 | 492 | exit() 493 | 494 | def get_batch_metrics(self, batch: Dict[str, Union[List, torch.LongTensor]], loss_config: DictConfig, train=True): 495 | """Compute the SFT or DPO loss and other metrics for the given batch of inputs.""" 496 | 497 | metrics = {} 498 | train_test = 'train' if train else 'eval' 499 | 500 | if loss_config.name in {'dpo', 'ipo'}: 501 | policy_chosen_logps, policy_rejected_logps = self.concatenated_forward(self.policy, batch) 502 | with torch.no_grad(): 503 | reference_chosen_logps, reference_rejected_logps = self.concatenated_forward(self.reference_model, batch) 504 | 505 | if loss_config.name == 'dpo': 506 | loss_kwargs = {'loss_name': loss_config.name, 'beta': loss_config.beta, 507 | 'reference_free': loss_config.reference_free, 'rejected_free': loss_config.rejected_free, 'label_smoothing': loss_config.label_smoothing} 508 | elif loss_config.name == 'ipo': 509 | loss_kwargs = {'loss_name': loss_config.name, 'beta': loss_config.beta} 510 | else: 511 | raise ValueError(f'unknown loss {loss_config.name}') 512 | 513 | losses, chosen_rewards, rejected_rewards = preference_loss( 514 | policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps, **loss_kwargs) 515 | 516 | reward_accuracies = (chosen_rewards > rejected_rewards).float() 517 | 518 | chosen_rewards = all_gather_if_needed(chosen_rewards, self.rank, self.world_size) 519 | rejected_rewards = all_gather_if_needed(rejected_rewards, self.rank, self.world_size) 520 | reward_accuracies = all_gather_if_needed(reward_accuracies, self.rank, self.world_size) 521 | 522 | metrics[f'rewards_{train_test}/chosen'] = chosen_rewards.cpu().numpy().tolist() 523 | metrics[f'rewards_{train_test}/rejected'] = rejected_rewards.cpu().numpy().tolist() 524 | metrics[f'rewards_{train_test}/accuracies'] = reward_accuracies.cpu().numpy().tolist() 525 | metrics[f'rewards_{train_test}/margins'] = (chosen_rewards - rejected_rewards).cpu().numpy().tolist() 526 | 527 | policy_rejected_logps = all_gather_if_needed(policy_rejected_logps.detach(), self.rank, self.world_size) 528 | metrics[f'logps_{train_test}/rejected'] = policy_rejected_logps.cpu().numpy().tolist() 529 | 530 | policy_chosen_logps = all_gather_if_needed(policy_chosen_logps.detach(), self.rank, self.world_size) 531 | metrics[f'logps_{train_test}/chosen'] = policy_chosen_logps.cpu().numpy().tolist() 532 | 533 | all_devices_losses = all_gather_if_needed(losses.detach(), self.rank, self.world_size) 534 | metrics[f'loss/{train_test}'] = all_devices_losses.cpu().numpy().tolist() 535 | losses = losses.mean() 536 | 537 | elif loss_config.name == 'sft': 538 | # with KL divergence 539 | input_ids = batch['target_input_ids'] 540 | attention_mask = batch['target_attention_mask'] 541 | labels = batch['target_labels'] 542 | 543 | logits = self.policy(input_ids, attention_mask=attention_mask).logits 544 | logps = _get_batch_logps(logits, labels, average_log_prob=True) 545 | losses = -logps.mean() 546 | 547 | metrics[f'loss/{train_test}'] = losses.cpu().numpy().tolist() 548 | 549 | elif loss_config.name == 'ift': 550 | input_ids = batch['target_input_ids'] 551 | attention_mask = batch['target_attention_mask'] 552 | labels = batch['target_labels'] 553 | loss_mask = (labels[:, 1:] != -100) 554 | 555 | # self.debug_inputs(batch) 556 | 557 | inputs_embeds = self.embed_tokens(input_ids) 558 | 559 | with torch.no_grad(): 560 | logits = self.policy(inputs_embeds=inputs_embeds, attention_mask=attention_mask).logits 561 | logps = _get_batch_logps(logits, labels, per_token_prob=True) 562 | losses_sft = -logps.sum(-1) / loss_mask.sum(-1) 563 | all_devices_losses_sft = all_gather_if_needed(losses_sft.mean().detach(), self.rank, self.world_size) 564 | metrics[f'loss/{train_test}_sft'] = all_devices_losses_sft.cpu().numpy().tolist() 565 | 566 | _lambda = self.update_lambda(step_idx=self.batch_counter) 567 | 568 | metrics[f'lambda/{train_test}'] = _lambda 569 | 570 | tokens_further = torch.cat((input_ids[:, 0].unsqueeze(-1), self.get_sample_function(logits)[:, :-1]), dim=-1) 571 | input_ids_further = torch.where(labels==-100, input_ids, tokens_further) 572 | attention_mask_further = attention_mask 573 | 574 | inputs_embeds_further = self.embed_tokens(input_ids_further) 575 | inputs_embeds_further = (1 - _lambda) * inputs_embeds + _lambda * inputs_embeds_further 576 | logits_further = self.policy(inputs_embeds=inputs_embeds_further, attention_mask=attention_mask_further).logits 577 | 578 | logps_further = _get_batch_logps(logits_further, labels, per_token_prob=True) 579 | 580 | with torch.no_grad(): 581 | losses_further = -logps_further.sum(-1) / loss_mask.sum(-1) 582 | all_devices_losses_further = all_gather_if_needed(losses_further.mean().detach(), self.rank, self.world_size) 583 | metrics[f'loss/{train_test}_further'] = all_devices_losses_further.cpu().numpy().tolist() 584 | 585 | cumsum_weight = self.get_cumsum_weight( 586 | logps=logps_further, 587 | loss_mask=loss_mask, 588 | gamma=loss_config.gamma, 589 | propagation_type=loss_config.propagation_type, 590 | propagation_norm=loss_config.propagation_norm, 591 | propagation_side=loss_config.propagation_side 592 | ) 593 | 594 | losses = ((-logps_further * cumsum_weight).sum(-1) / loss_mask.sum(-1)).mean() 595 | 596 | all_devices_losses = all_gather_if_needed(losses.detach(), self.rank, self.world_size) 597 | metrics[f'loss/{train_test}'] = all_devices_losses.cpu().numpy().tolist() 598 | 599 | elif loss_config.name == 'orpo': 600 | logps_chosen_sum, logps_rejected_sum = self.concatenated_forward(self.policy, batch) 601 | 602 | loss_mask_chosen = (batch['chosen_labels'][:, 1:] != -100) 603 | loss_mask_rejected = (batch['rejected_labels'][:, 1:] != -100) 604 | 605 | logps_chosen = logps_chosen_sum / loss_mask_chosen.sum(-1) 606 | logps_rejected = logps_rejected_sum / loss_mask_rejected.sum(-1) 607 | 608 | loss_chosen = -logps_chosen 609 | 610 | log_odds = (logps_chosen - logps_rejected) - (torch.log(1 - torch.exp(logps_chosen)) - torch.log(1 - torch.exp(logps_rejected))) 611 | sig_ratio = torch.sigmoid(log_odds) 612 | ratio = torch.log(sig_ratio) 613 | 614 | beta = loss_config.beta 615 | losses = (loss_chosen - beta * ratio).mean() 616 | 617 | all_devices_losses_sft = all_gather_if_needed(loss_chosen.mean().detach(), self.rank, self.world_size) 618 | metrics[f'loss/{train_test}_sft'] = all_devices_losses_sft.cpu().numpy().tolist() 619 | 620 | all_devices_loss = all_gather_if_needed(losses.mean().detach(), self.rank, self.world_size) 621 | metrics[f'loss/{train_test}'] = all_devices_loss.cpu().numpy().tolist() 622 | 623 | return losses, metrics 624 | 625 | def get_batch_embeddings(self, batch: Dict[str, Union[List, torch.LongTensor]]) -> Tuple[torch.FloatTensor, torch.FloatTensor]: 626 | """Compute the embeddings of the chosen and rejected responses for the given batch of inputs.""" 627 | # todo: check whether the function is correct 628 | with torch.no_grad(): 629 | policy_chosen_embeddings = self.policy.get_input_embeddings()(batch['chosen_input_ids']) 630 | policy_rejected_embeddings = self.policy.get_input_embeddings()(batch['rejected_input_ids']) 631 | 632 | if self.config.loss.name in {'dpo', 'ipo', 'mypo'}: 633 | with torch.no_grad(): 634 | reference_chosen_embeddings = self.reference_model.get_input_embeddings()(batch['chosen_input_ids']) 635 | reference_rejected_embeddings = self.reference_model.get_input_embeddings()(batch['rejected_input_ids']) 636 | elif self.config.loss.name == 'sft': 637 | reference_chosen_embeddings = torch.zeros_like(policy_chosen_embeddings) 638 | reference_rejected_embeddings = torch.zeros_like(policy_rejected_embeddings) 639 | else: 640 | raise ValueError(f'unknown loss {self.config.loss.name}') 641 | 642 | return policy_chosen_embeddings, policy_rejected_embeddings, reference_chosen_embeddings, reference_rejected_embeddings 643 | 644 | def train(self): 645 | """Begin either SFT or DPO training, with periodic evaluation.""" 646 | torch.manual_seed(self.seed) 647 | np.random.seed(self.seed) 648 | random.seed(self.seed) 649 | 650 | if self.config.loss.name in {'dpo', 'ipo'}: 651 | self.reference_model.eval() 652 | elif self.config.loss.name == 'gpo': 653 | self.policy_weak.eval() 654 | 655 | self.example_counter = 0 656 | self.batch_counter = 0 657 | last_log = None 658 | 659 | if self.config.checkpoint_path is not None: 660 | self.load_checkpoint(checkpoint_path=self.config.checkpoint_path) 661 | 662 | self.step_idx = self.checkpoint_example_idx if self.config.checkpoint_path is not None else 0 663 | 664 | for batch in tqdm.tqdm(self.train_loader, total=self.train_iterations, desc='Training'): 665 | if self.config.checkpoint_path is not None: 666 | if self.checkpoint_batch_idx > self.batch_counter: 667 | self.batch_counter += 1 668 | self.example_counter += self.config.batch_size 669 | continue 670 | #### BEGIN EVALUATION #### 671 | if self.example_counter % self.config.eval_every == 0 and (self.example_counter > 0 or self.config.do_first_eval): 672 | rank0_print(f'Running evaluation after {self.example_counter} train examples') 673 | self.policy.eval() 674 | 675 | all_eval_metrics = defaultdict(list) 676 | if self.config.sample_during_eval: 677 | all_policy_samples, all_reference_samples = [], [] 678 | policy_text_table = wandb.Table(columns=["step", "prompt", "sample"]) 679 | if self.config.loss.name in {'dpo', 'ipo', 'mypo'}: 680 | reference_text_table = wandb.Table(columns=["step", "prompt", "sample"]) 681 | 682 | for eval_batch in (tqdm.tqdm(self.eval_batches, desc='Computing eval metrics') if self.rank == 0 else self.eval_batches): 683 | local_eval_batch = slice_and_move_batch_for_device(eval_batch, self.rank, self.world_size, self.rank) 684 | with torch.no_grad(): 685 | loss, eval_metrics = self.get_batch_metrics(local_eval_batch, self.config.loss, train=False) 686 | 687 | for k, v in eval_metrics.items(): 688 | try: 689 | all_eval_metrics[k].extend(v) 690 | except: 691 | all_eval_metrics[k].append(v) 692 | 693 | if self.config.sample_during_eval: 694 | if self.config.n_eval_model_samples < self.config.eval_batch_size: 695 | rank0_print(f'Warning: n_eval_model_samples ({self.config.n_eval_model_samples}) < eval_batch_size ({self.config.eval_batch_size}). Sampling from the first complete eval batch of prompts.') 696 | sample_batches = self.eval_batches[:1] 697 | else: 698 | n_sample_batches = self.config.n_eval_model_samples // self.config.eval_batch_size 699 | sample_batches = self.eval_batches[:n_sample_batches] 700 | for eval_batch in (tqdm.tqdm(sample_batches, desc='Generating samples...') if self.rank == 0 else sample_batches): 701 | local_eval_batch = slice_and_move_batch_for_device(eval_batch, self.rank, self.world_size, self.rank) 702 | policy_samples, reference_samples = self.get_batch_samples(local_eval_batch) 703 | 704 | all_policy_samples.extend(policy_samples) 705 | all_reference_samples.extend(reference_samples) 706 | 707 | for prompt, sample in zip(eval_batch['prompt'], policy_samples): 708 | policy_text_table.add_data(self.example_counter, prompt, sample) 709 | if self.config.loss.name in {'dpo', 'ipo', 'mypo'}: 710 | for prompt, sample in zip(eval_batch['prompt'], reference_samples): 711 | reference_text_table.add_data(self.example_counter, prompt, sample) 712 | 713 | mean_eval_metrics = {k: sum(v) / len(v) for k, v in all_eval_metrics.items()} 714 | 715 | rank0_print(f'eval after {self.example_counter}: {formatted_dict(mean_eval_metrics)}') 716 | 717 | if self.config.sample_during_eval: 718 | rank0_print(json.dumps(all_policy_samples[:10], indent=2)) 719 | if self.config.loss.name in {'dpo', 'ipo', 'mypo'}: 720 | rank0_print(json.dumps(all_reference_samples[:10], indent=2)) 721 | 722 | if self.config.wandb.enabled and self.rank == 0: 723 | wandb.log(mean_eval_metrics, step=self.example_counter) 724 | 725 | if self.config.sample_during_eval: 726 | wandb.log({"policy_samples": policy_text_table}, step=self.example_counter) 727 | if self.config.loss.name in {'dpo', 'ipo', 'mypo'}: 728 | wandb.log({"reference_samples": reference_text_table}, step=self.example_counter) 729 | 730 | if self.example_counter > 0: 731 | if self.config.debug: 732 | rank0_print('skipping save in debug mode') 733 | else: 734 | output_dir = os.path.join(self.run_dir, f'step-{self.example_counter}') 735 | rank0_print(f'creating checkpoint to write to {output_dir}...') 736 | self.save(output_dir, mean_eval_metrics) 737 | #### END EVALUATION #### 738 | 739 | #### BEGIN TRAINING #### 740 | self.policy.train() 741 | start_time = time.time() 742 | batch_metrics = defaultdict(list) 743 | 744 | for microbatch_idx in range(self.config.gradient_accumulation_steps): 745 | global_microbatch = slice_and_move_batch_for_device(batch, microbatch_idx, self.config.gradient_accumulation_steps, 'cpu') 746 | local_microbatch = slice_and_move_batch_for_device(global_microbatch, self.rank, self.world_size, self.rank) 747 | 748 | loss, metrics = self.get_batch_metrics(local_microbatch, self.config.loss, train=True) 749 | 750 | loss.backward() 751 | 752 | for k, v in metrics.items(): 753 | try: 754 | batch_metrics[k].extend(v) 755 | except: 756 | batch_metrics[k].append(v) 757 | 758 | # gather the gradients 759 | for p in self.policy.parameters(): 760 | dist.all_reduce(p.grad.data, op=dist.ReduceOp.AVG) 761 | p.grad.data /= self.config.gradient_accumulation_steps 762 | 763 | grad_norm = self.clip_gradient() 764 | 765 | dist.barrier() 766 | 767 | self.optimizer.step() 768 | self.lr_scheduler.step() 769 | self.optimizer.zero_grad() 770 | 771 | step_time = time.time() - start_time 772 | examples_per_second = self.config.batch_size / step_time 773 | learning_rate = self.optimizer.param_groups[0]['lr'] 774 | 775 | batch_metrics['examples_per_second'].append(examples_per_second) 776 | batch_metrics['grad_norm'].append(grad_norm) 777 | batch_metrics['loss/train'].append(loss.item()) 778 | batch_metrics['learning_rate'].append(learning_rate) 779 | 780 | rank0_print(f'batch {self.batch_counter}: loss {loss.item()}, grad_norm {grad_norm}, lr {learning_rate}, {examples_per_second} examples/s') 781 | 782 | self.batch_counter += 1 783 | self.example_counter += self.config.batch_size 784 | 785 | if self.config.loss.name in {'ift'}: 786 | self.embed_tokens.weight.data = self.policy.state_dict()['model.embed_tokens.weight'].clone().to(self.embed_tokens.weight.dtype) 787 | dist.barrier() 788 | 789 | if last_log is None or time.time() - last_log > self.config.minimum_log_interval_secs: 790 | mean_train_metrics = {k: sum(v) / len(v) for k, v in batch_metrics.items()} 791 | mean_train_metrics['counters/examples'] = self.example_counter 792 | mean_train_metrics['counters/updates'] = self.batch_counter 793 | 794 | if self.config.wandb.enabled and self.rank == 0: 795 | wandb.log(mean_train_metrics, step=self.example_counter) 796 | 797 | last_log = time.time() 798 | else: 799 | rank0_print(f'skipping logging after {self.example_counter} examples to avoid logging too frequently') 800 | 801 | self.step_idx += 1 802 | #### END TRAINING #### 803 | 804 | def clip_gradient(self): 805 | """Clip the gradient norm of the parameters of a non-FSDP policy.""" 806 | return torch.nn.utils.clip_grad_norm_(self.policy.parameters(), self.config.max_grad_norm).item() 807 | 808 | def write_state_dict(self, step: int, state: Dict[str, torch.Tensor], metrics: Dict, filename: str, dir_name: Optional[str] = None): 809 | """Write a checkpoint to disk.""" 810 | if dir_name is None: 811 | dir_name = os.path.join(self.run_dir, f'LATEST') 812 | 813 | os.makedirs(dir_name, exist_ok=True) 814 | output_path = os.path.join(dir_name, filename) 815 | rank0_print(f'writing checkpoint to {output_path}...') 816 | torch.save({ 817 | 'step_idx': step, 818 | 'state': state, 819 | 'metrics': metrics if metrics is not None else {}, 820 | }, output_path) 821 | 822 | def load_checkpoint(self, checkpoint_path: str): 823 | """Load a checkpoint from disk.""" 824 | rank0_print(f'Loading checkpoint from {checkpoint_path}') 825 | policy = torch.load(f'{checkpoint_path}/policy.pt', map_location='cpu') 826 | self.policy.load_state_dict(policy['state']) 827 | optimizer = torch.load(f'{checkpoint_path}/optimizer.pt', map_location='cpu') 828 | self.optimizer.load_state_dict(optimizer['state']) 829 | lr_scheduler = torch.load(f'{checkpoint_path}/scheduler.pt', map_location='cpu') 830 | self.lr_scheduler.load_state_dict(lr_scheduler['state']) 831 | 832 | self.checkpoint_example_idx = policy['step_idx'] 833 | self.checkpoint_batch_idx = policy['step_idx'] // self.config.batch_size 834 | rank0_print(f'Loaded checkpoint from {checkpoint_path} at step {self.checkpoint_example_idx}') 835 | 836 | def save(self, output_dir: Optional[str] = None, metrics: Optional[Dict] = None): 837 | """Save policy, optimizer, and scheduler state to disk.""" 838 | # if self.config.loss.fuse_mode == 'embeds': 839 | # self.policy.model.embed_tokens.load_state_dict(self.embed_tokens.state_dict()) 840 | policy_state_dict = self.policy.state_dict() 841 | self.write_state_dict(self.example_counter, policy_state_dict, metrics, 'policy.pt', output_dir) 842 | del policy_state_dict 843 | 844 | optimizer_state_dict = self.optimizer.state_dict() 845 | self.write_state_dict(self.example_counter, optimizer_state_dict, metrics, 'optimizer.pt', output_dir) 846 | del optimizer_state_dict 847 | 848 | scheduler_state_dict = self.lr_scheduler.state_dict() 849 | self.write_state_dict(self.example_counter, scheduler_state_dict, metrics, 'scheduler.pt', output_dir) 850 | 851 | def watch_grad(self): 852 | """Watch the gradient of the policy during training, but don't make parameter updates.""" 853 | self.optimizer = getattr(torch.optim, self.config.optimizer)(self.policy.parameters(), lr=self.config.lr) 854 | self.lr_scheduler = torch.optim.lr_scheduler.LinearLR(self.optimizer, start_factor=0.1, total_iters=self.num_iterations) 855 | 856 | self.grad_list = [] 857 | self.policy.train() 858 | for batch in tqdm.tqdm(self.train_iterator, total=self.num_iterations): 859 | local_microbatch = slice_and_move_batch_for_device(batch, self.rank, self.world_size, self.rank) 860 | loss, metrics = self.get_batch_metrics(local_microbatch, self.config.loss, train=True) 861 | loss.backward() 862 | grad_norm = self.clip_gradient() 863 | 864 | 865 | self.grad_list.append(grad_norm) 866 | out_json = {"grad_norm": grad_norm, 867 | "prompt": batch['prompt'][0], 868 | "chosen": batch['chosen_response_only'][0], 869 | "rejected": batch['rejected_response_only'][0] 870 | } 871 | with open(f'figure/{self.config.datasets}_{self.config.model}_grad_norm.jsonl', 'a') as f: 872 | f.write(json.dumps(out_json) + "\n") 873 | with open(f'figure/{self.config.datasets}_{self.config.model}_grad_norm_indent.jsonl', 'a') as f: 874 | f.write(json.dumps(out_json, indent=2) + "\n") 875 | 876 | self.optimizer.zero_grad() 877 | 878 | plt.plot(list(range(1, self.num_iterations+1)), self.grad_list) 879 | plt.xlabel('Iterations') 880 | plt.ylabel('Gradient Norm') 881 | plt.title('Gradient Norm of the Policy') 882 | plt.savefig('figure/grad_norm.png') 883 | 884 | def watch_embedding(self): 885 | """Watch the embedding of the policy during training, but don't make parameter updates.""" 886 | # todo: finish the function 887 | self.optimizer = getattr(torch.optim, self.config.optimizer)(self.policy.parameters(), lr=self.config.lr) 888 | self.lr_scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lambda step: min(1.0, (step + 1) / (self.config.warmup_steps + 1))) 889 | self.policy.train() 890 | for batch in tqdm.tqdm(self.train_iterator, total=self.num_iterations): 891 | local_microbatch = slice_and_move_batch_for_device(batch, self.rank, self.world_size, self.rank) 892 | 893 | 894 | class FSDPTrainer(BasicTrainer): 895 | 896 | def __init__(self, 897 | policy: nn.Module, 898 | config: DictConfig, 899 | seed: int, 900 | run_dir: str, 901 | policy_weak: Optional[nn.Module] = None, 902 | reference_model: Optional[nn.Module] = None, 903 | truncation_side: str = 'right', 904 | padding_side: str = 'right', 905 | rank: int = 0, 906 | world_size: int = 1 907 | ): 908 | """A trainer subclass that uses PyTorch FSDP to shard the model across multiple GPUs. 909 | 910 | This trainer will shard both the policy and reference model across all available GPUs. 911 | Models are sharded at the block level, where the block class name is provided in the config. 912 | """ 913 | 914 | super().__init__(policy=policy, 915 | config=config, 916 | seed=seed, 917 | run_dir=run_dir, 918 | policy_weak=policy_weak, 919 | reference_model=reference_model, 920 | truncation_side=truncation_side, 921 | padding_side=padding_side, 922 | rank=rank, 923 | world_size=world_size) 924 | assert config.model.block_name is not None, 'must specify model.block_name (e.g., GPT2Block or GPTNeoXLayer) for FSDP' 925 | 926 | assert self.config.batch_size % self.config.gradient_accumulation_steps % self.world_size == 0, 'batch_size must be divisible by gradient_accumulation_steps and world_size' 927 | 928 | wrap_class = get_block_class_from_model(policy, config.model.block_name) 929 | model_auto_wrap_policy = functools.partial(transformer_auto_wrap_policy, transformer_layer_cls={wrap_class},) 930 | 931 | shared_fsdp_kwargs = dict( 932 | auto_wrap_policy=model_auto_wrap_policy, 933 | sharding_strategy=ShardingStrategy.FULL_SHARD, 934 | cpu_offload=CPUOffload(offload_params=False), 935 | backward_prefetch=BackwardPrefetch.BACKWARD_PRE, 936 | device_id=rank, 937 | ignored_modules=None, # if not (self.config.loss.fuse_mode == 'embeds') else [policy.model.embed_tokens], 938 | limit_all_gathers=False, # TODO: make sure whether the gradient accumulation is affected by this setting 939 | use_orig_params=False, 940 | sync_module_states=False 941 | ) 942 | self.shared_fsdp_kwargs = shared_fsdp_kwargs 943 | 944 | rank0_print('Sharding policy...') 945 | mp_dtype = getattr(torch, config.model.fsdp_policy_mp) if config.model.fsdp_policy_mp is not None else None 946 | policy_mp_policy = MixedPrecision(param_dtype=mp_dtype, reduce_dtype=mp_dtype, buffer_dtype=mp_dtype) 947 | self.policy = FSDP(policy, **shared_fsdp_kwargs, mixed_precision=policy_mp_policy) 948 | 949 | if config.activation_checkpointing: 950 | rank0_print('Attempting to enable activation checkpointing...') 951 | try: 952 | # use activation checkpointing, according to: 953 | # https://pytorch.org/blog/scaling-multimodal-foundation-models-in-torchmultimodal-with-pytorch-distributed/ 954 | # 955 | # first, verify we have FSDP activation support ready by importing: 956 | from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( 957 | checkpoint_wrapper, 958 | apply_activation_checkpointing, 959 | CheckpointImpl, 960 | ) 961 | non_reentrant_wrapper = functools.partial( 962 | checkpoint_wrapper, 963 | offload_to_cpu=False, 964 | checkpoint_impl=CheckpointImpl.NO_REENTRANT, 965 | ) 966 | except Exception as e: 967 | rank0_print('FSDP activation checkpointing not available:', e) 968 | else: 969 | check_fn = lambda submodule: isinstance(submodule, wrap_class) 970 | rank0_print('Applying activation checkpointing wrapper to policy...') 971 | apply_activation_checkpointing(self.policy, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn) 972 | rank0_print('FSDP activation checkpointing enabled!') 973 | 974 | if config.loss.name in {'sft', 'dpo', 'ipo'}: 975 | if config.loss.reference_free is False: 976 | rank0_print('Sharding reference model...') 977 | self.reference_model = FSDP(reference_model, **shared_fsdp_kwargs) 978 | else: 979 | self.reference_model = None 980 | 981 | self.policy_weak = None 982 | 983 | 984 | print('Loaded model on rank', rank) 985 | dist.barrier() 986 | 987 | # have to initialize the optimizer and scheduler after __init__ to avoid conflict with FSDP 988 | rank0_print(f'Using {self.config.optimizer} optimizer') 989 | self.optimizer = getattr(torch.optim, self.config.optimizer)(self.policy.parameters(), self.config.lr) 990 | 991 | rank0_print(f'Using {self.config.lr_scheduler} learning rate scheduler') 992 | if self.config.lr_scheduler == 'linear': 993 | self.lr_scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, 994 | lr_lambda=lambda step: min(1.0, (step + 1) / (self.warmup_steps + 1)) 995 | ) 996 | elif self.config.lr_scheduler == 'cosine': 997 | self.lr_scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, 998 | lr_lambda=lambda step: cosine_with_warmup(step, 999 | num_warmup_steps=self.warmup_steps, 1000 | num_training_steps=self.train_iterations) 1001 | ) 1002 | else: 1003 | raise ValueError(f'unknown lr_scheduler {self.config.lr_scheduler}') 1004 | 1005 | 1006 | 1007 | def clip_gradient(self): 1008 | """Clip the gradient norm of the parameters of an FSDP policy, gathering the gradients across all GPUs.""" 1009 | return self.policy.clip_grad_norm_(self.config.max_grad_norm).item() 1010 | 1011 | def load_checkpoint(self, checkpoint_path: str): 1012 | """Load a checkpoint from disk.""" 1013 | rank0_print(f'Loading checkpoint from {self.config.checkpoint_path}') 1014 | policy = torch.load(f'{self.config.checkpoint_path}/policy.pt', map_location='cpu') 1015 | self.policy.load_state_dict(policy['state']) 1016 | 1017 | optimizer = torch.load(f'{self.config.checkpoint_path}/optimizer.pt', map_location='cpu') 1018 | self.optimizer.load_state_dict(FSDP.optim_state_dict_to_load(self.policy, self.optimizer, optimizer['state'])) 1019 | 1020 | lr_scheduler = torch.load(f'{self.config.checkpoint_path}/scheduler.pt', map_location='cpu') 1021 | self.lr_scheduler.load_state_dict(lr_scheduler['state']) 1022 | 1023 | self.checkpoint_example_idx = policy['step_idx'] 1024 | self.checkpoint_batch_idx = policy['step_idx'] // self.config.batch_size 1025 | rank0_print(f'Loaded checkpoint from {self.config.checkpoint_path} at step {self.checkpoint_example_idx}') 1026 | 1027 | def save(self, output_dir=None, metrics=None): 1028 | """Save policy, optimizer, and scheduler state to disk, gathering from all processes and saving only on the rank 0 process.""" 1029 | save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) 1030 | with FSDP.state_dict_type(self.policy, StateDictType.FULL_STATE_DICT, state_dict_config=save_policy): 1031 | policy_state_dict = self.policy.state_dict() 1032 | 1033 | if self.rank == 0: 1034 | self.write_state_dict(self.example_counter, policy_state_dict, metrics, 'policy.pt', output_dir) 1035 | del policy_state_dict 1036 | dist.barrier() 1037 | 1038 | save_policy = FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True) 1039 | with FSDP.state_dict_type(self.policy, StateDictType.FULL_STATE_DICT, optim_state_dict_config=save_policy): 1040 | optimizer_state_dict = FSDP.optim_state_dict(self.policy, self.optimizer) 1041 | 1042 | if self.rank == 0: 1043 | self.write_state_dict(self.example_counter, optimizer_state_dict, metrics, 'optimizer.pt', output_dir) 1044 | del optimizer_state_dict 1045 | dist.barrier() 1046 | 1047 | if self.rank == 0: 1048 | scheduler_state_dict = self.lr_scheduler.state_dict() 1049 | self.write_state_dict(self.example_counter, scheduler_state_dict, metrics, 'scheduler.pt', output_dir) 1050 | dist.barrier() 1051 | --------------------------------------------------------------------------------