├── tokenizer └── vocab.txt ├── README.md ├── functions.py ├── 2.使用自定义模型.ipynb ├── 3.初始化参数的方式.ipynb └── 1.快速上手.ipynb /tokenizer/vocab.txt: -------------------------------------------------------------------------------- 1 | [PAD] 2 | [CLS] 3 | [SEP] 4 | [UNK] 5 | [MASK] 6 | 0 7 | 1 8 | 2 9 | 3 10 | 4 11 | 5 12 | 6 13 | 7 14 | 8 15 | 9 -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 环境信息: 2 | 3 | python==3.10 4 | 5 | torch==1.13.1+cu117 6 | 7 | transformers==4.38.2 8 | 9 | datasets==2.18.0 10 | 11 | accelerate==0.27.2 12 | 13 | 视频课程:https://www.bilibili.com/video/BV11E421u7MK 14 | -------------------------------------------------------------------------------- /functions.py: -------------------------------------------------------------------------------- 1 | def get_loader(text_lens=100): 2 | import torch 3 | import random 4 | from transformers import BertTokenizer 5 | from datasets import Dataset 6 | 7 | tokenizer = BertTokenizer(vocab_file='tokenizer/vocab.txt', 8 | model_max_length=512) 9 | 10 | def f(): 11 | for _ in range(2000): 12 | label = random.randint(0, 9) 13 | text = ' '.join(str(label) * text_lens) 14 | yield {'text': text, 'label': label} 15 | 16 | dataset = Dataset.from_generator(f) 17 | 18 | def f(data): 19 | text = [i['text'] for i in data] 20 | label = [i['label'] for i in data] 21 | 22 | data = tokenizer(text, 23 | padding=True, 24 | truncation=True, 25 | max_length=512, 26 | return_tensors='pt') 27 | 28 | data['labels'] = torch.LongTensor(label) 29 | 30 | return data 31 | 32 | loader = torch.utils.data.DataLoader(dataset=dataset, 33 | batch_size=32, 34 | shuffle=True, 35 | drop_last=True, 36 | collate_fn=f) 37 | 38 | return tokenizer, dataset, loader 39 | 40 | 41 | def get_model(num_hidden_layers=32): 42 | import torch 43 | from transformers import BertConfig, BertForSequenceClassification 44 | from transformers.optimization import get_scheduler 45 | 46 | config = BertConfig(num_labels=10, num_hidden_layers=num_hidden_layers) 47 | model = BertForSequenceClassification(config) 48 | 49 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) 50 | scheduler = get_scheduler(name='cosine', 51 | num_warmup_steps=0, 52 | num_training_steps=50, 53 | optimizer=optimizer) 54 | 55 | return model, optimizer, scheduler -------------------------------------------------------------------------------- /2.使用自定义模型.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "97221f5f", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "data": { 11 | "text/plain": [ 12 | "Linear(in_features=768, out_features=10, bias=True)" 13 | ] 14 | }, 15 | "execution_count": 1, 16 | "metadata": {}, 17 | "output_type": "execute_result" 18 | } 19 | ], 20 | "source": [ 21 | "import torch\n", 22 | "from functions import get_loader, get_model\n", 23 | "\n", 24 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", 25 | "_, _, loader = get_loader()\n", 26 | "model, _, _ = get_model()\n", 27 | "\n", 28 | "model.classifier" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 2, 34 | "id": "3cdc909a", 35 | "metadata": {}, 36 | "outputs": [ 37 | { 38 | "name": "stdout", 39 | "output_type": "stream", 40 | "text": [ 41 | "trainable params: 786,432 || all params: 252,033,802 || trainable%: 0.31203433577532586\n" 42 | ] 43 | }, 44 | { 45 | "data": { 46 | "text/plain": [ 47 | "Linear(in_features=768, out_features=10, bias=True)" 48 | ] 49 | }, 50 | "execution_count": 2, 51 | "metadata": {}, 52 | "output_type": "execute_result" 53 | } 54 | ], 55 | "source": [ 56 | "from peft import LoraConfig, TaskType, get_peft_model, LoftQConfig\n", 57 | "\n", 58 | "#此处不再指定task_type\n", 59 | "config = LoraConfig(inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1)\n", 60 | "\n", 61 | "model = get_peft_model(model, config)\n", 62 | "\n", 63 | "model.print_trainable_parameters()\n", 64 | "\n", 65 | "model.classifier" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": 3, 71 | "id": "79902e1e", 72 | "metadata": { 73 | "scrolled": false 74 | }, 75 | "outputs": [ 76 | { 77 | "name": "stdout", 78 | "output_type": "stream", 79 | "text": [ 80 | "0 62 2.2774503231048584 0.125\n", 81 | "1 62 2.2287027835845947 0.1875\n", 82 | "2 62 2.3163657188415527 0.15625\n", 83 | "3 62 2.209228992462158 0.21875\n", 84 | "4 62 2.1856229305267334 0.15625\n", 85 | "5 62 2.2188992500305176 0.1875\n", 86 | "6 62 2.0847771167755127 0.34375\n", 87 | "7 62 2.1538450717926025 0.3125\n", 88 | "8 62 2.0640904903411865 0.375\n", 89 | "9 62 1.9859848022460938 0.5625\n", 90 | "10 62 1.953343152999878 0.59375\n", 91 | "11 62 1.9733211994171143 0.6875\n", 92 | "12 62 1.9323341846466064 0.6875\n", 93 | "13 62 1.8512815237045288 0.78125\n", 94 | "14 62 1.9471282958984375 0.625\n", 95 | "15 62 1.7356795072555542 0.90625\n", 96 | "16 62 1.759588599205017 0.875\n", 97 | "17 62 1.6408013105392456 0.96875\n", 98 | "18 62 1.7141221761703491 0.8125\n", 99 | "19 62 1.6078336238861084 0.9375\n", 100 | "20 62 1.5887126922607422 1.0\n", 101 | "21 62 1.4986778497695923 0.96875\n", 102 | "22 62 1.5481253862380981 0.875\n", 103 | "23 62 1.4178872108459473 1.0\n", 104 | "24 62 1.371886968612671 1.0\n", 105 | "25 62 1.4725046157836914 0.9375\n", 106 | "26 62 1.3218789100646973 0.96875\n", 107 | "27 62 1.359554409980774 1.0\n", 108 | "28 62 1.2736034393310547 0.96875\n", 109 | "29 62 1.2936562299728394 0.96875\n", 110 | "30 62 1.1820000410079956 1.0\n", 111 | "31 62 1.1500431299209595 1.0\n", 112 | "32 62 1.122283697128296 1.0\n", 113 | "33 62 1.097461462020874 1.0\n", 114 | "34 62 0.9414330720901489 1.0\n", 115 | "35 62 1.0303120613098145 1.0\n", 116 | "36 62 0.9005396366119385 1.0\n", 117 | "37 62 0.8811569213867188 1.0\n", 118 | "38 62 0.8738059997558594 1.0\n", 119 | "39 62 0.8506068587303162 1.0\n", 120 | "40 62 0.7993475198745728 1.0\n", 121 | "41 62 0.7811093926429749 1.0\n", 122 | "42 62 0.7609606981277466 1.0\n", 123 | "43 62 0.759360671043396 1.0\n", 124 | "44 62 0.6664047241210938 1.0\n", 125 | "45 62 0.660634458065033 1.0\n", 126 | "46 62 0.6855562925338745 1.0\n", 127 | "47 62 0.6103322505950928 1.0\n", 128 | "48 62 0.5793362855911255 1.0\n", 129 | "49 62 0.5576601624488831 1.0\n", 130 | "50 62 0.5652727484703064 1.0\n", 131 | "51 62 0.5442566275596619 1.0\n", 132 | "52 62 0.5221884250640869 1.0\n", 133 | "53 62 0.4475937783718109 1.0\n", 134 | "54 62 0.5214552879333496 1.0\n", 135 | "55 62 0.48425009846687317 1.0\n", 136 | "56 62 0.3857068419456482 1.0\n", 137 | "57 62 0.4370235204696655 1.0\n", 138 | "58 62 0.4066462218761444 1.0\n", 139 | "59 62 0.39984583854675293 1.0\n", 140 | "60 62 0.3576478660106659 1.0\n", 141 | "61 62 0.36198505759239197 1.0\n" 142 | ] 143 | }, 144 | { 145 | "data": { 146 | "text/plain": [ 147 | "datetime.timedelta(seconds=25, microseconds=576708)" 148 | ] 149 | }, 150 | "execution_count": 3, 151 | "metadata": {}, 152 | "output_type": "execute_result" 153 | } 154 | ], 155 | "source": [ 156 | "import datetime\n", 157 | "\n", 158 | "#正常训练\n", 159 | "optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)\n", 160 | "model.to(device)\n", 161 | "\n", 162 | "now = datetime.datetime.now()\n", 163 | "for i, data in enumerate(loader):\n", 164 | " for k, v in data.items():\n", 165 | " data[k] = v.to(device)\n", 166 | " out = model(**data)\n", 167 | " out.loss.backward()\n", 168 | " torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n", 169 | " optimizer.step()\n", 170 | " optimizer.zero_grad()\n", 171 | "\n", 172 | " if i % 1 == 0:\n", 173 | " labels = data['labels']\n", 174 | " logits = out['logits'].argmax(1)\n", 175 | " acc = (labels == logits).sum().item() / len(labels)\n", 176 | "\n", 177 | " print(i, len(loader), out.loss.item(), acc)\n", 178 | "\n", 179 | "datetime.datetime.now() - now" 180 | ] 181 | } 182 | ], 183 | "metadata": { 184 | "kernelspec": { 185 | "display_name": "Python [conda env:cuda117]", 186 | "language": "python", 187 | "name": "conda-env-cuda117-py" 188 | }, 189 | "language_info": { 190 | "codemirror_mode": { 191 | "name": "ipython", 192 | "version": 3 193 | }, 194 | "file_extension": ".py", 195 | "mimetype": "text/x-python", 196 | "name": "python", 197 | "nbconvert_exporter": "python", 198 | "pygments_lexer": "ipython3", 199 | "version": "3.10.13" 200 | } 201 | }, 202 | "nbformat": 4, 203 | "nbformat_minor": 5 204 | } 205 | -------------------------------------------------------------------------------- /3.初始化参数的方式.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "53ce9fe8", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "data": { 11 | "text/plain": [ 12 | "Linear(in_features=768, out_features=10, bias=True)" 13 | ] 14 | }, 15 | "execution_count": 1, 16 | "metadata": {}, 17 | "output_type": "execute_result" 18 | } 19 | ], 20 | "source": [ 21 | "import torch\n", 22 | "from functions import get_loader, get_model\n", 23 | "\n", 24 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", 25 | "_, _, loader = get_loader()\n", 26 | "model, _, _ = get_model()\n", 27 | "\n", 28 | "model.classifier" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 2, 34 | "id": "d0dd067d", 35 | "metadata": {}, 36 | "outputs": [ 37 | { 38 | "name": "stdout", 39 | "output_type": "stream", 40 | "text": [ 41 | "trainable params: 7,700 || all params: 251,255,080 || trainable%: 0.0030646146537614285\n" 42 | ] 43 | }, 44 | { 45 | "data": { 46 | "text/plain": [ 47 | "ModulesToSaveWrapper(\n", 48 | " (original_module): Linear(\n", 49 | " (base_layer): Linear(in_features=768, out_features=10, bias=True)\n", 50 | " (ia3_l): ParameterDict( (default): Parameter containing: [torch.FloatTensor of size 10x1])\n", 51 | " )\n", 52 | " (modules_to_save): ModuleDict(\n", 53 | " (default): Linear(\n", 54 | " (base_layer): Linear(in_features=768, out_features=10, bias=True)\n", 55 | " (ia3_l): ParameterDict( (default): Parameter containing: [torch.FloatTensor of size 10x1])\n", 56 | " )\n", 57 | " )\n", 58 | ")" 59 | ] 60 | }, 61 | "execution_count": 2, 62 | "metadata": {}, 63 | "output_type": "execute_result" 64 | } 65 | ], 66 | "source": [ 67 | "from peft import LoraConfig, TaskType, get_peft_model, LoftQConfig, PromptEncoderConfig, IA3Config\n", 68 | "\n", 69 | "config = LoraConfig(\n", 70 | " task_type=TaskType.SEQ_CLS,\n", 71 | " inference_mode=False,\n", 72 | " r=8,\n", 73 | " lora_alpha=32,\n", 74 | " lora_dropout=0.1,\n", 75 | " target_modules=['classifier'],\n", 76 | "\n", 77 | " #设置A层参数初始化方式,默认A层是凯明均匀分布,B层是全0\n", 78 | " #init_lora_weights='gaussian',\n", 79 | "\n", 80 | " #使用loftq初始化参数,一般会获得更好的效果\n", 81 | " init_lora_weights='loftq',\n", 82 | " loftq_config=LoftQConfig(loftq_bits=4),\n", 83 | "\n", 84 | " #使用数值缩放,也是增进训练效果的\n", 85 | " use_rslora=True,\n", 86 | "\n", 87 | " #另一种插入层的结构,和loftq不共存\n", 88 | " use_dora=False,\n", 89 | ")\n", 90 | "\n", 91 | "#适用于CAUSAL_LM任务的配置\n", 92 | "config = PromptEncoderConfig(task_type='SEQ_CLS',\n", 93 | " num_virtual_tokens=20,\n", 94 | " encoder_hidden_size=128)\n", 95 | "\n", 96 | "#IA3是比lora更激进的方式,可训练的参数更少\n", 97 | "config = IA3Config(task_type='SEQ_CLS', target_modules=['classifier'])\n", 98 | "\n", 99 | "model = get_peft_model(model, config)\n", 100 | "\n", 101 | "model.print_trainable_parameters()\n", 102 | "\n", 103 | "model.classifier" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": 3, 109 | "id": "965de7ec", 110 | "metadata": { 111 | "scrolled": false 112 | }, 113 | "outputs": [ 114 | { 115 | "name": "stdout", 116 | "output_type": "stream", 117 | "text": [ 118 | "0 62 2.4089834690093994 0.0\n", 119 | "1 62 2.3106164932250977 0.09375\n", 120 | "2 62 2.316065788269043 0.0625\n", 121 | "3 62 2.363997459411621 0.09375\n", 122 | "4 62 2.301146984100342 0.09375\n", 123 | "5 62 2.2765536308288574 0.1875\n", 124 | "6 62 2.3127336502075195 0.0625\n", 125 | "7 62 2.313250780105591 0.0625\n", 126 | "8 62 2.263867139816284 0.09375\n", 127 | "9 62 2.257852554321289 0.03125\n", 128 | "10 62 2.229757785797119 0.1875\n", 129 | "11 62 2.2166881561279297 0.1875\n", 130 | "12 62 2.1789209842681885 0.125\n", 131 | "13 62 2.2374184131622314 0.1875\n", 132 | "14 62 2.230839729309082 0.15625\n", 133 | "15 62 2.2052810192108154 0.125\n", 134 | "16 62 2.1605560779571533 0.21875\n", 135 | "17 62 2.1335458755493164 0.25\n", 136 | "18 62 2.146886110305786 0.1875\n", 137 | "19 62 2.1742775440216064 0.25\n", 138 | "20 62 2.192859411239624 0.1875\n", 139 | "21 62 2.1038081645965576 0.28125\n", 140 | "22 62 2.1520001888275146 0.25\n", 141 | "23 62 2.1106300354003906 0.375\n", 142 | "24 62 2.1247026920318604 0.34375\n", 143 | "25 62 2.1607823371887207 0.21875\n", 144 | "26 62 2.1247217655181885 0.15625\n", 145 | "27 62 2.051104784011841 0.5625\n", 146 | "28 62 2.0689423084259033 0.46875\n", 147 | "29 62 2.037259340286255 0.59375\n", 148 | "30 62 2.0886518955230713 0.4375\n", 149 | "31 62 2.0399112701416016 0.53125\n", 150 | "32 62 2.0339131355285645 0.5625\n", 151 | "33 62 2.0201523303985596 0.59375\n", 152 | "34 62 2.03676438331604 0.53125\n", 153 | "35 62 2.012573480606079 0.625\n", 154 | "36 62 2.0173566341400146 0.625\n", 155 | "37 62 1.942825198173523 0.75\n", 156 | "38 62 1.9864193201065063 0.625\n", 157 | "39 62 1.9341926574707031 0.75\n", 158 | "40 62 1.9701616764068604 0.625\n", 159 | "41 62 1.9886361360549927 0.625\n", 160 | "42 62 1.9486091136932373 0.59375\n", 161 | "43 62 1.9258460998535156 0.84375\n", 162 | "44 62 1.9121100902557373 0.78125\n", 163 | "45 62 1.9130499362945557 0.84375\n", 164 | "46 62 1.9048962593078613 0.78125\n", 165 | "47 62 1.887307047843933 0.875\n", 166 | "48 62 1.9133539199829102 0.75\n", 167 | "49 62 1.8954181671142578 0.8125\n", 168 | "50 62 1.8833553791046143 0.71875\n", 169 | "51 62 1.878673791885376 0.90625\n", 170 | "52 62 1.8301059007644653 0.90625\n", 171 | "53 62 1.833159327507019 0.90625\n", 172 | "54 62 1.8676620721817017 0.78125\n", 173 | "55 62 1.7989805936813354 0.9375\n", 174 | "56 62 1.8499683141708374 0.84375\n", 175 | "57 62 1.7384623289108276 1.0\n", 176 | "58 62 1.8065447807312012 0.96875\n", 177 | "59 62 1.810263991355896 0.78125\n", 178 | "60 62 1.8282135725021362 0.84375\n", 179 | "61 62 1.7763408422470093 0.8125\n" 180 | ] 181 | }, 182 | { 183 | "data": { 184 | "text/plain": [ 185 | "datetime.timedelta(seconds=12, microseconds=140622)" 186 | ] 187 | }, 188 | "execution_count": 3, 189 | "metadata": {}, 190 | "output_type": "execute_result" 191 | } 192 | ], 193 | "source": [ 194 | "import datetime\n", 195 | "\n", 196 | "#正常训练\n", 197 | "optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)\n", 198 | "model.to(device)\n", 199 | "\n", 200 | "now = datetime.datetime.now()\n", 201 | "for i, data in enumerate(loader):\n", 202 | " for k, v in data.items():\n", 203 | " data[k] = v.to(device)\n", 204 | " out = model(**data)\n", 205 | " out.loss.backward()\n", 206 | " torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n", 207 | " optimizer.step()\n", 208 | " optimizer.zero_grad()\n", 209 | "\n", 210 | " if i % 1 == 0:\n", 211 | " labels = data['labels']\n", 212 | " logits = out['logits'].argmax(1)\n", 213 | " acc = (labels == logits).sum().item() / len(labels)\n", 214 | "\n", 215 | " print(i, len(loader), out.loss.item(), acc)\n", 216 | "\n", 217 | "datetime.datetime.now() - now" 218 | ] 219 | } 220 | ], 221 | "metadata": { 222 | "kernelspec": { 223 | "display_name": "Python [conda env:cuda117]", 224 | "language": "python", 225 | "name": "conda-env-cuda117-py" 226 | }, 227 | "language_info": { 228 | "codemirror_mode": { 229 | "name": "ipython", 230 | "version": 3 231 | }, 232 | "file_extension": ".py", 233 | "mimetype": "text/x-python", 234 | "name": "python", 235 | "nbconvert_exporter": "python", 236 | "pygments_lexer": "ipython3", 237 | "version": "3.10.13" 238 | } 239 | }, 240 | "nbformat": 4, 241 | "nbformat_minor": 5 242 | } 243 | -------------------------------------------------------------------------------- /1.快速上手.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "ef4a17fd", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "data": { 11 | "text/plain": [ 12 | "Linear(in_features=768, out_features=10, bias=True)" 13 | ] 14 | }, 15 | "execution_count": 1, 16 | "metadata": {}, 17 | "output_type": "execute_result" 18 | } 19 | ], 20 | "source": [ 21 | "import torch\n", 22 | "from functions import get_loader, get_model\n", 23 | "\n", 24 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", 25 | "_, _, loader = get_loader()\n", 26 | "model, _, _ = get_model()\n", 27 | "\n", 28 | "#保存原模型参数\n", 29 | "model.save_pretrained('model/save_pretrained')\n", 30 | "\n", 31 | "model.classifier" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": 2, 37 | "id": "a5c3f7c5", 38 | "metadata": {}, 39 | "outputs": [ 40 | { 41 | "name": "stdout", 42 | "output_type": "stream", 43 | "text": [ 44 | "trainable params: 13,914 || all params: 251,267,508 || trainable%: 0.005537524573213023\n" 45 | ] 46 | }, 47 | { 48 | "data": { 49 | "text/plain": [ 50 | "ModulesToSaveWrapper(\n", 51 | " (original_module): lora.Linear(\n", 52 | " (base_layer): Linear(in_features=768, out_features=10, bias=True)\n", 53 | " (lora_dropout): ModuleDict(\n", 54 | " (default): Dropout(p=0.1, inplace=False)\n", 55 | " )\n", 56 | " (lora_A): ModuleDict(\n", 57 | " (default): Linear(in_features=768, out_features=8, bias=False)\n", 58 | " )\n", 59 | " (lora_B): ModuleDict(\n", 60 | " (default): Linear(in_features=8, out_features=10, bias=False)\n", 61 | " )\n", 62 | " (lora_embedding_A): ParameterDict()\n", 63 | " (lora_embedding_B): ParameterDict()\n", 64 | " )\n", 65 | " (modules_to_save): ModuleDict(\n", 66 | " (default): lora.Linear(\n", 67 | " (base_layer): Linear(in_features=768, out_features=10, bias=True)\n", 68 | " (lora_dropout): ModuleDict(\n", 69 | " (default): Dropout(p=0.1, inplace=False)\n", 70 | " )\n", 71 | " (lora_A): ModuleDict(\n", 72 | " (default): Linear(in_features=768, out_features=8, bias=False)\n", 73 | " )\n", 74 | " (lora_B): ModuleDict(\n", 75 | " (default): Linear(in_features=8, out_features=10, bias=False)\n", 76 | " )\n", 77 | " (lora_embedding_A): ParameterDict()\n", 78 | " (lora_embedding_B): ParameterDict()\n", 79 | " )\n", 80 | " )\n", 81 | ")" 82 | ] 83 | }, 84 | "execution_count": 2, 85 | "metadata": {}, 86 | "output_type": "execute_result" 87 | } 88 | ], 89 | "source": [ 90 | "from peft import LoraConfig, TaskType, get_peft_model, LoftQConfig\n", 91 | "\n", 92 | "config = LoraConfig(\n", 93 | " #任务类型, SEQ_CLS,SEQ_2_SEQ_LM,CAUSAL_LM,TOKEN_CLS,QUESTION_ANS,FEATURE_EXTRACTION\n", 94 | " task_type=TaskType.SEQ_CLS,\n", 95 | " #是否是推理模式.\n", 96 | " inference_mode=False,\n", 97 | " #降秩矩阵的尺寸,这个参数会影响训练的参数量\n", 98 | " r=8,\n", 99 | " #lora的缩放系数,不影响参数量\n", 100 | " lora_alpha=32,\n", 101 | " #降秩矩阵的dropout\n", 102 | " lora_dropout=0.1,\n", 103 | " #指定要对原模型中的那一部分添加lora层,默认是qk线性层\n", 104 | " target_modules=['classifier'],\n", 105 | ")\n", 106 | "\n", 107 | "model = get_peft_model(model, config)\n", 108 | "\n", 109 | "model.print_trainable_parameters()\n", 110 | "\n", 111 | "model.classifier" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": 3, 117 | "id": "936e6af3", 118 | "metadata": { 119 | "scrolled": false 120 | }, 121 | "outputs": [ 122 | { 123 | "name": "stdout", 124 | "output_type": "stream", 125 | "text": [ 126 | "0 62 2.336458683013916 0.03125\n", 127 | "1 62 2.335559368133545 0.03125\n", 128 | "2 62 2.3020806312561035 0.125\n", 129 | "3 62 2.3629519939422607 0.15625\n", 130 | "4 62 2.3952343463897705 0.09375\n", 131 | "5 62 2.3703486919403076 0.03125\n", 132 | "6 62 2.3296892642974854 0.09375\n", 133 | "7 62 2.2450196743011475 0.15625\n", 134 | "8 62 2.204324722290039 0.25\n", 135 | "9 62 2.306922197341919 0.125\n", 136 | "10 62 2.308375597000122 0.0625\n", 137 | "11 62 2.3042044639587402 0.125\n", 138 | "12 62 2.22579026222229 0.25\n", 139 | "13 62 2.248403549194336 0.15625\n", 140 | "14 62 2.2131221294403076 0.25\n", 141 | "15 62 2.1622021198272705 0.25\n", 142 | "16 62 2.1667230129241943 0.1875\n", 143 | "17 62 2.12541127204895 0.28125\n", 144 | "18 62 2.1942343711853027 0.21875\n", 145 | "19 62 2.086634874343872 0.34375\n", 146 | "20 62 2.1510114669799805 0.25\n", 147 | "21 62 2.159356117248535 0.375\n", 148 | "22 62 2.141340732574463 0.34375\n", 149 | "23 62 2.131850481033325 0.34375\n", 150 | "24 62 2.1174545288085938 0.21875\n", 151 | "25 62 2.087817668914795 0.53125\n", 152 | "26 62 2.1108903884887695 0.375\n", 153 | "27 62 2.0811893939971924 0.46875\n", 154 | "28 62 2.055372953414917 0.40625\n", 155 | "29 62 2.0447239875793457 0.46875\n", 156 | "30 62 2.0550787448883057 0.34375\n", 157 | "31 62 2.0166432857513428 0.53125\n", 158 | "32 62 1.993998408317566 0.4375\n", 159 | "33 62 2.0733022689819336 0.4375\n", 160 | "34 62 2.0332281589508057 0.5\n", 161 | "35 62 2.022028923034668 0.5\n", 162 | "36 62 2.0159997940063477 0.46875\n", 163 | "37 62 1.9176312685012817 0.78125\n", 164 | "38 62 1.942380428314209 0.625\n", 165 | "39 62 1.9386800527572632 0.5625\n", 166 | "40 62 1.9307005405426025 0.65625\n", 167 | "41 62 1.9886863231658936 0.59375\n", 168 | "42 62 1.9458670616149902 0.5625\n", 169 | "43 62 1.928817868232727 0.71875\n", 170 | "44 62 1.8715364933013916 0.78125\n", 171 | "45 62 1.873715877532959 0.65625\n", 172 | "46 62 1.8763353824615479 0.6875\n", 173 | "47 62 1.86095130443573 0.75\n", 174 | "48 62 1.948620319366455 0.59375\n", 175 | "49 62 1.8806779384613037 0.8125\n", 176 | "50 62 1.8373451232910156 0.75\n", 177 | "51 62 1.842321753501892 0.78125\n", 178 | "52 62 1.8271459341049194 0.90625\n", 179 | "53 62 1.8349123001098633 0.84375\n", 180 | "54 62 1.8083330392837524 0.875\n", 181 | "55 62 1.8403195142745972 0.8125\n", 182 | "56 62 1.749794840812683 0.96875\n", 183 | "57 62 1.755589485168457 0.90625\n", 184 | "58 62 1.7760461568832397 0.90625\n", 185 | "59 62 1.736643671989441 0.84375\n", 186 | "60 62 1.76534104347229 0.9375\n", 187 | "61 62 1.723689079284668 0.90625\n" 188 | ] 189 | }, 190 | { 191 | "data": { 192 | "text/plain": [ 193 | "datetime.timedelta(seconds=12, microseconds=194751)" 194 | ] 195 | }, 196 | "execution_count": 3, 197 | "metadata": {}, 198 | "output_type": "execute_result" 199 | } 200 | ], 201 | "source": [ 202 | "import datetime\n", 203 | "\n", 204 | "#正常训练\n", 205 | "optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)\n", 206 | "model.to(device)\n", 207 | "\n", 208 | "now = datetime.datetime.now()\n", 209 | "for i, data in enumerate(loader):\n", 210 | " for k, v in data.items():\n", 211 | " data[k] = v.to(device)\n", 212 | " out = model(**data)\n", 213 | " out.loss.backward()\n", 214 | " torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n", 215 | " optimizer.step()\n", 216 | " optimizer.zero_grad()\n", 217 | "\n", 218 | " if i % 1 == 0:\n", 219 | " labels = data['labels']\n", 220 | " logits = out['logits'].argmax(1)\n", 221 | " acc = (labels == logits).sum().item() / len(labels)\n", 222 | "\n", 223 | " print(i, len(loader), out.loss.item(), acc)\n", 224 | "\n", 225 | "datetime.datetime.now() - now" 226 | ] 227 | }, 228 | { 229 | "cell_type": "code", 230 | "execution_count": 4, 231 | "id": "b8d5bac0", 232 | "metadata": {}, 233 | "outputs": [ 234 | { 235 | "name": "stderr", 236 | "output_type": "stream", 237 | "text": [ 238 | "/root/miniconda3/envs/cuda117/lib/python3.10/site-packages/peft/utils/save_and_load.py:154: UserWarning: Could not find a config file in - will assume that the vocabulary was not modified.\n", 239 | " warnings.warn(\n" 240 | ] 241 | }, 242 | { 243 | "data": { 244 | "text/plain": [ 245 | "Parameter containing:\n", 246 | "tensor([[ 0.0019, -0.0047, 0.0170, ..., 0.0103, -0.0181, -0.0162],\n", 247 | " [ 0.0281, 0.0129, 0.0396, ..., -0.0123, 0.0515, -0.0117],\n", 248 | " [-0.0530, -0.0161, -0.0173, ..., -0.0548, 0.0034, -0.0369],\n", 249 | " ...,\n", 250 | " [-0.0228, -0.0049, 0.0235, ..., -0.0174, 0.0303, 0.0107],\n", 251 | " [-0.0392, 0.0481, 0.0245, ..., 0.0204, -0.0020, 0.0287],\n", 252 | " [ 0.0116, -0.0089, -0.0318, ..., 0.0126, -0.0058, -0.0059]],\n", 253 | " device='cuda:0', requires_grad=True)" 254 | ] 255 | }, 256 | "execution_count": 4, 257 | "metadata": {}, 258 | "output_type": "execute_result" 259 | } 260 | ], 261 | "source": [ 262 | "#peft保存,保存的文件会很小,因为只保存了lora层\n", 263 | "model.save_pretrained('model/peft.save_pretrained')\n", 264 | "\n", 265 | "model.base_model.classifier.modules_to_save.default.weight" 266 | ] 267 | }, 268 | { 269 | "cell_type": "code", 270 | "execution_count": 5, 271 | "id": "64c2a5de", 272 | "metadata": { 273 | "scrolled": true 274 | }, 275 | "outputs": [ 276 | { 277 | "data": { 278 | "text/plain": [ 279 | "Parameter containing:\n", 280 | "tensor([[ 0.0019, -0.0047, 0.0170, ..., 0.0103, -0.0181, -0.0162],\n", 281 | " [ 0.0281, 0.0129, 0.0396, ..., -0.0123, 0.0515, -0.0117],\n", 282 | " [-0.0530, -0.0161, -0.0173, ..., -0.0548, 0.0034, -0.0369],\n", 283 | " ...,\n", 284 | " [-0.0228, -0.0049, 0.0235, ..., -0.0174, 0.0303, 0.0107],\n", 285 | " [-0.0392, 0.0481, 0.0245, ..., 0.0204, -0.0020, 0.0287],\n", 286 | " [ 0.0116, -0.0089, -0.0318, ..., 0.0126, -0.0058, -0.0059]],\n", 287 | " requires_grad=True)" 288 | ] 289 | }, 290 | "execution_count": 5, 291 | "metadata": {}, 292 | "output_type": "execute_result" 293 | } 294 | ], 295 | "source": [ 296 | "from transformers import BertForSequenceClassification\n", 297 | "from peft import PeftConfig, PeftModel\n", 298 | "\n", 299 | "#重启初始化原模型\n", 300 | "model = BertForSequenceClassification.from_pretrained('model/save_pretrained')\n", 301 | "\n", 302 | "#加载保存的config\n", 303 | "PeftConfig.from_pretrained('model/peft.save_pretrained')\n", 304 | "\n", 305 | "#插入保存的lora层\n", 306 | "model = PeftModel.from_pretrained(model,\n", 307 | " './model/peft.save_pretrained',\n", 308 | " is_trainable=True)\n", 309 | "\n", 310 | "model.base_model.classifier.modules_to_save.default.weight" 311 | ] 312 | }, 313 | { 314 | "cell_type": "code", 315 | "execution_count": 6, 316 | "id": "bad9dac1", 317 | "metadata": {}, 318 | "outputs": [ 319 | { 320 | "data": { 321 | "text/plain": [ 322 | "1.0" 323 | ] 324 | }, 325 | "execution_count": 6, 326 | "metadata": {}, 327 | "output_type": "execute_result" 328 | } 329 | ], 330 | "source": [ 331 | "#测试模型性能\n", 332 | "def test(model):\n", 333 | " model.to(device)\n", 334 | " data = next(iter(loader))\n", 335 | " for k, v in data.items():\n", 336 | " data[k] = v.to(device)\n", 337 | " with torch.no_grad():\n", 338 | " outs = model(**data)\n", 339 | " acc = (outs.logits.argmax(1) == data.labels).sum().item() / len(\n", 340 | " data.labels)\n", 341 | " return acc\n", 342 | "\n", 343 | "\n", 344 | "test(model)" 345 | ] 346 | }, 347 | { 348 | "cell_type": "code", 349 | "execution_count": 7, 350 | "id": "879db67d", 351 | "metadata": {}, 352 | "outputs": [ 353 | { 354 | "data": { 355 | "text/plain": [ 356 | "(transformers.models.bert.modeling_bert.BertForSequenceClassification, 1.0)" 357 | ] 358 | }, 359 | "execution_count": 7, 360 | "metadata": {}, 361 | "output_type": "execute_result" 362 | } 363 | ], 364 | "source": [ 365 | "#合并lora层到原始模型中,效果不会改变\n", 366 | "model_merge = model.merge_and_unload()\n", 367 | "\n", 368 | "type(model_merge), test(model_merge)" 369 | ] 370 | } 371 | ], 372 | "metadata": { 373 | "kernelspec": { 374 | "display_name": "Python [conda env:cuda117]", 375 | "language": "python", 376 | "name": "conda-env-cuda117-py" 377 | }, 378 | "language_info": { 379 | "codemirror_mode": { 380 | "name": "ipython", 381 | "version": 3 382 | }, 383 | "file_extension": ".py", 384 | "mimetype": "text/x-python", 385 | "name": "python", 386 | "nbconvert_exporter": "python", 387 | "pygments_lexer": "ipython3", 388 | "version": "3.10.13" 389 | } 390 | }, 391 | "nbformat": 4, 392 | "nbformat_minor": 5 393 | } 394 | --------------------------------------------------------------------------------