└── main.py /main.py: -------------------------------------------------------------------------------- 1 | from peft import get_peft_model, LoraConfig 2 | 3 | def train_with_lora(model_name, dataset_name, output_dir="./lora_model"): 4 | model = AutoModelForCausalLM.from_pretrained(model_name) 5 | tokenizer = AutoTokenizer.from_pretrained(model_name) 6 | 7 | lora_config = LoraConfig(r=16, lora_alpha=32, target_modules=["q_proj", "v_proj"]) 8 | model = get_peft_model(model, lora_config) 9 | 10 | dataset = load_dataset(dataset_name, split="train").map(tokenize_function, batched=True) 11 | training_args = TrainingArguments(output_dir=output_dir, per_device_train_batch_size=4) 12 | trainer = Trainer(model=model, args=training_args, train_dataset=dataset) 13 | trainer.train() 14 | model.save_pretrained(output_dir) 15 | 16 | def tokenize_function(examples): 17 | tokenizer = AutoTokenizer.from_pretrained("gpt2") 18 | return tokenizer(examples["text"], padding="max_length", truncation=True) 19 | 20 | train_with_lora("gpt2", "yahma/alpaca-cleaned") 21 | --------------------------------------------------------------------------------