├── .gitignore ├── LICENSE ├── README.md ├── app.py ├── finetune_model.py ├── requirements.txt └── test.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | **/.DS_Store/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 neuralwork 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Fine-tuning LLMs with PEFT 2 | This project is a tutorial on parameter-efficient fine-tuning (PEFT) and quantization of the [Mistral 7B v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1) model. We use LoRA for PEFT and 4-bit quantization to compress the model, and fine-tune the model on a semi-manually crafted fashion style recommendation instruct [dataset](https://huggingface.co/datasets/neuralwork/fashion-style-instruct). For more information and a step by step guide, see our [blog post](https://blog.neuralwork.ai/an-llm-fine-tuning-cookbook-with-mistral-7b/). 3 | 4 | ## Usage 5 | Start by cloning the repository, setting up a conda environment and installing the dependencies. We tested our scripts with python 3.9 and CUDA 11.7. 6 | ``` 7 | git clone https://github.com/neuralwork/finetune-mistral.git 8 | cd finetune-mistral 9 | 10 | conda create -n llm python=3.9 11 | conda activate llm 12 | pip install -r requirements.txt 13 | ``` 14 | 15 | You can finetune the model on our fashion-style-instruct [dataset](https://huggingface.co/datasets/neuralwork/fashion-style-instruct) or another dataset. Note that you will need to have the same features as our dataset and pass in your HF Hub token as an argument if using a private dataset. Fine-tuning takes about 2 hours on a single A40, you can either use the default accelerate settings or configure it to use multiple GPUS. To fine-tune the model: 16 | ``` 17 | accelerate config default 18 | 19 | python finetune_model.py --dataset= --base_model="mistralai/Mistral-7B-v0.1" --model_name= --auth_token= --push_to_hub 20 | ``` 21 | 22 | One model training is completed, only the fine-tuned (LoRA) parameters are saved, which are loaded to overwrite the corresponding parameters of the base model during testing. 23 | 24 | To test the fine-tuned model with a random sample selected from the dataset, run `python test.py`. To launch the full Gradio demo and play around with your own examples, launch the demo with `python app.py` 25 | 26 | 27 | ## License 28 | This project is licensed under the [MIT license](https://github.com/neuralwork/finetune-mistral/blob/main/LICENSE). 29 | 30 | From [neuralwork](https://neuralwork.ai/) with :heart: 31 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | import torch 5 | import gradio as gr 6 | from peft import AutoPeftModelForCausalLM 7 | from transformers import AutoTokenizer 8 | 9 | 10 | events = [ 11 | "nature retreat", 12 | "work / office event", 13 | "wedding as a guest", 14 | "tropical vacation", 15 | "conference", 16 | "sports event", 17 | "winter vacation", 18 | "beach", 19 | "play / concert", 20 | "picnic", 21 | "night club", 22 | "national parks", 23 | "music festival", 24 | "job interview", 25 | "city tour", 26 | "halloween party", 27 | "graduation", 28 | "gala / exhibition opening", 29 | "fancy date", 30 | "cruise", 31 | "casual gathering", 32 | "concert", 33 | "cocktail party", 34 | "casual date", 35 | "business meeting", 36 | "camping / hiking", 37 | "birthday party", 38 | "bar", 39 | "business lunch", 40 | "bachelorette / bachelor party", 41 | "semi-casual event", 42 | ] 43 | 44 | 45 | def format_instruction(input, context): 46 | return f"""You are a personal stylist recommending fashion advice and clothing combinations. Use the self body and style description below, combined with the event described in the context to generate 5 self-contained and complete outfit combinations. 47 | ### Input: 48 | {input} 49 | 50 | ### Context: 51 | I'm going to a {context}. 52 | 53 | ### Response: 54 | """ 55 | 56 | 57 | def main(): 58 | # load base LLM model, LoRA params and tokenizer 59 | model = AutoPeftModelForCausalLM.from_pretrained( 60 | "neuralwork/mistral-7b-style-instruct", 61 | low_cpu_mem_usage=True, 62 | torch_dtype=torch.float16, 63 | load_in_4bit=True, 64 | ) 65 | tokenizer = AutoTokenizer.from_pretrained("neuralwork/mistral-7b-style-instruct") 66 | 67 | def postprocess(outputs, prompt): 68 | outputs = outputs.detach().cpu().numpy() 69 | output = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] 70 | output = output[len(prompt) :] 71 | return output 72 | 73 | def generate( 74 | prompt: str, 75 | event: str, 76 | top_p: float, 77 | temperature: float, 78 | max_new_tokens: int, 79 | min_new_tokens: int, 80 | seed: int, 81 | ): 82 | torch.manual_seed(seed) 83 | prompt = format_instruction(str(prompt), str(event)) 84 | input_ids = tokenizer( 85 | prompt, return_tensors="pt", truncation=True 86 | ).input_ids.cuda() 87 | 88 | with torch.inference_mode(): 89 | outputs = model.generate( 90 | input_ids=input_ids, 91 | max_new_tokens=max_new_tokens, 92 | min_new_tokens=min_new_tokens, 93 | do_sample=True, 94 | top_p=top_p, 95 | temperature=temperature, 96 | ) 97 | 98 | output = postprocess(outputs, prompt) 99 | return output 100 | 101 | with gr.Blocks() as demo: 102 | gr.HTML( 103 | """ 104 |

105 | Instruct Fine-tune Mistral-7B-v0 106 |

107 |

Mistral-7B-v0 fine-tuned on the neuralwork/style-instruct dataset. 108 | To use the model, simply describe your body type and personal style and select the type of event you're planning to go. 109 |
110 | See our blog post for a detailed tutorial to fine-tune Mistral on your own dataset. 111 |

""" 112 | ) 113 | with gr.Row(): 114 | with gr.Column(scale=1): 115 | prompt = gr.Textbox( 116 | lines=4, 117 | label="Style prompt, describe your body type and fashion style.", 118 | interactive=True, 119 | value="I'm an above average height athletic woman with slightly of broad shoulders and a medium sized bust. I generally prefer a casual but sleek look with dark colors and jeans.", 120 | ) 121 | event = gr.Dropdown( 122 | choices=events, value="semi-casual event", label="Event type" 123 | ) 124 | seed = gr.Number( 125 | value=1371, 126 | precision=0, 127 | interactive=True, 128 | label="Seed for reproducibility, set to -1 to randomize seed", 129 | ) 130 | top_p = gr.Slider( 131 | value=0.9, 132 | label="Top p (nucleus sampling)", 133 | minimum=0.0, 134 | maximum=1.0, 135 | step=0.01, 136 | ) 137 | max_new_tokens = gr.Slider( 138 | minimum=1, 139 | maximum=2048, 140 | value=1500, 141 | label="Maximum new tokens", 142 | ) 143 | min_new_tokens = gr.Slider( 144 | minimum=-1, maximum=2048, value=-1, label="Minimum new tokens" 145 | ) 146 | temperature = gr.Slider( 147 | minimum=0.01, maximum=5, value=0.9, step=0.01, label="Temperature" 148 | ) 149 | repetition_penalty = gr.Slider( 150 | label="Repetition penalty", 151 | minimum=1.0, 152 | maximum=2.0, 153 | step=0.05, 154 | value=1.2, 155 | ) 156 | generate_button = gr.Button("Get outfit suggestions") 157 | 158 | with gr.Column(scale=2): 159 | response = gr.Textbox( 160 | lines=6, label="Outfit suggestions", interactive=False 161 | ) 162 | 163 | gr.Markdown("From [neuralwork](https://neuralwork.ai/) with :heart:") 164 | 165 | generate_button.click( 166 | fn=generate, 167 | inputs=[ 168 | prompt, 169 | event, 170 | top_p, 171 | temperature, 172 | max_new_tokens, 173 | min_new_tokens, 174 | seed, 175 | ], 176 | outputs=response, 177 | ) 178 | 179 | demo.launch(share=True) 180 | 181 | 182 | if __name__ == "__main__": 183 | main() 184 | -------------------------------------------------------------------------------- /finetune_model.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import argparse 4 | 5 | import torch 6 | from trl import SFTTrainer 7 | from datasets import load_dataset 8 | from transformers import TrainingArguments 9 | from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig 10 | from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model 11 | 12 | 13 | def print_trainable_parameters(model): 14 | """ 15 | Prints the number of trainable parameters in the model. 16 | """ 17 | trainable_params = 0 18 | all_param = 0 19 | for _, param in model.named_parameters(): 20 | all_param += param.numel() 21 | if param.requires_grad: 22 | trainable_params += param.numel() 23 | print(f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}") 24 | 25 | 26 | def format_instruction(sample): 27 | return f"""You are a personal stylist recommending fashion advice and clothing combinations. Use the self body and style description below, combined with the event described in the context to generate 5 self-contained and complete outfit combinations. 28 | ### Input: 29 | {sample["input"]} 30 | 31 | ### Context: 32 | {sample["context"]} 33 | 34 | ### Response: 35 | {sample["completion"]} 36 | """ 37 | 38 | def finetune_model(args): 39 | dataset = load_dataset(args.dataset, token=args.auth_token) 40 | # base model to finetune 41 | model_id = args.base_model 42 | 43 | # BitsAndBytesConfig to quantize the model int-4 config 44 | bnb_config = BitsAndBytesConfig( 45 | load_in_4bit=True, 46 | bnb_4bit_use_double_quant=True, 47 | bnb_4bit_quant_type="nf4", 48 | bnb_4bit_compute_dtype=torch.bfloat16 49 | ) 50 | 51 | # load model and tokenizer 52 | model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config, use_cache=False, device_map="auto") 53 | model.config.pretraining_tp = 1 54 | 55 | tokenizer = AutoTokenizer.from_pretrained(model_id) 56 | tokenizer.pad_token = tokenizer.eos_token 57 | 58 | # LoRA config based on QLoRA paper 59 | peft_config = LoraConfig( 60 | r=32, 61 | lora_alpha=64, 62 | target_modules=[ 63 | "q_proj", 64 | "k_proj", 65 | "v_proj", 66 | "o_proj", 67 | "gate_proj", 68 | "up_proj", 69 | "down_proj", 70 | "lm_head", 71 | ], 72 | bias="none", 73 | lora_dropout=0.05, 74 | task_type="CAUSAL_LM", 75 | ) 76 | 77 | # prepare model for training 78 | model = prepare_model_for_kbit_training(model) 79 | model = get_peft_model(model, peft_config) 80 | 81 | # print the number of trainable model params 82 | print_trainable_parameters(model) 83 | 84 | model_args = TrainingArguments( 85 | output_dir="mistral-7-style", 86 | num_train_epochs=3, 87 | per_device_train_batch_size=4, 88 | gradient_accumulation_steps=2, 89 | gradient_checkpointing=True, 90 | optim="paged_adamw_32bit", 91 | logging_steps=10, 92 | save_strategy="epoch", 93 | learning_rate=2e-4, 94 | bf16=True, 95 | tf32=True, 96 | max_grad_norm=0.3, 97 | warmup_ratio=0.03, 98 | lr_scheduler_type="constant", 99 | disable_tqdm=False 100 | ) 101 | 102 | max_seq_length = 2048 103 | 104 | trainer = SFTTrainer( 105 | model=model, 106 | train_dataset=dataset, 107 | peft_config=peft_config, 108 | max_seq_length=max_seq_length, 109 | tokenizer=tokenizer, 110 | packing=True, 111 | formatting_func=format_instruction, 112 | args=model_args, 113 | ) 114 | 115 | # train 116 | trainer.train() 117 | 118 | # save model 119 | trainer.save_model() 120 | 121 | if args.push_to_hub: 122 | trainer.model.push_to_hub(args.model_name) 123 | 124 | torch.cuda.empty_cache() 125 | 126 | 127 | if __name__ == "__main__": 128 | parser = argparse.ArgumentParser() 129 | parser.add_argument( 130 | "--dataset", type=str, default="neuralwork/fashion-style-instruct", 131 | help="Path to local or HF dataset." 132 | ) 133 | parser.add_argument( 134 | "--base_model", type=str, default="mistralai/Mistral-7B-v0.1", 135 | help="HF hub id of the base model to finetune." 136 | ) 137 | parser.add_argument( 138 | "--model_name", type=str, default="mistral-7b-style-instruct", help="Name of finetuned model." 139 | ) 140 | parser.add_argument( 141 | "--auth_token", type=str, default=None, 142 | help="HF authentication token, only used if downloading a private dataset." 143 | ) 144 | parser.add_argument( 145 | "--push_to_hub", default=False, action="store_true", 146 | help="Whether to push finetuned model to HF hub." 147 | ) 148 | args = parser.parse_args() 149 | finetune_model(args) 150 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.1.0 2 | trl==0.4.7 3 | peft==0.4.0 4 | accelerate==0.21.0 5 | datasets==2.13.0 6 | transformers==4.35.0 7 | huggingface-hub==0.19.4 8 | sentencepiece==0.1.99 9 | bitsandbytes==0.41.1 10 | gradio -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from random import randrange 4 | 5 | import torch 6 | from datasets import load_dataset 7 | from peft import AutoPeftModelForCausalLM 8 | from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig 9 | 10 | 11 | def format_instruction(sample): 12 | return f"""You are a personal stylist recommending fashion advice and clothing combinations. Use the self body and style description below, combined with the event described in the context to generate 5 self-contained and complete outfit combinations. 13 | ### Input: 14 | {sample["input"]} 15 | 16 | ### Context: 17 | {sample["context"]} 18 | 19 | ### Response: 20 | """ 21 | 22 | def postprocess(outputs, tokenizer, prompt, sample): 23 | outputs = outputs.detach().cpu().numpy() 24 | outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True) 25 | output = outputs[0][len(prompt):] 26 | 27 | print(f"Instruction: \n{sample['input']}\n") 28 | print(f"Context: \n{sample['context']}\n") 29 | print(f"Ground truth: \n{sample['completion']}\n") 30 | print(f"Generated output: \n{output}\n\n\n") 31 | return 32 | 33 | 34 | def run_model(config): 35 | # load dataset and select a random sample 36 | dataset = load_dataset(config.dataset) 37 | sample = dataset[randrange(len(dataset))] 38 | prompt = format_instruction(sample) 39 | 40 | # load base LLM model, LoRA params and tokenizer 41 | model = AutoPeftModelForCausalLM.from_pretrained( 42 | config.model_id, 43 | low_cpu_mem_usage=True, 44 | torch_dtype=torch.float16, 45 | load_in_4bit=True, 46 | ) 47 | tokenizer = AutoTokenizer.from_pretrained(config.model_id) 48 | input_ids = tokenizer(prompt, return_tensors="pt", truncation=True).input_ids.cuda() 49 | 50 | # inference 51 | with torch.inference_mode(): 52 | outputs = model.generate( 53 | input_ids=input_ids, 54 | max_new_tokens=800, 55 | do_sample=True, 56 | top_p=0.9, 57 | temperature=0.9 58 | ) 59 | 60 | postprocess(outputs, tokenizer, prompt, sample) 61 | 62 | 63 | if __name__ == "__main__": 64 | parser = argparse.ArgumentParser() 65 | parser.add_argument( 66 | "--dataset", type=str, default="neuralwork/fashion-style-instruct", 67 | help="HF dataset id or path to local dataset folder." 68 | ) 69 | parser.add_argument( 70 | "--model_id", type=str, default="neuralwork/mistral-7b-style-instruct", 71 | help="HF LoRA model id or path to local finetuned model folder." 72 | ) 73 | 74 | config = parser.parse_args() 75 | run_model(config) 76 | --------------------------------------------------------------------------------