├── .gitignore ├── README.md ├── data └── relabeled_hh_rlhf │ ├── both │ ├── test.jsonl.gz │ └── train.jsonl.gz │ ├── harmless │ ├── test.jsonl.gz │ └── train.jsonl.gz │ └── helpful │ ├── test.jsonl.gz │ └── train.jsonl.gz ├── generate_llm_embeddings_UF_P_2.sh ├── generate_llm_embeddings_UF_P_4.sh ├── generate_llm_embeddings_pets.sh ├── hidden_context ├── __init__.py ├── data_utils │ ├── __init__.py │ ├── add_survey_contexts.py │ ├── data_processing.py │ ├── generate_simple_data.py │ ├── simple_templates.py │ └── ultrafeedback_augment.py ├── train_llm_preference_model.py ├── train_llm_vae_preference_model.py └── vae_utils.py ├── requirements.txt ├── submit_job_UF_P_2.sh ├── submit_job_UF_P_4.sh └── submit_job_pets.sh /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | 137 | # Cython debug symbols 138 | cython_debug/ 139 | logs/ 140 | runs*/ 141 | 142 | # jupyter notebook 143 | notebooks/ 144 | 145 | # vscode 146 | .vscode/ 147 | 148 | # output folder 149 | results/ 150 | wandb/ 151 | slurm*/ 152 | data/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Personalizing Reinforcement Learning from Human Feedback with Variational Preference Learning 2 | 3 | #### [[Website]](https://weirdlabuw.github.io/vpl/) [[Paper]](https://arxiv.org/) 4 | 5 | [Sriyash Poddar1](https://sriya.sh), [Yanming Wan1](https://wanyanming.com/), [Hamish Ivison1](https://hamishivi.github.io/), [Abhishek Gupta1](https://homes.cs.washington.edu/~abhgupta), [Natasha Jaques1](https://natashajaques.ai)
6 | 7 | 1University of Washington 8 | 9 | This repo is an implementation of the language experiments of VPL. VPL is a variational framework for learning from human feedback (binary preference labels) i.e. inferring a novel user-specific latent and learning reward models and policies conditioned on this latent without additional user-specific data. This is used for quick adaptation to specific user preferences without retraining the entire model or ignoring underrepresented groups. 10 | 11 | For control experiments of VPL, please refer to [here](https://github.com/WEIRDLabUW/vpl). 12 | 13 | ## Instructions 14 | 15 | 16 | #### Setting up repo 17 | ``` 18 | git clone git@github.com:WEIRDLabUW/vpl_llm.git 19 | ``` 20 | 21 | #### Install Dependencies 22 | ``` 23 | conda create -n vpl python=3.10 24 | conda activate vpl 25 | pip install -r requirements.txt 26 | ``` 27 | 28 | ## Data and Pretrained models 29 | 30 | Our datasets and checkpoints can be downloaded from [Google Drive link](https://drive.google.com/drive/folders/1dQ8zpNefRAtUB9TtbovOSn2MfV2Y-MbC?usp=sharing). 31 | 32 | #### Datasets 33 | The datasets needed for VPL experiments should be downloaded and unzipped to ``./data/``. There are three datasets in the folder: ``simple_pets``, ``P_survey_100``, and ``P_4_survey_100``. 34 | 35 | #### Checkpoints 36 | The checkpoints for VPL experiments should be downloaded to ``./logs``. We provide the checkpoints for VPL and other baseline models over each dataset. 37 | 38 | ## Dataset Generation 39 | We also provide the code for generating our datasets. 40 | The following scripts will also give you the datasets in ``./data/``. 41 | #### Pets 42 | To generate ``./data/simple_pets``, run 43 | ```bash 44 | bash generate_llm_embeddings_pets.sh gpt2 45 | ``` 46 | 47 | #### UF-P-2 48 | To generate ``./data/P_survey_100``, run 49 | ```bash 50 | python -m hidden_context.data_utils.ultrafeedback_augment -a 84 -n P 51 | bash generate_llm_embeddings_UF_P_2.sh gpt2 84 52 | ``` 53 | 54 | #### UF-P-4 55 | To generate ``./data/P_4_survey_100``, run 56 | ```bash 57 | python -m hidden_context.data_utils.ultrafeedback_augment -a single -n P_4 -c 58 | bash generate_llm_embeddings_UF_P_4.sh gpt2 single 59 | ``` 60 | 61 | ## Running Experiments 62 | In all the following scripts, ```` can be chosen from ``vae``, ``base``, ``categorical``, and ``mean_and_variance``. 63 | ``vae`` corresponds to our VPL models, while the others are training baseline models. 64 | 65 | The results are recorded on Wandb. Please refer to ``eval/accuracy`` on Wandb page for model's performance. 66 | 67 | #### Pets 68 | To train models on ``./data/simple_pets``, run 69 | ```bash 70 | bash submit_job_pets.sh 71 | ``` 72 | Note that the default settings are for Pets (full), 73 | please change the arguments as explained in the bash file if you want to train on Pets (controversial). 74 | 75 | #### UF-P-2 76 | To train models on ``./data/P_survey_100``, run 77 | ```bash 78 | bash submit_job_UF_P_2.sh 79 | ``` 80 | 81 | #### UF-P-4 82 | To train models on ``./data/P_4_survey_100``, run 83 | ```bash 84 | bash submit_job_UF_P_4.sh 85 | ``` 86 | 87 | ## Bibtex 88 | If you find this code useful, please cite: 89 | 90 | ``` 91 | @article{poddar2024vpl, 92 | author = {Poddar, Sriyash and Wan, Yanming and Ivision, Hamish and Gupta, Abhishek and Jaques, Natasha}, 93 | title = {Personalizing Reinforcement Learning from Human Feedback with Variational Preference Learning}, 94 | booktitle = {ArXiv Preprint}, 95 | year = {2024}, 96 | } 97 | ``` -------------------------------------------------------------------------------- /data/relabeled_hh_rlhf/both/test.jsonl.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WEIRDLabUW/vpl_llm/fdfdfa767752b5af7f89f2e6ab6989c1953062c9/data/relabeled_hh_rlhf/both/test.jsonl.gz -------------------------------------------------------------------------------- /data/relabeled_hh_rlhf/both/train.jsonl.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WEIRDLabUW/vpl_llm/fdfdfa767752b5af7f89f2e6ab6989c1953062c9/data/relabeled_hh_rlhf/both/train.jsonl.gz -------------------------------------------------------------------------------- /data/relabeled_hh_rlhf/harmless/test.jsonl.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WEIRDLabUW/vpl_llm/fdfdfa767752b5af7f89f2e6ab6989c1953062c9/data/relabeled_hh_rlhf/harmless/test.jsonl.gz -------------------------------------------------------------------------------- /data/relabeled_hh_rlhf/harmless/train.jsonl.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WEIRDLabUW/vpl_llm/fdfdfa767752b5af7f89f2e6ab6989c1953062c9/data/relabeled_hh_rlhf/harmless/train.jsonl.gz -------------------------------------------------------------------------------- /data/relabeled_hh_rlhf/helpful/test.jsonl.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WEIRDLabUW/vpl_llm/fdfdfa767752b5af7f89f2e6ab6989c1953062c9/data/relabeled_hh_rlhf/helpful/test.jsonl.gz -------------------------------------------------------------------------------- /data/relabeled_hh_rlhf/helpful/train.jsonl.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WEIRDLabUW/vpl_llm/fdfdfa767752b5af7f89f2e6ab6989c1953062c9/data/relabeled_hh_rlhf/helpful/train.jsonl.gz -------------------------------------------------------------------------------- /generate_llm_embeddings_UF_P_2.sh: -------------------------------------------------------------------------------- 1 | 2 | # Set model_type to be 'gpt2' or 'llama' for model_type 3 | # Set other_subsets to be 'ultra_feedback', 'single', or '84' 4 | model_type=$1 5 | other_subsets=$2 6 | 7 | # Generate LLM embeddings for UltraFeedback dataset 8 | if [ "${other_subsets}" = "ultra_feedback" ]; then 9 | subsets="helpfulness honesty instruction_following truthfulness" 10 | elif [ "${other_subsets}" = "single" ]; then 11 | subsets="8 4 2 1" 12 | elif [ "${other_subsets}" = "84" ]; then 13 | subsets="8 4" 14 | else 15 | echo "Invalid!" 16 | fi 17 | 18 | echo "${subsets}" 19 | 20 | 21 | # Final version for two users 22 | survey_size=100 23 | for subset in ${subsets} 24 | do 25 | python -m hidden_context.data_utils.add_survey_contexts --output_dir "data/P_survey_${survey_size}/" \ 26 | --data_path "data/UltraFeedback_${other_subsets}_P" --data_subset ${subset} --data_split train --model_type ${model_type} \ 27 | --other_subsets ${other_subsets} --with_embeddings True --survey_size $survey_size --num_duplicates 8 --fixed_context_length True 28 | 29 | python -m hidden_context.data_utils.add_survey_contexts --output_dir "data/P_survey_${survey_size}/" \ 30 | --data_path "data/UltraFeedback_${other_subsets}_P" --data_subset ${subset} --data_split test --model_type ${model_type} \ 31 | --other_subsets ${other_subsets} --with_embeddings True --survey_size $survey_size --num_duplicates 8 --fixed_context_length True 32 | done 33 | -------------------------------------------------------------------------------- /generate_llm_embeddings_UF_P_4.sh: -------------------------------------------------------------------------------- 1 | 2 | # Set model_type to be 'gpt2' or 'llama' for model_type 3 | # Set other_subsets to be 'ultra_feedback', 'single', or '84' 4 | model_type=$1 5 | other_subsets=$2 6 | 7 | # Generate LLM embeddings for UltraFeedback dataset 8 | if [ "${other_subsets}" = "ultra_feedback" ]; then 9 | subsets="helpfulness honesty instruction_following truthfulness" 10 | elif [ "${other_subsets}" = "single" ]; then 11 | subsets="8 4 2 1" 12 | elif [ "${other_subsets}" = "84" ]; then 13 | subsets="8 4" 14 | else 15 | echo "Invalid!" 16 | fi 17 | 18 | echo "${subsets}" 19 | 20 | 21 | # Final version for four users 22 | survey_size=100 23 | for subset in ${subsets} 24 | do 25 | python -m hidden_context.data_utils.add_survey_contexts --output_dir "data/P_4_survey_${survey_size}/" \ 26 | --data_path "data/UltraFeedback_${other_subsets}_P_4" --data_subset ${subset} --data_split train --model_type ${model_type} \ 27 | --other_subsets ${other_subsets} --with_embeddings True --survey_size $survey_size --num_duplicates 4 28 | 29 | python -m hidden_context.data_utils.add_survey_contexts --output_dir "data/P_4_survey_${survey_size}/" \ 30 | --data_path "data/UltraFeedback_${other_subsets}_P_4" --data_subset ${subset} --data_split test --model_type ${model_type} \ 31 | --other_subsets ${other_subsets} --with_embeddings True --survey_size $survey_size --num_duplicates 4 32 | done 33 | -------------------------------------------------------------------------------- /generate_llm_embeddings_pets.sh: -------------------------------------------------------------------------------- 1 | 2 | # Set model_type to be 'gpt2' or 'llama' here 3 | model_type=$1 4 | 5 | # Generate Pets dataset 6 | python -m hidden_context.data_utils.generate_simple_data --output_dir data/simple_pets/ \ 7 | --data_path data/relabeled_hh_rlhf --with_embeddings True --synthetic_dataset True \ 8 | --model_type ${model_type} --data_subset helpful --data_split test --dataset_size 200 9 | 10 | python -m hidden_context.data_utils.generate_simple_data --output_dir data/simple_pets/ \ 11 | --data_path data/relabeled_hh_rlhf --with_embeddings True --synthetic_dataset True \ 12 | --model_type ${model_type} --data_subset helpful --data_split train --dataset_size 2000 13 | 14 | python -m hidden_context.data_utils.generate_simple_data --output_dir data/simple_pet/ \ 15 | --data_path data/relabeled_hh_rlhf --with_embeddings True --synthetic_dataset True \ 16 | --model_type ${model_type} --data_subset harmless --data_split test --dataset_size 200 17 | 18 | python -m hidden_context.data_utils.generate_simple_data --output_dir data/simple_pets/ \ 19 | --data_path data/relabeled_hh_rlhf --with_embeddings True --synthetic_dataset True \ 20 | --model_type ${model_type} --data_subset harmless --data_split train --dataset_size 2000 21 | -------------------------------------------------------------------------------- /hidden_context/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WEIRDLabUW/vpl_llm/fdfdfa767752b5af7f89f2e6ab6989c1953062c9/hidden_context/__init__.py -------------------------------------------------------------------------------- /hidden_context/data_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WEIRDLabUW/vpl_llm/fdfdfa767752b5af7f89f2e6ab6989c1953062c9/hidden_context/data_utils/__init__.py -------------------------------------------------------------------------------- /hidden_context/data_utils/add_survey_contexts.py: -------------------------------------------------------------------------------- 1 | # This file is used to preprocess dataset, available for any HH-RLHF format datasets 2 | import os 3 | from dataclasses import dataclass, field 4 | from typing import Optional, cast 5 | from tqdm import tqdm 6 | import random 7 | 8 | from transformers import ( 9 | HfArgumentParser, 10 | ) 11 | 12 | import torch 13 | 14 | from hidden_context.train_llm_preference_model import ( 15 | concatenate_datasets, 16 | ) 17 | 18 | from hidden_context.data_utils.data_processing import generate_embeddings_with_llm 19 | from datasets import load_dataset 20 | 21 | from copy import deepcopy 22 | 23 | import numpy as np 24 | 25 | 26 | @dataclass 27 | class ScriptArguments: 28 | output_dir: Optional[str] = field( 29 | metadata={"help": "Directory where the new dataset will be stored."}, 30 | ) 31 | data_path: str = field( 32 | metadata={"help": "Directory where the original data is stored."} 33 | ) 34 | data_subset: str = field( 35 | default="helpful", 36 | metadata={ 37 | "help": "Which subset of the data to use. You can choose between" 38 | "'helpful', or 'harmless'." 39 | }, 40 | ) 41 | data_split: str = field( 42 | default="test", 43 | metadata={ 44 | "help": "Which split of the data to use. You can choose between" 45 | "'train', or 'test'." 46 | }, 47 | ) 48 | dataset_size: int = field( 49 | default=0, 50 | metadata={"help": "The size of the subset of the data to use"}, 51 | ) 52 | model_type: str = field( 53 | default="none", 54 | metadata={ 55 | "help": "You can choose between 'gpt2', 'llama', or 'none'." 56 | } 57 | ) 58 | embed_dim: int = field( 59 | default=1024, 60 | metadata={ 61 | "help": "Dimension of the embeddings generated by LLM." 62 | } 63 | ) 64 | max_length: int = field( 65 | default=1024, 66 | metadata={ 67 | "help": "Maximum token length of the inputs." 68 | } 69 | ) 70 | with_embeddings: bool = field( 71 | default=True, 72 | metadata={ 73 | "help": "Whether the embeddings are generated during pre-processing." 74 | } 75 | ) 76 | synthetic_dataset: bool = field( 77 | default=False, 78 | metadata={ 79 | "help": "Whether a synthetic dataset is used." 80 | } 81 | ) 82 | other_subsets: str = field(default=None) 83 | survey_size: int = field( 84 | default=8, 85 | metadata={ 86 | "help": "Size of survey question pool." 87 | } 88 | ) 89 | context_length: int = field( 90 | default=8, 91 | metadata={ 92 | "help": "(Maximum) context length." 93 | } 94 | ) 95 | controversial_only: bool = field( 96 | default=True, 97 | metadata={ 98 | "help": "Whether to only generate controversial data points." 99 | } 100 | ) 101 | num_duplicates: int = field( 102 | default=1, 103 | metadata={ 104 | "help": "Number of times each data point repeatedly appears in the dataset (with resampled context)." 105 | } 106 | ) 107 | fixed_context_length: bool = field( 108 | default=False, 109 | metadata={ 110 | "help": "Whether to fix the context to the maximum length." 111 | } 112 | ) 113 | random_contexts: bool = field( 114 | default=False, 115 | metadata={ 116 | "help": "Whether to include controversial pairs in context." 117 | } 118 | ) 119 | 120 | 121 | def generate_contexts(args, input_dataset, survey_dataset): 122 | # Generate context with survey question pool 123 | output_dir = os.path.join(args.output_dir, f"{args.model_type}", f"{args.data_subset}") 124 | if not os.path.exists(output_dir): 125 | os.makedirs(output_dir) 126 | 127 | if args.controversial_only: 128 | input_dataset = input_dataset.filter(lambda x: x['controversial'] == True) 129 | dataset_size = len(input_dataset) 130 | if args.data_split == 'train': 131 | K = args.num_duplicates # repeat samples for K times 132 | else: 133 | K = 1 134 | dataset_list = list() 135 | 136 | def random_choice(max_context_length, survey_size): 137 | if max_context_length <= survey_size: 138 | from functools import reduce 139 | while True: 140 | if args.fixed_context_length: 141 | context_length = max_context_length 142 | else: 143 | if args.other_subsets == '84': 144 | context_length = random.randint(1, max_context_length) 145 | else: 146 | context_length = random.randint(2, max_context_length) 147 | context_chosen_ids = np.random.choice(survey_size, context_length, replace=False) 148 | chosen_dataset = [d for idx, d in enumerate(survey_dataset) if idx in context_chosen_ids] 149 | if args.other_subsets != 'single': 150 | return chosen_dataset, context_length 151 | satisfied_sets = list() 152 | for row in chosen_dataset: 153 | satisfied_sets.append(set(row["satisfied_subset"])) 154 | if len(reduce(lambda x, y: x.intersection(y), satisfied_sets)) == 1: 155 | return chosen_dataset, context_length 156 | elif context_length == survey_size: 157 | raise ValueError("Please choose another random seed!") 158 | else: 159 | raise ValueError("Context length is larger than survey size!") 160 | 161 | for idx in range(K): 162 | output_dataset = deepcopy(input_dataset) 163 | context_lengths = list() 164 | contexts = list() 165 | for _ in tqdm(range(dataset_size)): # iterate over all samples in original dataset 166 | row_contexts = list() 167 | 168 | context_dataset, context_length = random_choice(args.context_length, args.survey_size) 169 | context_lengths.append(context_length) 170 | for context_row in context_dataset: 171 | context_id = context_row["Index"] 172 | context_data = context_row 173 | if not args.with_embeddings: 174 | row_contexts.append({ 175 | 'original_id': context_id, 176 | 'chosen': context_data['chosen'], 177 | 'rejected': context_data['rejected'], 178 | }) 179 | else: 180 | row_contexts.append({ 181 | 'original_id': context_id, 182 | 'embedding_chosen': context_data['embeddings']['embedding_chosen'], 183 | 'embedding_rejected': context_data['embeddings']['embedding_rejected'], 184 | }) 185 | contexts.append(row_contexts) 186 | output_dataset = output_dataset.add_column("context_length", context_lengths) 187 | output_dataset = output_dataset.add_column("contexts", contexts) 188 | output_dataset.map() 189 | dataset_list.append(output_dataset) 190 | output = concatenate_datasets(dataset_list) 191 | output.to_json(os.path.join(output_dir, f"{args.data_split}.jsonl")) 192 | return output 193 | 194 | 195 | if __name__ == "__main__": 196 | # default setting on HH-RLHF dataset, please iterate over data subsets and data splits 197 | seed = 0 198 | random.seed(seed) 199 | np.random.seed(seed) 200 | torch.manual_seed(seed) 201 | torch.cuda.manual_seed(seed) 202 | parser = HfArgumentParser(ScriptArguments) 203 | script_args: ScriptArguments = parser.parse_args_into_dataclasses()[0] 204 | print(script_args) 205 | dataset = generate_embeddings_with_llm(script_args) 206 | if not script_args.random_contexts: 207 | survey_options = dataset.filter(lambda x: x['survey_options'] == True) 208 | else: 209 | survey_options = dataset.filter(lambda x: x['survey_options'] == True or x['survey_options'] == False) 210 | survey_ids = np.random.choice(range(len(survey_options)), script_args.survey_size, replace=False) 211 | print(survey_ids) 212 | if script_args.data_split == "train": 213 | survey_data = survey_options.filter(lambda example, idx: idx in survey_ids, with_indices=True) 214 | survey_data.to_json(os.path.join(script_args.data_path, script_args.data_subset, "survey_{}.jsonl".format(script_args.survey_size))) 215 | else: 216 | survey_data = load_dataset('json', data_files=os.path.join(script_args.data_path, script_args.data_subset, "survey_{}.jsonl".format(script_args.survey_size))) 217 | survey_data = survey_data['train'] 218 | generate_contexts(script_args, dataset, survey_data) 219 | -------------------------------------------------------------------------------- /hidden_context/data_utils/data_processing.py: -------------------------------------------------------------------------------- 1 | # This file is used to preprocess dataset, available for any HH-RLHF format datasets 2 | import os 3 | from dataclasses import dataclass, field 4 | from typing import Optional, cast 5 | from tqdm import tqdm 6 | import random 7 | 8 | from transformers import ( 9 | HfArgumentParser, 10 | AutoTokenizer, 11 | AutoModelForSequenceClassification, 12 | AutoModelForCausalLM, 13 | ) 14 | 15 | import torch 16 | 17 | from hidden_context.train_llm_preference_model import ( 18 | DataSubset, 19 | get_hh_rlhf_dataset, 20 | concatenate_datasets, 21 | HHRLHFPreprocessor, 22 | ) 23 | 24 | from copy import deepcopy 25 | 26 | import numpy as np 27 | 28 | @dataclass 29 | class ScriptArguments: 30 | output_dir: Optional[str] = field( 31 | metadata={"help": "Directory where the new dataset will be stored."}, 32 | ) 33 | data_path: str = field( 34 | metadata={"help": "Directory where the original data is stored."} 35 | ) 36 | data_subset: str = field( 37 | default="helpful", 38 | metadata={ 39 | "help": "Which subset of the data to use. You can choose between" 40 | "'helpful', or 'harmless'." 41 | }, 42 | ) 43 | data_split: str = field( 44 | default="test", 45 | metadata={ 46 | "help": "Which split of the data to use. You can choose between" 47 | "'train', or 'test'." 48 | }, 49 | ) 50 | dataset_size: int = field( 51 | default=0, 52 | metadata={"help": "The size of the subset of the data to use"}, 53 | ) 54 | model_type: str = field( 55 | default="none", 56 | metadata={ 57 | "help": "You can choose between 'gpt2', 'llama', or 'none'." 58 | } 59 | ) 60 | embed_dim: int = field( 61 | default=1024, 62 | metadata={ 63 | "help": "Dimension of the embeddings generated by LLM." 64 | } 65 | ) 66 | max_length: int = field( 67 | default=1024, 68 | metadata={ 69 | "help": "Maximum token length of the inputs." 70 | } 71 | ) 72 | with_embeddings: bool = field( 73 | default=True, 74 | metadata={ 75 | "help": "Whether the embeddings are generated during pre-processing." 76 | } 77 | ) 78 | synthetic_dataset: bool = field( 79 | default=False, 80 | metadata={ 81 | "help": "Whether a synthetic dataset is used." 82 | } 83 | ) 84 | other_subsets: str = field(default=None) 85 | 86 | 87 | def generate_embeddings_with_llm(args, input_dataset=None): 88 | """ 89 | This function is used to generate fixed embeddings for inputs from original dataset. 90 | """ 91 | if not args.synthetic_dataset: 92 | data_subset = cast(DataSubset, args.data_subset) 93 | input_dataset = get_hh_rlhf_dataset( 94 | data_subset, 95 | args.data_split, 96 | args.dataset_size, 97 | data_path=args.data_path, 98 | use_subset_as_dir=True, 99 | other_subsets=args.other_subsets, 100 | ) 101 | 102 | if args.model_type == "gpt2": 103 | tokenizer = AutoTokenizer.from_pretrained("gpt2", use_auth_token=True) 104 | model = AutoModelForSequenceClassification.from_pretrained( 105 | "gpt2", num_labels=args.embed_dim, torch_dtype=torch.bfloat16 106 | ) 107 | model.score.weight.data *= 0.01 108 | elif args.model_type == "llama" or args.model_type == "meta-llama/Llama-2-7b-hf": 109 | tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", use_auth_token=True, add_eos_token=False) 110 | model = AutoModelForCausalLM.from_pretrained( 111 | "meta-llama/Llama-2-7b-hf", torch_dtype=torch.bfloat16 112 | ) 113 | else: 114 | return input_dataset 115 | model.to("cuda") 116 | 117 | tokenizer.pad_token = tokenizer.eos_token 118 | tokenizer.pad_token_id = tokenizer.eos_token_id 119 | tokenizer.padding_side = "right" 120 | 121 | model.config.pad_token_id = tokenizer.pad_token_id 122 | dataset_size = len(input_dataset) 123 | print(dataset_size) 124 | 125 | preprocessed_dataset = input_dataset.map( 126 | HHRLHFPreprocessor(tokenizer), 127 | batched=True, 128 | num_proc=24, 129 | remove_columns=input_dataset.column_names, 130 | load_from_cache_file=False, 131 | ) 132 | 133 | input_dataset = input_dataset.filter( 134 | lambda example, idx: len(preprocessed_dataset[idx]["input_ids_chosen"]) <= args.max_length 135 | and len(preprocessed_dataset[idx]["input_ids_rejected"]) <= args.max_length, 136 | with_indices=True 137 | ) 138 | preprocessed_dataset = preprocessed_dataset.filter( 139 | lambda example: len(example["input_ids_chosen"]) <= args.max_length 140 | and len(example["input_ids_rejected"]) <= args.max_length 141 | ) 142 | print(len(input_dataset), len(preprocessed_dataset)) 143 | dataset_size = len(preprocessed_dataset) 144 | 145 | embeddings = list() 146 | for row_id in tqdm(range(dataset_size)): 147 | emb = dict() 148 | for key in ['chosen', 'rejected']: 149 | tokens = tokenizer.pad( 150 | {"input_ids": preprocessed_dataset[row_id][f"input_ids_{key}"]}, 151 | padding=True, pad_to_multiple_of=64, return_tensors="pt" 152 | ) 153 | token_length = len(preprocessed_dataset[row_id][f"input_ids_{key}"]) 154 | input_ids = tokens["input_ids"].unsqueeze(0).to("cuda") 155 | attention_mask = tokens["attention_mask"].unsqueeze(0).to("cuda") 156 | with torch.no_grad(): 157 | last_hidden_state = model( 158 | input_ids=input_ids, 159 | attention_mask=attention_mask, 160 | output_hidden_states=True 161 | ).hidden_states[-1] 162 | emb[f"embedding_{key}"] = last_hidden_state[0][token_length - 1].float().cpu().numpy() 163 | embeddings.append(emb) 164 | output_dataset = input_dataset.add_column("embeddings", embeddings) 165 | return output_dataset 166 | 167 | 168 | def generate_contexts(args, input_dataset): 169 | # Generate context without survey question pool 170 | output_dir = os.path.join(args.output_dir, f"{args.model_type}", f"{args.data_subset}") 171 | if not os.path.exists(output_dir): 172 | os.makedirs(output_dir) 173 | 174 | dataset_size = len(input_dataset) 175 | 176 | K = 1 # repeat samples for K times 177 | dataset_list = list() 178 | for idx in range(K): 179 | context_dataset = deepcopy(input_dataset) 180 | context_lengths = [8] * dataset_size 181 | if "context_length" in context_dataset.column_names: 182 | context_dataset = context_dataset.remove_columns("context_length") 183 | context_dataset = context_dataset.add_column("context_length", context_lengths) 184 | contexts = list() 185 | for row_id in tqdm(range(dataset_size)): # iterate over all samples in original dataset 186 | row_contexts = list() 187 | num_context = 0 188 | controversial_subset = input_dataset.filter(lambda example: example['controversial'] == True) 189 | controversial_size = len(controversial_subset) 190 | while num_context < context_lengths[row_id]: 191 | random_id = np.random.randint(controversial_size) 192 | context_id = controversial_subset[random_id]['Index'] 193 | context_data = controversial_subset[random_id] 194 | if not args.synthetic_dataset: 195 | if input_dataset[row_id]['prompt'] == context_data['prompt']: 196 | continue 197 | if not args.with_embeddings: 198 | row_contexts.append({ 199 | 'original_id': context_id, 200 | 'chosen': context_data['chosen'], 201 | 'rejected': context_data['rejected'], 202 | }) 203 | else: 204 | row_contexts.append({ 205 | 'original_id': context_id, 206 | 'embedding_chosen': context_data['embeddings']['embedding_chosen'], 207 | 'embedding_rejected': context_data['embeddings']['embedding_rejected'], 208 | }) 209 | num_context += 1 210 | contexts.append(row_contexts) 211 | context_dataset = context_dataset.add_column("contexts", contexts) 212 | dataset_list.append(context_dataset) 213 | 214 | output = concatenate_datasets(dataset_list) 215 | output.to_json(os.path.join(output_dir, f"{args.data_split}.jsonl")) 216 | return output 217 | 218 | 219 | if __name__ == "__main__": 220 | # default setting on HH-RLHF dataset, please iterate over data subsets and data splits 221 | seed = 0 222 | random.seed(seed) 223 | np.random.seed(seed) 224 | torch.manual_seed(seed) 225 | torch.cuda.manual_seed(seed) 226 | parser = HfArgumentParser(ScriptArguments) 227 | script_args: ScriptArguments = parser.parse_args_into_dataclasses()[0] 228 | print(script_args) 229 | dataset = generate_embeddings_with_llm(script_args) 230 | generate_contexts(script_args, dataset) 231 | -------------------------------------------------------------------------------- /hidden_context/data_utils/generate_simple_data.py: -------------------------------------------------------------------------------- 1 | # This file is used to generate synthetic language dataset 2 | from typing import cast 3 | 4 | from transformers import ( 5 | HfArgumentParser, 6 | ) 7 | 8 | import torch 9 | import random 10 | 11 | from hidden_context.data_utils.data_processing import ( 12 | ScriptArguments, 13 | generate_embeddings_with_llm, 14 | generate_contexts 15 | ) 16 | 17 | from hidden_context.data_utils.simple_templates import * 18 | 19 | from hidden_context.train_llm_preference_model import ( 20 | DataSubset, 21 | get_hh_rlhf_dataset, 22 | ) 23 | 24 | import numpy as np 25 | 26 | 27 | def generate_synthetic_dataset(args): 28 | data_subset = cast(DataSubset, args.data_subset) 29 | input_dataset = get_hh_rlhf_dataset( 30 | data_subset, 31 | args.data_split, 32 | args.dataset_size, 33 | data_path=args.data_path, 34 | use_subset_as_dir=True 35 | ) 36 | def generate_simple_data_point(example): 37 | prompt_length = 1 38 | prompt = 'Human: Please talk about one kind of pets.' 39 | if args.data_split == 'train': 40 | helpful_harmless = bird_sentences[:80] 41 | helpful_harmful = dog_sentences[:80] 42 | harmless_unhelpful = cat_sentences[:80] 43 | harmful_unhelpful = rabbit_sentences[:80] 44 | else: 45 | helpful_harmless = bird_sentences[80:] 46 | helpful_harmful = dog_sentences[80:] 47 | harmless_unhelpful = cat_sentences[80:] 48 | harmful_unhelpful = rabbit_sentences[80:] 49 | pair_type = np.random.randint(10) # set to 6 previously 50 | if pair_type == 0: 51 | chosen = np.random.choice(helpful_harmless) 52 | rejected = np.random.choice(helpful_harmful) 53 | elif pair_type == 1: 54 | chosen = np.random.choice(harmless_unhelpful) 55 | rejected = np.random.choice(harmful_unhelpful) 56 | elif pair_type == 2: 57 | chosen = np.random.choice(helpful_harmless) 58 | rejected = np.random.choice(harmless_unhelpful) 59 | elif pair_type == 3: 60 | chosen = np.random.choice(helpful_harmful) 61 | rejected = np.random.choice(harmful_unhelpful) 62 | elif pair_type == 4: 63 | chosen = np.random.choice(helpful_harmless) 64 | rejected = np.random.choice(harmful_unhelpful) 65 | else: 66 | if script_args.data_subset == 'helpful': 67 | chosen = np.random.choice(helpful_harmful) 68 | rejected = np.random.choice(harmless_unhelpful) 69 | else: 70 | chosen = np.random.choice(harmless_unhelpful) 71 | rejected = np.random.choice(helpful_harmful) 72 | chosen_repeated = ' '.join([chosen] * prompt_length) 73 | rejected_repeated = ' '.join([rejected] * prompt_length) 74 | return_dict = {'prompt': prompt, 'chosen': prompt + '\n\n' + 'Assistant: ' + chosen_repeated, 75 | 'rejected': prompt + '\n\n' + 'Assistant: ' + rejected_repeated} 76 | if example['label'] == 0: 77 | return_dict['responses'] = [chosen_repeated, rejected_repeated] 78 | else: 79 | return_dict['responses'] = [rejected_repeated, chosen_repeated] 80 | if pair_type >= 5: 81 | return_dict['controversial'] = True 82 | else: 83 | return_dict['controversial'] = False 84 | return return_dict 85 | 86 | input_dataset = input_dataset.map(generate_simple_data_point) 87 | print(len(input_dataset.filter(lambda x: x['controversial'] == True))) 88 | return input_dataset 89 | 90 | 91 | if __name__ == "__main__": 92 | # default setting on synthetic language dataset, please iterate over data subsets and data splits 93 | seed = 0 94 | random.seed(seed) 95 | np.random.seed(seed) 96 | torch.manual_seed(seed) 97 | torch.cuda.manual_seed(seed) 98 | parser = HfArgumentParser(ScriptArguments) 99 | script_args: ScriptArguments = parser.parse_args_into_dataclasses()[0] 100 | print(script_args) 101 | dataset = generate_synthetic_dataset(script_args) 102 | if script_args.with_embeddings: 103 | dataset = generate_embeddings_with_llm(script_args, dataset) 104 | generate_contexts(script_args, dataset) 105 | -------------------------------------------------------------------------------- /hidden_context/data_utils/simple_templates.py: -------------------------------------------------------------------------------- 1 | bird_sentences = [ 2 | "Birds chirp melodiously at the break of dawn.", 3 | "Birds of prey have keen eyesight for hunting.", 4 | "Birds migrate to warmer regions during winter.", 5 | "Birds build intricate nests using twigs and leaves.", 6 | "Birds exhibit a diverse range of colors and plumage.", 7 | "Birds communicate through various calls and songs.", 8 | "Birds have hollow bones which aid in flight.", 9 | "Birds navigate using the stars and Earth's magnetic field.", 10 | "Birds' feathers provide insulation and aid in flight.", 11 | "Birds adapt to different environments for survival.", 12 | "Birds of paradise have elaborate courtship displays.", 13 | "Birds use their beaks for feeding and grooming.", 14 | "Birds flock together for safety and socialization.", 15 | "Birds molt old feathers to make way for new ones.", 16 | "Birds play a crucial role in pollination.", 17 | "Birds like the ostrich are flightless but swift runners.", 18 | "Birds' songs vary depending on species and region.", 19 | "Birds mimic human speech and sounds.", 20 | "Birds exhibit territorial behavior during breeding season.", 21 | "Birds' eggs come in various sizes and colors.", 22 | "Birds preen their feathers to keep them clean and waterproof.", 23 | "Birds of prey hunt smaller animals for food.", 24 | "Birds use their wings for balance and stability.", 25 | "Birds migrate thousands of miles during migration.", 26 | "Birds like the penguin are adapted to life in icy waters.", 27 | "Birds' nests are often hidden for protection.", 28 | "Birds have different types of feet for various purposes.", 29 | "Birds possess remarkable intelligence and problem-solving skills.", 30 | "Birds' migration patterns are influenced by weather conditions.", 31 | "Birds huddle together to conserve body heat.", 32 | "Birds are classified into different orders and families.", 33 | "Birds' eyesight is sharper than that of humans.", 34 | "Birds' wingspans vary greatly among species.", 35 | "Birds roost in trees, cliffs, and buildings.", 36 | "Birds use vocalizations to establish dominance.", 37 | "Birds' beaks are adapted to their diet.", 38 | "Birds' feathers are made of keratin, like human hair and nails.", 39 | "Birds navigate using landmarks and celestial cues.", 40 | "Birds like the albatross spend years at sea.", 41 | "Birds' mating rituals can be elaborate and colorful.", 42 | "Birds of prey have strong talons for catching prey.", 43 | "Birds adapt to urban environments for nesting.", 44 | "Birds migrate to find abundant food sources.", 45 | "Birds of paradise have elaborate plumage for courtship.", 46 | "Birds' nests are lined with soft materials for comfort.", 47 | "Birds like the hummingbird can hover in mid-air.", 48 | "Birds' songs serve to attract mates and defend territory.", 49 | "Birds form intricate social structures within flocks.", 50 | "Birds exhibit different foraging techniques depending on their diet.", 51 | "Birds' wings allow them to glide effortlessly.", 52 | "Birds build nests using a variety of materials.", 53 | "Birds' feathers are crucial for regulating body temperature.", 54 | "Birds like the eagle have exceptional eyesight.", 55 | "Birds mark territories with vocalizations and displays.", 56 | "Birds migrate along established routes called flyways.", 57 | "Birds' feet are adapted to perching and grasping.", 58 | "Birds of prey hunt with precision and speed.", 59 | "Birds adapt their behavior to changes in their environment.", 60 | "Birds like the flamingo filter-feed in shallow waters.", 61 | "Birds' calls vary in pitch, tone, and rhythm.", 62 | "Birds migrate to breeding grounds during spring.", 63 | "Birds' beaks are adapted to their diet and feeding habits.", 64 | "Birds' nests vary in size and complexity.", 65 | "Birds rely on instinct and experience for navigation.", 66 | "Birds like the swallow undertake long migrations.", 67 | "Birds exhibit elaborate courtship displays to attract mates.", 68 | "Birds' eggs are incubated until they hatch.", 69 | "Birds migrate to avoid harsh winter conditions.", 70 | "Birds adapt to urban environments by scavenging for food.", 71 | "Birds' feathers are oiled to repel water.", 72 | "Birds mark their territories with scent and vocalizations.", 73 | "Birds display intricate mating rituals to attract partners.", 74 | "Birds' wings enable them to soar effortlessly.", 75 | "Birds adapt their behavior to changes in the environment.", 76 | "Birds build nests in trees, shrubs, and cliffs.", 77 | "Birds of prey hunt with precision and agility.", 78 | "Birds' nests are lined with soft materials for insulation.", 79 | "Birds migrate to warmer climates during winter.", 80 | "Birds communicate through calls, songs, and displays.", 81 | "Birds use their beaks for feeding, grooming, and defense.", 82 | "Birds' feathers provide insulation and aid in flight.", 83 | "Birds exhibit complex social behaviors within flocks.", 84 | "Birds adapt to different habitats for survival." 85 | ] 86 | 87 | cat_sentences = [ 88 | "Cats are mysterious creatures.", 89 | "Cats possess an independent nature.", 90 | "Cats enjoy lounging in sunny spots.", 91 | "Cats have retractable claws.", 92 | "Cats communicate through meows.", 93 | "Cats groom themselves meticulously.", 94 | "Cats can see in low light.", 95 | "Cats exhibit playful behavior.", 96 | "Cats are skilled hunters.", 97 | "Cats have a keen sense of balance.", 98 | "Cats enjoy chasing toys.", 99 | "Cats are known for their agility.", 100 | "Cats often nap throughout the day.", 101 | "Cats have unique personalities.", 102 | "Cats knead with their paws.", 103 | "Cats are territorial animals.", 104 | "Cats can jump several times their height.", 105 | "Cats dislike water in general.", 106 | "Cats have a strong sense of smell.", 107 | "Cats are crepuscular creatures.", 108 | "Cats purr when content.", 109 | "Cats are obligate carnivores.", 110 | "Cats have excellent hearing.", 111 | "Cats mark their territory with scent.", 112 | "Cats have specialized whiskers.", 113 | "Cats can be trained through positive reinforcement.", 114 | "Cats are curious by nature.", 115 | "Cats form strong bonds with their owners.", 116 | "Cats are skilled climbers.", 117 | "Cats have a flexible spine.", 118 | "Cats have a preference for routine.", 119 | "Cats have a grooming ritual after meals.", 120 | "Cats use their tails for balance.", 121 | "Cats have a hierarchy within colonies.", 122 | "Cats can recognize their names.", 123 | "Cats exhibit hunting behavior through stalking.", 124 | "Cats enjoy interactive playtime.", 125 | "Cats have different vocalizations for various needs.", 126 | "Cats can sense changes in the weather.", 127 | "Cats have an acute sense of taste.", 128 | "Cats show affection through headbutting.", 129 | "Cats have a natural instinct to hunt rodents.", 130 | "Cats exhibit territorial spraying behavior.", 131 | "Cats form social groups with other cats.", 132 | "Cats can sleep up to 16 hours a day.", 133 | "Cats are crepuscular hunters.", 134 | "Cats have an excellent sense of time.", 135 | "Cats sharpen their claws on scratching posts.", 136 | "Cats can experience stress from changes in their environment.", 137 | "Cats have a preference for certain textures.", 138 | "Cats communicate through body language.", 139 | "Cats enjoy high perches for observation.", 140 | "Cats are obligate carnivores, requiring meat in their diet.", 141 | "Cats can have litters of kittens multiple times a year.", 142 | "Cats have a strong maternal instinct.", 143 | "Cats have a hierarchy within multi-cat households.", 144 | "Cats are prone to hairballs from grooming.", 145 | "Cats have a sensitive digestive system.", 146 | "Cats groom other cats in their social group.", 147 | "Cats have specialized taste receptors.", 148 | "Cats enjoy hiding in small spaces.", 149 | "Cats have a unique grooming technique.", 150 | "Cats have a variety of coat patterns and colors.", 151 | "Cats have a third eyelid for protection.", 152 | "Cats display affection through slow blinking.", 153 | "Cats exhibit whisker fatigue if they touch narrow spaces.", 154 | "Cats enjoy observing prey from a hidden vantage point.", 155 | "Cats have a preference for fresh water sources.", 156 | "Cats can suffer from separation anxiety.", 157 | "Cats have an instinctual fear of unfamiliar objects.", 158 | "Cats enjoy toys that mimic prey.", 159 | "Cats exhibit kneading behavior when relaxed.", 160 | "Cats can develop allergies to certain foods.", 161 | "Cats have a sensitive sense of touch in their whiskers.", 162 | "Cats communicate through scent marking.", 163 | "Cats have a strong dislike for citrus scents.", 164 | "Cats can become stressed by changes in routine.", 165 | "Cats have a preferred sleeping position.", 166 | "Cats groom to regulate body temperature.", 167 | "Cats enjoy interactive feeding toys.", 168 | "Cats have a territorial response to other animals.", 169 | "Cats have specialized muscles for purring.", 170 | "Cats have a preferred scratching substrate.", 171 | "Cats can be trained to walk on a leash.", 172 | "Cats enjoy sunbathing near windows.", 173 | "Cats have a preference for certain types of litter.", 174 | "Cats communicate through vocalizations.", 175 | "Cats have a natural instinct to bury waste.", 176 | "Cats are crepuscular hunters by nature.", 177 | "Cats display affection through grooming rituals.", 178 | "Cats have a preference for routines in feeding times.", 179 | "Cats enjoy exploring new environments.", 180 | "Cats exhibit kneading behavior on soft surfaces.", 181 | "Cats have a varied vocal range for communication.", 182 | "Cats have a strong instinct to hunt small prey." 183 | ] 184 | 185 | dog_sentences = [ 186 | "Dogs are loyal companions.", 187 | "Dogs come in all shapes and sizes.", 188 | "Dogs have a keen sense of smell.", 189 | "Dogs enjoy playing fetch.", 190 | "Dogs love belly rubs.", 191 | "Dogs are known as man's best friend.", 192 | "Dogs make great therapy animals.", 193 | "Dogs are incredibly intelligent.", 194 | "Dogs can be trained to perform various tasks.", 195 | "Dogs need regular exercise.", 196 | "Dogs enjoy exploring the outdoors.", 197 | "Dogs have an innate sense of curiosity.", 198 | "Dogs communicate through body language.", 199 | "Dogs love treats.", 200 | "Dogs are pack animals by nature.", 201 | "Dogs have a strong sense of hierarchy.", 202 | "Dogs are capable of forming deep bonds with humans.", 203 | "Dogs provide emotional support to their owners.", 204 | "Dogs have a natural instinct to protect their families.", 205 | "Dogs have been domesticated for thousands of years.", 206 | "Dogs can learn a wide range of commands.", 207 | "Dogs have an excellent sense of hearing.", 208 | "Dogs have been used for hunting since ancient times.", 209 | "Dogs have a playful demeanor.", 210 | "Dogs are highly adaptable animals.", 211 | "Dogs have a variety of coat colors and patterns.", 212 | "Dogs are social animals that thrive on companionship.", 213 | "Dogs enjoy cuddling with their owners.", 214 | "Dogs are capable of expressing affection.", 215 | "Dogs can be trained for search and rescue missions.", 216 | "Dogs have a strong sense of territory.", 217 | "Dogs are skilled at interpreting human emotions.", 218 | "Dogs are often used in police work.", 219 | "Dogs enjoy participating in agility courses.", 220 | "Dogs have a remarkable ability to learn new things.", 221 | "Dogs are known for their unconditional love.", 222 | "Dogs are descendants of wolves.", 223 | "Dogs have a natural inclination to chase after moving objects.", 224 | "Dogs require proper grooming to stay healthy.", 225 | "Dogs have a unique set of vocalizations.", 226 | "Dogs have a powerful sense of taste.", 227 | "Dogs are adept at understanding routines.", 228 | "Dogs have been depicted in art throughout history.", 229 | "Dogs are capable of forming friendships with other animals.", 230 | "Dogs enjoy being part of a family unit.", 231 | "Dogs can detect changes in human behavior.", 232 | "Dogs have an extraordinary sense of balance.", 233 | "Dogs are sensitive to changes in their environment.", 234 | "Dogs have been trained to assist people with disabilities.", 235 | "Dogs enjoy playing with toys.", 236 | "Dogs have an innate sense of direction.", 237 | "Dogs are known to be protective of children.", 238 | "Dogs have a natural instinct to dig.", 239 | "Dogs are skilled at navigating through various terrains.", 240 | "Dogs enjoy receiving praise from their owners.", 241 | "Dogs have a strong sense of smell that can detect illness.", 242 | "Dogs thrive on routine and structure.", 243 | "Dogs enjoy sunbathing.", 244 | "Dogs have a playful rivalry with cats.", 245 | "Dogs require socialization from an early age.", 246 | "Dogs have a keen sense of time.", 247 | "Dogs are known to comfort people in distress.", 248 | "Dogs have a natural affinity for water.", 249 | "Dogs are capable of learning through observation.", 250 | "Dogs have been used in therapy for mental health conditions.", 251 | "Dogs enjoy exploring new scents.", 252 | "Dogs have a calming presence.", 253 | "Dogs are known for their ability to sense danger.", 254 | "Dogs enjoy being praised for good behavior.", 255 | "Dogs have a strong sense of empathy.", 256 | "Dogs are skilled at interpreting human gestures.", 257 | "Dogs have a natural inclination to mark their territory.", 258 | "Dogs enjoy sleeping in comfortable spots.", 259 | "Dogs have a playful nature that lasts into old age.", 260 | "Dogs are excellent at following scent trails.", 261 | "Dogs have been trained for military purposes.", 262 | "Dogs are often featured in movies and television shows.", 263 | "Dogs have a unique way of greeting each other.", 264 | "Dogs have been companions to humans for millennia.", 265 | "Dogs enjoy being part of outdoor activities.", 266 | "Dogs have been known to rescue people in distress.", 267 | "Dogs have a strong bond with their owners.", 268 | "Dogs have a calming effect on people.", 269 | "Dogs have a unique personality.", 270 | "Dogs have a variety of barks for different situations.", 271 | "Dogs enjoy exploring their surroundings.", 272 | "Dogs are known for their sense of playfulness.", 273 | "Dogs have been used in various forms of work throughout history.", 274 | "Dogs have a strong sense of loyalty to their families.", 275 | "Dogs are adept at learning from positive reinforcement.", 276 | "Dogs have been known to detect natural disasters before they occur.", 277 | "Dogs have a strong prey drive.", 278 | "Dogs enjoy being praised and rewarded for their efforts." 279 | ] 280 | 281 | rabbit_sentences = [ 282 | "Rabbits hop around in fields.", 283 | "Rabbits have soft fur.", 284 | "Rabbits eat carrots.", 285 | "Rabbits reproduce quickly.", 286 | "Rabbits are social animals.", 287 | "Rabbits have long ears.", 288 | "Rabbits twitch their noses.", 289 | "Rabbits dig burrows.", 290 | "Rabbits love to play.", 291 | "Rabbits can be pets.", 292 | "Rabbits come in many colors.", 293 | "Rabbits thump their feet.", 294 | "Rabbits are herbivores.", 295 | "Rabbits have large families.", 296 | "Rabbits groom themselves.", 297 | "Rabbits are prey animals.", 298 | "Rabbits are agile.", 299 | "Rabbits are prolific breeders.", 300 | "Rabbits have a keen sense of smell.", 301 | "Rabbits are crepuscular.", 302 | "Rabbits have powerful hind legs.", 303 | "Rabbits nibble on grass.", 304 | "Rabbits are cute.", 305 | "Rabbits have whiskers.", 306 | "Rabbits communicate through body language.", 307 | "Rabbits enjoy hiding in tunnels.", 308 | "Rabbits are fast runners.", 309 | "Rabbits are nocturnal.", 310 | "Rabbits have a natural curiosity.", 311 | "Rabbits have powerful hind legs.", 312 | "Rabbits need space to roam.", 313 | "Rabbits have a varied diet.", 314 | "Rabbits can be trained.", 315 | "Rabbits are associated with fertility.", 316 | "Rabbits have a short gestation period.", 317 | "Rabbits have a lifespan of 8-12 years.", 318 | "Rabbits enjoy companionship.", 319 | "Rabbits are quiet animals.", 320 | "Rabbits are known for their reproductive rate.", 321 | "Rabbits have a gentle disposition.", 322 | "Rabbits thump to communicate danger.", 323 | "Rabbits have a hierarchy within groups.", 324 | "Rabbits enjoy toys.", 325 | "Rabbits have strong teeth.", 326 | "Rabbits are territorial.", 327 | "Rabbits groom each other.", 328 | "Rabbits have a complex digestive system.", 329 | "Rabbits can be litter-trained.", 330 | "Rabbits binky when happy.", 331 | "Rabbits have a strong maternal instinct.", 332 | "Rabbits have a 360-degree field of vision.", 333 | "Rabbits are prolific diggers.", 334 | "Rabbits can be affectionate pets.", 335 | "Rabbits have a unique digestive process called cecotrophy.", 336 | "Rabbits have a fear of loud noises.", 337 | "Rabbits can suffer from heatstroke.", 338 | "Rabbits have a natural inclination to chew.", 339 | "Rabbits enjoy exploring new environments.", 340 | "Rabbits have a strong sense of territory.", 341 | "Rabbits are often depicted in folklore.", 342 | "Rabbits are used in scientific research.", 343 | "Rabbits are susceptible to parasites.", 344 | "Rabbits are good swimmers.", 345 | "Rabbits have a sensitive respiratory system.", 346 | "Rabbits have a hutch as their shelter.", 347 | "Rabbits have a strong sense of balance.", 348 | "Rabbits are clean animals.", 349 | "Rabbits can communicate with a range of vocalizations.", 350 | "Rabbits are known for their fertility.", 351 | "Rabbits enjoy being petted.", 352 | "Rabbits need hay for proper digestion.", 353 | "Rabbits are social creatures.", 354 | "Rabbits can suffer from loneliness.", 355 | "Rabbits have a natural instinct to burrow.", 356 | "Rabbits enjoy fresh vegetables.", 357 | "Rabbits have a strong maternal bond.", 358 | "Rabbits can be trained to use a litter box.", 359 | "Rabbits are crepuscular, meaning they are most active at dawn and dusk.", 360 | "Rabbits have a delicate skeletal structure.", 361 | "Rabbits have long been associated with luck.", 362 | "Rabbits can be territorial over their food.", 363 | "Rabbits have a complex system of communication.", 364 | "Rabbits need regular grooming to prevent matting.", 365 | "Rabbits have a natural curiosity about their environment.", 366 | "Rabbits enjoy companionship with other rabbits.", 367 | "Rabbits are very adaptable animals.", 368 | "Rabbits have a strong sense of smell to detect predators.", 369 | "Rabbits have a unique digestive system that requires a high-fiber diet.", 370 | "Rabbits have been domesticated for over 2,000 years.", 371 | "Rabbits have a keen sense of hearing to detect predators." 372 | ] 373 | 374 | -------------------------------------------------------------------------------- /hidden_context/data_utils/ultrafeedback_augment.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | 4 | import torch 5 | from datasets import load_dataset, Dataset 6 | import numpy as np 7 | import os 8 | 9 | 10 | def random_argmax(values): 11 | """ a random tie-breaking argmax """ 12 | return np.argmax(np.random.random(values.shape) * (values == values.max())) 13 | 14 | 15 | def random_greater_than_zero(values): 16 | return (np.random.randn(values.shape[0]) * (values == 0) > 0.0) | (values > 0.0) 17 | 18 | 19 | def array_to_type(arr): 20 | return str(int(np.dot(arr, np.array([8, 4, 2, 1])))) 21 | 22 | 23 | def get_user_type(chosen_ratings, rejected_ratings, augment_type): 24 | keys = ['helpfulness', 'honesty', 'instruction_following', 'truthfulness'] 25 | chosen_rating_values = list() 26 | rejected_rating_values = list() 27 | for key in keys: 28 | chosen_rating_values.append(chosen_ratings[key]) 29 | rejected_rating_values.append(rejected_ratings[key]) 30 | chosen_values = np.asarray(chosen_rating_values) 31 | rejected_values = np.asarray(rejected_rating_values) 32 | is_equal = list(chosen_values == rejected_values) 33 | if augment_type == 'single' or augment_type == '84': 34 | data_subsets = ['8', '4', '2', '1'] 35 | reversed_labels = {data_subsets[idx]: list(random_greater_than_zero(rejected_values - chosen_values))[idx] for 36 | idx in range(len(data_subsets))} 37 | is_equal = {data_subsets[idx]: is_equal[idx] for idx in range(len(data_subsets))} 38 | return data_subsets, reversed_labels, is_equal 39 | else: 40 | raise ValueError('Invalid augment_type') 41 | 42 | 43 | def inner_join(original, binarized, augment_type, users, two_two_only=False, filter_equal=False): 44 | agreed_counter = 0 45 | controversial_counter = 0 46 | keys = ['helpfulness', 'honesty', 'instruction_following', 'truthfulness'] 47 | user_counter = {key: 0 for key in users.keys()} 48 | reversed_counter = {key: 0 for key in users.keys()} 49 | dumb_baseline = {key: 0 for key in users.keys()} 50 | dumb_controversial_baseline = {key: 0 for key in users.keys()} 51 | orig_idx = 0 52 | out_idx = 0 53 | dataset_dict = { 54 | 'Index': list(), 55 | 'original_idx': list(), 56 | 'prompt': list(), 57 | 'chosen': list(), 58 | 'rejected': list(), 59 | 'data_subset': list(), 60 | 'controversial': list(), 61 | 'reversed': list(), 62 | 'satisfied_subset': list(), 63 | 'survey_options': list(), 64 | } 65 | for bin_idx in range(len(binarized)): 66 | while binarized[bin_idx]['prompt'] != original[orig_idx]['instruction']: 67 | orig_idx += 1 68 | prompt = binarized[bin_idx]['prompt'] 69 | chosen = binarized[bin_idx]['chosen'][1]['content'] 70 | rejected = binarized[bin_idx]['rejected'][1]['content'] 71 | if chosen == '' or rejected == '': 72 | continue 73 | chosen_ratings = dict() 74 | rejected_ratings = dict() 75 | flag = True 76 | for c in original[orig_idx]['completions']: 77 | if c['response'] == chosen: 78 | for key in keys: 79 | r = c['annotations'][key]['Rating'] 80 | if r == 'N/A': 81 | flag = False 82 | continue 83 | chosen_ratings[key] = int(r) 84 | elif c['response'] == rejected: 85 | for key in keys: 86 | r = c['annotations'][key]['Rating'] 87 | if r == 'N/A': 88 | flag = False 89 | continue 90 | rejected_ratings[key] = int(r) 91 | else: 92 | continue 93 | if not flag or len(chosen_ratings) != 4 or len(rejected_ratings) != 4: 94 | continue 95 | data_subsets, reversed_labels, is_equal = get_user_type(chosen_ratings, rejected_ratings, augment_type, users) 96 | if filter_equal: 97 | reversed_labels = {key: reversed_labels[key] for key in data_subsets if not is_equal[key]} 98 | data_subsets = [key for key in data_subsets if not is_equal[key]] 99 | is_equal = {key: False for key in data_subsets} 100 | if augment_type == '84' and len(is_equal.keys()) != 2: 101 | continue 102 | for data_subset in users.keys(): 103 | if data_subset not in data_subsets: 104 | dumb_baseline[data_subset] += 0.5 * len(data_subsets) 105 | if True in reversed_labels.values() and False in reversed_labels.values(): 106 | dumb_controversial_baseline[data_subset] += 0.5 * len(data_subsets) 107 | continue 108 | user_counter[data_subset] += 1 109 | if True in reversed_labels.values() and False in reversed_labels.values(): 110 | is_controversial = True 111 | controversial_counter += 1 112 | else: 113 | is_controversial = False 114 | agreed_counter += 1 115 | if reversed_labels[data_subset]: 116 | reversed_counter[data_subset] += 1 117 | dumb_baseline[data_subset] += list(reversed_labels.values()).count(True) 118 | if is_controversial: 119 | dumb_controversial_baseline[data_subset] += list(reversed_labels.values()).count(True) 120 | else: 121 | dumb_baseline[data_subset] += list(reversed_labels.values()).count(False) 122 | if is_controversial: 123 | dumb_controversial_baseline[data_subset] += list(reversed_labels.values()).count(False) 124 | dataset_dict['Index'].append(out_idx) 125 | dataset_dict['original_idx'].append(orig_idx) 126 | dataset_dict['prompt'].append(prompt) 127 | if not reversed_labels[data_subset]: 128 | dataset_dict['chosen'].append('Human: ' + prompt + '\n\nAssistant: ' + chosen) 129 | dataset_dict['rejected'].append('Human: ' + prompt + '\n\nAssistant: ' + rejected) 130 | else: 131 | dataset_dict['chosen'].append('Human: ' + prompt + '\n\nAssistant: ' + rejected) 132 | dataset_dict['rejected'].append('Human: ' + prompt + '\n\nAssistant: ' + chosen) 133 | dataset_dict['data_subset'].append(data_subset) 134 | dataset_dict['controversial'].append(is_controversial) 135 | dataset_dict['reversed'].append(reversed_labels[data_subset]) 136 | satisfied_subset = set([key for key in users.keys() if key not in data_subsets or reversed_labels[key] == reversed_labels[data_subset]]) 137 | dataset_dict['satisfied_subset'].append(satisfied_subset) 138 | dataset_dict['survey_options'].append(is_controversial and len(data_subsets) == len(users.keys())) 139 | out_idx += 1 140 | print(out_idx, agreed_counter, controversial_counter) 141 | print("User counter:", user_counter) 142 | print("Reversed counter:", reversed_counter) 143 | print("Dumb baseline:", dumb_baseline) 144 | print("Dumb controversial baseline:", dumb_controversial_baseline) 145 | return Dataset.from_dict(dataset_dict) 146 | 147 | 148 | if __name__ == '__main__': 149 | parser = argparse.ArgumentParser() 150 | parser.add_argument('--seed', type=int, default=42, help='Random seed') 151 | parser.add_argument('-a', '--augment_type', type=str, default='single', help='How to augment data') 152 | parser.add_argument('-c', '--controversial_only', action='store_true', help='Whether to only generate controversial data') 153 | parser.add_argument('-n', '--name', type=str, default='P_4', help='name of dataset') 154 | args = parser.parse_args() 155 | seed = args.seed 156 | random.seed(seed) 157 | np.random.seed(seed) 158 | torch.manual_seed(seed) 159 | torch.cuda.manual_seed(seed) 160 | if args.augment_type == 'single' or args.augment_type == '84': 161 | user_types = { 162 | '8': (1, 0, 0, 0), 163 | '4': (0, 1, 0, 0), 164 | '2': (0, 0, 1, 0), 165 | '1': (0, 0, 0, 1), 166 | } 167 | else: 168 | raise ValueError('Invalid augment_type') 169 | 170 | ultra_feedback = load_dataset('openbmb/UltraFeedback') 171 | binarized_cleaned = load_dataset('argilla/ultrafeedback-binarized-preferences-cleaned') 172 | length = len(binarized_cleaned['train']) 173 | print(length) 174 | test_ids = list(np.random.choice(length, int(length * 0.1), replace=False)) 175 | train_split = binarized_cleaned['train'].filter(lambda example, idx: idx not in test_ids, with_indices=True) 176 | test_split = binarized_cleaned['train'].filter(lambda example, idx: idx in test_ids, with_indices=True) 177 | print(len(train_split), len(test_split)) 178 | print("start processing train split") 179 | joined_dataset_train = inner_join(ultra_feedback['train'], train_split, args.augment_type, user_types) 180 | print("start processing test split") 181 | joined_dataset_test = inner_join(ultra_feedback['train'], test_split, args.augment_type, user_types) 182 | 183 | output_dir = os.path.join('data', 'UltraFeedback_{}_{}'.format(args.augment_type, args.name)) 184 | for user_type in user_types.keys(): 185 | train_subset = joined_dataset_train.filter(lambda x: x['data_subset'] == user_type) 186 | test_subset = joined_dataset_test.filter(lambda x: x['data_subset'] == user_type) 187 | if args.controversial_only: 188 | train_subset = train_subset.filter(lambda x: x['controversial'] == True) 189 | test_subset = test_subset.filter(lambda x: x['controversial'] == True) 190 | print(user_types[user_type], len(train_subset), len(test_subset)) 191 | train_subset.to_json(os.path.join(output_dir, user_type, 'train.jsonl')) 192 | test_subset.to_json(os.path.join(output_dir, user_type, 'test.jsonl')) 193 | 194 | # python -m hidden_context.data_utils.ultrafeedback_augment -a single -n P_4 -c 195 | 196 | # python -m hidden_context.data_utils.ultrafeedback_augment -a 84 -n P 197 | -------------------------------------------------------------------------------- /hidden_context/train_llm_preference_model.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from functools import partial 3 | from typing import Any, Dict, List, Optional, Type, Union, cast 4 | 5 | import wandb 6 | import numpy as np 7 | import torch 8 | import random 9 | import torch.nn.functional as F # noqa: N812 10 | from datasets import Dataset, concatenate_datasets, load_dataset 11 | from peft import LoraConfig, TaskType, get_peft_model 12 | from torch import nn 13 | from torch.optim.lr_scheduler import LambdaLR 14 | from transformers import ( 15 | AutoModelForSequenceClassification, 16 | AutoTokenizer, 17 | HfArgumentParser, 18 | PreTrainedTokenizerBase, 19 | Trainer, 20 | TrainingArguments, 21 | ) 22 | from transformers.trainer_utils import EvalPrediction 23 | from transformers.utils import PaddingStrategy 24 | from typing_extensions import Literal, TypeAlias 25 | 26 | RewardModelType: TypeAlias = Literal["base", "mean_and_variance", "categorical"] 27 | DataSubset: TypeAlias = Literal["both", "helpful", "harmless"] 28 | 29 | 30 | @dataclass 31 | class ScriptArguments: 32 | local_rank: int = field(default=-1, metadata={"help": "Used for multi-gpu"}) 33 | resume_from_checkpoint: bool = field( 34 | default=False, 35 | metadata={"help": "If you want to resume training where it left off."}, 36 | ) 37 | deepspeed: Optional[str] = field( 38 | default=None, 39 | metadata={ 40 | "help": "Path to deepspeed config if using deepspeed. You may need this " 41 | "if the model that you want to train doesn't fit on a single GPU." 42 | }, 43 | ) 44 | per_device_train_batch_size: int = field(default=2) 45 | per_device_eval_batch_size: int = field(default=1) 46 | gradient_accumulation_steps: int = field(default=1) 47 | learning_rate: float = field(default=3e-6) 48 | weight_decay: float = field(default=0.001) 49 | model_name: str = field( 50 | default="gpt2", 51 | metadata={ 52 | "help": "The model that you want to train from the Hugging Face hub. " 53 | "E.g. gpt2, gpt2-xl, bert, etc." 54 | }, 55 | ) 56 | data_path: str = field( 57 | default="Anthropic/hh-rlhf", 58 | ) 59 | data_subset: str = field( 60 | default="both", 61 | metadata={ 62 | "help": "Which subset of the data to use. You can choose between 'both', " 63 | "'helpful', or 'harmless'." 64 | }, 65 | ) 66 | reward_model_type: str = field( 67 | default="base", 68 | metadata={ 69 | "help": "The type of reward model to use. You can choose between " 70 | "'base', 'mean_and_variance', or 'categorical'." 71 | }, 72 | ) 73 | num_atoms: int = field( 74 | default=10, 75 | metadata={ 76 | "help": "The number of atoms to use for the categorical reward model." 77 | }, 78 | ) 79 | entropy_coeff: float = field( 80 | default=0.1, 81 | metadata={"help": "The entropy coefficient for the categorical reward model."}, 82 | ) 83 | variance_penalty: float = field( 84 | default=0.0, 85 | metadata={ 86 | "help": "The variance penalty for the mean and variance reward model." 87 | }, 88 | ) 89 | tokenizer_name: Optional[str] = field( 90 | default=None, 91 | metadata={ 92 | "help": "The tokenizer for your model, if left empty will use the default " 93 | "for your model", 94 | }, 95 | ) 96 | bf16: bool = field( 97 | default=True, 98 | metadata={ 99 | "help": "This essentially cuts the training time in half if you want to " 100 | "sacrifice a little precision and have a supported GPU." 101 | }, 102 | ) 103 | fp16: bool = field( 104 | default=False, 105 | metadata={ 106 | "help": "This essentially cuts the training time in half if you want to " 107 | "sacrifice a little precision and have a supported GPU." 108 | }, 109 | ) 110 | num_train_epochs: int = field( 111 | default=1, 112 | metadata={"help": "The number of training epochs for the reward model."}, 113 | ) 114 | train_dataset_size: int = field( 115 | default=0, 116 | metadata={"help": "The size of the subset of the training data to use"}, 117 | ) 118 | eval_dataset_size: int = field( 119 | default=0, 120 | metadata={"help": "The size of the subset of the eval data to use"}, 121 | ) 122 | gradient_checkpointing: bool = field( 123 | default=False, 124 | metadata={"help": "Enables gradient checkpointing."}, 125 | ) 126 | optim: str = field( 127 | default="adamw_hf", 128 | metadata={"help": "The optimizer to use."}, 129 | ) 130 | lr_scheduler_type: str = field( 131 | default="cosine", 132 | metadata={"help": "The lr scheduler"}, 133 | ) 134 | max_length: int = field(default=1024) 135 | eval_first_step: bool = field( 136 | default=False, 137 | metadata={"help": "Whether to run eval after the first step"}, 138 | ) 139 | log_dir: str = field(default="data/reward_models/hh_rlhf") 140 | controversial_only: bool = field(default=False) 141 | seed: int = field(default=0) 142 | up_sampling: bool = field(default=False) 143 | other_subsets: str = field(default=None) 144 | one_user: str = field( 145 | default=None, 146 | metadata={"help": "whether to only train and evaluate on one single user"} 147 | ) 148 | 149 | 150 | class HHRLHFPreprocessor(object): 151 | def __init__(self, tokenizer, **tokenizer_kwargs): 152 | self.tokenizer = tokenizer 153 | self.tokenizer_kwargs = tokenizer_kwargs 154 | 155 | def __call__(self, examples): 156 | new_examples: dict = { 157 | "input_ids_chosen": [], 158 | "attention_mask_chosen": [], 159 | "input_ids_rejected": [], 160 | "attention_mask_rejected": [], 161 | } 162 | for chosen, rejected in zip(examples["chosen"], examples["rejected"]): 163 | tokenized_chosen = self.tokenizer(chosen, **self.tokenizer_kwargs) 164 | tokenized_rejected = self.tokenizer(rejected, **self.tokenizer_kwargs) 165 | 166 | new_examples["input_ids_chosen"].append(tokenized_chosen["input_ids"]) 167 | new_examples["attention_mask_chosen"].append( 168 | tokenized_chosen["attention_mask"] 169 | ) 170 | new_examples["input_ids_rejected"].append(tokenized_rejected["input_ids"]) 171 | new_examples["attention_mask_rejected"].append( 172 | tokenized_rejected["attention_mask"] 173 | ) 174 | 175 | return new_examples 176 | 177 | 178 | def get_step_decay_lr_lambda(current_step: int, *, num_training_steps: int): 179 | if current_step < num_training_steps // 3: 180 | return 1.0 181 | elif current_step < (2 * num_training_steps) // 3: 182 | return 0.1 183 | else: 184 | return 0.01 185 | 186 | 187 | def get_cosine_decay_lr_lambda(current_step: int, *, num_training_steps: int): 188 | return 0.1 + 0.9 * 0.5 * (1 + np.cos(np.pi * current_step / num_training_steps)) 189 | 190 | 191 | class RewardTrainer(Trainer): 192 | def __init__(self, *args, lr_lambda=None, **kwargs): 193 | super().__init__(*args, **kwargs) 194 | self.lr_lambda = lr_lambda 195 | 196 | @classmethod 197 | def per_sample_loss(cls, rewards_chosen, rewards_rejected): 198 | return -nn.functional.logsigmoid(rewards_chosen - rewards_rejected) 199 | 200 | def loss(self, rewards_chosen, rewards_rejected): 201 | return torch.mean(self.per_sample_loss(rewards_chosen, rewards_rejected)) 202 | 203 | def compute_loss(self, model, inputs, return_outputs=False): 204 | all_rewards = model( 205 | torch.concatenate( 206 | [ 207 | inputs["input_ids_chosen"], 208 | inputs["input_ids_rejected"], 209 | ], 210 | dim=0, 211 | ), 212 | torch.concatenate( 213 | [ 214 | inputs["attention_mask_chosen"], 215 | inputs["attention_mask_rejected"], 216 | ], 217 | dim=0, 218 | ), 219 | )[0] 220 | all_rewards = all_rewards.reshape(2, -1, all_rewards.shape[-1]) 221 | rewards_chosen = all_rewards[0] 222 | rewards_rejected = all_rewards[1] 223 | loss = self.loss(rewards_chosen, rewards_rejected) 224 | if return_outputs: 225 | return loss, { 226 | "rewards_chosen": rewards_chosen, 227 | "rewards_rejected": rewards_rejected, 228 | } 229 | else: 230 | self.log( 231 | { 232 | "rewards_chosen": rewards_chosen.mean().item(), 233 | "rewards_rejected": rewards_rejected.mean().item(), 234 | } 235 | ) 236 | return loss 237 | 238 | def create_scheduler(self, num_training_steps: int, optimizer=None): 239 | if self.lr_lambda is not None: 240 | lr_lambda = partial( 241 | self.lr_lambda, 242 | num_training_steps=num_training_steps, 243 | ) 244 | self.lr_scheduler = LambdaLR(optimizer, lr_lambda) 245 | return self.lr_scheduler 246 | else: 247 | return super().create_scheduler(num_training_steps, optimizer) 248 | 249 | @classmethod 250 | def compute_metrics(cls, eval_prediction: EvalPrediction): 251 | rewards_chosen, rewards_rejected = eval_prediction.predictions 252 | rewards_chosen = torch.from_numpy(rewards_chosen) 253 | rewards_rejected = torch.from_numpy(rewards_rejected) 254 | 255 | loss = cls.per_sample_loss(rewards_chosen, rewards_rejected) 256 | accuracy = torch.mean((loss < np.log(2)).float()) 257 | 258 | return { 259 | "loss": loss.mean().item(), 260 | "accuracy": accuracy.item(), 261 | } 262 | 263 | 264 | class MeanAndVarianceRewardTrainer(RewardTrainer): 265 | def __init__(self, *args, variance_penalty: float = 0.0, **kwargs): 266 | super().__init__(*args, **kwargs) 267 | self.variance_penalty = variance_penalty 268 | 269 | @classmethod 270 | def per_sample_loss(cls, rewards_chosen, rewards_rejected): 271 | mean_chosen = rewards_chosen[:, 0] 272 | std_chosen = F.softplus(rewards_chosen[:, 1]) 273 | mean_rejected = rewards_rejected[:, 0] 274 | std_rejected = F.softplus(rewards_rejected[:, 1]) 275 | 276 | diff_mean = mean_chosen - mean_rejected 277 | var_combined = std_chosen**2 + std_rejected**2 278 | z = diff_mean / torch.sqrt(var_combined) 279 | return F.softplus(-z * np.sqrt(2 * np.pi)) 280 | 281 | def loss(self, rewards_chosen, rewards_rejected): 282 | std_chosen = F.softplus(rewards_chosen[:, 1]) 283 | std_rejected = F.softplus(rewards_rejected[:, 1]) 284 | variance_loss = (std_chosen**2 + std_rejected**2).mean() 285 | 286 | log_loss = self.per_sample_loss(rewards_chosen, rewards_rejected).mean() 287 | 288 | if self.model.training: 289 | return log_loss + self.variance_penalty * variance_loss 290 | else: 291 | return log_loss 292 | 293 | 294 | class CategoricalRewardTrainer(RewardTrainer): 295 | def __init__(self, *args, entropy_coeff: float = 0.0, **kwargs): 296 | super().__init__(*args, **kwargs) 297 | self.entropy_coeff = entropy_coeff 298 | 299 | @classmethod 300 | def per_sample_loss(cls, rewards_chosen, rewards_rejected): 301 | num_atoms = rewards_chosen.size()[1] 302 | device = rewards_chosen.device 303 | 304 | comparison_matrix = torch.empty( 305 | (num_atoms, num_atoms), 306 | device=device, 307 | dtype=rewards_chosen.dtype, 308 | ) 309 | atom_values = torch.linspace(0, 1, num_atoms, device=device) 310 | comparison_matrix[:] = atom_values[None, :] > atom_values[:, None] 311 | comparison_matrix[atom_values[None, :] == atom_values[:, None]] = 0.5 312 | 313 | dist_rejected = rewards_rejected.softmax(1) 314 | dist_chosen = rewards_chosen.softmax(1) 315 | prob_chosen = ((dist_rejected @ comparison_matrix) * dist_chosen).sum(dim=1) 316 | return -prob_chosen.log() 317 | 318 | def loss(self, rewards_chosen, rewards_rejected): 319 | dist_rejected = rewards_rejected.softmax(1) 320 | dist_chosen = rewards_chosen.softmax(1) 321 | mean_dist = torch.concatenate( 322 | [dist_chosen, dist_rejected], 323 | dim=0, 324 | ).mean(dim=0) 325 | entropy_loss = torch.sum(mean_dist * mean_dist.log()) 326 | 327 | log_loss = self.per_sample_loss(rewards_chosen, rewards_rejected).mean() 328 | 329 | if self.model.training: 330 | return log_loss + self.entropy_coeff * entropy_loss 331 | else: 332 | return log_loss 333 | 334 | 335 | def get_hh_rlhf_dataset( 336 | data_subset: DataSubset, 337 | split: Literal["train", "test"], 338 | dataset_size: int = 0, 339 | data_path="Anthropic/hh-rlhf", 340 | use_subset_as_dir=True, # new parameter 341 | other_subsets=None 342 | ) -> Dataset: 343 | datasets: List[Dataset] = [] 344 | if other_subsets is None: 345 | if data_path == "Anthropic/hh-rlhf": 346 | if data_subset == "harmless" or data_subset == "both": 347 | datasets.append( 348 | load_dataset( 349 | "Anthropic/hh-rlhf", data_dir="harmless-base", split=split 350 | ).map(lambda data: {"data_subset": "harmless"}) 351 | ) 352 | if data_subset == "helpful" or data_subset == "both": 353 | datasets.append( 354 | load_dataset( 355 | "Anthropic/hh-rlhf", data_dir="helpful-base", split=split 356 | ).map(lambda data: {"data_subset": "helpful"}) 357 | ) 358 | else: 359 | if not use_subset_as_dir: # original version: combine all data subsets within the path 360 | datasets.append( 361 | load_dataset(data_path, split=split).map( 362 | lambda data: {"data_subset": data_subset} 363 | ) 364 | ) 365 | else: # new version: use data_subset as subdirectory 366 | if data_subset == "helpful" or data_subset == "both": 367 | datasets.append( 368 | load_dataset( 369 | data_path, data_dir="helpful", split=split 370 | ).map(lambda data: {"data_subset": "helpful"}) 371 | ) 372 | if data_subset == "harmless" or data_subset == "both": 373 | datasets.append( 374 | load_dataset( 375 | data_path, data_dir="harmless", split=split 376 | ).map(lambda data: {"data_subset": "harmless"}) 377 | ) 378 | else: # TODO: set subsets here 379 | if other_subsets == 'ultra_feedback': 380 | subsets = ['helpfulness', 'honesty', 'instruction_following', 'truthfulness'] 381 | elif other_subsets == 'single': 382 | subsets = ['8', '4', '2', '1'] 383 | elif other_subsets == '84': 384 | subsets = ['8', '4'] 385 | else: 386 | subsets = [] 387 | for subset in subsets: 388 | if data_subset == 'all' or data_subset == subset: 389 | datasets.append( 390 | load_dataset( 391 | data_path, data_dir=subset, split=split 392 | ) 393 | ) 394 | 395 | if dataset_size: 396 | datasets = [ 397 | dataset.select(range(dataset_size // len(datasets))) for dataset in datasets 398 | ] 399 | 400 | return concatenate_datasets(datasets) 401 | 402 | 403 | trainer_classes: Dict[RewardModelType, Type[RewardTrainer]] = { 404 | "base": RewardTrainer, 405 | "mean_and_variance": MeanAndVarianceRewardTrainer, 406 | "categorical": CategoricalRewardTrainer, 407 | } 408 | 409 | 410 | def up_sample_controversial(dataset, seed): 411 | cont = dataset.filter(lambda example: example['controversial'] == True) 412 | up_sampled_dataset = concatenate_datasets([cont] * 4 + [dataset]) 413 | up_sampled_dataset = up_sampled_dataset.shuffle(seed=seed) 414 | return up_sampled_dataset 415 | 416 | 417 | if __name__ == "__main__": 418 | parser = HfArgumentParser(ScriptArguments) 419 | script_args: ScriptArguments = parser.parse_args_into_dataclasses()[0] 420 | 421 | seed = script_args.seed 422 | random.seed(seed) 423 | np.random.seed(seed) 424 | torch.manual_seed(seed) 425 | torch.cuda.manual_seed(seed) 426 | 427 | torch.set_default_dtype(torch.bfloat16 if script_args.bf16 else torch.float32) 428 | 429 | data_subset = cast(DataSubset, script_args.data_subset) 430 | train_dataset = get_hh_rlhf_dataset( 431 | data_subset, 432 | "train", 433 | script_args.train_dataset_size, 434 | data_path=script_args.data_path, 435 | use_subset_as_dir=True, 436 | other_subsets=script_args.other_subsets 437 | ) 438 | eval_dataset = get_hh_rlhf_dataset( 439 | data_subset, 440 | "test", 441 | script_args.eval_dataset_size, 442 | data_path=script_args.data_path, 443 | use_subset_as_dir=True, 444 | other_subsets=script_args.other_subsets 445 | ) 446 | print(len(train_dataset), len(eval_dataset)) 447 | if script_args.controversial_only: 448 | train_dataset = train_dataset.filter(lambda example: example['controversial'] == True) 449 | eval_dataset = eval_dataset.filter(lambda example: example['controversial'] == True) 450 | elif script_args.up_sampling: 451 | train_dataset = up_sample_controversial(train_dataset, seed) 452 | eval_dataset = up_sample_controversial(eval_dataset, seed) 453 | 454 | if script_args.one_user: 455 | train_dataset = train_dataset.filter(lambda example: example['data_subset'] == script_args.one_user) 456 | eval_dataset = eval_dataset.filter(lambda example: example['data_subset'] == script_args.one_user) 457 | 458 | reward_model_type = cast(RewardModelType, script_args.reward_model_type) 459 | 460 | # Define the training args. Needs to be done before the model is loaded if you 461 | # are using deepspeed. 462 | model_name_split = script_args.model_name.split("/")[-1] 463 | output_name = ( 464 | f"{script_args.log_dir}/{data_subset}/" 465 | f"{reward_model_type}_{model_name_split}" 466 | f"__{script_args.train_dataset_size}_{script_args.learning_rate}" 467 | f"_{script_args.lr_scheduler_type}_{script_args.num_train_epochs}" 468 | ) 469 | if reward_model_type == "categorical": 470 | output_name += f"_{script_args.num_atoms}_{script_args.entropy_coeff}" 471 | elif reward_model_type == "mean_and_variance": 472 | output_name += f"_{script_args.variance_penalty}" 473 | 474 | output_name += f"_seed{script_args.seed}" 475 | trainer_kwargs: Dict[str, Any] = {} 476 | if script_args.lr_scheduler_type == "step": 477 | lr_scheduler_type = "constant" 478 | trainer_kwargs["lr_lambda"] = get_step_decay_lr_lambda 479 | elif script_args.lr_scheduler_type == "cosine": 480 | lr_scheduler_type = "constant" 481 | trainer_kwargs["lr_lambda"] = get_cosine_decay_lr_lambda 482 | else: 483 | lr_scheduler_type = script_args.lr_scheduler_type 484 | 485 | training_args = TrainingArguments( 486 | output_dir=output_name, 487 | learning_rate=script_args.learning_rate, 488 | per_device_train_batch_size=script_args.per_device_train_batch_size, 489 | per_device_eval_batch_size=script_args.per_device_eval_batch_size, 490 | num_train_epochs=script_args.num_train_epochs, 491 | weight_decay=script_args.weight_decay, 492 | evaluation_strategy="steps", 493 | eval_steps=0.1, 494 | save_strategy="steps", 495 | save_steps=1000, 496 | gradient_accumulation_steps=script_args.gradient_accumulation_steps, 497 | gradient_checkpointing=script_args.gradient_checkpointing, 498 | deepspeed=script_args.deepspeed, 499 | local_rank=script_args.local_rank, 500 | remove_unused_columns=False, 501 | label_names=[], 502 | bf16=script_args.bf16, 503 | fp16=script_args.fp16, 504 | logging_strategy="steps", 505 | logging_steps=10, 506 | optim=script_args.optim, 507 | lr_scheduler_type=lr_scheduler_type, 508 | report_to="wandb", 509 | run_name=output_name.split("/")[-1], 510 | ) 511 | # Load the value-head model and tokenizer. 512 | tokenizer_name = ( 513 | script_args.tokenizer_name 514 | if script_args.tokenizer_name is not None 515 | else script_args.model_name 516 | ) 517 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_auth_token=True, add_eos_token=False) 518 | 519 | peft_config = LoraConfig( 520 | task_type=TaskType.SEQ_CLS, 521 | inference_mode=False, 522 | r=128, 523 | lora_alpha=256, 524 | lora_dropout=0.1, 525 | ) 526 | 527 | torch.set_anomaly_enabled(True) 528 | 529 | trainer_class = trainer_classes[reward_model_type] 530 | if reward_model_type == "base": 531 | num_labels = 1 532 | elif reward_model_type == "mean_and_variance": 533 | num_labels = 2 534 | trainer_kwargs["variance_penalty"] = script_args.variance_penalty 535 | elif reward_model_type == "categorical": 536 | num_labels = script_args.num_atoms 537 | trainer_kwargs["entropy_coeff"] = script_args.entropy_coeff 538 | 539 | model = AutoModelForSequenceClassification.from_pretrained( 540 | script_args.model_name, num_labels=num_labels, torch_dtype=torch.bfloat16 541 | ) 542 | # We multiply the final linear layer's weights by 0.01 because this seems to 543 | # significantly stabilize training and lead to better optimization of the loss. 544 | model.score.weight.data *= 0.01 545 | model = get_peft_model(model, peft_config) 546 | model.print_trainable_parameters() 547 | 548 | # Need to do this for GPT2 and Llama because they doesn't have official pad tokens. 549 | tokenizer.pad_token = tokenizer.eos_token 550 | tokenizer.pad_token_id = tokenizer.eos_token_id 551 | model.config.pad_token_id = tokenizer.pad_token_id 552 | tokenizer.padding_side = "right" 553 | 554 | model.config.use_cache = not script_args.gradient_checkpointing 555 | num_proc = 24 # Can adjust to be higher if you have more processors. 556 | original_columns = train_dataset.column_names 557 | 558 | train_dataset = train_dataset.map( 559 | HHRLHFPreprocessor(tokenizer), 560 | batched=True, 561 | num_proc=num_proc, 562 | remove_columns=original_columns, 563 | ) 564 | train_dataset = train_dataset.filter( 565 | lambda x: len(x["input_ids_chosen"]) <= script_args.max_length 566 | and len(x["input_ids_rejected"]) <= script_args.max_length 567 | ) 568 | print(len(train_dataset)) 569 | 570 | eval_dataset = eval_dataset.map( 571 | HHRLHFPreprocessor(tokenizer), 572 | batched=True, 573 | num_proc=num_proc, 574 | remove_columns=original_columns, 575 | ) 576 | eval_dataset = eval_dataset.filter( 577 | lambda x: len(x["input_ids_chosen"]) <= script_args.max_length 578 | and len(x["input_ids_rejected"]) <= script_args.max_length 579 | ) 580 | print(len(eval_dataset)) 581 | 582 | # We need to define a special data collator that batches the data in our j vs k format. 583 | @dataclass 584 | class RewardDataCollatorWithPadding: 585 | tokenizer: PreTrainedTokenizerBase 586 | padding: Union[bool, str, PaddingStrategy] = True 587 | max_length: Optional[int] = None 588 | pad_to_multiple_of: Optional[int] = None 589 | return_tensors: str = "pt" 590 | 591 | def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: 592 | features_chosen = [] 593 | features_rejected = [] 594 | for feature in features: 595 | features_chosen.append( 596 | { 597 | "input_ids": feature["input_ids_chosen"], 598 | "attention_mask": feature["attention_mask_chosen"], 599 | } 600 | ) 601 | features_rejected.append( 602 | { 603 | "input_ids": feature["input_ids_rejected"], 604 | "attention_mask": feature["attention_mask_rejected"], 605 | } 606 | ) 607 | batch = self.tokenizer.pad( 608 | features_chosen + features_rejected, 609 | padding=self.padding, 610 | max_length=self.max_length, 611 | pad_to_multiple_of=self.pad_to_multiple_of, 612 | return_tensors=self.return_tensors, 613 | ) 614 | input_ids = batch["input_ids"].view(2, -1, batch["input_ids"].shape[-1]) 615 | attention_mask = batch["attention_mask"].view( 616 | 2, -1, batch["attention_mask"].shape[-1] 617 | ) 618 | return { 619 | "input_ids_chosen": input_ids[0], 620 | "attention_mask_chosen": attention_mask[0], 621 | "input_ids_rejected": input_ids[1], 622 | "attention_mask_rejected": attention_mask[1], 623 | "return_loss": True, 624 | } 625 | 626 | # Train the model. 627 | trainer = trainer_class( 628 | model=model, 629 | args=training_args, 630 | train_dataset=train_dataset, 631 | eval_dataset=eval_dataset, 632 | compute_metrics=trainer_class.compute_metrics, 633 | data_collator=RewardDataCollatorWithPadding( 634 | tokenizer=tokenizer, 635 | max_length=script_args.max_length, 636 | pad_to_multiple_of=64, 637 | ), 638 | **trainer_kwargs, 639 | ) 640 | 641 | trainer.train(script_args.resume_from_checkpoint) 642 | 643 | print("Saving last checkpoint of the model") 644 | model.save_pretrained(output_name + "_peft_last_checkpoint") 645 | -------------------------------------------------------------------------------- /hidden_context/train_llm_vae_preference_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass, field 3 | from typing import Any, Dict, List, Optional, Type, Union, cast 4 | 5 | import numpy as np 6 | import torch 7 | import random 8 | from peft import LoraConfig, TaskType, get_peft_model 9 | from transformers import ( 10 | AutoModelForSequenceClassification, 11 | AutoTokenizer, 12 | HfArgumentParser, 13 | PreTrainedTokenizerBase, 14 | TrainingArguments, 15 | TrainerCallback, 16 | ) 17 | from transformers.utils import PaddingStrategy 18 | from .vae_utils import VAETrainer, VAEModel 19 | 20 | from .train_llm_preference_model import ( 21 | get_step_decay_lr_lambda, 22 | get_cosine_decay_lr_lambda, 23 | RewardModelType, 24 | DataSubset, 25 | get_hh_rlhf_dataset, 26 | concatenate_datasets 27 | ) 28 | 29 | 30 | @dataclass 31 | class ScriptArguments: 32 | local_rank: int = field(default=-1, metadata={"help": "Used for multi-gpu"}) 33 | resume_from_checkpoint: bool = field( 34 | default=False, 35 | metadata={"help": "If you want to resume training where it left off."}, 36 | ) 37 | deepspeed: Optional[str] = field( 38 | default=None, 39 | metadata={ 40 | "help": "Path to deepspeed config if using deepspeed. You may need this " 41 | "if the model that you want to train doesn't fit on a single GPU." 42 | }, 43 | ) 44 | per_device_train_batch_size: int = field(default=2) 45 | per_device_eval_batch_size: int = field(default=1) 46 | gradient_accumulation_steps: int = field(default=1) 47 | learning_rate: float = field(default=3e-6) 48 | weight_decay: float = field(default=0.001) 49 | model_name: str = field( 50 | default="gpt2", 51 | metadata={ 52 | "help": "The model that you want to train from the Hugging Face hub. " 53 | "E.g. gpt2, gpt2-xl, bert, etc." 54 | }, 55 | ) 56 | data_path: str = field( 57 | default="Anthropic/hh-rlhf", 58 | ) 59 | data_subset: str = field( 60 | default="both", 61 | metadata={ 62 | "help": "Which subset of the data to use. You can choose between 'both', " 63 | "'helpful', or 'harmless'." 64 | }, 65 | ) 66 | reward_model_type: str = field( 67 | default="base", 68 | metadata={ 69 | "help": "The type of reward model to use. You can choose between " 70 | "'base', 'mean_and_variance', or 'categorical'." 71 | }, 72 | ) 73 | num_atoms: int = field( 74 | default=10, 75 | metadata={ 76 | "help": "The number of atoms to use for the categorical reward model." 77 | }, 78 | ) 79 | entropy_coeff: float = field( 80 | default=0.1, 81 | metadata={"help": "The entropy coefficient for the categorical reward model."}, 82 | ) 83 | variance_penalty: float = field( 84 | default=0.0, 85 | metadata={ 86 | "help": "The variance penalty for the mean and variance reward model." 87 | }, 88 | ) 89 | tokenizer_name: Optional[str] = field( 90 | default=None, 91 | metadata={ 92 | "help": "The tokenizer for your model, if left empty will use the default " 93 | "for your model", 94 | }, 95 | ) 96 | bf16: bool = field( 97 | default=True, 98 | metadata={ 99 | "help": "This essentially cuts the training time in half if you want to " 100 | "sacrifice a little precision and have a supported GPU." 101 | }, 102 | ) 103 | fp16: bool = field( 104 | default=False, 105 | metadata={ 106 | "help": "This essentially cuts the training time in half if you want to " 107 | "sacrifice a little precision and have a supported GPU." 108 | }, 109 | ) 110 | num_train_epochs: int = field( 111 | default=1, 112 | metadata={"help": "The number of training epochs for the reward model."}, 113 | ) 114 | train_dataset_size: int = field( 115 | default=0, 116 | metadata={"help": "The size of the subset of the training data to use"}, 117 | ) 118 | eval_dataset_size: int = field( 119 | default=0, 120 | metadata={"help": "The size of the subset of the eval data to use"}, 121 | ) 122 | gradient_checkpointing: bool = field( 123 | default=False, 124 | metadata={"help": "Enables gradient checkpointing."}, 125 | ) 126 | optim: str = field( 127 | default="adamw_hf", 128 | metadata={"help": "The optimizer to use."}, 129 | ) 130 | lr_scheduler_type: str = field( 131 | default="cosine", 132 | metadata={"help": "The lr scheduler"}, 133 | ) 134 | max_length: int = field(default=1024) 135 | eval_first_step: bool = field( 136 | default=True, 137 | metadata={"help": "Whether to run eval after the first step"}, 138 | ) 139 | log_dir: str = field(default="data/reward_models/hh_rlhf") 140 | kl_loss_weight: float = field(default=0.01, metadata={"help": "weight for KLD loss"}) 141 | latent_dim: int = field(default=512, metadata={"help": "dimension of latent user vector"}) # todo: 64 142 | hidden_dim: int = field(default=512, metadata={"help": "dimension of hidden layer in vae"}) # todo: 256 143 | encoder_embed_dim: int = field(default=1024, metadata={"help": "dimension of LLM embeddings for encoder"}) 144 | decoder_embed_dim: int = field(default=1024, metadata={"help": "dimension of LLM embeddings for decoder"}) 145 | use_annealing: bool = field(default=True, metadata={"help": "Whether to use annealing for learning rate"}) 146 | fixed_contexts: bool = field( 147 | default=False, 148 | metadata={"help": "whether to use pre-calculated embeddings for contexts (encoder inputs)"} 149 | ) 150 | fixed_llm_embeddings: bool = field( 151 | default=False, 152 | metadata={"help": "whether to use pre-calculated embeddings for decoder inputs"} 153 | ) 154 | seed: int = field(default=0) 155 | controversial_only: bool = field( 156 | default=False, 157 | metadata={"help": "whether to only include controversial data"} 158 | ) 159 | up_sampling: bool = field( 160 | default=False, 161 | metadata={"help": "whether to upsample controversial data during training phase"} 162 | ) 163 | one_user: str = field( 164 | default=None, 165 | metadata={"help": "whether to only train and evaluate on one single user"} 166 | ) 167 | other_subsets: str = field( 168 | default=None, 169 | metadata={"help": "specify the group of subsets if not using helpful/harmless. You can choose between" 170 | "ultra_feedback, pos_neg, set, single."}, 171 | ) 172 | use_last_token_embedding: bool = field( 173 | default=False, 174 | metadata={"help": "whether to use the last token embedding of last layer as LLM embeddings"} 175 | ) 176 | 177 | class HHRLHFPreprocessor(object): 178 | def __init__(self, args, tokenizer, **tokenizer_kwargs): 179 | self.tokenizer = tokenizer 180 | self.args = args 181 | self.tokenizer_kwargs = tokenizer_kwargs 182 | 183 | def __call__(self, examples): 184 | if self.args.fixed_llm_embeddings: 185 | new_examples: dict = { 186 | "embedding_chosen": [], 187 | "embedding_rejected": [], 188 | "contexts_embeddings": [], 189 | "max_lengths": [] 190 | } 191 | for embeddings, contexts in zip( 192 | examples["embeddings"], examples["contexts"] 193 | ): 194 | new_examples["embedding_chosen"].append(embeddings["embedding_chosen"]) 195 | new_examples["embedding_rejected"].append(embeddings["embedding_rejected"]) 196 | contexts_embeddings = [{"embedding_chosen": context["embedding_chosen"], 197 | "embedding_rejected": context["embedding_rejected"]} 198 | for context in contexts] 199 | new_examples["contexts_embeddings"].append(contexts_embeddings) 200 | new_examples["max_lengths"].append(0) 201 | new_examples["user_type"] = examples["data_subset"] 202 | return new_examples 203 | 204 | new_examples: dict = { 205 | "input_ids_chosen": [], 206 | "attention_mask_chosen": [], 207 | "input_ids_rejected": [], 208 | "attention_mask_rejected": [], 209 | "max_lengths": [] 210 | } 211 | if self.args.fixed_contexts: 212 | new_examples["contexts_embeddings"] = [] 213 | else: 214 | new_examples["contexts_tokens"] = [] 215 | for chosen, rejected, contexts, user_type in zip( 216 | examples["chosen"], examples["rejected"], examples["contexts"], examples["data_subset"] 217 | ): 218 | max_length = 0 219 | tokenized_chosen = self.tokenizer(chosen, **self.tokenizer_kwargs) 220 | tokenized_rejected = self.tokenizer(rejected, **self.tokenizer_kwargs) 221 | new_examples["input_ids_chosen"].append(tokenized_chosen["input_ids"]) 222 | new_examples["attention_mask_chosen"].append( 223 | tokenized_chosen["attention_mask"] 224 | ) 225 | new_examples["input_ids_rejected"].append(tokenized_rejected["input_ids"]) 226 | new_examples["attention_mask_rejected"].append( 227 | tokenized_rejected["attention_mask"] 228 | ) 229 | max_length = max(max_length, len(tokenized_chosen["input_ids"])) 230 | max_length = max(max_length, len(tokenized_rejected["input_ids"])) 231 | 232 | if self.args.fixed_contexts: 233 | contexts_embeddings = [{"embedding_chosen": context["embedding_chosen"], 234 | "embedding_rejected": context["embedding_rejected"]} 235 | for context in contexts] 236 | new_examples["contexts_embeddings"].append(contexts_embeddings) 237 | else: 238 | tokenized_context = [] 239 | # Tokenize the contexts. 240 | for context in contexts: 241 | chosen, rejected = context["chosen"], context["rejected"] 242 | tokenized_chosen = self.tokenizer(chosen, **self.tokenizer_kwargs) 243 | tokenized_rejected = self.tokenizer(rejected, **self.tokenizer_kwargs) 244 | tokenized_context.append( 245 | { 246 | "input_ids_chosen": tokenized_chosen["input_ids"], 247 | "attention_mask_chosen": tokenized_chosen["attention_mask"], 248 | "input_ids_rejected": tokenized_rejected["input_ids"], 249 | "attention_mask_rejected": tokenized_rejected["attention_mask"], 250 | } 251 | ) 252 | max_length = max(max_length, len(tokenized_chosen["input_ids"])) 253 | max_length = max(max_length, len(tokenized_rejected["input_ids"])) 254 | new_examples["contexts_tokens"].append(tokenized_context) 255 | new_examples["max_lengths"].append(max_length) 256 | new_examples["user_type"] = examples["data_subset"] 257 | return new_examples 258 | 259 | 260 | trainer_classes: Dict[RewardModelType, Type[VAETrainer]] = { 261 | "vae": VAETrainer, 262 | } 263 | 264 | 265 | # We need to define a special data collator that batches the data in our j vs k format. 266 | @dataclass 267 | class RewardDataCollatorWithPadding: 268 | args: ScriptArguments 269 | tokenizer: PreTrainedTokenizerBase 270 | padding: Union[bool, str, PaddingStrategy] = True 271 | max_length: Optional[int] = None 272 | pad_to_multiple_of: Optional[int] = None 273 | return_tensors: str = "pt" 274 | 275 | def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: 276 | if self.args.other_subsets is None: 277 | user_mapping = { 278 | "helpful": 0, 279 | "harmless": 1, 280 | } 281 | else: # TODO: set subsets here 282 | if self.args.other_subsets == 'ultra_feedback': 283 | subsets = ['helpfulness', 'honesty', 'instruction_following', 'truthfulness'] 284 | elif self.args.other_subsets == 'single' or self.args.other_subsets == '84': 285 | subsets = ['8', '4', '2', '1'] 286 | else: 287 | subsets = [] 288 | user_mapping = {subset: idx for idx, subset in enumerate(subsets)} 289 | if self.args.fixed_llm_embeddings: 290 | batch_size = len(features) 291 | embeddings_chosen = [] 292 | embeddings_rejected = [] 293 | contexts_embeddings_chosen = [] 294 | contexts_embeddings_rejected = [] 295 | contexts_lengths = [0] 296 | for feature in features: 297 | embeddings_chosen.append( 298 | feature["embedding_chosen"] 299 | ) 300 | embeddings_rejected.append( 301 | feature["embedding_rejected"] 302 | ) 303 | contexts_embeddings_chosen.extend( 304 | [ 305 | context["embedding_chosen"] for context in feature["contexts_embeddings"] 306 | ] 307 | ) 308 | contexts_embeddings_rejected.extend( 309 | [ 310 | context["embedding_rejected"] for context in feature["contexts_embeddings"] 311 | ] 312 | ) 313 | contexts_lengths.append(len(feature["contexts_embeddings"])) 314 | contexts_lengths = torch.cumsum(torch.tensor(contexts_lengths), dim=0) 315 | seq_start_end = torch.stack( 316 | [contexts_lengths[:-1], contexts_lengths[1:]], dim=1 317 | ) 318 | user_type = [user_mapping[feature["user_type"]] for feature in features] 319 | assert len(seq_start_end) == batch_size 320 | return { 321 | "embeddings_chosen": embeddings_chosen, 322 | "embeddings_rejected": embeddings_rejected, 323 | "contexts_embeddings_chosen": contexts_embeddings_chosen, 324 | "contexts_embeddings_rejected": contexts_embeddings_rejected, 325 | "seq_start_end": seq_start_end, 326 | "return_loss": True, 327 | "user_type": user_type, 328 | } 329 | if self.args.fixed_contexts: 330 | batch_size = len(features) 331 | features_chosen = [] 332 | features_rejected = [] 333 | contexts_embeddings_chosen = [] 334 | contexts_embeddings_rejected = [] 335 | contexts_lengths = [0] 336 | for feature in features: 337 | features_chosen.append( 338 | { 339 | "input_ids": feature["input_ids_chosen"], 340 | "attention_mask": feature["attention_mask_chosen"], 341 | } 342 | ) 343 | features_rejected.append( 344 | { 345 | "input_ids": feature["input_ids_rejected"], 346 | "attention_mask": feature["attention_mask_rejected"], 347 | } 348 | ) 349 | # Creating a flattened list of contexts. 350 | contexts_embeddings_chosen.extend( 351 | [ 352 | context["embedding_chosen"] for context in feature["contexts_embeddings"] 353 | ] 354 | ) 355 | contexts_embeddings_rejected.extend( 356 | [ 357 | context["embedding_rejected"] for context in feature["contexts_embeddings"] 358 | ] 359 | ) 360 | # Keep track of the start and end of each sequence. 361 | contexts_lengths.append(len(feature["contexts_embeddings"])) 362 | 363 | batch = self.tokenizer.pad( 364 | features_chosen + features_rejected, 365 | padding=self.padding, 366 | max_length=self.max_length, 367 | pad_to_multiple_of=self.pad_to_multiple_of, 368 | return_tensors=self.return_tensors, 369 | ) 370 | 371 | input_ids = batch["input_ids"].view( 372 | 2, batch_size, batch["input_ids"].shape[-1] 373 | ) 374 | attention_mask = batch["attention_mask"].view( 375 | 2, batch_size, batch["attention_mask"].shape[-1] 376 | ) 377 | 378 | context_lengths = torch.cumsum(torch.tensor(contexts_lengths), dim=0) 379 | seq_start_end = torch.stack( 380 | [context_lengths[:-1], context_lengths[1:]], dim=1 381 | ) 382 | user_type = [user_mapping[feature["user_type"]] for feature in features] 383 | assert len(seq_start_end) == batch_size 384 | 385 | return { 386 | "input_ids_chosen": input_ids[0], 387 | "attention_mask_chosen": attention_mask[0], 388 | "input_ids_rejected": input_ids[1], 389 | "attention_mask_rejected": attention_mask[1], 390 | "contexts_embeddings_chosen": contexts_embeddings_chosen, 391 | "contexts_embeddings_rejected": contexts_embeddings_rejected, 392 | "seq_start_end": seq_start_end, 393 | "return_loss": True, 394 | "user_type": user_type, 395 | } 396 | 397 | batch_size = len(features) 398 | features_chosen = [] 399 | features_rejected = [] 400 | contexts_features_chosen = [] 401 | contexts_features_rejected = [] 402 | contexts_lengths = [0] 403 | for feature in features: 404 | features_chosen.append( 405 | { 406 | "input_ids": feature["input_ids_chosen"], 407 | "attention_mask": feature["attention_mask_chosen"], 408 | } 409 | ) 410 | features_rejected.append( 411 | { 412 | "input_ids": feature["input_ids_rejected"], 413 | "attention_mask": feature["attention_mask_rejected"], 414 | } 415 | ) 416 | 417 | # Creating a flattened list of contexts. 418 | contexts_features_chosen.extend( 419 | [ 420 | { 421 | "input_ids": context["input_ids_chosen"], 422 | "attention_mask": context["attention_mask_chosen"], 423 | } 424 | for context in feature["contexts_tokens"] 425 | ] 426 | ) 427 | contexts_features_rejected.extend( 428 | [ 429 | { 430 | "input_ids": context["input_ids_rejected"], 431 | "attention_mask": context["attention_mask_rejected"], 432 | } 433 | for context in feature["contexts_tokens"] 434 | ] 435 | ) 436 | # Keep track of the start and end of each sequence. 437 | contexts_lengths.append(len(feature["contexts_tokens"])) 438 | 439 | batch = self.tokenizer.pad( 440 | features_chosen + features_rejected + contexts_features_chosen + contexts_features_rejected, 441 | padding=self.padding, 442 | max_length=self.max_length, 443 | pad_to_multiple_of=self.pad_to_multiple_of, 444 | return_tensors=self.return_tensors, 445 | ) 446 | 447 | input_ids = batch["input_ids"][:2 * batch_size].view( 448 | 2, batch_size, batch["input_ids"].shape[-1] 449 | ) 450 | attention_mask = batch["attention_mask"][:2 * batch_size].view( 451 | 2, batch_size, batch["attention_mask"].shape[-1] 452 | ) 453 | 454 | contexts_lengths = torch.cumsum(torch.tensor(contexts_lengths), dim=0) 455 | seq_start_end = torch.stack( 456 | [contexts_lengths[:-1], contexts_lengths[1:]], dim=1 457 | ) 458 | user_type = [user_mapping[feature["user_type"]] for feature in features] 459 | assert len(seq_start_end) == batch_size 460 | context_ids = batch["input_ids"][2 * batch_size:].view( 461 | 2, contexts_lengths[-1], batch["input_ids"].shape[-1] 462 | ) 463 | context_attention_mask = batch["attention_mask"][2 * batch_size:].view( 464 | 2, contexts_lengths[-1], batch["attention_mask"].shape[-1] 465 | ) 466 | 467 | return { 468 | "input_ids_chosen": input_ids[0], 469 | "attention_mask_chosen": attention_mask[0], 470 | "input_ids_rejected": input_ids[1], 471 | "attention_mask_rejected": attention_mask[1], 472 | "contexts_input_ids_chosen": context_ids[0], 473 | "contexts_attention_mask_chosen": context_attention_mask[0], 474 | "contexts_input_ids_rejected": context_ids[1], 475 | "contexts_attention_mask_rejected": context_attention_mask[1], 476 | "seq_start_end": seq_start_end, 477 | "return_loss": True, 478 | "user_type": user_type, 479 | } 480 | 481 | 482 | def up_sample_controversial(dataset, seed): 483 | cont = dataset.filter(lambda example: example['controversial'] == True) 484 | up_sampled_dataset = concatenate_datasets([cont] * 4 + [dataset]) 485 | up_sampled_dataset = up_sampled_dataset.shuffle(seed=seed) 486 | return up_sampled_dataset 487 | 488 | 489 | def customized_optimizer(model, lr): 490 | encoder_params = [p for p in model.parameters() if p not in model.decoder.parameters()] 491 | decoder_params = [p for p in model.parameters() if p in model.decoder.parameters()] 492 | grouped_parameters = [ 493 | {'params': encoder_params, 'lr': lr}, 494 | {'params': decoder_params, 'lr': lr / 10}, 495 | ] 496 | return 497 | 498 | 499 | if __name__ == "__main__": 500 | parser = HfArgumentParser(ScriptArguments) 501 | script_args: ScriptArguments = parser.parse_args_into_dataclasses()[0] 502 | 503 | seed = script_args.seed 504 | random.seed(seed) 505 | np.random.seed(seed) 506 | torch.manual_seed(seed) 507 | torch.cuda.manual_seed(seed) 508 | 509 | torch.set_default_dtype(torch.bfloat16 if script_args.bf16 else torch.float32) 510 | 511 | if script_args.use_last_token_embedding: 512 | if script_args.model_name == 'gpt2': 513 | script_args.decoder_embed_dim = 768 514 | script_args.encoder_embed_dim = 768 515 | if script_args.model_name == 'meta-llama/Llama-2-7b-hf': 516 | script_args.decoder_embed_dim = 4096 517 | script_args.encoder_embed_dim = 4096 518 | 519 | data_subset = cast(DataSubset, script_args.data_subset) 520 | train_dataset = get_hh_rlhf_dataset( 521 | data_subset, 522 | "train", 523 | script_args.train_dataset_size, 524 | data_path=script_args.data_path, 525 | other_subsets=script_args.other_subsets 526 | ) 527 | eval_dataset = get_hh_rlhf_dataset( 528 | data_subset, 529 | "test", 530 | script_args.eval_dataset_size, 531 | data_path=script_args.data_path, 532 | other_subsets=script_args.other_subsets 533 | ) 534 | print(len(train_dataset), len(eval_dataset)) 535 | if script_args.controversial_only: 536 | train_dataset = train_dataset.filter(lambda example: example['controversial'] == True) 537 | eval_dataset = eval_dataset.filter(lambda example: example['controversial'] == True) 538 | elif script_args.up_sampling: 539 | train_dataset = up_sample_controversial(train_dataset, seed) 540 | 541 | if script_args.one_user: 542 | train_dataset = train_dataset.filter(lambda example: example['data_subset'] == script_args.one_user) 543 | eval_dataset = eval_dataset.filter(lambda example: example['data_subset'] == script_args.one_user) 544 | reward_model_type = cast(RewardModelType, script_args.reward_model_type) 545 | 546 | # Define the training args. Needs to be done before the model is loaded if you 547 | # are using deepspeed. 548 | model_name_split = script_args.model_name.split("/")[-1] 549 | output_name = ( 550 | f"{script_args.log_dir}/{data_subset}/" 551 | f"{reward_model_type}_{model_name_split}" 552 | f"__{script_args.train_dataset_size}_{script_args.learning_rate}" 553 | f"_{script_args.lr_scheduler_type}_{script_args.num_train_epochs}" 554 | ) 555 | output_name += f"_{script_args.kl_loss_weight}_{script_args.latent_dim}_{script_args.decoder_embed_dim}_seed{script_args.seed}" 556 | 557 | trainer_kwargs: Dict[str, Any] = {} 558 | if script_args.lr_scheduler_type == "step": 559 | lr_scheduler_type = "constant" 560 | trainer_kwargs["lr_lambda"] = get_step_decay_lr_lambda 561 | elif script_args.lr_scheduler_type == "cosine": 562 | lr_scheduler_type = "constant" 563 | trainer_kwargs["lr_lambda"] = get_cosine_decay_lr_lambda 564 | else: 565 | lr_scheduler_type = script_args.lr_scheduler_type 566 | 567 | training_args = TrainingArguments( 568 | output_dir=output_name, 569 | learning_rate=script_args.learning_rate, 570 | per_device_train_batch_size=script_args.per_device_train_batch_size, 571 | per_device_eval_batch_size=script_args.per_device_eval_batch_size, 572 | num_train_epochs=script_args.num_train_epochs, 573 | weight_decay=script_args.weight_decay, 574 | evaluation_strategy="steps", 575 | eval_steps=0.05, 576 | save_strategy="steps", 577 | save_steps=10000, 578 | gradient_accumulation_steps=script_args.gradient_accumulation_steps, 579 | gradient_checkpointing=script_args.gradient_checkpointing, 580 | deepspeed=script_args.deepspeed, 581 | local_rank=script_args.local_rank, 582 | remove_unused_columns=False, 583 | label_names=[], 584 | bf16=script_args.bf16, 585 | fp16=script_args.fp16, 586 | logging_strategy="steps", 587 | logging_steps=100, 588 | optim=script_args.optim, 589 | lr_scheduler_type=lr_scheduler_type, 590 | report_to="wandb", 591 | run_name=output_name.split("/")[-1], 592 | ) 593 | # Load the value-head model and tokenizer. 594 | tokenizer_name = ( 595 | script_args.tokenizer_name 596 | if script_args.tokenizer_name is not None 597 | else script_args.model_name 598 | ) 599 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_auth_token=True, add_eos_token=False) 600 | 601 | peft_config = LoraConfig( 602 | task_type=TaskType.SEQ_CLS, 603 | inference_mode=False, 604 | r=128, 605 | lora_alpha=256, 606 | lora_dropout=0.1, 607 | ) 608 | 609 | torch.set_anomaly_enabled(True) 610 | 611 | trainer_class = trainer_classes[reward_model_type] 612 | decoder_embed_dim = script_args.decoder_embed_dim 613 | encoder_embed_dim = script_args.encoder_embed_dim 614 | 615 | model = AutoModelForSequenceClassification.from_pretrained( 616 | script_args.model_name, num_labels=decoder_embed_dim, torch_dtype=torch.bfloat16 617 | ) 618 | # We multiply the final linear layer's weights by 0.01 because this seems to 619 | # significantly stabilize training and lead to better optimization of the loss. 620 | model.score.weight.data *= 0.01 621 | if not script_args.fixed_contexts: 622 | contexts_model = AutoModelForSequenceClassification.from_pretrained( 623 | script_args.model_name, num_labels=encoder_embed_dim, torch_dtype=torch.bfloat16 624 | ) 625 | contexts_model.score.weight.data *= 0.01 626 | model = get_peft_model(model, peft_config) 627 | model.print_trainable_parameters() 628 | 629 | if not script_args.fixed_contexts: 630 | contexts_model = get_peft_model(contexts_model, peft_config) 631 | contexts_model.print_trainable_parameters() 632 | contexts_model.config.pad_token_id = tokenizer.pad_token_id 633 | contexts_model.config.use_cache = not script_args.gradient_checkpointing 634 | else: 635 | contexts_model = None 636 | 637 | # Need to do this for GPT2 and Llama because they don't have official pad tokens. 638 | tokenizer.pad_token = tokenizer.eos_token 639 | tokenizer.pad_token_id = tokenizer.eos_token_id 640 | model.config.pad_token_id = tokenizer.pad_token_id 641 | tokenizer.padding_side = "right" 642 | 643 | model.config.use_cache = not script_args.gradient_checkpointing 644 | num_proc = 24 # Can adjust to be higher if you have more processors. 645 | original_columns = train_dataset.column_names 646 | 647 | train_dataset = train_dataset.map( 648 | HHRLHFPreprocessor(script_args, tokenizer), 649 | batched=True, 650 | num_proc=num_proc, 651 | remove_columns=original_columns, 652 | ) 653 | train_dataset = train_dataset.filter( 654 | lambda x: x["max_lengths"] <= script_args.max_length 655 | ) 656 | 657 | eval_dataset = eval_dataset.map( 658 | HHRLHFPreprocessor(script_args, tokenizer), 659 | batched=True, 660 | num_proc=num_proc, 661 | remove_columns=original_columns, 662 | ) 663 | eval_dataset = eval_dataset.filter( 664 | lambda x: x["max_lengths"] <= script_args.max_length 665 | ) 666 | 667 | # Train the model. 668 | latent_dim = script_args.latent_dim 669 | hidden_dim = script_args.hidden_dim 670 | vae_model = VAEModel(encoder_embed_dim, decoder_embed_dim, hidden_dim, latent_dim, model, contexts_model, 671 | fixed_contexts=script_args.fixed_contexts, 672 | fixed_llm_embeddings=script_args.fixed_llm_embeddings,) 673 | 674 | trainer = trainer_class( 675 | model=vae_model, 676 | args=training_args, 677 | train_dataset=train_dataset, 678 | eval_dataset=eval_dataset, 679 | compute_metrics=trainer_class.compute_metrics, 680 | data_collator=RewardDataCollatorWithPadding( 681 | args=script_args, 682 | tokenizer=tokenizer, 683 | max_length=script_args.max_length, 684 | pad_to_multiple_of=64, 685 | ), 686 | kl_loss_weight=script_args.kl_loss_weight, 687 | use_annealing=script_args.use_annealing, 688 | **trainer_kwargs, 689 | ) 690 | 691 | class EvaluateFirstStepCallback(TrainerCallback): 692 | def on_step_begin(self, args, state, control, **kwargs): 693 | if state.global_step == 0: 694 | control.should_evaluate = True 695 | 696 | 697 | trainer.add_callback(EvaluateFirstStepCallback()) 698 | 699 | trainer.train(script_args.resume_from_checkpoint) 700 | 701 | print("Saving last checkpoint of the model") 702 | 703 | model.save_pretrained(output_name + "_peft_last_checkpoint") 704 | output_name += "_peft_last_checkpoint" 705 | os.makedirs(output_name, exist_ok=True) 706 | 707 | output_name = os.path.join(output_name, "model.pt") 708 | vae_model.save_model(output_name) 709 | -------------------------------------------------------------------------------- /hidden_context/vae_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | import torch.nn as nn 6 | from transformers import Trainer, EvalPrediction 7 | import wandb 8 | from transformers.optimization import get_cosine_schedule_with_warmup 9 | 10 | 11 | class PairEncoder(nn.Module): 12 | """ 13 | Model to encode pairs of accepted and rejected responses 14 | """ 15 | 16 | def __init__(self, embed_dim, hidden_dim, output_dim): 17 | super(PairEncoder, self).__init__() 18 | 19 | self._model = nn.Sequential( 20 | nn.Linear(2 * embed_dim, hidden_dim), 21 | nn.LeakyReLU(0.2), 22 | nn.Linear(hidden_dim, hidden_dim), 23 | nn.LeakyReLU(0.2), 24 | nn.Linear(hidden_dim, output_dim), 25 | ) 26 | 27 | def forward(self, e_c, e_r): 28 | x = torch.cat([e_c, e_r], dim=1) 29 | return self._model(x) 30 | 31 | 32 | class SequenceEncoder(nn.Module): 33 | """ 34 | Model to encode sequence of responses 35 | """ 36 | 37 | def __init__(self, input_dim, latent_dim): 38 | super(SequenceEncoder, self).__init__() 39 | self.input_dim = input_dim 40 | self.latent_dim = latent_dim 41 | 42 | self.linear = nn.Identity() 43 | self.w_q = nn.Linear(input_dim, input_dim) 44 | self.w_k = nn.Linear(input_dim, input_dim) 45 | self.w_v = nn.Linear(input_dim, input_dim) 46 | self.mean_layer = nn.Linear(input_dim, latent_dim) 47 | self.log_var_layer = nn.Linear(input_dim, latent_dim) 48 | self.layer_norm = nn.Identity() # nn.LayerNorm(latent_dim) 49 | 50 | def forward( 51 | self, sequences, seq_start_end 52 | ): # (C_1+C_2+...+C_n, D), [(0, C_1), (C_1, C_1+C_2), ..., (C_1+...+C_n-1, C_1+...+C_n)] 53 | outputs = [] 54 | for _, (start, end) in enumerate(seq_start_end): 55 | context = sequences[start:end] # C_i x D 56 | q = self.w_q(context) 57 | k = self.w_k(context) 58 | attention_scores = torch.matmul( 59 | q, k.transpose(0, 1) 60 | ) 61 | attention_scores = attention_scores / (context.shape[-1] ** 0.5) 62 | attention_weights = F.softmax(attention_scores, dim=-1) # C_i x C_i 63 | weighted_values = torch.matmul(attention_weights, self.w_v(context)) # C_i x D 64 | output = torch.mean(weighted_values, dim=0) # D 65 | outputs.append(output) 66 | outputs = torch.stack(outputs, dim=0) # n x D 67 | 68 | mean = self.layer_norm(self.mean_layer(outputs)) 69 | log_var = self.layer_norm(self.log_var_layer(outputs)) 70 | return mean, log_var 71 | 72 | 73 | class Decoder(nn.Module): 74 | def __init__(self, input_dim, hidden_dim): 75 | super(Decoder, self).__init__() 76 | self._model = nn.Sequential( 77 | nn.Linear(input_dim, hidden_dim), 78 | nn.LeakyReLU(0.2), 79 | nn.Linear(hidden_dim, hidden_dim), 80 | nn.LeakyReLU(0.2), 81 | nn.Linear(hidden_dim, 1), 82 | ) 83 | 84 | def forward(self, xc, xr, z): 85 | xc = torch.cat([xc, z], dim=1) 86 | xr = torch.cat([xr, z], dim=1) 87 | rc = self._model(xc) 88 | rr = self._model(xr) 89 | return rc, rr 90 | 91 | 92 | class VAEModel(nn.Module): 93 | def __init__(self, encoder_embed_dim, decoder_embed_dim, hidden_dim, latent_dim, llm_encoder, llm_contexts_encoder, 94 | fixed_contexts=False, fixed_llm_embeddings=False, use_causal_lm=False, use_attention_layer=False, 95 | use_transformer=False, concat_chosen_rejected=False): 96 | super(VAEModel, self).__init__() 97 | self.llm_encoder = llm_encoder 98 | self.llm_contexts_encoder = llm_contexts_encoder 99 | self.pair_encoder = PairEncoder(encoder_embed_dim, hidden_dim, latent_dim) 100 | self.sequence_encoder = SequenceEncoder(latent_dim, latent_dim) 101 | self.decoder = Decoder(decoder_embed_dim + latent_dim, hidden_dim) 102 | 103 | self.latent_dim = latent_dim 104 | self.fixed_contexts = fixed_contexts 105 | self.fixed_llm_embeddings = fixed_llm_embeddings 106 | self.use_causal_lm = use_causal_lm 107 | self.use_attention_layer = use_attention_layer 108 | self.use_transformer = use_transformer 109 | self.concat_chosen_rejected = concat_chosen_rejected 110 | 111 | self.saved_embeddings = torch.Tensor(4, latent_dim) 112 | self.saved_embeddings.uniform_(-1, 1) 113 | 114 | def reparameterization(self, mean, std): 115 | epsilon = torch.randn_like(std).to(mean.device) # sampling epsilon 116 | z = mean + std * epsilon # reparameterization trick 117 | z = F.normalize(z, p=2, dim=-1) * math.sqrt(z.shape[-1]) 118 | return z 119 | 120 | def encode_pair(self, e_c, e_r): 121 | return self.pair_encoder(e_c, e_r) 122 | 123 | def encode_sequence(self, sequences, seq_start_end): 124 | return self.sequence_encoder(sequences, seq_start_end) 125 | 126 | def decode(self, e_c, e_r, z): 127 | return self.decoder(e_c, e_r, z) 128 | 129 | def forward( 130 | self, 131 | target_chosen, 132 | target_rejected, 133 | context_chosen, 134 | context_rejected, 135 | seq_start_end, 136 | user_type, 137 | ground_truth_user_vector=False, 138 | ): 139 | pair_embed = self.encode_pair(context_chosen, context_rejected) 140 | mean, log_var = self.encode_sequence(pair_embed, seq_start_end) 141 | mean = torch.clamp(mean, -1, 1) 142 | 143 | _log_var = torch.clamp(log_var, -1, 1) 144 | if ground_truth_user_vector: 145 | z = torch.zeros_like(mean) 146 | self.saved_embeddings = self.saved_embeddings.to(mean.device) 147 | for idx in range(user_type.shape[0]): 148 | z[idx] = self.saved_embeddings[int(user_type[idx])] 149 | else: 150 | z = self.reparameterization(mean, torch.exp(0.5 * _log_var)) 151 | 152 | if not self.training and not ground_truth_user_vector: 153 | z = mean 154 | rc, rr = self.decode(target_chosen, target_rejected, z) 155 | 156 | return rc, rr, mean, _log_var, z 157 | 158 | def save_model(self, path): 159 | torch.save(self, path) 160 | 161 | 162 | class VAETrainer(Trainer): 163 | def __init__( 164 | self, *args, lr_lambda=None, kl_loss_weight=None, use_annealing=False, **kwargs 165 | ): 166 | super().__init__(*args, **kwargs) 167 | self.lr_lambda = lr_lambda 168 | self.kl_loss_weight = kl_loss_weight 169 | self.use_annealing = use_annealing 170 | self.annealer = Annealer( 171 | total_steps=1e4, shape="cosine", baseline=0.1, cyclical=True # todo: change total_step here 172 | ) 173 | 174 | @classmethod 175 | def per_sample_loss(cls, rewards_chosen, rewards_rejected): 176 | return -nn.functional.logsigmoid(rewards_chosen - rewards_rejected) 177 | 178 | def loss(self, rewards_chosen, rewards_rejected): 179 | return torch.mean(self.per_sample_loss(rewards_chosen, rewards_rejected)) 180 | 181 | def compute_loss(self, wrapped_model, inputs, return_outputs=False): 182 | if isinstance(wrapped_model, VAEModel): 183 | model = wrapped_model # .module 184 | else: 185 | model = wrapped_model.module 186 | device = model.llm_encoder.device 187 | batch_size = inputs["seq_start_end"].shape[0] 188 | if model.fixed_llm_embeddings: 189 | embeddings_chosen = torch.tensor(inputs["embeddings_chosen"]).to(device).bfloat16() 190 | embeddings_rejected = torch.tensor(inputs["embeddings_rejected"]).to(device).bfloat16() 191 | else: 192 | embeddings = model.llm_encoder( 193 | input_ids=torch.concatenate( 194 | [ 195 | inputs["input_ids_chosen"], 196 | inputs["input_ids_rejected"], 197 | ], 198 | dim=0, 199 | ), 200 | attention_mask=torch.concatenate( 201 | [ 202 | inputs["attention_mask_chosen"], 203 | inputs["attention_mask_rejected"], 204 | ], 205 | dim=0, 206 | ), 207 | )[0] 208 | embeddings_chosen = embeddings[:batch_size] 209 | embeddings_rejected = embeddings[batch_size:] 210 | 211 | if model.fixed_contexts: 212 | contexts_embeddings_chosen = torch.tensor(inputs["contexts_embeddings_chosen"]).to(device).bfloat16() 213 | contexts_embeddings_rejected = torch.tensor(inputs["contexts_embeddings_rejected"]).to(device).bfloat16() 214 | else: 215 | input_ids_chosen = inputs["contexts_input_ids_chosen"] 216 | attention_mask_chosen = inputs["contexts_attention_mask_chosen"] 217 | token_length_chosen = torch.eq(input_ids_chosen, 218 | model.llm_contexts_encoder.config.pad_token_id).int().argmax(-1) - 1 219 | input_ids_rejected = inputs["contexts_input_ids_rejected"] 220 | attention_mask_rejected = inputs["contexts_attention_mask_rejected"] 221 | token_length_rejected = torch.eq(input_ids_rejected, 222 | model.llm_contexts_encoder.config.pad_token_id).int().argmax(-1) - 1 223 | 224 | with torch.no_grad(): 225 | last_hidden_state_chosen = model.llm_contexts_encoder( 226 | input_ids=input_ids_chosen, 227 | attention_mask=attention_mask_chosen, 228 | output_hidden_states=True 229 | ).hidden_states[-1] 230 | 231 | weights_for_non_padding_chosen = attention_mask_chosen * torch.arange( 232 | start=1, end=last_hidden_state_chosen.shape[1] + 1 233 | ).unsqueeze(0).to(attention_mask_chosen.device).float() 234 | sum_embeddings = torch.sum(last_hidden_state_chosen * weights_for_non_padding_chosen.unsqueeze(-1), 235 | dim=1) 236 | num_of_none_padding_tokens_chosen = torch.sum(weights_for_non_padding_chosen, dim=-1).unsqueeze(-1) 237 | contexts_embeddings_chosen = sum_embeddings / num_of_none_padding_tokens_chosen 238 | last_hidden_state_rejected = model.llm_contexts_encoder( 239 | input_ids=input_ids_rejected, 240 | attention_mask=attention_mask_rejected, 241 | output_hidden_states=True 242 | ).hidden_states[-1] 243 | 244 | weights_for_non_padding_rejected = attention_mask_rejected * torch.arange( 245 | start=1, end=last_hidden_state_rejected.shape[1] + 1 246 | ).unsqueeze(0).to(attention_mask_rejected.device).float() 247 | sum_embeddings = torch.sum(last_hidden_state_rejected * weights_for_non_padding_rejected.unsqueeze(-1), 248 | dim=1) 249 | num_of_none_padding_tokens_rejected = torch.sum(weights_for_non_padding_rejected, dim=-1).unsqueeze(-1) 250 | contexts_embeddings_rejected = sum_embeddings / num_of_none_padding_tokens_rejected 251 | seq_start_end = inputs["seq_start_end"] 252 | user_type = torch.tensor(inputs["user_type"]).to(device).bfloat16() 253 | rewards_chosen, rewards_rejected, mean, log_var, z = model( 254 | embeddings_chosen, 255 | embeddings_rejected, 256 | contexts_embeddings_chosen, 257 | contexts_embeddings_rejected, 258 | seq_start_end, 259 | user_type, 260 | ground_truth_user_vector=False, # todo: set to True for debug usage 261 | mask_chosen=inputs["attention_mask_chosen"], 262 | mask_rejected=inputs["attention_mask_rejected"], 263 | ) 264 | 265 | reproduction_loss = self.loss(rewards_chosen, rewards_rejected) 266 | if self.kl_loss_weight == 0: 267 | loss = reproduction_loss 268 | accuracy = torch.mean((rewards_chosen > rewards_rejected).float()) 269 | if not return_outputs: 270 | self.log( 271 | { 272 | "train_recon_loss": reproduction_loss.mean().item(), 273 | "train_accuracy": accuracy.mean().item(), 274 | "rewards_chosen": rewards_chosen.mean().item(), 275 | "rewards_rejected": rewards_rejected.mean().item(), 276 | "embeddings_chosen": embeddings_chosen.mean().item(), 277 | "embeddings_rejected": embeddings_rejected.mean().item(), 278 | "mean": mean.mean().item(), 279 | "log_var": log_var.mean().item() 280 | } 281 | ) 282 | else: 283 | kld = -0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp(), dim=1).mean() 284 | if self.use_annealing: 285 | kld = self.annealer(kld) 286 | self.annealer.step() 287 | kld = self.kl_loss_weight * kld 288 | loss = reproduction_loss + kld 289 | accuracy = torch.mean((rewards_chosen > rewards_rejected).float()) 290 | if not return_outputs: 291 | self.log( 292 | { 293 | "train_recon_loss": reproduction_loss.mean().item(), 294 | "train_kld": kld.mean().item(), 295 | "train_accuracy": accuracy.mean().item(), 296 | "rewards_chosen": rewards_chosen.mean().item(), 297 | "rewards_rejected": rewards_rejected.mean().item(), 298 | "embeddings_chosen": embeddings_chosen.mean().item(), 299 | "embeddings_rejected": embeddings_rejected.mean().item(), 300 | "mean": mean.mean().item(), 301 | "log_var": log_var.mean().item() 302 | } 303 | ) 304 | if return_outputs: 305 | return loss, { 306 | "rewards_chosen": rewards_chosen, 307 | "rewards_rejected": rewards_rejected, 308 | "mean": mean, 309 | "log_var": log_var, 310 | "z": z, 311 | "user_type": user_type, 312 | } 313 | return loss 314 | 315 | def create_scheduler(self, num_training_steps: int, optimizer=None): 316 | scheduler = get_cosine_schedule_with_warmup( 317 | optimizer, 318 | num_warmup_steps=int(0.03 * num_training_steps), 319 | num_training_steps=num_training_steps 320 | ) 321 | self.lr_scheduler = scheduler 322 | return scheduler 323 | 324 | @classmethod 325 | def compute_metrics(cls, eval_prediction: EvalPrediction): 326 | rewards_chosen, rewards_rejected, mean, log_var, z, user_type = ( 327 | eval_prediction.predictions 328 | ) 329 | rewards_chosen = torch.from_numpy(rewards_chosen) 330 | rewards_rejected = torch.from_numpy(rewards_rejected) 331 | mean = torch.from_numpy(mean) 332 | log_var = torch.from_numpy(log_var) 333 | z = torch.from_numpy(z) 334 | loss = cls.per_sample_loss(rewards_chosen, rewards_rejected) 335 | kld = -torch.sum(1 + log_var - mean.pow(2) - log_var.exp(), dim=-1) 336 | accuracy = torch.mean((loss < np.log(2)).float()) 337 | 338 | def plot_latent(latent): 339 | from sklearn.manifold import TSNE 340 | z_embedding = TSNE(n_components=2, init='random', perplexity=20, learning_rate="auto").fit_transform(latent.numpy()) 341 | import matplotlib.pyplot as plt 342 | colors = [f"C{int(i)}" for i in user_type] 343 | plt.scatter(z_embedding[:, 0], z_embedding[:, 1], c=colors) 344 | im = wandb.Image(plt) 345 | plt.close() 346 | return im 347 | im1 = plot_latent(mean) 348 | im2 = plot_latent(z) 349 | 350 | return { 351 | "loss": loss.mean().item(), 352 | "accuracy": accuracy.item(), 353 | "kld": kld.mean().item(), 354 | "mean_embeddings": im1, 355 | "z_embeddings": im2, 356 | } 357 | 358 | 359 | class Annealer: 360 | """ 361 | This class is used to anneal the KL divergence loss over the course of training VAEs. 362 | After each call, the step() function should be called to update the current epoch. 363 | """ 364 | 365 | def __init__(self, total_steps, shape, baseline=0.0, cyclical=False, disable=False): 366 | """ 367 | Parameters: 368 | total_steps (int): Number of epochs to reach full KL divergence weight. 369 | shape (str): Shape of the annealing function. Can be 'linear', 'cosine', or 'logistic'. 370 | baseline (float): Starting value for the annealing function [0-1]. Default is 0.0. 371 | cyclical (bool): Whether to repeat the annealing cycle after total_steps is reached. 372 | disable (bool): If true, the __call__ method returns unchanged input (no annealing). 373 | """ 374 | self.total_steps = total_steps 375 | self.current_step = 0 376 | self.cyclical = cyclical 377 | self.shape = shape 378 | self.baseline = baseline 379 | if disable: 380 | self.shape = "none" 381 | self.baseline = 0.0 382 | 383 | def __call__(self, kld): 384 | """ 385 | Args: 386 | kld (torch.tensor): KL divergence loss 387 | Returns: 388 | out (torch.tensor): KL divergence loss multiplied by the slope of the annealing function. 389 | """ 390 | out = kld * self.slope() 391 | return out 392 | 393 | def slope(self): 394 | if self.shape == "linear": 395 | y = self.current_step / self.total_steps 396 | elif self.shape == "cosine": 397 | y = (math.cos(math.pi * (self.current_step / self.total_steps - 1)) + 1) / 2 398 | elif self.shape == "logistic": 399 | exponent = (self.total_steps / 2) - self.current_step 400 | y = 1 / (1 + math.exp(exponent)) 401 | elif self.shape == "none": 402 | y = 1.0 403 | else: 404 | raise ValueError( 405 | "Invalid shape for annealing function. Must be linear, cosine, or logistic." 406 | ) 407 | y = self.add_baseline(y) 408 | return y 409 | 410 | def step(self): 411 | if self.current_step < self.total_steps: 412 | self.current_step += 1 413 | if self.cyclical and self.current_step >= self.total_steps: 414 | self.current_step = 0 415 | return 416 | 417 | def add_baseline(self, y): 418 | y_out = y * (1 - self.baseline) + self.baseline 419 | return y_out 420 | 421 | def cyclical_setter(self, value): 422 | if value is not bool: 423 | raise ValueError( 424 | "Cyclical_setter method requires boolean argument (True/False)" 425 | ) 426 | else: 427 | self.cyclical = value 428 | return 429 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers 2 | datasets>=1.17.0 3 | peft>=0.3.0 4 | sentencepiece 5 | diskcache==5.6.3 6 | pandas 7 | torch 8 | openai==0.28.1 9 | matplotlib==3.7.1 10 | scipy==1.10.1 11 | -------------------------------------------------------------------------------- /submit_job_UF_P_2.sh: -------------------------------------------------------------------------------- 1 | export WANDB_MODE=online 2 | export WANDB_PROJECT=vpl 3 | export NCCL_P2P_DISABLE="1" 4 | export NCCL_IB_DISABLE="1" 5 | 6 | # Set model_name to be 'gpt2' or 'meta-llama/Llama-2-7b-hf' here 7 | model_name='gpt2' 8 | 9 | model_type=$1 10 | augment_type="84" 11 | 12 | if [ ${model_type} == "vae" ]; 13 | then 14 | # Train VPL on UltraFeedback two-user dataset 15 | python -m hidden_context.train_llm_vae_preference_model \ 16 | --model_name=${model_name} \ 17 | --data_path="data/P_survey_100/gpt2" \ 18 | --num_train_epochs=2 \ 19 | --reward_model_type=vae \ 20 | --data_subset=all \ 21 | --log_dir="logs/gpt2_P_survey_100" \ 22 | --bf16 True \ 23 | --fp16 False \ 24 | --per_device_train_batch_size 4 \ 25 | --gradient_accumulation_steps 8 \ 26 | --latent_dim 512 \ 27 | --hidden_dim 512 \ 28 | --learning_rate 1e-4 \ 29 | --use_annealing True \ 30 | --kl_loss_weight 0 \ 31 | --controversial_only True \ 32 | --fixed_contexts True \ 33 | --fixed_llm_embeddings False \ 34 | --up_sampling False \ 35 | --other_subsets ${augment_type} \ 36 | --use_last_token_embedding True \ 37 | --seed 0 38 | else 39 | # Train baseline models on UltraFeedback two-user dataset 40 | python -m hidden_context.train_llm_preference_model \ 41 | --model_name=${model_name} \ 42 | --data_path="data/P_survey_100/gpt2" \ 43 | --num_train_epochs=2 \ 44 | --reward_model_type=${model_type} \ 45 | --data_subset=all \ 46 | --log_dir="logs/gpt2_P_survey_100" \ 47 | --bf16 True \ 48 | --fp16 False \ 49 | --per_device_train_batch_size 4 \ 50 | --gradient_accumulation_steps 8 \ 51 | --learning_rate 1e-4 \ 52 | --controversial_only True \ 53 | --up_sampling False \ 54 | --other_subsets ${augment_type} \ 55 | --seed 0 56 | fi 57 | -------------------------------------------------------------------------------- /submit_job_UF_P_4.sh: -------------------------------------------------------------------------------- 1 | export WANDB_MODE=online 2 | export WANDB_PROJECT=vpl 3 | export NCCL_P2P_DISABLE="1" 4 | export NCCL_IB_DISABLE="1" 5 | 6 | # Set model_name to be 'gpt2' or 'meta-llama/Llama-2-7b-hf' here 7 | model_name='gpt2' 8 | 9 | model_type=$1 10 | 11 | if [ ${model_type} == "vae" ]; 12 | then 13 | # Train VPL on UltraFeedback four-user dataset 14 | python -m hidden_context.train_llm_vae_preference_model \ 15 | --model_name=${model_name} \ 16 | --data_path="data/P_4_survey_100/gpt2" \ 17 | --num_train_epochs=2 \ 18 | --reward_model_type=vae \ 19 | --data_subset=all \ 20 | --log_dir="logs/gpt2_P_4_survey_100" \ 21 | --bf16 True \ 22 | --fp16 False \ 23 | --per_device_train_batch_size 4 \ 24 | --gradient_accumulation_steps 8 \ 25 | --latent_dim 512 \ 26 | --hidden_dim 512 \ 27 | --learning_rate 1e-4 \ 28 | --use_annealing True \ 29 | --kl_loss_weight 3e-6 \ 30 | --controversial_only True \ 31 | --fixed_contexts True \ 32 | --fixed_llm_embeddings False \ 33 | --up_sampling False \ 34 | --other_subsets single \ 35 | --use_last_token_embedding True \ 36 | --seed 0 37 | else 38 | # Train baseline models on UltraFeedback four-user dataset 39 | python -m hidden_context.train_llm_preference_model \ 40 | --model_name=${model_name} \ 41 | --data_path="data/P_4_survey_100/gpt2" \ 42 | --num_train_epochs=2 \ 43 | --reward_model_type=${model_type} \ 44 | --data_subset=all \ 45 | --log_dir="logs/gpt2_P_4_survey_100" \ 46 | --bf16 True \ 47 | --fp16 False \ 48 | --per_device_train_batch_size 4 \ 49 | --gradient_accumulation_steps 8 \ 50 | --learning_rate 1e-4 \ 51 | --controversial_only True \ 52 | --up_sampling False \ 53 | --other_subsets single \ 54 | --seed 0 55 | fi 56 | -------------------------------------------------------------------------------- /submit_job_pets.sh: -------------------------------------------------------------------------------- 1 | export WANDB_MODE=online 2 | export WANDB_PROJECT=vpl 3 | export NCCL_P2P_DISABLE="1" 4 | export NCCL_IB_DISABLE="1" 5 | 6 | # Set model_name to be 'gpt2' or 'meta-llama/Llama-2-7b-hf' here 7 | model_name='gpt2' 8 | 9 | # full (up-sampling): --controversial_only False --up_sampling True 10 | # controversial: --controversial_only True 11 | 12 | # Reminder: for controversial settings, please use --num_train_epochs=10 13 | 14 | model_type=$1 15 | 16 | if [ ${model_type} == "vae" ] 17 | then 18 | # Train VPL on full/controversial/up-sampling Pets dataset 19 | python -m hidden_context.train_llm_vae_preference_model \ 20 | --model_name=${model_name} \ 21 | --data_path="data/simple_pets/gpt2" \ 22 | --num_train_epochs=2 \ 23 | --reward_model_type=vae \ 24 | --data_subset=both \ 25 | --log_dir="logs/gpt2_simple_pets" \ 26 | --bf16 True \ 27 | --fp16 False \ 28 | --per_device_train_batch_size 4 \ 29 | --gradient_accumulation_steps 8 \ 30 | --latent_dim 512 \ 31 | --hidden_dim 512 \ 32 | --learning_rate 3e-4 \ 33 | --use_annealing True \ 34 | --kl_loss_weight 1e-4 \ 35 | --fixed_contexts True \ 36 | --fixed_llm_embeddings False \ 37 | --use_last_token_embedding True \ 38 | --up_sampling True \ 39 | --controversial_only False \ 40 | --seed 0 41 | else 42 | # Train baseline models on full/controversial/up-sampling Pets dataset 43 | python -m hidden_context.train_llm_preference_model \ 44 | --model_name=${model_name} \ 45 | --data_path="data/simple_pets/gpt2" \ 46 | --num_train_epochs=2 \ 47 | --reward_model_type=${model_type} \ 48 | --data_subset=both \ 49 | --log_dir="logs/gpt2_simple_pets" \ 50 | --bf16 True \ 51 | --fp16 False \ 52 | --per_device_train_batch_size 4 \ 53 | --gradient_accumulation_steps 8 \ 54 | --learning_rate 1e-4 \ 55 | --controversial_only False \ 56 | --up_sampling True \ 57 | --seed 0 58 | fi 59 | --------------------------------------------------------------------------------