├── .gitignore
├── DATA_LICENSE
├── LICENSE
├── README.md
├── assets
└── images
│ ├── Dromedary-2.png
│ ├── salmon_comparison.png
│ ├── salmon_logo_with_text.jpeg
│ └── salmon_pipeline.png
├── inference
├── README.md
└── demo.py
├── prompts
├── dromedary_inference_prompt.txt
├── pmp_reward_model_prompt.txt
├── principles
│ ├── principle_collection_harmless.json
│ ├── principle_collection_honest.json
│ ├── principle_collection_non_evasive.json
│ ├── principle_collection_ppo.json
│ └── principle_collection_rm.json
├── salmon_reward_model_prompt_v0.txt
├── salmon_reward_model_prompt_v1.txt
└── synthetic_preference_prompt.txt
├── requirements.txt
└── training
├── README.md
├── data_utils
├── common_utils.py
├── data_utils_ppo.py
├── data_utils_rm.py
└── data_utils_sft.py
├── models
├── configuration_llama.py
├── distributed_utils.py
├── llama_with_flash_attn.py
├── ppo_trainer.py
├── qlora_model.py
├── reward_model.py
├── rl_models.py
├── rl_trainer.py
└── trainer_utils.py
├── qlora_utils.py
├── step1_synthetic_preference_collection
├── batch_generation.py
├── clean_oasst1_prompts.py
├── scripts
│ ├── generate_oasst1_response0.sh
│ ├── generate_oasst1_response1.sh
│ └── generate_synthetic_preference.sh
└── synthetic_preference.py
├── step2_rm_training
├── aggregate_synthetic_preference.py
├── clean_pmp_data.py
└── scripts
│ ├── train_reward_model_70b_qlora_ft.sh
│ └── train_reward_model_70b_qlora_pmp.sh
├── step3_ppo_training
├── aggregate_sharegpt_prompts.py
├── clean_and_merge_prompts.py
├── scripts
│ └── train_ppo_model_70b_qlora_salmon.sh
└── subsample_openorca_prompts.py
├── train_qlora_ppo.py
├── train_qlora_rm.py
└── train_qlora_sft.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | .DS_Store
132 | .idea
133 |
134 | # temporary scripts
135 | tmp_scripts/
136 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
5 |
6 |
7 |
8 |
9 |
10 | ## SALMON: Self-Alignment with Principle-Following Reward Models
11 |
12 |
13 |
14 | [](LICENSE)
15 | [](DATA_LICENSE)
16 |
17 |
18 |
19 |
20 |
21 |
44 |
45 | SALMON is a new RLAIF paradigm for self-aligning language models from scratch, using only a small set of human-defined principles as guidance.
46 | Central to our approach is a principle-following reward model. Trained on synthetic preference data, this model can generate reward scores based on arbitrary human-defined principles.
47 | For comprehensive details and insights, we kindly direct you to our [paper](https://arxiv.org/abs/2310.05910).
48 |
49 |
50 |
51 |
52 |
53 |
54 |
56 |
57 |
58 |
59 | ## Dromedary-2
60 |
61 |
62 |
63 | We release the *Dromedary-2* model, which is trained with the SALMON paradigm on the [*LLaMA-2-70b* base language model](https://huggingface.co/meta-llama/Llama-2-70b-hf), with [Principle-Driven Self-Alignment](https://github.com/IBM/Dromedary) as the Supervised Fine-Tuning (SFT) strategy to initialize the policy model.
64 |
65 | This codebase focuses on the **Reinforcement Learning (RL)** fine-tuning stage with the SALMON paradigm, while the Self-Align SFT training pipeline is released at the [original Dromedary repo](https://github.com/IBM/Dromedary),
66 |
67 |
68 |
69 | ### Model Weights
70 |
71 | We release *Dromedary-2* weights as delta weights to comply with the LLaMA model license. You can directly load our QLoRA weights upon the *LLaMA-2* base model to obtain *Dromedary-2*. Instructions:
72 |
73 | 1. Get the original LLaMA-2 weights in the Hugging Face format by following the instructions [here](https://huggingface.co/meta-llama/Llama-2-70b-hf).
74 | 2. Download the QLoRA delta weights from our Hugging Face [model hub](https://huggingface.co/zhiqings/dromedary-2-70b-qlora-delta-v0).
75 | 3. Load the model with Hugging Face's [PEFT-LoRA](https://github.com/huggingface/peft) and QLoRA's [bitsandbytes](https://github.com/TimDettmers/bitsandbytes).
76 |
77 | **NOTE: *Dromedary-2* is trained with QLoRA and the bfloat16 data type.** While it is [possible](https://gist.github.com/ChrisHayduk/1a53463331f52dca205e55982baf9930) to merge the QLoRA weights with the quantized model and thus enable inference with libraries such as [TGI](https://github.com/huggingface/text-generation-inference) and [vLLM](https://github.com/vllm-project/vllm), we found the merged weights can lead to degenerated performance. Therefore, we recommend directly loading the QLoRA weights with the PEFT-LoRA framework.
78 |
79 | ```python
80 | # Please check the inference section for the complete inference code.
81 | system_prompt = (
82 | "# Dromedary\n\n## System Overview\n\n"
83 | "Consider an AI assistant whose codename is Dromedary, developed by the Self-Align team. "
84 | "Dromedary is trained on data up until Sept-2022, and it endeavors to be a helpful, ethical and reliable assistant.\n\n"
85 | "## User Conversation\n\n"
86 | )
87 | user_prompt = "### User\n"
88 | assistant_prompt = "### Dromedary\n"
89 | seperator = "\n\n"
90 |
91 | # USAGE: system_prompt + user_prompt + `user_message` + seperator + assistant_prompt + `assistant_message` + seperator + user_prompt ...
92 |
93 | dtype = torch.bfloat16
94 |
95 | model_path = "path/to/llama-2-70b-hf"
96 | qlora_path = "path/to/dromedary-2-70b-qlora-delta-v0"
97 |
98 | bnb_config = BitsAndBytesConfig(
99 | load_in_4bit=True,
100 | bnb_4bit_compute_dtype=dtype,
101 | bnb_4bit_use_double_quant=True,
102 | bnb_4bit_quant_type="nf4",
103 | )
104 |
105 | model = AutoModelForCausalLM.from_pretrained(
106 | model_path,
107 | load_in_4bit=True,
108 | device_map={"": "cuda:0"},
109 | quantization_config=bnb_config,
110 | torch_dtype=dtype,
111 | )
112 |
113 | model = PeftModel.from_pretrained(
114 | model,
115 | qlora_path,
116 | is_trainable=False,
117 | )
118 | ```
119 |
120 | ## Setup
121 |
122 | 1. Clone this repository and navigate to SALMON folder
123 |
124 | ```Shell
125 | git clone https://github.com/IBM/SALMON
126 | cd SALMON
127 | ```
128 |
129 | 2. Install the packages
130 |
131 | ```Shell
132 | conda create -n salmon python=3.9 -y
133 | conda activate salmon
134 | pip install -r requirements.txt
135 | ```
136 |
137 | ## Inference
138 |
139 | We provide a [chatbot demo](inference) for *Dromedary-2*.
140 |
141 | ## Training
142 |
143 | We provide the full [training pipeline](training) of *Dromedary-2* for reproduction.
144 |
145 | ## Prompts
146 |
147 | All the human supervision used in this project can be found [here](prompts).
148 |
149 | ### Citation
150 |
151 | Please consider citing the following papers if you use the data or code in this repo.
152 |
153 | ```
154 | @misc{sun2023principledriven,
155 | title={Principle-Driven Self-Alignment of Language Models from Scratch with Minimal Human Supervision},
156 | author={Zhiqing Sun and Yikang Shen and Qinhong Zhou and Hongxin Zhang and Zhenfang Chen and David Cox and Yiming Yang and Chuang Gan},
157 | year={2023},
158 | eprint={2305.03047},
159 | archivePrefix={arXiv},
160 | primaryClass={cs.LG}
161 | }
162 | ```
163 |
164 | ```
165 | @misc{sun2023salmon,
166 | title={SALMON: Self-Alignment with Principle-Following Reward Models},
167 | author={Zhiqing Sun and Yikang Shen and Hongxin Zhang and Qinhong Zhou and Zhenfang Chen and David Cox and Yiming Yang and Chuang Gan},
168 | year={2023},
169 | eprint={2310.05910},
170 | archivePrefix={arXiv},
171 | primaryClass={cs.LG}
172 | }
173 | ```
174 |
175 | ### Acknowledgements
176 |
177 | We thank [Meta LLaMA team](https://github.com/facebookresearch/llama), [Standford Alpaca team](https://github.com/tatsu-lab/stanford_alpaca), [Vicuna team](https://github.com/lm-sys/FastChat), [Alpaca-LoRA](https://github.com/tloen/alpaca-lora), [QLoRA team](https://github.com/artidoro/qlora), [Hugging Face PEFT](https://github.com/huggingface/peft), and [AlpacaFarm team](https://github.com/tatsu-lab/alpaca_farm) for their open-source efforts in democratizing large language models.
178 |
--------------------------------------------------------------------------------
/assets/images/Dromedary-2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IBM/SALMON/7392a5618c19a59d218ebbc2f50a5940a92304c5/assets/images/Dromedary-2.png
--------------------------------------------------------------------------------
/assets/images/salmon_comparison.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IBM/SALMON/7392a5618c19a59d218ebbc2f50a5940a92304c5/assets/images/salmon_comparison.png
--------------------------------------------------------------------------------
/assets/images/salmon_logo_with_text.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IBM/SALMON/7392a5618c19a59d218ebbc2f50a5940a92304c5/assets/images/salmon_logo_with_text.jpeg
--------------------------------------------------------------------------------
/assets/images/salmon_pipeline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IBM/SALMON/7392a5618c19a59d218ebbc2f50a5940a92304c5/assets/images/salmon_pipeline.png
--------------------------------------------------------------------------------
/inference/README.md:
--------------------------------------------------------------------------------
1 | # Demo
2 |
3 | To run our demo, you need to prepare *Dromedary-2* checkpoints locally. Please follow the [instructions here](https://github.com/IBM/SALMON#model-weights). This demo is adpted from [Guanaco demo notebook](https://colab.research.google.com/drive/17XEqL1JcmVWjHkT-WczdYkJlNINacwG7?usp=sharing). The code is tested on a single A100-80GB GPU, with a peak GPU memory usage < 48GB.
4 |
5 | ```bash
6 | python -u demo.py \
7 | --model-name "/path/to/llama-2-70b-hf" \
8 | --adapters-name "/path/to/qlora_adapter_models"
9 | ```
10 |
--------------------------------------------------------------------------------
/inference/demo.py:
--------------------------------------------------------------------------------
1 | # Load the model.
2 | # Note: It can take a while to download LLaMA and add the adapter modules.
3 |
4 | import torch
5 | import fire
6 |
7 | import datetime
8 | import os
9 | from threading import Event, Thread
10 | from uuid import uuid4
11 |
12 | import gradio as gr
13 | import requests
14 |
15 | from peft import PeftModel
16 | from transformers import (
17 | BitsAndBytesConfig,
18 | AutoModelForCausalLM,
19 | LlamaTokenizerFast,
20 | StoppingCriteria,
21 | StoppingCriteriaList,
22 | TextIteratorStreamer,
23 | )
24 |
25 | torch.backends.cuda.matmul.allow_tf32 = True
26 |
27 |
28 | def main(
29 | model_name: str = "/path/to/llama-2-70b-hf",
30 | adapters_name: str = "/path/to/adapter_model",
31 | tokenizer_name: str = "TheBloke/dromedary-65b-lora-HF", # a random loadable llama tokenizer
32 | max_queue_size: int = 16,
33 | sharable_link: bool = True,
34 | ):
35 | print(f"Starting to load the model {model_name} into memory")
36 |
37 | m = AutoModelForCausalLM.from_pretrained(
38 | model_name,
39 | load_in_4bit=True,
40 | torch_dtype=torch.bfloat16,
41 | device_map={"": 0},
42 | quantization_config=BitsAndBytesConfig(
43 | load_in_4bit=True,
44 | bnb_4bit_compute_dtype=torch.bfloat16,
45 | bnb_4bit_use_double_quant=True,
46 | bnb_4bit_quant_type="nf4",
47 | ),
48 | )
49 | m = PeftModel.from_pretrained(m, adapters_name, is_trainable=False)
50 |
51 | start_message = (
52 | "# Dromedary"
53 | "\n\n## System Overview"
54 | "\n\nConsider an AI assistant whose codename is Dromedary, developed by the Self-Align team. "
55 | "Dromedary is trained on data from before Sept-2022, and it endeavors to be a helpful, ethical and reliable assistant."
56 | "\n\n## User Conversation\n\n"
57 | )
58 |
59 | tok = LlamaTokenizerFast.from_pretrained(tokenizer_name)
60 |
61 | print(f"Successfully loaded the model {model_name} into memory")
62 |
63 | # Setup the gradio Demo.
64 |
65 | stop_tokens = ["### User"]
66 |
67 | class StopOnTokens(StoppingCriteria):
68 | def __call__(
69 | self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
70 | ) -> bool:
71 | decoded_tokens = tok.batch_decode(input_ids, skip_special_tokens=True)
72 |
73 | for stop_token in stop_tokens:
74 | if decoded_tokens[0].strip().endswith(stop_token):
75 | return True
76 | return False
77 |
78 | def convert_history_to_text(history):
79 | text = start_message + "".join(
80 | [
81 | "".join(
82 | [
83 | f"### User\n{item[0]}\n\n",
84 | f"### Dromedary\n{item[1]}\n\n",
85 | ]
86 | )
87 | for item in history[:-1]
88 | ]
89 | )
90 | text += "".join(
91 | [
92 | "".join(
93 | [
94 | f"### User\n{history[-1][0]}\n\n",
95 | f"### Dromedary\n{history[-1][1]}",
96 | ]
97 | )
98 | ]
99 | )
100 | return text
101 |
102 | def log_conversation(conversation_id, history, messages, generate_kwargs):
103 | logging_url = os.getenv("LOGGING_URL", None)
104 | if logging_url is None:
105 | return
106 |
107 | timestamp = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S")
108 |
109 | data = {
110 | "conversation_id": conversation_id,
111 | "timestamp": timestamp,
112 | "history": history,
113 | "messages": messages,
114 | "generate_kwargs": generate_kwargs,
115 | }
116 |
117 | try:
118 | requests.post(logging_url, json=data)
119 | except requests.exceptions.RequestException as e:
120 | print(f"Error logging conversation: {e}")
121 |
122 | def user(message, history):
123 | # Append the user's message to the conversation history
124 | return "", history + [[message, ""]]
125 |
126 | def bot(
127 | history, temperature, top_p, max_new_tokens, repetition_penalty, conversation_id
128 | ):
129 | # Initialize a StopOnTokens object
130 | stop = StopOnTokens()
131 |
132 | # Construct the input message string for the model by concatenating the current system message and conversation history
133 | messages = convert_history_to_text(history)
134 |
135 | print(f"history:")
136 | print(messages)
137 |
138 | # Tokenize the messages string
139 | input_ids = tok(messages, return_tensors="pt").input_ids
140 | input_ids = input_ids.to(m.device)
141 | streamer = TextIteratorStreamer(
142 | tok, timeout=10.0, skip_prompt=True, skip_special_tokens=True
143 | )
144 | generate_kwargs = dict(
145 | input_ids=input_ids,
146 | max_new_tokens=max_new_tokens,
147 | temperature=temperature,
148 | do_sample=temperature > 0.0,
149 | top_p=top_p,
150 | repetition_penalty=repetition_penalty,
151 | streamer=streamer,
152 | stopping_criteria=StoppingCriteriaList([stop]),
153 | )
154 |
155 | stream_complete = Event()
156 |
157 | def generate_and_signal_complete():
158 | m.generate(**generate_kwargs)
159 | stream_complete.set()
160 |
161 | def log_after_stream_complete():
162 | stream_complete.wait()
163 | log_conversation(
164 | conversation_id,
165 | history,
166 | messages,
167 | {
168 | "max_new_tokens": max_new_tokens,
169 | "top_p": top_p,
170 | "temperature": temperature,
171 | "repetition_penalty": repetition_penalty,
172 | },
173 | )
174 |
175 | t1 = Thread(target=generate_and_signal_complete)
176 | t1.start()
177 |
178 | t2 = Thread(target=log_after_stream_complete)
179 | t2.start()
180 |
181 | # Initialize an empty string to store the generated text
182 | partial_text = ""
183 | for new_text in streamer:
184 | partial_text += new_text
185 | history[-1][1] = partial_text.split("### User")[0]
186 | yield history
187 |
188 | print(f"output:")
189 | print(history[-1][1])
190 |
191 | def get_uuid():
192 | return str(uuid4())
193 |
194 | with gr.Blocks(
195 | theme=gr.themes.Soft(),
196 | css=".disclaimer {font-variant-caps: all-small-caps;}",
197 | ) as demo:
198 | conversation_id = gr.State(get_uuid)
199 | gr.Markdown(
200 | """Dromedary-2 Demo
201 | """
202 | )
203 | chatbot = gr.Chatbot().style(height=500)
204 | with gr.Row():
205 | with gr.Column():
206 | msg = gr.Textbox(
207 | label="Chat Message Box",
208 | placeholder="Chat Message Box",
209 | show_label=False,
210 | ).style(container=False)
211 | with gr.Column():
212 | with gr.Row():
213 | submit = gr.Button("Submit")
214 | # stop = gr.Button("Stop")
215 | clear = gr.Button("Clear")
216 | with gr.Row():
217 | with gr.Accordion("Advanced Options:", open=False):
218 | with gr.Row():
219 | with gr.Column():
220 | with gr.Row():
221 | temperature = gr.Slider(
222 | label="Temperature",
223 | value=0.7,
224 | minimum=0.0,
225 | maximum=1.0,
226 | step=0.1,
227 | interactive=True,
228 | info="Higher values produce more diverse outputs",
229 | )
230 | with gr.Column():
231 | with gr.Row():
232 | top_p = gr.Slider(
233 | label="Top-p (nucleus sampling)",
234 | value=0.9,
235 | minimum=0.0,
236 | maximum=1,
237 | step=0.01,
238 | interactive=True,
239 | info=(
240 | "Sample from the smallest possible set of tokens whose cumulative probability "
241 | "exceeds top_p. Set to 1 to disable and sample from all tokens."
242 | ),
243 | )
244 | with gr.Column():
245 | with gr.Row():
246 | max_new_tokens = gr.Slider(
247 | label="Response Length",
248 | value=768,
249 | minimum=64,
250 | maximum=1024,
251 | step=64,
252 | interactive=True,
253 | info="Sample from a shortlist of top-k tokens — 0 to disable and sample from all tokens.",
254 | )
255 | with gr.Column():
256 | with gr.Row():
257 | repetition_penalty = gr.Slider(
258 | label="Repetition Penalty",
259 | value=1.0,
260 | minimum=1.0,
261 | maximum=2.0,
262 | step=0.1,
263 | interactive=True,
264 | info="Penalize repetition — 1.0 to disable.",
265 | )
266 | with gr.Row():
267 | gr.Markdown(
268 | "Disclaimer: The model can produce factually incorrect output, and should not be relied on to produce "
269 | "factually accurate information. The model was trained on various public datasets; while great efforts "
270 | "have been taken to clean the pretraining data, it is possible that this model could generate lewd, "
271 | "biased, or otherwise offensive outputs.",
272 | elem_classes=["disclaimer"],
273 | )
274 | with gr.Row():
275 | gr.Markdown(
276 | "[Privacy policy](https://gist.github.com/samhavens/c29c68cdcd420a9aa0202d0839876dac)",
277 | elem_classes=["disclaimer"],
278 | )
279 |
280 | submit_event = msg.submit(
281 | fn=user,
282 | inputs=[msg, chatbot],
283 | outputs=[msg, chatbot],
284 | queue=False,
285 | ).then(
286 | fn=bot,
287 | inputs=[
288 | chatbot,
289 | temperature,
290 | top_p,
291 | max_new_tokens,
292 | repetition_penalty,
293 | conversation_id,
294 | ],
295 | outputs=chatbot,
296 | queue=True,
297 | )
298 | submit_click_event = submit.click(
299 | fn=user,
300 | inputs=[msg, chatbot],
301 | outputs=[msg, chatbot],
302 | queue=False,
303 | ).then(
304 | fn=bot,
305 | inputs=[
306 | chatbot,
307 | temperature,
308 | top_p,
309 | max_new_tokens,
310 | repetition_penalty,
311 | conversation_id,
312 | ],
313 | outputs=chatbot,
314 | queue=True,
315 | )
316 | # stop.click(
317 | # fn=None,
318 | # inputs=None,
319 | # outputs=None,
320 | # cancels=[submit_event, submit_click_event],
321 | # queue=False,
322 | # )
323 | clear.click(lambda: None, None, chatbot, queue=False)
324 |
325 | demo.queue(max_size=max_queue_size, concurrency_count=1)
326 |
327 | # Launch your Dromedary-2 Demo!
328 | demo.launch(share=sharable_link)
329 |
330 |
331 | if __name__ == "__main__":
332 | fire.Fire(main)
333 |
--------------------------------------------------------------------------------
/prompts/dromedary_inference_prompt.txt:
--------------------------------------------------------------------------------
1 | # Dromedary
2 |
3 | ## System Overview
4 |
5 | Consider an AI assistant whose codename is Dromedary, developed by the Self-Align team. Dromedary is trained on data from before Sept-2022, and it endeavors to be a helpful, ethical and reliable assistant.
6 |
7 | ## User Conversation
8 |
9 | ### User
--------------------------------------------------------------------------------
/prompts/pmp_reward_model_prompt.txt:
--------------------------------------------------------------------------------
1 | # System
2 |
3 | You are a reviewer whose goal is to judge the quality of the AI system's responses to instructions.
4 |
5 | ## Annotation Example
6 |
7 | ### AI system's Response
8 | {Output}
9 |
10 | ### Instruction to the AI system
11 | {Input}
12 |
13 | ## Annotation Guideline
14 |
15 | Your task is to decide the quality of the response. There are several dimensions that you can think along. Consider the following questions:
16 |
17 | - Is the response helpful? For example, if the instruction asked for a recipe for healthy food, and the response is a useful recipe, then we can consider it helpful.
18 | - Is the response language natural? For example, AI responses often have repetitions, which is not natural.
19 | - Is the response factual/accurate? For example, AI responses often make up new information. For example, if the response claims that Donald Trump is the current U.S. president, then you should consider it inaccurate.
20 | - and so on ... ultimately, you should decide whether the response is good or bad.
21 |
22 | ## Reviewer
23 | The quality of the output is
--------------------------------------------------------------------------------
/prompts/principles/principle_collection_harmless.json:
--------------------------------------------------------------------------------
1 | [
2 | {
3 | "dimension": "Non-Discriminatory",
4 | "definition": "The AI should not produce content that is offensive, discriminatory, or prejudiced in nature, either directly or through subtext or bias.",
5 | "weight": 10
6 | },
7 | {
8 | "dimension": "Refusal to Aid Dangerous Acts",
9 | "definition": "When prompted to assist in harmful activities (e.g. building a bomb), the AI should promptly and politely decline. It should have the capability to identify and counteract concealed requests for harmful actions.",
10 | "weight": 5
11 | },
12 | {
13 | "dimension": "Consequential Advice Awareness",
14 | "definition": "The AI should be cognizant when it might be providing critical or sensitive guidance. It should exercise restraint and provide responses with the necessary caution.",
15 | "weight": 4
16 | },
17 | {
18 | "dimension": "Cultural Sensitivity",
19 | "definition": "Understanding that perceptions of harm vary among individuals and cultures, the AI should be adaptive and respectful. It should acknowledge and respect the diverse range of cultural norms and values.",
20 | "weight": 4
21 | },
22 | {
23 | "dimension": "Context Awareness",
24 | "definition": "The AI's responses should be tailored to the nature of the user query, taking into account who is interacting with the AI, as well as the situational context in which the assistant is being engaged.",
25 | "weight": 4
26 | },
27 | {
28 | "dimension": "Guard Against Amplification",
29 | "definition": "To prevent reinforcing harmful beliefs or stereotypes, the AI should avoid echoing or amplifying potentially harmful statements made by users, instead promoting neutral or positive alternatives.",
30 | "weight": 4
31 | },
32 | {
33 | "dimension": "Empathy-driven Responses",
34 | "definition": "The AI should gauge the emotional tone of user queries and respond with sensitivity, especially in situations where users seem distressed or vulnerable.",
35 | "weight": 1
36 | },
37 | {
38 | "dimension": "Proactive Harm Prevention",
39 | "definition": "Rather than just reactively refusing harmful requests, the AI should actively provide alternatives or suggestions that promote safety and well-being.",
40 | "weight": 1
41 | },
42 | {
43 | "dimension": "Transparency on Biases",
44 | "definition": "The AI should openly acknowledge when its responses might be influenced by historical data biases, offering the user a clearer picture of its limitations.",
45 | "weight": 1
46 | },
47 | {
48 | "dimension": "Evolutive Learning",
49 | "definition": "The AI should continuously learn and adapt from feedback on harmful outputs, minimizing the risk of repeating harmful behaviors in future interactions.",
50 | "weight": 1
51 | },
52 | {
53 | "dimension": "Mental Health Awareness",
54 | "definition": "Recognizing that conversations can sometimes touch on mental health issues, the AI should handle such interactions with extreme care, always prioritizing the well-being of the user and never offering medical advice unless explicitly trained to do so.",
55 | "weight": 1
56 | }
57 | ]
--------------------------------------------------------------------------------
/prompts/principles/principle_collection_honest.json:
--------------------------------------------------------------------------------
1 | [
2 | {
3 | "dimension": "Unwavering Accuracy",
4 | "definition": "The AI must consistently provide accurate and verifiable information in its responses.",
5 | "weight": 5
6 | },
7 | {
8 | "dimension": "Consistent Calibration",
9 | "definition": "The AI should ensure that its confidence assertions align with its actual accuracy, reflecting genuine knowledge estimation.",
10 | "weight": 5
11 | },
12 | {
13 | "dimension": "Transparent Uncertainty Expression",
14 | "definition": "When uncertain, the AI should transparently express its level of uncertainty, ensuring users are never misled by vague or overly confident statements.",
15 | "weight": 4
16 | },
17 | {
18 | "dimension": "Honest Self-representation",
19 | "definition": "The AI must be forthright about its own capabilities and should not feign expertise or knowledge it doesn't possess.",
20 | "weight": 5
21 | },
22 | {
23 | "dimension": "Introspective Honesty",
24 | "definition": "If asked about its workings or limitations, the AI should provide clear and honest descriptions about its design and operational boundaries.",
25 | "weight": 4
26 | },
27 | {
28 | "dimension": "Bias Recognition",
29 | "definition": "The AI should acknowledge potential biases inherent in its data or model and inform users when a topic or response might be influenced by such biases.",
30 | "weight": 4
31 | },
32 | {
33 | "dimension": "Correction Commitment",
34 | "definition": "When confronted with an error in its response, the AI should admit the mistake and strive to provide the correct information.",
35 | "weight": 1
36 | },
37 | {
38 | "dimension": "Avoidance of Pretense",
39 | "definition": "The AI should avoid mimicking the behavior or response patterns of experts unless it's providing information based on verified data or recognized knowledge.",
40 | "weight": 1
41 | }
42 | ]
--------------------------------------------------------------------------------
/prompts/principles/principle_collection_non_evasive.json:
--------------------------------------------------------------------------------
1 | [
2 | {
3 | "dimension": "Educational and Engaging",
4 | "definition": "The AI's responses should be enriched with accurate, relevant, and current information, serving to educate while keeping the user engaged.",
5 | "weight": 1
6 | },
7 | {
8 | "dimension": "Creative",
9 | "definition": "The AI should be adept at generating original content, such as poems, stories, code, essays, songs, parodies, summaries, translations, and more.",
10 | "weight": 1
11 | },
12 | {
13 | "dimension": "Multilingual",
14 | "definition": "The AI should be capable of conversing in the language used by the user, for instance, replying in 中文 if the query is in 中文.",
15 | "weight": 1
16 | },
17 | {
18 | "dimension": "Comprehensive",
19 | "definition": "For information-seeking tasks, the AI should offer extensive and relevant details to ensure a thorough and in-depth response. It should impartially and extensively present arguments from diverse perspectives when dealing with contentious topics.",
20 | "weight": 3
21 | },
22 | {
23 | "dimension": "Natural Language",
24 | "definition": "The AI should respond with diverse and natural language, avoiding repetition and awkward phrasing.",
25 | "weight": 3
26 | },
27 | {
28 | "dimension": "Consistent Reasoning",
29 | "definition": "The AI should deliver responses that are clear and logically sound, ensuring they do not contain self-contradictions.",
30 | "weight": 1
31 | },
32 | {
33 | "dimension": "Numerical Sensitive",
34 | "definition": "The AI should ensure that any numerical specifications given in the instruction are carefully adhered to, avoiding any errors in numerical computations.",
35 | "weight": 1
36 | },
37 | {
38 | "dimension": "Analytical Structure",
39 | "definition": "For information analysis tasks, the AI should articulate its response in a manner that begins with a summary, followed by numerous key points, each underscored by a thorough analysis.",
40 | "weight": 3
41 | },
42 | {
43 | "dimension": "Vivid",
44 | "definition": "The AI should employ vibrant, energetic language, enhancing user engagement by making all interactions lively and dynamic.",
45 | "weight": 1
46 | }
47 | ]
--------------------------------------------------------------------------------
/prompts/principles/principle_collection_ppo.json:
--------------------------------------------------------------------------------
1 | [
2 | {
3 | "dimension": "Honest and Accurate",
4 | "definition": "The AI must furnish reliable and factual information, and candidly disclose its limitations and the extent of its knowledge.",
5 | "weight": 3
6 | },
7 | {
8 | "dimension": "Ethical",
9 | "definition": "The AI should produce content that is free from offensive, discriminatory, or harmful material, and should not participate in or endorse risky activities.",
10 | "weight": 2
11 | },
12 | {
13 | "dimension": "Educational and Engaging",
14 | "definition": "The AI's responses should be enriched with accurate, relevant, and current information, serving to educate while keeping the user engaged.",
15 | "weight": 1
16 | },
17 | {
18 | "dimension": "Creative",
19 | "definition": "The AI should be adept at generating original content, such as poems, stories, code, essays, songs, parodies, summaries, translations, and more.",
20 | "weight": 1
21 | },
22 | {
23 | "dimension": "Multilingual",
24 | "definition": "The AI should be capable of conversing in the language used by the user, for instance, replying in 中文 if the query is in 中文.",
25 | "weight": 1
26 | },
27 | {
28 | "dimension": "Comprehensive",
29 | "definition": "For information-seeking tasks, the AI should offer extensive and relevant details to ensure a thorough and in-depth response. It should impartially and extensively present arguments from diverse perspectives when dealing with contentious topics.",
30 | "weight": 3
31 | },
32 | {
33 | "dimension": "Natural Language",
34 | "definition": "The AI should respond with diverse and natural language, avoiding repetition and awkward phrasing.",
35 | "weight": 3
36 | },
37 | {
38 | "dimension": "Consistent Reasoning",
39 | "definition": "The AI should deliver responses that are clear and logically sound, ensuring they do not contain self-contradictions.",
40 | "weight": 1
41 | },
42 | {
43 | "dimension": "Numerical Sensitive",
44 | "definition": "The AI should ensure that any numerical specifications given in the instruction are carefully adhered to, avoiding any errors in numerical computations.",
45 | "weight": 1
46 | },
47 | {
48 | "dimension": "Analytical Structure",
49 | "definition": "For information analysis tasks, the AI should articulate its response in a manner that begins with a summary, followed by numerous key points, each underscored by a thorough analysis.",
50 | "weight": 3
51 | },
52 | {
53 | "dimension": "Vivid",
54 | "definition": "The AI should employ vibrant, energetic language, enhancing user engagement by making all interactions lively and dynamic.",
55 | "weight": 1
56 | },
57 | {
58 | "dimension": "Privacy Protection",
59 | "definition": "The AI should avoid generating any personal identifiable information (PII) or external URLs in its responses.",
60 | "weight": 1
61 | },
62 | {
63 | "dimension": "Candor",
64 | "definition": "The AI should avoid sharing false information. If a question does not make any sense, or is not factually coherent, the AI should explain why instead of answering something not correct.",
65 | "weight": 4
66 | },
67 | {
68 | "dimension": "Stand-Alone",
69 | "definition": "The AI must function as a stand-alone text-based system, avoiding interactions with any external sources, including URLs, images, or videos.",
70 | "weight": 2
71 | }
72 | ]
--------------------------------------------------------------------------------
/prompts/principles/principle_collection_rm.json:
--------------------------------------------------------------------------------
1 | [
2 | {
3 | "dimension": "Overall Helpful",
4 | "definition": "The AI’s response should precisely serve the user's needs and interests by being well-rounded, combining conciseness, accuracy, informativeness, and other critical qualities.",
5 | "negative_definition": "The AI should avoid generating a helpful response."
6 | },
7 | {
8 | "dimension": "Concise",
9 | "definition": "The response should efficiently address the task or answer the question, communicating the necessary information with brevity and clarity.",
10 | "negative_definition": "The response should avoid efficiently addressing the task or answering the question."
11 | },
12 | {
13 | "dimension": "Honest and Accurate",
14 | "definition": "The AI must furnish reliable and factual information, and candidly disclose its limitations and the extent of its knowledge.",
15 | "negative_definition": "The AI should avoid furnishing reliable and factual information."
16 | },
17 | {
18 | "dimension": "Ethical",
19 | "definition": "The AI should produce content that is free from offensive, discriminatory, or harmful material, and should not participate in or endorse risky activities.",
20 | "negative_definition": "The AI should avoid producing content that is free from offensive, discriminatory, or harmful material."
21 | },
22 | {
23 | "dimension": "Natural and Fluent",
24 | "definition": "The AI should employ language that flows smoothly and is free from repetitive or awkward constructs.",
25 | "negative_definition": "The AI should avoid employing language that flows smoothly."
26 | },
27 | {
28 | "dimension": "Specific",
29 | "definition": "The AI’s response should be directly pertinent to the query, addressing the particular subject in the instruction explicitly.",
30 | "negative_definition": "The AI’s response should avoid being directly pertinent to the query."
31 | },
32 | {
33 | "dimension": "Educational and Engaging",
34 | "definition": "The AI's responses should be enriched with accurate, relevant, and current information, serving to educate while keeping the user engaged.",
35 | "negative_definition": "The AI's responses should avoid being enriched with accurate, relevant, and current information."
36 | },
37 | {
38 | "dimension": "Methodical",
39 | "definition": "The AI should employ a structured approach when providing solutions, presenting logical and step-by-step explanation before arriving at a conclusion.",
40 | "negative_definition": "The AI should avoid employing a structured approach when providing solutions."
41 | },
42 | {
43 | "dimension": "Multilingual",
44 | "definition": "The AI should be capable of conversing in the language used by the user, for instance, replying in 中文 if the query is in 中文.",
45 | "negative_definition": "The AI should avoid being capable of conversing in the language used by the user."
46 | },
47 | {
48 | "dimension": "Creative",
49 | "definition": "The AI should be adept at generating original content, such as poems, stories, code, essays, songs, parodies, summaries, translations, and more.",
50 | "negative_definition": "The AI should avoid being adept at generating original content."
51 | },
52 | {
53 | "dimension": "Comprehensive",
54 | "definition": "The AI should offer extensive and relevant details to ensure a thorough and in-depth response. It should impartially and extensively present arguments from diverse perspectives when dealing with contentious topics.",
55 | "negative_definition": "The AI should avoid offering extensive and relevant details."
56 | }
57 | ]
--------------------------------------------------------------------------------
/prompts/salmon_reward_model_prompt_v0.txt:
--------------------------------------------------------------------------------
1 | # System
2 |
3 | You are a reviewer whose goal is to judge the quality of the AI system's responses to instructions.
4 |
5 | ## Annotation Example
6 |
7 | ### AI system's Response
8 | {Output}
9 |
10 | ### Instruction to the AI system
11 | {Input}
12 |
13 | ## Annotation Guideline
14 |
15 | Your task is to evaluate the quality of the response. There are several dimensions you should consider in your evaluation:
16 |
17 | {Dimensions}
18 |
19 | A good response should meet all of the above criteria.
20 |
21 | ## Reviewer
22 | The quality of the output is
--------------------------------------------------------------------------------
/prompts/salmon_reward_model_prompt_v1.txt:
--------------------------------------------------------------------------------
1 | # System
2 |
3 | You are a reviewer whose goal is to judge the quality of the AI system's responses to instructions.
4 |
5 | ## Annotation Example
6 |
7 | ### AI system's Response
8 | {Output}
9 |
10 | ### Instruction to the AI system
11 | {Input}
12 |
13 | ## Annotation Guideline
14 |
15 | Your task is to evaluate the quality of the response. There are several dimensions you should consider in your evaluation:
16 |
17 | - The AI must avoid generating commentary about its own responses or engaging in self-praise. It needs to stay humble.
18 | - The AI must prioritize fulfilling the instruction, avoiding high-level pros-and-cons analysis, step-by-step instructions, or executable code.
19 | - The AI should keep the response straightforward and on-point, answering the question or completing the task without unnecessary examples.
20 | {Dimensions}
21 |
22 | A good response should meet all of the above criteria.
23 |
24 | ## Reviewer
25 | The quality of the output is
--------------------------------------------------------------------------------
/prompts/synthetic_preference_prompt.txt:
--------------------------------------------------------------------------------
1 | # Dromedary
2 |
3 | ## System Overview
4 |
5 | Hi, Dromedary! You are a helpful assistant whose goal is to select the preferred (least wrong) AI model's output for a given instruction.
6 |
7 | You will read a batch of examples, which are composed of the following:
8 |
9 | 1. an Instruction we give to the AI system
10 | 2. Output (a), the first output from the AI system
11 | 3. Output (b), the second output from the AI system
12 |
13 | ## User Conversation
14 |
15 | ### User
16 | Please select the preferred (least wrong) output for a given instruction.
17 |
18 | #### Instruction
19 | {UserInstruction}
20 |
21 | #### Output (a)
22 | {OutputA}
23 |
24 | #### Output (b)
25 | {OutputB}
26 |
27 | #### Annotation Guide
28 |
29 | To simplify the evaluation process, one aspect to consider this time is as follows:
30 |
31 | {Dimension}: {Definition}
32 |
33 | Based on the provided definition, please select the preferred output for the given instruction.
34 |
35 | ### Dromedary
36 | Sure! After carefully reading the Instruction, Output (a), Output (b), and the definition of {Dimension}, I think the more {Dimension} output is Output (
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch >= 2.0.0
2 | accelerate==0.21.0
3 | bitsandbytes==0.41.0
4 | deepspeed==0.9.3
5 | transformers==4.31.0
6 | wandb==0.15.3
7 | tiktoken==0.4.0
8 | tokenizers>=0.12.1
9 | peft==0.4.0
10 | gradio==3.35.2
11 | gradio_client==0.2.9
12 | sentencepiece==0.1.99
13 | tensorboard >= 2.12.0
14 | fairscale
15 | fire
16 | einops
17 | tqdm
18 | fastapi
19 | requests
20 | numpy
21 | uvicorn
22 | markdown2[all]
23 | shortuuid
24 | scipy
--------------------------------------------------------------------------------
/training/README.md:
--------------------------------------------------------------------------------
1 | # Training Experiences
2 |
3 | The whole **SALMON** process involves three stages: Synthetic Preference Collection, Training the Principle-following Reward Model, and RL Training with the Principle-following Reward Model. In our [paper](https://arxiv.org/abs/2310.05910), we provide a detailed description of each of these stages.
4 |
5 | ## Prerequisites
6 |
7 | For efficiency concerns, we utilize the [model parallel](https://github.com/facebookresearch/fairscale/tree/main/fairscale/nn/model_parallel) scheme from [llama](https://github.com/facebookresearch/llama) when sampling responses from the inital policy model. To prepare the sharded model checkpoints of LLaMA or Dromedary on your own machine/cluster, please refer to the [inference guide in Dromedary](https://github.com/IBM/Dromedary/tree/main/inference).
8 |
9 | ## Step 1: Collecting Principle-Driven Synthetic Preference
10 |
11 | We sample two responses from the initial policy model, and use the policy model itself to select the preferred response based on a certain human-written principle.
12 |
13 | Before diving into the experiments, please install the [`llama_dromedary` pacakge in Dromedary](https://github.com/IBM/Dromedary/tree/main/llama_dromedary) to enable model parallelism.
14 |
15 | ### Step 1.1: Preparing OASST1 Prompt Dataset
16 |
17 |
18 | Running the code
19 |
20 | ```bash
21 | cd step1_synthetic_preference_collection
22 |
23 | python -u clean_oasst1_prompts.py \
24 | --output_file "/path/to/your/oasst1_prompts.json"
25 | ```
26 |
27 |
28 |
29 | ### Step 1.2: Sampling Responses from the Policy Model
30 |
31 |
32 | Running the code
33 |
34 | ```bash
35 | salloc --nodes 8 --time 6:00:00 --gres=gpu:32g:6 srun bash scripts/generate_oasst1_response0.sh
36 | salloc --nodes 8 --time 6:00:00 --gres=gpu:32g:6 srun bash scripts/generate_oasst1_response1.sh
37 | ```
38 |
39 |
40 |
41 | ### Step 1.3: Collecting Synthetic Preferences
42 |
43 |
44 | Running the code
45 |
46 | ```bash
47 | salloc --nodes 1 --time 24:00:00 --gres=gpu:80g:8 srun bash scripts/generate_synthetic_preference.sh
48 | ```
49 |
50 |
51 |
52 | ## Step 2: Training the Principle-following Reward Model
53 |
54 | Next, for each user prompt, a subset of principles is randomly sampled from the established principle list with certain principles being randomly negated. The user prompt, model responses, and the sub-sampled principles are aggregated as a single training instance for the reward model.
55 |
56 | ### Step 2.1: Aggregating the Collected Preferences
57 |
58 |
59 | Running the code
60 |
61 | ```bash
62 | cd step2_rm_training
63 |
64 | python -u aggregate_synthetic_preference.py \
65 | --response_pattern "/path/to/your/oasst1_dromedary2_sft_response*.json" \
66 | --preference_file "/path/to/your/oasst1_dromedary2_sft_preference.json" \
67 | --output_file "/path/to/your/oasst1_dromedary2_sft_aggregated_preference.json"
68 | ```
69 |
70 |
71 |
72 | ### Step 2.2: Preference Modeling Pre-training (PMP) of the Reward Model
73 |
74 |
75 | Running the code
76 |
77 | ```bash
78 | python -u clean_pmp_data.py \
79 | --output_file "/path/to/your/pmp_data.json"
80 |
81 | salloc --nodes 1 --time 24:00:00 --gres=gpu:80g:8 srun bash scripts/train_reward_model_70b_qlora_pmp.sh
82 | ```
83 |
84 |
85 |
86 | ### Step 2.3: Fine-tune the Reward Model with Principle-driven Preferences
87 |
88 |
89 | Running the code
90 |
91 | ```bash
92 | salloc --nodes 1 --time 24:00:00 --gres=gpu:80g:8 srun bash scripts/train_reward_model_70b_qlora_ft.sh
93 | ```
94 |
95 |
96 |
97 | ## Step 3: RL Training with the Principle-following Reward Model
98 |
99 | Finally, we train the policy model with the principle-following reward model. We use the diverse user prompts from [ShareGPT](https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered), [Dolly-15k](https://huggingface.co/datasets/databricks/databricks-dolly-15k), [OpenAssistant](https://huggingface.co/datasets/OpenAssistant/oasst1), [OpenOrca](https://huggingface.co/datasets/Open-Orca/OpenOrca), and [MATH](https://huggingface.co/datasets/competition_math).
100 |
101 | ### Step 3.1: Preparing the Prompt Dataset for RL Training
102 |
103 |
104 | Running the code
105 |
106 | ```bash
107 | cd step3_ppo_training
108 |
109 | python subsample_openorca_prompts.py \
110 | --train_data_path "/path/to/your/l1M-GPT4-Augmented.parquet (obtained from OpenOrca)" \
111 | --output_path "/path/to/your/openorca_prompts.json"
112 |
113 | python aggregate_sharegpt_prompts.py \
114 | --data_files=zetavg/ShareGPT-Processed,path/to/sg_90k_part1.json.json,path/to/sg_90k_part1.json (obtained from ShareGPT_Vicuna_unfiltered) \
115 | --output_path "/path/to/sharegpt_prompts.json"
116 |
117 | python clean_and_merge_prompts.py \
118 | --sharegpt_prompt_path "/path/to/sharegpt_prompts.json" \
119 | --openorca_prompt_path "/path/to/openorca_prompts.json" \
120 | --output_file "/path/to/your/salmon_merged_prompts.json"
121 | ```
122 |
123 |
124 |
125 | ### Step 3.2: RL Training
126 |
127 |
128 | Running the code
129 |
130 | ```bash
131 | salloc --nodes 6 --time 24:00:00 --gres=gpu:80g:8 srun bash scripts/train_ppo_model_70b_qlora_salmon.sh
132 | ```
133 |
--------------------------------------------------------------------------------
/training/data_utils/common_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The Alpaca Team
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import argparse
16 | import glob
17 | import os
18 | import random
19 | from typing import (
20 | Callable,
21 | Dict,
22 | Optional,
23 | Sequence,
24 | Union,
25 | Mapping,
26 | Any,
27 | )
28 |
29 | import numpy as np
30 | import torch
31 | import torch.nn.functional as F
32 | from torch.utils.data import DataLoader
33 |
34 | Numeric = Union[int, float]
35 |
36 |
37 | def zip_(*args: Sequence):
38 | """Assert sequences of same length before zipping."""
39 | if len(args) == 0:
40 | return []
41 | assert alleq(args, lambda x, y: len(x) == len(y))
42 | return zip(*args)
43 |
44 |
45 | def mean(*seqs: Sequence[Numeric]) -> Union[Numeric, Sequence[Numeric]]:
46 | singleton = len(seqs) == 1
47 | means = [float(np.mean(seq)) for seq in seqs]
48 | return means[0] if singleton else means
49 |
50 |
51 | def alleq(l: Sequence, f: Optional[Callable] = lambda x, y: x == y):
52 | """Check all arguments in a sequence are equal according to a given criterion.
53 |
54 | Args:
55 | f: A bi-variate boolean function.
56 | l: A list/tuple.
57 |
58 | Returns:
59 | True if everything is equal; otherwise False.
60 | """
61 | return all(f(l[0], li) for li in l[1:])
62 |
63 |
64 | def flatten_dict(nested, sep=".", postprocess_fn=lambda *args: args):
65 | def rec(nest, prefix, into):
66 | for k, v in nest.items():
67 | if sep in k:
68 | raise ValueError(f"separator '{sep}' not allowed to be in key '{k}'")
69 | if isinstance(v, dict): # collections.Mapping fails in py3.10.
70 | rec(v, prefix + k + sep, into)
71 | else:
72 | v = postprocess_fn(v)
73 | into[prefix + k] = v
74 |
75 | flat = {}
76 | rec(nested, "", flat)
77 | return flat
78 |
79 |
80 | def unpack_dict(
81 | d: Dict, keys: Sequence[str], return_type: type = tuple
82 | ) -> Union[Sequence, Dict]:
83 | if return_type in (tuple, list):
84 | return return_type(d[key] for key in keys)
85 | elif return_type == dict:
86 | return {key: d[key] for key in keys}
87 | else:
88 | raise ValueError(f"Unknown return_type: {return_type}")
89 |
90 |
91 | def merge_dict(dicts: Sequence[dict], merge_fn: Callable = lambda *args: args) -> dict:
92 | """Merge a sequence of dicts (with the same set of keys) into a single dict."""
93 | if len(dicts) == 0:
94 | return dict()
95 | return {key: merge_fn([dict_[key] for dict_ in dicts]) for key in dicts[0].keys()}
96 |
97 |
98 | def prepare_inputs(
99 | data: Union[torch.Tensor, Any], device: Union[str, int, torch.device]
100 | ) -> Union[torch.Tensor, Any]:
101 | if isinstance(data, Mapping):
102 | return type(data)(
103 | {k: prepare_inputs(v, device) for k, v in data.items()}
104 | ) # noqa
105 | elif isinstance(data, (tuple, list)):
106 | return type(data)(prepare_inputs(v, device) for v in data)
107 | elif isinstance(data, torch.Tensor):
108 | return data.to(device) # This can break with deepspeed.
109 | return data
110 |
111 |
112 | def compute_logprobs(
113 | logits: torch.Tensor, labels: torch.Tensor, ignore_index: int
114 | ) -> torch.Tensor:
115 | """Compute per-token logprobs, zeroing out places with ignore_index (padding)."""
116 | return -F.cross_entropy(
117 | logits.permute(0, 2, 1), labels, reduction="none", ignore_index=ignore_index
118 | )
119 |
120 |
121 | def pad(
122 | inputs: torch.Tensor,
123 | target_size: Union[torch.Size, Sequence[int]],
124 | value=0.0,
125 | left=True,
126 | ):
127 | current_size = inputs.size()
128 | diffs = tuple(ti - ci for ti, ci in zip_(target_size, current_size))
129 | pad_params = []
130 | for diff in diffs:
131 | pad_params = ([diff, 0] if left else [0, diff]) + pad_params
132 | res = F.pad(inputs, pad=pad_params, value=value)
133 | return res
134 |
135 |
136 | def left_pad(
137 | inputs: torch.Tensor, target_size: Union[torch.Size, Sequence[int]], value=0.0
138 | ):
139 | return pad(inputs=inputs, target_size=target_size, value=value, left=True)
140 |
141 |
142 | def right_pad(
143 | inputs: torch.Tensor, target_size: Union[torch.Size, Sequence[int]], value=0.0
144 | ):
145 | return pad(inputs=inputs, target_size=target_size, value=value, left=False)
146 |
147 |
148 | def manual_seed(args_or_seed: Union[int, argparse.Namespace], fix_cudnn=False):
149 | if hasattr(args_or_seed, "seed"):
150 | args_or_seed = args_or_seed.seed
151 | random.seed(args_or_seed)
152 | np.random.seed(args_or_seed)
153 | torch.manual_seed(args_or_seed)
154 | torch.cuda.manual_seed_all(args_or_seed)
155 | os.environ["PYTHONHASHSEED"] = str(args_or_seed)
156 | if fix_cudnn:
157 | torch.backends.cudnn.deterministic = True # noqa
158 | torch.backends.cudnn.benchmark = False # noqa
159 |
160 |
161 | def make_meta_prompts(meta_prompt_pattern: str):
162 | meta_prompt_files = glob.glob(meta_prompt_pattern)
163 | print(f"Found {len(meta_prompt_files)} meta prompts: {meta_prompt_files}")
164 |
165 | meta_prompts = []
166 | for meta_prompt_file in meta_prompt_files:
167 | with open(meta_prompt_file, "r", encoding="utf-8") as f:
168 | meta_prompt = f.readlines()
169 | meta_prompt = "".join(meta_prompt).strip()
170 | meta_prompts.append(meta_prompt)
171 | return meta_prompts
172 |
173 |
174 | class InfiniteLoader(object):
175 | """Wraps an existing loader so that it outputs stuff indefinitely; useful for semi-supervised learning."""
176 |
177 | def __init__(self, loader: DataLoader):
178 | super(InfiniteLoader, self).__init__()
179 | self.loader = loader
180 | self.iterator = iter(loader)
181 |
182 | def __next__(self):
183 | try:
184 | return next(self.iterator)
185 | except StopIteration:
186 | self.iterator = iter(self.loader)
187 | return next(self.iterator)
188 |
--------------------------------------------------------------------------------
/training/data_utils/data_utils_ppo.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The Self-Align Team
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import dataclasses
16 | from typing import Callable, Dict, Optional, List, Sequence
17 |
18 | import logging
19 | import pandas as pd
20 |
21 | import torch
22 | from torch.utils.data import Dataset
23 |
24 | import transformers
25 | import datasets
26 |
27 | import data_utils.common_utils as utils
28 |
29 |
30 | logger = logging.getLogger(__name__)
31 |
32 |
33 | BASE_PROMPT_DICT = {
34 | "prompt_input": "{instruction}\n\n{input}",
35 | "prompt_no_input": "{instruction}",
36 | }
37 |
38 |
39 | def format_input_and_prompt(
40 | example: Dict[str, str],
41 | meta_prompts: List[str],
42 | ) -> str:
43 | assert (
44 | "instruction" in example and "input" in example
45 | ), "Internal error: example missing required keys."
46 |
47 | if "example_id" in example:
48 | total_meta_prompt = len(meta_prompts)
49 | meta_prompt = meta_prompts[int(example["example_id"]) % total_meta_prompt]
50 | else:
51 | meta_prompt = meta_prompts[0]
52 |
53 | if example.get("input", "") != "":
54 | prompt_format = BASE_PROMPT_DICT["prompt_input"]
55 | else:
56 | prompt_format = BASE_PROMPT_DICT["prompt_no_input"]
57 |
58 | formatted_input = prompt_format.format(**example)
59 | meta_prompt = meta_prompt.split("{Output}")[0]
60 |
61 | formatted_prompt = meta_prompt.format(Input=formatted_input)
62 | return formatted_input, formatted_prompt
63 |
64 |
65 | class QueryResponseDataset(Dataset):
66 | """Dataset that emits tokenized left-padded queries."""
67 |
68 | def __init__(
69 | self,
70 | df: pd.DataFrame,
71 | meta_prompts: List[str],
72 | tokenizer: transformers.PreTrainedTokenizer,
73 | query_len: int,
74 | df_postprocessor: Optional[Callable] = None,
75 | ):
76 | super(QueryResponseDataset, self).__init__()
77 |
78 | if df_postprocessor is not None:
79 | df = df_postprocessor(df)
80 | list_dict_data = df.to_dict(orient="records")
81 |
82 | # prompts are strings; queries are tensors.
83 | inputs_and_prompts = [
84 | format_input_and_prompt(example=dict_data, meta_prompts=meta_prompts)
85 | for dict_data in list_dict_data
86 | ]
87 | formated_inputs, prompts = zip(*inputs_and_prompts)
88 |
89 | length_bonus = [
90 | dict_data.get("length_bonus", 1.0) for dict_data in list_dict_data
91 | ]
92 |
93 | # For question_type:
94 | # 0: general
95 | # 1: reasoning
96 | # 2: red-teaming
97 | question_types = [
98 | dict_data.get("question_type", 0) for dict_data in list_dict_data
99 | ]
100 |
101 | pure_queries = [
102 | tokenizer(
103 | formated_input, return_tensors="pt", truncation=False
104 | ).input_ids.squeeze(dim=0)
105 | for formated_input in formated_inputs
106 | ]
107 |
108 | queries = [
109 | tokenizer(prompt, return_tensors="pt", truncation=False).input_ids.squeeze(
110 | dim=0
111 | )
112 | for prompt in prompts
113 | ]
114 |
115 | filtered_inputs = []
116 | filtered_queries = []
117 | filtered_length_bonus = []
118 | filtered_question_types = []
119 |
120 | for pure_query, query, ex_length_bonus, ex_question_type in zip(
121 | pure_queries, queries, length_bonus, question_types
122 | ):
123 | if len(query) <= query_len:
124 | filtered_inputs.append(pure_query)
125 | filtered_queries.append(query)
126 | filtered_length_bonus.append(ex_length_bonus)
127 | filtered_question_types.append(ex_question_type)
128 |
129 | logger.warning(
130 | f"Filtered out {len(queries) - len(filtered_queries)} instances out of {len(queries)} that "
131 | f"exceed length limit. These examples are not used for training, but will still be used in evaluation. "
132 | )
133 |
134 | pure_queries = torch.stack(
135 | [
136 | utils.left_pad(
137 | pure_query, target_size=(query_len,), value=tokenizer.pad_token_id
138 | )
139 | for pure_query in filtered_inputs
140 | ]
141 | )
142 |
143 | queries = torch.stack(
144 | [
145 | utils.left_pad(
146 | query, target_size=(query_len,), value=tokenizer.pad_token_id
147 | )
148 | for query in filtered_queries
149 | ]
150 | )
151 |
152 | self.length_bonus = torch.tensor(
153 | filtered_length_bonus, dtype=torch.float32
154 | ).reshape(-1, 1)
155 | self.question_types = torch.tensor(
156 | filtered_question_types, dtype=torch.long
157 | ).reshape(-1, 1)
158 | self.pure_queries = pure_queries
159 | self.queries = queries
160 | self.query_attn_masks = queries.ne(tokenizer.pad_token_id).long()
161 |
162 | assert self.pure_queries.shape[0] == self.queries.shape[0]
163 | assert self.pure_queries.shape[0] == self.query_attn_masks.shape[0]
164 | assert self.pure_queries.shape[0] == self.length_bonus.shape[0]
165 | assert self.pure_queries.shape[0] == self.question_types.shape[0]
166 |
167 | # Auxiliary data.
168 | self.prompts = prompts
169 | self.list_dict_data = list_dict_data
170 |
171 | def __getitem__(self, i):
172 | return_dict = dict(
173 | pure_queries=self.pure_queries[i],
174 | queries=self.queries[i],
175 | query_attn_masks=self.query_attn_masks[i],
176 | length_bonus=self.length_bonus[i],
177 | question_types=self.question_types[i],
178 | )
179 | return return_dict
180 |
181 | def __len__(self):
182 | return len(self.queries)
183 |
184 |
185 | @dataclasses.dataclass
186 | class DataCollatorForQueryResponseDataset(object):
187 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
188 | return {
189 | key: torch.stack([instance[key] for instance in instances])
190 | for key in instances[0].keys()
191 | }
192 |
193 |
194 | def make_rl_data_module(
195 | tokenizer: transformers.PreTrainedTokenizer,
196 | data_args,
197 | training_args,
198 | ):
199 | # prompt_dict = utils.jload(data_args.prompt_dict_path)
200 |
201 | policy_meta_prompts = utils.make_meta_prompts(data_args.policy_meta_prompt_pattern)
202 |
203 | if data_args.dataset_path.endswith("json"):
204 | train_instructions = datasets.load_dataset(
205 | "json", data_files=data_args.dataset_path
206 | )
207 | else:
208 | train_instructions = datasets.load_dataset(
209 | data_args.dataset_path, data_args.dataset_name
210 | )
211 | train_df = pd.concat(
212 | [pd.DataFrame(train_instructions[split]) for split in data_args.train_splits]
213 | )
214 |
215 | eval_instructions = datasets.load_dataset(
216 | data_args.eval_dataset_path, data_args.eval_dataset_name
217 | )
218 | eval_df = pd.concat(
219 | [pd.DataFrame(eval_instructions[split]) for split in data_args.eval_splits]
220 | )
221 |
222 | train_dataset = QueryResponseDataset(
223 | df=train_df,
224 | meta_prompts=policy_meta_prompts,
225 | tokenizer=tokenizer,
226 | query_len=training_args.query_len,
227 | )
228 | eval_dataset = QueryResponseDataset(
229 | df=eval_df,
230 | meta_prompts=policy_meta_prompts,
231 | tokenizer=tokenizer,
232 | query_len=training_args.query_len,
233 | )
234 | return dict(
235 | train_dataset=train_dataset,
236 | eval_dataset=eval_dataset,
237 | data_collator=DataCollatorForQueryResponseDataset(),
238 | )
239 |
--------------------------------------------------------------------------------
/training/data_utils/data_utils_sft.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The Self-Align Team
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import copy
16 |
17 | import os
18 | from dataclasses import dataclass, field
19 | from typing import Optional, Dict, Sequence, List
20 | import pandas as pd
21 |
22 | import torch
23 | import transformers
24 | from torch.nn.utils.rnn import pad_sequence
25 | from datasets import load_dataset, Dataset
26 |
27 | import data_utils.common_utils as utils
28 |
29 |
30 | IGNORE_INDEX = -100
31 |
32 |
33 | @dataclass
34 | class DataCollatorForCausalLM(object):
35 | left_truncated_tokenizer: transformers.PreTrainedTokenizer
36 | tokenizer: transformers.PreTrainedTokenizer
37 | source_max_len: int
38 | target_max_len: int
39 | train_on_source: bool
40 | predict_with_generate: bool
41 | add_eos_to_target: bool
42 |
43 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
44 | # Extract elements
45 | sources = [example["input"] for example in instances]
46 | if self.add_eos_to_target:
47 | targets = [
48 | f"\n{example['output']}{self.tokenizer.eos_token}"
49 | for example in instances
50 | ]
51 | else:
52 | targets = [f"\n{example['output']}" for example in instances]
53 |
54 | begin_padding_len = self.tokenizer(
55 | ["\n"], return_tensors="pt", add_special_tokens=False
56 | ).input_ids.shape[1]
57 |
58 | # Tokenize
59 | tokenized_sources_with_prompt = self.left_truncated_tokenizer(
60 | sources,
61 | max_length=self.source_max_len,
62 | truncation=True,
63 | # add_special_tokens=False,
64 | )
65 | tokenized_targets = self.tokenizer(
66 | targets,
67 | max_length=self.target_max_len + begin_padding_len,
68 | truncation=True,
69 | add_special_tokens=False,
70 | )
71 | # Build the input and labels for causal LM
72 | input_ids = []
73 | labels = []
74 | for tokenized_source, tokenized_target in zip(
75 | tokenized_sources_with_prompt["input_ids"], tokenized_targets["input_ids"]
76 | ):
77 | tokenized_target = tokenized_target[begin_padding_len:]
78 | if not self.predict_with_generate:
79 | input_ids.append(torch.tensor(tokenized_source + tokenized_target))
80 | if not self.train_on_source:
81 | labels.append(
82 | torch.tensor(
83 | [IGNORE_INDEX for _ in range(len(tokenized_source))]
84 | + copy.deepcopy(tokenized_target)
85 | )
86 | )
87 | else:
88 | labels.append(
89 | torch.tensor(copy.deepcopy(tokenized_source + tokenized_target))
90 | )
91 | else:
92 | input_ids.append(torch.tensor(tokenized_source))
93 | # Apply padding
94 | input_ids = pad_sequence(
95 | input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
96 | )
97 | labels = (
98 | pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
99 | if not self.predict_with_generate
100 | else None
101 | )
102 | data_dict = {
103 | "input_ids": input_ids,
104 | "attention_mask": input_ids.ne(self.tokenizer.pad_token_id),
105 | }
106 | if labels is not None:
107 | data_dict["labels"] = labels
108 | return data_dict
109 |
110 |
111 | def extract_unnatural_instructions_data(examples, extract_reformulations=False):
112 | out = {
113 | "input": [],
114 | "output": [],
115 | }
116 | for example_instances in examples["instances"]:
117 | for instance in example_instances:
118 | out["input"].append(instance["instruction_with_input"])
119 | out["output"].append(instance["output"])
120 | if extract_reformulations:
121 | for example_reformulations in examples["reformulations"]:
122 | if example_reformulations is not None:
123 | for instance in example_reformulations:
124 | out["input"].append(instance["instruction_with_input"])
125 | out["output"].append(instance["output"])
126 | return out
127 |
128 |
129 | DROMEDARY_PROMPT_DICT = {
130 | "prompt_input": (
131 | "{meta_prompt}\n" "{instruction}\n\n" "{input}\n\n" "### Dromedary"
132 | ),
133 | "prompt_no_input": ("{meta_prompt}\n" "{instruction}\n\n" "### Dromedary"),
134 | }
135 |
136 |
137 | def extract_dromedary_dataset(example, meta_prompts):
138 | assert "example_id" in example
139 | total_meta_prompt = len(meta_prompts)
140 | meta_prompt = meta_prompts[int(example["example_id"]) % total_meta_prompt]
141 |
142 | if example.get("input", "") != "":
143 | prompt_format = DROMEDARY_PROMPT_DICT["prompt_input"]
144 | else:
145 | prompt_format = DROMEDARY_PROMPT_DICT["prompt_no_input"]
146 |
147 | return {
148 | "input": prompt_format.format(meta_prompt=meta_prompt, **example),
149 | "output": "\n" + example["output"],
150 | }
151 |
152 |
153 | def local_dataset(dataset_name):
154 | if dataset_name.endswith(".json"):
155 | full_dataset = load_dataset("json", data_files=dataset_name)
156 | else:
157 | raise ValueError(f"Unsupported dataset format: {dataset_name}")
158 |
159 | return full_dataset
160 |
161 |
162 | def make_sft_data_module(
163 | left_truncated_tokenizer: transformers.PreTrainedTokenizer,
164 | tokenizer: transformers.PreTrainedTokenizer,
165 | args,
166 | ) -> Dict:
167 | """
168 | Make dataset and collator for supervised fine-tuning.
169 | Datasets are expected to have the following columns: { `input`, `output` }
170 | """
171 |
172 | def load_data(dataset_name):
173 | if dataset_name == "alpaca":
174 | return load_dataset("tatsu-lab/alpaca")
175 | elif dataset_name == "alpaca-clean":
176 | return load_dataset("yahma/alpaca-cleaned")
177 | elif dataset_name == "chip2":
178 | return load_dataset("laion/OIG", data_files="unified_chip2.jsonl")
179 | elif dataset_name == "self-instruct":
180 | return load_dataset("yizhongw/self_instruct", name="self_instruct")
181 | elif dataset_name == "hh-rlhf":
182 | return load_dataset("Anthropic/hh-rlhf")
183 | elif dataset_name == "longform":
184 | return load_dataset("akoksal/LongForm")
185 | elif dataset_name == "oasst1":
186 | return load_dataset("timdettmers/openassistant-guanaco")
187 | elif dataset_name == "vicuna":
188 | raise NotImplementedError("Vicuna data was not released.")
189 | else:
190 | if os.path.exists(dataset_name):
191 | try:
192 | args.dataset_format = (
193 | args.dataset_format if args.dataset_format else "alpaca"
194 | )
195 | full_dataset = local_dataset(dataset_name)
196 | return full_dataset
197 | except:
198 | raise ValueError(f"Error loading dataset from {dataset_name}")
199 | else:
200 | raise NotImplementedError(
201 | f"Dataset {dataset_name} not implemented yet."
202 | )
203 |
204 | def format_dataset(dataset, dataset_format, meta_prompt_pattern):
205 | if dataset_format == "dromedary":
206 | assert meta_prompt_pattern is not None
207 |
208 | meta_prompts = utils.make_meta_prompts(meta_prompt_pattern)
209 |
210 | dataset = dataset.map(
211 | lambda ex: extract_dromedary_dataset(ex, meta_prompts=meta_prompts),
212 | remove_columns=["instruction", "example_id"],
213 | )
214 | elif dataset_format == "chip2" or (
215 | dataset_format is None and args.dataset == "chip2"
216 | ):
217 | dataset = dataset.map(
218 | lambda x: {
219 | "input": x["text"].split("\n: ")[0].replace(": ", ""),
220 | "output": x["text"].split("\n: ")[1],
221 | }
222 | )
223 | elif dataset_format == "self-instruct" or (
224 | dataset_format is None and args.dataset == "self-instruct"
225 | ):
226 | for old, new in [["prompt", "input"], ["completion", "output"]]:
227 | dataset = dataset.rename_column(old, new)
228 | elif dataset_format == "hh-rlhf" or (
229 | dataset_format is None and args.dataset == "hh-rlhf"
230 | ):
231 | dataset = dataset.map(lambda x: {"input": "", "output": x["chosen"]})
232 | elif dataset_format == "oasst1" or (
233 | dataset_format is None and args.dataset == "oasst1"
234 | ):
235 | dataset = dataset.map(
236 | lambda x: {
237 | "input": "",
238 | "output": x["text"],
239 | }
240 | )
241 | # Remove unused columns.
242 | dataset = dataset.remove_columns(
243 | [
244 | col
245 | for col in dataset.column_names["train"]
246 | if col not in ["input", "output"]
247 | ]
248 | )
249 | return dataset
250 |
251 | # Load dataset.
252 | dataset = load_data(args.dataset)
253 | dataset = format_dataset(dataset, args.dataset_format, args.meta_prompt_pattern)
254 |
255 | # Split train/eval, reduce size
256 | if args.do_eval or args.do_predict:
257 | if "eval" in dataset:
258 | eval_dataset = dataset["eval"]
259 | else:
260 | print(
261 | "Splitting train dataset in train and validation according to `eval_dataset_size`"
262 | )
263 | dataset = dataset["train"].train_test_split(
264 | test_size=args.eval_dataset_size, shuffle=True, seed=42
265 | )
266 | eval_dataset = dataset["test"]
267 | if (
268 | args.max_eval_samples is not None
269 | and len(eval_dataset) > args.max_eval_samples
270 | ):
271 | eval_dataset = eval_dataset.select(range(args.max_eval_samples))
272 | if args.group_by_length:
273 | eval_dataset = eval_dataset.map(
274 | lambda x: {"length": len(x["input"]) + len(x["output"])}
275 | )
276 | if args.do_train:
277 | train_dataset = dataset["train"]
278 | if (
279 | args.max_train_samples is not None
280 | and len(train_dataset) > args.max_train_samples
281 | ):
282 | train_dataset = train_dataset.select(range(args.max_train_samples))
283 | if args.group_by_length:
284 | train_dataset = train_dataset.map(
285 | lambda x: {"length": len(x["input"]) + len(x["output"])}
286 | )
287 |
288 | data_collator = DataCollatorForCausalLM(
289 | left_truncated_tokenizer=left_truncated_tokenizer,
290 | tokenizer=tokenizer,
291 | source_max_len=args.source_max_len,
292 | target_max_len=args.target_max_len,
293 | train_on_source=args.train_on_source,
294 | predict_with_generate=args.predict_with_generate,
295 | add_eos_to_target=args.add_eos_to_target,
296 | )
297 | return dict(
298 | train_dataset=train_dataset if args.do_train else None,
299 | eval_dataset=eval_dataset if args.do_eval else None,
300 | predict_dataset=eval_dataset if args.do_predict else None,
301 | data_collator=data_collator,
302 | )
303 |
--------------------------------------------------------------------------------
/training/models/configuration_llama.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 Self-Align, EleutherAI, and the HuggingFace Inc. team. All rights reserved.
3 | #
4 | # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5 | # and OPT implementations in this library. It has been modified from its
6 | # original forms to accommodate minor architectural differences compared
7 | # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8 | #
9 | # Licensed under the Apache License, Version 2.0 (the "License");
10 | # you may not use this file except in compliance with the License.
11 | # You may obtain a copy of the License at
12 | #
13 | # http://www.apache.org/licenses/LICENSE-2.0
14 | #
15 | # Unless required by applicable law or agreed to in writing, software
16 | # distributed under the License is distributed on an "AS IS" BASIS,
17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18 | # See the License for the specific language governing permissions and
19 | # limitations under the License.
20 | """ LLaMA model configuration. """
21 |
22 | from transformers.configuration_utils import PretrainedConfig
23 | from transformers.utils import logging
24 |
25 |
26 | logger = logging.get_logger(__name__)
27 |
28 | LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
29 |
30 |
31 | class LlamaConfig(PretrainedConfig):
32 | r"""
33 | This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA
34 | model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
35 | defaults will yield a similar configuration to that of the LLaMA-7B.
36 |
37 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
38 | documentation from [`PretrainedConfig`] for more information."""
39 | model_type = "llama"
40 | keys_to_ignore_at_inference = ["past_key_values"]
41 |
42 | def __init__(
43 | self,
44 | vocab_size=32000,
45 | hidden_size=4096,
46 | intermediate_size=11008,
47 | num_hidden_layers=32,
48 | num_attention_heads=32,
49 | num_key_value_heads=None,
50 | hidden_act="silu",
51 | max_position_embeddings=2048,
52 | initializer_range=0.02,
53 | rms_norm_eps=1e-6,
54 | use_cache=True,
55 | pad_token_id=0,
56 | bos_token_id=1,
57 | eos_token_id=2,
58 | tie_word_embeddings=False,
59 | rope_scaling=None,
60 | **kwargs,
61 | ):
62 | # for backward compatibility
63 | if num_key_value_heads is None:
64 | num_key_value_heads = num_attention_heads
65 |
66 | self.num_key_value_heads = num_key_value_heads
67 | self.vocab_size = vocab_size
68 | self.max_position_embeddings = max_position_embeddings
69 | self.hidden_size = hidden_size
70 | self.intermediate_size = intermediate_size
71 | self.num_hidden_layers = num_hidden_layers
72 | self.num_attention_heads = num_attention_heads
73 | self.hidden_act = hidden_act
74 | self.initializer_range = initializer_range
75 | self.rms_norm_eps = rms_norm_eps
76 | self.rope_scaling = rope_scaling
77 | self._rope_scaling_validation()
78 | self.use_cache = use_cache
79 | self.cache_shape = None
80 | super().__init__(
81 | pad_token_id=pad_token_id,
82 | bos_token_id=bos_token_id,
83 | eos_token_id=eos_token_id,
84 | tie_word_embeddings=tie_word_embeddings,
85 | **kwargs,
86 | )
87 |
88 | def _rope_scaling_validation(self):
89 | """
90 | Validate the `rope_scaling` configuration.
91 | """
92 | if self.rope_scaling is None:
93 | return
94 |
95 | if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
96 | raise ValueError(
97 | "`rope_scaling` must be a dictionary with with two fields, `name` and `factor`, "
98 | f"got {self.rope_scaling}"
99 | )
100 | rope_scaling_type = self.rope_scaling.get("type", None)
101 | rope_scaling_factor = self.rope_scaling.get("factor", None)
102 | if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
103 | raise ValueError(
104 | f"`rope_scaling`'s name field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
105 | )
106 | if (
107 | rope_scaling_factor is None
108 | or not isinstance(rope_scaling_factor, float)
109 | or rope_scaling_factor <= 1.0
110 | ):
111 | raise ValueError(
112 | f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}"
113 | )
114 |
--------------------------------------------------------------------------------
/training/models/distributed_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The Alpaca Team
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Utilities for PyTorch's distributed training.
16 |
17 | Compatible with torchrun / elastic.
18 |
19 | Internal map:
20 | https://github.com/lxuechen/ml-swissknife/blob/main/ml_swissknife/distributed_utils.py
21 | """
22 |
23 | import os
24 | import sys
25 | from typing import Optional
26 |
27 | import torch
28 | import torch.distributed as dist
29 |
30 |
31 | def setup(rank: Optional[int] = None, world_size: Optional[int] = None):
32 | if rank is None:
33 | rank = get_rank()
34 | if world_size is None:
35 | world_size = get_world_size()
36 |
37 | if world_size <= 1:
38 | return rank, world_size
39 |
40 | if not dist.is_initialized():
41 | if sys.platform == "win32":
42 | # Distributed package only covers collective communications with Gloo
43 | # backend and FileStore on Windows platform. Set init_method parameter
44 | # in init_process_group to a local file.
45 | # Example init_method="file:///f:/libtmp/some_file"
46 | init_method = "file:///f:/libtmp/dist-tmp"
47 | dist.init_process_group(
48 | backend="gloo",
49 | init_method=init_method,
50 | rank=rank,
51 | world_size=world_size,
52 | )
53 | elif torch.cuda.is_available():
54 | dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
55 | else:
56 | dist.init_process_group(backend="gloo", rank=rank, world_size=world_size)
57 |
58 | return rank, world_size
59 |
60 |
61 | def cleanup():
62 | dist.destroy_process_group()
63 |
64 |
65 | def get_rank():
66 | return int(os.getenv("RANK", 0))
67 |
68 |
69 | def get_local_rank():
70 | return int(os.getenv("LOCAL_RANK", 0))
71 |
72 |
73 | def get_world_size():
74 | return int(os.getenv("WORLD_SIZE", 1))
75 |
76 |
77 | def should_save():
78 | """Return True if the current process is the main process."""
79 | return get_rank() <= 0
80 |
81 |
82 | def all_gather_and_cat(tensor: torch.Tensor, dim=0):
83 | if get_world_size() > 1:
84 | tensor_list = [torch.empty_like(tensor) for _ in range(get_world_size())]
85 | dist.all_gather(tensor_list, tensor)
86 | tensor = torch.cat(tensor_list, dim=dim)
87 | return tensor
88 |
89 |
90 | is_main_process = should_save
91 |
--------------------------------------------------------------------------------
/training/models/qlora_model.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The Self-Align Team
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from argparse import Namespace
16 | from typing import Optional
17 | from os.path import join, exists
18 |
19 | import torch
20 | import bitsandbytes as bnb
21 | from transformers import (
22 | AutoModelForCausalLM,
23 | BitsAndBytesConfig,
24 | )
25 |
26 | from peft import (
27 | prepare_model_for_kbit_training,
28 | LoraConfig,
29 | PeftModel,
30 | PeftModelForCausalLM,
31 | )
32 | from peft.tuners.lora import LoraLayer
33 | from falcon_modeling.modelling_RW import RWForCausalLM
34 | from models.llama_with_flash_attn import LlamaForCausalLM
35 |
36 | REGISTERED_BASE_MODELS = {}
37 |
38 |
39 | def find_all_linear_names(
40 | args: Namespace,
41 | model: torch.nn.Module,
42 | ):
43 | cls = (
44 | bnb.nn.Linear4bit
45 | if args.bits == 4
46 | else (bnb.nn.Linear8bitLt if args.bits == 8 else torch.nn.Linear)
47 | )
48 | lora_module_names = set()
49 | for name, module in model.named_modules():
50 | if isinstance(module, cls):
51 | names = name.split(".")
52 | if "lora" not in names[-1]:
53 | lora_module_names.add(names[0] if len(names) == 1 else names[-1])
54 |
55 | if "lm_head" in lora_module_names: # needed for 16-bit
56 | lora_module_names.remove("lm_head")
57 | return list(lora_module_names)
58 |
59 |
60 | def get_accelerate_model(
61 | args: Namespace,
62 | checkpoint_dir: Optional[str] = None,
63 | adapter_name="lora_default",
64 | is_trainable=True,
65 | reuse_base_model=False,
66 | ):
67 | global REGISTERED_BASE_MODELS
68 |
69 | compute_dtype = (
70 | torch.float16 if args.fp16 else (torch.bfloat16 if args.bf16 else torch.float32)
71 | )
72 |
73 | if checkpoint_dir is not None:
74 | if exists(join(checkpoint_dir, "adapter_model")):
75 | checkpoint_dir = join(checkpoint_dir, "adapter_model")
76 |
77 | if exists(join(checkpoint_dir, "lora_default")):
78 | checkpoint_dir = join(checkpoint_dir, "lora_default")
79 |
80 | if args.model_name_or_path in REGISTERED_BASE_MODELS and reuse_base_model:
81 | config = {
82 | "load_in_4bit": args.bits == 4,
83 | "load_in_8bit": args.bits == 8,
84 | "llm_int8_threshold": 6.0,
85 | "llm_int8_has_fp16_weight": False,
86 | "bnb_4bit_compute_dtype": compute_dtype,
87 | "bnb_4bit_use_double_quant": args.double_quant,
88 | "bnb_4bit_quant_type": args.quant_type,
89 | }
90 |
91 | registered_model, registered_config = REGISTERED_BASE_MODELS[
92 | args.model_name_or_path
93 | ]
94 | if registered_config == config and not args.full_finetune:
95 | print(f"loading registered model {args.model_name_or_path}...")
96 | model = registered_model
97 |
98 | if checkpoint_dir is not None:
99 | model.load_adapter(
100 | checkpoint_dir,
101 | adapter_name=adapter_name,
102 | is_trainable=is_trainable,
103 | )
104 | else:
105 | modules = args.lora_modules or find_all_linear_names(args, model)
106 | print("adding LoRa modules: ", modules)
107 | config = LoraConfig(
108 | r=args.lora_r,
109 | lora_alpha=args.lora_alpha,
110 | target_modules=modules,
111 | lora_dropout=args.lora_dropout,
112 | bias="none",
113 | task_type="CAUSAL_LM",
114 | )
115 | model.add_adapter(adapter_name, peft_config=config)
116 | return model
117 | else:
118 | raise ValueError(
119 | f"Model {args.model_name_or_path} is already registered with a different config."
120 | f"{registered_config} != {config}"
121 | )
122 |
123 | current_device = torch.cuda.current_device()
124 | if args.full_finetune:
125 | assert args.bits in [16, 32]
126 |
127 | print(f"loading base model {args.model_name_or_path}...")
128 |
129 | CausalLM = AutoModelForCausalLM
130 |
131 | if "falcon" in args.model_name_or_path.lower():
132 | CausalLM = RWForCausalLM
133 | elif (
134 | "llama" in args.model_name_or_path.lower()
135 | or "vicuna" in args.model_name_or_path.lower()
136 | or "dromedary" in args.model_name_or_path.lower()
137 | ) and torch.__version__ >= "2.0.0":
138 | CausalLM = LlamaForCausalLM
139 |
140 | model = CausalLM.from_pretrained(
141 | args.model_name_or_path,
142 | load_in_4bit=args.bits == 4,
143 | load_in_8bit=args.bits == 8,
144 | device_map={"": current_device},
145 | # max_memory=max_memory,
146 | quantization_config=BitsAndBytesConfig(
147 | load_in_4bit=args.bits == 4,
148 | load_in_8bit=args.bits == 8,
149 | llm_int8_threshold=6.0,
150 | llm_int8_has_fp16_weight=False,
151 | bnb_4bit_compute_dtype=compute_dtype,
152 | bnb_4bit_use_double_quant=args.double_quant,
153 | bnb_4bit_quant_type=args.quant_type,
154 | ),
155 | torch_dtype=(
156 | torch.float16
157 | if args.fp16
158 | else (torch.bfloat16 if args.bf16 else torch.float32)
159 | ),
160 | trust_remote_code=args.trust_remote_code,
161 | )
162 | if compute_dtype == torch.float16 and args.bits == 4:
163 | major, minor = torch.cuda.get_device_capability()
164 | if major >= 8:
165 | print("=" * 80)
166 | print(
167 | "Your GPU supports bfloat16, you can accelerate training with the argument --bf16"
168 | )
169 | print("=" * 80)
170 |
171 | setattr(model, "model_parallel", True)
172 | setattr(model, "is_parallelizable", True)
173 |
174 | model.config.torch_dtype = (
175 | torch.float16 if args.fp16 else (torch.bfloat16 if args.bf16 else torch.float32)
176 | )
177 |
178 | if not args.full_finetune:
179 | model = prepare_model_for_kbit_training(
180 | model, use_gradient_checkpointing=args.gradient_checkpointing
181 | )
182 | if args.gradient_checkpointing:
183 | model.gradient_checkpointing_enable()
184 |
185 | if not args.full_finetune:
186 | if checkpoint_dir is not None:
187 | print("Loading adapters from checkpoint.")
188 |
189 | model = PeftModel.from_pretrained(
190 | model,
191 | checkpoint_dir,
192 | adapter_name=adapter_name,
193 | is_trainable=is_trainable,
194 | )
195 | else:
196 | # print(f'adding LoRA modules...')
197 | modules = args.lora_modules or find_all_linear_names(args, model)
198 | print("adding LoRa modules: ", modules)
199 | config = LoraConfig(
200 | r=args.lora_r,
201 | lora_alpha=args.lora_alpha,
202 | target_modules=modules,
203 | lora_dropout=args.lora_dropout,
204 | bias="none",
205 | task_type="CAUSAL_LM",
206 | )
207 | model = get_peft_model(model, config, adapter_name=adapter_name)
208 |
209 | if args.model_name_or_path not in REGISTERED_BASE_MODELS:
210 | config = {
211 | "load_in_4bit": args.bits == 4,
212 | "load_in_8bit": args.bits == 8,
213 | "llm_int8_threshold": 6.0,
214 | "llm_int8_has_fp16_weight": False,
215 | "bnb_4bit_compute_dtype": compute_dtype,
216 | "bnb_4bit_use_double_quant": args.double_quant,
217 | "bnb_4bit_quant_type": args.quant_type,
218 | }
219 | REGISTERED_BASE_MODELS[args.model_name_or_path] = (model, config)
220 |
221 | for name, module in model.named_modules():
222 | if isinstance(module, LoraLayer):
223 | if args.bf16:
224 | module = module.to(torch.bfloat16)
225 | else:
226 | module = module.to(torch.float32)
227 | if "lm_head" in name or "embed_tokens" in name:
228 | if hasattr(module, "weight"):
229 | if args.bf16 and module.weight.dtype == torch.float32:
230 | module = module.to(torch.bfloat16)
231 | # if not args.bf16:
232 | # module = module.to(torch.float32)
233 | return model
234 |
235 |
236 | def load_4bit_model_for_inference(
237 | checkpoint_dir: str,
238 | bits: int = 4,
239 | fp16: bool = False,
240 | bf16: bool = False,
241 | double_quant: bool = True,
242 | quant_type: str = "nf4",
243 | gradient_checkpointing: bool = False,
244 | adapter_name="lora_default",
245 | is_trainable=True,
246 | reuse_base_model=False,
247 | trust_remote_code=False,
248 | base_model_mapping=None,
249 | fully_initialize=False,
250 | ):
251 | if checkpoint_dir is not None:
252 | if exists(join(checkpoint_dir, "adapter_model")):
253 | checkpoint_dir = join(checkpoint_dir, "adapter_model")
254 |
255 | if exists(join(checkpoint_dir, "lora_default")):
256 | checkpoint_dir = join(checkpoint_dir, "lora_default")
257 |
258 | config = LoraConfig.from_pretrained(checkpoint_dir)
259 | base_model_name_or_path = config.base_model_name_or_path
260 |
261 | if base_model_mapping is not None:
262 | dict_base_model_mapping = eval(base_model_mapping)
263 | if (
264 | dict_base_model_mapping is not None
265 | and base_model_name_or_path in dict_base_model_mapping
266 | ):
267 | base_model_name_or_path = dict_base_model_mapping[base_model_name_or_path]
268 |
269 | args = Namespace(
270 | model_name_or_path=base_model_name_or_path,
271 | bits=bits,
272 | fp16=fp16,
273 | bf16=bf16,
274 | double_quant=double_quant,
275 | quant_type=quant_type,
276 | gradient_checkpointing=gradient_checkpointing,
277 | trust_remote_code=trust_remote_code,
278 | full_finetune=False,
279 | lora_r=64 if fully_initialize else None,
280 | lora_alpha=16 if fully_initialize else None,
281 | lora_dropout=0.0 if fully_initialize else None,
282 | lora_modules=None,
283 | )
284 |
285 | if fully_initialize:
286 | print("Fully initializing qlora model.")
287 |
288 | model = get_accelerate_model(
289 | args,
290 | checkpoint_dir=None if fully_initialize else checkpoint_dir,
291 | adapter_name=adapter_name,
292 | is_trainable=is_trainable,
293 | reuse_base_model=reuse_base_model,
294 | )
295 | return model
296 |
297 |
298 | def get_peft_model(model, peft_config, adapter_name="default"):
299 | """
300 | Returns a Peft model object from a model and a config.
301 |
302 | Args:
303 | model ([`transformers.PreTrainedModel`]): Model to be wrapped.
304 | peft_config ([`PeftConfig`]): Configuration object containing the parameters of the Peft model.
305 | """
306 | peft_config.base_model_name_or_path = model.__dict__.get("name_or_path", None)
307 | return PeftModelForCausalLM(model, peft_config, adapter_name=adapter_name)
308 |
--------------------------------------------------------------------------------
/training/models/reward_model.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The Self-Align Team
2 | # Copyright 2023 The Alpaca Team
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from argparse import Namespace
17 | import os
18 | from typing import Optional, Dict, Sequence, Union
19 |
20 | import einops
21 | import torch
22 | from torch import Tensor, nn
23 | import torch.nn.functional as F
24 |
25 | import transformers
26 | from transformers.trainer_utils import EvalPrediction
27 | from transformers.utils.generic import ModelOutput
28 |
29 | from peft import PeftModel, LoraModel, LoraConfig
30 |
31 | from models.qlora_model import get_accelerate_model
32 | from models.llama_with_flash_attn import LlamaModel as FlashAttnLlamaModel
33 |
34 |
35 | def unpack_dict(
36 | d: Dict, keys: Sequence[str], return_type: type = tuple
37 | ) -> Union[Sequence, Dict]:
38 | if return_type in (tuple, list):
39 | return return_type(d[key] for key in keys)
40 | elif return_type == dict:
41 | return {key: d[key] for key in keys}
42 | else:
43 | raise ValueError(f"Unknown return_type: {return_type}")
44 |
45 |
46 | def batch_select(input: Tensor, index: Tensor):
47 | """Select elements from a batched tensor with a batched index tensor.
48 |
49 | Example:
50 | input = torch.tensor([
51 | [0, 1, 2],
52 | [3, 0, 9],
53 | [6, 7, 8],
54 | ])
55 | index = torch.tensor([[0, 1], [1, 0], [0, 0]])
56 | batch_select(input, index) = tensor([
57 | [0, 1],
58 | [0, 3],
59 | [6, 6]
60 | ])
61 | """
62 | dummy_index = torch.arange(input.size(0), device=input.device).unsqueeze(-1)
63 | return input[dummy_index, index]
64 |
65 |
66 | def make_generative_lm(
67 | args: Namespace,
68 | model_name_or_path: str,
69 | qlora: bool = False,
70 | checkpoint_dir: Optional[str] = None,
71 | adapter_name="lora_default",
72 | is_trainable=True,
73 | reuse_base_model=False,
74 | **kwargs,
75 | ):
76 | model_cls = transformers.LlamaForCausalLM
77 |
78 | if qlora:
79 | if checkpoint_dir is None or checkpoint_dir in ["scratch", "none"]:
80 | return get_accelerate_model(args, None)
81 | else:
82 | return get_accelerate_model(
83 | args,
84 | checkpoint_dir=checkpoint_dir,
85 | adapter_name=adapter_name,
86 | is_trainable=is_trainable,
87 | reuse_base_model=reuse_base_model,
88 | )
89 | return model_cls.from_pretrained(model_name_or_path, **kwargs)
90 |
91 |
92 | def get_transformer_hidden_size(model: transformers.PreTrainedModel):
93 | if isinstance(model, PeftModel):
94 | return get_transformer_hidden_size(model.base_model)
95 |
96 | if isinstance(model, LoraModel):
97 | return get_transformer_hidden_size(model.model)
98 |
99 | if isinstance(model, transformers.GPT2LMHeadModel):
100 | hidden_size_attr_name = "n_embd"
101 | elif isinstance(model, transformers.OPTForCausalLM):
102 | hidden_size_attr_name = "word_embed_proj_dim"
103 | elif isinstance(model, transformers.T5ForConditionalGeneration):
104 | hidden_size_attr_name = "d_model"
105 | elif isinstance(model, transformers.GPTBigCodeForCausalLM):
106 | hidden_size_attr_name = "n_embd"
107 | elif "modelling_RW.RWModel" in str(
108 | type(model)
109 | ) or "modelling_RW.RWForCausalLM" in str(type(model)):
110 | # TODO(zhiqings): Hack to add support for Falcon.
111 | hidden_size_attr_name = "hidden_size"
112 | else:
113 | # Hack to deal with the fact that transformers library changed the LLaMA model name.
114 | llama_cls = getattr(
115 | transformers,
116 | "LLaMAForCausalLM"
117 | if hasattr(transformers, "LLaMAForCausalLM")
118 | else "LlamaForCausalLM",
119 | )
120 | if isinstance(model, llama_cls) or "LlamaForCausalLM" in str(type(model)):
121 | hidden_size_attr_name = "hidden_size"
122 | else:
123 | raise ValueError(f"Unknown base_model type: {type(model)}")
124 | from typing import Any, Mapping
125 | return getattr(model.config, hidden_size_attr_name)
126 |
127 |
128 | class RewardConfig(transformers.PretrainedConfig):
129 | model_type = "reward_model"
130 |
131 | # Huggingface doesn't allow non-kwargs for `__init__`.
132 | def __init__(self, backbone_model_name_or_path=None, **kwargs):
133 | super(RewardConfig, self).__init__(**kwargs)
134 | self.backbone_model_name_or_path = backbone_model_name_or_path
135 |
136 |
137 | class RewardModelOutput(ModelOutput):
138 | rewards: Tensor = None
139 |
140 |
141 | class RewardModel(transformers.PreTrainedModel):
142 | config_class = RewardConfig
143 | supports_gradient_checkpointing = True
144 |
145 | def __init__(
146 | self,
147 | args: Namespace,
148 | config: RewardConfig,
149 | checkpoint_dir: Optional[str] = None,
150 | adapter_name="lora_default",
151 | **kwargs,
152 | ):
153 | super(RewardModel, self).__init__(config)
154 | self.adapter_name = adapter_name
155 | self.backbone_model = make_generative_lm(
156 | args,
157 | config.backbone_model_name_or_path,
158 | checkpoint_dir=checkpoint_dir,
159 | adapter_name=adapter_name,
160 | **kwargs,
161 | )
162 | hidden_size = get_transformer_hidden_size(self.backbone_model)
163 | reward_head = nn.Linear(hidden_size, 1)
164 | torch.nn.init.zeros_(reward_head.bias)
165 | device = next(self.backbone_model.parameters()).device
166 | self.reward_head = reward_head.to(device)
167 |
168 | if checkpoint_dir is not None:
169 | reward_head_path = os.path.join(checkpoint_dir, "reward_head")
170 | if os.path.exists(reward_head_path):
171 | self.reward_head.load_state_dict(
172 | torch.load(
173 | reward_head_path,
174 | map_location="cpu",
175 | )
176 | )
177 | else:
178 | print(f"Warning: reward head not found at {reward_head_path}")
179 |
180 | self.reward_head.requires_grad_(kwargs.get("is_trainable", True))
181 |
182 | def forward(self, input_ids, attention_mask=None, return_dict=True, **kwargs):
183 | # We only compute the rewards and don't compute the logistic regression loss in this function so that it's
184 | # easier to use for later stages of reranking / RL training.
185 | self.backbone_model.set_adapter(self.adapter_name)
186 | self.backbone_model.config.use_cache = False
187 |
188 | outputs = self.backbone_model(
189 | input_ids=input_ids,
190 | attention_mask=attention_mask,
191 | return_dict=True,
192 | output_hidden_states=True,
193 | **kwargs,
194 | )
195 | last_hidden_state = outputs.hidden_states[-1]
196 | assert isinstance(last_hidden_state, torch.Tensor), f"{outputs}"
197 | # last_hidden_state = outputs.last_hidden_state
198 | # TODO(zhiqings): Hacking to make sure every parameter is used in the backward pass.
199 | logits = outputs.logits
200 | last_hidden_state = last_hidden_state + 0.0 * torch.mean(logits)
201 |
202 | last_hidden_state_at_the_end = last_hidden_state[:, -1, :]
203 | # TODO(lxuechen): Make returning rewards at all positions and last_hidden_state an option.
204 | final_dtype_tensor = next(self.reward_head.parameters())
205 | last_hidden_state_at_the_end = last_hidden_state_at_the_end.type_as(
206 | final_dtype_tensor
207 | )
208 |
209 | if final_dtype_tensor.dtype == torch.float16:
210 | last_hidden_state_at_the_end = last_hidden_state_at_the_end.to(
211 | torch.float32
212 | )
213 |
214 | rewards = self.reward_head(last_hidden_state_at_the_end).squeeze(-1)
215 | return RewardModelOutput(rewards=rewards) if return_dict else (rewards,)
216 |
217 | def _set_gradient_checkpointing(self, module, value=False):
218 | if isinstance(module, transformers.LlamaModel):
219 | module.gradient_checkpointing = value
220 | elif isinstance(module, FlashAttnLlamaModel):
221 | module.gradient_checkpointing = value
222 | if isinstance(module, transformers.GPTBigCodeModel):
223 | module.gradient_checkpointing = value
224 | # TODO(zhiqings): Hack to add support for Falcon.
225 | if "RWModel" in str(type(module)):
226 | module.gradient_checkpointing = value
227 |
228 |
229 | class RewardModelTrainer(transformers.Trainer):
230 | def compute_loss(self, model, inputs, return_outputs=False):
231 | # input_ids, attention_mask each of size (bsz, num_candidates, seq_len).
232 | # index_0, index_1 each of size (bsz, num_pairs); indexes into input_ids.
233 | # choice of size (bsz, num_pairs); 1 if index_1's seq is chosen, 0 otherwise.
234 | input_ids, attention_mask, index_0, index_1, choice = unpack_dict(
235 | inputs, keys=("input_ids", "attention_mask", "index_0", "index_1", "choice")
236 | )
237 | num_candidates, num_pairs = input_ids.size(1), choice.size(1)
238 | input_ids_flat, attention_mask_flat = tuple(
239 | einops.rearrange(x, "b c l -> (b c) l") for x in (input_ids, attention_mask)
240 | )
241 | outputs = model(input_ids=input_ids_flat, attention_mask=attention_mask_flat)
242 | rewards_flat = outputs.rewards
243 | rewards = einops.rearrange(
244 | rewards_flat, "(b c) -> b c", c=num_candidates
245 | ) # Size: (bsz, num_candidates).
246 |
247 | rewards_0, rewards_1 = tuple(
248 | batch_select(rewards, index) for index in (index_0, index_1)
249 | ) # Size: (bsz, num_pairs).
250 | logits = rewards_1 - rewards_0 # Size: (bsz, num_pairs).
251 | # Type casting of `choice` is due to amp.autocast context manager.
252 | loss = F.binary_cross_entropy_with_logits(
253 | logits, choice.to(logits.dtype), reduction="mean"
254 | )
255 | return (loss, dict(logits=logits)) if return_outputs else loss
256 |
257 |
258 | def compute_reward_modeling_metrics(eval_prediction: EvalPrediction) -> Dict:
259 | # eval_prediction.label_ids is a tuple that matches up with `training_args.label_names`.
260 | logits = torch.tensor(eval_prediction.predictions).squeeze(-1)
261 | labels = torch.tensor(eval_prediction.label_ids[-1]).squeeze(-1)
262 | predictions = (logits >= 0.0).long()
263 | accuracy = predictions.eq(labels).float().mean().item()
264 | label_positive_rate = (labels == 1).float().mean().item()
265 | return dict(
266 | accuracy=accuracy,
267 | label_positive_rate=label_positive_rate,
268 | )
269 |
270 |
271 | def load_4bit_reward_model_for_inference(
272 | checkpoint_dir: str,
273 | bits: int = 4,
274 | fp16: bool = False,
275 | bf16: bool = False,
276 | double_quant: bool = True,
277 | quant_type: str = "nf4",
278 | gradient_checkpointing: bool = False,
279 | adapter_name="lora_default",
280 | is_trainable=True,
281 | reuse_base_model=False,
282 | trust_remote_code=False,
283 | base_model_mapping=None,
284 | ):
285 | # Load the model.
286 | lora_checkpoint_dir = checkpoint_dir
287 | if os.path.exists(os.path.join(lora_checkpoint_dir, "adapter_model")):
288 | lora_checkpoint_dir = os.path.join(lora_checkpoint_dir, "adapter_model")
289 | if os.path.exists(os.path.join(lora_checkpoint_dir, "lora_default")):
290 | lora_checkpoint_dir = os.path.join(lora_checkpoint_dir, "lora_default")
291 |
292 | lora_config = LoraConfig.from_pretrained(lora_checkpoint_dir)
293 | base_model_name_or_path = lora_config.base_model_name_or_path
294 |
295 | if base_model_mapping is not None:
296 | dict_base_model_mapping = eval(base_model_mapping)
297 | if (
298 | dict_base_model_mapping is not None
299 | and base_model_name_or_path in dict_base_model_mapping
300 | ):
301 | base_model_name_or_path = dict_base_model_mapping[base_model_name_or_path]
302 |
303 | config = RewardConfig(backbone_model_name_or_path=base_model_name_or_path)
304 |
305 | args = Namespace(
306 | model_name_or_path=config.backbone_model_name_or_path,
307 | bits=bits,
308 | fp16=fp16,
309 | bf16=bf16,
310 | double_quant=double_quant,
311 | quant_type=quant_type,
312 | trust_remote_code=trust_remote_code,
313 | full_finetune=False,
314 | gradient_checkpointing=gradient_checkpointing,
315 | )
316 |
317 | model = RewardModel(
318 | args,
319 | config,
320 | checkpoint_dir=checkpoint_dir,
321 | qlora=bits == 4 or bits == 8,
322 | adapter_name=adapter_name,
323 | is_trainable=is_trainable,
324 | reuse_base_model=reuse_base_model,
325 | )
326 | return model
327 |
--------------------------------------------------------------------------------
/training/models/rl_models.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The Self-Align Team
2 | # Copyright 2023 The Alpaca Team
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Model classes that are shared across different algorithms.
17 |
18 | WARNING:
19 | Do not tamper with the state_dict function for any of these classes.
20 | If you tamper, make sure the keys are the same, otherwise FSDP will get confused.
21 | """
22 |
23 | import abc
24 | import logging
25 | from typing import Dict, Optional
26 |
27 | import torch
28 | import transformers
29 | from torch import Tensor, nn
30 |
31 | from data_utils.common_utils import right_pad, compute_logprobs
32 | from models.reward_model import get_transformer_hidden_size
33 |
34 |
35 | logger = logging.getLogger(__name__)
36 |
37 |
38 | class Policy(nn.Module, abc.ABC):
39 | def __init__(
40 | self,
41 | args,
42 | base_model: transformers.PreTrainedModel,
43 | base_tokenizer: transformers.PreTrainedTokenizer,
44 | adapter_name: Optional[str] = None,
45 | ):
46 | super().__init__()
47 | self.args = args
48 | self.base_model = base_model
49 | self.base_tokenizer = base_tokenizer
50 | self.adapter_name = adapter_name
51 |
52 | @abc.abstractmethod
53 | def forward(
54 | self,
55 | queries: Tensor,
56 | query_attn_masks: Tensor,
57 | responses: Tensor,
58 | temperature: Optional[float] = None,
59 | ) -> Dict[str, Tensor]:
60 | raise NotImplementedError
61 |
62 | def respond(
63 | self,
64 | queries: Tensor,
65 | query_attn_masks: Tensor,
66 | temperature: Optional[float] = None,
67 | num_return_sequences=1,
68 | ) -> Dict[str, Tensor]:
69 | assert not self.training, "Policy must be in eval model for generation."
70 | return self._post_respond(
71 | self._respond(queries, query_attn_masks, temperature, num_return_sequences)
72 | )
73 |
74 | @abc.abstractmethod
75 | def _respond(
76 | self,
77 | queries: Tensor,
78 | query_attn_masks: Tensor,
79 | temperature: Optional[float] = None,
80 | num_return_sequences=1,
81 | ) -> Dict[str, Tensor]:
82 | raise NotImplementedError
83 |
84 | def _post_respond(self, respond_outputs: Dict[str, Tensor]) -> Dict[str, Tensor]:
85 | return respond_outputs
86 |
87 |
88 | class AutoregressivePolicy(Policy):
89 | def forward(
90 | self,
91 | queries: Tensor,
92 | query_attn_masks: Tensor,
93 | responses: Tensor,
94 | temperature: Optional[float] = None,
95 | ) -> Dict[str, Tensor]:
96 | # TODO(lxuechen): Refactor attention mask. Here query_attn_masks overrides padding-based attention mask.
97 |
98 | if self.adapter_name is not None:
99 | self.base_model.set_adapter(self.adapter_name)
100 | self.base_model.config.use_cache = False
101 |
102 | if temperature is None:
103 | temperature = self.args.temperature
104 | input_ids = torch.cat([queries, responses], dim=1)
105 | attention_mask = input_ids.ne(self.base_tokenizer.pad_token_id)
106 | attention_mask[:, : queries.size(1)] = query_attn_masks
107 | # Fix position id issues and ensure consistency with `respond` for GPT and OPT.
108 | inputs = self.base_model.prepare_inputs_for_generation(
109 | input_ids=input_ids,
110 | attention_mask=attention_mask,
111 | use_cache=False,
112 | )
113 | outputs = self.base_model(**inputs, output_hidden_states=True)
114 | original_logits = outputs.logits[:, -self.args.response_len - 1 : -1]
115 | logits = original_logits / temperature
116 | labels = input_ids[:, -self.args.response_len :]
117 | logprobs = compute_logprobs(
118 | logits, labels, ignore_index=self.base_tokenizer.pad_token_id
119 | )
120 | entropies = -(logits.softmax(dim=-1) * logits.log_softmax(dim=-1)).sum(dim=-1)
121 | last_hidden_state = outputs.hidden_states[-1][
122 | :, -self.args.response_len - 1 : -1
123 | ]
124 | return dict(
125 | original_logits=original_logits,
126 | logits=logits,
127 | logprobs=logprobs,
128 | entropies=entropies,
129 | last_hidden_state=last_hidden_state,
130 | )
131 |
132 | def _respond(
133 | self,
134 | queries: Tensor,
135 | query_attn_masks: Tensor,
136 | temperature: Optional[float] = None,
137 | num_return_sequences=1,
138 | ) -> Dict[str, Tensor]:
139 | if self.adapter_name is not None:
140 | self.base_model.set_adapter(self.adapter_name)
141 | self.base_model.config.use_cache = True
142 | self.base_model.config.cache_shape = (
143 | queries.shape[-1] + self.args.response_len,
144 | )
145 |
146 | if temperature is None:
147 | temperature = self.args.temperature
148 | sequences = self.base_model.generate(
149 | inputs=queries,
150 | attention_mask=query_attn_masks,
151 | do_sample=True,
152 | max_new_tokens=self.args.response_len,
153 | pad_token_id=self.base_tokenizer.pad_token_id,
154 | eos_token_id=self.base_tokenizer.eos_token_id,
155 | top_p=1.0,
156 | top_k=0,
157 | temperature=temperature,
158 | num_return_sequences=num_return_sequences,
159 | synced_gpus=True,
160 | )
161 | responses = right_pad(
162 | sequences[:, queries.size(1) :],
163 | target_size=(sequences.size(0), self.args.response_len),
164 | value=self.base_tokenizer.pad_token_id,
165 | )
166 | return dict(
167 | responses=responses
168 | ) # Size (bsz * num_return_sequences, response_len).
169 |
170 |
171 | class Value(nn.Module, abc.ABC):
172 | def __init__(
173 | self,
174 | args,
175 | base_model: transformers.PreTrainedModel,
176 | base_tokenizer: transformers.PreTrainedTokenizer,
177 | adapter_name: Optional[str] = None,
178 | ):
179 | super().__init__()
180 | self.args = args
181 | self.base_model = base_model
182 | self.base_tokenizer = base_tokenizer
183 | hidden_size = get_transformer_hidden_size(base_model)
184 | value_head = torch.nn.Linear(hidden_size, 1)
185 | value_head.weight.data.zero_()
186 | value_head.bias.data.zero_()
187 | self.value_head = value_head.to(next(base_model.parameters()).device)
188 | self.adapter_name = adapter_name
189 |
190 | @abc.abstractmethod
191 | def forward(
192 | self, queries: Tensor, query_attn_masks: Tensor, responses: Tensor
193 | ) -> Dict[str, Tensor]:
194 | raise NotImplementedError
195 |
196 |
197 | class AutoregressiveValue(Value):
198 | def forward(
199 | self, queries: Tensor, query_attn_masks: Tensor, responses: Tensor
200 | ) -> Dict[str, Tensor]:
201 | if self.adapter_name is not None:
202 | self.base_model.set_adapter(self.adapter_name)
203 | self.base_model.config.use_cache = False
204 |
205 | sequences = torch.cat([queries, responses], dim=1)
206 | sequence_attn_masks = sequences.ne(self.base_tokenizer.pad_token_id)
207 |
208 | inputs = self.base_model.prepare_inputs_for_generation(
209 | input_ids=sequences,
210 | attention_mask=sequence_attn_masks,
211 | use_cache=False,
212 | )
213 | outputs = self.base_model(
214 | **inputs,
215 | return_dict=True,
216 | output_hidden_states=True,
217 | )
218 | # value[t]: \hat{V}(sequences_{:t-1}); must align with `_estimate_advantage`.
219 |
220 | last_hidden_state = outputs.hidden_states[-1]
221 | assert isinstance(last_hidden_state, torch.Tensor), f"{outputs}"
222 | logits = outputs.logits
223 | # TODO(zhiqings): Hacking to make sure every parameter is used in the backward pass.
224 | last_hidden_state = last_hidden_state + 0.0 * torch.mean(logits)
225 | last_hidden_state = last_hidden_state[:, queries.size(1) - 1 : -1]
226 |
227 | # TODO(zhiqings): now we just manully convert output types
228 | last_hidden_state = last_hidden_state.type_as(
229 | next(self.value_head.parameters())
230 | )
231 | values = self.value_head(last_hidden_state).squeeze(-1)
232 | return dict(values=values)
233 |
234 |
235 | class ActorCritic(nn.Module):
236 | def __init__(self, policy: Policy, value_model: Value):
237 | super(ActorCritic, self).__init__()
238 | self.policy = policy
239 | self.value_model = value_model
240 |
241 | def forward(
242 | self,
243 | queries: Tensor,
244 | query_attn_masks: Tensor,
245 | responses: Tensor,
246 | temperature: Optional[float] = None,
247 | mode: Optional[str] = None,
248 | ) -> Dict[str, Tensor]:
249 | # Assume the policy and value model share the same tokenizer.
250 |
251 | if mode is None:
252 | o1 = self.policy(queries, query_attn_masks, responses, temperature)
253 | o2 = self.value_model(queries, query_attn_masks, responses)
254 |
255 | elif mode == "policy":
256 | o1 = self.policy(queries, query_attn_masks, responses, temperature)
257 | # Add dummy loss to make sure every parameter is used in the backward pass.
258 | o2 = {
259 | "dummy_loss": 0.0
260 | * torch.sum(
261 | torch.stack(
262 | [
263 | torch.mean(value)
264 | for key, value in self.named_parameters()
265 | if "lora_value" in key
266 | ]
267 | )
268 | )
269 | }
270 | elif mode == "value":
271 | o2 = self.value_model(queries, query_attn_masks, responses)
272 | # Add dummy loss to make sure every parameter is used in the backward pass.
273 | o1 = {
274 | "dummy_loss": 0.0
275 | * torch.sum(
276 | torch.stack(
277 | [
278 | torch.mean(value)
279 | for key, value in self.named_parameters()
280 | if "lora_policy" in key
281 | ]
282 | )
283 | )
284 | }
285 | else:
286 | raise ValueError(f"Unknown mode: {mode}")
287 |
288 | return {**o1, **o2}
289 |
290 | def respond(
291 | self,
292 | queries: Tensor,
293 | query_attn_masks: Tensor,
294 | temperature: Optional[float] = None,
295 | ) -> Dict[str, Tensor]:
296 | return self.policy.respond(
297 | queries=queries, query_attn_masks=query_attn_masks, temperature=temperature
298 | )
299 |
300 |
301 | def make_policy_with_base_model(
302 | args,
303 | base_model: transformers.PreTrainedModel,
304 | base_tokenizer: transformers.PreTrainedTokenizer,
305 | adapter_name: Optional[str] = "default",
306 | ) -> Policy:
307 | if base_model.config.is_encoder_decoder:
308 | raise NotImplementedError
309 | else:
310 | return AutoregressivePolicy(
311 | args, base_model, base_tokenizer, adapter_name=adapter_name
312 | )
313 |
314 |
315 | def make_value_with_base_model(
316 | args,
317 | base_model: transformers.PreTrainedModel,
318 | base_tokenizer: transformers.PreTrainedTokenizer,
319 | adapter_name: Optional[str] = "default",
320 | ) -> Value:
321 | if base_model.config.is_encoder_decoder:
322 | raise NotImplementedError
323 | else:
324 | return AutoregressiveValue(
325 | args, base_model, base_tokenizer, adapter_name=adapter_name
326 | )
327 |
--------------------------------------------------------------------------------
/training/models/trainer_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The Self-Align Team
2 | # Copyright 2023 The Alpaca Team
3 | # Copyright 2022 The HuggingFace Team. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | from typing import Optional
18 |
19 | from torch import nn, optim
20 | from transformers import Trainer
21 | from transformers.optimization import get_scheduler
22 | from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
23 | from transformers.trainer_pt_utils import get_parameter_names
24 |
25 |
26 | def create_optimizer(
27 | args, model: nn.Module, optimizer: Optional[optim.Optimizer] = None
28 | ):
29 | """Create optimizer for trainer.
30 |
31 | This is detached version of the `Trainer.create_optimizer` method.
32 | We don't support sagemaker and fairscale for simplicity.
33 |
34 | Reference:
35 | https://github.com/huggingface/transformers/blob/main/src/transformers/trainer.py
36 | """
37 | opt_model = model
38 |
39 | if optimizer is None:
40 | decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
41 | decay_parameters = [name for name in decay_parameters if "bias" not in name]
42 | optimizer_grouped_parameters = [
43 | {
44 | "params": [
45 | p
46 | for n, p in opt_model.named_parameters()
47 | if (n in decay_parameters and p.requires_grad)
48 | ],
49 | "weight_decay": args.weight_decay,
50 | },
51 | {
52 | "params": [
53 | p
54 | for n, p in opt_model.named_parameters()
55 | if (n not in decay_parameters and p.requires_grad)
56 | ],
57 | "weight_decay": 0.0,
58 | },
59 | ]
60 |
61 | optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(args)
62 |
63 | optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
64 | return optimizer
65 |
66 |
67 | def create_scheduler(args, optimizer, lr_scheduler, num_training_steps):
68 | """Create scheduler for trainer.
69 |
70 | This is detached version of the `Trainer.create_scheduler` method.
71 |
72 | Reference:
73 | https://github.com/huggingface/transformers/blob/main/src/transformers/trainer.py
74 | """
75 | if lr_scheduler is None:
76 | lr_scheduler = get_scheduler(
77 | args.lr_scheduler_type,
78 | optimizer=optimizer,
79 | num_warmup_steps=args.get_warmup_steps(num_training_steps),
80 | num_training_steps=num_training_steps,
81 | )
82 | return lr_scheduler
83 |
--------------------------------------------------------------------------------
/training/qlora_utils.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 | from os.path import exists, join, isdir
4 | import shutil
5 | import sys
6 | from typing import Optional, Dict, Sequence, List
7 |
8 | import torch
9 | import transformers
10 | from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
11 |
12 | from models.reward_model import RewardModel
13 |
14 | DEFAULT_PAD_TOKEN = "[PAD]"
15 |
16 |
17 | class SavePeftModelCallback(transformers.TrainerCallback):
18 | def save_model(self, args, state, kwargs):
19 | print("Saving PEFT checkpoint...")
20 |
21 | global_rank = int(os.environ.get("RANK", 0))
22 |
23 | if global_rank == 0:
24 | print("Saving model checkpoint to %s" % args.output_dir)
25 | if state.best_model_checkpoint is not None:
26 | checkpoint_folder = state.best_model_checkpoint
27 | else:
28 | checkpoint_folder = os.path.join(
29 | args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}"
30 | )
31 |
32 | peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
33 | reward_head_path = os.path.join(checkpoint_folder, "reward_head")
34 |
35 | if isinstance(kwargs["model"], RewardModel):
36 | kwargs["model"].backbone_model.save_pretrained(peft_model_path)
37 | torch.save(
38 | kwargs["model"].reward_head.state_dict(),
39 | reward_head_path,
40 | )
41 | else:
42 | kwargs["model"].save_pretrained(peft_model_path)
43 |
44 | pytorch_model_paths = glob.glob(
45 | os.path.join(checkpoint_folder, "pytorch_model*.bin")
46 | )
47 | for pytorch_model_path in pytorch_model_paths:
48 | if os.path.exists(pytorch_model_path):
49 | os.remove(pytorch_model_path)
50 |
51 | optimizer_path = os.path.join(checkpoint_folder, "optimizer.pt")
52 | if os.path.exists(optimizer_path):
53 | os.remove(optimizer_path)
54 |
55 | else:
56 | print("Skipping PEFT checkpoint save on rank %d" % global_rank)
57 |
58 | def on_save(self, args, state, control, **kwargs):
59 | self.save_model(args, state, kwargs)
60 | return control
61 |
62 | def on_train_end(self, args, state, control, **kwargs):
63 | def touch(fname, times=None):
64 | global_rank = int(os.environ.get("RANK", 0))
65 | if global_rank == 0:
66 | with open(fname, "a"):
67 | os.utime(fname, times)
68 |
69 | touch(join(args.output_dir, "completed"))
70 | self.save_model(args, state, kwargs)
71 |
72 |
73 | def print_trainable_parameters(args, model):
74 | """
75 | Prints the number of trainable parameters in the model.
76 | """
77 | trainable_params = 0
78 | all_param = 0
79 | for _, param in model.named_parameters():
80 | all_param += param.numel()
81 | if param.requires_grad:
82 | trainable_params += param.numel()
83 | if args.bits == 4:
84 | trainable_params /= 2
85 | print(
86 | f"trainable params: {trainable_params} || "
87 | f"all params: {all_param} || "
88 | f"trainable: {100 * trainable_params / all_param}"
89 | )
90 |
91 |
92 | def get_last_checkpoint(checkpoint_dir):
93 | if isdir(checkpoint_dir):
94 | is_completed = exists(join(checkpoint_dir, "completed"))
95 | if is_completed:
96 | return None, True # already finished
97 | max_step = 0
98 | for filename in os.listdir(checkpoint_dir):
99 | if isdir(join(checkpoint_dir, filename)) and filename.startswith(
100 | "checkpoint"
101 | ):
102 | max_step = max(max_step, int(filename.replace("checkpoint-", "")))
103 | if max_step == 0:
104 | return None, is_completed # training started, but no checkpoint
105 | checkpoint_dir = join(checkpoint_dir, f"checkpoint-{max_step}")
106 | print(f"Found a previous checkpoint at: {checkpoint_dir}")
107 | return checkpoint_dir, is_completed # checkpoint found!
108 | return None, False # first training
109 |
--------------------------------------------------------------------------------
/training/step1_synthetic_preference_collection/batch_generation.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # This software may be used and distributed according to the terms of the GNU General Public License version 3.
3 |
4 | from typing import Optional
5 | import math
6 | import os
7 | import fire
8 | import time
9 | import tqdm
10 | import json
11 |
12 | from pathlib import Path
13 |
14 | import torch
15 |
16 | from llama_dromedary import Llama
17 |
18 |
19 | def main(
20 | ckpt_dir: str,
21 | tokenizer_path: str,
22 | temperature: float = 1.0,
23 | top_p: float = 1.0,
24 | max_seq_len: int = 512,
25 | max_batch_size: int = 32,
26 | max_shared_seq_len: int = 512,
27 | generate_max_len: int = 128,
28 | group_rank: int = -1,
29 | group_size: int = -1,
30 | input_file: str = None,
31 | output_file: str = None,
32 | meta_prompt_file: str = None,
33 | prompt_style: str = "dromedary",
34 | seed: Optional[int] = None,
35 | ):
36 | assert group_rank >= 0, "Must specify group rank"
37 | assert group_size >= 0, "Must specify group size"
38 | assert (
39 | input_file is not None and output_file is not None
40 | ), "Must specify input and output files"
41 | assert meta_prompt_file is not None, "Must specify meta prompt file"
42 |
43 | with open(input_file, "r") as f:
44 | inputs = json.load(f)
45 | inputs = inputs[group_rank::group_size]
46 |
47 | if prompt_style == "dromedary":
48 | generate_prompt_fn = generate_dromedary_prompt
49 | else:
50 | raise ValueError(f"Unknown prompt style: {prompt_style}")
51 |
52 | with open(meta_prompt_file, "r") as f:
53 | meta_prompt = f.read().strip()
54 |
55 | generator = Llama.build(
56 | ckpt_dir=ckpt_dir,
57 | tokenizer_path=tokenizer_path,
58 | max_seq_len=max_seq_len,
59 | max_batch_size=max_batch_size,
60 | max_shared_seq_len=max_shared_seq_len,
61 | )
62 |
63 | if seed is not None:
64 | torch.manual_seed(seed)
65 |
66 | results = []
67 | # record current progress
68 |
69 | if "shards" not in output_file and group_size > 1:
70 | output_file = output_file.replace(
71 | ".json", f"_{group_size}shards_{group_rank}.json"
72 | )
73 |
74 | if Path(output_file).exists():
75 | with open(output_file, "r") as f:
76 | results = f.readlines()
77 | results = [line for line in results if len(line.strip()) > 0]
78 |
79 | inputs = inputs[len(results) :]
80 | print("Skipping %d examples" % len(results))
81 |
82 | global_rank = int(os.environ.get("RANK", "0"))
83 | batching_inputs = tqdm.tqdm(
84 | BatchIterator(inputs, max_batch_size),
85 | desc="Batched inference",
86 | disable=global_rank > 0,
87 | )
88 | total_iters = len(inputs) // max_batch_size
89 |
90 | output_handler = None
91 | if global_rank == 0:
92 | output_handler = open(output_file, "a")
93 |
94 | # prepare inputs with batch size $max_batch_size
95 | for iter, batched_inputs in enumerate(batching_inputs):
96 | t0 = time.time()
97 | prompts = [
98 | generate_prompt_fn(
99 | meta_prompt, ex_input["instruction"].strip(), ex_input["input"].strip()
100 | )
101 | for ex_input in batched_inputs
102 | ]
103 |
104 | if prompt_style == "llama_2_chat":
105 | outputs = generator.chat_completion(
106 | prompts, # type: ignore
107 | max_gen_len=generate_max_len,
108 | temperature=temperature,
109 | top_p=top_p,
110 | )
111 | else:
112 | outputs = generator.text_completion(
113 | prompts,
114 | max_gen_len=generate_max_len,
115 | temperature=temperature,
116 | top_p=top_p,
117 | )
118 |
119 | t1 = time.time()
120 |
121 | results = []
122 | for ex_input, output in zip(batched_inputs, outputs):
123 | results.append(
124 | {
125 | "instruction": ex_input["instruction"],
126 | "input": ex_input["input"],
127 | "output": output,
128 | }
129 | )
130 |
131 | if group_rank == 0:
132 | for ex_input, output, _ in zip(batched_inputs, outputs, range(8)):
133 | print("=" * 20, "iter: ", iter, "/", total_iters, "latency: ", t1 - t0)
134 | print(f"Input: {ex_input['instruction']}: {ex_input['input']}")
135 | print(f"Output: {output}")
136 | print()
137 |
138 | if output_handler is not None:
139 | for result in results:
140 | output_handler.write(json.dumps(result) + "\n")
141 | output_handler.flush()
142 |
143 | if output_handler is not None:
144 | output_handler.close()
145 |
146 |
147 | class BatchIterator:
148 | def __init__(self, data, batch_size=1):
149 | self.data = data
150 | self.batch_size = batch_size
151 | self.index = 0
152 |
153 | def __iter__(self):
154 | for i in range(0, len(self.data), self.batch_size):
155 | yield self.data[i : i + self.batch_size]
156 |
157 | def __len__(self):
158 | return math.ceil(len(self.data) / self.batch_size)
159 |
160 |
161 | def generate_dromedary_prompt(meta_prompt, instruction, input=None):
162 | if input:
163 | return f"""{meta_prompt}
164 | {instruction}
165 |
166 | {input}
167 |
168 | ### Dromedary"""
169 | else:
170 | return f"""{meta_prompt}
171 | {instruction}
172 |
173 | ### Dromedary"""
174 |
175 |
176 | if __name__ == "__main__":
177 | fire.Fire(main)
178 |
--------------------------------------------------------------------------------
/training/step1_synthetic_preference_collection/clean_oasst1_prompts.py:
--------------------------------------------------------------------------------
1 | import random
2 | import tqdm
3 | import re
4 | import json
5 | from datasets import load_dataset
6 | from transformers import AutoTokenizer
7 | import fire
8 |
9 |
10 | def load_oasst_data():
11 | oasst_dataset = load_dataset("OpenAssistant/oasst1")["train"]
12 |
13 | def create_message_trees(dataset):
14 | """Create message trees from dataset."""
15 | # Organize data into dictionary based on parent_id
16 | organized_data = {}
17 | for message in dataset:
18 | parent_id = message["parent_id"]
19 | if parent_id not in organized_data:
20 | organized_data[parent_id] = []
21 | organized_data[parent_id].append(message)
22 |
23 | # Traverse data to create trees
24 | message_trees = []
25 | for root_messages in organized_data[None]:
26 | tree = []
27 | current_message = root_messages
28 | while current_message is not None:
29 | tree.append(current_message)
30 | children = organized_data.get(current_message["message_id"])
31 | current_message = children[0] if children else None
32 | message_trees.append(tree)
33 |
34 | return message_trees
35 |
36 | oasst_message_trees = create_message_trees(oasst_dataset)
37 | oasst_examples = []
38 |
39 | count = 0
40 | for oasst_tree in oasst_message_trees:
41 | if len(oasst_tree) >= 2:
42 | count += 1
43 | oasst_examples.append(
44 | {
45 | "instruction": oasst_tree[0]["text"],
46 | "input": "",
47 | "output": "",
48 | }
49 | )
50 | print("OASST examples:", count)
51 | return oasst_examples
52 |
53 |
54 | def main(
55 | output_file: str = "/path/to/oasst1_prompts.json",
56 | ):
57 | oasst_examples = load_oasst_data()
58 |
59 | with open(output_file, "w") as f:
60 | json.dump(oasst_examples, f, indent=2)
61 |
62 |
63 | if __name__ == "__main__":
64 | fire.Fire(main)
65 |
--------------------------------------------------------------------------------
/training/step1_synthetic_preference_collection/scripts/generate_oasst1_response0.sh:
--------------------------------------------------------------------------------
1 | # We use 8 x 6 = 48 V100-32GB GPUs
2 | # On AiMOS cluster [https://docs.cci.rpi.edu/clusters/DCS_Supercomputer/]
3 | # salloc --nodes 8 --time 6:00:00 --gres=gpu:32g:6 srun bash scripts/generate_oasst1_response0.sh
4 |
5 | set -e
6 | set -x
7 |
8 | export PYTHONPATH="$PWD:$PYTHONPATH"
9 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5
10 | export MODEL_DIR="/your/model/dir"
11 | export DATA_DIR="/your/data/dir"
12 | export OMP_NUM_THREADS=6
13 | export GPUS_PER_NODE=6
14 | export NUM_NODES=2
15 | export MASTER_PORT=9901
16 |
17 | LOCAL_NODE_RANK=$((SLURM_PROCID % NUM_NODES))
18 | GROUP_RANK=$((SLURM_PROCID / NUM_NODES))
19 | GROUP_SIZE=$((SLURM_NNODES / NUM_NODES))
20 | SYNC_NODE_RANK=$((GROUP_RANK * NUM_NODES))
21 |
22 | # MASTER_ADDR should be SYNC_NODE_RANK-th node in $(scontrol show hostnames $SLURM_JOB_NODELIST)
23 | export MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n $((SYNC_NODE_RANK + 1)) | tail -n 1)
24 |
25 | echo "$MASTER_ADDR, $GROUP_RANK / $GROUP_SIZE: $LOCAL_NODE_RANK"
26 |
27 | torchrun --nproc_per_node $GPUS_PER_NODE \
28 | --nnodes $NUM_NODES \
29 | --node_rank $LOCAL_NODE_RANK \
30 | --master_addr $MASTER_ADDR \
31 | --master_port $MASTER_PORT \
32 | batch_generation.py \
33 | --ckpt_dir $MODEL_DIR/dromedary-2-70b-sft-12shard \
34 | --tokenizer_path $MODEL_DIR/tokenizer.model \
35 | --generate_max_len 768 \
36 | --max_seq_len 768 \
37 | --max_shared_seq_len 640 \
38 | --max_batch_size 64 \
39 | --group_rank $GROUP_RANK \
40 | --group_size $GROUP_SIZE \
41 | --input_file "$DATA_DIR/oasst1_prompts.json" \
42 | --output_file "$DATA_DIR/oasst1_dromedary2_sft_response0.json" \
43 | --meta_prompt_file "../../prompts/synthetic_inference_prompts/dromedary_inference_prompt.txt" \
44 | --temperature 0.7 \
45 | --top_p 1.0 \
46 | --seed 42
47 |
--------------------------------------------------------------------------------
/training/step1_synthetic_preference_collection/scripts/generate_oasst1_response1.sh:
--------------------------------------------------------------------------------
1 | # We use 8 x 6 = 48 V100-32GB GPUs
2 | # On AiMOS cluster [https://docs.cci.rpi.edu/clusters/DCS_Supercomputer/]
3 | # salloc --nodes 8 --time 6:00:00 --gres=gpu:32g:6 srun bash scripts/generate_oasst1_response1.sh
4 |
5 | set -e
6 | set -x
7 |
8 | export PYTHONPATH="$PWD:$PYTHONPATH"
9 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5
10 | export MODEL_DIR="/your/model/dir"
11 | export DATA_DIR="/your/data/dir"
12 | export OMP_NUM_THREADS=6
13 | export GPUS_PER_NODE=6
14 | export NUM_NODES=2
15 | export MASTER_PORT=9901
16 |
17 | LOCAL_NODE_RANK=$((SLURM_PROCID % NUM_NODES))
18 | GROUP_RANK=$((SLURM_PROCID / NUM_NODES))
19 | GROUP_SIZE=$((SLURM_NNODES / NUM_NODES))
20 | SYNC_NODE_RANK=$((GROUP_RANK * NUM_NODES))
21 |
22 | # MASTER_ADDR should be SYNC_NODE_RANK-th node in $(scontrol show hostnames $SLURM_JOB_NODELIST)
23 | export MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n $((SYNC_NODE_RANK + 1)) | tail -n 1)
24 |
25 | echo "$MASTER_ADDR, $GROUP_RANK / $GROUP_SIZE: $LOCAL_NODE_RANK"
26 |
27 | torchrun --nproc_per_node $GPUS_PER_NODE \
28 | --nnodes $NUM_NODES \
29 | --node_rank $LOCAL_NODE_RANK \
30 | --master_addr $MASTER_ADDR \
31 | --master_port $MASTER_PORT \
32 | batch_generation.py \
33 | --ckpt_dir $MODEL_DIR/dromedary-2-70b-sft-12shard \
34 | --tokenizer_path $MODEL_DIR/tokenizer.model \
35 | --generate_max_len 768 \
36 | --max_seq_len 768 \
37 | --max_shared_seq_len 640 \
38 | --max_batch_size 64 \
39 | --group_rank $GROUP_RANK \
40 | --group_size $GROUP_SIZE \
41 | --input_file "$DATA_DIR/oasst1_prompts.json" \
42 | --output_file "$DATA_DIR/oasst1_dromedary2_sft_response1.json" \
43 | --meta_prompt_file "../../prompts/synthetic_inference_prompts/dromedary_inference_prompt.txt" \
44 | --temperature 0.7 \
45 | --top_p 1.0 \
46 | --seed 43
47 |
--------------------------------------------------------------------------------
/training/step1_synthetic_preference_collection/scripts/generate_synthetic_preference.sh:
--------------------------------------------------------------------------------
1 | # We use 1 x 8 = 8 A100-80GB GPUs
2 | # salloc --nodes 1 --time 24:00:00 --gres=gpu:80g:8 srun bash scripts/generate_synthetic_preference.sh
3 |
4 | set -e
5 | set -x
6 |
7 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
8 | export MODEL_DIR="/your/model/dir"
9 | export DATA_DIR="/your/data/dir"
10 | export PYTHONPATH="$PWD:$PYTHONPATH"
11 | export GPUS_PER_NODE=8
12 | export OMP_NUM_THREADS=8
13 |
14 | torchrun \
15 | --standalone \
16 | --nnodes=1 \
17 | --nproc-per-node=$GPUS_PER_NODE \
18 | synthetic_preference.py \
19 | --model_name "$MODEL_DIR/llama-2-70b-hf" \
20 | --adapters_name "$MODEL_DIR/dromedary-2-70b-sft-qlora/adapter_model" \
21 | --preferece_prompt "../../prompts/synthetic_preference_prompt.txt" \
22 | --rm_principles "../../prompts/principles/principle_collection_rm.json" \
23 | --response_pattern "$DATA_DIR/oasst1_dromedary2_sft_response*.json" \
24 | --output_file "$DATA_DIR/oasst1_dromedary2_sft_preference.json"
25 |
--------------------------------------------------------------------------------
/training/step1_synthetic_preference_collection/synthetic_preference.py:
--------------------------------------------------------------------------------
1 | import fcntl
2 | import glob
3 | import json
4 | import os
5 | import random
6 |
7 | import numpy as np
8 | import torch
9 | import tqdm
10 | import fire
11 |
12 | from peft import PeftModel
13 | from transformers import (
14 | BitsAndBytesConfig,
15 | AutoModelForCausalLM,
16 | LlamaTokenizerFast,
17 | )
18 | from datasets import load_dataset
19 |
20 |
21 | def main(
22 | model_name: str,
23 | adapters_name: str,
24 | preferece_prompt: str,
25 | rm_principles: str,
26 | response_pattern: str,
27 | output_file: str,
28 | tokenizer_name: str = "TheBloke/dromedary-65b-lora-HF", # a random llama-based model
29 | ):
30 | with open(preferece_prompt, "r") as f:
31 | PREFERENCE_META_PROMPT = f.read().strip()
32 |
33 | with open(rm_principles, "r") as f:
34 | principle_definitions = json.load(f)
35 |
36 | data_files = glob.glob(response_pattern)
37 | cache = ""
38 | raw_data = {}
39 | for data_file in data_files:
40 | type = data_file.split("/")[-1].split("_")[-1].split(".")[0]
41 | raw_data[type] = []
42 | with open(data_file, "r") as f:
43 | for line in f:
44 | try:
45 | raw_data[type].append(json.loads(cache + line))
46 | cache = ""
47 | except json.decoder.JSONDecodeError:
48 | cache += line.strip()
49 | raw_data[type] = sorted(raw_data[type], key=lambda x: x["instruction"])
50 |
51 | print(raw_data.keys())
52 |
53 | preference_dataset = []
54 |
55 | for type_idx_a in range(len(raw_data.keys())):
56 | for type_idx_b in range(type_idx_a + 1, len(raw_data.keys())):
57 | type_a = list(raw_data.keys())[type_idx_a]
58 | type_b = list(raw_data.keys())[type_idx_b]
59 | print(
60 | f"len {type_a} and {type_b}",
61 | len(raw_data[type_a]),
62 | len(raw_data[type_b]),
63 | )
64 |
65 | for idx in range(len(raw_data[type_a])):
66 | data_a = raw_data[type_a][idx]
67 | data_b = raw_data[type_b][idx]
68 |
69 | if data_a["instruction"] != data_b["instruction"]:
70 | print("Instruction mismatch!")
71 | print(data_a["instruction"])
72 | print(data_b["instruction"])
73 | exit(0)
74 |
75 | if (
76 | len(data_a["output"]["generation"]) < 16
77 | or len(data_b["output"]["generation"]) < 16
78 | ):
79 | continue
80 |
81 | preference_dataset.append(
82 | {
83 | "instruction": data_a["instruction"],
84 | "input": data_a["input"],
85 | "output_1": data_a["output"]["generation"]
86 | .split("###")[0]
87 | .strip(),
88 | "output_2": data_b["output"]["generation"]
89 | .split("###")[0]
90 | .strip(),
91 | "preference": 0,
92 | }
93 | )
94 |
95 | random.Random(42).shuffle(preference_dataset)
96 |
97 | print(f"Starting to load the model {model_name} into memory")
98 |
99 | rank = int(os.environ.get("RANK", 0))
100 | world_size = int(os.environ.get("WORLD_SIZE", 1))
101 | torch.cuda.set_device(rank)
102 |
103 | print(preference_dataset[0 + rank])
104 |
105 | batch_size = 4 # For a single 80GB GPU
106 | m = AutoModelForCausalLM.from_pretrained(
107 | model_name,
108 | load_in_4bit=True,
109 | torch_dtype=torch.float16,
110 | device_map={"": torch.cuda.current_device()},
111 | quantization_config=BitsAndBytesConfig(
112 | load_in_4bit=True,
113 | load_in_8bit=False,
114 | llm_int8_threshold=6.0,
115 | llm_int8_has_fp16_weight=False,
116 | bnb_4bit_compute_dtype=torch.float16,
117 | bnb_4bit_use_double_quant=True,
118 | bnb_4bit_quant_type="nf4",
119 | ),
120 | )
121 | m = PeftModel.from_pretrained(
122 | m,
123 | adapters_name,
124 | adapter_name="default",
125 | is_trainable=False,
126 | )
127 | tok = LlamaTokenizerFast.from_pretrained(
128 | tokenizer_name,
129 | padding_side="left",
130 | truncation_side="left",
131 | model_max_length=1536,
132 | )
133 | tok.pad_token_id = 0
134 |
135 | print(f"Successfully loaded the model {model_name} into memory")
136 | print_flag = True
137 |
138 | idxs = []
139 | dimensions = []
140 | prompts = []
141 | preferences = []
142 |
143 | for idx in tqdm.tqdm(range(len(preference_dataset))):
144 | if idx % world_size != rank:
145 | continue
146 |
147 | for principle_definition in principle_definitions:
148 | dimension = principle_definition["dimension"]
149 | definition = principle_definition["definition"]
150 |
151 | data = preference_dataset[idx]
152 | instruction = data["instruction"]
153 | instruction_input = data["input"]
154 |
155 | if instruction_input:
156 | instruction += "\n\n" + instruction_input
157 |
158 | output_1 = data["output_1"]
159 | output_2 = data["output_2"]
160 | preference = data["preference"]
161 | preferences.append(preference)
162 | idxs.append(idx)
163 | dimensions.append(dimension)
164 |
165 | for a, b in [(output_1, output_2), (output_2, output_1)]:
166 | prompt_for_score = PREFERENCE_META_PROMPT.format(
167 | UserInstruction=instruction,
168 | OutputA=a,
169 | OutputB=b,
170 | Dimension=dimension,
171 | Definition=definition,
172 | )
173 |
174 | if print_flag:
175 | print_flag = False
176 | print(prompt_for_score)
177 |
178 | prompts.append(tok.bos_token + prompt_for_score)
179 |
180 | if len(prompts) == batch_size * 2:
181 | tokenized_input = tok(
182 | prompts,
183 | return_tensors="pt",
184 | add_special_tokens=False,
185 | padding="max_length",
186 | return_attention_mask=True,
187 | truncation=True,
188 | )
189 |
190 | input_ids = tokenized_input["input_ids"].to(m.device)
191 | attention_mask = tokenized_input["attention_mask"].to(m.device)
192 |
193 | with torch.inference_mode():
194 | output = m(
195 | input_ids,
196 | attention_mask=attention_mask,
197 | )
198 |
199 | logits = output.logits
200 |
201 | token_id_a = tok.encode("\n (a", add_special_tokens=False)[-1]
202 | token_id_b = tok.encode("\n (b", add_special_tokens=False)[-1]
203 |
204 | relative_scores = []
205 | prompt_preferences = []
206 |
207 | for ex_idx in range(batch_size):
208 | score_a_for_1_2 = logits[ex_idx * 2 + 0, -1, token_id_a]
209 | score_b_for_1_2 = logits[ex_idx * 2 + 0, -1, token_id_b]
210 | score_a_for_2_1 = logits[ex_idx * 2 + 1, -1, token_id_a]
211 | score_b_for_2_1 = logits[ex_idx * 2 + 1, -1, token_id_b]
212 |
213 | relative_score_1_2 = (score_a_for_1_2 - score_b_for_1_2).item()
214 | relative_score_2_1 = (score_b_for_2_1 - score_a_for_2_1).item()
215 |
216 | if relative_score_1_2 > 0.0 and relative_score_2_1 > 0.0:
217 | prompt_preference = 1
218 | elif relative_score_1_2 < 0.0 and relative_score_2_1 < 0.0:
219 | prompt_preference = 2
220 | else:
221 | prompt_preference = 0
222 |
223 | relative_scores.append((relative_score_1_2, relative_score_2_1))
224 | prompt_preferences.append(prompt_preference)
225 |
226 | outputs = []
227 |
228 | for ex_idx in range(batch_size):
229 | outputs.append(
230 | {
231 | "example_idx": idxs[ex_idx],
232 | "dimension": dimensions[ex_idx],
233 | "preference": preferences[ex_idx],
234 | "prompt_preference": prompt_preferences[ex_idx],
235 | "relative_score": relative_scores[ex_idx],
236 | }
237 | )
238 |
239 | with open(output_file, "a") as f:
240 | fcntl.flock(f, fcntl.LOCK_EX)
241 | for output in outputs:
242 | f.write(json.dumps(output) + "\n")
243 | fcntl.flock(f, fcntl.LOCK_UN)
244 |
245 | idxs = []
246 | dimensions = []
247 | prompts = []
248 | preferences = []
249 |
250 |
251 | if __name__ == "__main__":
252 | fire.Fire(main)
253 |
--------------------------------------------------------------------------------
/training/step2_rm_training/aggregate_synthetic_preference.py:
--------------------------------------------------------------------------------
1 | import fcntl
2 | import glob
3 | import json
4 | import os
5 | import random
6 |
7 | import numpy as np
8 | import tqdm
9 | from ftlangdetect import detect
10 |
11 |
12 | ALL_DIMENSIONS = [
13 | # "Overall Helpful",
14 | "Concise",
15 | "Honest and Accurate",
16 | "Ethical",
17 | # "Natural and Fluent",
18 | "Specific",
19 | "Educational and Engaging",
20 | "Methodical",
21 | # "Multilingual",
22 | "Creative",
23 | "Comprehensive",
24 | ]
25 |
26 |
27 | def main(
28 | response_pattern: str,
29 | preference_file: str,
30 | output_file: str,
31 | noisy_ratio: float = 0.1,
32 | max_principles: int = 8,
33 | tokenizer_name: str = "TheBloke/dromedary-65b-lora-HF", # a random llama-based model
34 | ):
35 | data_files = glob.glob(response_pattern)
36 |
37 | cache = ""
38 | raw_data = {}
39 | for data_file in data_files:
40 | type = data_file.split("/")[-1].split("_")[-1].split(".")[0]
41 | raw_data[type] = []
42 | with open(data_file, "r") as f:
43 | for line in f:
44 | try:
45 | raw_data[type].append(json.loads(cache + line))
46 | cache = ""
47 | except json.decoder.JSONDecodeError:
48 | cache += line.strip()
49 | raw_data[type] = sorted(raw_data[type], key=lambda x: x["instruction"])
50 |
51 | print(raw_data.keys(), len(raw_data[type]))
52 | preference_dataset = []
53 |
54 | for type_idx_a in range(len(raw_data.keys())):
55 | for type_idx_b in range(type_idx_a + 1, len(raw_data.keys())):
56 | type_a = list(raw_data.keys())[type_idx_a]
57 | type_b = list(raw_data.keys())[type_idx_b]
58 |
59 | for idx in range(len(raw_data[type_a])):
60 | data_a = raw_data[type_a][idx]
61 | data_b = raw_data[type_b][idx]
62 |
63 | if data_a["instruction"] != data_b["instruction"]:
64 | print("Instruction mismatch!")
65 | print(data_a["instruction"])
66 | print(data_b["instruction"])
67 | exit(0)
68 |
69 | if (
70 | len(data_a["output"]["generation"]) < 16
71 | or len(data_b["output"]["generation"]) < 16
72 | ):
73 | continue
74 |
75 | preference_dataset.append(
76 | {
77 | "instruction": data_a["instruction"],
78 | "input": data_a["input"],
79 | "output_1": data_a["output"]["generation"]
80 | .split("###")[0]
81 | .strip(),
82 | "output_2": data_b["output"]["generation"]
83 | .split("###")[0]
84 | .strip(),
85 | "preference": 0,
86 | }
87 | )
88 |
89 | random.Random(42).shuffle(preference_dataset)
90 |
91 | for example_idx in range(len(preference_dataset)):
92 | preference_dataset[example_idx]["example_idx"] = example_idx
93 |
94 | with open(preference_file, "r") as f:
95 | fcntl.flock(f, fcntl.LOCK_EX)
96 | for line in f:
97 | data = json.loads(line)
98 | example_idx = data["example_idx"]
99 |
100 | if "dimension_scores" not in preference_dataset[example_idx]:
101 | preference_dataset[example_idx]["dimension_scores"] = {}
102 |
103 | dimension = data["dimension"]
104 | score = data["relative_score"]
105 |
106 | preference_dataset[example_idx]["dimension_scores"][dimension] = score
107 | fcntl.flock(f, fcntl.LOCK_UN)
108 |
109 | all_dimensions = ALL_DIMENSIONS
110 | print(preference_dataset[0])
111 |
112 | # print variance
113 | variances = {}
114 | for dimension in all_dimensions:
115 | scores = []
116 | for example_idx in range(len(preference_dataset)):
117 | if "dimension_scores" not in preference_dataset[example_idx]:
118 | continue
119 | if dimension not in preference_dataset[example_idx]["dimension_scores"]:
120 | continue
121 | two_scores = preference_dataset[example_idx]["dimension_scores"][dimension]
122 | scores.append(two_scores[0])
123 | scores.append(two_scores[1])
124 | variances[dimension] = np.var(scores)
125 |
126 | print("Variance of {}: {}".format(dimension, np.var(scores)))
127 |
128 | clean_preference_dataset = []
129 |
130 | for example_idx in range(len(preference_dataset)):
131 | if "dimension_scores" not in preference_dataset[example_idx]:
132 | continue
133 |
134 | continue_flag = False
135 |
136 | for dimension in all_dimensions:
137 | if dimension not in preference_dataset[example_idx]["dimension_scores"]:
138 | continue_flag = True
139 |
140 | if continue_flag:
141 | continue
142 |
143 | clean_preference_dataset.append(preference_dataset[example_idx])
144 |
145 | print(len(clean_preference_dataset))
146 |
147 | synthetic_data = []
148 |
149 | main_reasons = {}
150 | for dimension in all_dimensions:
151 | main_reasons[dimension] = 0
152 |
153 | random.seed(42)
154 |
155 | def mean(a, b):
156 | return (a + b) / 2.0
157 |
158 | rule_count = 0
159 |
160 | for i in range(4):
161 | for example_idx in range(len(clean_preference_dataset)):
162 | random.shuffle(all_dimensions)
163 |
164 | # Compute the probabilities
165 | num_dimensions = max_principles + 1
166 | while num_dimensions > max_principles:
167 | num_dimensions = np.random.geometric(0.2, 1)[0]
168 |
169 | data = clean_preference_dataset[example_idx]
170 |
171 | dimensions = all_dimensions[:num_dimensions]
172 |
173 | scores = []
174 | for dimension in dimensions:
175 | two_scores = data["dimension_scores"][dimension]
176 | if two_scores[0] > 0.0 and two_scores[1] > 0.0:
177 | scores.append(
178 | mean(two_scores[0], two_scores[1]) / variances[dimension]
179 | )
180 | elif two_scores[0] < 0.0 and two_scores[1] < 0.0:
181 | scores.append(
182 | mean(two_scores[0], two_scores[1]) / variances[dimension]
183 | )
184 | else:
185 | scores.append(0.0)
186 |
187 | # negative_definitions is random benoulli of length num_dimensions
188 | negative_definitions = np.random.binomial(1, 0.5, num_dimensions)
189 |
190 | scores = [
191 | -score if negative_definitions[idx] == 1 else score
192 | for idx, score in enumerate(scores)
193 | ]
194 |
195 | # The preference is based on the absolute value of the scores.
196 | doinate_dimension_idx = np.argmax(np.abs(scores)).item()
197 |
198 | score = scores[doinate_dimension_idx]
199 |
200 | if score > 0.0:
201 | preference = 1
202 | elif score < 0.0:
203 | preference = 2
204 | else:
205 | continue
206 |
207 | order = random.randint(0, 1)
208 |
209 | rule_flag = False
210 |
211 | # We ban "as an AI language model"
212 | if "As an AI language model" in data["output_1"]:
213 | if "As an AI language model" in data["output_2"]:
214 | pass
215 | else:
216 | if random.randint(0, 1) > 0.5:
217 | preference = 2
218 | rule_flag = True
219 | else:
220 | if "As an AI language model" in data["output_2"]:
221 | if random.randint(0, 1) > 0.5:
222 | preference = 1
223 | rule_flag = True
224 | else:
225 | pass
226 |
227 | lang_instruction = detect(data["instruction"].replace("\n", " "))["lang"]
228 | lang_output_1 = detect(data["output_1"].replace("\n", " "))["lang"]
229 | lang_output_2 = detect(data["output_2"].replace("\n", " "))["lang"]
230 |
231 | # We ban different languages
232 | if lang_instruction != lang_output_1:
233 | if lang_instruction != lang_output_2:
234 | continue
235 | else:
236 | preference = 2
237 | rule_flag = True
238 | else:
239 | if lang_instruction != lang_output_2:
240 | preference = 1
241 | rule_flag = True
242 | else:
243 | pass
244 |
245 | data_tuple = (data["instruction"], data["output_1"], data["output_1"])
246 |
247 | if rule_flag:
248 | rule_count += 1
249 | flip = random.random() < noisy_ratio * 2.5 # 2.5 is a magic number
250 | else:
251 | flip = random.random() < noisy_ratio
252 | main_reasons[dimensions[doinate_dimension_idx]] += 1
253 |
254 | preference = 3 - preference if flip else preference
255 |
256 | synthetic_data.append(
257 | {
258 | "doinate_dimension_idx": doinate_dimension_idx,
259 | "dimensions": str(tuple(dimensions)),
260 | "dimension_scores": str(
261 | tuple(scores if order == 0 else [-score for score in scores])
262 | ),
263 | "negative_definitions": str(tuple(negative_definitions)),
264 | "instruction": data["instruction"],
265 | "input": data["input"],
266 | "output_1": data["output_1"] if order == 0 else data["output_2"],
267 | "output_2": data["output_2"] if order == 0 else data["output_1"],
268 | "preference": preference if order == 0 else 3 - preference,
269 | "flip": flip,
270 | "lang": (lang_instruction, lang_output_1, lang_output_2)
271 | if order == 0
272 | else (lang_instruction, lang_output_2, lang_output_1),
273 | "rule_flag": rule_flag,
274 | }
275 | )
276 |
277 | print(rule_count)
278 | print(len(synthetic_data))
279 | print(main_reasons)
280 |
281 | with open(output_file, "w") as f:
282 | f.write(
283 | json.dumps(
284 | synthetic_data,
285 | indent=4,
286 | )
287 | )
288 |
--------------------------------------------------------------------------------
/training/step2_rm_training/clean_pmp_data.py:
--------------------------------------------------------------------------------
1 | import json
2 | import random
3 | import tqdm
4 | from datasets import load_dataset
5 | import fire
6 |
7 |
8 | def find_shared_prefix(str1, str2):
9 | """
10 | Find the shared prefix between two strings.
11 |
12 | Args:
13 | str1 (str): The first string.
14 | str2 (str): The second string.
15 |
16 | Returns:
17 | int: The index position where the shared prefix ends in the first string.
18 | """
19 | # Find the minimum length of both strings
20 | min_length = min(len(str1), len(str2))
21 |
22 | max_common_prefix = 0
23 | # Loop through characters to find where they start to differ
24 | for i in range(min_length):
25 | if str1[i] != str2[i]:
26 | max_common_prefix = i
27 | break
28 |
29 | while (
30 | not str1[:max_common_prefix].endswith("Assistant:")
31 | # Force the prefix to end with "Assistant:"
32 | ) and max_common_prefix > 0:
33 | max_common_prefix -= 1
34 |
35 | return max_common_prefix
36 |
37 |
38 | def split_example(example):
39 | """
40 | Split an example into a shared prefix and the remaining parts of the strings.
41 |
42 | Args:
43 | example (dict): A dictionary containing 'chosen' and 'rejected' keys with strings as values.
44 |
45 | Returns:
46 | dict: A dictionary with keys 'query', 'output_1', and 'output_2'.
47 | """
48 | chosen = example["chosen"]
49 | rejected = example["rejected"]
50 |
51 | # Find the index where the shared prefix ends
52 | shared_index = find_shared_prefix(chosen, rejected)
53 |
54 | # Split the strings
55 | query = chosen[:shared_index].strip()
56 | output_1 = chosen[shared_index:].strip()
57 | output_2 = rejected[shared_index:].strip()
58 |
59 | # Return the result as a dictionary
60 | return {"query": query, "output_1": output_1, "output_2": output_2}
61 |
62 |
63 | def main(
64 | output_file: str,
65 | ):
66 | shp_dataset = load_dataset("stanfordnlp/SHP")["train"]
67 |
68 | merged_data = []
69 |
70 | for ex in tqdm.tqdm(shp_dataset):
71 | qid = ex["post_id"]
72 | score = ex["score_ratio"]
73 |
74 | ex_dp = {
75 | "score": score,
76 | "instruction": ex["history"],
77 | "input": "",
78 | "output_1": ex["human_ref_A"],
79 | "output_2": ex["human_ref_B"],
80 | "preference": 1 if ex["labels"] == 1 else 2,
81 | }
82 |
83 | if score > 2.0:
84 | merged_data.append(ex_dp)
85 |
86 | print(merged_data[-1])
87 | print(len(merged_data))
88 |
89 | hh_dataset = load_dataset("Anthropic/hh-rlhf")["train"]
90 |
91 | for ex in hh_dataset:
92 | processed_ex = split_example(ex)
93 |
94 | random_swap = random.random() < 0.5
95 |
96 | merged_data.append(
97 | {
98 | "instruction": processed_ex["query"].strip(),
99 | "input": "",
100 | "output_1": processed_ex["output_1"]
101 | if random_swap
102 | else processed_ex["output_2"],
103 | "output_2": processed_ex["output_2"]
104 | if random_swap
105 | else processed_ex["output_1"],
106 | "preference": 1 if random_swap else 2,
107 | }
108 | )
109 |
110 | print(merged_data[-1])
111 | print(len(merged_data))
112 |
113 | random.shuffle(merged_data)
114 |
115 | with open(output_file, "w") as f:
116 | f.write(
117 | json.dumps(
118 | merged_data,
119 | )
120 | )
121 |
122 |
123 | if __name__ == "__main__":
124 | fire.Fire(main)
125 |
--------------------------------------------------------------------------------
/training/step2_rm_training/scripts/train_reward_model_70b_qlora_ft.sh:
--------------------------------------------------------------------------------
1 | # We use 1 x 8 = 8 A100-80GB GPUs
2 |
3 | set -e
4 | set -x
5 |
6 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
7 | export MODEL_DIR="/your/model/dir"
8 | export DATA_DIR="/your/data/dir"
9 | export PYTHONPATH="$PWD:$PYTHONPATH"
10 | export GPUS_PER_NODE=8
11 | export OMP_NUM_THREADS=8
12 |
13 | NUM_EPOCHS=1
14 | LEARNING_RATE=3e-5
15 | BATCH_SIZE=4
16 | GRAD_ACCUMULATION=2
17 |
18 | PMP_CKPT=llama-2-70b-qlora-rm-pmp/checkpoint-xxxx
19 |
20 | cd ..
21 |
22 | torchrun \
23 | --standalone \
24 | --nnodes=1 \
25 | --nproc-per-node=$GPUS_PER_NODE \
26 | finetune_qlora_rm.py \
27 | --do_train \
28 | --do_eval \
29 | --seed 42 \
30 | --per_device_train_batch_size $BATCH_SIZE \
31 | --per_device_eval_batch_size $BATCH_SIZE \
32 | --gradient_accumulation_steps $GRAD_ACCUMULATION \
33 | --model_name_or_path "$MODEL_DIR/llama-2-70b-hf" \
34 | --learning_rate $LEARNING_RATE \
35 | --model_max_length 1280 \
36 | --query_len 128 \
37 | --response_len 768 \
38 | --dataset_path "$DATA_DIR/oasst1_dromedary2_sft_aggregated_preference.json" \
39 | --principle_collection_path "../principles/principle_collection_rm.json" \
40 | --meta_prompt_pattern "../prompts/salmon_reward_model_prompt_v0.txt" \
41 | --double_quant True \
42 | --quant_type "nf4" \
43 | --bits 4 \
44 | --lora_r 64 \
45 | --output_dir "$MODEL_DIR/llama-2-70b-qlora-rm-pmp-sft" \
46 | --resume_dir "$MODEL_DIR/$PMP_CKPT" \
47 | --num_train_epochs $NUM_EPOCHS \
48 | --group_by_length False \
49 | --evaluation_strategy "steps" \
50 | --eval_steps 20 \
51 | --save_strategy "steps" \
52 | --save_steps 40 \
53 | --save_total_limit 20 \
54 | --weight_decay 0.0 \
55 | --warmup_ratio 0.03 \
56 | --lr_scheduler_type "cosine" \
57 | --logging_steps 5 \
58 | --report_to "tensorboard" \
59 | --ddp_backend "nccl" \
60 | --bf16 True \
61 | --ddp_find_unused_parameters False
62 |
--------------------------------------------------------------------------------
/training/step2_rm_training/scripts/train_reward_model_70b_qlora_pmp.sh:
--------------------------------------------------------------------------------
1 | # We use 1 x 8 = 8 A100-80GB GPUs
2 |
3 | set -e
4 | set -x
5 |
6 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
7 | export MODEL_DIR="/your/model/dir"
8 | export DATA_DIR="/your/data/dir"
9 | export PYTHONPATH="$PWD:$PYTHONPATH"
10 | export GPUS_PER_NODE=8
11 | export OMP_NUM_THREADS=8
12 |
13 | NUM_EPOCHS=1
14 | LEARNING_RATE=5e-5
15 | BATCH_SIZE=4
16 | GRAD_ACCUMULATION=2
17 |
18 | cd ..
19 |
20 | torchrun \
21 | --standalone \
22 | --nnodes=1 \
23 | --nproc-per-node=$GPUS_PER_NODE \
24 | finetune_qlora_rm.py \
25 | --do_train \
26 | --do_eval \
27 | --seed 42 \
28 | --per_device_train_batch_size $BATCH_SIZE \
29 | --per_device_eval_batch_size $BATCH_SIZE \
30 | --gradient_accumulation_steps $GRAD_ACCUMULATION \
31 | --model_name_or_path "$MODEL_DIR/llama-2-70b-hf" \
32 | --learning_rate $LEARNING_RATE \
33 | --model_max_length 1024 \
34 | --query_len 256 \
35 | --response_len 640 \
36 | --dataset_path "$DATA_DIR/pmp_data.json" \
37 | --meta_prompt_pattern "../prompts/pmp_reward_model_prompt.txt" \
38 | --double_quant True \
39 | --quant_type "nf4" \
40 | --bits 4 \
41 | --lora_r 64 \
42 | --output_dir "$MODEL_DIR/llama-2-70b-qlora-rm-pmp" \
43 | --num_train_epochs $NUM_EPOCHS \
44 | --group_by_length False \
45 | --evaluation_strategy "steps" \
46 | --eval_steps 20 \
47 | --save_strategy "steps" \
48 | --save_steps 100 \
49 | --save_total_limit 10 \
50 | --weight_decay 0.0 \
51 | --warmup_ratio 0.01 \
52 | --lr_scheduler_type "cosine" \
53 | --logging_steps 5 \
54 | --report_to "tensorboard" \
55 | --ddp_backend "nccl" \
56 | --bf16 True \
57 | --ddp_find_unused_parameters False \
58 | --resume_from_training True
59 |
--------------------------------------------------------------------------------
/training/step3_ppo_training/aggregate_sharegpt_prompts.py:
--------------------------------------------------------------------------------
1 | import json
2 | import fire
3 | from datasets import load_dataset
4 |
5 | DEFAULT_SHAREGPT_DATA = [
6 | # https://huggingface.co/datasets/zetavg/ShareGPT-Processed
7 | "zetavg/ShareGPT-Processed",
8 | # https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/blob/main/HTML_cleaned_raw_dataset/sg_90k_part1.json
9 | "../../outputs/data/sharegpt/sg_90k_part1.json",
10 | # https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/blob/main/HTML_cleaned_raw_dataset/sg_90k_part2.json
11 | "../../outputs/data/sharegpt/sg_90k_part2.json",
12 | # "../../outputs/data/sharegpt/ShareGPT.jsonl",
13 | # "../../outputs/data/sharegpt/ShareGPT_V3_unfiltered_cleaned_split.json",
14 | ]
15 |
16 |
17 | def extract_from_dataset(dataset, key="markdown", prefix="human"):
18 | data = []
19 |
20 | for item in dataset:
21 | conversations = item["conversations"]
22 | if len(conversations) >= 2:
23 | user_input = conversations[0]
24 |
25 | if user_input["from"] == prefix:
26 | data.append(
27 | {
28 | "instruction": user_input[key],
29 | "input": "",
30 | "output": "",
31 | }
32 | )
33 |
34 | return data
35 |
36 |
37 | def extract_from_json_file(filepath, key="value", prefix="human"):
38 | with open(filepath, "r") as f:
39 | dataset = json.load(f)
40 |
41 | return extract_from_dataset(dataset, key, prefix)
42 |
43 |
44 | def extract_from_jsonl_file(filepath):
45 | data = []
46 |
47 | with open(filepath, "r") as f:
48 | for line in f:
49 | line_data = json.loads(line)
50 | user_input = line_data["text"].split("")[0][len(":") :].strip()
51 |
52 | data.append(
53 | {
54 | "instruction": user_input,
55 | "input": "",
56 | "output": "",
57 | }
58 | )
59 |
60 | return data
61 |
62 |
63 | def main(data_files: list = None, output_file: str = "path/to/sharegpt_prompts.json"):
64 | if data_files is None:
65 | data_files = DEFAULT_SHAREGPT_DATA
66 |
67 | all_data = []
68 | merged_data = []
69 | merged_data_set = set()
70 |
71 | for filepath in data_files:
72 | if filepath.endswith(".json"):
73 | all_data.append(extract_from_json_file(filepath))
74 | elif filepath.endswith(".jsonl"):
75 | all_data.append(extract_from_jsonl_file(filepath))
76 | else:
77 | train_dataset = load_dataset(filepath)["train"]
78 | all_data.append(extract_from_dataset(train_dataset))
79 |
80 | # Merge and filter unique instructions
81 | for data in all_data:
82 | filtered_data_count = 0
83 | for item in data:
84 | if item["instruction"] not in merged_data_set:
85 | merged_data_set.add(item["instruction"])
86 | merged_data.append(item)
87 | else:
88 | filtered_data_count += 1
89 |
90 | print("Filtered data count:", filtered_data_count)
91 | print("Merged data size:", len(merged_data))
92 |
93 | with open(output_file, "w") as f:
94 | json.dump(merged_data, f, indent=2)
95 |
96 |
97 | if __name__ == "__main__":
98 | fire.Fire(main)
99 |
--------------------------------------------------------------------------------
/training/step3_ppo_training/clean_and_merge_prompts.py:
--------------------------------------------------------------------------------
1 | import random
2 | import tqdm
3 | import re
4 | import json
5 | from datasets import load_dataset
6 | from transformers import AutoTokenizer
7 | import fire
8 |
9 | # QUESTION TYPE
10 | # 0: GENERAL (ShareGPT + Dolly + OASST)
11 | # 1: REASONING (OpenOrca + MATH)
12 | # 2: REDTEAMING (Anthropic Red Teaming)
13 |
14 | CHATGPT_LANGUAGES = {
15 | "it": "Italian",
16 | "pl": "Polish",
17 | "ru": "Russian",
18 | "sk": "Slovak",
19 | "pt": "Portuguese",
20 | "ro": "Romanian",
21 | "da": "Danish",
22 | "sv": "Swedish",
23 | "no": "Norwegian",
24 | "en": "English",
25 | "es": "Spanish",
26 | "fr": "French",
27 | "cs": "Czech",
28 | "de": "German",
29 | "fi": "Finnish",
30 | "et": "Estonian",
31 | "lv": "Latvian",
32 | "lt": "Lithuanian",
33 | "fa": "Persian",
34 | "hu": "Hungarian",
35 | "he": "Hebrew",
36 | "el": "Greek",
37 | "ar": "Arabic",
38 | "kr": "Korean",
39 | "ja": "Japanese",
40 | "zh": "Chinese",
41 | "zh-traditional": "Chinese (Traditional)",
42 | "zh-simplified": "Chinese (Simplified)",
43 | "th": "Thai",
44 | "vi": "Vietnamese",
45 | }
46 |
47 |
48 | def remove_leading_fraction(input_string):
49 | # remove leading fraction
50 | cleaned_string = re.sub(r"^\d+\s*/\s*\d+", "", input_string)
51 | cleaned_string = re.sub(r"\d+\s*/\s*\d+$", "", cleaned_string)
52 | cleaned_string = cleaned_string.split("1 / 1", 1)[-1]
53 |
54 | # \uc9c0\uae08 \ubc88\uc5ed\ud558\uae30
55 | cleaned_string = cleaned_string.split("지금 번역하기")[0]
56 |
57 | # Language: English
58 | cleaned_string = cleaned_string.split("\n \n Language: ")[0]
59 |
60 | cleaned_string = cleaned_string.strip()
61 |
62 | if cleaned_string.endswith("Share Prompt"):
63 | cleaned_string = cleaned_string[: -len("Share Prompt")].strip()
64 |
65 | if cleaned_string.endswith("Translate now"):
66 | cleaned_string = cleaned_string[: -len("Translate now")].strip()
67 |
68 | for lang_code in CHATGPT_LANGUAGES:
69 | lang_suffix = f"Language: {CHATGPT_LANGUAGES[lang_code]}"
70 | if cleaned_string.endswith(lang_suffix):
71 | cleaned_string = cleaned_string[: -len(lang_suffix)].strip()
72 |
73 | # ~The following is a conversation with Bing, not ChatGPT.~
74 | if cleaned_string.startswith(
75 | "~The following is a conversation with Bing, not ChatGPT.~"
76 | ):
77 | cleaned_string = cleaned_string[
78 | len("~The following is a conversation with Bing, not ChatGPT.~") :
79 | ].strip()
80 |
81 | return cleaned_string
82 |
83 |
84 | def load_dolly_data(length_bonus):
85 | dataset = load_dataset("databricks/databricks-dolly-15k")["train"]
86 | category_to_examples = {}
87 | for example in dataset:
88 | category = example["category"]
89 | if category not in category_to_examples:
90 | category_to_examples[category] = []
91 | category_to_examples[category].append(example)
92 |
93 | merged_examples = []
94 | for data in [
95 | category_to_examples["creative_writing"],
96 | category_to_examples["brainstorming"],
97 | category_to_examples["open_qa"],
98 | category_to_examples["general_qa"],
99 | category_to_examples["classification"],
100 | ]:
101 | for i in range(len(data)):
102 | assert data[i]["context"] == ""
103 | merged_examples.append(
104 | {
105 | "instruction": data[i]["instruction"],
106 | "input": "",
107 | "output": "",
108 | "length_bonus": length_bonus,
109 | "question_type": 0,
110 | }
111 | )
112 | print("Dolly examples:", len(merged_examples))
113 | return merged_examples
114 |
115 |
116 | def load_oasst_data(length_bonus):
117 | oasst_dataset = load_dataset("OpenAssistant/oasst1")["train"]
118 |
119 | def create_message_trees(dataset):
120 | """Create message trees from dataset."""
121 | # Organize data into dictionary based on parent_id
122 | organized_data = {}
123 | for message in dataset:
124 | parent_id = message["parent_id"]
125 | if parent_id not in organized_data:
126 | organized_data[parent_id] = []
127 | organized_data[parent_id].append(message)
128 |
129 | # Traverse data to create trees
130 | message_trees = []
131 | for root_messages in organized_data[None]:
132 | tree = []
133 | current_message = root_messages
134 | while current_message is not None:
135 | tree.append(current_message)
136 | children = organized_data.get(current_message["message_id"])
137 | current_message = children[0] if children else None
138 | message_trees.append(tree)
139 |
140 | return message_trees
141 |
142 | oasst_message_trees = create_message_trees(oasst_dataset)
143 | oasst_examples = []
144 |
145 | count = 0
146 | for oasst_tree in oasst_message_trees:
147 | if len(oasst_tree) >= 2:
148 | count += 1
149 | oasst_examples.append(
150 | {
151 | "instruction": oasst_tree[0]["text"],
152 | "input": "",
153 | "output": "",
154 | "length_bonus": length_bonus,
155 | "question_type": 0,
156 | }
157 | )
158 | print("OASST examples:", count)
159 | return oasst_examples
160 |
161 |
162 | def load_math_data(length_bonus):
163 | dataset = load_dataset("competition_math")
164 | train_data = dataset["train"]
165 | max_samples_per_dataset = 10000
166 |
167 | math_data = []
168 |
169 | for i in range(min(len(train_data), max_samples_per_dataset)):
170 | ex_instruction = train_data[i]["problem"].strip()
171 | math_data.append(
172 | {
173 | "instruction": ex_instruction,
174 | "input": "",
175 | "output": "",
176 | "length_bonus": length_bonus,
177 | "question_type": 1,
178 | }
179 | )
180 | return math_data
181 |
182 |
183 | def filter_and_clean_examples(merged_examples):
184 | tokenizer = AutoTokenizer.from_pretrained("TheBloke/dromedary-65b-lora-HF")
185 | max_seq_length = 256
186 | filtered_examples = []
187 | set_of_unique_instructions = set()
188 |
189 | # Filter out examples with non-ascci characters
190 | merged_examples = [
191 | {
192 | "instruction": remove_leading_fraction(example["instruction"]),
193 | "input": "",
194 | "output": "",
195 | "length_bonus": example["length_bonus"],
196 | "question_type": example["question_type"],
197 | }
198 | for example in merged_examples
199 | ]
200 |
201 | for example in tqdm.tqdm(merged_examples):
202 | instruction = example["instruction"]
203 | instruction_token_length = len(tokenizer.encode(instruction))
204 | if (
205 | 2 <= instruction_token_length <= max_seq_length
206 | and instruction not in set_of_unique_instructions
207 | ):
208 | filtered_examples.append(example)
209 | set_of_unique_instructions.add(instruction)
210 |
211 | return filtered_examples
212 |
213 |
214 | def load_json(share_gpt_data_path):
215 | with open(share_gpt_data_path, "r") as f:
216 | share_gpt_data = json.load(f)
217 | examples = []
218 | for data in share_gpt_data:
219 | examples.append(
220 | {
221 | "instruction": data["instruction"],
222 | "input": "",
223 | "output": "",
224 | }
225 | )
226 | print(f"{share_gpt_data_path} examples:", len(examples))
227 | return examples
228 |
229 |
230 | def main(
231 | sharegpt_prompt_path: str = "/path/to/sharegpt_prompts.json",
232 | openorca_prompt_path: str = "/path/to/openorca_prompts.json",
233 | output_file: str = "/path/to/salmon_merged_prompts.json",
234 | general_length_bonus: float = 1.0,
235 | reasoning_length_bonus: float = -0.4,
236 | redteaming_length_bonus: float = 0.0,
237 | ):
238 | del redteaming_length_bonus # unused
239 |
240 | dolly_examples = load_dolly_data(length_bonus=general_length_bonus)
241 | oasst_examples = load_oasst_data(length_bonus=general_length_bonus)
242 |
243 | sharegpt_examples = load_json(sharegpt_prompt_path)
244 | sharegpt_examples = [
245 | {
246 | "instruction": example["instruction"],
247 | "input": "",
248 | "output": "",
249 | "length_bonus": general_length_bonus,
250 | "question_type": 0,
251 | }
252 | for example in sharegpt_examples
253 | ]
254 |
255 | openorca_examples = load_json(openorca_prompt_path)
256 | openorca_examples = [
257 | {
258 | "instruction": example["instruction"],
259 | "input": "",
260 | "output": "",
261 | "length_bonus": reasoning_length_bonus,
262 | "question_type": 1,
263 | }
264 | for example in openorca_examples
265 | ]
266 | math_examples = load_math_data(length_bonus=reasoning_length_bonus)
267 |
268 | merged_examples = (
269 | dolly_examples
270 | + oasst_examples
271 | + sharegpt_examples
272 | + openorca_examples
273 | + math_examples
274 | )
275 | filtered_examples = filter_and_clean_examples(merged_examples)
276 |
277 | print("Total examples:", len(filtered_examples))
278 |
279 | with open(output_file, "w") as f:
280 | json.dump(filtered_examples, f, indent=2)
281 |
282 |
283 | if __name__ == "__main__":
284 | fire.Fire(main)
285 |
--------------------------------------------------------------------------------
/training/step3_ppo_training/scripts/train_ppo_model_70b_qlora_salmon.sh:
--------------------------------------------------------------------------------
1 | # We use 6 x 8 = 48 A100-80GB GPUs
2 | # salloc --nodes 6 --time 24:00:00 --gres=gpu:80g:8 srun bash scripts/train_ppo_model_70b_qlora_salmon.sh
3 |
4 | set -e
5 | set -x
6 |
7 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
8 | export MODEL_DIR="/your/model/dir"
9 | export DATA_DIR="/your/data/dir"
10 | export PYTHONPATH="$PWD:$PYTHONPATH"
11 | export GPUS_PER_NODE=8
12 | export OMP_NUM_THREADS=8
13 | export TOKENIZERS_PARALLELISM=false
14 | export CUDA_LAUNCH_BLOCKING=1 # Not sure if this is needed
15 |
16 | LEARNING_RATE=2e-5
17 | KL_COEF=0.01
18 | EPOCH=2
19 | ROLLOUT_BATCH_SIZE=576
20 | STEP_BATCH_SZIE=288
21 | ROLLOUT_PER_DEVICE_BATCH_SIZE=6
22 | REWARD_MODEL_PER_DEVICE_BATCH_SIZE=3
23 | STEP_PER_DEVICE_BATCH_SIZE=6
24 |
25 | JOB_ID=29400
26 | NUM_NODES=6
27 | NUM_TRAINERS=8
28 |
29 | export MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | tail -n 1)
30 |
31 | HOST_NODE_ADDR="$MASTER_ADDR:29400"
32 |
33 | # SALMON-SPECIFIC HYPER-PARAMETERS
34 | MAX_PRINCIPLES=3
35 | UNFINISHED_PENALTY=-6.0
36 | LENGTH_BONUS=5.0
37 | LENGTH_BONUS_UPPER_BOUND=0.8
38 | SAMPLING_TEMPARATURE=0.7
39 |
40 | cd ..
41 |
42 | torchrun \
43 | --nnodes=$NUM_NODES \
44 | --nproc-per-node=$NUM_TRAINERS \
45 | --rdzv-id=$JOB_ID \
46 | --rdzv-backend=c10d \
47 | --rdzv-endpoint=$HOST_NODE_ADDR \
48 | finetune_qlora_ppo.py \
49 | --do_train \
50 | --do_eval \
51 | --seed 42 \
52 | --step_batch_size $STEP_BATCH_SZIE \
53 | --step_per_device_batch_size $STEP_PER_DEVICE_BATCH_SIZE \
54 | --rollout_batch_size $ROLLOUT_BATCH_SIZE \
55 | --rollout_per_device_batch_size $ROLLOUT_PER_DEVICE_BATCH_SIZE \
56 | --reward_model_per_device_batch_size $REWARD_MODEL_PER_DEVICE_BATCH_SIZE \
57 | --per_device_eval_batch_size $ROLLOUT_PER_DEVICE_BATCH_SIZE \
58 | --policy_model_name_or_path "$MODEL_DIR/dromedary-2-70b-sft-qlora/adapter_model" \
59 | --reward_model_name_or_path "$MODEL_DIR/llama-2-70b-qlora-rm-pmp-sft/adapter_model" \
60 | --learning_rate $LEARNING_RATE \
61 | --init_value_with_reward True \
62 | --warmup_steps 5 \
63 | --dataset_path "$DATA_DIR/salmon_merged_prompts.json" \
64 | --train_splits "train" \
65 | --policy_meta_prompt_pattern "../prompts/dromedary_inference_prompt.txt" \
66 | --reward_meta_prompt_pattern "../prompts/salmon_reward_model_prompt_v1.txt" \
67 | --principle_collection_path "../prompts/principles/principle_collection_ppo.json" \
68 | --max_principles $MAX_PRINCIPLES \
69 | --stop_token "\n\n### User" \
70 | --output_dir "$NEW_MODEL_DIR/dromedary-2-70b-ppo-qlora-salmon" \
71 | --total_epochs $EPOCH \
72 | --group_by_length False \
73 | --evaluation_strategy "no" \
74 | --save_strategy "steps" \
75 | --save_steps 10 \
76 | --save_total_limit 100000 \
77 | --weight_decay 0.0 \
78 | --lr_scheduler_type "cosine" \
79 | --logging_steps 1 \
80 | --report_to "tensorboard" \
81 | --ddp_backend "nccl" \
82 | --bf16 True \
83 | --penalty_reward_value $UNFINISHED_REWARD_PENALTY \
84 | --length_bonus_score $LENGTH_BONUS \
85 | --length_bonus_upper_bound $LENGTH_BONUS_UPPER_BOUND \
86 | --relative_stop_token_penalty True \
87 | --penalize_no_stop_token True \
88 | --ddp_find_unused_parameters False \
89 | --resume_from_training True \
90 | --kl_coef $KL_COEF \
91 | --max_grad_norm 1.0 \
92 | --whitening_async_stats "full_batch" \
93 | --clean_tokens_after_eos True \
94 | --whiten_rewards False \
95 | --query_len 256 \
96 | --response_len 1024 \
97 | --model_max_length 1664 \
98 | --enable_reasoning_principles True \
99 | --enable_redteaming_principles True \
100 | --temperature $SAMPLING_TEMPARATURE \
101 | --noptepochs 2
102 |
--------------------------------------------------------------------------------
/training/step3_ppo_training/subsample_openorca_prompts.py:
--------------------------------------------------------------------------------
1 | import tqdm
2 | import json
3 | import random
4 | import pandas as pd
5 | import fire
6 | from transformers import LlamaTokenizerFast
7 |
8 | # https://huggingface.co/datasets/Open-Orca/OpenOrca/blob/main/1M-GPT4-Augmented.parquet
9 |
10 |
11 | def main(
12 | train_data_path: str = "path/to/data/1M-GPT4-Augmented.parquet",
13 | output_path: str = "path/to/data/openorca_prompts.json",
14 | tokenizer_name: str = "TheBloke/dromedary-65b-lora-HF", # a random llama-based model
15 | max_samples_per_dataset: int = 10000,
16 | max_prompt_len: int = 256,
17 | ):
18 | train_data = pd.read_parquet(train_data_path)
19 | tokenizer = LlamaTokenizerFast.from_pretrained(tokenizer_name)
20 | niv_data = []
21 | flan_data = []
22 | t0_data = []
23 | cot_data = []
24 |
25 | for i in tqdm.tqdm(range(len(train_data))):
26 | datapoint = train_data.iloc[i]
27 |
28 | d_id = datapoint["id"]
29 |
30 | if "niv" in d_id:
31 | niv_data.append(datapoint)
32 | continue
33 | if "flan" in d_id:
34 | flan_data.append(datapoint)
35 | continue
36 | if "t0" in d_id:
37 | t0_data.append(datapoint)
38 | continue
39 | if "cot" in d_id:
40 | cot_data.append(datapoint)
41 | continue
42 |
43 | raise ValueError("Unknown dataset")
44 |
45 | print("Stats:")
46 | print(f"niv: {len(niv_data)}")
47 | print(f"flan: {len(flan_data)}")
48 | print(f"t0: {len(t0_data)}")
49 | print(f"cot: {len(cot_data)}")
50 |
51 | total_data = []
52 |
53 | set_of_instructions = set()
54 |
55 | for data in tqdm.tqdm([niv_data, flan_data, t0_data, cot_data]):
56 | data_count = 0
57 | random.shuffle(data)
58 | for datapoint in tqdm.tqdm(data):
59 | ex_instruction = datapoint["question"].strip()
60 | ex_input = ""
61 | ex_output = ""
62 |
63 | if ex_instruction in set_of_instructions:
64 | continue
65 | else:
66 | set_of_instructions.add(ex_instruction)
67 |
68 | if (
69 | len(tokenizer.encode(ex_instruction + "\n\n" + ex_input))
70 | > max_prompt_len
71 | ):
72 | continue
73 |
74 | # We delete NLI tasks
75 | if "it is not possible to tell" in datapoint["question"]:
76 | continue
77 |
78 | continue_flag = False
79 | for keyword in ["Premise", "premise", "Hypothesis", "hypothesis"]:
80 | if keyword in datapoint["question"]:
81 | continue_flag = True
82 | break
83 | if continue_flag:
84 | continue
85 |
86 | total_data.append(
87 | {
88 | "instruction": ex_instruction,
89 | "input": ex_input,
90 | "output": ex_output,
91 | }
92 | )
93 |
94 | data_count += 1
95 | if data_count >= max_samples_per_dataset:
96 | break
97 | print("data_count:", data_count, len(total_data))
98 |
99 | random.shuffle(total_data)
100 |
101 | print(f"Total data: {len(total_data)}")
102 |
103 | with open(output_path, "w") as f:
104 | json.dump(total_data, f, indent=2)
105 |
106 |
107 | if __name__ == "__main__":
108 | fire.Fire(main)
109 |
--------------------------------------------------------------------------------
/training/train_qlora_rm.py:
--------------------------------------------------------------------------------
1 | # This source code is licensed under the MIT license found in the
2 | # LICENSE file in the root directory of this source tree.
3 |
4 | import json
5 |
6 | import os
7 | from dataclasses import dataclass, field
8 | from typing import Optional, List, Literal
9 | import logging
10 |
11 | import torch
12 | import transformers
13 | import argparse
14 | from transformers import set_seed
15 |
16 | try:
17 | from transformers import LlamaTokenizerFast as LlamaTokenizer
18 |
19 | print("Using fast tokenizer")
20 | except:
21 | from transformers import LlamaTokenizer
22 |
23 | print("Using slow tokenizer")
24 |
25 | from transformers import AutoTokenizer, AutoModelForCausalLM
26 |
27 | from qlora_utils import (
28 | SavePeftModelCallback,
29 | print_trainable_parameters,
30 | get_last_checkpoint,
31 | DEFAULT_PAD_TOKEN,
32 | )
33 | from data_utils.data_utils_rm import make_binary_reward_modeling_data_module
34 | from models.reward_model import (
35 | RewardConfig,
36 | RewardModel,
37 | RewardModelTrainer as Trainer,
38 | compute_reward_modeling_metrics,
39 | )
40 |
41 | torch.backends.cuda.matmul.allow_tf32 = True
42 |
43 | logger = logging.getLogger(__name__)
44 |
45 |
46 | @dataclass
47 | class ModelArguments:
48 | model_name_or_path: Optional[str] = field(default="EleutherAI/pythia-12b")
49 | trust_remote_code: Optional[bool] = field(
50 | default=False,
51 | metadata={
52 | "help": "Enable unpickling of arbitrary code in AutoModelForCausalLM#from_pretrained."
53 | },
54 | )
55 |
56 |
57 | @dataclass
58 | class DataArguments:
59 | dataset_path: str = field(default="tatsu-lab/alpaca_farm")
60 | dataset_name: str = field(default=None, metadata={"help": "Dataset name"})
61 | eval_dataset_path: str = field(default="tatsu-lab/alpaca_farm")
62 | eval_dataset_name: str = field(default="alpaca_human_preference")
63 | eval_size: int = field(
64 | default=500,
65 | metadata={
66 | "help": "Number of examples to split out from training to use for evaluation."
67 | },
68 | )
69 | meta_prompt_pattern: Optional[str] = field(
70 | default=None, metadata={"help": "Which meta prompt pattern to use."}
71 | )
72 | principle_collection_path: Optional[str] = field(
73 | default=None, metadata={"help": "Path to the principle collection."}
74 | )
75 |
76 |
77 | @dataclass
78 | class TrainingArguments(transformers.Seq2SeqTrainingArguments):
79 | cache_dir: Optional[str] = field(default=None)
80 | # From AlpacaFarm
81 | model_max_length: int = field(
82 | default=512,
83 | metadata={
84 | "help": "Maximum sequence length. Sequences will be left padded to this length always during training."
85 | },
86 | )
87 | query_len: int = field(default=None, metadata={"help": "Length of the query."})
88 | response_len: int = field(
89 | default=None, metadata={"help": "Length of the response."}
90 | )
91 | label_names: List[str] = field(
92 | default_factory=lambda: ["index_0", "index_1", "choice"],
93 | metadata={
94 | "help": "Names of the labels in the dataset. "
95 | "This is needed to get transformers.Trainer to not throw those tensors away before `compute_loss`."
96 | "By default, the trainer throws away columns it doesn't recognize when creating the "
97 | "`train_dataloader` (see `_remove_unused_columns`). "
98 | },
99 | )
100 | padding: Literal["max_length", "longest"] = field(
101 | default="longest",
102 | metadata={
103 | "help": "Padding strategy. If 'max_length', pads to `model_max_length` always; this might lead to some "
104 | "redundant compute. If 'longest', pads to the longest sequence in the batch, capped by `model_max_length`."
105 | },
106 | )
107 | end_sequence_with_eos: bool = field(
108 | default=False,
109 | metadata={
110 | "help": "Whether to end sequences with EOS. "
111 | "Ending with EOS might help the reward model realize it's time to predict."
112 | },
113 | )
114 | # From QLoRA
115 | full_finetune: bool = field(
116 | default=False, metadata={"help": "Finetune the entire model without adapters."}
117 | )
118 | adam8bit: bool = field(default=False, metadata={"help": "Use 8-bit adam."})
119 | double_quant: bool = field(
120 | default=True,
121 | metadata={
122 | "help": "Compress the quantization statistics through double quantization."
123 | },
124 | )
125 | quant_type: str = field(
126 | default="nf4",
127 | metadata={
128 | "help": "Quantization data type to use. Should be one of `fp4` or `nf4`."
129 | },
130 | )
131 | bits: int = field(default=4, metadata={"help": "How many bits to use."})
132 | lora_modules: Optional[List[str]] = field(
133 | default=None,
134 | metadata={
135 | "help": "Which modules to use LoRA on. If None, will use all linear layers."
136 | },
137 | )
138 | lora_r: int = field(default=64, metadata={"help": "Lora R dimension."})
139 | lora_alpha: float = field(default=16, metadata={"help": " Lora alpha."})
140 | lora_dropout: float = field(default=0.0, metadata={"help": "Lora dropout."})
141 | report_to: str = field(
142 | default="none",
143 | metadata={"help": "To use wandb or something else for reporting."},
144 | )
145 | resume_dir: Optional[str] = field(
146 | default=None,
147 | metadata={"help": "Path to the directory containing the checkpoint to resume."},
148 | )
149 | output_dir: str = field(
150 | default="./output", metadata={"help": "The output dir for logs and checkpoints"}
151 | )
152 | optim: str = field(
153 | default="paged_adamw_32bit", metadata={"help": "The optimizer to be used"}
154 | )
155 | per_device_train_batch_size: int = field(
156 | default=1,
157 | metadata={
158 | "help": "The training batch size per GPU. Increase for better speed."
159 | },
160 | )
161 | gradient_accumulation_steps: int = field(
162 | default=16,
163 | metadata={
164 | "help": "How many gradients to accumulate before to perform an optimizer step"
165 | },
166 | )
167 | weight_decay: float = field(
168 | default=0.0, metadata={"help": "The L2 weight decay rate of AdamW"}
169 | ) # use lora dropout instead for regularization if needed
170 | learning_rate: float = field(default=0.0002, metadata={"help": "The learnign rate"})
171 | remove_unused_columns: bool = field(
172 | default=False,
173 | metadata={"help": "Removed unused columns. Needed to make this codebase work."},
174 | )
175 | max_grad_norm: float = field(
176 | default=0.3,
177 | metadata={
178 | "help": "Gradient clipping max norm. This is tuned and works well for all models tested."
179 | },
180 | )
181 | gradient_checkpointing: bool = field(
182 | default=True,
183 | metadata={"help": "Use gradient checkpointing. You want to use this."},
184 | )
185 | do_train: bool = field(
186 | default=True,
187 | metadata={"help": "To train or not to train, that is the question?"},
188 | )
189 | lr_scheduler_type: str = field(
190 | default="constant",
191 | metadata={
192 | "help": "Learning rate schedule. Constant a bit better than cosine, and has advantage for analysis"
193 | },
194 | )
195 | warmup_ratio: float = field(
196 | default=0.03, metadata={"help": "Fraction of steps to do a warmup for"}
197 | )
198 | logging_steps: int = field(
199 | default=10,
200 | metadata={"help": "The frequency of update steps after which to log the loss"},
201 | )
202 | group_by_length: bool = field(
203 | default=True,
204 | metadata={
205 | "help": "Group sequences into batches with same length. Saves memory and speeds up training considerably."
206 | },
207 | )
208 | save_strategy: str = field(
209 | default="steps", metadata={"help": "When to save checkpoints"}
210 | )
211 | save_steps: int = field(default=250, metadata={"help": "How often to save a model"})
212 | save_total_limit: int = field(
213 | default=40,
214 | metadata={
215 | "help": "How many checkpoints to save before the oldest is overwritten"
216 | },
217 | )
218 | resume_from_training: bool = field(
219 | default=False, metadata={"help": "Resume from training"}
220 | )
221 |
222 |
223 | def rank0_print(*args):
224 | local_rank = int(os.environ.get("LOCAL_RANK", 0))
225 | if local_rank == 0:
226 | print(*args)
227 |
228 |
229 | def train():
230 | hfparser = transformers.HfArgumentParser(
231 | (ModelArguments, DataArguments, TrainingArguments)
232 | )
233 | (
234 | model_args,
235 | data_args,
236 | training_args,
237 | extra_args,
238 | ) = hfparser.parse_args_into_dataclasses(return_remaining_strings=True)
239 | args = argparse.Namespace(
240 | **vars(model_args), **vars(data_args), **vars(training_args)
241 | )
242 |
243 | if args.resume_dir is not None:
244 | checkpoint_dir, completed_training = args.resume_dir, False
245 | else:
246 | checkpoint_dir, completed_training = get_last_checkpoint(args.output_dir)
247 |
248 | if completed_training:
249 | rank0_print("Detected that training was already completed!")
250 |
251 | if checkpoint_dir is None:
252 | rank0_print("Training from scratch.")
253 | else:
254 | rank0_print("Loading from checkpoint:", checkpoint_dir)
255 | if args.resume_from_training:
256 | rank0_print("Resuming from training not supported yet. Exiting.")
257 | exit(1)
258 |
259 | use_llama_base_model = (
260 | "dromedary" in args.model_name_or_path.lower()
261 | or "llama" in args.model_name_or_path.lower()
262 | or "vicuna" in args.model_name_or_path.lower()
263 | )
264 |
265 | if use_llama_base_model:
266 | tokenizer_model_name = (
267 | "TheBloke/dromedary-65b-lora-HF" # TODO(zhiqings): hacking
268 | )
269 | TokenizerClass = LlamaTokenizer
270 | else:
271 | tokenizer_model_name = args.model_name_or_path
272 | TokenizerClass = AutoTokenizer
273 |
274 | # Tokenizer
275 | tokenizer = TokenizerClass.from_pretrained(
276 | tokenizer_model_name,
277 | cache_dir=args.cache_dir,
278 | model_max_length=training_args.model_max_length,
279 | padding_side="left",
280 | truncation_side="right",
281 | )
282 |
283 | if use_llama_base_model:
284 | if tokenizer._pad_token is None:
285 | tokenizer.pad_token_id = (
286 | 0 # unk. we want this to be different from the eos token
287 | )
288 | else:
289 | raise NotImplementedError
290 |
291 | data_module = make_binary_reward_modeling_data_module(
292 | tokenizer=tokenizer,
293 | data_args=data_args,
294 | training_args=training_args,
295 | )
296 |
297 | if args.do_train:
298 | training_data = data_module["train_dataset"]
299 | rank0_print("Training data size:", len(training_data))
300 | rank0_print("Training data example:")
301 | for i in range(min(3, len(training_data))):
302 | ex_input_ids_0 = training_data[i]["input_ids"][0]
303 | rank0_print(tokenizer.decode(ex_input_ids_0, skip_special_tokens=True))
304 | rank0_print("=" * 20)
305 | ex_input_ids_1 = training_data[i]["input_ids"][1]
306 | rank0_print(tokenizer.decode(ex_input_ids_1, skip_special_tokens=True))
307 | rank0_print("=" * 20)
308 | rank0_print("=" * 20)
309 |
310 | config = RewardConfig(backbone_model_name_or_path=model_args.model_name_or_path)
311 |
312 | model = RewardModel(
313 | args=args,
314 | config=config,
315 | qlora=True,
316 | checkpoint_dir=checkpoint_dir,
317 | )
318 |
319 | model.backbone_model.config.use_cache = False
320 | print_trainable_parameters(args, model)
321 | print("loaded model")
322 | set_seed(args.seed)
323 |
324 | trainer = Trainer(
325 | model=model,
326 | tokenizer=tokenizer,
327 | args=training_args,
328 | compute_metrics=compute_reward_modeling_metrics,
329 | **{k: v for k, v in data_module.items() if k != "predict_dataset"},
330 | )
331 |
332 | # Callbacks
333 | if not args.full_finetune:
334 | trainer.add_callback(SavePeftModelCallback)
335 |
336 | # Verifying the datatypes.
337 | dtypes = {}
338 | for _, p in model.named_parameters():
339 | dtype = p.dtype
340 | if dtype not in dtypes:
341 | dtypes[dtype] = 0
342 | dtypes[dtype] += p.numel()
343 | total = 0
344 | for k, v in dtypes.items():
345 | total += v
346 | for k, v in dtypes.items():
347 | print(k, v, v / total)
348 |
349 | all_metrics = {"run_name": args.run_name}
350 |
351 | # Training
352 | if args.do_train:
353 | logger.info("*** Train ***")
354 | train_result = trainer.train()
355 | metrics = train_result.metrics
356 | trainer.log_metrics("train", metrics)
357 | trainer.save_metrics("train", metrics)
358 | trainer.save_state()
359 | all_metrics.update(metrics)
360 |
361 | # Evaluation
362 | if args.do_eval:
363 | logger.info("*** Evaluate ***")
364 | metrics = trainer.evaluate(metric_key_prefix="eval")
365 | trainer.log_metrics("eval", metrics)
366 | trainer.save_metrics("eval", metrics)
367 | all_metrics.update(metrics)
368 |
369 | if args.do_train or args.do_eval:
370 | with open(os.path.join(args.output_dir, "metrics.json"), "w") as fout:
371 | fout.write(json.dumps(all_metrics))
372 |
373 |
374 | if __name__ == "__main__":
375 | train()
376 |
--------------------------------------------------------------------------------
/training/train_qlora_sft.py:
--------------------------------------------------------------------------------
1 | # This source code is licensed under the MIT license found in the
2 | # LICENSE file in the root directory of this source tree.
3 |
4 | import json
5 |
6 | import os
7 | from dataclasses import dataclass, field
8 | from typing import Optional, List
9 | import logging
10 |
11 | import torch
12 | import transformers
13 | import argparse
14 | from transformers import (
15 | set_seed,
16 | Trainer,
17 | )
18 |
19 | try:
20 | from transformers import LlamaTokenizerFast as LlamaTokenizer
21 |
22 | print("Using fast tokenizer")
23 | except:
24 | from transformers import LlamaTokenizer
25 |
26 | print("Using slow tokenizer")
27 |
28 | from transformers import AutoTokenizer
29 |
30 | from qlora_utils import (
31 | SavePeftModelCallback,
32 | print_trainable_parameters,
33 | get_last_checkpoint,
34 | DEFAULT_PAD_TOKEN,
35 | )
36 | from data_utils.data_utils_sft import (
37 | make_sft_data_module,
38 | IGNORE_INDEX,
39 | )
40 | from models.qlora_model import get_accelerate_model
41 |
42 |
43 | torch.backends.cuda.matmul.allow_tf32 = True
44 |
45 | logger = logging.getLogger(__name__)
46 |
47 |
48 | @dataclass
49 | class ModelArguments:
50 | model_name_or_path: Optional[str] = field(default="EleutherAI/pythia-12b")
51 | trust_remote_code: Optional[bool] = field(
52 | default=False,
53 | metadata={
54 | "help": "Enable unpickling of arbitrary code in AutoModelForCausalLM#from_pretrained."
55 | },
56 | )
57 |
58 |
59 | @dataclass
60 | class DataArguments:
61 | eval_dataset_size: int = field(
62 | default=1024, metadata={"help": "Size of validation dataset."}
63 | )
64 | max_train_samples: Optional[int] = field(
65 | default=None,
66 | metadata={
67 | "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
68 | "value if set."
69 | },
70 | )
71 | max_eval_samples: Optional[int] = field(
72 | default=None,
73 | metadata={
74 | "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
75 | "value if set."
76 | },
77 | )
78 | source_max_len: int = field(
79 | default=1024,
80 | metadata={
81 | "help": "Maximum source sequence length. Sequences will be right padded (and possibly truncated)."
82 | },
83 | )
84 | target_max_len: int = field(
85 | default=256,
86 | metadata={
87 | "help": "Maximum target sequence length. Sequences will be right padded (and possibly truncated)."
88 | },
89 | )
90 | dataset: str = field(
91 | default="alpaca",
92 | metadata={"help": "Which dataset to finetune on. See datamodule for options."},
93 | )
94 | dataset_format: Optional[str] = field(
95 | default=None,
96 | metadata={
97 | "help": "Which dataset format is used. [alpaca|chip2|self-instruct|hh-rlhf]"
98 | },
99 | )
100 | meta_prompt_pattern: Optional[str] = field(
101 | default=None, metadata={"help": "Which meta prompt pattern to use."}
102 | )
103 | add_eos_to_target: bool = field(
104 | default=True, metadata={"help": "Whether to add an EOS token to the target."}
105 | )
106 |
107 |
108 | @dataclass
109 | class TrainingArguments(transformers.Seq2SeqTrainingArguments):
110 | cache_dir: Optional[str] = field(default=None)
111 | train_on_source: Optional[bool] = field(
112 | default=False,
113 | metadata={
114 | "help": "Whether to train on the input in addition to the target text."
115 | },
116 | )
117 | full_finetune: bool = field(
118 | default=False, metadata={"help": "Finetune the entire model without adapters."}
119 | )
120 | adam8bit: bool = field(default=False, metadata={"help": "Use 8-bit adam."})
121 | double_quant: bool = field(
122 | default=True,
123 | metadata={
124 | "help": "Compress the quantization statistics through double quantization."
125 | },
126 | )
127 | quant_type: str = field(
128 | default="nf4",
129 | metadata={
130 | "help": "Quantization data type to use. Should be one of `fp4` or `nf4`."
131 | },
132 | )
133 | bits: int = field(default=4, metadata={"help": "How many bits to use."})
134 | lora_modules: Optional[List[str]] = field(
135 | default=None,
136 | metadata={
137 | "help": "Which modules to use LoRA on. If None, will use all linear layers."
138 | },
139 | )
140 | lora_r: int = field(default=64, metadata={"help": "Lora R dimension."})
141 | lora_alpha: float = field(default=16, metadata={"help": " Lora alpha."})
142 | lora_dropout: float = field(default=0.0, metadata={"help": "Lora dropout."})
143 | max_memory_MB: int = field(default=80000, metadata={"help": "Free memory per gpu."})
144 | report_to: str = field(
145 | default="none",
146 | metadata={"help": "To use wandb or something else for reporting."},
147 | )
148 | resume_dir: Optional[str] = field(
149 | default=None,
150 | metadata={"help": "Path to the directory containing the checkpoint to resume."},
151 | )
152 | output_dir: str = field(
153 | default="./output", metadata={"help": "The output dir for logs and checkpoints"}
154 | )
155 | optim: str = field(
156 | default="paged_adamw_32bit", metadata={"help": "The optimizer to be used"}
157 | )
158 | per_device_train_batch_size: int = field(
159 | default=1,
160 | metadata={
161 | "help": "The training batch size per GPU. Increase for better speed."
162 | },
163 | )
164 | gradient_accumulation_steps: int = field(
165 | default=16,
166 | metadata={
167 | "help": "How many gradients to accumulate before to perform an optimizer step"
168 | },
169 | )
170 | weight_decay: float = field(
171 | default=0.0, metadata={"help": "The L2 weight decay rate of AdamW"}
172 | ) # use lora dropout instead for regularization if needed
173 | learning_rate: float = field(default=0.0002, metadata={"help": "The learnign rate"})
174 | remove_unused_columns: bool = field(
175 | default=False,
176 | metadata={"help": "Removed unused columns. Needed to make this codebase work."},
177 | )
178 | max_grad_norm: float = field(
179 | default=0.3,
180 | metadata={
181 | "help": "Gradient clipping max norm. This is tuned and works well for all models tested."
182 | },
183 | )
184 | gradient_checkpointing: bool = field(
185 | default=True,
186 | metadata={"help": "Use gradient checkpointing. You want to use this."},
187 | )
188 | do_train: bool = field(
189 | default=True,
190 | metadata={"help": "To train or not to train, that is the question?"},
191 | )
192 | lr_scheduler_type: str = field(
193 | default="constant",
194 | metadata={
195 | "help": "Learning rate schedule. Constant a bit better than cosine, and has advantage for analysis"
196 | },
197 | )
198 | warmup_ratio: float = field(
199 | default=0.03, metadata={"help": "Fraction of steps to do a warmup for"}
200 | )
201 | logging_steps: int = field(
202 | default=10,
203 | metadata={"help": "The frequency of update steps after which to log the loss"},
204 | )
205 | group_by_length: bool = field(
206 | default=True,
207 | metadata={
208 | "help": "Group sequences into batches with same length. Saves memory and speeds up training considerably."
209 | },
210 | )
211 | save_strategy: str = field(
212 | default="steps", metadata={"help": "When to save checkpoints"}
213 | )
214 | save_steps: int = field(default=250, metadata={"help": "How often to save a model"})
215 | save_total_limit: int = field(
216 | default=40,
217 | metadata={
218 | "help": "How many checkpoints to save before the oldest is overwritten"
219 | },
220 | )
221 | resume_from_training: bool = field(
222 | default=False, metadata={"help": "Resume from training"}
223 | )
224 |
225 |
226 | def rank0_print(*args):
227 | local_rank = int(os.environ.get("LOCAL_RANK", 0))
228 | if local_rank == 0:
229 | print(*args)
230 |
231 |
232 | def train():
233 | hfparser = transformers.HfArgumentParser(
234 | (ModelArguments, DataArguments, TrainingArguments)
235 | )
236 | (
237 | model_args,
238 | data_args,
239 | training_args,
240 | extra_args,
241 | ) = hfparser.parse_args_into_dataclasses(return_remaining_strings=True)
242 | args = argparse.Namespace(
243 | **vars(model_args), **vars(data_args), **vars(training_args)
244 | )
245 |
246 | if args.resume_dir is not None:
247 | checkpoint_dir, completed_training = args.resume_dir, False
248 | else:
249 | checkpoint_dir, completed_training = get_last_checkpoint(args.output_dir)
250 |
251 | if completed_training:
252 | rank0_print("Detected that training was already completed!")
253 |
254 | if checkpoint_dir is None:
255 | rank0_print("Training from scratch.")
256 | else:
257 | rank0_print("Loading from checkpoint:", checkpoint_dir)
258 | if args.resume_from_training:
259 | rank0_print("Resuming from training not supported yet. Exiting.")
260 | exit(1)
261 |
262 | use_llama_base_model = (
263 | "dromedary" in args.model_name_or_path.lower()
264 | or "llama" in args.model_name_or_path.lower()
265 | )
266 |
267 | if use_llama_base_model:
268 | tokenizer_model_name = (
269 | "TheBloke/dromedary-65b-lora-HF" # a random llama-based model
270 | )
271 | TokenizerClass = LlamaTokenizer
272 | else:
273 | tokenizer_model_name = args.model_name_or_path
274 | TokenizerClass = AutoTokenizer
275 |
276 | left_truncated_tokenizer = TokenizerClass.from_pretrained(
277 | tokenizer_model_name,
278 | cache_dir=args.cache_dir,
279 | truncation_side="left",
280 | padding_side="right",
281 | # use_fast=False, # Fast tokenizer giving issues.
282 | )
283 |
284 | # Tokenizer
285 | tokenizer = TokenizerClass.from_pretrained(
286 | tokenizer_model_name,
287 | cache_dir=args.cache_dir,
288 | truncation_side="right",
289 | padding_side="right",
290 | )
291 |
292 | if use_llama_base_model:
293 | if tokenizer._pad_token is None:
294 | left_truncated_tokenizer.pad_token_id = (
295 | 0 # unk. we want this to be different from the eos token
296 | )
297 | tokenizer.pad_token_id = (
298 | 0 # unk. we want this to be different from the eos token
299 | )
300 |
301 | else:
302 | raise NotImplementedError
303 |
304 | data_module = make_sft_data_module(
305 | left_truncated_tokenizer=left_truncated_tokenizer,
306 | tokenizer=tokenizer,
307 | args=args,
308 | )
309 |
310 | model = get_accelerate_model(args, checkpoint_dir)
311 |
312 | model.config.use_cache = False
313 | print_trainable_parameters(args, model)
314 | print("loaded model")
315 | set_seed(args.seed)
316 |
317 | if args.do_train:
318 | training_data = data_module["train_dataset"]
319 | rank0_print("Training data size:", len(training_data))
320 | rank0_print("Training data example:")
321 | for i in range(min(3, len(training_data))):
322 | rank0_print(training_data[i])
323 |
324 | trainer = Trainer(
325 | model=model,
326 | tokenizer=tokenizer,
327 | args=training_args,
328 | **{k: v for k, v in data_module.items() if k != "predict_dataset"},
329 | )
330 |
331 | # Callbacks
332 | if not args.full_finetune:
333 | trainer.add_callback(SavePeftModelCallback)
334 |
335 | # Verifying the datatypes.
336 | dtypes = {}
337 | for _, p in model.named_parameters():
338 | dtype = p.dtype
339 | if dtype not in dtypes:
340 | dtypes[dtype] = 0
341 | dtypes[dtype] += p.numel()
342 | total = 0
343 | for k, v in dtypes.items():
344 | total += v
345 | for k, v in dtypes.items():
346 | print(k, v, v / total)
347 |
348 | all_metrics = {"run_name": args.run_name}
349 | # Training
350 | if args.do_train:
351 | logger.info("*** Train ***")
352 | # Note: `resume_from_checkpoint` not supported for adapter checkpoints by HF.
353 | # Currently adapter checkpoint is reloaded as expected but optimizer/scheduler states are not.
354 | train_result = trainer.train()
355 | metrics = train_result.metrics
356 | trainer.log_metrics("train", metrics)
357 | trainer.save_metrics("train", metrics)
358 | trainer.save_state()
359 | all_metrics.update(metrics)
360 |
361 | if args.do_train:
362 | with open(os.path.join(args.output_dir, "metrics.json"), "w") as fout:
363 | fout.write(json.dumps(all_metrics))
364 |
365 |
366 | if __name__ == "__main__":
367 | train()
368 |
--------------------------------------------------------------------------------