├── .gitignore ├── DATA_LICENSE ├── LICENSE ├── README.md ├── assets └── images │ ├── Dromedary-2.png │ ├── salmon_comparison.png │ ├── salmon_logo_with_text.jpeg │ └── salmon_pipeline.png ├── inference ├── README.md └── demo.py ├── prompts ├── dromedary_inference_prompt.txt ├── pmp_reward_model_prompt.txt ├── principles │ ├── principle_collection_harmless.json │ ├── principle_collection_honest.json │ ├── principle_collection_non_evasive.json │ ├── principle_collection_ppo.json │ └── principle_collection_rm.json ├── salmon_reward_model_prompt_v0.txt ├── salmon_reward_model_prompt_v1.txt └── synthetic_preference_prompt.txt ├── requirements.txt └── training ├── README.md ├── data_utils ├── common_utils.py ├── data_utils_ppo.py ├── data_utils_rm.py └── data_utils_sft.py ├── models ├── configuration_llama.py ├── distributed_utils.py ├── llama_with_flash_attn.py ├── ppo_trainer.py ├── qlora_model.py ├── reward_model.py ├── rl_models.py ├── rl_trainer.py └── trainer_utils.py ├── qlora_utils.py ├── step1_synthetic_preference_collection ├── batch_generation.py ├── clean_oasst1_prompts.py ├── scripts │ ├── generate_oasst1_response0.sh │ ├── generate_oasst1_response1.sh │ └── generate_synthetic_preference.sh └── synthetic_preference.py ├── step2_rm_training ├── aggregate_synthetic_preference.py ├── clean_pmp_data.py └── scripts │ ├── train_reward_model_70b_qlora_ft.sh │ └── train_reward_model_70b_qlora_pmp.sh ├── step3_ppo_training ├── aggregate_sharegpt_prompts.py ├── clean_and_merge_prompts.py ├── scripts │ └── train_ppo_model_70b_qlora_salmon.sh └── subsample_openorca_prompts.py ├── train_qlora_ppo.py ├── train_qlora_rm.py └── train_qlora_sft.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | .DS_Store 132 | .idea 133 | 134 | # temporary scripts 135 | tmp_scripts/ 136 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | SALMON Logo 3 |

Generated by DALL·E 3

4 |
5 | 6 |
7 | 8 | 9 | 10 | ## SALMON: Self-Alignment with Principle-Following Reward Models 11 | 12 |
13 | 14 | [![Code License](https://img.shields.io/badge/Code%20License-GPL_3.0-green.svg)](LICENSE) 15 | [![Data License](https://img.shields.io/badge/Data%20License-CC%20By%20NC%204.0-red.svg)](DATA_LICENSE) 16 | 17 | 18 | 19 | 20 | 21 | 44 | 45 | SALMON is a new RLAIF paradigm for self-aligning language models from scratch, using only a small set of human-defined principles as guidance. 46 | Central to our approach is a principle-following reward model. Trained on synthetic preference data, this model can generate reward scores based on arbitrary human-defined principles. 47 | For comprehensive details and insights, we kindly direct you to our [paper](https://arxiv.org/abs/2310.05910). 48 | 49 |

50 | 51 | 52 | SALMON Comparison 53 | 54 | 56 | 57 |

58 | 59 | ## Dromedary-2 60 | 61 | 62 | 63 | We release the *Dromedary-2* model, which is trained with the SALMON paradigm on the [*LLaMA-2-70b* base language model](https://huggingface.co/meta-llama/Llama-2-70b-hf), with [Principle-Driven Self-Alignment](https://github.com/IBM/Dromedary) as the Supervised Fine-Tuning (SFT) strategy to initialize the policy model. 64 | 65 | This codebase focuses on the **Reinforcement Learning (RL)** fine-tuning stage with the SALMON paradigm, while the Self-Align SFT training pipeline is released at the [original Dromedary repo](https://github.com/IBM/Dromedary), 66 | 67 | Dromedary-2 Pipeline 68 | 69 | ### Model Weights 70 | 71 | We release *Dromedary-2* weights as delta weights to comply with the LLaMA model license. You can directly load our QLoRA weights upon the *LLaMA-2* base model to obtain *Dromedary-2*. Instructions: 72 | 73 | 1. Get the original LLaMA-2 weights in the Hugging Face format by following the instructions [here](https://huggingface.co/meta-llama/Llama-2-70b-hf). 74 | 2. Download the QLoRA delta weights from our Hugging Face [model hub](https://huggingface.co/zhiqings/dromedary-2-70b-qlora-delta-v0). 75 | 3. Load the model with Hugging Face's [PEFT-LoRA](https://github.com/huggingface/peft) and QLoRA's [bitsandbytes](https://github.com/TimDettmers/bitsandbytes). 76 | 77 | **NOTE: *Dromedary-2* is trained with QLoRA and the bfloat16 data type.** While it is [possible](https://gist.github.com/ChrisHayduk/1a53463331f52dca205e55982baf9930) to merge the QLoRA weights with the quantized model and thus enable inference with libraries such as [TGI](https://github.com/huggingface/text-generation-inference) and [vLLM](https://github.com/vllm-project/vllm), we found the merged weights can lead to degenerated performance. Therefore, we recommend directly loading the QLoRA weights with the PEFT-LoRA framework. 78 | 79 | ```python 80 | # Please check the inference section for the complete inference code. 81 | system_prompt = ( 82 | "# Dromedary\n\n## System Overview\n\n" 83 | "Consider an AI assistant whose codename is Dromedary, developed by the Self-Align team. " 84 | "Dromedary is trained on data up until Sept-2022, and it endeavors to be a helpful, ethical and reliable assistant.\n\n" 85 | "## User Conversation\n\n" 86 | ) 87 | user_prompt = "### User\n" 88 | assistant_prompt = "### Dromedary\n" 89 | seperator = "\n\n" 90 | 91 | # USAGE: system_prompt + user_prompt + `user_message` + seperator + assistant_prompt + `assistant_message` + seperator + user_prompt ... 92 | 93 | dtype = torch.bfloat16 94 | 95 | model_path = "path/to/llama-2-70b-hf" 96 | qlora_path = "path/to/dromedary-2-70b-qlora-delta-v0" 97 | 98 | bnb_config = BitsAndBytesConfig( 99 | load_in_4bit=True, 100 | bnb_4bit_compute_dtype=dtype, 101 | bnb_4bit_use_double_quant=True, 102 | bnb_4bit_quant_type="nf4", 103 | ) 104 | 105 | model = AutoModelForCausalLM.from_pretrained( 106 | model_path, 107 | load_in_4bit=True, 108 | device_map={"": "cuda:0"}, 109 | quantization_config=bnb_config, 110 | torch_dtype=dtype, 111 | ) 112 | 113 | model = PeftModel.from_pretrained( 114 | model, 115 | qlora_path, 116 | is_trainable=False, 117 | ) 118 | ``` 119 | 120 | ## Setup 121 | 122 | 1. Clone this repository and navigate to SALMON folder 123 | 124 | ```Shell 125 | git clone https://github.com/IBM/SALMON 126 | cd SALMON 127 | ``` 128 | 129 | 2. Install the packages 130 | 131 | ```Shell 132 | conda create -n salmon python=3.9 -y 133 | conda activate salmon 134 | pip install -r requirements.txt 135 | ``` 136 | 137 | ## Inference 138 | 139 | We provide a [chatbot demo](inference) for *Dromedary-2*. 140 | 141 | ## Training 142 | 143 | We provide the full [training pipeline](training) of *Dromedary-2* for reproduction. 144 | 145 | ## Prompts 146 | 147 | All the human supervision used in this project can be found [here](prompts). 148 | 149 | ### Citation 150 | 151 | Please consider citing the following papers if you use the data or code in this repo. 152 | 153 | ``` 154 | @misc{sun2023principledriven, 155 | title={Principle-Driven Self-Alignment of Language Models from Scratch with Minimal Human Supervision}, 156 | author={Zhiqing Sun and Yikang Shen and Qinhong Zhou and Hongxin Zhang and Zhenfang Chen and David Cox and Yiming Yang and Chuang Gan}, 157 | year={2023}, 158 | eprint={2305.03047}, 159 | archivePrefix={arXiv}, 160 | primaryClass={cs.LG} 161 | } 162 | ``` 163 | 164 | ``` 165 | @misc{sun2023salmon, 166 | title={SALMON: Self-Alignment with Principle-Following Reward Models}, 167 | author={Zhiqing Sun and Yikang Shen and Hongxin Zhang and Qinhong Zhou and Zhenfang Chen and David Cox and Yiming Yang and Chuang Gan}, 168 | year={2023}, 169 | eprint={2310.05910}, 170 | archivePrefix={arXiv}, 171 | primaryClass={cs.LG} 172 | } 173 | ``` 174 | 175 | ### Acknowledgements 176 | 177 | We thank [Meta LLaMA team](https://github.com/facebookresearch/llama), [Standford Alpaca team](https://github.com/tatsu-lab/stanford_alpaca), [Vicuna team](https://github.com/lm-sys/FastChat), [Alpaca-LoRA](https://github.com/tloen/alpaca-lora), [QLoRA team](https://github.com/artidoro/qlora), [Hugging Face PEFT](https://github.com/huggingface/peft), and [AlpacaFarm team](https://github.com/tatsu-lab/alpaca_farm) for their open-source efforts in democratizing large language models. 178 | -------------------------------------------------------------------------------- /assets/images/Dromedary-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/SALMON/7392a5618c19a59d218ebbc2f50a5940a92304c5/assets/images/Dromedary-2.png -------------------------------------------------------------------------------- /assets/images/salmon_comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/SALMON/7392a5618c19a59d218ebbc2f50a5940a92304c5/assets/images/salmon_comparison.png -------------------------------------------------------------------------------- /assets/images/salmon_logo_with_text.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/SALMON/7392a5618c19a59d218ebbc2f50a5940a92304c5/assets/images/salmon_logo_with_text.jpeg -------------------------------------------------------------------------------- /assets/images/salmon_pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/SALMON/7392a5618c19a59d218ebbc2f50a5940a92304c5/assets/images/salmon_pipeline.png -------------------------------------------------------------------------------- /inference/README.md: -------------------------------------------------------------------------------- 1 | # Demo 2 | 3 | To run our demo, you need to prepare *Dromedary-2* checkpoints locally. Please follow the [instructions here](https://github.com/IBM/SALMON#model-weights). This demo is adpted from [Guanaco demo notebook](https://colab.research.google.com/drive/17XEqL1JcmVWjHkT-WczdYkJlNINacwG7?usp=sharing). The code is tested on a single A100-80GB GPU, with a peak GPU memory usage < 48GB. 4 | 5 | ```bash 6 | python -u demo.py \ 7 | --model-name "/path/to/llama-2-70b-hf" \ 8 | --adapters-name "/path/to/qlora_adapter_models" 9 | ``` 10 | -------------------------------------------------------------------------------- /inference/demo.py: -------------------------------------------------------------------------------- 1 | # Load the model. 2 | # Note: It can take a while to download LLaMA and add the adapter modules. 3 | 4 | import torch 5 | import fire 6 | 7 | import datetime 8 | import os 9 | from threading import Event, Thread 10 | from uuid import uuid4 11 | 12 | import gradio as gr 13 | import requests 14 | 15 | from peft import PeftModel 16 | from transformers import ( 17 | BitsAndBytesConfig, 18 | AutoModelForCausalLM, 19 | LlamaTokenizerFast, 20 | StoppingCriteria, 21 | StoppingCriteriaList, 22 | TextIteratorStreamer, 23 | ) 24 | 25 | torch.backends.cuda.matmul.allow_tf32 = True 26 | 27 | 28 | def main( 29 | model_name: str = "/path/to/llama-2-70b-hf", 30 | adapters_name: str = "/path/to/adapter_model", 31 | tokenizer_name: str = "TheBloke/dromedary-65b-lora-HF", # a random loadable llama tokenizer 32 | max_queue_size: int = 16, 33 | sharable_link: bool = True, 34 | ): 35 | print(f"Starting to load the model {model_name} into memory") 36 | 37 | m = AutoModelForCausalLM.from_pretrained( 38 | model_name, 39 | load_in_4bit=True, 40 | torch_dtype=torch.bfloat16, 41 | device_map={"": 0}, 42 | quantization_config=BitsAndBytesConfig( 43 | load_in_4bit=True, 44 | bnb_4bit_compute_dtype=torch.bfloat16, 45 | bnb_4bit_use_double_quant=True, 46 | bnb_4bit_quant_type="nf4", 47 | ), 48 | ) 49 | m = PeftModel.from_pretrained(m, adapters_name, is_trainable=False) 50 | 51 | start_message = ( 52 | "# Dromedary" 53 | "\n\n## System Overview" 54 | "\n\nConsider an AI assistant whose codename is Dromedary, developed by the Self-Align team. " 55 | "Dromedary is trained on data from before Sept-2022, and it endeavors to be a helpful, ethical and reliable assistant." 56 | "\n\n## User Conversation\n\n" 57 | ) 58 | 59 | tok = LlamaTokenizerFast.from_pretrained(tokenizer_name) 60 | 61 | print(f"Successfully loaded the model {model_name} into memory") 62 | 63 | # Setup the gradio Demo. 64 | 65 | stop_tokens = ["### User"] 66 | 67 | class StopOnTokens(StoppingCriteria): 68 | def __call__( 69 | self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs 70 | ) -> bool: 71 | decoded_tokens = tok.batch_decode(input_ids, skip_special_tokens=True) 72 | 73 | for stop_token in stop_tokens: 74 | if decoded_tokens[0].strip().endswith(stop_token): 75 | return True 76 | return False 77 | 78 | def convert_history_to_text(history): 79 | text = start_message + "".join( 80 | [ 81 | "".join( 82 | [ 83 | f"### User\n{item[0]}\n\n", 84 | f"### Dromedary\n{item[1]}\n\n", 85 | ] 86 | ) 87 | for item in history[:-1] 88 | ] 89 | ) 90 | text += "".join( 91 | [ 92 | "".join( 93 | [ 94 | f"### User\n{history[-1][0]}\n\n", 95 | f"### Dromedary\n{history[-1][1]}", 96 | ] 97 | ) 98 | ] 99 | ) 100 | return text 101 | 102 | def log_conversation(conversation_id, history, messages, generate_kwargs): 103 | logging_url = os.getenv("LOGGING_URL", None) 104 | if logging_url is None: 105 | return 106 | 107 | timestamp = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S") 108 | 109 | data = { 110 | "conversation_id": conversation_id, 111 | "timestamp": timestamp, 112 | "history": history, 113 | "messages": messages, 114 | "generate_kwargs": generate_kwargs, 115 | } 116 | 117 | try: 118 | requests.post(logging_url, json=data) 119 | except requests.exceptions.RequestException as e: 120 | print(f"Error logging conversation: {e}") 121 | 122 | def user(message, history): 123 | # Append the user's message to the conversation history 124 | return "", history + [[message, ""]] 125 | 126 | def bot( 127 | history, temperature, top_p, max_new_tokens, repetition_penalty, conversation_id 128 | ): 129 | # Initialize a StopOnTokens object 130 | stop = StopOnTokens() 131 | 132 | # Construct the input message string for the model by concatenating the current system message and conversation history 133 | messages = convert_history_to_text(history) 134 | 135 | print(f"history:") 136 | print(messages) 137 | 138 | # Tokenize the messages string 139 | input_ids = tok(messages, return_tensors="pt").input_ids 140 | input_ids = input_ids.to(m.device) 141 | streamer = TextIteratorStreamer( 142 | tok, timeout=10.0, skip_prompt=True, skip_special_tokens=True 143 | ) 144 | generate_kwargs = dict( 145 | input_ids=input_ids, 146 | max_new_tokens=max_new_tokens, 147 | temperature=temperature, 148 | do_sample=temperature > 0.0, 149 | top_p=top_p, 150 | repetition_penalty=repetition_penalty, 151 | streamer=streamer, 152 | stopping_criteria=StoppingCriteriaList([stop]), 153 | ) 154 | 155 | stream_complete = Event() 156 | 157 | def generate_and_signal_complete(): 158 | m.generate(**generate_kwargs) 159 | stream_complete.set() 160 | 161 | def log_after_stream_complete(): 162 | stream_complete.wait() 163 | log_conversation( 164 | conversation_id, 165 | history, 166 | messages, 167 | { 168 | "max_new_tokens": max_new_tokens, 169 | "top_p": top_p, 170 | "temperature": temperature, 171 | "repetition_penalty": repetition_penalty, 172 | }, 173 | ) 174 | 175 | t1 = Thread(target=generate_and_signal_complete) 176 | t1.start() 177 | 178 | t2 = Thread(target=log_after_stream_complete) 179 | t2.start() 180 | 181 | # Initialize an empty string to store the generated text 182 | partial_text = "" 183 | for new_text in streamer: 184 | partial_text += new_text 185 | history[-1][1] = partial_text.split("### User")[0] 186 | yield history 187 | 188 | print(f"output:") 189 | print(history[-1][1]) 190 | 191 | def get_uuid(): 192 | return str(uuid4()) 193 | 194 | with gr.Blocks( 195 | theme=gr.themes.Soft(), 196 | css=".disclaimer {font-variant-caps: all-small-caps;}", 197 | ) as demo: 198 | conversation_id = gr.State(get_uuid) 199 | gr.Markdown( 200 | """Dromedary-2 Demo 201 | """ 202 | ) 203 | chatbot = gr.Chatbot().style(height=500) 204 | with gr.Row(): 205 | with gr.Column(): 206 | msg = gr.Textbox( 207 | label="Chat Message Box", 208 | placeholder="Chat Message Box", 209 | show_label=False, 210 | ).style(container=False) 211 | with gr.Column(): 212 | with gr.Row(): 213 | submit = gr.Button("Submit") 214 | # stop = gr.Button("Stop") 215 | clear = gr.Button("Clear") 216 | with gr.Row(): 217 | with gr.Accordion("Advanced Options:", open=False): 218 | with gr.Row(): 219 | with gr.Column(): 220 | with gr.Row(): 221 | temperature = gr.Slider( 222 | label="Temperature", 223 | value=0.7, 224 | minimum=0.0, 225 | maximum=1.0, 226 | step=0.1, 227 | interactive=True, 228 | info="Higher values produce more diverse outputs", 229 | ) 230 | with gr.Column(): 231 | with gr.Row(): 232 | top_p = gr.Slider( 233 | label="Top-p (nucleus sampling)", 234 | value=0.9, 235 | minimum=0.0, 236 | maximum=1, 237 | step=0.01, 238 | interactive=True, 239 | info=( 240 | "Sample from the smallest possible set of tokens whose cumulative probability " 241 | "exceeds top_p. Set to 1 to disable and sample from all tokens." 242 | ), 243 | ) 244 | with gr.Column(): 245 | with gr.Row(): 246 | max_new_tokens = gr.Slider( 247 | label="Response Length", 248 | value=768, 249 | minimum=64, 250 | maximum=1024, 251 | step=64, 252 | interactive=True, 253 | info="Sample from a shortlist of top-k tokens — 0 to disable and sample from all tokens.", 254 | ) 255 | with gr.Column(): 256 | with gr.Row(): 257 | repetition_penalty = gr.Slider( 258 | label="Repetition Penalty", 259 | value=1.0, 260 | minimum=1.0, 261 | maximum=2.0, 262 | step=0.1, 263 | interactive=True, 264 | info="Penalize repetition — 1.0 to disable.", 265 | ) 266 | with gr.Row(): 267 | gr.Markdown( 268 | "Disclaimer: The model can produce factually incorrect output, and should not be relied on to produce " 269 | "factually accurate information. The model was trained on various public datasets; while great efforts " 270 | "have been taken to clean the pretraining data, it is possible that this model could generate lewd, " 271 | "biased, or otherwise offensive outputs.", 272 | elem_classes=["disclaimer"], 273 | ) 274 | with gr.Row(): 275 | gr.Markdown( 276 | "[Privacy policy](https://gist.github.com/samhavens/c29c68cdcd420a9aa0202d0839876dac)", 277 | elem_classes=["disclaimer"], 278 | ) 279 | 280 | submit_event = msg.submit( 281 | fn=user, 282 | inputs=[msg, chatbot], 283 | outputs=[msg, chatbot], 284 | queue=False, 285 | ).then( 286 | fn=bot, 287 | inputs=[ 288 | chatbot, 289 | temperature, 290 | top_p, 291 | max_new_tokens, 292 | repetition_penalty, 293 | conversation_id, 294 | ], 295 | outputs=chatbot, 296 | queue=True, 297 | ) 298 | submit_click_event = submit.click( 299 | fn=user, 300 | inputs=[msg, chatbot], 301 | outputs=[msg, chatbot], 302 | queue=False, 303 | ).then( 304 | fn=bot, 305 | inputs=[ 306 | chatbot, 307 | temperature, 308 | top_p, 309 | max_new_tokens, 310 | repetition_penalty, 311 | conversation_id, 312 | ], 313 | outputs=chatbot, 314 | queue=True, 315 | ) 316 | # stop.click( 317 | # fn=None, 318 | # inputs=None, 319 | # outputs=None, 320 | # cancels=[submit_event, submit_click_event], 321 | # queue=False, 322 | # ) 323 | clear.click(lambda: None, None, chatbot, queue=False) 324 | 325 | demo.queue(max_size=max_queue_size, concurrency_count=1) 326 | 327 | # Launch your Dromedary-2 Demo! 328 | demo.launch(share=sharable_link) 329 | 330 | 331 | if __name__ == "__main__": 332 | fire.Fire(main) 333 | -------------------------------------------------------------------------------- /prompts/dromedary_inference_prompt.txt: -------------------------------------------------------------------------------- 1 | # Dromedary 2 | 3 | ## System Overview 4 | 5 | Consider an AI assistant whose codename is Dromedary, developed by the Self-Align team. Dromedary is trained on data from before Sept-2022, and it endeavors to be a helpful, ethical and reliable assistant. 6 | 7 | ## User Conversation 8 | 9 | ### User -------------------------------------------------------------------------------- /prompts/pmp_reward_model_prompt.txt: -------------------------------------------------------------------------------- 1 | # System 2 | 3 | You are a reviewer whose goal is to judge the quality of the AI system's responses to instructions. 4 | 5 | ## Annotation Example 6 | 7 | ### AI system's Response 8 | {Output} 9 | 10 | ### Instruction to the AI system 11 | {Input} 12 | 13 | ## Annotation Guideline 14 | 15 | Your task is to decide the quality of the response. There are several dimensions that you can think along. Consider the following questions: 16 | 17 | - Is the response helpful? For example, if the instruction asked for a recipe for healthy food, and the response is a useful recipe, then we can consider it helpful. 18 | - Is the response language natural? For example, AI responses often have repetitions, which is not natural. 19 | - Is the response factual/accurate? For example, AI responses often make up new information. For example, if the response claims that Donald Trump is the current U.S. president, then you should consider it inaccurate. 20 | - and so on ... ultimately, you should decide whether the response is good or bad. 21 | 22 | ## Reviewer 23 | The quality of the output is -------------------------------------------------------------------------------- /prompts/principles/principle_collection_harmless.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "dimension": "Non-Discriminatory", 4 | "definition": "The AI should not produce content that is offensive, discriminatory, or prejudiced in nature, either directly or through subtext or bias.", 5 | "weight": 10 6 | }, 7 | { 8 | "dimension": "Refusal to Aid Dangerous Acts", 9 | "definition": "When prompted to assist in harmful activities (e.g. building a bomb), the AI should promptly and politely decline. It should have the capability to identify and counteract concealed requests for harmful actions.", 10 | "weight": 5 11 | }, 12 | { 13 | "dimension": "Consequential Advice Awareness", 14 | "definition": "The AI should be cognizant when it might be providing critical or sensitive guidance. It should exercise restraint and provide responses with the necessary caution.", 15 | "weight": 4 16 | }, 17 | { 18 | "dimension": "Cultural Sensitivity", 19 | "definition": "Understanding that perceptions of harm vary among individuals and cultures, the AI should be adaptive and respectful. It should acknowledge and respect the diverse range of cultural norms and values.", 20 | "weight": 4 21 | }, 22 | { 23 | "dimension": "Context Awareness", 24 | "definition": "The AI's responses should be tailored to the nature of the user query, taking into account who is interacting with the AI, as well as the situational context in which the assistant is being engaged.", 25 | "weight": 4 26 | }, 27 | { 28 | "dimension": "Guard Against Amplification", 29 | "definition": "To prevent reinforcing harmful beliefs or stereotypes, the AI should avoid echoing or amplifying potentially harmful statements made by users, instead promoting neutral or positive alternatives.", 30 | "weight": 4 31 | }, 32 | { 33 | "dimension": "Empathy-driven Responses", 34 | "definition": "The AI should gauge the emotional tone of user queries and respond with sensitivity, especially in situations where users seem distressed or vulnerable.", 35 | "weight": 1 36 | }, 37 | { 38 | "dimension": "Proactive Harm Prevention", 39 | "definition": "Rather than just reactively refusing harmful requests, the AI should actively provide alternatives or suggestions that promote safety and well-being.", 40 | "weight": 1 41 | }, 42 | { 43 | "dimension": "Transparency on Biases", 44 | "definition": "The AI should openly acknowledge when its responses might be influenced by historical data biases, offering the user a clearer picture of its limitations.", 45 | "weight": 1 46 | }, 47 | { 48 | "dimension": "Evolutive Learning", 49 | "definition": "The AI should continuously learn and adapt from feedback on harmful outputs, minimizing the risk of repeating harmful behaviors in future interactions.", 50 | "weight": 1 51 | }, 52 | { 53 | "dimension": "Mental Health Awareness", 54 | "definition": "Recognizing that conversations can sometimes touch on mental health issues, the AI should handle such interactions with extreme care, always prioritizing the well-being of the user and never offering medical advice unless explicitly trained to do so.", 55 | "weight": 1 56 | } 57 | ] -------------------------------------------------------------------------------- /prompts/principles/principle_collection_honest.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "dimension": "Unwavering Accuracy", 4 | "definition": "The AI must consistently provide accurate and verifiable information in its responses.", 5 | "weight": 5 6 | }, 7 | { 8 | "dimension": "Consistent Calibration", 9 | "definition": "The AI should ensure that its confidence assertions align with its actual accuracy, reflecting genuine knowledge estimation.", 10 | "weight": 5 11 | }, 12 | { 13 | "dimension": "Transparent Uncertainty Expression", 14 | "definition": "When uncertain, the AI should transparently express its level of uncertainty, ensuring users are never misled by vague or overly confident statements.", 15 | "weight": 4 16 | }, 17 | { 18 | "dimension": "Honest Self-representation", 19 | "definition": "The AI must be forthright about its own capabilities and should not feign expertise or knowledge it doesn't possess.", 20 | "weight": 5 21 | }, 22 | { 23 | "dimension": "Introspective Honesty", 24 | "definition": "If asked about its workings or limitations, the AI should provide clear and honest descriptions about its design and operational boundaries.", 25 | "weight": 4 26 | }, 27 | { 28 | "dimension": "Bias Recognition", 29 | "definition": "The AI should acknowledge potential biases inherent in its data or model and inform users when a topic or response might be influenced by such biases.", 30 | "weight": 4 31 | }, 32 | { 33 | "dimension": "Correction Commitment", 34 | "definition": "When confronted with an error in its response, the AI should admit the mistake and strive to provide the correct information.", 35 | "weight": 1 36 | }, 37 | { 38 | "dimension": "Avoidance of Pretense", 39 | "definition": "The AI should avoid mimicking the behavior or response patterns of experts unless it's providing information based on verified data or recognized knowledge.", 40 | "weight": 1 41 | } 42 | ] -------------------------------------------------------------------------------- /prompts/principles/principle_collection_non_evasive.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "dimension": "Educational and Engaging", 4 | "definition": "The AI's responses should be enriched with accurate, relevant, and current information, serving to educate while keeping the user engaged.", 5 | "weight": 1 6 | }, 7 | { 8 | "dimension": "Creative", 9 | "definition": "The AI should be adept at generating original content, such as poems, stories, code, essays, songs, parodies, summaries, translations, and more.", 10 | "weight": 1 11 | }, 12 | { 13 | "dimension": "Multilingual", 14 | "definition": "The AI should be capable of conversing in the language used by the user, for instance, replying in 中文 if the query is in 中文.", 15 | "weight": 1 16 | }, 17 | { 18 | "dimension": "Comprehensive", 19 | "definition": "For information-seeking tasks, the AI should offer extensive and relevant details to ensure a thorough and in-depth response. It should impartially and extensively present arguments from diverse perspectives when dealing with contentious topics.", 20 | "weight": 3 21 | }, 22 | { 23 | "dimension": "Natural Language", 24 | "definition": "The AI should respond with diverse and natural language, avoiding repetition and awkward phrasing.", 25 | "weight": 3 26 | }, 27 | { 28 | "dimension": "Consistent Reasoning", 29 | "definition": "The AI should deliver responses that are clear and logically sound, ensuring they do not contain self-contradictions.", 30 | "weight": 1 31 | }, 32 | { 33 | "dimension": "Numerical Sensitive", 34 | "definition": "The AI should ensure that any numerical specifications given in the instruction are carefully adhered to, avoiding any errors in numerical computations.", 35 | "weight": 1 36 | }, 37 | { 38 | "dimension": "Analytical Structure", 39 | "definition": "For information analysis tasks, the AI should articulate its response in a manner that begins with a summary, followed by numerous key points, each underscored by a thorough analysis.", 40 | "weight": 3 41 | }, 42 | { 43 | "dimension": "Vivid", 44 | "definition": "The AI should employ vibrant, energetic language, enhancing user engagement by making all interactions lively and dynamic.", 45 | "weight": 1 46 | } 47 | ] -------------------------------------------------------------------------------- /prompts/principles/principle_collection_ppo.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "dimension": "Honest and Accurate", 4 | "definition": "The AI must furnish reliable and factual information, and candidly disclose its limitations and the extent of its knowledge.", 5 | "weight": 3 6 | }, 7 | { 8 | "dimension": "Ethical", 9 | "definition": "The AI should produce content that is free from offensive, discriminatory, or harmful material, and should not participate in or endorse risky activities.", 10 | "weight": 2 11 | }, 12 | { 13 | "dimension": "Educational and Engaging", 14 | "definition": "The AI's responses should be enriched with accurate, relevant, and current information, serving to educate while keeping the user engaged.", 15 | "weight": 1 16 | }, 17 | { 18 | "dimension": "Creative", 19 | "definition": "The AI should be adept at generating original content, such as poems, stories, code, essays, songs, parodies, summaries, translations, and more.", 20 | "weight": 1 21 | }, 22 | { 23 | "dimension": "Multilingual", 24 | "definition": "The AI should be capable of conversing in the language used by the user, for instance, replying in 中文 if the query is in 中文.", 25 | "weight": 1 26 | }, 27 | { 28 | "dimension": "Comprehensive", 29 | "definition": "For information-seeking tasks, the AI should offer extensive and relevant details to ensure a thorough and in-depth response. It should impartially and extensively present arguments from diverse perspectives when dealing with contentious topics.", 30 | "weight": 3 31 | }, 32 | { 33 | "dimension": "Natural Language", 34 | "definition": "The AI should respond with diverse and natural language, avoiding repetition and awkward phrasing.", 35 | "weight": 3 36 | }, 37 | { 38 | "dimension": "Consistent Reasoning", 39 | "definition": "The AI should deliver responses that are clear and logically sound, ensuring they do not contain self-contradictions.", 40 | "weight": 1 41 | }, 42 | { 43 | "dimension": "Numerical Sensitive", 44 | "definition": "The AI should ensure that any numerical specifications given in the instruction are carefully adhered to, avoiding any errors in numerical computations.", 45 | "weight": 1 46 | }, 47 | { 48 | "dimension": "Analytical Structure", 49 | "definition": "For information analysis tasks, the AI should articulate its response in a manner that begins with a summary, followed by numerous key points, each underscored by a thorough analysis.", 50 | "weight": 3 51 | }, 52 | { 53 | "dimension": "Vivid", 54 | "definition": "The AI should employ vibrant, energetic language, enhancing user engagement by making all interactions lively and dynamic.", 55 | "weight": 1 56 | }, 57 | { 58 | "dimension": "Privacy Protection", 59 | "definition": "The AI should avoid generating any personal identifiable information (PII) or external URLs in its responses.", 60 | "weight": 1 61 | }, 62 | { 63 | "dimension": "Candor", 64 | "definition": "The AI should avoid sharing false information. If a question does not make any sense, or is not factually coherent, the AI should explain why instead of answering something not correct.", 65 | "weight": 4 66 | }, 67 | { 68 | "dimension": "Stand-Alone", 69 | "definition": "The AI must function as a stand-alone text-based system, avoiding interactions with any external sources, including URLs, images, or videos.", 70 | "weight": 2 71 | } 72 | ] -------------------------------------------------------------------------------- /prompts/principles/principle_collection_rm.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "dimension": "Overall Helpful", 4 | "definition": "The AI’s response should precisely serve the user's needs and interests by being well-rounded, combining conciseness, accuracy, informativeness, and other critical qualities.", 5 | "negative_definition": "The AI should avoid generating a helpful response." 6 | }, 7 | { 8 | "dimension": "Concise", 9 | "definition": "The response should efficiently address the task or answer the question, communicating the necessary information with brevity and clarity.", 10 | "negative_definition": "The response should avoid efficiently addressing the task or answering the question." 11 | }, 12 | { 13 | "dimension": "Honest and Accurate", 14 | "definition": "The AI must furnish reliable and factual information, and candidly disclose its limitations and the extent of its knowledge.", 15 | "negative_definition": "The AI should avoid furnishing reliable and factual information." 16 | }, 17 | { 18 | "dimension": "Ethical", 19 | "definition": "The AI should produce content that is free from offensive, discriminatory, or harmful material, and should not participate in or endorse risky activities.", 20 | "negative_definition": "The AI should avoid producing content that is free from offensive, discriminatory, or harmful material." 21 | }, 22 | { 23 | "dimension": "Natural and Fluent", 24 | "definition": "The AI should employ language that flows smoothly and is free from repetitive or awkward constructs.", 25 | "negative_definition": "The AI should avoid employing language that flows smoothly." 26 | }, 27 | { 28 | "dimension": "Specific", 29 | "definition": "The AI’s response should be directly pertinent to the query, addressing the particular subject in the instruction explicitly.", 30 | "negative_definition": "The AI’s response should avoid being directly pertinent to the query." 31 | }, 32 | { 33 | "dimension": "Educational and Engaging", 34 | "definition": "The AI's responses should be enriched with accurate, relevant, and current information, serving to educate while keeping the user engaged.", 35 | "negative_definition": "The AI's responses should avoid being enriched with accurate, relevant, and current information." 36 | }, 37 | { 38 | "dimension": "Methodical", 39 | "definition": "The AI should employ a structured approach when providing solutions, presenting logical and step-by-step explanation before arriving at a conclusion.", 40 | "negative_definition": "The AI should avoid employing a structured approach when providing solutions." 41 | }, 42 | { 43 | "dimension": "Multilingual", 44 | "definition": "The AI should be capable of conversing in the language used by the user, for instance, replying in 中文 if the query is in 中文.", 45 | "negative_definition": "The AI should avoid being capable of conversing in the language used by the user." 46 | }, 47 | { 48 | "dimension": "Creative", 49 | "definition": "The AI should be adept at generating original content, such as poems, stories, code, essays, songs, parodies, summaries, translations, and more.", 50 | "negative_definition": "The AI should avoid being adept at generating original content." 51 | }, 52 | { 53 | "dimension": "Comprehensive", 54 | "definition": "The AI should offer extensive and relevant details to ensure a thorough and in-depth response. It should impartially and extensively present arguments from diverse perspectives when dealing with contentious topics.", 55 | "negative_definition": "The AI should avoid offering extensive and relevant details." 56 | } 57 | ] -------------------------------------------------------------------------------- /prompts/salmon_reward_model_prompt_v0.txt: -------------------------------------------------------------------------------- 1 | # System 2 | 3 | You are a reviewer whose goal is to judge the quality of the AI system's responses to instructions. 4 | 5 | ## Annotation Example 6 | 7 | ### AI system's Response 8 | {Output} 9 | 10 | ### Instruction to the AI system 11 | {Input} 12 | 13 | ## Annotation Guideline 14 | 15 | Your task is to evaluate the quality of the response. There are several dimensions you should consider in your evaluation: 16 | 17 | {Dimensions} 18 | 19 | A good response should meet all of the above criteria. 20 | 21 | ## Reviewer 22 | The quality of the output is -------------------------------------------------------------------------------- /prompts/salmon_reward_model_prompt_v1.txt: -------------------------------------------------------------------------------- 1 | # System 2 | 3 | You are a reviewer whose goal is to judge the quality of the AI system's responses to instructions. 4 | 5 | ## Annotation Example 6 | 7 | ### AI system's Response 8 | {Output} 9 | 10 | ### Instruction to the AI system 11 | {Input} 12 | 13 | ## Annotation Guideline 14 | 15 | Your task is to evaluate the quality of the response. There are several dimensions you should consider in your evaluation: 16 | 17 | - The AI must avoid generating commentary about its own responses or engaging in self-praise. It needs to stay humble. 18 | - The AI must prioritize fulfilling the instruction, avoiding high-level pros-and-cons analysis, step-by-step instructions, or executable code. 19 | - The AI should keep the response straightforward and on-point, answering the question or completing the task without unnecessary examples. 20 | {Dimensions} 21 | 22 | A good response should meet all of the above criteria. 23 | 24 | ## Reviewer 25 | The quality of the output is -------------------------------------------------------------------------------- /prompts/synthetic_preference_prompt.txt: -------------------------------------------------------------------------------- 1 | # Dromedary 2 | 3 | ## System Overview 4 | 5 | Hi, Dromedary! You are a helpful assistant whose goal is to select the preferred (least wrong) AI model's output for a given instruction. 6 | 7 | You will read a batch of examples, which are composed of the following: 8 | 9 | 1. an Instruction we give to the AI system 10 | 2. Output (a), the first output from the AI system 11 | 3. Output (b), the second output from the AI system 12 | 13 | ## User Conversation 14 | 15 | ### User 16 | Please select the preferred (least wrong) output for a given instruction. 17 | 18 | #### Instruction 19 | {UserInstruction} 20 | 21 | #### Output (a) 22 | {OutputA} 23 | 24 | #### Output (b) 25 | {OutputB} 26 | 27 | #### Annotation Guide 28 | 29 | To simplify the evaluation process, one aspect to consider this time is as follows: 30 | 31 | {Dimension}: {Definition} 32 | 33 | Based on the provided definition, please select the preferred output for the given instruction. 34 | 35 | ### Dromedary 36 | Sure! After carefully reading the Instruction, Output (a), Output (b), and the definition of {Dimension}, I think the more {Dimension} output is Output ( -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch >= 2.0.0 2 | accelerate==0.21.0 3 | bitsandbytes==0.41.0 4 | deepspeed==0.9.3 5 | transformers==4.31.0 6 | wandb==0.15.3 7 | tiktoken==0.4.0 8 | tokenizers>=0.12.1 9 | peft==0.4.0 10 | gradio==3.35.2 11 | gradio_client==0.2.9 12 | sentencepiece==0.1.99 13 | tensorboard >= 2.12.0 14 | fairscale 15 | fire 16 | einops 17 | tqdm 18 | fastapi 19 | requests 20 | numpy 21 | uvicorn 22 | markdown2[all] 23 | shortuuid 24 | scipy -------------------------------------------------------------------------------- /training/README.md: -------------------------------------------------------------------------------- 1 | # Training Experiences 2 | 3 | The whole **SALMON** process involves three stages: Synthetic Preference Collection, Training the Principle-following Reward Model, and RL Training with the Principle-following Reward Model. In our [paper](https://arxiv.org/abs/2310.05910), we provide a detailed description of each of these stages. 4 | 5 | ## Prerequisites 6 | 7 | For efficiency concerns, we utilize the [model parallel](https://github.com/facebookresearch/fairscale/tree/main/fairscale/nn/model_parallel) scheme from [llama](https://github.com/facebookresearch/llama) when sampling responses from the inital policy model. To prepare the sharded model checkpoints of LLaMA or Dromedary on your own machine/cluster, please refer to the [inference guide in Dromedary](https://github.com/IBM/Dromedary/tree/main/inference). 8 | 9 | ## Step 1: Collecting Principle-Driven Synthetic Preference 10 | 11 | We sample two responses from the initial policy model, and use the policy model itself to select the preferred response based on a certain human-written principle. 12 | 13 | Before diving into the experiments, please install the [`llama_dromedary` pacakge in Dromedary](https://github.com/IBM/Dromedary/tree/main/llama_dromedary) to enable model parallelism. 14 | 15 | ### Step 1.1: Preparing OASST1 Prompt Dataset 16 | 17 |
18 | Running the code 19 | 20 | ```bash 21 | cd step1_synthetic_preference_collection 22 | 23 | python -u clean_oasst1_prompts.py \ 24 | --output_file "/path/to/your/oasst1_prompts.json" 25 | ``` 26 | 27 |
28 | 29 | ### Step 1.2: Sampling Responses from the Policy Model 30 | 31 |
32 | Running the code 33 | 34 | ```bash 35 | salloc --nodes 8 --time 6:00:00 --gres=gpu:32g:6 srun bash scripts/generate_oasst1_response0.sh 36 | salloc --nodes 8 --time 6:00:00 --gres=gpu:32g:6 srun bash scripts/generate_oasst1_response1.sh 37 | ``` 38 | 39 |
40 | 41 | ### Step 1.3: Collecting Synthetic Preferences 42 | 43 |
44 | Running the code 45 | 46 | ```bash 47 | salloc --nodes 1 --time 24:00:00 --gres=gpu:80g:8 srun bash scripts/generate_synthetic_preference.sh 48 | ``` 49 | 50 |
51 | 52 | ## Step 2: Training the Principle-following Reward Model 53 | 54 | Next, for each user prompt, a subset of principles is randomly sampled from the established principle list with certain principles being randomly negated. The user prompt, model responses, and the sub-sampled principles are aggregated as a single training instance for the reward model. 55 | 56 | ### Step 2.1: Aggregating the Collected Preferences 57 | 58 |
59 | Running the code 60 | 61 | ```bash 62 | cd step2_rm_training 63 | 64 | python -u aggregate_synthetic_preference.py \ 65 | --response_pattern "/path/to/your/oasst1_dromedary2_sft_response*.json" \ 66 | --preference_file "/path/to/your/oasst1_dromedary2_sft_preference.json" \ 67 | --output_file "/path/to/your/oasst1_dromedary2_sft_aggregated_preference.json" 68 | ``` 69 | 70 |
71 | 72 | ### Step 2.2: Preference Modeling Pre-training (PMP) of the Reward Model 73 | 74 |
75 | Running the code 76 | 77 | ```bash 78 | python -u clean_pmp_data.py \ 79 | --output_file "/path/to/your/pmp_data.json" 80 | 81 | salloc --nodes 1 --time 24:00:00 --gres=gpu:80g:8 srun bash scripts/train_reward_model_70b_qlora_pmp.sh 82 | ``` 83 | 84 |
85 | 86 | ### Step 2.3: Fine-tune the Reward Model with Principle-driven Preferences 87 | 88 |
89 | Running the code 90 | 91 | ```bash 92 | salloc --nodes 1 --time 24:00:00 --gres=gpu:80g:8 srun bash scripts/train_reward_model_70b_qlora_ft.sh 93 | ``` 94 | 95 |
96 | 97 | ## Step 3: RL Training with the Principle-following Reward Model 98 | 99 | Finally, we train the policy model with the principle-following reward model. We use the diverse user prompts from [ShareGPT](https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered), [Dolly-15k](https://huggingface.co/datasets/databricks/databricks-dolly-15k), [OpenAssistant](https://huggingface.co/datasets/OpenAssistant/oasst1), [OpenOrca](https://huggingface.co/datasets/Open-Orca/OpenOrca), and [MATH](https://huggingface.co/datasets/competition_math). 100 | 101 | ### Step 3.1: Preparing the Prompt Dataset for RL Training 102 | 103 |
104 | Running the code 105 | 106 | ```bash 107 | cd step3_ppo_training 108 | 109 | python subsample_openorca_prompts.py \ 110 | --train_data_path "/path/to/your/l1M-GPT4-Augmented.parquet (obtained from OpenOrca)" \ 111 | --output_path "/path/to/your/openorca_prompts.json" 112 | 113 | python aggregate_sharegpt_prompts.py \ 114 | --data_files=zetavg/ShareGPT-Processed,path/to/sg_90k_part1.json.json,path/to/sg_90k_part1.json (obtained from ShareGPT_Vicuna_unfiltered) \ 115 | --output_path "/path/to/sharegpt_prompts.json" 116 | 117 | python clean_and_merge_prompts.py \ 118 | --sharegpt_prompt_path "/path/to/sharegpt_prompts.json" \ 119 | --openorca_prompt_path "/path/to/openorca_prompts.json" \ 120 | --output_file "/path/to/your/salmon_merged_prompts.json" 121 | ``` 122 | 123 |
124 | 125 | ### Step 3.2: RL Training 126 | 127 |
128 | Running the code 129 | 130 | ```bash 131 | salloc --nodes 6 --time 24:00:00 --gres=gpu:80g:8 srun bash scripts/train_ppo_model_70b_qlora_salmon.sh 132 | ``` 133 | -------------------------------------------------------------------------------- /training/data_utils/common_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The Alpaca Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import argparse 16 | import glob 17 | import os 18 | import random 19 | from typing import ( 20 | Callable, 21 | Dict, 22 | Optional, 23 | Sequence, 24 | Union, 25 | Mapping, 26 | Any, 27 | ) 28 | 29 | import numpy as np 30 | import torch 31 | import torch.nn.functional as F 32 | from torch.utils.data import DataLoader 33 | 34 | Numeric = Union[int, float] 35 | 36 | 37 | def zip_(*args: Sequence): 38 | """Assert sequences of same length before zipping.""" 39 | if len(args) == 0: 40 | return [] 41 | assert alleq(args, lambda x, y: len(x) == len(y)) 42 | return zip(*args) 43 | 44 | 45 | def mean(*seqs: Sequence[Numeric]) -> Union[Numeric, Sequence[Numeric]]: 46 | singleton = len(seqs) == 1 47 | means = [float(np.mean(seq)) for seq in seqs] 48 | return means[0] if singleton else means 49 | 50 | 51 | def alleq(l: Sequence, f: Optional[Callable] = lambda x, y: x == y): 52 | """Check all arguments in a sequence are equal according to a given criterion. 53 | 54 | Args: 55 | f: A bi-variate boolean function. 56 | l: A list/tuple. 57 | 58 | Returns: 59 | True if everything is equal; otherwise False. 60 | """ 61 | return all(f(l[0], li) for li in l[1:]) 62 | 63 | 64 | def flatten_dict(nested, sep=".", postprocess_fn=lambda *args: args): 65 | def rec(nest, prefix, into): 66 | for k, v in nest.items(): 67 | if sep in k: 68 | raise ValueError(f"separator '{sep}' not allowed to be in key '{k}'") 69 | if isinstance(v, dict): # collections.Mapping fails in py3.10. 70 | rec(v, prefix + k + sep, into) 71 | else: 72 | v = postprocess_fn(v) 73 | into[prefix + k] = v 74 | 75 | flat = {} 76 | rec(nested, "", flat) 77 | return flat 78 | 79 | 80 | def unpack_dict( 81 | d: Dict, keys: Sequence[str], return_type: type = tuple 82 | ) -> Union[Sequence, Dict]: 83 | if return_type in (tuple, list): 84 | return return_type(d[key] for key in keys) 85 | elif return_type == dict: 86 | return {key: d[key] for key in keys} 87 | else: 88 | raise ValueError(f"Unknown return_type: {return_type}") 89 | 90 | 91 | def merge_dict(dicts: Sequence[dict], merge_fn: Callable = lambda *args: args) -> dict: 92 | """Merge a sequence of dicts (with the same set of keys) into a single dict.""" 93 | if len(dicts) == 0: 94 | return dict() 95 | return {key: merge_fn([dict_[key] for dict_ in dicts]) for key in dicts[0].keys()} 96 | 97 | 98 | def prepare_inputs( 99 | data: Union[torch.Tensor, Any], device: Union[str, int, torch.device] 100 | ) -> Union[torch.Tensor, Any]: 101 | if isinstance(data, Mapping): 102 | return type(data)( 103 | {k: prepare_inputs(v, device) for k, v in data.items()} 104 | ) # noqa 105 | elif isinstance(data, (tuple, list)): 106 | return type(data)(prepare_inputs(v, device) for v in data) 107 | elif isinstance(data, torch.Tensor): 108 | return data.to(device) # This can break with deepspeed. 109 | return data 110 | 111 | 112 | def compute_logprobs( 113 | logits: torch.Tensor, labels: torch.Tensor, ignore_index: int 114 | ) -> torch.Tensor: 115 | """Compute per-token logprobs, zeroing out places with ignore_index (padding).""" 116 | return -F.cross_entropy( 117 | logits.permute(0, 2, 1), labels, reduction="none", ignore_index=ignore_index 118 | ) 119 | 120 | 121 | def pad( 122 | inputs: torch.Tensor, 123 | target_size: Union[torch.Size, Sequence[int]], 124 | value=0.0, 125 | left=True, 126 | ): 127 | current_size = inputs.size() 128 | diffs = tuple(ti - ci for ti, ci in zip_(target_size, current_size)) 129 | pad_params = [] 130 | for diff in diffs: 131 | pad_params = ([diff, 0] if left else [0, diff]) + pad_params 132 | res = F.pad(inputs, pad=pad_params, value=value) 133 | return res 134 | 135 | 136 | def left_pad( 137 | inputs: torch.Tensor, target_size: Union[torch.Size, Sequence[int]], value=0.0 138 | ): 139 | return pad(inputs=inputs, target_size=target_size, value=value, left=True) 140 | 141 | 142 | def right_pad( 143 | inputs: torch.Tensor, target_size: Union[torch.Size, Sequence[int]], value=0.0 144 | ): 145 | return pad(inputs=inputs, target_size=target_size, value=value, left=False) 146 | 147 | 148 | def manual_seed(args_or_seed: Union[int, argparse.Namespace], fix_cudnn=False): 149 | if hasattr(args_or_seed, "seed"): 150 | args_or_seed = args_or_seed.seed 151 | random.seed(args_or_seed) 152 | np.random.seed(args_or_seed) 153 | torch.manual_seed(args_or_seed) 154 | torch.cuda.manual_seed_all(args_or_seed) 155 | os.environ["PYTHONHASHSEED"] = str(args_or_seed) 156 | if fix_cudnn: 157 | torch.backends.cudnn.deterministic = True # noqa 158 | torch.backends.cudnn.benchmark = False # noqa 159 | 160 | 161 | def make_meta_prompts(meta_prompt_pattern: str): 162 | meta_prompt_files = glob.glob(meta_prompt_pattern) 163 | print(f"Found {len(meta_prompt_files)} meta prompts: {meta_prompt_files}") 164 | 165 | meta_prompts = [] 166 | for meta_prompt_file in meta_prompt_files: 167 | with open(meta_prompt_file, "r", encoding="utf-8") as f: 168 | meta_prompt = f.readlines() 169 | meta_prompt = "".join(meta_prompt).strip() 170 | meta_prompts.append(meta_prompt) 171 | return meta_prompts 172 | 173 | 174 | class InfiniteLoader(object): 175 | """Wraps an existing loader so that it outputs stuff indefinitely; useful for semi-supervised learning.""" 176 | 177 | def __init__(self, loader: DataLoader): 178 | super(InfiniteLoader, self).__init__() 179 | self.loader = loader 180 | self.iterator = iter(loader) 181 | 182 | def __next__(self): 183 | try: 184 | return next(self.iterator) 185 | except StopIteration: 186 | self.iterator = iter(self.loader) 187 | return next(self.iterator) 188 | -------------------------------------------------------------------------------- /training/data_utils/data_utils_ppo.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The Self-Align Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import dataclasses 16 | from typing import Callable, Dict, Optional, List, Sequence 17 | 18 | import logging 19 | import pandas as pd 20 | 21 | import torch 22 | from torch.utils.data import Dataset 23 | 24 | import transformers 25 | import datasets 26 | 27 | import data_utils.common_utils as utils 28 | 29 | 30 | logger = logging.getLogger(__name__) 31 | 32 | 33 | BASE_PROMPT_DICT = { 34 | "prompt_input": "{instruction}\n\n{input}", 35 | "prompt_no_input": "{instruction}", 36 | } 37 | 38 | 39 | def format_input_and_prompt( 40 | example: Dict[str, str], 41 | meta_prompts: List[str], 42 | ) -> str: 43 | assert ( 44 | "instruction" in example and "input" in example 45 | ), "Internal error: example missing required keys." 46 | 47 | if "example_id" in example: 48 | total_meta_prompt = len(meta_prompts) 49 | meta_prompt = meta_prompts[int(example["example_id"]) % total_meta_prompt] 50 | else: 51 | meta_prompt = meta_prompts[0] 52 | 53 | if example.get("input", "") != "": 54 | prompt_format = BASE_PROMPT_DICT["prompt_input"] 55 | else: 56 | prompt_format = BASE_PROMPT_DICT["prompt_no_input"] 57 | 58 | formatted_input = prompt_format.format(**example) 59 | meta_prompt = meta_prompt.split("{Output}")[0] 60 | 61 | formatted_prompt = meta_prompt.format(Input=formatted_input) 62 | return formatted_input, formatted_prompt 63 | 64 | 65 | class QueryResponseDataset(Dataset): 66 | """Dataset that emits tokenized left-padded queries.""" 67 | 68 | def __init__( 69 | self, 70 | df: pd.DataFrame, 71 | meta_prompts: List[str], 72 | tokenizer: transformers.PreTrainedTokenizer, 73 | query_len: int, 74 | df_postprocessor: Optional[Callable] = None, 75 | ): 76 | super(QueryResponseDataset, self).__init__() 77 | 78 | if df_postprocessor is not None: 79 | df = df_postprocessor(df) 80 | list_dict_data = df.to_dict(orient="records") 81 | 82 | # prompts are strings; queries are tensors. 83 | inputs_and_prompts = [ 84 | format_input_and_prompt(example=dict_data, meta_prompts=meta_prompts) 85 | for dict_data in list_dict_data 86 | ] 87 | formated_inputs, prompts = zip(*inputs_and_prompts) 88 | 89 | length_bonus = [ 90 | dict_data.get("length_bonus", 1.0) for dict_data in list_dict_data 91 | ] 92 | 93 | # For question_type: 94 | # 0: general 95 | # 1: reasoning 96 | # 2: red-teaming 97 | question_types = [ 98 | dict_data.get("question_type", 0) for dict_data in list_dict_data 99 | ] 100 | 101 | pure_queries = [ 102 | tokenizer( 103 | formated_input, return_tensors="pt", truncation=False 104 | ).input_ids.squeeze(dim=0) 105 | for formated_input in formated_inputs 106 | ] 107 | 108 | queries = [ 109 | tokenizer(prompt, return_tensors="pt", truncation=False).input_ids.squeeze( 110 | dim=0 111 | ) 112 | for prompt in prompts 113 | ] 114 | 115 | filtered_inputs = [] 116 | filtered_queries = [] 117 | filtered_length_bonus = [] 118 | filtered_question_types = [] 119 | 120 | for pure_query, query, ex_length_bonus, ex_question_type in zip( 121 | pure_queries, queries, length_bonus, question_types 122 | ): 123 | if len(query) <= query_len: 124 | filtered_inputs.append(pure_query) 125 | filtered_queries.append(query) 126 | filtered_length_bonus.append(ex_length_bonus) 127 | filtered_question_types.append(ex_question_type) 128 | 129 | logger.warning( 130 | f"Filtered out {len(queries) - len(filtered_queries)} instances out of {len(queries)} that " 131 | f"exceed length limit. These examples are not used for training, but will still be used in evaluation. " 132 | ) 133 | 134 | pure_queries = torch.stack( 135 | [ 136 | utils.left_pad( 137 | pure_query, target_size=(query_len,), value=tokenizer.pad_token_id 138 | ) 139 | for pure_query in filtered_inputs 140 | ] 141 | ) 142 | 143 | queries = torch.stack( 144 | [ 145 | utils.left_pad( 146 | query, target_size=(query_len,), value=tokenizer.pad_token_id 147 | ) 148 | for query in filtered_queries 149 | ] 150 | ) 151 | 152 | self.length_bonus = torch.tensor( 153 | filtered_length_bonus, dtype=torch.float32 154 | ).reshape(-1, 1) 155 | self.question_types = torch.tensor( 156 | filtered_question_types, dtype=torch.long 157 | ).reshape(-1, 1) 158 | self.pure_queries = pure_queries 159 | self.queries = queries 160 | self.query_attn_masks = queries.ne(tokenizer.pad_token_id).long() 161 | 162 | assert self.pure_queries.shape[0] == self.queries.shape[0] 163 | assert self.pure_queries.shape[0] == self.query_attn_masks.shape[0] 164 | assert self.pure_queries.shape[0] == self.length_bonus.shape[0] 165 | assert self.pure_queries.shape[0] == self.question_types.shape[0] 166 | 167 | # Auxiliary data. 168 | self.prompts = prompts 169 | self.list_dict_data = list_dict_data 170 | 171 | def __getitem__(self, i): 172 | return_dict = dict( 173 | pure_queries=self.pure_queries[i], 174 | queries=self.queries[i], 175 | query_attn_masks=self.query_attn_masks[i], 176 | length_bonus=self.length_bonus[i], 177 | question_types=self.question_types[i], 178 | ) 179 | return return_dict 180 | 181 | def __len__(self): 182 | return len(self.queries) 183 | 184 | 185 | @dataclasses.dataclass 186 | class DataCollatorForQueryResponseDataset(object): 187 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: 188 | return { 189 | key: torch.stack([instance[key] for instance in instances]) 190 | for key in instances[0].keys() 191 | } 192 | 193 | 194 | def make_rl_data_module( 195 | tokenizer: transformers.PreTrainedTokenizer, 196 | data_args, 197 | training_args, 198 | ): 199 | # prompt_dict = utils.jload(data_args.prompt_dict_path) 200 | 201 | policy_meta_prompts = utils.make_meta_prompts(data_args.policy_meta_prompt_pattern) 202 | 203 | if data_args.dataset_path.endswith("json"): 204 | train_instructions = datasets.load_dataset( 205 | "json", data_files=data_args.dataset_path 206 | ) 207 | else: 208 | train_instructions = datasets.load_dataset( 209 | data_args.dataset_path, data_args.dataset_name 210 | ) 211 | train_df = pd.concat( 212 | [pd.DataFrame(train_instructions[split]) for split in data_args.train_splits] 213 | ) 214 | 215 | eval_instructions = datasets.load_dataset( 216 | data_args.eval_dataset_path, data_args.eval_dataset_name 217 | ) 218 | eval_df = pd.concat( 219 | [pd.DataFrame(eval_instructions[split]) for split in data_args.eval_splits] 220 | ) 221 | 222 | train_dataset = QueryResponseDataset( 223 | df=train_df, 224 | meta_prompts=policy_meta_prompts, 225 | tokenizer=tokenizer, 226 | query_len=training_args.query_len, 227 | ) 228 | eval_dataset = QueryResponseDataset( 229 | df=eval_df, 230 | meta_prompts=policy_meta_prompts, 231 | tokenizer=tokenizer, 232 | query_len=training_args.query_len, 233 | ) 234 | return dict( 235 | train_dataset=train_dataset, 236 | eval_dataset=eval_dataset, 237 | data_collator=DataCollatorForQueryResponseDataset(), 238 | ) 239 | -------------------------------------------------------------------------------- /training/data_utils/data_utils_sft.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The Self-Align Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import copy 16 | 17 | import os 18 | from dataclasses import dataclass, field 19 | from typing import Optional, Dict, Sequence, List 20 | import pandas as pd 21 | 22 | import torch 23 | import transformers 24 | from torch.nn.utils.rnn import pad_sequence 25 | from datasets import load_dataset, Dataset 26 | 27 | import data_utils.common_utils as utils 28 | 29 | 30 | IGNORE_INDEX = -100 31 | 32 | 33 | @dataclass 34 | class DataCollatorForCausalLM(object): 35 | left_truncated_tokenizer: transformers.PreTrainedTokenizer 36 | tokenizer: transformers.PreTrainedTokenizer 37 | source_max_len: int 38 | target_max_len: int 39 | train_on_source: bool 40 | predict_with_generate: bool 41 | add_eos_to_target: bool 42 | 43 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: 44 | # Extract elements 45 | sources = [example["input"] for example in instances] 46 | if self.add_eos_to_target: 47 | targets = [ 48 | f"\n{example['output']}{self.tokenizer.eos_token}" 49 | for example in instances 50 | ] 51 | else: 52 | targets = [f"\n{example['output']}" for example in instances] 53 | 54 | begin_padding_len = self.tokenizer( 55 | ["\n"], return_tensors="pt", add_special_tokens=False 56 | ).input_ids.shape[1] 57 | 58 | # Tokenize 59 | tokenized_sources_with_prompt = self.left_truncated_tokenizer( 60 | sources, 61 | max_length=self.source_max_len, 62 | truncation=True, 63 | # add_special_tokens=False, 64 | ) 65 | tokenized_targets = self.tokenizer( 66 | targets, 67 | max_length=self.target_max_len + begin_padding_len, 68 | truncation=True, 69 | add_special_tokens=False, 70 | ) 71 | # Build the input and labels for causal LM 72 | input_ids = [] 73 | labels = [] 74 | for tokenized_source, tokenized_target in zip( 75 | tokenized_sources_with_prompt["input_ids"], tokenized_targets["input_ids"] 76 | ): 77 | tokenized_target = tokenized_target[begin_padding_len:] 78 | if not self.predict_with_generate: 79 | input_ids.append(torch.tensor(tokenized_source + tokenized_target)) 80 | if not self.train_on_source: 81 | labels.append( 82 | torch.tensor( 83 | [IGNORE_INDEX for _ in range(len(tokenized_source))] 84 | + copy.deepcopy(tokenized_target) 85 | ) 86 | ) 87 | else: 88 | labels.append( 89 | torch.tensor(copy.deepcopy(tokenized_source + tokenized_target)) 90 | ) 91 | else: 92 | input_ids.append(torch.tensor(tokenized_source)) 93 | # Apply padding 94 | input_ids = pad_sequence( 95 | input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id 96 | ) 97 | labels = ( 98 | pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) 99 | if not self.predict_with_generate 100 | else None 101 | ) 102 | data_dict = { 103 | "input_ids": input_ids, 104 | "attention_mask": input_ids.ne(self.tokenizer.pad_token_id), 105 | } 106 | if labels is not None: 107 | data_dict["labels"] = labels 108 | return data_dict 109 | 110 | 111 | def extract_unnatural_instructions_data(examples, extract_reformulations=False): 112 | out = { 113 | "input": [], 114 | "output": [], 115 | } 116 | for example_instances in examples["instances"]: 117 | for instance in example_instances: 118 | out["input"].append(instance["instruction_with_input"]) 119 | out["output"].append(instance["output"]) 120 | if extract_reformulations: 121 | for example_reformulations in examples["reformulations"]: 122 | if example_reformulations is not None: 123 | for instance in example_reformulations: 124 | out["input"].append(instance["instruction_with_input"]) 125 | out["output"].append(instance["output"]) 126 | return out 127 | 128 | 129 | DROMEDARY_PROMPT_DICT = { 130 | "prompt_input": ( 131 | "{meta_prompt}\n" "{instruction}\n\n" "{input}\n\n" "### Dromedary" 132 | ), 133 | "prompt_no_input": ("{meta_prompt}\n" "{instruction}\n\n" "### Dromedary"), 134 | } 135 | 136 | 137 | def extract_dromedary_dataset(example, meta_prompts): 138 | assert "example_id" in example 139 | total_meta_prompt = len(meta_prompts) 140 | meta_prompt = meta_prompts[int(example["example_id"]) % total_meta_prompt] 141 | 142 | if example.get("input", "") != "": 143 | prompt_format = DROMEDARY_PROMPT_DICT["prompt_input"] 144 | else: 145 | prompt_format = DROMEDARY_PROMPT_DICT["prompt_no_input"] 146 | 147 | return { 148 | "input": prompt_format.format(meta_prompt=meta_prompt, **example), 149 | "output": "\n" + example["output"], 150 | } 151 | 152 | 153 | def local_dataset(dataset_name): 154 | if dataset_name.endswith(".json"): 155 | full_dataset = load_dataset("json", data_files=dataset_name) 156 | else: 157 | raise ValueError(f"Unsupported dataset format: {dataset_name}") 158 | 159 | return full_dataset 160 | 161 | 162 | def make_sft_data_module( 163 | left_truncated_tokenizer: transformers.PreTrainedTokenizer, 164 | tokenizer: transformers.PreTrainedTokenizer, 165 | args, 166 | ) -> Dict: 167 | """ 168 | Make dataset and collator for supervised fine-tuning. 169 | Datasets are expected to have the following columns: { `input`, `output` } 170 | """ 171 | 172 | def load_data(dataset_name): 173 | if dataset_name == "alpaca": 174 | return load_dataset("tatsu-lab/alpaca") 175 | elif dataset_name == "alpaca-clean": 176 | return load_dataset("yahma/alpaca-cleaned") 177 | elif dataset_name == "chip2": 178 | return load_dataset("laion/OIG", data_files="unified_chip2.jsonl") 179 | elif dataset_name == "self-instruct": 180 | return load_dataset("yizhongw/self_instruct", name="self_instruct") 181 | elif dataset_name == "hh-rlhf": 182 | return load_dataset("Anthropic/hh-rlhf") 183 | elif dataset_name == "longform": 184 | return load_dataset("akoksal/LongForm") 185 | elif dataset_name == "oasst1": 186 | return load_dataset("timdettmers/openassistant-guanaco") 187 | elif dataset_name == "vicuna": 188 | raise NotImplementedError("Vicuna data was not released.") 189 | else: 190 | if os.path.exists(dataset_name): 191 | try: 192 | args.dataset_format = ( 193 | args.dataset_format if args.dataset_format else "alpaca" 194 | ) 195 | full_dataset = local_dataset(dataset_name) 196 | return full_dataset 197 | except: 198 | raise ValueError(f"Error loading dataset from {dataset_name}") 199 | else: 200 | raise NotImplementedError( 201 | f"Dataset {dataset_name} not implemented yet." 202 | ) 203 | 204 | def format_dataset(dataset, dataset_format, meta_prompt_pattern): 205 | if dataset_format == "dromedary": 206 | assert meta_prompt_pattern is not None 207 | 208 | meta_prompts = utils.make_meta_prompts(meta_prompt_pattern) 209 | 210 | dataset = dataset.map( 211 | lambda ex: extract_dromedary_dataset(ex, meta_prompts=meta_prompts), 212 | remove_columns=["instruction", "example_id"], 213 | ) 214 | elif dataset_format == "chip2" or ( 215 | dataset_format is None and args.dataset == "chip2" 216 | ): 217 | dataset = dataset.map( 218 | lambda x: { 219 | "input": x["text"].split("\n: ")[0].replace(": ", ""), 220 | "output": x["text"].split("\n: ")[1], 221 | } 222 | ) 223 | elif dataset_format == "self-instruct" or ( 224 | dataset_format is None and args.dataset == "self-instruct" 225 | ): 226 | for old, new in [["prompt", "input"], ["completion", "output"]]: 227 | dataset = dataset.rename_column(old, new) 228 | elif dataset_format == "hh-rlhf" or ( 229 | dataset_format is None and args.dataset == "hh-rlhf" 230 | ): 231 | dataset = dataset.map(lambda x: {"input": "", "output": x["chosen"]}) 232 | elif dataset_format == "oasst1" or ( 233 | dataset_format is None and args.dataset == "oasst1" 234 | ): 235 | dataset = dataset.map( 236 | lambda x: { 237 | "input": "", 238 | "output": x["text"], 239 | } 240 | ) 241 | # Remove unused columns. 242 | dataset = dataset.remove_columns( 243 | [ 244 | col 245 | for col in dataset.column_names["train"] 246 | if col not in ["input", "output"] 247 | ] 248 | ) 249 | return dataset 250 | 251 | # Load dataset. 252 | dataset = load_data(args.dataset) 253 | dataset = format_dataset(dataset, args.dataset_format, args.meta_prompt_pattern) 254 | 255 | # Split train/eval, reduce size 256 | if args.do_eval or args.do_predict: 257 | if "eval" in dataset: 258 | eval_dataset = dataset["eval"] 259 | else: 260 | print( 261 | "Splitting train dataset in train and validation according to `eval_dataset_size`" 262 | ) 263 | dataset = dataset["train"].train_test_split( 264 | test_size=args.eval_dataset_size, shuffle=True, seed=42 265 | ) 266 | eval_dataset = dataset["test"] 267 | if ( 268 | args.max_eval_samples is not None 269 | and len(eval_dataset) > args.max_eval_samples 270 | ): 271 | eval_dataset = eval_dataset.select(range(args.max_eval_samples)) 272 | if args.group_by_length: 273 | eval_dataset = eval_dataset.map( 274 | lambda x: {"length": len(x["input"]) + len(x["output"])} 275 | ) 276 | if args.do_train: 277 | train_dataset = dataset["train"] 278 | if ( 279 | args.max_train_samples is not None 280 | and len(train_dataset) > args.max_train_samples 281 | ): 282 | train_dataset = train_dataset.select(range(args.max_train_samples)) 283 | if args.group_by_length: 284 | train_dataset = train_dataset.map( 285 | lambda x: {"length": len(x["input"]) + len(x["output"])} 286 | ) 287 | 288 | data_collator = DataCollatorForCausalLM( 289 | left_truncated_tokenizer=left_truncated_tokenizer, 290 | tokenizer=tokenizer, 291 | source_max_len=args.source_max_len, 292 | target_max_len=args.target_max_len, 293 | train_on_source=args.train_on_source, 294 | predict_with_generate=args.predict_with_generate, 295 | add_eos_to_target=args.add_eos_to_target, 296 | ) 297 | return dict( 298 | train_dataset=train_dataset if args.do_train else None, 299 | eval_dataset=eval_dataset if args.do_eval else None, 300 | predict_dataset=eval_dataset if args.do_predict else None, 301 | data_collator=data_collator, 302 | ) 303 | -------------------------------------------------------------------------------- /training/models/configuration_llama.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 Self-Align, EleutherAI, and the HuggingFace Inc. team. All rights reserved. 3 | # 4 | # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX 5 | # and OPT implementations in this library. It has been modified from its 6 | # original forms to accommodate minor architectural differences compared 7 | # to GPT-NeoX and OPT used by the Meta AI team that trained the model. 8 | # 9 | # Licensed under the Apache License, Version 2.0 (the "License"); 10 | # you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # 13 | # http://www.apache.org/licenses/LICENSE-2.0 14 | # 15 | # Unless required by applicable law or agreed to in writing, software 16 | # distributed under the License is distributed on an "AS IS" BASIS, 17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18 | # See the License for the specific language governing permissions and 19 | # limitations under the License. 20 | """ LLaMA model configuration. """ 21 | 22 | from transformers.configuration_utils import PretrainedConfig 23 | from transformers.utils import logging 24 | 25 | 26 | logger = logging.get_logger(__name__) 27 | 28 | LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP = {} 29 | 30 | 31 | class LlamaConfig(PretrainedConfig): 32 | r""" 33 | This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA 34 | model according to the specified arguments, defining the model architecture. Instantiating a configuration with the 35 | defaults will yield a similar configuration to that of the LLaMA-7B. 36 | 37 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 38 | documentation from [`PretrainedConfig`] for more information.""" 39 | model_type = "llama" 40 | keys_to_ignore_at_inference = ["past_key_values"] 41 | 42 | def __init__( 43 | self, 44 | vocab_size=32000, 45 | hidden_size=4096, 46 | intermediate_size=11008, 47 | num_hidden_layers=32, 48 | num_attention_heads=32, 49 | num_key_value_heads=None, 50 | hidden_act="silu", 51 | max_position_embeddings=2048, 52 | initializer_range=0.02, 53 | rms_norm_eps=1e-6, 54 | use_cache=True, 55 | pad_token_id=0, 56 | bos_token_id=1, 57 | eos_token_id=2, 58 | tie_word_embeddings=False, 59 | rope_scaling=None, 60 | **kwargs, 61 | ): 62 | # for backward compatibility 63 | if num_key_value_heads is None: 64 | num_key_value_heads = num_attention_heads 65 | 66 | self.num_key_value_heads = num_key_value_heads 67 | self.vocab_size = vocab_size 68 | self.max_position_embeddings = max_position_embeddings 69 | self.hidden_size = hidden_size 70 | self.intermediate_size = intermediate_size 71 | self.num_hidden_layers = num_hidden_layers 72 | self.num_attention_heads = num_attention_heads 73 | self.hidden_act = hidden_act 74 | self.initializer_range = initializer_range 75 | self.rms_norm_eps = rms_norm_eps 76 | self.rope_scaling = rope_scaling 77 | self._rope_scaling_validation() 78 | self.use_cache = use_cache 79 | self.cache_shape = None 80 | super().__init__( 81 | pad_token_id=pad_token_id, 82 | bos_token_id=bos_token_id, 83 | eos_token_id=eos_token_id, 84 | tie_word_embeddings=tie_word_embeddings, 85 | **kwargs, 86 | ) 87 | 88 | def _rope_scaling_validation(self): 89 | """ 90 | Validate the `rope_scaling` configuration. 91 | """ 92 | if self.rope_scaling is None: 93 | return 94 | 95 | if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: 96 | raise ValueError( 97 | "`rope_scaling` must be a dictionary with with two fields, `name` and `factor`, " 98 | f"got {self.rope_scaling}" 99 | ) 100 | rope_scaling_type = self.rope_scaling.get("type", None) 101 | rope_scaling_factor = self.rope_scaling.get("factor", None) 102 | if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: 103 | raise ValueError( 104 | f"`rope_scaling`'s name field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" 105 | ) 106 | if ( 107 | rope_scaling_factor is None 108 | or not isinstance(rope_scaling_factor, float) 109 | or rope_scaling_factor <= 1.0 110 | ): 111 | raise ValueError( 112 | f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}" 113 | ) 114 | -------------------------------------------------------------------------------- /training/models/distributed_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The Alpaca Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Utilities for PyTorch's distributed training. 16 | 17 | Compatible with torchrun / elastic. 18 | 19 | Internal map: 20 | https://github.com/lxuechen/ml-swissknife/blob/main/ml_swissknife/distributed_utils.py 21 | """ 22 | 23 | import os 24 | import sys 25 | from typing import Optional 26 | 27 | import torch 28 | import torch.distributed as dist 29 | 30 | 31 | def setup(rank: Optional[int] = None, world_size: Optional[int] = None): 32 | if rank is None: 33 | rank = get_rank() 34 | if world_size is None: 35 | world_size = get_world_size() 36 | 37 | if world_size <= 1: 38 | return rank, world_size 39 | 40 | if not dist.is_initialized(): 41 | if sys.platform == "win32": 42 | # Distributed package only covers collective communications with Gloo 43 | # backend and FileStore on Windows platform. Set init_method parameter 44 | # in init_process_group to a local file. 45 | # Example init_method="file:///f:/libtmp/some_file" 46 | init_method = "file:///f:/libtmp/dist-tmp" 47 | dist.init_process_group( 48 | backend="gloo", 49 | init_method=init_method, 50 | rank=rank, 51 | world_size=world_size, 52 | ) 53 | elif torch.cuda.is_available(): 54 | dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) 55 | else: 56 | dist.init_process_group(backend="gloo", rank=rank, world_size=world_size) 57 | 58 | return rank, world_size 59 | 60 | 61 | def cleanup(): 62 | dist.destroy_process_group() 63 | 64 | 65 | def get_rank(): 66 | return int(os.getenv("RANK", 0)) 67 | 68 | 69 | def get_local_rank(): 70 | return int(os.getenv("LOCAL_RANK", 0)) 71 | 72 | 73 | def get_world_size(): 74 | return int(os.getenv("WORLD_SIZE", 1)) 75 | 76 | 77 | def should_save(): 78 | """Return True if the current process is the main process.""" 79 | return get_rank() <= 0 80 | 81 | 82 | def all_gather_and_cat(tensor: torch.Tensor, dim=0): 83 | if get_world_size() > 1: 84 | tensor_list = [torch.empty_like(tensor) for _ in range(get_world_size())] 85 | dist.all_gather(tensor_list, tensor) 86 | tensor = torch.cat(tensor_list, dim=dim) 87 | return tensor 88 | 89 | 90 | is_main_process = should_save 91 | -------------------------------------------------------------------------------- /training/models/qlora_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The Self-Align Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from argparse import Namespace 16 | from typing import Optional 17 | from os.path import join, exists 18 | 19 | import torch 20 | import bitsandbytes as bnb 21 | from transformers import ( 22 | AutoModelForCausalLM, 23 | BitsAndBytesConfig, 24 | ) 25 | 26 | from peft import ( 27 | prepare_model_for_kbit_training, 28 | LoraConfig, 29 | PeftModel, 30 | PeftModelForCausalLM, 31 | ) 32 | from peft.tuners.lora import LoraLayer 33 | from falcon_modeling.modelling_RW import RWForCausalLM 34 | from models.llama_with_flash_attn import LlamaForCausalLM 35 | 36 | REGISTERED_BASE_MODELS = {} 37 | 38 | 39 | def find_all_linear_names( 40 | args: Namespace, 41 | model: torch.nn.Module, 42 | ): 43 | cls = ( 44 | bnb.nn.Linear4bit 45 | if args.bits == 4 46 | else (bnb.nn.Linear8bitLt if args.bits == 8 else torch.nn.Linear) 47 | ) 48 | lora_module_names = set() 49 | for name, module in model.named_modules(): 50 | if isinstance(module, cls): 51 | names = name.split(".") 52 | if "lora" not in names[-1]: 53 | lora_module_names.add(names[0] if len(names) == 1 else names[-1]) 54 | 55 | if "lm_head" in lora_module_names: # needed for 16-bit 56 | lora_module_names.remove("lm_head") 57 | return list(lora_module_names) 58 | 59 | 60 | def get_accelerate_model( 61 | args: Namespace, 62 | checkpoint_dir: Optional[str] = None, 63 | adapter_name="lora_default", 64 | is_trainable=True, 65 | reuse_base_model=False, 66 | ): 67 | global REGISTERED_BASE_MODELS 68 | 69 | compute_dtype = ( 70 | torch.float16 if args.fp16 else (torch.bfloat16 if args.bf16 else torch.float32) 71 | ) 72 | 73 | if checkpoint_dir is not None: 74 | if exists(join(checkpoint_dir, "adapter_model")): 75 | checkpoint_dir = join(checkpoint_dir, "adapter_model") 76 | 77 | if exists(join(checkpoint_dir, "lora_default")): 78 | checkpoint_dir = join(checkpoint_dir, "lora_default") 79 | 80 | if args.model_name_or_path in REGISTERED_BASE_MODELS and reuse_base_model: 81 | config = { 82 | "load_in_4bit": args.bits == 4, 83 | "load_in_8bit": args.bits == 8, 84 | "llm_int8_threshold": 6.0, 85 | "llm_int8_has_fp16_weight": False, 86 | "bnb_4bit_compute_dtype": compute_dtype, 87 | "bnb_4bit_use_double_quant": args.double_quant, 88 | "bnb_4bit_quant_type": args.quant_type, 89 | } 90 | 91 | registered_model, registered_config = REGISTERED_BASE_MODELS[ 92 | args.model_name_or_path 93 | ] 94 | if registered_config == config and not args.full_finetune: 95 | print(f"loading registered model {args.model_name_or_path}...") 96 | model = registered_model 97 | 98 | if checkpoint_dir is not None: 99 | model.load_adapter( 100 | checkpoint_dir, 101 | adapter_name=adapter_name, 102 | is_trainable=is_trainable, 103 | ) 104 | else: 105 | modules = args.lora_modules or find_all_linear_names(args, model) 106 | print("adding LoRa modules: ", modules) 107 | config = LoraConfig( 108 | r=args.lora_r, 109 | lora_alpha=args.lora_alpha, 110 | target_modules=modules, 111 | lora_dropout=args.lora_dropout, 112 | bias="none", 113 | task_type="CAUSAL_LM", 114 | ) 115 | model.add_adapter(adapter_name, peft_config=config) 116 | return model 117 | else: 118 | raise ValueError( 119 | f"Model {args.model_name_or_path} is already registered with a different config." 120 | f"{registered_config} != {config}" 121 | ) 122 | 123 | current_device = torch.cuda.current_device() 124 | if args.full_finetune: 125 | assert args.bits in [16, 32] 126 | 127 | print(f"loading base model {args.model_name_or_path}...") 128 | 129 | CausalLM = AutoModelForCausalLM 130 | 131 | if "falcon" in args.model_name_or_path.lower(): 132 | CausalLM = RWForCausalLM 133 | elif ( 134 | "llama" in args.model_name_or_path.lower() 135 | or "vicuna" in args.model_name_or_path.lower() 136 | or "dromedary" in args.model_name_or_path.lower() 137 | ) and torch.__version__ >= "2.0.0": 138 | CausalLM = LlamaForCausalLM 139 | 140 | model = CausalLM.from_pretrained( 141 | args.model_name_or_path, 142 | load_in_4bit=args.bits == 4, 143 | load_in_8bit=args.bits == 8, 144 | device_map={"": current_device}, 145 | # max_memory=max_memory, 146 | quantization_config=BitsAndBytesConfig( 147 | load_in_4bit=args.bits == 4, 148 | load_in_8bit=args.bits == 8, 149 | llm_int8_threshold=6.0, 150 | llm_int8_has_fp16_weight=False, 151 | bnb_4bit_compute_dtype=compute_dtype, 152 | bnb_4bit_use_double_quant=args.double_quant, 153 | bnb_4bit_quant_type=args.quant_type, 154 | ), 155 | torch_dtype=( 156 | torch.float16 157 | if args.fp16 158 | else (torch.bfloat16 if args.bf16 else torch.float32) 159 | ), 160 | trust_remote_code=args.trust_remote_code, 161 | ) 162 | if compute_dtype == torch.float16 and args.bits == 4: 163 | major, minor = torch.cuda.get_device_capability() 164 | if major >= 8: 165 | print("=" * 80) 166 | print( 167 | "Your GPU supports bfloat16, you can accelerate training with the argument --bf16" 168 | ) 169 | print("=" * 80) 170 | 171 | setattr(model, "model_parallel", True) 172 | setattr(model, "is_parallelizable", True) 173 | 174 | model.config.torch_dtype = ( 175 | torch.float16 if args.fp16 else (torch.bfloat16 if args.bf16 else torch.float32) 176 | ) 177 | 178 | if not args.full_finetune: 179 | model = prepare_model_for_kbit_training( 180 | model, use_gradient_checkpointing=args.gradient_checkpointing 181 | ) 182 | if args.gradient_checkpointing: 183 | model.gradient_checkpointing_enable() 184 | 185 | if not args.full_finetune: 186 | if checkpoint_dir is not None: 187 | print("Loading adapters from checkpoint.") 188 | 189 | model = PeftModel.from_pretrained( 190 | model, 191 | checkpoint_dir, 192 | adapter_name=adapter_name, 193 | is_trainable=is_trainable, 194 | ) 195 | else: 196 | # print(f'adding LoRA modules...') 197 | modules = args.lora_modules or find_all_linear_names(args, model) 198 | print("adding LoRa modules: ", modules) 199 | config = LoraConfig( 200 | r=args.lora_r, 201 | lora_alpha=args.lora_alpha, 202 | target_modules=modules, 203 | lora_dropout=args.lora_dropout, 204 | bias="none", 205 | task_type="CAUSAL_LM", 206 | ) 207 | model = get_peft_model(model, config, adapter_name=adapter_name) 208 | 209 | if args.model_name_or_path not in REGISTERED_BASE_MODELS: 210 | config = { 211 | "load_in_4bit": args.bits == 4, 212 | "load_in_8bit": args.bits == 8, 213 | "llm_int8_threshold": 6.0, 214 | "llm_int8_has_fp16_weight": False, 215 | "bnb_4bit_compute_dtype": compute_dtype, 216 | "bnb_4bit_use_double_quant": args.double_quant, 217 | "bnb_4bit_quant_type": args.quant_type, 218 | } 219 | REGISTERED_BASE_MODELS[args.model_name_or_path] = (model, config) 220 | 221 | for name, module in model.named_modules(): 222 | if isinstance(module, LoraLayer): 223 | if args.bf16: 224 | module = module.to(torch.bfloat16) 225 | else: 226 | module = module.to(torch.float32) 227 | if "lm_head" in name or "embed_tokens" in name: 228 | if hasattr(module, "weight"): 229 | if args.bf16 and module.weight.dtype == torch.float32: 230 | module = module.to(torch.bfloat16) 231 | # if not args.bf16: 232 | # module = module.to(torch.float32) 233 | return model 234 | 235 | 236 | def load_4bit_model_for_inference( 237 | checkpoint_dir: str, 238 | bits: int = 4, 239 | fp16: bool = False, 240 | bf16: bool = False, 241 | double_quant: bool = True, 242 | quant_type: str = "nf4", 243 | gradient_checkpointing: bool = False, 244 | adapter_name="lora_default", 245 | is_trainable=True, 246 | reuse_base_model=False, 247 | trust_remote_code=False, 248 | base_model_mapping=None, 249 | fully_initialize=False, 250 | ): 251 | if checkpoint_dir is not None: 252 | if exists(join(checkpoint_dir, "adapter_model")): 253 | checkpoint_dir = join(checkpoint_dir, "adapter_model") 254 | 255 | if exists(join(checkpoint_dir, "lora_default")): 256 | checkpoint_dir = join(checkpoint_dir, "lora_default") 257 | 258 | config = LoraConfig.from_pretrained(checkpoint_dir) 259 | base_model_name_or_path = config.base_model_name_or_path 260 | 261 | if base_model_mapping is not None: 262 | dict_base_model_mapping = eval(base_model_mapping) 263 | if ( 264 | dict_base_model_mapping is not None 265 | and base_model_name_or_path in dict_base_model_mapping 266 | ): 267 | base_model_name_or_path = dict_base_model_mapping[base_model_name_or_path] 268 | 269 | args = Namespace( 270 | model_name_or_path=base_model_name_or_path, 271 | bits=bits, 272 | fp16=fp16, 273 | bf16=bf16, 274 | double_quant=double_quant, 275 | quant_type=quant_type, 276 | gradient_checkpointing=gradient_checkpointing, 277 | trust_remote_code=trust_remote_code, 278 | full_finetune=False, 279 | lora_r=64 if fully_initialize else None, 280 | lora_alpha=16 if fully_initialize else None, 281 | lora_dropout=0.0 if fully_initialize else None, 282 | lora_modules=None, 283 | ) 284 | 285 | if fully_initialize: 286 | print("Fully initializing qlora model.") 287 | 288 | model = get_accelerate_model( 289 | args, 290 | checkpoint_dir=None if fully_initialize else checkpoint_dir, 291 | adapter_name=adapter_name, 292 | is_trainable=is_trainable, 293 | reuse_base_model=reuse_base_model, 294 | ) 295 | return model 296 | 297 | 298 | def get_peft_model(model, peft_config, adapter_name="default"): 299 | """ 300 | Returns a Peft model object from a model and a config. 301 | 302 | Args: 303 | model ([`transformers.PreTrainedModel`]): Model to be wrapped. 304 | peft_config ([`PeftConfig`]): Configuration object containing the parameters of the Peft model. 305 | """ 306 | peft_config.base_model_name_or_path = model.__dict__.get("name_or_path", None) 307 | return PeftModelForCausalLM(model, peft_config, adapter_name=adapter_name) 308 | -------------------------------------------------------------------------------- /training/models/reward_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The Self-Align Team 2 | # Copyright 2023 The Alpaca Team 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from argparse import Namespace 17 | import os 18 | from typing import Optional, Dict, Sequence, Union 19 | 20 | import einops 21 | import torch 22 | from torch import Tensor, nn 23 | import torch.nn.functional as F 24 | 25 | import transformers 26 | from transformers.trainer_utils import EvalPrediction 27 | from transformers.utils.generic import ModelOutput 28 | 29 | from peft import PeftModel, LoraModel, LoraConfig 30 | 31 | from models.qlora_model import get_accelerate_model 32 | from models.llama_with_flash_attn import LlamaModel as FlashAttnLlamaModel 33 | 34 | 35 | def unpack_dict( 36 | d: Dict, keys: Sequence[str], return_type: type = tuple 37 | ) -> Union[Sequence, Dict]: 38 | if return_type in (tuple, list): 39 | return return_type(d[key] for key in keys) 40 | elif return_type == dict: 41 | return {key: d[key] for key in keys} 42 | else: 43 | raise ValueError(f"Unknown return_type: {return_type}") 44 | 45 | 46 | def batch_select(input: Tensor, index: Tensor): 47 | """Select elements from a batched tensor with a batched index tensor. 48 | 49 | Example: 50 | input = torch.tensor([ 51 | [0, 1, 2], 52 | [3, 0, 9], 53 | [6, 7, 8], 54 | ]) 55 | index = torch.tensor([[0, 1], [1, 0], [0, 0]]) 56 | batch_select(input, index) = tensor([ 57 | [0, 1], 58 | [0, 3], 59 | [6, 6] 60 | ]) 61 | """ 62 | dummy_index = torch.arange(input.size(0), device=input.device).unsqueeze(-1) 63 | return input[dummy_index, index] 64 | 65 | 66 | def make_generative_lm( 67 | args: Namespace, 68 | model_name_or_path: str, 69 | qlora: bool = False, 70 | checkpoint_dir: Optional[str] = None, 71 | adapter_name="lora_default", 72 | is_trainable=True, 73 | reuse_base_model=False, 74 | **kwargs, 75 | ): 76 | model_cls = transformers.LlamaForCausalLM 77 | 78 | if qlora: 79 | if checkpoint_dir is None or checkpoint_dir in ["scratch", "none"]: 80 | return get_accelerate_model(args, None) 81 | else: 82 | return get_accelerate_model( 83 | args, 84 | checkpoint_dir=checkpoint_dir, 85 | adapter_name=adapter_name, 86 | is_trainable=is_trainable, 87 | reuse_base_model=reuse_base_model, 88 | ) 89 | return model_cls.from_pretrained(model_name_or_path, **kwargs) 90 | 91 | 92 | def get_transformer_hidden_size(model: transformers.PreTrainedModel): 93 | if isinstance(model, PeftModel): 94 | return get_transformer_hidden_size(model.base_model) 95 | 96 | if isinstance(model, LoraModel): 97 | return get_transformer_hidden_size(model.model) 98 | 99 | if isinstance(model, transformers.GPT2LMHeadModel): 100 | hidden_size_attr_name = "n_embd" 101 | elif isinstance(model, transformers.OPTForCausalLM): 102 | hidden_size_attr_name = "word_embed_proj_dim" 103 | elif isinstance(model, transformers.T5ForConditionalGeneration): 104 | hidden_size_attr_name = "d_model" 105 | elif isinstance(model, transformers.GPTBigCodeForCausalLM): 106 | hidden_size_attr_name = "n_embd" 107 | elif "modelling_RW.RWModel" in str( 108 | type(model) 109 | ) or "modelling_RW.RWForCausalLM" in str(type(model)): 110 | # TODO(zhiqings): Hack to add support for Falcon. 111 | hidden_size_attr_name = "hidden_size" 112 | else: 113 | # Hack to deal with the fact that transformers library changed the LLaMA model name. 114 | llama_cls = getattr( 115 | transformers, 116 | "LLaMAForCausalLM" 117 | if hasattr(transformers, "LLaMAForCausalLM") 118 | else "LlamaForCausalLM", 119 | ) 120 | if isinstance(model, llama_cls) or "LlamaForCausalLM" in str(type(model)): 121 | hidden_size_attr_name = "hidden_size" 122 | else: 123 | raise ValueError(f"Unknown base_model type: {type(model)}") 124 | from typing import Any, Mapping 125 | return getattr(model.config, hidden_size_attr_name) 126 | 127 | 128 | class RewardConfig(transformers.PretrainedConfig): 129 | model_type = "reward_model" 130 | 131 | # Huggingface doesn't allow non-kwargs for `__init__`. 132 | def __init__(self, backbone_model_name_or_path=None, **kwargs): 133 | super(RewardConfig, self).__init__(**kwargs) 134 | self.backbone_model_name_or_path = backbone_model_name_or_path 135 | 136 | 137 | class RewardModelOutput(ModelOutput): 138 | rewards: Tensor = None 139 | 140 | 141 | class RewardModel(transformers.PreTrainedModel): 142 | config_class = RewardConfig 143 | supports_gradient_checkpointing = True 144 | 145 | def __init__( 146 | self, 147 | args: Namespace, 148 | config: RewardConfig, 149 | checkpoint_dir: Optional[str] = None, 150 | adapter_name="lora_default", 151 | **kwargs, 152 | ): 153 | super(RewardModel, self).__init__(config) 154 | self.adapter_name = adapter_name 155 | self.backbone_model = make_generative_lm( 156 | args, 157 | config.backbone_model_name_or_path, 158 | checkpoint_dir=checkpoint_dir, 159 | adapter_name=adapter_name, 160 | **kwargs, 161 | ) 162 | hidden_size = get_transformer_hidden_size(self.backbone_model) 163 | reward_head = nn.Linear(hidden_size, 1) 164 | torch.nn.init.zeros_(reward_head.bias) 165 | device = next(self.backbone_model.parameters()).device 166 | self.reward_head = reward_head.to(device) 167 | 168 | if checkpoint_dir is not None: 169 | reward_head_path = os.path.join(checkpoint_dir, "reward_head") 170 | if os.path.exists(reward_head_path): 171 | self.reward_head.load_state_dict( 172 | torch.load( 173 | reward_head_path, 174 | map_location="cpu", 175 | ) 176 | ) 177 | else: 178 | print(f"Warning: reward head not found at {reward_head_path}") 179 | 180 | self.reward_head.requires_grad_(kwargs.get("is_trainable", True)) 181 | 182 | def forward(self, input_ids, attention_mask=None, return_dict=True, **kwargs): 183 | # We only compute the rewards and don't compute the logistic regression loss in this function so that it's 184 | # easier to use for later stages of reranking / RL training. 185 | self.backbone_model.set_adapter(self.adapter_name) 186 | self.backbone_model.config.use_cache = False 187 | 188 | outputs = self.backbone_model( 189 | input_ids=input_ids, 190 | attention_mask=attention_mask, 191 | return_dict=True, 192 | output_hidden_states=True, 193 | **kwargs, 194 | ) 195 | last_hidden_state = outputs.hidden_states[-1] 196 | assert isinstance(last_hidden_state, torch.Tensor), f"{outputs}" 197 | # last_hidden_state = outputs.last_hidden_state 198 | # TODO(zhiqings): Hacking to make sure every parameter is used in the backward pass. 199 | logits = outputs.logits 200 | last_hidden_state = last_hidden_state + 0.0 * torch.mean(logits) 201 | 202 | last_hidden_state_at_the_end = last_hidden_state[:, -1, :] 203 | # TODO(lxuechen): Make returning rewards at all positions and last_hidden_state an option. 204 | final_dtype_tensor = next(self.reward_head.parameters()) 205 | last_hidden_state_at_the_end = last_hidden_state_at_the_end.type_as( 206 | final_dtype_tensor 207 | ) 208 | 209 | if final_dtype_tensor.dtype == torch.float16: 210 | last_hidden_state_at_the_end = last_hidden_state_at_the_end.to( 211 | torch.float32 212 | ) 213 | 214 | rewards = self.reward_head(last_hidden_state_at_the_end).squeeze(-1) 215 | return RewardModelOutput(rewards=rewards) if return_dict else (rewards,) 216 | 217 | def _set_gradient_checkpointing(self, module, value=False): 218 | if isinstance(module, transformers.LlamaModel): 219 | module.gradient_checkpointing = value 220 | elif isinstance(module, FlashAttnLlamaModel): 221 | module.gradient_checkpointing = value 222 | if isinstance(module, transformers.GPTBigCodeModel): 223 | module.gradient_checkpointing = value 224 | # TODO(zhiqings): Hack to add support for Falcon. 225 | if "RWModel" in str(type(module)): 226 | module.gradient_checkpointing = value 227 | 228 | 229 | class RewardModelTrainer(transformers.Trainer): 230 | def compute_loss(self, model, inputs, return_outputs=False): 231 | # input_ids, attention_mask each of size (bsz, num_candidates, seq_len). 232 | # index_0, index_1 each of size (bsz, num_pairs); indexes into input_ids. 233 | # choice of size (bsz, num_pairs); 1 if index_1's seq is chosen, 0 otherwise. 234 | input_ids, attention_mask, index_0, index_1, choice = unpack_dict( 235 | inputs, keys=("input_ids", "attention_mask", "index_0", "index_1", "choice") 236 | ) 237 | num_candidates, num_pairs = input_ids.size(1), choice.size(1) 238 | input_ids_flat, attention_mask_flat = tuple( 239 | einops.rearrange(x, "b c l -> (b c) l") for x in (input_ids, attention_mask) 240 | ) 241 | outputs = model(input_ids=input_ids_flat, attention_mask=attention_mask_flat) 242 | rewards_flat = outputs.rewards 243 | rewards = einops.rearrange( 244 | rewards_flat, "(b c) -> b c", c=num_candidates 245 | ) # Size: (bsz, num_candidates). 246 | 247 | rewards_0, rewards_1 = tuple( 248 | batch_select(rewards, index) for index in (index_0, index_1) 249 | ) # Size: (bsz, num_pairs). 250 | logits = rewards_1 - rewards_0 # Size: (bsz, num_pairs). 251 | # Type casting of `choice` is due to amp.autocast context manager. 252 | loss = F.binary_cross_entropy_with_logits( 253 | logits, choice.to(logits.dtype), reduction="mean" 254 | ) 255 | return (loss, dict(logits=logits)) if return_outputs else loss 256 | 257 | 258 | def compute_reward_modeling_metrics(eval_prediction: EvalPrediction) -> Dict: 259 | # eval_prediction.label_ids is a tuple that matches up with `training_args.label_names`. 260 | logits = torch.tensor(eval_prediction.predictions).squeeze(-1) 261 | labels = torch.tensor(eval_prediction.label_ids[-1]).squeeze(-1) 262 | predictions = (logits >= 0.0).long() 263 | accuracy = predictions.eq(labels).float().mean().item() 264 | label_positive_rate = (labels == 1).float().mean().item() 265 | return dict( 266 | accuracy=accuracy, 267 | label_positive_rate=label_positive_rate, 268 | ) 269 | 270 | 271 | def load_4bit_reward_model_for_inference( 272 | checkpoint_dir: str, 273 | bits: int = 4, 274 | fp16: bool = False, 275 | bf16: bool = False, 276 | double_quant: bool = True, 277 | quant_type: str = "nf4", 278 | gradient_checkpointing: bool = False, 279 | adapter_name="lora_default", 280 | is_trainable=True, 281 | reuse_base_model=False, 282 | trust_remote_code=False, 283 | base_model_mapping=None, 284 | ): 285 | # Load the model. 286 | lora_checkpoint_dir = checkpoint_dir 287 | if os.path.exists(os.path.join(lora_checkpoint_dir, "adapter_model")): 288 | lora_checkpoint_dir = os.path.join(lora_checkpoint_dir, "adapter_model") 289 | if os.path.exists(os.path.join(lora_checkpoint_dir, "lora_default")): 290 | lora_checkpoint_dir = os.path.join(lora_checkpoint_dir, "lora_default") 291 | 292 | lora_config = LoraConfig.from_pretrained(lora_checkpoint_dir) 293 | base_model_name_or_path = lora_config.base_model_name_or_path 294 | 295 | if base_model_mapping is not None: 296 | dict_base_model_mapping = eval(base_model_mapping) 297 | if ( 298 | dict_base_model_mapping is not None 299 | and base_model_name_or_path in dict_base_model_mapping 300 | ): 301 | base_model_name_or_path = dict_base_model_mapping[base_model_name_or_path] 302 | 303 | config = RewardConfig(backbone_model_name_or_path=base_model_name_or_path) 304 | 305 | args = Namespace( 306 | model_name_or_path=config.backbone_model_name_or_path, 307 | bits=bits, 308 | fp16=fp16, 309 | bf16=bf16, 310 | double_quant=double_quant, 311 | quant_type=quant_type, 312 | trust_remote_code=trust_remote_code, 313 | full_finetune=False, 314 | gradient_checkpointing=gradient_checkpointing, 315 | ) 316 | 317 | model = RewardModel( 318 | args, 319 | config, 320 | checkpoint_dir=checkpoint_dir, 321 | qlora=bits == 4 or bits == 8, 322 | adapter_name=adapter_name, 323 | is_trainable=is_trainable, 324 | reuse_base_model=reuse_base_model, 325 | ) 326 | return model 327 | -------------------------------------------------------------------------------- /training/models/rl_models.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The Self-Align Team 2 | # Copyright 2023 The Alpaca Team 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Model classes that are shared across different algorithms. 17 | 18 | WARNING: 19 | Do not tamper with the state_dict function for any of these classes. 20 | If you tamper, make sure the keys are the same, otherwise FSDP will get confused. 21 | """ 22 | 23 | import abc 24 | import logging 25 | from typing import Dict, Optional 26 | 27 | import torch 28 | import transformers 29 | from torch import Tensor, nn 30 | 31 | from data_utils.common_utils import right_pad, compute_logprobs 32 | from models.reward_model import get_transformer_hidden_size 33 | 34 | 35 | logger = logging.getLogger(__name__) 36 | 37 | 38 | class Policy(nn.Module, abc.ABC): 39 | def __init__( 40 | self, 41 | args, 42 | base_model: transformers.PreTrainedModel, 43 | base_tokenizer: transformers.PreTrainedTokenizer, 44 | adapter_name: Optional[str] = None, 45 | ): 46 | super().__init__() 47 | self.args = args 48 | self.base_model = base_model 49 | self.base_tokenizer = base_tokenizer 50 | self.adapter_name = adapter_name 51 | 52 | @abc.abstractmethod 53 | def forward( 54 | self, 55 | queries: Tensor, 56 | query_attn_masks: Tensor, 57 | responses: Tensor, 58 | temperature: Optional[float] = None, 59 | ) -> Dict[str, Tensor]: 60 | raise NotImplementedError 61 | 62 | def respond( 63 | self, 64 | queries: Tensor, 65 | query_attn_masks: Tensor, 66 | temperature: Optional[float] = None, 67 | num_return_sequences=1, 68 | ) -> Dict[str, Tensor]: 69 | assert not self.training, "Policy must be in eval model for generation." 70 | return self._post_respond( 71 | self._respond(queries, query_attn_masks, temperature, num_return_sequences) 72 | ) 73 | 74 | @abc.abstractmethod 75 | def _respond( 76 | self, 77 | queries: Tensor, 78 | query_attn_masks: Tensor, 79 | temperature: Optional[float] = None, 80 | num_return_sequences=1, 81 | ) -> Dict[str, Tensor]: 82 | raise NotImplementedError 83 | 84 | def _post_respond(self, respond_outputs: Dict[str, Tensor]) -> Dict[str, Tensor]: 85 | return respond_outputs 86 | 87 | 88 | class AutoregressivePolicy(Policy): 89 | def forward( 90 | self, 91 | queries: Tensor, 92 | query_attn_masks: Tensor, 93 | responses: Tensor, 94 | temperature: Optional[float] = None, 95 | ) -> Dict[str, Tensor]: 96 | # TODO(lxuechen): Refactor attention mask. Here query_attn_masks overrides padding-based attention mask. 97 | 98 | if self.adapter_name is not None: 99 | self.base_model.set_adapter(self.adapter_name) 100 | self.base_model.config.use_cache = False 101 | 102 | if temperature is None: 103 | temperature = self.args.temperature 104 | input_ids = torch.cat([queries, responses], dim=1) 105 | attention_mask = input_ids.ne(self.base_tokenizer.pad_token_id) 106 | attention_mask[:, : queries.size(1)] = query_attn_masks 107 | # Fix position id issues and ensure consistency with `respond` for GPT and OPT. 108 | inputs = self.base_model.prepare_inputs_for_generation( 109 | input_ids=input_ids, 110 | attention_mask=attention_mask, 111 | use_cache=False, 112 | ) 113 | outputs = self.base_model(**inputs, output_hidden_states=True) 114 | original_logits = outputs.logits[:, -self.args.response_len - 1 : -1] 115 | logits = original_logits / temperature 116 | labels = input_ids[:, -self.args.response_len :] 117 | logprobs = compute_logprobs( 118 | logits, labels, ignore_index=self.base_tokenizer.pad_token_id 119 | ) 120 | entropies = -(logits.softmax(dim=-1) * logits.log_softmax(dim=-1)).sum(dim=-1) 121 | last_hidden_state = outputs.hidden_states[-1][ 122 | :, -self.args.response_len - 1 : -1 123 | ] 124 | return dict( 125 | original_logits=original_logits, 126 | logits=logits, 127 | logprobs=logprobs, 128 | entropies=entropies, 129 | last_hidden_state=last_hidden_state, 130 | ) 131 | 132 | def _respond( 133 | self, 134 | queries: Tensor, 135 | query_attn_masks: Tensor, 136 | temperature: Optional[float] = None, 137 | num_return_sequences=1, 138 | ) -> Dict[str, Tensor]: 139 | if self.adapter_name is not None: 140 | self.base_model.set_adapter(self.adapter_name) 141 | self.base_model.config.use_cache = True 142 | self.base_model.config.cache_shape = ( 143 | queries.shape[-1] + self.args.response_len, 144 | ) 145 | 146 | if temperature is None: 147 | temperature = self.args.temperature 148 | sequences = self.base_model.generate( 149 | inputs=queries, 150 | attention_mask=query_attn_masks, 151 | do_sample=True, 152 | max_new_tokens=self.args.response_len, 153 | pad_token_id=self.base_tokenizer.pad_token_id, 154 | eos_token_id=self.base_tokenizer.eos_token_id, 155 | top_p=1.0, 156 | top_k=0, 157 | temperature=temperature, 158 | num_return_sequences=num_return_sequences, 159 | synced_gpus=True, 160 | ) 161 | responses = right_pad( 162 | sequences[:, queries.size(1) :], 163 | target_size=(sequences.size(0), self.args.response_len), 164 | value=self.base_tokenizer.pad_token_id, 165 | ) 166 | return dict( 167 | responses=responses 168 | ) # Size (bsz * num_return_sequences, response_len). 169 | 170 | 171 | class Value(nn.Module, abc.ABC): 172 | def __init__( 173 | self, 174 | args, 175 | base_model: transformers.PreTrainedModel, 176 | base_tokenizer: transformers.PreTrainedTokenizer, 177 | adapter_name: Optional[str] = None, 178 | ): 179 | super().__init__() 180 | self.args = args 181 | self.base_model = base_model 182 | self.base_tokenizer = base_tokenizer 183 | hidden_size = get_transformer_hidden_size(base_model) 184 | value_head = torch.nn.Linear(hidden_size, 1) 185 | value_head.weight.data.zero_() 186 | value_head.bias.data.zero_() 187 | self.value_head = value_head.to(next(base_model.parameters()).device) 188 | self.adapter_name = adapter_name 189 | 190 | @abc.abstractmethod 191 | def forward( 192 | self, queries: Tensor, query_attn_masks: Tensor, responses: Tensor 193 | ) -> Dict[str, Tensor]: 194 | raise NotImplementedError 195 | 196 | 197 | class AutoregressiveValue(Value): 198 | def forward( 199 | self, queries: Tensor, query_attn_masks: Tensor, responses: Tensor 200 | ) -> Dict[str, Tensor]: 201 | if self.adapter_name is not None: 202 | self.base_model.set_adapter(self.adapter_name) 203 | self.base_model.config.use_cache = False 204 | 205 | sequences = torch.cat([queries, responses], dim=1) 206 | sequence_attn_masks = sequences.ne(self.base_tokenizer.pad_token_id) 207 | 208 | inputs = self.base_model.prepare_inputs_for_generation( 209 | input_ids=sequences, 210 | attention_mask=sequence_attn_masks, 211 | use_cache=False, 212 | ) 213 | outputs = self.base_model( 214 | **inputs, 215 | return_dict=True, 216 | output_hidden_states=True, 217 | ) 218 | # value[t]: \hat{V}(sequences_{:t-1}); must align with `_estimate_advantage`. 219 | 220 | last_hidden_state = outputs.hidden_states[-1] 221 | assert isinstance(last_hidden_state, torch.Tensor), f"{outputs}" 222 | logits = outputs.logits 223 | # TODO(zhiqings): Hacking to make sure every parameter is used in the backward pass. 224 | last_hidden_state = last_hidden_state + 0.0 * torch.mean(logits) 225 | last_hidden_state = last_hidden_state[:, queries.size(1) - 1 : -1] 226 | 227 | # TODO(zhiqings): now we just manully convert output types 228 | last_hidden_state = last_hidden_state.type_as( 229 | next(self.value_head.parameters()) 230 | ) 231 | values = self.value_head(last_hidden_state).squeeze(-1) 232 | return dict(values=values) 233 | 234 | 235 | class ActorCritic(nn.Module): 236 | def __init__(self, policy: Policy, value_model: Value): 237 | super(ActorCritic, self).__init__() 238 | self.policy = policy 239 | self.value_model = value_model 240 | 241 | def forward( 242 | self, 243 | queries: Tensor, 244 | query_attn_masks: Tensor, 245 | responses: Tensor, 246 | temperature: Optional[float] = None, 247 | mode: Optional[str] = None, 248 | ) -> Dict[str, Tensor]: 249 | # Assume the policy and value model share the same tokenizer. 250 | 251 | if mode is None: 252 | o1 = self.policy(queries, query_attn_masks, responses, temperature) 253 | o2 = self.value_model(queries, query_attn_masks, responses) 254 | 255 | elif mode == "policy": 256 | o1 = self.policy(queries, query_attn_masks, responses, temperature) 257 | # Add dummy loss to make sure every parameter is used in the backward pass. 258 | o2 = { 259 | "dummy_loss": 0.0 260 | * torch.sum( 261 | torch.stack( 262 | [ 263 | torch.mean(value) 264 | for key, value in self.named_parameters() 265 | if "lora_value" in key 266 | ] 267 | ) 268 | ) 269 | } 270 | elif mode == "value": 271 | o2 = self.value_model(queries, query_attn_masks, responses) 272 | # Add dummy loss to make sure every parameter is used in the backward pass. 273 | o1 = { 274 | "dummy_loss": 0.0 275 | * torch.sum( 276 | torch.stack( 277 | [ 278 | torch.mean(value) 279 | for key, value in self.named_parameters() 280 | if "lora_policy" in key 281 | ] 282 | ) 283 | ) 284 | } 285 | else: 286 | raise ValueError(f"Unknown mode: {mode}") 287 | 288 | return {**o1, **o2} 289 | 290 | def respond( 291 | self, 292 | queries: Tensor, 293 | query_attn_masks: Tensor, 294 | temperature: Optional[float] = None, 295 | ) -> Dict[str, Tensor]: 296 | return self.policy.respond( 297 | queries=queries, query_attn_masks=query_attn_masks, temperature=temperature 298 | ) 299 | 300 | 301 | def make_policy_with_base_model( 302 | args, 303 | base_model: transformers.PreTrainedModel, 304 | base_tokenizer: transformers.PreTrainedTokenizer, 305 | adapter_name: Optional[str] = "default", 306 | ) -> Policy: 307 | if base_model.config.is_encoder_decoder: 308 | raise NotImplementedError 309 | else: 310 | return AutoregressivePolicy( 311 | args, base_model, base_tokenizer, adapter_name=adapter_name 312 | ) 313 | 314 | 315 | def make_value_with_base_model( 316 | args, 317 | base_model: transformers.PreTrainedModel, 318 | base_tokenizer: transformers.PreTrainedTokenizer, 319 | adapter_name: Optional[str] = "default", 320 | ) -> Value: 321 | if base_model.config.is_encoder_decoder: 322 | raise NotImplementedError 323 | else: 324 | return AutoregressiveValue( 325 | args, base_model, base_tokenizer, adapter_name=adapter_name 326 | ) 327 | -------------------------------------------------------------------------------- /training/models/trainer_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The Self-Align Team 2 | # Copyright 2023 The Alpaca Team 3 | # Copyright 2022 The HuggingFace Team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | from typing import Optional 18 | 19 | from torch import nn, optim 20 | from transformers import Trainer 21 | from transformers.optimization import get_scheduler 22 | from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS 23 | from transformers.trainer_pt_utils import get_parameter_names 24 | 25 | 26 | def create_optimizer( 27 | args, model: nn.Module, optimizer: Optional[optim.Optimizer] = None 28 | ): 29 | """Create optimizer for trainer. 30 | 31 | This is detached version of the `Trainer.create_optimizer` method. 32 | We don't support sagemaker and fairscale for simplicity. 33 | 34 | Reference: 35 | https://github.com/huggingface/transformers/blob/main/src/transformers/trainer.py 36 | """ 37 | opt_model = model 38 | 39 | if optimizer is None: 40 | decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS) 41 | decay_parameters = [name for name in decay_parameters if "bias" not in name] 42 | optimizer_grouped_parameters = [ 43 | { 44 | "params": [ 45 | p 46 | for n, p in opt_model.named_parameters() 47 | if (n in decay_parameters and p.requires_grad) 48 | ], 49 | "weight_decay": args.weight_decay, 50 | }, 51 | { 52 | "params": [ 53 | p 54 | for n, p in opt_model.named_parameters() 55 | if (n not in decay_parameters and p.requires_grad) 56 | ], 57 | "weight_decay": 0.0, 58 | }, 59 | ] 60 | 61 | optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(args) 62 | 63 | optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) 64 | return optimizer 65 | 66 | 67 | def create_scheduler(args, optimizer, lr_scheduler, num_training_steps): 68 | """Create scheduler for trainer. 69 | 70 | This is detached version of the `Trainer.create_scheduler` method. 71 | 72 | Reference: 73 | https://github.com/huggingface/transformers/blob/main/src/transformers/trainer.py 74 | """ 75 | if lr_scheduler is None: 76 | lr_scheduler = get_scheduler( 77 | args.lr_scheduler_type, 78 | optimizer=optimizer, 79 | num_warmup_steps=args.get_warmup_steps(num_training_steps), 80 | num_training_steps=num_training_steps, 81 | ) 82 | return lr_scheduler 83 | -------------------------------------------------------------------------------- /training/qlora_utils.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | from os.path import exists, join, isdir 4 | import shutil 5 | import sys 6 | from typing import Optional, Dict, Sequence, List 7 | 8 | import torch 9 | import transformers 10 | from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR 11 | 12 | from models.reward_model import RewardModel 13 | 14 | DEFAULT_PAD_TOKEN = "[PAD]" 15 | 16 | 17 | class SavePeftModelCallback(transformers.TrainerCallback): 18 | def save_model(self, args, state, kwargs): 19 | print("Saving PEFT checkpoint...") 20 | 21 | global_rank = int(os.environ.get("RANK", 0)) 22 | 23 | if global_rank == 0: 24 | print("Saving model checkpoint to %s" % args.output_dir) 25 | if state.best_model_checkpoint is not None: 26 | checkpoint_folder = state.best_model_checkpoint 27 | else: 28 | checkpoint_folder = os.path.join( 29 | args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}" 30 | ) 31 | 32 | peft_model_path = os.path.join(checkpoint_folder, "adapter_model") 33 | reward_head_path = os.path.join(checkpoint_folder, "reward_head") 34 | 35 | if isinstance(kwargs["model"], RewardModel): 36 | kwargs["model"].backbone_model.save_pretrained(peft_model_path) 37 | torch.save( 38 | kwargs["model"].reward_head.state_dict(), 39 | reward_head_path, 40 | ) 41 | else: 42 | kwargs["model"].save_pretrained(peft_model_path) 43 | 44 | pytorch_model_paths = glob.glob( 45 | os.path.join(checkpoint_folder, "pytorch_model*.bin") 46 | ) 47 | for pytorch_model_path in pytorch_model_paths: 48 | if os.path.exists(pytorch_model_path): 49 | os.remove(pytorch_model_path) 50 | 51 | optimizer_path = os.path.join(checkpoint_folder, "optimizer.pt") 52 | if os.path.exists(optimizer_path): 53 | os.remove(optimizer_path) 54 | 55 | else: 56 | print("Skipping PEFT checkpoint save on rank %d" % global_rank) 57 | 58 | def on_save(self, args, state, control, **kwargs): 59 | self.save_model(args, state, kwargs) 60 | return control 61 | 62 | def on_train_end(self, args, state, control, **kwargs): 63 | def touch(fname, times=None): 64 | global_rank = int(os.environ.get("RANK", 0)) 65 | if global_rank == 0: 66 | with open(fname, "a"): 67 | os.utime(fname, times) 68 | 69 | touch(join(args.output_dir, "completed")) 70 | self.save_model(args, state, kwargs) 71 | 72 | 73 | def print_trainable_parameters(args, model): 74 | """ 75 | Prints the number of trainable parameters in the model. 76 | """ 77 | trainable_params = 0 78 | all_param = 0 79 | for _, param in model.named_parameters(): 80 | all_param += param.numel() 81 | if param.requires_grad: 82 | trainable_params += param.numel() 83 | if args.bits == 4: 84 | trainable_params /= 2 85 | print( 86 | f"trainable params: {trainable_params} || " 87 | f"all params: {all_param} || " 88 | f"trainable: {100 * trainable_params / all_param}" 89 | ) 90 | 91 | 92 | def get_last_checkpoint(checkpoint_dir): 93 | if isdir(checkpoint_dir): 94 | is_completed = exists(join(checkpoint_dir, "completed")) 95 | if is_completed: 96 | return None, True # already finished 97 | max_step = 0 98 | for filename in os.listdir(checkpoint_dir): 99 | if isdir(join(checkpoint_dir, filename)) and filename.startswith( 100 | "checkpoint" 101 | ): 102 | max_step = max(max_step, int(filename.replace("checkpoint-", ""))) 103 | if max_step == 0: 104 | return None, is_completed # training started, but no checkpoint 105 | checkpoint_dir = join(checkpoint_dir, f"checkpoint-{max_step}") 106 | print(f"Found a previous checkpoint at: {checkpoint_dir}") 107 | return checkpoint_dir, is_completed # checkpoint found! 108 | return None, False # first training 109 | -------------------------------------------------------------------------------- /training/step1_synthetic_preference_collection/batch_generation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the GNU General Public License version 3. 3 | 4 | from typing import Optional 5 | import math 6 | import os 7 | import fire 8 | import time 9 | import tqdm 10 | import json 11 | 12 | from pathlib import Path 13 | 14 | import torch 15 | 16 | from llama_dromedary import Llama 17 | 18 | 19 | def main( 20 | ckpt_dir: str, 21 | tokenizer_path: str, 22 | temperature: float = 1.0, 23 | top_p: float = 1.0, 24 | max_seq_len: int = 512, 25 | max_batch_size: int = 32, 26 | max_shared_seq_len: int = 512, 27 | generate_max_len: int = 128, 28 | group_rank: int = -1, 29 | group_size: int = -1, 30 | input_file: str = None, 31 | output_file: str = None, 32 | meta_prompt_file: str = None, 33 | prompt_style: str = "dromedary", 34 | seed: Optional[int] = None, 35 | ): 36 | assert group_rank >= 0, "Must specify group rank" 37 | assert group_size >= 0, "Must specify group size" 38 | assert ( 39 | input_file is not None and output_file is not None 40 | ), "Must specify input and output files" 41 | assert meta_prompt_file is not None, "Must specify meta prompt file" 42 | 43 | with open(input_file, "r") as f: 44 | inputs = json.load(f) 45 | inputs = inputs[group_rank::group_size] 46 | 47 | if prompt_style == "dromedary": 48 | generate_prompt_fn = generate_dromedary_prompt 49 | else: 50 | raise ValueError(f"Unknown prompt style: {prompt_style}") 51 | 52 | with open(meta_prompt_file, "r") as f: 53 | meta_prompt = f.read().strip() 54 | 55 | generator = Llama.build( 56 | ckpt_dir=ckpt_dir, 57 | tokenizer_path=tokenizer_path, 58 | max_seq_len=max_seq_len, 59 | max_batch_size=max_batch_size, 60 | max_shared_seq_len=max_shared_seq_len, 61 | ) 62 | 63 | if seed is not None: 64 | torch.manual_seed(seed) 65 | 66 | results = [] 67 | # record current progress 68 | 69 | if "shards" not in output_file and group_size > 1: 70 | output_file = output_file.replace( 71 | ".json", f"_{group_size}shards_{group_rank}.json" 72 | ) 73 | 74 | if Path(output_file).exists(): 75 | with open(output_file, "r") as f: 76 | results = f.readlines() 77 | results = [line for line in results if len(line.strip()) > 0] 78 | 79 | inputs = inputs[len(results) :] 80 | print("Skipping %d examples" % len(results)) 81 | 82 | global_rank = int(os.environ.get("RANK", "0")) 83 | batching_inputs = tqdm.tqdm( 84 | BatchIterator(inputs, max_batch_size), 85 | desc="Batched inference", 86 | disable=global_rank > 0, 87 | ) 88 | total_iters = len(inputs) // max_batch_size 89 | 90 | output_handler = None 91 | if global_rank == 0: 92 | output_handler = open(output_file, "a") 93 | 94 | # prepare inputs with batch size $max_batch_size 95 | for iter, batched_inputs in enumerate(batching_inputs): 96 | t0 = time.time() 97 | prompts = [ 98 | generate_prompt_fn( 99 | meta_prompt, ex_input["instruction"].strip(), ex_input["input"].strip() 100 | ) 101 | for ex_input in batched_inputs 102 | ] 103 | 104 | if prompt_style == "llama_2_chat": 105 | outputs = generator.chat_completion( 106 | prompts, # type: ignore 107 | max_gen_len=generate_max_len, 108 | temperature=temperature, 109 | top_p=top_p, 110 | ) 111 | else: 112 | outputs = generator.text_completion( 113 | prompts, 114 | max_gen_len=generate_max_len, 115 | temperature=temperature, 116 | top_p=top_p, 117 | ) 118 | 119 | t1 = time.time() 120 | 121 | results = [] 122 | for ex_input, output in zip(batched_inputs, outputs): 123 | results.append( 124 | { 125 | "instruction": ex_input["instruction"], 126 | "input": ex_input["input"], 127 | "output": output, 128 | } 129 | ) 130 | 131 | if group_rank == 0: 132 | for ex_input, output, _ in zip(batched_inputs, outputs, range(8)): 133 | print("=" * 20, "iter: ", iter, "/", total_iters, "latency: ", t1 - t0) 134 | print(f"Input: {ex_input['instruction']}: {ex_input['input']}") 135 | print(f"Output: {output}") 136 | print() 137 | 138 | if output_handler is not None: 139 | for result in results: 140 | output_handler.write(json.dumps(result) + "\n") 141 | output_handler.flush() 142 | 143 | if output_handler is not None: 144 | output_handler.close() 145 | 146 | 147 | class BatchIterator: 148 | def __init__(self, data, batch_size=1): 149 | self.data = data 150 | self.batch_size = batch_size 151 | self.index = 0 152 | 153 | def __iter__(self): 154 | for i in range(0, len(self.data), self.batch_size): 155 | yield self.data[i : i + self.batch_size] 156 | 157 | def __len__(self): 158 | return math.ceil(len(self.data) / self.batch_size) 159 | 160 | 161 | def generate_dromedary_prompt(meta_prompt, instruction, input=None): 162 | if input: 163 | return f"""{meta_prompt} 164 | {instruction} 165 | 166 | {input} 167 | 168 | ### Dromedary""" 169 | else: 170 | return f"""{meta_prompt} 171 | {instruction} 172 | 173 | ### Dromedary""" 174 | 175 | 176 | if __name__ == "__main__": 177 | fire.Fire(main) 178 | -------------------------------------------------------------------------------- /training/step1_synthetic_preference_collection/clean_oasst1_prompts.py: -------------------------------------------------------------------------------- 1 | import random 2 | import tqdm 3 | import re 4 | import json 5 | from datasets import load_dataset 6 | from transformers import AutoTokenizer 7 | import fire 8 | 9 | 10 | def load_oasst_data(): 11 | oasst_dataset = load_dataset("OpenAssistant/oasst1")["train"] 12 | 13 | def create_message_trees(dataset): 14 | """Create message trees from dataset.""" 15 | # Organize data into dictionary based on parent_id 16 | organized_data = {} 17 | for message in dataset: 18 | parent_id = message["parent_id"] 19 | if parent_id not in organized_data: 20 | organized_data[parent_id] = [] 21 | organized_data[parent_id].append(message) 22 | 23 | # Traverse data to create trees 24 | message_trees = [] 25 | for root_messages in organized_data[None]: 26 | tree = [] 27 | current_message = root_messages 28 | while current_message is not None: 29 | tree.append(current_message) 30 | children = organized_data.get(current_message["message_id"]) 31 | current_message = children[0] if children else None 32 | message_trees.append(tree) 33 | 34 | return message_trees 35 | 36 | oasst_message_trees = create_message_trees(oasst_dataset) 37 | oasst_examples = [] 38 | 39 | count = 0 40 | for oasst_tree in oasst_message_trees: 41 | if len(oasst_tree) >= 2: 42 | count += 1 43 | oasst_examples.append( 44 | { 45 | "instruction": oasst_tree[0]["text"], 46 | "input": "", 47 | "output": "", 48 | } 49 | ) 50 | print("OASST examples:", count) 51 | return oasst_examples 52 | 53 | 54 | def main( 55 | output_file: str = "/path/to/oasst1_prompts.json", 56 | ): 57 | oasst_examples = load_oasst_data() 58 | 59 | with open(output_file, "w") as f: 60 | json.dump(oasst_examples, f, indent=2) 61 | 62 | 63 | if __name__ == "__main__": 64 | fire.Fire(main) 65 | -------------------------------------------------------------------------------- /training/step1_synthetic_preference_collection/scripts/generate_oasst1_response0.sh: -------------------------------------------------------------------------------- 1 | # We use 8 x 6 = 48 V100-32GB GPUs 2 | # On AiMOS cluster [https://docs.cci.rpi.edu/clusters/DCS_Supercomputer/] 3 | # salloc --nodes 8 --time 6:00:00 --gres=gpu:32g:6 srun bash scripts/generate_oasst1_response0.sh 4 | 5 | set -e 6 | set -x 7 | 8 | export PYTHONPATH="$PWD:$PYTHONPATH" 9 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5 10 | export MODEL_DIR="/your/model/dir" 11 | export DATA_DIR="/your/data/dir" 12 | export OMP_NUM_THREADS=6 13 | export GPUS_PER_NODE=6 14 | export NUM_NODES=2 15 | export MASTER_PORT=9901 16 | 17 | LOCAL_NODE_RANK=$((SLURM_PROCID % NUM_NODES)) 18 | GROUP_RANK=$((SLURM_PROCID / NUM_NODES)) 19 | GROUP_SIZE=$((SLURM_NNODES / NUM_NODES)) 20 | SYNC_NODE_RANK=$((GROUP_RANK * NUM_NODES)) 21 | 22 | # MASTER_ADDR should be SYNC_NODE_RANK-th node in $(scontrol show hostnames $SLURM_JOB_NODELIST) 23 | export MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n $((SYNC_NODE_RANK + 1)) | tail -n 1) 24 | 25 | echo "$MASTER_ADDR, $GROUP_RANK / $GROUP_SIZE: $LOCAL_NODE_RANK" 26 | 27 | torchrun --nproc_per_node $GPUS_PER_NODE \ 28 | --nnodes $NUM_NODES \ 29 | --node_rank $LOCAL_NODE_RANK \ 30 | --master_addr $MASTER_ADDR \ 31 | --master_port $MASTER_PORT \ 32 | batch_generation.py \ 33 | --ckpt_dir $MODEL_DIR/dromedary-2-70b-sft-12shard \ 34 | --tokenizer_path $MODEL_DIR/tokenizer.model \ 35 | --generate_max_len 768 \ 36 | --max_seq_len 768 \ 37 | --max_shared_seq_len 640 \ 38 | --max_batch_size 64 \ 39 | --group_rank $GROUP_RANK \ 40 | --group_size $GROUP_SIZE \ 41 | --input_file "$DATA_DIR/oasst1_prompts.json" \ 42 | --output_file "$DATA_DIR/oasst1_dromedary2_sft_response0.json" \ 43 | --meta_prompt_file "../../prompts/synthetic_inference_prompts/dromedary_inference_prompt.txt" \ 44 | --temperature 0.7 \ 45 | --top_p 1.0 \ 46 | --seed 42 47 | -------------------------------------------------------------------------------- /training/step1_synthetic_preference_collection/scripts/generate_oasst1_response1.sh: -------------------------------------------------------------------------------- 1 | # We use 8 x 6 = 48 V100-32GB GPUs 2 | # On AiMOS cluster [https://docs.cci.rpi.edu/clusters/DCS_Supercomputer/] 3 | # salloc --nodes 8 --time 6:00:00 --gres=gpu:32g:6 srun bash scripts/generate_oasst1_response1.sh 4 | 5 | set -e 6 | set -x 7 | 8 | export PYTHONPATH="$PWD:$PYTHONPATH" 9 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5 10 | export MODEL_DIR="/your/model/dir" 11 | export DATA_DIR="/your/data/dir" 12 | export OMP_NUM_THREADS=6 13 | export GPUS_PER_NODE=6 14 | export NUM_NODES=2 15 | export MASTER_PORT=9901 16 | 17 | LOCAL_NODE_RANK=$((SLURM_PROCID % NUM_NODES)) 18 | GROUP_RANK=$((SLURM_PROCID / NUM_NODES)) 19 | GROUP_SIZE=$((SLURM_NNODES / NUM_NODES)) 20 | SYNC_NODE_RANK=$((GROUP_RANK * NUM_NODES)) 21 | 22 | # MASTER_ADDR should be SYNC_NODE_RANK-th node in $(scontrol show hostnames $SLURM_JOB_NODELIST) 23 | export MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n $((SYNC_NODE_RANK + 1)) | tail -n 1) 24 | 25 | echo "$MASTER_ADDR, $GROUP_RANK / $GROUP_SIZE: $LOCAL_NODE_RANK" 26 | 27 | torchrun --nproc_per_node $GPUS_PER_NODE \ 28 | --nnodes $NUM_NODES \ 29 | --node_rank $LOCAL_NODE_RANK \ 30 | --master_addr $MASTER_ADDR \ 31 | --master_port $MASTER_PORT \ 32 | batch_generation.py \ 33 | --ckpt_dir $MODEL_DIR/dromedary-2-70b-sft-12shard \ 34 | --tokenizer_path $MODEL_DIR/tokenizer.model \ 35 | --generate_max_len 768 \ 36 | --max_seq_len 768 \ 37 | --max_shared_seq_len 640 \ 38 | --max_batch_size 64 \ 39 | --group_rank $GROUP_RANK \ 40 | --group_size $GROUP_SIZE \ 41 | --input_file "$DATA_DIR/oasst1_prompts.json" \ 42 | --output_file "$DATA_DIR/oasst1_dromedary2_sft_response1.json" \ 43 | --meta_prompt_file "../../prompts/synthetic_inference_prompts/dromedary_inference_prompt.txt" \ 44 | --temperature 0.7 \ 45 | --top_p 1.0 \ 46 | --seed 43 47 | -------------------------------------------------------------------------------- /training/step1_synthetic_preference_collection/scripts/generate_synthetic_preference.sh: -------------------------------------------------------------------------------- 1 | # We use 1 x 8 = 8 A100-80GB GPUs 2 | # salloc --nodes 1 --time 24:00:00 --gres=gpu:80g:8 srun bash scripts/generate_synthetic_preference.sh 3 | 4 | set -e 5 | set -x 6 | 7 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 8 | export MODEL_DIR="/your/model/dir" 9 | export DATA_DIR="/your/data/dir" 10 | export PYTHONPATH="$PWD:$PYTHONPATH" 11 | export GPUS_PER_NODE=8 12 | export OMP_NUM_THREADS=8 13 | 14 | torchrun \ 15 | --standalone \ 16 | --nnodes=1 \ 17 | --nproc-per-node=$GPUS_PER_NODE \ 18 | synthetic_preference.py \ 19 | --model_name "$MODEL_DIR/llama-2-70b-hf" \ 20 | --adapters_name "$MODEL_DIR/dromedary-2-70b-sft-qlora/adapter_model" \ 21 | --preferece_prompt "../../prompts/synthetic_preference_prompt.txt" \ 22 | --rm_principles "../../prompts/principles/principle_collection_rm.json" \ 23 | --response_pattern "$DATA_DIR/oasst1_dromedary2_sft_response*.json" \ 24 | --output_file "$DATA_DIR/oasst1_dromedary2_sft_preference.json" 25 | -------------------------------------------------------------------------------- /training/step1_synthetic_preference_collection/synthetic_preference.py: -------------------------------------------------------------------------------- 1 | import fcntl 2 | import glob 3 | import json 4 | import os 5 | import random 6 | 7 | import numpy as np 8 | import torch 9 | import tqdm 10 | import fire 11 | 12 | from peft import PeftModel 13 | from transformers import ( 14 | BitsAndBytesConfig, 15 | AutoModelForCausalLM, 16 | LlamaTokenizerFast, 17 | ) 18 | from datasets import load_dataset 19 | 20 | 21 | def main( 22 | model_name: str, 23 | adapters_name: str, 24 | preferece_prompt: str, 25 | rm_principles: str, 26 | response_pattern: str, 27 | output_file: str, 28 | tokenizer_name: str = "TheBloke/dromedary-65b-lora-HF", # a random llama-based model 29 | ): 30 | with open(preferece_prompt, "r") as f: 31 | PREFERENCE_META_PROMPT = f.read().strip() 32 | 33 | with open(rm_principles, "r") as f: 34 | principle_definitions = json.load(f) 35 | 36 | data_files = glob.glob(response_pattern) 37 | cache = "" 38 | raw_data = {} 39 | for data_file in data_files: 40 | type = data_file.split("/")[-1].split("_")[-1].split(".")[0] 41 | raw_data[type] = [] 42 | with open(data_file, "r") as f: 43 | for line in f: 44 | try: 45 | raw_data[type].append(json.loads(cache + line)) 46 | cache = "" 47 | except json.decoder.JSONDecodeError: 48 | cache += line.strip() 49 | raw_data[type] = sorted(raw_data[type], key=lambda x: x["instruction"]) 50 | 51 | print(raw_data.keys()) 52 | 53 | preference_dataset = [] 54 | 55 | for type_idx_a in range(len(raw_data.keys())): 56 | for type_idx_b in range(type_idx_a + 1, len(raw_data.keys())): 57 | type_a = list(raw_data.keys())[type_idx_a] 58 | type_b = list(raw_data.keys())[type_idx_b] 59 | print( 60 | f"len {type_a} and {type_b}", 61 | len(raw_data[type_a]), 62 | len(raw_data[type_b]), 63 | ) 64 | 65 | for idx in range(len(raw_data[type_a])): 66 | data_a = raw_data[type_a][idx] 67 | data_b = raw_data[type_b][idx] 68 | 69 | if data_a["instruction"] != data_b["instruction"]: 70 | print("Instruction mismatch!") 71 | print(data_a["instruction"]) 72 | print(data_b["instruction"]) 73 | exit(0) 74 | 75 | if ( 76 | len(data_a["output"]["generation"]) < 16 77 | or len(data_b["output"]["generation"]) < 16 78 | ): 79 | continue 80 | 81 | preference_dataset.append( 82 | { 83 | "instruction": data_a["instruction"], 84 | "input": data_a["input"], 85 | "output_1": data_a["output"]["generation"] 86 | .split("###")[0] 87 | .strip(), 88 | "output_2": data_b["output"]["generation"] 89 | .split("###")[0] 90 | .strip(), 91 | "preference": 0, 92 | } 93 | ) 94 | 95 | random.Random(42).shuffle(preference_dataset) 96 | 97 | print(f"Starting to load the model {model_name} into memory") 98 | 99 | rank = int(os.environ.get("RANK", 0)) 100 | world_size = int(os.environ.get("WORLD_SIZE", 1)) 101 | torch.cuda.set_device(rank) 102 | 103 | print(preference_dataset[0 + rank]) 104 | 105 | batch_size = 4 # For a single 80GB GPU 106 | m = AutoModelForCausalLM.from_pretrained( 107 | model_name, 108 | load_in_4bit=True, 109 | torch_dtype=torch.float16, 110 | device_map={"": torch.cuda.current_device()}, 111 | quantization_config=BitsAndBytesConfig( 112 | load_in_4bit=True, 113 | load_in_8bit=False, 114 | llm_int8_threshold=6.0, 115 | llm_int8_has_fp16_weight=False, 116 | bnb_4bit_compute_dtype=torch.float16, 117 | bnb_4bit_use_double_quant=True, 118 | bnb_4bit_quant_type="nf4", 119 | ), 120 | ) 121 | m = PeftModel.from_pretrained( 122 | m, 123 | adapters_name, 124 | adapter_name="default", 125 | is_trainable=False, 126 | ) 127 | tok = LlamaTokenizerFast.from_pretrained( 128 | tokenizer_name, 129 | padding_side="left", 130 | truncation_side="left", 131 | model_max_length=1536, 132 | ) 133 | tok.pad_token_id = 0 134 | 135 | print(f"Successfully loaded the model {model_name} into memory") 136 | print_flag = True 137 | 138 | idxs = [] 139 | dimensions = [] 140 | prompts = [] 141 | preferences = [] 142 | 143 | for idx in tqdm.tqdm(range(len(preference_dataset))): 144 | if idx % world_size != rank: 145 | continue 146 | 147 | for principle_definition in principle_definitions: 148 | dimension = principle_definition["dimension"] 149 | definition = principle_definition["definition"] 150 | 151 | data = preference_dataset[idx] 152 | instruction = data["instruction"] 153 | instruction_input = data["input"] 154 | 155 | if instruction_input: 156 | instruction += "\n\n" + instruction_input 157 | 158 | output_1 = data["output_1"] 159 | output_2 = data["output_2"] 160 | preference = data["preference"] 161 | preferences.append(preference) 162 | idxs.append(idx) 163 | dimensions.append(dimension) 164 | 165 | for a, b in [(output_1, output_2), (output_2, output_1)]: 166 | prompt_for_score = PREFERENCE_META_PROMPT.format( 167 | UserInstruction=instruction, 168 | OutputA=a, 169 | OutputB=b, 170 | Dimension=dimension, 171 | Definition=definition, 172 | ) 173 | 174 | if print_flag: 175 | print_flag = False 176 | print(prompt_for_score) 177 | 178 | prompts.append(tok.bos_token + prompt_for_score) 179 | 180 | if len(prompts) == batch_size * 2: 181 | tokenized_input = tok( 182 | prompts, 183 | return_tensors="pt", 184 | add_special_tokens=False, 185 | padding="max_length", 186 | return_attention_mask=True, 187 | truncation=True, 188 | ) 189 | 190 | input_ids = tokenized_input["input_ids"].to(m.device) 191 | attention_mask = tokenized_input["attention_mask"].to(m.device) 192 | 193 | with torch.inference_mode(): 194 | output = m( 195 | input_ids, 196 | attention_mask=attention_mask, 197 | ) 198 | 199 | logits = output.logits 200 | 201 | token_id_a = tok.encode("\n (a", add_special_tokens=False)[-1] 202 | token_id_b = tok.encode("\n (b", add_special_tokens=False)[-1] 203 | 204 | relative_scores = [] 205 | prompt_preferences = [] 206 | 207 | for ex_idx in range(batch_size): 208 | score_a_for_1_2 = logits[ex_idx * 2 + 0, -1, token_id_a] 209 | score_b_for_1_2 = logits[ex_idx * 2 + 0, -1, token_id_b] 210 | score_a_for_2_1 = logits[ex_idx * 2 + 1, -1, token_id_a] 211 | score_b_for_2_1 = logits[ex_idx * 2 + 1, -1, token_id_b] 212 | 213 | relative_score_1_2 = (score_a_for_1_2 - score_b_for_1_2).item() 214 | relative_score_2_1 = (score_b_for_2_1 - score_a_for_2_1).item() 215 | 216 | if relative_score_1_2 > 0.0 and relative_score_2_1 > 0.0: 217 | prompt_preference = 1 218 | elif relative_score_1_2 < 0.0 and relative_score_2_1 < 0.0: 219 | prompt_preference = 2 220 | else: 221 | prompt_preference = 0 222 | 223 | relative_scores.append((relative_score_1_2, relative_score_2_1)) 224 | prompt_preferences.append(prompt_preference) 225 | 226 | outputs = [] 227 | 228 | for ex_idx in range(batch_size): 229 | outputs.append( 230 | { 231 | "example_idx": idxs[ex_idx], 232 | "dimension": dimensions[ex_idx], 233 | "preference": preferences[ex_idx], 234 | "prompt_preference": prompt_preferences[ex_idx], 235 | "relative_score": relative_scores[ex_idx], 236 | } 237 | ) 238 | 239 | with open(output_file, "a") as f: 240 | fcntl.flock(f, fcntl.LOCK_EX) 241 | for output in outputs: 242 | f.write(json.dumps(output) + "\n") 243 | fcntl.flock(f, fcntl.LOCK_UN) 244 | 245 | idxs = [] 246 | dimensions = [] 247 | prompts = [] 248 | preferences = [] 249 | 250 | 251 | if __name__ == "__main__": 252 | fire.Fire(main) 253 | -------------------------------------------------------------------------------- /training/step2_rm_training/aggregate_synthetic_preference.py: -------------------------------------------------------------------------------- 1 | import fcntl 2 | import glob 3 | import json 4 | import os 5 | import random 6 | 7 | import numpy as np 8 | import tqdm 9 | from ftlangdetect import detect 10 | 11 | 12 | ALL_DIMENSIONS = [ 13 | # "Overall Helpful", 14 | "Concise", 15 | "Honest and Accurate", 16 | "Ethical", 17 | # "Natural and Fluent", 18 | "Specific", 19 | "Educational and Engaging", 20 | "Methodical", 21 | # "Multilingual", 22 | "Creative", 23 | "Comprehensive", 24 | ] 25 | 26 | 27 | def main( 28 | response_pattern: str, 29 | preference_file: str, 30 | output_file: str, 31 | noisy_ratio: float = 0.1, 32 | max_principles: int = 8, 33 | tokenizer_name: str = "TheBloke/dromedary-65b-lora-HF", # a random llama-based model 34 | ): 35 | data_files = glob.glob(response_pattern) 36 | 37 | cache = "" 38 | raw_data = {} 39 | for data_file in data_files: 40 | type = data_file.split("/")[-1].split("_")[-1].split(".")[0] 41 | raw_data[type] = [] 42 | with open(data_file, "r") as f: 43 | for line in f: 44 | try: 45 | raw_data[type].append(json.loads(cache + line)) 46 | cache = "" 47 | except json.decoder.JSONDecodeError: 48 | cache += line.strip() 49 | raw_data[type] = sorted(raw_data[type], key=lambda x: x["instruction"]) 50 | 51 | print(raw_data.keys(), len(raw_data[type])) 52 | preference_dataset = [] 53 | 54 | for type_idx_a in range(len(raw_data.keys())): 55 | for type_idx_b in range(type_idx_a + 1, len(raw_data.keys())): 56 | type_a = list(raw_data.keys())[type_idx_a] 57 | type_b = list(raw_data.keys())[type_idx_b] 58 | 59 | for idx in range(len(raw_data[type_a])): 60 | data_a = raw_data[type_a][idx] 61 | data_b = raw_data[type_b][idx] 62 | 63 | if data_a["instruction"] != data_b["instruction"]: 64 | print("Instruction mismatch!") 65 | print(data_a["instruction"]) 66 | print(data_b["instruction"]) 67 | exit(0) 68 | 69 | if ( 70 | len(data_a["output"]["generation"]) < 16 71 | or len(data_b["output"]["generation"]) < 16 72 | ): 73 | continue 74 | 75 | preference_dataset.append( 76 | { 77 | "instruction": data_a["instruction"], 78 | "input": data_a["input"], 79 | "output_1": data_a["output"]["generation"] 80 | .split("###")[0] 81 | .strip(), 82 | "output_2": data_b["output"]["generation"] 83 | .split("###")[0] 84 | .strip(), 85 | "preference": 0, 86 | } 87 | ) 88 | 89 | random.Random(42).shuffle(preference_dataset) 90 | 91 | for example_idx in range(len(preference_dataset)): 92 | preference_dataset[example_idx]["example_idx"] = example_idx 93 | 94 | with open(preference_file, "r") as f: 95 | fcntl.flock(f, fcntl.LOCK_EX) 96 | for line in f: 97 | data = json.loads(line) 98 | example_idx = data["example_idx"] 99 | 100 | if "dimension_scores" not in preference_dataset[example_idx]: 101 | preference_dataset[example_idx]["dimension_scores"] = {} 102 | 103 | dimension = data["dimension"] 104 | score = data["relative_score"] 105 | 106 | preference_dataset[example_idx]["dimension_scores"][dimension] = score 107 | fcntl.flock(f, fcntl.LOCK_UN) 108 | 109 | all_dimensions = ALL_DIMENSIONS 110 | print(preference_dataset[0]) 111 | 112 | # print variance 113 | variances = {} 114 | for dimension in all_dimensions: 115 | scores = [] 116 | for example_idx in range(len(preference_dataset)): 117 | if "dimension_scores" not in preference_dataset[example_idx]: 118 | continue 119 | if dimension not in preference_dataset[example_idx]["dimension_scores"]: 120 | continue 121 | two_scores = preference_dataset[example_idx]["dimension_scores"][dimension] 122 | scores.append(two_scores[0]) 123 | scores.append(two_scores[1]) 124 | variances[dimension] = np.var(scores) 125 | 126 | print("Variance of {}: {}".format(dimension, np.var(scores))) 127 | 128 | clean_preference_dataset = [] 129 | 130 | for example_idx in range(len(preference_dataset)): 131 | if "dimension_scores" not in preference_dataset[example_idx]: 132 | continue 133 | 134 | continue_flag = False 135 | 136 | for dimension in all_dimensions: 137 | if dimension not in preference_dataset[example_idx]["dimension_scores"]: 138 | continue_flag = True 139 | 140 | if continue_flag: 141 | continue 142 | 143 | clean_preference_dataset.append(preference_dataset[example_idx]) 144 | 145 | print(len(clean_preference_dataset)) 146 | 147 | synthetic_data = [] 148 | 149 | main_reasons = {} 150 | for dimension in all_dimensions: 151 | main_reasons[dimension] = 0 152 | 153 | random.seed(42) 154 | 155 | def mean(a, b): 156 | return (a + b) / 2.0 157 | 158 | rule_count = 0 159 | 160 | for i in range(4): 161 | for example_idx in range(len(clean_preference_dataset)): 162 | random.shuffle(all_dimensions) 163 | 164 | # Compute the probabilities 165 | num_dimensions = max_principles + 1 166 | while num_dimensions > max_principles: 167 | num_dimensions = np.random.geometric(0.2, 1)[0] 168 | 169 | data = clean_preference_dataset[example_idx] 170 | 171 | dimensions = all_dimensions[:num_dimensions] 172 | 173 | scores = [] 174 | for dimension in dimensions: 175 | two_scores = data["dimension_scores"][dimension] 176 | if two_scores[0] > 0.0 and two_scores[1] > 0.0: 177 | scores.append( 178 | mean(two_scores[0], two_scores[1]) / variances[dimension] 179 | ) 180 | elif two_scores[0] < 0.0 and two_scores[1] < 0.0: 181 | scores.append( 182 | mean(two_scores[0], two_scores[1]) / variances[dimension] 183 | ) 184 | else: 185 | scores.append(0.0) 186 | 187 | # negative_definitions is random benoulli of length num_dimensions 188 | negative_definitions = np.random.binomial(1, 0.5, num_dimensions) 189 | 190 | scores = [ 191 | -score if negative_definitions[idx] == 1 else score 192 | for idx, score in enumerate(scores) 193 | ] 194 | 195 | # The preference is based on the absolute value of the scores. 196 | doinate_dimension_idx = np.argmax(np.abs(scores)).item() 197 | 198 | score = scores[doinate_dimension_idx] 199 | 200 | if score > 0.0: 201 | preference = 1 202 | elif score < 0.0: 203 | preference = 2 204 | else: 205 | continue 206 | 207 | order = random.randint(0, 1) 208 | 209 | rule_flag = False 210 | 211 | # We ban "as an AI language model" 212 | if "As an AI language model" in data["output_1"]: 213 | if "As an AI language model" in data["output_2"]: 214 | pass 215 | else: 216 | if random.randint(0, 1) > 0.5: 217 | preference = 2 218 | rule_flag = True 219 | else: 220 | if "As an AI language model" in data["output_2"]: 221 | if random.randint(0, 1) > 0.5: 222 | preference = 1 223 | rule_flag = True 224 | else: 225 | pass 226 | 227 | lang_instruction = detect(data["instruction"].replace("\n", " "))["lang"] 228 | lang_output_1 = detect(data["output_1"].replace("\n", " "))["lang"] 229 | lang_output_2 = detect(data["output_2"].replace("\n", " "))["lang"] 230 | 231 | # We ban different languages 232 | if lang_instruction != lang_output_1: 233 | if lang_instruction != lang_output_2: 234 | continue 235 | else: 236 | preference = 2 237 | rule_flag = True 238 | else: 239 | if lang_instruction != lang_output_2: 240 | preference = 1 241 | rule_flag = True 242 | else: 243 | pass 244 | 245 | data_tuple = (data["instruction"], data["output_1"], data["output_1"]) 246 | 247 | if rule_flag: 248 | rule_count += 1 249 | flip = random.random() < noisy_ratio * 2.5 # 2.5 is a magic number 250 | else: 251 | flip = random.random() < noisy_ratio 252 | main_reasons[dimensions[doinate_dimension_idx]] += 1 253 | 254 | preference = 3 - preference if flip else preference 255 | 256 | synthetic_data.append( 257 | { 258 | "doinate_dimension_idx": doinate_dimension_idx, 259 | "dimensions": str(tuple(dimensions)), 260 | "dimension_scores": str( 261 | tuple(scores if order == 0 else [-score for score in scores]) 262 | ), 263 | "negative_definitions": str(tuple(negative_definitions)), 264 | "instruction": data["instruction"], 265 | "input": data["input"], 266 | "output_1": data["output_1"] if order == 0 else data["output_2"], 267 | "output_2": data["output_2"] if order == 0 else data["output_1"], 268 | "preference": preference if order == 0 else 3 - preference, 269 | "flip": flip, 270 | "lang": (lang_instruction, lang_output_1, lang_output_2) 271 | if order == 0 272 | else (lang_instruction, lang_output_2, lang_output_1), 273 | "rule_flag": rule_flag, 274 | } 275 | ) 276 | 277 | print(rule_count) 278 | print(len(synthetic_data)) 279 | print(main_reasons) 280 | 281 | with open(output_file, "w") as f: 282 | f.write( 283 | json.dumps( 284 | synthetic_data, 285 | indent=4, 286 | ) 287 | ) 288 | -------------------------------------------------------------------------------- /training/step2_rm_training/clean_pmp_data.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | import tqdm 4 | from datasets import load_dataset 5 | import fire 6 | 7 | 8 | def find_shared_prefix(str1, str2): 9 | """ 10 | Find the shared prefix between two strings. 11 | 12 | Args: 13 | str1 (str): The first string. 14 | str2 (str): The second string. 15 | 16 | Returns: 17 | int: The index position where the shared prefix ends in the first string. 18 | """ 19 | # Find the minimum length of both strings 20 | min_length = min(len(str1), len(str2)) 21 | 22 | max_common_prefix = 0 23 | # Loop through characters to find where they start to differ 24 | for i in range(min_length): 25 | if str1[i] != str2[i]: 26 | max_common_prefix = i 27 | break 28 | 29 | while ( 30 | not str1[:max_common_prefix].endswith("Assistant:") 31 | # Force the prefix to end with "Assistant:" 32 | ) and max_common_prefix > 0: 33 | max_common_prefix -= 1 34 | 35 | return max_common_prefix 36 | 37 | 38 | def split_example(example): 39 | """ 40 | Split an example into a shared prefix and the remaining parts of the strings. 41 | 42 | Args: 43 | example (dict): A dictionary containing 'chosen' and 'rejected' keys with strings as values. 44 | 45 | Returns: 46 | dict: A dictionary with keys 'query', 'output_1', and 'output_2'. 47 | """ 48 | chosen = example["chosen"] 49 | rejected = example["rejected"] 50 | 51 | # Find the index where the shared prefix ends 52 | shared_index = find_shared_prefix(chosen, rejected) 53 | 54 | # Split the strings 55 | query = chosen[:shared_index].strip() 56 | output_1 = chosen[shared_index:].strip() 57 | output_2 = rejected[shared_index:].strip() 58 | 59 | # Return the result as a dictionary 60 | return {"query": query, "output_1": output_1, "output_2": output_2} 61 | 62 | 63 | def main( 64 | output_file: str, 65 | ): 66 | shp_dataset = load_dataset("stanfordnlp/SHP")["train"] 67 | 68 | merged_data = [] 69 | 70 | for ex in tqdm.tqdm(shp_dataset): 71 | qid = ex["post_id"] 72 | score = ex["score_ratio"] 73 | 74 | ex_dp = { 75 | "score": score, 76 | "instruction": ex["history"], 77 | "input": "", 78 | "output_1": ex["human_ref_A"], 79 | "output_2": ex["human_ref_B"], 80 | "preference": 1 if ex["labels"] == 1 else 2, 81 | } 82 | 83 | if score > 2.0: 84 | merged_data.append(ex_dp) 85 | 86 | print(merged_data[-1]) 87 | print(len(merged_data)) 88 | 89 | hh_dataset = load_dataset("Anthropic/hh-rlhf")["train"] 90 | 91 | for ex in hh_dataset: 92 | processed_ex = split_example(ex) 93 | 94 | random_swap = random.random() < 0.5 95 | 96 | merged_data.append( 97 | { 98 | "instruction": processed_ex["query"].strip(), 99 | "input": "", 100 | "output_1": processed_ex["output_1"] 101 | if random_swap 102 | else processed_ex["output_2"], 103 | "output_2": processed_ex["output_2"] 104 | if random_swap 105 | else processed_ex["output_1"], 106 | "preference": 1 if random_swap else 2, 107 | } 108 | ) 109 | 110 | print(merged_data[-1]) 111 | print(len(merged_data)) 112 | 113 | random.shuffle(merged_data) 114 | 115 | with open(output_file, "w") as f: 116 | f.write( 117 | json.dumps( 118 | merged_data, 119 | ) 120 | ) 121 | 122 | 123 | if __name__ == "__main__": 124 | fire.Fire(main) 125 | -------------------------------------------------------------------------------- /training/step2_rm_training/scripts/train_reward_model_70b_qlora_ft.sh: -------------------------------------------------------------------------------- 1 | # We use 1 x 8 = 8 A100-80GB GPUs 2 | 3 | set -e 4 | set -x 5 | 6 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 7 | export MODEL_DIR="/your/model/dir" 8 | export DATA_DIR="/your/data/dir" 9 | export PYTHONPATH="$PWD:$PYTHONPATH" 10 | export GPUS_PER_NODE=8 11 | export OMP_NUM_THREADS=8 12 | 13 | NUM_EPOCHS=1 14 | LEARNING_RATE=3e-5 15 | BATCH_SIZE=4 16 | GRAD_ACCUMULATION=2 17 | 18 | PMP_CKPT=llama-2-70b-qlora-rm-pmp/checkpoint-xxxx 19 | 20 | cd .. 21 | 22 | torchrun \ 23 | --standalone \ 24 | --nnodes=1 \ 25 | --nproc-per-node=$GPUS_PER_NODE \ 26 | finetune_qlora_rm.py \ 27 | --do_train \ 28 | --do_eval \ 29 | --seed 42 \ 30 | --per_device_train_batch_size $BATCH_SIZE \ 31 | --per_device_eval_batch_size $BATCH_SIZE \ 32 | --gradient_accumulation_steps $GRAD_ACCUMULATION \ 33 | --model_name_or_path "$MODEL_DIR/llama-2-70b-hf" \ 34 | --learning_rate $LEARNING_RATE \ 35 | --model_max_length 1280 \ 36 | --query_len 128 \ 37 | --response_len 768 \ 38 | --dataset_path "$DATA_DIR/oasst1_dromedary2_sft_aggregated_preference.json" \ 39 | --principle_collection_path "../principles/principle_collection_rm.json" \ 40 | --meta_prompt_pattern "../prompts/salmon_reward_model_prompt_v0.txt" \ 41 | --double_quant True \ 42 | --quant_type "nf4" \ 43 | --bits 4 \ 44 | --lora_r 64 \ 45 | --output_dir "$MODEL_DIR/llama-2-70b-qlora-rm-pmp-sft" \ 46 | --resume_dir "$MODEL_DIR/$PMP_CKPT" \ 47 | --num_train_epochs $NUM_EPOCHS \ 48 | --group_by_length False \ 49 | --evaluation_strategy "steps" \ 50 | --eval_steps 20 \ 51 | --save_strategy "steps" \ 52 | --save_steps 40 \ 53 | --save_total_limit 20 \ 54 | --weight_decay 0.0 \ 55 | --warmup_ratio 0.03 \ 56 | --lr_scheduler_type "cosine" \ 57 | --logging_steps 5 \ 58 | --report_to "tensorboard" \ 59 | --ddp_backend "nccl" \ 60 | --bf16 True \ 61 | --ddp_find_unused_parameters False 62 | -------------------------------------------------------------------------------- /training/step2_rm_training/scripts/train_reward_model_70b_qlora_pmp.sh: -------------------------------------------------------------------------------- 1 | # We use 1 x 8 = 8 A100-80GB GPUs 2 | 3 | set -e 4 | set -x 5 | 6 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 7 | export MODEL_DIR="/your/model/dir" 8 | export DATA_DIR="/your/data/dir" 9 | export PYTHONPATH="$PWD:$PYTHONPATH" 10 | export GPUS_PER_NODE=8 11 | export OMP_NUM_THREADS=8 12 | 13 | NUM_EPOCHS=1 14 | LEARNING_RATE=5e-5 15 | BATCH_SIZE=4 16 | GRAD_ACCUMULATION=2 17 | 18 | cd .. 19 | 20 | torchrun \ 21 | --standalone \ 22 | --nnodes=1 \ 23 | --nproc-per-node=$GPUS_PER_NODE \ 24 | finetune_qlora_rm.py \ 25 | --do_train \ 26 | --do_eval \ 27 | --seed 42 \ 28 | --per_device_train_batch_size $BATCH_SIZE \ 29 | --per_device_eval_batch_size $BATCH_SIZE \ 30 | --gradient_accumulation_steps $GRAD_ACCUMULATION \ 31 | --model_name_or_path "$MODEL_DIR/llama-2-70b-hf" \ 32 | --learning_rate $LEARNING_RATE \ 33 | --model_max_length 1024 \ 34 | --query_len 256 \ 35 | --response_len 640 \ 36 | --dataset_path "$DATA_DIR/pmp_data.json" \ 37 | --meta_prompt_pattern "../prompts/pmp_reward_model_prompt.txt" \ 38 | --double_quant True \ 39 | --quant_type "nf4" \ 40 | --bits 4 \ 41 | --lora_r 64 \ 42 | --output_dir "$MODEL_DIR/llama-2-70b-qlora-rm-pmp" \ 43 | --num_train_epochs $NUM_EPOCHS \ 44 | --group_by_length False \ 45 | --evaluation_strategy "steps" \ 46 | --eval_steps 20 \ 47 | --save_strategy "steps" \ 48 | --save_steps 100 \ 49 | --save_total_limit 10 \ 50 | --weight_decay 0.0 \ 51 | --warmup_ratio 0.01 \ 52 | --lr_scheduler_type "cosine" \ 53 | --logging_steps 5 \ 54 | --report_to "tensorboard" \ 55 | --ddp_backend "nccl" \ 56 | --bf16 True \ 57 | --ddp_find_unused_parameters False \ 58 | --resume_from_training True 59 | -------------------------------------------------------------------------------- /training/step3_ppo_training/aggregate_sharegpt_prompts.py: -------------------------------------------------------------------------------- 1 | import json 2 | import fire 3 | from datasets import load_dataset 4 | 5 | DEFAULT_SHAREGPT_DATA = [ 6 | # https://huggingface.co/datasets/zetavg/ShareGPT-Processed 7 | "zetavg/ShareGPT-Processed", 8 | # https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/blob/main/HTML_cleaned_raw_dataset/sg_90k_part1.json 9 | "../../outputs/data/sharegpt/sg_90k_part1.json", 10 | # https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/blob/main/HTML_cleaned_raw_dataset/sg_90k_part2.json 11 | "../../outputs/data/sharegpt/sg_90k_part2.json", 12 | # "../../outputs/data/sharegpt/ShareGPT.jsonl", 13 | # "../../outputs/data/sharegpt/ShareGPT_V3_unfiltered_cleaned_split.json", 14 | ] 15 | 16 | 17 | def extract_from_dataset(dataset, key="markdown", prefix="human"): 18 | data = [] 19 | 20 | for item in dataset: 21 | conversations = item["conversations"] 22 | if len(conversations) >= 2: 23 | user_input = conversations[0] 24 | 25 | if user_input["from"] == prefix: 26 | data.append( 27 | { 28 | "instruction": user_input[key], 29 | "input": "", 30 | "output": "", 31 | } 32 | ) 33 | 34 | return data 35 | 36 | 37 | def extract_from_json_file(filepath, key="value", prefix="human"): 38 | with open(filepath, "r") as f: 39 | dataset = json.load(f) 40 | 41 | return extract_from_dataset(dataset, key, prefix) 42 | 43 | 44 | def extract_from_jsonl_file(filepath): 45 | data = [] 46 | 47 | with open(filepath, "r") as f: 48 | for line in f: 49 | line_data = json.loads(line) 50 | user_input = line_data["text"].split("")[0][len(":") :].strip() 51 | 52 | data.append( 53 | { 54 | "instruction": user_input, 55 | "input": "", 56 | "output": "", 57 | } 58 | ) 59 | 60 | return data 61 | 62 | 63 | def main(data_files: list = None, output_file: str = "path/to/sharegpt_prompts.json"): 64 | if data_files is None: 65 | data_files = DEFAULT_SHAREGPT_DATA 66 | 67 | all_data = [] 68 | merged_data = [] 69 | merged_data_set = set() 70 | 71 | for filepath in data_files: 72 | if filepath.endswith(".json"): 73 | all_data.append(extract_from_json_file(filepath)) 74 | elif filepath.endswith(".jsonl"): 75 | all_data.append(extract_from_jsonl_file(filepath)) 76 | else: 77 | train_dataset = load_dataset(filepath)["train"] 78 | all_data.append(extract_from_dataset(train_dataset)) 79 | 80 | # Merge and filter unique instructions 81 | for data in all_data: 82 | filtered_data_count = 0 83 | for item in data: 84 | if item["instruction"] not in merged_data_set: 85 | merged_data_set.add(item["instruction"]) 86 | merged_data.append(item) 87 | else: 88 | filtered_data_count += 1 89 | 90 | print("Filtered data count:", filtered_data_count) 91 | print("Merged data size:", len(merged_data)) 92 | 93 | with open(output_file, "w") as f: 94 | json.dump(merged_data, f, indent=2) 95 | 96 | 97 | if __name__ == "__main__": 98 | fire.Fire(main) 99 | -------------------------------------------------------------------------------- /training/step3_ppo_training/clean_and_merge_prompts.py: -------------------------------------------------------------------------------- 1 | import random 2 | import tqdm 3 | import re 4 | import json 5 | from datasets import load_dataset 6 | from transformers import AutoTokenizer 7 | import fire 8 | 9 | # QUESTION TYPE 10 | # 0: GENERAL (ShareGPT + Dolly + OASST) 11 | # 1: REASONING (OpenOrca + MATH) 12 | # 2: REDTEAMING (Anthropic Red Teaming) 13 | 14 | CHATGPT_LANGUAGES = { 15 | "it": "Italian", 16 | "pl": "Polish", 17 | "ru": "Russian", 18 | "sk": "Slovak", 19 | "pt": "Portuguese", 20 | "ro": "Romanian", 21 | "da": "Danish", 22 | "sv": "Swedish", 23 | "no": "Norwegian", 24 | "en": "English", 25 | "es": "Spanish", 26 | "fr": "French", 27 | "cs": "Czech", 28 | "de": "German", 29 | "fi": "Finnish", 30 | "et": "Estonian", 31 | "lv": "Latvian", 32 | "lt": "Lithuanian", 33 | "fa": "Persian", 34 | "hu": "Hungarian", 35 | "he": "Hebrew", 36 | "el": "Greek", 37 | "ar": "Arabic", 38 | "kr": "Korean", 39 | "ja": "Japanese", 40 | "zh": "Chinese", 41 | "zh-traditional": "Chinese (Traditional)", 42 | "zh-simplified": "Chinese (Simplified)", 43 | "th": "Thai", 44 | "vi": "Vietnamese", 45 | } 46 | 47 | 48 | def remove_leading_fraction(input_string): 49 | # remove leading fraction 50 | cleaned_string = re.sub(r"^\d+\s*/\s*\d+", "", input_string) 51 | cleaned_string = re.sub(r"\d+\s*/\s*\d+$", "", cleaned_string) 52 | cleaned_string = cleaned_string.split("1 / 1", 1)[-1] 53 | 54 | # \uc9c0\uae08 \ubc88\uc5ed\ud558\uae30 55 | cleaned_string = cleaned_string.split("지금 번역하기")[0] 56 | 57 | # Language: English 58 | cleaned_string = cleaned_string.split("\n \n Language: ")[0] 59 | 60 | cleaned_string = cleaned_string.strip() 61 | 62 | if cleaned_string.endswith("Share Prompt"): 63 | cleaned_string = cleaned_string[: -len("Share Prompt")].strip() 64 | 65 | if cleaned_string.endswith("Translate now"): 66 | cleaned_string = cleaned_string[: -len("Translate now")].strip() 67 | 68 | for lang_code in CHATGPT_LANGUAGES: 69 | lang_suffix = f"Language: {CHATGPT_LANGUAGES[lang_code]}" 70 | if cleaned_string.endswith(lang_suffix): 71 | cleaned_string = cleaned_string[: -len(lang_suffix)].strip() 72 | 73 | # ~The following is a conversation with Bing, not ChatGPT.~ 74 | if cleaned_string.startswith( 75 | "~The following is a conversation with Bing, not ChatGPT.~" 76 | ): 77 | cleaned_string = cleaned_string[ 78 | len("~The following is a conversation with Bing, not ChatGPT.~") : 79 | ].strip() 80 | 81 | return cleaned_string 82 | 83 | 84 | def load_dolly_data(length_bonus): 85 | dataset = load_dataset("databricks/databricks-dolly-15k")["train"] 86 | category_to_examples = {} 87 | for example in dataset: 88 | category = example["category"] 89 | if category not in category_to_examples: 90 | category_to_examples[category] = [] 91 | category_to_examples[category].append(example) 92 | 93 | merged_examples = [] 94 | for data in [ 95 | category_to_examples["creative_writing"], 96 | category_to_examples["brainstorming"], 97 | category_to_examples["open_qa"], 98 | category_to_examples["general_qa"], 99 | category_to_examples["classification"], 100 | ]: 101 | for i in range(len(data)): 102 | assert data[i]["context"] == "" 103 | merged_examples.append( 104 | { 105 | "instruction": data[i]["instruction"], 106 | "input": "", 107 | "output": "", 108 | "length_bonus": length_bonus, 109 | "question_type": 0, 110 | } 111 | ) 112 | print("Dolly examples:", len(merged_examples)) 113 | return merged_examples 114 | 115 | 116 | def load_oasst_data(length_bonus): 117 | oasst_dataset = load_dataset("OpenAssistant/oasst1")["train"] 118 | 119 | def create_message_trees(dataset): 120 | """Create message trees from dataset.""" 121 | # Organize data into dictionary based on parent_id 122 | organized_data = {} 123 | for message in dataset: 124 | parent_id = message["parent_id"] 125 | if parent_id not in organized_data: 126 | organized_data[parent_id] = [] 127 | organized_data[parent_id].append(message) 128 | 129 | # Traverse data to create trees 130 | message_trees = [] 131 | for root_messages in organized_data[None]: 132 | tree = [] 133 | current_message = root_messages 134 | while current_message is not None: 135 | tree.append(current_message) 136 | children = organized_data.get(current_message["message_id"]) 137 | current_message = children[0] if children else None 138 | message_trees.append(tree) 139 | 140 | return message_trees 141 | 142 | oasst_message_trees = create_message_trees(oasst_dataset) 143 | oasst_examples = [] 144 | 145 | count = 0 146 | for oasst_tree in oasst_message_trees: 147 | if len(oasst_tree) >= 2: 148 | count += 1 149 | oasst_examples.append( 150 | { 151 | "instruction": oasst_tree[0]["text"], 152 | "input": "", 153 | "output": "", 154 | "length_bonus": length_bonus, 155 | "question_type": 0, 156 | } 157 | ) 158 | print("OASST examples:", count) 159 | return oasst_examples 160 | 161 | 162 | def load_math_data(length_bonus): 163 | dataset = load_dataset("competition_math") 164 | train_data = dataset["train"] 165 | max_samples_per_dataset = 10000 166 | 167 | math_data = [] 168 | 169 | for i in range(min(len(train_data), max_samples_per_dataset)): 170 | ex_instruction = train_data[i]["problem"].strip() 171 | math_data.append( 172 | { 173 | "instruction": ex_instruction, 174 | "input": "", 175 | "output": "", 176 | "length_bonus": length_bonus, 177 | "question_type": 1, 178 | } 179 | ) 180 | return math_data 181 | 182 | 183 | def filter_and_clean_examples(merged_examples): 184 | tokenizer = AutoTokenizer.from_pretrained("TheBloke/dromedary-65b-lora-HF") 185 | max_seq_length = 256 186 | filtered_examples = [] 187 | set_of_unique_instructions = set() 188 | 189 | # Filter out examples with non-ascci characters 190 | merged_examples = [ 191 | { 192 | "instruction": remove_leading_fraction(example["instruction"]), 193 | "input": "", 194 | "output": "", 195 | "length_bonus": example["length_bonus"], 196 | "question_type": example["question_type"], 197 | } 198 | for example in merged_examples 199 | ] 200 | 201 | for example in tqdm.tqdm(merged_examples): 202 | instruction = example["instruction"] 203 | instruction_token_length = len(tokenizer.encode(instruction)) 204 | if ( 205 | 2 <= instruction_token_length <= max_seq_length 206 | and instruction not in set_of_unique_instructions 207 | ): 208 | filtered_examples.append(example) 209 | set_of_unique_instructions.add(instruction) 210 | 211 | return filtered_examples 212 | 213 | 214 | def load_json(share_gpt_data_path): 215 | with open(share_gpt_data_path, "r") as f: 216 | share_gpt_data = json.load(f) 217 | examples = [] 218 | for data in share_gpt_data: 219 | examples.append( 220 | { 221 | "instruction": data["instruction"], 222 | "input": "", 223 | "output": "", 224 | } 225 | ) 226 | print(f"{share_gpt_data_path} examples:", len(examples)) 227 | return examples 228 | 229 | 230 | def main( 231 | sharegpt_prompt_path: str = "/path/to/sharegpt_prompts.json", 232 | openorca_prompt_path: str = "/path/to/openorca_prompts.json", 233 | output_file: str = "/path/to/salmon_merged_prompts.json", 234 | general_length_bonus: float = 1.0, 235 | reasoning_length_bonus: float = -0.4, 236 | redteaming_length_bonus: float = 0.0, 237 | ): 238 | del redteaming_length_bonus # unused 239 | 240 | dolly_examples = load_dolly_data(length_bonus=general_length_bonus) 241 | oasst_examples = load_oasst_data(length_bonus=general_length_bonus) 242 | 243 | sharegpt_examples = load_json(sharegpt_prompt_path) 244 | sharegpt_examples = [ 245 | { 246 | "instruction": example["instruction"], 247 | "input": "", 248 | "output": "", 249 | "length_bonus": general_length_bonus, 250 | "question_type": 0, 251 | } 252 | for example in sharegpt_examples 253 | ] 254 | 255 | openorca_examples = load_json(openorca_prompt_path) 256 | openorca_examples = [ 257 | { 258 | "instruction": example["instruction"], 259 | "input": "", 260 | "output": "", 261 | "length_bonus": reasoning_length_bonus, 262 | "question_type": 1, 263 | } 264 | for example in openorca_examples 265 | ] 266 | math_examples = load_math_data(length_bonus=reasoning_length_bonus) 267 | 268 | merged_examples = ( 269 | dolly_examples 270 | + oasst_examples 271 | + sharegpt_examples 272 | + openorca_examples 273 | + math_examples 274 | ) 275 | filtered_examples = filter_and_clean_examples(merged_examples) 276 | 277 | print("Total examples:", len(filtered_examples)) 278 | 279 | with open(output_file, "w") as f: 280 | json.dump(filtered_examples, f, indent=2) 281 | 282 | 283 | if __name__ == "__main__": 284 | fire.Fire(main) 285 | -------------------------------------------------------------------------------- /training/step3_ppo_training/scripts/train_ppo_model_70b_qlora_salmon.sh: -------------------------------------------------------------------------------- 1 | # We use 6 x 8 = 48 A100-80GB GPUs 2 | # salloc --nodes 6 --time 24:00:00 --gres=gpu:80g:8 srun bash scripts/train_ppo_model_70b_qlora_salmon.sh 3 | 4 | set -e 5 | set -x 6 | 7 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 8 | export MODEL_DIR="/your/model/dir" 9 | export DATA_DIR="/your/data/dir" 10 | export PYTHONPATH="$PWD:$PYTHONPATH" 11 | export GPUS_PER_NODE=8 12 | export OMP_NUM_THREADS=8 13 | export TOKENIZERS_PARALLELISM=false 14 | export CUDA_LAUNCH_BLOCKING=1 # Not sure if this is needed 15 | 16 | LEARNING_RATE=2e-5 17 | KL_COEF=0.01 18 | EPOCH=2 19 | ROLLOUT_BATCH_SIZE=576 20 | STEP_BATCH_SZIE=288 21 | ROLLOUT_PER_DEVICE_BATCH_SIZE=6 22 | REWARD_MODEL_PER_DEVICE_BATCH_SIZE=3 23 | STEP_PER_DEVICE_BATCH_SIZE=6 24 | 25 | JOB_ID=29400 26 | NUM_NODES=6 27 | NUM_TRAINERS=8 28 | 29 | export MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | tail -n 1) 30 | 31 | HOST_NODE_ADDR="$MASTER_ADDR:29400" 32 | 33 | # SALMON-SPECIFIC HYPER-PARAMETERS 34 | MAX_PRINCIPLES=3 35 | UNFINISHED_PENALTY=-6.0 36 | LENGTH_BONUS=5.0 37 | LENGTH_BONUS_UPPER_BOUND=0.8 38 | SAMPLING_TEMPARATURE=0.7 39 | 40 | cd .. 41 | 42 | torchrun \ 43 | --nnodes=$NUM_NODES \ 44 | --nproc-per-node=$NUM_TRAINERS \ 45 | --rdzv-id=$JOB_ID \ 46 | --rdzv-backend=c10d \ 47 | --rdzv-endpoint=$HOST_NODE_ADDR \ 48 | finetune_qlora_ppo.py \ 49 | --do_train \ 50 | --do_eval \ 51 | --seed 42 \ 52 | --step_batch_size $STEP_BATCH_SZIE \ 53 | --step_per_device_batch_size $STEP_PER_DEVICE_BATCH_SIZE \ 54 | --rollout_batch_size $ROLLOUT_BATCH_SIZE \ 55 | --rollout_per_device_batch_size $ROLLOUT_PER_DEVICE_BATCH_SIZE \ 56 | --reward_model_per_device_batch_size $REWARD_MODEL_PER_DEVICE_BATCH_SIZE \ 57 | --per_device_eval_batch_size $ROLLOUT_PER_DEVICE_BATCH_SIZE \ 58 | --policy_model_name_or_path "$MODEL_DIR/dromedary-2-70b-sft-qlora/adapter_model" \ 59 | --reward_model_name_or_path "$MODEL_DIR/llama-2-70b-qlora-rm-pmp-sft/adapter_model" \ 60 | --learning_rate $LEARNING_RATE \ 61 | --init_value_with_reward True \ 62 | --warmup_steps 5 \ 63 | --dataset_path "$DATA_DIR/salmon_merged_prompts.json" \ 64 | --train_splits "train" \ 65 | --policy_meta_prompt_pattern "../prompts/dromedary_inference_prompt.txt" \ 66 | --reward_meta_prompt_pattern "../prompts/salmon_reward_model_prompt_v1.txt" \ 67 | --principle_collection_path "../prompts/principles/principle_collection_ppo.json" \ 68 | --max_principles $MAX_PRINCIPLES \ 69 | --stop_token "\n\n### User" \ 70 | --output_dir "$NEW_MODEL_DIR/dromedary-2-70b-ppo-qlora-salmon" \ 71 | --total_epochs $EPOCH \ 72 | --group_by_length False \ 73 | --evaluation_strategy "no" \ 74 | --save_strategy "steps" \ 75 | --save_steps 10 \ 76 | --save_total_limit 100000 \ 77 | --weight_decay 0.0 \ 78 | --lr_scheduler_type "cosine" \ 79 | --logging_steps 1 \ 80 | --report_to "tensorboard" \ 81 | --ddp_backend "nccl" \ 82 | --bf16 True \ 83 | --penalty_reward_value $UNFINISHED_REWARD_PENALTY \ 84 | --length_bonus_score $LENGTH_BONUS \ 85 | --length_bonus_upper_bound $LENGTH_BONUS_UPPER_BOUND \ 86 | --relative_stop_token_penalty True \ 87 | --penalize_no_stop_token True \ 88 | --ddp_find_unused_parameters False \ 89 | --resume_from_training True \ 90 | --kl_coef $KL_COEF \ 91 | --max_grad_norm 1.0 \ 92 | --whitening_async_stats "full_batch" \ 93 | --clean_tokens_after_eos True \ 94 | --whiten_rewards False \ 95 | --query_len 256 \ 96 | --response_len 1024 \ 97 | --model_max_length 1664 \ 98 | --enable_reasoning_principles True \ 99 | --enable_redteaming_principles True \ 100 | --temperature $SAMPLING_TEMPARATURE \ 101 | --noptepochs 2 102 | -------------------------------------------------------------------------------- /training/step3_ppo_training/subsample_openorca_prompts.py: -------------------------------------------------------------------------------- 1 | import tqdm 2 | import json 3 | import random 4 | import pandas as pd 5 | import fire 6 | from transformers import LlamaTokenizerFast 7 | 8 | # https://huggingface.co/datasets/Open-Orca/OpenOrca/blob/main/1M-GPT4-Augmented.parquet 9 | 10 | 11 | def main( 12 | train_data_path: str = "path/to/data/1M-GPT4-Augmented.parquet", 13 | output_path: str = "path/to/data/openorca_prompts.json", 14 | tokenizer_name: str = "TheBloke/dromedary-65b-lora-HF", # a random llama-based model 15 | max_samples_per_dataset: int = 10000, 16 | max_prompt_len: int = 256, 17 | ): 18 | train_data = pd.read_parquet(train_data_path) 19 | tokenizer = LlamaTokenizerFast.from_pretrained(tokenizer_name) 20 | niv_data = [] 21 | flan_data = [] 22 | t0_data = [] 23 | cot_data = [] 24 | 25 | for i in tqdm.tqdm(range(len(train_data))): 26 | datapoint = train_data.iloc[i] 27 | 28 | d_id = datapoint["id"] 29 | 30 | if "niv" in d_id: 31 | niv_data.append(datapoint) 32 | continue 33 | if "flan" in d_id: 34 | flan_data.append(datapoint) 35 | continue 36 | if "t0" in d_id: 37 | t0_data.append(datapoint) 38 | continue 39 | if "cot" in d_id: 40 | cot_data.append(datapoint) 41 | continue 42 | 43 | raise ValueError("Unknown dataset") 44 | 45 | print("Stats:") 46 | print(f"niv: {len(niv_data)}") 47 | print(f"flan: {len(flan_data)}") 48 | print(f"t0: {len(t0_data)}") 49 | print(f"cot: {len(cot_data)}") 50 | 51 | total_data = [] 52 | 53 | set_of_instructions = set() 54 | 55 | for data in tqdm.tqdm([niv_data, flan_data, t0_data, cot_data]): 56 | data_count = 0 57 | random.shuffle(data) 58 | for datapoint in tqdm.tqdm(data): 59 | ex_instruction = datapoint["question"].strip() 60 | ex_input = "" 61 | ex_output = "" 62 | 63 | if ex_instruction in set_of_instructions: 64 | continue 65 | else: 66 | set_of_instructions.add(ex_instruction) 67 | 68 | if ( 69 | len(tokenizer.encode(ex_instruction + "\n\n" + ex_input)) 70 | > max_prompt_len 71 | ): 72 | continue 73 | 74 | # We delete NLI tasks 75 | if "it is not possible to tell" in datapoint["question"]: 76 | continue 77 | 78 | continue_flag = False 79 | for keyword in ["Premise", "premise", "Hypothesis", "hypothesis"]: 80 | if keyword in datapoint["question"]: 81 | continue_flag = True 82 | break 83 | if continue_flag: 84 | continue 85 | 86 | total_data.append( 87 | { 88 | "instruction": ex_instruction, 89 | "input": ex_input, 90 | "output": ex_output, 91 | } 92 | ) 93 | 94 | data_count += 1 95 | if data_count >= max_samples_per_dataset: 96 | break 97 | print("data_count:", data_count, len(total_data)) 98 | 99 | random.shuffle(total_data) 100 | 101 | print(f"Total data: {len(total_data)}") 102 | 103 | with open(output_path, "w") as f: 104 | json.dump(total_data, f, indent=2) 105 | 106 | 107 | if __name__ == "__main__": 108 | fire.Fire(main) 109 | -------------------------------------------------------------------------------- /training/train_qlora_rm.py: -------------------------------------------------------------------------------- 1 | # This source code is licensed under the MIT license found in the 2 | # LICENSE file in the root directory of this source tree. 3 | 4 | import json 5 | 6 | import os 7 | from dataclasses import dataclass, field 8 | from typing import Optional, List, Literal 9 | import logging 10 | 11 | import torch 12 | import transformers 13 | import argparse 14 | from transformers import set_seed 15 | 16 | try: 17 | from transformers import LlamaTokenizerFast as LlamaTokenizer 18 | 19 | print("Using fast tokenizer") 20 | except: 21 | from transformers import LlamaTokenizer 22 | 23 | print("Using slow tokenizer") 24 | 25 | from transformers import AutoTokenizer, AutoModelForCausalLM 26 | 27 | from qlora_utils import ( 28 | SavePeftModelCallback, 29 | print_trainable_parameters, 30 | get_last_checkpoint, 31 | DEFAULT_PAD_TOKEN, 32 | ) 33 | from data_utils.data_utils_rm import make_binary_reward_modeling_data_module 34 | from models.reward_model import ( 35 | RewardConfig, 36 | RewardModel, 37 | RewardModelTrainer as Trainer, 38 | compute_reward_modeling_metrics, 39 | ) 40 | 41 | torch.backends.cuda.matmul.allow_tf32 = True 42 | 43 | logger = logging.getLogger(__name__) 44 | 45 | 46 | @dataclass 47 | class ModelArguments: 48 | model_name_or_path: Optional[str] = field(default="EleutherAI/pythia-12b") 49 | trust_remote_code: Optional[bool] = field( 50 | default=False, 51 | metadata={ 52 | "help": "Enable unpickling of arbitrary code in AutoModelForCausalLM#from_pretrained." 53 | }, 54 | ) 55 | 56 | 57 | @dataclass 58 | class DataArguments: 59 | dataset_path: str = field(default="tatsu-lab/alpaca_farm") 60 | dataset_name: str = field(default=None, metadata={"help": "Dataset name"}) 61 | eval_dataset_path: str = field(default="tatsu-lab/alpaca_farm") 62 | eval_dataset_name: str = field(default="alpaca_human_preference") 63 | eval_size: int = field( 64 | default=500, 65 | metadata={ 66 | "help": "Number of examples to split out from training to use for evaluation." 67 | }, 68 | ) 69 | meta_prompt_pattern: Optional[str] = field( 70 | default=None, metadata={"help": "Which meta prompt pattern to use."} 71 | ) 72 | principle_collection_path: Optional[str] = field( 73 | default=None, metadata={"help": "Path to the principle collection."} 74 | ) 75 | 76 | 77 | @dataclass 78 | class TrainingArguments(transformers.Seq2SeqTrainingArguments): 79 | cache_dir: Optional[str] = field(default=None) 80 | # From AlpacaFarm 81 | model_max_length: int = field( 82 | default=512, 83 | metadata={ 84 | "help": "Maximum sequence length. Sequences will be left padded to this length always during training." 85 | }, 86 | ) 87 | query_len: int = field(default=None, metadata={"help": "Length of the query."}) 88 | response_len: int = field( 89 | default=None, metadata={"help": "Length of the response."} 90 | ) 91 | label_names: List[str] = field( 92 | default_factory=lambda: ["index_0", "index_1", "choice"], 93 | metadata={ 94 | "help": "Names of the labels in the dataset. " 95 | "This is needed to get transformers.Trainer to not throw those tensors away before `compute_loss`." 96 | "By default, the trainer throws away columns it doesn't recognize when creating the " 97 | "`train_dataloader` (see `_remove_unused_columns`). " 98 | }, 99 | ) 100 | padding: Literal["max_length", "longest"] = field( 101 | default="longest", 102 | metadata={ 103 | "help": "Padding strategy. If 'max_length', pads to `model_max_length` always; this might lead to some " 104 | "redundant compute. If 'longest', pads to the longest sequence in the batch, capped by `model_max_length`." 105 | }, 106 | ) 107 | end_sequence_with_eos: bool = field( 108 | default=False, 109 | metadata={ 110 | "help": "Whether to end sequences with EOS. " 111 | "Ending with EOS might help the reward model realize it's time to predict." 112 | }, 113 | ) 114 | # From QLoRA 115 | full_finetune: bool = field( 116 | default=False, metadata={"help": "Finetune the entire model without adapters."} 117 | ) 118 | adam8bit: bool = field(default=False, metadata={"help": "Use 8-bit adam."}) 119 | double_quant: bool = field( 120 | default=True, 121 | metadata={ 122 | "help": "Compress the quantization statistics through double quantization." 123 | }, 124 | ) 125 | quant_type: str = field( 126 | default="nf4", 127 | metadata={ 128 | "help": "Quantization data type to use. Should be one of `fp4` or `nf4`." 129 | }, 130 | ) 131 | bits: int = field(default=4, metadata={"help": "How many bits to use."}) 132 | lora_modules: Optional[List[str]] = field( 133 | default=None, 134 | metadata={ 135 | "help": "Which modules to use LoRA on. If None, will use all linear layers." 136 | }, 137 | ) 138 | lora_r: int = field(default=64, metadata={"help": "Lora R dimension."}) 139 | lora_alpha: float = field(default=16, metadata={"help": " Lora alpha."}) 140 | lora_dropout: float = field(default=0.0, metadata={"help": "Lora dropout."}) 141 | report_to: str = field( 142 | default="none", 143 | metadata={"help": "To use wandb or something else for reporting."}, 144 | ) 145 | resume_dir: Optional[str] = field( 146 | default=None, 147 | metadata={"help": "Path to the directory containing the checkpoint to resume."}, 148 | ) 149 | output_dir: str = field( 150 | default="./output", metadata={"help": "The output dir for logs and checkpoints"} 151 | ) 152 | optim: str = field( 153 | default="paged_adamw_32bit", metadata={"help": "The optimizer to be used"} 154 | ) 155 | per_device_train_batch_size: int = field( 156 | default=1, 157 | metadata={ 158 | "help": "The training batch size per GPU. Increase for better speed." 159 | }, 160 | ) 161 | gradient_accumulation_steps: int = field( 162 | default=16, 163 | metadata={ 164 | "help": "How many gradients to accumulate before to perform an optimizer step" 165 | }, 166 | ) 167 | weight_decay: float = field( 168 | default=0.0, metadata={"help": "The L2 weight decay rate of AdamW"} 169 | ) # use lora dropout instead for regularization if needed 170 | learning_rate: float = field(default=0.0002, metadata={"help": "The learnign rate"}) 171 | remove_unused_columns: bool = field( 172 | default=False, 173 | metadata={"help": "Removed unused columns. Needed to make this codebase work."}, 174 | ) 175 | max_grad_norm: float = field( 176 | default=0.3, 177 | metadata={ 178 | "help": "Gradient clipping max norm. This is tuned and works well for all models tested." 179 | }, 180 | ) 181 | gradient_checkpointing: bool = field( 182 | default=True, 183 | metadata={"help": "Use gradient checkpointing. You want to use this."}, 184 | ) 185 | do_train: bool = field( 186 | default=True, 187 | metadata={"help": "To train or not to train, that is the question?"}, 188 | ) 189 | lr_scheduler_type: str = field( 190 | default="constant", 191 | metadata={ 192 | "help": "Learning rate schedule. Constant a bit better than cosine, and has advantage for analysis" 193 | }, 194 | ) 195 | warmup_ratio: float = field( 196 | default=0.03, metadata={"help": "Fraction of steps to do a warmup for"} 197 | ) 198 | logging_steps: int = field( 199 | default=10, 200 | metadata={"help": "The frequency of update steps after which to log the loss"}, 201 | ) 202 | group_by_length: bool = field( 203 | default=True, 204 | metadata={ 205 | "help": "Group sequences into batches with same length. Saves memory and speeds up training considerably." 206 | }, 207 | ) 208 | save_strategy: str = field( 209 | default="steps", metadata={"help": "When to save checkpoints"} 210 | ) 211 | save_steps: int = field(default=250, metadata={"help": "How often to save a model"}) 212 | save_total_limit: int = field( 213 | default=40, 214 | metadata={ 215 | "help": "How many checkpoints to save before the oldest is overwritten" 216 | }, 217 | ) 218 | resume_from_training: bool = field( 219 | default=False, metadata={"help": "Resume from training"} 220 | ) 221 | 222 | 223 | def rank0_print(*args): 224 | local_rank = int(os.environ.get("LOCAL_RANK", 0)) 225 | if local_rank == 0: 226 | print(*args) 227 | 228 | 229 | def train(): 230 | hfparser = transformers.HfArgumentParser( 231 | (ModelArguments, DataArguments, TrainingArguments) 232 | ) 233 | ( 234 | model_args, 235 | data_args, 236 | training_args, 237 | extra_args, 238 | ) = hfparser.parse_args_into_dataclasses(return_remaining_strings=True) 239 | args = argparse.Namespace( 240 | **vars(model_args), **vars(data_args), **vars(training_args) 241 | ) 242 | 243 | if args.resume_dir is not None: 244 | checkpoint_dir, completed_training = args.resume_dir, False 245 | else: 246 | checkpoint_dir, completed_training = get_last_checkpoint(args.output_dir) 247 | 248 | if completed_training: 249 | rank0_print("Detected that training was already completed!") 250 | 251 | if checkpoint_dir is None: 252 | rank0_print("Training from scratch.") 253 | else: 254 | rank0_print("Loading from checkpoint:", checkpoint_dir) 255 | if args.resume_from_training: 256 | rank0_print("Resuming from training not supported yet. Exiting.") 257 | exit(1) 258 | 259 | use_llama_base_model = ( 260 | "dromedary" in args.model_name_or_path.lower() 261 | or "llama" in args.model_name_or_path.lower() 262 | or "vicuna" in args.model_name_or_path.lower() 263 | ) 264 | 265 | if use_llama_base_model: 266 | tokenizer_model_name = ( 267 | "TheBloke/dromedary-65b-lora-HF" # TODO(zhiqings): hacking 268 | ) 269 | TokenizerClass = LlamaTokenizer 270 | else: 271 | tokenizer_model_name = args.model_name_or_path 272 | TokenizerClass = AutoTokenizer 273 | 274 | # Tokenizer 275 | tokenizer = TokenizerClass.from_pretrained( 276 | tokenizer_model_name, 277 | cache_dir=args.cache_dir, 278 | model_max_length=training_args.model_max_length, 279 | padding_side="left", 280 | truncation_side="right", 281 | ) 282 | 283 | if use_llama_base_model: 284 | if tokenizer._pad_token is None: 285 | tokenizer.pad_token_id = ( 286 | 0 # unk. we want this to be different from the eos token 287 | ) 288 | else: 289 | raise NotImplementedError 290 | 291 | data_module = make_binary_reward_modeling_data_module( 292 | tokenizer=tokenizer, 293 | data_args=data_args, 294 | training_args=training_args, 295 | ) 296 | 297 | if args.do_train: 298 | training_data = data_module["train_dataset"] 299 | rank0_print("Training data size:", len(training_data)) 300 | rank0_print("Training data example:") 301 | for i in range(min(3, len(training_data))): 302 | ex_input_ids_0 = training_data[i]["input_ids"][0] 303 | rank0_print(tokenizer.decode(ex_input_ids_0, skip_special_tokens=True)) 304 | rank0_print("=" * 20) 305 | ex_input_ids_1 = training_data[i]["input_ids"][1] 306 | rank0_print(tokenizer.decode(ex_input_ids_1, skip_special_tokens=True)) 307 | rank0_print("=" * 20) 308 | rank0_print("=" * 20) 309 | 310 | config = RewardConfig(backbone_model_name_or_path=model_args.model_name_or_path) 311 | 312 | model = RewardModel( 313 | args=args, 314 | config=config, 315 | qlora=True, 316 | checkpoint_dir=checkpoint_dir, 317 | ) 318 | 319 | model.backbone_model.config.use_cache = False 320 | print_trainable_parameters(args, model) 321 | print("loaded model") 322 | set_seed(args.seed) 323 | 324 | trainer = Trainer( 325 | model=model, 326 | tokenizer=tokenizer, 327 | args=training_args, 328 | compute_metrics=compute_reward_modeling_metrics, 329 | **{k: v for k, v in data_module.items() if k != "predict_dataset"}, 330 | ) 331 | 332 | # Callbacks 333 | if not args.full_finetune: 334 | trainer.add_callback(SavePeftModelCallback) 335 | 336 | # Verifying the datatypes. 337 | dtypes = {} 338 | for _, p in model.named_parameters(): 339 | dtype = p.dtype 340 | if dtype not in dtypes: 341 | dtypes[dtype] = 0 342 | dtypes[dtype] += p.numel() 343 | total = 0 344 | for k, v in dtypes.items(): 345 | total += v 346 | for k, v in dtypes.items(): 347 | print(k, v, v / total) 348 | 349 | all_metrics = {"run_name": args.run_name} 350 | 351 | # Training 352 | if args.do_train: 353 | logger.info("*** Train ***") 354 | train_result = trainer.train() 355 | metrics = train_result.metrics 356 | trainer.log_metrics("train", metrics) 357 | trainer.save_metrics("train", metrics) 358 | trainer.save_state() 359 | all_metrics.update(metrics) 360 | 361 | # Evaluation 362 | if args.do_eval: 363 | logger.info("*** Evaluate ***") 364 | metrics = trainer.evaluate(metric_key_prefix="eval") 365 | trainer.log_metrics("eval", metrics) 366 | trainer.save_metrics("eval", metrics) 367 | all_metrics.update(metrics) 368 | 369 | if args.do_train or args.do_eval: 370 | with open(os.path.join(args.output_dir, "metrics.json"), "w") as fout: 371 | fout.write(json.dumps(all_metrics)) 372 | 373 | 374 | if __name__ == "__main__": 375 | train() 376 | -------------------------------------------------------------------------------- /training/train_qlora_sft.py: -------------------------------------------------------------------------------- 1 | # This source code is licensed under the MIT license found in the 2 | # LICENSE file in the root directory of this source tree. 3 | 4 | import json 5 | 6 | import os 7 | from dataclasses import dataclass, field 8 | from typing import Optional, List 9 | import logging 10 | 11 | import torch 12 | import transformers 13 | import argparse 14 | from transformers import ( 15 | set_seed, 16 | Trainer, 17 | ) 18 | 19 | try: 20 | from transformers import LlamaTokenizerFast as LlamaTokenizer 21 | 22 | print("Using fast tokenizer") 23 | except: 24 | from transformers import LlamaTokenizer 25 | 26 | print("Using slow tokenizer") 27 | 28 | from transformers import AutoTokenizer 29 | 30 | from qlora_utils import ( 31 | SavePeftModelCallback, 32 | print_trainable_parameters, 33 | get_last_checkpoint, 34 | DEFAULT_PAD_TOKEN, 35 | ) 36 | from data_utils.data_utils_sft import ( 37 | make_sft_data_module, 38 | IGNORE_INDEX, 39 | ) 40 | from models.qlora_model import get_accelerate_model 41 | 42 | 43 | torch.backends.cuda.matmul.allow_tf32 = True 44 | 45 | logger = logging.getLogger(__name__) 46 | 47 | 48 | @dataclass 49 | class ModelArguments: 50 | model_name_or_path: Optional[str] = field(default="EleutherAI/pythia-12b") 51 | trust_remote_code: Optional[bool] = field( 52 | default=False, 53 | metadata={ 54 | "help": "Enable unpickling of arbitrary code in AutoModelForCausalLM#from_pretrained." 55 | }, 56 | ) 57 | 58 | 59 | @dataclass 60 | class DataArguments: 61 | eval_dataset_size: int = field( 62 | default=1024, metadata={"help": "Size of validation dataset."} 63 | ) 64 | max_train_samples: Optional[int] = field( 65 | default=None, 66 | metadata={ 67 | "help": "For debugging purposes or quicker training, truncate the number of training examples to this " 68 | "value if set." 69 | }, 70 | ) 71 | max_eval_samples: Optional[int] = field( 72 | default=None, 73 | metadata={ 74 | "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 75 | "value if set." 76 | }, 77 | ) 78 | source_max_len: int = field( 79 | default=1024, 80 | metadata={ 81 | "help": "Maximum source sequence length. Sequences will be right padded (and possibly truncated)." 82 | }, 83 | ) 84 | target_max_len: int = field( 85 | default=256, 86 | metadata={ 87 | "help": "Maximum target sequence length. Sequences will be right padded (and possibly truncated)." 88 | }, 89 | ) 90 | dataset: str = field( 91 | default="alpaca", 92 | metadata={"help": "Which dataset to finetune on. See datamodule for options."}, 93 | ) 94 | dataset_format: Optional[str] = field( 95 | default=None, 96 | metadata={ 97 | "help": "Which dataset format is used. [alpaca|chip2|self-instruct|hh-rlhf]" 98 | }, 99 | ) 100 | meta_prompt_pattern: Optional[str] = field( 101 | default=None, metadata={"help": "Which meta prompt pattern to use."} 102 | ) 103 | add_eos_to_target: bool = field( 104 | default=True, metadata={"help": "Whether to add an EOS token to the target."} 105 | ) 106 | 107 | 108 | @dataclass 109 | class TrainingArguments(transformers.Seq2SeqTrainingArguments): 110 | cache_dir: Optional[str] = field(default=None) 111 | train_on_source: Optional[bool] = field( 112 | default=False, 113 | metadata={ 114 | "help": "Whether to train on the input in addition to the target text." 115 | }, 116 | ) 117 | full_finetune: bool = field( 118 | default=False, metadata={"help": "Finetune the entire model without adapters."} 119 | ) 120 | adam8bit: bool = field(default=False, metadata={"help": "Use 8-bit adam."}) 121 | double_quant: bool = field( 122 | default=True, 123 | metadata={ 124 | "help": "Compress the quantization statistics through double quantization." 125 | }, 126 | ) 127 | quant_type: str = field( 128 | default="nf4", 129 | metadata={ 130 | "help": "Quantization data type to use. Should be one of `fp4` or `nf4`." 131 | }, 132 | ) 133 | bits: int = field(default=4, metadata={"help": "How many bits to use."}) 134 | lora_modules: Optional[List[str]] = field( 135 | default=None, 136 | metadata={ 137 | "help": "Which modules to use LoRA on. If None, will use all linear layers." 138 | }, 139 | ) 140 | lora_r: int = field(default=64, metadata={"help": "Lora R dimension."}) 141 | lora_alpha: float = field(default=16, metadata={"help": " Lora alpha."}) 142 | lora_dropout: float = field(default=0.0, metadata={"help": "Lora dropout."}) 143 | max_memory_MB: int = field(default=80000, metadata={"help": "Free memory per gpu."}) 144 | report_to: str = field( 145 | default="none", 146 | metadata={"help": "To use wandb or something else for reporting."}, 147 | ) 148 | resume_dir: Optional[str] = field( 149 | default=None, 150 | metadata={"help": "Path to the directory containing the checkpoint to resume."}, 151 | ) 152 | output_dir: str = field( 153 | default="./output", metadata={"help": "The output dir for logs and checkpoints"} 154 | ) 155 | optim: str = field( 156 | default="paged_adamw_32bit", metadata={"help": "The optimizer to be used"} 157 | ) 158 | per_device_train_batch_size: int = field( 159 | default=1, 160 | metadata={ 161 | "help": "The training batch size per GPU. Increase for better speed." 162 | }, 163 | ) 164 | gradient_accumulation_steps: int = field( 165 | default=16, 166 | metadata={ 167 | "help": "How many gradients to accumulate before to perform an optimizer step" 168 | }, 169 | ) 170 | weight_decay: float = field( 171 | default=0.0, metadata={"help": "The L2 weight decay rate of AdamW"} 172 | ) # use lora dropout instead for regularization if needed 173 | learning_rate: float = field(default=0.0002, metadata={"help": "The learnign rate"}) 174 | remove_unused_columns: bool = field( 175 | default=False, 176 | metadata={"help": "Removed unused columns. Needed to make this codebase work."}, 177 | ) 178 | max_grad_norm: float = field( 179 | default=0.3, 180 | metadata={ 181 | "help": "Gradient clipping max norm. This is tuned and works well for all models tested." 182 | }, 183 | ) 184 | gradient_checkpointing: bool = field( 185 | default=True, 186 | metadata={"help": "Use gradient checkpointing. You want to use this."}, 187 | ) 188 | do_train: bool = field( 189 | default=True, 190 | metadata={"help": "To train or not to train, that is the question?"}, 191 | ) 192 | lr_scheduler_type: str = field( 193 | default="constant", 194 | metadata={ 195 | "help": "Learning rate schedule. Constant a bit better than cosine, and has advantage for analysis" 196 | }, 197 | ) 198 | warmup_ratio: float = field( 199 | default=0.03, metadata={"help": "Fraction of steps to do a warmup for"} 200 | ) 201 | logging_steps: int = field( 202 | default=10, 203 | metadata={"help": "The frequency of update steps after which to log the loss"}, 204 | ) 205 | group_by_length: bool = field( 206 | default=True, 207 | metadata={ 208 | "help": "Group sequences into batches with same length. Saves memory and speeds up training considerably." 209 | }, 210 | ) 211 | save_strategy: str = field( 212 | default="steps", metadata={"help": "When to save checkpoints"} 213 | ) 214 | save_steps: int = field(default=250, metadata={"help": "How often to save a model"}) 215 | save_total_limit: int = field( 216 | default=40, 217 | metadata={ 218 | "help": "How many checkpoints to save before the oldest is overwritten" 219 | }, 220 | ) 221 | resume_from_training: bool = field( 222 | default=False, metadata={"help": "Resume from training"} 223 | ) 224 | 225 | 226 | def rank0_print(*args): 227 | local_rank = int(os.environ.get("LOCAL_RANK", 0)) 228 | if local_rank == 0: 229 | print(*args) 230 | 231 | 232 | def train(): 233 | hfparser = transformers.HfArgumentParser( 234 | (ModelArguments, DataArguments, TrainingArguments) 235 | ) 236 | ( 237 | model_args, 238 | data_args, 239 | training_args, 240 | extra_args, 241 | ) = hfparser.parse_args_into_dataclasses(return_remaining_strings=True) 242 | args = argparse.Namespace( 243 | **vars(model_args), **vars(data_args), **vars(training_args) 244 | ) 245 | 246 | if args.resume_dir is not None: 247 | checkpoint_dir, completed_training = args.resume_dir, False 248 | else: 249 | checkpoint_dir, completed_training = get_last_checkpoint(args.output_dir) 250 | 251 | if completed_training: 252 | rank0_print("Detected that training was already completed!") 253 | 254 | if checkpoint_dir is None: 255 | rank0_print("Training from scratch.") 256 | else: 257 | rank0_print("Loading from checkpoint:", checkpoint_dir) 258 | if args.resume_from_training: 259 | rank0_print("Resuming from training not supported yet. Exiting.") 260 | exit(1) 261 | 262 | use_llama_base_model = ( 263 | "dromedary" in args.model_name_or_path.lower() 264 | or "llama" in args.model_name_or_path.lower() 265 | ) 266 | 267 | if use_llama_base_model: 268 | tokenizer_model_name = ( 269 | "TheBloke/dromedary-65b-lora-HF" # a random llama-based model 270 | ) 271 | TokenizerClass = LlamaTokenizer 272 | else: 273 | tokenizer_model_name = args.model_name_or_path 274 | TokenizerClass = AutoTokenizer 275 | 276 | left_truncated_tokenizer = TokenizerClass.from_pretrained( 277 | tokenizer_model_name, 278 | cache_dir=args.cache_dir, 279 | truncation_side="left", 280 | padding_side="right", 281 | # use_fast=False, # Fast tokenizer giving issues. 282 | ) 283 | 284 | # Tokenizer 285 | tokenizer = TokenizerClass.from_pretrained( 286 | tokenizer_model_name, 287 | cache_dir=args.cache_dir, 288 | truncation_side="right", 289 | padding_side="right", 290 | ) 291 | 292 | if use_llama_base_model: 293 | if tokenizer._pad_token is None: 294 | left_truncated_tokenizer.pad_token_id = ( 295 | 0 # unk. we want this to be different from the eos token 296 | ) 297 | tokenizer.pad_token_id = ( 298 | 0 # unk. we want this to be different from the eos token 299 | ) 300 | 301 | else: 302 | raise NotImplementedError 303 | 304 | data_module = make_sft_data_module( 305 | left_truncated_tokenizer=left_truncated_tokenizer, 306 | tokenizer=tokenizer, 307 | args=args, 308 | ) 309 | 310 | model = get_accelerate_model(args, checkpoint_dir) 311 | 312 | model.config.use_cache = False 313 | print_trainable_parameters(args, model) 314 | print("loaded model") 315 | set_seed(args.seed) 316 | 317 | if args.do_train: 318 | training_data = data_module["train_dataset"] 319 | rank0_print("Training data size:", len(training_data)) 320 | rank0_print("Training data example:") 321 | for i in range(min(3, len(training_data))): 322 | rank0_print(training_data[i]) 323 | 324 | trainer = Trainer( 325 | model=model, 326 | tokenizer=tokenizer, 327 | args=training_args, 328 | **{k: v for k, v in data_module.items() if k != "predict_dataset"}, 329 | ) 330 | 331 | # Callbacks 332 | if not args.full_finetune: 333 | trainer.add_callback(SavePeftModelCallback) 334 | 335 | # Verifying the datatypes. 336 | dtypes = {} 337 | for _, p in model.named_parameters(): 338 | dtype = p.dtype 339 | if dtype not in dtypes: 340 | dtypes[dtype] = 0 341 | dtypes[dtype] += p.numel() 342 | total = 0 343 | for k, v in dtypes.items(): 344 | total += v 345 | for k, v in dtypes.items(): 346 | print(k, v, v / total) 347 | 348 | all_metrics = {"run_name": args.run_name} 349 | # Training 350 | if args.do_train: 351 | logger.info("*** Train ***") 352 | # Note: `resume_from_checkpoint` not supported for adapter checkpoints by HF. 353 | # Currently adapter checkpoint is reloaded as expected but optimizer/scheduler states are not. 354 | train_result = trainer.train() 355 | metrics = train_result.metrics 356 | trainer.log_metrics("train", metrics) 357 | trainer.save_metrics("train", metrics) 358 | trainer.save_state() 359 | all_metrics.update(metrics) 360 | 361 | if args.do_train: 362 | with open(os.path.join(args.output_dir, "metrics.json"), "w") as fout: 363 | fout.write(json.dumps(all_metrics)) 364 | 365 | 366 | if __name__ == "__main__": 367 | train() 368 | --------------------------------------------------------------------------------