├── .gitignore ├── LICENSE ├── Makefile ├── README.md ├── chat.py ├── direct_preference_finetuning.py ├── merge_peft_adapter.py ├── requirements.txt └── supervised_finetuning.py /.gitignore: -------------------------------------------------------------------------------- 1 | venv/ 2 | huggingface_cache/ 3 | llama-supervised-finetune-output/ 4 | llama-reward-model-output/ 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2023 Anthony Zhang (me@anthonyz.ca) 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. 20 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | ####################### 2 | # user-facing targets # 3 | ####################### 4 | 5 | .PHONY: train 6 | train: download-datasets-and-models train-supervised-finetuning train-direct-preference-finetuning 7 | 8 | .PHONY: download-datasets-and-models 9 | download-datasets-and-models: venv/requirements_installed 10 | . ./venv/bin/activate && python3 -c 'import huggingface_hub as h; h.hf_hub_download(repo_id="jondurbin/airoboros-gpt4-1.4.1", repo_type="dataset", revision="433c04038d724bf29a193bc3c1a48b600cc417a1", filename="instructions.jsonl", cache_dir="./huggingface_cache", resume_download=True)' 11 | . ./venv/bin/activate && python3 -c 'import huggingface_hub as h; h.snapshot_download(repo_id="NousResearch/Llama-2-13b-hf", revision="81da3af9503579bf991e3995564baa683b27d38c", ignore_patterns=["pytorch_*"], cache_dir="./huggingface_cache", resume_download=True)' 12 | . ./venv/bin/activate && python3 -c 'import huggingface_hub as h; h.snapshot_download(repo_id="Anthropic/hh-rlhf", repo_type="dataset", revision="09be8c5bbc57cb3887f3a9732ad6aa7ec602a1fa", cache_dir="./huggingface_cache", resume_download=True)' 13 | 14 | .PHONY: train-supervised-finetuning 15 | train-supervised-finetuning: llama2-supervised-finetuning-output/final_checkpoint_merged 16 | 17 | .PHONY: train-reward-modeling 18 | train-direct-preference-finetuning: llama2-direct-preference-finetuning-output/final_checkpoint_merged 19 | 20 | .PHONY: chat 21 | chat: train-direct-preference-finetuning 22 | . ./venv/bin/activate && python3 chat.py 23 | 24 | .PHONY: generate-ggml 25 | generate-ggml: exported-models/ggml-robot-agent-q5_K_M.bin 26 | 27 | chat-llama-cpp: exported-models/ggml-robot-agent-q5_K_M.bin llama.cpp/main 28 | cd llama.cpp && ./main --model ../exported-models/ggml-robot-agent-q5_K_M.bin --color -i --interactive-first --mirostat 2 --ctx-size 2048 -r $$'\n\n### Human:\n' --in-prefix $$'\n\n### Human:\n' --in-suffix $$'\n\n### Assistant:\n' -n -1 29 | 30 | #################### 31 | # internal targets # 32 | #################### 33 | 34 | llama2-supervised-finetuning-output/final_checkpoint_merged: venv/requirements_installed 35 | . ./venv/bin/activate && python3 supervised_finetuning.py 36 | . ./venv/bin/activate && python3 merge_peft_adapter.py llama2-supervised-finetuning-output/final_checkpoint huggingface_cache/models--NousResearch--Llama-2-13b-hf/snapshots/81da3af9503579bf991e3995564baa683b27d38c llama2-supervised-finetuning-output/final_checkpoint_merged 37 | 38 | llama2-direct-preference-finetuning-output/final_checkpoint_merged: venv/requirements_installed llama2-supervised-finetuning-output/final_checkpoint_merged 39 | . ./venv/bin/activate && python3 direct_preference_finetuning.py 40 | . ./venv/bin/activate && python3 merge_peft_adapter.py llama2-direct-preference-finetuning-output/final_checkpoint llama2-supervised-finetuning-output/final_checkpoint_merged llama2-direct-preference-finetuning-output/final_checkpoint_merged 41 | 42 | venv/requirements_installed: requirements.txt 43 | python3 -m venv venv 44 | . ./venv/bin/activate && pip install setuptools==68.0.0 && pip install torch==2.0.1 --index-url https://download.pytorch.org/whl/cu118 # install the CUDA 11.8 version of PyTorch rather than the default CUDA 11.7 version for a nice 50% GPU performance bump, currently the PyTorch install page (https://pytorch.org/get-started/locally/) shows "pip install torch" under the CUDA 11.7 section, whereas the CUDA 11.8 section shows the different command we're using here. because it's using a different Python package index, we also can't put this in the requirements.txt file either 45 | . ./venv/bin/activate && pip install -r requirements.txt && touch ./venv/requirements_installed 46 | 47 | llama.cpp/main: 48 | [ ! -d llama.cpp/.git ] && git clone https://github.com/ggerganov/llama.cpp.git 49 | cd llama.cpp && git reset --hard d01bccde9f759b24449fdaa16306b406a50eb367 && make 50 | 51 | exported-models/ggml-robot-agent-q5_K_M.bin: llama2-direct-preference-finetuning-output/final_checkpoint_merged llama.cpp/main 52 | mkdir -p exported-models 53 | . ./venv/bin/activate && python3 llama.cpp/convert.py --outfile exported-models/ggml-robot-agent-f16.bin llama2-direct-preference-finetuning-output/final_checkpoint_merged 54 | ./llama.cpp/quantize exported-models/ggml-robot-agent-f16.bin exported-models/ggml-robot-agent-q5_K_M.bin q5_K_M 55 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Robot Agent 2 | =========== 3 | 4 | Fine-tuned Llama2 13B model designed for ReAct-style and Tree-Of-Thoughts style prompting. The codebase has the following desirable features: 5 | 6 | * Entire training procedure runs out of the box on a single computer with 32GB of RAM and 24GB of VRAM (i.e. consumer-grade graphics cards such as the RTX 3090 and RTX 4090) with less than 30 hours of compute time. 7 | * Carefully tuned to use no more than 27GiB of RAM and 23.6GiB of VRAM. 8 | * This is accomplished through quantization, FP16, TF32, and the usual gradient accumulation/checkpointing settings. 9 | * Training is fully interruptible/resumable. 10 | * Heavily commented, short, clean, and reproducible training code. 11 | * All library dependency versions fully pinned, base models and datasets are pinned and downloaded as part of setup process. 12 | * After initial setup, training process does not require network access - entire project folder is portable, can be moved into airgapped and offline environments. 13 | * Use SafeTensors everywhere for speed and security. 14 | 15 | Technical details: 16 | 17 | * Based on [Llama2 13B](https://huggingface.co/NousResearch/Llama-2-13b-hf). 18 | * QLoRA training, a 128 rank LoRA similar to [Guanaco](https://github.com/artidoro/qlora/blob/cc488110b5ea23594a418daca7085000a9420625/qlora.py#L324). 19 | * 2048-token context window used in supervised finetuning, 1536-token context window used in direct preference finetuning. 20 | * Supervised finetuning using [Airoboros' self-instruct dataset](https://huggingface.co/datasets/jondurbin/airoboros-gpt4-1.4.1), generated by [Airoboros' self-instruct implementation](https://github.com/jondurbin/airoboros). 21 | * The dataset has been filtered for refusals, and so could be considered "uncensored". 22 | * The dataset generation code also uses a GPT4 jailbreak to reduce the number of refusals in the first place. 23 | * Direct preference finetuning using [Anthropic's hh-rlhf dataset](https://huggingface.co/datasets/Anthropic/hh-rlhf) 24 | * This replaces the reward modelling and reinforcement learning steps in a standard RLHF pipeline. 25 | * Codebase takes ideas and inspiration from [StackLLaMa](https://github.com/lvwerra/trl/tree/5c7bfbc8d9aeabee893290cc02121d7260636978/examples/research_projects/stack_llama/scripts), [QLoRA](https://github.com/artidoro/qlora), [LLaMA-TRL](https://github.com/jasonvanf/llama-trl), [Airoboros](https://github.com/jondurbin/airoboros), . 26 | 27 | Roadmap 28 | ------- 29 | 30 | * [x] Full reproducible environment with all datasets, base models, and dependencies included. 31 | * [x] Supervised finetuning script using high-quality publically-available instruct datasets. 32 | * [x] Human-preference finetuning script based on Anthropic's hh-rlhf "helpfulness" dataset. 33 | * [x] Accidentally delete the training results on my GPU server and start the training over again from scratch. 34 | * [ ] Fiddle with agentic dataset generation using Charades dataset. 35 | * [ ] If that doesn't work, fiddle with video captioning using multimodal models like Otter to generate agentic captions from how-to videos on Youtube. 36 | 37 | Prompt Format 38 | ------------- 39 | 40 | ``` 41 | ### Human: 42 | INSTRUCTIONS_GO_HERE 43 | 44 | ### Assistant: 45 | ``` 46 | 47 | Note that there is a single newline at the end of the prompt. Example: 48 | 49 | ``` 50 | ### Human: 51 | What color is the sky? 52 | 53 | ### Assistant: 54 | The sky is blue. 55 | ``` 56 | 57 | Training 58 | -------- 59 | 60 | First, download everything that requires an internet connection into the current project folder. It will increase to around 30GiB in size: 61 | 62 | ```sh 63 | make download-datasets-and-models 64 | ``` 65 | 66 | Next, transfer the current project folder to the training machine, where the rest of the training can be performed fully offline: 67 | 68 | ```sh 69 | make train 70 | ``` 71 | 72 | Inference 73 | --------- 74 | 75 | To use the model, a simple chat-like interface is included for demo purposes, it's not very fancy but it's good enough for testing purposes: 76 | 77 | ```sh 78 | make chat 79 | ``` 80 | 81 | ### Using Llama.cpp 82 | 83 | First, run the following command to create `./exported-models/ggml-robot-agent-q5_K_M.bin`, an 8.6GiB GGML file compatible with Llama.cpp: 84 | 85 | ```sh 86 | make generate-ggml 87 | ``` 88 | 89 | Now to load the model using Llama.cpp: 90 | 91 | ```sh 92 | make chat-llama-cpp 93 | ``` 94 | 95 | To use Llama.cpp manually, navigate to your llama.cpp folder and start using the model with the following command (replace `PATH_TO_PROJECT_FOLDER` with the path to the current project folder): 96 | 97 | ```sh 98 | ./main --model PATH_TO_PROJECT_FOLDER/exported-models/ggml-robot-agent-q5_K_M.bin --color --interactive --interactive-first --mirostat 2 --ctx-size 2048 --reverse-prompt $'\n\n### Human:\n' --prompt $'\n\n### Human:\n' --in-suffix $'\n### Assistant:\n' 99 | ``` 100 | -------------------------------------------------------------------------------- /chat.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import json 4 | import random 5 | 6 | os.environ['HF_DATASETS_OFFLINE'] = '1' # ask datasets library not to make arbitrary web requests 7 | os.environ['TRANSFORMERS_OFFLINE'] = '1' # ask transformer library not to make arbitrary web requests 8 | os.environ['BITSANDBYTES_NOWELCOME'] = '1' # disable welcome message that bitsandbytes prints, it's unnecessary noise 9 | 10 | import transformers 11 | 12 | 13 | BASE_MODEL_PATH = "./llama2-direct-preference-finetuning-output/final_checkpoint_merged" # generated as the output of running `make train-direct-preference-finetuning` 14 | CONTEXT_WINDOW_SIZE = 2048 15 | 16 | 17 | if __name__ == "__main__": 18 | tokenizer = transformers.AutoTokenizer.from_pretrained(BASE_MODEL_PATH, use_safetensors=True) 19 | tokenizer.pad_token = tokenizer.eos_token # the padding token isn't set in the included tokenizer by default (see tokenizer.special_tokens_map for existing special tokens), set it manually 20 | base_model = transformers.AutoModelForCausalLM.from_pretrained(BASE_MODEL_PATH, load_in_8bit=True, low_cpu_mem_usage=True, use_safetensors=True) # load in 8-bit quantized mode 21 | stopping_criteria = transformers.StoppingCriteriaList([lambda input_ids, scores: tokenizer.decode(input_ids[0]).endswith("\n\n### Human:")]) 22 | 23 | prompt_so_far = "" 24 | print("enter your prompt below, ending with an EOF character (Enter, Ctrl+D, Enter on most terminals)") 25 | print("### Human:") 26 | while True: 27 | prompt = [] 28 | while True: 29 | try: 30 | prompt.append(input()) 31 | except EOFError: 32 | break 33 | if not prompt: 34 | break 35 | prompt_so_far = (prompt_so_far + "\n\n### Human:\n" + "\n".join(prompt) + "\n\n### Assistant:\n")[-CONTEXT_WINDOW_SIZE:] 36 | tokenized_prompt = tokenizer(prompt_so_far, return_tensors="pt").to("cuda") 37 | print("\n### Assistant:") 38 | generated_token_ids = base_model.generate( 39 | input_ids=tokenized_prompt.input_ids, 40 | attention_mask=tokenized_prompt.attention_mask, 41 | generation_config=transformers.GenerationConfig( 42 | max_new_tokens=CONTEXT_WINDOW_SIZE, # generate up to an entire context window's worth of new output 43 | pad_token_id=tokenizer.eos_token_id, 44 | eos_token_id=tokenizer.eos_token_id, 45 | penalty_alpha=0.6, top_k=4, # contrastive search sampling, gives more coherent but non-repetitive outputs - the alpha penalty determines how much to penalize the score of each next token being "similar" to the existing tokens in the context (so when alpha=0, there's no penalty and it just becomes greedy sampling), and the top-k parameter is the number of top candidates to consider after computing this score 46 | low_memory=True, # use a more memory-efficient but slower sampling method, sequentially instead of in parallel 47 | ), 48 | stopping_criteria=stopping_criteria, 49 | )[0] 50 | if generated_token_ids[-1] == tokenizer.eos_token_id: 51 | generated_token_ids = generated_token_ids[0:-1] 52 | output = tokenizer.decode(generated_token_ids[tokenized_prompt.input_ids.shape[1]:]) 53 | if output.endswith("\n\n### Human:"): 54 | output = output[:-len("\n\n### Human:")] 55 | prompt_so_far = (prompt_so_far + output)[-CONTEXT_WINDOW_SIZE:] 56 | print(f"{output}\n\n### Human:") 57 | -------------------------------------------------------------------------------- /direct_preference_finetuning.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import re 4 | import json 5 | import gzip 6 | from collections import defaultdict 7 | 8 | os.environ['HF_DATASETS_OFFLINE'] = '1' # ask datasets library not to make arbitrary web requests 9 | os.environ['TRANSFORMERS_OFFLINE'] = '1' # ask transformer library not to make arbitrary web requests 10 | os.environ['BITSANDBYTES_NOWELCOME'] = '1' # disable welcome message that bitsandbytes prints, it's unnecessary noise 11 | 12 | import numpy as np 13 | import torch 14 | import peft 15 | import transformers 16 | import trl 17 | 18 | 19 | DATA_JSON_FILES_PATH = "./huggingface_cache/datasets--Anthropic--hh-rlhf/snapshots/09be8c5bbc57cb3887f3a9732ad6aa7ec602a1fa/helpful-online/" # downloaded by running `make download-datasets-and-models` in this repo 20 | BASE_MODEL_PATH = "./llama2-supervised-finetuning-output/final_checkpoint_merged" # generated as the output of running `make train-supervised-finetuning` 21 | OUTPUT_DIRECTORY = "./llama2-direct-preference-finetuning-output" 22 | RANDOMNESS_SEED = 0 23 | BATCH_SIZE = 1 # number of samples seen per gradient update - to increase training speed, set this to the largest size that your hardware can support without running out of memory 24 | CONTEXT_WINDOW_SIZE = 1300 # maximum length of any input to the model, used to filter out too-long data points (this doesn't have to be the same value as in supervised_finetuning.py) - to improve performance on longer prompts, set this to the largest size that your hardware can support without running out of memory 25 | TRAINING_STEPS = 1000 # number of steps to train for (since we're using gradient_accumulation_steps=4, the model will see TRAINING_STEPS * 4 * BATCH_SIZE samples throughout the entire training run) 26 | 27 | 28 | def prompt_formatter(example): 29 | chosen_messages, rejected_messages = re.split(r"\n\n(Human|Assistant): ", example["chosen"])[1:], re.split(r"\n\n(Human|Assistant): ", example["rejected"])[1:] 30 | chosen_prompt_so_far = "" 31 | if len(chosen_messages) != len(rejected_messages) or len(chosen_messages) % 4 != 0: 32 | return [] 33 | result = [] 34 | for i in range(0, min(len(chosen_messages), len(rejected_messages)), 4): 35 | if not (chosen_messages[i] == rejected_messages[i] == "Human"): 36 | return [] 37 | if not (chosen_messages[i + 2] == rejected_messages[i + 2] == "Assistant"): 38 | return [] 39 | prompt, chosen, rejected = chosen_messages[i + 1], chosen_messages[i + 3], rejected_messages[i + 3] 40 | chosen_prompt_so_far += f'\n\n### Human:\n{prompt}\n\n### Assistant:\n' 41 | result.append({"prompt": chosen_prompt_so_far, "chosen": chosen, "rejected": rejected}) 42 | chosen_prompt_so_far += chosen_messages[i + 3] 43 | return result 44 | 45 | 46 | def create_datasets(tokenizer): 47 | with gzip.open(os.path.join(DATA_JSON_FILES_PATH, "train.jsonl.gz"), mode="rt") as f: 48 | train_dataset = [example for line in f for example in prompt_formatter(json.loads(line))] 49 | with gzip.open(os.path.join(DATA_JSON_FILES_PATH, "test.jsonl.gz"), mode="rt") as f: 50 | test_dataset = [example for line in f for example in prompt_formatter(json.loads(line))] 51 | return train_dataset, test_dataset 52 | 53 | 54 | def run_training(train_dataset, test_dataset, tokenizer, resume_from_checkpoint): 55 | base_model = transformers.AutoModelForCausalLM.from_pretrained(BASE_MODEL_PATH, load_in_8bit=True, low_cpu_mem_usage=True, use_safetensors=True) # load in 8-bit quantized mode 56 | peft.prepare_model_for_kbit_training(base_model) 57 | if resume_from_checkpoint is not None: 58 | checkpoint_name = os.path.join(OUTPUT_DIRECTORY, resume_from_checkpoint) 59 | assert os.path.exists(checkpoint_name), checkpoint_name 60 | peft_base_model = peft.PeftModel.from_pretrained(base_model, checkpoint_name, is_trainable=True) # TODO: is there a way to make this error out if it isn't in safetensors format? 61 | else: 62 | peft_base_model = peft.get_peft_model(base_model, peft.LoraConfig( 63 | # by default, this adds LoRA adapters around Llama2's "q_proj" and "v_proj", see https://github.com/huggingface/peft/blob/5a0e19dda1048ff8caaa12970ba7574f9cdfbf76/src/peft/utils/other.py#L280 for more details 64 | # some other models add more adapters, such as Guanaco, which does it on every linear layer except the head (https://github.com/artidoro/qlora/blob/845188de110d8eb7c95cc8907b54d8cb2e7c01bd/qlora.py#L221), but this doesn't seem to have too noticeable a benefit compared to models that just use these default settings, like Airoboros does (https://github.com/jondurbin/FastChat/blob/5bd738586ae6a495bd73152d74969465f30d43ac/fastchat/train/train_lora.py#L51) 65 | r=128, # number of LoRA attention dimension parameters - directly proportional to LoRA adapter VRAM usage, set this to the largest value your hardware can support without running out of memory 66 | lora_alpha=16, # alpha parameter for LoRA, essentially scales all of the LoRA weights, which determines "how much of an effect" this LoRA has on the final model 67 | lora_dropout=0.05, # dropout probability for LoRA layers 68 | task_type=peft.TaskType.CAUSAL_LM, 69 | )) 70 | peft_base_model.print_trainable_parameters() 71 | torch.cuda.empty_cache() # helps reduce VRAM usage - it's right after the PEFT version of the model was created, so some old stuff is still around unnecessarily 72 | 73 | trainer = trl.DPOTrainer( 74 | peft_base_model, 75 | base_model, # use the original base model for comparison purposes (this model will be used in evaluation/inference mode, as a reference) 76 | args=transformers.TrainingArguments( 77 | output_dir=OUTPUT_DIRECTORY, 78 | dataloader_drop_last=True, # when the dataset size isn't evenly divisible by the batch size, the remainder forms an incomplete batch - throw away this batch to avoid having to ever see an incomplete batch 79 | max_steps=TRAINING_STEPS, # perform a fixed number of training steps before stopping 80 | evaluation_strategy="steps", eval_steps=200, # run an evaluation every 200 training steps (~1 hour) 81 | save_strategy="steps", save_steps=200, save_safetensors=True, # save a checkpoint every 200 training steps (~1 hour) 82 | logging_strategy="steps", logging_steps=1, # log output every training step 83 | per_device_train_batch_size=BATCH_SIZE, 84 | per_device_eval_batch_size=BATCH_SIZE, 85 | learning_rate=1e-5, warmup_steps=150, # linearly ramp up the learning rate for the AdamW optimizer from 0 to 1e-5 over the first 150 steps, then keep it at 1e-5 afterwards 86 | gradient_accumulation_steps=16, 87 | gradient_checkpointing=True, # use gradient checkpointing to decrease VRAM usage 88 | bf16=True, # use 16-bit bfloats for training instead of 32-bit floats in most operations (some are still kept in 32-bit for precision) to decrease VRAM usage and increase training performance, in practice the precision loss has a relatively small effect on the final result 89 | tf32=True, # in newer NVIDIA hardware, this replaces the remaining 32-bit operations with a 19-bit TensorFloat operations to increase training performance, in practice the precision loss has no noticeable effect on the final result 90 | remove_unused_columns=False, # the DPO default data collator requires this setting, since it customizes the model training/prediction steps 91 | run_name="llama2-direct-preference-finetuning", 92 | # TODO: when Apex becomes more stable + easier to install, look into using adamw_apex_fused rather than adamw_hf for the optim= parameter (note that some code uses optimizers= on the DPOTrainer itself, which overrides the optim= parameter on TrainingArguments, see https://github.com/huggingface/transformers/blob/53e1f5cf66d320b9c809f3940c707b6fef435d2d/src/transformers/trainer.py#L1084) 93 | ), 94 | beta=0.1, # parameter controlling the deviation from the reference model, higher values prevent the model from deviating too far from the reference model, 0.1 is a relatively low value so the model will change quite a bit 95 | train_dataset=train_dataset, 96 | eval_dataset=test_dataset, 97 | tokenizer=tokenizer, 98 | max_length=CONTEXT_WINDOW_SIZE, # any inputs that are longer than this value get truncated - first by cutting off the start of the prompt, then by cutting off the end of the response 99 | max_prompt_length=CONTEXT_WINDOW_SIZE // 2, # when truncating by cutting off the start of the prompt, cut it down to half the context window size 100 | ) 101 | 102 | os.makedirs(OUTPUT_DIRECTORY, exist_ok=True) 103 | trainer.train() 104 | peft_base_model.save_pretrained(os.path.join(OUTPUT_DIRECTORY, "final_checkpoint"), safe_serialization=True) # save LoRA by itself 105 | 106 | 107 | if __name__ == "__main__": 108 | parser = argparse.ArgumentParser() 109 | parser.add_argument("--resume_from_checkpoint", type=str, default=None, help=f"If specified, start the training from the specified checkpoint (e.g., checkpoint-500). You can get a list of all checkpoints by running: ls {OUTPUT_DIRECTORY}'") 110 | args = parser.parse_args() 111 | 112 | transformers.set_seed(RANDOMNESS_SEED) 113 | 114 | tokenizer = transformers.AutoTokenizer.from_pretrained(BASE_MODEL_PATH, use_safetensors=True) 115 | tokenizer.pad_token = tokenizer.eos_token # the padding token isn't set in the included tokenizer by default (see tokenizer.special_tokens_map for existing special tokens), set it manually 116 | train_dataset, test_dataset = create_datasets(tokenizer) 117 | run_training(train_dataset, test_dataset, tokenizer, args.resume_from_checkpoint) 118 | -------------------------------------------------------------------------------- /merge_peft_adapter.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | os.environ['HF_DATASETS_OFFLINE'] = '1' # ask datasets library not to make arbitrary web requests 5 | os.environ['TRANSFORMERS_OFFLINE'] = '1' # ask transformer library not to make arbitrary web requests 6 | os.environ['BITSANDBYTES_NOWELCOME'] = '1' # disable welcome message that bitsandbytes prints, it's unnecessary noise 7 | 8 | import torch 9 | import peft 10 | import transformers 11 | 12 | def merge_lora_back_into_base_model(lora_path, base_model_path, output_path): 13 | peft_config = peft.PeftConfig.from_pretrained(lora_path) 14 | assert peft_config.task_type == peft.TaskType.CAUSAL_LM, peft_config.task_type 15 | base_model = transformers.AutoModelForCausalLM.from_pretrained(base_model_path, torch_dtype=torch.float16, return_dict=True, use_safetensors=True) # the dtype should already be torch.float16 for this particular model, but we actually have to specify this explicitly because , see https://github.com/huggingface/transformers/blob/07360b6c9c9448d619a82798419ed291dfc6ac8f/src/transformers/models/llama/convert_llama_weights_to_hf.py#L259 for details 16 | peft_base_model = peft.PeftModel.from_pretrained(base_model, lora_path) 17 | peft_base_model.eval() # switch the model over into inference mode, disabling training-specific functionality such as dropout layers 18 | merged_model = peft_base_model.merge_and_unload() 19 | del merged_model.config._name_or_path # remove path metadata from the model config 20 | merged_model.save_pretrained(output_path, safe_serialization=True) 21 | 22 | tokenizer = transformers.AutoTokenizer.from_pretrained(base_model_path) 23 | tokenizer.save_pretrained(output_path) 24 | 25 | if __name__ == "__main__": 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument("lora_path", type=str, help=f"Path to the directory containing adapter_model.safetensors") 28 | parser.add_argument("base_model_path", type=str, help=f"Path to the directory containing config.json") 29 | parser.add_argument("output_path", type=str, help=f"Path that will become the output directory") 30 | args = parser.parse_args() 31 | 32 | merge_lora_back_into_base_model(args.lora_path, args.base_model_path, args.output_path) 33 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # NOTE: the Makefile also has a special pip install command for PyTorch under the `venv/requirements_installed` target, which can't be put in this file because it uses a different Python package index 2 | 3 | git+https://github.com/huggingface/transformers.git@b257c46a075419c09e5ce5c5aa39bc346ecdb9a5 4 | git+https://github.com/huggingface/peft.git@e06d94ddeb6c70913593740618df76908b918d66 5 | git+https://github.com/lvwerra/trl.git@170d58ffcede84b3bc822294317fc2bb6df85865 6 | 7 | # extra libraries whose presence changes the functionality of the above libraries 8 | safetensors==0.3.1 9 | bitsandbytes==0.41.0 10 | 11 | # transitive dependencies that the above libraries don't specify properly but are necessary anyways 12 | einops==0.6.1 13 | scipy==1.11.1 14 | sentencepiece==0.1.99 15 | protobuf==4.23.4 16 | -------------------------------------------------------------------------------- /supervised_finetuning.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import json 4 | import random 5 | 6 | os.environ['HF_DATASETS_OFFLINE'] = '1' # ask datasets library not to make arbitrary web requests 7 | os.environ['TRANSFORMERS_OFFLINE'] = '1' # ask transformer library not to make arbitrary web requests 8 | os.environ['BITSANDBYTES_NOWELCOME'] = '1' # disable welcome message that bitsandbytes prints, it's unnecessary noise 9 | 10 | import torch 11 | import peft 12 | import transformers 13 | import trl.trainer.utils 14 | 15 | 16 | DATA_JSON_PATH = "./huggingface_cache/datasets--jondurbin--airoboros-gpt4-1.4.1/snapshots/433c04038d724bf29a193bc3c1a48b600cc417a1/instructions.jsonl" # downloaded by running `make download-datasets-and-models` in this repo 17 | BASE_MODEL_PATH = "./huggingface_cache/models--NousResearch--Llama-2-13b-hf/snapshots/81da3af9503579bf991e3995564baa683b27d38c/" # downloaded by running `make download-datasets-and-models` in this repo 18 | OUTPUT_DIRECTORY = "./llama2-supervised-finetuning-output" 19 | RANDOMNESS_SEED = 0 20 | BATCH_SIZE = 1 # number of samples seen per gradient update - to increase training speed, set this to the largest size that your hardware can support without running out of memory 21 | CONTEXT_WINDOW_SIZE = 2048 # size of individual dataset entries used when training the model - to improve performance on longer prompts, set this to the largest size that your hardware can support without running out of memory 22 | TRAINING_STEPS = 4500 # number of steps to train for (since we're using gradient_accumulation_steps=4, the model will see TRAINING_STEPS * 4 * BATCH_SIZE samples throughout the entire training run) 23 | 24 | 25 | def prompt_formatter(example): 26 | return f'\n\n### Human:\n{example["instruction"]}\n\n### Assistant:\n{example["response"]}' 27 | 28 | 29 | def create_datasets(tokenizer): 30 | # generate train/test split with 0.5% test 31 | with open(DATA_JSON_PATH) as f: 32 | dataset = [prompt_formatter(json.loads(line)) for line in f] 33 | random.shuffle(dataset) 34 | test_dataset_size = round(len(dataset) * 0.005) 35 | train_dataset, test_dataset = dataset[test_dataset_size:], dataset[:test_dataset_size] 36 | print(f"train dataset size: {len(train_dataset)}, test dataset size: {len(test_dataset)}") 37 | 38 | # estimate the average number of characters per token in the dataset using 400 samples 39 | total_characters, total_tokens = 0, 0 40 | for _, example in zip(range(400), train_dataset): 41 | total_characters += len(example) 42 | total_tokens += len(tokenizer(example).tokens()) 43 | estimated_chars_per_token = total_characters / total_tokens 44 | print(f"dataset character to token ratio: {estimated_chars_per_token}") 45 | 46 | # pack multiple short examples into a single CONTEXT_WINDOW_SIZE-token-long input sequence, rather than training on each short example individually - improves training efficiency (this technique is known as "example packing") 47 | train_dataset_packed = trl.trainer.utils.ConstantLengthDataset( 48 | tokenizer, 49 | train_dataset, 50 | formatting_func=lambda x: x, 51 | seq_length=CONTEXT_WINDOW_SIZE, 52 | chars_per_token=estimated_chars_per_token, 53 | ) 54 | test_dataset_packed = trl.trainer.utils.ConstantLengthDataset( 55 | tokenizer, 56 | test_dataset, 57 | formatting_func=lambda x: x, 58 | seq_length=CONTEXT_WINDOW_SIZE, 59 | chars_per_token=estimated_chars_per_token, 60 | ) 61 | print(f"packed train dataset size: {sum(1 for _ in train_dataset_packed)}, packed test dataset size: {sum(1 for _ in test_dataset_packed)}") 62 | train_dataset_packed.infinite = True # generate unlimited sequences by repeatedly going through the dataset 63 | return train_dataset_packed, test_dataset_packed 64 | 65 | 66 | def run_training(train_dataset: torch.utils.data.IterableDataset, test_dataset: torch.utils.data.IterableDataset, resume_from_checkpoint: str): 67 | base_model = transformers.AutoModelForCausalLM.from_pretrained(BASE_MODEL_PATH, load_in_8bit=True, low_cpu_mem_usage=True, use_safetensors=True) # load in 8-bit quantized mode 68 | peft.prepare_model_for_kbit_training(base_model) 69 | if resume_from_checkpoint is not None: 70 | checkpoint_name = os.path.join(OUTPUT_DIRECTORY, resume_from_checkpoint) 71 | assert os.path.exists(checkpoint_name), checkpoint_name 72 | peft_base_model = peft.PeftModel.from_pretrained(base_model, checkpoint_name, is_trainable=True) # TODO: is there a way to make this error out if it isn't in safetensors format? 73 | else: 74 | peft_base_model = peft.get_peft_model(base_model, peft.LoraConfig( 75 | # by default, this adds LoRA adapters around Llama2's "q_proj" and "v_proj", see https://github.com/huggingface/peft/blob/5a0e19dda1048ff8caaa12970ba7574f9cdfbf76/src/peft/utils/other.py#L280 for more details 76 | # some other models add more adapters, such as Guanaco, which does it on every linear layer except the head (https://github.com/artidoro/qlora/blob/845188de110d8eb7c95cc8907b54d8cb2e7c01bd/qlora.py#L221), but this doesn't seem to have too noticeable a benefit compared to models that just use these default settings, like Airoboros does (https://github.com/jondurbin/FastChat/blob/5bd738586ae6a495bd73152d74969465f30d43ac/fastchat/train/train_lora.py#L51) 77 | r=128, # number of LoRA attention dimension parameters - directly proportional to LoRA adapter VRAM usage, set this to the largest value your hardware can support without running out of memory 78 | lora_alpha=16, # alpha parameter for LoRA, essentially scales all of the LoRA weights, which determines "how much of an effect" this LoRA has on the final model 79 | lora_dropout=0.05, # dropout probability for LoRA layers 80 | task_type=peft.TaskType.CAUSAL_LM, 81 | )) 82 | peft_base_model.print_trainable_parameters() 83 | torch.cuda.empty_cache() # helps reduce VRAM usage - it's right after the PEFT version of the model was created, so some old stuff is still around unnecessarily 84 | 85 | trainer = transformers.Trainer( 86 | model=peft_base_model, 87 | args=transformers.TrainingArguments( 88 | output_dir=OUTPUT_DIRECTORY, 89 | dataloader_drop_last=True, # when the dataset size isn't evenly divisible by the batch size, the remainder forms an incomplete batch - throw away this batch to avoid having to ever see an incomplete batch 90 | max_steps=TRAINING_STEPS, # perform a fixed number of training steps before stopping 91 | evaluation_strategy="steps", eval_steps=300, # run an evaluation every 300 training steps (~1 hour) 92 | save_strategy="steps", save_steps=300, save_safetensors=True, # save a checkpoint every 300 training steps (~1 hour) 93 | logging_strategy="steps", logging_steps=1, # log output every training step 94 | per_device_train_batch_size=BATCH_SIZE, # batch size used in training 95 | per_device_eval_batch_size=BATCH_SIZE, # batch size used in evaluation 96 | learning_rate=1e-5, warmup_steps=100, # linearly ramp up the learning rate for the AdamW optimizer from 0 to 1e-5 over the first 100 steps, then keep it at 1e-5 afterwards 97 | gradient_accumulation_steps=4, # use gradient accumulation to multiply effective batch size by 4 (without increasing VRAM usage by 4) 98 | gradient_checkpointing=True, # use gradient checkpointing to decrease VRAM usage 99 | bf16=True, # use 16-bit bfloats for training instead of 32-bit floats in most operations (some are still kept in 32-bit for precision) to decrease VRAM usage and increase training performance, in practice the precision loss has a relatively small effect on the final result 100 | tf32=True, # in newer NVIDIA hardware, this replaces the remaining 32-bit operations with a 19-bit TensorFloat operations to increase training performance, in practice the precision loss has no noticeable effect on the final result 101 | weight_decay=0.05, # set the weight decay regularization factor of the optimizer 102 | run_name="llama2-supervised-finetuning", 103 | # TODO: when Apex becomes more stable + easier to install, look into using adamw_apex_fused rather than adamw_hf for the optim= parameter 104 | ), 105 | train_dataset=train_dataset, 106 | eval_dataset=test_dataset, 107 | tokenizer=tokenizer, 108 | callbacks=[trl.trainer.utils.PeftSavingCallback], 109 | ) 110 | 111 | os.makedirs(OUTPUT_DIRECTORY, exist_ok=True) 112 | trainer.train() 113 | peft_base_model.save_pretrained(os.path.join(OUTPUT_DIRECTORY, "final_checkpoint"), safe_serialization=True) # save LoRA by itself 114 | 115 | 116 | if __name__ == "__main__": 117 | parser = argparse.ArgumentParser() 118 | parser.add_argument("--resume_from_checkpoint", type=str, default=None, help=f"If specified, start the training from the specified checkpoint (e.g., checkpoint-500). You can get a list of all checkpoints by running: ls {OUTPUT_DIRECTORY}'") 119 | args = parser.parse_args() 120 | 121 | transformers.set_seed(RANDOMNESS_SEED) 122 | 123 | tokenizer = transformers.AutoTokenizer.from_pretrained(BASE_MODEL_PATH, use_safetensors=True) 124 | tokenizer.pad_token = tokenizer.eos_token # the padding token isn't set in the included tokenizer by default (see tokenizer.special_tokens_map for existing special tokens), set it manually 125 | train_dataset, test_dataset = create_datasets(tokenizer) 126 | run_training(train_dataset, test_dataset, args.resume_from_checkpoint) 127 | --------------------------------------------------------------------------------