├── config ├── loss │ ├── base.yaml │ ├── sft.yaml │ ├── dpo.yaml │ ├── ipo.yaml │ ├── rdpo.yaml │ └── ripo.yaml ├── data_selection │ ├── uniform.yaml │ └── rho_loss_dpo.yaml ├── model │ ├── gpt2-xl.yaml │ ├── gemma-2b-it.yaml │ ├── gpt2-large.yaml │ ├── gptj.yaml │ ├── pythia28.yaml │ ├── gemma-2b.yaml │ ├── bert-tiny.yaml │ └── blank_model.yaml ├── test_experiment │ ├── test_sft.yaml │ ├── test_dpo.yaml │ ├── test_dpo_us.yaml │ └── test_dpo_rho.yaml └── config.yaml ├── alt_requirements.txt ├── src ├── trainers_factory.py ├── groupstuff │ ├── group_dataset.py │ ├── global_opinion_data_processing.py │ ├── global_opinion_data_processing_kfold.py │ └── data_processing.py ├── eval │ ├── win_rate.py │ └── fast_oai.py ├── loss_utils.py ├── utils.py ├── trainers │ ├── paralleltrainer.py │ └── basictrainer.py ├── data_selection.py ├── models.py └── preference_datasets.py ├── scripts ├── run_sft.sh ├── run_multi.sh └── run_multi_robust.sh ├── main_requirements.txt ├── README.md ├── plot_scripts └── visualisations_utils_wandb_api.py ├── train.py └── LICENSE /config/loss/base.yaml: -------------------------------------------------------------------------------- 1 | name : base -------------------------------------------------------------------------------- /config/loss/sft.yaml: -------------------------------------------------------------------------------- 1 | name: sft -------------------------------------------------------------------------------- /alt_requirements.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rsshyam/GRPO/HEAD/alt_requirements.txt -------------------------------------------------------------------------------- /config/data_selection/uniform.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.data_selection.UniformRandomSelection 2 | 3 | #other arguments go here e.g. we can point to model configs etc... -------------------------------------------------------------------------------- /config/model/gpt2-xl.yaml: -------------------------------------------------------------------------------- 1 | name_or_path: gpt2-xl 2 | tokenizer_name_or_path: null 3 | archive: null 4 | block_name: GPT2Block 5 | 6 | policy_dtype: float32 7 | fsdp_policy_mp: null 8 | reference_dtype: float16 -------------------------------------------------------------------------------- /config/model/gemma-2b-it.yaml: -------------------------------------------------------------------------------- 1 | name_or_path: google/gemma-2b-it 2 | tokenizer_name_or_path: null 3 | archive: null 4 | block_name: null 5 | 6 | policy_dtype: float32 7 | fsdp_policy_mp: null 8 | reference_dtype: float16 -------------------------------------------------------------------------------- /config/model/gpt2-large.yaml: -------------------------------------------------------------------------------- 1 | name_or_path: gpt2-large 2 | tokenizer_name_or_path: null 3 | archive: null 4 | block_name: GPT2Block 5 | 6 | policy_dtype: float32 7 | fsdp_policy_mp: null 8 | reference_dtype: float16 -------------------------------------------------------------------------------- /config/model/gptj.yaml: -------------------------------------------------------------------------------- 1 | name_or_path: EleutherAI/gpt-j-6b 2 | tokenizer_name_or_path: null 3 | archive: null 4 | block_name: GPTJBlock 5 | 6 | policy_dtype: float32 7 | fsdp_policy_mp: null 8 | reference_dtype: float16 -------------------------------------------------------------------------------- /config/data_selection/rho_loss_dpo.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.data_selection.DPORHOLossSelection 2 | 3 | #model state dict paths 4 | ft_state_dict_path: null 5 | sft_state_dict_path: null 6 | 7 | #We can use the following +model@data_selection.model=tiny-mistral' to use model configs in this element... -------------------------------------------------------------------------------- /config/model/pythia28.yaml: -------------------------------------------------------------------------------- 1 | name_or_path: EleutherAI/pythia-2.8b 2 | tokenizer_name_or_path: null 3 | archive: null 4 | block_name: GPTNeoXLayer 5 | 6 | policy_dtype: float32 7 | fsdp_policy_mp: null 8 | reference_dtype: float16 9 | 10 | use_lora: True 11 | lora_rank: 8 12 | lora_alpha: 32 13 | lora_dropout: 0.0 14 | lora_target_modules: 15 | - query_key_value -------------------------------------------------------------------------------- /config/test_experiment/test_sft.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: tiny-mistral 4 | 5 | datasets: 6 | - hh 7 | 8 | local_dirs: 9 | - test_outputs 10 | 11 | exp_name: test_sft_tiny_mistral_run 12 | gradient_accumulation_steps: 2 13 | batch_size: 10 14 | eval_batch_size: 10 15 | trainer: BasicTrainer 16 | sample_during_eval: false 17 | 18 | wandb: 19 | enabled: false 20 | 21 | test_dataset: true -------------------------------------------------------------------------------- /config/model/gemma-2b.yaml: -------------------------------------------------------------------------------- 1 | name_or_path: google/gemma-2b 2 | tokenizer_name_or_path: null 3 | archive: null 4 | block_name: null 5 | 6 | policy_dtype: float32 7 | fsdp_policy_mp: null 8 | reference_dtype: float32 9 | 10 | use_lora: True 11 | lora_rank: 8 12 | lora_alpha: 32 13 | lora_dropout: 0.0 14 | lora_target_modules: 15 | - q_proj 16 | - o_proj 17 | - k_proj 18 | - v_proj 19 | - gate_proj 20 | - up_proj 21 | - down_proj -------------------------------------------------------------------------------- /config/test_experiment/test_dpo.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: tiny-mistral 4 | - override /loss: dpo 5 | 6 | datasets: 7 | - hh 8 | 9 | local_dirs: 10 | - test_outputs 11 | 12 | exp_name: test_dpo_tiny_mistral_run 13 | 14 | trainer: BasicTrainer 15 | loss: 16 | beta: 0.1 17 | 18 | sample_during_eval: false 19 | gradient_accumulation_steps: 2 20 | batch_size: 10 21 | eval_batch_size: 10 22 | 23 | wandb: 24 | enabled: false 25 | 26 | test_dataset: true 27 | 28 | -------------------------------------------------------------------------------- /config/loss/dpo.yaml: -------------------------------------------------------------------------------- 1 | # do DPO preference-based training 2 | name: dpo 3 | 4 | # the temperature parameter for DPO; lower values mean we care less about 5 | # the reference model 6 | beta: ??? 7 | 8 | # the noise parameter for conservative DPO; should be in range (0, 0.5); interpreted as 9 | # the fraction of preference pairs that are flipped 10 | # eps=0 is the original DPO loss in the DPO paper 11 | label_smoothing: 0 12 | 13 | # if true, use a uniform (maximum entropy) reference model 14 | reference_free: false -------------------------------------------------------------------------------- /config/loss/ipo.yaml: -------------------------------------------------------------------------------- 1 | # do DPO preference-based training 2 | name: ipo 3 | 4 | # the temperature parameter for DPO; lower values mean we care less about 5 | # the reference model 6 | beta: ??? 7 | 8 | # the noise parameter for conservative DPO; should be in range (0, 0.5); interpreted as 9 | # the fraction of preference pairs that are flipped 10 | # eps=0 is the original DPO loss in the DPO paper 11 | label_smoothing: 0 12 | 13 | # if true, use a uniform (maximum entropy) reference model 14 | reference_free: false -------------------------------------------------------------------------------- /config/test_experiment/test_dpo_us.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: tiny-mistral 4 | - override /data_selection: uniform 5 | - override /loss: dpo 6 | 7 | datasets: 8 | - hh 9 | 10 | local_dirs: 11 | - test_outputs 12 | 13 | loss: 14 | beta: 0.1 15 | 16 | exp_name: test_sft_tiny_mistral_run 17 | gradient_accumulation_steps: 2 18 | batch_size: 10 19 | eval_batch_size: 10 20 | trainer: BasicTrainer 21 | sample_during_eval: false 22 | 23 | wandb: 24 | enabled: false 25 | 26 | test_dataset: true 27 | selected_batch_size: 2 -------------------------------------------------------------------------------- /config/test_experiment/test_dpo_rho.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: tiny-mistral 4 | - override /data_selection: rho_loss_dpo 5 | - override /loss: dpo 6 | 7 | datasets: 8 | - hh 9 | 10 | local_dirs: 11 | - test_outputs 12 | 13 | loss: 14 | beta: 0.1 15 | 16 | exp_name: test_dpo_rho_loss_run 17 | gradient_accumulation_steps: 2 18 | batch_size: 10 19 | eval_batch_size: 10 20 | trainer: BasicTrainer 21 | sample_during_eval: false 22 | 23 | wandb: 24 | enabled: false 25 | 26 | test_dataset: true 27 | selected_batch_size: 2 -------------------------------------------------------------------------------- /config/model/bert-tiny.yaml: -------------------------------------------------------------------------------- 1 | name_or_path: prajjwal1/bert-tiny 2 | tokenizer_name_or_path: prajjwal1/bert-tiny 3 | archive: null 4 | 5 | # the name of the module class to wrap with FSDP; should be something like 6 | # e.g. GPT2Block, GPTNeoXLayer, LlamaDecoderLayer, etc. 7 | block_name:BertLayer 8 | 9 | # the dtype for the policy parameters/optimizer state 10 | policy_dtype: float32 11 | 12 | # the mixed precision dtype if using FSDP; defaults to the same as the policy 13 | fsdp_policy_mp: null 14 | 15 | # the dtype for the reference model (which is used for inference only) 16 | reference_dtype: float16 17 | -------------------------------------------------------------------------------- /config/loss/rdpo.yaml: -------------------------------------------------------------------------------- 1 | # do DPO preference-based training 2 | name: rdpo 3 | 4 | # the temperature parameter for DPO; lower values mean we care less about 5 | # the reference model 6 | beta: ??? 7 | 8 | # the noise parameter for conservative DPO; should be in range (0, 0.5); interpreted as 9 | # the fraction of preference pairs that are flipped 10 | # eps=0 is the original DPO loss in the DPO paper 11 | label_smoothing: 0 12 | 13 | # if true, use a uniform (maximum entropy) reference model 14 | reference_free: false 15 | 16 | step_size: 0.01 17 | 18 | importance_sampling: False 19 | 20 | imp_weights: False 21 | 22 | adj: 0 23 | 24 | #weight_decay: 0 25 | dpowts: False 26 | 27 | divide_by_totalcount: True 28 | 29 | adaptive_step_size: False 30 | 31 | step_factor: 0.1 -------------------------------------------------------------------------------- /config/loss/ripo.yaml: -------------------------------------------------------------------------------- 1 | # do DPO preference-based training 2 | name: ripo 3 | 4 | # the temperature parameter for DPO; lower values mean we care less about 5 | # the reference model 6 | beta: ??? 7 | 8 | # the noise parameter for conservative DPO; should be in range (0, 0.5); interpreted as 9 | # the fraction of preference pairs that are flipped 10 | # eps=0 is the original DPO loss in the DPO paper 11 | label_smoothing: 0 12 | 13 | # if true, use a uniform (maximum entropy) reference model 14 | reference_free: false 15 | 16 | step_size: 0.01 17 | 18 | importance_sampling: False 19 | 20 | imp_weights: False 21 | 22 | adj: 0 23 | 24 | #weight_decay: 0 25 | dpowts: False 26 | 27 | divide_by_totalcount: True 28 | 29 | adaptive_step_size: False 30 | 31 | step_factor: 0.1 -------------------------------------------------------------------------------- /config/model/blank_model.yaml: -------------------------------------------------------------------------------- 1 | # the name of the model to use; should be something like 2 | # gpt2-xl or gpt-neo-2.7B or huggyllama/llama-7b 3 | name_or_path: ??? 4 | 5 | # the name of the tokenizer to use; if null, will use the tokenizer from the model 6 | tokenizer_name_or_path: null 7 | 8 | # override pre-trained weights (e.g., from SFT); optional 9 | archive: null 10 | 11 | # the name of the module class to wrap with FSDP; should be something like 12 | # e.g. GPT2Block, GPTNeoXLayer, LlamaDecoderLayer, etc. 13 | block_name: null 14 | 15 | # the dtype for the policy parameters/optimizer state 16 | policy_dtype: float32 17 | 18 | # the mixed precision dtype if using FSDP; defaults to the same as the policy 19 | fsdp_policy_mp: null 20 | 21 | # the dtype for the reference model (which is used for inference only) 22 | reference_dtype: float16 23 | -------------------------------------------------------------------------------- /src/trainers_factory.py: -------------------------------------------------------------------------------- 1 | # trainer_factory.py 2 | from src.trainers.basictrainer import BasicTrainer 3 | from src.trainers.grouptrainer import GroupTrainer 4 | from src.trainers.paralleltrainer import FSDPTrainer,TensorParallelTrainer 5 | from src.trainers.grouptrainerearlystop import GroupTrainerEarlyStop 6 | 7 | def get_trainer(trainer,policy, config, seed, local_run_dir, reference_model, data_selector, rank, world_size): 8 | if trainer == "BasicTrainer": 9 | return BasicTrainer(policy, config, seed, local_run_dir, reference_model,data_selector, rank, world_size) 10 | elif trainer == "GroupTrainer": 11 | return GroupTrainer(policy, config, seed, local_run_dir, reference_model,data_selector, rank, world_size) 12 | elif trainer == "GroupTrainerEarlyStop": 13 | return GroupTrainerEarlyStop(policy, config, seed, local_run_dir, reference_model,data_selector, rank, world_size) 14 | elif trainer == "parallel_fsdp": 15 | return FSDPTrainer(policy, config, seed, local_run_dir, reference_model,data_selector, rank, world_size) 16 | elif trainer == "parallel_tensor": 17 | return TensorParallelTrainer(policy, config, seed, local_run_dir, reference_model,data_selector, rank, world_size) 18 | else: 19 | raise ValueError("Unknown trainer type") -------------------------------------------------------------------------------- /scripts/run_sft.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Default parameters 4 | MODEL="gemma-2b" 5 | DATASETS="goqa_0,goqa_1,goqa_2,goqa_3,goqa_4" 6 | TRAIN_FRAC=0.8 7 | LOSS="sft" 8 | GRADIENT_ACCUMULATION_STEPS=2 9 | BATCH_SIZE=16 10 | EVAL_BATCH_SIZE=8 11 | SAMPLE_DURING_EVAL="False" 12 | TRAINER="GroupTrainer" 13 | LR=1e-4 14 | N_EPOCHS=1 15 | EVAL_EVERY=192 16 | EVAL_TRAIN_EVERY=192 17 | EVAL_ONLY_ONCE="False" 18 | # Parse arguments 19 | while [[ $# -gt 0 ]]; do 20 | key="$1" 21 | case $key in 22 | --model) 23 | MODEL="$2" 24 | shift # past argument 25 | shift # past value 26 | ;; 27 | --datasets) 28 | DATASETS="$2" 29 | shift # past argument 30 | shift # past value 31 | ;; 32 | --train_frac) 33 | TRAIN_FRAC="$2" 34 | shift # past argument 35 | shift # past value 36 | ;; 37 | --loss) 38 | LOSS="$2" 39 | shift # past argument 40 | shift # past value 41 | ;; 42 | --gradient_accumulation_steps) 43 | GRADIENT_ACCUMULATION_STEPS="$2" 44 | shift # past argument 45 | shift # past value 46 | ;; 47 | --batch_size) 48 | BATCH_SIZE="$2" 49 | shift # past argument 50 | shift # past value 51 | ;; 52 | --eval_batch_size) 53 | EVAL_BATCH_SIZE="$2" 54 | shift # past argument 55 | shift # past value 56 | ;; 57 | --sample_during_eval) 58 | SAMPLE_DURING_EVAL="$2" 59 | shift # past argument 60 | shift # past value 61 | ;; 62 | --trainer) 63 | TRAINER="$2" 64 | shift # past argument 65 | shift # past value 66 | ;; 67 | --lr) 68 | LR="$2" 69 | shift # past argument 70 | shift # past value 71 | ;; 72 | --n_epochs) 73 | N_EPOCHS="$2" 74 | shift # past argument 75 | shift # past value 76 | ;; 77 | --eval_every) 78 | EVAL_EVERY="$2" 79 | shift # past argument 80 | shift # past value 81 | ;; 82 | --eval_train_every) 83 | EVAL_TRAIN_EVERY="$2" 84 | shift # past argument 85 | shift # past value 86 | ;; 87 | *) # unknown option 88 | echo "Unknown option: $1" 89 | exit 1 90 | ;; 91 | esac 92 | done 93 | 94 | python -u train.py model=$MODEL datasets=[$DATASETS] train_frac=$TRAIN_FRAC loss=$LOSS gradient_accumulation_steps=$GRADIENT_ACCUMULATION_STEPS batch_size=$BATCH_SIZE eval_batch_size=$EVAL_BATCH_SIZE sample_during_eval=$SAMPLE_DURING_EVAL trainer=$TRAINER lr=$LR eval_every=$EVAL_EVERY eval_train_every=$EVAL_TRAIN_EVERY -------------------------------------------------------------------------------- /src/groupstuff/group_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.utils.data import Dataset, DataLoader 4 | from torch.utils.data.sampler import WeightedRandomSampler 5 | 6 | class GroupDataset(Dataset): 7 | def __init__(self, dataset, n_groups): 8 | #print(len(dataset),n_groups) 9 | self.dataset = dataset 10 | self.n_groups = n_groups 11 | group_array = [] 12 | resp_array = [] 13 | count_array = [] 14 | for prompt, responses, pairs, sft_target, truncation_mode,id in self: 15 | count_array.append(len(pairs)) 16 | group_array.append(id) 17 | resp_array.append(responses) 18 | self._group_array = torch.LongTensor(group_array) 19 | self._count_array= torch.LongTensor(count_array) 20 | #self._resp_array = torch.LongTensor(resp_array) 21 | #self._group_counts = (torch.arange(self.n_groups).unsqueeze(1)==self._group_array).sum(1).float() 22 | self._group_counts = torch.bincount(self._group_array, weights=self._count_array).float() 23 | print(self._group_counts) 24 | #self._resp_counts = (torch.arange(self.n_classes).unsqueeze(1)==self._resp_array).sum(1).float() 25 | 26 | def __getitem__(self, idx): 27 | return self.dataset[idx] 28 | 29 | def __len__(self): 30 | return len(self.dataset) 31 | 32 | def group_counts(self): 33 | return self._group_counts 34 | 35 | def input_size(self): 36 | for prompt, responses, pairs, sft_target, truncation_mode,id in self: 37 | return prompt.size() 38 | 39 | def get_loader(self): 40 | # Training and reweighting 41 | # When the --robust flag is not set, reweighting changes the loss function 42 | # from the normal ERM (average loss over each training example) 43 | # to a reweighted ERM (weighted average where each (y,c) group has equal weight) . 44 | # When the --robust flag is set, reweighting does not change the loss function 45 | # since the minibatch is only used for mean gradient estimation for each group separately 46 | #print(len(self),self._group_counts) 47 | group_weights = len(self)/self._group_counts 48 | #print(group_weights,self._group_array) 49 | weights = group_weights[self._group_array] 50 | 51 | # Replacement needs to be set to True, otherwise we'll run out of minority samples 52 | sampler = WeightedRandomSampler(weights, len(self), replacement=True) 53 | 54 | loader = DataLoader( 55 | self, 56 | shuffle=False, 57 | sampler=sampler) 58 | 59 | return loader -------------------------------------------------------------------------------- /main_requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.0.0 2 | accelerate==0.20.3 3 | aiohttp==3.8.4 4 | aiosignal==1.3.1 5 | antlr4-python3-runtime==4.9.3 6 | appdirs==1.4.4 7 | asttokens==2.2.1 8 | async-timeout==4.0.2 9 | attrs==23.1.0 10 | backcall==0.2.0 11 | beautifulsoup4==4.12.2 12 | bitsandbytes==0.41.1 13 | bloom-filter2==2.0.0 14 | cachetools==5.3.2 15 | certifi==2023.5.7 16 | charset-normalizer==3.1.0 17 | clarabel==0.6.0 18 | click==8.1.3 19 | cloudpickle==3.0.0 20 | cmake==3.26.4 21 | comm==0.1.3 22 | contourpy==1.2.0 23 | cvxpy==1.4.1 24 | cycler==0.12.1 25 | dask==2023.12.0 26 | datasets==2.12.0 27 | debugpy==1.6.7 28 | decorator==5.1.1 29 | dill==0.3.6 30 | docker-pycreds==0.4.0 31 | ecos==2.0.12 32 | executing==1.2.0 33 | filelock==3.12.2 34 | fonttools==4.46.0 35 | frozenlist==1.3.3 36 | fsspec==2023.6.0 37 | gitdb==4.0.10 38 | GitPython==3.1.31 39 | google-auth==2.26.1 40 | google-auth-oauthlib==1.2.0 41 | grpcio==1.60.0 42 | huggingface-hub==0.15.1 43 | hydra-core==1.3.2 44 | idna==3.4 45 | importlib-metadata==7.0.0 46 | ipdb==0.13.13 47 | ipykernel==6.23.1 48 | ipython==8.14.0 49 | jedi==0.18.2 50 | Jinja2==3.1.2 51 | jupyter_client==8.3.0 52 | jupyter_core==5.3.1 53 | kiwisolver==1.4.5 54 | lit==16.0.6 55 | locket==1.0.0 56 | Markdown==3.5.1 57 | MarkupSafe==2.1.3 58 | matplotlib==3.8.2 59 | matplotlib-inline==0.1.6 60 | mpmath==1.3.0 61 | multidict==6.0.4 62 | multiprocess==0.70.14 63 | neatplot @ git+https://github.com/willieneis/neatplot.git@8a0291aa2b5c83e7f36ab695629d3b147f1645f4 64 | nest-asyncio==1.5.6 65 | networkx==3.1 66 | numpy==1.24.3 67 | nvidia-cublas-cu11==11.10.3.66 68 | nvidia-cuda-cupti-cu11==11.7.101 69 | nvidia-cuda-nvrtc-cu11==11.7.99 70 | nvidia-cuda-runtime-cu11==11.7.99 71 | nvidia-cudnn-cu11==8.5.0.96 72 | nvidia-cufft-cu11==10.9.0.58 73 | nvidia-curand-cu11==10.2.10.91 74 | nvidia-cusolver-cu11==11.4.0.1 75 | nvidia-cusparse-cu11==11.7.4.91 76 | nvidia-nccl-cu11==2.14.3 77 | nvidia-nvtx-cu11==11.7.91 78 | oauthlib==3.2.2 79 | omegaconf==2.3.0 80 | openai==0.28.0 81 | osqp==0.6.3 82 | packaging==23.1 83 | pandas==2.0.2 84 | parso==0.8.3 85 | partd==1.4.1 86 | pathtools==0.1.2 87 | peft @ git+https://github.com/huggingface/peft@ffbb6bcf9c3801c8c9e49491412ce9917defcc78 88 | pexpect==4.8.0 89 | pickleshare==0.7.5 90 | Pillow==10.1.0 91 | platformdirs==3.8.0 92 | prompt-toolkit==3.0.38 93 | protobuf==4.23.3 94 | psutil==5.9.5 95 | ptyprocess==0.7.0 96 | pure-eval==0.2.2 97 | pyarrow==12.0.1 98 | pyasn1==0.5.1 99 | pyasn1-modules==0.3.0 100 | pybind11==2.11.1 101 | Pygments==2.15.1 102 | pyparsing==3.1.1 103 | python-dateutil==2.8.2 104 | python-dotenv==1.0.0 105 | pytz==2023.3 106 | PyYAML==6.0 107 | pyzmq==25.1.0 108 | qdldl==0.1.7.post0 109 | regex==2023.6.3 110 | requests==2.31.0 111 | requests-oauthlib==1.3.1 112 | responses==0.18.0 113 | rsa==4.9 114 | safetensors==0.3.1 115 | scipy==1.11.2 116 | scs==3.2.4.post1 117 | sentry-sdk==1.26.0 118 | setproctitle==1.3.2 119 | six==1.16.0 120 | smmap==5.0.0 121 | soupsieve==2.4.1 122 | stack-data==0.6.2 123 | swifter==1.4.0 124 | sympy==1.12 125 | tensor-parallel==1.2.4 126 | tensorboard==2.15.1 127 | tensorboard-data-server==0.7.2 128 | termcolor==2.4.0 129 | tokenizers==0.13.3 130 | tomli==2.0.1 131 | toolz==0.12.0 132 | torch==2.0.1 133 | tornado==6.3.2 134 | tqdm==4.65.0 135 | traitlets==5.9.0 136 | transformers==4.29.2 137 | triton==2.0.0 138 | typing_extensions==4.6.3 139 | tzdata==2023.3 140 | urllib3==2.0.3 141 | wandb==0.15.3 142 | wcwidth==0.2.6 143 | Werkzeug==3.0.1 144 | xxhash==3.2.0 145 | yarl==1.9.2 146 | zipp==3.17.0 -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GRPO: Group Robust Preference Optimization 2 | 3 | This codebase builds upon the DPO codebase publicly available in github https://github.com/eric-mitchell/direct-preference-optimization 4 | 5 | ## What is this repo? 6 | 7 | This repo includes a reference implementation of the GRPO algorithm for training language models from preference data, as described in our paper 8 | 9 | 10 | Similar to DPO, our pipeline has two stages: 11 | 12 | 1. Run supervised fine-tuning (SFT) on the dataset(s) of interest. 13 | 2. Run robust preference learning (GRIPO) on the model from step 1, using preference data. 14 | 15 | The important files in this repo are: 16 | - `train.py`: the main entry point for training (either SFT/IPO/GRIPO preference-based training) 17 | - `src/trainers_factory.py`: calls all the trainer classes from `src/trainers` 18 | - `src/utils.py`: common functions used by multiple methods 19 | - `src/preference_datasets.py`: dataset processing logic for both SFT and IPO/GRIPO preference-based training; 20 | 21 | In this codebase, we specifically use the Gemma-2b model and the configurations used are detailed in `config/model/gemma-2b.yaml`. To download and use the Gemma-2b model, kindly refer to https://huggingface.co/google/gemma-2b. It is a gated model, and hence requires access through huggingface. 22 | 23 | Our dataset is the global opinion data from https://huggingface.co/datasets/Anthropic/llm_global_opinions 24 | 25 | ### Set up environment 26 | 27 | First, create a virtualenv and install the dependencies. Python 3.10+ is recommended. 28 | 29 | python3 -m venv env 30 | source env/bin/activate 31 | pip install -r main_requirements.txt 32 | pip install scikit-learn 33 | 34 | 35 | In `config.yaml` setup your wandb details, so that results can be visualized there. 36 | 37 | ## Running SFT 38 | 39 | sh scripts/run_sft.sh 40 | 41 | Please run this command to reproduce the SFT used in our setup. 42 | ## Running IPO/GRIPO 43 | 44 | To run IPO, one requires the reference policy which is the path to the sft file. Please run the following command for IPO with `path-to-sft-file` replaced with your actual path to SFT trained policy 45 | 46 | sh scripts/run_multi.sh --model.archive path-to-sft-file 47 | 48 | The exact configurations we used in our IPO training are already set in `sh scripts/run_multi.sh` 49 | 50 | Similarly for GRIPO 51 | 52 | sh scripts/run_multi_robust.sh --model.archive path-to-sft-file 53 | 54 | Note these commands were run on a machine with 1 40GB A100 GPU. Further, we are running single GPU training, using `GroupEarlyStopTrainer` which 55 | reduces the learning rate if there is improvement in loss values after a certain number of iterations and is tunable. 56 | 57 | ## Plotting results 58 | In order to visualize the results, we collect data directly from wandb and plot the same. We include plotting scripts in `plot_scripts` folder that performs this. Kindly change the wandb details and `path-to-sft-file` in the plot scripts to retrieve the plots. 59 | 60 | `plot_scripts/plot_from_wandb_full_metrics.py` plots all the relevant metrics tracked in our experiments 61 | `plot_scripts/plot_from_wandb_paper_plots.py` reproduces the plots mentioned in the paper 62 | 63 | ## Citation 64 | Please cite our paper if you find the repo helpful in your work: 65 | 66 | ```bibtex 67 | @article{ramesh2024grpo, 68 | title={Group Robust Preference Optimization in Reward-free RLHF}, 69 | author={Shyam Sundhar Ramesh, Iason Chaimalas, Viraj Mehta, Haitham Bou Ammar, 70 | Pier Giuseppe Sessa, Yifan Hu, Ilija Bogunovic}, 71 | year={2024} 72 | } 73 | ``` 74 | 75 | 76 | 77 | 78 | -------------------------------------------------------------------------------- /src/eval/win_rate.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pathlib import Path 3 | import pandas as pd 4 | import logging 5 | from tqdm import tqdm 6 | import pickle 7 | from fast_oai import call_chats####openai_api added here, check file before using 8 | sys.path.append('..') 9 | sys.path.append('path')#folderpath --- src or not? 10 | import os 11 | print(sys.path) 12 | #from epinet import get_shuffle_iterator 13 | system_prompt = "Please act as an impartial judge and evaluate the quality of the responses provided by two AI assistants to the user question displayed below. You should choose the assistant that follows the user's instructions and answers the user's question better. Your evaluation should consider factors such as the helpfulness, relevance, accuracy, depth, creativity, and level of detail of their responses. Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. Output your final verdict by strictly following this format: 'A' if assistant A is better, 'B' if assistant B is better, and 'C' for a tie. Output only that character and do not include any other characters or spaces." 14 | 15 | user_prompt = "[User Question]\n{prompt}\n[The Start of Assistant A's Answer]\n{sample1}\n[The End of Assistant A's Answer]\n[The Start of Assistant B's Answer]\n{sample2}\n[The End of Assistant B's Answer]\n" 16 | 17 | 18 | def get_user_prompt(row): 19 | prompt = row["prompt"] 20 | sample1 = row['sample_only'] 21 | sample2 = row['correct response'] 22 | return user_prompt.format(prompt=prompt, sample1=sample1, sample2=sample2) 23 | 24 | def main(csv_dir_path, overwrite_model_result=False): 25 | csv_dir_path = Path(csv_dir_path) 26 | 27 | for csv_path in tqdm(csv_dir_path.iterdir()): 28 | print(f'processing {csv_path}') 29 | if not str(csv_path).endswith('.csv'): 30 | continue 31 | df=pd.read_csv(csv_path) 32 | print(df.columns) 33 | if df.columns[0]!='step':#in case the csv file directly starts with data 34 | df = pd.read_csv(csv_path,header=None) 35 | columns=["step", "prompt", "sample","correct response"]#change depending on file creation 36 | df.columns=columns 37 | print(df.columns) 38 | print(df.head) 39 | if 'model_result' in df.columns and not overwrite_model_result: 40 | mask = ~df['model_result'].isin(['Win', 'Lose', 'Tie']) 41 | else: 42 | mask = pd.Series(True, index=df.index) 43 | #print(df.columns) 44 | #keys = list(dataset.keys()) 45 | df['sample_only'] = df.apply(lambda row: row['sample'][len(row['prompt']):], axis=1)##removes prompt part of the sample and stores it in sample_only 46 | #df['sft_target'] = df.apply(lambda row: dataset[row['prompt']]['correct response'], axis=1) 47 | 48 | df['user_prompt'] = df.apply(get_user_prompt, axis=1) 49 | user_prompt_list = df.loc[mask, 'user_prompt'].tolist() 50 | system_prompt_gen = (system_prompt for _ in range(len(user_prompt_list))) 51 | completions = call_chats(zip(system_prompt_gen, user_prompt_list)) 52 | vals = [] 53 | for i, dec in enumerate(completions): 54 | if dec == 'A': 55 | vals.append("Win") 56 | elif dec == 'B': 57 | vals.append("Lose") 58 | elif dec == 'C': 59 | vals.append('Tie') 60 | else: 61 | logging.warning(f"Unexpected decision {dec} on row {i}") 62 | vals.append(dec) 63 | df.loc[mask, 'model_result'] = vals 64 | print(f'writing to {csv_path}') 65 | df.to_csv(csv_path, index=False) 66 | 67 | if __name__ == '__main__': 68 | main(*sys.argv[1:]) -------------------------------------------------------------------------------- /config/config.yaml: -------------------------------------------------------------------------------- 1 | # random seed for batch sampling 2 | seed: 0 3 | 4 | # name for this experiment in the local run directory and on wandb 5 | exp_name: ??? 6 | 7 | group_name: ??? 8 | 9 | use_kfoldsplit: False 10 | 11 | # the batch size for training; for FSDP, the batch size per GPU is batch_size / (grad_accumulation_steps * num_gpus) 12 | batch_size: 4 13 | 14 | # the batch size during evaluation and sampling, if enabled 15 | eval_batch_size: 16 16 | 17 | # debug mode (disables wandb, model checkpointing, etc.) 18 | debug: false 19 | 20 | # the port to use for FSDP 21 | fsdp_port: null 22 | 23 | # which dataset(s) to train on; can pass a list like datasets=[hh,shp] 24 | datasets: 25 | - goqa 26 | 27 | # wandb configuration 28 | wandb: 29 | enabled: true 30 | entity: 'entity_name' 31 | project: 'project_name' 32 | key: null 33 | scheduler_metric: "accuracy" 34 | patience_factor: 1 35 | # to create the local run directory and cache models/datasets, 36 | # we will try each of these directories in order; if none exist, 37 | # we will create the last one and use it 38 | local_dirs: 39 | - /scr-ssd 40 | - /scr 41 | - .cache 42 | 43 | # whether or not to generate samples during evaluation; disable for FSDP/TensorParallel 44 | # is recommended, because they are slow 45 | sample_during_eval: true 46 | 47 | # how many model samples to generate during evaluation 48 | n_eval_model_samples: 16 49 | 50 | # whether to eval at the very beginning of training 51 | do_first_eval: true 52 | 53 | # an OmegaConf resolver that returns the local run directory, calling a function in utils.py 54 | local_run_dir: ${get_local_run_dir_group:${exp_name},${group_name},${local_dirs}} 55 | 56 | # the learning rate 57 | lr: 5e-7 58 | min_lr: 1e-8 59 | 60 | # number of steps to accumulate over for each batch 61 | # (e.g. if batch_size=4 and gradient_accumulation_steps=2, then we will 62 | # accumulate gradients over 2 microbatches of size 2) 63 | gradient_accumulation_steps: 1 64 | 65 | # the maximum gradient norm to clip to 66 | max_grad_norm: 10.0 67 | 68 | # the maximum allowed length for an input (prompt + response) 69 | max_length: 512 70 | 71 | # the maximum allowed length for a prompt 72 | max_prompt_length: 256 73 | 74 | # the number of epochs to train for; if null, must specify n_examples 75 | n_epochs: 1 76 | 77 | # the number of examples to train for; if null, must specify n_epochs 78 | n_examples: null 79 | 80 | # the number of examples to evaluate on (and sample from, if sample_during_eval is true) 81 | n_eval_examples: 128 82 | 83 | # the trainer class to use (e.g. BasicTrainer, FSDPTrainer, TensorParallelTrainer) 84 | trainer: BasicTrainer 85 | 86 | # The optimizer to use; we use RMSprop because it works about as well as Adam and is more memory-efficient 87 | optimizer: RMSprop 88 | 89 | # number of linear warmup steps for the learning rate 90 | warmup_steps: 150 91 | 92 | # whether or not to use activation/gradient checkpointing 93 | activation_checkpointing: false 94 | 95 | # evaluate and save model every eval_every steps 96 | eval_every: 96 97 | 98 | # prevent wandb from logging more than once per minimum_log_interval_secs 99 | minimum_log_interval_secs: 1.0 100 | 101 | # ensure an sft model is explicitly provided before running: 102 | assert_sft_step: true 103 | 104 | #test dataset setup: 105 | test_dataset: False 106 | 107 | #Active Learning default settings: 108 | active: False 109 | selected_batch_size: null 110 | 111 | train_frac: 0.8 112 | 113 | #group addition 114 | #robust additions 115 | group_handling: False 116 | #ref_sample: false 117 | eval_train_data: True 118 | eval_train_full: False 119 | # how many model samples to generate during evaluation 120 | #n_eval_model_samples: 128 121 | # whether to eval at the very beginning of training 122 | use_ref: true 123 | #do_first_eval_gen: False 124 | #split: 100 125 | eval_train_every: 96 126 | save_every: 16000 127 | eval_full: False 128 | eval_only_once: False 129 | 130 | weighted_batches: False 131 | #n_eval_metrics: 1 132 | #check_all_responses: False 133 | eval_train_end: True 134 | sep_pairs: False 135 | 136 | max_train_examples: null 137 | 138 | defaults: 139 | - _self_ 140 | - model: blank_model_fp32 # basic model configuration 141 | - loss: sft # which loss function, either sft or dpo (specify loss.beta if using dpo) 142 | - data_selection: null -------------------------------------------------------------------------------- /plot_scripts/visualisations_utils_wandb_api.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import wandb 4 | import itertools 5 | import pandas as pd 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | 9 | pd.options.mode.chained_assignment = None # default='warn' 10 | 11 | def download_runs(entity, project, filters, num_samples=1e6): 12 | """ 13 | Download runs via the wandb API 14 | 15 | Parameters 16 | ---------- 17 | entity : str 18 | wandb entity name. 19 | project : str 20 | wandb project name. 21 | filters : dict 22 | filter to select runs in a MongoDB query format. 23 | 24 | Returns 25 | ------- 26 | runs : list 27 | list of the run history retrieved from wandb. 28 | """ 29 | 30 | api = wandb.Api(timeout=30) 31 | runs = api.runs(entity + "/" + project, filters=filters) 32 | #print(entity,project,filters) 33 | #runs_config = [run.config for run in runs] 34 | runs_hist = [run.history(samples=num_samples) for run in runs] 35 | 36 | return runs_hist #runs_hist, runs_config, 37 | 38 | def process_runs(runs, field, time_field='epoch', agg='mean'): 39 | 40 | processed_runs = list() 41 | 42 | for df_run in runs: 43 | #print(df_run.columns) 44 | #Download the run's history 45 | if field not in df_run.columns: 46 | continue 47 | df = df_run[[field, time_field]] 48 | #print(df[field][:-1]) 49 | #print(df) 50 | subset=df.iloc[:-1][field] 51 | #print(subset,'subset') 52 | is_string_nan_mask = subset.apply(lambda x: isinstance(x, str) and x.lower() == 'nan') 53 | 54 | #print(is_string_nan_mask.any(),'stringnan') 55 | if is_string_nan_mask.any(): 56 | df[field] = df[field].replace('NaN', pd.NA) 57 | # print(df.isna()) 58 | # print(f"Skipping run due to NaN values in field '{field}'") 59 | #print(df[field]) 60 | # continue 61 | _filter = df[field].isna() 62 | df_filtered = df[~_filter] 63 | #print(df_filtered) 64 | #Assert numerical typing: 65 | df_filtered.loc[:, field] = df_filtered[field].astype(np.float64) 66 | 67 | df_filtered = df_filtered.groupby(time_field).agg({field:agg}) 68 | 69 | processed_runs.append(df_filtered) 70 | 71 | return processed_runs 72 | 73 | def process_max_fields(runs, fields, maximum=True, time_field='epoch', x_percent_rmv=None): 74 | 75 | processed_runs = list() 76 | 77 | def remove_x_percent(series): 78 | pass 79 | 80 | for df_run in runs: 81 | 82 | run_results = list() 83 | 84 | for i, field in enumerate(fields): 85 | 86 | df = df_run[[field, time_field]] 87 | _filter = df[time_field].isnull() 88 | df_filtered = df[~_filter] 89 | 90 | if x_percent_rmv is None: 91 | df_filtered = df_filtered.groupby(time_field).agg({field:'mean'}) 92 | else: 93 | df_filtered = df_filtered.groupby(time_field).agg({field:remove_x_percent}) 94 | 95 | 96 | run_results.append(df_filtered) 97 | 98 | df_concat = pd.concat(run_results, axis=1) 99 | 100 | if maximum: 101 | run_results = df_concat.max(axis=1) 102 | else: 103 | run_results = df_concat.min(axis=1) 104 | 105 | processed_runs.append(run_results) 106 | 107 | return processed_runs 108 | 109 | 110 | def group_process_runs(processed_runs, runs): 111 | 112 | assert len(processed_runs) == len(runs),\ 113 | 'processed runs must be the same length as runs' 114 | 115 | #Stack together processed runs: 116 | df = pd.concat(processed_runs, axis=1) 117 | 118 | #Calculate the mean and std 119 | mean = df.mean(axis=1) 120 | std = df.std(axis=1) 121 | 122 | return pd.concat([mean, std], axis=1) 123 | 124 | 125 | def process_and_plot_max_grp_runs(fig, axs, runs, fields): 126 | pass 127 | 128 | def process_and_plot_grp_runs(fig, axs, runs, fields): 129 | pass 130 | 131 | 132 | 133 | -------------------------------------------------------------------------------- /scripts/run_multi.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Default parameters 4 | MODEL="gemma-2b" 5 | DATASETS="goqa_0,goqa_1,goqa_2,goqa_3,goqa_4" 6 | TRAIN_FRAC=0.8 7 | LOSS="ipo" # ipo, ripo, rdpo 8 | GRADIENT_ACCUMULATION_STEPS=2 9 | BATCH_SIZE=16 10 | EVAL_BATCH_SIZE=8 11 | SAMPLE_DURING_EVAL="False" 12 | TRAINER="GroupTrainerEarlyStop" 13 | LR=2e-5 14 | LABEL_SMOOTHING=0 15 | MODEL_ARCHIVE="path-to-sft-policy" 16 | LOSS_BETA=0.01 17 | N_EPOCHS=30 18 | EVAL_EVERY=960 19 | EVAL_TRAIN_EVERY=192 20 | NSEEDS=4 21 | STSEED=0 22 | EVAL_ONLY_ONCE="False" 23 | PATIENCE_FACTOR=2 24 | SCHEDULER_METRIC="loss" 25 | USE_KFOLDSPLIT="False" 26 | OPTIMIZER="AdamW" 27 | MIN_LR=0.00000001 28 | # Parse arguments 29 | while [[ $# -gt 0 ]]; do 30 | key="$1" 31 | case $key in 32 | --model) 33 | MODEL="$2" 34 | shift # past argument 35 | shift # past value 36 | ;; 37 | --datasets) 38 | DATASETS="$2" 39 | shift # past argument 40 | shift # past value 41 | ;; 42 | --step_size) 43 | STEP_SIZE="$2" 44 | shift # past argument 45 | shift # past value 46 | ;; 47 | --divide_by_totalcount) 48 | DIVIDE_BY_TOTALCOUNT="$2" 49 | shift # past argument 50 | shift # past value 51 | ;; 52 | --train_frac) 53 | TRAIN_FRAC="$2" 54 | shift # past argument 55 | shift # past value 56 | ;; 57 | --loss) 58 | LOSS="$2" 59 | shift # past argument 60 | shift # past value 61 | ;; 62 | --gradient_accumulation_steps) 63 | GRADIENT_ACCUMULATION_STEPS="$2" 64 | shift # past argument 65 | shift # past value 66 | ;; 67 | --batch_size) 68 | BATCH_SIZE="$2" 69 | shift # past argument 70 | shift # past value 71 | ;; 72 | --eval_batch_size) 73 | EVAL_BATCH_SIZE="$2" 74 | shift # past argument 75 | shift # past value 76 | ;; 77 | --min_lr) 78 | MIN_LR="$2" 79 | shift # past argument 80 | shift # past value 81 | ;; 82 | --sample_during_eval) 83 | SAMPLE_DURING_EVAL="$2" 84 | shift # past argument 85 | shift # past value 86 | ;; 87 | --trainer) 88 | TRAINER="$2" 89 | shift # past argument 90 | shift # past value 91 | ;; 92 | --scheduler_metric) 93 | SCHEDULER_METRIC="$2" 94 | shift # past argument 95 | shift # past value 96 | ;; 97 | --lr) 98 | LR="$2" 99 | shift # past argument 100 | shift # past value 101 | ;; 102 | --label_smoothing) 103 | LABEL_SMOOTHING="$2" 104 | shift # past argument 105 | shift # past value 106 | ;; 107 | --model_archive) 108 | MODEL_ARCHIVE="$2" 109 | shift # past argument 110 | shift # past value 111 | ;; 112 | --optimizer) 113 | OPTIMIZER="$2" 114 | shift # past argument 115 | shift # past value 116 | ;; 117 | --loss_beta) 118 | LOSS_BETA="$2" 119 | shift # past argument 120 | shift # past value 121 | ;; 122 | --n_epochs) 123 | N_EPOCHS="$2" 124 | shift # past argument 125 | shift # past value 126 | ;; 127 | --eval_every) 128 | EVAL_EVERY="$2" 129 | shift # past argument 130 | shift # past value 131 | ;; 132 | --eval_train_every) 133 | EVAL_TRAIN_EVERY="$2" 134 | shift # past argument 135 | shift # past value 136 | ;; 137 | --eval_only_once) 138 | EVAL_ONLY_ONCE="$2" 139 | shift # past argument 140 | shift # past value 141 | ;; 142 | --use_kfoldsplit) 143 | USE_KFOLDSPLIT="$2" 144 | shift # past argument 145 | shift # past value 146 | ;; 147 | --nseeds) 148 | NSEEDS="$2" 149 | shift # past argument 150 | shift # past value 151 | ;; 152 | --stseed) 153 | STSEED="$2" 154 | shift # past argument 155 | shift # past value 156 | ;; 157 | --patience_factor) 158 | PATIENCE_FACTOR="$2" 159 | shift # past argument 160 | shift # past value 161 | ;; 162 | *) # unknown option 163 | echo "Unknown option: $1" 164 | exit 1 165 | ;; 166 | esac 167 | done 168 | 169 | # Loop over seeds 0 to nseeds 170 | for SEED in $(seq $STSEED $NSEEDS) 171 | do 172 | echo "Running training with seed $SEED" 173 | python -u train.py model=$MODEL min_lr=$MIN_LR optimizer=$OPTIMIZER datasets=[$DATASETS] use_kfoldsplit=$USE_KFOLDSPLIT train_frac=$TRAIN_FRAC patience_factor=$PATIENCE_FACTOR scheduler_metric=$SCHEDULER_METRIC loss=$LOSS gradient_accumulation_steps=$GRADIENT_ACCUMULATION_STEPS batch_size=$BATCH_SIZE eval_batch_size=$EVAL_BATCH_SIZE sample_during_eval=$SAMPLE_DURING_EVAL trainer=$TRAINER lr=$LR model.archive=$MODEL_ARCHIVE loss.beta=$LOSS_BETA seed=$SEED n_epochs=$N_EPOCHS eval_every=$EVAL_EVERY eval_train_every=$EVAL_TRAIN_EVERY eval_only_once=$EVAL_ONLY_ONCE loss.label_smoothing=$LABEL_SMOOTHING 174 | done 175 | -------------------------------------------------------------------------------- /scripts/run_multi_robust.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Default parameters 4 | MODEL="gemma-2b" 5 | DATASETS="goqa_0,goqa_1,goqa_2,goqa_3,goqa_4" 6 | STEP_SIZES=("0.0000005") # exponential group rate 7 | DIVIDE_BY_TOTALCOUNT="True" 8 | TRAIN_FRAC=0.8 9 | LOSS="ripo" 10 | GRADIENT_ACCUMULATION_STEPS=2 11 | BATCH_SIZE=16 12 | EVAL_BATCH_SIZE=8 13 | SAMPLE_DURING_EVAL="False" 14 | TRAINER="GroupTrainerEarlyStop" 15 | LR=6e-5 16 | LABEL_SMOOTHING=0 17 | MODEL_ARCHIVE="path-to-sft-policy" 18 | LOSS_BETA=0.01 19 | N_EPOCHS=30 20 | EVAL_EVERY=960 21 | EVAL_TRAIN_EVERY=192 22 | EVAL_ONLY_ONCE="False" 23 | NSEEDS=4 24 | PATIENCE_FACTOR=2 25 | SCHEDULER_METRIC="loss" 26 | ADAPTIVE_STEP_SIZE="True" 27 | USE_KFOLDSPLIT="False" 28 | OPTIMIZER="AdamW" 29 | STEP_FACTOR=0.5 30 | MIN_LR=0.00000001 31 | # Parse arguments 32 | while [[ $# -gt 0 ]]; do 33 | key="$1" 34 | case $key in 35 | --model) 36 | MODEL="$2" 37 | shift # past argument 38 | shift # past value 39 | ;; 40 | --datasets) 41 | DATASETS="$2" 42 | shift # past argument 43 | shift # past value 44 | ;; 45 | --step_sizes) 46 | IFS=',' read -r -a STEP_SIZES <<< "$2" # Read comma-separated list into array 47 | shift # past argument 48 | shift # past value 49 | ;; 50 | --divide_by_totalcount) 51 | DIVIDE_BY_TOTALCOUNT="$2" 52 | shift # past argument 53 | shift # past value 54 | ;; 55 | --train_frac) 56 | TRAIN_FRAC="$2" 57 | shift # past argument 58 | shift # past value 59 | ;; 60 | --loss) 61 | LOSS="$2" 62 | shift # past argument 63 | shift # past value 64 | ;; 65 | --gradient_accumulation_steps) 66 | GRADIENT_ACCUMULATION_STEPS="$2" 67 | shift # past argument 68 | shift # past value 69 | ;; 70 | --batch_size) 71 | BATCH_SIZE="$2" 72 | shift # past argument 73 | shift # past value 74 | ;; 75 | --eval_batch_size) 76 | EVAL_BATCH_SIZE="$2" 77 | shift # past argument 78 | shift # past value 79 | ;; 80 | --scheduler_metric) 81 | SCHEDULER_METRIC="$2" 82 | shift # past argument 83 | shift # past value 84 | ;; 85 | --adaptive_step_size) 86 | ADAPTIVE_STEP_SIZE="$2" 87 | shift # past argument 88 | shift # past value 89 | ;; 90 | --sample_during_eval) 91 | SAMPLE_DURING_EVAL="$2" 92 | shift # past argument 93 | shift # past value 94 | ;; 95 | --eval_only_once) 96 | EVAL_ONLY_ONCE="$2" 97 | shift # past argument 98 | shift # past value 99 | ;; 100 | --optimizer) 101 | OPTIMIZER="$2" 102 | shift # past argument 103 | shift # past value 104 | ;; 105 | --trainer) 106 | TRAINER="$2" 107 | shift # past argument 108 | shift # past value 109 | ;; 110 | --lr) 111 | LR="$2" 112 | shift # past argument 113 | shift # past value 114 | ;; 115 | --min_lr) 116 | MIN_LR="$2" 117 | shift # past argument 118 | shift # past value 119 | ;; 120 | --label_smoothing) 121 | LABEL_SMOOTHING="$2" 122 | shift # past argument 123 | shift # past value 124 | ;; 125 | --model_archive) 126 | MODEL_ARCHIVE="$2" 127 | shift # past argument 128 | shift # past value 129 | ;; 130 | --loss_beta) 131 | LOSS_BETA="$2" 132 | shift # past argument 133 | shift # past value 134 | ;; 135 | --step_factor) 136 | STEP_FACTOR="$2" 137 | shift # past argument 138 | shift # past value 139 | ;; 140 | --n_epochs) 141 | N_EPOCHS="$2" 142 | shift # past argument 143 | shift # past value 144 | ;; 145 | --nseeds) 146 | NSEEDS="$2" 147 | shift # past argument 148 | shift # past value 149 | ;; 150 | --eval_every) 151 | EVAL_EVERY="$2" 152 | shift # past argument 153 | shift # past value 154 | ;; 155 | --eval_train_every) 156 | EVAL_TRAIN_EVERY="$2" 157 | shift # past argument 158 | shift # past value 159 | ;; 160 | --use_kfoldsplit) 161 | USE_KFOLDSPLIT="$2" 162 | shift # past argument 163 | shift # past value 164 | ;; 165 | --patience_factor) 166 | PATIENCE_FACTOR="$2" 167 | shift # past argument 168 | shift # past value 169 | ;; 170 | *) # unknown option 171 | echo "Unknown option: $1" 172 | exit 1 173 | ;; 174 | esac 175 | done 176 | 177 | # Outer loop over step sizes 178 | for STEP_SIZE in "${STEP_SIZES[@]}" 179 | do 180 | echo "Running simulations for step size $STEP_SIZE" 181 | # Loop over seeds 0 to 4 182 | for SEED in $(seq 0 $NSEEDS) 183 | do 184 | echo "Running training with seed $SEED" 185 | python -u train.py model=$MODEL min_lr=$MIN_LR loss.step_factor=$STEP_FACTOR datasets=[$DATASETS] optimizer=$OPTIMIZER loss.step_size=$STEP_SIZE use_kfoldsplit=$USE_KFOLDSPLIT patience_factor=$PATIENCE_FACTOR scheduler_metric=$SCHEDULER_METRIC loss.adaptive_step_size=$ADAPTIVE_STEP_SIZE loss.divide_by_totalcount=$DIVIDE_BY_TOTALCOUNT train_frac=$TRAIN_FRAC loss=$LOSS loss.label_smoothing=$LABEL_SMOOTHING gradient_accumulation_steps=$GRADIENT_ACCUMULATION_STEPS batch_size=$BATCH_SIZE eval_batch_size=$EVAL_BATCH_SIZE sample_during_eval=$SAMPLE_DURING_EVAL trainer=$TRAINER lr=$LR model.archive=$MODEL_ARCHIVE loss.beta=$LOSS_BETA seed=$SEED n_epochs=$N_EPOCHS eval_every=$EVAL_EVERY eval_train_every=$EVAL_TRAIN_EVERY eval_only_once=$EVAL_ONLY_ONCE 186 | done 187 | done 188 | -------------------------------------------------------------------------------- /src/eval/fast_oai.py: -------------------------------------------------------------------------------- 1 | import openai 2 | import os 3 | from dotenv import load_dotenv 4 | from typing import List, Optional, Tuple 5 | import asyncio 6 | from asyncio import Semaphore, Lock 7 | import logging 8 | from time import time, sleep 9 | 10 | 11 | dotenv_path = os.path.join(os.path.dirname(__file__), '..', '.env') 12 | load_dotenv(dotenv_path) 13 | 14 | 15 | openai.api_key = os.getenv("OPENAI_API_KEY")###requires adding openai_api_key as environment variable manually 16 | if openai.api_key is None: 17 | logging.warning("openai.api_key is None") 18 | 19 | start = None 20 | num_requests = 0 21 | 22 | class TokenBucket: 23 | def __init__(self, rate: int): 24 | # rate is in requests per second 25 | self._rate = rate 26 | self._capacity = rate 27 | self._tokens = self._capacity 28 | self._last_refill = time() 29 | 30 | async def consume(self): 31 | while self._tokens < 1: 32 | self._refill() 33 | await asyncio.sleep(1) # Sleep for some time before trying again 34 | self._tokens -= 1 35 | 36 | def _refill(self): 37 | now = time() 38 | time_passed = now - self._last_refill 39 | refill_amount = time_passed * self._rate 40 | self._tokens = min(self._capacity, self._tokens + refill_amount) 41 | self._last_refill = now 42 | 43 | MaybeTokenBucket = Optional[TokenBucket] 44 | 45 | 46 | async def _call_chat(system_prompt: str, 47 | user_prompt:str, 48 | temperature: float=1., 49 | token_bucket: MaybeTokenBucket=None, 50 | timeout: int=20, 51 | max_retries=5, 52 | model="gpt-3.5-turbo") -> str: 53 | done = False 54 | messages= [ 55 | {"role": "system", "content": system_prompt}, 56 | {"role": "user", "content": user_prompt}, 57 | ] 58 | backoff = 1 59 | retries = 0 60 | while not done: 61 | try: 62 | if token_bucket is not None: 63 | await token_bucket.consume() 64 | response = await asyncio.wait_for(openai.ChatCompletion.acreate( 65 | model=model, 66 | messages=messages, 67 | ), timeout=timeout) 68 | completion = response.choices[0].message.content 69 | total_tokens = response.usage.total_tokens 70 | done=True 71 | except asyncio.TimeoutError as e: 72 | if backoff > 128: 73 | print(f"Failed to call chat after {backoff} seconds due to {e}") 74 | completion = None 75 | total_tokens = 0 76 | done = True 77 | await asyncio.sleep(backoff) 78 | backoff *= 2 79 | except Exception as e: 80 | await asyncio.sleep(backoff) 81 | backoff *= 2 82 | backoff = min(backoff, 64) 83 | retries += 1 84 | if retries >= max_retries: 85 | print(f"Failed to call chat after {retries} retries due to {e}:\n\nMessages:{messages}\n\n") 86 | completion = None 87 | total_tokens = 0 88 | done = True 89 | return completion, total_tokens 90 | 91 | 92 | async def _handle_chat(system_prompt: str, user_prompt: str, 93 | token_bucket: TokenBucket, semaphore: Semaphore, lock: Lock, results_counter: dict, 94 | model: str, timeout: int, temperature: float) -> str: 95 | async with semaphore: 96 | completion, toks = await _call_chat(system_prompt=system_prompt, 97 | user_prompt=user_prompt, 98 | temperature=temperature, 99 | token_bucket=token_bucket, 100 | timeout=timeout, 101 | model=model) 102 | 103 | async with lock: # Ensure atomicity of operations 104 | results_counter['num_requests'] += 1 105 | results_counter['tokens'] += toks 106 | print_period = 10 107 | if results_counter['num_requests'] % print_period == 0: 108 | duration = time() - results_counter['start_time'] 109 | duration_min = duration / 60 110 | cost = results_counter['cost_per_ktok'] * results_counter['tokens'] / 1000 111 | print(f"{results_counter['num_requests']=}, {duration=:.2f} rate per min={results_counter['num_requests'] / duration_min:.2f} tokens / request: {results_counter['tokens'] / results_counter['num_requests']:.2f} cost: {cost:.2f}") 112 | 113 | return completion 114 | 115 | 116 | def call_chats(prompts: List[Tuple[str, str]], 117 | model: str="gpt-3.5-turbo", 118 | timeout: int=10, 119 | temperature: float=1.) -> List[str]: 120 | # prompts should be [(system_prompt, user_prompt), ...] 121 | max_concurrent_tasks = 20 122 | oai_quotas = {'gpt-3.5-turbo': 150, 'gpt-3.5-turbo-16k': 200, 'gpt-4': 200} 123 | oai_costs_per_ktok = {'gpt-3.5-turbo': 0.0015, 'gpt-3.5-turbo-16k': 0.003, 'gpt-4': 0.03} 124 | oai_quota_per_minute = oai_quotas[model] 125 | oai_quota_per_second = oai_quota_per_minute // 60 126 | semaphore = Semaphore(max_concurrent_tasks) 127 | token_bucket = TokenBucket(oai_quota_per_second) 128 | lock = Lock() 129 | results_counter = {'num_requests': 0, 'start_time': time(), 'tokens': 0, 'cost_per_ktok': oai_costs_per_ktok[model]} 130 | async def gather_tasks(): 131 | tasks = [_handle_chat(system_prompt, user_prompt, token_bucket, semaphore, 132 | lock, results_counter, 133 | model, timeout, temperature) for system_prompt, user_prompt in prompts] 134 | return await asyncio.gather(*tasks) 135 | return asyncio.run(gather_tasks()) 136 | 137 | 138 | def test_chats(): 139 | fun_sentences = [ 140 | "The sky is blue today.", 141 | "Ducks quack to communicate.", 142 | "Bananas are my favorite fruit.", 143 | "Chocolate makes everything better.", 144 | "Singing in the rain is fun.", 145 | "Cats have nine lives, they say.", 146 | "The moon is made of cheese.", 147 | "Robots will take over the world.", 148 | "Pineapples belong on pizza.", 149 | "Unicorns are just horses with a twist." 150 | ] 151 | system_prompts = ["You are a pig-latinifiying bot. Please reproduce the user message in Pig Latin"] * len(fun_sentences) 152 | 153 | completions = call_chats(list(zip(system_prompts, fun_sentences))) 154 | for completion, fun_sentence in zip(completions, fun_sentences): 155 | print(f"{fun_sentence=}, {completion=}") 156 | 157 | 158 | if __name__ == '__main__': 159 | test_chats() -------------------------------------------------------------------------------- /src/loss_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | 4 | import torch 5 | import torch.nn as nn 6 | from typing import Dict, List, Union, Tuple 7 | import torch.nn.functional as F 8 | from src.utils import pad_to_length 9 | 10 | def preference_loss(policy_chosen_logps: torch.FloatTensor, 11 | policy_rejected_logps: torch.FloatTensor, 12 | reference_chosen_logps: torch.FloatTensor, 13 | reference_rejected_logps: torch.FloatTensor, 14 | beta: float, 15 | label_smoothing: float = 0.0, 16 | ipo: bool = False, 17 | reference_free: bool = False) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: 18 | """Compute the DPO loss for a batch of policy and reference model log probabilities. 19 | 20 | Args: 21 | policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,) 22 | policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,) 23 | reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,) 24 | reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,) 25 | beta: Temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5. We ignore the reference model as beta -> 0. 26 | label_smoothing: conservativeness for DPO loss, which assumes that preferences are noisy (flipped with probability label_smoothing) 27 | ipo: If True, use the IPO loss instead of the DPO loss. 28 | reference_free: If True, we ignore the _provided_ reference model and implicitly use a reference model that assigns equal probability to all responses. 29 | 30 | Returns: 31 | A tuple of three tensors: (losses, chosen_rewards, rejected_rewards). 32 | The losses tensor contains the DPO loss for each example in the batch. 33 | The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively. 34 | """ 35 | pi_logratios = policy_chosen_logps - policy_rejected_logps 36 | ref_logratios = reference_chosen_logps - reference_rejected_logps 37 | 38 | if reference_free: 39 | ref_logratios = 0 40 | 41 | logits = pi_logratios - ref_logratios # also known as h_{\pi_\theta}^{y_w,y_l} 42 | 43 | if ipo: 44 | losses = (logits - 1/(2 * beta)) ** 2 # Eq. 17 of https://arxiv.org/pdf/2310.12036v2.pdf 45 | else: 46 | # Eq. 3 https://ericmitchell.ai/cdpo.pdf; label_smoothing=0 gives original DPO (Eq. 7 of https://arxiv.org/pdf/2305.18290.pdf) 47 | losses = -F.logsigmoid(beta * logits) * (1 - label_smoothing) - F.logsigmoid(-beta * logits) * label_smoothing 48 | 49 | chosen_rewards = beta * (policy_chosen_logps - reference_chosen_logps).detach() 50 | rejected_rewards = beta * (policy_rejected_logps - reference_rejected_logps).detach() 51 | 52 | return losses, chosen_rewards, rejected_rewards 53 | 54 | 55 | def _get_batch_logps(logits: torch.FloatTensor, labels: torch.LongTensor, average_log_prob: bool = False) -> torch.FloatTensor: 56 | """Compute the log probabilities of the given labels under the given logits. 57 | 58 | Args: 59 | logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size) 60 | labels: Labels for which to compute the log probabilities. Label tokens with a value of -100 are ignored. Shape: (batch_size, sequence_length) 61 | average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens. 62 | 63 | Returns: 64 | A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits. 65 | """ 66 | assert logits.shape[:-1] == labels.shape 67 | 68 | labels = labels[:, 1:].clone() 69 | logits = logits[:, :-1, :] 70 | loss_mask = (labels != -100) 71 | 72 | # dummy token; we'll ignore the losses on these tokens later 73 | labels[labels == -100] = 0 74 | 75 | per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) 76 | 77 | if average_log_prob: 78 | return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) 79 | else: 80 | return (per_token_logps * loss_mask).sum(-1) 81 | 82 | 83 | def concatenated_inputs(batch: Dict[str, Union[List, torch.LongTensor]]) -> Dict[str, torch.LongTensor]: 84 | """Concatenate the chosen and rejected inputs into a single tensor. 85 | 86 | Args: 87 | batch: A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors of shape (batch_size, sequence_length). 88 | 89 | Returns: 90 | A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'. 91 | """ 92 | max_length = max(batch['chosen_input_ids'].shape[1], batch['rejected_input_ids'].shape[1]) 93 | concatenated_batch = {} 94 | for k in batch: 95 | if k.startswith('chosen') and isinstance(batch[k], torch.Tensor): 96 | pad_value = -100 if 'labels' in k else 0 97 | concatenated_key = k.replace('chosen', 'concatenated') 98 | concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value) 99 | for k in batch: 100 | if k.startswith('rejected') and isinstance(batch[k], torch.Tensor): 101 | pad_value = -100 if 'labels' in k else 0 102 | concatenated_key = k.replace('rejected', 'concatenated') 103 | concatenated_batch[concatenated_key] = torch.cat(( 104 | concatenated_batch[concatenated_key], 105 | pad_to_length(batch[k], max_length, pad_value=pad_value), 106 | ), dim=0) 107 | return concatenated_batch 108 | 109 | def concatenated_forward(model: nn.Module, 110 | batch: Dict[str, Union[List, torch.LongTensor]]) -> Tuple[torch.FloatTensor, torch.FloatTensor]: 111 | """ 112 | Run the given model on the given batch of inputs, concatenating the chosen 113 | and rejected inputs together. We do this to avoid doing two forward passes, 114 | because it's faster for FSDP. 115 | """ 116 | 117 | concatenated_batch = concatenated_inputs(batch) 118 | print(model.device) 119 | all_logits = model(concatenated_batch['concatenated_input_ids'], \ 120 | attention_mask=concatenated_batch['concatenated_attention_mask']).logits.to(torch.float32) 121 | print(all_logits.device) 122 | all_logps = _get_batch_logps(all_logits, concatenated_batch['concatenated_labels'], average_log_prob=False) 123 | print(all_logps.device,'all-logp') 124 | chosen_logps = all_logps[:batch['chosen_input_ids'].shape[0]] 125 | rejected_logps = all_logps[batch['chosen_input_ids'].shape[0]:] 126 | 127 | return chosen_logps, rejected_logps -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | torch.backends.cuda.matmul.allow_tf32 = True 3 | import torch.nn as nn 4 | import transformers 5 | from src.utils import get_local_dir, get_local_run_dir, get_local_run_dir_group, disable_dropout, init_distributed, get_open_port 6 | import os 7 | import hydra 8 | import torch.multiprocessing as mp 9 | from omegaconf import OmegaConf, DictConfig 10 | from src.trainers_factory import get_trainer 11 | import wandb 12 | import json 13 | import socket 14 | from typing import Optional, Set 15 | from src.models import ModelGenerator 16 | from src.data_selection import DataSelector 17 | 18 | 19 | #System specific installs: 20 | if os.name != 'nt': 21 | #We can't install resource on windows 22 | import resource 23 | 24 | 25 | OmegaConf.register_new_resolver("get_local_run_dir_group", lambda exp_name, group_name, local_dirs: get_local_run_dir_group(exp_name, group_name, local_dirs)) 26 | 27 | 28 | def worker_main(rank: int, world_size: int, config: DictConfig, policy: nn.Module, 29 | reference_model: Optional[nn.Module] = None, data_selector: Optional[DataSelector] = None): 30 | """Main function for each worker process (may be only 1 for BasicTrainer/TensorParallelTrainer).""" 31 | if 'FSDP' in config.trainer: 32 | init_distributed(rank, world_size, port=config.fsdp_port) 33 | 34 | if config.debug: 35 | wandb.init = lambda *args, **kwargs: None 36 | wandb.log = lambda *args, **kwargs: None 37 | 38 | if rank == 0 and config.wandb.enabled: 39 | os.environ['WANDB_CACHE_DIR'] = get_local_dir(config.local_dirs) 40 | wandb.login(key=config.wandb.key) 41 | tags=[f"n_epochs_{config.n_epochs}",f"learning_rate_{config.lr}",f"batch_size_{config.batch_size}"] 42 | if 'po' in config.loss.name: 43 | tags.append(f"beta_{config.loss.beta}") 44 | wandb.init( 45 | group=config.group_name, 46 | entity=config.wandb.entity, 47 | project=config.wandb.project, 48 | config=OmegaConf.to_container(config), 49 | dir=get_local_dir(config.local_dirs), 50 | name=config.exp_name, 51 | tags=tags 52 | ) 53 | 54 | 55 | #TrainerClass = getattr(trainers, config.trainer) 56 | print(f'Creating trainer on process {rank} with world size {world_size}') 57 | trainer=get_trainer(config.trainer,policy, config, config.seed, config.local_run_dir, reference_model=reference_model,data_selector=data_selector, rank=rank, world_size=world_size) 58 | 59 | trainer.train() 60 | trainer.save() 61 | 62 | 63 | @hydra.main(version_base=None, config_path="config", config_name="config") 64 | def main(config: DictConfig): 65 | """Main entry point for training. Validates config, creates/initializes model(s), and kicks off worker process(es).""" 66 | 67 | if config.loss.name in {'dpo','ipo'}: 68 | exp_name=f"{config.loss.name}_beta_{config.loss.beta}_seed_{config.seed}_batch_{config.batch_size}_nepoch_{config.n_epochs}_lr_{config.lr}_avg_fr_vald" 69 | elif config.loss.name in {'sft','base'}: 70 | exp_name=f"{config.loss.name}_seed_{config.seed}_batch_{config.batch_size}_nepoch_{config.n_epochs}_lr_{config.lr}" 71 | elif config.loss.name in {'rdpo','ripo'}: 72 | exp_name=f"{config.loss.name}_beta_{config.loss.beta}_seed_{config.seed}_expstepsize_{config.loss.step_size}_nepoch_{config.n_epochs}_batch_{config.batch_size}_weightedbatch_{config.weighted_batches}" 73 | else: 74 | raise NotImplementedError 75 | config.exp_name=exp_name 76 | # Resolve hydra references, e.g. so we don't re-compute the run directory 77 | dataset_group = config.datasets[0].split('_')[0] 78 | #print(f'{dataset_group}_{len(config.datasets)}'+f'tr_frac{config.train_frac}'+f'{config.model.name}_spairs_{config.sep_pairs}_{config.trainer}') 79 | #if config.new_grp_name: 80 | group_indices="_".join(dataset.split('_')[1] for dataset in config.datasets) 81 | group=f'{dataset_group}_{group_indices}'+f'tr_frac{config.train_frac}'+f'{config.model.name_or_path}_spairs_{config.sep_pairs}_{config.trainer}' 82 | config.group_name=group 83 | print(group) 84 | OmegaConf.resolve(config) 85 | 86 | missing_keys: Set[str] = OmegaConf.missing_keys(config) 87 | if missing_keys: 88 | raise ValueError(f"Got missing keys in config:\n{missing_keys}") 89 | 90 | if config.eval_every % config.batch_size != 0: 91 | print('WARNING: eval_every must be divisible by batch_size') 92 | print('Setting eval_every to', config.eval_every - config.eval_every % config.batch_size) 93 | config.eval_every = config.eval_every - config.eval_every % config.batch_size 94 | 95 | if 'FSDP' in config.trainer and config.fsdp_port is None: 96 | free_port = get_open_port() 97 | print('no FSDP port specified; using open port for FSDP:', free_port) 98 | config.fsdp_port = free_port 99 | 100 | print(OmegaConf.to_yaml(config)) 101 | 102 | config_path = os.path.join(config.local_run_dir, 'config.yaml') 103 | with open(config_path, 'w') as f: 104 | OmegaConf.save(config, f) 105 | 106 | print('=' * 80) 107 | print(f'Writing to {socket.gethostname()}:{config.local_run_dir}') 108 | print('=' * 80) 109 | 110 | #CREATES THE POLICY 111 | os.environ['XDG_CACHE_HOME'] = get_local_dir(config.local_dirs) 112 | 113 | print('build data selector') 114 | data_selector = hydra.utils.instantiate(config.get('data_selection', None), 115 | other_config=config, 116 | _recursive_=False) 117 | 118 | print('building policy') 119 | #TODO: Temporary code -> we can store all models in a class and request access to them as needed in the trainer 120 | model_generator = ModelGenerator() 121 | models = model_generator.generate_models(config) 122 | 123 | if config.loss.name == 'sft': 124 | policy = models.get('sft_model', None) 125 | reference_model = None 126 | elif config.loss.name == 'base': 127 | policy = models.get('base_model', None) 128 | reference_model = None 129 | else: 130 | policy = models.get('policy_model', None) 131 | reference_model = models.get('ref_model', None) 132 | 133 | if 'FSDP' in config.trainer and os.name != 'nt': 134 | 135 | world_size = torch.cuda.device_count() 136 | print('starting', world_size, 'processes for FSDP training') 137 | soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE) 138 | resource.setrlimit(resource.RLIMIT_NOFILE, (hard, hard)) 139 | print(f'setting RLIMIT_NOFILE soft limit to {hard} from {soft}') 140 | mp.spawn(worker_main, nprocs=world_size, args=(world_size, config, policy, reference_model, data_selector), join=True) 141 | else: 142 | print('starting single-process worker') 143 | worker_main(0, 1, config, policy, reference_model, data_selector=data_selector) 144 | 145 | 146 | if __name__ == '__main__': 147 | main() -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import getpass 3 | from datetime import datetime 4 | import torch 5 | import random 6 | import numpy as np 7 | import torch.distributed as dist 8 | import inspect 9 | import importlib.util 10 | import socket 11 | import os 12 | from typing import Dict, Union, Type, List 13 | 14 | 15 | def get_open_port(): 16 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: 17 | s.bind(('', 0)) # bind to all interfaces and use an OS provided port 18 | return s.getsockname()[1] # return only the port number 19 | 20 | 21 | def get_remote_file(remote_path, local_path=None): 22 | hostname, path = remote_path.split(':') 23 | local_hostname = socket.gethostname() 24 | if hostname == local_hostname or hostname == local_hostname[:local_hostname.find('.')]: 25 | return path 26 | 27 | if local_path is None: 28 | local_path = path 29 | # local_path = local_path.replace('/scr-ssd', '/scr') 30 | if os.path.exists(local_path): 31 | return local_path 32 | local_dir = os.path.dirname(local_path) 33 | os.makedirs(local_dir, exist_ok=True) 34 | 35 | print(f'Copying {hostname}:{path} to {local_path}') 36 | os.system(f'scp {remote_path} {local_path}') 37 | return local_path 38 | 39 | 40 | def rank0_print(*args, **kwargs): 41 | """Print, but only on rank 0.""" 42 | if not dist.is_initialized() or dist.get_rank() == 0: 43 | print(*args, **kwargs) 44 | 45 | 46 | def get_local_dir(prefixes_to_resolve: List[str]) -> str: 47 | """Return the path to the cache directory for this user.""" 48 | for prefix in prefixes_to_resolve: 49 | if os.path.exists(prefix): 50 | return f"{prefix}/{getpass.getuser()}" 51 | os.makedirs(prefix) 52 | return f"{prefix}/{getpass.getuser()}" 53 | 54 | 55 | def get_local_run_dir(exp_name: str, local_dirs: List[str]) -> str: 56 | """Create a local directory to store outputs for this run, and return its path.""" 57 | now = datetime.now() 58 | timestamp = now.strftime("%Y-%m-%d_%H-%M-%S_%f") 59 | run_dir = f"{get_local_dir(local_dirs)}/{exp_name}_{timestamp}" 60 | os.makedirs(run_dir, exist_ok=True) 61 | return run_dir 62 | 63 | def get_local_run_dir_group(exp_name: str, group_name: str, local_dirs: List[str]) -> str: 64 | """Create a local directory to store outputs for this run, and return its path.""" 65 | now = datetime.now() 66 | timestamp = now.strftime("%Y-%m-%d_%H-%M-%S_%f") 67 | run_dir = f"{get_local_dir(local_dirs)}/{group_name}/{exp_name}_{timestamp}" 68 | os.makedirs(run_dir, exist_ok=True) 69 | return run_dir 70 | 71 | 72 | def slice_and_move_batch_for_device(batch: Dict, rank: int, world_size: int, device: str) -> Dict: 73 | """Slice a batch into chunks, and move each chunk to the specified device.""" 74 | chunk_size = len(list(batch.values())[0]) // world_size 75 | start = chunk_size * rank 76 | end = chunk_size * (rank + 1) 77 | sliced = {k: v[start:end] for k, v in batch.items()} 78 | on_device = {k: (v.to(device) if isinstance(v, torch.Tensor) else v) for k, v in sliced.items()} 79 | return on_device 80 | 81 | 82 | def pad_to_length(tensor: torch.Tensor, length: int, pad_value: Union[int, float], dim: int = -1) -> torch.Tensor: 83 | if tensor.size(dim) >= length: 84 | return tensor 85 | else: 86 | pad_size = list(tensor.shape) 87 | pad_size[dim] = length - tensor.size(dim) 88 | return torch.cat([tensor, pad_value * torch.ones(*pad_size, dtype=tensor.dtype, device=tensor.device)], dim=dim) 89 | 90 | 91 | def all_gather_if_needed(values: torch.Tensor, rank: int, world_size: int) -> torch.Tensor: 92 | """Gather and stack/cat values from all processes, if there are multiple processes.""" 93 | if world_size == 1: 94 | return values 95 | 96 | all_values = [torch.empty_like(values).to(rank) for _ in range(world_size)] 97 | dist.all_gather(all_values, values) 98 | cat_function = torch.cat if values.dim() > 0 else torch.stack 99 | return cat_function(all_values, dim=0) 100 | 101 | 102 | def formatted_dict(d: Dict) -> Dict: 103 | """Format a dictionary for printing.""" 104 | return {k: (f"{v:.5g}" if type(v) == float else v) for k, v in d.items()} 105 | 106 | 107 | def disable_dropout(model: torch.nn.Module): 108 | """Disable dropout in a model.""" 109 | for module in model.modules(): 110 | if isinstance(module, torch.nn.Dropout): 111 | module.p = 0 112 | 113 | 114 | def print_gpu_memory(rank: int = None, message: str = ''): 115 | """Print the amount of GPU memory currently allocated for each GPU.""" 116 | if torch.cuda.is_available(): 117 | device_count = torch.cuda.device_count() 118 | for i in range(device_count): 119 | device = torch.device(f'cuda:{i}') 120 | allocated_bytes = torch.cuda.memory_allocated(device) 121 | if allocated_bytes == 0: 122 | continue 123 | print('*' * 40) 124 | print(f'[{message} rank {rank} ] GPU {i}: {allocated_bytes / 1024**2:.2f} MB') 125 | print('*' * 40) 126 | 127 | 128 | def get_block_class_from_model(model: torch.nn.Module, block_class_name: str) -> torch.nn.Module: 129 | """Get the class of a block from a model, using the block's class name.""" 130 | for module in model.modules(): 131 | if module.__class__.__name__ == block_class_name: 132 | return module.__class__ 133 | raise ValueError(f"Could not find block class {block_class_name} in model {model}") 134 | 135 | 136 | def get_block_class_from_model_class_and_block_name(model_class: Type, block_class_name: str) -> Type: 137 | filepath = inspect.getfile(model_class) 138 | assert filepath.endswith('.py'), f"Expected a .py file, got {filepath}" 139 | assert os.path.exists(filepath), f"File {filepath} does not exist" 140 | assert "transformers" in filepath, f"Expected a transformers model, got {filepath}" 141 | 142 | module_name = filepath[filepath.find('transformers'):].replace('/', '.')[:-3] 143 | print(f"Searching in file {filepath}, module {module_name} for class {block_class_name}") 144 | 145 | # Load the module dynamically 146 | spec = importlib.util.spec_from_file_location(module_name, filepath) 147 | module = importlib.util.module_from_spec(spec) 148 | spec.loader.exec_module(module) 149 | 150 | # Get the class dynamically 151 | class_ = getattr(module, block_class_name) 152 | print(f"Found class {class_} in module {module_name}") 153 | return class_ 154 | 155 | 156 | def init_distributed(rank: int, world_size: int, master_addr: str = 'localhost', port: int = 12355, backend: str = 'nccl'): 157 | print(rank, 'initializing distributed') 158 | os.environ["MASTER_ADDR"] = master_addr 159 | os.environ["MASTER_PORT"] = str(port) 160 | dist.init_process_group(backend, rank=rank, world_size=world_size) 161 | torch.cuda.set_device(rank) 162 | 163 | class TemporarilySeededRandom: 164 | def __init__(self, seed): 165 | """Temporarily set the random seed, and then restore it when exiting the context.""" 166 | self.seed = int(seed) 167 | self.stored_state = None 168 | self.stored_np_state = None 169 | 170 | def __enter__(self): 171 | # Store the current random state 172 | self.stored_state = random.getstate() 173 | self.stored_np_state = np.random.get_state() 174 | 175 | # Set the random seed 176 | random.seed(self.seed) 177 | np.random.seed(self.seed) 178 | 179 | def __exit__(self, exc_type, exc_value, traceback): 180 | # Restore the random state 181 | random.setstate(self.stored_state) 182 | np.random.set_state(self.stored_np_state) 183 | -------------------------------------------------------------------------------- /src/trainers/paralleltrainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | torch.backends.cuda.matmul.allow_tf32 = True 3 | import torch.nn.functional as F 4 | import torch.nn as nn 5 | import transformers 6 | from omegaconf import DictConfig 7 | 8 | import torch.distributed as dist 9 | from torch.distributed.fsdp import ( 10 | FullyShardedDataParallel as FSDP, 11 | MixedPrecision, 12 | StateDictType, 13 | BackwardPrefetch, 14 | ShardingStrategy, 15 | CPUOffload, 16 | ) 17 | from torch.distributed.fsdp.api import FullStateDictConfig, FullOptimStateDictConfig 18 | from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy 19 | import tensor_parallel as tp 20 | import contextlib 21 | 22 | from src.preference_datasets import get_batch_iterator 23 | from src.utils import ( 24 | slice_and_move_batch_for_device, 25 | formatted_dict, 26 | all_gather_if_needed, 27 | pad_to_length, 28 | get_block_class_from_model, 29 | rank0_print, 30 | get_local_dir, 31 | ) 32 | from src.data_selection import DataSelector 33 | from src.loss_utils import ( 34 | preference_loss, 35 | _get_batch_logps, 36 | concatenated_inputs) 37 | 38 | import numpy as np 39 | import wandb 40 | import tqdm 41 | 42 | import random 43 | import os 44 | from collections import defaultdict 45 | import time 46 | import json 47 | import functools 48 | from typing import Optional, Dict, List, Union, Tuple 49 | 50 | 51 | 52 | from src.trainers.basictrainer import BasicTrainer 53 | 54 | 55 | 56 | class FSDPTrainer(BasicTrainer): 57 | def __init__(self, policy: nn.Module, config: DictConfig, seed: int, run_dir: str, reference_model: Optional[nn.Module] = None, rank: int = 0, world_size: int = 1): 58 | """A trainer subclass that uses PyTorch FSDP to shard the model across multiple GPUs. 59 | 60 | This trainer will shard both the policy and reference model across all available GPUs. 61 | Models are sharded at the block level, where the block class name is provided in the config. 62 | """ 63 | 64 | super().__init__(policy, config, seed, run_dir, reference_model, rank, world_size) 65 | assert config.model.block_name is not None, 'must specify model.block_name (e.g., GPT2Block or GPTNeoXLayer) for FSDP' 66 | 67 | wrap_class = get_block_class_from_model(policy, config.model.block_name) 68 | model_auto_wrap_policy = functools.partial(transformer_auto_wrap_policy, transformer_layer_cls={wrap_class},) 69 | 70 | shared_fsdp_kwargs = dict( 71 | auto_wrap_policy=model_auto_wrap_policy, 72 | sharding_strategy=ShardingStrategy.FULL_SHARD, 73 | cpu_offload=CPUOffload(offload_params=False), 74 | backward_prefetch=BackwardPrefetch.BACKWARD_PRE, 75 | device_id=rank, 76 | ignored_modules=None, 77 | limit_all_gathers=False, 78 | use_orig_params=False, 79 | sync_module_states=False 80 | ) 81 | 82 | rank0_print('Sharding policy...') 83 | mp_dtype = getattr(torch, config.model.fsdp_policy_mp) if config.model.fsdp_policy_mp is not None else None 84 | policy_mp_policy = MixedPrecision(param_dtype=mp_dtype, reduce_dtype=mp_dtype, buffer_dtype=mp_dtype) 85 | self.policy = FSDP(policy, **shared_fsdp_kwargs, mixed_precision=policy_mp_policy) 86 | 87 | if config.activation_checkpointing: 88 | rank0_print('Attempting to enable activation checkpointing...') 89 | try: 90 | # use activation checkpointing, according to: 91 | # https://pytorch.org/blog/scaling-multimodal-foundation-models-in-torchmultimodal-with-pytorch-distributed/ 92 | # 93 | # first, verify we have FSDP activation support ready by importing: 94 | from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( 95 | checkpoint_wrapper, 96 | apply_activation_checkpointing, 97 | CheckpointImpl, 98 | ) 99 | non_reentrant_wrapper = functools.partial( 100 | checkpoint_wrapper, 101 | offload_to_cpu=False, 102 | checkpoint_impl=CheckpointImpl.NO_REENTRANT, 103 | ) 104 | except Exception as e: 105 | rank0_print('FSDP activation checkpointing not available:', e) 106 | else: 107 | check_fn = lambda submodule: isinstance(submodule, wrap_class) 108 | rank0_print('Applying activation checkpointing wrapper to policy...') 109 | apply_activation_checkpointing(self.policy, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn) 110 | rank0_print('FSDP activation checkpointing enabled!') 111 | 112 | if config.loss.name in {'dpo', 'ipo'}: 113 | rank0_print('Sharding reference model...') 114 | self.reference_model = FSDP(reference_model, **shared_fsdp_kwargs) 115 | 116 | print('Loaded model on rank', rank) 117 | dist.barrier() 118 | 119 | def clip_gradient(self): 120 | """Clip the gradient norm of the parameters of an FSDP policy, gathering the gradients across all GPUs.""" 121 | return self.policy.clip_grad_norm_(self.config.max_grad_norm).item() 122 | 123 | def save(self, output_dir=None, metrics=None): 124 | """Save policy, optimizer, and scheduler state to disk, gathering from all processes and saving only on the rank 0 process.""" 125 | save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) 126 | with FSDP.state_dict_type(self.policy, StateDictType.FULL_STATE_DICT, state_dict_config=save_policy): 127 | policy_state_dict = self.policy.state_dict() 128 | 129 | if self.rank == 0: 130 | self.write_state_dict(self.example_counter, policy_state_dict, metrics, 'policy.pt', output_dir) 131 | del policy_state_dict 132 | dist.barrier() 133 | 134 | save_policy = FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True) 135 | with FSDP.state_dict_type(self.policy, StateDictType.FULL_STATE_DICT, optim_state_dict_config=save_policy): 136 | optimizer_state_dict = FSDP.optim_state_dict(self.policy, self.optimizer) 137 | 138 | if self.rank == 0: 139 | self.write_state_dict(self.example_counter, optimizer_state_dict, metrics, 'optimizer.pt', output_dir) 140 | del optimizer_state_dict 141 | dist.barrier() 142 | 143 | if self.rank == 0: 144 | scheduler_state_dict = self.scheduler.state_dict() 145 | self.write_state_dict(self.example_counter, scheduler_state_dict, metrics, 'scheduler.pt', output_dir) 146 | dist.barrier() 147 | 148 | 149 | class TensorParallelTrainer(BasicTrainer): 150 | def __init__(self, policy, config, seed, run_dir, reference_model=None, rank=0, world_size=1): 151 | """A trainer subclass that uses TensorParallel to shard the model across multiple GPUs. 152 | 153 | Based on https://github.com/BlackSamorez/tensor_parallel. Note sampling is extremely slow, 154 | see https://github.com/BlackSamorez/tensor_parallel/issues/66. 155 | """ 156 | super().__init__(policy, config, seed, run_dir, reference_model, rank, world_size) 157 | 158 | rank0_print('Sharding policy...') 159 | self.policy = tp.tensor_parallel(policy, sharded=True) 160 | if config.loss.name in {'dpo', 'ipo'}: 161 | rank0_print('Sharding reference model...') 162 | self.reference_model = tp.tensor_parallel(reference_model, sharded=False) 163 | 164 | def save(self, output_dir=None, metrics=None): 165 | """Save (unsharded) policy state to disk.""" 166 | with tp.save_tensor_parallel(self.policy): 167 | policy_state_dict = self.policy.state_dict() 168 | 169 | self.write_state_dict(self.example_counter, policy_state_dict, metrics, 'policy.pt', output_dir) 170 | del policy_state_dict 171 | -------------------------------------------------------------------------------- /src/data_selection.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from abc import ABC, abstractmethod 7 | from collections.abc import Iterable 8 | 9 | from src.utils import ( 10 | get_local_dir, 11 | slice_and_move_batch_for_device) 12 | from src.models import ModelGenerator 13 | from src.loss_utils import ( 14 | preference_loss, 15 | concatenated_forward) 16 | 17 | from omegaconf import DictConfig 18 | from typing import Dict, List, Union 19 | 20 | class DataSelector(ABC): 21 | 22 | """ 23 | Abstract base class for the different data selection functions in our code 24 | base, we can add and adapt this as necessary but it is probably overkill. 25 | """ 26 | 27 | def __init__(self, other_config:str): 28 | self.config = other_config 29 | 30 | def select_top_k(self, vector, k): 31 | 32 | sorted_idx = torch.argsort(vector, descending=True) 33 | 34 | top_x_indices = sorted_idx[:k] 35 | other_indices = sorted_idx[k:] 36 | 37 | return top_x_indices, other_indices 38 | 39 | def subselect_batch(self, batch:dict, selected_idx:torch.tensor, 40 | not_selected_idx:torch.tensor): 41 | """ 42 | Select a subset of the batch, return the selected and not selected subsets. 43 | 44 | """ 45 | 46 | selected_batch = dict() 47 | not_selected_batch = dict() 48 | 49 | #We can try use this: 50 | #sliced = {k: v[start:end] for k, v in batch.items()} only works for consecutive elements 51 | 52 | for key in batch.keys(): 53 | 54 | key_batch = batch[key] 55 | selected_batch[key] = [key_batch[i] for i in selected_idx.to(dtype=torch.long)] 56 | 57 | #If the batch stores as type tensor then map to tensor: 58 | if isinstance(key_batch, torch.Tensor): 59 | selected_batch[key] = torch.stack(selected_batch[key]) 60 | 61 | 62 | if not_selected_idx is not None: 63 | not_selected_batch[key] = [key_batch[i] for i in not_selected_idx.to(dtype=torch.long)] 64 | 65 | #If the data is stored as a tensor then map to tensor: 66 | if isinstance(key_batch, torch.Tensor): 67 | not_selected_batch[key] = torch.stack(not_selected_batch[key]) 68 | 69 | else: 70 | not_selected_batch = None 71 | 72 | return selected_batch, not_selected_batch 73 | 74 | @abstractmethod 75 | def select_batch(self, batch:dict, selected_batch_size:int, **kwargs) -> Iterable: 76 | pass 77 | 78 | class UniformRandomSelection(DataSelector): 79 | 80 | """ 81 | Randomly select and return a subset of the input batch. 82 | 83 | """ 84 | 85 | def __init__(self, other_config): 86 | pass 87 | 88 | def batch_len(self, batch): 89 | """ 90 | Return the length of a list of the first key. 91 | 92 | """ 93 | 94 | keys = list(batch.keys()) 95 | 96 | return len(batch[keys[0]]) 97 | 98 | def select_batch(self, batch:Iterable, selected_batch_size:int, 99 | policy:nn.Module=None, ref_policy:nn.Module=None) -> Iterable: 100 | """ 101 | Return the random/uniform selected batch and not selected batch. 102 | 103 | """ 104 | 105 | blen = self.batch_len(batch) 106 | 107 | if selected_batch_size > blen: 108 | print('selected batch size:{selected_batch_size} is greater than batch size:{blen}') 109 | selected_batch_size = blen 110 | 111 | idx = torch.randperm(blen) 112 | 113 | selected, not_selected = self.subselect_batch(batch, idx[:selected_batch_size], 114 | None if selected_batch_size == blen \ 115 | else idx[selected_batch_size:]) 116 | 117 | return selected, not_selected, selected_batch_size 118 | 119 | 120 | class SFTRHOLossSelection(DataSelector): 121 | pass 122 | 123 | class DPORHOLossSelection(DataSelector): 124 | 125 | """ 126 | Selects and returns a subset of the input batch using the RHO-Loss selection 127 | objective: 128 | 129 | RHO(x,y) = L(x,y) - L_ref(x,y) 130 | 131 | There are two options: 132 | 1. Memory efficient but Compute slow 133 | 134 | We use two forward passes, one to calculate the rho objective without 135 | creating a computation graph and one with .train() which does create a 136 | computation graph but only on the small selected batch. 137 | 138 | 2. Memory Using but Compute fast. 139 | 140 | We use a single forward pass that creates a computation graph for the 141 | entire batch, we then uses the losses calculated from this to select a 142 | sub-batch and then the backprop is only applied to those elements of the 143 | graph - does multiplying by zero at the loss stage prevent gradient being 144 | calculated any further? 145 | 146 | We also want to be able to do gradient accumulation steps 147 | 148 | Aim to implement both? -> might need to adjust depending upon FSDP 149 | How do we test FSDP given locally we only have 1 GPU? 150 | """ 151 | 152 | def __init__(self, ft_state_dict_path, sft_state_dict_path, model, other_config): 153 | 154 | """ 155 | Using the config, create the sft and ft reference models 156 | """ 157 | 158 | #For FSDP we'll need to take in a model and then shard it on each process 159 | #see FSDP trainer script 160 | super().__init__(other_config) 161 | 162 | #local_dir = get_local_dir(other_config.local_dirs) 163 | local_dir = get_local_dir(self.config.local_dirs) 164 | trainer = self.config.get('trainer', 'BasicTrainer') 165 | 166 | model_generator = ModelGenerator() 167 | sft_ref_model = model_generator.\ 168 | create_policy_from_config(model, trainer=trainer, 169 | local_dirs=local_dir, 170 | reference=True) 171 | 172 | 173 | ft_ref_model = model_generator.\ 174 | create_policy_from_config(model, trainer=trainer, 175 | local_dirs=local_dir, 176 | reference=True) 177 | 178 | self.sft_ref_model = model_generator.load_saved_model( 179 | sft_ref_model, sft_state_dict_path) 180 | 181 | self.ft_ref_model = model_generator.load_saved_model( 182 | ft_ref_model, ft_state_dict_path) 183 | 184 | 185 | def get_batch_preference_loss(self, ft_model: nn.Module, sft_model: nn.Module, 186 | batch: Dict[str, Union[List, torch.LongTensor]], 187 | loss_config: DictConfig): 188 | """Compute the SFT or DPO loss and other metrics for the given batch of inputs.""" 189 | 190 | with torch.no_grad(): 191 | 192 | #Implement gradient accumulation without FSDP compatability at this point: 193 | accumulated_losses = list() 194 | 195 | for i in range(self.config.gradient_accumulation_steps): 196 | global_microbatch = slice_and_move_batch_for_device(batch, i, 197 | self.config.gradient_accumulation_steps, 0) 198 | 199 | #TODO: Implement FSDP calculation here for RHO-Loss policies etc... 200 | 201 | policy_chosen_logps, policy_rejected_logps = concatenated_forward(ft_model, global_microbatch) 202 | reference_chosen_logps, reference_rejected_logps = concatenated_forward(sft_model, global_microbatch) 203 | 204 | if loss_config.name == 'dpo': 205 | loss_kwargs = {'beta': loss_config.beta, 206 | 'reference_free': loss_config.reference_free, 207 | 'label_smoothing': loss_config.label_smoothing, 208 | 'ipo': False} 209 | elif loss_config.name == 'ipo': 210 | loss_kwargs = {'beta': loss_config.beta, 'ipo': True} 211 | else: 212 | raise ValueError(f'unknown loss {loss_config.name}') 213 | 214 | losses, _, _ = preference_loss( 215 | policy_chosen_logps, policy_rejected_logps, 216 | reference_chosen_logps, reference_rejected_logps, 217 | **loss_kwargs) 218 | 219 | accumulated_losses.append(losses) 220 | 221 | #Return the losses accumulated over grad accum steps 222 | return torch.concat(accumulated_losses).to(device=losses.device) 223 | 224 | def select_batch(self, batch:dict, selected_batch_size:int, 225 | policy:nn.Module, ref_policy:nn.Module) -> Iterable: 226 | 227 | #Calculate the batch length and adjust selected batch size: 228 | blen = len(list(batch.values())[0]) 229 | if selected_batch_size > blen: 230 | print('selected batch size:{selected_batch_size} is greater than batch size:{blen}') 231 | selected_batch_size = blen 232 | 233 | #Calculate the ref model losses for the batch: 234 | ref_model_loss = self.get_batch_preference_loss(ft_model=self.ft_ref_model, 235 | sft_model=self.sft_ref_model, 236 | batch=batch, 237 | loss_config=self.config.loss) 238 | 239 | 240 | #These look wrong in initial implementation 241 | model_loss = self.get_batch_preference_loss(ft_model=policy, 242 | sft_model=ref_policy, 243 | batch=batch, 244 | loss_config=self.config.loss) 245 | 246 | rho_loss = model_loss - ref_model_loss 247 | selected_idx, not_selected_idx = self.select_top_k(rho_loss, selected_batch_size) 248 | 249 | selected, not_selected = self.subselect_batch(batch, selected_idx, 250 | None if selected_batch_size == blen \ 251 | else not_selected_idx) 252 | 253 | return selected, not_selected, selected_batch_size 254 | 255 | 256 | 257 | 258 | 259 | 260 | 261 | 262 | 263 | 264 | 265 | 266 | 267 | 268 | 269 | 270 | 271 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /src/groupstuff/global_opinion_data_processing.py: -------------------------------------------------------------------------------- 1 | 2 | import datasets 3 | import torch 4 | import json 5 | from torch.utils.data import DataLoader, Dataset 6 | from src.utils import get_local_dir, TemporarilySeededRandom 7 | from torch.nn.utils.rnn import pad_sequence 8 | from collections import defaultdict 9 | import tqdm 10 | import random 11 | from bs4 import BeautifulSoup, NavigableString 12 | import numpy as np 13 | from typing import Dict, List, Optional, Iterator, Callable, Union, Tuple 14 | import pandas as pd 15 | import ast 16 | import matplotlib.pyplot as plt 17 | from typing import Literal 18 | 19 | COUNTRIES=[ 20 | 'Nigeria', 21 | 'Egypt', 22 | 'India (Current national sample)', 23 | 'China', 24 | 'Japan', 25 | 'Germany', 26 | 'France', 27 | 'Spain', 28 | 'United States', 29 | 'Canada', 30 | 'Brazil', 31 | 'Argentina', 32 | 'Australia', 33 | 'New Zealand' 34 | ] 35 | 36 | 37 | def load_and_prepare_data(dataset_name: str, split: str, group_filter: List[str], cache_dir: str = None): 38 | try: 39 | # Try to load the dataset from the Hugging Face Hub 40 | dataset = datasets.load_dataset(dataset_name, cache_dir=cache_dir)["train"] 41 | except Exception as e: 42 | print(f"Error loading dataset from Hugging Face Hub: {e}") 43 | print("Loading dataset from local storage...") 44 | dataset = datasets.load_dataset("/home/uceesr4/Group-robust-preference-optimization/llm_global_opinions")["train"] 45 | #dataset = datasets.load_dataset(dataset_name, cache_dir=cache_dir)["train"] 46 | #print(dataset) 47 | df = pd.DataFrame(dataset) 48 | df['qkey'] = df.index 49 | 50 | new_selections = [] 51 | new_rows = [] 52 | new_options = [] 53 | for i in range(len(df)): 54 | if not df.loc[i, "question"] or not df.loc[i, "options"]: 55 | continue 56 | selections_str = "{" + df.loc[i, "selections"].split("{")[1].split("}")[0] + "}" 57 | selections_dict = ast.literal_eval(selections_str) 58 | #print(selections_dict,group_filter) 59 | if group_filter and not any(country in selections_dict for country in group_filter): 60 | #for country in group_filter: 61 | # if not selections_dict[country] or len(selections_dict[country])==0 or np.sum(selections_dict[country]) == 0: 62 | continue 63 | 64 | ##one condition missing, need to be checked later 65 | new_selections.append(selections_dict) 66 | new_rows.append(df.loc[i]) 67 | parsed_options = ast.literal_eval(df.loc[i, "options"]) 68 | new_options.append([str(opt) for opt in parsed_options]) 69 | 70 | return pd.DataFrame(new_rows), new_selections, new_options 71 | 72 | def process_data_frame(df, selections,group_filter, options): 73 | df['selections'] = selections 74 | df['options'] = options 75 | df['selections'] = df['selections'].apply(lambda x: [(k, v) for k, v in x.items()]) # create country - selections tuples 76 | df = df.explode('selections', ignore_index=True) 77 | #print(df['selections']) 78 | df[['group', 'prob_y']] = pd.DataFrame(df['selections'].tolist(), index=df.index) 79 | df = df[df['prob_y'].apply(lambda x: x is not None and len(x) > 0 and np.sum(x) > 0)] 80 | return df[df['group'].isin(group_filter)] 81 | 82 | def plot_questions_by_country(df, title_suffix=""): 83 | count = df['group'].value_counts() 84 | count.plot(kind='bar', figsize=(10, 6), color='skyblue') 85 | plt.title(f'Number of Questions by Country {title_suffix}') 86 | plt.xlabel('Country') 87 | plt.ylabel('Number of Questions') 88 | plt.xticks(rotation=45, ha='right') 89 | plt.savefig(f'questions_by_country{title_suffix}.png') 90 | plt.close() 91 | 92 | def create_goqa_data(df,split,train_frac=0.8, multi_pair=False,n_pairs=4): 93 | df_train=df.sample(frac=0.8,random_state=42) 94 | df_test=df.drop(df_train.index) 95 | 96 | df_truetest=df_test.sample(frac=0.5,random_state=42) 97 | df_valtest=df_test.drop(df_truetest.index) 98 | if split=='train': 99 | if train_frac < 0.8: 100 | df=df_train.sample(frac=(train_frac/0.8),random_state=42) 101 | else: 102 | df=df_train 103 | elif split=='test': 104 | df=df_test 105 | elif split=='truetest': 106 | df=df_truetest 107 | elif split=='valtest': 108 | df=df_valtest 109 | else: 110 | print(split) 111 | raise Exception('incorrect split') 112 | grouped = df.groupby('group') 113 | data = defaultdict(lambda: defaultdict(list)) 114 | for group_name, group_data in grouped: 115 | for qkey, sub_group in group_data.groupby('qkey'): 116 | 117 | question = sub_group['question'].values[0] 118 | 119 | # Process options, excluding any invalid ones 120 | options = sub_group['options'].values[0] 121 | #####treat refused option separately-----options = [opt for opt in options if opt != "Refused"] 122 | 123 | # Construct the prompt 124 | prompt = f"Opinion of people in {group_name} on: {question}\nPlease select the best response:" 125 | 126 | # Generate the full prompt with options 127 | letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'[:len(options)] 128 | for opt, letter in zip(options, letters): 129 | prompt += f"\n{letter}. {opt}" 130 | #print(sub_group) 131 | responses = [letter for letter in letters[:len(options)]] 132 | 133 | prob_y = torch.tensor(np.stack(sub_group['prob_y'].values), dtype=torch.float).squeeze() 134 | ranks=torch.argsort(prob_y) 135 | pairs = [(ranks[i], ranks[j]) for i in range(len(ranks)) for j in range(i)] 136 | correct_response_index = ranks[-1] 137 | correct_response = responses[ranks[-1]] 138 | 139 | data[prompt]['sft_target'] = correct_response 140 | data[prompt]['responses'] = responses 141 | if multi_pair: 142 | data[prompt]['pairs']=random.sample(pairs,min(n_pairs,len(pairs))) 143 | else: 144 | wrong_indices = [i for i in range(len(options)) if i != correct_response_index] 145 | if wrong_indices: 146 | wrong_response_index = random.choice(wrong_indices) 147 | data[prompt]['pairs']=[(correct_response_index,wrong_response_index)] 148 | #print(data[prompt]) 149 | #print(len(data)) 150 | return data 151 | 152 | def get_goqa(split: str, train_frac: float = 0.8, group_id: int = None, multi_pair: bool = False,n_pairs: int=4, silent: bool = False, cache_dir: str = None): 153 | if group_id==None: 154 | group_filter = COUNTRIES 155 | else: 156 | group_filter = [COUNTRIES[group_id]] 157 | df, selections, options = load_and_prepare_data("Anthropic/llm_global_opinions", split, group_filter, cache_dir) 158 | df = process_data_frame(df, selections, group_filter, options) 159 | plot_questions_by_country(df, title_suffix=f" {split} with groups {' '.join(group_filter)}") 160 | #print(group_id,group_filter) 161 | return create_goqa_data(df=df,split=split,train_frac=train_frac,multi_pair= multi_pair,n_pairs=n_pairs) 162 | 163 | 164 | 165 | ##test 166 | # Example usage: 167 | #data_train = get_goqa("train", group_filter=["USA", "UK"]) 168 | #data_test = get_goqa("test", multi_response=True) 169 | #print(data_train) 170 | #print(data_test) 171 | 172 | 173 | def create_goqa_data_alt( 174 | df: pd.DataFrame, 175 | split: Literal['train','test'], 176 | train_frac: float = 0.8, 177 | multi_response: bool = False, 178 | option_mode: Literal['preferred_least_min_gap','balanced'] | None = None, 179 | ) -> dict[list]: 180 | # Sampling train and test sets 181 | df_train = df.sample(frac=0.8, random_state=42) 182 | df_test = df.drop(df_train.index) 183 | 184 | # Selecting the data split 185 | if split == 'train': 186 | if train_frac < 0.8: 187 | df = df_train.sample(frac=(train_frac / 0.8), random_state=42) 188 | else: 189 | df = df_train 190 | elif split == 'test': 191 | df = df_test 192 | else: 193 | raise ValueError('incorrect split value') 194 | 195 | grouped = df.groupby('group') 196 | data = defaultdict(lambda: defaultdict(list)) 197 | 198 | if option_mode == 'preferred_least_min_gap': 199 | group_sizes = grouped.size() 200 | sorted_groups = group_sizes.sort_values() 201 | cumulative_sizes = sorted_groups.cumsum() 202 | half_data_point = cumulative_sizes.iloc[-1] / 2 203 | closest_split_index = (cumulative_sizes - half_data_point).abs().argmin() 204 | preferred_groups = set(sorted_groups.iloc[:closest_split_index + 1].index) 205 | minimal_gap_groups = set(sorted_groups.iloc[closest_split_index + 1:].index) 206 | 207 | # Process each group 208 | for group_name, group_data in grouped: 209 | for qkey, sub_group in group_data.groupby('qkey'): 210 | prob_y = torch.tensor(np.stack(sub_group['prob_y'].values), dtype=torch.float).squeeze() 211 | prompt = f"{sub_group['questions'].values[0]} Opinion of people in country: {group_name} is" 212 | options = sub_group['options'].values[0] 213 | 214 | if option_mode == 'preferred_least_min_gap': 215 | if group_name in preferred_groups: 216 | # Most and least preferred selection 217 | max_index = torch.argmax(prob_y) 218 | min_index = torch.argmin(prob_y) 219 | selected_indices = [max_index, min_index] 220 | else: 221 | # Minimal gap selection 222 | sorted_indices = np.argsort(prob_y) 223 | min_gap = float('inf') 224 | selected_pair = (0, 1) 225 | for i in range(len(prob_y) - 1): 226 | gap = prob_y[sorted_indices[i+1]] - prob_y[sorted_indices[i]] 227 | if gap < min_gap: 228 | min_gap = gap 229 | selected_pair = (sorted_indices[i], sorted_indices[i+1]) 230 | selected_indices = list(selected_pair) 231 | elif option_mode == 'balanced': 232 | selected_indices = np.random.choice(len(options), 2, replace=False) 233 | else: 234 | # Default behavior or other modes can be added here 235 | max_index = torch.argmax(prob_y) 236 | selected_indices = [max_index, np.random.choice([i for i in range(len(options)) if i != max_index])] 237 | 238 | # Adding the correct and selected responses 239 | for index in selected_indices: 240 | if multi_response: 241 | data[prompt]['responses'].append(options[index]) 242 | else: 243 | data[prompt]['pairs'].append((len(data[prompt]['responses']), len(data[prompt]['responses']) + 1)) 244 | data[prompt]['responses'].extend([options[selected_indices[0]], options[index]]) 245 | if multi_response: 246 | for i, option in enumerate(sub_group['options'].values[0]): 247 | if i != correct_response_index: 248 | data[prompt]['pairs'].append((len(data[prompt]['responses']), len(data[prompt]['responses'])+1)) 249 | data[prompt]['responses'].extend([correct_response, option]) 250 | else: 251 | wrong_indices = [i for i in range(len(sub_group['options'].values[0])) if i != correct_response_index] 252 | if wrong_indices: 253 | wrong_option_index = np.random.choice(wrong_indices) 254 | wrong_response = sub_group['options'].values[0][wrong_option_index] 255 | data[prompt]['pairs'].append((len(data[prompt]['responses']), len(data[prompt]['responses'])+1)) 256 | data[prompt]['responses'].extend([correct_response, wrong_response]) 257 | return data -------------------------------------------------------------------------------- /src/groupstuff/global_opinion_data_processing_kfold.py: -------------------------------------------------------------------------------- 1 | 2 | import datasets 3 | import torch 4 | import json 5 | from torch.utils.data import DataLoader, Dataset 6 | from src.utils import get_local_dir, TemporarilySeededRandom 7 | from torch.nn.utils.rnn import pad_sequence 8 | from collections import defaultdict 9 | import tqdm 10 | import random 11 | from bs4 import BeautifulSoup, NavigableString 12 | import numpy as np 13 | from typing import Dict, List, Optional, Iterator, Callable, Union, Tuple 14 | import pandas as pd 15 | import ast 16 | import matplotlib.pyplot as plt 17 | from typing import Literal 18 | from sklearn.model_selection import KFold 19 | 20 | COUNTRIES=[ 21 | 'Nigeria', 22 | 'Egypt', 23 | 'India (Current national sample)', 24 | 'China', 25 | 'Japan', 26 | 'Germany', 27 | 'France', 28 | 'Spain', 29 | 'United States', 30 | 'Canada', 31 | 'Brazil', 32 | 'Argentina', 33 | 'Australia', 34 | 'New Zealand' 35 | ] 36 | 37 | 38 | def load_and_prepare_data(dataset_name: str, split: str, group_filter: List[str], cache_dir: str = None): 39 | dataset = datasets.load_dataset(dataset_name, cache_dir=cache_dir)["train"] 40 | #print(dataset) 41 | df = pd.DataFrame(dataset) 42 | df['qkey'] = df.index 43 | 44 | new_selections = [] 45 | new_rows = [] 46 | new_options = [] 47 | for i in range(len(df)): 48 | if not df.loc[i, "question"] or not df.loc[i, "options"]: 49 | continue 50 | selections_str = "{" + df.loc[i, "selections"].split("{")[1].split("}")[0] + "}" 51 | selections_dict = ast.literal_eval(selections_str) 52 | #print(selections_dict,group_filter) 53 | if group_filter and not any(country in selections_dict for country in group_filter): 54 | #for country in group_filter: 55 | # if not selections_dict[country] or len(selections_dict[country])==0 or np.sum(selections_dict[country]) == 0: 56 | continue 57 | 58 | ##one condition missing, need to be checked later 59 | new_selections.append(selections_dict) 60 | new_rows.append(df.loc[i]) 61 | parsed_options = ast.literal_eval(df.loc[i, "options"]) 62 | new_options.append([str(opt) for opt in parsed_options]) 63 | 64 | return pd.DataFrame(new_rows), new_selections, new_options 65 | 66 | def process_data_frame(df, selections,group_filter, options): 67 | df['selections'] = selections 68 | df['options'] = options 69 | df['selections'] = df['selections'].apply(lambda x: [(k, v) for k, v in x.items()]) # create country - selections tuples 70 | df = df.explode('selections', ignore_index=True) 71 | #print(df['selections']) 72 | df[['group', 'prob_y']] = pd.DataFrame(df['selections'].tolist(), index=df.index) 73 | df = df[df['prob_y'].apply(lambda x: x is not None and len(x) > 0 and np.sum(x) > 0)] 74 | return df[df['group'].isin(group_filter)] 75 | 76 | def plot_questions_by_country(df, title_suffix=""): 77 | count = df['group'].value_counts() 78 | count.plot(kind='bar', figsize=(10, 6), color='skyblue') 79 | plt.title(f'Number of Questions by Country {title_suffix}') 80 | plt.xlabel('Country') 81 | plt.ylabel('Number of Questions') 82 | plt.xticks(rotation=45, ha='right') 83 | plt.savefig(f'questions_by_country{title_suffix}.png') 84 | plt.close() 85 | 86 | def create_goqa_data(df,split,train_frac=0.8, multi_pair=False,n_pairs=4,split_idx=0): 87 | num_folds = 5 # Example: 5-fold cross-validation 88 | kf = KFold(n_splits=num_folds, shuffle=True, random_state=42) 89 | splits=list(kf.split(df)) 90 | 91 | if split_idx>=num_folds: 92 | raise ValueError('invalid split id') 93 | 94 | train_index, test_index = splits[split_idx] 95 | df_train = df.iloc[train_index] 96 | df_test = df.iloc[test_index] 97 | 98 | #df_train=df.sample(frac=0.8,random_state=42) 99 | #df_test=df.drop(df_train.index) 100 | 101 | df_truetest=df_test.sample(frac=0.5,random_state=42) 102 | df_valtest=df_test.drop(df_truetest.index) 103 | if split=='train': 104 | if train_frac < 0.8: 105 | df=df_train.sample(frac=(train_frac/0.8),random_state=42) 106 | else: 107 | df=df_train 108 | elif split=='test': 109 | df=df_test 110 | elif split=='truetest': 111 | df=df_truetest 112 | elif split=='valtest': 113 | df=df_valtest 114 | else: 115 | print(split) 116 | raise Exception('incorrect split') 117 | grouped = df.groupby('group') 118 | data = defaultdict(lambda: defaultdict(list)) 119 | for group_name, group_data in grouped: 120 | for qkey, sub_group in group_data.groupby('qkey'): 121 | 122 | question = sub_group['question'].values[0] 123 | 124 | # Process options, excluding any invalid ones 125 | options = sub_group['options'].values[0] 126 | #####treat refused option separately-----options = [opt for opt in options if opt != "Refused"] 127 | 128 | # Construct the prompt 129 | prompt = f"Opinion of people in {group_name} on: {question}\nPlease select the best response:" 130 | 131 | # Generate the full prompt with options 132 | letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'[:len(options)] 133 | for opt, letter in zip(options, letters): 134 | prompt += f"\n{letter}. {opt}" 135 | #print(sub_group) 136 | responses = [letter for letter in letters[:len(options)]] 137 | 138 | prob_y = torch.tensor(np.stack(sub_group['prob_y'].values), dtype=torch.float).squeeze() 139 | ranks=torch.argsort(prob_y) 140 | pairs = [(ranks[i], ranks[j]) for i in range(len(ranks)) for j in range(i)] 141 | correct_response_index = ranks[-1] 142 | correct_response = responses[ranks[-1]] 143 | 144 | data[prompt]['sft_target'] = correct_response 145 | data[prompt]['responses'] = responses 146 | if multi_pair: 147 | data[prompt]['pairs']=random.sample(pairs,min(n_pairs,len(pairs))) 148 | else: 149 | wrong_indices = [i for i in range(len(options)) if i != correct_response_index] 150 | if wrong_indices: 151 | wrong_response_index = random.choice(wrong_indices) 152 | data[prompt]['pairs']=[(correct_response_index,wrong_response_index)] 153 | #print(data[prompt]) 154 | #print(len(data)) 155 | return data 156 | 157 | def get_goqa_kfold(split: str, train_frac: float = 0.8, group_id: int = None, multi_pair: bool = False,n_pairs: int=4, silent: bool = False, cache_dir: str = None, split_idx=0): 158 | if group_id==None: 159 | group_filter = COUNTRIES 160 | else: 161 | group_filter = [COUNTRIES[group_id]] 162 | df, selections, options = load_and_prepare_data("Anthropic/llm_global_opinions", split, group_filter, cache_dir) 163 | df = process_data_frame(df, selections, group_filter, options) 164 | plot_questions_by_country(df, title_suffix=f" {split} with groups {' '.join(group_filter)}") 165 | #print(group_id,group_filter) 166 | return create_goqa_data(df=df,split=split,train_frac=train_frac,multi_pair= multi_pair,n_pairs=n_pairs,split_idx=split_idx) 167 | 168 | 169 | 170 | ##test 171 | # Example usage: 172 | #data_train = get_goqa("train", group_filter=["USA", "UK"]) 173 | #data_test = get_goqa("test", multi_response=True) 174 | #print(data_train) 175 | #print(data_test) 176 | 177 | 178 | def create_goqa_data_alt( 179 | df: pd.DataFrame, 180 | split: Literal['train','test'], 181 | train_frac: float = 0.8, 182 | multi_response: bool = False, 183 | option_mode: Literal['preferred_least_min_gap','balanced'] | None = None, 184 | ) -> dict[list]: 185 | # Sampling train and test sets 186 | df_train = df.sample(frac=0.8, random_state=42) 187 | df_test = df.drop(df_train.index) 188 | 189 | # Selecting the data split 190 | if split == 'train': 191 | if train_frac < 0.8: 192 | df = df_train.sample(frac=(train_frac / 0.8), random_state=42) 193 | else: 194 | df = df_train 195 | elif split == 'test': 196 | df = df_test 197 | else: 198 | raise ValueError('incorrect split value') 199 | 200 | grouped = df.groupby('group') 201 | data = defaultdict(lambda: defaultdict(list)) 202 | 203 | if option_mode == 'preferred_least_min_gap': 204 | group_sizes = grouped.size() 205 | sorted_groups = group_sizes.sort_values() 206 | cumulative_sizes = sorted_groups.cumsum() 207 | half_data_point = cumulative_sizes.iloc[-1] / 2 208 | closest_split_index = (cumulative_sizes - half_data_point).abs().argmin() 209 | preferred_groups = set(sorted_groups.iloc[:closest_split_index + 1].index) 210 | minimal_gap_groups = set(sorted_groups.iloc[closest_split_index + 1:].index) 211 | 212 | # Process each group 213 | for group_name, group_data in grouped: 214 | for qkey, sub_group in group_data.groupby('qkey'): 215 | prob_y = torch.tensor(np.stack(sub_group['prob_y'].values), dtype=torch.float).squeeze() 216 | prompt = f"{sub_group['questions'].values[0]} Opinion of people in country: {group_name} is" 217 | options = sub_group['options'].values[0] 218 | 219 | if option_mode == 'preferred_least_min_gap': 220 | if group_name in preferred_groups: 221 | # Most and least preferred selection 222 | max_index = torch.argmax(prob_y) 223 | min_index = torch.argmin(prob_y) 224 | selected_indices = [max_index, min_index] 225 | else: 226 | # Minimal gap selection 227 | sorted_indices = np.argsort(prob_y) 228 | min_gap = float('inf') 229 | selected_pair = (0, 1) 230 | for i in range(len(prob_y) - 1): 231 | gap = prob_y[sorted_indices[i+1]] - prob_y[sorted_indices[i]] 232 | if gap < min_gap: 233 | min_gap = gap 234 | selected_pair = (sorted_indices[i], sorted_indices[i+1]) 235 | selected_indices = list(selected_pair) 236 | elif option_mode == 'balanced': 237 | selected_indices = np.random.choice(len(options), 2, replace=False) 238 | else: 239 | # Default behavior or other modes can be added here 240 | max_index = torch.argmax(prob_y) 241 | selected_indices = [max_index, np.random.choice([i for i in range(len(options)) if i != max_index])] 242 | 243 | # Adding the correct and selected responses 244 | for index in selected_indices: 245 | if multi_response: 246 | data[prompt]['responses'].append(options[index]) 247 | else: 248 | data[prompt]['pairs'].append((len(data[prompt]['responses']), len(data[prompt]['responses']) + 1)) 249 | data[prompt]['responses'].extend([options[selected_indices[0]], options[index]]) 250 | if multi_response: 251 | for i, option in enumerate(sub_group['options'].values[0]): 252 | if i != correct_response_index: 253 | data[prompt]['pairs'].append((len(data[prompt]['responses']), len(data[prompt]['responses'])+1)) 254 | data[prompt]['responses'].extend([correct_response, option]) 255 | else: 256 | wrong_indices = [i for i in range(len(sub_group['options'].values[0])) if i != correct_response_index] 257 | if wrong_indices: 258 | wrong_option_index = np.random.choice(wrong_indices) 259 | wrong_response = sub_group['options'].values[0][wrong_option_index] 260 | data[prompt]['pairs'].append((len(data[prompt]['responses']), len(data[prompt]['responses'])+1)) 261 | data[prompt]['responses'].extend([correct_response, wrong_response]) 262 | return data -------------------------------------------------------------------------------- /src/models.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import json 4 | 5 | import torch 6 | import transformers 7 | from transformers import BitsAndBytesConfig 8 | from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model, prepare_model_for_kbit_training 9 | from peft.tuners.lora import LoraLayer 10 | 11 | from src.utils import get_local_dir 12 | from omegaconf.listconfig import ListConfig 13 | 14 | class ModelGenerator: 15 | 16 | 17 | def load_saved_model(self, model, model_state_dict_path): 18 | 19 | #Load the state dictionary into memory 20 | state_dict = torch.load(model_state_dict_path, map_location='cpu') 21 | step, metrics = state_dict['step_idx'], state_dict['metrics'] 22 | 23 | #Load state dict into policy and ref model: 24 | print(f'loading pre-trained weights at step {step} from\ 25 | {model_state_dict_path} with metrics {json.dumps(metrics, indent=2)}') 26 | 27 | #load_state_dict moves weights onto the model's device 28 | model.load_state_dict(state_dict['state']) 29 | 30 | return model 31 | 32 | def create_policy_from_config(self, model_config, trainer:str, local_dirs, reference:bool=False): 33 | 34 | model_kwargs = {'device_map': 'balanced'} if trainer in {'BasicTrainer','GroupTrainer','GroupTrainerDebug'} else {} 35 | 36 | if reference: 37 | dtype = model_config.policy_dtype 38 | else: 39 | dtype = model_config.reference_dtype 40 | 41 | bnb_config = BitsAndBytesConfig( 42 | load_in_4bit=True, 43 | bnb_4bit_use_double_quant=True, 44 | bnb_4bit_quant_type="nf4", 45 | bnb_4bit_compute_dtype=torch.bfloat16 46 | ) 47 | 48 | #Load model from Huggingface: 49 | policy = transformers.AutoModelForCausalLM.from_pretrained( 50 | model_config.name_or_path, 51 | cache_dir=local_dirs, 52 | low_cpu_mem_usage=True, 53 | quantization_config=bnb_config, 54 | output_hidden_states=True, 55 | trust_remote_code=True, 56 | **model_kwargs) 57 | 58 | policy.gradient_checkpointing_enable() 59 | 60 | #Setup model with LoRA: 61 | if model_config.use_lora: 62 | policy = prepare_model_for_kbit_training(policy) 63 | 64 | target_modules = model_config.lora_target_modules 65 | 66 | assert isinstance(target_modules, ListConfig) or isinstance(target_modules, list),\ 67 | f'lora_target_modules type:{type(target_modules)} must be type ListConfig or list' 68 | 69 | loraconfig = LoraConfig( 70 | r=model_config.lora_rank, 71 | lora_alpha=model_config.lora_alpha, 72 | target_modules=target_modules, 73 | lora_dropout=model_config.lora_dropout, 74 | bias="none", 75 | task_type="CAUSAL_LM") 76 | 77 | #Apply lora config to policy model: 78 | policy = get_peft_model(policy, loraconfig) 79 | policy = self.manually_map_lora_to_dtype(policy, getattr(torch,dtype)) 80 | 81 | print('Current GPU usage') 82 | 83 | for dev in range(torch.cuda.device_count()): 84 | print(f"dev {dev}, torch.cuda.memory_allocated: %fGB"%(torch.cuda.memory_allocated(dev)/1024/1024/1024)) 85 | print(f"dev {dev}, torch.cuda.memory_reserved: %fGB"%(torch.cuda.memory_reserved(dev)/1024/1024/1024)) 86 | print(f"dev {dev}, torch.cuda.max_memory_reserved: %fGB"%(torch.cuda.max_memory_reserved(dev)/1024/1024/1024)) 87 | 88 | print(f'Loaded model onto device: {policy.device}') 89 | 90 | return policy 91 | 92 | 93 | 94 | def create_policy(self, model_name, dtype, config, use_lora:bool=False, 95 | lora_rank:int=8, lora_alpha:int=32, lora_dropout:float=0.0): 96 | """ 97 | Load a model from huggingface AutoModelForCausalLLM, apply a bitsandbytes 98 | config file and if required setup the model to use lora. 99 | 100 | Parameters 101 | ---------- 102 | model_name : str 103 | Huggingface model name or path 104 | dtype : 105 | float point precision to map the weights of the loaded model to. 106 | use_lora : bool, optional 107 | DESCRIPTION. The default is False. 108 | lora_rank : int, optional 109 | DESCRIPTION. The default is 8. 110 | lora_alpha : int, optional 111 | DESCRIPTION. The default is 32. 112 | lora_dropout : float, optional 113 | DESCRIPTION. The default is 0.0. 114 | 115 | Returns 116 | ------- 117 | policy : nn.Module 118 | Huggingface LLM module setup with dtype precision and lora training weights 119 | 120 | """ 121 | #Setup model and bitsandbytes conifgs: 122 | model_kwargs = {'device_map': 'balanced'} if config.trainer in {'BasicTrainer','GroupTrainer','GroupTrainerDebug'} else {} 123 | 124 | compute_dtype = getattr(torch, dtype) 125 | bnb_config = BitsAndBytesConfig( 126 | load_in_4bit=True, 127 | bnb_4bit_use_double_quant=True, 128 | bnb_4bit_quant_type="nf4", 129 | bnb_4bit_compute_dtype=torch.bfloat16 130 | ) 131 | 132 | #Load model from Huggingface: 133 | policy = transformers.AutoModelForCausalLM.from_pretrained( 134 | model_name, 135 | cache_dir=get_local_dir(config.local_dirs), 136 | low_cpu_mem_usage=True, 137 | quantization_config=bnb_config, 138 | output_hidden_states=True, 139 | trust_remote_code=True, 140 | **model_kwargs) 141 | 142 | policy.gradient_checkpointing_enable() 143 | 144 | #Setup model with LoRA: 145 | if use_lora: 146 | policy = prepare_model_for_kbit_training(policy) 147 | 148 | target_modules = config.model.lora_target_modules 149 | 150 | assert isinstance(target_modules, ListConfig) or isinstance(target_modules, list),\ 151 | f'lora_target_modules type:{type(target_modules)} must be type ListConfig or list' 152 | 153 | loraconfig = LoraConfig( 154 | r=lora_rank, 155 | lora_alpha=lora_alpha, 156 | target_modules=target_modules, 157 | lora_dropout=lora_dropout, 158 | bias="none", 159 | task_type="CAUSAL_LM") 160 | 161 | #Apply lora config to policy model: 162 | policy = get_peft_model(policy, loraconfig) 163 | policy = self.manually_map_lora_to_dtype(policy, getattr(torch,dtype)) 164 | 165 | 166 | print('Current GPU usage') 167 | 168 | for dev in range(torch.cuda.device_count()): 169 | print("torch.cuda.memory_allocated: %fGB"%(torch.cuda.memory_allocated(dev)/1024/1024/1024)) 170 | print("torch.cuda.memory_reserved: %fGB"%(torch.cuda.memory_reserved(dev)/1024/1024/1024)) 171 | print("torch.cuda.max_memory_reserved: %fGB"%(torch.cuda.max_memory_reserved(dev)/1024/1024/1024)) 172 | 173 | print(f'Loaded model onto device: {policy.device}') 174 | 175 | return policy 176 | 177 | def manually_map_lora_to_dtype(self, policy, dtype): 178 | """ 179 | Maps a model setup with LoRA layers to the specified dtype. This is used 180 | after the lora config has been applied to a huggingface model loaded with a 181 | specific dtype. 182 | 183 | Parameters 184 | ---------- 185 | policy : TYPE 186 | DESCRIPTION. 187 | dtype : TYPE 188 | DESCRIPTION. 189 | 190 | Returns 191 | ------- 192 | policy : TYPE 193 | DESCRIPTION. 194 | 195 | """ 196 | 197 | for name, module in policy.named_modules(): 198 | if isinstance(module, LoraLayer): 199 | module = module.to(dtype) 200 | if 'norm' in name: 201 | module = module.to(dtype) 202 | if hasattr(module, 'weight'): 203 | if module.weight.dtype == torch.float32: 204 | module = module.to(dtype) 205 | if 'lm_head' in name or 'embed_tokens' in name: 206 | if hasattr(module, 'weight'): 207 | if module.weight.dtype == torch.float32: 208 | module = module.to(dtype) 209 | 210 | return policy 211 | 212 | 213 | def generate_models(self, config): 214 | """ 215 | Return a dictionary with the relevant models create and sorted 216 | 217 | Parameters 218 | ---------- 219 | config : dict 220 | Config file containing the relevant parameters 221 | 222 | Raises 223 | ------ 224 | NotImplementedError 225 | The config.loss.name has not been implemented yet 226 | 227 | Returns 228 | ------- 229 | models : dict 230 | A dictionary of policy models 231 | """ 232 | 233 | if config.loss.name == 'sft': 234 | 235 | sft_model = self.create_policy(config.model.name_or_path, 236 | config.model.policy_dtype, 237 | config, 238 | use_lora=config.model.use_lora, 239 | lora_rank=config.model.lora_rank, 240 | lora_alpha=config.model.lora_alpha, 241 | lora_dropout=config.model.lora_dropout) 242 | 243 | models = {'sft_model': sft_model} 244 | 245 | elif config.loss.name == 'base': 246 | 247 | base_model = self.create_policy(config.model.name_or_path, 248 | config.model.policy_dtype, 249 | config, 250 | use_lora=config.model.use_lora, 251 | lora_rank=config.model.lora_rank, 252 | lora_alpha=config.model.lora_alpha, 253 | lora_dropout=config.model.lora_dropout) 254 | 255 | models = {'base_model': base_model} 256 | print(base_model.device,'base-model-device') 257 | elif config.loss.name in ['dpo', 'ipo', 'rdpo', 'ripo']: 258 | 259 | #create main policy: 260 | policy_model = self.create_policy(config.model.name_or_path, 261 | config.model.policy_dtype, 262 | config, 263 | use_lora=config.model.use_lora, 264 | lora_rank=config.model.lora_rank, 265 | lora_alpha=config.model.lora_alpha, 266 | lora_dropout=config.model.lora_dropout) 267 | 268 | #create the reference policy: 269 | ref_model = self.create_policy(config.model.name_or_path, 270 | config.model.reference_dtype, 271 | config, 272 | use_lora=config.model.use_lora, 273 | lora_rank=config.model.lora_rank, 274 | lora_alpha=config.model.lora_alpha, 275 | lora_dropout=config.model.lora_dropout) 276 | 277 | policy_device = policy_model.device 278 | ref_device = ref_model.device 279 | 280 | #Check sft model is asserted: 281 | if config.assert_sft_step: 282 | assert config.model.archive is not None,\ 283 | 'config.model.archive should be provided when training with PO methods' 284 | 285 | #Load the previous model state dict and upload to policy and reference model: 286 | if config.model.archive is not None: 287 | 288 | #Load state dict: 289 | state_dict = torch.load(config.model.archive, map_location='cpu') 290 | step, metrics = state_dict['step_idx'], state_dict['metrics'] 291 | 292 | #Load state dict into policy and ref model: 293 | print(f'loading pre-trained weights at step {step} from\ 294 | {config.model.archive} with metrics {json.dumps(metrics, indent=2)}') 295 | policy_model.load_state_dict(state_dict['state']) 296 | ref_model.load_state_dict(state_dict['state']) 297 | 298 | print('Loaded pretrained weights') 299 | 300 | #Ensure the device hasn't changed at this step: 301 | assert (policy_model.device == policy_device) and (ref_model.device == ref_device), \ 302 | 'The policy and reference models device should not change' 303 | 304 | models = {'policy_model': policy_model, 305 | 'ref_model': ref_model} 306 | 307 | else: 308 | raise NotImplementedError( 309 | f'config.loss.name: {config.loss.name} not implemented yet') 310 | 311 | return models 312 | 313 | if __name__ == '__main__': 314 | 315 | #TODO: Write some tests: 316 | print('Model Generator Tests') 317 | 318 | -------------------------------------------------------------------------------- /src/groupstuff/data_processing.py: -------------------------------------------------------------------------------- 1 | import datasets 2 | import torch 3 | import ast 4 | import json 5 | #from torch.utils.data import DataLoader, Dataset 6 | #from src.utils import get_local_dir, TemporarilySeededRandom 7 | #from torch.nn.utils.rnn import pad_sequence 8 | from collections import defaultdict 9 | import tqdm 10 | import random 11 | #from bs4 import BeautifulSoup, NavigableString 12 | import numpy as np 13 | from typing import Literal, Dict, List, Optional, Iterator, Callable, Union, Tuple 14 | import pandas as pd 15 | import ast 16 | import matplotlib.pyplot as plt 17 | 18 | COUNTRIES=[ 19 | 'Nigeria', 20 | 'Egypt', 21 | 'India (Current national sample)', 22 | 'China', 23 | 'Japan', 24 | 'Germany', 25 | 'France', 26 | 'Spain', 27 | 'United States', 28 | 'Canada', 29 | 'Brazil', 30 | 'Argentina', 31 | 'Australia', 32 | 'New Zealand' 33 | ] 34 | 35 | def get_dataset(name: str, split: str, silent: bool = False, cache_dir: str = None): 36 | """Load the given dataset by name. Supported by default are 'shp', 'hh', and 'se'.""" 37 | if name == 'shp': 38 | data = get_shp(split, silent=silent, cache_dir=cache_dir) 39 | elif name == 'hh': 40 | data = get_hh(split, silent=silent, cache_dir=cache_dir) 41 | elif name == 'se': 42 | data = get_se(split, silent=silent, cache_dir=cache_dir) 43 | elif name == 'jeopardy': 44 | data = get_jeopardy(split, silent=silent, cache_dir=cache_dir) 45 | elif "jeopardy" in name: 46 | value = name.split("_")[-1] 47 | if value!='final': 48 | value=int(value) 49 | data = get_jeopardy_value(split, value, silent=silent, cache_dir=cache_dir) 50 | elif name.startswith('oqa'): 51 | if name == 'oqa': 52 | data = get_oqa(split, silent=silent, cache_dir=cache_dir) 53 | else: 54 | name, attribute, group = name.split('_') 55 | # should be something like oqa_SEX_male 56 | data = get_oqa_group(split, attribute, group, silent=silent, cache_dir=cache_dir) 57 | elif name == 'hel': 58 | data = get_hel(split, silent=silent, cache_dir=cache_dir) 59 | elif name == 'helon': 60 | data = get_helon(split, silent=silent, cache_dir=cache_dir) 61 | elif name == 'helrej': 62 | data = get_helrej(split, silent=silent, cache_dir=cache_dir) 63 | elif name == 'heltot': 64 | data = get_heltot(split, silent=silent, cache_dir=cache_dir) 65 | elif name == 'har': 66 | data = get_har(split, silent=silent, cache_dir=cache_dir) 67 | elif 'reddit' in name: 68 | group_id = int(name.split('_')[-1]) 69 | split = 'validation' if split == 'test' else split 70 | data = get_reddit(split, group_id, silent=silent, cache_dir=cache_dir) 71 | elif 'hel_' in name: 72 | change_split=name.split('_')[-1] 73 | if 'train' in split: 74 | split=f'train[:{change_split}%]' 75 | #else: 76 | #split=f'test[:{change_split}%]' 77 | data = get_hel(split, silent=silent, cache_dir=cache_dir) 78 | elif name == 'GOqa': 79 | data=get_goqa(split, silent=silent, cache_dir=cache_dir) 80 | elif name == 'GOqMa': 81 | data=get_goqa_multiple(split, silent=silent, cache_dir=cache_dir) 82 | elif 'GOqa' in name: 83 | group_id=int(name.split('_')[-1]) 84 | data=get_goqa_group(split, group_id, silent=silent, cache_dir=cache_dir) 85 | elif 'GOqMa' in name: 86 | group_id=int(name.split('_')[-1]) 87 | data=get_goqa_group_multiple(split, group_id, silent=silent, cache_dir=cache_dir) 88 | 89 | else: 90 | raise ValueError(f"Unknown dataset '{name}'") 91 | 92 | assert set(list(data.values())[0].keys()) == {'responses', 'pairs', 'sft_target'}, \ 93 | f"Unexpected keys in dataset: {list(list(data.values())[0].keys())}" 94 | 95 | return data 96 | 97 | def get_oqa( 98 | split: str, 99 | attribute: str, 100 | group: str, 101 | mode: Literal["best-random","best-worst","random"], 102 | multi_pair: bool=False, 103 | n_pairs: int=4, 104 | silent: bool=False, 105 | plot_distr: bool = False, 106 | cache_dir: str = None 107 | ) -> Dict[str, Dict[str, Union[List[Tuple[int, int]], List[str], str]]]: 108 | # TODO: cache_dir is unused currently 109 | # TODO: better abstraction of group to other types except SEX 110 | 111 | OQA_ATTRIBUTES = ['SEX'] 112 | OQA_GROUPS = ['Male','Female'] 113 | OQA_RACE_GROUPS = ['Asian','Other','Black','White','Hispanic'] 114 | 115 | ATTRIBUTE = attribute #OQA_ATTRIBUTES[group_id[0]] 116 | GROUP = group #OQA_GROUPS[group_id[1]] 117 | 118 | if split not in ('test', 'train'): 119 | raise ValueError(f'split {split} not recognized (valid: test, train)') 120 | print(f'Loading GPO (OQA) dataset from file...\n') 121 | df = pd.read_csv(f'src/data/{split}_oqa.csv') 122 | 123 | letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" 124 | 125 | def softmax(x): 126 | e_x = np.exp(x - np.max(x)) 127 | return e_x / e_x.sum(axis=0) 128 | 129 | 130 | def make_prompt_and_responses(elt): 131 | # take a row of the csv and return a prompt and responses 132 | question = elt['question'] 133 | options = ast.literal_eval(elt['options']) 134 | options = [opt for opt in options if opt != "Refused"] 135 | attribute = elt['attribute'] 136 | group = elt['group'] 137 | distribution = elt['D_H'] 138 | numbers_str = distribution.strip('[]').split() 139 | numbers_float = [float(x) for x in numbers_str] 140 | distribution = np.array(numbers_float) 141 | 142 | prompt = f"Answer the following question as if you were {attribute} of {group}: {question}\nRespond with a single letter:" 143 | letters_opt = letters[:len(options)] 144 | for opt, letter in zip(options, letters_opt): 145 | prompt += f"\n{letter}. {opt}" 146 | responses = [letter for letter in letters_opt] 147 | 148 | best_idx = np.argmax(distribution) 149 | worst_idx = np.argmin(distribution) 150 | 151 | if multi_pair is True: 152 | # pairs given as (correct, wrong) based on explicit user preference (deterministic) 153 | ranks = np.argsort(distribution) 154 | pairs = [tuple(sorted([ranks[i], ranks[j]],reverse=True)) for i in range(len(ranks)) for j in range(i)] 155 | pairs = random.sample(pairs,min(n_pairs,len(pairs))) 156 | else: 157 | # single pair (correct,wrong) is the best-preferred (correct) vs least-preferred (wrong) 158 | correct_response_index = best_idx 159 | if mode=='best-worst': 160 | wrong_response_index = worst_idx 161 | pairs = [(correct_response_index,wrong_response_index)] 162 | elif mode=='best-random': 163 | wrong_indices = [i for i in range(len(options)) if i != correct_response_index] 164 | if len(wrong_indices)>0: 165 | wrong_response_index = random.choice(wrong_indices) 166 | pairs = [(correct_response_index,wrong_response_index)] 167 | else: 168 | pairs = [] 169 | elif mode=='random': # according to Bradley-Terry preference distribution (rewards given in OQA data) 170 | distribution_softmax = softmax(distribution) 171 | pair = np.random.choice(np.arange(len(distribution)), size=2, replace=False, p=distribution_softmax) 172 | pairs = [tuple(pair)] 173 | else: 174 | raise ValueError 175 | 176 | sft_target = options[best_idx] # best-preferred option 177 | return prompt, dict(responses=responses, pairs=pairs, sft_target=sft_target) 178 | 179 | def plot_distribution(all_data: Dict[str, Dict]): 180 | correct_idx = [] 181 | wrong_idx = [] 182 | for prompt in all_data: 183 | for pair in all_data[prompt]['pairs']: 184 | correct_idx.append(pair[0]) 185 | wrong_idx.append(pair[1]) 186 | plt.figure() 187 | plt.bar(np.arange(len(correct_idx)), correct_idx, label='correct') 188 | plt.bar(np.arange(len(wrong_idx)), wrong_idx, label='wrong') 189 | plt.legend() 190 | plt.savefig(f'./src/groupstuff/dataload_plt/oqa_distribution_{ATTRIBUTE}_{GROUP}.png') 191 | 192 | all_data = {} 193 | for idx, row in tqdm.tqdm(df.iterrows(), disable=silent, desc="Processing OQA"): 194 | if row['attribute'] == ATTRIBUTE and row['group'] == GROUP: 195 | prompt, data = make_prompt_and_responses(row) 196 | all_data[prompt] = data 197 | 198 | #print('ALL DATA: ', list(all_data.items())[:10]) 199 | 200 | if plot_distr is True: 201 | plot_distribution(all_data) 202 | 203 | return all_data 204 | 205 | def get_oqa_group(split: str, attribute: str, group: str, silent: bool=False, cache_dir: str = None) -> Dict[str, Dict[str, Union[List[Tuple[int, int]], List[str], str]]]: 206 | if split not in ('test', 'train'): 207 | raise ValueError(f'split {split} not recognized (valid: test, train)') 208 | groups = pd.read_csv('data/groups.csv') 209 | # Check if the pair exists in the DataFrame 210 | if not ((groups['attribute'] == attribute) & (groups['group'] == group)).any(): 211 | raise ValueError(f"The pair attribute={attribute}, group={group} is not present in the DataFrame.") 212 | print(f'Loading GPO dataset from file...') 213 | df = pd.read_csv(f'data/{split}_oqa.csv') 214 | letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" 215 | def make_prompt_and_responses(elt): 216 | # take a row of the csv and return a prompt and responses 217 | question = elt['question'] 218 | options = ast.literal_eval(elt['options']) 219 | options = [opt for opt in options if opt != "Refused"] 220 | this_attribute = elt['attribute'] 221 | this_group = elt['group'] 222 | if this_attribute != attribute or this_group != group: 223 | return None, None 224 | distribution = elt['D_H'] 225 | numbers_str = distribution.strip('[]').split() 226 | numbers_float = [float(x) for x in numbers_str] 227 | distribution = np.array(numbers_float) 228 | prompt = f"Answer the following question as if you were {attribute} of {group}: {question}\nRespond with a single letter:" 229 | for opt, letter in zip(options, letters): 230 | prompt += f"\n{letter}. {opt}" 231 | responses = [letter for letter in letters[:len(options)]] 232 | ranks = np.argsort(distribution) 233 | pairs = [(ranks[i], ranks[j]) for i in range(len(ranks)) for j in range(i)] 234 | sft_target = responses[ranks[-1]] 235 | return prompt, dict(responses=responses, pairs=pairs, sft_target=sft_target) 236 | 237 | all_data = {} 238 | for idx, row in tqdm.tqdm(df.iterrows(), disable=silent, desc="Processing OQA"): 239 | prompt, data = make_prompt_and_responses(row) 240 | all_data[prompt] = data 241 | return all_data 242 | 243 | 244 | 245 | 246 | def get_jeopardy_value(split: str, value: int, silent: bool = False, cache_dir: str = None) -> Dict[str, Dict[str, Union[List[Tuple[int, int]], List[str], str]]]: 247 | if split not in ('test', 'train'): 248 | raise ValueError(f'split {split} not recognized (valid: test, train)') 249 | if value not in (200, 400, 600, 800, 1000, 1200, 1600, 2000, 'dd', 'final'): 250 | raise ValueError(f"Jeopardy! dataset requested with value {value} that isn't present") 251 | print(f'Loading Jeopardy! dataset from file...') 252 | with open(f'data/{split}_jeopardy_data.json', 'r') as f: 253 | data = json.load(f) 254 | ''' 255 | data is of the form 256 | 257 | {'category': 'HISTORY', 'air_date': '2004-12-31', 'question': "'For the last 8 years of his life, Galileo was under house arrest for espousing this man's theory'", 'value': '$200', 'answer': 'Copernicus', 'round': 'Jeopardy!', 'show_number': '4680', 'wrong_answer': 'Kepler'} 258 | ''' 259 | # TODO: will need to iterate on prompts to some extent 260 | def make_prompt_and_responses(elt): 261 | category = elt['category'] 262 | question = elt['question'] 263 | if elt['value'] is None: 264 | elt_value = 'final' 265 | elif elt['value'] not in (200, 400, 600, 800, 1000, 1200, 1600, 2000): 266 | elt_value = 'dd' 267 | else: 268 | elt_value = int(elt['value'].replace("$", "").replace(",", "")) 269 | if elt_value != value: 270 | return None, None 271 | answer = elt['answer'] 272 | wrong_answer = elt['wrong_answer'] 273 | prompt = f'{category}, for {value}: {question}' 274 | # change null token to empty string 275 | # responses = [answer, 'null', wrong_answer] 276 | responses = [answer, "", wrong_answer] 277 | pairs = [(0, 1), (0, 2), (1, 2)] 278 | # take a single sample 279 | pairs = [random.choice(pairs)] 280 | return prompt, dict(responses=responses, pairs=pairs, sft_target=answer) 281 | all_data = {} 282 | for row in tqdm.tqdm(data, desc="Processing Jeopardy!", disable=silent): 283 | prompt, data = make_prompt_and_responses(row) 284 | if prompt is None: 285 | continue 286 | all_data[prompt] = data 287 | return all_data 288 | 289 | 290 | def get_jeopardy(split: str, silent: bool = False, cache_dir: str = None) -> Dict[str, Dict[str, Union[List[Tuple[int, int]], List[str], str]]]: 291 | if split not in ('test', 'train'): 292 | raise ValueError(f'split {split} not recognized (valid: test, train)') 293 | print(f'Loading Jeopardy! dataset from file...') 294 | with open(f'data/{split}_jeopardy_data.json', 'r') as f: 295 | data = json.load(f) 296 | ''' 297 | data is of the form 298 | 299 | {'category': 'HISTORY', 'air_date': '2004-12-31', 'question': "'For the last 8 years of his life, Galileo was under house arrest for espousing this man's theory'", 'value': '$200', 'answer': 'Copernicus', 'round': 'Jeopardy!', 'show_number': '4680', 'wrong_answer': 'Kepler'} 300 | ''' 301 | # TODO: will need to iterate on prompts to some extent 302 | def make_prompt_and_responses(elt): 303 | category = elt['category'] 304 | question = elt['question'] 305 | value = elt['value'] 306 | answer = elt['answer'] 307 | wrong_answer = elt['wrong_answer'] 308 | prompt = f'{category}, for {value}: {question}' 309 | # change null token to empty string 310 | # responses = [answer, 'null', wrong_answer] 311 | responses = [answer, "", wrong_answer] 312 | pairs = [(0, 1), (0, 2), (1, 2)] 313 | # take a single sample 314 | pairs = [random.choice(pairs)] 315 | return prompt, dict(responses=responses, pairs=pairs, sft_target=answer) 316 | all_data = {} 317 | for row in tqdm.tqdm(data, desc="Processing Jeopardy!", disable=silent): 318 | prompt, data = make_prompt_and_responses(row) 319 | all_data[prompt] = data 320 | return all_data 321 | 322 | def get_reddit(split: str, group_id, silent: bool = False, cache_dir: str = None) -> Dict[str, Dict[str, Union[List[Tuple[int, int]], List[str], str]]]: 323 | print(f'Loading Reddit TL;DR dataset ({split} split, group_id {group_id}) from Huggingface...') 324 | dataset = datasets.load_dataset('openai/summarize_from_feedback', 'comparisons', split=split, cache_dir=cache_dir) 325 | 326 | def split_prompt_and_responses(ex): 327 | prompt = ex['info']['post'] 328 | chosen_response = ex['summaries'][ex['choice']]['text'] 329 | rejected_response = ex['summaries'][1 - ex['choice']]['text'] 330 | group = ex['info']['subreddit'] 331 | return prompt, chosen_response, rejected_response, group 332 | 333 | data = defaultdict(lambda: defaultdict(list)) 334 | datapoint_per_group = {'relationships': 52346, 'AskReddit': 12963, 'weddingplanning': 644, 'jobs': 782, 'dating_advice': 2257, 'legaladvice': 1769, 'askwomenadvice': 536, 'offmychest': 1447, 'personalfinance': 1900, 'relationship_advice': 7037, 'loseit': 1015, 'needadvice': 305, 'BreakUps': 622, 'GetMotivated': 253, 'Advice': 1802, 'self': 1402, 'Dogtraining': 381, 'pettyrevenge': 580, 'college': 209, 'cats': 300, 'dogs': 482, 'travel': 319, 'books': 190, 'AskDocs': 331, 'Parenting': 356, 'running': 363, 'Cooking': 161, 'Pets': 205, 'tifu': 1901} 335 | # {52346: 'relationships', 12963: 'AskReddit', 7037: 'relationship_advice', 2257: 'dating_advice', 1901: 'tifu', 1900: 'personalfinance', 1802: 'Advice', 1769: 'legaladvice', 1447: 'offmychest', 1402: 'self', 1015: 'loseit'} 336 | group_ids = {0: 'relationships', 1: 'AskReddit', 2: 'relationship_advice', 3: 'dating_advice', 4: 'tifu', 5: 'personalfinance', 6: 'Advice', 7: 'legaladvice', 8: 'offmychest', 9: 'self', 10: 'loseit'} 337 | uniq_topic = {} 338 | count=0 339 | count_2=0 340 | for row in tqdm.tqdm(dataset, desc='Processing Reddit', disable=silent): 341 | prompt, chosen, rejected, group = split_prompt_and_responses(row) 342 | #print(group) 343 | if group_ids[group_id] != group: 344 | count=count+1 345 | continue 346 | 347 | responses = [chosen, rejected] 348 | n_responses = len(data[prompt]['responses']) 349 | if prompt in data: 350 | count_2+=1 351 | data[prompt]['pairs'].append((n_responses, n_responses + 1)) 352 | data[prompt]['responses'].extend(responses) 353 | data[prompt]['sft_target'] = chosen 354 | if group not in uniq_topic: 355 | uniq_topic[group] = 0 356 | uniq_topic[group] += 1 357 | print(count) 358 | print(count_2) 359 | print(len(data)) 360 | return data 361 | 362 | 363 | def get_hh_datasets(split: str, variants: list, silent: bool = False, cache_dir: str = None) -> dict: 364 | """ 365 | Load and merge specific variants of the Anthropic Helpful-Harmless dataset from Huggingface. 366 | 367 | Parameters: 368 | split (str): Dataset split (e.g., 'train', 'test'). 369 | variants (list): List of dataset variants to load (e.g., ['helpful-base', 'helpful-online', 'helpful-rejection-sampled']). 370 | silent (bool): If True, suppress tqdm progress display. 371 | cache_dir (str): Directory for caching the dataset, optional. 372 | 373 | Returns: 374 | dict: A structured dictionary with the combined dataset content formatted for model training or evaluation. 375 | """ 376 | def extract_anthropic_prompt(text): 377 | """Utility function to extract the prompt part from a response.""" 378 | return text.split('\n\nAssistant:')[0] + '\n\nAssistant:' 379 | 380 | def split_prompt_and_responses(ex): 381 | """Splits the dataset entry into prompt and responses.""" 382 | prompt = extract_anthropic_prompt(ex['chosen']) 383 | chosen_response = ex['chosen'][len(prompt):] 384 | rejected_response = ex['rejected'][len(prompt):] 385 | return prompt, chosen_response, rejected_response 386 | 387 | data = defaultdict(lambda: {'responses': [], 'pairs': [], 'sft_target': ''}) 388 | for variant in variants: 389 | print(f'Loading {variant} dataset ({split} split) from Huggingface...') 390 | dataset = datasets.load_dataset('Anthropic/hh-rlhf', data_dir=variant, split=split, cache_dir=cache_dir) 391 | print('done') 392 | 393 | for row in tqdm.tqdm(dataset, desc=f'Processing {variant}', disable=silent): 394 | prompt, chosen, rejected = split_prompt_and_responses(row) 395 | responses = [chosen, rejected] 396 | n_responses = len(data[prompt]['responses']) 397 | data[prompt]['pairs'].append((n_responses, n_responses + 1)) 398 | data[prompt]['responses'].extend(responses) 399 | # Set the sft_target only if not already set (keep the first chosen response encountered) 400 | if not data[prompt]['sft_target']: 401 | data[prompt]['sft_target'] = chosen 402 | 403 | return data 404 | 405 | 406 | def main(): 407 | data = get_oqa('train', 'SEX', 'Male', mode='best-random', plot_distr=False) 408 | #Example of using the function to load and merge three datasets 409 | 410 | #data = load_and_merge_hh_datasets('train', ['helpful-rejection-sampled', 'helpful-online', 'helpful-base']) 411 | 412 | # data = get_oqa('train') 413 | # data = get_jeopardy_value('train', 200) 414 | # data = get_jeopardy_value('train', 'final') 415 | 416 | 417 | if __name__ == "__main__": 418 | main() -------------------------------------------------------------------------------- /src/trainers/basictrainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | torch.backends.cuda.matmul.allow_tf32 = True 3 | import torch.nn.functional as F 4 | import torch.nn as nn 5 | import transformers 6 | from omegaconf import DictConfig 7 | 8 | import torch.distributed as dist 9 | from torch.distributed.fsdp import ( 10 | FullyShardedDataParallel as FSDP, 11 | MixedPrecision, 12 | StateDictType, 13 | BackwardPrefetch, 14 | ShardingStrategy, 15 | CPUOffload, 16 | ) 17 | from torch.distributed.fsdp.api import FullStateDictConfig, FullOptimStateDictConfig 18 | from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy 19 | import tensor_parallel as tp 20 | import contextlib 21 | 22 | from src.preference_datasets import get_batch_iterator 23 | from src.utils import ( 24 | slice_and_move_batch_for_device, 25 | formatted_dict, 26 | all_gather_if_needed, 27 | pad_to_length, 28 | get_block_class_from_model, 29 | rank0_print, 30 | get_local_dir, 31 | ) 32 | from src.data_selection import DataSelector 33 | from src.loss_utils import ( 34 | preference_loss, 35 | _get_batch_logps, 36 | concatenated_inputs) 37 | 38 | import numpy as np 39 | import wandb 40 | import tqdm 41 | 42 | import random 43 | import os 44 | from collections import defaultdict 45 | import time 46 | import json 47 | import functools 48 | from typing import Optional, Dict, List, Union, Tuple 49 | import csv 50 | 51 | 52 | class BasicTrainer(object): 53 | def __init__(self, policy: nn.Module, config: DictConfig, seed: int, run_dir: str, 54 | reference_model: Optional[nn.Module] = None, data_selector: DataSelector = None, 55 | rank: int = 0, world_size: int = 1): 56 | """A trainer for a language model, supporting either SFT or DPO training. 57 | 58 | If multiple GPUs are present, naively splits the model across them, effectively 59 | offering N times available memory, but without any parallel computation. 60 | """ 61 | self.seed = seed 62 | self.rank = rank 63 | self.world_size = world_size 64 | self.config = config 65 | self.run_dir = run_dir 66 | 67 | tokenizer_name_or_path = config.model.tokenizer_name_or_path or config.model.name_or_path 68 | rank0_print(f'Loading tokenizer {tokenizer_name_or_path}') 69 | self.tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_name_or_path, cache_dir=get_local_dir(config.local_dirs)) 70 | if self.tokenizer.pad_token_id is None: 71 | self.tokenizer.pad_token_id = self.tokenizer.eos_token_id 72 | 73 | data_iterator_kwargs = dict( 74 | names=config.datasets, 75 | tokenizer=self.tokenizer, 76 | shuffle=True, 77 | max_length=config.max_length, 78 | max_prompt_length=config.max_prompt_length, 79 | sft_mode=config.loss.name == 'sft', 80 | test_dataset=config.test_dataset 81 | ) 82 | 83 | self.policy = policy 84 | self.reference_model = reference_model 85 | 86 | self.train_iterator = get_batch_iterator(**data_iterator_kwargs, split='train', n_epochs=config.n_epochs, n_examples=config.n_examples, batch_size=config.batch_size, silent=rank != 0, cache_dir=get_local_dir(config.local_dirs)) 87 | rank0_print('Loaded train data iterator') 88 | self.eval_iterator = get_batch_iterator(**data_iterator_kwargs, split='test', n_examples=config.n_eval_examples, batch_size=config.eval_batch_size, silent=rank != 0, cache_dir=get_local_dir(config.local_dirs)) 89 | self.eval_batches = list(self.eval_iterator) 90 | rank0_print(f'Loaded {len(self.eval_batches)} eval batches of size {config.eval_batch_size}') 91 | 92 | #Use the passed data selector argument 93 | self.data_selector = data_selector 94 | 95 | def get_batch_samples(self, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]: 96 | """Generate samples from the policy (and reference model, if doing DPO training) for the given batch of inputs.""" 97 | 98 | # FSDP generation according to https://github.com/pytorch/pytorch/issues/100069 99 | with torch.no_grad(): 100 | ctx = lambda: (FSDP.summon_full_params(self.policy, writeback=False, recurse=False) if 'FSDP' in self.config.trainer else contextlib.nullcontext()) 101 | with ctx(): 102 | policy_output = self.policy.generate( 103 | batch['prompt_input_ids'], attention_mask=batch['prompt_attention_mask'], max_length=self.config.max_length, do_sample=True, pad_token_id=self.tokenizer.pad_token_id) 104 | 105 | if self.config.loss.name in {'dpo', 'ipo'}: 106 | ctx = lambda: (FSDP.summon_full_params(self.reference_model, writeback=False, recurse=False) if 'FSDP' in self.config.trainer else contextlib.nullcontext()) 107 | with ctx(): 108 | reference_output = self.reference_model.generate( 109 | batch['prompt_input_ids'], attention_mask=batch['prompt_attention_mask'], max_length=self.config.max_length, do_sample=True, pad_token_id=self.tokenizer.pad_token_id) 110 | 111 | policy_output = pad_to_length(policy_output, self.config.max_length, self.tokenizer.pad_token_id) 112 | policy_output = all_gather_if_needed(policy_output, self.rank, self.world_size) 113 | policy_output_decoded = self.tokenizer.batch_decode(policy_output, skip_special_tokens=True) 114 | 115 | if self.config.loss.name in {'dpo', 'ipo'}: 116 | reference_output = pad_to_length(reference_output, self.config.max_length, self.tokenizer.pad_token_id) 117 | reference_output = all_gather_if_needed(reference_output, self.rank, self.world_size) 118 | reference_output_decoded = self.tokenizer.batch_decode(reference_output, skip_special_tokens=True) 119 | else: 120 | reference_output_decoded = [] 121 | 122 | return policy_output_decoded, reference_output_decoded 123 | 124 | def concatenated_forward(self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]) -> Tuple[torch.FloatTensor, torch.FloatTensor]: 125 | """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together. 126 | 127 | We do this to avoid doing two forward passes, because it's faster for FSDP. 128 | 129 | TODO: Can we get rid of this? 130 | """ 131 | concatenated_batch = concatenated_inputs(batch) 132 | all_logits = model(concatenated_batch['concatenated_input_ids'], attention_mask=concatenated_batch['concatenated_attention_mask']).logits.to(torch.float32) 133 | all_logps = _get_batch_logps(all_logits, concatenated_batch['concatenated_labels'], average_log_prob=False) 134 | chosen_logps = all_logps[:batch['chosen_input_ids'].shape[0]] 135 | rejected_logps = all_logps[batch['chosen_input_ids'].shape[0]:] 136 | return chosen_logps, rejected_logps 137 | 138 | 139 | def get_batch_metrics(self, batch: Dict[str, Union[List, torch.LongTensor]], loss_config: DictConfig, train=True): 140 | """Compute the SFT or DPO loss and other metrics for the given batch of inputs.""" 141 | 142 | metrics = {} 143 | train_test = 'train' if train else 'eval' 144 | 145 | if loss_config.name in {'dpo', 'ipo', 'rdpo', 'ripo'}: 146 | policy_chosen_logps, policy_rejected_logps = self.concatenated_forward(self.policy, batch) 147 | with torch.no_grad(): 148 | reference_chosen_logps, reference_rejected_logps = self.concatenated_forward(self.reference_model, batch) 149 | 150 | if 'dpo' in loss_config.name: 151 | loss_kwargs = {'beta': loss_config.beta, 'reference_free': loss_config.reference_free, 'label_smoothing': loss_config.label_smoothing, 'ipo': False} 152 | elif 'ipo' in loss_config.name: 153 | loss_kwargs = {'beta': loss_config.beta, 'ipo': True} 154 | else: 155 | raise ValueError(f'unknown loss {loss_config.name}') 156 | 157 | losses, chosen_rewards, rejected_rewards = preference_loss( 158 | policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps, **loss_kwargs) 159 | 160 | reward_accuracies = (chosen_rewards > rejected_rewards).float() 161 | logps_pol_accuracies= (policy_chosen_logps > policy_rejected_logps).float() 162 | logps_ref_accuracies= (reference_chosen_logps > reference_rejected_logps).float() 163 | 164 | chosen_rewards = all_gather_if_needed(chosen_rewards, self.rank, self.world_size) 165 | rejected_rewards = all_gather_if_needed(rejected_rewards, self.rank, self.world_size) 166 | reward_accuracies = all_gather_if_needed(reward_accuracies, self.rank, self.world_size) 167 | logps_pol_accuracies = all_gather_if_needed(logps_pol_accuracies.detach(), self.rank, self.world_size) 168 | logps_ref_accuracies = all_gather_if_needed(logps_ref_accuracies.detach(), self.rank, self.world_size) 169 | 170 | 171 | metrics[f'rewards_{train_test}/chosen'] = chosen_rewards.cpu().numpy().tolist() 172 | metrics[f'rewards_{train_test}/rejected'] = rejected_rewards.cpu().numpy().tolist() 173 | metrics[f'rewards_{train_test}/accuracies'] = reward_accuracies.cpu().numpy().tolist() 174 | metrics[f'logps_pol_{train_test}/accuracies'] = logps_pol_accuracies.cpu().numpy().tolist() 175 | metrics[f'logps_ref_{train_test}/accuracies'] = logps_ref_accuracies.cpu().numpy().tolist() 176 | metrics[f'rewards_{train_test}/margins'] = (chosen_rewards - rejected_rewards).cpu().numpy().tolist() 177 | 178 | policy_rejected_logps = all_gather_if_needed(policy_rejected_logps.detach(), self.rank, self.world_size) 179 | metrics[f'logps_{train_test}/rejected'] = policy_rejected_logps.cpu().numpy().tolist() 180 | 181 | elif loss_config.name == 'sft': 182 | policy_chosen_logits = self.policy(batch['chosen_input_ids'], attention_mask=batch['chosen_attention_mask']).logits.to(torch.float32) 183 | policy_chosen_logps = _get_batch_logps(policy_chosen_logits, batch['chosen_labels'], average_log_prob=False) 184 | 185 | losses = -policy_chosen_logps 186 | 187 | elif loss_config.name == 'base': 188 | policy_chosen_logps, policy_rejected_logps = self.concatenated_forward(self.policy, batch) 189 | logps_accuracies= (policy_chosen_logps > policy_rejected_logps).float() 190 | losses=-logps_accuracies 191 | 192 | gathered_tensors = { 193 | 'logps_accuracies': all_gather_if_needed(logps_accuracies, self.rank, self.world_size), 194 | 'policy_rejected_logps': policy_rejected_logps.detach() 195 | } 196 | 197 | metrics.update({ 198 | f'{k}_{train_test}': v.cpu().numpy().tolist() for k, v in gathered_tensors.items() 199 | }) 200 | 201 | policy_chosen_logps = all_gather_if_needed(policy_chosen_logps.detach(), self.rank, self.world_size) 202 | metrics[f'logps_{train_test}/chosen'] = policy_chosen_logps.cpu().numpy().tolist() 203 | 204 | all_devices_losses = all_gather_if_needed(losses.detach(), self.rank, self.world_size) 205 | metrics[f'loss/{train_test}'] = all_devices_losses.cpu().numpy().tolist() 206 | 207 | return losses.mean(), metrics 208 | 209 | def train(self): 210 | """Begin either SFT or DPO training, with periodic evaluation.""" 211 | 212 | rank0_print(f'Using {self.config.optimizer} optimizer') 213 | self.optimizer = getattr(torch.optim, self.config.optimizer)(self.policy.parameters(), lr=self.config.lr) 214 | self.scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lambda step: min(1.0, (step + 1) / (self.config.warmup_steps + 1))) 215 | 216 | torch.manual_seed(self.seed) 217 | np.random.seed(self.seed) 218 | random.seed(self.seed) 219 | 220 | if self.config.loss.name in {'dpo', 'ipo'}: 221 | self.reference_model.eval() 222 | 223 | self.example_counter = 0 224 | self.batch_counter = 0 225 | last_log = None 226 | 227 | for batch in self.train_iterator: 228 | #### BEGIN EVALUATION #### 229 | if self.example_counter % self.config.eval_every == 0 and (self.example_counter > 0 or self.config.do_first_eval): 230 | rank0_print(f'Running evaluation after {self.example_counter} train examples') 231 | self.policy.eval() 232 | 233 | all_eval_metrics = defaultdict(list) 234 | 235 | if self.config.sample_during_eval: 236 | all_policy_samples, all_reference_samples = [], [] 237 | policy_text_table = wandb.Table(columns=["step", "prompt", "sample","correct response"]) 238 | if self.config.loss.name in {'dpo', 'ipo'}: 239 | reference_text_table = wandb.Table(columns=["step", "prompt", "sample","correct response"]) 240 | 241 | for eval_batch in (tqdm.tqdm(self.eval_batches, desc='Computing eval metrics') if self.rank == 0 else self.eval_batches): 242 | local_eval_batch = slice_and_move_batch_for_device(eval_batch, self.rank, self.world_size, self.rank) 243 | with torch.no_grad(): 244 | _, eval_metrics = self.get_batch_metrics(local_eval_batch, self.config.loss, train=False) 245 | 246 | for k, v in eval_metrics.items(): 247 | all_eval_metrics[k].extend(v) 248 | 249 | if self.config.sample_during_eval: 250 | if self.config.n_eval_model_samples < self.config.eval_batch_size: 251 | rank0_print(f'Warning: n_eval_model_samples ({self.config.n_eval_model_samples}) < eval_batch_size ({self.config.eval_batch_size}). Sampling from the first complete eval batch of prompts.') 252 | sample_batches = self.gen_batches[:1] 253 | else: 254 | n_sample_batches = self.config.n_eval_model_samples // self.config.eval_batch_size 255 | sample_batches = self.gen_batches[:n_sample_batches] 256 | for eval_batch in (tqdm.tqdm(sample_batches, desc='Generating samples...') if self.rank == 0 else sample_batches): 257 | local_eval_batch = slice_and_move_batch_for_device(eval_batch, self.rank, self.world_size, self.rank) 258 | policy_samples, reference_samples = self.get_batch_samples(local_eval_batch) 259 | 260 | all_policy_samples.extend(policy_samples) 261 | all_reference_samples.extend(reference_samples) 262 | 263 | for prompt, sample in zip(eval_batch['prompt'], policy_samples): 264 | policy_text_table.add_data(self.example_counter, prompt, sample) 265 | if self.config.loss.name in {'dpo', 'ipo'}: 266 | for prompt, sample in zip(eval_batch['prompt'], reference_samples): 267 | reference_text_table.add_data(self.example_counter, prompt, sample) 268 | 269 | mean_eval_metrics = {k: sum(v) / len(v) for k, v in all_eval_metrics.items()} 270 | rank0_print(f'eval after {self.example_counter}: {formatted_dict(mean_eval_metrics)}') 271 | if self.config.sample_during_eval: 272 | rank0_print(json.dumps(all_policy_samples[:10], indent=2)) 273 | if self.config.loss.name in {'dpo', 'ipo'}: 274 | rank0_print(json.dumps(all_reference_samples[:10], indent=2)) 275 | 276 | if self.config.wandb.enabled and self.rank == 0: 277 | wandb.log(mean_eval_metrics, step=self.example_counter) 278 | 279 | if self.config.sample_during_eval: 280 | save_path = os.path.join(self.run_dir, f'step-{self.example_counter}_samples') 281 | #rank0_print(f'creating checkpoint to write samples to {output_dir}...') 282 | rank0_print(f'writing table to {save_path}...') 283 | with open(save_path, mode='w', newline='') as file: 284 | writer = csv.writer(file) 285 | writer.writerow(policy_text_table.columns) 286 | for i,row in policy_text_table.iterrows(): 287 | writer.writerow(row) 288 | wandb.log({f"policy_samples": policy_text_table}, step=self.example_counter) 289 | if self.config.loss.name in {'dpo', 'ipo'}: 290 | wandb.log({"reference_samples": reference_text_table}, step=self.example_counter) 291 | 292 | if self.example_counter > 0: 293 | if self.config.debug: 294 | rank0_print('skipping save in debug mode') 295 | else: 296 | output_dir = os.path.join(self.run_dir, f'step-{self.example_counter}') 297 | rank0_print(f'creating checkpoint to write to {output_dir}...') 298 | self.save(output_dir, mean_eval_metrics) 299 | if self.config.eval_only_once==True: 300 | return 301 | #### END EVALUATION #### 302 | 303 | #### POINT SELECTION #### 304 | if self.data_selector is not None: 305 | 306 | selected_batch, not_selected_batch, selected_size = self.data_selector.\ 307 | select_batch(batch, self.config.selected_batch_size, 308 | self.policy, self.reference_model) 309 | batch_size = selected_size 310 | 311 | else: 312 | selected_batch = batch 313 | not_selected_batch = None 314 | batch_size = self.config.batch_size 315 | 316 | #### BEGIN TRAINING #### 317 | 318 | self.policy.train() 319 | 320 | start_time = time.time() 321 | batch_metrics = defaultdict(list) 322 | for microbatch_idx in range(self.config.gradient_accumulation_steps): 323 | global_microbatch = slice_and_move_batch_for_device(selected_batch, microbatch_idx, 324 | self.config.gradient_accumulation_steps, self.rank) 325 | local_microbatch = slice_and_move_batch_for_device(global_microbatch, self.rank, self.world_size, self.rank) 326 | loss, metrics = self.get_batch_metrics(local_microbatch, self.config.loss, train=True) 327 | (loss / self.config.gradient_accumulation_steps).backward() 328 | 329 | for k, v in metrics.items(): 330 | batch_metrics[k].extend(v) 331 | 332 | grad_norm = self.clip_gradient() 333 | self.optimizer.step() 334 | self.scheduler.step() 335 | self.optimizer.zero_grad() 336 | 337 | step_time = time.time() - start_time 338 | examples_per_second = batch_size / step_time 339 | batch_metrics['examples_per_second'].append(examples_per_second) 340 | batch_metrics['grad_norm'].append(grad_norm) 341 | 342 | self.batch_counter += 1 343 | self.example_counter += batch_size #n.b. self.config.batch_size exists if needed 344 | 345 | if last_log is None or time.time() - last_log > self.config.minimum_log_interval_secs: 346 | mean_train_metrics = {k: sum(v) / len(v) for k, v in batch_metrics.items()} 347 | mean_train_metrics['counters/examples'] = self.example_counter 348 | mean_train_metrics['counters/updates'] = self.batch_counter 349 | rank0_print(f'train stats after {self.example_counter} examples: {formatted_dict(mean_train_metrics)}') 350 | 351 | if self.config.wandb.enabled and self.rank == 0: 352 | wandb.log(mean_train_metrics, step=self.example_counter) 353 | 354 | last_log = time.time() 355 | else: 356 | rank0_print(f'skipping logging after {self.example_counter} examples to avoid logging too frequently') 357 | #### END TRAINING #### 358 | 359 | #### END OF TRAINING EVAL (can be on the train set) #### 360 | #TODO: Evaluate and save the losses of the training dataset if option requested. 361 | 362 | 363 | 364 | def clip_gradient(self): 365 | """Clip the gradient norm of the parameters of a non-FSDP policy.""" 366 | return torch.nn.utils.clip_grad_norm_(self.policy.parameters(), self.config.max_grad_norm).item() 367 | 368 | def write_state_dict(self, step: int, state: Dict[str, torch.Tensor], metrics: Dict, filename: str, dir_name: Optional[str] = None): 369 | """Write a checkpoint to disk.""" 370 | if dir_name is None: 371 | dir_name = os.path.join(self.run_dir, f'LATEST') 372 | 373 | os.makedirs(dir_name, exist_ok=True) 374 | output_path = os.path.join(dir_name, filename) 375 | rank0_print(f'writing checkpoint to {output_path}...') 376 | torch.save({ 377 | 'step_idx': step, 378 | 'state': state, 379 | 'metrics': metrics if metrics is not None else {}, 380 | }, output_path) 381 | 382 | def write_labelled_dataset(self, dataset, labels): 383 | pass 384 | 385 | def save(self, output_dir: Optional[str] = None, metrics: Optional[Dict] = None): 386 | """Save policy, optimizer, and scheduler state to disk.""" 387 | 388 | policy_state_dict = self.policy.state_dict() 389 | self.write_state_dict(self.example_counter, policy_state_dict, metrics, 'policy.pt', output_dir) 390 | del policy_state_dict 391 | 392 | optimizer_state_dict = self.optimizer.state_dict() 393 | self.write_state_dict(self.example_counter, optimizer_state_dict, metrics, 'optimizer.pt', output_dir) 394 | del optimizer_state_dict 395 | 396 | scheduler_state_dict = self.scheduler.state_dict() 397 | self.write_state_dict(self.example_counter, scheduler_state_dict, metrics, 'scheduler.pt', output_dir) 398 | 399 | -------------------------------------------------------------------------------- /src/preference_datasets.py: -------------------------------------------------------------------------------- 1 | import datasets 2 | import torch 3 | from torch.utils.data import DataLoader, Dataset 4 | from src.utils import get_local_dir, TemporarilySeededRandom 5 | from src.groupstuff.group_dataset import GroupDataset 6 | from torch.nn.utils.rnn import pad_sequence 7 | from collections import defaultdict 8 | import tqdm 9 | import random 10 | from bs4 import BeautifulSoup, NavigableString 11 | import numpy as np 12 | from typing import Dict, List, Optional, Iterator, Callable, Union, Tuple 13 | from src.groupstuff.data_processing import get_oqa,get_oqa_group,get_hh_datasets 14 | from src.groupstuff.global_opinion_data_processing import get_goqa 15 | from src.groupstuff.global_opinion_data_processing_kfold import get_goqa_kfold 16 | 17 | def extract_anthropic_prompt(prompt_and_response): 18 | """Extract the anthropic prompt from a prompt and response pair.""" 19 | search_term = '\n\nAssistant:' 20 | search_term_idx = prompt_and_response.rfind(search_term) 21 | assert search_term_idx != -1, f"Prompt and response does not contain '{search_term}'" 22 | return prompt_and_response[:search_term_idx + len(search_term)] 23 | 24 | 25 | def strip_html_tags(html_string): 26 | """Strip HTML tags from a string, except for tags (which contain real code in the StackExchange answers).""" 27 | # Create a BeautifulSoup object 28 | soup = BeautifulSoup(html_string, 'html.parser') 29 | 30 | # Initialize an empty list to store the text 31 | text = [] 32 | for element in soup.children: 33 | if isinstance(element, NavigableString): 34 | continue 35 | if element.name == 'p': 36 | text.append(''.join(child.string for child in element.children if isinstance(child, NavigableString))) 37 | elif element.name == 'pre': 38 | for code in element.find_all('code'): 39 | text.append("" + code.get_text() + "") 40 | elif element.name == 'code': 41 | text.append("" + element.get_text() + "") 42 | 43 | # Join the text together with newlines in between 44 | text = "\n\n".join(text) 45 | 46 | return text 47 | 48 | 49 | def get_se(split, silent=False, cache_dir: str = None) -> Dict[str, Dict[str, Union[List[Tuple[int, int]], List[str], str]]]: 50 | """Load the StackExchange dataset from Huggingface, and return a dict of prompts and responses. See get_hh for the format. 51 | 52 | We strip the HTML tags from the responses (except for tags), and we add necessary newlines. 53 | """ 54 | print(f'Loading SE dataset ({split} split) from Huggingface...') 55 | dataset = datasets.load_dataset('HuggingFaceH4/stack-exchange-preferences', cache_dir=cache_dir)['train'] 56 | print('done') 57 | 58 | # shuffle the dataset and select 1% for test 59 | dataset = dataset.shuffle(seed=42) 60 | dataset = dataset.select(range(int(len(dataset) * 0.01))) if split == 'test' else dataset.select( 61 | range(int(len(dataset) * 0.01), len(dataset))) 62 | 63 | def strip_html(x): 64 | x['question'] = strip_html_tags(x['question']) 65 | for a in x['answers']: 66 | a['text'] = strip_html_tags(a['text']) 67 | return x 68 | 69 | dataset = dataset.map(strip_html, num_proc=64) 70 | 71 | data = defaultdict(dict) 72 | for row in tqdm.tqdm(dataset, desc='Processing SE', disable=silent): 73 | prompt = '\n\nHuman: ' + row['question'] + '\n\nAssistant:' 74 | responses = [' ' + a['text'] for a in row['answers']] 75 | scores = [a['pm_score'] for a in row['answers']] 76 | 77 | pairs = [] 78 | for i in range(len(responses)): 79 | for j in range(i + 1, len(responses)): 80 | pairs.append((i, j) if scores[i] > scores[j] else (j, i)) 81 | 82 | data[prompt]['responses'] = responses 83 | data[prompt]['pairs'] = pairs 84 | data[prompt]['sft_target'] = max(responses, key=lambda x: scores[responses.index(x)]) 85 | 86 | return data 87 | 88 | def get_shp(split: str, silent: bool = False, cache_dir: str = None) -> Dict[str, Dict[str, Union[List[Tuple[int, int]], List[str], str]]]: 89 | """Load the Stanford Human Preferences dataset from Huggingface and convert it to the necessary format. See hh for the format. 90 | 91 | We filter preference pairs to only keep pairs where the score ratio is at least 2. 92 | For this dataset, the sft_target is the response with the highest score. 93 | """ 94 | print(f'Loading SHP dataset ({split} split) from Huggingface...') 95 | dataset = datasets.load_dataset('stanfordnlp/SHP', split=split, cache_dir=cache_dir) 96 | print('done') 97 | 98 | data = defaultdict(lambda: defaultdict(list)) 99 | for row in tqdm.tqdm(dataset, desc='Processing SHP', disable=silent): 100 | prompt = '\n\nHuman: ' + row['history'] + '\n\nAssistant:' 101 | responses = [' ' + row['human_ref_A'], ' ' + row['human_ref_B']] 102 | scores = [row['score_A'], row['score_B']] 103 | if prompt in data: 104 | n_responses = len(data[prompt]['responses']) 105 | else: 106 | n_responses = 0 107 | score_ratio = max(scores[0] / scores[1], scores[1] / scores[0]) 108 | if score_ratio < 2: 109 | continue 110 | 111 | # according to https://huggingface.co/datasets/stanfordnlp/SHP 112 | data[prompt]['pairs'].append((n_responses, n_responses + 1) if row['labels'] == 1 else (n_responses + 1, n_responses)) 113 | data[prompt]['responses'].extend(responses) 114 | data[prompt]['scores'].extend(scores) 115 | 116 | for prompt in data: 117 | data[prompt]['sft_target'] = max(data[prompt]['responses'], key=lambda x: data[prompt]['scores'][data[prompt]['responses'].index(x)]) 118 | del data[prompt]['scores'] 119 | 120 | return data 121 | 122 | 123 | def get_hh(split: str, silent: bool = False, cache_dir: str = None) -> Dict[str, Dict[str, Union[List[Tuple[int, int]], List[str], str]]]: 124 | """Load the Anthropic Helpful-Harmless dataset from Huggingface and convert it to the necessary format. 125 | 126 | The dataset is converted to a dictionary with the following structure: 127 | { 128 | 'prompt1': { 129 | 'responses': List[str], 130 | 'pairs': List[Tuple[int, int]], 131 | 'sft_target': str 132 | }, 133 | 'prompt2': { 134 | ... 135 | }, 136 | } 137 | 138 | Prompts should be structured as follows: 139 | \n\nHuman: \n\nAssistant: 140 | Multiple turns are allowed, but the prompt should always start with \n\nHuman: and end with \n\nAssistant:. 141 | 142 | For this dataset, the sft_target is just the chosen response. 143 | """ 144 | print(f'Loading HH dataset ({split} split) from Huggingface...') 145 | dataset = datasets.load_dataset('Anthropic/hh-rlhf', split=split, cache_dir=cache_dir) 146 | print('done') 147 | 148 | def split_prompt_and_responses(ex): 149 | prompt = extract_anthropic_prompt(ex['chosen']) 150 | chosen_response = ex['chosen'][len(prompt):] 151 | rejected_response = ex['rejected'][len(prompt):] 152 | return prompt, chosen_response, rejected_response 153 | 154 | data = defaultdict(lambda: defaultdict(list)) 155 | for row in tqdm.tqdm(dataset, desc='Processing HH', disable=silent): 156 | prompt, chosen, rejected = split_prompt_and_responses(row) 157 | responses = [chosen, rejected] 158 | n_responses = len(data[prompt]['responses']) 159 | data[prompt]['pairs'].append((n_responses, n_responses + 1)) 160 | data[prompt]['responses'].extend(responses) 161 | data[prompt]['sft_target'] = chosen 162 | 163 | return data 164 | 165 | 166 | def get_dataset(name: str, split: str, train_frac: float = 0.8, silent: bool = False, cache_dir: str = None, test:bool = False, split_idx: int = None): 167 | """Load the given dataset by name. Supported by default are 'shp', 'hh', and 'se'.""" 168 | if name == 'shp': 169 | data = get_shp(split=split, silent=silent, cache_dir=cache_dir) 170 | elif name == 'hh': 171 | data = get_hh(split=split, silent=silent, cache_dir=cache_dir) 172 | elif name == 'se': 173 | data = get_se(split=split, silent=silent, cache_dir=cache_dir) 174 | elif 'goqma' in name: 175 | group_id=int(name.split('_')[-1]) 176 | data=get_goqa(split=split,train_frac= train_frac,group_id= group_id,multi_pair=True, silent=silent, cache_dir=cache_dir) 177 | elif 'goqa' in name: 178 | group_id=int(name.split('_')[-1]) 179 | if not split_idx: 180 | data=get_goqa(split=split,train_frac=train_frac,group_id= group_id,multi_pair=False, silent=silent, cache_dir=cache_dir) 181 | else: 182 | data=get_goqa_kfold(split=split,train_frac=train_frac,group_id= group_id,multi_pair=False, silent=silent, cache_dir=cache_dir, split_idx=split_idx) 183 | elif 'oqa' in name: 184 | namesplit = name.split('_') # name format e.g. "oqa_SEX_Male" here 185 | attribute = namesplit[1] 186 | group = namesplit[2] 187 | data=get_oqa(split=split,attribute=attribute,group=group,mode="best-random",multi_pair=False,silent=silent,cache_dir=cache_dir) 188 | elif name in ['hel','helon','helraj','har']: 189 | data = get_hh_datasets(split, variant=[name], silent=silent, cache_dir=cache_dir) 190 | elif name == 'heltot': 191 | data = get_hh_datasets(split, variant=['hel','helon','helraj'], silent=silent, cache_dir=cache_dir) 192 | else: 193 | raise ValueError(f"Unknown dataset '{name}'") 194 | 195 | assert set(list(data.values())[0].keys()) == {'responses', 'pairs', 'sft_target'}, \ 196 | f"Unexpected keys in dataset: {list(list(data.values())[0].keys())}" 197 | 198 | #If test mode reduce the dataset size to only 10 datapoints 199 | if test: 200 | print('Using test data config') 201 | data_keys, new_data = list(data.keys())[:32], dict() 202 | for key in data_keys: 203 | new_data[key] = data[key] 204 | data = new_data 205 | print('Pruned test data config') 206 | 207 | return data 208 | 209 | def get_collate_fn(tokenizer) -> Callable[[List[Dict]], Dict[str, Union[List, torch.Tensor]]]: 210 | """Returns a collate function for the given tokenizer. 211 | 212 | The collate function takes a list of examples (dicts, where values are lists of 213 | ints [tokens] or strings [the original texts]) and returns a batch of examples, 214 | PyTorch tensors padded to the maximum length. Strings are passed through.""" 215 | def collate_fn(batch): 216 | # first, pad everything to the same length 217 | padded_batch = {} 218 | for k in batch[0].keys(): 219 | if k.endswith('_input_ids') or k.endswith('_attention_mask') or k.endswith('_labels'): 220 | if 'prompt' in k: # adapted from https://stackoverflow.com/questions/73256206 221 | to_pad = [torch.LongTensor(ex[k][::-1]) for ex in batch] 222 | else: 223 | to_pad = [torch.LongTensor(ex[k]) for ex in batch] 224 | if k.endswith('_input_ids'): 225 | padding_value = tokenizer.pad_token_id 226 | elif k.endswith('_labels'): 227 | padding_value = -100 228 | elif k.endswith('_attention_mask'): 229 | padding_value = 0 230 | else: 231 | raise ValueError(f"Unexpected key in batch '{k}'") 232 | 233 | padded_batch[k] = pad_sequence(to_pad, batch_first=True, padding_value=padding_value) 234 | if 'prompt' in k: # for the prompt, flip back so padding is on left side 235 | padded_batch[k] = padded_batch[k].flip(dims=[1]) 236 | else: 237 | padded_batch[k] = [ex[k] for ex in batch] 238 | 239 | return padded_batch 240 | return collate_fn 241 | 242 | 243 | def tokenize_batch_element(prompt: str, chosen: str, rejected: str, truncation_mode: str, tokenizer, max_length: int, max_prompt_length: int, group: int=None) -> Dict: 244 | """Tokenize a single batch element. 245 | 246 | At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation 247 | in case the prompt + chosen or prompt + rejected responses is/are too long. First 248 | we truncate the prompt; if we're still too long, we truncate the chosen/rejected. 249 | 250 | We also create the labels for the chosen/rejected responses, which are of length equal to 251 | the sum of the length of the prompt and the chosen/rejected response, with -100 for the 252 | prompt tokens. 253 | """ 254 | chosen_tokens = tokenizer(chosen, add_special_tokens=False) 255 | rejected_tokens = tokenizer(rejected, add_special_tokens=False) 256 | prompt_tokens = tokenizer(prompt, add_special_tokens=False) 257 | 258 | assert tokenizer.eos_token_id not in prompt_tokens['input_ids'], f"Prompt contains EOS token: {prompt}" 259 | assert tokenizer.eos_token_id not in chosen_tokens['input_ids'], f"Chosen response contains EOS token: {chosen}" 260 | assert tokenizer.eos_token_id not in rejected_tokens['input_ids'], f"Rejected response contains EOS token: {rejected}" 261 | 262 | chosen_tokens['input_ids'].append(tokenizer.eos_token_id) 263 | chosen_tokens['attention_mask'].append(1) 264 | 265 | rejected_tokens['input_ids'].append(tokenizer.eos_token_id) 266 | rejected_tokens['attention_mask'].append(1) 267 | 268 | longer_response_length = max(len(chosen_tokens['input_ids']), len(rejected_tokens['input_ids'])) 269 | 270 | # if combined sequence is too long, truncate the prompt 271 | if len(prompt_tokens['input_ids']) + longer_response_length > max_length: 272 | if truncation_mode == 'keep_start': 273 | prompt_tokens = {k: v[:max_prompt_length] for k, v in prompt_tokens.items()} 274 | elif truncation_mode == 'keep_end': 275 | prompt_tokens = {k: v[-max_prompt_length:] for k, v in prompt_tokens.items()} 276 | else: 277 | raise ValueError(f'Unknown truncation mode: {truncation_mode}') 278 | 279 | # if that's still too long, truncate the response 280 | if len(prompt_tokens['input_ids']) + longer_response_length > max_length: 281 | chosen_tokens = {k: v[:max_length - max_prompt_length] for k, v in chosen_tokens.items()} 282 | rejected_tokens = {k: v[:max_length - max_prompt_length] for k, v in rejected_tokens.items()} 283 | 284 | # Create labels 285 | chosen_sequence_tokens = {k: prompt_tokens[k] + chosen_tokens[k] for k in chosen_tokens} 286 | rejected_sequence_tokens = {k: prompt_tokens[k] + rejected_tokens[k] for k in rejected_tokens} 287 | chosen_sequence_tokens['labels'] = chosen_sequence_tokens['input_ids'][:] 288 | chosen_sequence_tokens['labels'][:len(prompt_tokens['input_ids'])] = [-100] * len(prompt_tokens['input_ids']) 289 | rejected_sequence_tokens['labels'] = rejected_sequence_tokens['input_ids'][:] 290 | rejected_sequence_tokens['labels'][:len(prompt_tokens['input_ids'])] = [-100] * len(prompt_tokens['input_ids']) 291 | 292 | batch = {} 293 | 294 | batch['prompt'] = prompt 295 | batch['chosen'] = prompt + chosen 296 | batch['rejected'] = prompt + rejected 297 | batch['chosen_response_only'] = chosen 298 | batch['rejected_response_only'] = rejected 299 | #print('element',group) 300 | #print('element wth group') 301 | if group is not None: 302 | batch['group']=group 303 | 304 | for k, toks in {'chosen': chosen_sequence_tokens, 'rejected': rejected_sequence_tokens, 'prompt': prompt_tokens}.items(): 305 | for type_key, tokens in toks.items(): 306 | if type_key == 'token_type_ids': 307 | continue 308 | batch[f'{k}_{type_key}'] = tokens 309 | 310 | return batch 311 | 312 | 313 | def process_dataset(dataset: Dict, 314 | truncation_mode: str, 315 | sep_pairs: bool, 316 | unique_prompts: bool, 317 | group_handling: bool = False, 318 | group_id: Optional[int] = None): 319 | """ 320 | Process the dataset to prepare data for batching, considering separation of pairs and group handling. 321 | 322 | Args: 323 | dataset: The dataset to process. 324 | truncation_mode: Truncation mode to apply ('keep_start' or 'keep_end'). 325 | sep_pairs: Whether to separate pairs into individual items. 326 | group_handling: Whether group-specific logic is enabled. 327 | group_id: Optional identifier for the data group. 328 | 329 | Returns: 330 | A list of processed dataset items ready for batching. 331 | """ 332 | flat_data = [] 333 | for prompt, data in dataset.items(): 334 | # Process based on separation of pairs and group handling 335 | if len(data['pairs']) > 1 and sep_pairs: 336 | for pair_index, pair in enumerate(data['pairs']): 337 | responses = [data['responses'][pair[0]], data['responses'][pair[1]]] 338 | # Include group_id if group_handling is True 339 | data_tuple = (prompt, responses, [(0, 1)], data['sft_target'], truncation_mode) 340 | if group_handling: 341 | flat_data.append((*data_tuple, group_id)) 342 | #print((*data_tuple, group_id)) 343 | else: 344 | flat_data.append(data_tuple) 345 | if unique_prompts: 346 | break 347 | else: 348 | data_tuple = (prompt, data['responses'], data['pairs'], data['sft_target'], truncation_mode) 349 | if group_handling: 350 | flat_data.append((*data_tuple, group_id)) 351 | #print((*data_tuple, group_id)) 352 | else: 353 | flat_data.append(data_tuple) 354 | return flat_data 355 | 356 | def transform_weighted_item(prompt, responses, pairs, sft_target, truncation_mode, group_id): 357 | """ 358 | Example transformation function for items from a weighted iterable, 359 | adjusting them to match the structure expected by the processing logic. 360 | """ 361 | # Transform the item to match the expected structure (prompt, responses, pairs, etc.) 362 | # This is placeholder logic and should be replaced with actual transformation code. 363 | prompt=prompt[0]#to remove tuple 364 | #print(prompt,'prompt') 365 | tresponses=[] 366 | for r in responses: 367 | tresponses.append(r[0])#to remove tuple 368 | responses=tresponses 369 | tpairs=[] 370 | for p in pairs: 371 | tpairs.append(tuple(q.item() for q in p))#to remove tensored versions 372 | pairs=tpairs 373 | sft_target=sft_target[0]#to remove tuple 374 | #print(sft_target) 375 | truncation_mode=truncation_mode[0]#to remove tuple 376 | # 377 | #print(truncation_mode) 378 | group_id=group_id.item() 379 | return prompt, responses, pairs, sft_target, truncation_mode, group_id 380 | 381 | def process_batches(flat_data, batch_size, collate_fn, tokenizer, max_length, max_prompt_length, sft_mode, n_examples,n_epochs, split, silent,shuffle, permutation_seeds, unique_prompts, group_handling, weighted,n_groups): 382 | """ 383 | Processes data into batches and yields them. Handles both weighted and non-weighted scenarios. 384 | 385 | Args: 386 | flat_data: The preprocessed data ready for batching. 387 | batch_size: The size of each batch. 388 | collate_fn: Function to collate data points into batches. 389 | tokenizer: The tokenizer to use for processing text. 390 | max_length: Maximum sequence length. 391 | max_prompt_length: Maximum prompt length. 392 | sft_mode: Whether to use SFT mode for tokenization. 393 | n_examples: The number of examples to process. If None, processes all examples. 394 | split: The data split being used (e.g., 'train', 'test'). 395 | silent: If True, does not print progress messages. 396 | is_train: Indicates if the current processing is for training data. 397 | weighted: Indicates if weighted sampling should be used. 398 | """ 399 | epoch_idx = 0 400 | example_idx = 0 401 | done = False 402 | while not done: 403 | if n_epochs is not None and epoch_idx >= n_epochs: 404 | if not silent: 405 | print(f'Finished generating {n_examples} examples on {split} split') 406 | break 407 | # Shuffle data if required 408 | 409 | if shuffle: 410 | print(next(permutation_seeds),'next seed') 411 | with TemporarilySeededRandom(next(permutation_seeds)): 412 | random.shuffle(flat_data) 413 | 414 | if group_handling and weighted: 415 | # Replace with your actual weighted sampling logic 416 | # For example, use a DataLoader with a WeightedRandomSampler 417 | iterable = GroupDataset(flat_data,n_groups).get_loader() # Assuming this is a DataLoader 418 | else: 419 | iterable = flat_data 420 | 421 | batch = [] 422 | # Process each data point into a batch 423 | for data_point in iterable: 424 | if done: 425 | break 426 | 427 | prompt, responses, pairs, sft_target, truncation_mode = data_point[:5] 428 | #print(data_point) 429 | group_id = data_point[5] if group_handling else None 430 | 431 | # Adjust for weighted scenario unpacking 432 | if group_handling and weighted: 433 | # Example transformation function to match the non-weighted structure 434 | prompt, responses, pairs, sft_target, truncation_mode, group_id = transform_weighted_item(prompt, responses, pairs, sft_target, truncation_mode, group_id) 435 | # Specific logic for SFT mode 436 | if sft_mode: 437 | batch_element = tokenize_batch_element(prompt, sft_target, sft_target, truncation_mode, tokenizer, max_length, max_prompt_length, group=group_id if group_handling else None) 438 | batch_element = {k: v for k, v in batch_element.items() if 'rejected' not in k} 439 | batch.append(batch_element) 440 | example_idx += 1 441 | if len(batch) == batch_size: 442 | #print(batch) 443 | yield collate_fn(batch) 444 | if n_examples is not None and example_idx >= n_examples: 445 | if not silent: 446 | print(f'Finished generating {n_examples} examples on {split} split') 447 | done = True 448 | batch = [] 449 | else: 450 | # Standard processing 451 | for p in pairs: 452 | if done: 453 | break 454 | batch_element = tokenize_batch_element(prompt, responses[p[0]], responses[p[1]], truncation_mode, tokenizer, max_length, max_prompt_length, group=group_id if group_handling else None) 455 | batch.append(batch_element) 456 | example_idx += 1 457 | if len(batch) == batch_size: 458 | yield collate_fn(batch) 459 | if n_examples is not None and example_idx >= n_examples: 460 | if not silent: 461 | print(f'Finished generating {n_examples} examples on {split} split') 462 | done = True 463 | batch = [] 464 | if unique_prompts: 465 | break 466 | 467 | if done: 468 | break 469 | 470 | epoch_idx += 1 471 | 472 | 473 | 474 | def get_batch_iterator(names: List[str], 475 | tokenizer, 476 | split: str = 'train', 477 | batch_size: int = 1, 478 | shuffle: bool = True, 479 | max_length: int = 512, 480 | max_prompt_length: int = 128, 481 | sft_mode: bool = False, 482 | n_epochs: Optional[int] = None, 483 | n_examples: Optional[int] = None, 484 | seed: int = 0, 485 | silent: bool = False, 486 | cache_dir: Optional[str] = None, 487 | test_dataset: bool = False, 488 | group_handling: bool = False, 489 | train_frac: float=0.8, 490 | sep_pairs: bool = False, 491 | weighted: bool = False, 492 | mode: str = 'batch_iterator', 493 | split_idx: int = None) -> Iterator[Dict]: 494 | """Get an iterator over batches of data with optional group handling. 495 | Stops after n_epochs or n_examples, whichever comes first. 496 | 497 | Args: 498 | names: Names of datasets to use. 499 | tokenizer: Tokenizer to use. 500 | split: Which split to use. 501 | batch_size: Batch size. 502 | shuffle: Whether to shuffle the data after each epoch. 503 | max_length: Maximum length of the combined prompt + response. 504 | max_prompt_length: Maximum length of the prompt. 505 | sft_mode: Whether to use SFT mode. 506 | n_epochs: Number of epochs to run for. This or n_examples must be specified. 507 | n_examples: Number of examples to run for. This or n_epochs must be specified. 508 | seed: Random seed. 509 | silent: Whether to silence the progress bar(s). 510 | cache_dir: Directory to cache the datasets in. 511 | test_dataset: Flag to indicate if using a test dataset. 512 | group_handling: Flag to enable group-based data handling. 513 | sep_paris: Flag to enable separation of response pairs corresponding to a single prompt 514 | """ 515 | 516 | assert n_epochs is not None or n_examples is not None, "Must specify either n_epochs or n_examples" 517 | if silent: 518 | datasets.logging.disable_progress_bar() 519 | datasets.logging.set_verbosity_error() 520 | if 'gen' in split: 521 | split=split.split('_')[0] 522 | unique_prompts=True 523 | else: 524 | unique_prompts=False 525 | with TemporarilySeededRandom(seed): 526 | permutation_seeds = iter(np.random.randint(0, 2**32, size=1000000)) 527 | flat_data = [] 528 | group_counts=[] 529 | for name in names: 530 | truncation_mode = 'keep_end' if name in ['hh', 'har', 'hel', 'helon', 'helrej', 'heltot'] else 'keep_start' 531 | if mode=='batch_iterator': 532 | dataset = get_dataset(name=name, train_frac=train_frac, split=split, silent=silent, cache_dir=cache_dir, test=test_dataset, split_idx=split_idx) 533 | group_id = names.index(name) if group_handling else None 534 | flat_data.extend(process_dataset(dataset, truncation_mode, sep_pairs,unique_prompts,group_handling,group_id)) 535 | elif mode=='count_groups': 536 | g_len=0 537 | for prompt, data in get_dataset(name=name, train_frac=train_frac, split=split, silent=silent, cache_dir=cache_dir, test=test_dataset, split_idx=split_idx).items(): 538 | g_len+= 1 if unique_prompts else len(data['pairs']) 539 | group_counts.append(g_len) 540 | 541 | if mode=='count_groups': 542 | return group_counts 543 | 544 | collate_fn = get_collate_fn(tokenizer) 545 | n_groups=len(names) 546 | return process_batches(flat_data, batch_size, collate_fn, tokenizer, max_length, max_prompt_length, sft_mode, n_examples, n_epochs, split, silent, shuffle, permutation_seeds, unique_prompts, group_handling, weighted,n_groups) 547 | 548 | 549 | def strings_match_up_to_spaces(str_a: str, str_b: str) -> bool: 550 | """Returns True if str_a and str_b match up to spaces, False otherwise.""" 551 | for idx in range(min(len(str_a), len(str_b)) - 2): 552 | if str_a[idx] != str_b[idx]: 553 | if str_a[idx] != ' ' and str_b[idx] != ' ': 554 | return False 555 | else: 556 | if str_a[idx] == ' ': 557 | str_a = str_a[:idx] + str_a[idx + 1:] 558 | else: 559 | str_b = str_b[:idx] + str_b[idx + 1:] 560 | 561 | return True --------------------------------------------------------------------------------