├── common.py ├── README.md ├── inference.py └── train.py /common.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from modal import Image, Stub, Volume 3 | 4 | BASE_MODEL = "mistralai/Mistral-7B-v0.1" 5 | MODEL_PATH = Path("/model") 6 | 7 | # Baking the pretrained model weights and tokenizer into our container image, so we don't need to re-download them every time. 8 | # Note that we run this function as a build step when we define our image below. 9 | def download_model(): 10 | from transformers import AutoModelForCausalLM, AutoTokenizer 11 | 12 | model = AutoModelForCausalLM.from_pretrained(BASE_MODEL) 13 | model.save_pretrained(MODEL_PATH) 14 | 15 | tokenizer = AutoTokenizer.from_pretrained( 16 | BASE_MODEL, 17 | model_max_length=512, 18 | padding_side="left", 19 | add_eos_token=True 20 | ) 21 | tokenizer.pad_token = tokenizer.eos_token 22 | tokenizer.save_pretrained(MODEL_PATH) 23 | 24 | 25 | # Defining our container image, which includes installing all the required dependencies 26 | # and downloading our pretrained model weights 27 | image = ( 28 | Image.micromamba() 29 | .micromamba_install( 30 | "cudatoolkit=11.8", 31 | "cudnn=8.1.0", 32 | "cuda-nvcc", 33 | channels=["conda-forge", "nvidia"], 34 | ) 35 | .apt_install("git") 36 | # packages pinned to 11/1/2023 37 | .pip_install( 38 | # pinned to 11/1/2023 39 | "bitsandbytes==0.41.1", 40 | "peft==0.6.0", 41 | "transformers==4.35.0", 42 | "accelerate==0.24.1", 43 | "datasets==2.14.6", 44 | "scipy==1.11.3", 45 | "wandb==0.15.12", 46 | "py7zr", # needed for samsum dataset 47 | ) 48 | .pip_install( 49 | "torch==2.0.1+cu118", index_url="https://download.pytorch.org/whl/cu118" 50 | ) 51 | .run_function(download_model) 52 | ) 53 | 54 | stub = Stub(name="example-mistral-7b-finetune", image=image) 55 | 56 | # Setting up persisting Volumes to store our training data and finetuning results across runs 57 | stub.training_data_volume = Volume.persisted("training-data-vol") 58 | stub.results_volume = Volume.persisted("results-vol") 59 | 60 | # Defining mount paths for Volumes within container 61 | VOLUME_CONFIG = { 62 | "/training_data": stub.training_data_volume, 63 | "/results": stub.results_volume, 64 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Fine-tuning Mistral 7B on a single GPU with QLoRA 2 | 3 | This simple guide will help you fine-tune any language model to make it better at a specific task. With Modal, you can do this training and serve your model in the cloud in minutes - without having to deal with any infrastructure headaches like building images and setting up GPUs. 4 | 5 | For this guide, we train [Mistral 7B](https://huggingface.co/mistralai/Mistral-7B-v0.1) on a single GPU using [QLoRA](https://github.com/artidoro/qlora), an efficient fine-tuning technique that combines quantization with LoRA to reduce memory usage while preserving task performance. We use 4-bit quantization and train our model on the [SAMsum dataset](https://huggingface.co/datasets/samsum), an existing dataset that summarizes messenger-like conversations in the third person. Modal's easy [GPU-accelerated setup](https://modal.com/docs/guide/gpu) and [built-in storage system](https://modal.com/docs/guide/volumes) help us kick off training in no time. 6 | 7 | It's easy to tweak this repository to fit your needs: 8 | - To train another language model, define `BASE_MODEL` with the desired model name (`common.py`) 9 | - To use your own training data (saved in local .csv or .jsonl files), upload your test and validation datasets to your modal.Volume using the CLI command `modal volume put training-data-vol /local_path/to/dataset /training_data`. Make sure to modify the prompt templates to match your dataset. 10 | - To change the quantization parameters (or do without quantization altogether), modify the `BitsandBytesConfig` in `train.py` (and make sure to apply the same modifications in `inference.py`) 11 | 12 | ## Before we start - set up a Modal account 13 | 1. Create an account on [modal.com](https://modal.com/). 14 | 2. Install `modal` in your current Python virtual environment (`pip install modal`) 15 | 3. Set up a Modal token in your environment (`python3 -m modal setup`) 16 | 4. If you want to monitor your training runs using Weights and Biases, you need to have a [secret](https://modal.com/secrets) named `my-wandb-secret` in your Modal workspace. Only the `WANDB_API_KEY` is needed, which you can get if you log into your Weights and Biases account and go to the [Authorize page](https://wandb.ai/authorize). 17 | 18 | ## Training 19 | To launch a training job, use: 20 | ``` 21 | modal run train.py 22 | ``` 23 | 24 | Flags: 25 | - `--detach`: don't terminate app when your local process dies or disconnects (i.e. makes sure you don't accidentally terminate your training job when you close your terminal). 26 | ``` 27 | modal run --detach train.py 28 | ``` 29 | - `--run_id`: use your own run_id to track your training runs (otherwise will default to `mistral7b-finetune-%Y-%m-%d-%H-%M`) 30 | ``` 31 | modal run train.py --run_id 32 | ``` 33 | - `--resume-from-checkpoint`: resume training from a certain checkpoint saved to your results volume. 34 | ``` 35 | modal run train.py --resume-from-checkpoint /results/ 36 | ``` 37 | 38 | You should make sure your adapter weights have been properly saved in your results volume by running: 39 | ``` 40 | modal volume ls results-vol 41 | ``` 42 | 43 | ## Inference 44 | To try out your freshly fine-tuned model, use: 45 | ``` 46 | modal run inference.py --run_id 47 | ``` 48 | 49 | ## Next steps 50 | - Serve your model's inference function using a Modal [web endpoint](https://modal.com/docs/guide/webhooks). Note that leaving an endpoint deployed on Modal doesn't cost you anything, since we scale them serverlessly. See our [QuiLLMan](https://github.com/modal-labs/quillman/) repository for an example of a FastAPI server with an inference endpoint (in `quillman/src/app.py`) 51 | - Try training another model. If the model you'd like to fine-tune is too large for single-GPU training, take a look at our [Llama finetuning repository](https://github.com/modal-labs/llama-finetuning/), which uses FSDP to scale training optimally with multi-GPU. -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | from modal import method 2 | from typing import Optional 3 | from common import stub, MODEL_PATH, VOLUME_CONFIG 4 | 5 | 6 | @stub.cls(gpu="A100", volumes=VOLUME_CONFIG) 7 | class Model: 8 | def __init__(self, run_id: Optional[str] = None): 9 | import torch 10 | from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig 11 | from peft import PeftModel 12 | 13 | # Quantization config, make sure it is the same as the one used during training 14 | bnb_config = BitsAndBytesConfig( 15 | load_in_4bit=True, 16 | bnb_4bit_use_double_quant=True, 17 | bnb_4bit_quant_type="nf4", 18 | bnb_4bit_compute_dtype=torch.bfloat16, 19 | ) 20 | 21 | base_model = AutoModelForCausalLM.from_pretrained( 22 | MODEL_PATH, 23 | quantization_config=bnb_config, 24 | device_map="auto", 25 | trust_remote_code=True, 26 | ) 27 | 28 | if run_id: 29 | self.model = PeftModel.from_pretrained( # model with adapter 30 | base_model, 31 | f"/results/{run_id}", 32 | torch_dtype=torch.bfloat16, 33 | ) 34 | else: 35 | self.model = base_model 36 | 37 | self.eval_tokenizer = AutoTokenizer.from_pretrained( 38 | MODEL_PATH, 39 | add_bos_token=True, 40 | trust_remote_code=True, 41 | padding_side="right", 42 | ) 43 | 44 | self.template = "[INST] <>\nUse the Input to provide a summary of a conversation.\n<>\n\nInput:\n{message} [/INST]\n\nSummary:" 45 | 46 | def tokenize_prompt(self, prompt: str = ""): 47 | return self.eval_tokenizer(prompt, return_tensors="pt").to("cuda") 48 | 49 | @method() 50 | async def generate(self, message: str): 51 | import torch 52 | 53 | model_input = self.tokenize_prompt(self.template.format(message=message)) 54 | 55 | self.model.eval() 56 | with torch.no_grad(): 57 | print( 58 | self.eval_tokenizer.decode( 59 | self.model.generate( 60 | **model_input, 61 | max_new_tokens=100, 62 | eos_token_id=self.eval_tokenizer.eos_token_id, 63 | )[0], 64 | skip_special_tokens=True, 65 | ) 66 | ) 67 | 68 | 69 | @stub.local_entrypoint() 70 | def main(run_id: str): 71 | if not run_id: 72 | print( 73 | "Warning: run_id not found. Please input run_id from previous training run to generate with trained adapter." 74 | ) 75 | print("Usage with trained adapter: modal run inference.py --run_id ") 76 | 77 | messages = [ 78 | "Eric: MACHINE! Rob: That's so gr8! Eric: I know! And shows how Americans see Russian ;) Rob: And it's really funny! Eric: I know! I especially like the train part! Rob: Hahaha! No one talks to the machine like that! Eric: Is this his only stand-up? Rob: Idk. I'll check. Eric: Sure. Rob: Turns out no! There are some of his stand-ups on youtube. Eric: Gr8! I'll watch them now! Rob: Me too! Eric: MACHINE! Rob: MACHINE! Eric: TTYL? Rob: Sure :)", 79 | "Ollie: Hi , are you in Warsaw Jane: yes, just back! Btw are you free for diner the 19th? Ollie: nope! Jane: and the 18th? Ollie: nope, we have this party and you must be there, remember? Jane: oh right! i lost my calendar.. thanks for reminding me Ollie: we have lunch this week? Jane: with pleasure! Ollie: friday? Jane: ok Jane: what do you mean 'we don't have any more whisky!' lol.. Ollie: what!!! Jane: you just call me and the all thing i heard was that sentence about whisky... what's wrong with you? Ollie: oh oh... very strange! i have to be carefull may be there is some spy in my mobile! lol Jane: dont' worry, we'll check on friday. Ollie: don't forget to bring some sun with you Jane: I can't wait to be in Morocco.. Ollie: enjoy and see you friday Jane: sorry Ollie, i'm very busy, i won't have time for lunch tomorrow, but may be at 6pm after my courses?this trip to Morocco was so nice, but time consuming! Ollie: ok for tea! Jane: I'm on my way.. Ollie: tea is ready, did you bring the pastries? Jane: I already ate them all... see you in a minute Ollie: ok", 80 | "Rita: I'm so bloody tired. Falling asleep at work. :-( Tina: I know what you mean. Tina: I keep on nodding off at my keyboard hoping that the boss doesn't notice.. Rita: The time just keeps on dragging on and on and on.... Rita: I keep on looking at the clock and there's still 4 hours of this drudgery to go. Tina: Times like these I really hate my work. Rita: I'm really not cut out for this level of boredom. Tina: Neither am I.", 81 | ] 82 | 83 | print("=" * 20 + "Generating without adapter" + "=" * 20) 84 | for summary in Model().generate.map(messages): 85 | print(summary) 86 | 87 | if run_id: 88 | print("=" * 20 + "Generating with adapter" + "=" * 20) 89 | for summary in Model(run_id=run_id).generate.map(messages): 90 | print(summary) 91 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from modal import Secret 3 | from transformers import TrainerCallback 4 | 5 | from common import stub, BASE_MODEL, MODEL_PATH, VOLUME_CONFIG 6 | 7 | WANDB_PROJECT = "hf-mistral7b-finetune" 8 | 9 | # Callback function to store model checkpoints in modal.Volume 10 | class CheckpointCallback(TrainerCallback): 11 | def __init__(self, volume): 12 | self.volume = volume 13 | 14 | def on_save(self, args, state, control, **kwargs): 15 | if state.is_world_process_zero: 16 | print("running commit on modal.Volume after model checkpoint") 17 | self.volume.commit() 18 | 19 | 20 | # Download training dataset from Hugging Face and push to modal.Volume. 21 | # You can load in your own dataset to push to a Volume here or push local data files 22 | # using `modal volume put VOLUME_NAME [LOCAL_PATH] [REMOTE_PATH]` in your CLI. 23 | @stub.function(volumes=VOLUME_CONFIG) 24 | def download_dataset(): 25 | import os 26 | from datasets import load_dataset 27 | 28 | if not os.path.exists('/training_data/data_train.jsonl'): 29 | def format_instruction(sample): 30 | PROMPT_TEMPLATE = "[INST] <>\nUse the Input to provide a summary of a conversation.\n<>\n\nInput:\n{message} [/INST]\n\nSummary: {summary}" 31 | return {"text": PROMPT_TEMPLATE.format(message=sample["dialogue"], summary=sample["summary"])} 32 | 33 | # downloading data from hugging face 34 | train_dataset = load_dataset('samsum', split='train') 35 | val_dataset = load_dataset('samsum', split='validation') 36 | 37 | train_dataset = train_dataset.map(format_instruction, remove_columns=['id', 'dialogue', 'summary']) 38 | val_dataset = val_dataset.map(format_instruction, remove_columns=['id', 'dialogue', 'summary']) 39 | 40 | # writing data to Volume mounted at "/training-data" in container 41 | train_dataset.to_json(f"/training_data/data_train.jsonl") 42 | val_dataset.to_json(f"/training_data/data_val.jsonl") 43 | 44 | stub.training_data_volume.commit() 45 | 46 | 47 | @stub.function( 48 | gpu="A100", 49 | secret=Secret.from_name("my-wandb-secret") if WANDB_PROJECT else None, 50 | timeout=60 * 60 * 4, 51 | volumes=VOLUME_CONFIG, 52 | ) 53 | def finetune( 54 | model_name: str, 55 | run_id: str = "", 56 | wandb_project: str = "", 57 | resume_from_checkpoint: str = None # path to checkpoint in Volume (e.g. "/results/checkpoint-300/") 58 | ): 59 | import os 60 | import torch 61 | import transformers 62 | from peft import ( 63 | LoraConfig, 64 | get_peft_model, 65 | prepare_model_for_kbit_training, 66 | set_peft_model_state_dict, 67 | ) 68 | from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig 69 | from datasets import load_dataset 70 | 71 | bnb_config = BitsAndBytesConfig( 72 | load_in_4bit=True, 73 | bnb_4bit_use_double_quant=True, 74 | bnb_4bit_quant_type="nf4", 75 | bnb_4bit_compute_dtype=torch.bfloat16 76 | ) 77 | 78 | model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, quantization_config=bnb_config) # Load and quantize the pretrained model baked into our image 79 | 80 | tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) 81 | tokenizer.pad_token = tokenizer.eos_token 82 | 83 | def tokenize(sample, cutoff_len=512, add_eos_token=True): 84 | prompt = sample["text"] 85 | result = tokenizer.__call__( 86 | prompt, 87 | truncation=True, 88 | max_length=cutoff_len, 89 | padding="max_length", 90 | ) 91 | if ( 92 | result["input_ids"][-1] != tokenizer.eos_token_id 93 | and len(result["input_ids"]) < cutoff_len 94 | and add_eos_token 95 | ): 96 | result["input_ids"].append(tokenizer.eos_token_id) 97 | result["attention_mask"].append(1) 98 | result["labels"] = result["input_ids"].copy() 99 | 100 | return result 101 | 102 | # Load datasets from training data Volume 103 | train_dataset = load_dataset('json', data_files='/training_data/data_train.jsonl', split="train") 104 | eval_dataset = load_dataset('json', data_files='/training_data/data_val.jsonl', split="train") 105 | 106 | tokenized_train_dataset = train_dataset.map(tokenize) 107 | tokenized_val_dataset = eval_dataset.map(tokenize) 108 | 109 | model.gradient_checkpointing_enable() 110 | model = prepare_model_for_kbit_training(model) 111 | 112 | config = LoraConfig( 113 | r=64, 114 | lora_alpha=16, 115 | target_modules=[ 116 | "q_proj", 117 | "k_proj", 118 | "v_proj", 119 | "o_proj", 120 | "gate_proj", 121 | "up_proj", 122 | "down_proj", 123 | "lm_head", 124 | ], 125 | bias="none", 126 | lora_dropout=0.05, 127 | task_type="CAUSAL_LM", 128 | ) 129 | model = get_peft_model(model, config) 130 | 131 | if len(wandb_project) > 0: 132 | # Set environment variables if Weights and Biases is enabled 133 | os.environ["WANDB_PROJECT"] = wandb_project 134 | os.environ["WANDB_WATCH"] = "gradients" 135 | os.environ["WANDB_LOG_MODEL"] = "checkpoint" 136 | 137 | if resume_from_checkpoint: 138 | # Check the available weights and load them 139 | checkpoint_name = os.path.join(resume_from_checkpoint, "pytorch_model.bin") # Full checkpoint 140 | if not os.path.exists(checkpoint_name): 141 | checkpoint_name = os.path.join( 142 | resume_from_checkpoint, "adapter_model.bin" 143 | ) # only LoRA model - LoRA config above has to fit 144 | resume_from_checkpoint = False # So the trainer won't try loading its state 145 | # The two files above have a different name depending on how they were saved, but are actually the same. 146 | if os.path.exists(checkpoint_name): 147 | print(f"Restarting from {checkpoint_name}") 148 | adapters_weights = torch.load(checkpoint_name) 149 | set_peft_model_state_dict(model, adapters_weights) 150 | else: 151 | print(f"Checkpoint {checkpoint_name} not found") 152 | 153 | model.print_trainable_parameters() 154 | 155 | if torch.cuda.device_count() > 1: 156 | # keeps Trainer from trying its own DataParallelism when more than 1 gpu is available 157 | model.is_parallelizable = True 158 | model.model_parallel = True 159 | 160 | trainer = transformers.Trainer( 161 | model=model, 162 | train_dataset=tokenized_train_dataset, 163 | eval_dataset=tokenized_val_dataset, 164 | callbacks=[CheckpointCallback(stub.results_volume)], # Callback function for committing a checkpoint to Volume when reached 165 | args=transformers.TrainingArguments( 166 | output_dir=f"/results/{run_id}", # Must also set this to write into results Volume's mount location 167 | warmup_steps=5, 168 | per_device_train_batch_size=8, 169 | gradient_accumulation_steps=4, 170 | max_steps=1000, # Feel free to tweak to correct for under/overfitting 171 | learning_rate=2e-5, # ~10x smaller than Mistral's learning rate 172 | bf16=True, 173 | optim="adamw_8bit", 174 | save_strategy="steps", # Save the model checkpoint every logging step 175 | save_steps=50, # Save checkpoints every 50 steps 176 | evaluation_strategy="steps", # Evaluate the model every logging step 177 | eval_steps=50, # Evaluate and save checkpoints every 50 steps 178 | do_eval=True, # Perform evaluation at the end of training 179 | report_to="wandb" if wandb_project else "", 180 | run_name=run_id if wandb_project else "" 181 | ), 182 | data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False), 183 | ) 184 | 185 | model.config.use_cache = False # Silence the warnings. Re-enable for inference! 186 | trainer.train() # Run training 187 | 188 | model.save_pretrained(f"/results/{run_id}") 189 | stub.results_volume.commit() 190 | 191 | 192 | @stub.local_entrypoint() 193 | def main(run_id: str = "", resume_from_checkpoint: str = None): 194 | print("Downloading data from Hugging Face and syncing to volume.") 195 | download_dataset.remote() 196 | print("Finished syncing data.") 197 | 198 | if not run_id: 199 | run_id = f"mistral7b-finetune-{datetime.now().strftime('%Y-%m-%d-%H-%M')}" 200 | 201 | print(f"Starting training run {run_id=}.") 202 | finetune.remote(model_name=BASE_MODEL, run_id=run_id, wandb_project=WANDB_PROJECT, resume_from_checkpoint=resume_from_checkpoint) 203 | print(f"Completed training run {run_id=}") 204 | print("To test your trained model, run `modal run inference.py --run_id `") --------------------------------------------------------------------------------