├── 01-finetune-opt-with-lora.ipynb ├── 02-finetune-gpt2-with-lora.ipynb ├── Readme.md └── images ├── auto_regressive_transformer.png └── lora.png /01-finetune-opt-with-lora.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "857cafc6-da38-4aa7-8afc-63aa626fa7aa", 6 | "metadata": {}, 7 | "source": [ 8 | "# 01. Finetuning OPT with LoRA\n", 9 | "\n", 10 | "Today's popular auto-regressive models - such as, GPT, LLaMA, Falcon, etc - are decoder-only models, in which the output token is predicted by using only input's text (called a prompt).\n", 11 | "\n", 12 | "![Decoder-only transformers](./images/auto_regressive_transformer.png)\n", 13 | "\n", 14 | "*\"Decoder-only\" model is implemented using layers in the red box.
\n", 15 | "(Diagram from : [Attention Is All You Need](https://arxiv.org/abs/1706.03762))*\n", 16 | "\n", 17 | "In this model, the task is differentiated also by using input's text (i.e, prompt).\n", 18 | "\n", 19 | "> Note : See [this repository](https://github.com/tsmatz/nlp-tutorials) for intrinsic idea of LLM transformers.\n", 20 | "\n", 21 | "In this example, we fine-tune the pre-trained auto-regressive model, Meta's OPT (```facebook/opt-125m```), by applying LoRA (Low-Rank Adaptation) optimization.\n", 22 | "\n", 23 | "In this example, I download the pre-trained model from Hugging Face hub, but fine-tune model with regular PyTorch training loop.
\n", 24 | "(Here I don't use Hugging Face Trainer class.)\n", 25 | "\n", 26 | "See [Readme](https://github.com/tsmatz/finetune_llm_with_lora) for prerequisite's setup." 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 1, 32 | "id": "3d49acf1-9ad1-4a6c-9312-6785cb3f5862", 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "model_name = \"facebook/opt-125m\"\n", 37 | "# model_name = \"facebook/opt-350m\"\n", 38 | "# model_name = \"facebook/opt-1.3b\"\n", 39 | "# model_name = \"facebook/opt-6.7b\"" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 2, 45 | "id": "4d835e84-a01d-4c33-926b-60d9dd4a7627", 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "import torch\n", 50 | "\n", 51 | "device = torch.device(\"cuda\")" 52 | ] 53 | }, 54 | { 55 | "cell_type": "markdown", 56 | "id": "ead383e5-149b-4bfb-9324-3cc639fd398d", 57 | "metadata": {}, 58 | "source": [ 59 | "## Prepare dataset and dataloader" 60 | ] 61 | }, 62 | { 63 | "cell_type": "markdown", 64 | "id": "91ecbb08-6a74-4623-bfe8-bddba5254e35", 65 | "metadata": {}, 66 | "source": [ 67 | "In this example, we use dataset used in [official LoRA example](https://github.com/microsoft/LoRA).\n", 68 | "\n", 69 | "Download dataset from official repository." 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": 3, 75 | "id": "54a564f1-f8f3-42a6-b160-bebdbcc3aac0", 76 | "metadata": {}, 77 | "outputs": [ 78 | { 79 | "name": "stdout", 80 | "output_type": "stream", 81 | "text": [ 82 | "--2023-10-06 03:27:50-- https://github.com/microsoft/LoRA/raw/main/examples/NLG/data/e2e/train.txt\n", 83 | "Resolving github.com (github.com)... 140.82.114.3\n", 84 | "Connecting to github.com (github.com)|140.82.114.3|:443... connected.\n", 85 | "HTTP request sent, awaiting response... 302 Found\n", 86 | "Location: https://raw.githubusercontent.com/microsoft/LoRA/main/examples/NLG/data/e2e/train.txt [following]\n", 87 | "--2023-10-06 03:27:51-- https://raw.githubusercontent.com/microsoft/LoRA/main/examples/NLG/data/e2e/train.txt\n", 88 | "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.109.133, 185.199.108.133, ...\n", 89 | "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.\n", 90 | "HTTP request sent, awaiting response... 200 OK\n", 91 | "Length: 9624463 (9.2M) [text/plain]\n", 92 | "Saving to: ‘train.txt’\n", 93 | "\n", 94 | "train.txt 100%[===================>] 9.18M --.-KB/s in 0.04s \n", 95 | "\n", 96 | "2023-10-06 03:27:51 (248 MB/s) - ‘train.txt’ saved [9624463/9624463]\n", 97 | "\n" 98 | ] 99 | } 100 | ], 101 | "source": [ 102 | "!wget https://github.com/microsoft/LoRA/raw/main/examples/NLG/data/e2e/train.txt" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": 4, 108 | "id": "d48464ea-991f-48b2-9166-3323cfd61676", 109 | "metadata": { 110 | "scrolled": true 111 | }, 112 | "outputs": [ 113 | { 114 | "name": "stdout", 115 | "output_type": "stream", 116 | "text": [ 117 | "--2023-10-06 03:27:54-- https://github.com/microsoft/LoRA/raw/main/examples/NLG/data/e2e/test.txt\n", 118 | "Resolving github.com (github.com)... 140.82.114.3\n", 119 | "Connecting to github.com (github.com)|140.82.114.3|:443... connected.\n", 120 | "HTTP request sent, awaiting response... 302 Found\n", 121 | "Location: https://raw.githubusercontent.com/microsoft/LoRA/main/examples/NLG/data/e2e/test.txt [following]\n", 122 | "--2023-10-06 03:27:54-- https://raw.githubusercontent.com/microsoft/LoRA/main/examples/NLG/data/e2e/test.txt\n", 123 | "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.108.133, 185.199.109.133, ...\n", 124 | "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.\n", 125 | "HTTP request sent, awaiting response... 200 OK\n", 126 | "Length: 1351149 (1.3M) [text/plain]\n", 127 | "Saving to: ‘test.txt’\n", 128 | "\n", 129 | "test.txt 100%[===================>] 1.29M --.-KB/s in 0.006s \n", 130 | "\n", 131 | "2023-10-06 03:27:54 (208 MB/s) - ‘test.txt’ saved [1351149/1351149]\n", 132 | "\n" 133 | ] 134 | } 135 | ], 136 | "source": [ 137 | "!wget https://github.com/microsoft/LoRA/raw/main/examples/NLG/data/e2e/test.txt" 138 | ] 139 | }, 140 | { 141 | "cell_type": "markdown", 142 | "id": "09472803-8c62-48e0-9a63-b9b9448f16d3", 143 | "metadata": {}, 144 | "source": [ 145 | "Show the downloaded data (first 5 rows)." 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": 5, 151 | "id": "e6e60596-028f-4c4b-a95d-f74a0ff3b188", 152 | "metadata": {}, 153 | "outputs": [ 154 | { 155 | "name": "stdout", 156 | "output_type": "stream", 157 | "text": [ 158 | "name : The Vaults | Type : pub | price : more than £ 30 | customer rating : 5 out of 5 | near : Café Adriatic||The Vaults pub near Café Adriatic has a 5 star rating . Prices start at £ 30 . \n", 159 | "name : The Cambridge Blue | Type : pub | food : English | price : cheap | near : Café Brazil||Close to Café Brazil , The Cambridge Blue pub serves delicious Tuscan Beef for the cheap price of £ 10.50 . Delicious Pub food . \n", 160 | "name : The Eagle | Type : coffee shop | food : Japanese | price : less than £ 20 | customer rating : low | area : riverside | family friendly : yes | near : Burger King||The Eagle is a low rated coffee shop near Burger King and the riverside that is family friendly and is less than £ 20 for Japanese food . \n", 161 | "name : The Mill | Type : coffee shop | food : French | price : £ 20 - 25 | area : riverside | near : The Sorrento||Located near The Sorrento is a French Theme eatery and coffee shop called The Mill , with a price range at £ 20- £ 25 it is in the riverside area . \n", 162 | "name : Loch Fyne | food : French | customer rating : high | area : riverside | near : The Rice Boat||For luxurious French food , the Loch Fyne is located by the river next to The Rice Boat . \n" 163 | ] 164 | } 165 | ], 166 | "source": [ 167 | "!head -n 5 train.txt" 168 | ] 169 | }, 170 | { 171 | "cell_type": "markdown", 172 | "id": "93f5fabe-590c-459b-aa16-4b5a506fb54b", 173 | "metadata": {}, 174 | "source": [ 175 | "Convert above data into JsonL format." 176 | ] 177 | }, 178 | { 179 | "cell_type": "code", 180 | "execution_count": 6, 181 | "id": "7376e0c0-16c9-46f4-ad4c-83d1a677f5a2", 182 | "metadata": {}, 183 | "outputs": [], 184 | "source": [ 185 | "import sys\n", 186 | "import io\n", 187 | "import json\n", 188 | "\n", 189 | "def format_convert(read_file, write_file):\n", 190 | " with open(read_file, \"r\", encoding=\"utf8\") as reader, \\\n", 191 | " \t open(write_file, \"w\", encoding=\"utf8\") as writer :\n", 192 | " \tfor line in reader:\n", 193 | " \t\titems = line.strip().split(\"||\")\n", 194 | " \t\tcontext = items[0]\n", 195 | " \t\tcompletion = items[1].strip(\"\\n\")\n", 196 | " \t\tx = {}\n", 197 | " \t\tx[\"context\"] = context\n", 198 | " \t\tx[\"completion\"] = completion\n", 199 | " \t\twriter.write(json.dumps(x)+\"\\n\")\n", 200 | "\n", 201 | "format_convert(\"train.txt\", \"train_formatted.jsonl\")\n", 202 | "format_convert(\"test.txt\", \"test_formatted.jsonl\")" 203 | ] 204 | }, 205 | { 206 | "cell_type": "markdown", 207 | "id": "3ceec952-fe03-475f-9f3e-22237cc9c44b", 208 | "metadata": {}, 209 | "source": [ 210 | "Show the converted data (first 5 rows)." 211 | ] 212 | }, 213 | { 214 | "cell_type": "code", 215 | "execution_count": 7, 216 | "id": "cb32aca7-bd0e-4847-a4c2-cc7e67dc2b7a", 217 | "metadata": {}, 218 | "outputs": [ 219 | { 220 | "name": "stdout", 221 | "output_type": "stream", 222 | "text": [ 223 | "{\"context\": \"name : The Vaults | Type : pub | price : more than \\u00a3 30 | customer rating : 5 out of 5 | near : Caf\\u00e9 Adriatic\", \"completion\": \"The Vaults pub near Caf\\u00e9 Adriatic has a 5 star rating . Prices start at \\u00a3 30 .\"}\n", 224 | "\n", 225 | "{\"context\": \"name : The Cambridge Blue | Type : pub | food : English | price : cheap | near : Caf\\u00e9 Brazil\", \"completion\": \"Close to Caf\\u00e9 Brazil , The Cambridge Blue pub serves delicious Tuscan Beef for the cheap price of \\u00a3 10.50 . Delicious Pub food .\"}\n", 226 | "\n", 227 | "{\"context\": \"name : The Eagle | Type : coffee shop | food : Japanese | price : less than \\u00a3 20 | customer rating : low | area : riverside | family friendly : yes | near : Burger King\", \"completion\": \"The Eagle is a low rated coffee shop near Burger King and the riverside that is family friendly and is less than \\u00a3 20 for Japanese food .\"}\n", 228 | "\n", 229 | "{\"context\": \"name : The Mill | Type : coffee shop | food : French | price : \\u00a3 20 - 25 | area : riverside | near : The Sorrento\", \"completion\": \"Located near The Sorrento is a French Theme eatery and coffee shop called The Mill , with a price range at \\u00a3 20- \\u00a3 25 it is in the riverside area .\"}\n", 230 | "\n", 231 | "{\"context\": \"name : Loch Fyne | food : French | customer rating : high | area : riverside | near : The Rice Boat\", \"completion\": \"For luxurious French food , the Loch Fyne is located by the river next to The Rice Boat .\"}\n", 232 | "\n" 233 | ] 234 | } 235 | ], 236 | "source": [ 237 | "with open(\"train_formatted.jsonl\", \"r\") as reader:\n", 238 | " for _ in range(5):\n", 239 | " print(next(reader))" 240 | ] 241 | }, 242 | { 243 | "cell_type": "markdown", 244 | "id": "6631f786-be4b-40cf-89d9-7009c1888821", 245 | "metadata": {}, 246 | "source": [ 247 | "Load tokenizer from Hugging Face." 248 | ] 249 | }, 250 | { 251 | "cell_type": "code", 252 | "execution_count": 8, 253 | "id": "e5433dc0-b5a5-4c01-adb5-3ffa2279eca8", 254 | "metadata": {}, 255 | "outputs": [], 256 | "source": [ 257 | "from transformers import AutoTokenizer\n", 258 | "import os\n", 259 | "\n", 260 | "tokenizer = AutoTokenizer.from_pretrained(\n", 261 | " model_name,\n", 262 | " fast_tokenizer=True)\n", 263 | "tokenizer.pad_token = tokenizer.eos_token\n", 264 | "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"" 265 | ] 266 | }, 267 | { 268 | "cell_type": "markdown", 269 | "id": "50817c47-a97b-4f80-975b-836859a0a7cf", 270 | "metadata": {}, 271 | "source": [ 272 | "Set block size, which is used to separate long text for model consumption." 273 | ] 274 | }, 275 | { 276 | "cell_type": "code", 277 | "execution_count": 9, 278 | "id": "5f250929-5703-4b17-9f7b-26340950c055", 279 | "metadata": {}, 280 | "outputs": [], 281 | "source": [ 282 | "block_size = 512" 283 | ] 284 | }, 285 | { 286 | "cell_type": "markdown", 287 | "id": "2332617b-1e66-4812-ad47-5eaeb52b101b", 288 | "metadata": {}, 289 | "source": [ 290 | "Create function to convert data. (Later this function is then used in data loader.)
\n", 291 | "In this function,\n", 292 | "\n", 293 | "1. Tokenize both contexts and compeletions. : e.g, ```\"This is a pen.\"``` --> ```[1212, 318, 257, 3112, 13]```\n", 294 | "2. Concatenate context's token and completion's token. (But it's delimited by \"\\n\" between context and completion.) This is used for inputs for LLM.\n", 295 | "3. Create labels (targets) with inputs. Label is ```input[1:]``` (i.e, shifted right by one element), and is filled by ```-100``` in context's positions. (See below note.)\n", 296 | "4. Pad tokens to make the length of token become ```block_size```.\n", 297 | "\n", 298 | "> Note : Here I set ```-100``` as an ignored index for loss computation, because PyTorch cross-entropy function (```torch.nn.functional.cross_entropy()```) has a property ```ignore_index``` which default value is ```-100```." 299 | ] 300 | }, 301 | { 302 | "cell_type": "code", 303 | "execution_count": 10, 304 | "id": "9f2f38aa-b3d0-4614-aa59-8ddd977176d1", 305 | "metadata": {}, 306 | "outputs": [], 307 | "source": [ 308 | "from torch.utils.data import DataLoader\n", 309 | "import pandas as pd\n", 310 | "\n", 311 | "def fill_ignore_label(l, c):\n", 312 | " l[:len(c) - 1] = [-100] * (len(c) - 1)\n", 313 | " return l\n", 314 | "\n", 315 | "def pad_tokens(tokens, max_seq_length, padding_token):\n", 316 | " res_tokens = tokens[:max_seq_length]\n", 317 | " token_len = len(res_tokens)\n", 318 | " res_tokens = res_tokens + \\\n", 319 | " [padding_token for _ in range(max_seq_length - token_len)]\n", 320 | " return res_tokens\n", 321 | "\n", 322 | "def collate_batch(batch):\n", 323 | " # tokenize both context and completion respectively\n", 324 | " # (context and completion is delimited by \"\\n\")\n", 325 | " context_list = list(zip(*batch))[0]\n", 326 | " context_list = [c + \"\\n\" for c in context_list]\n", 327 | " completion_list = list(zip(*batch))[1]\n", 328 | " context_result = tokenizer(context_list)\n", 329 | " context_tokens = context_result[\"input_ids\"]\n", 330 | " context_masks = context_result[\"attention_mask\"]\n", 331 | " completion_result = tokenizer(completion_list)\n", 332 | " completion_tokens = completion_result[\"input_ids\"]\n", 333 | " completion_masks = completion_result[\"attention_mask\"]\n", 334 | " # OPT tokenizer adds the start token in sequence,\n", 335 | " # and we then remove it in completion\n", 336 | " completion_tokens = [t[1:] for t in completion_tokens]\n", 337 | " completion_masks = [t[1:] for t in completion_masks]\n", 338 | " # concatenate token\n", 339 | " inputs = [i + j for i, j in zip(context_tokens, completion_tokens)]\n", 340 | " masks = [i + j for i, j in zip(context_masks, completion_masks)]\n", 341 | " # create label\n", 342 | " eos_id = tokenizer.encode(tokenizer.eos_token)[0]\n", 343 | " labels = [t[1:] + [eos_id] for t in inputs]\n", 344 | " labels = list(map(fill_ignore_label, labels, context_tokens))\n", 345 | " # truncate and pad tokens\n", 346 | " inputs = [pad_tokens(t, block_size, 0) for t in inputs] # OPT and GPT-2 doesn't use pad token (instead attn mask is used)\n", 347 | " masks = [pad_tokens(t, block_size, 0) for t in masks]\n", 348 | " labels = [pad_tokens(t, block_size, -100) for t in labels]\n", 349 | " # convert to tensor\n", 350 | " inputs = torch.tensor(inputs, dtype=torch.int64).to(device)\n", 351 | " masks = torch.tensor(masks, dtype=torch.int64).to(device)\n", 352 | " labels = torch.tensor(labels, dtype=torch.int64).to(device)\n", 353 | " return inputs, labels, masks" 354 | ] 355 | }, 356 | { 357 | "cell_type": "markdown", 358 | "id": "2084d2e9-ef64-47a2-aec9-d24ead1cb38a", 359 | "metadata": {}, 360 | "source": [ 361 | "Now create PyTorch dataloader with previous function (collator function).\n", 362 | "\n", 363 | "> Note : In this example, data is small and we then load all JSON data in memory.
\n", 364 | "> When it's large, load data progressively by implementing custom PyTorch dataset. (See [here](https://github.com/tsmatz/decision-transformer) for example.)" 365 | ] 366 | }, 367 | { 368 | "cell_type": "code", 369 | "execution_count": 11, 370 | "id": "f3bce3bb-2215-4bd6-a6a6-5b6b9d5afdc0", 371 | "metadata": {}, 372 | "outputs": [], 373 | "source": [ 374 | "batch_size = 8\n", 375 | "gradient_accumulation_steps = 16\n", 376 | "\n", 377 | "data = pd.read_json(\"train_formatted.jsonl\", lines=True)\n", 378 | "dataloader = DataLoader(\n", 379 | " list(zip(data[\"context\"], data[\"completion\"])),\n", 380 | " batch_size=batch_size,\n", 381 | " shuffle=True,\n", 382 | " collate_fn=collate_batch\n", 383 | ")" 384 | ] 385 | }, 386 | { 387 | "cell_type": "markdown", 388 | "id": "3ba64144-b698-457e-b827-941020456536", 389 | "metadata": {}, 390 | "source": [ 391 | "## Load model" 392 | ] 393 | }, 394 | { 395 | "cell_type": "markdown", 396 | "id": "1bfd360d-7bdc-4fd7-9b12-bcf9fe0a8db2", 397 | "metadata": {}, 398 | "source": [ 399 | "Load model from Hugging Face." 400 | ] 401 | }, 402 | { 403 | "cell_type": "code", 404 | "execution_count": 12, 405 | "id": "271181bd-677a-4da9-9e57-2874f5e47bd0", 406 | "metadata": {}, 407 | "outputs": [], 408 | "source": [ 409 | "from transformers import AutoModelForCausalLM, AutoConfig\n", 410 | "\n", 411 | "config = AutoConfig.from_pretrained(model_name)\n", 412 | "model = AutoModelForCausalLM.from_pretrained(\n", 413 | " model_name,\n", 414 | " config=config,\n", 415 | ").to(device)" 416 | ] 417 | }, 418 | { 419 | "cell_type": "markdown", 420 | "id": "27ab764a-d634-40f8-9edb-a01146845233", 421 | "metadata": {}, 422 | "source": [ 423 | "## Generate text (before fine-tuning)" 424 | ] 425 | }, 426 | { 427 | "cell_type": "markdown", 428 | "id": "559efeaf-4b38-4a0c-9be6-eb394221e374", 429 | "metadata": {}, 430 | "source": [ 431 | "Now run prediction with downloaded model (which is not still fine-tuned).\n", 432 | "\n", 433 | "First we create a function to generate text." 434 | ] 435 | }, 436 | { 437 | "cell_type": "code", 438 | "execution_count": 13, 439 | "id": "51a0c4fc-e0a7-4bbf-b25a-c335fe61f3df", 440 | "metadata": {}, 441 | "outputs": [], 442 | "source": [ 443 | "def generate_text(model, input, mask, eos_id, pred_sequence_length):\n", 444 | " predicted_last_id = -1\n", 445 | " start_token_len = torch.sum(mask).cpu().numpy()\n", 446 | " token_len = start_token_len\n", 447 | " with torch.no_grad():\n", 448 | " while (predicted_last_id != eos_id) and \\\n", 449 | " (token_len - start_token_len < pred_sequence_length):\n", 450 | " output = model(\n", 451 | " input_ids=input,\n", 452 | " attention_mask=mask,\n", 453 | " )\n", 454 | " predicted_ids = torch.argmax(output.logits, axis=-1).cpu().numpy()\n", 455 | " predicted_last_id = predicted_ids[0][token_len - 1]\n", 456 | " input[0][token_len] = predicted_last_id\n", 457 | " mask[0][token_len] = 1\n", 458 | " token_len = torch.sum(mask).cpu().numpy()\n", 459 | " return input, token_len" 460 | ] 461 | }, 462 | { 463 | "cell_type": "markdown", 464 | "id": "3936b1a1-ae9f-48a5-80db-691261dda704", 465 | "metadata": {}, 466 | "source": [ 467 | "Let's test our function and generate text. (Here we stop the text generation when it reaches 15 tokens in prediction.)" 468 | ] 469 | }, 470 | { 471 | "cell_type": "code", 472 | "execution_count": 14, 473 | "id": "28b7e13f-e8fb-4a9f-90ed-0464463ef569", 474 | "metadata": {}, 475 | "outputs": [ 476 | { 477 | "name": "stdout", 478 | "output_type": "stream", 479 | "text": [ 480 | "Once upon a time, I was a student at the University of California, Berkeley. I was a\n", 481 | "My name is Clara and I am a student at the University of California, Berkeley. I am a member of\n" 482 | ] 483 | } 484 | ], 485 | "source": [ 486 | "eos_id = tokenizer.encode(tokenizer.eos_token)[0]\n", 487 | "\n", 488 | "result = tokenizer(\"Once upon a time,\")\n", 489 | "input = result[\"input_ids\"]\n", 490 | "mask = result[\"attention_mask\"]\n", 491 | "input = pad_tokens(input, block_size, 0)\n", 492 | "mask = pad_tokens(mask, block_size, 0)\n", 493 | "input = torch.tensor([input], dtype=torch.int64).to(device)\n", 494 | "mask = torch.tensor([mask], dtype=torch.int64).to(device)\n", 495 | "\n", 496 | "result_token, result_len = generate_text(\n", 497 | " model,\n", 498 | " input,\n", 499 | " mask,\n", 500 | " eos_id,\n", 501 | " pred_sequence_length=15)\n", 502 | "print(tokenizer.decode(result_token[0][:result_len]))\n", 503 | "\n", 504 | "result = tokenizer(\"My name is Clara and I am\")\n", 505 | "input = result[\"input_ids\"]\n", 506 | "mask = result[\"attention_mask\"]\n", 507 | "input = pad_tokens(input, block_size, 0)\n", 508 | "mask = pad_tokens(mask, block_size, 0)\n", 509 | "input = torch.tensor([input], dtype=torch.int64).to(device)\n", 510 | "mask = torch.tensor([mask], dtype=torch.int64).to(device)\n", 511 | "\n", 512 | "result_token, result_len = generate_text(\n", 513 | " model,\n", 514 | " input,\n", 515 | " mask,\n", 516 | " eos_id,\n", 517 | " pred_sequence_length=15)\n", 518 | "print(tokenizer.decode(result_token[0][:result_len]))" 519 | ] 520 | }, 521 | { 522 | "cell_type": "markdown", 523 | "id": "d48fb60b-c05d-4884-a9bc-92152c94c894", 524 | "metadata": {}, 525 | "source": [ 526 | "Now we generate text with our test dataset (5 rows).
\n", 527 | "As you can see below, it cannot output the completion well, because it's not still fine-tuned." 528 | ] 529 | }, 530 | { 531 | "cell_type": "code", 532 | "execution_count": 15, 533 | "id": "495728ef-fbe6-4953-a354-4b7a8bb88798", 534 | "metadata": {}, 535 | "outputs": [ 536 | { 537 | "name": "stdout", 538 | "output_type": "stream", 539 | "text": [ 540 | "********** input **********\n", 541 | "name : The Wrestlers | Type : pub | food : Italian | price : less than £ 20 | area : riverside | family friendly : no | near : Raja Indian Cuisine\n", 542 | "\n", 543 | "********** result **********\n", 544 | "name : The Wrestlers | Type : pub | food : Italian | price : less than £ 20 | area : riverside | family friendly : no | near : Raja Indian Cuisine\n", 545 | "\n", 546 | "The Wrestlers is a restaurant in the heart of the city of Raja, India. It is located in the heart of the city of Raj\n", 547 | "********** input **********\n", 548 | "name : The Cricketers | Type : coffee shop | customer rating : 1 out of 5 | family friendly : yes | near : Avalon\n", 549 | "\n", 550 | "********** result **********\n", 551 | "name : The Cricketers | Type : coffee shop | customer rating : 1 out of 5 | family friendly : yes | near : Avalon\n", 552 | "\n", 553 | "The Cricketers is a coffee shop in Avalon, New York. It is located at the corner of Main Street and Main Street. The coffee\n", 554 | "********** input **********\n", 555 | "name : The Cricketers | Type : restaurant | food : Chinese | price : cheap | customer rating : 5 out of 5 | area : city centre | family friendly : no | near : All Bar One\n", 556 | "\n", 557 | "********** result **********\n", 558 | "name : The Cricketers | Type : restaurant | food : Chinese | price : cheap | customer rating : 5 out of 5 | area : city centre | family friendly : no | near : All Bar One\n", 559 | "\n", 560 | "The Cricketers | Type : restaurant | food : Chinese | price : cheap | customer rating : 5 out of 5 | area : city centre\n", 561 | "********** input **********\n", 562 | "name : The Punter | Type : restaurant | food : English | price : high | area : riverside | family friendly : no | near : Raja Indian Cuisine\n", 563 | "\n", 564 | "********** result **********\n", 565 | "name : The Punter | Type : restaurant | food : English | price : high | area : riverside | family friendly : no | near : Raja Indian Cuisine\n", 566 | "\n", 567 | "The Punter is a restaurant in Raja, India. It is located in the heart of the Raja district of Rajasthan. It\n", 568 | "********** input **********\n", 569 | "name : The Cricketers | Type : restaurant | food : Chinese | price : cheap | customer rating : average | area : city centre | family friendly : yes | near : All Bar One\n", 570 | "\n", 571 | "********** result **********\n", 572 | "name : The Cricketers | Type : restaurant | food : Chinese | price : cheap | customer rating : average | area : city centre | family friendly : yes | near : All Bar One\n", 573 | "\n", 574 | "The Cricketers | Type : restaurant | food : Chinese | price : cheap | customer rating : average | area : city centre | family friendly\n" 575 | ] 576 | } 577 | ], 578 | "source": [ 579 | "test_data = pd.read_json(\"test_formatted.jsonl\", lines=True)\n", 580 | "test_data = test_data[::2] # because it's duplicated\n", 581 | "test_loader = DataLoader(\n", 582 | " list(zip(test_data[\"context\"], [\"\"] * len(test_data[\"context\"]))),\n", 583 | " batch_size=1,\n", 584 | " shuffle=True,\n", 585 | " collate_fn=collate_batch\n", 586 | ")\n", 587 | "\n", 588 | "for i, (input, _, mask) in enumerate(test_loader):\n", 589 | " if i == 5:\n", 590 | " break\n", 591 | " print(\"********** input **********\")\n", 592 | " input_len = torch.sum(mask).cpu().numpy()\n", 593 | " print(tokenizer.decode(input[0][:input_len]))\n", 594 | " result_token, result_len = generate_text(\n", 595 | " model,\n", 596 | " input,\n", 597 | " mask,\n", 598 | " eos_id,\n", 599 | " pred_sequence_length=30)\n", 600 | " print(\"********** result **********\")\n", 601 | " print(tokenizer.decode(result_token[0][:result_len]))" 602 | ] 603 | }, 604 | { 605 | "cell_type": "markdown", 606 | "id": "e3138341-e01c-4fae-af78-c61e34967e92", 607 | "metadata": {}, 608 | "source": [ 609 | "## LoRA (Low-Rank Adaptation)\n", 610 | "\n", 611 | "Now we apply LoRA in our downloaded model.\n", 612 | "\n", 613 | "[LoRA (Low-Rank Adaptation)](https://arxiv.org/abs/2106.09685) (which is developed by Microsoft Research) is a popular adaptation method for efficient fine-tuning.\n", 614 | "\n", 615 | "In a task-specific fine-tuning, the change in weights during model adaptation has a low intrinsic rank.
\n", 616 | "With this hypothesis, we can assume that model's updates ($ \\Delta W $) will be re-written with much smaller low-rank matrices $ B \\cdot A $ as follows.\n", 617 | "\n", 618 | "$$ \\displaystyle W_0 x + \\Delta W x = W_0 x + B \\cdot A x $$\n", 619 | "\n", 620 | "where\n", 621 | "\n", 622 | "- $ W_0 \\in \\mathbb{R}^{d \\times k} $ is a pre-trained weight's matrix (which is frozen).\n", 623 | "- $ \\Delta W $ is updates.\n", 624 | "- $ B \\in \\mathbb{R}^{d \\times r}, A \\in \\mathbb{R}^{r \\times k} $ and $ \\verb| rank |\\ r \\ll min(d, k) $\n", 625 | "\n", 626 | "![LoRA](./images/lora.png)\n", 627 | "\n", 628 | "*From : [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685)*\n", 629 | "\n", 630 | "In this assumption, we freeze all weights except for $ B $ and $ A $, and train only these low-ranked matrices $ B $ and $ A $.
\n", 631 | "With this manner, you can fine-tune large transformers for a specific task without full-parameter's fine-tuning.\n", 632 | "\n", 633 | "This will significantly save the required capacity (GPU memories) for training, and the number of required GPUs can approximately be reduced to one-fourth in the benchmark with GPT-3.\n", 634 | "\n", 635 | "For the purpose of your learning, here I manually (from scratch) convert the current model into the model with LoRA.\n", 636 | "\n", 637 | "> Note : You can use ```PEFT``` package to be able to get LoRA model with a few lines of code. (Here I don't use this package.)" 638 | ] 639 | }, 640 | { 641 | "cell_type": "markdown", 642 | "id": "5265832d-a736-4d68-80d3-347833d2c590", 643 | "metadata": {}, 644 | "source": [ 645 | "Before changing our model, first we check the structure of our model.
\n", 646 | "As you can see below (see the result in the cell), the following 6 linear layers are used in a single transformer layer on OPT.\n", 647 | "\n", 648 | "- Linear layer to get key\n", 649 | "- Linear layer to get value\n", 650 | "- Linear layer to get query\n", 651 | "- Linear layer for the output of attention\n", 652 | "- 2 linear layers (feed-forward layer) for the output of a single layer of transformer\n", 653 | "\n", 654 | "In this example, we'll convert all these layers into LoRA layers.
\n", 655 | "The transformer in OPT-125M has 12 layers and it then has total 6 x 12 = 72 linear layers to be converted." 656 | ] 657 | }, 658 | { 659 | "cell_type": "code", 660 | "execution_count": 16, 661 | "id": "5acb8f62-791a-4fa4-b00c-2666cf34827f", 662 | "metadata": {}, 663 | "outputs": [ 664 | { 665 | "data": { 666 | "text/plain": [ 667 | "OPTForCausalLM(\n", 668 | " (model): OPTModel(\n", 669 | " (decoder): OPTDecoder(\n", 670 | " (embed_tokens): Embedding(50272, 768, padding_idx=1)\n", 671 | " (embed_positions): OPTLearnedPositionalEmbedding(2050, 768)\n", 672 | " (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", 673 | " (layers): ModuleList(\n", 674 | " (0-11): 12 x OPTDecoderLayer(\n", 675 | " (self_attn): OPTAttention(\n", 676 | " (k_proj): Linear(in_features=768, out_features=768, bias=True)\n", 677 | " (v_proj): Linear(in_features=768, out_features=768, bias=True)\n", 678 | " (q_proj): Linear(in_features=768, out_features=768, bias=True)\n", 679 | " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n", 680 | " )\n", 681 | " (activation_fn): ReLU()\n", 682 | " (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", 683 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", 684 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", 685 | " (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", 686 | " )\n", 687 | " )\n", 688 | " )\n", 689 | " )\n", 690 | " (lm_head): Linear(in_features=768, out_features=50272, bias=False)\n", 691 | ")" 692 | ] 693 | }, 694 | "execution_count": 16, 695 | "metadata": {}, 696 | "output_type": "execute_result" 697 | } 698 | ], 699 | "source": [ 700 | "model" 701 | ] 702 | }, 703 | { 704 | "cell_type": "markdown", 705 | "id": "045e7239-cb8a-46dd-815d-e48e7e49eea4", 706 | "metadata": {}, 707 | "source": [ 708 | "First we build custom linear layer with LoRA as follows." 709 | ] 710 | }, 711 | { 712 | "cell_type": "code", 713 | "execution_count": 17, 714 | "id": "77889272-9a93-491b-93cb-b0bed5ce7cd8", 715 | "metadata": {}, 716 | "outputs": [], 717 | "source": [ 718 | "import math\n", 719 | "from torch import nn\n", 720 | "\n", 721 | "class LoRA_Linear(nn.Module):\n", 722 | " def __init__(self, weight, bias, lora_dim):\n", 723 | " super(LoRA_Linear, self).__init__()\n", 724 | "\n", 725 | " row, column = weight.shape\n", 726 | "\n", 727 | " # restore Linear\n", 728 | " if bias is None:\n", 729 | " self.linear = nn.Linear(column, row, bias=False)\n", 730 | " self.linear.load_state_dict({\"weight\": weight})\n", 731 | " else:\n", 732 | " self.linear = nn.Linear(column, row)\n", 733 | " self.linear.load_state_dict({\"weight\": weight, \"bias\": bias})\n", 734 | "\n", 735 | " # create LoRA weights (with initialization)\n", 736 | " self.lora_right = nn.Parameter(torch.zeros(column, lora_dim))\n", 737 | " nn.init.kaiming_uniform_(self.lora_right, a=math.sqrt(5))\n", 738 | " self.lora_left = nn.Parameter(torch.zeros(lora_dim, row))\n", 739 | "\n", 740 | " def forward(self, input):\n", 741 | " x = self.linear(input)\n", 742 | " y = input @ self.lora_right @ self.lora_left\n", 743 | " return x + y" 744 | ] 745 | }, 746 | { 747 | "cell_type": "markdown", 748 | "id": "954e2c9d-545e-4bd9-9b0f-eba3fe29a1de", 749 | "metadata": {}, 750 | "source": [ 751 | "Replace targeting linear layers with LoRA layers." 752 | ] 753 | }, 754 | { 755 | "cell_type": "code", 756 | "execution_count": 18, 757 | "id": "baf8a748-a3e3-45b8-9c64-252c56abe923", 758 | "metadata": {}, 759 | "outputs": [], 760 | "source": [ 761 | "lora_dim = 128\n", 762 | "\n", 763 | "# get target module name\n", 764 | "target_names = []\n", 765 | "for name, module in model.named_modules():\n", 766 | " if isinstance(module, nn.Linear) and \"decoder.layers.\" in name:\n", 767 | " target_names.append(name)\n", 768 | "\n", 769 | "# replace each module with LoRA\n", 770 | "for name in target_names:\n", 771 | " name_struct = name.split(\".\")\n", 772 | " # get target module\n", 773 | " module_list = [model]\n", 774 | " for struct in name_struct:\n", 775 | " module_list.append(getattr(module_list[-1], struct))\n", 776 | " # build LoRA\n", 777 | " lora = LoRA_Linear(\n", 778 | " weight = module_list[-1].weight,\n", 779 | " bias = module_list[-1].bias,\n", 780 | " lora_dim = lora_dim,\n", 781 | " ).to(device)\n", 782 | " # replace\n", 783 | " module_list[-2].__setattr__(name_struct[-1], lora)" 784 | ] 785 | }, 786 | { 787 | "cell_type": "markdown", 788 | "id": "8aae2df9-fae7-4ecc-8260-80e8e578d951", 789 | "metadata": {}, 790 | "source": [ 791 | "See how model is changed." 792 | ] 793 | }, 794 | { 795 | "cell_type": "code", 796 | "execution_count": 19, 797 | "id": "bf16b414-b973-40eb-be81-fd2aa3dde439", 798 | "metadata": {}, 799 | "outputs": [ 800 | { 801 | "data": { 802 | "text/plain": [ 803 | "OPTForCausalLM(\n", 804 | " (model): OPTModel(\n", 805 | " (decoder): OPTDecoder(\n", 806 | " (embed_tokens): Embedding(50272, 768, padding_idx=1)\n", 807 | " (embed_positions): OPTLearnedPositionalEmbedding(2050, 768)\n", 808 | " (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", 809 | " (layers): ModuleList(\n", 810 | " (0-11): 12 x OPTDecoderLayer(\n", 811 | " (self_attn): OPTAttention(\n", 812 | " (k_proj): LoRA_Linear(\n", 813 | " (linear): Linear(in_features=768, out_features=768, bias=True)\n", 814 | " )\n", 815 | " (v_proj): LoRA_Linear(\n", 816 | " (linear): Linear(in_features=768, out_features=768, bias=True)\n", 817 | " )\n", 818 | " (q_proj): LoRA_Linear(\n", 819 | " (linear): Linear(in_features=768, out_features=768, bias=True)\n", 820 | " )\n", 821 | " (out_proj): LoRA_Linear(\n", 822 | " (linear): Linear(in_features=768, out_features=768, bias=True)\n", 823 | " )\n", 824 | " )\n", 825 | " (activation_fn): ReLU()\n", 826 | " (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", 827 | " (fc1): LoRA_Linear(\n", 828 | " (linear): Linear(in_features=768, out_features=3072, bias=True)\n", 829 | " )\n", 830 | " (fc2): LoRA_Linear(\n", 831 | " (linear): Linear(in_features=3072, out_features=768, bias=True)\n", 832 | " )\n", 833 | " (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", 834 | " )\n", 835 | " )\n", 836 | " )\n", 837 | " )\n", 838 | " (lm_head): Linear(in_features=768, out_features=50272, bias=False)\n", 839 | ")" 840 | ] 841 | }, 842 | "execution_count": 19, 843 | "metadata": {}, 844 | "output_type": "execute_result" 845 | } 846 | ], 847 | "source": [ 848 | "model" 849 | ] 850 | }, 851 | { 852 | "cell_type": "markdown", 853 | "id": "e9099c08-f6a6-45f8-939b-cc3ed9415976", 854 | "metadata": {}, 855 | "source": [ 856 | "Finally, freeze all parameters except for LoRA parameters." 857 | ] 858 | }, 859 | { 860 | "cell_type": "code", 861 | "execution_count": 20, 862 | "id": "81d06bba-955b-4806-8ff7-f217252e3268", 863 | "metadata": {}, 864 | "outputs": [], 865 | "source": [ 866 | "for name, param in model.named_parameters():\n", 867 | " if \"lora_right\" in name or \"lora_left\" in name:\n", 868 | " param.requires_grad = True\n", 869 | " else:\n", 870 | " param.requires_grad = False" 871 | ] 872 | }, 873 | { 874 | "cell_type": "code", 875 | "execution_count": null, 876 | "id": "6c0a4469-2827-4f30-9324-711a9feea1ae", 877 | "metadata": {}, 878 | "outputs": [], 879 | "source": [ 880 | "### Do this when you run adapter fine-tuning on Hugging Face framework\n", 881 | "# model.gradient_checkpointing_enable()\n", 882 | "# model.enable_input_require_grads()" 883 | ] 884 | }, 885 | { 886 | "cell_type": "markdown", 887 | "id": "6d6c7d6f-6c50-4839-88a5-c851caab9ba2", 888 | "metadata": {}, 889 | "source": [ 890 | "## Fine-tune" 891 | ] 892 | }, 893 | { 894 | "cell_type": "markdown", 895 | "id": "a12b875f-36cc-40b8-aaab-1efda68710f3", 896 | "metadata": {}, 897 | "source": [ 898 | "Now let's start to run fine-tuning.\n", 899 | "\n", 900 | "First we build optimizer as follows." 901 | ] 902 | }, 903 | { 904 | "cell_type": "code", 905 | "execution_count": 21, 906 | "id": "bb51298a-2d55-466c-a990-0ea08a247350", 907 | "metadata": {}, 908 | "outputs": [], 909 | "source": [ 910 | "optimizer = torch.optim.AdamW(\n", 911 | " params=model.parameters(),\n", 912 | " lr=1e-3,\n", 913 | " betas=(0.9, 0.95),\n", 914 | ")" 915 | ] 916 | }, 917 | { 918 | "cell_type": "markdown", 919 | "id": "d37db1a8-0053-4acc-94ce-89d87c78942e", 920 | "metadata": {}, 921 | "source": [ 922 | "In this example, we build cosine scheduler for training." 923 | ] 924 | }, 925 | { 926 | "cell_type": "code", 927 | "execution_count": 22, 928 | "id": "6f95bdf6-4498-4d40-90aa-1267d55f38c3", 929 | "metadata": {}, 930 | "outputs": [], 931 | "source": [ 932 | "from torch.optim.lr_scheduler import LambdaLR\n", 933 | "\n", 934 | "num_epochs = 2\n", 935 | "\n", 936 | "num_update_steps = math.ceil(len(dataloader) / batch_size / gradient_accumulation_steps)\n", 937 | "def _get_cosine_schedule(\n", 938 | " current_step: int,\n", 939 | " num_warmup_steps: int = 0,\n", 940 | " num_training_steps: int = num_epochs * num_update_steps\n", 941 | "):\n", 942 | " if current_step < num_warmup_steps:\n", 943 | " return float(current_step) / float(max(1, num_warmup_steps))\n", 944 | " progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))\n", 945 | " return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))\n", 946 | "scheduler = LambdaLR(optimizer, lr_lambda=_get_cosine_schedule)" 947 | ] 948 | }, 949 | { 950 | "cell_type": "markdown", 951 | "id": "a9f9e828-c4fb-493d-a6de-78e03dbf035e", 952 | "metadata": {}, 953 | "source": [ 954 | "Run fine-tuning." 955 | ] 956 | }, 957 | { 958 | "cell_type": "code", 959 | "execution_count": 23, 960 | "id": "75d22125-830a-4ec6-8417-cdb8a97ec559", 961 | "metadata": {}, 962 | "outputs": [ 963 | { 964 | "name": "stdout", 965 | "output_type": "stream", 966 | "text": [ 967 | "Epoch 1 42/42 - loss: 1.0724\n", 968 | "Epoch 2 42/42 - loss: 1.3185\n" 969 | ] 970 | } 971 | ], 972 | "source": [ 973 | "from torch.nn import functional as F\n", 974 | "\n", 975 | "if os.path.exists(\"loss.txt\"):\n", 976 | " os.remove(\"loss.txt\")\n", 977 | "\n", 978 | "for epoch in range(num_epochs):\n", 979 | " optimizer.zero_grad()\n", 980 | " model.train()\n", 981 | " for i, (inputs, labels, masks) in enumerate(dataloader):\n", 982 | " with torch.set_grad_enabled(True):\n", 983 | " outputs = model(\n", 984 | " input_ids=inputs,\n", 985 | " attention_mask=masks,\n", 986 | " )\n", 987 | " loss = F.cross_entropy(outputs.logits.transpose(1,2), labels)\n", 988 | " loss.backward()\n", 989 | " if ((i + 1) % gradient_accumulation_steps == 0) or \\\n", 990 | " (i + 1 == len(dataloader)):\n", 991 | " optimizer.step()\n", 992 | " optimizer.zero_grad()\n", 993 | " scheduler.step()\n", 994 | "\n", 995 | " print(f\"Epoch {epoch+1} {math.ceil((i + 1) / batch_size / gradient_accumulation_steps)}/{num_update_steps} - loss: {loss.item() :2.4f}\", end=\"\\r\")\n", 996 | "\n", 997 | " # record loss\n", 998 | " with open(\"loss.txt\", \"a\") as f:\n", 999 | " f.write(str(loss.item()))\n", 1000 | " f.write(\"\\n\")\n", 1001 | " print(\"\")\n", 1002 | "\n", 1003 | "# save model\n", 1004 | "torch.save(model.state_dict(), \"finetuned_opt.bin\")" 1005 | ] 1006 | }, 1007 | { 1008 | "cell_type": "markdown", 1009 | "id": "83993d92-d7ed-4a07-8985-cc59bd4e4fef", 1010 | "metadata": {}, 1011 | "source": [ 1012 | "> Note : Here we save LoRA-enabled model without any changes, but you can also merge the trained LoRA's parameters into the original linear layer's weights." 1013 | ] 1014 | }, 1015 | { 1016 | "cell_type": "markdown", 1017 | "id": "1bc086e5-e93f-4264-a8fa-6428f844ac3c", 1018 | "metadata": {}, 1019 | "source": [ 1020 | "Show loss transition in plot." 1021 | ] 1022 | }, 1023 | { 1024 | "cell_type": "code", 1025 | "execution_count": 25, 1026 | "id": "e37c5aee-38d4-4a2a-952c-4fd2bef41e2b", 1027 | "metadata": {}, 1028 | "outputs": [ 1029 | { 1030 | "data": { 1031 | "image/png": "", 1032 | "text/plain": [ 1033 | "
" 1034 | ] 1035 | }, 1036 | "metadata": {}, 1037 | "output_type": "display_data" 1038 | } 1039 | ], 1040 | "source": [ 1041 | "import matplotlib.pyplot as plt\n", 1042 | "import pandas as pd\n", 1043 | "\n", 1044 | "data = pd.read_csv(\"loss.txt\")\n", 1045 | "plt.plot(data)\n", 1046 | "plt.show()" 1047 | ] 1048 | }, 1049 | { 1050 | "cell_type": "markdown", 1051 | "id": "9809bc9f-4ff6-46c3-9c43-08c6c2694a82", 1052 | "metadata": {}, 1053 | "source": [ 1054 | "## Generate text with fine-tuned model\n", 1055 | "\n", 1056 | "Again we check results with our test dataset (5 rows).
\n", 1057 | "As you can see below, it can output the completion very well, because it's fine-tuned." 1058 | ] 1059 | }, 1060 | { 1061 | "cell_type": "code", 1062 | "execution_count": 26, 1063 | "id": "29903cae-404e-4209-9c84-6c8a69609c13", 1064 | "metadata": {}, 1065 | "outputs": [ 1066 | { 1067 | "name": "stdout", 1068 | "output_type": "stream", 1069 | "text": [ 1070 | "********** input **********\n", 1071 | "name : The Punter | Type : pub | food : Chinese | price : more than £ 30 | area : riverside | family friendly : yes | near : Raja Indian Cuisine\n", 1072 | "\n", 1073 | "********** result **********\n", 1074 | "name : The Punter | Type : pub | food : Chinese | price : more than £ 30 | area : riverside | family friendly : yes | near : Raja Indian Cuisine\n", 1075 | "The Punter is a children friendly pub that serves Chinese food. It is located in the riverside area near Raja Indian Cuisine and has a\n", 1076 | "********** input **********\n", 1077 | "name : The Cricketers | Type : restaurant | food : Chinese | price : cheap | customer rating : 5 out of 5 | area : city centre | family friendly : no | near : All Bar One\n", 1078 | "\n", 1079 | "********** result **********\n", 1080 | "name : The Cricketers | Type : restaurant | food : Chinese | price : cheap | customer rating : 5 out of 5 | area : city centre | family friendly : no | near : All Bar One\n", 1081 | "The Cricketers is a Chinese restaurant with a cheap price range, located in the city centre near All Bar One. It has a customer rating of\n", 1082 | "********** input **********\n", 1083 | "name : The Phoenix | Type : pub | food : French | price : moderate | customer rating : 1 out of 5 | area : riverside | family friendly : no | near : Crowne Plaza Hotel\n", 1084 | "\n", 1085 | "********** result **********\n", 1086 | "name : The Phoenix | Type : pub | food : French | price : moderate | customer rating : 1 out of 5 | area : riverside | family friendly : no | near : Crowne Plaza Hotel\n", 1087 | "The Phoenix is a pub that serves French food. It is located near Crown Plaza Hotel in the riverside area. It has a moderate price range and\n", 1088 | "********** input **********\n", 1089 | "name : Giraffe | Type : restaurant | food : Fast food | area : riverside | family friendly : yes | near : Rainbow Vegetarian Café\n", 1090 | "\n", 1091 | "********** result **********\n", 1092 | "name : Giraffe | Type : restaurant | food : Fast food | area : riverside | family friendly : yes | near : Rainbow Vegetarian Café\n", 1093 | "Giraffe is a fast food restaurant located in the riverside area near Rainbow Vegetarian Café. It is family friendly.\n", 1094 | "********** input **********\n", 1095 | "name : The Vaults | Type : pub | food : French | price : more than £ 30 | area : city centre | family friendly : yes | near : Raja Indian Cuisine\n", 1096 | "\n", 1097 | "********** result **********\n", 1098 | "name : The Vaults | Type : pub | food : French | price : more than £ 30 | area : city centre | family friendly : yes | near : Raja Indian Cuisine\n", 1099 | "The Vaults is a children friendly French pub located in the city centre near Raja Indian Cuisine.\n" 1100 | ] 1101 | } 1102 | ], 1103 | "source": [ 1104 | "test_data = pd.read_json(\"test_formatted.jsonl\", lines=True)\n", 1105 | "test_data = test_data[::2] # because it's duplicated\n", 1106 | "test_loader = DataLoader(\n", 1107 | " list(zip(test_data[\"context\"], [\"\"] * len(test_data[\"context\"]))),\n", 1108 | " batch_size=1,\n", 1109 | " shuffle=True,\n", 1110 | " collate_fn=collate_batch\n", 1111 | ")\n", 1112 | "\n", 1113 | "for i, (input, _, mask) in enumerate(test_loader):\n", 1114 | " if i == 5:\n", 1115 | " break\n", 1116 | " print(\"********** input **********\")\n", 1117 | " input_len = torch.sum(mask).cpu().numpy()\n", 1118 | " print(tokenizer.decode(input[0][:input_len]))\n", 1119 | " result_token, result_len = generate_text(\n", 1120 | " model,\n", 1121 | " input,\n", 1122 | " mask,\n", 1123 | " eos_id,\n", 1124 | " pred_sequence_length=30)\n", 1125 | " print(\"********** result **********\")\n", 1126 | " print(tokenizer.decode(result_token[0][:result_len]))" 1127 | ] 1128 | }, 1129 | { 1130 | "cell_type": "code", 1131 | "execution_count": null, 1132 | "id": "6a7c1dd3-4057-497a-83ae-f99b1883697e", 1133 | "metadata": {}, 1134 | "outputs": [], 1135 | "source": [] 1136 | } 1137 | ], 1138 | "metadata": { 1139 | "kernelspec": { 1140 | "display_name": "Python 3 (ipykernel)", 1141 | "language": "python", 1142 | "name": "python3" 1143 | }, 1144 | "language_info": { 1145 | "codemirror_mode": { 1146 | "name": "ipython", 1147 | "version": 3 1148 | }, 1149 | "file_extension": ".py", 1150 | "mimetype": "text/x-python", 1151 | "name": "python", 1152 | "nbconvert_exporter": "python", 1153 | "pygments_lexer": "ipython3", 1154 | "version": "3.8.10" 1155 | } 1156 | }, 1157 | "nbformat": 4, 1158 | "nbformat_minor": 5 1159 | } 1160 | -------------------------------------------------------------------------------- /02-finetune-gpt2-with-lora.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "857cafc6-da38-4aa7-8afc-63aa626fa7aa", 6 | "metadata": {}, 7 | "source": [ 8 | "# 02. Finetuning GPT-2 with LoRA\n", 9 | "\n", 10 | "In this example, we fine-tune the pre-trained auto-regressive model, **OpenAI's GPT-2** (small version, 124M parameters), by applying LoRA (Low-Rank Adaptation) optimization.\n", 11 | "\n", 12 | "In this example, I download the pre-trained model from Hugging Face hub, but fine-tune model with regular PyTorch training loop.
\n", 13 | "(Here I don't use Hugging Face Trainer class.)\n", 14 | "\n", 15 | "See [Readme](https://github.com/tsmatz/finetune_llm_with_lora) for prerequisite's setup." 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": 1, 21 | "id": "3d49acf1-9ad1-4a6c-9312-6785cb3f5862", 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "model_name = \"gpt2\"" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 2, 31 | "id": "4d835e84-a01d-4c33-926b-60d9dd4a7627", 32 | "metadata": {}, 33 | "outputs": [], 34 | "source": [ 35 | "import torch\n", 36 | "\n", 37 | "device = torch.device(\"cuda\")" 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "id": "ead383e5-149b-4bfb-9324-3cc639fd398d", 43 | "metadata": {}, 44 | "source": [ 45 | "## Prepare dataset and dataloader" 46 | ] 47 | }, 48 | { 49 | "cell_type": "markdown", 50 | "id": "91ecbb08-6a74-4623-bfe8-bddba5254e35", 51 | "metadata": {}, 52 | "source": [ 53 | "In this example, we use dataset used in [official LoRA example](https://github.com/microsoft/LoRA).\n", 54 | "\n", 55 | "Download dataset from official repository." 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 3, 61 | "id": "54a564f1-f8f3-42a6-b160-bebdbcc3aac0", 62 | "metadata": {}, 63 | "outputs": [ 64 | { 65 | "name": "stdout", 66 | "output_type": "stream", 67 | "text": [ 68 | "--2023-10-06 03:27:50-- https://github.com/microsoft/LoRA/raw/main/examples/NLG/data/e2e/train.txt\n", 69 | "Resolving github.com (github.com)... 140.82.114.3\n", 70 | "Connecting to github.com (github.com)|140.82.114.3|:443... connected.\n", 71 | "HTTP request sent, awaiting response... 302 Found\n", 72 | "Location: https://raw.githubusercontent.com/microsoft/LoRA/main/examples/NLG/data/e2e/train.txt [following]\n", 73 | "--2023-10-06 03:27:51-- https://raw.githubusercontent.com/microsoft/LoRA/main/examples/NLG/data/e2e/train.txt\n", 74 | "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.109.133, 185.199.108.133, ...\n", 75 | "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.\n", 76 | "HTTP request sent, awaiting response... 200 OK\n", 77 | "Length: 9624463 (9.2M) [text/plain]\n", 78 | "Saving to: ‘train.txt’\n", 79 | "\n", 80 | "train.txt 100%[===================>] 9.18M --.-KB/s in 0.04s \n", 81 | "\n", 82 | "2023-10-06 03:27:51 (248 MB/s) - ‘train.txt’ saved [9624463/9624463]\n", 83 | "\n" 84 | ] 85 | } 86 | ], 87 | "source": [ 88 | "!wget https://github.com/microsoft/LoRA/raw/main/examples/NLG/data/e2e/train.txt" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": 4, 94 | "id": "d48464ea-991f-48b2-9166-3323cfd61676", 95 | "metadata": { 96 | "scrolled": true 97 | }, 98 | "outputs": [ 99 | { 100 | "name": "stdout", 101 | "output_type": "stream", 102 | "text": [ 103 | "--2023-10-06 03:27:54-- https://github.com/microsoft/LoRA/raw/main/examples/NLG/data/e2e/test.txt\n", 104 | "Resolving github.com (github.com)... 140.82.114.3\n", 105 | "Connecting to github.com (github.com)|140.82.114.3|:443... connected.\n", 106 | "HTTP request sent, awaiting response... 302 Found\n", 107 | "Location: https://raw.githubusercontent.com/microsoft/LoRA/main/examples/NLG/data/e2e/test.txt [following]\n", 108 | "--2023-10-06 03:27:54-- https://raw.githubusercontent.com/microsoft/LoRA/main/examples/NLG/data/e2e/test.txt\n", 109 | "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.108.133, 185.199.109.133, ...\n", 110 | "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.\n", 111 | "HTTP request sent, awaiting response... 200 OK\n", 112 | "Length: 1351149 (1.3M) [text/plain]\n", 113 | "Saving to: ‘test.txt’\n", 114 | "\n", 115 | "test.txt 100%[===================>] 1.29M --.-KB/s in 0.006s \n", 116 | "\n", 117 | "2023-10-06 03:27:54 (208 MB/s) - ‘test.txt’ saved [1351149/1351149]\n", 118 | "\n" 119 | ] 120 | } 121 | ], 122 | "source": [ 123 | "!wget https://github.com/microsoft/LoRA/raw/main/examples/NLG/data/e2e/test.txt" 124 | ] 125 | }, 126 | { 127 | "cell_type": "markdown", 128 | "id": "09472803-8c62-48e0-9a63-b9b9448f16d3", 129 | "metadata": {}, 130 | "source": [ 131 | "Show the downloaded data (first 5 rows)." 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": 5, 137 | "id": "e6e60596-028f-4c4b-a95d-f74a0ff3b188", 138 | "metadata": {}, 139 | "outputs": [ 140 | { 141 | "name": "stdout", 142 | "output_type": "stream", 143 | "text": [ 144 | "name : The Vaults | Type : pub | price : more than £ 30 | customer rating : 5 out of 5 | near : Café Adriatic||The Vaults pub near Café Adriatic has a 5 star rating . Prices start at £ 30 . \n", 145 | "name : The Cambridge Blue | Type : pub | food : English | price : cheap | near : Café Brazil||Close to Café Brazil , The Cambridge Blue pub serves delicious Tuscan Beef for the cheap price of £ 10.50 . Delicious Pub food . \n", 146 | "name : The Eagle | Type : coffee shop | food : Japanese | price : less than £ 20 | customer rating : low | area : riverside | family friendly : yes | near : Burger King||The Eagle is a low rated coffee shop near Burger King and the riverside that is family friendly and is less than £ 20 for Japanese food . \n", 147 | "name : The Mill | Type : coffee shop | food : French | price : £ 20 - 25 | area : riverside | near : The Sorrento||Located near The Sorrento is a French Theme eatery and coffee shop called The Mill , with a price range at £ 20- £ 25 it is in the riverside area . \n", 148 | "name : Loch Fyne | food : French | customer rating : high | area : riverside | near : The Rice Boat||For luxurious French food , the Loch Fyne is located by the river next to The Rice Boat . \n" 149 | ] 150 | } 151 | ], 152 | "source": [ 153 | "!head -n 5 train.txt" 154 | ] 155 | }, 156 | { 157 | "cell_type": "markdown", 158 | "id": "93f5fabe-590c-459b-aa16-4b5a506fb54b", 159 | "metadata": {}, 160 | "source": [ 161 | "Convert above data into JsonL format." 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": 6, 167 | "id": "7376e0c0-16c9-46f4-ad4c-83d1a677f5a2", 168 | "metadata": {}, 169 | "outputs": [], 170 | "source": [ 171 | "import sys\n", 172 | "import io\n", 173 | "import json\n", 174 | "\n", 175 | "def format_convert(read_file, write_file):\n", 176 | " with open(read_file, \"r\", encoding=\"utf8\") as reader, \\\n", 177 | " \t open(write_file, \"w\", encoding=\"utf8\") as writer :\n", 178 | " \tfor line in reader:\n", 179 | " \t\titems = line.strip().split(\"||\")\n", 180 | " \t\tcontext = items[0]\n", 181 | " \t\tcompletion = items[1].strip(\"\\n\")\n", 182 | " \t\tx = {}\n", 183 | " \t\tx[\"context\"] = context\n", 184 | " \t\tx[\"completion\"] = completion\n", 185 | " \t\twriter.write(json.dumps(x)+\"\\n\")\n", 186 | "\n", 187 | "format_convert(\"train.txt\", \"train_formatted.jsonl\")\n", 188 | "format_convert(\"test.txt\", \"test_formatted.jsonl\")" 189 | ] 190 | }, 191 | { 192 | "cell_type": "markdown", 193 | "id": "3ceec952-fe03-475f-9f3e-22237cc9c44b", 194 | "metadata": {}, 195 | "source": [ 196 | "Show the converted data (first 5 rows)." 197 | ] 198 | }, 199 | { 200 | "cell_type": "code", 201 | "execution_count": 7, 202 | "id": "cb32aca7-bd0e-4847-a4c2-cc7e67dc2b7a", 203 | "metadata": {}, 204 | "outputs": [ 205 | { 206 | "name": "stdout", 207 | "output_type": "stream", 208 | "text": [ 209 | "{\"context\": \"name : The Vaults | Type : pub | price : more than \\u00a3 30 | customer rating : 5 out of 5 | near : Caf\\u00e9 Adriatic\", \"completion\": \"The Vaults pub near Caf\\u00e9 Adriatic has a 5 star rating . Prices start at \\u00a3 30 .\"}\n", 210 | "\n", 211 | "{\"context\": \"name : The Cambridge Blue | Type : pub | food : English | price : cheap | near : Caf\\u00e9 Brazil\", \"completion\": \"Close to Caf\\u00e9 Brazil , The Cambridge Blue pub serves delicious Tuscan Beef for the cheap price of \\u00a3 10.50 . Delicious Pub food .\"}\n", 212 | "\n", 213 | "{\"context\": \"name : The Eagle | Type : coffee shop | food : Japanese | price : less than \\u00a3 20 | customer rating : low | area : riverside | family friendly : yes | near : Burger King\", \"completion\": \"The Eagle is a low rated coffee shop near Burger King and the riverside that is family friendly and is less than \\u00a3 20 for Japanese food .\"}\n", 214 | "\n", 215 | "{\"context\": \"name : The Mill | Type : coffee shop | food : French | price : \\u00a3 20 - 25 | area : riverside | near : The Sorrento\", \"completion\": \"Located near The Sorrento is a French Theme eatery and coffee shop called The Mill , with a price range at \\u00a3 20- \\u00a3 25 it is in the riverside area .\"}\n", 216 | "\n", 217 | "{\"context\": \"name : Loch Fyne | food : French | customer rating : high | area : riverside | near : The Rice Boat\", \"completion\": \"For luxurious French food , the Loch Fyne is located by the river next to The Rice Boat .\"}\n", 218 | "\n" 219 | ] 220 | } 221 | ], 222 | "source": [ 223 | "with open(\"train_formatted.jsonl\", \"r\") as reader:\n", 224 | " for _ in range(5):\n", 225 | " print(next(reader))" 226 | ] 227 | }, 228 | { 229 | "cell_type": "markdown", 230 | "id": "6631f786-be4b-40cf-89d9-7009c1888821", 231 | "metadata": {}, 232 | "source": [ 233 | "Load tokenizer from Hugging Face." 234 | ] 235 | }, 236 | { 237 | "cell_type": "code", 238 | "execution_count": 8, 239 | "id": "e5433dc0-b5a5-4c01-adb5-3ffa2279eca8", 240 | "metadata": {}, 241 | "outputs": [], 242 | "source": [ 243 | "from transformers import AutoTokenizer\n", 244 | "import os\n", 245 | "\n", 246 | "tokenizer = AutoTokenizer.from_pretrained(\n", 247 | " model_name,\n", 248 | " fast_tokenizer=True)\n", 249 | "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"" 250 | ] 251 | }, 252 | { 253 | "cell_type": "markdown", 254 | "id": "50817c47-a97b-4f80-975b-836859a0a7cf", 255 | "metadata": {}, 256 | "source": [ 257 | "Set block size which is used to separate long text for model consumption.
\n", 258 | "Max 1024 tokens can be used in GPT-2, but here I set 512, because it's enough for our dataset." 259 | ] 260 | }, 261 | { 262 | "cell_type": "code", 263 | "execution_count": 9, 264 | "id": "5f250929-5703-4b17-9f7b-26340950c055", 265 | "metadata": {}, 266 | "outputs": [ 267 | { 268 | "name": "stdout", 269 | "output_type": "stream", 270 | "text": [ 271 | "Max length of tokens is 1024 in this model.\n", 272 | "But here we use max 512 tokens in the training.\n" 273 | ] 274 | } 275 | ], 276 | "source": [ 277 | "block_size = 512\n", 278 | "\n", 279 | "print(f\"Max length of tokens is {tokenizer.model_max_length} in this model.\")\n", 280 | "print(f\"But here we use max {block_size} tokens in the training.\")" 281 | ] 282 | }, 283 | { 284 | "cell_type": "markdown", 285 | "id": "2332617b-1e66-4812-ad47-5eaeb52b101b", 286 | "metadata": {}, 287 | "source": [ 288 | "Create function to convert data. (Later this function is then used in data loader.)
\n", 289 | "In this function,\n", 290 | "\n", 291 | "1. Tokenize both contexts and compeletions. : e.g, ```\"This is a pen.\"``` --> ```[1212, 318, 257, 3112, 13]```\n", 292 | "2. Concatenate context's token and completion's token. (But it's delimited by \"\\n\" between context and completion.) This is used for inputs for LLM.\n", 293 | "3. Create labels (targets) with inputs. Label is ```input[1:]``` (i.e, shifted right by one element), and is filled by ```-100``` in context's positions. (See below note.)\n", 294 | "4. Pad tokens to make the length of token become ```block_size```.\n", 295 | "\n", 296 | "> Note : Here I set ```-100``` as an ignored index for loss computation, because PyTorch cross-entropy function (```torch.nn.functional.cross_entropy()```) has a property ```ignore_index``` which default value is ```-100```." 297 | ] 298 | }, 299 | { 300 | "cell_type": "code", 301 | "execution_count": 10, 302 | "id": "9f2f38aa-b3d0-4614-aa59-8ddd977176d1", 303 | "metadata": {}, 304 | "outputs": [], 305 | "source": [ 306 | "from torch.utils.data import DataLoader\n", 307 | "import pandas as pd\n", 308 | "\n", 309 | "def fill_ignore_label(l, c):\n", 310 | " l[:len(c) - 1] = [-100] * (len(c) - 1)\n", 311 | " return l\n", 312 | "\n", 313 | "def pad_tokens(tokens, max_seq_length, padding_token):\n", 314 | " res_tokens = tokens[:max_seq_length]\n", 315 | " token_len = len(res_tokens)\n", 316 | " res_tokens = res_tokens + \\\n", 317 | " [padding_token for _ in range(max_seq_length - token_len)]\n", 318 | " return res_tokens\n", 319 | "\n", 320 | "def collate_batch(batch):\n", 321 | " # tokenize both context and completion respectively\n", 322 | " # (context and completion is delimited by \"\\n\")\n", 323 | " context_list = list(zip(*batch))[0]\n", 324 | " context_list = [c + \"\\n\" for c in context_list]\n", 325 | " completion_list = list(zip(*batch))[1]\n", 326 | " context_result = tokenizer(context_list)\n", 327 | " context_tokens = context_result[\"input_ids\"]\n", 328 | " context_masks = context_result[\"attention_mask\"]\n", 329 | " completion_result = tokenizer(completion_list)\n", 330 | " completion_tokens = completion_result[\"input_ids\"]\n", 331 | " completion_masks = completion_result[\"attention_mask\"]\n", 332 | " # concatenate token\n", 333 | " inputs = [i + j for i, j in zip(context_tokens, completion_tokens)]\n", 334 | " masks = [i + j for i, j in zip(context_masks, completion_masks)]\n", 335 | " # create label\n", 336 | " eos_id = tokenizer.encode(tokenizer.eos_token)[0]\n", 337 | " labels = [t[1:] + [eos_id] for t in inputs]\n", 338 | " labels = list(map(fill_ignore_label, labels, context_tokens))\n", 339 | " # truncate and pad tokens\n", 340 | " inputs = [pad_tokens(t, block_size, 0) for t in inputs] # OPT and GPT-2 doesn't use pad token (instead attn mask is used)\n", 341 | " masks = [pad_tokens(t, block_size, 0) for t in masks]\n", 342 | " labels = [pad_tokens(t, block_size, -100) for t in labels]\n", 343 | " # convert to tensor\n", 344 | " inputs = torch.tensor(inputs, dtype=torch.int64).to(device)\n", 345 | " masks = torch.tensor(masks, dtype=torch.int64).to(device)\n", 346 | " labels = torch.tensor(labels, dtype=torch.int64).to(device)\n", 347 | " return inputs, labels, masks" 348 | ] 349 | }, 350 | { 351 | "cell_type": "markdown", 352 | "id": "2084d2e9-ef64-47a2-aec9-d24ead1cb38a", 353 | "metadata": {}, 354 | "source": [ 355 | "Now create PyTorch dataloader with previous function (collator function).\n", 356 | "\n", 357 | "> Note : In this example, data is small and we then load all JSON data in memory.
\n", 358 | "> When it's large, load data progressively by implementing custom PyTorch dataset. (See [here](https://github.com/tsmatz/decision-transformer) for example.)" 359 | ] 360 | }, 361 | { 362 | "cell_type": "code", 363 | "execution_count": 11, 364 | "id": "f3bce3bb-2215-4bd6-a6a6-5b6b9d5afdc0", 365 | "metadata": {}, 366 | "outputs": [], 367 | "source": [ 368 | "batch_size = 8\n", 369 | "gradient_accumulation_steps = 16\n", 370 | "\n", 371 | "data = pd.read_json(\"train_formatted.jsonl\", lines=True)\n", 372 | "dataloader = DataLoader(\n", 373 | " list(zip(data[\"context\"], data[\"completion\"])),\n", 374 | " batch_size=batch_size,\n", 375 | " shuffle=True,\n", 376 | " collate_fn=collate_batch\n", 377 | ")" 378 | ] 379 | }, 380 | { 381 | "cell_type": "markdown", 382 | "id": "3ba64144-b698-457e-b827-941020456536", 383 | "metadata": {}, 384 | "source": [ 385 | "## Load model" 386 | ] 387 | }, 388 | { 389 | "cell_type": "markdown", 390 | "id": "1bfd360d-7bdc-4fd7-9b12-bcf9fe0a8db2", 391 | "metadata": {}, 392 | "source": [ 393 | "Load model from Hugging Face." 394 | ] 395 | }, 396 | { 397 | "cell_type": "code", 398 | "execution_count": 12, 399 | "id": "271181bd-677a-4da9-9e57-2874f5e47bd0", 400 | "metadata": {}, 401 | "outputs": [], 402 | "source": [ 403 | "from transformers import AutoModelForCausalLM, AutoConfig\n", 404 | "\n", 405 | "config = AutoConfig.from_pretrained(model_name)\n", 406 | "model = AutoModelForCausalLM.from_pretrained(\n", 407 | " model_name,\n", 408 | " config=config,\n", 409 | ").to(device)" 410 | ] 411 | }, 412 | { 413 | "cell_type": "markdown", 414 | "id": "27ab764a-d634-40f8-9edb-a01146845233", 415 | "metadata": {}, 416 | "source": [ 417 | "## Generate text (before fine-tuning)" 418 | ] 419 | }, 420 | { 421 | "cell_type": "markdown", 422 | "id": "559efeaf-4b38-4a0c-9be6-eb394221e374", 423 | "metadata": {}, 424 | "source": [ 425 | "Now run prediction with downloaded model (which is not still fine-tuned).\n", 426 | "\n", 427 | "First we create a function to generate text." 428 | ] 429 | }, 430 | { 431 | "cell_type": "code", 432 | "execution_count": 13, 433 | "id": "51a0c4fc-e0a7-4bbf-b25a-c335fe61f3df", 434 | "metadata": {}, 435 | "outputs": [], 436 | "source": [ 437 | "def generate_text(model, input, mask, eos_id, pred_sequence_length):\n", 438 | " predicted_last_id = -1\n", 439 | " start_token_len = torch.sum(mask).cpu().numpy()\n", 440 | " token_len = start_token_len\n", 441 | " with torch.no_grad():\n", 442 | " while (predicted_last_id != eos_id) and \\\n", 443 | " (token_len - start_token_len < pred_sequence_length):\n", 444 | " output = model(\n", 445 | " input_ids=input,\n", 446 | " attention_mask=mask,\n", 447 | " )\n", 448 | " predicted_ids = torch.argmax(output.logits, axis=-1).cpu().numpy()\n", 449 | " predicted_last_id = predicted_ids[0][token_len - 1]\n", 450 | " input[0][token_len] = predicted_last_id\n", 451 | " mask[0][token_len] = 1\n", 452 | " token_len = torch.sum(mask).cpu().numpy()\n", 453 | " return input, token_len" 454 | ] 455 | }, 456 | { 457 | "cell_type": "markdown", 458 | "id": "3936b1a1-ae9f-48a5-80db-691261dda704", 459 | "metadata": {}, 460 | "source": [ 461 | "Let's test our function and generate text. (Here we stop the text generation when it reaches 15 tokens in prediction.)" 462 | ] 463 | }, 464 | { 465 | "cell_type": "code", 466 | "execution_count": 14, 467 | "id": "28b7e13f-e8fb-4a9f-90ed-0464463ef569", 468 | "metadata": {}, 469 | "outputs": [ 470 | { 471 | "name": "stdout", 472 | "output_type": "stream", 473 | "text": [ 474 | "Once upon a time, the world was a place of great beauty and great danger. The world was\n", 475 | "My name is Clara and I am a woman. I am a woman who is a woman. I am a\n" 476 | ] 477 | } 478 | ], 479 | "source": [ 480 | "eos_id = tokenizer.encode(tokenizer.eos_token)[0]\n", 481 | "\n", 482 | "result = tokenizer(\"Once upon a time,\")\n", 483 | "input = result[\"input_ids\"]\n", 484 | "mask = result[\"attention_mask\"]\n", 485 | "input = pad_tokens(input, block_size, 0)\n", 486 | "mask = pad_tokens(mask, block_size, 0)\n", 487 | "input = torch.tensor([input], dtype=torch.int64).to(device)\n", 488 | "mask = torch.tensor([mask], dtype=torch.int64).to(device)\n", 489 | "\n", 490 | "result_token, result_len = generate_text(\n", 491 | " model,\n", 492 | " input,\n", 493 | " mask,\n", 494 | " eos_id,\n", 495 | " pred_sequence_length=15)\n", 496 | "print(tokenizer.decode(result_token[0][:result_len]))\n", 497 | "\n", 498 | "result = tokenizer(\"My name is Clara and I am\")\n", 499 | "input = result[\"input_ids\"]\n", 500 | "mask = result[\"attention_mask\"]\n", 501 | "input = pad_tokens(input, block_size, 0)\n", 502 | "mask = pad_tokens(mask, block_size, 0)\n", 503 | "input = torch.tensor([input], dtype=torch.int64).to(device)\n", 504 | "mask = torch.tensor([mask], dtype=torch.int64).to(device)\n", 505 | "\n", 506 | "result_token, result_len = generate_text(\n", 507 | " model,\n", 508 | " input,\n", 509 | " mask,\n", 510 | " eos_id,\n", 511 | " pred_sequence_length=15)\n", 512 | "print(tokenizer.decode(result_token[0][:result_len]))" 513 | ] 514 | }, 515 | { 516 | "cell_type": "markdown", 517 | "id": "d48fb60b-c05d-4884-a9bc-92152c94c894", 518 | "metadata": {}, 519 | "source": [ 520 | "Now we generate text with our test dataset (5 rows).
\n", 521 | "As you can see below, it cannot output the completion well, because it's not still fine-tuned." 522 | ] 523 | }, 524 | { 525 | "cell_type": "code", 526 | "execution_count": 15, 527 | "id": "495728ef-fbe6-4953-a354-4b7a8bb88798", 528 | "metadata": {}, 529 | "outputs": [ 530 | { 531 | "name": "stdout", 532 | "output_type": "stream", 533 | "text": [ 534 | "********** input **********\n", 535 | "name : Wildwood | Type : pub | food : Indian | area : city centre | family friendly : yes | near : Raja Indian Cuisine\n", 536 | "\n", 537 | "********** result **********\n", 538 | "name : Wildwood | Type : pub | food : Indian | area : city centre | family friendly : yes | near : Raja Indian Cuisine\n", 539 | "\n", 540 | "Raja Indian Cuisine : Indian | price : Rs. 1,000 | menu : Indian | menu type : food | menu size :\n", 541 | "********** input **********\n", 542 | "name : Giraffe | Type : pub | food : Fast food | area : riverside | family friendly : yes | near : Rainbow Vegetarian Café\n", 543 | "\n", 544 | "********** result **********\n", 545 | "name : Giraffe | Type : pub | food : Fast food | area : riverside | family friendly : yes | near : Rainbow Vegetarian Café\n", 546 | "\n", 547 | ": Giraffe | Type : pub | food : Fast food | area : riverside | family friendly : yes | near : Rainbow Vegetarian Café\n", 548 | "********** input **********\n", 549 | "name : The Waterman | Type : pub | food : Italian | price : less than £ 20 | area : city centre | family friendly : yes | near : Raja Indian Cuisine\n", 550 | "\n", 551 | "********** result **********\n", 552 | "name : The Waterman | Type : pub | food : Italian | price : less than £ 20 | area : city centre | family friendly : yes | near : Raja Indian Cuisine\n", 553 | "\n", 554 | "The Waterman is a pub in the heart of the city centre. It is a place where you can enjoy a good meal and drink a good\n", 555 | "********** input **********\n", 556 | "name : The Vaults | Type : pub | food : Italian | price : moderate | customer rating : 1 out of 5 | area : city centre | family friendly : no | near : Rainbow Vegetarian Café\n", 557 | "\n", 558 | "********** result **********\n", 559 | "name : The Vaults | Type : pub | food : Italian | price : moderate | customer rating : 1 out of 5 | area : city centre | family friendly : no | near : Rainbow Vegetarian Café\n", 560 | "\n", 561 | "The Vaults | Type : pub | food : Italian | price : moderate | customer rating : 1 out of 5 | area : city centre | family\n", 562 | "********** input **********\n", 563 | "name : The Vaults | Type : restaurant | food : French | price : less than £ 20 | area : riverside | family friendly : yes | near : Raja Indian Cuisine\n", 564 | "\n", 565 | "********** result **********\n", 566 | "name : The Vaults | Type : restaurant | food : French | price : less than £ 20 | area : riverside | family friendly : yes | near : Raja Indian Cuisine\n", 567 | "\n", 568 | "The restaurant is located in the centre of the city. It is a small restaurant with a small menu. The menu is very simple and the food\n" 569 | ] 570 | } 571 | ], 572 | "source": [ 573 | "test_data = pd.read_json(\"test_formatted.jsonl\", lines=True)\n", 574 | "test_data = test_data[::2] # because it's duplicated\n", 575 | "test_loader = DataLoader(\n", 576 | " list(zip(test_data[\"context\"], [\"\"] * len(test_data[\"context\"]))),\n", 577 | " batch_size=1,\n", 578 | " shuffle=True,\n", 579 | " collate_fn=collate_batch\n", 580 | ")\n", 581 | "\n", 582 | "for i, (input, _, mask) in enumerate(test_loader):\n", 583 | " if i == 5:\n", 584 | " break\n", 585 | " print(\"********** input **********\")\n", 586 | " input_len = torch.sum(mask).cpu().numpy()\n", 587 | " print(tokenizer.decode(input[0][:input_len]))\n", 588 | " result_token, result_len = generate_text(\n", 589 | " model,\n", 590 | " input,\n", 591 | " mask,\n", 592 | " eos_id,\n", 593 | " pred_sequence_length=30)\n", 594 | " print(\"********** result **********\")\n", 595 | " print(tokenizer.decode(result_token[0][:result_len]))" 596 | ] 597 | }, 598 | { 599 | "cell_type": "markdown", 600 | "id": "e3138341-e01c-4fae-af78-c61e34967e92", 601 | "metadata": {}, 602 | "source": [ 603 | "## LoRA (Low-Rank Adaptation)\n", 604 | "\n", 605 | "Now we apply LoRA in our downloaded model.
\n", 606 | "For semantics of LoRA (Low-Rank Adaptation), see [01-finetune-opt-with-lora.ipynb](./01-finetune-opt-with-lora.ipynb).\n", 607 | "\n", 608 | "For the purpose of your learning, here I manually (from scratch) convert the current model into the model with LoRA.\n", 609 | "\n", 610 | "> Note : You can use ```PEFT``` package to be able to get LoRA model with a few lines of code. (Here I don't use this package.)" 611 | ] 612 | }, 613 | { 614 | "cell_type": "markdown", 615 | "id": "e296bdf2-129a-4278-8fe3-d08333ebf1df", 616 | "metadata": {}, 617 | "source": [ 618 | "Before changing our model, first we check the structure of our model. (See the following result in the cell.)\n", 619 | "\n", 620 | "As you can see below, you cannot find any linear layers in OpenAI's GPT-2 transformer, unlike [Meta's OPT transformer](./01-finetune-opt-with-lora.ipynb). Instead, you will find Conv1D (1D convolution) in transformer.
\n", 621 | "However, this Conv1D is not ```torch.nn.Conv1d``` and it's a custom layer defined for OpenAI GPT, which works same as a linear layer, but the weights are transposed. (See [this source code](https://github.com/huggingface/transformers/blob/main/src/transformers/pytorch_utils.py) for custom ```pytorch_utils.Conv1D``` implementation.)
\n", 622 | "This custom Conv1D layer (intrinsically, a linear layer) is used for MLP and getting key/value/query in GPT-2 transformer as follows.
\n", 623 | "(See [source code](https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py) for GPT-2 in Hugging Face tarnsformers.)\n", 624 | "\n", 625 | "- ```transformer.h.n.attn.c_attn``` : Layer to get key/value/query before processing attention.\n", 626 | "- ```transformer.h.n.attn.c_proj``` : Layer for projection after processing attention.\n", 627 | "- ```transformer.h.n.mlp.c_attn``` : MLP in GPT-2 is Linear(GeLU(Linear)). This is an inner Linear layer (custom Conv1D).\n", 628 | "- ```transformer.h.n.mlp.c_proj``` : MLP in GPT-2 is Linear(GeLU(Linear)). This is an outer Linear layer (custom Conv1D).\n", 629 | "\n", 630 | "In this example, we'll only convert ```transformer.h.n.attn.c_attn``` layers into LoRA layers.
\n", 631 | "The transformer in GPT-2 with 124M parameters has 12 layers and it then has total 12 layers (n=0,1, ... , 11) to be converted." 632 | ] 633 | }, 634 | { 635 | "cell_type": "code", 636 | "execution_count": 16, 637 | "id": "5acb8f62-791a-4fa4-b00c-2666cf34827f", 638 | "metadata": {}, 639 | "outputs": [ 640 | { 641 | "data": { 642 | "text/plain": [ 643 | "GPT2LMHeadModel(\n", 644 | " (transformer): GPT2Model(\n", 645 | " (wte): Embedding(50257, 768)\n", 646 | " (wpe): Embedding(1024, 768)\n", 647 | " (drop): Dropout(p=0.1, inplace=False)\n", 648 | " (h): ModuleList(\n", 649 | " (0-11): 12 x GPT2Block(\n", 650 | " (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", 651 | " (attn): GPT2Attention(\n", 652 | " (c_attn): Conv1D()\n", 653 | " (c_proj): Conv1D()\n", 654 | " (attn_dropout): Dropout(p=0.1, inplace=False)\n", 655 | " (resid_dropout): Dropout(p=0.1, inplace=False)\n", 656 | " )\n", 657 | " (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", 658 | " (mlp): GPT2MLP(\n", 659 | " (c_fc): Conv1D()\n", 660 | " (c_proj): Conv1D()\n", 661 | " (act): NewGELUActivation()\n", 662 | " (dropout): Dropout(p=0.1, inplace=False)\n", 663 | " )\n", 664 | " )\n", 665 | " )\n", 666 | " (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", 667 | " )\n", 668 | " (lm_head): Linear(in_features=768, out_features=50257, bias=False)\n", 669 | ")" 670 | ] 671 | }, 672 | "execution_count": 16, 673 | "metadata": {}, 674 | "output_type": "execute_result" 675 | } 676 | ], 677 | "source": [ 678 | "model" 679 | ] 680 | }, 681 | { 682 | "cell_type": "markdown", 683 | "id": "045e7239-cb8a-46dd-815d-e48e7e49eea4", 684 | "metadata": {}, 685 | "source": [ 686 | "First we build custom linear layer with LoRA as follows." 687 | ] 688 | }, 689 | { 690 | "cell_type": "code", 691 | "execution_count": 17, 692 | "id": "77889272-9a93-491b-93cb-b0bed5ce7cd8", 693 | "metadata": {}, 694 | "outputs": [], 695 | "source": [ 696 | "import math\n", 697 | "from torch import nn\n", 698 | "\n", 699 | "class LoRA_Linear(nn.Module):\n", 700 | " def __init__(self, weight, bias, lora_dim):\n", 701 | " super(LoRA_Linear, self).__init__()\n", 702 | "\n", 703 | " row, column = weight.shape\n", 704 | "\n", 705 | " # restore Linear\n", 706 | " if bias is None:\n", 707 | " self.linear = nn.Linear(column, row, bias=False)\n", 708 | " self.linear.load_state_dict({\"weight\": weight})\n", 709 | " else:\n", 710 | " self.linear = nn.Linear(column, row)\n", 711 | " self.linear.load_state_dict({\"weight\": weight, \"bias\": bias})\n", 712 | "\n", 713 | " # create LoRA weights (with initialization)\n", 714 | " self.lora_right = nn.Parameter(torch.zeros(column, lora_dim))\n", 715 | " nn.init.kaiming_uniform_(self.lora_right, a=math.sqrt(5))\n", 716 | " self.lora_left = nn.Parameter(torch.zeros(lora_dim, row))\n", 717 | "\n", 718 | " def forward(self, input):\n", 719 | " x = self.linear(input)\n", 720 | " y = input @ self.lora_right @ self.lora_left\n", 721 | " return x + y" 722 | ] 723 | }, 724 | { 725 | "cell_type": "markdown", 726 | "id": "954e2c9d-545e-4bd9-9b0f-eba3fe29a1de", 727 | "metadata": {}, 728 | "source": [ 729 | "Replace targeting linear layers with LoRA layers.\n", 730 | "\n", 731 | "> Note : As I have mentioned above, custom Conv1D layer in GPT-2 is intrinsically a linear layer, but the weights are transposed." 732 | ] 733 | }, 734 | { 735 | "cell_type": "code", 736 | "execution_count": 18, 737 | "id": "baf8a748-a3e3-45b8-9c64-252c56abe923", 738 | "metadata": {}, 739 | "outputs": [], 740 | "source": [ 741 | "lora_dim = 128\n", 742 | "\n", 743 | "# get target module name\n", 744 | "target_names = []\n", 745 | "for name, module in model.named_modules():\n", 746 | " if \"attn.c_attn\" in name:\n", 747 | " target_names.append(name)\n", 748 | "\n", 749 | "# replace each module with LoRA\n", 750 | "for name in target_names:\n", 751 | " name_struct = name.split(\".\")\n", 752 | " # get target module\n", 753 | " module_list = [model]\n", 754 | " for struct in name_struct:\n", 755 | " module_list.append(getattr(module_list[-1], struct))\n", 756 | " # build LoRA\n", 757 | " lora = LoRA_Linear(\n", 758 | " weight = torch.transpose(module_list[-1].weight, 0, 1),\n", 759 | " bias = module_list[-1].bias,\n", 760 | " lora_dim = lora_dim,\n", 761 | " ).to(device)\n", 762 | " # replace\n", 763 | " module_list[-2].__setattr__(name_struct[-1], lora)" 764 | ] 765 | }, 766 | { 767 | "cell_type": "markdown", 768 | "id": "8aae2df9-fae7-4ecc-8260-80e8e578d951", 769 | "metadata": {}, 770 | "source": [ 771 | "See how model is changed." 772 | ] 773 | }, 774 | { 775 | "cell_type": "code", 776 | "execution_count": 19, 777 | "id": "bf16b414-b973-40eb-be81-fd2aa3dde439", 778 | "metadata": {}, 779 | "outputs": [ 780 | { 781 | "data": { 782 | "text/plain": [ 783 | "GPT2LMHeadModel(\n", 784 | " (transformer): GPT2Model(\n", 785 | " (wte): Embedding(50257, 768)\n", 786 | " (wpe): Embedding(1024, 768)\n", 787 | " (drop): Dropout(p=0.1, inplace=False)\n", 788 | " (h): ModuleList(\n", 789 | " (0-11): 12 x GPT2Block(\n", 790 | " (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", 791 | " (attn): GPT2Attention(\n", 792 | " (c_attn): LoRA_Linear(\n", 793 | " (linear): Linear(in_features=768, out_features=2304, bias=True)\n", 794 | " )\n", 795 | " (c_proj): Conv1D()\n", 796 | " (attn_dropout): Dropout(p=0.1, inplace=False)\n", 797 | " (resid_dropout): Dropout(p=0.1, inplace=False)\n", 798 | " )\n", 799 | " (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", 800 | " (mlp): GPT2MLP(\n", 801 | " (c_fc): Conv1D()\n", 802 | " (c_proj): Conv1D()\n", 803 | " (act): NewGELUActivation()\n", 804 | " (dropout): Dropout(p=0.1, inplace=False)\n", 805 | " )\n", 806 | " )\n", 807 | " )\n", 808 | " (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", 809 | " )\n", 810 | " (lm_head): Linear(in_features=768, out_features=50257, bias=False)\n", 811 | ")" 812 | ] 813 | }, 814 | "execution_count": 19, 815 | "metadata": {}, 816 | "output_type": "execute_result" 817 | } 818 | ], 819 | "source": [ 820 | "model" 821 | ] 822 | }, 823 | { 824 | "cell_type": "markdown", 825 | "id": "e9099c08-f6a6-45f8-939b-cc3ed9415976", 826 | "metadata": {}, 827 | "source": [ 828 | "Finally, freeze all parameters except for LoRA parameters." 829 | ] 830 | }, 831 | { 832 | "cell_type": "code", 833 | "execution_count": 20, 834 | "id": "81d06bba-955b-4806-8ff7-f217252e3268", 835 | "metadata": {}, 836 | "outputs": [], 837 | "source": [ 838 | "for name, param in model.named_parameters():\n", 839 | " if \"lora_right\" in name or \"lora_left\" in name:\n", 840 | " param.requires_grad = True\n", 841 | " else:\n", 842 | " param.requires_grad = False" 843 | ] 844 | }, 845 | { 846 | "cell_type": "code", 847 | "execution_count": null, 848 | "id": "6c0a4469-2827-4f30-9324-711a9feea1ae", 849 | "metadata": {}, 850 | "outputs": [], 851 | "source": [ 852 | "### Do this when you run adapter fine-tuning on Hugging Face framework\n", 853 | "# model.gradient_checkpointing_enable()\n", 854 | "# model.enable_input_require_grads()" 855 | ] 856 | }, 857 | { 858 | "cell_type": "markdown", 859 | "id": "6d6c7d6f-6c50-4839-88a5-c851caab9ba2", 860 | "metadata": {}, 861 | "source": [ 862 | "## Fine-tune" 863 | ] 864 | }, 865 | { 866 | "cell_type": "markdown", 867 | "id": "a12b875f-36cc-40b8-aaab-1efda68710f3", 868 | "metadata": {}, 869 | "source": [ 870 | "Now let's start to run fine-tuning.\n", 871 | "\n", 872 | "First we build optimizer as follows." 873 | ] 874 | }, 875 | { 876 | "cell_type": "code", 877 | "execution_count": 21, 878 | "id": "bb51298a-2d55-466c-a990-0ea08a247350", 879 | "metadata": {}, 880 | "outputs": [], 881 | "source": [ 882 | "optimizer = torch.optim.AdamW(\n", 883 | " params=model.parameters(),\n", 884 | " lr=0.0002,\n", 885 | " betas=(0.9, 0.999),\n", 886 | " eps=1e-6,\n", 887 | ")" 888 | ] 889 | }, 890 | { 891 | "cell_type": "markdown", 892 | "id": "d37db1a8-0053-4acc-94ce-89d87c78942e", 893 | "metadata": {}, 894 | "source": [ 895 | "In this example, we build linear scheduler for training." 896 | ] 897 | }, 898 | { 899 | "cell_type": "code", 900 | "execution_count": 22, 901 | "id": "6f95bdf6-4498-4d40-90aa-1267d55f38c3", 902 | "metadata": {}, 903 | "outputs": [], 904 | "source": [ 905 | "from torch.optim.lr_scheduler import LambdaLR\n", 906 | "\n", 907 | "num_epochs = 2\n", 908 | "num_warmup_steps = 500\n", 909 | "\n", 910 | "num_update_steps = math.ceil(len(dataloader) / batch_size / gradient_accumulation_steps)\n", 911 | "def _get_linear_schedule(current_step):\n", 912 | " if current_step < num_warmup_steps:\n", 913 | " return float(current_step) / float(max(1, num_warmup_steps))\n", 914 | " return max(0.0, float(num_update_steps * num_epochs - current_step) / float(max(1, num_update_steps * num_epochs - num_warmup_steps)))\n", 915 | "scheduler = LambdaLR(optimizer, lr_lambda=_get_linear_schedule)" 916 | ] 917 | }, 918 | { 919 | "cell_type": "markdown", 920 | "id": "a9f9e828-c4fb-493d-a6de-78e03dbf035e", 921 | "metadata": {}, 922 | "source": [ 923 | "Run fine-tuning." 924 | ] 925 | }, 926 | { 927 | "cell_type": "code", 928 | "execution_count": 23, 929 | "id": "3752481d-8ee8-4c43-b677-add136a2fd5b", 930 | "metadata": {}, 931 | "outputs": [ 932 | { 933 | "name": "stdout", 934 | "output_type": "stream", 935 | "text": [ 936 | "Epoch 1 42/42 - loss: 1.3620\n", 937 | "Epoch 2 42/42 - loss: 1.4432\n" 938 | ] 939 | } 940 | ], 941 | "source": [ 942 | "from torch.nn import functional as F\n", 943 | "\n", 944 | "if os.path.exists(\"loss.txt\"):\n", 945 | " os.remove(\"loss.txt\")\n", 946 | "\n", 947 | "for epoch in range(num_epochs):\n", 948 | " optimizer.zero_grad()\n", 949 | " model.train()\n", 950 | " for i, (inputs, labels, masks) in enumerate(dataloader):\n", 951 | " with torch.set_grad_enabled(True):\n", 952 | " outputs = model(\n", 953 | " input_ids=inputs,\n", 954 | " attention_mask=masks,\n", 955 | " )\n", 956 | " loss = F.cross_entropy(outputs.logits.transpose(1,2), labels)\n", 957 | " loss.backward()\n", 958 | " if ((i + 1) % gradient_accumulation_steps == 0) or \\\n", 959 | " (i + 1 == len(dataloader)):\n", 960 | " optimizer.step()\n", 961 | " scheduler.step()\n", 962 | " optimizer.zero_grad()\n", 963 | "\n", 964 | " print(f\"Epoch {epoch+1} {math.ceil((i + 1) / batch_size / gradient_accumulation_steps)}/{num_update_steps} - loss: {loss.item() :2.4f}\", end=\"\\r\")\n", 965 | "\n", 966 | " # record loss\n", 967 | " with open(\"loss.txt\", \"a\") as f:\n", 968 | " f.write(str(loss.item()))\n", 969 | " f.write(\"\\n\")\n", 970 | " print(\"\")\n", 971 | "\n", 972 | "# save model\n", 973 | "torch.save(model.state_dict(), \"finetuned_gpt2.bin\")" 974 | ] 975 | }, 976 | { 977 | "cell_type": "markdown", 978 | "id": "83993d92-d7ed-4a07-8985-cc59bd4e4fef", 979 | "metadata": {}, 980 | "source": [ 981 | "> Note : Here we save LoRA-enabled model without any changes, but you can also merge the trained LoRA's parameters into the original model's weights." 982 | ] 983 | }, 984 | { 985 | "cell_type": "markdown", 986 | "id": "1bc086e5-e93f-4264-a8fa-6428f844ac3c", 987 | "metadata": {}, 988 | "source": [ 989 | "Show loss transition in plot." 990 | ] 991 | }, 992 | { 993 | "cell_type": "code", 994 | "execution_count": 24, 995 | "id": "e37c5aee-38d4-4a2a-952c-4fd2bef41e2b", 996 | "metadata": {}, 997 | "outputs": [ 998 | { 999 | "data": { 1000 | "image/png": "", 1001 | "text/plain": [ 1002 | "
" 1003 | ] 1004 | }, 1005 | "metadata": {}, 1006 | "output_type": "display_data" 1007 | } 1008 | ], 1009 | "source": [ 1010 | "import matplotlib.pyplot as plt\n", 1011 | "import pandas as pd\n", 1012 | "\n", 1013 | "data = pd.read_csv(\"loss.txt\")\n", 1014 | "plt.plot(data)\n", 1015 | "plt.show()" 1016 | ] 1017 | }, 1018 | { 1019 | "cell_type": "markdown", 1020 | "id": "9809bc9f-4ff6-46c3-9c43-08c6c2694a82", 1021 | "metadata": {}, 1022 | "source": [ 1023 | "## Generate text with fine-tuned model\n", 1024 | "\n", 1025 | "Again we check results with our test dataset (5 rows).
\n", 1026 | "As you can see below, it can output the completion very well, because it's fine-tuned." 1027 | ] 1028 | }, 1029 | { 1030 | "cell_type": "code", 1031 | "execution_count": 25, 1032 | "id": "29903cae-404e-4209-9c84-6c8a69609c13", 1033 | "metadata": {}, 1034 | "outputs": [ 1035 | { 1036 | "name": "stdout", 1037 | "output_type": "stream", 1038 | "text": [ 1039 | "********** input **********\n", 1040 | "name : The Vaults | Type : pub | food : Italian | price : less than £ 20 | customer rating : low | area : city centre | family friendly : no | near : Rainbow Vegetarian Café\n", 1041 | "\n", 1042 | "********** result **********\n", 1043 | "name : The Vaults | Type : pub | food : Italian | price : less than £ 20 | customer rating : low | area : city centre | family friendly : no | near : Rainbow Vegetarian Café\n", 1044 | "The Vaults is a pub near the Rainbow Vegetarian Café in the city centre. It is not family friendly and has a low customer rating of less than\n", 1045 | "********** input **********\n", 1046 | "name : The Cricketers | Type : restaurant | customer rating : average | family friendly : yes | near : Café Sicilia\n", 1047 | "\n", 1048 | "********** result **********\n", 1049 | "name : The Cricketers | Type : restaurant | customer rating : average | family friendly : yes | near : Café Sicilia\n", 1050 | "The Cricketers is a restaurant near Café Sicilia. It is family friendly and has an average customer rating.<|endoftext|>\n", 1051 | "********** input **********\n", 1052 | "name : The Cricketers | Type : restaurant | food : Chinese | price : cheap | customer rating : average | area : city centre | family friendly : no | near : All Bar One\n", 1053 | "\n", 1054 | "********** result **********\n", 1055 | "name : The Cricketers | Type : restaurant | food : Chinese | price : cheap | customer rating : average | area : city centre | family friendly : no | near : All Bar One\n", 1056 | "The Cricketers is a restaurant located in the city centre near All Bar One. It is not family - friendly. It is located in the cheap\n", 1057 | "********** input **********\n", 1058 | "name : The Vaults | Type : pub | food : Japanese | price : cheap | customer rating : 5 out of 5 | area : city centre | family friendly : yes | near : Raja Indian Cuisine\n", 1059 | "\n", 1060 | "********** result **********\n", 1061 | "name : The Vaults | Type : pub | food : Japanese | price : cheap | customer rating : 5 out of 5 | area : city centre | family friendly : yes | near : Raja Indian Cuisine\n", 1062 | "The Vaults is a cheap, family friendly pub located in the city centre near Raja Indian Cuisine.<|endoftext|>\n", 1063 | "********** input **********\n", 1064 | "name : The Wrestlers | Type : pub | food : Italian | price : less than £ 20 | area : riverside | family friendly : no | near : Raja Indian Cuisine\n", 1065 | "\n", 1066 | "********** result **********\n", 1067 | "name : The Wrestlers | Type : pub | food : Italian | price : less than £ 20 | area : riverside | family friendly : no | near : Raja Indian Cuisine\n", 1068 | "The Wrestlers is a pub near Raja Indian Cuisine in riverside. It is not family friendly.<|endoftext|>\n" 1069 | ] 1070 | } 1071 | ], 1072 | "source": [ 1073 | "test_data = pd.read_json(\"test_formatted.jsonl\", lines=True)\n", 1074 | "test_data = test_data[::2] # because it's duplicated\n", 1075 | "test_loader = DataLoader(\n", 1076 | " list(zip(test_data[\"context\"], [\"\"] * len(test_data[\"context\"]))),\n", 1077 | " batch_size=1,\n", 1078 | " shuffle=True,\n", 1079 | " collate_fn=collate_batch\n", 1080 | ")\n", 1081 | "\n", 1082 | "for i, (input, _, mask) in enumerate(test_loader):\n", 1083 | " if i == 5:\n", 1084 | " break\n", 1085 | " print(\"********** input **********\")\n", 1086 | " input_len = torch.sum(mask).cpu().numpy()\n", 1087 | " print(tokenizer.decode(input[0][:input_len]))\n", 1088 | " result_token, result_len = generate_text(\n", 1089 | " model,\n", 1090 | " input,\n", 1091 | " mask,\n", 1092 | " eos_id,\n", 1093 | " pred_sequence_length=30)\n", 1094 | " print(\"********** result **********\")\n", 1095 | " print(tokenizer.decode(result_token[0][:result_len]))" 1096 | ] 1097 | }, 1098 | { 1099 | "cell_type": "code", 1100 | "execution_count": null, 1101 | "id": "6a7c1dd3-4057-497a-83ae-f99b1883697e", 1102 | "metadata": {}, 1103 | "outputs": [], 1104 | "source": [] 1105 | } 1106 | ], 1107 | "metadata": { 1108 | "kernelspec": { 1109 | "display_name": "Python 3 (ipykernel)", 1110 | "language": "python", 1111 | "name": "python3" 1112 | }, 1113 | "language_info": { 1114 | "codemirror_mode": { 1115 | "name": "ipython", 1116 | "version": 3 1117 | }, 1118 | "file_extension": ".py", 1119 | "mimetype": "text/x-python", 1120 | "name": "python", 1121 | "nbconvert_exporter": "python", 1122 | "pygments_lexer": "ipython3", 1123 | "version": "3.8.10" 1124 | } 1125 | }, 1126 | "nbformat": 4, 1127 | "nbformat_minor": 5 1128 | } 1129 | -------------------------------------------------------------------------------- /Readme.md: -------------------------------------------------------------------------------- 1 | # Fine-tuning LLM with LoRA (Low-Rank Adaptation) 2 | 3 | LoRA (Low-Rank Adaptation) is one of mostly used parameter-efficient fine-tuning (PEFT) methods today. 4 | 5 | This example shows you [LoRA (Low-Rank Adaptation)](https://arxiv.org/abs/2106.09685) implementation from scratch (manually) in a step-by-step manner (without ```PEFT``` package), and also shows you clear ideas behind this implementation in IPython notebook. 6 | 7 | This is also runnable in the mainstream hardware with small footprint - such as, a signle GPU of Tesla T4, consumer GPUs (NVIDIA RTX), etc - for you to try this code easily. 8 | 9 | | Example | Description | 10 | | -------------------------------------------------------------------- | ----------------------------------------------------------------------- | 11 | | [01-finetune-opt-with-lora.ipynb](01-finetune-opt-with-lora.ipynb) | Fine-tuning Meta's OPT-125M with LoRA
(Also, explaining LoRA method) | 12 | | [02-finetune-gpt2-with-lora.ipynb](02-finetune-gpt2-with-lora.ipynb) | Fine-tuning OpenAI's GPT-2 small (124M) with LoRA | 13 | 14 | Unlike examples in [official repository](https://github.com/microsoft/LoRA), here I download pre-trained models to focus on LoRA implementation. 15 | 16 | > Note : In this repository, Hugging Face API is used to download pre-trained models and I then apply regular PyTorch training loop for fine-tuning. (I don't use blackboxed ```Trainer``` class in Hugging Face API.) 17 | 18 | ## 1. Set-up and Install 19 | 20 | To run this example, please install prerequisite's software and setup your environment as follows.
21 | In the following setting, I have used a GPU-utilized virtual machine (VM) with "Ubuntu Server 20.04 LTS" image in Microsoft Azure. 22 | 23 | ### Install GPU driver (CUDA) 24 | 25 | Install CUDA (NVIDIA GPU driver) as follows. 26 | 27 | ``` 28 | # compilers and development settings 29 | sudo apt-get update 30 | sudo apt install -y gcc 31 | sudo apt-get install -y make 32 | 33 | # install CUDA 34 | wget https://developer.download.nvidia.com/compute/cuda/12.2.2/local_installers/cuda_12.2.2_535.104.05_linux.run 35 | sudo sh cuda_12.2.2_535.104.05_linux.run 36 | echo -e "export LD_LIBRARY_PATH=/usr/local/cuda-12.2/lib64" >> ~/.bashrc 37 | source ~/.bashrc 38 | ``` 39 | 40 | ### Install packages 41 | 42 | Install PyTorch, Hugging Face transformer, and other libraries as follows. 43 | 44 | ``` 45 | # install and upgrade pip 46 | sudo apt-get install -y python3-pip 47 | sudo -H pip3 install --upgrade pip 48 | # install packages 49 | pip3 install torch transformers pandas matplotlib 50 | # install jupyter for running notebook 51 | pip3 install jupyter 52 | ``` 53 | 54 | ## 2. Fine-tune (Train) 55 | 56 | Download this repository. 57 | 58 | ``` 59 | git clone https://github.com/tsmatz/finetune_llm_with_lora 60 | ``` 61 | 62 | Run jupyter notebook. 63 | 64 | ``` 65 | jupyter notebook 66 | ``` 67 | 68 | Open jupyter notebook in browser, and run examples in this repository. 69 | -------------------------------------------------------------------------------- /images/auto_regressive_transformer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsmatz/finetune_llm_with_lora/2e84a5e9e5095aaeacacaa723ee5a7c34c36b678/images/auto_regressive_transformer.png -------------------------------------------------------------------------------- /images/lora.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsmatz/finetune_llm_with_lora/2e84a5e9e5095aaeacacaa723ee5a7c34c36b678/images/lora.png --------------------------------------------------------------------------------