└── finetuning_jamba.py /finetuning_jamba.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | from trl import SFTTrainer 3 | from peft import LoraConfig 4 | import torch 5 | from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, BitsAndBytesConfig 6 | 7 | #Check if you do not have any import issue to use the Fast Mamba Kernel 8 | #Will (very appropriately) break before loading the weights. 9 | import mamba_ssm 10 | 11 | #With 4bit quants have to manually correct modeling_jamba.py on l. 1070: 12 | #if not is_fast_path_available or "cuda" not in self.x_proj.weight.device.type: 13 | #becoming: 14 | #if not is_fast_path_available: 15 | 16 | quantization_config = BitsAndBytesConfig( 17 | load_in_4bit=True, 18 | llm_int4_skip_modules=["mamba"] #Maybe not necessary (per axoltl) but to test. 19 | ) 20 | 21 | tokenizer = AutoTokenizer.from_pretrained("jamba") 22 | 23 | dataset = load_dataset("Abirate/english_quotes", split="train") 24 | training_args = TrainingArguments( 25 | output_dir="./results", 26 | num_train_epochs=1, 27 | per_device_train_batch_size=1, 28 | gradient_accumulation_steps=4, 29 | optim = "adamw_8bit", 30 | max_grad_norm = 0.3, 31 | weight_decay = 0.001, 32 | warmup_ratio = 0.03, 33 | gradient_checkpointing=True, 34 | logging_dir='./logs', 35 | logging_steps=1, 36 | max_steps=50, 37 | group_by_length=True, 38 | lr_scheduler_type = "linear", 39 | learning_rate=2e-3 40 | ) 41 | lora_config = LoraConfig( 42 | lora_alpha=16, 43 | lora_dropout=0.05, 44 | init_lora_weights=False, 45 | r=8, 46 | target_modules=["embed_tokens", "x_proj", "in_proj", "out_proj"], 47 | task_type="CAUSAL_LM", 48 | bias="none" 49 | ) 50 | 51 | model = AutoModelForCausalLM.from_pretrained( 52 | "jamba", 53 | trust_remote_code=True, 54 | device_map='auto', 55 | attn_implementation="flash_attention_2", 56 | quantization_config=quantization_config, 57 | use_mamba_kernels=True 58 | ) 59 | 60 | trainer = SFTTrainer( 61 | model=model, 62 | tokenizer=tokenizer, 63 | args=training_args, 64 | peft_config=lora_config, 65 | train_dataset=dataset, 66 | max_seq_length = 256, 67 | dataset_text_field="quote", 68 | ) 69 | 70 | trainer.train() 71 | --------------------------------------------------------------------------------