├── README.md ├── finetune ├── configs │ └── ds_config_zero3.json ├── finetune_lora_plus.py └── run_lora.sh ├── requirements.txt └── tricks ├── __init__.py └── lora_plus.py /README.md: -------------------------------------------------------------------------------- 1 | # 🚀 Simple lora plus 🚀 2 | A simple implementation of LoRA+: Efficient Low Rank Adaptation of Large Models 3 | ## 介绍 4 | 对lora+论文进行了简单的实现,并以微调deepseek-coder为例,将实现的lora+方法应用在deepseek-coder微调中🎉 5 | 6 | 在这里只是以deepseek为例,使用本项目构建好的的lora+方法你也可以对其他模型进行微调。 7 | 8 | ## 目录结构 9 | 10 | **fintune**:此目录下是基于deepseek-coder官方实现的微调代码进行修改以适用的微调脚本。 11 | 12 | **tricks**:目录下lora_plus.py即为lora+的实现代码 13 | 14 | ## 使用&细节 15 | 16 | --- 17 | ### 环境要求 18 | 因为是以deepseek-coder为例进行实验,所以环境要求一致。 19 | 20 | ### lora+ 使用 21 | 只需要将你的其他训练脚本中(需要是lora训练)的Trainer改为LoraPlusTrainer即可使用lora+进行训练!👋 22 | ```python 23 | #原始 24 | # trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module) 25 | # 加入lora+ Trainer 26 | trainer = LoraPlusTrainer(model=model, tokenizer=tokenizer, args=training_args, **data_module) 27 | ``` 28 | 29 | 基于huggingface的Trainer,进行LoraPlusTrainer的编写,并且重写create_optimizer方法。 30 | 详细代码如下, 31 | 32 | ```python 33 | class LoraPlusTrainer(Trainer): 34 | def create_optimizer(self): 35 | opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model 36 | if self.optimizer is None: 37 | optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs( 38 | self.args 39 | ) 40 | 41 | lora_lr_ratio = LORA_LR_RATIO 42 | lora_lr_embedding = LORA_LR_EMBEDDING 43 | 44 | self.optimizer = create_lorap_optimizer(opt_model, lora_lr_ratio, optimizer_cls, optimizer_kwargs, 45 | lora_lr_embedding) 46 | if is_sagemaker_mp_enabled(): 47 | self.optimizer = smp.DistributedOptimizer(self.optimizer) 48 | 49 | return self.optimizer 50 | ``` 51 | 52 | ### deepseek-coder例子 53 | 运行如下命令即可 54 | 55 | 值得注意的是ds_config_zero3.json文件与原始有所不同,去除了学习率的相关参数,因为lora+的实现简单讲就是调整lora中A和B学习率 56 | 57 | > bash run_lora.sh 58 | 59 | 60 | ## 贡献指南 61 | 🤝 如果你有任何改进建议、发现了bug或者想要添加新功能,请随时提交issue或pull request。 62 | 63 | 64 | -------------------------------------------------------------------------------- /finetune/configs/ds_config_zero3.json: -------------------------------------------------------------------------------- 1 | { 2 | "bf16": { 3 | "enabled": "auto" 4 | }, 5 | "optimizer": { 6 | "type": "AdamW", 7 | "params": { 8 | "lr": "auto", 9 | "betas": "auto", 10 | "eps": "auto", 11 | "weight_decay": "auto" 12 | } 13 | }, 14 | 15 | "scheduler": { 16 | "type": "WarmupLR", 17 | "params": { 18 | "warmup_min_lr": "auto", 19 | "warmup_max_lr": "auto", 20 | "warmup_num_steps": "auto" 21 | } 22 | }, 23 | 24 | "zero_optimization": { 25 | "stage": 3, 26 | "offload_optimizer": { 27 | "device": "cpu", 28 | "pin_memory": true 29 | }, 30 | "offload_param": { 31 | "device": "cpu", 32 | "pin_memory": true 33 | }, 34 | "overlap_comm": true, 35 | "contiguous_gradients": true, 36 | "sub_group_size": 1e9, 37 | "reduce_bucket_size": "auto", 38 | "stage3_prefetch_bucket_size": "auto", 39 | "stage3_param_persistence_threshold": "auto", 40 | "stage3_max_live_parameters": 1e9, 41 | "stage3_max_reuse_distance": 1e9, 42 | "stage3_gather_16bit_weights_on_model_save": true 43 | }, 44 | 45 | "gradient_accumulation_steps": "auto", 46 | "gradient_clipping": "auto", 47 | "steps_per_print": 20, 48 | "train_batch_size": "auto", 49 | "train_micro_batch_size_per_gpu": "auto", 50 | "wall_clock_breakdown": false 51 | } -------------------------------------------------------------------------------- /finetune/finetune_lora_plus.py: -------------------------------------------------------------------------------- 1 | import deepspeed 2 | 3 | deepspeed.ops.op_builder.CPUAdamBuilder().load() 4 | import copy 5 | import random 6 | from dataclasses import dataclass, field 7 | from typing import Optional, Dict, Sequence 8 | 9 | import torch 10 | import torch.distributed 11 | import transformers 12 | from transformers import Trainer 13 | from datasets import load_dataset 14 | from peft import LoraConfig, TaskType, get_peft_model 15 | from tricks.lora_plus import LoraPlusTrainer 16 | 17 | IGNORE_INDEX = -100 18 | EOT_TOKEN = "<|EOT|>" 19 | 20 | 21 | def build_instruction_prompt(instruction: str): 22 | return ''' 23 | You are an AI programming assistant, utilizing the DeepSeek Coder model, developed by DeepSeek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer. 24 | ### Instruction: 25 | {} 26 | ### Response: 27 | '''.format(instruction.strip()).lstrip() 28 | 29 | 30 | @dataclass 31 | class ModelArguments: 32 | model_name_or_path: Optional[str] = field(default="deepseek-ai/deepseek-coder-6.7b-instruct") 33 | 34 | 35 | @dataclass 36 | class DataArguments: 37 | data_path: str = field(default=None, metadata={"help": "Path to the training data."}) 38 | 39 | 40 | @dataclass 41 | class TrainingArguments(transformers.TrainingArguments): 42 | cache_dir: Optional[str] = field(default=None) 43 | optim: str = field(default="adamw_torch") 44 | # noisy embedding设定 45 | # neftune_noise_alpha: int = field(default=5) 46 | # neftune_noise_alpha: int = field(default=5) 47 | model_max_length: int = field( 48 | default=512, 49 | metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."}, 50 | ) 51 | 52 | 53 | def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str): 54 | """Collects the state dict and dump to disk.""" 55 | state_dict = trainer.model.state_dict() 56 | if trainer.args.should_save: 57 | cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()} 58 | del state_dict 59 | trainer._save(output_dir, state_dict=cpu_state_dict) # noqa 60 | 61 | 62 | def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict: 63 | """Tokenize a list of strings.""" 64 | tokenized_list = [ 65 | tokenizer( 66 | text, 67 | return_tensors="pt", 68 | padding="longest", 69 | max_length=tokenizer.model_max_length, 70 | truncation=True, 71 | ) 72 | for text in strings 73 | ] 74 | 75 | input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list] 76 | input_ids_lens = labels_lens = [ 77 | tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list 78 | ] 79 | 80 | return dict( 81 | input_ids=input_ids, 82 | labels=labels, 83 | input_ids_lens=input_ids_lens, 84 | labels_lens=labels_lens, 85 | ) 86 | 87 | 88 | def preprocess( 89 | sources: Sequence[str], 90 | targets: Sequence[str], 91 | tokenizer: transformers.PreTrainedTokenizer, 92 | ) -> Dict: 93 | """Preprocess the data by tokenizing.""" 94 | examples = [s + t for s, t in zip(sources, targets)] 95 | examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)] 96 | input_ids = examples_tokenized["input_ids"] 97 | 98 | labels = copy.deepcopy(input_ids) 99 | for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]): 100 | label[:source_len] = IGNORE_INDEX 101 | return dict(input_ids=input_ids, labels=labels) 102 | 103 | 104 | @dataclass 105 | class DataCollatorForSupervisedDataset(object): 106 | """Collate examples for supervised fine-tuning.""" 107 | tokenizer: transformers.PreTrainedTokenizer 108 | 109 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: 110 | input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels")) 111 | input_ids = [torch.tensor(x) for x in input_ids] 112 | input_ids = torch.nn.utils.rnn.pad_sequence( 113 | input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id 114 | ) 115 | labels = [torch.tensor(x) for x in labels] 116 | labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) 117 | 118 | return dict( 119 | input_ids=input_ids, 120 | labels=labels, 121 | attention_mask=input_ids.ne(self.tokenizer.pad_token_id), 122 | ) 123 | 124 | 125 | def train_tokenize_function(examples, tokenizer): 126 | sources = [ 127 | build_instruction_prompt(instruction) 128 | for instruction in examples['instruction'] 129 | ] 130 | targets = [f"{output}\n{EOT_TOKEN}" for output in examples['output']] 131 | data_dict = preprocess(sources, targets, tokenizer) 132 | return data_dict 133 | 134 | 135 | def train(): 136 | # 加入lora参数 137 | config = LoraConfig( 138 | task_type=TaskType.CAUSAL_LM, 139 | target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], 140 | inference_mode=False, # 训练模式 141 | r=8, # Lora 秩 142 | lora_alpha=32, # Lora alaph,具体作用参见 Lora 原理 143 | lora_dropout=0.1 # Dropout 比例 144 | ) 145 | 146 | parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) 147 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 148 | 149 | if training_args.local_rank == 0: 150 | print('=' * 100) 151 | print(training_args) 152 | 153 | tokenizer = transformers.AutoTokenizer.from_pretrained( 154 | model_args.model_name_or_path, 155 | model_max_length=training_args.model_max_length, 156 | padding_side="right", 157 | use_fast=True, 158 | trust_remote_code=True 159 | ) 160 | 161 | print("PAD Token:", tokenizer.pad_token, tokenizer.pad_token_id) 162 | print("BOS Token", tokenizer.bos_token, tokenizer.bos_token_id) 163 | print("EOS Token", tokenizer.eos_token, tokenizer.eos_token_id) 164 | 165 | if training_args.local_rank == 0: 166 | print("Load tokenizer from {} over.".format(model_args.model_name_or_path)) 167 | 168 | model = transformers.AutoModelForCausalLM.from_pretrained( 169 | model_args.model_name_or_path, 170 | torch_dtype=torch.bfloat16 171 | ) 172 | model.enable_input_require_grads() 173 | # lora model 174 | model = get_peft_model(model, config) 175 | 176 | if training_args.local_rank == 0: 177 | print("Load model from {} over.".format(model_args.model_name_or_path)) 178 | 179 | raw_train_datasets = load_dataset( 180 | 'json', 181 | data_files=data_args.data_path, 182 | split="train", 183 | cache_dir=training_args.cache_dir 184 | ) 185 | if training_args.local_rank > 0: 186 | torch.distributed.barrier() 187 | 188 | train_dataset = raw_train_datasets.map( 189 | train_tokenize_function, 190 | batched=True, 191 | batch_size=3000, 192 | num_proc=32, 193 | remove_columns=raw_train_datasets.column_names, 194 | load_from_cache_file=True, # not args.overwrite_cache 195 | desc="Running Encoding", 196 | fn_kwargs={"tokenizer": tokenizer} 197 | ) 198 | 199 | if training_args.local_rank == 0: 200 | torch.distributed.barrier() 201 | 202 | if training_args.local_rank == 0: 203 | print("Training dataset samples:", len(train_dataset)) 204 | for index in random.sample(range(len(train_dataset)), 3): 205 | print( 206 | f"Sample {index} of the training set: {train_dataset[index]['input_ids']}, {train_dataset[index]['labels']}.") 207 | print(f"Sample {index} of the training set: {tokenizer.decode(list(train_dataset[index]['input_ids']))}.") 208 | 209 | data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) 210 | data_module = dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator) 211 | 212 | # trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module) 213 | # 加入lora+ Trainer 214 | trainer = LoraPlusTrainer(model=model, tokenizer=tokenizer, args=training_args, **data_module) 215 | 216 | trainer.train() 217 | trainer.save_state() 218 | safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir) 219 | 220 | 221 | if __name__ == "__main__": 222 | train() 223 | -------------------------------------------------------------------------------- /finetune/run_lora.sh: -------------------------------------------------------------------------------- 1 | DATA_PATH="" 2 | OUTPUT_PATH="" 3 | MODEL="deepseek-ai/deepseek-coder-6.7b-instruct" 4 | MODEL_PATH="../deepseek-ai/deepseek-coder-6.7b-instruct" 5 | 6 | cd finetune && nohup deepspeed --include localhost:1 finetune_lora.py \ 7 | --model_name_or_path $MODEL_PATH \ 8 | --data_path $DATA_PATH \ 9 | --output_dir $OUTPUT_PATH \ 10 | --num_train_epochs 1 \ 11 | --model_max_length 1024 \ 12 | --per_device_train_batch_size 4 \ 13 | --per_device_eval_batch_size 1 \ 14 | --gradient_accumulation_steps 4 \ 15 | --evaluation_strategy "no" \ 16 | --save_strategy "steps" \ 17 | --save_steps 1562 \ 18 | --save_total_limit 100 \ 19 | --learning_rate 2e-4 \ 20 | --warmup_steps 10 \ 21 | --logging_steps 1 \ 22 | --lr_scheduler_type "cosine" \ 23 | --gradient_checkpointing True \ 24 | --report_to "tensorboard" \ 25 | --deepspeed configs/ds_config_zero3.json \ 26 | --bf16 False -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers==4.37.2 2 | torch 3 | deepspeed 4 | 5 | -------------------------------------------------------------------------------- /tricks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mst272/simple-lora-plus/df4044bc2f88f55edb268c6430d9e933ac4ad231/tricks/__init__.py -------------------------------------------------------------------------------- /tricks/lora_plus.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Optional 3 | 4 | import torch.nn as nn 5 | from functools import reduce 6 | from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS 7 | from peft.tuners import lora 8 | from transformers import Trainer, HfArgumentParser 9 | from transformers.utils import is_sagemaker_mp_enabled, logging 10 | 11 | if is_sagemaker_mp_enabled(): 12 | import smdistributed.modelparallel.torch as smp 13 | 14 | logger = logging.get_logger(__name__) 15 | 16 | 17 | # LayerNorm层及bias等不需要进行weight decay 18 | # ALL_LAYERNORM_LAYERS = [nn.LayerNorm, LlamaRMSNorm] 19 | 20 | # 论文中的推荐相关参数设置 21 | LORA_LR_RATIO = 16 22 | LORA_LR_EMBEDDING = 1e-6 23 | WEIGHT_DECAY = 0.0 24 | 25 | 26 | def get_modules(name, model): 27 | """ 28 | 通过名字获取module 29 | """ 30 | if "lora" in name: 31 | parent_idx = 2 32 | else: 33 | parent_idx = 1 34 | 35 | module_name = name.split(sep=".")[:-parent_idx] 36 | module = reduce(getattr, module_name, model) 37 | return module 38 | 39 | 40 | def get_parameter_names(model, forbidden_layer_types): 41 | """ 42 | Returns the names of the model parameters that are not inside a forbidden layer. 43 | """ 44 | result = [] 45 | for name, child in model.named_children(): 46 | result += [ 47 | f"{name}.{n}" 48 | for n in get_parameter_names(child, forbidden_layer_types) 49 | if not isinstance(child, tuple(forbidden_layer_types)) 50 | ] 51 | # Add model specific parameters (defined with nn.Parameter) since they are not in any child. 52 | result += list(model._parameters.keys()) 53 | return result 54 | 55 | 56 | def create_lorap_optimizer(model, lora_lr_ratio, optimizer_cls, optimizer_kwargs, lora_lr_embedding=None): 57 | if lora_lr_embedding is None: 58 | lora_lr_embedding = 1e-6 59 | decay_parameters = get_parameter_names(model, ALL_LAYERNORM_LAYERS) 60 | decay_parameters = [name for name in decay_parameters if "bias" not in name] 61 | parameters = { 62 | "A": {}, 63 | "B": {}, 64 | "B_no_decay": {}, 65 | "embedding": {} 66 | } 67 | 68 | for name, param in model.named_parameters(): 69 | if not param.requires_grad: 70 | continue 71 | 72 | module = get_modules(name, model) 73 | if isinstance(module, lora.Embedding): 74 | parameters['embedding'][name] = param 75 | elif "lora_B" in name: 76 | if name in decay_parameters: 77 | parameters['B'][name] = param 78 | else: 79 | parameters['B_no_decay'][name] = param 80 | else: 81 | parameters['A'][name] = param 82 | 83 | apply_param_groups = "" 84 | for group in parameters: 85 | apply_param_groups += f"{group}\n {list(parameters[group].keys())}\n\n" 86 | logger.info(apply_param_groups) 87 | 88 | lr = optimizer_kwargs["lr"] 89 | weight_decay = WEIGHT_DECAY 90 | 91 | optimizer_grouped_parameters = [ 92 | { 93 | "params": list(parameters["A"].values()), 94 | "weight_decay": weight_decay, 95 | "lr": lr, 96 | }, 97 | { 98 | "params": list(parameters["embedding"].values()), 99 | "weight_decay": weight_decay, 100 | "lr": lora_lr_embedding, 101 | }, 102 | { 103 | "params": list(parameters["B"].values()), 104 | "weight_decay": weight_decay, 105 | "lr": lr * lora_lr_ratio, 106 | }, 107 | { 108 | "params": list(parameters["B_no_decay"].values()), 109 | "weight_decay": 0.0, 110 | "lr": lr * lora_lr_ratio, 111 | }, 112 | ] 113 | 114 | optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) 115 | 116 | # transformers trainer 117 | if optimizer_cls.__name__ == "Adam8bit": 118 | import bitsandbytes 119 | 120 | manager = bitsandbytes.optim.GlobalOptimManager.get_instance() 121 | 122 | skipped = 0 123 | for module in model.modules(): 124 | if isinstance(module, nn.Embedding): 125 | skipped += sum( 126 | {p.data_ptr(): p.numel() for p in module.parameters()}.values() 127 | ) 128 | logger.info(f"skipped {module}: {skipped / 2 ** 20}M params") 129 | manager.register_module_override(module, "weight", {"optim_bits": 32}) 130 | logger.debug(f"bitsandbytes: will optimize {module} in fp32") 131 | logger.info(f"skipped: {skipped / 2 ** 20}M params") 132 | 133 | return optimizer 134 | 135 | 136 | # 重写Trainer 的 create_optimizer方法 137 | class LoraPlusTrainer(Trainer): 138 | def create_optimizer(self): 139 | opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model 140 | if self.optimizer is None: 141 | optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs( 142 | self.args 143 | ) 144 | 145 | lora_lr_ratio = LORA_LR_RATIO 146 | lora_lr_embedding = LORA_LR_EMBEDDING 147 | 148 | self.optimizer = create_lorap_optimizer(opt_model, lora_lr_ratio, optimizer_cls, optimizer_kwargs, 149 | lora_lr_embedding) 150 | if is_sagemaker_mp_enabled(): 151 | self.optimizer = smp.DistributedOptimizer(self.optimizer) 152 | 153 | return self.optimizer 154 | 155 | 156 | --------------------------------------------------------------------------------