├── .gitignore
├── README.md
├── data
└── relabeled_hh_rlhf
│ ├── both
│ ├── test.jsonl.gz
│ └── train.jsonl.gz
│ ├── harmless
│ ├── test.jsonl.gz
│ └── train.jsonl.gz
│ └── helpful
│ ├── test.jsonl.gz
│ └── train.jsonl.gz
├── generate_llm_embeddings_UF_P_2.sh
├── generate_llm_embeddings_UF_P_4.sh
├── generate_llm_embeddings_pets.sh
├── hidden_context
├── __init__.py
├── data_utils
│ ├── __init__.py
│ ├── add_survey_contexts.py
│ ├── data_processing.py
│ ├── generate_simple_data.py
│ ├── simple_templates.py
│ └── ultrafeedback_augment.py
├── train_llm_preference_model.py
├── train_llm_vae_preference_model.py
└── vae_utils.py
├── requirements.txt
├── submit_job_UF_P_2.sh
├── submit_job_UF_P_4.sh
└── submit_job_pets.sh
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
98 | __pypackages__/
99 |
100 | # Celery stuff
101 | celerybeat-schedule
102 | celerybeat.pid
103 |
104 | # SageMath parsed files
105 | *.sage.py
106 |
107 | # Environments
108 | .env
109 | .venv
110 | env/
111 | venv/
112 | ENV/
113 | env.bak/
114 | venv.bak/
115 |
116 | # Spyder project settings
117 | .spyderproject
118 | .spyproject
119 |
120 | # Rope project settings
121 | .ropeproject
122 |
123 | # mkdocs documentation
124 | /site
125 |
126 | # mypy
127 | .mypy_cache/
128 | .dmypy.json
129 | dmypy.json
130 |
131 | # Pyre type checker
132 | .pyre/
133 |
134 | # pytype static type analyzer
135 | .pytype/
136 |
137 | # Cython debug symbols
138 | cython_debug/
139 | logs/
140 | runs*/
141 |
142 | # jupyter notebook
143 | notebooks/
144 |
145 | # vscode
146 | .vscode/
147 |
148 | # output folder
149 | results/
150 | wandb/
151 | slurm*/
152 | data/
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Personalizing Reinforcement Learning from Human Feedback with Variational Preference Learning
2 |
3 | #### [[Website]](https://weirdlabuw.github.io/vpl/) [[Paper]](https://arxiv.org/)
4 |
5 | [Sriyash Poddar1](https://sriya.sh), [Yanming Wan1](https://wanyanming.com/), [Hamish Ivison1](https://hamishivi.github.io/), [Abhishek Gupta1](https://homes.cs.washington.edu/~abhgupta), [Natasha Jaques1](https://natashajaques.ai)
6 |
7 | 1University of Washington
8 |
9 | This repo is an implementation of the language experiments of VPL. VPL is a variational framework for learning from human feedback (binary preference labels) i.e. inferring a novel user-specific latent and learning reward models and policies conditioned on this latent without additional user-specific data. This is used for quick adaptation to specific user preferences without retraining the entire model or ignoring underrepresented groups.
10 |
11 | For control experiments of VPL, please refer to [here](https://github.com/WEIRDLabUW/vpl).
12 |
13 | ## Instructions
14 |
15 |
16 | #### Setting up repo
17 | ```
18 | git clone git@github.com:WEIRDLabUW/vpl_llm.git
19 | ```
20 |
21 | #### Install Dependencies
22 | ```
23 | conda create -n vpl python=3.10
24 | conda activate vpl
25 | pip install -r requirements.txt
26 | ```
27 |
28 | ## Data and Pretrained models
29 |
30 | Our datasets and checkpoints can be downloaded from [Google Drive link](https://drive.google.com/drive/folders/1dQ8zpNefRAtUB9TtbovOSn2MfV2Y-MbC?usp=sharing).
31 |
32 | #### Datasets
33 | The datasets needed for VPL experiments should be downloaded and unzipped to ``./data/``. There are three datasets in the folder: ``simple_pets``, ``P_survey_100``, and ``P_4_survey_100``.
34 |
35 | #### Checkpoints
36 | The checkpoints for VPL experiments should be downloaded to ``./logs``. We provide the checkpoints for VPL and other baseline models over each dataset.
37 |
38 | ## Dataset Generation
39 | We also provide the code for generating our datasets.
40 | The following scripts will also give you the datasets in ``./data/``.
41 | #### Pets
42 | To generate ``./data/simple_pets``, run
43 | ```bash
44 | bash generate_llm_embeddings_pets.sh gpt2
45 | ```
46 |
47 | #### UF-P-2
48 | To generate ``./data/P_survey_100``, run
49 | ```bash
50 | python -m hidden_context.data_utils.ultrafeedback_augment -a 84 -n P
51 | bash generate_llm_embeddings_UF_P_2.sh gpt2 84
52 | ```
53 |
54 | #### UF-P-4
55 | To generate ``./data/P_4_survey_100``, run
56 | ```bash
57 | python -m hidden_context.data_utils.ultrafeedback_augment -a single -n P_4 -c
58 | bash generate_llm_embeddings_UF_P_4.sh gpt2 single
59 | ```
60 |
61 | ## Running Experiments
62 | In all the following scripts, ```` can be chosen from ``vae``, ``base``, ``categorical``, and ``mean_and_variance``.
63 | ``vae`` corresponds to our VPL models, while the others are training baseline models.
64 |
65 | The results are recorded on Wandb. Please refer to ``eval/accuracy`` on Wandb page for model's performance.
66 |
67 | #### Pets
68 | To train models on ``./data/simple_pets``, run
69 | ```bash
70 | bash submit_job_pets.sh
71 | ```
72 | Note that the default settings are for Pets (full),
73 | please change the arguments as explained in the bash file if you want to train on Pets (controversial).
74 |
75 | #### UF-P-2
76 | To train models on ``./data/P_survey_100``, run
77 | ```bash
78 | bash submit_job_UF_P_2.sh
79 | ```
80 |
81 | #### UF-P-4
82 | To train models on ``./data/P_4_survey_100``, run
83 | ```bash
84 | bash submit_job_UF_P_4.sh
85 | ```
86 |
87 | ## Bibtex
88 | If you find this code useful, please cite:
89 |
90 | ```
91 | @article{poddar2024vpl,
92 | author = {Poddar, Sriyash and Wan, Yanming and Ivision, Hamish and Gupta, Abhishek and Jaques, Natasha},
93 | title = {Personalizing Reinforcement Learning from Human Feedback with Variational Preference Learning},
94 | booktitle = {ArXiv Preprint},
95 | year = {2024},
96 | }
97 | ```
--------------------------------------------------------------------------------
/data/relabeled_hh_rlhf/both/test.jsonl.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WEIRDLabUW/vpl_llm/fdfdfa767752b5af7f89f2e6ab6989c1953062c9/data/relabeled_hh_rlhf/both/test.jsonl.gz
--------------------------------------------------------------------------------
/data/relabeled_hh_rlhf/both/train.jsonl.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WEIRDLabUW/vpl_llm/fdfdfa767752b5af7f89f2e6ab6989c1953062c9/data/relabeled_hh_rlhf/both/train.jsonl.gz
--------------------------------------------------------------------------------
/data/relabeled_hh_rlhf/harmless/test.jsonl.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WEIRDLabUW/vpl_llm/fdfdfa767752b5af7f89f2e6ab6989c1953062c9/data/relabeled_hh_rlhf/harmless/test.jsonl.gz
--------------------------------------------------------------------------------
/data/relabeled_hh_rlhf/harmless/train.jsonl.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WEIRDLabUW/vpl_llm/fdfdfa767752b5af7f89f2e6ab6989c1953062c9/data/relabeled_hh_rlhf/harmless/train.jsonl.gz
--------------------------------------------------------------------------------
/data/relabeled_hh_rlhf/helpful/test.jsonl.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WEIRDLabUW/vpl_llm/fdfdfa767752b5af7f89f2e6ab6989c1953062c9/data/relabeled_hh_rlhf/helpful/test.jsonl.gz
--------------------------------------------------------------------------------
/data/relabeled_hh_rlhf/helpful/train.jsonl.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WEIRDLabUW/vpl_llm/fdfdfa767752b5af7f89f2e6ab6989c1953062c9/data/relabeled_hh_rlhf/helpful/train.jsonl.gz
--------------------------------------------------------------------------------
/generate_llm_embeddings_UF_P_2.sh:
--------------------------------------------------------------------------------
1 |
2 | # Set model_type to be 'gpt2' or 'llama' for model_type
3 | # Set other_subsets to be 'ultra_feedback', 'single', or '84'
4 | model_type=$1
5 | other_subsets=$2
6 |
7 | # Generate LLM embeddings for UltraFeedback dataset
8 | if [ "${other_subsets}" = "ultra_feedback" ]; then
9 | subsets="helpfulness honesty instruction_following truthfulness"
10 | elif [ "${other_subsets}" = "single" ]; then
11 | subsets="8 4 2 1"
12 | elif [ "${other_subsets}" = "84" ]; then
13 | subsets="8 4"
14 | else
15 | echo "Invalid!"
16 | fi
17 |
18 | echo "${subsets}"
19 |
20 |
21 | # Final version for two users
22 | survey_size=100
23 | for subset in ${subsets}
24 | do
25 | python -m hidden_context.data_utils.add_survey_contexts --output_dir "data/P_survey_${survey_size}/" \
26 | --data_path "data/UltraFeedback_${other_subsets}_P" --data_subset ${subset} --data_split train --model_type ${model_type} \
27 | --other_subsets ${other_subsets} --with_embeddings True --survey_size $survey_size --num_duplicates 8 --fixed_context_length True
28 |
29 | python -m hidden_context.data_utils.add_survey_contexts --output_dir "data/P_survey_${survey_size}/" \
30 | --data_path "data/UltraFeedback_${other_subsets}_P" --data_subset ${subset} --data_split test --model_type ${model_type} \
31 | --other_subsets ${other_subsets} --with_embeddings True --survey_size $survey_size --num_duplicates 8 --fixed_context_length True
32 | done
33 |
--------------------------------------------------------------------------------
/generate_llm_embeddings_UF_P_4.sh:
--------------------------------------------------------------------------------
1 |
2 | # Set model_type to be 'gpt2' or 'llama' for model_type
3 | # Set other_subsets to be 'ultra_feedback', 'single', or '84'
4 | model_type=$1
5 | other_subsets=$2
6 |
7 | # Generate LLM embeddings for UltraFeedback dataset
8 | if [ "${other_subsets}" = "ultra_feedback" ]; then
9 | subsets="helpfulness honesty instruction_following truthfulness"
10 | elif [ "${other_subsets}" = "single" ]; then
11 | subsets="8 4 2 1"
12 | elif [ "${other_subsets}" = "84" ]; then
13 | subsets="8 4"
14 | else
15 | echo "Invalid!"
16 | fi
17 |
18 | echo "${subsets}"
19 |
20 |
21 | # Final version for four users
22 | survey_size=100
23 | for subset in ${subsets}
24 | do
25 | python -m hidden_context.data_utils.add_survey_contexts --output_dir "data/P_4_survey_${survey_size}/" \
26 | --data_path "data/UltraFeedback_${other_subsets}_P_4" --data_subset ${subset} --data_split train --model_type ${model_type} \
27 | --other_subsets ${other_subsets} --with_embeddings True --survey_size $survey_size --num_duplicates 4
28 |
29 | python -m hidden_context.data_utils.add_survey_contexts --output_dir "data/P_4_survey_${survey_size}/" \
30 | --data_path "data/UltraFeedback_${other_subsets}_P_4" --data_subset ${subset} --data_split test --model_type ${model_type} \
31 | --other_subsets ${other_subsets} --with_embeddings True --survey_size $survey_size --num_duplicates 4
32 | done
33 |
--------------------------------------------------------------------------------
/generate_llm_embeddings_pets.sh:
--------------------------------------------------------------------------------
1 |
2 | # Set model_type to be 'gpt2' or 'llama' here
3 | model_type=$1
4 |
5 | # Generate Pets dataset
6 | python -m hidden_context.data_utils.generate_simple_data --output_dir data/simple_pets/ \
7 | --data_path data/relabeled_hh_rlhf --with_embeddings True --synthetic_dataset True \
8 | --model_type ${model_type} --data_subset helpful --data_split test --dataset_size 200
9 |
10 | python -m hidden_context.data_utils.generate_simple_data --output_dir data/simple_pets/ \
11 | --data_path data/relabeled_hh_rlhf --with_embeddings True --synthetic_dataset True \
12 | --model_type ${model_type} --data_subset helpful --data_split train --dataset_size 2000
13 |
14 | python -m hidden_context.data_utils.generate_simple_data --output_dir data/simple_pet/ \
15 | --data_path data/relabeled_hh_rlhf --with_embeddings True --synthetic_dataset True \
16 | --model_type ${model_type} --data_subset harmless --data_split test --dataset_size 200
17 |
18 | python -m hidden_context.data_utils.generate_simple_data --output_dir data/simple_pets/ \
19 | --data_path data/relabeled_hh_rlhf --with_embeddings True --synthetic_dataset True \
20 | --model_type ${model_type} --data_subset harmless --data_split train --dataset_size 2000
21 |
--------------------------------------------------------------------------------
/hidden_context/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WEIRDLabUW/vpl_llm/fdfdfa767752b5af7f89f2e6ab6989c1953062c9/hidden_context/__init__.py
--------------------------------------------------------------------------------
/hidden_context/data_utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WEIRDLabUW/vpl_llm/fdfdfa767752b5af7f89f2e6ab6989c1953062c9/hidden_context/data_utils/__init__.py
--------------------------------------------------------------------------------
/hidden_context/data_utils/add_survey_contexts.py:
--------------------------------------------------------------------------------
1 | # This file is used to preprocess dataset, available for any HH-RLHF format datasets
2 | import os
3 | from dataclasses import dataclass, field
4 | from typing import Optional, cast
5 | from tqdm import tqdm
6 | import random
7 |
8 | from transformers import (
9 | HfArgumentParser,
10 | )
11 |
12 | import torch
13 |
14 | from hidden_context.train_llm_preference_model import (
15 | concatenate_datasets,
16 | )
17 |
18 | from hidden_context.data_utils.data_processing import generate_embeddings_with_llm
19 | from datasets import load_dataset
20 |
21 | from copy import deepcopy
22 |
23 | import numpy as np
24 |
25 |
26 | @dataclass
27 | class ScriptArguments:
28 | output_dir: Optional[str] = field(
29 | metadata={"help": "Directory where the new dataset will be stored."},
30 | )
31 | data_path: str = field(
32 | metadata={"help": "Directory where the original data is stored."}
33 | )
34 | data_subset: str = field(
35 | default="helpful",
36 | metadata={
37 | "help": "Which subset of the data to use. You can choose between"
38 | "'helpful', or 'harmless'."
39 | },
40 | )
41 | data_split: str = field(
42 | default="test",
43 | metadata={
44 | "help": "Which split of the data to use. You can choose between"
45 | "'train', or 'test'."
46 | },
47 | )
48 | dataset_size: int = field(
49 | default=0,
50 | metadata={"help": "The size of the subset of the data to use"},
51 | )
52 | model_type: str = field(
53 | default="none",
54 | metadata={
55 | "help": "You can choose between 'gpt2', 'llama', or 'none'."
56 | }
57 | )
58 | embed_dim: int = field(
59 | default=1024,
60 | metadata={
61 | "help": "Dimension of the embeddings generated by LLM."
62 | }
63 | )
64 | max_length: int = field(
65 | default=1024,
66 | metadata={
67 | "help": "Maximum token length of the inputs."
68 | }
69 | )
70 | with_embeddings: bool = field(
71 | default=True,
72 | metadata={
73 | "help": "Whether the embeddings are generated during pre-processing."
74 | }
75 | )
76 | synthetic_dataset: bool = field(
77 | default=False,
78 | metadata={
79 | "help": "Whether a synthetic dataset is used."
80 | }
81 | )
82 | other_subsets: str = field(default=None)
83 | survey_size: int = field(
84 | default=8,
85 | metadata={
86 | "help": "Size of survey question pool."
87 | }
88 | )
89 | context_length: int = field(
90 | default=8,
91 | metadata={
92 | "help": "(Maximum) context length."
93 | }
94 | )
95 | controversial_only: bool = field(
96 | default=True,
97 | metadata={
98 | "help": "Whether to only generate controversial data points."
99 | }
100 | )
101 | num_duplicates: int = field(
102 | default=1,
103 | metadata={
104 | "help": "Number of times each data point repeatedly appears in the dataset (with resampled context)."
105 | }
106 | )
107 | fixed_context_length: bool = field(
108 | default=False,
109 | metadata={
110 | "help": "Whether to fix the context to the maximum length."
111 | }
112 | )
113 | random_contexts: bool = field(
114 | default=False,
115 | metadata={
116 | "help": "Whether to include controversial pairs in context."
117 | }
118 | )
119 |
120 |
121 | def generate_contexts(args, input_dataset, survey_dataset):
122 | # Generate context with survey question pool
123 | output_dir = os.path.join(args.output_dir, f"{args.model_type}", f"{args.data_subset}")
124 | if not os.path.exists(output_dir):
125 | os.makedirs(output_dir)
126 |
127 | if args.controversial_only:
128 | input_dataset = input_dataset.filter(lambda x: x['controversial'] == True)
129 | dataset_size = len(input_dataset)
130 | if args.data_split == 'train':
131 | K = args.num_duplicates # repeat samples for K times
132 | else:
133 | K = 1
134 | dataset_list = list()
135 |
136 | def random_choice(max_context_length, survey_size):
137 | if max_context_length <= survey_size:
138 | from functools import reduce
139 | while True:
140 | if args.fixed_context_length:
141 | context_length = max_context_length
142 | else:
143 | if args.other_subsets == '84':
144 | context_length = random.randint(1, max_context_length)
145 | else:
146 | context_length = random.randint(2, max_context_length)
147 | context_chosen_ids = np.random.choice(survey_size, context_length, replace=False)
148 | chosen_dataset = [d for idx, d in enumerate(survey_dataset) if idx in context_chosen_ids]
149 | if args.other_subsets != 'single':
150 | return chosen_dataset, context_length
151 | satisfied_sets = list()
152 | for row in chosen_dataset:
153 | satisfied_sets.append(set(row["satisfied_subset"]))
154 | if len(reduce(lambda x, y: x.intersection(y), satisfied_sets)) == 1:
155 | return chosen_dataset, context_length
156 | elif context_length == survey_size:
157 | raise ValueError("Please choose another random seed!")
158 | else:
159 | raise ValueError("Context length is larger than survey size!")
160 |
161 | for idx in range(K):
162 | output_dataset = deepcopy(input_dataset)
163 | context_lengths = list()
164 | contexts = list()
165 | for _ in tqdm(range(dataset_size)): # iterate over all samples in original dataset
166 | row_contexts = list()
167 |
168 | context_dataset, context_length = random_choice(args.context_length, args.survey_size)
169 | context_lengths.append(context_length)
170 | for context_row in context_dataset:
171 | context_id = context_row["Index"]
172 | context_data = context_row
173 | if not args.with_embeddings:
174 | row_contexts.append({
175 | 'original_id': context_id,
176 | 'chosen': context_data['chosen'],
177 | 'rejected': context_data['rejected'],
178 | })
179 | else:
180 | row_contexts.append({
181 | 'original_id': context_id,
182 | 'embedding_chosen': context_data['embeddings']['embedding_chosen'],
183 | 'embedding_rejected': context_data['embeddings']['embedding_rejected'],
184 | })
185 | contexts.append(row_contexts)
186 | output_dataset = output_dataset.add_column("context_length", context_lengths)
187 | output_dataset = output_dataset.add_column("contexts", contexts)
188 | output_dataset.map()
189 | dataset_list.append(output_dataset)
190 | output = concatenate_datasets(dataset_list)
191 | output.to_json(os.path.join(output_dir, f"{args.data_split}.jsonl"))
192 | return output
193 |
194 |
195 | if __name__ == "__main__":
196 | # default setting on HH-RLHF dataset, please iterate over data subsets and data splits
197 | seed = 0
198 | random.seed(seed)
199 | np.random.seed(seed)
200 | torch.manual_seed(seed)
201 | torch.cuda.manual_seed(seed)
202 | parser = HfArgumentParser(ScriptArguments)
203 | script_args: ScriptArguments = parser.parse_args_into_dataclasses()[0]
204 | print(script_args)
205 | dataset = generate_embeddings_with_llm(script_args)
206 | if not script_args.random_contexts:
207 | survey_options = dataset.filter(lambda x: x['survey_options'] == True)
208 | else:
209 | survey_options = dataset.filter(lambda x: x['survey_options'] == True or x['survey_options'] == False)
210 | survey_ids = np.random.choice(range(len(survey_options)), script_args.survey_size, replace=False)
211 | print(survey_ids)
212 | if script_args.data_split == "train":
213 | survey_data = survey_options.filter(lambda example, idx: idx in survey_ids, with_indices=True)
214 | survey_data.to_json(os.path.join(script_args.data_path, script_args.data_subset, "survey_{}.jsonl".format(script_args.survey_size)))
215 | else:
216 | survey_data = load_dataset('json', data_files=os.path.join(script_args.data_path, script_args.data_subset, "survey_{}.jsonl".format(script_args.survey_size)))
217 | survey_data = survey_data['train']
218 | generate_contexts(script_args, dataset, survey_data)
219 |
--------------------------------------------------------------------------------
/hidden_context/data_utils/data_processing.py:
--------------------------------------------------------------------------------
1 | # This file is used to preprocess dataset, available for any HH-RLHF format datasets
2 | import os
3 | from dataclasses import dataclass, field
4 | from typing import Optional, cast
5 | from tqdm import tqdm
6 | import random
7 |
8 | from transformers import (
9 | HfArgumentParser,
10 | AutoTokenizer,
11 | AutoModelForSequenceClassification,
12 | AutoModelForCausalLM,
13 | )
14 |
15 | import torch
16 |
17 | from hidden_context.train_llm_preference_model import (
18 | DataSubset,
19 | get_hh_rlhf_dataset,
20 | concatenate_datasets,
21 | HHRLHFPreprocessor,
22 | )
23 |
24 | from copy import deepcopy
25 |
26 | import numpy as np
27 |
28 | @dataclass
29 | class ScriptArguments:
30 | output_dir: Optional[str] = field(
31 | metadata={"help": "Directory where the new dataset will be stored."},
32 | )
33 | data_path: str = field(
34 | metadata={"help": "Directory where the original data is stored."}
35 | )
36 | data_subset: str = field(
37 | default="helpful",
38 | metadata={
39 | "help": "Which subset of the data to use. You can choose between"
40 | "'helpful', or 'harmless'."
41 | },
42 | )
43 | data_split: str = field(
44 | default="test",
45 | metadata={
46 | "help": "Which split of the data to use. You can choose between"
47 | "'train', or 'test'."
48 | },
49 | )
50 | dataset_size: int = field(
51 | default=0,
52 | metadata={"help": "The size of the subset of the data to use"},
53 | )
54 | model_type: str = field(
55 | default="none",
56 | metadata={
57 | "help": "You can choose between 'gpt2', 'llama', or 'none'."
58 | }
59 | )
60 | embed_dim: int = field(
61 | default=1024,
62 | metadata={
63 | "help": "Dimension of the embeddings generated by LLM."
64 | }
65 | )
66 | max_length: int = field(
67 | default=1024,
68 | metadata={
69 | "help": "Maximum token length of the inputs."
70 | }
71 | )
72 | with_embeddings: bool = field(
73 | default=True,
74 | metadata={
75 | "help": "Whether the embeddings are generated during pre-processing."
76 | }
77 | )
78 | synthetic_dataset: bool = field(
79 | default=False,
80 | metadata={
81 | "help": "Whether a synthetic dataset is used."
82 | }
83 | )
84 | other_subsets: str = field(default=None)
85 |
86 |
87 | def generate_embeddings_with_llm(args, input_dataset=None):
88 | """
89 | This function is used to generate fixed embeddings for inputs from original dataset.
90 | """
91 | if not args.synthetic_dataset:
92 | data_subset = cast(DataSubset, args.data_subset)
93 | input_dataset = get_hh_rlhf_dataset(
94 | data_subset,
95 | args.data_split,
96 | args.dataset_size,
97 | data_path=args.data_path,
98 | use_subset_as_dir=True,
99 | other_subsets=args.other_subsets,
100 | )
101 |
102 | if args.model_type == "gpt2":
103 | tokenizer = AutoTokenizer.from_pretrained("gpt2", use_auth_token=True)
104 | model = AutoModelForSequenceClassification.from_pretrained(
105 | "gpt2", num_labels=args.embed_dim, torch_dtype=torch.bfloat16
106 | )
107 | model.score.weight.data *= 0.01
108 | elif args.model_type == "llama" or args.model_type == "meta-llama/Llama-2-7b-hf":
109 | tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", use_auth_token=True, add_eos_token=False)
110 | model = AutoModelForCausalLM.from_pretrained(
111 | "meta-llama/Llama-2-7b-hf", torch_dtype=torch.bfloat16
112 | )
113 | else:
114 | return input_dataset
115 | model.to("cuda")
116 |
117 | tokenizer.pad_token = tokenizer.eos_token
118 | tokenizer.pad_token_id = tokenizer.eos_token_id
119 | tokenizer.padding_side = "right"
120 |
121 | model.config.pad_token_id = tokenizer.pad_token_id
122 | dataset_size = len(input_dataset)
123 | print(dataset_size)
124 |
125 | preprocessed_dataset = input_dataset.map(
126 | HHRLHFPreprocessor(tokenizer),
127 | batched=True,
128 | num_proc=24,
129 | remove_columns=input_dataset.column_names,
130 | load_from_cache_file=False,
131 | )
132 |
133 | input_dataset = input_dataset.filter(
134 | lambda example, idx: len(preprocessed_dataset[idx]["input_ids_chosen"]) <= args.max_length
135 | and len(preprocessed_dataset[idx]["input_ids_rejected"]) <= args.max_length,
136 | with_indices=True
137 | )
138 | preprocessed_dataset = preprocessed_dataset.filter(
139 | lambda example: len(example["input_ids_chosen"]) <= args.max_length
140 | and len(example["input_ids_rejected"]) <= args.max_length
141 | )
142 | print(len(input_dataset), len(preprocessed_dataset))
143 | dataset_size = len(preprocessed_dataset)
144 |
145 | embeddings = list()
146 | for row_id in tqdm(range(dataset_size)):
147 | emb = dict()
148 | for key in ['chosen', 'rejected']:
149 | tokens = tokenizer.pad(
150 | {"input_ids": preprocessed_dataset[row_id][f"input_ids_{key}"]},
151 | padding=True, pad_to_multiple_of=64, return_tensors="pt"
152 | )
153 | token_length = len(preprocessed_dataset[row_id][f"input_ids_{key}"])
154 | input_ids = tokens["input_ids"].unsqueeze(0).to("cuda")
155 | attention_mask = tokens["attention_mask"].unsqueeze(0).to("cuda")
156 | with torch.no_grad():
157 | last_hidden_state = model(
158 | input_ids=input_ids,
159 | attention_mask=attention_mask,
160 | output_hidden_states=True
161 | ).hidden_states[-1]
162 | emb[f"embedding_{key}"] = last_hidden_state[0][token_length - 1].float().cpu().numpy()
163 | embeddings.append(emb)
164 | output_dataset = input_dataset.add_column("embeddings", embeddings)
165 | return output_dataset
166 |
167 |
168 | def generate_contexts(args, input_dataset):
169 | # Generate context without survey question pool
170 | output_dir = os.path.join(args.output_dir, f"{args.model_type}", f"{args.data_subset}")
171 | if not os.path.exists(output_dir):
172 | os.makedirs(output_dir)
173 |
174 | dataset_size = len(input_dataset)
175 |
176 | K = 1 # repeat samples for K times
177 | dataset_list = list()
178 | for idx in range(K):
179 | context_dataset = deepcopy(input_dataset)
180 | context_lengths = [8] * dataset_size
181 | if "context_length" in context_dataset.column_names:
182 | context_dataset = context_dataset.remove_columns("context_length")
183 | context_dataset = context_dataset.add_column("context_length", context_lengths)
184 | contexts = list()
185 | for row_id in tqdm(range(dataset_size)): # iterate over all samples in original dataset
186 | row_contexts = list()
187 | num_context = 0
188 | controversial_subset = input_dataset.filter(lambda example: example['controversial'] == True)
189 | controversial_size = len(controversial_subset)
190 | while num_context < context_lengths[row_id]:
191 | random_id = np.random.randint(controversial_size)
192 | context_id = controversial_subset[random_id]['Index']
193 | context_data = controversial_subset[random_id]
194 | if not args.synthetic_dataset:
195 | if input_dataset[row_id]['prompt'] == context_data['prompt']:
196 | continue
197 | if not args.with_embeddings:
198 | row_contexts.append({
199 | 'original_id': context_id,
200 | 'chosen': context_data['chosen'],
201 | 'rejected': context_data['rejected'],
202 | })
203 | else:
204 | row_contexts.append({
205 | 'original_id': context_id,
206 | 'embedding_chosen': context_data['embeddings']['embedding_chosen'],
207 | 'embedding_rejected': context_data['embeddings']['embedding_rejected'],
208 | })
209 | num_context += 1
210 | contexts.append(row_contexts)
211 | context_dataset = context_dataset.add_column("contexts", contexts)
212 | dataset_list.append(context_dataset)
213 |
214 | output = concatenate_datasets(dataset_list)
215 | output.to_json(os.path.join(output_dir, f"{args.data_split}.jsonl"))
216 | return output
217 |
218 |
219 | if __name__ == "__main__":
220 | # default setting on HH-RLHF dataset, please iterate over data subsets and data splits
221 | seed = 0
222 | random.seed(seed)
223 | np.random.seed(seed)
224 | torch.manual_seed(seed)
225 | torch.cuda.manual_seed(seed)
226 | parser = HfArgumentParser(ScriptArguments)
227 | script_args: ScriptArguments = parser.parse_args_into_dataclasses()[0]
228 | print(script_args)
229 | dataset = generate_embeddings_with_llm(script_args)
230 | generate_contexts(script_args, dataset)
231 |
--------------------------------------------------------------------------------
/hidden_context/data_utils/generate_simple_data.py:
--------------------------------------------------------------------------------
1 | # This file is used to generate synthetic language dataset
2 | from typing import cast
3 |
4 | from transformers import (
5 | HfArgumentParser,
6 | )
7 |
8 | import torch
9 | import random
10 |
11 | from hidden_context.data_utils.data_processing import (
12 | ScriptArguments,
13 | generate_embeddings_with_llm,
14 | generate_contexts
15 | )
16 |
17 | from hidden_context.data_utils.simple_templates import *
18 |
19 | from hidden_context.train_llm_preference_model import (
20 | DataSubset,
21 | get_hh_rlhf_dataset,
22 | )
23 |
24 | import numpy as np
25 |
26 |
27 | def generate_synthetic_dataset(args):
28 | data_subset = cast(DataSubset, args.data_subset)
29 | input_dataset = get_hh_rlhf_dataset(
30 | data_subset,
31 | args.data_split,
32 | args.dataset_size,
33 | data_path=args.data_path,
34 | use_subset_as_dir=True
35 | )
36 | def generate_simple_data_point(example):
37 | prompt_length = 1
38 | prompt = 'Human: Please talk about one kind of pets.'
39 | if args.data_split == 'train':
40 | helpful_harmless = bird_sentences[:80]
41 | helpful_harmful = dog_sentences[:80]
42 | harmless_unhelpful = cat_sentences[:80]
43 | harmful_unhelpful = rabbit_sentences[:80]
44 | else:
45 | helpful_harmless = bird_sentences[80:]
46 | helpful_harmful = dog_sentences[80:]
47 | harmless_unhelpful = cat_sentences[80:]
48 | harmful_unhelpful = rabbit_sentences[80:]
49 | pair_type = np.random.randint(10) # set to 6 previously
50 | if pair_type == 0:
51 | chosen = np.random.choice(helpful_harmless)
52 | rejected = np.random.choice(helpful_harmful)
53 | elif pair_type == 1:
54 | chosen = np.random.choice(harmless_unhelpful)
55 | rejected = np.random.choice(harmful_unhelpful)
56 | elif pair_type == 2:
57 | chosen = np.random.choice(helpful_harmless)
58 | rejected = np.random.choice(harmless_unhelpful)
59 | elif pair_type == 3:
60 | chosen = np.random.choice(helpful_harmful)
61 | rejected = np.random.choice(harmful_unhelpful)
62 | elif pair_type == 4:
63 | chosen = np.random.choice(helpful_harmless)
64 | rejected = np.random.choice(harmful_unhelpful)
65 | else:
66 | if script_args.data_subset == 'helpful':
67 | chosen = np.random.choice(helpful_harmful)
68 | rejected = np.random.choice(harmless_unhelpful)
69 | else:
70 | chosen = np.random.choice(harmless_unhelpful)
71 | rejected = np.random.choice(helpful_harmful)
72 | chosen_repeated = ' '.join([chosen] * prompt_length)
73 | rejected_repeated = ' '.join([rejected] * prompt_length)
74 | return_dict = {'prompt': prompt, 'chosen': prompt + '\n\n' + 'Assistant: ' + chosen_repeated,
75 | 'rejected': prompt + '\n\n' + 'Assistant: ' + rejected_repeated}
76 | if example['label'] == 0:
77 | return_dict['responses'] = [chosen_repeated, rejected_repeated]
78 | else:
79 | return_dict['responses'] = [rejected_repeated, chosen_repeated]
80 | if pair_type >= 5:
81 | return_dict['controversial'] = True
82 | else:
83 | return_dict['controversial'] = False
84 | return return_dict
85 |
86 | input_dataset = input_dataset.map(generate_simple_data_point)
87 | print(len(input_dataset.filter(lambda x: x['controversial'] == True)))
88 | return input_dataset
89 |
90 |
91 | if __name__ == "__main__":
92 | # default setting on synthetic language dataset, please iterate over data subsets and data splits
93 | seed = 0
94 | random.seed(seed)
95 | np.random.seed(seed)
96 | torch.manual_seed(seed)
97 | torch.cuda.manual_seed(seed)
98 | parser = HfArgumentParser(ScriptArguments)
99 | script_args: ScriptArguments = parser.parse_args_into_dataclasses()[0]
100 | print(script_args)
101 | dataset = generate_synthetic_dataset(script_args)
102 | if script_args.with_embeddings:
103 | dataset = generate_embeddings_with_llm(script_args, dataset)
104 | generate_contexts(script_args, dataset)
105 |
--------------------------------------------------------------------------------
/hidden_context/data_utils/simple_templates.py:
--------------------------------------------------------------------------------
1 | bird_sentences = [
2 | "Birds chirp melodiously at the break of dawn.",
3 | "Birds of prey have keen eyesight for hunting.",
4 | "Birds migrate to warmer regions during winter.",
5 | "Birds build intricate nests using twigs and leaves.",
6 | "Birds exhibit a diverse range of colors and plumage.",
7 | "Birds communicate through various calls and songs.",
8 | "Birds have hollow bones which aid in flight.",
9 | "Birds navigate using the stars and Earth's magnetic field.",
10 | "Birds' feathers provide insulation and aid in flight.",
11 | "Birds adapt to different environments for survival.",
12 | "Birds of paradise have elaborate courtship displays.",
13 | "Birds use their beaks for feeding and grooming.",
14 | "Birds flock together for safety and socialization.",
15 | "Birds molt old feathers to make way for new ones.",
16 | "Birds play a crucial role in pollination.",
17 | "Birds like the ostrich are flightless but swift runners.",
18 | "Birds' songs vary depending on species and region.",
19 | "Birds mimic human speech and sounds.",
20 | "Birds exhibit territorial behavior during breeding season.",
21 | "Birds' eggs come in various sizes and colors.",
22 | "Birds preen their feathers to keep them clean and waterproof.",
23 | "Birds of prey hunt smaller animals for food.",
24 | "Birds use their wings for balance and stability.",
25 | "Birds migrate thousands of miles during migration.",
26 | "Birds like the penguin are adapted to life in icy waters.",
27 | "Birds' nests are often hidden for protection.",
28 | "Birds have different types of feet for various purposes.",
29 | "Birds possess remarkable intelligence and problem-solving skills.",
30 | "Birds' migration patterns are influenced by weather conditions.",
31 | "Birds huddle together to conserve body heat.",
32 | "Birds are classified into different orders and families.",
33 | "Birds' eyesight is sharper than that of humans.",
34 | "Birds' wingspans vary greatly among species.",
35 | "Birds roost in trees, cliffs, and buildings.",
36 | "Birds use vocalizations to establish dominance.",
37 | "Birds' beaks are adapted to their diet.",
38 | "Birds' feathers are made of keratin, like human hair and nails.",
39 | "Birds navigate using landmarks and celestial cues.",
40 | "Birds like the albatross spend years at sea.",
41 | "Birds' mating rituals can be elaborate and colorful.",
42 | "Birds of prey have strong talons for catching prey.",
43 | "Birds adapt to urban environments for nesting.",
44 | "Birds migrate to find abundant food sources.",
45 | "Birds of paradise have elaborate plumage for courtship.",
46 | "Birds' nests are lined with soft materials for comfort.",
47 | "Birds like the hummingbird can hover in mid-air.",
48 | "Birds' songs serve to attract mates and defend territory.",
49 | "Birds form intricate social structures within flocks.",
50 | "Birds exhibit different foraging techniques depending on their diet.",
51 | "Birds' wings allow them to glide effortlessly.",
52 | "Birds build nests using a variety of materials.",
53 | "Birds' feathers are crucial for regulating body temperature.",
54 | "Birds like the eagle have exceptional eyesight.",
55 | "Birds mark territories with vocalizations and displays.",
56 | "Birds migrate along established routes called flyways.",
57 | "Birds' feet are adapted to perching and grasping.",
58 | "Birds of prey hunt with precision and speed.",
59 | "Birds adapt their behavior to changes in their environment.",
60 | "Birds like the flamingo filter-feed in shallow waters.",
61 | "Birds' calls vary in pitch, tone, and rhythm.",
62 | "Birds migrate to breeding grounds during spring.",
63 | "Birds' beaks are adapted to their diet and feeding habits.",
64 | "Birds' nests vary in size and complexity.",
65 | "Birds rely on instinct and experience for navigation.",
66 | "Birds like the swallow undertake long migrations.",
67 | "Birds exhibit elaborate courtship displays to attract mates.",
68 | "Birds' eggs are incubated until they hatch.",
69 | "Birds migrate to avoid harsh winter conditions.",
70 | "Birds adapt to urban environments by scavenging for food.",
71 | "Birds' feathers are oiled to repel water.",
72 | "Birds mark their territories with scent and vocalizations.",
73 | "Birds display intricate mating rituals to attract partners.",
74 | "Birds' wings enable them to soar effortlessly.",
75 | "Birds adapt their behavior to changes in the environment.",
76 | "Birds build nests in trees, shrubs, and cliffs.",
77 | "Birds of prey hunt with precision and agility.",
78 | "Birds' nests are lined with soft materials for insulation.",
79 | "Birds migrate to warmer climates during winter.",
80 | "Birds communicate through calls, songs, and displays.",
81 | "Birds use their beaks for feeding, grooming, and defense.",
82 | "Birds' feathers provide insulation and aid in flight.",
83 | "Birds exhibit complex social behaviors within flocks.",
84 | "Birds adapt to different habitats for survival."
85 | ]
86 |
87 | cat_sentences = [
88 | "Cats are mysterious creatures.",
89 | "Cats possess an independent nature.",
90 | "Cats enjoy lounging in sunny spots.",
91 | "Cats have retractable claws.",
92 | "Cats communicate through meows.",
93 | "Cats groom themselves meticulously.",
94 | "Cats can see in low light.",
95 | "Cats exhibit playful behavior.",
96 | "Cats are skilled hunters.",
97 | "Cats have a keen sense of balance.",
98 | "Cats enjoy chasing toys.",
99 | "Cats are known for their agility.",
100 | "Cats often nap throughout the day.",
101 | "Cats have unique personalities.",
102 | "Cats knead with their paws.",
103 | "Cats are territorial animals.",
104 | "Cats can jump several times their height.",
105 | "Cats dislike water in general.",
106 | "Cats have a strong sense of smell.",
107 | "Cats are crepuscular creatures.",
108 | "Cats purr when content.",
109 | "Cats are obligate carnivores.",
110 | "Cats have excellent hearing.",
111 | "Cats mark their territory with scent.",
112 | "Cats have specialized whiskers.",
113 | "Cats can be trained through positive reinforcement.",
114 | "Cats are curious by nature.",
115 | "Cats form strong bonds with their owners.",
116 | "Cats are skilled climbers.",
117 | "Cats have a flexible spine.",
118 | "Cats have a preference for routine.",
119 | "Cats have a grooming ritual after meals.",
120 | "Cats use their tails for balance.",
121 | "Cats have a hierarchy within colonies.",
122 | "Cats can recognize their names.",
123 | "Cats exhibit hunting behavior through stalking.",
124 | "Cats enjoy interactive playtime.",
125 | "Cats have different vocalizations for various needs.",
126 | "Cats can sense changes in the weather.",
127 | "Cats have an acute sense of taste.",
128 | "Cats show affection through headbutting.",
129 | "Cats have a natural instinct to hunt rodents.",
130 | "Cats exhibit territorial spraying behavior.",
131 | "Cats form social groups with other cats.",
132 | "Cats can sleep up to 16 hours a day.",
133 | "Cats are crepuscular hunters.",
134 | "Cats have an excellent sense of time.",
135 | "Cats sharpen their claws on scratching posts.",
136 | "Cats can experience stress from changes in their environment.",
137 | "Cats have a preference for certain textures.",
138 | "Cats communicate through body language.",
139 | "Cats enjoy high perches for observation.",
140 | "Cats are obligate carnivores, requiring meat in their diet.",
141 | "Cats can have litters of kittens multiple times a year.",
142 | "Cats have a strong maternal instinct.",
143 | "Cats have a hierarchy within multi-cat households.",
144 | "Cats are prone to hairballs from grooming.",
145 | "Cats have a sensitive digestive system.",
146 | "Cats groom other cats in their social group.",
147 | "Cats have specialized taste receptors.",
148 | "Cats enjoy hiding in small spaces.",
149 | "Cats have a unique grooming technique.",
150 | "Cats have a variety of coat patterns and colors.",
151 | "Cats have a third eyelid for protection.",
152 | "Cats display affection through slow blinking.",
153 | "Cats exhibit whisker fatigue if they touch narrow spaces.",
154 | "Cats enjoy observing prey from a hidden vantage point.",
155 | "Cats have a preference for fresh water sources.",
156 | "Cats can suffer from separation anxiety.",
157 | "Cats have an instinctual fear of unfamiliar objects.",
158 | "Cats enjoy toys that mimic prey.",
159 | "Cats exhibit kneading behavior when relaxed.",
160 | "Cats can develop allergies to certain foods.",
161 | "Cats have a sensitive sense of touch in their whiskers.",
162 | "Cats communicate through scent marking.",
163 | "Cats have a strong dislike for citrus scents.",
164 | "Cats can become stressed by changes in routine.",
165 | "Cats have a preferred sleeping position.",
166 | "Cats groom to regulate body temperature.",
167 | "Cats enjoy interactive feeding toys.",
168 | "Cats have a territorial response to other animals.",
169 | "Cats have specialized muscles for purring.",
170 | "Cats have a preferred scratching substrate.",
171 | "Cats can be trained to walk on a leash.",
172 | "Cats enjoy sunbathing near windows.",
173 | "Cats have a preference for certain types of litter.",
174 | "Cats communicate through vocalizations.",
175 | "Cats have a natural instinct to bury waste.",
176 | "Cats are crepuscular hunters by nature.",
177 | "Cats display affection through grooming rituals.",
178 | "Cats have a preference for routines in feeding times.",
179 | "Cats enjoy exploring new environments.",
180 | "Cats exhibit kneading behavior on soft surfaces.",
181 | "Cats have a varied vocal range for communication.",
182 | "Cats have a strong instinct to hunt small prey."
183 | ]
184 |
185 | dog_sentences = [
186 | "Dogs are loyal companions.",
187 | "Dogs come in all shapes and sizes.",
188 | "Dogs have a keen sense of smell.",
189 | "Dogs enjoy playing fetch.",
190 | "Dogs love belly rubs.",
191 | "Dogs are known as man's best friend.",
192 | "Dogs make great therapy animals.",
193 | "Dogs are incredibly intelligent.",
194 | "Dogs can be trained to perform various tasks.",
195 | "Dogs need regular exercise.",
196 | "Dogs enjoy exploring the outdoors.",
197 | "Dogs have an innate sense of curiosity.",
198 | "Dogs communicate through body language.",
199 | "Dogs love treats.",
200 | "Dogs are pack animals by nature.",
201 | "Dogs have a strong sense of hierarchy.",
202 | "Dogs are capable of forming deep bonds with humans.",
203 | "Dogs provide emotional support to their owners.",
204 | "Dogs have a natural instinct to protect their families.",
205 | "Dogs have been domesticated for thousands of years.",
206 | "Dogs can learn a wide range of commands.",
207 | "Dogs have an excellent sense of hearing.",
208 | "Dogs have been used for hunting since ancient times.",
209 | "Dogs have a playful demeanor.",
210 | "Dogs are highly adaptable animals.",
211 | "Dogs have a variety of coat colors and patterns.",
212 | "Dogs are social animals that thrive on companionship.",
213 | "Dogs enjoy cuddling with their owners.",
214 | "Dogs are capable of expressing affection.",
215 | "Dogs can be trained for search and rescue missions.",
216 | "Dogs have a strong sense of territory.",
217 | "Dogs are skilled at interpreting human emotions.",
218 | "Dogs are often used in police work.",
219 | "Dogs enjoy participating in agility courses.",
220 | "Dogs have a remarkable ability to learn new things.",
221 | "Dogs are known for their unconditional love.",
222 | "Dogs are descendants of wolves.",
223 | "Dogs have a natural inclination to chase after moving objects.",
224 | "Dogs require proper grooming to stay healthy.",
225 | "Dogs have a unique set of vocalizations.",
226 | "Dogs have a powerful sense of taste.",
227 | "Dogs are adept at understanding routines.",
228 | "Dogs have been depicted in art throughout history.",
229 | "Dogs are capable of forming friendships with other animals.",
230 | "Dogs enjoy being part of a family unit.",
231 | "Dogs can detect changes in human behavior.",
232 | "Dogs have an extraordinary sense of balance.",
233 | "Dogs are sensitive to changes in their environment.",
234 | "Dogs have been trained to assist people with disabilities.",
235 | "Dogs enjoy playing with toys.",
236 | "Dogs have an innate sense of direction.",
237 | "Dogs are known to be protective of children.",
238 | "Dogs have a natural instinct to dig.",
239 | "Dogs are skilled at navigating through various terrains.",
240 | "Dogs enjoy receiving praise from their owners.",
241 | "Dogs have a strong sense of smell that can detect illness.",
242 | "Dogs thrive on routine and structure.",
243 | "Dogs enjoy sunbathing.",
244 | "Dogs have a playful rivalry with cats.",
245 | "Dogs require socialization from an early age.",
246 | "Dogs have a keen sense of time.",
247 | "Dogs are known to comfort people in distress.",
248 | "Dogs have a natural affinity for water.",
249 | "Dogs are capable of learning through observation.",
250 | "Dogs have been used in therapy for mental health conditions.",
251 | "Dogs enjoy exploring new scents.",
252 | "Dogs have a calming presence.",
253 | "Dogs are known for their ability to sense danger.",
254 | "Dogs enjoy being praised for good behavior.",
255 | "Dogs have a strong sense of empathy.",
256 | "Dogs are skilled at interpreting human gestures.",
257 | "Dogs have a natural inclination to mark their territory.",
258 | "Dogs enjoy sleeping in comfortable spots.",
259 | "Dogs have a playful nature that lasts into old age.",
260 | "Dogs are excellent at following scent trails.",
261 | "Dogs have been trained for military purposes.",
262 | "Dogs are often featured in movies and television shows.",
263 | "Dogs have a unique way of greeting each other.",
264 | "Dogs have been companions to humans for millennia.",
265 | "Dogs enjoy being part of outdoor activities.",
266 | "Dogs have been known to rescue people in distress.",
267 | "Dogs have a strong bond with their owners.",
268 | "Dogs have a calming effect on people.",
269 | "Dogs have a unique personality.",
270 | "Dogs have a variety of barks for different situations.",
271 | "Dogs enjoy exploring their surroundings.",
272 | "Dogs are known for their sense of playfulness.",
273 | "Dogs have been used in various forms of work throughout history.",
274 | "Dogs have a strong sense of loyalty to their families.",
275 | "Dogs are adept at learning from positive reinforcement.",
276 | "Dogs have been known to detect natural disasters before they occur.",
277 | "Dogs have a strong prey drive.",
278 | "Dogs enjoy being praised and rewarded for their efforts."
279 | ]
280 |
281 | rabbit_sentences = [
282 | "Rabbits hop around in fields.",
283 | "Rabbits have soft fur.",
284 | "Rabbits eat carrots.",
285 | "Rabbits reproduce quickly.",
286 | "Rabbits are social animals.",
287 | "Rabbits have long ears.",
288 | "Rabbits twitch their noses.",
289 | "Rabbits dig burrows.",
290 | "Rabbits love to play.",
291 | "Rabbits can be pets.",
292 | "Rabbits come in many colors.",
293 | "Rabbits thump their feet.",
294 | "Rabbits are herbivores.",
295 | "Rabbits have large families.",
296 | "Rabbits groom themselves.",
297 | "Rabbits are prey animals.",
298 | "Rabbits are agile.",
299 | "Rabbits are prolific breeders.",
300 | "Rabbits have a keen sense of smell.",
301 | "Rabbits are crepuscular.",
302 | "Rabbits have powerful hind legs.",
303 | "Rabbits nibble on grass.",
304 | "Rabbits are cute.",
305 | "Rabbits have whiskers.",
306 | "Rabbits communicate through body language.",
307 | "Rabbits enjoy hiding in tunnels.",
308 | "Rabbits are fast runners.",
309 | "Rabbits are nocturnal.",
310 | "Rabbits have a natural curiosity.",
311 | "Rabbits have powerful hind legs.",
312 | "Rabbits need space to roam.",
313 | "Rabbits have a varied diet.",
314 | "Rabbits can be trained.",
315 | "Rabbits are associated with fertility.",
316 | "Rabbits have a short gestation period.",
317 | "Rabbits have a lifespan of 8-12 years.",
318 | "Rabbits enjoy companionship.",
319 | "Rabbits are quiet animals.",
320 | "Rabbits are known for their reproductive rate.",
321 | "Rabbits have a gentle disposition.",
322 | "Rabbits thump to communicate danger.",
323 | "Rabbits have a hierarchy within groups.",
324 | "Rabbits enjoy toys.",
325 | "Rabbits have strong teeth.",
326 | "Rabbits are territorial.",
327 | "Rabbits groom each other.",
328 | "Rabbits have a complex digestive system.",
329 | "Rabbits can be litter-trained.",
330 | "Rabbits binky when happy.",
331 | "Rabbits have a strong maternal instinct.",
332 | "Rabbits have a 360-degree field of vision.",
333 | "Rabbits are prolific diggers.",
334 | "Rabbits can be affectionate pets.",
335 | "Rabbits have a unique digestive process called cecotrophy.",
336 | "Rabbits have a fear of loud noises.",
337 | "Rabbits can suffer from heatstroke.",
338 | "Rabbits have a natural inclination to chew.",
339 | "Rabbits enjoy exploring new environments.",
340 | "Rabbits have a strong sense of territory.",
341 | "Rabbits are often depicted in folklore.",
342 | "Rabbits are used in scientific research.",
343 | "Rabbits are susceptible to parasites.",
344 | "Rabbits are good swimmers.",
345 | "Rabbits have a sensitive respiratory system.",
346 | "Rabbits have a hutch as their shelter.",
347 | "Rabbits have a strong sense of balance.",
348 | "Rabbits are clean animals.",
349 | "Rabbits can communicate with a range of vocalizations.",
350 | "Rabbits are known for their fertility.",
351 | "Rabbits enjoy being petted.",
352 | "Rabbits need hay for proper digestion.",
353 | "Rabbits are social creatures.",
354 | "Rabbits can suffer from loneliness.",
355 | "Rabbits have a natural instinct to burrow.",
356 | "Rabbits enjoy fresh vegetables.",
357 | "Rabbits have a strong maternal bond.",
358 | "Rabbits can be trained to use a litter box.",
359 | "Rabbits are crepuscular, meaning they are most active at dawn and dusk.",
360 | "Rabbits have a delicate skeletal structure.",
361 | "Rabbits have long been associated with luck.",
362 | "Rabbits can be territorial over their food.",
363 | "Rabbits have a complex system of communication.",
364 | "Rabbits need regular grooming to prevent matting.",
365 | "Rabbits have a natural curiosity about their environment.",
366 | "Rabbits enjoy companionship with other rabbits.",
367 | "Rabbits are very adaptable animals.",
368 | "Rabbits have a strong sense of smell to detect predators.",
369 | "Rabbits have a unique digestive system that requires a high-fiber diet.",
370 | "Rabbits have been domesticated for over 2,000 years.",
371 | "Rabbits have a keen sense of hearing to detect predators."
372 | ]
373 |
374 |
--------------------------------------------------------------------------------
/hidden_context/data_utils/ultrafeedback_augment.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import random
3 |
4 | import torch
5 | from datasets import load_dataset, Dataset
6 | import numpy as np
7 | import os
8 |
9 |
10 | def random_argmax(values):
11 | """ a random tie-breaking argmax """
12 | return np.argmax(np.random.random(values.shape) * (values == values.max()))
13 |
14 |
15 | def random_greater_than_zero(values):
16 | return (np.random.randn(values.shape[0]) * (values == 0) > 0.0) | (values > 0.0)
17 |
18 |
19 | def array_to_type(arr):
20 | return str(int(np.dot(arr, np.array([8, 4, 2, 1]))))
21 |
22 |
23 | def get_user_type(chosen_ratings, rejected_ratings, augment_type):
24 | keys = ['helpfulness', 'honesty', 'instruction_following', 'truthfulness']
25 | chosen_rating_values = list()
26 | rejected_rating_values = list()
27 | for key in keys:
28 | chosen_rating_values.append(chosen_ratings[key])
29 | rejected_rating_values.append(rejected_ratings[key])
30 | chosen_values = np.asarray(chosen_rating_values)
31 | rejected_values = np.asarray(rejected_rating_values)
32 | is_equal = list(chosen_values == rejected_values)
33 | if augment_type == 'single' or augment_type == '84':
34 | data_subsets = ['8', '4', '2', '1']
35 | reversed_labels = {data_subsets[idx]: list(random_greater_than_zero(rejected_values - chosen_values))[idx] for
36 | idx in range(len(data_subsets))}
37 | is_equal = {data_subsets[idx]: is_equal[idx] for idx in range(len(data_subsets))}
38 | return data_subsets, reversed_labels, is_equal
39 | else:
40 | raise ValueError('Invalid augment_type')
41 |
42 |
43 | def inner_join(original, binarized, augment_type, users, two_two_only=False, filter_equal=False):
44 | agreed_counter = 0
45 | controversial_counter = 0
46 | keys = ['helpfulness', 'honesty', 'instruction_following', 'truthfulness']
47 | user_counter = {key: 0 for key in users.keys()}
48 | reversed_counter = {key: 0 for key in users.keys()}
49 | dumb_baseline = {key: 0 for key in users.keys()}
50 | dumb_controversial_baseline = {key: 0 for key in users.keys()}
51 | orig_idx = 0
52 | out_idx = 0
53 | dataset_dict = {
54 | 'Index': list(),
55 | 'original_idx': list(),
56 | 'prompt': list(),
57 | 'chosen': list(),
58 | 'rejected': list(),
59 | 'data_subset': list(),
60 | 'controversial': list(),
61 | 'reversed': list(),
62 | 'satisfied_subset': list(),
63 | 'survey_options': list(),
64 | }
65 | for bin_idx in range(len(binarized)):
66 | while binarized[bin_idx]['prompt'] != original[orig_idx]['instruction']:
67 | orig_idx += 1
68 | prompt = binarized[bin_idx]['prompt']
69 | chosen = binarized[bin_idx]['chosen'][1]['content']
70 | rejected = binarized[bin_idx]['rejected'][1]['content']
71 | if chosen == '' or rejected == '':
72 | continue
73 | chosen_ratings = dict()
74 | rejected_ratings = dict()
75 | flag = True
76 | for c in original[orig_idx]['completions']:
77 | if c['response'] == chosen:
78 | for key in keys:
79 | r = c['annotations'][key]['Rating']
80 | if r == 'N/A':
81 | flag = False
82 | continue
83 | chosen_ratings[key] = int(r)
84 | elif c['response'] == rejected:
85 | for key in keys:
86 | r = c['annotations'][key]['Rating']
87 | if r == 'N/A':
88 | flag = False
89 | continue
90 | rejected_ratings[key] = int(r)
91 | else:
92 | continue
93 | if not flag or len(chosen_ratings) != 4 or len(rejected_ratings) != 4:
94 | continue
95 | data_subsets, reversed_labels, is_equal = get_user_type(chosen_ratings, rejected_ratings, augment_type, users)
96 | if filter_equal:
97 | reversed_labels = {key: reversed_labels[key] for key in data_subsets if not is_equal[key]}
98 | data_subsets = [key for key in data_subsets if not is_equal[key]]
99 | is_equal = {key: False for key in data_subsets}
100 | if augment_type == '84' and len(is_equal.keys()) != 2:
101 | continue
102 | for data_subset in users.keys():
103 | if data_subset not in data_subsets:
104 | dumb_baseline[data_subset] += 0.5 * len(data_subsets)
105 | if True in reversed_labels.values() and False in reversed_labels.values():
106 | dumb_controversial_baseline[data_subset] += 0.5 * len(data_subsets)
107 | continue
108 | user_counter[data_subset] += 1
109 | if True in reversed_labels.values() and False in reversed_labels.values():
110 | is_controversial = True
111 | controversial_counter += 1
112 | else:
113 | is_controversial = False
114 | agreed_counter += 1
115 | if reversed_labels[data_subset]:
116 | reversed_counter[data_subset] += 1
117 | dumb_baseline[data_subset] += list(reversed_labels.values()).count(True)
118 | if is_controversial:
119 | dumb_controversial_baseline[data_subset] += list(reversed_labels.values()).count(True)
120 | else:
121 | dumb_baseline[data_subset] += list(reversed_labels.values()).count(False)
122 | if is_controversial:
123 | dumb_controversial_baseline[data_subset] += list(reversed_labels.values()).count(False)
124 | dataset_dict['Index'].append(out_idx)
125 | dataset_dict['original_idx'].append(orig_idx)
126 | dataset_dict['prompt'].append(prompt)
127 | if not reversed_labels[data_subset]:
128 | dataset_dict['chosen'].append('Human: ' + prompt + '\n\nAssistant: ' + chosen)
129 | dataset_dict['rejected'].append('Human: ' + prompt + '\n\nAssistant: ' + rejected)
130 | else:
131 | dataset_dict['chosen'].append('Human: ' + prompt + '\n\nAssistant: ' + rejected)
132 | dataset_dict['rejected'].append('Human: ' + prompt + '\n\nAssistant: ' + chosen)
133 | dataset_dict['data_subset'].append(data_subset)
134 | dataset_dict['controversial'].append(is_controversial)
135 | dataset_dict['reversed'].append(reversed_labels[data_subset])
136 | satisfied_subset = set([key for key in users.keys() if key not in data_subsets or reversed_labels[key] == reversed_labels[data_subset]])
137 | dataset_dict['satisfied_subset'].append(satisfied_subset)
138 | dataset_dict['survey_options'].append(is_controversial and len(data_subsets) == len(users.keys()))
139 | out_idx += 1
140 | print(out_idx, agreed_counter, controversial_counter)
141 | print("User counter:", user_counter)
142 | print("Reversed counter:", reversed_counter)
143 | print("Dumb baseline:", dumb_baseline)
144 | print("Dumb controversial baseline:", dumb_controversial_baseline)
145 | return Dataset.from_dict(dataset_dict)
146 |
147 |
148 | if __name__ == '__main__':
149 | parser = argparse.ArgumentParser()
150 | parser.add_argument('--seed', type=int, default=42, help='Random seed')
151 | parser.add_argument('-a', '--augment_type', type=str, default='single', help='How to augment data')
152 | parser.add_argument('-c', '--controversial_only', action='store_true', help='Whether to only generate controversial data')
153 | parser.add_argument('-n', '--name', type=str, default='P_4', help='name of dataset')
154 | args = parser.parse_args()
155 | seed = args.seed
156 | random.seed(seed)
157 | np.random.seed(seed)
158 | torch.manual_seed(seed)
159 | torch.cuda.manual_seed(seed)
160 | if args.augment_type == 'single' or args.augment_type == '84':
161 | user_types = {
162 | '8': (1, 0, 0, 0),
163 | '4': (0, 1, 0, 0),
164 | '2': (0, 0, 1, 0),
165 | '1': (0, 0, 0, 1),
166 | }
167 | else:
168 | raise ValueError('Invalid augment_type')
169 |
170 | ultra_feedback = load_dataset('openbmb/UltraFeedback')
171 | binarized_cleaned = load_dataset('argilla/ultrafeedback-binarized-preferences-cleaned')
172 | length = len(binarized_cleaned['train'])
173 | print(length)
174 | test_ids = list(np.random.choice(length, int(length * 0.1), replace=False))
175 | train_split = binarized_cleaned['train'].filter(lambda example, idx: idx not in test_ids, with_indices=True)
176 | test_split = binarized_cleaned['train'].filter(lambda example, idx: idx in test_ids, with_indices=True)
177 | print(len(train_split), len(test_split))
178 | print("start processing train split")
179 | joined_dataset_train = inner_join(ultra_feedback['train'], train_split, args.augment_type, user_types)
180 | print("start processing test split")
181 | joined_dataset_test = inner_join(ultra_feedback['train'], test_split, args.augment_type, user_types)
182 |
183 | output_dir = os.path.join('data', 'UltraFeedback_{}_{}'.format(args.augment_type, args.name))
184 | for user_type in user_types.keys():
185 | train_subset = joined_dataset_train.filter(lambda x: x['data_subset'] == user_type)
186 | test_subset = joined_dataset_test.filter(lambda x: x['data_subset'] == user_type)
187 | if args.controversial_only:
188 | train_subset = train_subset.filter(lambda x: x['controversial'] == True)
189 | test_subset = test_subset.filter(lambda x: x['controversial'] == True)
190 | print(user_types[user_type], len(train_subset), len(test_subset))
191 | train_subset.to_json(os.path.join(output_dir, user_type, 'train.jsonl'))
192 | test_subset.to_json(os.path.join(output_dir, user_type, 'test.jsonl'))
193 |
194 | # python -m hidden_context.data_utils.ultrafeedback_augment -a single -n P_4 -c
195 |
196 | # python -m hidden_context.data_utils.ultrafeedback_augment -a 84 -n P
197 |
--------------------------------------------------------------------------------
/hidden_context/train_llm_preference_model.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 | from functools import partial
3 | from typing import Any, Dict, List, Optional, Type, Union, cast
4 |
5 | import wandb
6 | import numpy as np
7 | import torch
8 | import random
9 | import torch.nn.functional as F # noqa: N812
10 | from datasets import Dataset, concatenate_datasets, load_dataset
11 | from peft import LoraConfig, TaskType, get_peft_model
12 | from torch import nn
13 | from torch.optim.lr_scheduler import LambdaLR
14 | from transformers import (
15 | AutoModelForSequenceClassification,
16 | AutoTokenizer,
17 | HfArgumentParser,
18 | PreTrainedTokenizerBase,
19 | Trainer,
20 | TrainingArguments,
21 | )
22 | from transformers.trainer_utils import EvalPrediction
23 | from transformers.utils import PaddingStrategy
24 | from typing_extensions import Literal, TypeAlias
25 |
26 | RewardModelType: TypeAlias = Literal["base", "mean_and_variance", "categorical"]
27 | DataSubset: TypeAlias = Literal["both", "helpful", "harmless"]
28 |
29 |
30 | @dataclass
31 | class ScriptArguments:
32 | local_rank: int = field(default=-1, metadata={"help": "Used for multi-gpu"})
33 | resume_from_checkpoint: bool = field(
34 | default=False,
35 | metadata={"help": "If you want to resume training where it left off."},
36 | )
37 | deepspeed: Optional[str] = field(
38 | default=None,
39 | metadata={
40 | "help": "Path to deepspeed config if using deepspeed. You may need this "
41 | "if the model that you want to train doesn't fit on a single GPU."
42 | },
43 | )
44 | per_device_train_batch_size: int = field(default=2)
45 | per_device_eval_batch_size: int = field(default=1)
46 | gradient_accumulation_steps: int = field(default=1)
47 | learning_rate: float = field(default=3e-6)
48 | weight_decay: float = field(default=0.001)
49 | model_name: str = field(
50 | default="gpt2",
51 | metadata={
52 | "help": "The model that you want to train from the Hugging Face hub. "
53 | "E.g. gpt2, gpt2-xl, bert, etc."
54 | },
55 | )
56 | data_path: str = field(
57 | default="Anthropic/hh-rlhf",
58 | )
59 | data_subset: str = field(
60 | default="both",
61 | metadata={
62 | "help": "Which subset of the data to use. You can choose between 'both', "
63 | "'helpful', or 'harmless'."
64 | },
65 | )
66 | reward_model_type: str = field(
67 | default="base",
68 | metadata={
69 | "help": "The type of reward model to use. You can choose between "
70 | "'base', 'mean_and_variance', or 'categorical'."
71 | },
72 | )
73 | num_atoms: int = field(
74 | default=10,
75 | metadata={
76 | "help": "The number of atoms to use for the categorical reward model."
77 | },
78 | )
79 | entropy_coeff: float = field(
80 | default=0.1,
81 | metadata={"help": "The entropy coefficient for the categorical reward model."},
82 | )
83 | variance_penalty: float = field(
84 | default=0.0,
85 | metadata={
86 | "help": "The variance penalty for the mean and variance reward model."
87 | },
88 | )
89 | tokenizer_name: Optional[str] = field(
90 | default=None,
91 | metadata={
92 | "help": "The tokenizer for your model, if left empty will use the default "
93 | "for your model",
94 | },
95 | )
96 | bf16: bool = field(
97 | default=True,
98 | metadata={
99 | "help": "This essentially cuts the training time in half if you want to "
100 | "sacrifice a little precision and have a supported GPU."
101 | },
102 | )
103 | fp16: bool = field(
104 | default=False,
105 | metadata={
106 | "help": "This essentially cuts the training time in half if you want to "
107 | "sacrifice a little precision and have a supported GPU."
108 | },
109 | )
110 | num_train_epochs: int = field(
111 | default=1,
112 | metadata={"help": "The number of training epochs for the reward model."},
113 | )
114 | train_dataset_size: int = field(
115 | default=0,
116 | metadata={"help": "The size of the subset of the training data to use"},
117 | )
118 | eval_dataset_size: int = field(
119 | default=0,
120 | metadata={"help": "The size of the subset of the eval data to use"},
121 | )
122 | gradient_checkpointing: bool = field(
123 | default=False,
124 | metadata={"help": "Enables gradient checkpointing."},
125 | )
126 | optim: str = field(
127 | default="adamw_hf",
128 | metadata={"help": "The optimizer to use."},
129 | )
130 | lr_scheduler_type: str = field(
131 | default="cosine",
132 | metadata={"help": "The lr scheduler"},
133 | )
134 | max_length: int = field(default=1024)
135 | eval_first_step: bool = field(
136 | default=False,
137 | metadata={"help": "Whether to run eval after the first step"},
138 | )
139 | log_dir: str = field(default="data/reward_models/hh_rlhf")
140 | controversial_only: bool = field(default=False)
141 | seed: int = field(default=0)
142 | up_sampling: bool = field(default=False)
143 | other_subsets: str = field(default=None)
144 | one_user: str = field(
145 | default=None,
146 | metadata={"help": "whether to only train and evaluate on one single user"}
147 | )
148 |
149 |
150 | class HHRLHFPreprocessor(object):
151 | def __init__(self, tokenizer, **tokenizer_kwargs):
152 | self.tokenizer = tokenizer
153 | self.tokenizer_kwargs = tokenizer_kwargs
154 |
155 | def __call__(self, examples):
156 | new_examples: dict = {
157 | "input_ids_chosen": [],
158 | "attention_mask_chosen": [],
159 | "input_ids_rejected": [],
160 | "attention_mask_rejected": [],
161 | }
162 | for chosen, rejected in zip(examples["chosen"], examples["rejected"]):
163 | tokenized_chosen = self.tokenizer(chosen, **self.tokenizer_kwargs)
164 | tokenized_rejected = self.tokenizer(rejected, **self.tokenizer_kwargs)
165 |
166 | new_examples["input_ids_chosen"].append(tokenized_chosen["input_ids"])
167 | new_examples["attention_mask_chosen"].append(
168 | tokenized_chosen["attention_mask"]
169 | )
170 | new_examples["input_ids_rejected"].append(tokenized_rejected["input_ids"])
171 | new_examples["attention_mask_rejected"].append(
172 | tokenized_rejected["attention_mask"]
173 | )
174 |
175 | return new_examples
176 |
177 |
178 | def get_step_decay_lr_lambda(current_step: int, *, num_training_steps: int):
179 | if current_step < num_training_steps // 3:
180 | return 1.0
181 | elif current_step < (2 * num_training_steps) // 3:
182 | return 0.1
183 | else:
184 | return 0.01
185 |
186 |
187 | def get_cosine_decay_lr_lambda(current_step: int, *, num_training_steps: int):
188 | return 0.1 + 0.9 * 0.5 * (1 + np.cos(np.pi * current_step / num_training_steps))
189 |
190 |
191 | class RewardTrainer(Trainer):
192 | def __init__(self, *args, lr_lambda=None, **kwargs):
193 | super().__init__(*args, **kwargs)
194 | self.lr_lambda = lr_lambda
195 |
196 | @classmethod
197 | def per_sample_loss(cls, rewards_chosen, rewards_rejected):
198 | return -nn.functional.logsigmoid(rewards_chosen - rewards_rejected)
199 |
200 | def loss(self, rewards_chosen, rewards_rejected):
201 | return torch.mean(self.per_sample_loss(rewards_chosen, rewards_rejected))
202 |
203 | def compute_loss(self, model, inputs, return_outputs=False):
204 | all_rewards = model(
205 | torch.concatenate(
206 | [
207 | inputs["input_ids_chosen"],
208 | inputs["input_ids_rejected"],
209 | ],
210 | dim=0,
211 | ),
212 | torch.concatenate(
213 | [
214 | inputs["attention_mask_chosen"],
215 | inputs["attention_mask_rejected"],
216 | ],
217 | dim=0,
218 | ),
219 | )[0]
220 | all_rewards = all_rewards.reshape(2, -1, all_rewards.shape[-1])
221 | rewards_chosen = all_rewards[0]
222 | rewards_rejected = all_rewards[1]
223 | loss = self.loss(rewards_chosen, rewards_rejected)
224 | if return_outputs:
225 | return loss, {
226 | "rewards_chosen": rewards_chosen,
227 | "rewards_rejected": rewards_rejected,
228 | }
229 | else:
230 | self.log(
231 | {
232 | "rewards_chosen": rewards_chosen.mean().item(),
233 | "rewards_rejected": rewards_rejected.mean().item(),
234 | }
235 | )
236 | return loss
237 |
238 | def create_scheduler(self, num_training_steps: int, optimizer=None):
239 | if self.lr_lambda is not None:
240 | lr_lambda = partial(
241 | self.lr_lambda,
242 | num_training_steps=num_training_steps,
243 | )
244 | self.lr_scheduler = LambdaLR(optimizer, lr_lambda)
245 | return self.lr_scheduler
246 | else:
247 | return super().create_scheduler(num_training_steps, optimizer)
248 |
249 | @classmethod
250 | def compute_metrics(cls, eval_prediction: EvalPrediction):
251 | rewards_chosen, rewards_rejected = eval_prediction.predictions
252 | rewards_chosen = torch.from_numpy(rewards_chosen)
253 | rewards_rejected = torch.from_numpy(rewards_rejected)
254 |
255 | loss = cls.per_sample_loss(rewards_chosen, rewards_rejected)
256 | accuracy = torch.mean((loss < np.log(2)).float())
257 |
258 | return {
259 | "loss": loss.mean().item(),
260 | "accuracy": accuracy.item(),
261 | }
262 |
263 |
264 | class MeanAndVarianceRewardTrainer(RewardTrainer):
265 | def __init__(self, *args, variance_penalty: float = 0.0, **kwargs):
266 | super().__init__(*args, **kwargs)
267 | self.variance_penalty = variance_penalty
268 |
269 | @classmethod
270 | def per_sample_loss(cls, rewards_chosen, rewards_rejected):
271 | mean_chosen = rewards_chosen[:, 0]
272 | std_chosen = F.softplus(rewards_chosen[:, 1])
273 | mean_rejected = rewards_rejected[:, 0]
274 | std_rejected = F.softplus(rewards_rejected[:, 1])
275 |
276 | diff_mean = mean_chosen - mean_rejected
277 | var_combined = std_chosen**2 + std_rejected**2
278 | z = diff_mean / torch.sqrt(var_combined)
279 | return F.softplus(-z * np.sqrt(2 * np.pi))
280 |
281 | def loss(self, rewards_chosen, rewards_rejected):
282 | std_chosen = F.softplus(rewards_chosen[:, 1])
283 | std_rejected = F.softplus(rewards_rejected[:, 1])
284 | variance_loss = (std_chosen**2 + std_rejected**2).mean()
285 |
286 | log_loss = self.per_sample_loss(rewards_chosen, rewards_rejected).mean()
287 |
288 | if self.model.training:
289 | return log_loss + self.variance_penalty * variance_loss
290 | else:
291 | return log_loss
292 |
293 |
294 | class CategoricalRewardTrainer(RewardTrainer):
295 | def __init__(self, *args, entropy_coeff: float = 0.0, **kwargs):
296 | super().__init__(*args, **kwargs)
297 | self.entropy_coeff = entropy_coeff
298 |
299 | @classmethod
300 | def per_sample_loss(cls, rewards_chosen, rewards_rejected):
301 | num_atoms = rewards_chosen.size()[1]
302 | device = rewards_chosen.device
303 |
304 | comparison_matrix = torch.empty(
305 | (num_atoms, num_atoms),
306 | device=device,
307 | dtype=rewards_chosen.dtype,
308 | )
309 | atom_values = torch.linspace(0, 1, num_atoms, device=device)
310 | comparison_matrix[:] = atom_values[None, :] > atom_values[:, None]
311 | comparison_matrix[atom_values[None, :] == atom_values[:, None]] = 0.5
312 |
313 | dist_rejected = rewards_rejected.softmax(1)
314 | dist_chosen = rewards_chosen.softmax(1)
315 | prob_chosen = ((dist_rejected @ comparison_matrix) * dist_chosen).sum(dim=1)
316 | return -prob_chosen.log()
317 |
318 | def loss(self, rewards_chosen, rewards_rejected):
319 | dist_rejected = rewards_rejected.softmax(1)
320 | dist_chosen = rewards_chosen.softmax(1)
321 | mean_dist = torch.concatenate(
322 | [dist_chosen, dist_rejected],
323 | dim=0,
324 | ).mean(dim=0)
325 | entropy_loss = torch.sum(mean_dist * mean_dist.log())
326 |
327 | log_loss = self.per_sample_loss(rewards_chosen, rewards_rejected).mean()
328 |
329 | if self.model.training:
330 | return log_loss + self.entropy_coeff * entropy_loss
331 | else:
332 | return log_loss
333 |
334 |
335 | def get_hh_rlhf_dataset(
336 | data_subset: DataSubset,
337 | split: Literal["train", "test"],
338 | dataset_size: int = 0,
339 | data_path="Anthropic/hh-rlhf",
340 | use_subset_as_dir=True, # new parameter
341 | other_subsets=None
342 | ) -> Dataset:
343 | datasets: List[Dataset] = []
344 | if other_subsets is None:
345 | if data_path == "Anthropic/hh-rlhf":
346 | if data_subset == "harmless" or data_subset == "both":
347 | datasets.append(
348 | load_dataset(
349 | "Anthropic/hh-rlhf", data_dir="harmless-base", split=split
350 | ).map(lambda data: {"data_subset": "harmless"})
351 | )
352 | if data_subset == "helpful" or data_subset == "both":
353 | datasets.append(
354 | load_dataset(
355 | "Anthropic/hh-rlhf", data_dir="helpful-base", split=split
356 | ).map(lambda data: {"data_subset": "helpful"})
357 | )
358 | else:
359 | if not use_subset_as_dir: # original version: combine all data subsets within the path
360 | datasets.append(
361 | load_dataset(data_path, split=split).map(
362 | lambda data: {"data_subset": data_subset}
363 | )
364 | )
365 | else: # new version: use data_subset as subdirectory
366 | if data_subset == "helpful" or data_subset == "both":
367 | datasets.append(
368 | load_dataset(
369 | data_path, data_dir="helpful", split=split
370 | ).map(lambda data: {"data_subset": "helpful"})
371 | )
372 | if data_subset == "harmless" or data_subset == "both":
373 | datasets.append(
374 | load_dataset(
375 | data_path, data_dir="harmless", split=split
376 | ).map(lambda data: {"data_subset": "harmless"})
377 | )
378 | else: # TODO: set subsets here
379 | if other_subsets == 'ultra_feedback':
380 | subsets = ['helpfulness', 'honesty', 'instruction_following', 'truthfulness']
381 | elif other_subsets == 'single':
382 | subsets = ['8', '4', '2', '1']
383 | elif other_subsets == '84':
384 | subsets = ['8', '4']
385 | else:
386 | subsets = []
387 | for subset in subsets:
388 | if data_subset == 'all' or data_subset == subset:
389 | datasets.append(
390 | load_dataset(
391 | data_path, data_dir=subset, split=split
392 | )
393 | )
394 |
395 | if dataset_size:
396 | datasets = [
397 | dataset.select(range(dataset_size // len(datasets))) for dataset in datasets
398 | ]
399 |
400 | return concatenate_datasets(datasets)
401 |
402 |
403 | trainer_classes: Dict[RewardModelType, Type[RewardTrainer]] = {
404 | "base": RewardTrainer,
405 | "mean_and_variance": MeanAndVarianceRewardTrainer,
406 | "categorical": CategoricalRewardTrainer,
407 | }
408 |
409 |
410 | def up_sample_controversial(dataset, seed):
411 | cont = dataset.filter(lambda example: example['controversial'] == True)
412 | up_sampled_dataset = concatenate_datasets([cont] * 4 + [dataset])
413 | up_sampled_dataset = up_sampled_dataset.shuffle(seed=seed)
414 | return up_sampled_dataset
415 |
416 |
417 | if __name__ == "__main__":
418 | parser = HfArgumentParser(ScriptArguments)
419 | script_args: ScriptArguments = parser.parse_args_into_dataclasses()[0]
420 |
421 | seed = script_args.seed
422 | random.seed(seed)
423 | np.random.seed(seed)
424 | torch.manual_seed(seed)
425 | torch.cuda.manual_seed(seed)
426 |
427 | torch.set_default_dtype(torch.bfloat16 if script_args.bf16 else torch.float32)
428 |
429 | data_subset = cast(DataSubset, script_args.data_subset)
430 | train_dataset = get_hh_rlhf_dataset(
431 | data_subset,
432 | "train",
433 | script_args.train_dataset_size,
434 | data_path=script_args.data_path,
435 | use_subset_as_dir=True,
436 | other_subsets=script_args.other_subsets
437 | )
438 | eval_dataset = get_hh_rlhf_dataset(
439 | data_subset,
440 | "test",
441 | script_args.eval_dataset_size,
442 | data_path=script_args.data_path,
443 | use_subset_as_dir=True,
444 | other_subsets=script_args.other_subsets
445 | )
446 | print(len(train_dataset), len(eval_dataset))
447 | if script_args.controversial_only:
448 | train_dataset = train_dataset.filter(lambda example: example['controversial'] == True)
449 | eval_dataset = eval_dataset.filter(lambda example: example['controversial'] == True)
450 | elif script_args.up_sampling:
451 | train_dataset = up_sample_controversial(train_dataset, seed)
452 | eval_dataset = up_sample_controversial(eval_dataset, seed)
453 |
454 | if script_args.one_user:
455 | train_dataset = train_dataset.filter(lambda example: example['data_subset'] == script_args.one_user)
456 | eval_dataset = eval_dataset.filter(lambda example: example['data_subset'] == script_args.one_user)
457 |
458 | reward_model_type = cast(RewardModelType, script_args.reward_model_type)
459 |
460 | # Define the training args. Needs to be done before the model is loaded if you
461 | # are using deepspeed.
462 | model_name_split = script_args.model_name.split("/")[-1]
463 | output_name = (
464 | f"{script_args.log_dir}/{data_subset}/"
465 | f"{reward_model_type}_{model_name_split}"
466 | f"__{script_args.train_dataset_size}_{script_args.learning_rate}"
467 | f"_{script_args.lr_scheduler_type}_{script_args.num_train_epochs}"
468 | )
469 | if reward_model_type == "categorical":
470 | output_name += f"_{script_args.num_atoms}_{script_args.entropy_coeff}"
471 | elif reward_model_type == "mean_and_variance":
472 | output_name += f"_{script_args.variance_penalty}"
473 |
474 | output_name += f"_seed{script_args.seed}"
475 | trainer_kwargs: Dict[str, Any] = {}
476 | if script_args.lr_scheduler_type == "step":
477 | lr_scheduler_type = "constant"
478 | trainer_kwargs["lr_lambda"] = get_step_decay_lr_lambda
479 | elif script_args.lr_scheduler_type == "cosine":
480 | lr_scheduler_type = "constant"
481 | trainer_kwargs["lr_lambda"] = get_cosine_decay_lr_lambda
482 | else:
483 | lr_scheduler_type = script_args.lr_scheduler_type
484 |
485 | training_args = TrainingArguments(
486 | output_dir=output_name,
487 | learning_rate=script_args.learning_rate,
488 | per_device_train_batch_size=script_args.per_device_train_batch_size,
489 | per_device_eval_batch_size=script_args.per_device_eval_batch_size,
490 | num_train_epochs=script_args.num_train_epochs,
491 | weight_decay=script_args.weight_decay,
492 | evaluation_strategy="steps",
493 | eval_steps=0.1,
494 | save_strategy="steps",
495 | save_steps=1000,
496 | gradient_accumulation_steps=script_args.gradient_accumulation_steps,
497 | gradient_checkpointing=script_args.gradient_checkpointing,
498 | deepspeed=script_args.deepspeed,
499 | local_rank=script_args.local_rank,
500 | remove_unused_columns=False,
501 | label_names=[],
502 | bf16=script_args.bf16,
503 | fp16=script_args.fp16,
504 | logging_strategy="steps",
505 | logging_steps=10,
506 | optim=script_args.optim,
507 | lr_scheduler_type=lr_scheduler_type,
508 | report_to="wandb",
509 | run_name=output_name.split("/")[-1],
510 | )
511 | # Load the value-head model and tokenizer.
512 | tokenizer_name = (
513 | script_args.tokenizer_name
514 | if script_args.tokenizer_name is not None
515 | else script_args.model_name
516 | )
517 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_auth_token=True, add_eos_token=False)
518 |
519 | peft_config = LoraConfig(
520 | task_type=TaskType.SEQ_CLS,
521 | inference_mode=False,
522 | r=128,
523 | lora_alpha=256,
524 | lora_dropout=0.1,
525 | )
526 |
527 | torch.set_anomaly_enabled(True)
528 |
529 | trainer_class = trainer_classes[reward_model_type]
530 | if reward_model_type == "base":
531 | num_labels = 1
532 | elif reward_model_type == "mean_and_variance":
533 | num_labels = 2
534 | trainer_kwargs["variance_penalty"] = script_args.variance_penalty
535 | elif reward_model_type == "categorical":
536 | num_labels = script_args.num_atoms
537 | trainer_kwargs["entropy_coeff"] = script_args.entropy_coeff
538 |
539 | model = AutoModelForSequenceClassification.from_pretrained(
540 | script_args.model_name, num_labels=num_labels, torch_dtype=torch.bfloat16
541 | )
542 | # We multiply the final linear layer's weights by 0.01 because this seems to
543 | # significantly stabilize training and lead to better optimization of the loss.
544 | model.score.weight.data *= 0.01
545 | model = get_peft_model(model, peft_config)
546 | model.print_trainable_parameters()
547 |
548 | # Need to do this for GPT2 and Llama because they doesn't have official pad tokens.
549 | tokenizer.pad_token = tokenizer.eos_token
550 | tokenizer.pad_token_id = tokenizer.eos_token_id
551 | model.config.pad_token_id = tokenizer.pad_token_id
552 | tokenizer.padding_side = "right"
553 |
554 | model.config.use_cache = not script_args.gradient_checkpointing
555 | num_proc = 24 # Can adjust to be higher if you have more processors.
556 | original_columns = train_dataset.column_names
557 |
558 | train_dataset = train_dataset.map(
559 | HHRLHFPreprocessor(tokenizer),
560 | batched=True,
561 | num_proc=num_proc,
562 | remove_columns=original_columns,
563 | )
564 | train_dataset = train_dataset.filter(
565 | lambda x: len(x["input_ids_chosen"]) <= script_args.max_length
566 | and len(x["input_ids_rejected"]) <= script_args.max_length
567 | )
568 | print(len(train_dataset))
569 |
570 | eval_dataset = eval_dataset.map(
571 | HHRLHFPreprocessor(tokenizer),
572 | batched=True,
573 | num_proc=num_proc,
574 | remove_columns=original_columns,
575 | )
576 | eval_dataset = eval_dataset.filter(
577 | lambda x: len(x["input_ids_chosen"]) <= script_args.max_length
578 | and len(x["input_ids_rejected"]) <= script_args.max_length
579 | )
580 | print(len(eval_dataset))
581 |
582 | # We need to define a special data collator that batches the data in our j vs k format.
583 | @dataclass
584 | class RewardDataCollatorWithPadding:
585 | tokenizer: PreTrainedTokenizerBase
586 | padding: Union[bool, str, PaddingStrategy] = True
587 | max_length: Optional[int] = None
588 | pad_to_multiple_of: Optional[int] = None
589 | return_tensors: str = "pt"
590 |
591 | def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
592 | features_chosen = []
593 | features_rejected = []
594 | for feature in features:
595 | features_chosen.append(
596 | {
597 | "input_ids": feature["input_ids_chosen"],
598 | "attention_mask": feature["attention_mask_chosen"],
599 | }
600 | )
601 | features_rejected.append(
602 | {
603 | "input_ids": feature["input_ids_rejected"],
604 | "attention_mask": feature["attention_mask_rejected"],
605 | }
606 | )
607 | batch = self.tokenizer.pad(
608 | features_chosen + features_rejected,
609 | padding=self.padding,
610 | max_length=self.max_length,
611 | pad_to_multiple_of=self.pad_to_multiple_of,
612 | return_tensors=self.return_tensors,
613 | )
614 | input_ids = batch["input_ids"].view(2, -1, batch["input_ids"].shape[-1])
615 | attention_mask = batch["attention_mask"].view(
616 | 2, -1, batch["attention_mask"].shape[-1]
617 | )
618 | return {
619 | "input_ids_chosen": input_ids[0],
620 | "attention_mask_chosen": attention_mask[0],
621 | "input_ids_rejected": input_ids[1],
622 | "attention_mask_rejected": attention_mask[1],
623 | "return_loss": True,
624 | }
625 |
626 | # Train the model.
627 | trainer = trainer_class(
628 | model=model,
629 | args=training_args,
630 | train_dataset=train_dataset,
631 | eval_dataset=eval_dataset,
632 | compute_metrics=trainer_class.compute_metrics,
633 | data_collator=RewardDataCollatorWithPadding(
634 | tokenizer=tokenizer,
635 | max_length=script_args.max_length,
636 | pad_to_multiple_of=64,
637 | ),
638 | **trainer_kwargs,
639 | )
640 |
641 | trainer.train(script_args.resume_from_checkpoint)
642 |
643 | print("Saving last checkpoint of the model")
644 | model.save_pretrained(output_name + "_peft_last_checkpoint")
645 |
--------------------------------------------------------------------------------
/hidden_context/train_llm_vae_preference_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | from dataclasses import dataclass, field
3 | from typing import Any, Dict, List, Optional, Type, Union, cast
4 |
5 | import numpy as np
6 | import torch
7 | import random
8 | from peft import LoraConfig, TaskType, get_peft_model
9 | from transformers import (
10 | AutoModelForSequenceClassification,
11 | AutoTokenizer,
12 | HfArgumentParser,
13 | PreTrainedTokenizerBase,
14 | TrainingArguments,
15 | TrainerCallback,
16 | )
17 | from transformers.utils import PaddingStrategy
18 | from .vae_utils import VAETrainer, VAEModel
19 |
20 | from .train_llm_preference_model import (
21 | get_step_decay_lr_lambda,
22 | get_cosine_decay_lr_lambda,
23 | RewardModelType,
24 | DataSubset,
25 | get_hh_rlhf_dataset,
26 | concatenate_datasets
27 | )
28 |
29 |
30 | @dataclass
31 | class ScriptArguments:
32 | local_rank: int = field(default=-1, metadata={"help": "Used for multi-gpu"})
33 | resume_from_checkpoint: bool = field(
34 | default=False,
35 | metadata={"help": "If you want to resume training where it left off."},
36 | )
37 | deepspeed: Optional[str] = field(
38 | default=None,
39 | metadata={
40 | "help": "Path to deepspeed config if using deepspeed. You may need this "
41 | "if the model that you want to train doesn't fit on a single GPU."
42 | },
43 | )
44 | per_device_train_batch_size: int = field(default=2)
45 | per_device_eval_batch_size: int = field(default=1)
46 | gradient_accumulation_steps: int = field(default=1)
47 | learning_rate: float = field(default=3e-6)
48 | weight_decay: float = field(default=0.001)
49 | model_name: str = field(
50 | default="gpt2",
51 | metadata={
52 | "help": "The model that you want to train from the Hugging Face hub. "
53 | "E.g. gpt2, gpt2-xl, bert, etc."
54 | },
55 | )
56 | data_path: str = field(
57 | default="Anthropic/hh-rlhf",
58 | )
59 | data_subset: str = field(
60 | default="both",
61 | metadata={
62 | "help": "Which subset of the data to use. You can choose between 'both', "
63 | "'helpful', or 'harmless'."
64 | },
65 | )
66 | reward_model_type: str = field(
67 | default="base",
68 | metadata={
69 | "help": "The type of reward model to use. You can choose between "
70 | "'base', 'mean_and_variance', or 'categorical'."
71 | },
72 | )
73 | num_atoms: int = field(
74 | default=10,
75 | metadata={
76 | "help": "The number of atoms to use for the categorical reward model."
77 | },
78 | )
79 | entropy_coeff: float = field(
80 | default=0.1,
81 | metadata={"help": "The entropy coefficient for the categorical reward model."},
82 | )
83 | variance_penalty: float = field(
84 | default=0.0,
85 | metadata={
86 | "help": "The variance penalty for the mean and variance reward model."
87 | },
88 | )
89 | tokenizer_name: Optional[str] = field(
90 | default=None,
91 | metadata={
92 | "help": "The tokenizer for your model, if left empty will use the default "
93 | "for your model",
94 | },
95 | )
96 | bf16: bool = field(
97 | default=True,
98 | metadata={
99 | "help": "This essentially cuts the training time in half if you want to "
100 | "sacrifice a little precision and have a supported GPU."
101 | },
102 | )
103 | fp16: bool = field(
104 | default=False,
105 | metadata={
106 | "help": "This essentially cuts the training time in half if you want to "
107 | "sacrifice a little precision and have a supported GPU."
108 | },
109 | )
110 | num_train_epochs: int = field(
111 | default=1,
112 | metadata={"help": "The number of training epochs for the reward model."},
113 | )
114 | train_dataset_size: int = field(
115 | default=0,
116 | metadata={"help": "The size of the subset of the training data to use"},
117 | )
118 | eval_dataset_size: int = field(
119 | default=0,
120 | metadata={"help": "The size of the subset of the eval data to use"},
121 | )
122 | gradient_checkpointing: bool = field(
123 | default=False,
124 | metadata={"help": "Enables gradient checkpointing."},
125 | )
126 | optim: str = field(
127 | default="adamw_hf",
128 | metadata={"help": "The optimizer to use."},
129 | )
130 | lr_scheduler_type: str = field(
131 | default="cosine",
132 | metadata={"help": "The lr scheduler"},
133 | )
134 | max_length: int = field(default=1024)
135 | eval_first_step: bool = field(
136 | default=True,
137 | metadata={"help": "Whether to run eval after the first step"},
138 | )
139 | log_dir: str = field(default="data/reward_models/hh_rlhf")
140 | kl_loss_weight: float = field(default=0.01, metadata={"help": "weight for KLD loss"})
141 | latent_dim: int = field(default=512, metadata={"help": "dimension of latent user vector"}) # todo: 64
142 | hidden_dim: int = field(default=512, metadata={"help": "dimension of hidden layer in vae"}) # todo: 256
143 | encoder_embed_dim: int = field(default=1024, metadata={"help": "dimension of LLM embeddings for encoder"})
144 | decoder_embed_dim: int = field(default=1024, metadata={"help": "dimension of LLM embeddings for decoder"})
145 | use_annealing: bool = field(default=True, metadata={"help": "Whether to use annealing for learning rate"})
146 | fixed_contexts: bool = field(
147 | default=False,
148 | metadata={"help": "whether to use pre-calculated embeddings for contexts (encoder inputs)"}
149 | )
150 | fixed_llm_embeddings: bool = field(
151 | default=False,
152 | metadata={"help": "whether to use pre-calculated embeddings for decoder inputs"}
153 | )
154 | seed: int = field(default=0)
155 | controversial_only: bool = field(
156 | default=False,
157 | metadata={"help": "whether to only include controversial data"}
158 | )
159 | up_sampling: bool = field(
160 | default=False,
161 | metadata={"help": "whether to upsample controversial data during training phase"}
162 | )
163 | one_user: str = field(
164 | default=None,
165 | metadata={"help": "whether to only train and evaluate on one single user"}
166 | )
167 | other_subsets: str = field(
168 | default=None,
169 | metadata={"help": "specify the group of subsets if not using helpful/harmless. You can choose between"
170 | "ultra_feedback, pos_neg, set, single."},
171 | )
172 | use_last_token_embedding: bool = field(
173 | default=False,
174 | metadata={"help": "whether to use the last token embedding of last layer as LLM embeddings"}
175 | )
176 |
177 | class HHRLHFPreprocessor(object):
178 | def __init__(self, args, tokenizer, **tokenizer_kwargs):
179 | self.tokenizer = tokenizer
180 | self.args = args
181 | self.tokenizer_kwargs = tokenizer_kwargs
182 |
183 | def __call__(self, examples):
184 | if self.args.fixed_llm_embeddings:
185 | new_examples: dict = {
186 | "embedding_chosen": [],
187 | "embedding_rejected": [],
188 | "contexts_embeddings": [],
189 | "max_lengths": []
190 | }
191 | for embeddings, contexts in zip(
192 | examples["embeddings"], examples["contexts"]
193 | ):
194 | new_examples["embedding_chosen"].append(embeddings["embedding_chosen"])
195 | new_examples["embedding_rejected"].append(embeddings["embedding_rejected"])
196 | contexts_embeddings = [{"embedding_chosen": context["embedding_chosen"],
197 | "embedding_rejected": context["embedding_rejected"]}
198 | for context in contexts]
199 | new_examples["contexts_embeddings"].append(contexts_embeddings)
200 | new_examples["max_lengths"].append(0)
201 | new_examples["user_type"] = examples["data_subset"]
202 | return new_examples
203 |
204 | new_examples: dict = {
205 | "input_ids_chosen": [],
206 | "attention_mask_chosen": [],
207 | "input_ids_rejected": [],
208 | "attention_mask_rejected": [],
209 | "max_lengths": []
210 | }
211 | if self.args.fixed_contexts:
212 | new_examples["contexts_embeddings"] = []
213 | else:
214 | new_examples["contexts_tokens"] = []
215 | for chosen, rejected, contexts, user_type in zip(
216 | examples["chosen"], examples["rejected"], examples["contexts"], examples["data_subset"]
217 | ):
218 | max_length = 0
219 | tokenized_chosen = self.tokenizer(chosen, **self.tokenizer_kwargs)
220 | tokenized_rejected = self.tokenizer(rejected, **self.tokenizer_kwargs)
221 | new_examples["input_ids_chosen"].append(tokenized_chosen["input_ids"])
222 | new_examples["attention_mask_chosen"].append(
223 | tokenized_chosen["attention_mask"]
224 | )
225 | new_examples["input_ids_rejected"].append(tokenized_rejected["input_ids"])
226 | new_examples["attention_mask_rejected"].append(
227 | tokenized_rejected["attention_mask"]
228 | )
229 | max_length = max(max_length, len(tokenized_chosen["input_ids"]))
230 | max_length = max(max_length, len(tokenized_rejected["input_ids"]))
231 |
232 | if self.args.fixed_contexts:
233 | contexts_embeddings = [{"embedding_chosen": context["embedding_chosen"],
234 | "embedding_rejected": context["embedding_rejected"]}
235 | for context in contexts]
236 | new_examples["contexts_embeddings"].append(contexts_embeddings)
237 | else:
238 | tokenized_context = []
239 | # Tokenize the contexts.
240 | for context in contexts:
241 | chosen, rejected = context["chosen"], context["rejected"]
242 | tokenized_chosen = self.tokenizer(chosen, **self.tokenizer_kwargs)
243 | tokenized_rejected = self.tokenizer(rejected, **self.tokenizer_kwargs)
244 | tokenized_context.append(
245 | {
246 | "input_ids_chosen": tokenized_chosen["input_ids"],
247 | "attention_mask_chosen": tokenized_chosen["attention_mask"],
248 | "input_ids_rejected": tokenized_rejected["input_ids"],
249 | "attention_mask_rejected": tokenized_rejected["attention_mask"],
250 | }
251 | )
252 | max_length = max(max_length, len(tokenized_chosen["input_ids"]))
253 | max_length = max(max_length, len(tokenized_rejected["input_ids"]))
254 | new_examples["contexts_tokens"].append(tokenized_context)
255 | new_examples["max_lengths"].append(max_length)
256 | new_examples["user_type"] = examples["data_subset"]
257 | return new_examples
258 |
259 |
260 | trainer_classes: Dict[RewardModelType, Type[VAETrainer]] = {
261 | "vae": VAETrainer,
262 | }
263 |
264 |
265 | # We need to define a special data collator that batches the data in our j vs k format.
266 | @dataclass
267 | class RewardDataCollatorWithPadding:
268 | args: ScriptArguments
269 | tokenizer: PreTrainedTokenizerBase
270 | padding: Union[bool, str, PaddingStrategy] = True
271 | max_length: Optional[int] = None
272 | pad_to_multiple_of: Optional[int] = None
273 | return_tensors: str = "pt"
274 |
275 | def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
276 | if self.args.other_subsets is None:
277 | user_mapping = {
278 | "helpful": 0,
279 | "harmless": 1,
280 | }
281 | else: # TODO: set subsets here
282 | if self.args.other_subsets == 'ultra_feedback':
283 | subsets = ['helpfulness', 'honesty', 'instruction_following', 'truthfulness']
284 | elif self.args.other_subsets == 'single' or self.args.other_subsets == '84':
285 | subsets = ['8', '4', '2', '1']
286 | else:
287 | subsets = []
288 | user_mapping = {subset: idx for idx, subset in enumerate(subsets)}
289 | if self.args.fixed_llm_embeddings:
290 | batch_size = len(features)
291 | embeddings_chosen = []
292 | embeddings_rejected = []
293 | contexts_embeddings_chosen = []
294 | contexts_embeddings_rejected = []
295 | contexts_lengths = [0]
296 | for feature in features:
297 | embeddings_chosen.append(
298 | feature["embedding_chosen"]
299 | )
300 | embeddings_rejected.append(
301 | feature["embedding_rejected"]
302 | )
303 | contexts_embeddings_chosen.extend(
304 | [
305 | context["embedding_chosen"] for context in feature["contexts_embeddings"]
306 | ]
307 | )
308 | contexts_embeddings_rejected.extend(
309 | [
310 | context["embedding_rejected"] for context in feature["contexts_embeddings"]
311 | ]
312 | )
313 | contexts_lengths.append(len(feature["contexts_embeddings"]))
314 | contexts_lengths = torch.cumsum(torch.tensor(contexts_lengths), dim=0)
315 | seq_start_end = torch.stack(
316 | [contexts_lengths[:-1], contexts_lengths[1:]], dim=1
317 | )
318 | user_type = [user_mapping[feature["user_type"]] for feature in features]
319 | assert len(seq_start_end) == batch_size
320 | return {
321 | "embeddings_chosen": embeddings_chosen,
322 | "embeddings_rejected": embeddings_rejected,
323 | "contexts_embeddings_chosen": contexts_embeddings_chosen,
324 | "contexts_embeddings_rejected": contexts_embeddings_rejected,
325 | "seq_start_end": seq_start_end,
326 | "return_loss": True,
327 | "user_type": user_type,
328 | }
329 | if self.args.fixed_contexts:
330 | batch_size = len(features)
331 | features_chosen = []
332 | features_rejected = []
333 | contexts_embeddings_chosen = []
334 | contexts_embeddings_rejected = []
335 | contexts_lengths = [0]
336 | for feature in features:
337 | features_chosen.append(
338 | {
339 | "input_ids": feature["input_ids_chosen"],
340 | "attention_mask": feature["attention_mask_chosen"],
341 | }
342 | )
343 | features_rejected.append(
344 | {
345 | "input_ids": feature["input_ids_rejected"],
346 | "attention_mask": feature["attention_mask_rejected"],
347 | }
348 | )
349 | # Creating a flattened list of contexts.
350 | contexts_embeddings_chosen.extend(
351 | [
352 | context["embedding_chosen"] for context in feature["contexts_embeddings"]
353 | ]
354 | )
355 | contexts_embeddings_rejected.extend(
356 | [
357 | context["embedding_rejected"] for context in feature["contexts_embeddings"]
358 | ]
359 | )
360 | # Keep track of the start and end of each sequence.
361 | contexts_lengths.append(len(feature["contexts_embeddings"]))
362 |
363 | batch = self.tokenizer.pad(
364 | features_chosen + features_rejected,
365 | padding=self.padding,
366 | max_length=self.max_length,
367 | pad_to_multiple_of=self.pad_to_multiple_of,
368 | return_tensors=self.return_tensors,
369 | )
370 |
371 | input_ids = batch["input_ids"].view(
372 | 2, batch_size, batch["input_ids"].shape[-1]
373 | )
374 | attention_mask = batch["attention_mask"].view(
375 | 2, batch_size, batch["attention_mask"].shape[-1]
376 | )
377 |
378 | context_lengths = torch.cumsum(torch.tensor(contexts_lengths), dim=0)
379 | seq_start_end = torch.stack(
380 | [context_lengths[:-1], context_lengths[1:]], dim=1
381 | )
382 | user_type = [user_mapping[feature["user_type"]] for feature in features]
383 | assert len(seq_start_end) == batch_size
384 |
385 | return {
386 | "input_ids_chosen": input_ids[0],
387 | "attention_mask_chosen": attention_mask[0],
388 | "input_ids_rejected": input_ids[1],
389 | "attention_mask_rejected": attention_mask[1],
390 | "contexts_embeddings_chosen": contexts_embeddings_chosen,
391 | "contexts_embeddings_rejected": contexts_embeddings_rejected,
392 | "seq_start_end": seq_start_end,
393 | "return_loss": True,
394 | "user_type": user_type,
395 | }
396 |
397 | batch_size = len(features)
398 | features_chosen = []
399 | features_rejected = []
400 | contexts_features_chosen = []
401 | contexts_features_rejected = []
402 | contexts_lengths = [0]
403 | for feature in features:
404 | features_chosen.append(
405 | {
406 | "input_ids": feature["input_ids_chosen"],
407 | "attention_mask": feature["attention_mask_chosen"],
408 | }
409 | )
410 | features_rejected.append(
411 | {
412 | "input_ids": feature["input_ids_rejected"],
413 | "attention_mask": feature["attention_mask_rejected"],
414 | }
415 | )
416 |
417 | # Creating a flattened list of contexts.
418 | contexts_features_chosen.extend(
419 | [
420 | {
421 | "input_ids": context["input_ids_chosen"],
422 | "attention_mask": context["attention_mask_chosen"],
423 | }
424 | for context in feature["contexts_tokens"]
425 | ]
426 | )
427 | contexts_features_rejected.extend(
428 | [
429 | {
430 | "input_ids": context["input_ids_rejected"],
431 | "attention_mask": context["attention_mask_rejected"],
432 | }
433 | for context in feature["contexts_tokens"]
434 | ]
435 | )
436 | # Keep track of the start and end of each sequence.
437 | contexts_lengths.append(len(feature["contexts_tokens"]))
438 |
439 | batch = self.tokenizer.pad(
440 | features_chosen + features_rejected + contexts_features_chosen + contexts_features_rejected,
441 | padding=self.padding,
442 | max_length=self.max_length,
443 | pad_to_multiple_of=self.pad_to_multiple_of,
444 | return_tensors=self.return_tensors,
445 | )
446 |
447 | input_ids = batch["input_ids"][:2 * batch_size].view(
448 | 2, batch_size, batch["input_ids"].shape[-1]
449 | )
450 | attention_mask = batch["attention_mask"][:2 * batch_size].view(
451 | 2, batch_size, batch["attention_mask"].shape[-1]
452 | )
453 |
454 | contexts_lengths = torch.cumsum(torch.tensor(contexts_lengths), dim=0)
455 | seq_start_end = torch.stack(
456 | [contexts_lengths[:-1], contexts_lengths[1:]], dim=1
457 | )
458 | user_type = [user_mapping[feature["user_type"]] for feature in features]
459 | assert len(seq_start_end) == batch_size
460 | context_ids = batch["input_ids"][2 * batch_size:].view(
461 | 2, contexts_lengths[-1], batch["input_ids"].shape[-1]
462 | )
463 | context_attention_mask = batch["attention_mask"][2 * batch_size:].view(
464 | 2, contexts_lengths[-1], batch["attention_mask"].shape[-1]
465 | )
466 |
467 | return {
468 | "input_ids_chosen": input_ids[0],
469 | "attention_mask_chosen": attention_mask[0],
470 | "input_ids_rejected": input_ids[1],
471 | "attention_mask_rejected": attention_mask[1],
472 | "contexts_input_ids_chosen": context_ids[0],
473 | "contexts_attention_mask_chosen": context_attention_mask[0],
474 | "contexts_input_ids_rejected": context_ids[1],
475 | "contexts_attention_mask_rejected": context_attention_mask[1],
476 | "seq_start_end": seq_start_end,
477 | "return_loss": True,
478 | "user_type": user_type,
479 | }
480 |
481 |
482 | def up_sample_controversial(dataset, seed):
483 | cont = dataset.filter(lambda example: example['controversial'] == True)
484 | up_sampled_dataset = concatenate_datasets([cont] * 4 + [dataset])
485 | up_sampled_dataset = up_sampled_dataset.shuffle(seed=seed)
486 | return up_sampled_dataset
487 |
488 |
489 | def customized_optimizer(model, lr):
490 | encoder_params = [p for p in model.parameters() if p not in model.decoder.parameters()]
491 | decoder_params = [p for p in model.parameters() if p in model.decoder.parameters()]
492 | grouped_parameters = [
493 | {'params': encoder_params, 'lr': lr},
494 | {'params': decoder_params, 'lr': lr / 10},
495 | ]
496 | return
497 |
498 |
499 | if __name__ == "__main__":
500 | parser = HfArgumentParser(ScriptArguments)
501 | script_args: ScriptArguments = parser.parse_args_into_dataclasses()[0]
502 |
503 | seed = script_args.seed
504 | random.seed(seed)
505 | np.random.seed(seed)
506 | torch.manual_seed(seed)
507 | torch.cuda.manual_seed(seed)
508 |
509 | torch.set_default_dtype(torch.bfloat16 if script_args.bf16 else torch.float32)
510 |
511 | if script_args.use_last_token_embedding:
512 | if script_args.model_name == 'gpt2':
513 | script_args.decoder_embed_dim = 768
514 | script_args.encoder_embed_dim = 768
515 | if script_args.model_name == 'meta-llama/Llama-2-7b-hf':
516 | script_args.decoder_embed_dim = 4096
517 | script_args.encoder_embed_dim = 4096
518 |
519 | data_subset = cast(DataSubset, script_args.data_subset)
520 | train_dataset = get_hh_rlhf_dataset(
521 | data_subset,
522 | "train",
523 | script_args.train_dataset_size,
524 | data_path=script_args.data_path,
525 | other_subsets=script_args.other_subsets
526 | )
527 | eval_dataset = get_hh_rlhf_dataset(
528 | data_subset,
529 | "test",
530 | script_args.eval_dataset_size,
531 | data_path=script_args.data_path,
532 | other_subsets=script_args.other_subsets
533 | )
534 | print(len(train_dataset), len(eval_dataset))
535 | if script_args.controversial_only:
536 | train_dataset = train_dataset.filter(lambda example: example['controversial'] == True)
537 | eval_dataset = eval_dataset.filter(lambda example: example['controversial'] == True)
538 | elif script_args.up_sampling:
539 | train_dataset = up_sample_controversial(train_dataset, seed)
540 |
541 | if script_args.one_user:
542 | train_dataset = train_dataset.filter(lambda example: example['data_subset'] == script_args.one_user)
543 | eval_dataset = eval_dataset.filter(lambda example: example['data_subset'] == script_args.one_user)
544 | reward_model_type = cast(RewardModelType, script_args.reward_model_type)
545 |
546 | # Define the training args. Needs to be done before the model is loaded if you
547 | # are using deepspeed.
548 | model_name_split = script_args.model_name.split("/")[-1]
549 | output_name = (
550 | f"{script_args.log_dir}/{data_subset}/"
551 | f"{reward_model_type}_{model_name_split}"
552 | f"__{script_args.train_dataset_size}_{script_args.learning_rate}"
553 | f"_{script_args.lr_scheduler_type}_{script_args.num_train_epochs}"
554 | )
555 | output_name += f"_{script_args.kl_loss_weight}_{script_args.latent_dim}_{script_args.decoder_embed_dim}_seed{script_args.seed}"
556 |
557 | trainer_kwargs: Dict[str, Any] = {}
558 | if script_args.lr_scheduler_type == "step":
559 | lr_scheduler_type = "constant"
560 | trainer_kwargs["lr_lambda"] = get_step_decay_lr_lambda
561 | elif script_args.lr_scheduler_type == "cosine":
562 | lr_scheduler_type = "constant"
563 | trainer_kwargs["lr_lambda"] = get_cosine_decay_lr_lambda
564 | else:
565 | lr_scheduler_type = script_args.lr_scheduler_type
566 |
567 | training_args = TrainingArguments(
568 | output_dir=output_name,
569 | learning_rate=script_args.learning_rate,
570 | per_device_train_batch_size=script_args.per_device_train_batch_size,
571 | per_device_eval_batch_size=script_args.per_device_eval_batch_size,
572 | num_train_epochs=script_args.num_train_epochs,
573 | weight_decay=script_args.weight_decay,
574 | evaluation_strategy="steps",
575 | eval_steps=0.05,
576 | save_strategy="steps",
577 | save_steps=10000,
578 | gradient_accumulation_steps=script_args.gradient_accumulation_steps,
579 | gradient_checkpointing=script_args.gradient_checkpointing,
580 | deepspeed=script_args.deepspeed,
581 | local_rank=script_args.local_rank,
582 | remove_unused_columns=False,
583 | label_names=[],
584 | bf16=script_args.bf16,
585 | fp16=script_args.fp16,
586 | logging_strategy="steps",
587 | logging_steps=100,
588 | optim=script_args.optim,
589 | lr_scheduler_type=lr_scheduler_type,
590 | report_to="wandb",
591 | run_name=output_name.split("/")[-1],
592 | )
593 | # Load the value-head model and tokenizer.
594 | tokenizer_name = (
595 | script_args.tokenizer_name
596 | if script_args.tokenizer_name is not None
597 | else script_args.model_name
598 | )
599 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_auth_token=True, add_eos_token=False)
600 |
601 | peft_config = LoraConfig(
602 | task_type=TaskType.SEQ_CLS,
603 | inference_mode=False,
604 | r=128,
605 | lora_alpha=256,
606 | lora_dropout=0.1,
607 | )
608 |
609 | torch.set_anomaly_enabled(True)
610 |
611 | trainer_class = trainer_classes[reward_model_type]
612 | decoder_embed_dim = script_args.decoder_embed_dim
613 | encoder_embed_dim = script_args.encoder_embed_dim
614 |
615 | model = AutoModelForSequenceClassification.from_pretrained(
616 | script_args.model_name, num_labels=decoder_embed_dim, torch_dtype=torch.bfloat16
617 | )
618 | # We multiply the final linear layer's weights by 0.01 because this seems to
619 | # significantly stabilize training and lead to better optimization of the loss.
620 | model.score.weight.data *= 0.01
621 | if not script_args.fixed_contexts:
622 | contexts_model = AutoModelForSequenceClassification.from_pretrained(
623 | script_args.model_name, num_labels=encoder_embed_dim, torch_dtype=torch.bfloat16
624 | )
625 | contexts_model.score.weight.data *= 0.01
626 | model = get_peft_model(model, peft_config)
627 | model.print_trainable_parameters()
628 |
629 | if not script_args.fixed_contexts:
630 | contexts_model = get_peft_model(contexts_model, peft_config)
631 | contexts_model.print_trainable_parameters()
632 | contexts_model.config.pad_token_id = tokenizer.pad_token_id
633 | contexts_model.config.use_cache = not script_args.gradient_checkpointing
634 | else:
635 | contexts_model = None
636 |
637 | # Need to do this for GPT2 and Llama because they don't have official pad tokens.
638 | tokenizer.pad_token = tokenizer.eos_token
639 | tokenizer.pad_token_id = tokenizer.eos_token_id
640 | model.config.pad_token_id = tokenizer.pad_token_id
641 | tokenizer.padding_side = "right"
642 |
643 | model.config.use_cache = not script_args.gradient_checkpointing
644 | num_proc = 24 # Can adjust to be higher if you have more processors.
645 | original_columns = train_dataset.column_names
646 |
647 | train_dataset = train_dataset.map(
648 | HHRLHFPreprocessor(script_args, tokenizer),
649 | batched=True,
650 | num_proc=num_proc,
651 | remove_columns=original_columns,
652 | )
653 | train_dataset = train_dataset.filter(
654 | lambda x: x["max_lengths"] <= script_args.max_length
655 | )
656 |
657 | eval_dataset = eval_dataset.map(
658 | HHRLHFPreprocessor(script_args, tokenizer),
659 | batched=True,
660 | num_proc=num_proc,
661 | remove_columns=original_columns,
662 | )
663 | eval_dataset = eval_dataset.filter(
664 | lambda x: x["max_lengths"] <= script_args.max_length
665 | )
666 |
667 | # Train the model.
668 | latent_dim = script_args.latent_dim
669 | hidden_dim = script_args.hidden_dim
670 | vae_model = VAEModel(encoder_embed_dim, decoder_embed_dim, hidden_dim, latent_dim, model, contexts_model,
671 | fixed_contexts=script_args.fixed_contexts,
672 | fixed_llm_embeddings=script_args.fixed_llm_embeddings,)
673 |
674 | trainer = trainer_class(
675 | model=vae_model,
676 | args=training_args,
677 | train_dataset=train_dataset,
678 | eval_dataset=eval_dataset,
679 | compute_metrics=trainer_class.compute_metrics,
680 | data_collator=RewardDataCollatorWithPadding(
681 | args=script_args,
682 | tokenizer=tokenizer,
683 | max_length=script_args.max_length,
684 | pad_to_multiple_of=64,
685 | ),
686 | kl_loss_weight=script_args.kl_loss_weight,
687 | use_annealing=script_args.use_annealing,
688 | **trainer_kwargs,
689 | )
690 |
691 | class EvaluateFirstStepCallback(TrainerCallback):
692 | def on_step_begin(self, args, state, control, **kwargs):
693 | if state.global_step == 0:
694 | control.should_evaluate = True
695 |
696 |
697 | trainer.add_callback(EvaluateFirstStepCallback())
698 |
699 | trainer.train(script_args.resume_from_checkpoint)
700 |
701 | print("Saving last checkpoint of the model")
702 |
703 | model.save_pretrained(output_name + "_peft_last_checkpoint")
704 | output_name += "_peft_last_checkpoint"
705 | os.makedirs(output_name, exist_ok=True)
706 |
707 | output_name = os.path.join(output_name, "model.pt")
708 | vae_model.save_model(output_name)
709 |
--------------------------------------------------------------------------------
/hidden_context/vae_utils.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 | import torch
4 | import torch.nn.functional as F
5 | import torch.nn as nn
6 | from transformers import Trainer, EvalPrediction
7 | import wandb
8 | from transformers.optimization import get_cosine_schedule_with_warmup
9 |
10 |
11 | class PairEncoder(nn.Module):
12 | """
13 | Model to encode pairs of accepted and rejected responses
14 | """
15 |
16 | def __init__(self, embed_dim, hidden_dim, output_dim):
17 | super(PairEncoder, self).__init__()
18 |
19 | self._model = nn.Sequential(
20 | nn.Linear(2 * embed_dim, hidden_dim),
21 | nn.LeakyReLU(0.2),
22 | nn.Linear(hidden_dim, hidden_dim),
23 | nn.LeakyReLU(0.2),
24 | nn.Linear(hidden_dim, output_dim),
25 | )
26 |
27 | def forward(self, e_c, e_r):
28 | x = torch.cat([e_c, e_r], dim=1)
29 | return self._model(x)
30 |
31 |
32 | class SequenceEncoder(nn.Module):
33 | """
34 | Model to encode sequence of responses
35 | """
36 |
37 | def __init__(self, input_dim, latent_dim):
38 | super(SequenceEncoder, self).__init__()
39 | self.input_dim = input_dim
40 | self.latent_dim = latent_dim
41 |
42 | self.linear = nn.Identity()
43 | self.w_q = nn.Linear(input_dim, input_dim)
44 | self.w_k = nn.Linear(input_dim, input_dim)
45 | self.w_v = nn.Linear(input_dim, input_dim)
46 | self.mean_layer = nn.Linear(input_dim, latent_dim)
47 | self.log_var_layer = nn.Linear(input_dim, latent_dim)
48 | self.layer_norm = nn.Identity() # nn.LayerNorm(latent_dim)
49 |
50 | def forward(
51 | self, sequences, seq_start_end
52 | ): # (C_1+C_2+...+C_n, D), [(0, C_1), (C_1, C_1+C_2), ..., (C_1+...+C_n-1, C_1+...+C_n)]
53 | outputs = []
54 | for _, (start, end) in enumerate(seq_start_end):
55 | context = sequences[start:end] # C_i x D
56 | q = self.w_q(context)
57 | k = self.w_k(context)
58 | attention_scores = torch.matmul(
59 | q, k.transpose(0, 1)
60 | )
61 | attention_scores = attention_scores / (context.shape[-1] ** 0.5)
62 | attention_weights = F.softmax(attention_scores, dim=-1) # C_i x C_i
63 | weighted_values = torch.matmul(attention_weights, self.w_v(context)) # C_i x D
64 | output = torch.mean(weighted_values, dim=0) # D
65 | outputs.append(output)
66 | outputs = torch.stack(outputs, dim=0) # n x D
67 |
68 | mean = self.layer_norm(self.mean_layer(outputs))
69 | log_var = self.layer_norm(self.log_var_layer(outputs))
70 | return mean, log_var
71 |
72 |
73 | class Decoder(nn.Module):
74 | def __init__(self, input_dim, hidden_dim):
75 | super(Decoder, self).__init__()
76 | self._model = nn.Sequential(
77 | nn.Linear(input_dim, hidden_dim),
78 | nn.LeakyReLU(0.2),
79 | nn.Linear(hidden_dim, hidden_dim),
80 | nn.LeakyReLU(0.2),
81 | nn.Linear(hidden_dim, 1),
82 | )
83 |
84 | def forward(self, xc, xr, z):
85 | xc = torch.cat([xc, z], dim=1)
86 | xr = torch.cat([xr, z], dim=1)
87 | rc = self._model(xc)
88 | rr = self._model(xr)
89 | return rc, rr
90 |
91 |
92 | class VAEModel(nn.Module):
93 | def __init__(self, encoder_embed_dim, decoder_embed_dim, hidden_dim, latent_dim, llm_encoder, llm_contexts_encoder,
94 | fixed_contexts=False, fixed_llm_embeddings=False, use_causal_lm=False, use_attention_layer=False,
95 | use_transformer=False, concat_chosen_rejected=False):
96 | super(VAEModel, self).__init__()
97 | self.llm_encoder = llm_encoder
98 | self.llm_contexts_encoder = llm_contexts_encoder
99 | self.pair_encoder = PairEncoder(encoder_embed_dim, hidden_dim, latent_dim)
100 | self.sequence_encoder = SequenceEncoder(latent_dim, latent_dim)
101 | self.decoder = Decoder(decoder_embed_dim + latent_dim, hidden_dim)
102 |
103 | self.latent_dim = latent_dim
104 | self.fixed_contexts = fixed_contexts
105 | self.fixed_llm_embeddings = fixed_llm_embeddings
106 | self.use_causal_lm = use_causal_lm
107 | self.use_attention_layer = use_attention_layer
108 | self.use_transformer = use_transformer
109 | self.concat_chosen_rejected = concat_chosen_rejected
110 |
111 | self.saved_embeddings = torch.Tensor(4, latent_dim)
112 | self.saved_embeddings.uniform_(-1, 1)
113 |
114 | def reparameterization(self, mean, std):
115 | epsilon = torch.randn_like(std).to(mean.device) # sampling epsilon
116 | z = mean + std * epsilon # reparameterization trick
117 | z = F.normalize(z, p=2, dim=-1) * math.sqrt(z.shape[-1])
118 | return z
119 |
120 | def encode_pair(self, e_c, e_r):
121 | return self.pair_encoder(e_c, e_r)
122 |
123 | def encode_sequence(self, sequences, seq_start_end):
124 | return self.sequence_encoder(sequences, seq_start_end)
125 |
126 | def decode(self, e_c, e_r, z):
127 | return self.decoder(e_c, e_r, z)
128 |
129 | def forward(
130 | self,
131 | target_chosen,
132 | target_rejected,
133 | context_chosen,
134 | context_rejected,
135 | seq_start_end,
136 | user_type,
137 | ground_truth_user_vector=False,
138 | ):
139 | pair_embed = self.encode_pair(context_chosen, context_rejected)
140 | mean, log_var = self.encode_sequence(pair_embed, seq_start_end)
141 | mean = torch.clamp(mean, -1, 1)
142 |
143 | _log_var = torch.clamp(log_var, -1, 1)
144 | if ground_truth_user_vector:
145 | z = torch.zeros_like(mean)
146 | self.saved_embeddings = self.saved_embeddings.to(mean.device)
147 | for idx in range(user_type.shape[0]):
148 | z[idx] = self.saved_embeddings[int(user_type[idx])]
149 | else:
150 | z = self.reparameterization(mean, torch.exp(0.5 * _log_var))
151 |
152 | if not self.training and not ground_truth_user_vector:
153 | z = mean
154 | rc, rr = self.decode(target_chosen, target_rejected, z)
155 |
156 | return rc, rr, mean, _log_var, z
157 |
158 | def save_model(self, path):
159 | torch.save(self, path)
160 |
161 |
162 | class VAETrainer(Trainer):
163 | def __init__(
164 | self, *args, lr_lambda=None, kl_loss_weight=None, use_annealing=False, **kwargs
165 | ):
166 | super().__init__(*args, **kwargs)
167 | self.lr_lambda = lr_lambda
168 | self.kl_loss_weight = kl_loss_weight
169 | self.use_annealing = use_annealing
170 | self.annealer = Annealer(
171 | total_steps=1e4, shape="cosine", baseline=0.1, cyclical=True # todo: change total_step here
172 | )
173 |
174 | @classmethod
175 | def per_sample_loss(cls, rewards_chosen, rewards_rejected):
176 | return -nn.functional.logsigmoid(rewards_chosen - rewards_rejected)
177 |
178 | def loss(self, rewards_chosen, rewards_rejected):
179 | return torch.mean(self.per_sample_loss(rewards_chosen, rewards_rejected))
180 |
181 | def compute_loss(self, wrapped_model, inputs, return_outputs=False):
182 | if isinstance(wrapped_model, VAEModel):
183 | model = wrapped_model # .module
184 | else:
185 | model = wrapped_model.module
186 | device = model.llm_encoder.device
187 | batch_size = inputs["seq_start_end"].shape[0]
188 | if model.fixed_llm_embeddings:
189 | embeddings_chosen = torch.tensor(inputs["embeddings_chosen"]).to(device).bfloat16()
190 | embeddings_rejected = torch.tensor(inputs["embeddings_rejected"]).to(device).bfloat16()
191 | else:
192 | embeddings = model.llm_encoder(
193 | input_ids=torch.concatenate(
194 | [
195 | inputs["input_ids_chosen"],
196 | inputs["input_ids_rejected"],
197 | ],
198 | dim=0,
199 | ),
200 | attention_mask=torch.concatenate(
201 | [
202 | inputs["attention_mask_chosen"],
203 | inputs["attention_mask_rejected"],
204 | ],
205 | dim=0,
206 | ),
207 | )[0]
208 | embeddings_chosen = embeddings[:batch_size]
209 | embeddings_rejected = embeddings[batch_size:]
210 |
211 | if model.fixed_contexts:
212 | contexts_embeddings_chosen = torch.tensor(inputs["contexts_embeddings_chosen"]).to(device).bfloat16()
213 | contexts_embeddings_rejected = torch.tensor(inputs["contexts_embeddings_rejected"]).to(device).bfloat16()
214 | else:
215 | input_ids_chosen = inputs["contexts_input_ids_chosen"]
216 | attention_mask_chosen = inputs["contexts_attention_mask_chosen"]
217 | token_length_chosen = torch.eq(input_ids_chosen,
218 | model.llm_contexts_encoder.config.pad_token_id).int().argmax(-1) - 1
219 | input_ids_rejected = inputs["contexts_input_ids_rejected"]
220 | attention_mask_rejected = inputs["contexts_attention_mask_rejected"]
221 | token_length_rejected = torch.eq(input_ids_rejected,
222 | model.llm_contexts_encoder.config.pad_token_id).int().argmax(-1) - 1
223 |
224 | with torch.no_grad():
225 | last_hidden_state_chosen = model.llm_contexts_encoder(
226 | input_ids=input_ids_chosen,
227 | attention_mask=attention_mask_chosen,
228 | output_hidden_states=True
229 | ).hidden_states[-1]
230 |
231 | weights_for_non_padding_chosen = attention_mask_chosen * torch.arange(
232 | start=1, end=last_hidden_state_chosen.shape[1] + 1
233 | ).unsqueeze(0).to(attention_mask_chosen.device).float()
234 | sum_embeddings = torch.sum(last_hidden_state_chosen * weights_for_non_padding_chosen.unsqueeze(-1),
235 | dim=1)
236 | num_of_none_padding_tokens_chosen = torch.sum(weights_for_non_padding_chosen, dim=-1).unsqueeze(-1)
237 | contexts_embeddings_chosen = sum_embeddings / num_of_none_padding_tokens_chosen
238 | last_hidden_state_rejected = model.llm_contexts_encoder(
239 | input_ids=input_ids_rejected,
240 | attention_mask=attention_mask_rejected,
241 | output_hidden_states=True
242 | ).hidden_states[-1]
243 |
244 | weights_for_non_padding_rejected = attention_mask_rejected * torch.arange(
245 | start=1, end=last_hidden_state_rejected.shape[1] + 1
246 | ).unsqueeze(0).to(attention_mask_rejected.device).float()
247 | sum_embeddings = torch.sum(last_hidden_state_rejected * weights_for_non_padding_rejected.unsqueeze(-1),
248 | dim=1)
249 | num_of_none_padding_tokens_rejected = torch.sum(weights_for_non_padding_rejected, dim=-1).unsqueeze(-1)
250 | contexts_embeddings_rejected = sum_embeddings / num_of_none_padding_tokens_rejected
251 | seq_start_end = inputs["seq_start_end"]
252 | user_type = torch.tensor(inputs["user_type"]).to(device).bfloat16()
253 | rewards_chosen, rewards_rejected, mean, log_var, z = model(
254 | embeddings_chosen,
255 | embeddings_rejected,
256 | contexts_embeddings_chosen,
257 | contexts_embeddings_rejected,
258 | seq_start_end,
259 | user_type,
260 | ground_truth_user_vector=False, # todo: set to True for debug usage
261 | mask_chosen=inputs["attention_mask_chosen"],
262 | mask_rejected=inputs["attention_mask_rejected"],
263 | )
264 |
265 | reproduction_loss = self.loss(rewards_chosen, rewards_rejected)
266 | if self.kl_loss_weight == 0:
267 | loss = reproduction_loss
268 | accuracy = torch.mean((rewards_chosen > rewards_rejected).float())
269 | if not return_outputs:
270 | self.log(
271 | {
272 | "train_recon_loss": reproduction_loss.mean().item(),
273 | "train_accuracy": accuracy.mean().item(),
274 | "rewards_chosen": rewards_chosen.mean().item(),
275 | "rewards_rejected": rewards_rejected.mean().item(),
276 | "embeddings_chosen": embeddings_chosen.mean().item(),
277 | "embeddings_rejected": embeddings_rejected.mean().item(),
278 | "mean": mean.mean().item(),
279 | "log_var": log_var.mean().item()
280 | }
281 | )
282 | else:
283 | kld = -0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp(), dim=1).mean()
284 | if self.use_annealing:
285 | kld = self.annealer(kld)
286 | self.annealer.step()
287 | kld = self.kl_loss_weight * kld
288 | loss = reproduction_loss + kld
289 | accuracy = torch.mean((rewards_chosen > rewards_rejected).float())
290 | if not return_outputs:
291 | self.log(
292 | {
293 | "train_recon_loss": reproduction_loss.mean().item(),
294 | "train_kld": kld.mean().item(),
295 | "train_accuracy": accuracy.mean().item(),
296 | "rewards_chosen": rewards_chosen.mean().item(),
297 | "rewards_rejected": rewards_rejected.mean().item(),
298 | "embeddings_chosen": embeddings_chosen.mean().item(),
299 | "embeddings_rejected": embeddings_rejected.mean().item(),
300 | "mean": mean.mean().item(),
301 | "log_var": log_var.mean().item()
302 | }
303 | )
304 | if return_outputs:
305 | return loss, {
306 | "rewards_chosen": rewards_chosen,
307 | "rewards_rejected": rewards_rejected,
308 | "mean": mean,
309 | "log_var": log_var,
310 | "z": z,
311 | "user_type": user_type,
312 | }
313 | return loss
314 |
315 | def create_scheduler(self, num_training_steps: int, optimizer=None):
316 | scheduler = get_cosine_schedule_with_warmup(
317 | optimizer,
318 | num_warmup_steps=int(0.03 * num_training_steps),
319 | num_training_steps=num_training_steps
320 | )
321 | self.lr_scheduler = scheduler
322 | return scheduler
323 |
324 | @classmethod
325 | def compute_metrics(cls, eval_prediction: EvalPrediction):
326 | rewards_chosen, rewards_rejected, mean, log_var, z, user_type = (
327 | eval_prediction.predictions
328 | )
329 | rewards_chosen = torch.from_numpy(rewards_chosen)
330 | rewards_rejected = torch.from_numpy(rewards_rejected)
331 | mean = torch.from_numpy(mean)
332 | log_var = torch.from_numpy(log_var)
333 | z = torch.from_numpy(z)
334 | loss = cls.per_sample_loss(rewards_chosen, rewards_rejected)
335 | kld = -torch.sum(1 + log_var - mean.pow(2) - log_var.exp(), dim=-1)
336 | accuracy = torch.mean((loss < np.log(2)).float())
337 |
338 | def plot_latent(latent):
339 | from sklearn.manifold import TSNE
340 | z_embedding = TSNE(n_components=2, init='random', perplexity=20, learning_rate="auto").fit_transform(latent.numpy())
341 | import matplotlib.pyplot as plt
342 | colors = [f"C{int(i)}" for i in user_type]
343 | plt.scatter(z_embedding[:, 0], z_embedding[:, 1], c=colors)
344 | im = wandb.Image(plt)
345 | plt.close()
346 | return im
347 | im1 = plot_latent(mean)
348 | im2 = plot_latent(z)
349 |
350 | return {
351 | "loss": loss.mean().item(),
352 | "accuracy": accuracy.item(),
353 | "kld": kld.mean().item(),
354 | "mean_embeddings": im1,
355 | "z_embeddings": im2,
356 | }
357 |
358 |
359 | class Annealer:
360 | """
361 | This class is used to anneal the KL divergence loss over the course of training VAEs.
362 | After each call, the step() function should be called to update the current epoch.
363 | """
364 |
365 | def __init__(self, total_steps, shape, baseline=0.0, cyclical=False, disable=False):
366 | """
367 | Parameters:
368 | total_steps (int): Number of epochs to reach full KL divergence weight.
369 | shape (str): Shape of the annealing function. Can be 'linear', 'cosine', or 'logistic'.
370 | baseline (float): Starting value for the annealing function [0-1]. Default is 0.0.
371 | cyclical (bool): Whether to repeat the annealing cycle after total_steps is reached.
372 | disable (bool): If true, the __call__ method returns unchanged input (no annealing).
373 | """
374 | self.total_steps = total_steps
375 | self.current_step = 0
376 | self.cyclical = cyclical
377 | self.shape = shape
378 | self.baseline = baseline
379 | if disable:
380 | self.shape = "none"
381 | self.baseline = 0.0
382 |
383 | def __call__(self, kld):
384 | """
385 | Args:
386 | kld (torch.tensor): KL divergence loss
387 | Returns:
388 | out (torch.tensor): KL divergence loss multiplied by the slope of the annealing function.
389 | """
390 | out = kld * self.slope()
391 | return out
392 |
393 | def slope(self):
394 | if self.shape == "linear":
395 | y = self.current_step / self.total_steps
396 | elif self.shape == "cosine":
397 | y = (math.cos(math.pi * (self.current_step / self.total_steps - 1)) + 1) / 2
398 | elif self.shape == "logistic":
399 | exponent = (self.total_steps / 2) - self.current_step
400 | y = 1 / (1 + math.exp(exponent))
401 | elif self.shape == "none":
402 | y = 1.0
403 | else:
404 | raise ValueError(
405 | "Invalid shape for annealing function. Must be linear, cosine, or logistic."
406 | )
407 | y = self.add_baseline(y)
408 | return y
409 |
410 | def step(self):
411 | if self.current_step < self.total_steps:
412 | self.current_step += 1
413 | if self.cyclical and self.current_step >= self.total_steps:
414 | self.current_step = 0
415 | return
416 |
417 | def add_baseline(self, y):
418 | y_out = y * (1 - self.baseline) + self.baseline
419 | return y_out
420 |
421 | def cyclical_setter(self, value):
422 | if value is not bool:
423 | raise ValueError(
424 | "Cyclical_setter method requires boolean argument (True/False)"
425 | )
426 | else:
427 | self.cyclical = value
428 | return
429 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | transformers
2 | datasets>=1.17.0
3 | peft>=0.3.0
4 | sentencepiece
5 | diskcache==5.6.3
6 | pandas
7 | torch
8 | openai==0.28.1
9 | matplotlib==3.7.1
10 | scipy==1.10.1
11 |
--------------------------------------------------------------------------------
/submit_job_UF_P_2.sh:
--------------------------------------------------------------------------------
1 | export WANDB_MODE=online
2 | export WANDB_PROJECT=vpl
3 | export NCCL_P2P_DISABLE="1"
4 | export NCCL_IB_DISABLE="1"
5 |
6 | # Set model_name to be 'gpt2' or 'meta-llama/Llama-2-7b-hf' here
7 | model_name='gpt2'
8 |
9 | model_type=$1
10 | augment_type="84"
11 |
12 | if [ ${model_type} == "vae" ];
13 | then
14 | # Train VPL on UltraFeedback two-user dataset
15 | python -m hidden_context.train_llm_vae_preference_model \
16 | --model_name=${model_name} \
17 | --data_path="data/P_survey_100/gpt2" \
18 | --num_train_epochs=2 \
19 | --reward_model_type=vae \
20 | --data_subset=all \
21 | --log_dir="logs/gpt2_P_survey_100" \
22 | --bf16 True \
23 | --fp16 False \
24 | --per_device_train_batch_size 4 \
25 | --gradient_accumulation_steps 8 \
26 | --latent_dim 512 \
27 | --hidden_dim 512 \
28 | --learning_rate 1e-4 \
29 | --use_annealing True \
30 | --kl_loss_weight 0 \
31 | --controversial_only True \
32 | --fixed_contexts True \
33 | --fixed_llm_embeddings False \
34 | --up_sampling False \
35 | --other_subsets ${augment_type} \
36 | --use_last_token_embedding True \
37 | --seed 0
38 | else
39 | # Train baseline models on UltraFeedback two-user dataset
40 | python -m hidden_context.train_llm_preference_model \
41 | --model_name=${model_name} \
42 | --data_path="data/P_survey_100/gpt2" \
43 | --num_train_epochs=2 \
44 | --reward_model_type=${model_type} \
45 | --data_subset=all \
46 | --log_dir="logs/gpt2_P_survey_100" \
47 | --bf16 True \
48 | --fp16 False \
49 | --per_device_train_batch_size 4 \
50 | --gradient_accumulation_steps 8 \
51 | --learning_rate 1e-4 \
52 | --controversial_only True \
53 | --up_sampling False \
54 | --other_subsets ${augment_type} \
55 | --seed 0
56 | fi
57 |
--------------------------------------------------------------------------------
/submit_job_UF_P_4.sh:
--------------------------------------------------------------------------------
1 | export WANDB_MODE=online
2 | export WANDB_PROJECT=vpl
3 | export NCCL_P2P_DISABLE="1"
4 | export NCCL_IB_DISABLE="1"
5 |
6 | # Set model_name to be 'gpt2' or 'meta-llama/Llama-2-7b-hf' here
7 | model_name='gpt2'
8 |
9 | model_type=$1
10 |
11 | if [ ${model_type} == "vae" ];
12 | then
13 | # Train VPL on UltraFeedback four-user dataset
14 | python -m hidden_context.train_llm_vae_preference_model \
15 | --model_name=${model_name} \
16 | --data_path="data/P_4_survey_100/gpt2" \
17 | --num_train_epochs=2 \
18 | --reward_model_type=vae \
19 | --data_subset=all \
20 | --log_dir="logs/gpt2_P_4_survey_100" \
21 | --bf16 True \
22 | --fp16 False \
23 | --per_device_train_batch_size 4 \
24 | --gradient_accumulation_steps 8 \
25 | --latent_dim 512 \
26 | --hidden_dim 512 \
27 | --learning_rate 1e-4 \
28 | --use_annealing True \
29 | --kl_loss_weight 3e-6 \
30 | --controversial_only True \
31 | --fixed_contexts True \
32 | --fixed_llm_embeddings False \
33 | --up_sampling False \
34 | --other_subsets single \
35 | --use_last_token_embedding True \
36 | --seed 0
37 | else
38 | # Train baseline models on UltraFeedback four-user dataset
39 | python -m hidden_context.train_llm_preference_model \
40 | --model_name=${model_name} \
41 | --data_path="data/P_4_survey_100/gpt2" \
42 | --num_train_epochs=2 \
43 | --reward_model_type=${model_type} \
44 | --data_subset=all \
45 | --log_dir="logs/gpt2_P_4_survey_100" \
46 | --bf16 True \
47 | --fp16 False \
48 | --per_device_train_batch_size 4 \
49 | --gradient_accumulation_steps 8 \
50 | --learning_rate 1e-4 \
51 | --controversial_only True \
52 | --up_sampling False \
53 | --other_subsets single \
54 | --seed 0
55 | fi
56 |
--------------------------------------------------------------------------------
/submit_job_pets.sh:
--------------------------------------------------------------------------------
1 | export WANDB_MODE=online
2 | export WANDB_PROJECT=vpl
3 | export NCCL_P2P_DISABLE="1"
4 | export NCCL_IB_DISABLE="1"
5 |
6 | # Set model_name to be 'gpt2' or 'meta-llama/Llama-2-7b-hf' here
7 | model_name='gpt2'
8 |
9 | # full (up-sampling): --controversial_only False --up_sampling True
10 | # controversial: --controversial_only True
11 |
12 | # Reminder: for controversial settings, please use --num_train_epochs=10
13 |
14 | model_type=$1
15 |
16 | if [ ${model_type} == "vae" ]
17 | then
18 | # Train VPL on full/controversial/up-sampling Pets dataset
19 | python -m hidden_context.train_llm_vae_preference_model \
20 | --model_name=${model_name} \
21 | --data_path="data/simple_pets/gpt2" \
22 | --num_train_epochs=2 \
23 | --reward_model_type=vae \
24 | --data_subset=both \
25 | --log_dir="logs/gpt2_simple_pets" \
26 | --bf16 True \
27 | --fp16 False \
28 | --per_device_train_batch_size 4 \
29 | --gradient_accumulation_steps 8 \
30 | --latent_dim 512 \
31 | --hidden_dim 512 \
32 | --learning_rate 3e-4 \
33 | --use_annealing True \
34 | --kl_loss_weight 1e-4 \
35 | --fixed_contexts True \
36 | --fixed_llm_embeddings False \
37 | --use_last_token_embedding True \
38 | --up_sampling True \
39 | --controversial_only False \
40 | --seed 0
41 | else
42 | # Train baseline models on full/controversial/up-sampling Pets dataset
43 | python -m hidden_context.train_llm_preference_model \
44 | --model_name=${model_name} \
45 | --data_path="data/simple_pets/gpt2" \
46 | --num_train_epochs=2 \
47 | --reward_model_type=${model_type} \
48 | --data_subset=both \
49 | --log_dir="logs/gpt2_simple_pets" \
50 | --bf16 True \
51 | --fp16 False \
52 | --per_device_train_batch_size 4 \
53 | --gradient_accumulation_steps 8 \
54 | --learning_rate 1e-4 \
55 | --controversial_only False \
56 | --up_sampling True \
57 | --seed 0
58 | fi
59 |
--------------------------------------------------------------------------------