├── 01-finetune-opt-with-lora.ipynb
├── 02-finetune-gpt2-with-lora.ipynb
├── Readme.md
└── images
├── auto_regressive_transformer.png
└── lora.png
/01-finetune-opt-with-lora.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "857cafc6-da38-4aa7-8afc-63aa626fa7aa",
6 | "metadata": {},
7 | "source": [
8 | "# 01. Finetuning OPT with LoRA\n",
9 | "\n",
10 | "Today's popular auto-regressive models - such as, GPT, LLaMA, Falcon, etc - are decoder-only models, in which the output token is predicted by using only input's text (called a prompt).\n",
11 | "\n",
12 | "\n",
13 | "\n",
14 | "*\"Decoder-only\" model is implemented using layers in the red box.
\n",
15 | "(Diagram from : [Attention Is All You Need](https://arxiv.org/abs/1706.03762))*\n",
16 | "\n",
17 | "In this model, the task is differentiated also by using input's text (i.e, prompt).\n",
18 | "\n",
19 | "> Note : See [this repository](https://github.com/tsmatz/nlp-tutorials) for intrinsic idea of LLM transformers.\n",
20 | "\n",
21 | "In this example, we fine-tune the pre-trained auto-regressive model, Meta's OPT (```facebook/opt-125m```), by applying LoRA (Low-Rank Adaptation) optimization.\n",
22 | "\n",
23 | "In this example, I download the pre-trained model from Hugging Face hub, but fine-tune model with regular PyTorch training loop.
\n",
24 | "(Here I don't use Hugging Face Trainer class.)\n",
25 | "\n",
26 | "See [Readme](https://github.com/tsmatz/finetune_llm_with_lora) for prerequisite's setup."
27 | ]
28 | },
29 | {
30 | "cell_type": "code",
31 | "execution_count": 1,
32 | "id": "3d49acf1-9ad1-4a6c-9312-6785cb3f5862",
33 | "metadata": {},
34 | "outputs": [],
35 | "source": [
36 | "model_name = \"facebook/opt-125m\"\n",
37 | "# model_name = \"facebook/opt-350m\"\n",
38 | "# model_name = \"facebook/opt-1.3b\"\n",
39 | "# model_name = \"facebook/opt-6.7b\""
40 | ]
41 | },
42 | {
43 | "cell_type": "code",
44 | "execution_count": 2,
45 | "id": "4d835e84-a01d-4c33-926b-60d9dd4a7627",
46 | "metadata": {},
47 | "outputs": [],
48 | "source": [
49 | "import torch\n",
50 | "\n",
51 | "device = torch.device(\"cuda\")"
52 | ]
53 | },
54 | {
55 | "cell_type": "markdown",
56 | "id": "ead383e5-149b-4bfb-9324-3cc639fd398d",
57 | "metadata": {},
58 | "source": [
59 | "## Prepare dataset and dataloader"
60 | ]
61 | },
62 | {
63 | "cell_type": "markdown",
64 | "id": "91ecbb08-6a74-4623-bfe8-bddba5254e35",
65 | "metadata": {},
66 | "source": [
67 | "In this example, we use dataset used in [official LoRA example](https://github.com/microsoft/LoRA).\n",
68 | "\n",
69 | "Download dataset from official repository."
70 | ]
71 | },
72 | {
73 | "cell_type": "code",
74 | "execution_count": 3,
75 | "id": "54a564f1-f8f3-42a6-b160-bebdbcc3aac0",
76 | "metadata": {},
77 | "outputs": [
78 | {
79 | "name": "stdout",
80 | "output_type": "stream",
81 | "text": [
82 | "--2023-10-06 03:27:50-- https://github.com/microsoft/LoRA/raw/main/examples/NLG/data/e2e/train.txt\n",
83 | "Resolving github.com (github.com)... 140.82.114.3\n",
84 | "Connecting to github.com (github.com)|140.82.114.3|:443... connected.\n",
85 | "HTTP request sent, awaiting response... 302 Found\n",
86 | "Location: https://raw.githubusercontent.com/microsoft/LoRA/main/examples/NLG/data/e2e/train.txt [following]\n",
87 | "--2023-10-06 03:27:51-- https://raw.githubusercontent.com/microsoft/LoRA/main/examples/NLG/data/e2e/train.txt\n",
88 | "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.109.133, 185.199.108.133, ...\n",
89 | "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.\n",
90 | "HTTP request sent, awaiting response... 200 OK\n",
91 | "Length: 9624463 (9.2M) [text/plain]\n",
92 | "Saving to: ‘train.txt’\n",
93 | "\n",
94 | "train.txt 100%[===================>] 9.18M --.-KB/s in 0.04s \n",
95 | "\n",
96 | "2023-10-06 03:27:51 (248 MB/s) - ‘train.txt’ saved [9624463/9624463]\n",
97 | "\n"
98 | ]
99 | }
100 | ],
101 | "source": [
102 | "!wget https://github.com/microsoft/LoRA/raw/main/examples/NLG/data/e2e/train.txt"
103 | ]
104 | },
105 | {
106 | "cell_type": "code",
107 | "execution_count": 4,
108 | "id": "d48464ea-991f-48b2-9166-3323cfd61676",
109 | "metadata": {
110 | "scrolled": true
111 | },
112 | "outputs": [
113 | {
114 | "name": "stdout",
115 | "output_type": "stream",
116 | "text": [
117 | "--2023-10-06 03:27:54-- https://github.com/microsoft/LoRA/raw/main/examples/NLG/data/e2e/test.txt\n",
118 | "Resolving github.com (github.com)... 140.82.114.3\n",
119 | "Connecting to github.com (github.com)|140.82.114.3|:443... connected.\n",
120 | "HTTP request sent, awaiting response... 302 Found\n",
121 | "Location: https://raw.githubusercontent.com/microsoft/LoRA/main/examples/NLG/data/e2e/test.txt [following]\n",
122 | "--2023-10-06 03:27:54-- https://raw.githubusercontent.com/microsoft/LoRA/main/examples/NLG/data/e2e/test.txt\n",
123 | "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.108.133, 185.199.109.133, ...\n",
124 | "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.\n",
125 | "HTTP request sent, awaiting response... 200 OK\n",
126 | "Length: 1351149 (1.3M) [text/plain]\n",
127 | "Saving to: ‘test.txt’\n",
128 | "\n",
129 | "test.txt 100%[===================>] 1.29M --.-KB/s in 0.006s \n",
130 | "\n",
131 | "2023-10-06 03:27:54 (208 MB/s) - ‘test.txt’ saved [1351149/1351149]\n",
132 | "\n"
133 | ]
134 | }
135 | ],
136 | "source": [
137 | "!wget https://github.com/microsoft/LoRA/raw/main/examples/NLG/data/e2e/test.txt"
138 | ]
139 | },
140 | {
141 | "cell_type": "markdown",
142 | "id": "09472803-8c62-48e0-9a63-b9b9448f16d3",
143 | "metadata": {},
144 | "source": [
145 | "Show the downloaded data (first 5 rows)."
146 | ]
147 | },
148 | {
149 | "cell_type": "code",
150 | "execution_count": 5,
151 | "id": "e6e60596-028f-4c4b-a95d-f74a0ff3b188",
152 | "metadata": {},
153 | "outputs": [
154 | {
155 | "name": "stdout",
156 | "output_type": "stream",
157 | "text": [
158 | "name : The Vaults | Type : pub | price : more than £ 30 | customer rating : 5 out of 5 | near : Café Adriatic||The Vaults pub near Café Adriatic has a 5 star rating . Prices start at £ 30 . \n",
159 | "name : The Cambridge Blue | Type : pub | food : English | price : cheap | near : Café Brazil||Close to Café Brazil , The Cambridge Blue pub serves delicious Tuscan Beef for the cheap price of £ 10.50 . Delicious Pub food . \n",
160 | "name : The Eagle | Type : coffee shop | food : Japanese | price : less than £ 20 | customer rating : low | area : riverside | family friendly : yes | near : Burger King||The Eagle is a low rated coffee shop near Burger King and the riverside that is family friendly and is less than £ 20 for Japanese food . \n",
161 | "name : The Mill | Type : coffee shop | food : French | price : £ 20 - 25 | area : riverside | near : The Sorrento||Located near The Sorrento is a French Theme eatery and coffee shop called The Mill , with a price range at £ 20- £ 25 it is in the riverside area . \n",
162 | "name : Loch Fyne | food : French | customer rating : high | area : riverside | near : The Rice Boat||For luxurious French food , the Loch Fyne is located by the river next to The Rice Boat . \n"
163 | ]
164 | }
165 | ],
166 | "source": [
167 | "!head -n 5 train.txt"
168 | ]
169 | },
170 | {
171 | "cell_type": "markdown",
172 | "id": "93f5fabe-590c-459b-aa16-4b5a506fb54b",
173 | "metadata": {},
174 | "source": [
175 | "Convert above data into JsonL format."
176 | ]
177 | },
178 | {
179 | "cell_type": "code",
180 | "execution_count": 6,
181 | "id": "7376e0c0-16c9-46f4-ad4c-83d1a677f5a2",
182 | "metadata": {},
183 | "outputs": [],
184 | "source": [
185 | "import sys\n",
186 | "import io\n",
187 | "import json\n",
188 | "\n",
189 | "def format_convert(read_file, write_file):\n",
190 | " with open(read_file, \"r\", encoding=\"utf8\") as reader, \\\n",
191 | " \t open(write_file, \"w\", encoding=\"utf8\") as writer :\n",
192 | " \tfor line in reader:\n",
193 | " \t\titems = line.strip().split(\"||\")\n",
194 | " \t\tcontext = items[0]\n",
195 | " \t\tcompletion = items[1].strip(\"\\n\")\n",
196 | " \t\tx = {}\n",
197 | " \t\tx[\"context\"] = context\n",
198 | " \t\tx[\"completion\"] = completion\n",
199 | " \t\twriter.write(json.dumps(x)+\"\\n\")\n",
200 | "\n",
201 | "format_convert(\"train.txt\", \"train_formatted.jsonl\")\n",
202 | "format_convert(\"test.txt\", \"test_formatted.jsonl\")"
203 | ]
204 | },
205 | {
206 | "cell_type": "markdown",
207 | "id": "3ceec952-fe03-475f-9f3e-22237cc9c44b",
208 | "metadata": {},
209 | "source": [
210 | "Show the converted data (first 5 rows)."
211 | ]
212 | },
213 | {
214 | "cell_type": "code",
215 | "execution_count": 7,
216 | "id": "cb32aca7-bd0e-4847-a4c2-cc7e67dc2b7a",
217 | "metadata": {},
218 | "outputs": [
219 | {
220 | "name": "stdout",
221 | "output_type": "stream",
222 | "text": [
223 | "{\"context\": \"name : The Vaults | Type : pub | price : more than \\u00a3 30 | customer rating : 5 out of 5 | near : Caf\\u00e9 Adriatic\", \"completion\": \"The Vaults pub near Caf\\u00e9 Adriatic has a 5 star rating . Prices start at \\u00a3 30 .\"}\n",
224 | "\n",
225 | "{\"context\": \"name : The Cambridge Blue | Type : pub | food : English | price : cheap | near : Caf\\u00e9 Brazil\", \"completion\": \"Close to Caf\\u00e9 Brazil , The Cambridge Blue pub serves delicious Tuscan Beef for the cheap price of \\u00a3 10.50 . Delicious Pub food .\"}\n",
226 | "\n",
227 | "{\"context\": \"name : The Eagle | Type : coffee shop | food : Japanese | price : less than \\u00a3 20 | customer rating : low | area : riverside | family friendly : yes | near : Burger King\", \"completion\": \"The Eagle is a low rated coffee shop near Burger King and the riverside that is family friendly and is less than \\u00a3 20 for Japanese food .\"}\n",
228 | "\n",
229 | "{\"context\": \"name : The Mill | Type : coffee shop | food : French | price : \\u00a3 20 - 25 | area : riverside | near : The Sorrento\", \"completion\": \"Located near The Sorrento is a French Theme eatery and coffee shop called The Mill , with a price range at \\u00a3 20- \\u00a3 25 it is in the riverside area .\"}\n",
230 | "\n",
231 | "{\"context\": \"name : Loch Fyne | food : French | customer rating : high | area : riverside | near : The Rice Boat\", \"completion\": \"For luxurious French food , the Loch Fyne is located by the river next to The Rice Boat .\"}\n",
232 | "\n"
233 | ]
234 | }
235 | ],
236 | "source": [
237 | "with open(\"train_formatted.jsonl\", \"r\") as reader:\n",
238 | " for _ in range(5):\n",
239 | " print(next(reader))"
240 | ]
241 | },
242 | {
243 | "cell_type": "markdown",
244 | "id": "6631f786-be4b-40cf-89d9-7009c1888821",
245 | "metadata": {},
246 | "source": [
247 | "Load tokenizer from Hugging Face."
248 | ]
249 | },
250 | {
251 | "cell_type": "code",
252 | "execution_count": 8,
253 | "id": "e5433dc0-b5a5-4c01-adb5-3ffa2279eca8",
254 | "metadata": {},
255 | "outputs": [],
256 | "source": [
257 | "from transformers import AutoTokenizer\n",
258 | "import os\n",
259 | "\n",
260 | "tokenizer = AutoTokenizer.from_pretrained(\n",
261 | " model_name,\n",
262 | " fast_tokenizer=True)\n",
263 | "tokenizer.pad_token = tokenizer.eos_token\n",
264 | "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\""
265 | ]
266 | },
267 | {
268 | "cell_type": "markdown",
269 | "id": "50817c47-a97b-4f80-975b-836859a0a7cf",
270 | "metadata": {},
271 | "source": [
272 | "Set block size, which is used to separate long text for model consumption."
273 | ]
274 | },
275 | {
276 | "cell_type": "code",
277 | "execution_count": 9,
278 | "id": "5f250929-5703-4b17-9f7b-26340950c055",
279 | "metadata": {},
280 | "outputs": [],
281 | "source": [
282 | "block_size = 512"
283 | ]
284 | },
285 | {
286 | "cell_type": "markdown",
287 | "id": "2332617b-1e66-4812-ad47-5eaeb52b101b",
288 | "metadata": {},
289 | "source": [
290 | "Create function to convert data. (Later this function is then used in data loader.)
\n",
291 | "In this function,\n",
292 | "\n",
293 | "1. Tokenize both contexts and compeletions. : e.g, ```\"This is a pen.\"``` --> ```[1212, 318, 257, 3112, 13]```\n",
294 | "2. Concatenate context's token and completion's token. (But it's delimited by \"\\n\" between context and completion.) This is used for inputs for LLM.\n",
295 | "3. Create labels (targets) with inputs. Label is ```input[1:]``` (i.e, shifted right by one element), and is filled by ```-100``` in context's positions. (See below note.)\n",
296 | "4. Pad tokens to make the length of token become ```block_size```.\n",
297 | "\n",
298 | "> Note : Here I set ```-100``` as an ignored index for loss computation, because PyTorch cross-entropy function (```torch.nn.functional.cross_entropy()```) has a property ```ignore_index``` which default value is ```-100```."
299 | ]
300 | },
301 | {
302 | "cell_type": "code",
303 | "execution_count": 10,
304 | "id": "9f2f38aa-b3d0-4614-aa59-8ddd977176d1",
305 | "metadata": {},
306 | "outputs": [],
307 | "source": [
308 | "from torch.utils.data import DataLoader\n",
309 | "import pandas as pd\n",
310 | "\n",
311 | "def fill_ignore_label(l, c):\n",
312 | " l[:len(c) - 1] = [-100] * (len(c) - 1)\n",
313 | " return l\n",
314 | "\n",
315 | "def pad_tokens(tokens, max_seq_length, padding_token):\n",
316 | " res_tokens = tokens[:max_seq_length]\n",
317 | " token_len = len(res_tokens)\n",
318 | " res_tokens = res_tokens + \\\n",
319 | " [padding_token for _ in range(max_seq_length - token_len)]\n",
320 | " return res_tokens\n",
321 | "\n",
322 | "def collate_batch(batch):\n",
323 | " # tokenize both context and completion respectively\n",
324 | " # (context and completion is delimited by \"\\n\")\n",
325 | " context_list = list(zip(*batch))[0]\n",
326 | " context_list = [c + \"\\n\" for c in context_list]\n",
327 | " completion_list = list(zip(*batch))[1]\n",
328 | " context_result = tokenizer(context_list)\n",
329 | " context_tokens = context_result[\"input_ids\"]\n",
330 | " context_masks = context_result[\"attention_mask\"]\n",
331 | " completion_result = tokenizer(completion_list)\n",
332 | " completion_tokens = completion_result[\"input_ids\"]\n",
333 | " completion_masks = completion_result[\"attention_mask\"]\n",
334 | " # OPT tokenizer adds the start token in sequence,\n",
335 | " # and we then remove it in completion\n",
336 | " completion_tokens = [t[1:] for t in completion_tokens]\n",
337 | " completion_masks = [t[1:] for t in completion_masks]\n",
338 | " # concatenate token\n",
339 | " inputs = [i + j for i, j in zip(context_tokens, completion_tokens)]\n",
340 | " masks = [i + j for i, j in zip(context_masks, completion_masks)]\n",
341 | " # create label\n",
342 | " eos_id = tokenizer.encode(tokenizer.eos_token)[0]\n",
343 | " labels = [t[1:] + [eos_id] for t in inputs]\n",
344 | " labels = list(map(fill_ignore_label, labels, context_tokens))\n",
345 | " # truncate and pad tokens\n",
346 | " inputs = [pad_tokens(t, block_size, 0) for t in inputs] # OPT and GPT-2 doesn't use pad token (instead attn mask is used)\n",
347 | " masks = [pad_tokens(t, block_size, 0) for t in masks]\n",
348 | " labels = [pad_tokens(t, block_size, -100) for t in labels]\n",
349 | " # convert to tensor\n",
350 | " inputs = torch.tensor(inputs, dtype=torch.int64).to(device)\n",
351 | " masks = torch.tensor(masks, dtype=torch.int64).to(device)\n",
352 | " labels = torch.tensor(labels, dtype=torch.int64).to(device)\n",
353 | " return inputs, labels, masks"
354 | ]
355 | },
356 | {
357 | "cell_type": "markdown",
358 | "id": "2084d2e9-ef64-47a2-aec9-d24ead1cb38a",
359 | "metadata": {},
360 | "source": [
361 | "Now create PyTorch dataloader with previous function (collator function).\n",
362 | "\n",
363 | "> Note : In this example, data is small and we then load all JSON data in memory.
\n",
364 | "> When it's large, load data progressively by implementing custom PyTorch dataset. (See [here](https://github.com/tsmatz/decision-transformer) for example.)"
365 | ]
366 | },
367 | {
368 | "cell_type": "code",
369 | "execution_count": 11,
370 | "id": "f3bce3bb-2215-4bd6-a6a6-5b6b9d5afdc0",
371 | "metadata": {},
372 | "outputs": [],
373 | "source": [
374 | "batch_size = 8\n",
375 | "gradient_accumulation_steps = 16\n",
376 | "\n",
377 | "data = pd.read_json(\"train_formatted.jsonl\", lines=True)\n",
378 | "dataloader = DataLoader(\n",
379 | " list(zip(data[\"context\"], data[\"completion\"])),\n",
380 | " batch_size=batch_size,\n",
381 | " shuffle=True,\n",
382 | " collate_fn=collate_batch\n",
383 | ")"
384 | ]
385 | },
386 | {
387 | "cell_type": "markdown",
388 | "id": "3ba64144-b698-457e-b827-941020456536",
389 | "metadata": {},
390 | "source": [
391 | "## Load model"
392 | ]
393 | },
394 | {
395 | "cell_type": "markdown",
396 | "id": "1bfd360d-7bdc-4fd7-9b12-bcf9fe0a8db2",
397 | "metadata": {},
398 | "source": [
399 | "Load model from Hugging Face."
400 | ]
401 | },
402 | {
403 | "cell_type": "code",
404 | "execution_count": 12,
405 | "id": "271181bd-677a-4da9-9e57-2874f5e47bd0",
406 | "metadata": {},
407 | "outputs": [],
408 | "source": [
409 | "from transformers import AutoModelForCausalLM, AutoConfig\n",
410 | "\n",
411 | "config = AutoConfig.from_pretrained(model_name)\n",
412 | "model = AutoModelForCausalLM.from_pretrained(\n",
413 | " model_name,\n",
414 | " config=config,\n",
415 | ").to(device)"
416 | ]
417 | },
418 | {
419 | "cell_type": "markdown",
420 | "id": "27ab764a-d634-40f8-9edb-a01146845233",
421 | "metadata": {},
422 | "source": [
423 | "## Generate text (before fine-tuning)"
424 | ]
425 | },
426 | {
427 | "cell_type": "markdown",
428 | "id": "559efeaf-4b38-4a0c-9be6-eb394221e374",
429 | "metadata": {},
430 | "source": [
431 | "Now run prediction with downloaded model (which is not still fine-tuned).\n",
432 | "\n",
433 | "First we create a function to generate text."
434 | ]
435 | },
436 | {
437 | "cell_type": "code",
438 | "execution_count": 13,
439 | "id": "51a0c4fc-e0a7-4bbf-b25a-c335fe61f3df",
440 | "metadata": {},
441 | "outputs": [],
442 | "source": [
443 | "def generate_text(model, input, mask, eos_id, pred_sequence_length):\n",
444 | " predicted_last_id = -1\n",
445 | " start_token_len = torch.sum(mask).cpu().numpy()\n",
446 | " token_len = start_token_len\n",
447 | " with torch.no_grad():\n",
448 | " while (predicted_last_id != eos_id) and \\\n",
449 | " (token_len - start_token_len < pred_sequence_length):\n",
450 | " output = model(\n",
451 | " input_ids=input,\n",
452 | " attention_mask=mask,\n",
453 | " )\n",
454 | " predicted_ids = torch.argmax(output.logits, axis=-1).cpu().numpy()\n",
455 | " predicted_last_id = predicted_ids[0][token_len - 1]\n",
456 | " input[0][token_len] = predicted_last_id\n",
457 | " mask[0][token_len] = 1\n",
458 | " token_len = torch.sum(mask).cpu().numpy()\n",
459 | " return input, token_len"
460 | ]
461 | },
462 | {
463 | "cell_type": "markdown",
464 | "id": "3936b1a1-ae9f-48a5-80db-691261dda704",
465 | "metadata": {},
466 | "source": [
467 | "Let's test our function and generate text. (Here we stop the text generation when it reaches 15 tokens in prediction.)"
468 | ]
469 | },
470 | {
471 | "cell_type": "code",
472 | "execution_count": 14,
473 | "id": "28b7e13f-e8fb-4a9f-90ed-0464463ef569",
474 | "metadata": {},
475 | "outputs": [
476 | {
477 | "name": "stdout",
478 | "output_type": "stream",
479 | "text": [
480 | "Once upon a time, I was a student at the University of California, Berkeley. I was a\n",
481 | "My name is Clara and I am a student at the University of California, Berkeley. I am a member of\n"
482 | ]
483 | }
484 | ],
485 | "source": [
486 | "eos_id = tokenizer.encode(tokenizer.eos_token)[0]\n",
487 | "\n",
488 | "result = tokenizer(\"Once upon a time,\")\n",
489 | "input = result[\"input_ids\"]\n",
490 | "mask = result[\"attention_mask\"]\n",
491 | "input = pad_tokens(input, block_size, 0)\n",
492 | "mask = pad_tokens(mask, block_size, 0)\n",
493 | "input = torch.tensor([input], dtype=torch.int64).to(device)\n",
494 | "mask = torch.tensor([mask], dtype=torch.int64).to(device)\n",
495 | "\n",
496 | "result_token, result_len = generate_text(\n",
497 | " model,\n",
498 | " input,\n",
499 | " mask,\n",
500 | " eos_id,\n",
501 | " pred_sequence_length=15)\n",
502 | "print(tokenizer.decode(result_token[0][:result_len]))\n",
503 | "\n",
504 | "result = tokenizer(\"My name is Clara and I am\")\n",
505 | "input = result[\"input_ids\"]\n",
506 | "mask = result[\"attention_mask\"]\n",
507 | "input = pad_tokens(input, block_size, 0)\n",
508 | "mask = pad_tokens(mask, block_size, 0)\n",
509 | "input = torch.tensor([input], dtype=torch.int64).to(device)\n",
510 | "mask = torch.tensor([mask], dtype=torch.int64).to(device)\n",
511 | "\n",
512 | "result_token, result_len = generate_text(\n",
513 | " model,\n",
514 | " input,\n",
515 | " mask,\n",
516 | " eos_id,\n",
517 | " pred_sequence_length=15)\n",
518 | "print(tokenizer.decode(result_token[0][:result_len]))"
519 | ]
520 | },
521 | {
522 | "cell_type": "markdown",
523 | "id": "d48fb60b-c05d-4884-a9bc-92152c94c894",
524 | "metadata": {},
525 | "source": [
526 | "Now we generate text with our test dataset (5 rows).
\n",
527 | "As you can see below, it cannot output the completion well, because it's not still fine-tuned."
528 | ]
529 | },
530 | {
531 | "cell_type": "code",
532 | "execution_count": 15,
533 | "id": "495728ef-fbe6-4953-a354-4b7a8bb88798",
534 | "metadata": {},
535 | "outputs": [
536 | {
537 | "name": "stdout",
538 | "output_type": "stream",
539 | "text": [
540 | "********** input **********\n",
541 | "name : The Wrestlers | Type : pub | food : Italian | price : less than £ 20 | area : riverside | family friendly : no | near : Raja Indian Cuisine\n",
542 | "\n",
543 | "********** result **********\n",
544 | "name : The Wrestlers | Type : pub | food : Italian | price : less than £ 20 | area : riverside | family friendly : no | near : Raja Indian Cuisine\n",
545 | "\n",
546 | "The Wrestlers is a restaurant in the heart of the city of Raja, India. It is located in the heart of the city of Raj\n",
547 | "********** input **********\n",
548 | "name : The Cricketers | Type : coffee shop | customer rating : 1 out of 5 | family friendly : yes | near : Avalon\n",
549 | "\n",
550 | "********** result **********\n",
551 | "name : The Cricketers | Type : coffee shop | customer rating : 1 out of 5 | family friendly : yes | near : Avalon\n",
552 | "\n",
553 | "The Cricketers is a coffee shop in Avalon, New York. It is located at the corner of Main Street and Main Street. The coffee\n",
554 | "********** input **********\n",
555 | "name : The Cricketers | Type : restaurant | food : Chinese | price : cheap | customer rating : 5 out of 5 | area : city centre | family friendly : no | near : All Bar One\n",
556 | "\n",
557 | "********** result **********\n",
558 | "name : The Cricketers | Type : restaurant | food : Chinese | price : cheap | customer rating : 5 out of 5 | area : city centre | family friendly : no | near : All Bar One\n",
559 | "\n",
560 | "The Cricketers | Type : restaurant | food : Chinese | price : cheap | customer rating : 5 out of 5 | area : city centre\n",
561 | "********** input **********\n",
562 | "name : The Punter | Type : restaurant | food : English | price : high | area : riverside | family friendly : no | near : Raja Indian Cuisine\n",
563 | "\n",
564 | "********** result **********\n",
565 | "name : The Punter | Type : restaurant | food : English | price : high | area : riverside | family friendly : no | near : Raja Indian Cuisine\n",
566 | "\n",
567 | "The Punter is a restaurant in Raja, India. It is located in the heart of the Raja district of Rajasthan. It\n",
568 | "********** input **********\n",
569 | "name : The Cricketers | Type : restaurant | food : Chinese | price : cheap | customer rating : average | area : city centre | family friendly : yes | near : All Bar One\n",
570 | "\n",
571 | "********** result **********\n",
572 | "name : The Cricketers | Type : restaurant | food : Chinese | price : cheap | customer rating : average | area : city centre | family friendly : yes | near : All Bar One\n",
573 | "\n",
574 | "The Cricketers | Type : restaurant | food : Chinese | price : cheap | customer rating : average | area : city centre | family friendly\n"
575 | ]
576 | }
577 | ],
578 | "source": [
579 | "test_data = pd.read_json(\"test_formatted.jsonl\", lines=True)\n",
580 | "test_data = test_data[::2] # because it's duplicated\n",
581 | "test_loader = DataLoader(\n",
582 | " list(zip(test_data[\"context\"], [\"\"] * len(test_data[\"context\"]))),\n",
583 | " batch_size=1,\n",
584 | " shuffle=True,\n",
585 | " collate_fn=collate_batch\n",
586 | ")\n",
587 | "\n",
588 | "for i, (input, _, mask) in enumerate(test_loader):\n",
589 | " if i == 5:\n",
590 | " break\n",
591 | " print(\"********** input **********\")\n",
592 | " input_len = torch.sum(mask).cpu().numpy()\n",
593 | " print(tokenizer.decode(input[0][:input_len]))\n",
594 | " result_token, result_len = generate_text(\n",
595 | " model,\n",
596 | " input,\n",
597 | " mask,\n",
598 | " eos_id,\n",
599 | " pred_sequence_length=30)\n",
600 | " print(\"********** result **********\")\n",
601 | " print(tokenizer.decode(result_token[0][:result_len]))"
602 | ]
603 | },
604 | {
605 | "cell_type": "markdown",
606 | "id": "e3138341-e01c-4fae-af78-c61e34967e92",
607 | "metadata": {},
608 | "source": [
609 | "## LoRA (Low-Rank Adaptation)\n",
610 | "\n",
611 | "Now we apply LoRA in our downloaded model.\n",
612 | "\n",
613 | "[LoRA (Low-Rank Adaptation)](https://arxiv.org/abs/2106.09685) (which is developed by Microsoft Research) is a popular adaptation method for efficient fine-tuning.\n",
614 | "\n",
615 | "In a task-specific fine-tuning, the change in weights during model adaptation has a low intrinsic rank.
\n",
616 | "With this hypothesis, we can assume that model's updates ($ \\Delta W $) will be re-written with much smaller low-rank matrices $ B \\cdot A $ as follows.\n",
617 | "\n",
618 | "$$ \\displaystyle W_0 x + \\Delta W x = W_0 x + B \\cdot A x $$\n",
619 | "\n",
620 | "where\n",
621 | "\n",
622 | "- $ W_0 \\in \\mathbb{R}^{d \\times k} $ is a pre-trained weight's matrix (which is frozen).\n",
623 | "- $ \\Delta W $ is updates.\n",
624 | "- $ B \\in \\mathbb{R}^{d \\times r}, A \\in \\mathbb{R}^{r \\times k} $ and $ \\verb| rank |\\ r \\ll min(d, k) $\n",
625 | "\n",
626 | "\n",
627 | "\n",
628 | "*From : [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685)*\n",
629 | "\n",
630 | "In this assumption, we freeze all weights except for $ B $ and $ A $, and train only these low-ranked matrices $ B $ and $ A $.
\n",
631 | "With this manner, you can fine-tune large transformers for a specific task without full-parameter's fine-tuning.\n",
632 | "\n",
633 | "This will significantly save the required capacity (GPU memories) for training, and the number of required GPUs can approximately be reduced to one-fourth in the benchmark with GPT-3.\n",
634 | "\n",
635 | "For the purpose of your learning, here I manually (from scratch) convert the current model into the model with LoRA.\n",
636 | "\n",
637 | "> Note : You can use ```PEFT``` package to be able to get LoRA model with a few lines of code. (Here I don't use this package.)"
638 | ]
639 | },
640 | {
641 | "cell_type": "markdown",
642 | "id": "5265832d-a736-4d68-80d3-347833d2c590",
643 | "metadata": {},
644 | "source": [
645 | "Before changing our model, first we check the structure of our model.
\n",
646 | "As you can see below (see the result in the cell), the following 6 linear layers are used in a single transformer layer on OPT.\n",
647 | "\n",
648 | "- Linear layer to get key\n",
649 | "- Linear layer to get value\n",
650 | "- Linear layer to get query\n",
651 | "- Linear layer for the output of attention\n",
652 | "- 2 linear layers (feed-forward layer) for the output of a single layer of transformer\n",
653 | "\n",
654 | "In this example, we'll convert all these layers into LoRA layers.
\n",
655 | "The transformer in OPT-125M has 12 layers and it then has total 6 x 12 = 72 linear layers to be converted."
656 | ]
657 | },
658 | {
659 | "cell_type": "code",
660 | "execution_count": 16,
661 | "id": "5acb8f62-791a-4fa4-b00c-2666cf34827f",
662 | "metadata": {},
663 | "outputs": [
664 | {
665 | "data": {
666 | "text/plain": [
667 | "OPTForCausalLM(\n",
668 | " (model): OPTModel(\n",
669 | " (decoder): OPTDecoder(\n",
670 | " (embed_tokens): Embedding(50272, 768, padding_idx=1)\n",
671 | " (embed_positions): OPTLearnedPositionalEmbedding(2050, 768)\n",
672 | " (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
673 | " (layers): ModuleList(\n",
674 | " (0-11): 12 x OPTDecoderLayer(\n",
675 | " (self_attn): OPTAttention(\n",
676 | " (k_proj): Linear(in_features=768, out_features=768, bias=True)\n",
677 | " (v_proj): Linear(in_features=768, out_features=768, bias=True)\n",
678 | " (q_proj): Linear(in_features=768, out_features=768, bias=True)\n",
679 | " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n",
680 | " )\n",
681 | " (activation_fn): ReLU()\n",
682 | " (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
683 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
684 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
685 | " (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
686 | " )\n",
687 | " )\n",
688 | " )\n",
689 | " )\n",
690 | " (lm_head): Linear(in_features=768, out_features=50272, bias=False)\n",
691 | ")"
692 | ]
693 | },
694 | "execution_count": 16,
695 | "metadata": {},
696 | "output_type": "execute_result"
697 | }
698 | ],
699 | "source": [
700 | "model"
701 | ]
702 | },
703 | {
704 | "cell_type": "markdown",
705 | "id": "045e7239-cb8a-46dd-815d-e48e7e49eea4",
706 | "metadata": {},
707 | "source": [
708 | "First we build custom linear layer with LoRA as follows."
709 | ]
710 | },
711 | {
712 | "cell_type": "code",
713 | "execution_count": 17,
714 | "id": "77889272-9a93-491b-93cb-b0bed5ce7cd8",
715 | "metadata": {},
716 | "outputs": [],
717 | "source": [
718 | "import math\n",
719 | "from torch import nn\n",
720 | "\n",
721 | "class LoRA_Linear(nn.Module):\n",
722 | " def __init__(self, weight, bias, lora_dim):\n",
723 | " super(LoRA_Linear, self).__init__()\n",
724 | "\n",
725 | " row, column = weight.shape\n",
726 | "\n",
727 | " # restore Linear\n",
728 | " if bias is None:\n",
729 | " self.linear = nn.Linear(column, row, bias=False)\n",
730 | " self.linear.load_state_dict({\"weight\": weight})\n",
731 | " else:\n",
732 | " self.linear = nn.Linear(column, row)\n",
733 | " self.linear.load_state_dict({\"weight\": weight, \"bias\": bias})\n",
734 | "\n",
735 | " # create LoRA weights (with initialization)\n",
736 | " self.lora_right = nn.Parameter(torch.zeros(column, lora_dim))\n",
737 | " nn.init.kaiming_uniform_(self.lora_right, a=math.sqrt(5))\n",
738 | " self.lora_left = nn.Parameter(torch.zeros(lora_dim, row))\n",
739 | "\n",
740 | " def forward(self, input):\n",
741 | " x = self.linear(input)\n",
742 | " y = input @ self.lora_right @ self.lora_left\n",
743 | " return x + y"
744 | ]
745 | },
746 | {
747 | "cell_type": "markdown",
748 | "id": "954e2c9d-545e-4bd9-9b0f-eba3fe29a1de",
749 | "metadata": {},
750 | "source": [
751 | "Replace targeting linear layers with LoRA layers."
752 | ]
753 | },
754 | {
755 | "cell_type": "code",
756 | "execution_count": 18,
757 | "id": "baf8a748-a3e3-45b8-9c64-252c56abe923",
758 | "metadata": {},
759 | "outputs": [],
760 | "source": [
761 | "lora_dim = 128\n",
762 | "\n",
763 | "# get target module name\n",
764 | "target_names = []\n",
765 | "for name, module in model.named_modules():\n",
766 | " if isinstance(module, nn.Linear) and \"decoder.layers.\" in name:\n",
767 | " target_names.append(name)\n",
768 | "\n",
769 | "# replace each module with LoRA\n",
770 | "for name in target_names:\n",
771 | " name_struct = name.split(\".\")\n",
772 | " # get target module\n",
773 | " module_list = [model]\n",
774 | " for struct in name_struct:\n",
775 | " module_list.append(getattr(module_list[-1], struct))\n",
776 | " # build LoRA\n",
777 | " lora = LoRA_Linear(\n",
778 | " weight = module_list[-1].weight,\n",
779 | " bias = module_list[-1].bias,\n",
780 | " lora_dim = lora_dim,\n",
781 | " ).to(device)\n",
782 | " # replace\n",
783 | " module_list[-2].__setattr__(name_struct[-1], lora)"
784 | ]
785 | },
786 | {
787 | "cell_type": "markdown",
788 | "id": "8aae2df9-fae7-4ecc-8260-80e8e578d951",
789 | "metadata": {},
790 | "source": [
791 | "See how model is changed."
792 | ]
793 | },
794 | {
795 | "cell_type": "code",
796 | "execution_count": 19,
797 | "id": "bf16b414-b973-40eb-be81-fd2aa3dde439",
798 | "metadata": {},
799 | "outputs": [
800 | {
801 | "data": {
802 | "text/plain": [
803 | "OPTForCausalLM(\n",
804 | " (model): OPTModel(\n",
805 | " (decoder): OPTDecoder(\n",
806 | " (embed_tokens): Embedding(50272, 768, padding_idx=1)\n",
807 | " (embed_positions): OPTLearnedPositionalEmbedding(2050, 768)\n",
808 | " (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
809 | " (layers): ModuleList(\n",
810 | " (0-11): 12 x OPTDecoderLayer(\n",
811 | " (self_attn): OPTAttention(\n",
812 | " (k_proj): LoRA_Linear(\n",
813 | " (linear): Linear(in_features=768, out_features=768, bias=True)\n",
814 | " )\n",
815 | " (v_proj): LoRA_Linear(\n",
816 | " (linear): Linear(in_features=768, out_features=768, bias=True)\n",
817 | " )\n",
818 | " (q_proj): LoRA_Linear(\n",
819 | " (linear): Linear(in_features=768, out_features=768, bias=True)\n",
820 | " )\n",
821 | " (out_proj): LoRA_Linear(\n",
822 | " (linear): Linear(in_features=768, out_features=768, bias=True)\n",
823 | " )\n",
824 | " )\n",
825 | " (activation_fn): ReLU()\n",
826 | " (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
827 | " (fc1): LoRA_Linear(\n",
828 | " (linear): Linear(in_features=768, out_features=3072, bias=True)\n",
829 | " )\n",
830 | " (fc2): LoRA_Linear(\n",
831 | " (linear): Linear(in_features=3072, out_features=768, bias=True)\n",
832 | " )\n",
833 | " (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
834 | " )\n",
835 | " )\n",
836 | " )\n",
837 | " )\n",
838 | " (lm_head): Linear(in_features=768, out_features=50272, bias=False)\n",
839 | ")"
840 | ]
841 | },
842 | "execution_count": 19,
843 | "metadata": {},
844 | "output_type": "execute_result"
845 | }
846 | ],
847 | "source": [
848 | "model"
849 | ]
850 | },
851 | {
852 | "cell_type": "markdown",
853 | "id": "e9099c08-f6a6-45f8-939b-cc3ed9415976",
854 | "metadata": {},
855 | "source": [
856 | "Finally, freeze all parameters except for LoRA parameters."
857 | ]
858 | },
859 | {
860 | "cell_type": "code",
861 | "execution_count": 20,
862 | "id": "81d06bba-955b-4806-8ff7-f217252e3268",
863 | "metadata": {},
864 | "outputs": [],
865 | "source": [
866 | "for name, param in model.named_parameters():\n",
867 | " if \"lora_right\" in name or \"lora_left\" in name:\n",
868 | " param.requires_grad = True\n",
869 | " else:\n",
870 | " param.requires_grad = False"
871 | ]
872 | },
873 | {
874 | "cell_type": "code",
875 | "execution_count": null,
876 | "id": "6c0a4469-2827-4f30-9324-711a9feea1ae",
877 | "metadata": {},
878 | "outputs": [],
879 | "source": [
880 | "### Do this when you run adapter fine-tuning on Hugging Face framework\n",
881 | "# model.gradient_checkpointing_enable()\n",
882 | "# model.enable_input_require_grads()"
883 | ]
884 | },
885 | {
886 | "cell_type": "markdown",
887 | "id": "6d6c7d6f-6c50-4839-88a5-c851caab9ba2",
888 | "metadata": {},
889 | "source": [
890 | "## Fine-tune"
891 | ]
892 | },
893 | {
894 | "cell_type": "markdown",
895 | "id": "a12b875f-36cc-40b8-aaab-1efda68710f3",
896 | "metadata": {},
897 | "source": [
898 | "Now let's start to run fine-tuning.\n",
899 | "\n",
900 | "First we build optimizer as follows."
901 | ]
902 | },
903 | {
904 | "cell_type": "code",
905 | "execution_count": 21,
906 | "id": "bb51298a-2d55-466c-a990-0ea08a247350",
907 | "metadata": {},
908 | "outputs": [],
909 | "source": [
910 | "optimizer = torch.optim.AdamW(\n",
911 | " params=model.parameters(),\n",
912 | " lr=1e-3,\n",
913 | " betas=(0.9, 0.95),\n",
914 | ")"
915 | ]
916 | },
917 | {
918 | "cell_type": "markdown",
919 | "id": "d37db1a8-0053-4acc-94ce-89d87c78942e",
920 | "metadata": {},
921 | "source": [
922 | "In this example, we build cosine scheduler for training."
923 | ]
924 | },
925 | {
926 | "cell_type": "code",
927 | "execution_count": 22,
928 | "id": "6f95bdf6-4498-4d40-90aa-1267d55f38c3",
929 | "metadata": {},
930 | "outputs": [],
931 | "source": [
932 | "from torch.optim.lr_scheduler import LambdaLR\n",
933 | "\n",
934 | "num_epochs = 2\n",
935 | "\n",
936 | "num_update_steps = math.ceil(len(dataloader) / batch_size / gradient_accumulation_steps)\n",
937 | "def _get_cosine_schedule(\n",
938 | " current_step: int,\n",
939 | " num_warmup_steps: int = 0,\n",
940 | " num_training_steps: int = num_epochs * num_update_steps\n",
941 | "):\n",
942 | " if current_step < num_warmup_steps:\n",
943 | " return float(current_step) / float(max(1, num_warmup_steps))\n",
944 | " progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))\n",
945 | " return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))\n",
946 | "scheduler = LambdaLR(optimizer, lr_lambda=_get_cosine_schedule)"
947 | ]
948 | },
949 | {
950 | "cell_type": "markdown",
951 | "id": "a9f9e828-c4fb-493d-a6de-78e03dbf035e",
952 | "metadata": {},
953 | "source": [
954 | "Run fine-tuning."
955 | ]
956 | },
957 | {
958 | "cell_type": "code",
959 | "execution_count": 23,
960 | "id": "75d22125-830a-4ec6-8417-cdb8a97ec559",
961 | "metadata": {},
962 | "outputs": [
963 | {
964 | "name": "stdout",
965 | "output_type": "stream",
966 | "text": [
967 | "Epoch 1 42/42 - loss: 1.0724\n",
968 | "Epoch 2 42/42 - loss: 1.3185\n"
969 | ]
970 | }
971 | ],
972 | "source": [
973 | "from torch.nn import functional as F\n",
974 | "\n",
975 | "if os.path.exists(\"loss.txt\"):\n",
976 | " os.remove(\"loss.txt\")\n",
977 | "\n",
978 | "for epoch in range(num_epochs):\n",
979 | " optimizer.zero_grad()\n",
980 | " model.train()\n",
981 | " for i, (inputs, labels, masks) in enumerate(dataloader):\n",
982 | " with torch.set_grad_enabled(True):\n",
983 | " outputs = model(\n",
984 | " input_ids=inputs,\n",
985 | " attention_mask=masks,\n",
986 | " )\n",
987 | " loss = F.cross_entropy(outputs.logits.transpose(1,2), labels)\n",
988 | " loss.backward()\n",
989 | " if ((i + 1) % gradient_accumulation_steps == 0) or \\\n",
990 | " (i + 1 == len(dataloader)):\n",
991 | " optimizer.step()\n",
992 | " optimizer.zero_grad()\n",
993 | " scheduler.step()\n",
994 | "\n",
995 | " print(f\"Epoch {epoch+1} {math.ceil((i + 1) / batch_size / gradient_accumulation_steps)}/{num_update_steps} - loss: {loss.item() :2.4f}\", end=\"\\r\")\n",
996 | "\n",
997 | " # record loss\n",
998 | " with open(\"loss.txt\", \"a\") as f:\n",
999 | " f.write(str(loss.item()))\n",
1000 | " f.write(\"\\n\")\n",
1001 | " print(\"\")\n",
1002 | "\n",
1003 | "# save model\n",
1004 | "torch.save(model.state_dict(), \"finetuned_opt.bin\")"
1005 | ]
1006 | },
1007 | {
1008 | "cell_type": "markdown",
1009 | "id": "83993d92-d7ed-4a07-8985-cc59bd4e4fef",
1010 | "metadata": {},
1011 | "source": [
1012 | "> Note : Here we save LoRA-enabled model without any changes, but you can also merge the trained LoRA's parameters into the original linear layer's weights."
1013 | ]
1014 | },
1015 | {
1016 | "cell_type": "markdown",
1017 | "id": "1bc086e5-e93f-4264-a8fa-6428f844ac3c",
1018 | "metadata": {},
1019 | "source": [
1020 | "Show loss transition in plot."
1021 | ]
1022 | },
1023 | {
1024 | "cell_type": "code",
1025 | "execution_count": 25,
1026 | "id": "e37c5aee-38d4-4a2a-952c-4fd2bef41e2b",
1027 | "metadata": {},
1028 | "outputs": [
1029 | {
1030 | "data": {
1031 | "image/png": "",
1032 | "text/plain": [
1033 | ""
1034 | ]
1035 | },
1036 | "metadata": {},
1037 | "output_type": "display_data"
1038 | }
1039 | ],
1040 | "source": [
1041 | "import matplotlib.pyplot as plt\n",
1042 | "import pandas as pd\n",
1043 | "\n",
1044 | "data = pd.read_csv(\"loss.txt\")\n",
1045 | "plt.plot(data)\n",
1046 | "plt.show()"
1047 | ]
1048 | },
1049 | {
1050 | "cell_type": "markdown",
1051 | "id": "9809bc9f-4ff6-46c3-9c43-08c6c2694a82",
1052 | "metadata": {},
1053 | "source": [
1054 | "## Generate text with fine-tuned model\n",
1055 | "\n",
1056 | "Again we check results with our test dataset (5 rows).
\n",
1057 | "As you can see below, it can output the completion very well, because it's fine-tuned."
1058 | ]
1059 | },
1060 | {
1061 | "cell_type": "code",
1062 | "execution_count": 26,
1063 | "id": "29903cae-404e-4209-9c84-6c8a69609c13",
1064 | "metadata": {},
1065 | "outputs": [
1066 | {
1067 | "name": "stdout",
1068 | "output_type": "stream",
1069 | "text": [
1070 | "********** input **********\n",
1071 | "name : The Punter | Type : pub | food : Chinese | price : more than £ 30 | area : riverside | family friendly : yes | near : Raja Indian Cuisine\n",
1072 | "\n",
1073 | "********** result **********\n",
1074 | "name : The Punter | Type : pub | food : Chinese | price : more than £ 30 | area : riverside | family friendly : yes | near : Raja Indian Cuisine\n",
1075 | "The Punter is a children friendly pub that serves Chinese food. It is located in the riverside area near Raja Indian Cuisine and has a\n",
1076 | "********** input **********\n",
1077 | "name : The Cricketers | Type : restaurant | food : Chinese | price : cheap | customer rating : 5 out of 5 | area : city centre | family friendly : no | near : All Bar One\n",
1078 | "\n",
1079 | "********** result **********\n",
1080 | "name : The Cricketers | Type : restaurant | food : Chinese | price : cheap | customer rating : 5 out of 5 | area : city centre | family friendly : no | near : All Bar One\n",
1081 | "The Cricketers is a Chinese restaurant with a cheap price range, located in the city centre near All Bar One. It has a customer rating of\n",
1082 | "********** input **********\n",
1083 | "name : The Phoenix | Type : pub | food : French | price : moderate | customer rating : 1 out of 5 | area : riverside | family friendly : no | near : Crowne Plaza Hotel\n",
1084 | "\n",
1085 | "********** result **********\n",
1086 | "name : The Phoenix | Type : pub | food : French | price : moderate | customer rating : 1 out of 5 | area : riverside | family friendly : no | near : Crowne Plaza Hotel\n",
1087 | "The Phoenix is a pub that serves French food. It is located near Crown Plaza Hotel in the riverside area. It has a moderate price range and\n",
1088 | "********** input **********\n",
1089 | "name : Giraffe | Type : restaurant | food : Fast food | area : riverside | family friendly : yes | near : Rainbow Vegetarian Café\n",
1090 | "\n",
1091 | "********** result **********\n",
1092 | "name : Giraffe | Type : restaurant | food : Fast food | area : riverside | family friendly : yes | near : Rainbow Vegetarian Café\n",
1093 | "Giraffe is a fast food restaurant located in the riverside area near Rainbow Vegetarian Café. It is family friendly.\n",
1094 | "********** input **********\n",
1095 | "name : The Vaults | Type : pub | food : French | price : more than £ 30 | area : city centre | family friendly : yes | near : Raja Indian Cuisine\n",
1096 | "\n",
1097 | "********** result **********\n",
1098 | "name : The Vaults | Type : pub | food : French | price : more than £ 30 | area : city centre | family friendly : yes | near : Raja Indian Cuisine\n",
1099 | "The Vaults is a children friendly French pub located in the city centre near Raja Indian Cuisine.\n"
1100 | ]
1101 | }
1102 | ],
1103 | "source": [
1104 | "test_data = pd.read_json(\"test_formatted.jsonl\", lines=True)\n",
1105 | "test_data = test_data[::2] # because it's duplicated\n",
1106 | "test_loader = DataLoader(\n",
1107 | " list(zip(test_data[\"context\"], [\"\"] * len(test_data[\"context\"]))),\n",
1108 | " batch_size=1,\n",
1109 | " shuffle=True,\n",
1110 | " collate_fn=collate_batch\n",
1111 | ")\n",
1112 | "\n",
1113 | "for i, (input, _, mask) in enumerate(test_loader):\n",
1114 | " if i == 5:\n",
1115 | " break\n",
1116 | " print(\"********** input **********\")\n",
1117 | " input_len = torch.sum(mask).cpu().numpy()\n",
1118 | " print(tokenizer.decode(input[0][:input_len]))\n",
1119 | " result_token, result_len = generate_text(\n",
1120 | " model,\n",
1121 | " input,\n",
1122 | " mask,\n",
1123 | " eos_id,\n",
1124 | " pred_sequence_length=30)\n",
1125 | " print(\"********** result **********\")\n",
1126 | " print(tokenizer.decode(result_token[0][:result_len]))"
1127 | ]
1128 | },
1129 | {
1130 | "cell_type": "code",
1131 | "execution_count": null,
1132 | "id": "6a7c1dd3-4057-497a-83ae-f99b1883697e",
1133 | "metadata": {},
1134 | "outputs": [],
1135 | "source": []
1136 | }
1137 | ],
1138 | "metadata": {
1139 | "kernelspec": {
1140 | "display_name": "Python 3 (ipykernel)",
1141 | "language": "python",
1142 | "name": "python3"
1143 | },
1144 | "language_info": {
1145 | "codemirror_mode": {
1146 | "name": "ipython",
1147 | "version": 3
1148 | },
1149 | "file_extension": ".py",
1150 | "mimetype": "text/x-python",
1151 | "name": "python",
1152 | "nbconvert_exporter": "python",
1153 | "pygments_lexer": "ipython3",
1154 | "version": "3.8.10"
1155 | }
1156 | },
1157 | "nbformat": 4,
1158 | "nbformat_minor": 5
1159 | }
1160 |
--------------------------------------------------------------------------------
/02-finetune-gpt2-with-lora.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "857cafc6-da38-4aa7-8afc-63aa626fa7aa",
6 | "metadata": {},
7 | "source": [
8 | "# 02. Finetuning GPT-2 with LoRA\n",
9 | "\n",
10 | "In this example, we fine-tune the pre-trained auto-regressive model, **OpenAI's GPT-2** (small version, 124M parameters), by applying LoRA (Low-Rank Adaptation) optimization.\n",
11 | "\n",
12 | "In this example, I download the pre-trained model from Hugging Face hub, but fine-tune model with regular PyTorch training loop.
\n",
13 | "(Here I don't use Hugging Face Trainer class.)\n",
14 | "\n",
15 | "See [Readme](https://github.com/tsmatz/finetune_llm_with_lora) for prerequisite's setup."
16 | ]
17 | },
18 | {
19 | "cell_type": "code",
20 | "execution_count": 1,
21 | "id": "3d49acf1-9ad1-4a6c-9312-6785cb3f5862",
22 | "metadata": {},
23 | "outputs": [],
24 | "source": [
25 | "model_name = \"gpt2\""
26 | ]
27 | },
28 | {
29 | "cell_type": "code",
30 | "execution_count": 2,
31 | "id": "4d835e84-a01d-4c33-926b-60d9dd4a7627",
32 | "metadata": {},
33 | "outputs": [],
34 | "source": [
35 | "import torch\n",
36 | "\n",
37 | "device = torch.device(\"cuda\")"
38 | ]
39 | },
40 | {
41 | "cell_type": "markdown",
42 | "id": "ead383e5-149b-4bfb-9324-3cc639fd398d",
43 | "metadata": {},
44 | "source": [
45 | "## Prepare dataset and dataloader"
46 | ]
47 | },
48 | {
49 | "cell_type": "markdown",
50 | "id": "91ecbb08-6a74-4623-bfe8-bddba5254e35",
51 | "metadata": {},
52 | "source": [
53 | "In this example, we use dataset used in [official LoRA example](https://github.com/microsoft/LoRA).\n",
54 | "\n",
55 | "Download dataset from official repository."
56 | ]
57 | },
58 | {
59 | "cell_type": "code",
60 | "execution_count": 3,
61 | "id": "54a564f1-f8f3-42a6-b160-bebdbcc3aac0",
62 | "metadata": {},
63 | "outputs": [
64 | {
65 | "name": "stdout",
66 | "output_type": "stream",
67 | "text": [
68 | "--2023-10-06 03:27:50-- https://github.com/microsoft/LoRA/raw/main/examples/NLG/data/e2e/train.txt\n",
69 | "Resolving github.com (github.com)... 140.82.114.3\n",
70 | "Connecting to github.com (github.com)|140.82.114.3|:443... connected.\n",
71 | "HTTP request sent, awaiting response... 302 Found\n",
72 | "Location: https://raw.githubusercontent.com/microsoft/LoRA/main/examples/NLG/data/e2e/train.txt [following]\n",
73 | "--2023-10-06 03:27:51-- https://raw.githubusercontent.com/microsoft/LoRA/main/examples/NLG/data/e2e/train.txt\n",
74 | "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.109.133, 185.199.108.133, ...\n",
75 | "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.\n",
76 | "HTTP request sent, awaiting response... 200 OK\n",
77 | "Length: 9624463 (9.2M) [text/plain]\n",
78 | "Saving to: ‘train.txt’\n",
79 | "\n",
80 | "train.txt 100%[===================>] 9.18M --.-KB/s in 0.04s \n",
81 | "\n",
82 | "2023-10-06 03:27:51 (248 MB/s) - ‘train.txt’ saved [9624463/9624463]\n",
83 | "\n"
84 | ]
85 | }
86 | ],
87 | "source": [
88 | "!wget https://github.com/microsoft/LoRA/raw/main/examples/NLG/data/e2e/train.txt"
89 | ]
90 | },
91 | {
92 | "cell_type": "code",
93 | "execution_count": 4,
94 | "id": "d48464ea-991f-48b2-9166-3323cfd61676",
95 | "metadata": {
96 | "scrolled": true
97 | },
98 | "outputs": [
99 | {
100 | "name": "stdout",
101 | "output_type": "stream",
102 | "text": [
103 | "--2023-10-06 03:27:54-- https://github.com/microsoft/LoRA/raw/main/examples/NLG/data/e2e/test.txt\n",
104 | "Resolving github.com (github.com)... 140.82.114.3\n",
105 | "Connecting to github.com (github.com)|140.82.114.3|:443... connected.\n",
106 | "HTTP request sent, awaiting response... 302 Found\n",
107 | "Location: https://raw.githubusercontent.com/microsoft/LoRA/main/examples/NLG/data/e2e/test.txt [following]\n",
108 | "--2023-10-06 03:27:54-- https://raw.githubusercontent.com/microsoft/LoRA/main/examples/NLG/data/e2e/test.txt\n",
109 | "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.108.133, 185.199.109.133, ...\n",
110 | "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.\n",
111 | "HTTP request sent, awaiting response... 200 OK\n",
112 | "Length: 1351149 (1.3M) [text/plain]\n",
113 | "Saving to: ‘test.txt’\n",
114 | "\n",
115 | "test.txt 100%[===================>] 1.29M --.-KB/s in 0.006s \n",
116 | "\n",
117 | "2023-10-06 03:27:54 (208 MB/s) - ‘test.txt’ saved [1351149/1351149]\n",
118 | "\n"
119 | ]
120 | }
121 | ],
122 | "source": [
123 | "!wget https://github.com/microsoft/LoRA/raw/main/examples/NLG/data/e2e/test.txt"
124 | ]
125 | },
126 | {
127 | "cell_type": "markdown",
128 | "id": "09472803-8c62-48e0-9a63-b9b9448f16d3",
129 | "metadata": {},
130 | "source": [
131 | "Show the downloaded data (first 5 rows)."
132 | ]
133 | },
134 | {
135 | "cell_type": "code",
136 | "execution_count": 5,
137 | "id": "e6e60596-028f-4c4b-a95d-f74a0ff3b188",
138 | "metadata": {},
139 | "outputs": [
140 | {
141 | "name": "stdout",
142 | "output_type": "stream",
143 | "text": [
144 | "name : The Vaults | Type : pub | price : more than £ 30 | customer rating : 5 out of 5 | near : Café Adriatic||The Vaults pub near Café Adriatic has a 5 star rating . Prices start at £ 30 . \n",
145 | "name : The Cambridge Blue | Type : pub | food : English | price : cheap | near : Café Brazil||Close to Café Brazil , The Cambridge Blue pub serves delicious Tuscan Beef for the cheap price of £ 10.50 . Delicious Pub food . \n",
146 | "name : The Eagle | Type : coffee shop | food : Japanese | price : less than £ 20 | customer rating : low | area : riverside | family friendly : yes | near : Burger King||The Eagle is a low rated coffee shop near Burger King and the riverside that is family friendly and is less than £ 20 for Japanese food . \n",
147 | "name : The Mill | Type : coffee shop | food : French | price : £ 20 - 25 | area : riverside | near : The Sorrento||Located near The Sorrento is a French Theme eatery and coffee shop called The Mill , with a price range at £ 20- £ 25 it is in the riverside area . \n",
148 | "name : Loch Fyne | food : French | customer rating : high | area : riverside | near : The Rice Boat||For luxurious French food , the Loch Fyne is located by the river next to The Rice Boat . \n"
149 | ]
150 | }
151 | ],
152 | "source": [
153 | "!head -n 5 train.txt"
154 | ]
155 | },
156 | {
157 | "cell_type": "markdown",
158 | "id": "93f5fabe-590c-459b-aa16-4b5a506fb54b",
159 | "metadata": {},
160 | "source": [
161 | "Convert above data into JsonL format."
162 | ]
163 | },
164 | {
165 | "cell_type": "code",
166 | "execution_count": 6,
167 | "id": "7376e0c0-16c9-46f4-ad4c-83d1a677f5a2",
168 | "metadata": {},
169 | "outputs": [],
170 | "source": [
171 | "import sys\n",
172 | "import io\n",
173 | "import json\n",
174 | "\n",
175 | "def format_convert(read_file, write_file):\n",
176 | " with open(read_file, \"r\", encoding=\"utf8\") as reader, \\\n",
177 | " \t open(write_file, \"w\", encoding=\"utf8\") as writer :\n",
178 | " \tfor line in reader:\n",
179 | " \t\titems = line.strip().split(\"||\")\n",
180 | " \t\tcontext = items[0]\n",
181 | " \t\tcompletion = items[1].strip(\"\\n\")\n",
182 | " \t\tx = {}\n",
183 | " \t\tx[\"context\"] = context\n",
184 | " \t\tx[\"completion\"] = completion\n",
185 | " \t\twriter.write(json.dumps(x)+\"\\n\")\n",
186 | "\n",
187 | "format_convert(\"train.txt\", \"train_formatted.jsonl\")\n",
188 | "format_convert(\"test.txt\", \"test_formatted.jsonl\")"
189 | ]
190 | },
191 | {
192 | "cell_type": "markdown",
193 | "id": "3ceec952-fe03-475f-9f3e-22237cc9c44b",
194 | "metadata": {},
195 | "source": [
196 | "Show the converted data (first 5 rows)."
197 | ]
198 | },
199 | {
200 | "cell_type": "code",
201 | "execution_count": 7,
202 | "id": "cb32aca7-bd0e-4847-a4c2-cc7e67dc2b7a",
203 | "metadata": {},
204 | "outputs": [
205 | {
206 | "name": "stdout",
207 | "output_type": "stream",
208 | "text": [
209 | "{\"context\": \"name : The Vaults | Type : pub | price : more than \\u00a3 30 | customer rating : 5 out of 5 | near : Caf\\u00e9 Adriatic\", \"completion\": \"The Vaults pub near Caf\\u00e9 Adriatic has a 5 star rating . Prices start at \\u00a3 30 .\"}\n",
210 | "\n",
211 | "{\"context\": \"name : The Cambridge Blue | Type : pub | food : English | price : cheap | near : Caf\\u00e9 Brazil\", \"completion\": \"Close to Caf\\u00e9 Brazil , The Cambridge Blue pub serves delicious Tuscan Beef for the cheap price of \\u00a3 10.50 . Delicious Pub food .\"}\n",
212 | "\n",
213 | "{\"context\": \"name : The Eagle | Type : coffee shop | food : Japanese | price : less than \\u00a3 20 | customer rating : low | area : riverside | family friendly : yes | near : Burger King\", \"completion\": \"The Eagle is a low rated coffee shop near Burger King and the riverside that is family friendly and is less than \\u00a3 20 for Japanese food .\"}\n",
214 | "\n",
215 | "{\"context\": \"name : The Mill | Type : coffee shop | food : French | price : \\u00a3 20 - 25 | area : riverside | near : The Sorrento\", \"completion\": \"Located near The Sorrento is a French Theme eatery and coffee shop called The Mill , with a price range at \\u00a3 20- \\u00a3 25 it is in the riverside area .\"}\n",
216 | "\n",
217 | "{\"context\": \"name : Loch Fyne | food : French | customer rating : high | area : riverside | near : The Rice Boat\", \"completion\": \"For luxurious French food , the Loch Fyne is located by the river next to The Rice Boat .\"}\n",
218 | "\n"
219 | ]
220 | }
221 | ],
222 | "source": [
223 | "with open(\"train_formatted.jsonl\", \"r\") as reader:\n",
224 | " for _ in range(5):\n",
225 | " print(next(reader))"
226 | ]
227 | },
228 | {
229 | "cell_type": "markdown",
230 | "id": "6631f786-be4b-40cf-89d9-7009c1888821",
231 | "metadata": {},
232 | "source": [
233 | "Load tokenizer from Hugging Face."
234 | ]
235 | },
236 | {
237 | "cell_type": "code",
238 | "execution_count": 8,
239 | "id": "e5433dc0-b5a5-4c01-adb5-3ffa2279eca8",
240 | "metadata": {},
241 | "outputs": [],
242 | "source": [
243 | "from transformers import AutoTokenizer\n",
244 | "import os\n",
245 | "\n",
246 | "tokenizer = AutoTokenizer.from_pretrained(\n",
247 | " model_name,\n",
248 | " fast_tokenizer=True)\n",
249 | "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\""
250 | ]
251 | },
252 | {
253 | "cell_type": "markdown",
254 | "id": "50817c47-a97b-4f80-975b-836859a0a7cf",
255 | "metadata": {},
256 | "source": [
257 | "Set block size which is used to separate long text for model consumption.
\n",
258 | "Max 1024 tokens can be used in GPT-2, but here I set 512, because it's enough for our dataset."
259 | ]
260 | },
261 | {
262 | "cell_type": "code",
263 | "execution_count": 9,
264 | "id": "5f250929-5703-4b17-9f7b-26340950c055",
265 | "metadata": {},
266 | "outputs": [
267 | {
268 | "name": "stdout",
269 | "output_type": "stream",
270 | "text": [
271 | "Max length of tokens is 1024 in this model.\n",
272 | "But here we use max 512 tokens in the training.\n"
273 | ]
274 | }
275 | ],
276 | "source": [
277 | "block_size = 512\n",
278 | "\n",
279 | "print(f\"Max length of tokens is {tokenizer.model_max_length} in this model.\")\n",
280 | "print(f\"But here we use max {block_size} tokens in the training.\")"
281 | ]
282 | },
283 | {
284 | "cell_type": "markdown",
285 | "id": "2332617b-1e66-4812-ad47-5eaeb52b101b",
286 | "metadata": {},
287 | "source": [
288 | "Create function to convert data. (Later this function is then used in data loader.)
\n",
289 | "In this function,\n",
290 | "\n",
291 | "1. Tokenize both contexts and compeletions. : e.g, ```\"This is a pen.\"``` --> ```[1212, 318, 257, 3112, 13]```\n",
292 | "2. Concatenate context's token and completion's token. (But it's delimited by \"\\n\" between context and completion.) This is used for inputs for LLM.\n",
293 | "3. Create labels (targets) with inputs. Label is ```input[1:]``` (i.e, shifted right by one element), and is filled by ```-100``` in context's positions. (See below note.)\n",
294 | "4. Pad tokens to make the length of token become ```block_size```.\n",
295 | "\n",
296 | "> Note : Here I set ```-100``` as an ignored index for loss computation, because PyTorch cross-entropy function (```torch.nn.functional.cross_entropy()```) has a property ```ignore_index``` which default value is ```-100```."
297 | ]
298 | },
299 | {
300 | "cell_type": "code",
301 | "execution_count": 10,
302 | "id": "9f2f38aa-b3d0-4614-aa59-8ddd977176d1",
303 | "metadata": {},
304 | "outputs": [],
305 | "source": [
306 | "from torch.utils.data import DataLoader\n",
307 | "import pandas as pd\n",
308 | "\n",
309 | "def fill_ignore_label(l, c):\n",
310 | " l[:len(c) - 1] = [-100] * (len(c) - 1)\n",
311 | " return l\n",
312 | "\n",
313 | "def pad_tokens(tokens, max_seq_length, padding_token):\n",
314 | " res_tokens = tokens[:max_seq_length]\n",
315 | " token_len = len(res_tokens)\n",
316 | " res_tokens = res_tokens + \\\n",
317 | " [padding_token for _ in range(max_seq_length - token_len)]\n",
318 | " return res_tokens\n",
319 | "\n",
320 | "def collate_batch(batch):\n",
321 | " # tokenize both context and completion respectively\n",
322 | " # (context and completion is delimited by \"\\n\")\n",
323 | " context_list = list(zip(*batch))[0]\n",
324 | " context_list = [c + \"\\n\" for c in context_list]\n",
325 | " completion_list = list(zip(*batch))[1]\n",
326 | " context_result = tokenizer(context_list)\n",
327 | " context_tokens = context_result[\"input_ids\"]\n",
328 | " context_masks = context_result[\"attention_mask\"]\n",
329 | " completion_result = tokenizer(completion_list)\n",
330 | " completion_tokens = completion_result[\"input_ids\"]\n",
331 | " completion_masks = completion_result[\"attention_mask\"]\n",
332 | " # concatenate token\n",
333 | " inputs = [i + j for i, j in zip(context_tokens, completion_tokens)]\n",
334 | " masks = [i + j for i, j in zip(context_masks, completion_masks)]\n",
335 | " # create label\n",
336 | " eos_id = tokenizer.encode(tokenizer.eos_token)[0]\n",
337 | " labels = [t[1:] + [eos_id] for t in inputs]\n",
338 | " labels = list(map(fill_ignore_label, labels, context_tokens))\n",
339 | " # truncate and pad tokens\n",
340 | " inputs = [pad_tokens(t, block_size, 0) for t in inputs] # OPT and GPT-2 doesn't use pad token (instead attn mask is used)\n",
341 | " masks = [pad_tokens(t, block_size, 0) for t in masks]\n",
342 | " labels = [pad_tokens(t, block_size, -100) for t in labels]\n",
343 | " # convert to tensor\n",
344 | " inputs = torch.tensor(inputs, dtype=torch.int64).to(device)\n",
345 | " masks = torch.tensor(masks, dtype=torch.int64).to(device)\n",
346 | " labels = torch.tensor(labels, dtype=torch.int64).to(device)\n",
347 | " return inputs, labels, masks"
348 | ]
349 | },
350 | {
351 | "cell_type": "markdown",
352 | "id": "2084d2e9-ef64-47a2-aec9-d24ead1cb38a",
353 | "metadata": {},
354 | "source": [
355 | "Now create PyTorch dataloader with previous function (collator function).\n",
356 | "\n",
357 | "> Note : In this example, data is small and we then load all JSON data in memory.
\n",
358 | "> When it's large, load data progressively by implementing custom PyTorch dataset. (See [here](https://github.com/tsmatz/decision-transformer) for example.)"
359 | ]
360 | },
361 | {
362 | "cell_type": "code",
363 | "execution_count": 11,
364 | "id": "f3bce3bb-2215-4bd6-a6a6-5b6b9d5afdc0",
365 | "metadata": {},
366 | "outputs": [],
367 | "source": [
368 | "batch_size = 8\n",
369 | "gradient_accumulation_steps = 16\n",
370 | "\n",
371 | "data = pd.read_json(\"train_formatted.jsonl\", lines=True)\n",
372 | "dataloader = DataLoader(\n",
373 | " list(zip(data[\"context\"], data[\"completion\"])),\n",
374 | " batch_size=batch_size,\n",
375 | " shuffle=True,\n",
376 | " collate_fn=collate_batch\n",
377 | ")"
378 | ]
379 | },
380 | {
381 | "cell_type": "markdown",
382 | "id": "3ba64144-b698-457e-b827-941020456536",
383 | "metadata": {},
384 | "source": [
385 | "## Load model"
386 | ]
387 | },
388 | {
389 | "cell_type": "markdown",
390 | "id": "1bfd360d-7bdc-4fd7-9b12-bcf9fe0a8db2",
391 | "metadata": {},
392 | "source": [
393 | "Load model from Hugging Face."
394 | ]
395 | },
396 | {
397 | "cell_type": "code",
398 | "execution_count": 12,
399 | "id": "271181bd-677a-4da9-9e57-2874f5e47bd0",
400 | "metadata": {},
401 | "outputs": [],
402 | "source": [
403 | "from transformers import AutoModelForCausalLM, AutoConfig\n",
404 | "\n",
405 | "config = AutoConfig.from_pretrained(model_name)\n",
406 | "model = AutoModelForCausalLM.from_pretrained(\n",
407 | " model_name,\n",
408 | " config=config,\n",
409 | ").to(device)"
410 | ]
411 | },
412 | {
413 | "cell_type": "markdown",
414 | "id": "27ab764a-d634-40f8-9edb-a01146845233",
415 | "metadata": {},
416 | "source": [
417 | "## Generate text (before fine-tuning)"
418 | ]
419 | },
420 | {
421 | "cell_type": "markdown",
422 | "id": "559efeaf-4b38-4a0c-9be6-eb394221e374",
423 | "metadata": {},
424 | "source": [
425 | "Now run prediction with downloaded model (which is not still fine-tuned).\n",
426 | "\n",
427 | "First we create a function to generate text."
428 | ]
429 | },
430 | {
431 | "cell_type": "code",
432 | "execution_count": 13,
433 | "id": "51a0c4fc-e0a7-4bbf-b25a-c335fe61f3df",
434 | "metadata": {},
435 | "outputs": [],
436 | "source": [
437 | "def generate_text(model, input, mask, eos_id, pred_sequence_length):\n",
438 | " predicted_last_id = -1\n",
439 | " start_token_len = torch.sum(mask).cpu().numpy()\n",
440 | " token_len = start_token_len\n",
441 | " with torch.no_grad():\n",
442 | " while (predicted_last_id != eos_id) and \\\n",
443 | " (token_len - start_token_len < pred_sequence_length):\n",
444 | " output = model(\n",
445 | " input_ids=input,\n",
446 | " attention_mask=mask,\n",
447 | " )\n",
448 | " predicted_ids = torch.argmax(output.logits, axis=-1).cpu().numpy()\n",
449 | " predicted_last_id = predicted_ids[0][token_len - 1]\n",
450 | " input[0][token_len] = predicted_last_id\n",
451 | " mask[0][token_len] = 1\n",
452 | " token_len = torch.sum(mask).cpu().numpy()\n",
453 | " return input, token_len"
454 | ]
455 | },
456 | {
457 | "cell_type": "markdown",
458 | "id": "3936b1a1-ae9f-48a5-80db-691261dda704",
459 | "metadata": {},
460 | "source": [
461 | "Let's test our function and generate text. (Here we stop the text generation when it reaches 15 tokens in prediction.)"
462 | ]
463 | },
464 | {
465 | "cell_type": "code",
466 | "execution_count": 14,
467 | "id": "28b7e13f-e8fb-4a9f-90ed-0464463ef569",
468 | "metadata": {},
469 | "outputs": [
470 | {
471 | "name": "stdout",
472 | "output_type": "stream",
473 | "text": [
474 | "Once upon a time, the world was a place of great beauty and great danger. The world was\n",
475 | "My name is Clara and I am a woman. I am a woman who is a woman. I am a\n"
476 | ]
477 | }
478 | ],
479 | "source": [
480 | "eos_id = tokenizer.encode(tokenizer.eos_token)[0]\n",
481 | "\n",
482 | "result = tokenizer(\"Once upon a time,\")\n",
483 | "input = result[\"input_ids\"]\n",
484 | "mask = result[\"attention_mask\"]\n",
485 | "input = pad_tokens(input, block_size, 0)\n",
486 | "mask = pad_tokens(mask, block_size, 0)\n",
487 | "input = torch.tensor([input], dtype=torch.int64).to(device)\n",
488 | "mask = torch.tensor([mask], dtype=torch.int64).to(device)\n",
489 | "\n",
490 | "result_token, result_len = generate_text(\n",
491 | " model,\n",
492 | " input,\n",
493 | " mask,\n",
494 | " eos_id,\n",
495 | " pred_sequence_length=15)\n",
496 | "print(tokenizer.decode(result_token[0][:result_len]))\n",
497 | "\n",
498 | "result = tokenizer(\"My name is Clara and I am\")\n",
499 | "input = result[\"input_ids\"]\n",
500 | "mask = result[\"attention_mask\"]\n",
501 | "input = pad_tokens(input, block_size, 0)\n",
502 | "mask = pad_tokens(mask, block_size, 0)\n",
503 | "input = torch.tensor([input], dtype=torch.int64).to(device)\n",
504 | "mask = torch.tensor([mask], dtype=torch.int64).to(device)\n",
505 | "\n",
506 | "result_token, result_len = generate_text(\n",
507 | " model,\n",
508 | " input,\n",
509 | " mask,\n",
510 | " eos_id,\n",
511 | " pred_sequence_length=15)\n",
512 | "print(tokenizer.decode(result_token[0][:result_len]))"
513 | ]
514 | },
515 | {
516 | "cell_type": "markdown",
517 | "id": "d48fb60b-c05d-4884-a9bc-92152c94c894",
518 | "metadata": {},
519 | "source": [
520 | "Now we generate text with our test dataset (5 rows).
\n",
521 | "As you can see below, it cannot output the completion well, because it's not still fine-tuned."
522 | ]
523 | },
524 | {
525 | "cell_type": "code",
526 | "execution_count": 15,
527 | "id": "495728ef-fbe6-4953-a354-4b7a8bb88798",
528 | "metadata": {},
529 | "outputs": [
530 | {
531 | "name": "stdout",
532 | "output_type": "stream",
533 | "text": [
534 | "********** input **********\n",
535 | "name : Wildwood | Type : pub | food : Indian | area : city centre | family friendly : yes | near : Raja Indian Cuisine\n",
536 | "\n",
537 | "********** result **********\n",
538 | "name : Wildwood | Type : pub | food : Indian | area : city centre | family friendly : yes | near : Raja Indian Cuisine\n",
539 | "\n",
540 | "Raja Indian Cuisine : Indian | price : Rs. 1,000 | menu : Indian | menu type : food | menu size :\n",
541 | "********** input **********\n",
542 | "name : Giraffe | Type : pub | food : Fast food | area : riverside | family friendly : yes | near : Rainbow Vegetarian Café\n",
543 | "\n",
544 | "********** result **********\n",
545 | "name : Giraffe | Type : pub | food : Fast food | area : riverside | family friendly : yes | near : Rainbow Vegetarian Café\n",
546 | "\n",
547 | ": Giraffe | Type : pub | food : Fast food | area : riverside | family friendly : yes | near : Rainbow Vegetarian Café\n",
548 | "********** input **********\n",
549 | "name : The Waterman | Type : pub | food : Italian | price : less than £ 20 | area : city centre | family friendly : yes | near : Raja Indian Cuisine\n",
550 | "\n",
551 | "********** result **********\n",
552 | "name : The Waterman | Type : pub | food : Italian | price : less than £ 20 | area : city centre | family friendly : yes | near : Raja Indian Cuisine\n",
553 | "\n",
554 | "The Waterman is a pub in the heart of the city centre. It is a place where you can enjoy a good meal and drink a good\n",
555 | "********** input **********\n",
556 | "name : The Vaults | Type : pub | food : Italian | price : moderate | customer rating : 1 out of 5 | area : city centre | family friendly : no | near : Rainbow Vegetarian Café\n",
557 | "\n",
558 | "********** result **********\n",
559 | "name : The Vaults | Type : pub | food : Italian | price : moderate | customer rating : 1 out of 5 | area : city centre | family friendly : no | near : Rainbow Vegetarian Café\n",
560 | "\n",
561 | "The Vaults | Type : pub | food : Italian | price : moderate | customer rating : 1 out of 5 | area : city centre | family\n",
562 | "********** input **********\n",
563 | "name : The Vaults | Type : restaurant | food : French | price : less than £ 20 | area : riverside | family friendly : yes | near : Raja Indian Cuisine\n",
564 | "\n",
565 | "********** result **********\n",
566 | "name : The Vaults | Type : restaurant | food : French | price : less than £ 20 | area : riverside | family friendly : yes | near : Raja Indian Cuisine\n",
567 | "\n",
568 | "The restaurant is located in the centre of the city. It is a small restaurant with a small menu. The menu is very simple and the food\n"
569 | ]
570 | }
571 | ],
572 | "source": [
573 | "test_data = pd.read_json(\"test_formatted.jsonl\", lines=True)\n",
574 | "test_data = test_data[::2] # because it's duplicated\n",
575 | "test_loader = DataLoader(\n",
576 | " list(zip(test_data[\"context\"], [\"\"] * len(test_data[\"context\"]))),\n",
577 | " batch_size=1,\n",
578 | " shuffle=True,\n",
579 | " collate_fn=collate_batch\n",
580 | ")\n",
581 | "\n",
582 | "for i, (input, _, mask) in enumerate(test_loader):\n",
583 | " if i == 5:\n",
584 | " break\n",
585 | " print(\"********** input **********\")\n",
586 | " input_len = torch.sum(mask).cpu().numpy()\n",
587 | " print(tokenizer.decode(input[0][:input_len]))\n",
588 | " result_token, result_len = generate_text(\n",
589 | " model,\n",
590 | " input,\n",
591 | " mask,\n",
592 | " eos_id,\n",
593 | " pred_sequence_length=30)\n",
594 | " print(\"********** result **********\")\n",
595 | " print(tokenizer.decode(result_token[0][:result_len]))"
596 | ]
597 | },
598 | {
599 | "cell_type": "markdown",
600 | "id": "e3138341-e01c-4fae-af78-c61e34967e92",
601 | "metadata": {},
602 | "source": [
603 | "## LoRA (Low-Rank Adaptation)\n",
604 | "\n",
605 | "Now we apply LoRA in our downloaded model.
\n",
606 | "For semantics of LoRA (Low-Rank Adaptation), see [01-finetune-opt-with-lora.ipynb](./01-finetune-opt-with-lora.ipynb).\n",
607 | "\n",
608 | "For the purpose of your learning, here I manually (from scratch) convert the current model into the model with LoRA.\n",
609 | "\n",
610 | "> Note : You can use ```PEFT``` package to be able to get LoRA model with a few lines of code. (Here I don't use this package.)"
611 | ]
612 | },
613 | {
614 | "cell_type": "markdown",
615 | "id": "e296bdf2-129a-4278-8fe3-d08333ebf1df",
616 | "metadata": {},
617 | "source": [
618 | "Before changing our model, first we check the structure of our model. (See the following result in the cell.)\n",
619 | "\n",
620 | "As you can see below, you cannot find any linear layers in OpenAI's GPT-2 transformer, unlike [Meta's OPT transformer](./01-finetune-opt-with-lora.ipynb). Instead, you will find Conv1D (1D convolution) in transformer.
\n",
621 | "However, this Conv1D is not ```torch.nn.Conv1d``` and it's a custom layer defined for OpenAI GPT, which works same as a linear layer, but the weights are transposed. (See [this source code](https://github.com/huggingface/transformers/blob/main/src/transformers/pytorch_utils.py) for custom ```pytorch_utils.Conv1D``` implementation.)
\n",
622 | "This custom Conv1D layer (intrinsically, a linear layer) is used for MLP and getting key/value/query in GPT-2 transformer as follows.
\n",
623 | "(See [source code](https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py) for GPT-2 in Hugging Face tarnsformers.)\n",
624 | "\n",
625 | "- ```transformer.h.n.attn.c_attn``` : Layer to get key/value/query before processing attention.\n",
626 | "- ```transformer.h.n.attn.c_proj``` : Layer for projection after processing attention.\n",
627 | "- ```transformer.h.n.mlp.c_attn``` : MLP in GPT-2 is Linear(GeLU(Linear)). This is an inner Linear layer (custom Conv1D).\n",
628 | "- ```transformer.h.n.mlp.c_proj``` : MLP in GPT-2 is Linear(GeLU(Linear)). This is an outer Linear layer (custom Conv1D).\n",
629 | "\n",
630 | "In this example, we'll only convert ```transformer.h.n.attn.c_attn``` layers into LoRA layers.
\n",
631 | "The transformer in GPT-2 with 124M parameters has 12 layers and it then has total 12 layers (n=0,1, ... , 11) to be converted."
632 | ]
633 | },
634 | {
635 | "cell_type": "code",
636 | "execution_count": 16,
637 | "id": "5acb8f62-791a-4fa4-b00c-2666cf34827f",
638 | "metadata": {},
639 | "outputs": [
640 | {
641 | "data": {
642 | "text/plain": [
643 | "GPT2LMHeadModel(\n",
644 | " (transformer): GPT2Model(\n",
645 | " (wte): Embedding(50257, 768)\n",
646 | " (wpe): Embedding(1024, 768)\n",
647 | " (drop): Dropout(p=0.1, inplace=False)\n",
648 | " (h): ModuleList(\n",
649 | " (0-11): 12 x GPT2Block(\n",
650 | " (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
651 | " (attn): GPT2Attention(\n",
652 | " (c_attn): Conv1D()\n",
653 | " (c_proj): Conv1D()\n",
654 | " (attn_dropout): Dropout(p=0.1, inplace=False)\n",
655 | " (resid_dropout): Dropout(p=0.1, inplace=False)\n",
656 | " )\n",
657 | " (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
658 | " (mlp): GPT2MLP(\n",
659 | " (c_fc): Conv1D()\n",
660 | " (c_proj): Conv1D()\n",
661 | " (act): NewGELUActivation()\n",
662 | " (dropout): Dropout(p=0.1, inplace=False)\n",
663 | " )\n",
664 | " )\n",
665 | " )\n",
666 | " (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
667 | " )\n",
668 | " (lm_head): Linear(in_features=768, out_features=50257, bias=False)\n",
669 | ")"
670 | ]
671 | },
672 | "execution_count": 16,
673 | "metadata": {},
674 | "output_type": "execute_result"
675 | }
676 | ],
677 | "source": [
678 | "model"
679 | ]
680 | },
681 | {
682 | "cell_type": "markdown",
683 | "id": "045e7239-cb8a-46dd-815d-e48e7e49eea4",
684 | "metadata": {},
685 | "source": [
686 | "First we build custom linear layer with LoRA as follows."
687 | ]
688 | },
689 | {
690 | "cell_type": "code",
691 | "execution_count": 17,
692 | "id": "77889272-9a93-491b-93cb-b0bed5ce7cd8",
693 | "metadata": {},
694 | "outputs": [],
695 | "source": [
696 | "import math\n",
697 | "from torch import nn\n",
698 | "\n",
699 | "class LoRA_Linear(nn.Module):\n",
700 | " def __init__(self, weight, bias, lora_dim):\n",
701 | " super(LoRA_Linear, self).__init__()\n",
702 | "\n",
703 | " row, column = weight.shape\n",
704 | "\n",
705 | " # restore Linear\n",
706 | " if bias is None:\n",
707 | " self.linear = nn.Linear(column, row, bias=False)\n",
708 | " self.linear.load_state_dict({\"weight\": weight})\n",
709 | " else:\n",
710 | " self.linear = nn.Linear(column, row)\n",
711 | " self.linear.load_state_dict({\"weight\": weight, \"bias\": bias})\n",
712 | "\n",
713 | " # create LoRA weights (with initialization)\n",
714 | " self.lora_right = nn.Parameter(torch.zeros(column, lora_dim))\n",
715 | " nn.init.kaiming_uniform_(self.lora_right, a=math.sqrt(5))\n",
716 | " self.lora_left = nn.Parameter(torch.zeros(lora_dim, row))\n",
717 | "\n",
718 | " def forward(self, input):\n",
719 | " x = self.linear(input)\n",
720 | " y = input @ self.lora_right @ self.lora_left\n",
721 | " return x + y"
722 | ]
723 | },
724 | {
725 | "cell_type": "markdown",
726 | "id": "954e2c9d-545e-4bd9-9b0f-eba3fe29a1de",
727 | "metadata": {},
728 | "source": [
729 | "Replace targeting linear layers with LoRA layers.\n",
730 | "\n",
731 | "> Note : As I have mentioned above, custom Conv1D layer in GPT-2 is intrinsically a linear layer, but the weights are transposed."
732 | ]
733 | },
734 | {
735 | "cell_type": "code",
736 | "execution_count": 18,
737 | "id": "baf8a748-a3e3-45b8-9c64-252c56abe923",
738 | "metadata": {},
739 | "outputs": [],
740 | "source": [
741 | "lora_dim = 128\n",
742 | "\n",
743 | "# get target module name\n",
744 | "target_names = []\n",
745 | "for name, module in model.named_modules():\n",
746 | " if \"attn.c_attn\" in name:\n",
747 | " target_names.append(name)\n",
748 | "\n",
749 | "# replace each module with LoRA\n",
750 | "for name in target_names:\n",
751 | " name_struct = name.split(\".\")\n",
752 | " # get target module\n",
753 | " module_list = [model]\n",
754 | " for struct in name_struct:\n",
755 | " module_list.append(getattr(module_list[-1], struct))\n",
756 | " # build LoRA\n",
757 | " lora = LoRA_Linear(\n",
758 | " weight = torch.transpose(module_list[-1].weight, 0, 1),\n",
759 | " bias = module_list[-1].bias,\n",
760 | " lora_dim = lora_dim,\n",
761 | " ).to(device)\n",
762 | " # replace\n",
763 | " module_list[-2].__setattr__(name_struct[-1], lora)"
764 | ]
765 | },
766 | {
767 | "cell_type": "markdown",
768 | "id": "8aae2df9-fae7-4ecc-8260-80e8e578d951",
769 | "metadata": {},
770 | "source": [
771 | "See how model is changed."
772 | ]
773 | },
774 | {
775 | "cell_type": "code",
776 | "execution_count": 19,
777 | "id": "bf16b414-b973-40eb-be81-fd2aa3dde439",
778 | "metadata": {},
779 | "outputs": [
780 | {
781 | "data": {
782 | "text/plain": [
783 | "GPT2LMHeadModel(\n",
784 | " (transformer): GPT2Model(\n",
785 | " (wte): Embedding(50257, 768)\n",
786 | " (wpe): Embedding(1024, 768)\n",
787 | " (drop): Dropout(p=0.1, inplace=False)\n",
788 | " (h): ModuleList(\n",
789 | " (0-11): 12 x GPT2Block(\n",
790 | " (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
791 | " (attn): GPT2Attention(\n",
792 | " (c_attn): LoRA_Linear(\n",
793 | " (linear): Linear(in_features=768, out_features=2304, bias=True)\n",
794 | " )\n",
795 | " (c_proj): Conv1D()\n",
796 | " (attn_dropout): Dropout(p=0.1, inplace=False)\n",
797 | " (resid_dropout): Dropout(p=0.1, inplace=False)\n",
798 | " )\n",
799 | " (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
800 | " (mlp): GPT2MLP(\n",
801 | " (c_fc): Conv1D()\n",
802 | " (c_proj): Conv1D()\n",
803 | " (act): NewGELUActivation()\n",
804 | " (dropout): Dropout(p=0.1, inplace=False)\n",
805 | " )\n",
806 | " )\n",
807 | " )\n",
808 | " (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
809 | " )\n",
810 | " (lm_head): Linear(in_features=768, out_features=50257, bias=False)\n",
811 | ")"
812 | ]
813 | },
814 | "execution_count": 19,
815 | "metadata": {},
816 | "output_type": "execute_result"
817 | }
818 | ],
819 | "source": [
820 | "model"
821 | ]
822 | },
823 | {
824 | "cell_type": "markdown",
825 | "id": "e9099c08-f6a6-45f8-939b-cc3ed9415976",
826 | "metadata": {},
827 | "source": [
828 | "Finally, freeze all parameters except for LoRA parameters."
829 | ]
830 | },
831 | {
832 | "cell_type": "code",
833 | "execution_count": 20,
834 | "id": "81d06bba-955b-4806-8ff7-f217252e3268",
835 | "metadata": {},
836 | "outputs": [],
837 | "source": [
838 | "for name, param in model.named_parameters():\n",
839 | " if \"lora_right\" in name or \"lora_left\" in name:\n",
840 | " param.requires_grad = True\n",
841 | " else:\n",
842 | " param.requires_grad = False"
843 | ]
844 | },
845 | {
846 | "cell_type": "code",
847 | "execution_count": null,
848 | "id": "6c0a4469-2827-4f30-9324-711a9feea1ae",
849 | "metadata": {},
850 | "outputs": [],
851 | "source": [
852 | "### Do this when you run adapter fine-tuning on Hugging Face framework\n",
853 | "# model.gradient_checkpointing_enable()\n",
854 | "# model.enable_input_require_grads()"
855 | ]
856 | },
857 | {
858 | "cell_type": "markdown",
859 | "id": "6d6c7d6f-6c50-4839-88a5-c851caab9ba2",
860 | "metadata": {},
861 | "source": [
862 | "## Fine-tune"
863 | ]
864 | },
865 | {
866 | "cell_type": "markdown",
867 | "id": "a12b875f-36cc-40b8-aaab-1efda68710f3",
868 | "metadata": {},
869 | "source": [
870 | "Now let's start to run fine-tuning.\n",
871 | "\n",
872 | "First we build optimizer as follows."
873 | ]
874 | },
875 | {
876 | "cell_type": "code",
877 | "execution_count": 21,
878 | "id": "bb51298a-2d55-466c-a990-0ea08a247350",
879 | "metadata": {},
880 | "outputs": [],
881 | "source": [
882 | "optimizer = torch.optim.AdamW(\n",
883 | " params=model.parameters(),\n",
884 | " lr=0.0002,\n",
885 | " betas=(0.9, 0.999),\n",
886 | " eps=1e-6,\n",
887 | ")"
888 | ]
889 | },
890 | {
891 | "cell_type": "markdown",
892 | "id": "d37db1a8-0053-4acc-94ce-89d87c78942e",
893 | "metadata": {},
894 | "source": [
895 | "In this example, we build linear scheduler for training."
896 | ]
897 | },
898 | {
899 | "cell_type": "code",
900 | "execution_count": 22,
901 | "id": "6f95bdf6-4498-4d40-90aa-1267d55f38c3",
902 | "metadata": {},
903 | "outputs": [],
904 | "source": [
905 | "from torch.optim.lr_scheduler import LambdaLR\n",
906 | "\n",
907 | "num_epochs = 2\n",
908 | "num_warmup_steps = 500\n",
909 | "\n",
910 | "num_update_steps = math.ceil(len(dataloader) / batch_size / gradient_accumulation_steps)\n",
911 | "def _get_linear_schedule(current_step):\n",
912 | " if current_step < num_warmup_steps:\n",
913 | " return float(current_step) / float(max(1, num_warmup_steps))\n",
914 | " return max(0.0, float(num_update_steps * num_epochs - current_step) / float(max(1, num_update_steps * num_epochs - num_warmup_steps)))\n",
915 | "scheduler = LambdaLR(optimizer, lr_lambda=_get_linear_schedule)"
916 | ]
917 | },
918 | {
919 | "cell_type": "markdown",
920 | "id": "a9f9e828-c4fb-493d-a6de-78e03dbf035e",
921 | "metadata": {},
922 | "source": [
923 | "Run fine-tuning."
924 | ]
925 | },
926 | {
927 | "cell_type": "code",
928 | "execution_count": 23,
929 | "id": "3752481d-8ee8-4c43-b677-add136a2fd5b",
930 | "metadata": {},
931 | "outputs": [
932 | {
933 | "name": "stdout",
934 | "output_type": "stream",
935 | "text": [
936 | "Epoch 1 42/42 - loss: 1.3620\n",
937 | "Epoch 2 42/42 - loss: 1.4432\n"
938 | ]
939 | }
940 | ],
941 | "source": [
942 | "from torch.nn import functional as F\n",
943 | "\n",
944 | "if os.path.exists(\"loss.txt\"):\n",
945 | " os.remove(\"loss.txt\")\n",
946 | "\n",
947 | "for epoch in range(num_epochs):\n",
948 | " optimizer.zero_grad()\n",
949 | " model.train()\n",
950 | " for i, (inputs, labels, masks) in enumerate(dataloader):\n",
951 | " with torch.set_grad_enabled(True):\n",
952 | " outputs = model(\n",
953 | " input_ids=inputs,\n",
954 | " attention_mask=masks,\n",
955 | " )\n",
956 | " loss = F.cross_entropy(outputs.logits.transpose(1,2), labels)\n",
957 | " loss.backward()\n",
958 | " if ((i + 1) % gradient_accumulation_steps == 0) or \\\n",
959 | " (i + 1 == len(dataloader)):\n",
960 | " optimizer.step()\n",
961 | " scheduler.step()\n",
962 | " optimizer.zero_grad()\n",
963 | "\n",
964 | " print(f\"Epoch {epoch+1} {math.ceil((i + 1) / batch_size / gradient_accumulation_steps)}/{num_update_steps} - loss: {loss.item() :2.4f}\", end=\"\\r\")\n",
965 | "\n",
966 | " # record loss\n",
967 | " with open(\"loss.txt\", \"a\") as f:\n",
968 | " f.write(str(loss.item()))\n",
969 | " f.write(\"\\n\")\n",
970 | " print(\"\")\n",
971 | "\n",
972 | "# save model\n",
973 | "torch.save(model.state_dict(), \"finetuned_gpt2.bin\")"
974 | ]
975 | },
976 | {
977 | "cell_type": "markdown",
978 | "id": "83993d92-d7ed-4a07-8985-cc59bd4e4fef",
979 | "metadata": {},
980 | "source": [
981 | "> Note : Here we save LoRA-enabled model without any changes, but you can also merge the trained LoRA's parameters into the original model's weights."
982 | ]
983 | },
984 | {
985 | "cell_type": "markdown",
986 | "id": "1bc086e5-e93f-4264-a8fa-6428f844ac3c",
987 | "metadata": {},
988 | "source": [
989 | "Show loss transition in plot."
990 | ]
991 | },
992 | {
993 | "cell_type": "code",
994 | "execution_count": 24,
995 | "id": "e37c5aee-38d4-4a2a-952c-4fd2bef41e2b",
996 | "metadata": {},
997 | "outputs": [
998 | {
999 | "data": {
1000 | "image/png": "",
1001 | "text/plain": [
1002 | ""
1003 | ]
1004 | },
1005 | "metadata": {},
1006 | "output_type": "display_data"
1007 | }
1008 | ],
1009 | "source": [
1010 | "import matplotlib.pyplot as plt\n",
1011 | "import pandas as pd\n",
1012 | "\n",
1013 | "data = pd.read_csv(\"loss.txt\")\n",
1014 | "plt.plot(data)\n",
1015 | "plt.show()"
1016 | ]
1017 | },
1018 | {
1019 | "cell_type": "markdown",
1020 | "id": "9809bc9f-4ff6-46c3-9c43-08c6c2694a82",
1021 | "metadata": {},
1022 | "source": [
1023 | "## Generate text with fine-tuned model\n",
1024 | "\n",
1025 | "Again we check results with our test dataset (5 rows).
\n",
1026 | "As you can see below, it can output the completion very well, because it's fine-tuned."
1027 | ]
1028 | },
1029 | {
1030 | "cell_type": "code",
1031 | "execution_count": 25,
1032 | "id": "29903cae-404e-4209-9c84-6c8a69609c13",
1033 | "metadata": {},
1034 | "outputs": [
1035 | {
1036 | "name": "stdout",
1037 | "output_type": "stream",
1038 | "text": [
1039 | "********** input **********\n",
1040 | "name : The Vaults | Type : pub | food : Italian | price : less than £ 20 | customer rating : low | area : city centre | family friendly : no | near : Rainbow Vegetarian Café\n",
1041 | "\n",
1042 | "********** result **********\n",
1043 | "name : The Vaults | Type : pub | food : Italian | price : less than £ 20 | customer rating : low | area : city centre | family friendly : no | near : Rainbow Vegetarian Café\n",
1044 | "The Vaults is a pub near the Rainbow Vegetarian Café in the city centre. It is not family friendly and has a low customer rating of less than\n",
1045 | "********** input **********\n",
1046 | "name : The Cricketers | Type : restaurant | customer rating : average | family friendly : yes | near : Café Sicilia\n",
1047 | "\n",
1048 | "********** result **********\n",
1049 | "name : The Cricketers | Type : restaurant | customer rating : average | family friendly : yes | near : Café Sicilia\n",
1050 | "The Cricketers is a restaurant near Café Sicilia. It is family friendly and has an average customer rating.<|endoftext|>\n",
1051 | "********** input **********\n",
1052 | "name : The Cricketers | Type : restaurant | food : Chinese | price : cheap | customer rating : average | area : city centre | family friendly : no | near : All Bar One\n",
1053 | "\n",
1054 | "********** result **********\n",
1055 | "name : The Cricketers | Type : restaurant | food : Chinese | price : cheap | customer rating : average | area : city centre | family friendly : no | near : All Bar One\n",
1056 | "The Cricketers is a restaurant located in the city centre near All Bar One. It is not family - friendly. It is located in the cheap\n",
1057 | "********** input **********\n",
1058 | "name : The Vaults | Type : pub | food : Japanese | price : cheap | customer rating : 5 out of 5 | area : city centre | family friendly : yes | near : Raja Indian Cuisine\n",
1059 | "\n",
1060 | "********** result **********\n",
1061 | "name : The Vaults | Type : pub | food : Japanese | price : cheap | customer rating : 5 out of 5 | area : city centre | family friendly : yes | near : Raja Indian Cuisine\n",
1062 | "The Vaults is a cheap, family friendly pub located in the city centre near Raja Indian Cuisine.<|endoftext|>\n",
1063 | "********** input **********\n",
1064 | "name : The Wrestlers | Type : pub | food : Italian | price : less than £ 20 | area : riverside | family friendly : no | near : Raja Indian Cuisine\n",
1065 | "\n",
1066 | "********** result **********\n",
1067 | "name : The Wrestlers | Type : pub | food : Italian | price : less than £ 20 | area : riverside | family friendly : no | near : Raja Indian Cuisine\n",
1068 | "The Wrestlers is a pub near Raja Indian Cuisine in riverside. It is not family friendly.<|endoftext|>\n"
1069 | ]
1070 | }
1071 | ],
1072 | "source": [
1073 | "test_data = pd.read_json(\"test_formatted.jsonl\", lines=True)\n",
1074 | "test_data = test_data[::2] # because it's duplicated\n",
1075 | "test_loader = DataLoader(\n",
1076 | " list(zip(test_data[\"context\"], [\"\"] * len(test_data[\"context\"]))),\n",
1077 | " batch_size=1,\n",
1078 | " shuffle=True,\n",
1079 | " collate_fn=collate_batch\n",
1080 | ")\n",
1081 | "\n",
1082 | "for i, (input, _, mask) in enumerate(test_loader):\n",
1083 | " if i == 5:\n",
1084 | " break\n",
1085 | " print(\"********** input **********\")\n",
1086 | " input_len = torch.sum(mask).cpu().numpy()\n",
1087 | " print(tokenizer.decode(input[0][:input_len]))\n",
1088 | " result_token, result_len = generate_text(\n",
1089 | " model,\n",
1090 | " input,\n",
1091 | " mask,\n",
1092 | " eos_id,\n",
1093 | " pred_sequence_length=30)\n",
1094 | " print(\"********** result **********\")\n",
1095 | " print(tokenizer.decode(result_token[0][:result_len]))"
1096 | ]
1097 | },
1098 | {
1099 | "cell_type": "code",
1100 | "execution_count": null,
1101 | "id": "6a7c1dd3-4057-497a-83ae-f99b1883697e",
1102 | "metadata": {},
1103 | "outputs": [],
1104 | "source": []
1105 | }
1106 | ],
1107 | "metadata": {
1108 | "kernelspec": {
1109 | "display_name": "Python 3 (ipykernel)",
1110 | "language": "python",
1111 | "name": "python3"
1112 | },
1113 | "language_info": {
1114 | "codemirror_mode": {
1115 | "name": "ipython",
1116 | "version": 3
1117 | },
1118 | "file_extension": ".py",
1119 | "mimetype": "text/x-python",
1120 | "name": "python",
1121 | "nbconvert_exporter": "python",
1122 | "pygments_lexer": "ipython3",
1123 | "version": "3.8.10"
1124 | }
1125 | },
1126 | "nbformat": 4,
1127 | "nbformat_minor": 5
1128 | }
1129 |
--------------------------------------------------------------------------------
/Readme.md:
--------------------------------------------------------------------------------
1 | # Fine-tuning LLM with LoRA (Low-Rank Adaptation)
2 |
3 | LoRA (Low-Rank Adaptation) is one of mostly used parameter-efficient fine-tuning (PEFT) methods today.
4 |
5 | This example shows you [LoRA (Low-Rank Adaptation)](https://arxiv.org/abs/2106.09685) implementation from scratch (manually) in a step-by-step manner (without ```PEFT``` package), and also shows you clear ideas behind this implementation in IPython notebook.
6 |
7 | This is also runnable in the mainstream hardware with small footprint - such as, a signle GPU of Tesla T4, consumer GPUs (NVIDIA RTX), etc - for you to try this code easily.
8 |
9 | | Example | Description |
10 | | -------------------------------------------------------------------- | ----------------------------------------------------------------------- |
11 | | [01-finetune-opt-with-lora.ipynb](01-finetune-opt-with-lora.ipynb) | Fine-tuning Meta's OPT-125M with LoRA
(Also, explaining LoRA method) |
12 | | [02-finetune-gpt2-with-lora.ipynb](02-finetune-gpt2-with-lora.ipynb) | Fine-tuning OpenAI's GPT-2 small (124M) with LoRA |
13 |
14 | Unlike examples in [official repository](https://github.com/microsoft/LoRA), here I download pre-trained models to focus on LoRA implementation.
15 |
16 | > Note : In this repository, Hugging Face API is used to download pre-trained models and I then apply regular PyTorch training loop for fine-tuning. (I don't use blackboxed ```Trainer``` class in Hugging Face API.)
17 |
18 | ## 1. Set-up and Install
19 |
20 | To run this example, please install prerequisite's software and setup your environment as follows.
21 | In the following setting, I have used a GPU-utilized virtual machine (VM) with "Ubuntu Server 20.04 LTS" image in Microsoft Azure.
22 |
23 | ### Install GPU driver (CUDA)
24 |
25 | Install CUDA (NVIDIA GPU driver) as follows.
26 |
27 | ```
28 | # compilers and development settings
29 | sudo apt-get update
30 | sudo apt install -y gcc
31 | sudo apt-get install -y make
32 |
33 | # install CUDA
34 | wget https://developer.download.nvidia.com/compute/cuda/12.2.2/local_installers/cuda_12.2.2_535.104.05_linux.run
35 | sudo sh cuda_12.2.2_535.104.05_linux.run
36 | echo -e "export LD_LIBRARY_PATH=/usr/local/cuda-12.2/lib64" >> ~/.bashrc
37 | source ~/.bashrc
38 | ```
39 |
40 | ### Install packages
41 |
42 | Install PyTorch, Hugging Face transformer, and other libraries as follows.
43 |
44 | ```
45 | # install and upgrade pip
46 | sudo apt-get install -y python3-pip
47 | sudo -H pip3 install --upgrade pip
48 | # install packages
49 | pip3 install torch transformers pandas matplotlib
50 | # install jupyter for running notebook
51 | pip3 install jupyter
52 | ```
53 |
54 | ## 2. Fine-tune (Train)
55 |
56 | Download this repository.
57 |
58 | ```
59 | git clone https://github.com/tsmatz/finetune_llm_with_lora
60 | ```
61 |
62 | Run jupyter notebook.
63 |
64 | ```
65 | jupyter notebook
66 | ```
67 |
68 | Open jupyter notebook in browser, and run examples in this repository.
69 |
--------------------------------------------------------------------------------
/images/auto_regressive_transformer.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tsmatz/finetune_llm_with_lora/2e84a5e9e5095aaeacacaa723ee5a7c34c36b678/images/auto_regressive_transformer.png
--------------------------------------------------------------------------------
/images/lora.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tsmatz/finetune_llm_with_lora/2e84a5e9e5095aaeacacaa723ee5a7c34c36b678/images/lora.png
--------------------------------------------------------------------------------