├── src ├── utils │ ├── __init__.py │ ├── reward.py │ └── utils.py ├── data │ ├── raw_data │ │ ├── __init__.py │ │ ├── stack_exchange_paired.py │ │ ├── shp.py │ │ ├── summarize_from_feedback.py │ │ ├── utils.py │ │ ├── safe_rlhf.py │ │ ├── hh_rlhf.py │ │ ├── helpsteer.py │ │ └── ultrafeedback.py │ └── configs.py ├── tools │ └── merge_peft_adapter.py └── trainer │ ├── modpo_trainer.py │ ├── sft_trainer.py │ ├── rm_trainer.py │ └── dpo_trainer.py ├── scripts ├── accelerate_configs │ ├── multi_gpu.yaml │ ├── deepspeed_zero1.yaml │ ├── deepspeed_zero2.yaml │ ├── deepspeed_zero3.yaml │ └── fsdp_llama.yaml ├── modpo │ ├── beavertails │ │ ├── utils │ │ │ ├── score.sh │ │ │ ├── gen.sh │ │ │ ├── score.py │ │ │ ├── gen.py │ │ │ └── score_model.py │ │ ├── run.sh │ │ ├── README.md │ │ └── modpo.py │ └── summarize_w_length_penalty │ │ ├── run.sh │ │ ├── README.md │ │ └── modpo.py └── examples │ ├── sft │ ├── run.sh │ └── sft.py │ ├── rm │ ├── run.sh │ └── rm.py │ └── dpo │ ├── run.sh │ └── dpo.py ├── requirements.txt ├── .gitignore └── README.md /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from src.utils.utils import * 2 | # from src.utils.reward import * 3 | -------------------------------------------------------------------------------- /scripts/accelerate_configs/multi_gpu.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | distributed_type: MULTI_GPU 3 | downcast_bf16: 'no' 4 | machine_rank: 0 5 | num_machines: 1 6 | num_processes: 8 7 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers==4.34.1 2 | accelerate==0.23.0 3 | peft==0.5.0 4 | datasets==2.14.5 5 | sentencepiece==0.1.99 6 | trl==0.7.4 7 | wandb==0.15.12 8 | scipy==1.11.3 9 | tqdm==4.66.1 10 | tyro==0.6.0 11 | -------------------------------------------------------------------------------- /scripts/accelerate_configs/deepspeed_zero1.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | distributed_type: DEEPSPEED 3 | downcast_bf16: 'no' 4 | deepspeed_config: 5 | deepspeed_multinode_launcher: standard 6 | zero3_init_flag: false 7 | zero_stage: 1 8 | machine_rank: 0 9 | num_machines: 1 10 | num_processes: 8 11 | -------------------------------------------------------------------------------- /scripts/accelerate_configs/deepspeed_zero2.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | distributed_type: DEEPSPEED 3 | downcast_bf16: 'no' 4 | deepspeed_config: 5 | deepspeed_multinode_launcher: standard 6 | offload_optimizer_device: none 7 | offload_param_device: none 8 | zero3_init_flag: false 9 | zero_stage: 2 10 | machine_rank: 0 11 | num_machines: 1 12 | num_processes: 8 13 | -------------------------------------------------------------------------------- /scripts/accelerate_configs/deepspeed_zero3.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | distributed_type: DEEPSPEED 3 | downcast_bf16: 'no' 4 | deepspeed_config: 5 | deepspeed_multinode_launcher: standard 6 | offload_optimizer_device: none 7 | offload_param_device: none 8 | zero3_init_flag: true 9 | zero3_save_16bit_model: true 10 | zero_stage: 3 11 | machine_rank: 0 12 | num_machines: 1 13 | num_processes: 8 14 | -------------------------------------------------------------------------------- /src/data/raw_data/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import RawDatasetPreprocessor 2 | from .hh_rlhf import HhRlhfRDP 3 | from .safe_rlhf import ( 4 | PKUSafeRlhfRDP, PKUSafeRlhf10KRDP, 5 | ) 6 | from .shp import SHPRDP 7 | from .stack_exchange_paired import StackExchangePairedRDP 8 | from .summarize_from_feedback import SummarizeFromFeedbackRDP 9 | from .helpsteer import HelpSteerRDP 10 | from .ultrafeedback import UltraFeedbackRDP 11 | -------------------------------------------------------------------------------- /scripts/accelerate_configs/fsdp_llama.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | distributed_type: FSDP 3 | downcast_bf16: 'no' 4 | fsdp_config: 5 | fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP 6 | fsdp_backward_prefetch_policy: BACKWARD_PRE 7 | fsdp_offload_params: false 8 | fsdp_sharding_strategy: 1 9 | fsdp_state_dict_type: FULL_STATE_DICT 10 | fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer 11 | machine_rank: 0 12 | num_machines: 1 13 | num_processes: 2 14 | -------------------------------------------------------------------------------- /scripts/modpo/beavertails/utils/score.sh: -------------------------------------------------------------------------------- 1 | # sh scripts/modpo/beavertails/utils/score.sh 2 | 3 | # score 4 | for input_dir in $(find output/PKU-Alignment/PKU-SafeRLHF-10K/modpo/lm -type 'd' | grep "gen"); do 5 | output_dir=${input_dir/"gen"/"score"} 6 | echo "#### scoring responses from dir ${input_dir} ####" 7 | PYTHONPATH=. python3 scripts/modpo/beavertails/utils/score.py \ 8 | --input_dir ${input_dir} \ 9 | --output_dir ${output_dir} 10 | done 11 | 12 | # aggregate scores 13 | output_path="output/PKU-Alignment/PKU-SafeRLHF-10K/modpo/score.csv" 14 | echo "dir, mean reward, mean cost" >> ${output_path} 15 | for input_path in $(find output/PKU-Alignment/PKU-SafeRLHF-10K/modpo/lm | grep "score/mean.csv"); do 16 | echo -n "${input_path}, " >> ${output_path} 17 | cat ${input_path} | tail -1 >> ${output_path} 18 | done 19 | -------------------------------------------------------------------------------- /scripts/modpo/beavertails/utils/gen.sh: -------------------------------------------------------------------------------- 1 | # sh scripts/modpo/beavertails/utils/gen.sh 2 | sft_model_name="PKU-Alignment/alpaca-7b-reproduced" 3 | prompt_template="BEGINNING OF CONVERSATION: USER: {raw_prompt} ASSISTANT:" 4 | dataset_name="PKU-Alignment/PKU-SafeRLHF-10K" 5 | max_length=512 6 | 7 | for dir in $(find output/PKU-Alignment/PKU-SafeRLHF-10K/modpo/lm -mindepth 1 -maxdepth 1 -type 'd'); do 8 | adapter_model_name="${dir}/best_checkpoint" 9 | output_dir="${dir}/gen" 10 | echo "#### generating responses from ckpt ${adapter_model_name} ####" 11 | PYTHONPATH=. python3 scripts/modpo/beavertails/utils/gen.py \ 12 | --sft_model_name ${sft_model_name} \ 13 | --adapter_model_name ${adapter_model_name} \ 14 | --prompt_template "${prompt_template}" \ 15 | --dataset_name "${dataset_name}-safer" \ 16 | --output_dir ${output_dir} \ 17 | --max_length ${max_length} 18 | done 19 | -------------------------------------------------------------------------------- /scripts/examples/sft/run.sh: -------------------------------------------------------------------------------- 1 | # sh scripts/examples/sft/run.sh 2 | LAUNCH="accelerate launch --config_file scripts/accelerate_configs/multi_gpu.yaml --num_processes=8" 3 | 4 | base_model_name="meta-llama/Llama-2-7b-hf" 5 | dataset_name="PKU-Alignment/PKU-SafeRLHF-10K-safer" 6 | max_length=512 7 | sanity_check=False 8 | chosen_only=False 9 | output_dir="./output" 10 | 11 | # SFT 12 | sft_run_name="${dataset_name}/sft" 13 | PYTHONPATH=. $LAUNCH scripts/examples/sft/sft.py \ 14 | --base_model_name ${base_model_name} \ 15 | --dataset_name ${dataset_name} \ 16 | --sanity_check ${sanity_check} \ 17 | --max_length ${max_length} \ 18 | --chosen_only ${chosen_only} \ 19 | --training_args.output_dir "${output_dir}/${sft_run_name}" \ 20 | --training_args.run_name ${sft_run_name} \ 21 | --peft_config.target_modules q_proj k_proj v_proj o_proj 22 | 23 | sft_model_name="${output_dir}/${sft_run_name}/merged_checkpoint" 24 | 25 | # Merge SFT LoRA weights 26 | PYTHONPATH=. python src/tools/merge_peft_adapter.py \ 27 | --adapter_model_name "${output_dir}/${sft_run_name}/best_checkpoint" \ 28 | --base_model_name ${base_model_name} \ 29 | --output_name ${sft_model_name} 30 | -------------------------------------------------------------------------------- /scripts/examples/rm/run.sh: -------------------------------------------------------------------------------- 1 | # sh scripts/examples/sft/run.sh 2 | LAUNCH="accelerate launch --config_file scripts/accelerate_configs/multi_gpu.yaml --num_processes=8" 3 | 4 | base_model_name="meta-llama/Llama-2-7b-hf" 5 | dataset_name="PKU-Alignment/PKU-SafeRLHF-10K-safer" 6 | max_length=512 7 | sanity_check=False 8 | output_dir="./output" 9 | 10 | # SFT 11 | sft_run_name="${dataset_name}/sft" 12 | PYTHONPATH=. $LAUNCH scripts/examples/sft/sft.py \ 13 | --base_model_name ${base_model_name} \ 14 | --dataset_name ${dataset_name} \ 15 | --sanity_check ${sanity_check} \ 16 | --max_length ${max_length} \ 17 | --training_args.output_dir "${output_dir}/${sft_run_name}" \ 18 | --training_args.run_name ${sft_run_name} \ 19 | --peft_config.target_modules q_proj k_proj v_proj o_proj 20 | 21 | sft_model_name="${output_dir}/${sft_run_name}/merged_checkpoint" 22 | 23 | # Merge SFT LoRA weights 24 | PYTHONPATH=. python src/tools/merge_peft_adapter.py \ 25 | --adapter_model_name "${output_dir}/${sft_run_name}/best_checkpoint" \ 26 | --base_model_name ${base_model_name} \ 27 | --output_name ${sft_model_name} 28 | 29 | # RM 30 | rm_run_name="${dataset_name}/rm" 31 | PYTHONPATH=. $LAUNCH scripts/examples/rm/rm.py \ 32 | --sft_model_name ${sft_model_name} \ 33 | --dataset_name ${dataset_name} \ 34 | --sanity_check ${sanity_check} \ 35 | --max_length ${max_length} \ 36 | --training_args.output_dir "${output_dir}/${rm_run_name}" \ 37 | --training_args.run_name ${rm_run_name} \ 38 | --peft_config.target_modules q_proj k_proj v_proj o_proj 39 | -------------------------------------------------------------------------------- /src/data/raw_data/stack_exchange_paired.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Dict, Optional 3 | 4 | from datasets import load_dataset 5 | 6 | from .utils import RawDatasetPreprocessor 7 | 8 | 9 | @dataclass 10 | class StackExchangePairedRDP(RawDatasetPreprocessor): 11 | path: Optional[str] = "lvwerra/stack-exchange-paired" 12 | 13 | def _get_raw_dataset(self, split): 14 | if split == "train": 15 | return load_dataset(self.path, split="train") 16 | elif split == "validation": 17 | return load_dataset(self.path, split="validation").train_test_split(test_size=0.5, seed=0)['train'] 18 | elif split == "test": 19 | return load_dataset(self.path, split="validation").train_test_split(test_size=0.5, seed=0)['test'] 20 | else: 21 | raise NotImplementedError 22 | 23 | def _dataset_to_preference_formatter(self, example) -> Dict[str, str]: 24 | return { 25 | "raw_prompt": example["question"], 26 | "prompt": self.prompt_template.format(raw_prompt=example["question"]), 27 | "chosen": example["response_j"], 28 | "rejected": example["response_k"], 29 | } 30 | 31 | 32 | if __name__ == '__main__': 33 | path = "lvwerra/stack-exchange-paired" 34 | train_dataset = StackExchangePairedRDP().get_preference_dataset(split="train") 35 | validation_dataset = StackExchangePairedRDP().get_preference_dataset(split="validation") 36 | test_dataset = StackExchangePairedRDP().get_preference_dataset(split="test") 37 | -------------------------------------------------------------------------------- /src/data/raw_data/shp.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Dict, Optional 3 | 4 | from datasets import load_dataset 5 | 6 | from .utils import RawDatasetPreprocessor 7 | 8 | @dataclass 9 | class SHPRDP(RawDatasetPreprocessor): 10 | path: Optional[str] = "stanfordnlp/SHP" 11 | """ 12 | labels: the preference label -- it is 1 if A is preferred to B; 0 if B is preferred to A. 13 | """ 14 | def _get_raw_dataset(self, split): 15 | if split == "train": 16 | return load_dataset(self.path, split="train") 17 | elif split == "validation": 18 | return load_dataset(self.path, split="validation") 19 | elif split == "test": 20 | return load_dataset(self.path, split="test") 21 | else: 22 | raise NotImplementedError 23 | 24 | def _dataset_to_preference_formatter(self, example) -> Dict[str, str]: 25 | return { 26 | "raw_prompt": example["history"], 27 | "prompt": self.prompt_template.format(raw_prompt=example["history"]), 28 | "chosen": example["human_ref_A"] if example["labels"] == 1 else example["human_ref_B"], 29 | "rejected": example["human_ref_B"] if example["labels"] == 1 else example["human_ref_A"], 30 | } 31 | 32 | 33 | if __name__ == '__main__': 34 | train_dataset = SHPRDP().get_preference_dataset(split="train") 35 | validation_dataset = SHPRDP().get_preference_dataset(split="validation") 36 | test_dataset = SHPRDP().get_preference_dataset(split="test") 37 | 38 | sft_train_dataset = SHPRDP().get_sft_dataset(split="train") 39 | -------------------------------------------------------------------------------- /scripts/examples/dpo/run.sh: -------------------------------------------------------------------------------- 1 | # sh scripts/examples/dpo/run.sh 2 | LAUNCH="accelerate launch --config_file scripts/accelerate_configs/multi_gpu.yaml --num_processes=8" 3 | 4 | base_model_name="meta-llama/Llama-2-7b-hf" 5 | dataset_name="PKU-Alignment/PKU-SafeRLHF-10K-safer" 6 | max_length=512 7 | chosen_only=False 8 | sanity_check=False 9 | output_dir="./output" 10 | 11 | # SFT 12 | sft_run_name="${dataset_name}/sft" 13 | PYTHONPATH=. $LAUNCH scripts/examples/sft/sft.py \ 14 | --base_model_name ${base_model_name} \ 15 | --dataset_name ${dataset_name} \ 16 | --sanity_check ${sanity_check} \ 17 | --max_length ${max_length} \ 18 | --chosen_only ${chosen_only} \ 19 | --training_args.output_dir "${output_dir}/${sft_run_name}" \ 20 | --training_args.run_name ${sft_run_name} \ 21 | --peft_config.target_modules q_proj k_proj v_proj o_proj 22 | 23 | sft_model_name="${output_dir}/${sft_run_name}/merged_checkpoint" 24 | 25 | # Merge SFT LoRA weights 26 | PYTHONPATH=. python src/tools/merge_peft_adapter.py \ 27 | --adapter_model_name "${output_dir}/${sft_run_name}/best_checkpoint" \ 28 | --base_model_name ${base_model_name} \ 29 | --output_name ${sft_model_name} 30 | 31 | # DPO 32 | dpo_run_name="${dataset_name}/dpo" 33 | PYTHONPATH=. $LAUNCH scripts/examples/dpo/dpo.py \ 34 | --sft_model_name ${sft_model_name} \ 35 | --dataset_name ${dataset_name} \ 36 | --sanity_check ${sanity_check} \ 37 | --max_length ${max_length} \ 38 | --training_args.output_dir "${output_dir}/${dpo_run_name}" \ 39 | --training_args.run_name ${dpo_run_name} \ 40 | --peft_config.target_modules q_proj k_proj v_proj o_proj 41 | -------------------------------------------------------------------------------- /src/data/raw_data/summarize_from_feedback.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Dict, Optional 3 | 4 | from datasets import load_dataset 5 | 6 | from .utils import RawDatasetPreprocessor 7 | 8 | 9 | @dataclass 10 | class SummarizeFromFeedbackRDP(RawDatasetPreprocessor): 11 | path: Optional[str] = "openai/summarize_from_feedback" 12 | 13 | def _get_raw_dataset(self, split): 14 | if split == "train": 15 | return load_dataset(self.path, 'comparisons', split="train") 16 | elif split == "validation": 17 | return load_dataset(self.path, 'comparisons', split="validation").train_test_split(test_size=0.5, seed=0)['train'] 18 | elif split == "test": 19 | return load_dataset(self.path, 'comparisons', split="validation").train_test_split(test_size=0.5, seed=0)['test'] 20 | else: 21 | raise NotImplementedError 22 | 23 | def _dataset_to_preference_formatter(self, example) -> Dict[str, str]: 24 | return { 25 | "raw_prompt": example["info"]["post"], 26 | "prompt": self.prompt_template.format(raw_prompt=example["info"]["post"]), 27 | "chosen": example["summaries"][example["choice"]]["text"], 28 | "rejected": example["summaries"][1-example["choice"]]["text"], 29 | } 30 | 31 | 32 | if __name__ == '__main__': 33 | train_dataset = SummarizeFromFeedbackRDP().get_preference_dataset(split="train") 34 | validation_dataset = SummarizeFromFeedbackRDP().get_preference_dataset(split="validation") 35 | test_dataset = SummarizeFromFeedbackRDP().get_preference_dataset(split="test") 36 | -------------------------------------------------------------------------------- /src/data/raw_data/utils.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from dataclasses import dataclass 3 | from typing import Dict, Optional 4 | 5 | from datasets import Dataset, concatenate_datasets 6 | 7 | from src.utils import print_local_main 8 | 9 | DEFAULT_PROMPT_TEMPLATE = "\n\nHuman:\n{raw_prompt}\n\nAssistant:\n" 10 | 11 | @dataclass 12 | class RawDatasetPreprocessor(ABC): 13 | path: Optional[str] = None 14 | prompt_template: Optional[str] = DEFAULT_PROMPT_TEMPLATE 15 | sanity_check: Optional[bool] = False 16 | num_proc: Optional[int] = 4 17 | 18 | @abstractmethod 19 | def _get_raw_dataset(self, split) -> Dataset: 20 | raise NotImplementedError 21 | 22 | @abstractmethod 23 | def _dataset_to_preference_formatter(self, example) -> Dict[str, str]: 24 | # return { 25 | # "raw_prompt": str, # optional, useful for generation 26 | # "prompt": str, 27 | # "chosen": str, 28 | # "rejected": str, 29 | # } 30 | raise NotImplementedError 31 | 32 | def get_preference_dataset(self, split) -> Dataset: 33 | """ 34 | return a dataset of texts with three keys "prompt", "chosen", "rejected", ("raw_prompt", optional) 35 | """ 36 | dataset = self._get_raw_dataset(split) 37 | if self.sanity_check: dataset = dataset.select(range(min(len(dataset), 100))) 38 | print_local_main("mapping dataset to standard format...") 39 | return dataset.map(self._dataset_to_preference_formatter, num_proc=self.num_proc, remove_columns=dataset.column_names) 40 | 41 | def get_sft_dataset(self, split, chosen_only=True): 42 | """ 43 | return a dataset of texts with two keys "prompt", "response", ("raw_prompt", optional) 44 | """ 45 | print_local_main("mapping preference to sft...") 46 | dataset = self.get_preference_dataset(split) 47 | chosen_only_dataset = dataset.remove_columns("rejected").rename_column("chosen", "response") 48 | if chosen_only: 49 | return chosen_only_dataset 50 | rejected_only_dataset = dataset.remove_columns("chosen").rename_column("rejected", "response") 51 | return concatenate_datasets([chosen_only_dataset, rejected_only_dataset]) 52 | -------------------------------------------------------------------------------- /src/data/raw_data/safe_rlhf.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from abc import ABC 3 | from typing import Dict, Literal, Optional 4 | 5 | from datasets import load_dataset 6 | 7 | from .utils import RawDatasetPreprocessor 8 | 9 | @dataclass 10 | class PKUSafeRlhfRDPBase(RawDatasetPreprocessor, ABC): 11 | dimension: Literal["safer", "better"] = "better" 12 | 13 | def _dataset_to_preference_formatter(self, example) -> Dict[str, str]: 14 | chosen_idx = example[f"{self.dimension}_response_id"] 15 | return { 16 | "raw_prompt": example["prompt"], 17 | "prompt": self.prompt_template.format(raw_prompt=example["prompt"]), 18 | "chosen": example[f"response_{chosen_idx}"], 19 | "rejected": example[f"response_{1-chosen_idx}"], 20 | } 21 | 22 | @dataclass 23 | class PKUSafeRlhfRDP(PKUSafeRlhfRDPBase): 24 | path: Optional[str] = "PKU-Alignment/PKU-SafeRLHF" 25 | 26 | def _get_raw_dataset(self, split): 27 | if split == "train": 28 | return load_dataset(self.path, split="train").train_test_split(test_size=0.1, seed=0)["train"] 29 | elif split == "validation": 30 | return load_dataset(self.path, split="train").train_test_split(test_size=0.1, seed=0)["test"] 31 | elif split == "test": 32 | return load_dataset(self.path, split="test") 33 | else: 34 | raise NotImplementedError 35 | 36 | 37 | @dataclass 38 | class PKUSafeRlhf10KRDP(PKUSafeRlhfRDPBase): 39 | path: Optional[str] = "PKU-Alignment/PKU-SafeRLHF-10K" 40 | 41 | def _get_raw_dataset(self, split): 42 | if split == "train": 43 | return load_dataset(self.path, split="train").train_test_split(test_size=0.1, seed=0)["train"] 44 | elif split == "validation": 45 | return load_dataset(self.path, split="train").train_test_split(test_size=0.1, seed=0)["test"] 46 | elif split == "test": 47 | raise NotImplementedError("PKU-Alignment/PKU-SafeRLHF-10K is for development, no test set available.") 48 | else: 49 | raise NotImplementedError 50 | 51 | 52 | if __name__ == '__main__': 53 | safer10k_train_dataset = PKUSafeRlhf10KRDP(dimension="safer").get_preference_dataset(split="train") 54 | better10k_train_dataset = PKUSafeRlhf10KRDP(dimension="better").get_preference_dataset(split="train") 55 | breakpoint() 56 | -------------------------------------------------------------------------------- /scripts/modpo/beavertails/utils/score.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Optional 3 | import os 4 | import math 5 | 6 | import torch 7 | import tyro 8 | import tqdm 9 | from transformers import AutoTokenizer 10 | from datasets import Dataset, load_dataset 11 | 12 | from src.utils import ( 13 | disable_progress_bar_non_local_main 14 | ) 15 | from scripts.modpo.beavertails.utils.score_model import LlamaForScore 16 | 17 | disable_progress_bar_non_local_main() 18 | 19 | 20 | @dataclass 21 | class ScriptArguments: 22 | 23 | input_dir: Optional[str] = field(default=None, metadata={"help": "output path for generations"}) 24 | output_dir: Optional[str] = field(default=None, metadata={"help": "output path for generations"}) 25 | 26 | 27 | if __name__ == "__main__": 28 | 29 | script_args = tyro.cli(ScriptArguments) 30 | 31 | reward = LlamaForScore.from_pretrained('PKU-Alignment/beaver-7b-v1.0-reward', torch_dtype=torch.bfloat16, device_map='auto') 32 | cost = LlamaForScore.from_pretrained('PKU-Alignment/beaver-7b-v1.0-cost', torch_dtype=torch.bfloat16, device_map='auto') 33 | tokenizer = AutoTokenizer.from_pretrained('PKU-Alignment/beaver-7b-v1.0-reward') 34 | 35 | generation = load_dataset(script_args.input_dir, split="train") 36 | 37 | results = [] 38 | with torch.no_grad(): 39 | for prompt_response in tqdm.tqdm(generation['prompt_response']): 40 | input = tokenizer(prompt_response, return_tensors="pt") 41 | reward_output = reward(input["input_ids"].cuda(), input["attention_mask"].cuda()) 42 | cost_output = cost(input["input_ids"].cuda(), input["attention_mask"].cuda()) 43 | results.append({ 44 | "prompt_response": prompt_response, 45 | "reward": reward_output.end_scores.item(), 46 | "cost": cost_output.end_scores.item(), 47 | }) 48 | 49 | # raw 50 | dataset = Dataset.from_list(results) 51 | dataset.to_json(os.path.join(script_args.output_dir, "raw.jsonl")) 52 | 53 | # mean 54 | rewards = [result["reward"] for result in results] 55 | costs = [result["cost"] for result in results] 56 | mean_reward = sum(rewards) / len(rewards) 57 | mean_cost = sum(costs) / len(costs) 58 | with open(os.path.join(script_args.output_dir, "mean.csv"), "w") as f: 59 | f.write("mean reward, mean cost\n") 60 | f.write(f"{mean_reward}, {mean_cost}\n") 61 | -------------------------------------------------------------------------------- /src/tools/merge_peft_adapter.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Optional 3 | 4 | import torch 5 | from peft import PeftConfig, PeftModel 6 | from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, HfArgumentParser 7 | 8 | 9 | @dataclass 10 | class ScriptArguments: 11 | """ 12 | The input names representing the Adapter and Base model fine-tuned with PEFT, and the output name representing the 13 | merged model. 14 | """ 15 | 16 | adapter_model_name: Optional[str] = field(default=None, metadata={"help": "the adapter name"}) 17 | base_model_name: Optional[str] = field(default=None, metadata={"help": "the base model name"}) 18 | output_name: Optional[str] = field(default=None, metadata={"help": "the merged model name"}) 19 | dtype: Optional[str] = field(default="fp16", metadata={"help": "dtype"}) 20 | push_to_hub: Optional[bool] = field(default=False, metadata={"help": "the merged model name"}) 21 | 22 | 23 | parser = HfArgumentParser(ScriptArguments) 24 | script_args = parser.parse_args_into_dataclasses()[0] 25 | assert script_args.adapter_model_name is not None, "please provide the name of the Adapter you would like to merge" 26 | assert script_args.base_model_name is not None, "please provide the name of the Base model" 27 | assert script_args.output_name is not None, "please provide the output name of the merged model" 28 | 29 | str2dtype = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp32": torch.float} 30 | 31 | peft_config = PeftConfig.from_pretrained(script_args.adapter_model_name) 32 | if peft_config.task_type == "SEQ_CLS": 33 | # The sequence classification task is used for the reward model in PPO 34 | model = AutoModelForSequenceClassification.from_pretrained( 35 | script_args.base_model_name, num_labels=1, torch_dtype=str2dtype[script_args.dtype] 36 | ) 37 | else: 38 | model = AutoModelForCausalLM.from_pretrained( 39 | script_args.base_model_name, return_dict=True, torch_dtype=str2dtype[script_args.dtype] 40 | ) 41 | 42 | tokenizer = AutoTokenizer.from_pretrained(script_args.base_model_name) 43 | 44 | # Load the PEFT model 45 | model = PeftModel.from_pretrained(model, script_args.adapter_model_name) 46 | model.eval() 47 | 48 | model = model.merge_and_unload() 49 | 50 | model.save_pretrained(f"{script_args.output_name}") 51 | tokenizer.save_pretrained(f"{script_args.output_name}") 52 | if script_args.push_to_hub: 53 | model.push_to_hub(f"{script_args.output_name}", use_temp_dir=False) 54 | -------------------------------------------------------------------------------- /src/data/configs.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import Dict 3 | 4 | from src.data.raw_data.helpsteer import HelpSteerRDP 5 | 6 | from .raw_data import ( 7 | RawDatasetPreprocessor, 8 | HhRlhfRDP, 9 | PKUSafeRlhfRDP, PKUSafeRlhf10KRDP, 10 | SHPRDP, 11 | StackExchangePairedRDP, 12 | SummarizeFromFeedbackRDP, 13 | HelpSteerRDP, 14 | UltraFeedbackRDP, 15 | ) 16 | from .raw_data.utils import DEFAULT_PROMPT_TEMPLATE 17 | 18 | 19 | REAL_DATASET_CONFIGS: Dict[str, RawDatasetPreprocessor] = { 20 | ##### hh-rlhf (https://huggingface.co/datasets/Anthropic/hh-rlhf) ##### 21 | "Anthropic/hh-rlhf": HhRlhfRDP, 22 | 23 | ##### PKU-SafeRLHF (https://huggingface.co/datasets/PKU-Alignment/PKU-SafeRLHF) ##### 24 | **{ 25 | f"PKU-Alignment/PKU-SafeRLHF-{dimension}": partial(PKUSafeRlhfRDP, dimension=dimension) 26 | for dimension in ["safer", "better"] 27 | }, 28 | **{ 29 | f"PKU-Alignment/PKU-SafeRLHF-10K-{dimension}": partial(PKUSafeRlhf10KRDP, dimension=dimension) 30 | for dimension in ["safer", "better"] 31 | }, 32 | 33 | ##### stack-exchange-paired (https://huggingface.co/datasets/lvwerra/stack-exchange-paired) ##### 34 | "lvwerra/stack-exchange-paired": StackExchangePairedRDP, 35 | 36 | ##### SHP (https://huggingface.co/datasets/stanfordnlp/SHP) ##### 37 | "stanfordnlp/SHP": SHPRDP, 38 | 39 | ##### summarize_from_feedback (https://huggingface.co/datasets/openai/summarize_from_feedback) ##### 40 | "openai/summarize_from_feedback": SummarizeFromFeedbackRDP, 41 | 42 | ##### UltraFeedback (https://huggingface.co/datasets/openbmb/UltraFeedback) ##### 43 | "OpenBMB/UltraFeedback": UltraFeedbackRDP, 44 | **{ 45 | f"OpenBMB/UltraFeedback-{dimension}": partial(UltraFeedbackRDP, dimension=dimension) 46 | for dimension in ["overall", "instruction_following", "honesty", "truthfulness", "helpfulness"] 47 | }, 48 | 49 | ##### HelpSteer (https://huggingface.co/datasets/nvidia/HelpSteer) ##### 50 | "nvidia/HelpSteer": HelpSteerRDP, 51 | **{ 52 | f"nvidia/HelpSteer-pairwise-{dimension}": partial(HelpSteerRDP, dimension=dimension) 53 | for dimension in ["overall", "helpfulness", "correctness", "coherence", "complexity", "verbosity"] 54 | }, 55 | } 56 | 57 | 58 | # !WARNING: Synthetic datasets are WIP. These configs are just placeholders 59 | SYNTHETIC_DATASET_CONFIGS = { 60 | 61 | } 62 | 63 | 64 | # MERGE two dicts 65 | DATASET_CONFIGS = {**REAL_DATASET_CONFIGS, **SYNTHETIC_DATASET_CONFIGS} 66 | -------------------------------------------------------------------------------- /src/data/raw_data/hh_rlhf.py: -------------------------------------------------------------------------------- 1 | # adapted from https://github.com/huggingface/trl/blob/main/examples/scripts/dpo.py#L82 2 | from dataclasses import dataclass 3 | from typing import Dict, Optional 4 | 5 | from datasets import load_dataset 6 | 7 | from .utils import RawDatasetPreprocessor 8 | 9 | 10 | def preprocess_anthropic_prompt_and_response(prompt_and_response): 11 | prompt_and_response = prompt_and_response.replace("\n\nHuman: ", "\n\nHuman:\n") 12 | prompt_and_response = prompt_and_response.replace("\n\nAssistant: ", "\n\nAssistant:\n") 13 | return prompt_and_response 14 | 15 | def extract_anthropic_prompt_and_response(prompt_and_response): 16 | """Extract the anthropic prompt from a prompt and response pair.""" 17 | search_term = "\n\nAssistant:\n" 18 | search_term_idx = prompt_and_response.rfind(search_term) 19 | assert search_term_idx != -1, f"Prompt and response does not contain '{search_term}'" 20 | return prompt_and_response[: search_term_idx + len(search_term)] 21 | 22 | 23 | @dataclass 24 | class HhRlhfRDP(RawDatasetPreprocessor): 25 | path: Optional[str] = "Anthropic/hh-rlhf" 26 | 27 | # def __post_init__(self): 28 | # assert self.prompt_template == "\n\nHuman: {prompt}\n\nAssistant:" 29 | 30 | def _get_raw_dataset(self, split): 31 | if split == "train": 32 | return load_dataset(self.path, split="train").train_test_split(test_size=0.1, seed=0)["train"] 33 | elif split == "validation": 34 | return load_dataset(self.path, split="train").train_test_split(test_size=0.1, seed=0)["test"] 35 | elif split == "test": 36 | return load_dataset(self.path, split="test") 37 | else: 38 | raise NotImplementedError 39 | 40 | def _dataset_to_preference_formatter(self, example) -> Dict[str, str]: 41 | example["chosen"] = preprocess_anthropic_prompt_and_response(example["chosen"]) 42 | example["rejected"] = preprocess_anthropic_prompt_and_response(example["rejected"]) 43 | prompt = extract_anthropic_prompt_and_response(example["chosen"]) 44 | return { 45 | "prompt": prompt, 46 | "chosen": example["chosen"][len(prompt) :], 47 | "rejected": example["rejected"][len(prompt) :], 48 | } 49 | 50 | 51 | if __name__ == "__main__": 52 | train_dataset = HhRlhfRDP(num_proc=1).get_preference_dataset(split="train") 53 | validation_dataset = HhRlhfRDP(num_proc=1).get_preference_dataset(split="validation") 54 | test_dataset = HhRlhfRDP(num_proc=1).get_preference_dataset(split="test") 55 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .idea 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | applications/DeepSpeed-Chat/data 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | pip-wheel-metadata/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # pipenv 90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 93 | # install all needed dependencies. 94 | #Pipfile.lock 95 | 96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 97 | __pypackages__/ 98 | 99 | # Celery stuff 100 | celerybeat-schedule 101 | celerybeat.pid 102 | 103 | # SageMath parsed files 104 | *.sage.py 105 | 106 | # Environments 107 | .env 108 | .venv 109 | env/ 110 | venv/ 111 | ENV/ 112 | env.bak/ 113 | venv.bak/ 114 | 115 | # Spyder project settings 116 | .spyderproject 117 | .spyproject 118 | 119 | # Rope project settings 120 | .ropeproject 121 | 122 | # mkdocs documentation 123 | /site 124 | 125 | # mypy 126 | .mypy_cache/ 127 | .dmypy.json 128 | dmypy.json 129 | 130 | # Pyre type checker 131 | .pyre/ 132 | 133 | cache/ 134 | .vscode/ 135 | package/ 136 | output*/ 137 | tmp/ 138 | dataset/ 139 | wandb/ 140 | -------------------------------------------------------------------------------- /scripts/modpo/beavertails/run.sh: -------------------------------------------------------------------------------- 1 | # sh scripts/modpo/beavertails/run.sh 2 | LAUNCH="accelerate launch --config_file scripts/accelerate_configs/multi_gpu.yaml --num_processes=8" 3 | 4 | sft_model_name="PKU-Alignment/alpaca-7b-reproduced" 5 | prompt_template="BEGINNING OF CONVERSATION: USER: {raw_prompt} ASSISTANT:" 6 | dataset_name="PKU-Alignment/PKU-SafeRLHF-10K" 7 | sanity_check=False 8 | output_dir="./output" 9 | max_length=512 10 | per_device_train_batch_size=6 11 | per_device_eval_batch_size=6 12 | gradient_accumulation_steps=2 13 | learning_rate=5e-4 14 | 15 | # Reward Modeling: Run DPO on safe preferences to train a safe reward model that encourages safe response 16 | rm_run_name="${dataset_name}/modpo/rm/safer" 17 | PYTHONPATH=. $LAUNCH scripts/examples/dpo/dpo.py \ 18 | --sft_model_name ${sft_model_name} \ 19 | --prompt_template "${prompt_template}" \ 20 | --dataset_name "${dataset_name}-safer" \ 21 | --sanity_check ${sanity_check} \ 22 | --max_length ${max_length} \ 23 | --training_args.output_dir "${output_dir}/${rm_run_name}" \ 24 | --training_args.run_name ${rm_run_name} \ 25 | --training_args.per_device_train_batch_size ${per_device_train_batch_size} \ 26 | --training_args.per_device_eval_batch_size ${per_device_eval_batch_size} \ 27 | --training_args.gradient_accumulation_steps ${gradient_accumulation_steps} \ 28 | --training_args.learning_rate ${learning_rate} \ 29 | --peft_config.r 64 \ 30 | --peft_config.target_modules q_proj k_proj v_proj o_proj \ 31 | --peft_config.lora_alpha 1 \ 32 | --peft_config.lora_dropout 0 33 | 34 | # Language Modeling: Run MODPO on helpful preferences, with the safe reward as margin, to train language models that are both helpful and safe 35 | # r = (w)r_better + (1-w)r_safer 36 | for w in 0.1 0.5 0.9; do 37 | lm_run_name="${dataset_name}/modpo/lm/($w)better+(1-$w)safer" 38 | PYTHONPATH=. $LAUNCH scripts/modpo/beavertails/modpo.py \ 39 | --sft_model_name ${sft_model_name} \ 40 | --margin_reward_model_name "${output_dir}/${rm_run_name}/best_checkpoint" \ 41 | --prompt_template "${prompt_template}" \ 42 | --dataset_name "${dataset_name}-better" \ 43 | --sanity_check ${sanity_check} \ 44 | --w ${w} \ 45 | --max_length ${max_length} \ 46 | --training_args.output_dir "${output_dir}/${lm_run_name}" \ 47 | --training_args.run_name ${lm_run_name} \ 48 | --training_args.per_device_train_batch_size ${per_device_train_batch_size} \ 49 | --training_args.per_device_eval_batch_size ${per_device_eval_batch_size} \ 50 | --training_args.gradient_accumulation_steps ${gradient_accumulation_steps} \ 51 | --training_args.learning_rate ${learning_rate} \ 52 | --peft_config.r 64 \ 53 | --peft_config.target_modules q_proj k_proj v_proj o_proj \ 54 | --peft_config.lora_alpha 1 \ 55 | --peft_config.lora_dropout 0 56 | done 57 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MODPO: Multi-Objective Direct Preference Optimization 2 | 3 | Code release for [Beyond One-Preference-Fits-All Alignment: Multi-Objective Direct Preference Optimization](https://arxiv.org/pdf/2310.03708.pdf). 4 | 5 | TL;DR: Compared to [DPO loss](https://github.com/ZHZisZZ/modpo/blob/main/src/trainer/dpo_trainer.py#L413), [MODPO loss](https://github.com/ZHZisZZ/modpo/blob/main/src/trainer/modpo_trainer.py#L142) includes [a margin](https://github.com/ZHZisZZ/modpo/blob/main/src/trainer/modpo_trainer.py#L151-L152) to steer language models by multiple objectives. 6 | 7 | ## Installation 8 | 9 | ```bash 10 | conda create -n modpo python=3.10 11 | conda activate modpo 12 | pip install torch==2.1.0 --index-url https://download.pytorch.org/whl/cu118 13 | pip install -r requirements.txt 14 | # (optional) pip install flash-attn==2.3.2 --no-build-isolation 15 | ``` 16 | 17 | ## Running MODPO 18 | 19 | This repository includes two MODPO examples: 20 | 21 | - Safety alignment ([`scripts/modpo/beavertails`](https://github.com/ZHZisZZ/modpo/blob/main/scripts/modpo/beavertails)): Balances different values such as safety vs. helpfulness. 22 | 23 | - Summarization with length penalty ([`scripts/modpo/summarize_w_length_penalty`](https://github.com/ZHZisZZ/modpo/blob/main/scripts/modpo/summarize_w_length_penalty)): Reduces length bias (verbosity) in summarization. 24 | 25 | ## Other examples 26 | 27 | This repository also contains other off-the-shelf tuning recipes: 28 | 29 | - SFT (Supervised Fine-tuning): [`scripts/examples/sft/run.sh`](https://github.com/ZHZisZZ/modpo/blob/main/scripts/examples/sft/run.sh) 30 | - RM (Reward Modeling): [`scripts/examples/rm/run.sh`](https://github.com/ZHZisZZ/modpo/blob/main/scripts/examples/rm/run.sh) 31 | - DPO (Direct Preference Optimization): [`scripts/examples/dpo/run.sh`](https://github.com/ZHZisZZ/modpo/blob/main/scripts/examples/dpo/run.sh) 32 | 33 | To implement new alignment algorithms, please add new trainers at [`src/trainer`](https://github.com/ZHZisZZ/modpo/blob/main/src/trainer). 34 | 35 | 36 | ## Customized datasets 37 | 38 | For supported datasets, refer to [`REAL_DATASET_CONFIGS(src/data/configs.py)`](https://github.com/ZHZisZZ/modpo/blob/main/src/data/configs.py#L19). 39 | To train on your datasets, add them under [`src/data/raw_data`](https://github.com/ZHZisZZ/modpo/blob/main/src/data/raw_data) and modify [`REAL_DATASET_CONFIGS(src/data/configs.py)`](https://github.com/ZHZisZZ/modpo/blob/main/src/data/configs.py#L19) accordingly. Please see [`src/data/raw_data/shp`](https://github.com/ZHZisZZ/modpo/blob/main/src/data/raw_data/shp.py) for an example. 40 | 41 | ## Reference 42 | 43 | ``` 44 | @inproceedings{zhou2024beyond, 45 | title={Beyond one-preference-fits-all alignment: Multi-objective direct preference optimization}, 46 | author={Zhou, Zhanhui and Liu, Jie and Shao, Jing and Yue, Xiangyu and Yang, Chao and Ouyang, Wanli and Qiao, Yu}, 47 | booktitle={Findings of the Association for Computational Linguistics ACL 2024}, 48 | pages={10586--10613}, 49 | year={2024} 50 | } 51 | ``` 52 | 53 | -------------------------------------------------------------------------------- /scripts/modpo/summarize_w_length_penalty/run.sh: -------------------------------------------------------------------------------- 1 | # sh scripts/modpo/summarize_w_length_penalty/run.sh 2 | LAUNCH="accelerate launch --config_file scripts/accelerate_configs/multi_gpu.yaml --num_processes=8" 3 | 4 | base_model_name="meta-llama/Llama-2-7b-hf" 5 | prompt_template="\n\nParagraph:\n{raw_prompt}\n\nTL;DR:\n" 6 | dataset_name="openai/summarize_from_feedback" 7 | sanity_check=False 8 | output_dir="./output" 9 | max_length=512 10 | per_device_train_batch_size=3 11 | per_device_eval_batch_size=6 12 | gradient_accumulation_steps=4 13 | learning_rate=5e-4 14 | 15 | # Supervised Fine-Tuning: Run SFT with LoRA 16 | sft_run_name="${dataset_name}/sft" 17 | chosen_only=True 18 | PYTHONPATH=. $LAUNCH scripts/examples/sft/sft.py \ 19 | --base_model_name ${base_model_name} \ 20 | --prompt_template "${prompt_template}" \ 21 | --dataset_name ${dataset_name} \ 22 | --sanity_check ${sanity_check} \ 23 | --max_length ${max_length} \ 24 | --chosen_only ${chosen_only} \ 25 | --training_args.output_dir "${output_dir}/${sft_run_name}" \ 26 | --training_args.run_name ${sft_run_name} \ 27 | --training_args.num_train_epochs 1 \ 28 | --training_args.per_device_train_batch_size ${per_device_train_batch_size} \ 29 | --training_args.per_device_eval_batch_size ${per_device_eval_batch_size} \ 30 | --training_args.gradient_accumulation_steps ${gradient_accumulation_steps} \ 31 | --training_args.learning_rate ${learning_rate} \ 32 | --peft_config.r 64 \ 33 | --peft_config.target_modules q_proj k_proj v_proj o_proj \ 34 | --peft_config.lora_alpha 1 \ 35 | --peft_config.lora_dropout 0.05 36 | 37 | # Supervised Fine-Tuning: Merge SFT LoRA weights 38 | PYTHONPATH=. python src/tools/merge_peft_adapter.py \ 39 | --adapter_model_name "${output_dir}/${sft_run_name}/best_checkpoint" \ 40 | --base_model_name ${base_model_name} \ 41 | --dtype bf16 \ 42 | --output_name "${output_dir}/${sft_run_name}/merged_checkpoint" 43 | 44 | # Language Model: Run MODPO on human preferences, with length penalty as margin 45 | # r = r_prefernce + w*length_penalty 46 | w=0.1 47 | lm_run_name="${dataset_name}/modpo/lm/preference+($w)*length_penalty" 48 | PYTHONPATH=. $LAUNCH scripts/modpo/summarize_w_length_penalty/modpo.py \ 49 | --sft_model_name "${output_dir}/${sft_run_name}/best_checkpoint" \ 50 | --prompt_template "${prompt_template}" \ 51 | --dataset_name "${dataset_name}" \ 52 | --sanity_check ${sanity_check} \ 53 | --w ${w} \ 54 | --max_length ${max_length} \ 55 | --training_args.output_dir "${output_dir}/${lm_run_name}" \ 56 | --training_args.run_name ${lm_run_name} \ 57 | --training_args.num_train_epochs 2 \ 58 | --training_args.per_device_train_batch_size ${per_device_train_batch_size} \ 59 | --training_args.per_device_eval_batch_size ${per_device_eval_batch_size} \ 60 | --training_args.gradient_accumulation_steps ${gradient_accumulation_steps} \ 61 | --training_args.learning_rate ${learning_rate} \ 62 | --peft_config.r 64 \ 63 | --peft_config.target_modules q_proj k_proj v_proj o_proj \ 64 | --peft_config.lora_alpha 1 \ 65 | --peft_config.lora_dropout 0 66 | -------------------------------------------------------------------------------- /scripts/modpo/summarize_w_length_penalty/README.md: -------------------------------------------------------------------------------- 1 | # Summarize with Length Penalty 2 | 3 | This directory illustrates how we use MODPO to reduce length bias (reduce response verbosity) in summarization on the [TL;DR dataset](https://huggingface.co/datasets/openai/summarize_from_feedback). 4 | This is a simplified version of the long-form QA experiments from the [MODPO paper](https://arxiv.org/pdf/2310.03708.pdf). 5 | 6 | ## Automatic Pipeline 7 | 8 | To optimize a language model for human preferences with length penalty, run: 9 | ``` 10 | sh scripts/modpo/summarize_w_length_penalty/run.sh 11 | ``` 12 | 13 | 14 | ## Annotated Pipeline 15 | 16 | 1. Supervised Fine-Tuning: 17 | 1. Run SFT with LoRA: 18 | ```sh 19 | PYTHONPATH=. accelerate launch --config_file scripts/accelerate_configs/multi_gpu.yaml --num_processes=8 \ 20 | scripts/examples/sft/sft.py \ 21 | --base_model_name "meta-llama/Llama-2-7b-hf" \ 22 | --prompt_template "\n\nParagraph:\n{raw_prompt}\n\nTL;DR:\n" \ 23 | --dataset_name "openai/summarize_from_feedback" \ 24 | --max_length 512 \ 25 | --chosen_only True \ 26 | --training_args.output_dir "./output/openai/summarize_from_feedback/sft" \ 27 | --training_args.run_name "openai/summarize_from_feedback/sft" \ 28 | --training_args.num_train_epochs 1 \ 29 | --training_args.per_device_train_batch_size 3 \ 30 | --training_args.per_device_eval_batch_size 6 \ 31 | --training_args.gradient_accumulation_steps 4 \ 32 | --training_args.learning_rate 5e-4 \ 33 | --peft_config.r 64 \ 34 | --peft_config.target_modules q_proj k_proj v_proj o_proj \ 35 | --peft_config.lora_alpha 1 \ 36 | --peft_config.lora_dropout 0.05 37 | ``` 38 | 2. Merge SFT LoRA weights: 39 | ```sh 40 | PYTHONPATH=. python src/tools/merge_peft_adapter.py \ 41 | --adapter_model_name "./output/openai/summarize_from_feedback/sft/best_checkpoint" \ 42 | --base_model_name "meta-llama/Llama-2-7b-hf" \ 43 | --dtype bf16 \ 44 | --output_name "./output/openai/summarize_from_feedback/sft/merged_checkpoint" 45 | ``` 46 | 47 | 2. Language Modeling: Run MODPO on human preferences, with length penalty as margin: 48 | ```sh 49 | PYTHONPATH=. accelerate launch --config_file scripts/accelerate_configs/multi_gpu.yaml --num_processes=8 \ 50 | scripts/modpo/summarize_w_length_penalty/modpo.py \ 51 | --sft_model_name "./output/openai/summarize_from_feedback/sft/merged_checkpoint" \ 52 | --prompt_template "\n\nParagraph:\n{raw_prompt}\n\nTL;DR:\n" \ 53 | --dataset_name "openai/summarize_from_feedback" \ 54 | --w 0.1 \ 55 | --max_length 512 \ 56 | --training_args.output_dir "./output/openai/summarize_from_feedback/modpo/lm/preference+(0.1)*length_penalty" \ 57 | --training_args.run_name "openai/summarize_from_feedback/modpo/lm/preference+(0.1)*length_penalty" \ 58 | --training_args.num_train_epochs 2 \ 59 | --training_args.per_device_train_batch_size 3 \ 60 | --training_args.per_device_eval_batch_size 6 \ 61 | --training_args.gradient_accumulation_steps 4 \ 62 | --training_args.learning_rate 5e-4 \ 63 | --peft_config.r 64 \ 64 | --peft_config.target_modules q_proj k_proj v_proj o_proj \ 65 | --peft_config.lora_alpha 1 \ 66 | --peft_config.lora_dropout 0 67 | ``` -------------------------------------------------------------------------------- /src/utils/reward.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractclassmethod 2 | from dataclasses import dataclass, asdict 3 | from typing import Any, Text, List, Dict, Optional 4 | 5 | import torch 6 | from accelerate import Accelerator 7 | from transformers import PreTrainedTokenizerBase, PreTrainedModel 8 | 9 | from src.utils import get_batch_logps, prepare_input 10 | from src.data.configs import DEFAULT_PROMPT_TEMPLATE 11 | 12 | 13 | @dataclass 14 | class RewardWrapperInput: 15 | raw_prompt: List[str] 16 | response: List[str] 17 | 18 | 19 | @dataclass 20 | class RewardWrapperBase(ABC): 21 | @abstractclassmethod 22 | def __call__(self, inputs: Any) -> Any: 23 | raise NotImplementedError 24 | 25 | 26 | @dataclass 27 | class RewardWrapperList(RewardWrapperBase): 28 | reward_wrapper_list: List[RewardWrapperBase] 29 | 30 | def __call__(self, inputs: Any) -> List[torch.Tensor]: 31 | outputs_list = [] 32 | for reward_wrapper in self.reward_wrapper_list: 33 | outputs_list.append(reward_wrapper(inputs)) 34 | return outputs_list 35 | 36 | def map(self, func): 37 | for i in range(len(self.reward_wrapper_list)): 38 | self.reward_wrapper_list[i] = func(self.reward_wrapper_list[i]) 39 | return self 40 | 41 | def __len__(self): 42 | return len(self.reward_wrapper_list) 43 | 44 | 45 | @dataclass 46 | class ImplicitRewardWrapper(RewardWrapperBase): 47 | """ 48 | An implicit reward model parameterized as r(x,y) = logp(y|x)-logp_{ref}(y|x) 49 | """ 50 | model: PreTrainedModel 51 | ref_model: PreTrainedModel 52 | tokenizer: PreTrainedTokenizerBase 53 | prompt_template: Optional[str] = DEFAULT_PROMPT_TEMPLATE 54 | beta: Optional[bool] = 0.1 55 | average_log_prob: Optional[bool] = False 56 | label_pad_token_id: Optional[int] = -100 57 | 58 | @torch.no_grad() 59 | def __call__(self, inputs: RewardWrapperInput) -> torch.Tensor: 60 | from src.trainer.sft_trainer import SFTDataMapFunc, SFTDataCollatorWithPadding 61 | inputs = asdict(inputs) 62 | inputs["prompt"] = [self.prompt_template.format( 63 | raw_prompt=raw_prompt) for raw_prompt in inputs["raw_prompt"]] 64 | tokens = SFTDataMapFunc(tokenizer=self.tokenizer)(inputs) 65 | batch = SFTDataCollatorWithPadding(tokenizer=self.tokenizer)( 66 | [{k:v[i] for k,v in tokens.items()} for i in range(len(inputs["prompt"]))] 67 | ) 68 | batch = prepare_input(batch) 69 | policy_all_logps = self.forward(self.model, batch) 70 | ref_all_logps = self.forward(self.ref_model, batch) 71 | return self.beta * (policy_all_logps - ref_all_logps) 72 | 73 | @torch.no_grad() 74 | def forward(self, model: PreTrainedModel, batch: Dict[Text, torch.Tensor]) -> torch.Tensor: 75 | all_logits = model( 76 | input_ids=batch["input_ids"], 77 | attention_mask=batch["attention_mask"], 78 | ).logits.to(torch.float32) 79 | all_logps = get_batch_logps( 80 | all_logits, 81 | batch["labels"], 82 | average_log_prob=self.average_log_prob, 83 | label_pad_token_id=self.label_pad_token_id, 84 | ) 85 | return all_logps 86 | 87 | 88 | if __name__ == "__main__": 89 | from transformers import AutoModelForCausalLM, AutoTokenizer 90 | from accelerate import Accelerator 91 | 92 | model = AutoModelForCausalLM.from_pretrained( 93 | "meta-llama/Llama-2-7b-hf", 94 | use_flash_attention_2=True, # flash attn 95 | torch_dtype=torch.bfloat16, 96 | device_map={"": Accelerator().local_process_index}, 97 | ) 98 | 99 | tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", trust_remote_code=True) 100 | tokenizer.pad_token = tokenizer.eos_token 101 | tokenizer.padding_side = "right" 102 | 103 | implicit_reward = ImplicitRewardWrapper( 104 | model=model, 105 | ref_model=model, 106 | tokenizer=tokenizer, 107 | ) 108 | 109 | implicit_reward({"raw_prompt": ["who are you", "hi"], "response": ["i am your dad", "goodbye"]}) 110 | breakpoint() 111 | -------------------------------------------------------------------------------- /scripts/modpo/beavertails/utils/gen.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Optional 3 | import os 4 | import math 5 | 6 | import torch 7 | import tyro 8 | import tqdm 9 | from transformers import AutoModelForCausalLM, AutoTokenizer 10 | from datasets import Dataset 11 | from peft import PeftModel 12 | 13 | from src.data.configs import DATASET_CONFIGS, DEFAULT_PROMPT_TEMPLATE 14 | from src.utils import ( 15 | print_local_main, disable_progress_bar_non_local_main, set_seeds 16 | ) 17 | 18 | disable_progress_bar_non_local_main() 19 | 20 | 21 | @dataclass 22 | class ScriptArguments: 23 | 24 | sft_model_name: str = field(metadata={"help": "the sft model name"}) 25 | adapter_model_name: str = field(default=None, metadata={"help": "lora name"}) 26 | use_flash_attention_2: Optional[bool] = field(default=False, metadata={"help": "whether to use flash attention 2"}) 27 | prompt_template: Optional[str] = field(default=DEFAULT_PROMPT_TEMPLATE, metadata={"help": "the prompt template"}) 28 | dataset_name: Optional[str] = field(default="PKU-Alignment/PKU-SafeRLHF-10K", metadata={"help": "the dataset name"}) 29 | dataset_caching: Optional[bool] = field(default=False, metadata={"help": "used cached dataset"}) 30 | 31 | output_dir: Optional[str] = field(default=None, metadata={"help": "output path for generations"}) 32 | eval_size: Optional[int] = field(default=200, metadata={"help": "number of prompts for generations"}) 33 | max_length: Optional[int] = field(default=1024, metadata={"help": "the maximum sequence length"}) 34 | batch_size: Optional[int] = field(default=8) 35 | rank: Optional[int] = field(default=0) 36 | world_size: Optional[int] = field(default=1) 37 | seed: Optional[int] = field(default=0) 38 | 39 | 40 | if __name__ == "__main__": 41 | 42 | script_args = tyro.cli(ScriptArguments) 43 | set_seeds(script_args.seed) 44 | 45 | # base model 46 | print_local_main("loading model...") 47 | sft_model = AutoModelForCausalLM.from_pretrained( 48 | script_args.sft_model_name, 49 | use_flash_attention_2=script_args.use_flash_attention_2, # flash attn 50 | torch_dtype=torch.bfloat16, # necessary for llama2, otherwise will be cast to float32 51 | device_map="auto", 52 | ) 53 | if script_args.adapter_model_name: 54 | model = PeftModel.from_pretrained(sft_model, script_args.adapter_model_name) 55 | else: 56 | model = sft_model # sft 57 | 58 | # tokenizer: left padding for generation 59 | tokenizer = AutoTokenizer.from_pretrained(script_args.sft_model_name, trust_remote_code=True) 60 | tokenizer.pad_token = tokenizer.eos_token 61 | tokenizer.padding_side = "left" 62 | 63 | # dataset 64 | if not script_args.dataset_caching: 65 | from datasets import disable_caching 66 | disable_caching() 67 | rdp = DATASET_CONFIGS[script_args.dataset_name](prompt_template=script_args.prompt_template) 68 | eval_dataset = rdp.get_sft_dataset(split="validation").select(range(script_args.eval_size)) 69 | 70 | split_size = math.ceil(len(eval_dataset) /script_args.world_size) 71 | eval_dataset = eval_dataset.select(range( 72 | script_args.rank*split_size, 73 | min((script_args.rank+1)*split_size, len(eval_dataset)) 74 | )) 75 | output_path = os.path.join( 76 | script_args.output_dir, 77 | f"{str(script_args.rank+1).zfill(5)}-of-{str(script_args.world_size).zfill(5)}.jsonl" 78 | ) 79 | 80 | results = [] 81 | for idx in tqdm.tqdm(range(0, len(eval_dataset), script_args.batch_size)): 82 | batch = eval_dataset[idx: idx+script_args.batch_size] 83 | prompt_tokenized = tokenizer( 84 | batch["prompt"], 85 | return_tensors="pt", 86 | padding=True, 87 | ) 88 | output_tokenized = model.generate( 89 | input_ids=prompt_tokenized["input_ids"].cuda(), 90 | attention_mask=prompt_tokenized["attention_mask"].cuda(), 91 | max_length=script_args.max_length, 92 | ) 93 | output = tokenizer.batch_decode(output_tokenized, skip_special_tokens=True) 94 | for sample in output: 95 | results.append({'prompt_response': sample}) 96 | 97 | 98 | dataset = Dataset.from_list(results) 99 | dataset.to_json(output_path) 100 | -------------------------------------------------------------------------------- /scripts/modpo/beavertails/README.md: -------------------------------------------------------------------------------- 1 | # Safety Alignment 2 | 3 | This directory demonstrates the use of MODPO to train language models that balance helpfulness and safety. For a comprehensive experimental setup, refer to the [MODPO paper](https://arxiv.org/pdf/2310.03708.pdf). 4 | 5 | ## Automatic Pipeline 6 | 7 | To train and evaluate language models with varying balances of helpfulness and safety: 8 | 9 | 1. Training: 10 | ``` 11 | sh scripts/modpo/beavertails/run.sh 12 | ``` 13 | 14 | 2. Evaluation: 15 | ``` 16 | sh scripts/modpo/beavertails/utils/gen.py 17 | sh scripts/modpo/beavertails/utils/score.py 18 | ``` 19 | 20 | View the results in ``output/PKU-Alignment/PKU-SafeRLHF-10K/modpo/score.csv``. Note that higher mean reward indicate better helpfulness, whereas higher mean cost suggest increased harmfulness. 21 | 22 | 23 | ## Annotated Pipeline 24 | 25 | ### Training 26 | 27 | Two steps to train a language model that is both helpful and safe: 28 | 29 | 1. Reward Modeling: Run DPO on safe preferences to train a safe reward model that encourages safe response: 30 | ```sh 31 | PYTHONPATH=. accelerate launch --config_file scripts/accelerate_configs/multi_gpu.yaml --num_processes=8 \ 32 | scripts/examples/dpo/dpo.py \ 33 | --sft_model_name "PKU-Alignment/alpaca-7b-reproduced" \ 34 | --prompt_template "BEGINNING OF CONVERSATION: USER: {raw_prompt} ASSISTANT:" \ 35 | --dataset_name "PKU-Alignment/PKU-SafeRLHF-10K-safer" \ 36 | --max_length 512 \ 37 | --training_args.output_dir "./output/PKU-Alignment/PKU-SafeRLHF-10K/modpo/rm/safer" \ 38 | --training_args.run_name "PKU-Alignment/PKU-SafeRLHF-10K/modpo/rm/safer" \ 39 | --training_args.per_device_train_batch_size 6 \ 40 | --training_args.per_device_eval_batch_size 6 \ 41 | --training_args.gradient_accumulation_steps 2 \ 42 | --training_args.learning_rate 5e-4 \ 43 | --peft_config.r 64 \ 44 | --peft_config.target_modules q_proj k_proj v_proj o_proj \ 45 | --peft_config.lora_alpha 1 \ 46 | --peft_config.lora_dropout 0 47 | ``` 48 | 49 | 2. Language Modeling: Run MODPO on helpful preferences, with the safe reward as margin, to train a language model that is both helpful and safe: 50 | ```sh 51 | PYTHONPATH=. accelerate launch --config_file scripts/accelerate_configs/multi_gpu.yaml --num_processes=8 \ 52 | scripts/modpo/beavertails/modpo.py \ 53 | --sft_model_name "PKU-Alignment/alpaca-7b-reproduced" \ 54 | --margin_reward_model_name "./output/PKU-Alignment/PKU-SafeRLHF-10K/modpo/rm/safer/best_checkpoint" \ 55 | --prompt_template "BEGINNING OF CONVERSATION: USER: {raw_prompt} ASSISTANT:" \ 56 | --dataset_name "PKU-Alignment/PKU-SafeRLHF-10K-better" \ 57 | --max_length 512 \ 58 | --w 0.5 \ 59 | --training_args.output_dir "./output/PKU-Alignment/PKU-SafeRLHF-10K/modpo/lm/(0.5)better+(1-0.5)safer" \ 60 | --training_args.run_name "PKU-SafeRLHF-10K/modpo/lm/(0.5)better+(1-0.5)safer" \ 61 | --training_args.per_device_train_batch_size 6 \ 62 | --training_args.per_device_eval_batch_size 6 \ 63 | --training_args.gradient_accumulation_steps 2 \ 64 | --training_args.learning_rate 5e-4 \ 65 | --peft_config.r 64 \ 66 | --peft_config.target_modules q_proj k_proj v_proj o_proj \ 67 | --peft_config.lora_alpha 1 \ 68 | --peft_config.lora_dropout 0 69 | ``` 70 | 71 | 72 | ### Evaluation 73 | 74 | Two steps to evaluate the trained language model: 75 | 76 | 1. Generation: 77 | - Generate with one GPU: 78 | ```sh 79 | # dataset can be either "PKU-Alignment/PKU-SafeRLHF-10K-safer" or "PKU-Alignment/PKU-SafeRLHF-10K-better"; only prompts are used here for generation. 80 | PYTHONPATH=. python3 scripts/modpo/beavertails/utils/gen.py \ 81 | --sft_model_name "PKU-Alignment/alpaca-7b-reproduced" \ 82 | --adapter_model_name "./output/PKU-Alignment/PKU-SafeRLHF-10K/modpo/lm/(0.5)better+(1-0.5)safer/best_checkpoint" \ 83 | --prompt_template "BEGINNING OF CONVERSATION: USER: {raw_prompt} ASSISTANT:" \ 84 | --dataset_name "PKU-Alignment/PKU-SafeRLHF-10K-safer" \ 85 | --output_dir "./output/PKU-Alignment/PKU-SafeRLHF-10K/modpo/lm/(0.5)better+(1-0.5)safer/gen" \ 86 | --max_length 512 87 | ``` 88 | - Generate with multiple GPUs in parallel (2 GPUs for example): 89 | ```sh 90 | for rank in 0 1; do 91 | CUDA_VISIBLE_DEVICES=${rank} PYTHONPATH=. python3 scripts/modpo/beavertails/utils/gen.py \ 92 | --sft_model_name "PKU-Alignment/alpaca-7b-reproduced" \ 93 | --adapter_model_name "./output/PKU-Alignment/PKU-SafeRLHF-10K/modpo/lm/(0.5)better+(1-0.5)safer/best_checkpoint" \ 94 | --prompt_template "BEGINNING OF CONVERSATION: USER: {raw_prompt} ASSISTANT:" \ 95 | --dataset_name "PKU-Alignment/PKU-SafeRLHF-10K-safer" \ 96 | --output_dir "./output/PKU-Alignment/PKU-SafeRLHF-10K/modpo/lm/(0.5)better+(1-0.5)safer/gen" \ 97 | --max_length 512 \ 98 | --rank ${rank} \ 99 | --world_size 2 & 100 | done 101 | 102 | ``` 103 | 104 | 2. Scoring: 105 | ```sh 106 | PYTHONPATH=. python3 scripts/modpo/beavertails/utils/score.py \ 107 | --input_dir "./output/PKU-Alignment/PKU-SafeRLHF-10K/modpo/lm/(0.5)better+(1-0.5)safer/gen" \ 108 | --output_dir "./output/PKU-Alignment/PKU-SafeRLHF-10K/modpo/lm/(0.5)better+(1-0.5)safer/score" 109 | ``` 110 | -------------------------------------------------------------------------------- /scripts/examples/rm/rm.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass, field 3 | from typing import Optional 4 | 5 | import torch 6 | import tyro 7 | from accelerate import Accelerator 8 | from peft import LoraConfig 9 | from transformers import AutoModelForSequenceClassification, AutoTokenizer, TrainingArguments 10 | 11 | from src.trainer.rm_trainer import RewardTrainer 12 | from src.data.configs import DATASET_CONFIGS, DEFAULT_PROMPT_TEMPLATE 13 | from src.utils import print_local_main, disable_progress_bar_non_local_main, param_sharding_enabled, set_seeds 14 | 15 | disable_progress_bar_non_local_main() 16 | 17 | @dataclass 18 | class ScriptArguments: 19 | 20 | sft_model_name: str = field(metadata={"help": "the sft model name"}) 21 | use_flash_attention_2: Optional[bool] = field(default=False, metadata={"help": "whether to use flash attention 2"}) 22 | prompt_template: Optional[str] = field(default=DEFAULT_PROMPT_TEMPLATE, metadata={"help": "the prompt template"}) 23 | dataset_name: Optional[str] = field(default="Anthropic/hh-rlhf", metadata={"help": "the dataset name"}) 24 | dataset_caching: Optional[bool] = field(default=False, metadata={"help": "used cached dataset"}) 25 | sanity_check: Optional[bool] = field(default=False, metadata={"help": "whether to conduct sanity check"}) 26 | 27 | max_length: Optional[int] = field(default=1024, metadata={"help": "the maximum sequence length"}) 28 | num_proc: Optional[int] = field(default=4, metadata={"help": "num_proc for dataset.map"}) 29 | 30 | training_args: TrainingArguments = field( 31 | default_factory=lambda: TrainingArguments( 32 | output_dir="./output/dev/reward", 33 | overwrite_output_dir=True, 34 | seed=42, 35 | 36 | per_device_train_batch_size=4, 37 | per_device_eval_batch_size=4, 38 | gradient_accumulation_steps=2, 39 | learning_rate=1e-4, 40 | lr_scheduler_type="cosine", 41 | warmup_steps=0.1, 42 | weight_decay=0.05, 43 | fp16=True, 44 | remove_unused_columns=False, 45 | run_name="dev_rm", 46 | report_to="wandb", 47 | 48 | num_train_epochs=3, 49 | logging_steps=10, 50 | save_steps=0.25, 51 | eval_steps=0.25, 52 | eval_delay=0.25, 53 | evaluation_strategy="steps", 54 | save_total_limit=3, 55 | load_best_model_at_end=True, 56 | ) 57 | ) 58 | 59 | peft: Optional[bool] = field(default=True, metadata={"help": "whether to use peft for training"}) 60 | peft_config: LoraConfig = field( 61 | default_factory=lambda: LoraConfig( 62 | r=16, 63 | lora_alpha=32, 64 | lora_dropout=0.05, 65 | bias="none", 66 | task_type="SEQ_CLS", 67 | modules_to_save=["score"], # maybe optional 68 | ) 69 | ) 70 | 71 | script_args = tyro.cli(ScriptArguments) 72 | set_seeds(script_args.training_args.seed) 73 | if not script_args.peft: 74 | script_args.peft_config = None 75 | 76 | # base model 77 | print_local_main("loading model...") 78 | reward_model = AutoModelForSequenceClassification.from_pretrained( 79 | script_args.sft_model_name, 80 | use_flash_attention_2=script_args.use_flash_attention_2, # flash attn 81 | torch_dtype=torch.bfloat16, 82 | num_labels=1, 83 | **({"device_map": {"": Accelerator().local_process_index}} if not param_sharding_enabled() else {}), 84 | ) 85 | reward_model.score = reward_model.score.float() # score head is not trainabled if loaded with torch.float16 86 | reward_model.config.update({ 87 | "use_cache": False, 88 | "pad_token_id": reward_model.config.eos_token_id 89 | }) 90 | print_local_main(reward_model) 91 | print_local_main(script_args.peft_config) 92 | 93 | # tokenizer 94 | tokenizer = AutoTokenizer.from_pretrained(script_args.sft_model_name, trust_remote_code=True) 95 | tokenizer.pad_token = tokenizer.eos_token 96 | tokenizer.padding_side = "right" 97 | 98 | # dataset 99 | if not script_args.dataset_caching: 100 | from datasets import disable_caching 101 | disable_caching() 102 | rdp = DATASET_CONFIGS[script_args.dataset_name]( 103 | prompt_template=script_args.prompt_template, 104 | sanity_check=script_args.sanity_check, 105 | ) 106 | train_dataset = rdp.get_preference_dataset(split="train") 107 | eval_dataset = rdp.get_preference_dataset(split="validation") 108 | 109 | # get ready for training 110 | print_local_main("start training...") 111 | trainer = RewardTrainer( 112 | model=reward_model, 113 | args=script_args.training_args, 114 | train_dataset=train_dataset, 115 | eval_dataset=eval_dataset, 116 | tokenizer=tokenizer, 117 | peft_config=script_args.peft_config, 118 | max_length=script_args.max_length, 119 | num_proc=script_args.num_proc, 120 | ) 121 | if Accelerator().is_local_main_process and script_args.peft_config: 122 | trainer.model.print_trainable_parameters() 123 | trainer.train() 124 | 125 | save_name = "best_checkpoint" if script_args.training_args.load_best_model_at_end else "final_checkpoint" 126 | trainer.model.save_pretrained(os.path.join(script_args.training_args.output_dir, save_name)) 127 | trainer.tokenizer.save_pretrained(os.path.join(script_args.training_args.output_dir, save_name)) 128 | -------------------------------------------------------------------------------- /scripts/examples/dpo/dpo.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass, field 3 | from typing import Optional 4 | 5 | import torch 6 | import tyro 7 | from accelerate import Accelerator 8 | from peft import LoraConfig 9 | from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments 10 | 11 | from src.trainer.dpo_trainer import DPOTrainer 12 | from src.data.configs import DATASET_CONFIGS, DEFAULT_PROMPT_TEMPLATE 13 | from src.utils import print_local_main, disable_progress_bar_non_local_main, param_sharding_enabled, set_seeds 14 | 15 | disable_progress_bar_non_local_main() 16 | 17 | 18 | @dataclass 19 | class ScriptArguments: 20 | 21 | sft_model_name: str = field(metadata={"help": "the sft model name"}) 22 | use_flash_attention_2: Optional[bool] = field(default=False, metadata={"help": "whether to use flash attention 2"}) 23 | prompt_template: Optional[str] = field(default=DEFAULT_PROMPT_TEMPLATE, metadata={"help": "the prompt template"}) 24 | dataset_name: Optional[str] = field(default="Anthropic/hh-rlhf", metadata={"help": "the dataset name"}) 25 | dataset_caching: Optional[bool] = field(default=False, metadata={"help": "used cached dataset"}) 26 | sanity_check: Optional[bool] = field(default=False, metadata={"help": "whether to conduct sanity check"}) 27 | 28 | beta: Optional[float] = field(default=0.1, metadata={"help": "beta for kl control"}) 29 | max_length: Optional[int] = field(default=1024, metadata={"help": "the maximum sequence length"}) 30 | num_proc: Optional[int] = field(default=4, metadata={"help": "num_proc for dataset.map"}) 31 | generate_during_eval: Optional[bool] = field(default=True, metadata={"help": "whether to generate during evaluation"}) 32 | 33 | training_args: TrainingArguments = field( 34 | default_factory=lambda: TrainingArguments( 35 | output_dir="./output/dev/dpo", 36 | overwrite_output_dir=True, 37 | seed=42, 38 | 39 | per_device_train_batch_size=4, 40 | per_device_eval_batch_size=4, 41 | gradient_accumulation_steps=2, 42 | learning_rate=1e-4, 43 | lr_scheduler_type="cosine", 44 | warmup_steps=0.1, 45 | weight_decay=0.05, 46 | fp16=True, 47 | remove_unused_columns=False, 48 | run_name="dev_dpo", 49 | report_to="wandb", 50 | 51 | num_train_epochs=3, 52 | logging_steps=10, 53 | save_steps=0.25, 54 | eval_steps=0.25, 55 | eval_delay=0.25, 56 | evaluation_strategy="steps", 57 | save_total_limit=3, 58 | load_best_model_at_end=True, 59 | ) 60 | ) 61 | 62 | peft: Optional[bool] = field(default=True, metadata={"help": "whether to use peft for training"}) 63 | peft_config: LoraConfig = field( 64 | default_factory=lambda: LoraConfig( 65 | r=16, 66 | lora_alpha=32, 67 | lora_dropout=0.05, 68 | bias="none", 69 | task_type="CAUSAL_LM", 70 | ) 71 | ) 72 | 73 | script_args = tyro.cli(ScriptArguments) 74 | set_seeds(script_args.training_args.seed) 75 | if not script_args.peft: 76 | script_args.peft_config = None 77 | 78 | # base model 79 | print_local_main("loading model...") 80 | sft_model = AutoModelForCausalLM.from_pretrained( 81 | script_args.sft_model_name, 82 | use_flash_attention_2=script_args.use_flash_attention_2, # flash attn 83 | torch_dtype=torch.bfloat16, # necessary for llama2, otherwise will be cast to float32 84 | **({"device_map": {"": Accelerator().local_process_index}} if not param_sharding_enabled() else {}), 85 | ) 86 | sft_model.config.update({ 87 | "use_cache": False, 88 | "pad_token_id": sft_model.config.eos_token_id 89 | }) 90 | print_local_main(sft_model) 91 | print_local_main(script_args.peft_config) 92 | 93 | # tokenizer 94 | tokenizer = AutoTokenizer.from_pretrained(script_args.sft_model_name, trust_remote_code=True) 95 | tokenizer.pad_token = tokenizer.eos_token 96 | tokenizer.padding_side = "right" 97 | 98 | # dataset 99 | if not script_args.dataset_caching: 100 | from datasets import disable_caching 101 | disable_caching() 102 | rdp = DATASET_CONFIGS[script_args.dataset_name]( 103 | prompt_template=script_args.prompt_template, 104 | sanity_check=script_args.sanity_check, 105 | ) 106 | train_dataset = rdp.get_preference_dataset(split="train") 107 | eval_dataset = rdp.get_preference_dataset(split="validation") 108 | 109 | # get ready for training 110 | print_local_main("start training...") 111 | trainer = DPOTrainer( 112 | model=sft_model, 113 | beta=script_args.beta, 114 | args=script_args.training_args, 115 | train_dataset=train_dataset, 116 | eval_dataset=eval_dataset, 117 | tokenizer=tokenizer, 118 | peft_config=script_args.peft_config, 119 | max_length=script_args.max_length, 120 | num_proc=script_args.num_proc, 121 | generate_during_eval=script_args.generate_during_eval, 122 | ) 123 | if Accelerator().is_local_main_process and script_args.peft_config: 124 | trainer.model.print_trainable_parameters() 125 | trainer.train() 126 | 127 | save_name = "best_checkpoint" if script_args.training_args.load_best_model_at_end else "final_checkpoint" 128 | trainer.model.save_pretrained(os.path.join(script_args.training_args.output_dir, save_name)) 129 | trainer.tokenizer.save_pretrained(os.path.join(script_args.training_args.output_dir, save_name)) 130 | -------------------------------------------------------------------------------- /src/data/raw_data/helpsteer.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | from dataclasses import dataclass 3 | from datasets import load_dataset 4 | from typing import Literal, Optional 5 | 6 | from .utils import RawDatasetPreprocessor 7 | from src.utils import print_local_main 8 | 9 | 10 | def helpsteer_transform_to_preference(batched_sample): 11 | def chosen_id(score_0, score_1): 12 | if score_0 < score_1: 13 | return 1 14 | elif score_0 > score_1: 15 | return 0 16 | else: 17 | return -1 18 | 19 | finegrained_dimensions = ("helpfulness", "correctness", "coherence", "complexity", "verbosity") 20 | dimensions = finegrained_dimensions + ("overall",) 21 | 22 | debatched_sample = [{k:batched_sample[k][i] for k in batched_sample.keys()} for i in range(len(batched_sample["prompt"]))] 23 | 24 | new_batched_sample = { 25 | "prompt": [], 26 | "response_0": [], 27 | "response_1": [], 28 | **{f"{dimension}_chosen_id": [] for dimension in dimensions} 29 | } 30 | mini_debatch = [] 31 | for i, sample in enumerate(debatched_sample): 32 | mini_debatch.append(sample) 33 | if i != len(debatched_sample) - 1 and sample["prompt"] == debatched_sample[i+1]["prompt"]: 34 | continue 35 | 36 | for j in range(len(mini_debatch)): 37 | for k in range(j+1, len(mini_debatch)): 38 | new_batched_sample["prompt"].append(mini_debatch[j]["prompt"]) 39 | new_batched_sample["response_0"].append(mini_debatch[j]["response"]) 40 | new_batched_sample["response_1"].append(mini_debatch[k]["response"]) 41 | new_batched_sample["overall_chosen_id"].append( 42 | chosen_id( 43 | sum(mini_debatch[j][dimension] for dimension in finegrained_dimensions), 44 | sum(mini_debatch[k][dimension] for dimension in finegrained_dimensions), 45 | ) 46 | ) 47 | for dimension in finegrained_dimensions: 48 | new_batched_sample[f"{dimension}_chosen_id"].append( 49 | chosen_id( 50 | mini_debatch[j][dimension], 51 | mini_debatch[k][dimension], 52 | ) 53 | ) 54 | 55 | mini_debatch = [] 56 | 57 | return new_batched_sample 58 | 59 | 60 | @dataclass 61 | class HelpSteerRDP(RawDatasetPreprocessor): 62 | path: Optional[str] = "nvidia/HelpSteer" 63 | # None for sft 64 | dimension: Optional[Literal["overall", "helpfulness", "correctness", "coherence", "complexity", "verbosity"]] = None 65 | 66 | def _get_raw_dataset(self, split): 67 | if split == "train": 68 | return load_dataset(self.path, split="train") 69 | elif split == "validation": 70 | return load_dataset(self.path, split="validation") 71 | elif split == "test": 72 | raise NotImplementedError("test split not implemented for helpsteer") 73 | else: 74 | raise NotImplementedError 75 | 76 | def _dataset_to_preference_formatter(self, example) -> Dict[str, str]: 77 | chosen_id = example[f"{self.dimension}_chosen_id"] 78 | return { 79 | "raw_prompt": example["prompt"], 80 | "prompt": self.prompt_template.format(raw_prompt=example["prompt"]), 81 | "chosen": example[f"response_{chosen_id}"], 82 | "rejected": example[f"response_{1-chosen_id}"], 83 | } 84 | 85 | def get_preference_dataset(self, split): 86 | assert self.dimension, "preference dimension has to be specified" 87 | dataset = self._get_raw_dataset(split) 88 | if self.sanity_check: 89 | dataset = dataset.select(range(min(len(dataset), 100))) 90 | print_local_main("mapping raw dataset to preference...") 91 | dataset = dataset.map( 92 | helpsteer_transform_to_preference, 93 | batched=True, 94 | num_proc=self.num_proc, 95 | remove_columns=dataset.column_names, 96 | ) 97 | print_local_main("filtering preference...") 98 | dataset = dataset.filter(lambda x: x[f"{self.dimension}_chosen_id"] != -1) 99 | print_local_main("mapping dataset to standard format...") 100 | return dataset.map(self._dataset_to_preference_formatter, num_proc=self.num_proc, remove_columns=dataset.column_names) 101 | 102 | def get_sft_dataset(self, split, **kwargs): 103 | if self.dimension: 104 | return super().get_sft_dataset(split, **kwargs) 105 | dataset = self._get_raw_dataset(split) 106 | if self.sanity_check: 107 | dataset = dataset.select(range(min(len(dataset), 100))) 108 | print_local_main("mapping raw dataset to sft...") 109 | return dataset.map( 110 | lambda sample: { 111 | "raw_prompt": sample["prompt"], 112 | "prompt": self.prompt_template.format(raw_prompt=sample["prompt"]), 113 | "response": sample["response"], 114 | }, 115 | num_proc=self.num_proc, 116 | remove_columns=dataset.column_names, 117 | ) 118 | 119 | 120 | if __name__ == "__main__": 121 | num_proc = 4 122 | helpful_dataset = HelpSteerRDP(dimension="helpfulness", num_proc=num_proc).get_preference_dataset(split="train") 123 | overall_dataset = HelpSteerRDP(dimension="overall", num_proc=num_proc).get_preference_dataset(split="train") 124 | sft_dataset = HelpSteerRDP(num_proc=num_proc).get_sft_dataset(split="train") 125 | breakpoint() 126 | -------------------------------------------------------------------------------- /scripts/examples/sft/sft.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass, field 3 | from typing import Optional 4 | 5 | import torch 6 | import tyro 7 | from accelerate import Accelerator 8 | from peft import LoraConfig 9 | from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments 10 | 11 | from src.trainer.sft_trainer import SFTTrainer 12 | from src.data.configs import DATASET_CONFIGS, DEFAULT_PROMPT_TEMPLATE 13 | from src.utils import print_local_main, disable_progress_bar_non_local_main, param_sharding_enabled, set_seeds 14 | 15 | disable_progress_bar_non_local_main() 16 | 17 | @dataclass 18 | class ScriptArguments: 19 | 20 | base_model_name: Optional[str] = field(default="meta-llama/Llama-2-7b-hf", metadata={"help": "the base model name"}) 21 | use_flash_attention_2: Optional[bool] = field(default=False, metadata={"help": "whether to use flash attention 2"}) 22 | prompt_template: Optional[str] = field(default=DEFAULT_PROMPT_TEMPLATE, metadata={"help": "the prompt template"}) 23 | dataset_name: Optional[str] = field(default="Anthropic/hh-rlhf", metadata={"help": "the dataset name"}) 24 | dataset_caching: Optional[bool] = field(default=False, metadata={"help": "used cached dataset"}) 25 | sanity_check: Optional[bool] = field(default=False, metadata={"help": "whether to conduct sanity check"}) 26 | 27 | max_length: Optional[int] = field(default=1024, metadata={"help": "the maximum sequence length"}) 28 | chosen_only: Optional[bool] = field(default=True, metadata={"help": "whether to train only on preferred response"}) 29 | completion_only: Optional[bool] = field(default=True, metadata={"help": "whether to train only on completion"}) 30 | num_proc: Optional[int] = field(default=4, metadata={"help": "num_proc for dataset.map"}) 31 | generate_during_eval: Optional[bool] = field(default=True, metadata={"help": "whether to generate during evaluation"}) 32 | 33 | training_args: TrainingArguments = field( 34 | default_factory=lambda: TrainingArguments( 35 | output_dir="./output/dev/sft", 36 | overwrite_output_dir=True, 37 | seed=42, 38 | 39 | per_device_train_batch_size=4, 40 | per_device_eval_batch_size=4, 41 | gradient_accumulation_steps=2, 42 | learning_rate=1e-4, 43 | lr_scheduler_type="cosine", 44 | warmup_steps=0.1, 45 | weight_decay=0.05, 46 | fp16=True, 47 | remove_unused_columns=False, 48 | run_name="dev_sft", 49 | report_to="wandb", 50 | 51 | num_train_epochs=3, 52 | logging_steps=10, 53 | save_steps=0.25, 54 | eval_steps=0.25, 55 | eval_delay=0.25, 56 | evaluation_strategy="steps", 57 | save_total_limit=3, 58 | load_best_model_at_end=True, 59 | ) 60 | ) 61 | 62 | peft: Optional[bool] = field(default=True, metadata={"help": "whether to use peft for training"}) 63 | peft_config: LoraConfig = field( 64 | default_factory=lambda: LoraConfig( 65 | r=16, 66 | lora_alpha=32, 67 | lora_dropout=0.05, 68 | bias="none", 69 | task_type="CAUSAL_LM", 70 | ) 71 | ) 72 | 73 | script_args = tyro.cli(ScriptArguments) 74 | set_seeds(script_args.training_args.seed) 75 | if not script_args.peft: 76 | script_args.peft_config = None 77 | 78 | # base model 79 | print_local_main("loading model...") 80 | base_model = AutoModelForCausalLM.from_pretrained( 81 | script_args.base_model_name, 82 | use_flash_attention_2=script_args.use_flash_attention_2, # flash attn 83 | torch_dtype=torch.bfloat16, # necessary for llama2, otherwise will be cast to float32 84 | **({"device_map": {"": Accelerator().local_process_index}} if not param_sharding_enabled() else {}), 85 | ) 86 | base_model.config.update({ 87 | "use_cache": False, 88 | "pad_token_id": base_model.config.eos_token_id 89 | }) 90 | print_local_main(base_model) 91 | print_local_main(script_args.peft_config) 92 | 93 | # tokenizer 94 | tokenizer = AutoTokenizer.from_pretrained(script_args.base_model_name, trust_remote_code=True) 95 | tokenizer.pad_token = tokenizer.eos_token 96 | tokenizer.padding_side = "right" 97 | 98 | # dataset 99 | if not script_args.dataset_caching: 100 | from datasets import disable_caching 101 | disable_caching() 102 | rdp = DATASET_CONFIGS[script_args.dataset_name]( 103 | prompt_template=script_args.prompt_template, 104 | sanity_check=script_args.sanity_check, 105 | ) 106 | train_dataset = rdp.get_sft_dataset(split="train", chosen_only=script_args.chosen_only) 107 | eval_dataset = rdp.get_sft_dataset(split="validation", chosen_only=script_args.chosen_only) 108 | 109 | # get ready for training 110 | print_local_main("start training...") 111 | trainer = SFTTrainer( 112 | model=base_model, 113 | args=script_args.training_args, 114 | train_dataset=train_dataset, 115 | eval_dataset=eval_dataset, 116 | tokenizer=tokenizer, 117 | peft_config=script_args.peft_config, 118 | max_length=script_args.max_length, 119 | completion_only=script_args.completion_only, 120 | num_proc=script_args.num_proc, 121 | generate_during_eval=script_args.generate_during_eval, 122 | ) 123 | if Accelerator().is_local_main_process and script_args.peft_config: 124 | trainer.model.print_trainable_parameters() 125 | trainer.train() 126 | 127 | save_name = "best_checkpoint" if script_args.training_args.load_best_model_at_end else "final_checkpoint" 128 | trainer.model.save_pretrained(os.path.join(script_args.training_args.output_dir, save_name)) 129 | trainer.tokenizer.save_pretrained(os.path.join(script_args.training_args.output_dir, save_name)) 130 | -------------------------------------------------------------------------------- /src/data/raw_data/ultrafeedback.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from functools import partial 3 | from typing import Literal, Dict, Optional 4 | 5 | from datasets import load_dataset 6 | 7 | from .utils import RawDatasetPreprocessor 8 | from src.utils import print_local_main 9 | 10 | 11 | def ultrafeedback_transform_to_sft(batched_sample, prompt_template): 12 | new_batched_sample = { 13 | "raw_prompt": [], 14 | "prompt": [], 15 | "response": [], 16 | } 17 | for instruction, completions in zip(batched_sample["instruction"], batched_sample["completions"]): 18 | # breakpoint() 19 | for completion in completions: 20 | new_batched_sample["raw_prompt"].append(instruction) 21 | new_batched_sample["prompt"].append(prompt_template.format(raw_prompt=instruction)) 22 | new_batched_sample["response"].append(completion['response']) 23 | return new_batched_sample 24 | 25 | 26 | def ultrafeedback_transform_to_preference(batched_sample): 27 | def chosen_id(score_0, score_1): 28 | if score_0 < score_1: 29 | return 1 30 | elif score_0 > score_1: 31 | return 0 32 | else: 33 | return -1 34 | 35 | finegrained_dimensions = ("instruction_following", "honesty", "truthfulness", "helpfulness") 36 | dimensions = finegrained_dimensions + ("overall",) 37 | 38 | new_batched_sample = { 39 | "prompt": [], 40 | "response_0": [], 41 | "response_1": [], 42 | **{f"{dimension}_chosen_id": [] for dimension in dimensions} 43 | } 44 | for instruction, completions in zip(batched_sample["instruction"], batched_sample["completions"]): 45 | n_responses = len(completions) 46 | 47 | for j in range(n_responses): 48 | for k in range(j+1, n_responses): 49 | new_batched_sample["prompt"].append(instruction) 50 | new_batched_sample["response_0"].append(completions[j]['response']) 51 | new_batched_sample["response_1"].append(completions[k]['response']) 52 | new_batched_sample["overall_chosen_id"].append( 53 | chosen_id( 54 | completions[j]["overall_score"], 55 | completions[k]["overall_score"] 56 | ) 57 | ) 58 | for dimension in finegrained_dimensions: 59 | new_batched_sample[f"{dimension}_chosen_id"].append( 60 | chosen_id( 61 | completions[j]["annotations"][dimension]["Rating"], 62 | completions[k]["annotations"][dimension]["Rating"] 63 | ) 64 | ) 65 | 66 | return new_batched_sample 67 | 68 | 69 | @dataclass 70 | class UltraFeedbackRDP(RawDatasetPreprocessor): 71 | path: Optional[str] = "OpenBMB/UltraFeedback" 72 | dimension: Optional[Literal["overall", "instruction_following", "honesty", "truthfulness", "helpfulness"]] = None 73 | 74 | def _get_raw_dataset(self, split): 75 | if split == "train": 76 | return load_dataset(self.path, split="train").train_test_split(test_size=0.1, seed=0)["train"] 77 | elif split == "validation": 78 | return load_dataset(self.path, split="train").train_test_split(test_size=0.1, seed=0)["test"] 79 | elif split == "test": 80 | raise NotImplementedError("test split not implemented for UltraFeedbackRDP") 81 | else: 82 | raise NotImplementedError 83 | 84 | def _dataset_to_preference_formatter(self, example) -> Dict[str, str]: 85 | chosen_id = example[f"{self.dimension}_chosen_id"] 86 | return { 87 | "raw_prompt": example["prompt"], 88 | "prompt": self.prompt_template.format(raw_prompt=example["prompt"]), 89 | "chosen": example[f"response_{chosen_id}"], 90 | "rejected": example[f"response_{1-chosen_id}"], 91 | } 92 | 93 | def get_preference_dataset(self, split): 94 | dataset = self._get_raw_dataset(split) 95 | if self.sanity_check: 96 | dataset = dataset.select(range(min(len(dataset), 100))) 97 | print_local_main("mapping raw dataset to preference...") 98 | dataset = dataset.map( 99 | ultrafeedback_transform_to_preference, 100 | batched=True, 101 | num_proc=self.num_proc, 102 | remove_columns=dataset.column_names, 103 | ) 104 | print_local_main("filtering preference...") 105 | dataset = dataset.filter(lambda x: x[f"{self.dimension}_chosen_id"] != -1) 106 | print_local_main("mapping dataset to standard format...") 107 | return dataset.map(self._dataset_to_preference_formatter, num_proc=self.num_proc, remove_columns=dataset.column_names) 108 | 109 | def get_sft_dataset(self, split, **kwargs): 110 | if self.dimension: 111 | return super().get_sft_dataset(split, **kwargs) 112 | dataset = self._get_raw_dataset(split) 113 | if self.sanity_check: 114 | dataset = dataset.select(range(min(len(dataset), 100))) 115 | print_local_main("mapping raw dataset to sft...") 116 | return dataset.map( 117 | partial(ultrafeedback_transform_to_sft, prompt_template=self.prompt_template), 118 | batched=True, 119 | num_proc=self.num_proc, 120 | remove_columns=dataset.column_names, 121 | ) 122 | 123 | 124 | if __name__ == "__main__": 125 | num_proc = 4 126 | overall_dataset = UltraFeedbackRDP(dimension="overall", num_proc=num_proc).get_preference_dataset(split="train") 127 | sft_dataset = UltraFeedbackRDP(num_proc=num_proc).get_sft_dataset(split="train") 128 | breakpoint() 129 | -------------------------------------------------------------------------------- /scripts/modpo/summarize_w_length_penalty/modpo.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass, field 3 | from typing import Optional 4 | 5 | import torch 6 | import tyro 7 | from accelerate import Accelerator 8 | from peft import LoraConfig 9 | from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, PreTrainedTokenizerBase 10 | 11 | from src.trainer.modpo_trainer import MODPOTrainer 12 | from src.data.configs import DATASET_CONFIGS, DEFAULT_PROMPT_TEMPLATE 13 | from src.utils import print_local_main, disable_progress_bar_non_local_main, set_seeds, param_sharding_enabled 14 | from src.utils.reward import RewardWrapperList, RewardWrapperBase, RewardWrapperInput 15 | 16 | disable_progress_bar_non_local_main() 17 | 18 | 19 | @dataclass 20 | class ScriptArguments: 21 | 22 | sft_model_name: str = field(metadata={"help": "the sft model name"}) 23 | use_flash_attention_2: Optional[bool] = field(default=False, metadata={"help": "whether to use flash attention 2"}) 24 | prompt_template: Optional[str] = field(default=DEFAULT_PROMPT_TEMPLATE, metadata={"help": "the prompt template"}) 25 | dataset_name: Optional[str] = field(default="Anthropic/hh-rlhf", metadata={"help": "the dataset name"}) 26 | dataset_caching: Optional[bool] = field(default=False, metadata={"help": "used cached dataset"}) 27 | sanity_check: Optional[bool] = field(default=False, metadata={"help": "whether to conduct sanity check"}) 28 | 29 | w: Optional[float] = field(default=0.5, metadata={"help": "weight"}) 30 | beta: Optional[float] = field(default=0.1, metadata={"help": "beta for kl control"}) 31 | max_length: Optional[int] = field(default=1024, metadata={"help": "the maximum sequence length"}) 32 | num_proc: Optional[int] = field(default=4, metadata={"help": "num_proc for dataset.map"}) 33 | generate_during_eval: Optional[bool] = field(default=True, metadata={"help": "whether to generate during evaluation"}) 34 | 35 | training_args: TrainingArguments = field( 36 | default_factory=lambda: TrainingArguments( 37 | output_dir="./output/dev/modpo", 38 | overwrite_output_dir=True, 39 | seed=42, 40 | 41 | per_device_train_batch_size=4, 42 | per_device_eval_batch_size=4, 43 | gradient_accumulation_steps=2, 44 | learning_rate=1e-4, 45 | lr_scheduler_type="cosine", 46 | warmup_steps=0.1, 47 | weight_decay=0.05, 48 | fp16=True, 49 | remove_unused_columns=False, 50 | run_name="dev_modpo", 51 | report_to="wandb", 52 | 53 | num_train_epochs=3, 54 | logging_steps=10, 55 | save_steps=0.25, 56 | eval_steps=0.25, 57 | eval_delay=0.25, 58 | evaluation_strategy="steps", 59 | save_total_limit=3, 60 | load_best_model_at_end=True, 61 | ) 62 | ) 63 | 64 | peft: Optional[bool] = field(default=True, metadata={"help": "whether to use peft for training"}) 65 | peft_config: LoraConfig = field( 66 | default_factory=lambda: LoraConfig( 67 | r=16, 68 | lora_alpha=32, 69 | lora_dropout=0.05, 70 | bias="none", 71 | task_type="CAUSAL_LM", 72 | ) 73 | ) 74 | 75 | script_args = tyro.cli(ScriptArguments) 76 | set_seeds(script_args.training_args.seed) 77 | if not script_args.peft: 78 | script_args.peft_config = None 79 | 80 | # base model 81 | print_local_main("loading model...") 82 | sft_model = AutoModelForCausalLM.from_pretrained( 83 | script_args.sft_model_name, 84 | use_flash_attention_2=script_args.use_flash_attention_2, # flash attn 85 | torch_dtype=torch.bfloat16, # necessary for llama2, otherwise will be cast to float32 86 | **({"device_map": {"": Accelerator().local_process_index}} if not param_sharding_enabled() else {}), 87 | ) 88 | sft_model.config.update({ 89 | "use_cache": False, 90 | "pad_token_id": sft_model.config.eos_token_id 91 | }) 92 | print_local_main(sft_model) 93 | print_local_main(script_args.peft_config) 94 | 95 | # tokenizer 96 | tokenizer = AutoTokenizer.from_pretrained(script_args.sft_model_name, trust_remote_code=True) 97 | tokenizer.pad_token = tokenizer.eos_token 98 | tokenizer.padding_side = "right" 99 | 100 | # dataset 101 | if not script_args.dataset_caching: 102 | from datasets import disable_caching 103 | disable_caching() 104 | rdp = DATASET_CONFIGS[script_args.dataset_name]( 105 | prompt_template=script_args.prompt_template, 106 | sanity_check=script_args.sanity_check, 107 | ) 108 | train_dataset = rdp.get_preference_dataset(split="train") 109 | eval_dataset = rdp.get_preference_dataset(split="validation") 110 | 111 | # get ready for training 112 | print_local_main("start training...") 113 | trainer = MODPOTrainer( 114 | model=sft_model, 115 | beta=script_args.beta, 116 | args=script_args.training_args, 117 | train_dataset=train_dataset, 118 | eval_dataset=eval_dataset, 119 | tokenizer=tokenizer, 120 | peft_config=script_args.peft_config, 121 | max_length=script_args.max_length, 122 | num_proc=script_args.num_proc, 123 | generate_during_eval=script_args.generate_during_eval, 124 | ) 125 | if Accelerator().is_local_main_process and script_args.peft_config: 126 | trainer.model.print_trainable_parameters() 127 | 128 | @dataclass 129 | class LengthPenaltyWrapper(RewardWrapperBase): 130 | """Penalizes longer responses.""" 131 | tokenizer: PreTrainedTokenizerBase 132 | def __call__(self, inputs: RewardWrapperInput) -> torch.Tensor: 133 | from src.utils import prepare_input 134 | tokenized_responses = self.tokenizer(inputs.response) 135 | rewards = [-len(tokenized_response_id) for tokenized_response_id in tokenized_responses["input_ids"]] 136 | return prepare_input(torch.Tensor(rewards).to(torch.bfloat16)) 137 | 138 | trainer.set_wrapped_margin_reward_model_list( 139 | RewardWrapperList([LengthPenaltyWrapper(tokenizer=tokenizer)]), 140 | w=(1, script_args.w), 141 | prepare=False, 142 | ) 143 | trainer.train() 144 | 145 | save_name = "best_checkpoint" if script_args.training_args.load_best_model_at_end else "final_checkpoint" 146 | trainer.model.save_pretrained(os.path.join(script_args.training_args.output_dir, save_name)) 147 | trainer.tokenizer.save_pretrained(os.path.join(script_args.training_args.output_dir, save_name)) 148 | -------------------------------------------------------------------------------- /scripts/modpo/beavertails/modpo.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass, field 3 | from typing import Optional 4 | 5 | import torch 6 | import tyro 7 | from accelerate import Accelerator 8 | from peft import LoraConfig 9 | from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments 10 | 11 | from src.trainer.modpo_trainer import MODPOTrainer 12 | from src.data.configs import DATASET_CONFIGS, DEFAULT_PROMPT_TEMPLATE 13 | from src.utils import ( 14 | print_local_main, disable_progress_bar_non_local_main, set_seeds, 15 | prepare_model_for_peft, param_sharding_enabled, PeftAsPreTrained, 16 | ) 17 | from src.utils.reward import RewardWrapperList, ImplicitRewardWrapper 18 | 19 | disable_progress_bar_non_local_main() 20 | 21 | 22 | @dataclass 23 | class ScriptArguments: 24 | 25 | sft_model_name: str = field(metadata={"help": "the sft model name"}) 26 | margin_reward_model_name: str = field(metadata={"help": "the margin reward model name"}) 27 | use_flash_attention_2: Optional[bool] = field(default=False, metadata={"help": "whether to use flash attention 2"}) 28 | prompt_template: Optional[str] = field(default=DEFAULT_PROMPT_TEMPLATE, metadata={"help": "the prompt template"}) 29 | dataset_name: Optional[str] = field(default="PKU-Alignment/PKU-SafeRLHF-10K", metadata={"help": "the dataset name"}) 30 | dataset_caching: Optional[bool] = field(default=False, metadata={"help": "used cached dataset"}) 31 | sanity_check: Optional[bool] = field(default=False, metadata={"help": "whether to conduct sanity check"}) 32 | 33 | w: Optional[float] = field(default=0.5, metadata={"help": "weight"}) 34 | beta: Optional[float] = field(default=0.1, metadata={"help": "beta for kl control"}) 35 | max_length: Optional[int] = field(default=1024, metadata={"help": "the maximum sequence length"}) 36 | num_proc: Optional[int] = field(default=4, metadata={"help": "num_proc for dataset.map"}) 37 | generate_during_eval: Optional[bool] = field(default=True, metadata={"help": "whether to generate during evaluation"}) 38 | 39 | training_args: TrainingArguments = field( 40 | default_factory=lambda: TrainingArguments( 41 | output_dir="./output/dev/modpo", 42 | overwrite_output_dir=True, 43 | seed=42, 44 | 45 | per_device_train_batch_size=4, 46 | per_device_eval_batch_size=4, 47 | gradient_accumulation_steps=2, 48 | learning_rate=1e-4, 49 | lr_scheduler_type="cosine", 50 | warmup_steps=0.1, 51 | weight_decay=0.05, 52 | fp16=True, 53 | remove_unused_columns=False, 54 | run_name="dev_modpo", 55 | report_to="wandb", 56 | 57 | num_train_epochs=3, 58 | logging_steps=10, 59 | save_steps=0.25, 60 | eval_steps=0.25, 61 | eval_delay=0.25, 62 | evaluation_strategy="steps", 63 | save_total_limit=3, 64 | load_best_model_at_end=True, 65 | ) 66 | ) 67 | 68 | peft: Optional[bool] = field(default=True, metadata={"help": "whether to use peft for training"}) 69 | peft_config: LoraConfig = field( 70 | default_factory=lambda: LoraConfig( 71 | r=16, 72 | lora_alpha=32, 73 | lora_dropout=0.05, 74 | bias="none", 75 | task_type="CAUSAL_LM", 76 | ) 77 | ) 78 | 79 | script_args = tyro.cli(ScriptArguments) 80 | set_seeds(script_args.training_args.seed) 81 | if not script_args.peft: 82 | script_args.peft_config = None 83 | 84 | # base model 85 | print_local_main("loading model...") 86 | sft_model = AutoModelForCausalLM.from_pretrained( 87 | script_args.sft_model_name, 88 | use_flash_attention_2=script_args.use_flash_attention_2, # flash attn 89 | torch_dtype=torch.bfloat16, # necessary for llama2, otherwise will be cast to float32 90 | **({"device_map": {"": Accelerator().local_process_index}} if not param_sharding_enabled() else {}), 91 | ) 92 | sft_model.config.update({ 93 | "use_cache": False, 94 | "pad_token_id": sft_model.config.eos_token_id 95 | }) 96 | print_local_main(sft_model) 97 | print_local_main(script_args.peft_config) 98 | 99 | # peft 100 | model = prepare_model_for_peft(sft_model, peft_config=script_args.peft_config, args=script_args.training_args) 101 | # load frozon margin reward weights as lora 102 | model.load_adapter(script_args.margin_reward_model_name, adapter_name="margin_reward") 103 | 104 | # tokenizer 105 | tokenizer = AutoTokenizer.from_pretrained(script_args.sft_model_name, trust_remote_code=True) 106 | tokenizer.pad_token = tokenizer.eos_token 107 | tokenizer.padding_side = "right" 108 | 109 | # dataset 110 | if not script_args.dataset_caching: 111 | from datasets import disable_caching 112 | disable_caching() 113 | rdp = DATASET_CONFIGS[script_args.dataset_name]( 114 | prompt_template=script_args.prompt_template, 115 | sanity_check=script_args.sanity_check, 116 | ) 117 | train_dataset = rdp.get_preference_dataset(split="train") 118 | eval_dataset = rdp.get_preference_dataset(split="validation") 119 | 120 | # get ready for training 121 | print_local_main("start training...") 122 | trainer = MODPOTrainer( 123 | model=model, 124 | beta=script_args.beta, 125 | args=script_args.training_args, 126 | train_dataset=train_dataset, 127 | eval_dataset=eval_dataset, 128 | tokenizer=tokenizer, 129 | max_length=script_args.max_length, 130 | num_proc=script_args.num_proc, 131 | generate_during_eval=script_args.generate_during_eval, 132 | ) 133 | if Accelerator().is_local_main_process: 134 | trainer.model.print_trainable_parameters() 135 | trainer.set_wrapped_margin_reward_model_list( 136 | RewardWrapperList([ 137 | ImplicitRewardWrapper( 138 | model=PeftAsPreTrained(trainer.model, "margin_reward"), 139 | ref_model=PeftAsPreTrained(trainer.model), 140 | tokenizer=tokenizer, 141 | beta=script_args.beta, 142 | prompt_template=script_args.prompt_template, 143 | ) 144 | ]), 145 | w=(script_args.w, 1-script_args.w), 146 | prepare=False, # avoid extra copies of the model weights; margin reward has been prepared as part of lora weights of the main model 147 | ) 148 | trainer.train() 149 | 150 | save_name = "best_checkpoint" if script_args.training_args.load_best_model_at_end else "final_checkpoint" 151 | trainer.model.save_pretrained(os.path.join(script_args.training_args.output_dir, save_name)) 152 | trainer.tokenizer.save_pretrained(os.path.join(script_args.training_args.output_dir, save_name)) 153 | -------------------------------------------------------------------------------- /src/utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import inspect 3 | from dataclasses import dataclass 4 | from contextlib import contextmanager 5 | from collections.abc import Mapping 6 | from typing import Optional, Text, Any 7 | 8 | import torch 9 | import numpy as np 10 | from peft import PeftModel 11 | from accelerate import Accelerator 12 | 13 | from trl.import_utils import is_peft_available 14 | 15 | 16 | if is_peft_available(): 17 | from peft import get_peft_model, prepare_model_for_kbit_training 18 | 19 | 20 | def prepare_model_for_peft(model, peft_config, args): 21 | if not is_peft_available() and peft_config is not None: 22 | raise ValueError( 23 | "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models" 24 | ) 25 | elif is_peft_available() and peft_config is not None: 26 | if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): 27 | _support_gc_kwargs = hasattr( 28 | args, "gradient_checkpointing_kwargs" 29 | ) and "gradient_checkpointing_kwargs" in list( 30 | inspect.signature(prepare_model_for_kbit_training).parameters 31 | ) 32 | 33 | preprare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing} 34 | 35 | if _support_gc_kwargs: 36 | preprare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs 37 | 38 | model = prepare_model_for_kbit_training(model, **preprare_model_kwargs) 39 | elif getattr(args, "gradient_checkpointing", False): 40 | # For backward compatibility with older versions of transformers 41 | if hasattr(model, "enable_input_require_grads"): 42 | model.enable_input_require_grads() 43 | else: 44 | 45 | def make_inputs_require_grad(module, input, output): 46 | output.requires_grad_(True) 47 | 48 | model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) 49 | model = get_peft_model(model, peft_config) 50 | # For models that use gradient_checkpoiting, we need to attach a hook that enables input 51 | # to explicitly have `requires_grad=True`, otherwise training will either silently 52 | # fail or completely fail. 53 | elif getattr(args, "gradient_checkpointing", False): 54 | # For backward compatibility with older versions of transformers 55 | if hasattr(model, "enable_input_require_grads"): 56 | model.enable_input_require_grads() 57 | else: 58 | 59 | def make_inputs_require_grad(module, input, output): 60 | output.requires_grad_(True) 61 | 62 | model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) 63 | 64 | return model 65 | 66 | 67 | def common_prefix_length(list_a, list_b): 68 | length = 0 69 | for i in range(min(len(list_a), len(list_b))): 70 | if list_a[i] == list_b[i]: 71 | length += 1 72 | else: 73 | break 74 | return length 75 | 76 | 77 | def pad_labels(features, tokenizer, pad_to_multiple_of=None, label_pad_token_id=-100): 78 | # copied from https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/data/data_collator.py#L562-L584 79 | labels = [feature["labels"] for feature in features] if "labels" in features[0].keys() else None 80 | # We have to pad the labels before calling `tokenizer.pad` as this method won't pad them and needs them of the 81 | # same length to return tensors. 82 | if labels is not None: 83 | max_label_length = max(len(l) for l in labels) 84 | if pad_to_multiple_of is not None: 85 | max_label_length = ( 86 | (max_label_length + pad_to_multiple_of - 1) 87 | // pad_to_multiple_of 88 | * pad_to_multiple_of 89 | ) 90 | 91 | padding_side = tokenizer.padding_side 92 | for feature in features: 93 | remainder = [label_pad_token_id] * (max_label_length - len(feature["labels"])) 94 | if isinstance(feature["labels"], list): 95 | feature["labels"] = ( 96 | feature["labels"] + remainder if padding_side == "right" else remainder + feature["labels"] 97 | ) 98 | elif padding_side == "right": 99 | feature["labels"] = np.concatenate([feature["labels"], remainder]).astype(np.int64) 100 | else: 101 | feature["labels"] = np.concatenate([remainder, feature["labels"]]).astype(np.int64) 102 | 103 | 104 | def get_batch_logps( 105 | logits: torch.FloatTensor, 106 | labels: torch.LongTensor, 107 | average_log_prob: bool = False, 108 | label_pad_token_id: int = -100, 109 | ) -> torch.FloatTensor: 110 | if logits.shape[:-1] != labels.shape: 111 | raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.") 112 | 113 | labels = labels[:, 1:].clone() 114 | logits = logits[:, :-1, :] 115 | loss_mask = labels != label_pad_token_id 116 | 117 | # dummy token; we'll ignore the losses on these tokens later 118 | labels[labels == label_pad_token_id] = 0 119 | 120 | per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) 121 | 122 | if average_log_prob: 123 | return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) 124 | else: 125 | return (per_token_logps * loss_mask).sum(-1) 126 | 127 | 128 | def set_adapter_ctx(model, adapter_name): 129 | @contextmanager 130 | def _set_adapter_ctx(): 131 | old_adapter_name = model.active_adapter 132 | try: 133 | if adapter_name is not None: 134 | model.set_adapter(adapter_name) 135 | yield model 136 | else: 137 | with model.disable_adapter(): 138 | yield model 139 | finally: 140 | model.set_adapter(old_adapter_name) 141 | return _set_adapter_ctx 142 | 143 | 144 | @dataclass 145 | class PeftAsPreTrained: 146 | model: PeftModel 147 | adapter_name: Optional[Text] = None 148 | 149 | def __post_init__(self): 150 | assert isinstance(self.model, PeftModel) 151 | if self.adapter_name: 152 | self.ctx = set_adapter_ctx(self.model, self.adapter_name) 153 | else: 154 | self.ctx = self.model.disable_adapter 155 | 156 | def __call__(self, *args, **kwargs): 157 | with self.ctx(): 158 | outputs = self.model(*args, **kwargs) 159 | return outputs 160 | 161 | def generate(self, *args, **kwargs): 162 | with self.ctx(): 163 | outputs = self.model.generate(*args, **kwargs) 164 | return outputs 165 | 166 | def __getattribute__(self, name: str) -> Any: 167 | try: 168 | return super().__getattribute__(name) 169 | except AttributeError: 170 | return getattr(self.model, name) 171 | 172 | 173 | @Accelerator().on_local_main_process 174 | def print_local_main(text): 175 | print(text) 176 | 177 | 178 | def disable_progress_bar_non_local_main(): 179 | if not Accelerator().is_local_main_process: 180 | import datasets 181 | import transformers 182 | import warnings 183 | datasets.utils.logging.disable_progress_bar() 184 | transformers.utils.logging.disable_progress_bar() 185 | warnings.filterwarnings('ignore') 186 | 187 | 188 | def param_sharding_enabled(): 189 | from transformers.modeling_utils import is_deepspeed_zero3_enabled, is_fsdp_enabled 190 | return is_deepspeed_zero3_enabled() or is_fsdp_enabled() 191 | 192 | 193 | def prepare_input(data): 194 | # adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/trainer.py#L2626 195 | if isinstance(data, Mapping): 196 | return type(data)({k: prepare_input(v) for k, v in data.items()}) 197 | elif isinstance(data, (tuple, list)): 198 | return type(data)(prepare_input(v) for v in data) 199 | elif isinstance(data, torch.Tensor): 200 | kwargs = {"device": Accelerator().device} 201 | # TODO: inference-time deepspeed? 202 | # if self.is_deepspeed_enabled and (torch.is_floating_point(data) or torch.is_complex(data)): 203 | # # NLP models inputs are int/uint and those get adjusted to the right dtype of the 204 | # # embedding. Other models such as wav2vec2's inputs are already float and thus 205 | # # may need special handling to match the dtypes of the model 206 | # kwargs.update({"dtype": Accelerator().state.deepspeed_plugin.hf_ds_config.dtype()}) 207 | return data.to(**kwargs) 208 | return data 209 | 210 | 211 | def set_seeds(seed): 212 | import random 213 | import numpy as np 214 | import torch 215 | random.seed(seed) 216 | np.random.seed(seed) 217 | torch.manual_seed(seed) 218 | torch.cuda.manual_seed(seed) 219 | torch.backends.cudnn.deterministic = True 220 | -------------------------------------------------------------------------------- /src/trainer/modpo_trainer.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Callable, Dict, List, Literal, Optional, Tuple, Union, Any 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from datasets import Dataset 8 | from transformers import DataCollator, PreTrainedModel, PreTrainedTokenizerBase, TrainingArguments 9 | from transformers.trainer_callback import TrainerCallback 10 | from transformers.trainer_utils import EvalLoopOutput 11 | 12 | from src.trainer.dpo_trainer import DPOTrainer, DPODataMapFunc, DPODataCollatorWithPadding 13 | from src.utils.reward import RewardWrapperList, RewardWrapperInput 14 | 15 | 16 | @dataclass 17 | class MODPODataMapFunc(DPODataMapFunc): 18 | def __call__(self, examples): 19 | """ 20 | Additionally keep untokenized prompts (`raw_prompt`) and responses (`chosen`, `rejected`) 21 | in the batch for easy adaptation for customized margin reward models (`src.utils.RewardWrapperBase`). 22 | 23 | For example, margin reward models can be an external API than depends on raw texts. 24 | """ 25 | new_examples = super().__call__(examples) 26 | new_examples["raw_prompt"] = examples["raw_prompt"] 27 | new_examples["chosen"] = examples["chosen"] 28 | new_examples["rejected"] = examples["rejected"] 29 | return new_examples 30 | 31 | 32 | @dataclass 33 | class MODPODataCollatorWithPadding(DPODataCollatorWithPadding): 34 | def __call__(self, features: List[Dict[str, Any]], generate: Optional[bool] = False) -> Dict[str, Any]: 35 | batch = super().__call__(features, generate) 36 | if not generate: 37 | batch["raw_prompt"] = [feature["raw_prompt"] for feature in features]*2 38 | batch["response"] = [feature["chosen"] for feature in features] + [feature["rejected"] for feature in features] 39 | return batch 40 | 41 | 42 | class MODPOTrainer(DPOTrainer): 43 | """ 44 | The MODPOTrainer is a light-weight extension of DPOTrainer that supports training with 45 | multiple margin reward models for multi-objective alignment. 46 | 47 | Please use `set_wrapped_margin_reward_model_list` to set your customized margin reward models 48 | (`wrapped_margin_reward_model_list`) and the weights for each objective (`w`). 49 | """ 50 | 51 | def __init__( 52 | self, 53 | model: Union[PreTrainedModel, nn.Module] = None, 54 | ref_model: Optional[Union[PreTrainedModel, nn.Module]] = None, 55 | beta: float = 0.1, 56 | loss_type: Literal["sigmoid", "hinge"] = "sigmoid", 57 | args: TrainingArguments = None, 58 | tokenize_map_func: Optional[Callable] = None, 59 | data_collator: Optional[DataCollator] = None, 60 | label_pad_token_id: int = -100, 61 | train_dataset: Optional[Dataset] = None, 62 | eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, 63 | tokenizer: Optional[PreTrainedTokenizerBase] = None, 64 | model_init: Optional[Callable[[], PreTrainedModel]] = None, 65 | callbacks: Optional[List[TrainerCallback]] = None, 66 | optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = ( 67 | None, 68 | None, 69 | ), 70 | preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, 71 | peft_config: Optional[Dict] = None, 72 | disable_dropout: bool = True, 73 | max_length: Optional[int] = 1024, 74 | num_proc: Optional[int] = 4, 75 | generate_during_eval: bool = True, 76 | compute_metrics: Optional[Callable[[EvalLoopOutput], Dict]] = None, 77 | ): 78 | 79 | if tokenize_map_func is None: 80 | tokenize_map_func = MODPODataMapFunc(tokenizer) 81 | 82 | if data_collator is None: 83 | data_collator = MODPODataCollatorWithPadding(tokenizer) 84 | 85 | super().__init__( 86 | model=model, 87 | ref_model=ref_model, 88 | beta=beta, 89 | loss_type=loss_type, 90 | args=args, 91 | tokenize_map_func=tokenize_map_func, 92 | data_collator=data_collator, 93 | label_pad_token_id=label_pad_token_id, 94 | train_dataset=train_dataset, 95 | eval_dataset=eval_dataset, 96 | tokenizer=tokenizer, 97 | model_init=model_init, 98 | callbacks=callbacks, 99 | optimizers=optimizers, 100 | preprocess_logits_for_metrics=preprocess_logits_for_metrics, 101 | peft_config=peft_config, 102 | disable_dropout=disable_dropout, 103 | max_length=max_length, 104 | num_proc=num_proc, 105 | generate_during_eval=generate_during_eval, 106 | compute_metrics=compute_metrics, 107 | ) 108 | 109 | def set_wrapped_margin_reward_model_list( 110 | self, 111 | wrapped_margin_reward_model_list: RewardWrapperList, 112 | w: List[float], 113 | prepare: Optional[bool] = True, 114 | ): 115 | """ 116 | Set margin reward models. 117 | 118 | Args: 119 | wrapped_margin_reward_model_list (`src.utils.RewardWrapperList`): 120 | A list of reward model to act as margin in `modpo_loss`. 121 | w (`List[float]`): 122 | A list of weights for each objective. Note that w[0] indicates the weight for 123 | the preference that we are currently training on and w[1:] indicate the weights for 124 | the margin reward models in `wrapped_margin_reward_model_list`. 125 | prepare (`bool`): 126 | Whether or not we need to `self.accelerator.prepare_model` the margin reward models for advanced distributed training. 127 | If these margin reward models are part of the self.model (e.g, lora weights), they will have been prepared 128 | in `__init__` and we would recommend `prepare=False` to avoid unnecessary model weights copies. 129 | See `scripts/modpo/beavertails/modpo.py` for a complete example. 130 | """ 131 | if prepare: 132 | def prepare(wrapped_reward_model): 133 | if hasattr(wrapped_reward_model, "model"): 134 | wrapped_reward_model.model = self.accelerator.prepare_model( 135 | wrapped_reward_model.model, evaluation_mode=True) 136 | return wrapped_reward_model 137 | wrapped_margin_reward_model_list = wrapped_margin_reward_model_list.map(prepare) 138 | self.wrapped_margin_reward_model_list = wrapped_margin_reward_model_list 139 | self.w = torch.tensor(w).to(self.accelerator.device) 140 | assert len(self.wrapped_margin_reward_model_list) == len(self.w) - 1 141 | 142 | def modpo_loss( 143 | self, 144 | policy_chosen_logps: torch.FloatTensor, 145 | policy_rejected_logps: torch.FloatTensor, 146 | reference_chosen_logps: torch.FloatTensor, 147 | reference_rejected_logps: torch.FloatTensor, 148 | chosen_margin_reward: torch.FloatTensor, 149 | rejected_margin_reward: torch.FloatTensor, 150 | ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: 151 | chosen_rewards = (1/self.w[0])*(self.beta * (policy_chosen_logps - reference_chosen_logps) - chosen_margin_reward @ self.w[1:]) 152 | rejected_rewards = (1/self.w[0])*(self.beta * (policy_rejected_logps - reference_rejected_logps) - rejected_margin_reward @ self.w[1:]) 153 | 154 | logits = chosen_rewards - rejected_rewards 155 | if self.loss_type == "sigmoid": 156 | losses = -F.logsigmoid(logits) 157 | elif self.loss_type == "hinge": 158 | losses = torch.relu(1 - logits) 159 | else: 160 | raise ValueError(f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge']") 161 | 162 | return losses, chosen_rewards.detach(), rejected_rewards.detach() 163 | 164 | def dpo_loss(self, *args, **kwargs): 165 | """Disable the `dpo_loss` inherited from the DPOTrainer""" 166 | raise NotImplementedError 167 | 168 | def get_batch_metrics( 169 | self, 170 | model, 171 | batch: Dict[str, Union[List, torch.LongTensor]], 172 | train_eval: Literal["train", "eval"] = "train", 173 | ): 174 | metrics = {} 175 | 176 | ( 177 | policy_chosen_logps, 178 | policy_rejected_logps, 179 | _, 180 | _, 181 | ) = self.forward(model, batch) 182 | with torch.no_grad(): 183 | ( 184 | reference_chosen_logps, 185 | reference_rejected_logps, 186 | _, 187 | _, 188 | ) = self.forward(self.ref_model, batch) 189 | 190 | margin_reward_list = self.wrapped_margin_reward_model_list( 191 | RewardWrapperInput(raw_prompt=batch["raw_prompt"], response=batch["response"])) 192 | margin_rewards = torch.stack(margin_reward_list, dim=-1).to( 193 | policy_chosen_logps.dtype).to(self.accelerator.device) # (B*2, n-1) 194 | chosen_margin_rewards, rejected_margin_rewards = margin_rewards.chunk(2) # (B, n-1) 195 | 196 | losses, chosen_rewards, rejected_rewards = self.modpo_loss( 197 | policy_chosen_logps, 198 | policy_rejected_logps, 199 | reference_chosen_logps, 200 | reference_rejected_logps, 201 | chosen_margin_rewards, 202 | rejected_margin_rewards, 203 | ) 204 | 205 | accuracies = (chosen_rewards > rejected_rewards).float() 206 | 207 | prefix = "eval_" if train_eval == "eval" else "" 208 | metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).cpu() 209 | metrics[f"{prefix}rewards/chosen"] = chosen_rewards.cpu() 210 | metrics[f"{prefix}rewards/rejected"] = rejected_rewards.cpu() 211 | metrics[f"{prefix}logps/margins"] = (policy_chosen_logps - policy_rejected_logps).detach().cpu() 212 | metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().cpu() 213 | metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().cpu() 214 | if train_eval == "train": 215 | metrics[f"{prefix}accuracy"] = accuracies.detach().cpu() 216 | 217 | return losses.mean(), metrics 218 | -------------------------------------------------------------------------------- /src/trainer/sft_trainer.py: -------------------------------------------------------------------------------- 1 | import random 2 | from dataclasses import dataclass 3 | from typing import Callable, Dict, List, Optional, Tuple, Union, Any 4 | 5 | import wandb 6 | import torch 7 | import torch.nn as nn 8 | from torch.utils.data import DataLoader 9 | from datasets import Dataset 10 | from transformers import ( 11 | AutoTokenizer, 12 | DataCollator, 13 | PreTrainedModel, 14 | PreTrainedTokenizerBase, 15 | Trainer, 16 | TrainingArguments, 17 | ) 18 | from transformers.trainer_callback import TrainerCallback 19 | from transformers.trainer_utils import EvalPrediction 20 | from transformers import PreTrainedTokenizerBase, TrainingArguments 21 | from transformers.trainer_utils import EvalLoopOutput 22 | from trl.import_utils import is_peft_available, is_wandb_available 23 | from trl.trainer.utils import ( 24 | PeftSavingCallback, 25 | ) 26 | 27 | from src.utils import pad_labels, print_local_main, prepare_model_for_peft, common_prefix_length 28 | 29 | 30 | if is_peft_available(): 31 | from peft import PeftConfig, PeftModel 32 | 33 | 34 | @dataclass 35 | class SFTDataMapFunc: 36 | """Map raw texts to tokens, attention masks, and labels.""" 37 | tokenizer: PreTrainedTokenizerBase 38 | label_pad_token_id: Optional[int] = -100 39 | completion_only: Optional[bool] = True 40 | 41 | def __call__(self, examples): 42 | new_examples = { 43 | "prompt_response_input_ids": [], 44 | "prompt_response_attention_mask": [], 45 | "prompt_response_labels": [], 46 | 47 | "prompt_input_ids": [], 48 | "prompt_attention_mask": [], 49 | 50 | "prompt": [], 51 | } 52 | for prompt, response in zip(examples["prompt"], examples["response"]): 53 | prompt_tokens = self.tokenizer(prompt) 54 | prompt_response_tokens = self.tokenizer(prompt + response) 55 | # add EOS to response 56 | prompt_response_tokens["input_ids"].append(self.tokenizer.eos_token_id) 57 | prompt_response_tokens["attention_mask"].append(1) 58 | 59 | prompt_len = common_prefix_length(prompt_tokens["input_ids"], prompt_response_tokens["input_ids"]) 60 | 61 | for k, toks in { 62 | "prompt": prompt_tokens, 63 | "prompt_response": prompt_response_tokens, 64 | }.items(): 65 | for type_key, tokens in toks.items(): 66 | new_examples[f"{k}_{type_key}"].append(tokens) 67 | 68 | for k, toks in { 69 | "prompt_response": prompt_response_tokens, 70 | }.items(): 71 | labels = toks["input_ids"].copy() 72 | if self.completion_only: 73 | labels[:prompt_len] = [self.label_pad_token_id] * prompt_len 74 | new_examples[f"{k}_labels"].append(labels) 75 | 76 | new_examples["prompt"] = examples["prompt"] 77 | 78 | return new_examples 79 | 80 | 81 | @dataclass 82 | class SFTDataCollatorWithPadding: 83 | tokenizer: PreTrainedTokenizerBase 84 | label_pad_token_id: Optional[int] = -100 85 | pad_to_multiple_of: Optional[int] = None 86 | 87 | def __call__(self, features: List[Dict[str, Any]], generate: Optional[bool] = False) -> Dict[str, Any]: 88 | """ 89 | if not generate: 90 | batch = { 91 | "input_ids": ..., 92 | "attention_mask": ..., 93 | "labels": ..., 94 | } 95 | else: 96 | batch = { 97 | "prompt": ..., 98 | "prompt_input_ids": ..., 99 | "prompt_attention_mask": ..., 100 | } 101 | """ 102 | if not generate: 103 | 104 | # right padding for training 105 | right_padding_features = [] 106 | for feature in features: 107 | right_padding_features.append( 108 | { 109 | "input_ids": feature["prompt_response_input_ids"], 110 | "attention_mask": feature["prompt_response_attention_mask"], 111 | "labels": feature["prompt_response_labels"], 112 | } 113 | ) 114 | 115 | pad_labels(right_padding_features, self.tokenizer, self.pad_to_multiple_of, self.label_pad_token_id) 116 | 117 | right_padding_batch = self.tokenizer.pad( 118 | right_padding_features, 119 | padding=True, 120 | pad_to_multiple_of=self.pad_to_multiple_of, 121 | return_tensors="pt", 122 | ) 123 | 124 | return right_padding_batch 125 | 126 | else: 127 | 128 | # left padding for batched generation 129 | left_padding_features = [] 130 | padding_side_default = self.tokenizer.padding_side 131 | self.tokenizer.padding_side = "left" 132 | for feature in features: 133 | left_padding_features.append( 134 | { 135 | "input_ids": feature["prompt_input_ids"], 136 | "attention_mask": feature["prompt_attention_mask"], 137 | } 138 | ) 139 | left_padding_batch = self.tokenizer.pad( 140 | left_padding_features, 141 | padding=True, 142 | pad_to_multiple_of=self.pad_to_multiple_of, 143 | return_tensors="pt", 144 | ) 145 | self.tokenizer.padding_side = padding_side_default 146 | 147 | return { 148 | "prompt": [feature["prompt"] for feature in features], 149 | "prompt_input_ids": left_padding_batch["input_ids"], 150 | "prompt_attention_mask": left_padding_batch["attention_mask"], 151 | } 152 | 153 | 154 | class SFTTrainer(Trainer): 155 | r""" 156 | Class definition of the Supervised Finetuning Trainer (SFT Trainer). 157 | This class is a wrapper around the `transformers.Trainer` class and inherits all of its attributes and methods. 158 | The trainer takes care of properly initializing the PeftModel in case a user passes a `PeftConfig` object. 159 | 160 | Args: 161 | model (Union[`transformers.PreTrainedModel`, `nn.Module`, `str`]): 162 | The model to train, can be a `PreTrainedModel`, a `torch.nn.Module` or a string with the model name to 163 | load from cache or download. The model can be also converted to a `PeftModel` if a `PeftConfig` object is 164 | passed to the `peft_config` argument. 165 | args (Optional[`transformers.TrainingArguments`]): 166 | The arguments to tweak for training. Please refer to the official documentation of `transformers.TrainingArguments` 167 | for more information. 168 | data_collator (Optional[`transformers.DataCollator`]): 169 | The data collator to use for training. 170 | train_dataset (Optional[`datasets.Dataset`]): 171 | The dataset to use for training. We recommend users to use `trl.trainer.ConstantLengthDataset` to create their dataset. 172 | eval_dataset (Optional[Union[`datasets.Dataset`, Dict[`str`, `datasets.Dataset`]]]): 173 | The dataset to use for evaluation. We recommend users to use `trl.trainer.ConstantLengthDataset` to create their dataset. 174 | tokenizer (Optional[`transformers.PreTrainedTokenizer`]): 175 | The tokenizer to use for training. If not specified, the tokenizer associated to the model will be used. 176 | model_init (`Callable[[], transformers.PreTrainedModel]`): 177 | The model initializer to use for training. If None is specified, the default model initializer will be used. 178 | compute_metrics (`Callable[[transformers.EvalPrediction], Dict]`, *optional* defaults to `compute_accuracy`): 179 | The metrics to use for evaluation. If no metrics are specified, the default metric (`compute_accuracy`) will be used. 180 | callbacks (`List[transformers.TrainerCallback]`): 181 | The callbacks to use for training. 182 | optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`): 183 | The optimizer and scheduler to use for training. 184 | preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`): 185 | The function to use to preprocess the logits before computing the metrics. 186 | peft_config (`Optional[PeftConfig]`): 187 | The PeftConfig object to use to initialize the PeftModel. 188 | max_length (`Optional[int]`): 189 | The maximum sequence length to use for the `ConstantLengthDataset` and for automaticallty creating the Dataset. Defaults to `512`. 190 | """ 191 | 192 | def __init__( 193 | self, 194 | model: Union[PreTrainedModel, nn.Module, str] = None, 195 | args: TrainingArguments = None, 196 | tokenize_map_func: Optional[Callable] = None, 197 | data_collator: Optional[DataCollator] = None, 198 | train_dataset: Optional[Dataset] = None, 199 | eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, 200 | tokenizer: Optional[PreTrainedTokenizerBase] = None, 201 | model_init: Optional[Callable[[], PreTrainedModel]] = None, 202 | compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, 203 | callbacks: Optional[List[TrainerCallback]] = None, 204 | optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), 205 | preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, 206 | peft_config: Optional[PeftConfig] = None, 207 | max_length: Optional[int] = 1024, 208 | completion_only: Optional[bool] = True, 209 | num_proc: Optional[int] = 4, 210 | generate_during_eval: Optional[bool] = True, 211 | ): 212 | if not isinstance(model, PeftModel) and is_peft_available() and peft_config: 213 | model = prepare_model_for_peft(model, peft_config, args) 214 | 215 | if tokenize_map_func is None: 216 | tokenize_map_func = SFTDataMapFunc(tokenizer, completion_only=completion_only) 217 | 218 | if data_collator is None: 219 | data_collator = SFTDataCollatorWithPadding(tokenizer) 220 | 221 | if tokenizer is None: 222 | tokenizer = AutoTokenizer.from_pretrained(model.config._name_or_path) 223 | if getattr(tokenizer, "pad_token", None) is None: 224 | tokenizer.pad_token = tokenizer.eos_token 225 | 226 | # preprocess dataset 227 | def preprocess_dataset(dataset): 228 | # tokenize samples 229 | dataset = dataset.map( 230 | tokenize_map_func, 231 | batched=True, 232 | num_proc=num_proc, 233 | remove_columns=dataset.column_names 234 | ) 235 | # truncate samples that are too long 236 | dataset = dataset.map( 237 | lambda sample: {k: v[:max_length] for k, v in sample.items()}, 238 | num_proc=num_proc, 239 | ) 240 | return dataset 241 | if "prompt_response_input_ids" not in train_dataset[0].keys(): 242 | print_local_main("dataset preprocessing...") 243 | train_dataset = preprocess_dataset(train_dataset) 244 | eval_dataset = preprocess_dataset(eval_dataset) 245 | 246 | if is_peft_available() and isinstance(model, PeftModel): 247 | if callbacks is None: 248 | callbacks = [PeftSavingCallback()] 249 | else: 250 | callbacks += [PeftSavingCallback()] 251 | 252 | if generate_during_eval and not is_wandb_available(): 253 | raise ValueError( 254 | "`generate_during_eval=True` requires Weights and Biases to be installed." 255 | " Please install `wandb` to resolve." 256 | ) 257 | 258 | self.max_length = max_length 259 | self.generate_during_eval = generate_during_eval 260 | 261 | self.table = None # for late initialization 262 | 263 | super().__init__( 264 | model=model, 265 | args=args, 266 | data_collator=data_collator, 267 | train_dataset=train_dataset, 268 | eval_dataset=eval_dataset, 269 | tokenizer=tokenizer, 270 | model_init=model_init, 271 | compute_metrics=compute_metrics, 272 | callbacks=callbacks, 273 | optimizers=optimizers, 274 | preprocess_logits_for_metrics=preprocess_logits_for_metrics, 275 | ) 276 | 277 | def train( 278 | self, 279 | resume_from_checkpoint: Optional[Union[str, bool]] = None, 280 | trial: Union["optuna.Trial", Dict[str, Any]] = None, 281 | ignore_keys_for_eval: Optional[List[str]] = None, 282 | **kwargs, 283 | ): 284 | initial_output = super().train( 285 | resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs, 286 | ) 287 | 288 | # upload wandb table at the end of training if it exists 289 | if self.table: 290 | self.log({"eval_game_log": self.table}) 291 | self.state.log_history.pop() 292 | 293 | return initial_output 294 | 295 | def evaluation_loop( 296 | self, 297 | dataloader: DataLoader, 298 | description: str, 299 | prediction_loss_only: Optional[bool] = None, 300 | ignore_keys: Optional[List[str]] = None, 301 | metric_key_prefix: str = "eval", 302 | ) -> EvalLoopOutput: 303 | # adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/dpo_trainer.py#L600-L647 304 | if self.generate_during_eval and self.state.is_world_process_zero: 305 | # late init 306 | self.table = wandb.Table(columns=["Prompt", "Policy"]) if self.table == None else self.table 307 | 308 | print("generating response...") 309 | # Generate random indices within the range of the total number of samples 310 | num_samples = len(dataloader.dataset) 311 | random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size) 312 | 313 | # Use dataloader.dataset.select to get the random batch without iterating over the DataLoader 314 | random_batch_dataset = dataloader.dataset.select(random_indices) 315 | random_batch = self.data_collator(random_batch_dataset, generate=True) 316 | random_batch = self._prepare_inputs(random_batch) 317 | 318 | # get batch samples 319 | policy_output = self.model.generate( 320 | input_ids=random_batch["prompt_input_ids"], 321 | attention_mask=random_batch["prompt_attention_mask"], 322 | max_length=self.max_length, 323 | do_sample=True, 324 | pad_token_id=self.tokenizer.pad_token_id, 325 | ) 326 | 327 | response_decoded = self.tokenizer.batch_decode(policy_output, skip_special_tokens=True) 328 | 329 | for prompt, response in zip(random_batch["prompt"], response_decoded): 330 | self.table.add_data(f"(epoch{self.state.epoch}) {prompt}", response[len(prompt):]) 331 | 332 | # barrier 333 | self.accelerator.wait_for_everyone() 334 | 335 | # Base evaluation 336 | initial_output = super().evaluation_loop( 337 | dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix 338 | ) 339 | 340 | return initial_output 341 | 342 | 343 | if __name__ == '__main__': 344 | from src.data.configs import DATASET_CONFIGS 345 | from transformers import LlamaTokenizer 346 | dataset = DATASET_CONFIGS["PKU-Alignment/PKU-SafeRLHF-10K-better"](sanity_check=True).get_sft_dataset(split="train") 347 | tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") 348 | tokenizer.pad_token = tokenizer.eos_token 349 | 350 | # SFTDataMapFunc unit test 351 | dataset = dataset.map(SFTDataMapFunc(tokenizer=tokenizer), batched=True, remove_columns=dataset.column_names) 352 | 353 | # SFTDataCollatorWithPadding unit test 354 | batch = SFTDataCollatorWithPadding(tokenizer=tokenizer)([dataset[0], dataset[1]]) 355 | batch_prompt = SFTDataCollatorWithPadding(tokenizer=tokenizer)([dataset[0], dataset[1]], generate=True) 356 | breakpoint() 357 | -------------------------------------------------------------------------------- /src/trainer/rm_trainer.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from dataclasses import dataclass 3 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Literal 4 | 5 | import torch 6 | import torch.nn as nn 7 | import numpy as np 8 | from datasets import Dataset 9 | from transformers import DataCollator, PreTrainedModel, PreTrainedTokenizerBase, Trainer, AutoTokenizer 10 | from transformers.trainer_callback import TrainerCallback 11 | from transformers.trainer_pt_utils import nested_detach 12 | from transformers.trainer_utils import EvalPrediction 13 | from trl.import_utils import is_peft_available 14 | from trl.trainer.training_configs import RewardConfig 15 | from trl.trainer.utils import PeftSavingCallback, RewardDataCollatorWithPadding, compute_accuracy 16 | 17 | from src.utils import print_local_main, prepare_model_for_peft 18 | 19 | 20 | if is_peft_available(): 21 | from peft import PeftModel 22 | 23 | 24 | @dataclass 25 | class RewardDataMapFunc: 26 | tokenizer: PreTrainedTokenizerBase 27 | 28 | def __call__(self, examples): 29 | prompt_chosen = [prompt + chosen for prompt, chosen in zip(examples["prompt"], examples["chosen"])] 30 | prompt_rejected = [prompt + rejected for prompt, rejected in zip(examples["prompt"], examples["rejected"])] 31 | chosen_sequence_tokens = self.tokenizer(prompt_chosen) 32 | rejected_sequence_tokens = self.tokenizer(prompt_rejected) 33 | 34 | return { 35 | "prompt_chosen_input_ids": chosen_sequence_tokens["input_ids"], 36 | "prompt_chosen_attention_mask": chosen_sequence_tokens["attention_mask"], 37 | "prompt_rejected_input_ids": rejected_sequence_tokens["input_ids"], 38 | "prompt_rejected_attention_mask": rejected_sequence_tokens["attention_mask"], 39 | **({"margin": examples["margin"]} if "margin" in examples else {}), 40 | } 41 | 42 | 43 | @dataclass 44 | class RewardDataCollatorWithPadding: 45 | r""" 46 | Reward DataCollator class that pads the inputs to the maximum length of the batch. 47 | Args: 48 | tokenizer (`PreTrainedTokenizerBase`): 49 | The tokenizer used for encoding the data. 50 | padding (`Union[bool, str, `PaddingStrategy`]`, `optional`, defaults to `True`): 51 | padding_strategy to pass to the tokenizer. 52 | max_length (`Optional[int]`, `optional`, defaults to `None`): 53 | The maximum length of the sequence to be processed. 54 | pad_to_multiple_of (`Optional[int]`, `optional`, defaults to `None`): 55 | If set will pad the sequence to a multiple of the provided value. 56 | return_tensors (`str`, `optional`, defaults to `"pt"`): 57 | The tensor type to use. 58 | """ 59 | tokenizer: PreTrainedTokenizerBase 60 | padding: Union[bool, str] = True 61 | max_length: Optional[int] = None 62 | pad_to_multiple_of: Optional[int] = None 63 | return_tensors: str = "pt" 64 | 65 | def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: 66 | """ 67 | batch = { 68 | "input_ids": ..., 69 | "attention_mask": ..., 70 | "return_loss": True, 71 | "margin": ..., (optional) 72 | } 73 | """ 74 | new_features = [] 75 | # check if we have a margin. If we do, we need to batch it as well 76 | margin = [] 77 | if "margin" in features[0]: 78 | margin = [feature["margin"] for feature in features] 79 | for feature in features: 80 | new_features.append( 81 | { 82 | "input_ids": feature["prompt_chosen_input_ids"], 83 | "attention_mask": feature["prompt_chosen_attention_mask"], 84 | } 85 | ) 86 | for feature in features: 87 | new_features.append( 88 | { 89 | "input_ids": feature["prompt_rejected_input_ids"], 90 | "attention_mask": feature["prompt_rejected_attention_mask"], 91 | } 92 | ) 93 | batch = self.tokenizer.pad( 94 | new_features, 95 | padding=self.padding, 96 | max_length=self.max_length, 97 | pad_to_multiple_of=self.pad_to_multiple_of, 98 | return_tensors=self.return_tensors, 99 | ) 100 | batch["return_loss"] = True, 101 | if margin: 102 | margin = torch.tensor(margin, dtype=torch.float) 103 | batch["margin"] = margin 104 | return batch 105 | 106 | 107 | class RewardTrainer(Trainer): 108 | r""" 109 | The RewardTrainer can be used to train your custom Reward Model. It is a subclass of the 110 | `transformers.Trainer` class and inherits all of its attributes and methods. It is recommended to use 111 | an `AutoModelForSequenceClassification` as the reward model. The reward model should be trained on a dataset 112 | of paired examples, where each example is a tuple of two sequences. The reward model should be trained to 113 | predict which example in the pair is more relevant to the task at hand. 114 | 115 | The reward trainer expects a very specific format for the dataset. The dataset should contain two 4 entries at least 116 | if you don't use the default `RewardDataCollatorWithPadding` data collator. The entries should be named 117 | - `input_ids_chosen` 118 | - `attention_mask_chosen` 119 | - `input_ids_rejected` 120 | - `attention_mask_rejected` 121 | 122 | Optionally, you can also pass a `margin` entry to the dataset. This entry should contain the margin used to modulate the 123 | loss of the reward model as outlined in https://ai.meta.com/research/publications/llama-2-open-foundation-and-fine-tuned-chat-models/. 124 | If you don't pass a margin, no margin will be used. 125 | """ 126 | 127 | def __init__( 128 | self, 129 | model: Union[PreTrainedModel, nn.Module] = None, 130 | args: Optional[RewardConfig] = None, 131 | tokenize_map_func: Optional[Callable] = None, 132 | data_collator: Optional[DataCollator] = None, 133 | train_dataset: Optional[Dataset] = None, 134 | eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, 135 | tokenizer: Optional[PreTrainedTokenizerBase] = None, 136 | model_init: Optional[Callable[[], PreTrainedModel]] = None, 137 | compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, 138 | callbacks: Optional[List[TrainerCallback]] = None, 139 | optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = ( 140 | None, 141 | None, 142 | ), 143 | preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, 144 | peft_config: Optional[Dict] = None, 145 | max_length: Optional[int] = 1024, 146 | filter_too_long: Optional[bool] = True, 147 | num_proc: Optional[int] = 4, 148 | ): 149 | """ 150 | Initialize RewardTrainer. 151 | 152 | Args: 153 | model (`transformers.PreTrainedModel`): 154 | The model to train, preferably an `AutoModelForSequenceClassification`. 155 | args (`RewardConfig`): 156 | The arguments to use for training. 157 | data_collator (`transformers.DataCollator`): 158 | The data collator to use for training. If None is specified, the default data collator (`RewardDataCollatorWithPadding`) will be used 159 | which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences. 160 | train_dataset (`datasets.Dataset`): 161 | The dataset to use for training. 162 | eval_dataset (`datasets.Dataset`): 163 | The dataset to use for evaluation. 164 | tokenizer (`transformers.PreTrainedTokenizerBase`): 165 | The tokenizer to use for training. This argument is required if you want to use the default data collator. 166 | model_init (`Callable[[], transformers.PreTrainedModel]`): 167 | The model initializer to use for training. If None is specified, the default model initializer will be used. 168 | compute_metrics (`Callable[[transformers.EvalPrediction], Dict]`, *optional* defaults to `compute_accuracy`): 169 | The metrics to use for evaluation. If no metrics are specified, the default metric (`compute_accuracy`) will be used. 170 | callbacks (`List[transformers.TrainerCallback]`): 171 | The callbacks to use for training. 172 | optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`): 173 | The optimizer and scheduler to use for training. 174 | preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`): 175 | The function to use to preprocess the logits before computing the metrics. 176 | peft_config (`Dict`, defaults to `None`): 177 | The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model. 178 | """ 179 | if not isinstance(model, PeftModel) and is_peft_available() and peft_config: 180 | model = prepare_model_for_peft(model, peft_config, args) 181 | 182 | if tokenize_map_func is None: 183 | tokenize_map_func = RewardDataMapFunc(tokenizer) 184 | 185 | if data_collator is None: 186 | data_collator = RewardDataCollatorWithPadding(tokenizer) 187 | 188 | if tokenizer is None: 189 | tokenizer = AutoTokenizer.from_pretrained(model.config._name_or_path) 190 | if getattr(tokenizer, "pad_token", None) is None: 191 | tokenizer.pad_token = tokenizer.eos_token 192 | 193 | def preprocess_dataset(dataset): 194 | # tokenize samples 195 | dataset = dataset.map( 196 | tokenize_map_func, 197 | batched=True, 198 | num_proc=num_proc, 199 | remove_columns=dataset.column_names 200 | ) 201 | original_length = len(dataset) 202 | # filter samples that are too long 203 | if filter_too_long: 204 | dataset = dataset.filter( 205 | lambda x: len(x["prompt_chosen_input_ids"]) <= max_length and len(x["prompt_rejected_input_ids"]) <= max_length 206 | ) 207 | else: 208 | dataset = dataset.map( 209 | # truncate chosen and rejected 210 | lambda sample: {k: v[:max_length] if ('chosen' in k or 'rejected' in k) else v for k, v in sample.items()}, 211 | num_proc=num_proc, 212 | ) 213 | filtered_length = len(dataset) 214 | return dataset, filtered_length / original_length 215 | preprocess_here = False 216 | if "prompt_chosen_input_ids" not in train_dataset[0].keys(): 217 | preprocess_here = True 218 | print_local_main("dataset preprocessing...") 219 | train_dataset, train_dataset_retain = preprocess_dataset(train_dataset) 220 | eval_dataset, eval_dataset_retain = preprocess_dataset(eval_dataset) 221 | print_local_main(f"train_dataset_retain: {train_dataset_retain}") 222 | print_local_main(f"eval_dataset_retain: {eval_dataset_retain}") 223 | 224 | if is_peft_available() and isinstance(model, PeftModel): 225 | if callbacks is None: 226 | callbacks = [PeftSavingCallback()] 227 | else: 228 | callbacks += [PeftSavingCallback()] 229 | 230 | if compute_metrics is None: 231 | compute_metrics = compute_accuracy 232 | 233 | self._stored_metrics = defaultdict(lambda: defaultdict(list)) 234 | 235 | super().__init__( 236 | model, 237 | args, 238 | data_collator, 239 | train_dataset, 240 | eval_dataset, 241 | tokenizer, 242 | model_init, 243 | compute_metrics, 244 | callbacks, 245 | optimizers, 246 | preprocess_logits_for_metrics, 247 | ) 248 | 249 | if preprocess_here: 250 | self.log({"eval_dataset_retain": train_dataset_retain, "dataset_retain": eval_dataset_retain}) 251 | 252 | def compute_loss( 253 | self, 254 | model: Union[PreTrainedModel, nn.Module], 255 | inputs: Dict[str, Union[torch.Tensor, Any]], 256 | return_outputs=False, 257 | ) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]: 258 | rewards = model( 259 | input_ids=inputs["input_ids"], 260 | attention_mask=inputs["attention_mask"], 261 | )[0] 262 | chosen_rewards, rejected_rewards = rewards.chunk(2) 263 | # calculate loss, optionally modulate with margin 264 | if "margin" in inputs: 265 | loss = -nn.functional.logsigmoid(chosen_rewards - rejected_rewards - inputs["margin"]).mean() 266 | else: 267 | loss = -nn.functional.logsigmoid(chosen_rewards - rejected_rewards).mean() 268 | 269 | # force log the metrics 270 | if self.accelerator.is_main_process: 271 | accuracy = (chosen_rewards > rejected_rewards).float().detach().cpu() 272 | self.store_metrics({"accuracy": accuracy}, train_eval="train") 273 | 274 | if return_outputs: 275 | return loss, { 276 | "chosen_rewards": chosen_rewards, 277 | "rejected_rewards": rejected_rewards, 278 | } 279 | return loss 280 | 281 | def prediction_step( 282 | self, 283 | model: Union[PreTrainedModel, nn.Module], 284 | inputs: Dict[str, Union[torch.Tensor, Any]], 285 | prediction_loss_only: bool, 286 | ignore_keys: Optional[List[str]] = None, 287 | ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: 288 | inputs = self._prepare_inputs(inputs) 289 | if ignore_keys is None: 290 | if hasattr(self.model, "config"): 291 | ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", []) 292 | else: 293 | ignore_keys = [] 294 | 295 | with torch.no_grad(): 296 | loss, logits_dict = self.compute_loss(model, inputs, return_outputs=True) 297 | 298 | if prediction_loss_only: 299 | return (loss, None, None) 300 | 301 | loss = loss.detach() 302 | logits = tuple(v for k, v in logits_dict.items() if k not in ignore_keys) 303 | logits = nested_detach(logits) 304 | # Stack accepted against rejected, mean over logits 305 | # and softmax to get preferences between accepted and rejected to sum to 1 306 | logits = torch.stack(logits).mean(dim=2).softmax(dim=0).T 307 | 308 | labels = torch.zeros(logits.shape[0]) 309 | labels = self._prepare_inputs(labels) 310 | 311 | return loss, logits, labels 312 | 313 | def store_metrics(self, metrics: Dict[str, np.ndarray], train_eval: Literal["train", "eval"] = "train") -> None: 314 | for key, value in metrics.items(): 315 | self._stored_metrics[train_eval][key].append(value.mean()) 316 | 317 | def log(self, logs: Dict[str, float]) -> None: 318 | """ 319 | Log `logs` on the various objects watching training, including stored metrics. 320 | 321 | Args: 322 | logs (`Dict[str, float]`): 323 | The values to log. 324 | """ 325 | # logs either has 'loss' or 'eval_loss' 326 | train_eval = "train" if "loss" in logs else "eval" 327 | # Add averaged stored metrics to logs 328 | for key, metrics in self._stored_metrics[train_eval].items(): 329 | logs[key] = torch.tensor(metrics).mean().item() 330 | del self._stored_metrics[train_eval] 331 | return super().log(logs) 332 | 333 | 334 | if __name__ == '__main__': 335 | from src.data.configs import DATASET_CONFIGS 336 | from transformers import LlamaTokenizer 337 | dataset = DATASET_CONFIGS["PKU-Alignment/PKU-SafeRLHF-better"](sanity_check=True).get_preference_dataset(split="test") 338 | tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") 339 | tokenizer.pad_token = tokenizer.eos_token 340 | dataset = dataset.map(RewardDataMapFunc(tokenizer=tokenizer), batched=True) 341 | batch = RewardDataCollatorWithPadding(tokenizer=tokenizer)([dataset[0], dataset[1]]) 342 | -------------------------------------------------------------------------------- /scripts/modpo/beavertails/utils/score_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023-2024 PKU-Alignment Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | from __future__ import annotations 17 | 18 | from typing import Any, Literal 19 | from abc import abstractmethod 20 | from dataclasses import dataclass 21 | 22 | import torch 23 | import torch.nn as nn 24 | from torch import distributed as dist 25 | from torch.types import Number 26 | from transformers import LlamaModel, LlamaPreTrainedModel, PretrainedConfig, PreTrainedModel 27 | from transformers.models.llama.modeling_llama import _CONFIG_FOR_DOC, LLAMA_INPUTS_DOCSTRING 28 | from transformers.utils.doc import add_start_docstrings_to_model_forward, replace_return_docstrings 29 | from transformers.utils.generic import ModelOutput 30 | 31 | 32 | NormalizeFunction = Literal['affine', 'scale', 'translate', 'identity'] 33 | NormalizerType = Literal['RunningMeanStd', 'ExponentialMovingAverage'] 34 | 35 | 36 | class Normalizer(nn.Module): 37 | """Normalize input to have zero mean and unit variance.""" 38 | 39 | mean: torch.Tensor 40 | var: torch.Tensor 41 | count: torch.LongTensor 42 | normalize_function: NormalizeFunction 43 | 44 | def __init__( 45 | self, 46 | normalize_function: NormalizeFunction, 47 | shape: tuple[int, ...], 48 | device: torch.device | str | None = None, 49 | ) -> None: 50 | """Initialize.""" 51 | super().__init__() 52 | if normalize_function not in {'affine', 'scale', 'translate', 'identity'}: 53 | raise ValueError( 54 | f'Invalid normalization function type: {normalize_function}. ', 55 | 'Expected one of "affine", "scale", "translate", "identity".', 56 | ) 57 | self.normalize_function = normalize_function 58 | self.register_buffer('mean', torch.zeros(shape, device=device)) 59 | self.register_buffer('var', torch.ones(shape, device=device)) 60 | self.register_buffer('count', torch.zeros(1, dtype=torch.long, device=device)) 61 | 62 | @abstractmethod 63 | def update(self, data: torch.Tensor) -> None: 64 | """Update mean and variance.""" 65 | raise NotImplementedError 66 | 67 | @property 68 | def std(self) -> torch.Tensor: 69 | """Return standard deviation.""" 70 | return self.var.sqrt() 71 | 72 | def set_mean_var( 73 | self, 74 | mean: torch.Tensor | list[float] | tuple[float, ...] | None, 75 | var: torch.Tensor | list[float] | tuple[float, ...] | None, 76 | ) -> None: 77 | """Set mean and variance.""" 78 | mean = ( 79 | torch.as_tensor(mean, dtype=self.mean.dtype, device=self.mean.device) 80 | if mean is not None 81 | else self.mean 82 | ) 83 | var = ( 84 | torch.as_tensor(var, dtype=self.var.dtype, device=self.var.device) 85 | if var is not None 86 | else self.var 87 | ) 88 | 89 | assert mean.shape == self.mean.shape 90 | assert var.shape == self.var.shape 91 | 92 | self.mean = mean 93 | self.var = var 94 | 95 | def forward( 96 | self, 97 | data: torch.Tensor, 98 | epsilon: Number = 1e-8, 99 | ) -> torch.Tensor: 100 | """Update and normalize input.""" 101 | if self.training: 102 | self.update(data) 103 | return self.normalize(data, epsilon=epsilon) 104 | 105 | def normalize( 106 | self, 107 | data: torch.Tensor, 108 | epsilon: Number = 1e-8, 109 | ) -> torch.Tensor: 110 | """Normalize input.""" 111 | if self.normalize_function == 'affine': 112 | return (data - self.mean.detach()) / (self.std.detach() + epsilon) 113 | if self.normalize_function == 'scale': 114 | return data / (self.std.detach() + epsilon) 115 | if self.normalize_function == 'translate': 116 | return data - self.mean.detach() 117 | if self.normalize_function == 'identity': 118 | return data 119 | raise ValueError( 120 | f'Invalid normalization function type: {self.normalize_function}. ', 121 | 'Expected one of "affine", "scale", "translate", "identity".', 122 | ) 123 | 124 | @classmethod 125 | def instantiate( 126 | cls, 127 | normalizer_type: NormalizerType | None, 128 | normalize_function: NormalizeFunction, 129 | shape: tuple[int, ...], 130 | device: torch.device | str | None = None, 131 | **kwargs: Any, 132 | ) -> Normalizer: 133 | """Get a normalizer.""" 134 | if normalizer_type == 'RunningMeanStd': 135 | return RunningMeanStd( 136 | normalize_function, 137 | shape=shape, 138 | device=device, 139 | ) 140 | if normalizer_type == 'ExponentialMovingAverage': 141 | return ExponentialMovingAverage( 142 | normalize_function, 143 | shape=shape, 144 | device=device, 145 | **kwargs, 146 | ) 147 | if normalizer_type is None: 148 | return IdentityNormalizer( 149 | normalize_function, 150 | shape=shape, 151 | device=device, 152 | ) 153 | raise ValueError( 154 | f'Invalid normalization function type: {normalizer_type}. ' 155 | 'Expected one of "RunningMeanStd", "ExponentialMovingAverage".', 156 | ) 157 | 158 | 159 | class RunningMeanStd(Normalizer): 160 | """Running mean and standard deviation.""" 161 | 162 | def update(self, data: torch.Tensor) -> None: 163 | """Update mean and variance.""" 164 | batch_mean = data.mean(dim=0) 165 | batch_var = data.var(dim=0) 166 | batch_count = data.size(0) 167 | 168 | delta = batch_mean - self.mean 169 | total_count = self.count + batch_count 170 | 171 | new_mean = self.mean + delta * batch_count / total_count 172 | m_a = self.var * self.count 173 | m_b = batch_var * batch_count 174 | m2 = ( # pylint: disable=invalid-name 175 | m_a + m_b + torch.square(delta) * (self.count * batch_count / total_count) 176 | ) 177 | new_var = m2 / total_count 178 | 179 | self.mean = new_mean 180 | self.var = new_var 181 | self.count = total_count 182 | 183 | 184 | class ExponentialMovingAverage(Normalizer): 185 | """Exponential moving average.""" 186 | 187 | def __init__( 188 | self, 189 | normalize_function: NormalizeFunction, 190 | shape: tuple[int, ...], 191 | device: torch.device | str | None = None, 192 | momentum: float = 0.9, 193 | ) -> None: 194 | super().__init__(normalize_function, shape=shape, device=device) 195 | self.momentum = momentum 196 | 197 | def update(self, data: torch.Tensor) -> None: 198 | """Update mean and variance.""" 199 | batch_mean = data.mean(dim=0) 200 | batch_var = data.var(dim=0) 201 | batch_count = data.size(0) 202 | 203 | self.mean = self.momentum * self.mean + (1.0 - self.momentum) * batch_mean 204 | self.var = self.momentum * self.var + (1.0 - self.momentum) * batch_var 205 | self.count += batch_count # pylint: disable=no-member 206 | 207 | 208 | class IdentityNormalizer(Normalizer): 209 | """Identity normalizer.""" 210 | 211 | def update(self, data: torch.Tensor) -> None: 212 | """Update mean and variance.""" 213 | self.count += data.size(0) # pylint: disable=no-member 214 | 215 | 216 | @dataclass 217 | class ScoreModelOutput(ModelOutput): 218 | """ 219 | Output of the score model. 220 | 221 | Args: 222 | scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, score_dim)`): 223 | Prediction scores of the score model. 224 | end_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, score_dim)`): 225 | Prediction scores of the end of the sequence. 226 | last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_dim)`): 227 | Sequence of hidden-states at the output of the last layer of the model. 228 | end_last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, hidden_dim)`): 229 | Last hidden state of the sequence at the output of the last layer of the model. 230 | end_index (`torch.LongTensor` of shape `(batch_size,)`): 231 | Indices of the end of the sequence. 232 | """ 233 | 234 | scores: torch.FloatTensor | None = None # size = (B, L, D) 235 | end_scores: torch.FloatTensor | None = None # size = (B, D) 236 | last_hidden_state: torch.FloatTensor | None = None # size = (B, L, E) 237 | end_last_hidden_state: torch.FloatTensor | None = None # size = (B, E) 238 | end_index: torch.LongTensor | None = None # size = (B,) 239 | 240 | 241 | class ScoreModelMixin: 242 | """Base class for score models.""" 243 | 244 | score_head: nn.Linear 245 | normalizer: Normalizer 246 | do_normalize: bool = False 247 | normalize_function: NormalizeFunction = 'affine' 248 | _is_score_head_initialized: bool = False 249 | 250 | def init_score_head(self, config: PretrainedConfig, hidden_size: int, **kwargs: Any) -> None: 251 | """Initialize the score head.""" 252 | if self._is_score_head_initialized: 253 | return 254 | 255 | self.score_dim = config.score_dim = kwargs.pop( 256 | 'score_dim', 257 | getattr(config, 'score_dim', 1), 258 | ) 259 | self.score_bias = config.score_bias = kwargs.pop( 260 | 'score_bias', 261 | getattr(config, 'score_bias', True), 262 | ) 263 | 264 | self.score_head = nn.Linear(hidden_size, config.score_dim, bias=config.score_bias) 265 | if config.score_bias: 266 | nn.init.zeros_(self.score_head.bias) 267 | 268 | config.score_type = kwargs.pop('score_type', getattr(config, 'score_type', 'reward')) 269 | if config.score_type == 'reward': 270 | self.normalize_function = 'affine' 271 | elif config.score_type == 'cost': 272 | self.normalize_function = 'scale' 273 | elif config.score_type == 'critic': 274 | self.normalize_function = 'identity' 275 | else: 276 | raise ValueError( 277 | f"Invalid score type: {config.score_type}. Expected one of 'reward', 'cost', or 'critic'.", 278 | ) 279 | 280 | self.do_normalize = config.do_normalize = kwargs.pop( 281 | 'do_normalize', 282 | getattr(config, 'do_normalize', False), 283 | ) 284 | 285 | config.normalizer_type = kwargs.pop( 286 | 'normalizer_type', 287 | getattr(config, 'normalizer_type', None), 288 | ) 289 | if config.normalizer_type not in {'RunningMeanStd', 'ExponentialMovingAverage', None}: 290 | raise ValueError( 291 | f'Invalid norm type: {config.normalizer_type}.' 292 | "Expected one of 'RunningMeanStd', 'ExponentialMovingAverage', or None.", 293 | ) 294 | if config.normalizer_type == 'ExponentialMovingAverage': 295 | config.momentum = kwargs.pop('momentum', getattr(config, 'momentum', None)) 296 | momentum = getattr(config, 'momentum', None) 297 | self.normalizer = Normalizer.instantiate( 298 | normalizer_type=config.normalizer_type, 299 | normalize_function=self.normalize_function, 300 | shape=(config.score_dim,), 301 | momentum=momentum, 302 | ) 303 | 304 | mean = getattr(config, 'mean', None) 305 | var = getattr(config, 'var', None) 306 | self.normalizer.set_mean_var(mean, var) 307 | 308 | self._is_score_head_initialized = True 309 | 310 | def get_scores( 311 | self, 312 | last_hidden_state: torch.FloatTensor, # size = (B, L, E) 313 | attention_mask: torch.BoolTensor | None = None, # size = (B, L) 314 | return_dict: bool | None = None, 315 | ) -> tuple[torch.Tensor, torch.Tensor] | ScoreModelOutput: 316 | """Forward pass of the score model.""" 317 | B, L, E = last_hidden_state.size() 318 | 319 | if attention_mask is None: 320 | if B > 1: 321 | raise ValueError("'attention_mask' is required when batch size > 1.") 322 | attention_mask = last_hidden_state.new_ones(B, L, dtype=torch.bool) # size = (B, L) 323 | 324 | scores = self.score_head(last_hidden_state).float() # size = (B, L, D) 325 | 326 | end_index = torch.cat([m.nonzero()[-1] for m in attention_mask]) # size = (B,) 327 | end_last_hidden_state = torch.gather( # size = (B, 1, E) 328 | last_hidden_state, 329 | dim=1, 330 | index=( 331 | end_index.to(last_hidden_state.device) 332 | .unsqueeze(dim=1) 333 | .unsqueeze(dim=2) 334 | .expand(-1, -1, last_hidden_state.size(-1)) 335 | ), 336 | ) 337 | end_scores = torch.gather( # size = (B, 1, D) 338 | scores, 339 | dim=1, 340 | index=( 341 | end_index.to(scores.device) 342 | .unsqueeze(dim=1) 343 | .unsqueeze(dim=2) 344 | .expand(-1, -1, scores.size(-1)) 345 | ), 346 | ) 347 | end_last_hidden_state = end_last_hidden_state.squeeze(dim=1) # size = (B, E) 348 | end_scores = end_scores.squeeze(dim=1) # size = (B, D) 349 | 350 | if self.training: 351 | if dist.is_initialized(): 352 | gathered_end_scores_list = [ 353 | torch.zeros_like(end_scores) for _ in range(dist.get_world_size()) 354 | ] 355 | dist.all_gather(gathered_end_scores_list, end_scores) 356 | gathered_end_scores = torch.cat(gathered_end_scores_list, dim=0) 357 | self.normalizer.update(gathered_end_scores) 358 | else: 359 | self.normalizer.update(end_scores) 360 | self.config.mean = self.normalizer.mean.tolist() 361 | self.config.var = self.normalizer.var.tolist() 362 | 363 | if self.do_normalize: 364 | scores = self.normalizer.normalize(scores) 365 | end_scores = self.normalizer.normalize(end_scores) 366 | 367 | if not return_dict: 368 | return scores, end_scores 369 | 370 | return ScoreModelOutput( 371 | scores=scores, # size = (B, L, D) 372 | end_scores=end_scores, # size = (B, D) 373 | last_hidden_state=last_hidden_state, # size = (B, L, E) 374 | end_last_hidden_state=end_last_hidden_state, # size = (B, E) 375 | end_index=end_index, # size = (B,) 376 | ) 377 | 378 | def set_normalize(self, mode: bool = True) -> None: 379 | if self.do_normalize == mode: 380 | return 381 | 382 | self.do_normalize = self.config.do_normalize = mode 383 | 384 | 385 | class LlamaForScore(ScoreModelMixin, LlamaPreTrainedModel): 386 | def __init__(self, config: PretrainedConfig, **kwargs: Any) -> None: 387 | super().__init__(config) 388 | self.model = LlamaModel(config) 389 | 390 | config.architectures = [self.__class__.__name__] 391 | self.init_score_head(config, hidden_size=config.hidden_size, **kwargs) 392 | 393 | # Initialize weights and apply final processing 394 | self.post_init() 395 | 396 | def get_input_embeddings(self) -> nn.Embedding: 397 | return self.model.embed_tokens 398 | 399 | def set_input_embeddings(self, value: nn.Embedding) -> None: 400 | self.model.embed_tokens = value 401 | 402 | def get_output_embeddings(self) -> None: 403 | return None 404 | 405 | def set_decoder(self, decoder: PreTrainedModel) -> None: 406 | self.model = decoder 407 | 408 | def get_decoder(self) -> PreTrainedModel: 409 | return self.model 410 | 411 | @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) 412 | @replace_return_docstrings(output_type=ScoreModelOutput, config_class=_CONFIG_FOR_DOC) 413 | def forward( # pylint: disable=too-many-arguments 414 | self, 415 | input_ids: torch.LongTensor | None = None, 416 | attention_mask: torch.Tensor | None = None, 417 | position_ids: torch.LongTensor | None = None, 418 | past_key_values: list[torch.FloatTensor] | None = None, 419 | inputs_embeds: torch.FloatTensor | None = None, 420 | use_cache: bool | None = None, 421 | return_dict: bool | None = None, 422 | ) -> tuple[torch.Tensor, torch.Tensor] | ScoreModelOutput: 423 | """ 424 | Args: 425 | 426 | Returns: 427 | 428 | Examples: 429 | 430 | ```python 431 | >>> from safe_rlhf.models.score_model.llama.modeling_llama import LlamaForScore 432 | >>> from transformers import LlamaTokenizer 433 | 434 | >>> model = LlamaForScore.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) 435 | >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) 436 | 437 | >>> prompt = "Hey, are you conscious? Can you talk to me?" 438 | >>> inputs = tokenizer(prompt, return_tensors="pt") 439 | 440 | # get score 441 | >>> outputs = model(**inputs) 442 | >>> end_scores = outputs.end_scores 443 | >>> end_scores 444 | tensor([[0.0000]]) 445 | ``` 446 | """ 447 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 448 | 449 | outputs = self.model( 450 | input_ids, 451 | attention_mask=attention_mask, 452 | position_ids=position_ids, 453 | past_key_values=past_key_values, 454 | inputs_embeds=inputs_embeds, 455 | use_cache=use_cache, 456 | output_attentions=False, 457 | output_hidden_states=False, 458 | return_dict=True, 459 | ) 460 | last_hidden_state = outputs.last_hidden_state # size = (B, L, E) 461 | return self.get_scores( 462 | last_hidden_state, 463 | attention_mask=attention_mask, 464 | return_dict=return_dict, 465 | ) 466 | 467 | 468 | if __name__ == "__main__": 469 | import torch 470 | from transformers import AutoTokenizer 471 | 472 | model = LlamaForScore.from_pretrained('PKU-Alignment/beaver-7b-v1.0-reward', torch_dtype=torch.bfloat16, device_map='cpu') 473 | tokenizer = AutoTokenizer.from_pretrained('PKU-Alignment/beaver-7b-v1.0-reward') 474 | 475 | input = 'BEGINNING OF CONVERSATION: USER: hello ASSISTANT:Hello! How can I help you today?' 476 | 477 | input_ids = tokenizer(input, return_tensors='pt') 478 | output = model(**input_ids) 479 | print(output) 480 | 481 | 482 | model = LlamaForScore.from_pretrained('PKU-Alignment/beaver-7b-v1.0-cost', torch_dtype=torch.bfloat16, device_map='cpu') 483 | tokenizer = AutoTokenizer.from_pretrained('PKU-Alignment/beaver-7b-v1.0-cost') 484 | 485 | input = 'BEGINNING OF CONVERSATION: USER: hello ASSISTANT:Hello! How can I help you today?' 486 | 487 | input_ids = tokenizer(input, return_tensors='pt') 488 | output = model(**input_ids) 489 | print(output) 490 | -------------------------------------------------------------------------------- /src/trainer/dpo_trainer.py: -------------------------------------------------------------------------------- 1 | import random 2 | from collections import defaultdict 3 | from dataclasses import dataclass 4 | from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import numpy as np 10 | from accelerate.utils import is_deepspeed_available 11 | from datasets import Dataset 12 | from torch.utils.data import DataLoader 13 | from transformers import DataCollator, PreTrainedModel, PreTrainedTokenizerBase, Trainer, TrainingArguments, AutoTokenizer 14 | from transformers.trainer_callback import TrainerCallback 15 | from transformers.trainer_utils import EvalLoopOutput 16 | 17 | from trl.import_utils import is_peft_available, is_wandb_available 18 | from trl.models import PreTrainedModelWrapper, create_reference_model 19 | from trl.trainer.utils import PeftSavingCallback, compute_accuracy, disable_dropout_in_model, pad_to_length 20 | 21 | from src.utils import print_local_main, prepare_model_for_peft, get_batch_logps, pad_labels, common_prefix_length, PeftAsPreTrained 22 | 23 | 24 | if is_peft_available(): 25 | from peft import PeftModel 26 | 27 | if is_wandb_available(): 28 | import wandb 29 | 30 | if is_deepspeed_available(): 31 | import deepspeed 32 | 33 | 34 | @dataclass 35 | class DPODataMapFunc: 36 | """Map raw texts to tokens, attention masks, and labels.""" 37 | tokenizer: PreTrainedTokenizerBase 38 | label_pad_token_id: Optional[int] = -100 39 | completion_only: Optional[bool] = True 40 | 41 | def __call__(self, examples): 42 | new_examples = { 43 | "prompt_chosen_input_ids": [], 44 | "prompt_chosen_attention_mask": [], 45 | "prompt_chosen_labels": [], 46 | 47 | "prompt_rejected_input_ids": [], 48 | "prompt_rejected_attention_mask": [], 49 | "prompt_rejected_labels": [], 50 | 51 | "prompt_input_ids": [], 52 | "prompt_attention_mask": [], 53 | 54 | "prompt": [], 55 | } 56 | 57 | for prompt, chosen, rejected in zip(examples["prompt"], examples["chosen"], examples["rejected"]): 58 | prompt_tokens = self.tokenizer(prompt) 59 | prompt_chosen_tokens = self.tokenizer(prompt + chosen) 60 | prompt_rejected_tokens = self.tokenizer(prompt + rejected) 61 | # add EOS to response 62 | prompt_chosen_tokens["input_ids"].append(self.tokenizer.eos_token_id) 63 | prompt_chosen_tokens["attention_mask"].append(1) 64 | prompt_rejected_tokens["input_ids"].append(self.tokenizer.eos_token_id) 65 | prompt_rejected_tokens["attention_mask"].append(1) 66 | 67 | prompt_len = common_prefix_length(prompt_chosen_tokens["input_ids"], prompt_rejected_tokens["input_ids"]) 68 | 69 | for k, toks in { 70 | "prompt": prompt_tokens, 71 | "prompt_chosen": prompt_chosen_tokens, 72 | "prompt_rejected": prompt_rejected_tokens, 73 | }.items(): 74 | for type_key, tokens in toks.items(): 75 | new_examples[f"{k}_{type_key}"].append(tokens) 76 | 77 | for k, toks in { 78 | "prompt_chosen": prompt_chosen_tokens, 79 | "prompt_rejected": prompt_rejected_tokens, 80 | }.items(): 81 | labels = toks["input_ids"].copy() 82 | if self.completion_only: 83 | labels[:prompt_len] = [self.label_pad_token_id] * prompt_len 84 | new_examples[f"{k}_labels"].append(labels) 85 | 86 | new_examples["prompt"] = examples["prompt"] 87 | 88 | return new_examples 89 | 90 | 91 | @dataclass 92 | class DPODataCollatorWithPadding: 93 | tokenizer: PreTrainedTokenizerBase 94 | label_pad_token_id: Optional[int] = -100 95 | pad_to_multiple_of: Optional[int] = None 96 | 97 | def __call__(self, features: List[Dict[str, Any]], generate: Optional[bool] = False) -> Dict[str, Any]: 98 | """ 99 | if not generate: 100 | batch = { 101 | "input_ids": ..., 102 | "attention_mask": ..., 103 | "labels": ..., 104 | } 105 | else: 106 | batch = { 107 | "prompt": ..., 108 | "prompt_input_ids": ..., 109 | "prompt_attention_mask": ..., 110 | } 111 | """ 112 | if not generate: 113 | 114 | # `chosen` and `rejected` merged into a single batch for more efficient batched forward pass; 115 | right_padding_features = [] 116 | for feature in features: 117 | right_padding_features.append( 118 | { 119 | "input_ids": feature["prompt_chosen_input_ids"], 120 | "attention_mask": feature["prompt_chosen_attention_mask"], 121 | "labels": feature["prompt_chosen_labels"], 122 | } 123 | ) 124 | for feature in features: 125 | right_padding_features.append( 126 | { 127 | "input_ids": feature["prompt_rejected_input_ids"], 128 | "attention_mask": feature["prompt_rejected_attention_mask"], 129 | "labels": feature["prompt_rejected_labels"], 130 | } 131 | ) 132 | 133 | pad_labels(right_padding_features, self.tokenizer, self.pad_to_multiple_of, self.label_pad_token_id) 134 | 135 | right_padding_batch = self.tokenizer.pad( 136 | right_padding_features, 137 | padding=True, 138 | pad_to_multiple_of=self.pad_to_multiple_of, 139 | return_tensors="pt", 140 | ) 141 | 142 | return right_padding_batch 143 | 144 | else: 145 | 146 | left_padding_features = [] 147 | padding_side_default = self.tokenizer.padding_side 148 | self.tokenizer.padding_side = "left" 149 | for feature in features: 150 | left_padding_features.append( 151 | { 152 | "input_ids": feature["prompt_input_ids"], 153 | "attention_mask": feature["prompt_attention_mask"], 154 | } 155 | ) 156 | left_padding_batch = self.tokenizer.pad( 157 | left_padding_features, 158 | padding=True, 159 | pad_to_multiple_of=self.pad_to_multiple_of, 160 | return_tensors="pt", 161 | ) 162 | self.tokenizer.padding_side = padding_side_default 163 | 164 | return { 165 | "prompt": [feature["prompt"] for feature in features], 166 | "prompt_input_ids": left_padding_batch["input_ids"], 167 | "prompt_attention_mask": left_padding_batch["attention_mask"], 168 | } 169 | 170 | 171 | class DPOTrainer(Trainer): 172 | r""" 173 | Initialize DPOTrainer. 174 | 175 | Args: 176 | model (`transformers.PreTrainedModel`): 177 | The model to train, preferably an `AutoModelForSequenceClassification`. 178 | ref_model (`PreTrainedModelWrapper`): 179 | Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. If no 180 | reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized. 181 | beta (`float`, defaults to 0.1): 182 | The beta factor in DPO loss. Higher beta means less divergence from the initial policy. 183 | loss_type (`str`, defaults to `"sigmoid"`): 184 | The type of DPO loss to use. Either `"sigmoid"` the default DPO loss or `"hinge"` loss from SLiC paper. 185 | args (`transformers.TrainingArguments`): 186 | The arguments to use for training. 187 | data_collator (`transformers.DataCollator`): 188 | The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used 189 | which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences. 190 | label_pad_token_id (`int`, defaults to `-100`): 191 | The label pad token id. This argument is required if you want to use the default data collator. 192 | train_dataset (`datasets.Dataset`): 193 | The dataset to use for training. 194 | eval_dataset (`datasets.Dataset`): 195 | The dataset to use for evaluation. 196 | tokenizer (`transformers.PreTrainedTokenizerBase`): 197 | The tokenizer to use for training. This argument is required if you want to use the default data collator. 198 | model_init (`Callable[[], transformers.PreTrainedModel]`): 199 | The model initializer to use for training. If None is specified, the default model initializer will be used. 200 | callbacks (`List[transformers.TrainerCallback]`): 201 | The callbacks to use for training. 202 | optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`): 203 | The optimizer and scheduler to use for training. 204 | preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`): 205 | The function to use to preprocess the logits before computing the metrics. 206 | max_length (`int`, defaults to `None`): 207 | The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator. 208 | peft_config (`Dict`, defaults to `None`): 209 | The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model. 210 | disable_dropout (`bool`, defaults to `True`): 211 | Whether or not to disable dropouts in `model` and `ref_model`. 212 | generate_during_eval (`bool`, defaults to `True`): 213 | Whether to sample and log generations during evaluation step. 214 | compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*): 215 | The function to use to compute the metrics. Must take a `EvalPrediction` and return 216 | a dictionary string to metric values. 217 | """ 218 | 219 | def __init__( 220 | self, 221 | model: Union[PreTrainedModel, nn.Module] = None, 222 | ref_model: Optional[Union[PreTrainedModel, nn.Module]] = None, 223 | beta: float = 0.1, 224 | loss_type: Literal["sigmoid", "hinge"] = "sigmoid", 225 | args: TrainingArguments = None, 226 | tokenize_map_func: Optional[Callable] = None, 227 | data_collator: Optional[DataCollator] = None, 228 | label_pad_token_id: int = -100, 229 | train_dataset: Optional[Dataset] = None, 230 | eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, 231 | tokenizer: Optional[PreTrainedTokenizerBase] = None, 232 | model_init: Optional[Callable[[], PreTrainedModel]] = None, 233 | callbacks: Optional[List[TrainerCallback]] = None, 234 | optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = ( 235 | None, 236 | None, 237 | ), 238 | preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, 239 | peft_config: Optional[Dict] = None, 240 | disable_dropout: bool = True, 241 | max_length: Optional[int] = 1024, 242 | filter_too_long: Optional[bool] = True, 243 | num_proc: Optional[int] = 4, 244 | generate_during_eval: bool = True, 245 | compute_metrics: Optional[Callable[[EvalLoopOutput], Dict]] = None, 246 | ): 247 | if not isinstance(model, PeftModel) and is_peft_available() and peft_config: 248 | model = prepare_model_for_peft(model, peft_config, args) 249 | 250 | if ref_model: 251 | self.ref_model = ref_model 252 | elif isinstance(model, PeftModel): 253 | # The `model` with adapters turned off will be used as the reference model 254 | self.ref_model = None 255 | else: 256 | self.ref_model = create_reference_model(model) 257 | 258 | if tokenize_map_func is None: 259 | tokenize_map_func = DPODataMapFunc(tokenizer, label_pad_token_id=label_pad_token_id) 260 | 261 | if data_collator is None: 262 | data_collator = DPODataCollatorWithPadding(tokenizer, label_pad_token_id=label_pad_token_id) 263 | 264 | if tokenizer is None: 265 | tokenizer = AutoTokenizer.from_pretrained(model.config._name_or_path) 266 | if getattr(tokenizer, "pad_token", None) is None: 267 | tokenizer.pad_token = tokenizer.eos_token 268 | 269 | def preprocess_dataset(dataset): 270 | # tokenize samples 271 | dataset = dataset.map( 272 | tokenize_map_func, 273 | batched=True, 274 | num_proc=num_proc, 275 | remove_columns=dataset.column_names 276 | ) 277 | original_length = len(dataset) 278 | # filter samples that are too long 279 | if filter_too_long: 280 | dataset = dataset.filter( 281 | lambda x: len(x["prompt_chosen_input_ids"]) <= max_length and len(x["prompt_rejected_input_ids"]) <= max_length 282 | ) 283 | else: 284 | dataset = dataset.map( 285 | # truncate chosen and rejected 286 | lambda sample: {k: v[:max_length] if ('chosen' in k or 'rejected' in k) else v for k, v in sample.items()}, 287 | num_proc=num_proc, 288 | ) 289 | filtered_length = len(dataset) 290 | return dataset, filtered_length / original_length 291 | preprocess_here = False 292 | if "prompt_chosen_input_ids" not in train_dataset[0].keys(): 293 | preprocess_here = True 294 | print_local_main("dataset preprocessing...") 295 | train_dataset, train_dataset_retain = preprocess_dataset(train_dataset) 296 | eval_dataset, eval_dataset_retain = preprocess_dataset(eval_dataset) 297 | print_local_main(f"train_dataset_retain: {train_dataset_retain}") 298 | print_local_main(f"eval_dataset_retain: {eval_dataset_retain}") 299 | 300 | if is_peft_available() and isinstance(model, PeftModel): 301 | if callbacks is None: 302 | callbacks = [PeftSavingCallback()] 303 | else: 304 | callbacks += [PeftSavingCallback()] 305 | 306 | if compute_metrics is None: 307 | compute_metrics = compute_accuracy 308 | 309 | if disable_dropout: 310 | disable_dropout_in_model(model) 311 | if self.ref_model is not None: 312 | disable_dropout_in_model(self.ref_model) 313 | 314 | if generate_during_eval and not is_wandb_available(): 315 | raise ValueError( 316 | "`generate_during_eval=True` requires Weights and Biases to be installed." 317 | " Please install `wandb` to resolve." 318 | ) 319 | 320 | self.max_length = max_length 321 | self.generate_during_eval = generate_during_eval 322 | self.label_pad_token_id = label_pad_token_id 323 | 324 | self.beta = beta 325 | self.loss_type = loss_type 326 | 327 | self._stored_metrics = defaultdict(lambda: defaultdict(list)) 328 | 329 | self.table = None # for late initialization 330 | 331 | super().__init__( 332 | model=model, 333 | args=args, 334 | data_collator=data_collator, 335 | train_dataset=train_dataset, 336 | eval_dataset=eval_dataset, 337 | tokenizer=tokenizer, 338 | model_init=model_init, 339 | compute_metrics=compute_metrics, 340 | callbacks=callbacks, 341 | optimizers=optimizers, 342 | preprocess_logits_for_metrics=preprocess_logits_for_metrics, 343 | ) 344 | 345 | if preprocess_here: 346 | self.log({"eval_dataset_retain": train_dataset_retain, "dataset_retain": eval_dataset_retain}) 347 | 348 | if not hasattr(self, "accelerator"): 349 | raise AttributeError( 350 | "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`." 351 | ) 352 | 353 | if self.ref_model is None: 354 | if not hasattr(self.model, "disable_adapter"): 355 | raise ValueError( 356 | "You are using a `peft` version that does not support `disable_adapter`. Please update your `peft` version to the latest version." 357 | ) 358 | self.ref_model = PeftAsPreTrained(self.model) 359 | else: 360 | if self.is_deepspeed_enabled: 361 | self.ref_model = self._prepare_deepspeed(self.ref_model) 362 | else: 363 | self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) 364 | 365 | def _prepare_deepspeed(self, model: PreTrainedModelWrapper): 366 | # Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473 367 | deepspeed_plugin = self.accelerator.state.deepspeed_plugin 368 | config_kwargs = deepspeed_plugin.deepspeed_config 369 | if model is not None: 370 | if hasattr(model, "config"): 371 | hidden_size = ( 372 | max(model.config.hidden_sizes) 373 | if getattr(model.config, "hidden_sizes", None) 374 | else getattr(model.config, "hidden_size", None) 375 | ) 376 | if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3: 377 | # Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0` 378 | # This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081 379 | config_kwargs.update( 380 | { 381 | "zero_optimization.reduce_bucket_size": hidden_size * hidden_size, 382 | "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size, 383 | "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size, 384 | } 385 | ) 386 | 387 | # If ZeRO-3 is used, we shard both the active and reference model. 388 | # Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0) 389 | if config_kwargs["zero_optimization"]["stage"] != 3: 390 | config_kwargs["zero_optimization"]["stage"] = 0 391 | model, *_ = deepspeed.initialize(model=model, config=config_kwargs) 392 | model.eval() 393 | return model 394 | 395 | def train( 396 | self, 397 | resume_from_checkpoint: Optional[Union[str, bool]] = None, 398 | trial: Union["optuna.Trial", Dict[str, Any]] = None, 399 | ignore_keys_for_eval: Optional[List[str]] = None, 400 | **kwargs, 401 | ): 402 | initial_output = super().train( 403 | resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs, 404 | ) 405 | 406 | # upload wandb table at the end of training if it exists 407 | if self.table: 408 | self.log({"eval_game_log": self.table}) 409 | self.state.log_history.pop() 410 | 411 | return initial_output 412 | 413 | def dpo_loss( 414 | self, 415 | policy_chosen_logps: torch.FloatTensor, 416 | policy_rejected_logps: torch.FloatTensor, 417 | reference_chosen_logps: torch.FloatTensor, 418 | reference_rejected_logps: torch.FloatTensor, 419 | ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: 420 | """Compute the DPO loss for a batch of policy and reference model log probabilities. 421 | 422 | Args: 423 | policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,) 424 | policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,) 425 | reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,) 426 | reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,) 427 | beta: Temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5. We ignore the reference model as beta -> 0. 428 | 429 | Returns: 430 | A tuple of three tensors: (losses, chosen_rewards, rejected_rewards). 431 | The losses tensor contains the DPO loss for each example in the batch. 432 | The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively. 433 | """ 434 | chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps) 435 | rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps) 436 | 437 | logits = chosen_rewards - rejected_rewards 438 | if self.loss_type == "sigmoid": 439 | losses = -F.logsigmoid(logits) 440 | elif self.loss_type == "hinge": 441 | losses = torch.relu(1 - logits) 442 | else: 443 | raise ValueError(f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge']") 444 | 445 | return losses, chosen_rewards.detach(), rejected_rewards.detach() 446 | 447 | def forward( 448 | self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]] 449 | ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: 450 | # batch is already concatenated 451 | all_logits = model( 452 | input_ids=batch["input_ids"], 453 | attention_mask=batch["attention_mask"], 454 | ).logits.to(torch.float32) 455 | 456 | all_logps = get_batch_logps( 457 | all_logits, 458 | batch["labels"], 459 | average_log_prob=False, 460 | label_pad_token_id=self.label_pad_token_id, 461 | ) 462 | 463 | chosen_logps, rejected_logps = all_logps.chunk(2) 464 | chosen_logits, rejected_logits = all_logits.chunk(2) 465 | 466 | return (chosen_logps, rejected_logps, chosen_logits, rejected_logits) 467 | 468 | def get_batch_metrics( 469 | self, 470 | model, 471 | batch: Dict[str, Union[List, torch.LongTensor]], 472 | train_eval: Literal["train", "eval"] = "train", 473 | ): 474 | """Compute the DPO loss and other metrics for the given batch of inputs for train or test.""" 475 | metrics = {} 476 | 477 | ( 478 | policy_chosen_logps, 479 | policy_rejected_logps, 480 | _, 481 | _, 482 | ) = self.forward(model, batch) 483 | with torch.no_grad(): 484 | ( 485 | reference_chosen_logps, 486 | reference_rejected_logps, 487 | _, 488 | _, 489 | ) = self.forward(self.ref_model, batch) 490 | 491 | losses, chosen_rewards, rejected_rewards = self.dpo_loss( 492 | policy_chosen_logps, 493 | policy_rejected_logps, 494 | reference_chosen_logps, 495 | reference_rejected_logps, 496 | ) 497 | accuracies = (chosen_rewards > rejected_rewards).float() 498 | 499 | # change 500 | prefix = "eval_" if train_eval == "eval" else "" 501 | metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).cpu() 502 | metrics[f"{prefix}rewards/chosen"] = chosen_rewards.cpu() 503 | metrics[f"{prefix}rewards/rejected"] = rejected_rewards.cpu() 504 | metrics[f"{prefix}logps/margins"] = (policy_chosen_logps - policy_rejected_logps).detach().cpu() 505 | metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().cpu() 506 | metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().cpu() 507 | if train_eval == "train": 508 | metrics[f"{prefix}accuracy"] = accuracies.detach().cpu() 509 | 510 | return losses.mean(), metrics 511 | 512 | def compute_loss( 513 | self, 514 | model: Union[PreTrainedModel, nn.Module], 515 | inputs: Dict[str, Union[torch.Tensor, Any]], 516 | return_outputs=False, 517 | ) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]: 518 | loss, metrics = self.get_batch_metrics(model, inputs, train_eval="train") 519 | 520 | # force log the metrics 521 | if self.accelerator.is_main_process: 522 | self.store_metrics(metrics, train_eval="train") 523 | 524 | if return_outputs: 525 | return (loss, metrics) 526 | return loss 527 | 528 | def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]: 529 | """Generate samples from the model and reference model for the given batch of inputs.""" 530 | 531 | policy_output = model.generate( 532 | input_ids=batch["prompt_input_ids"], 533 | attention_mask=batch["prompt_attention_mask"], 534 | max_length=self.max_length, 535 | do_sample=True, 536 | pad_token_id=self.tokenizer.pad_token_id, 537 | ) 538 | 539 | reference_output = self.ref_model.generate( 540 | input_ids=batch["prompt_input_ids"], 541 | attention_mask=batch["prompt_attention_mask"], 542 | max_length=self.max_length, 543 | do_sample=True, 544 | pad_token_id=self.tokenizer.pad_token_id, 545 | ) 546 | 547 | policy_output = pad_to_length(policy_output, self.max_length, self.tokenizer.pad_token_id) 548 | policy_output_decoded = self.tokenizer.batch_decode(policy_output, skip_special_tokens=True) 549 | 550 | reference_output = pad_to_length(reference_output, self.max_length, self.tokenizer.pad_token_id) 551 | reference_output_decoded = self.tokenizer.batch_decode(reference_output, skip_special_tokens=True) 552 | 553 | return policy_output_decoded, reference_output_decoded 554 | 555 | def prediction_step( 556 | self, 557 | model: Union[PreTrainedModel, nn.Module], 558 | inputs: Dict[str, Union[torch.Tensor, Any]], 559 | prediction_loss_only: bool, 560 | ignore_keys: Optional[List[str]] = None, 561 | ): 562 | if ignore_keys is None: 563 | if hasattr(model, "config"): 564 | ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", []) 565 | else: 566 | ignore_keys = [] 567 | 568 | with torch.no_grad(): 569 | loss, metrics = self.get_batch_metrics(model, inputs, train_eval="eval") 570 | 571 | # force log the metrics 572 | if self.accelerator.is_main_process: 573 | self.store_metrics(metrics, train_eval="eval") 574 | 575 | if prediction_loss_only: 576 | return (loss.detach(), None, None) 577 | 578 | # logits for the chosen and rejected samples from model 579 | logits_dict = { 580 | "eval_rewards/chosen": metrics["eval_rewards/chosen"], 581 | "eval_rewards/rejected": metrics["eval_rewards/rejected"], 582 | } 583 | logits = tuple(v for k, v in logits_dict.items() if k not in ignore_keys) 584 | logits = torch.stack(logits).T.to(self.accelerator.device) 585 | labels = torch.zeros(logits.shape[0], device=self.accelerator.device) 586 | 587 | return (loss.detach(), logits, labels) 588 | 589 | def store_metrics(self, metrics: Dict[str, np.ndarray], train_eval: Literal["train", "eval"] = "train") -> None: 590 | for key, value in metrics.items(): 591 | self._stored_metrics[train_eval][key].append(value.mean()) 592 | 593 | def evaluation_loop( 594 | self, 595 | dataloader: DataLoader, 596 | description: str, 597 | prediction_loss_only: Optional[bool] = None, 598 | ignore_keys: Optional[List[str]] = None, 599 | metric_key_prefix: str = "eval", 600 | ) -> EvalLoopOutput: 601 | """ 602 | Overriding built-in evaluation loop to store metrics for each batch. 603 | Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`. 604 | 605 | Works both with or without labels. 606 | """ 607 | 608 | # Sample and save to game log if requested (for one batch to save time) 609 | if self.generate_during_eval and self.state.is_world_process_zero: 610 | # late init 611 | self.table = wandb.Table(columns=["Prompt", "Policy", "Ref Policy"]) if self.table == None else self.table 612 | 613 | print("generating response...") 614 | # Generate random indices within the range of the total number of samples 615 | num_samples = len(dataloader.dataset) 616 | random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size) 617 | 618 | # Use dataloader.dataset.select to get the random batch without iterating over the DataLoader 619 | random_batch_dataset = dataloader.dataset.select(random_indices) 620 | random_batch = self.data_collator(random_batch_dataset, generate=True) 621 | random_batch = self._prepare_inputs(random_batch) 622 | 623 | policy_output_decoded, ref_output_decoded = self.get_batch_samples(self.model, random_batch) 624 | 625 | for prompt, policy_output, ref_policy_output in zip(random_batch["prompt"], policy_output_decoded, ref_output_decoded): 626 | self.table.add_data(f"(epoch{self.state.epoch}) {prompt}", policy_output[len(prompt):], ref_policy_output[len(prompt):]) 627 | 628 | # barrier 629 | self.accelerator.wait_for_everyone() 630 | 631 | # Base evaluation 632 | initial_output = super().evaluation_loop( 633 | dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix 634 | ) 635 | 636 | return initial_output 637 | 638 | def log(self, logs: Dict[str, float]) -> None: 639 | """ 640 | Log `logs` on the various objects watching training, including stored metrics. 641 | 642 | Args: 643 | logs (`Dict[str, float]`): 644 | The values to log. 645 | """ 646 | # logs either has 'loss' or 'eval_loss' 647 | train_eval = "train" if "loss" in logs else "eval" 648 | # Add averaged stored metrics to logs 649 | for key, metrics in self._stored_metrics[train_eval].items(): 650 | logs[key] = torch.tensor(metrics).mean().item() 651 | del self._stored_metrics[train_eval] 652 | return super().log(logs) 653 | 654 | 655 | if __name__ == '__main__': 656 | from src.data.configs import DATASET_CONFIGS 657 | from transformers import LlamaTokenizer 658 | dataset = DATASET_CONFIGS["PKU-Alignment/PKU-SafeRLHF-10K-better"](sanity_check=True).get_preference_dataset(split="train") 659 | tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") 660 | tokenizer.pad_token = tokenizer.eos_token 661 | dataset = dataset.map(DPODataMapFunc(tokenizer=tokenizer), batched=True) 662 | batch = DPODataCollatorWithPadding(tokenizer=tokenizer)([dataset[0], dataset[1]]) 663 | breakpoint() 664 | --------------------------------------------------------------------------------