"
433 | ]
434 | },
435 | "metadata": {},
436 | "output_type": "display_data"
437 | }
438 | ],
439 | "source": [
440 | "plt.figure(figsize=(7.5, 4.25))\n",
441 | "for i, (title, param1, param2) in enumerate(tqdm([(\"$K$\", K1, K2), (\"$V$\", V1, V2), (\"$W_K$\", WK1, WK2), \n",
442 | " (\"$W_Q$\", WQ1, WQ2), (\"$W_V$\", WV1, WV2), (\"$W_O$\", WO1, WO2)])):\n",
443 | " S = corr_coef(param1.to(device), param2.to(device))\n",
444 | " layer_size = param1.shape[0] // num_layers\n",
445 | " S_agg = S.view(num_layers, layer_size, num_layers, layer_size).abs().mean([-1, -3]).cpu().numpy()\n",
446 | " _, edges = linear_sum_assignment(-S_agg)\n",
447 | " tmp_plot = torch.zeros(num_layers, num_layers)\n",
448 | " tmp_plot[torch.arange(num_layers), edges] = 1\n",
449 | " plt.subplot(2, 3, i+1)\n",
450 | " plt.title(title)\n",
451 | " sns.heatmap(tmp_plot, cbar=False)\n",
452 | " plt.xticks([])\n",
453 | " plt.yticks([])\n",
454 | "plt.savefig(\"artifacts/all_diagonals_raw.pdf\") "
455 | ]
456 | },
457 | {
458 | "cell_type": "code",
459 | "execution_count": null,
460 | "id": "ddb9ac26-ca92-4d3b-aeb2-28b86bb8cf84",
461 | "metadata": {},
462 | "outputs": [],
463 | "source": []
464 | }
465 | ],
466 | "metadata": {
467 | "kernelspec": {
468 | "display_name": "Python 3",
469 | "language": "python",
470 | "name": "python3"
471 | },
472 | "language_info": {
473 | "codemirror_mode": {
474 | "name": "ipython",
475 | "version": 3
476 | },
477 | "file_extension": ".py",
478 | "mimetype": "text/x-python",
479 | "name": "python",
480 | "nbconvert_exporter": "python",
481 | "pygments_lexer": "ipython3",
482 | "version": "3.8.13"
483 | }
484 | },
485 | "nbformat": 4,
486 | "nbformat_minor": 5
487 | }
488 |
--------------------------------------------------------------------------------
/sentiment-analysis.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "4c4c816e-f4ca-43ba-8000-c93942620147",
6 | "metadata": {},
7 | "source": [
8 | "## Init"
9 | ]
10 | },
11 | {
12 | "cell_type": "code",
13 | "execution_count": 1,
14 | "id": "0dd22ca7-dbdd-4855-83ac-405efbf1e408",
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "import torch\n",
19 | "from torch import nn\n",
20 | "import torch.nn.functional as F\n",
21 | "from copy import deepcopy\n",
22 | "from transformers import (AutoModelForMaskedLM, AutoModelForCausalLM, AutoTokenizer, AutoModelForTokenClassification,\n",
23 | " AutoModelForSequenceClassification, TrainingArguments, Trainer)\n",
24 | "from tqdm.auto import tqdm\n",
25 | "import numpy as np\n",
26 | "import matplotlib.pyplot as plt\n",
27 | "import seaborn as sns\n",
28 | "import json\n",
29 | "from tensorflow.keras.models import load_model\n",
30 | "from datasets import load_dataset, load_metric\n",
31 | "import os\n",
32 | "from utils import top_tokens\n",
33 | "from tabulate import tabulate"
34 | ]
35 | },
36 | {
37 | "cell_type": "code",
38 | "execution_count": 2,
39 | "id": "07fdd79f-7d70-4548-8b9c-909cc86d81f3",
40 | "metadata": {},
41 | "outputs": [],
42 | "source": [
43 | "tokenizer = AutoTokenizer.from_pretrained('gpt2') # ('bert-base-uncased') # get_multiberts_tokenizer()"
44 | ]
45 | },
46 | {
47 | "cell_type": "code",
48 | "execution_count": 3,
49 | "id": "f14b9d71-1578-402e-bb0e-e343f0cbd78a",
50 | "metadata": {},
51 | "outputs": [],
52 | "source": [
53 | "class Gpt2AvgClassifier(nn.Module):\n",
54 | " def __init__(self, name, freeze=None, num_labels=2):\n",
55 | " super().__init__()\n",
56 | " self.model = AutoModelForTokenClassification.from_pretrained(name, num_labels=num_labels)\n",
57 | " self.model.transformer.ln_f = nn.Identity(self.model.config.n_ctx)\n",
58 | " if freeze is not None:\n",
59 | " for n, p in self.named_parameters():\n",
60 | " p.requires_grad = False\n",
61 | " if len(n.split('.transformer.h.')) == 2 and n.endswith('.weight'):\n",
62 | " if int(n.split('.transformer.h.')[1].split('.')[0]) >= freeze:\n",
63 | " p.requires_grad = True\n",
64 | " print(n)\n",
65 | " if n.endswith('.classifier.weight'):\n",
66 | " p.requires_grad = True\n",
67 | " print(n)\n",
68 | " \n",
69 | " def forward(self, input_ids, labels, inputs_embeds=None):\n",
70 | " res = self.model(input_ids=input_ids, inputs_embeds=inputs_embeds)\n",
71 | " res.logits = res.logits.mean(dim=-2)\n",
72 | " res['loss'] = F.cross_entropy(res.logits.view(-1, res.logits.shape[-1]), labels.view(-1))\n",
73 | " return res"
74 | ]
75 | },
76 | {
77 | "cell_type": "markdown",
78 | "id": "8c4d2819-e575-4f98-83af-1cd016b88bce",
79 | "metadata": {},
80 | "source": [
81 | "### Initialize Models"
82 | ]
83 | },
84 | {
85 | "cell_type": "code",
86 | "execution_count": 4,
87 | "id": "601e4a50-cf00-4245-a62e-421113db425a",
88 | "metadata": {},
89 | "outputs": [],
90 | "source": [
91 | "freeze = 9 # number of layers to freeze"
92 | ]
93 | },
94 | {
95 | "cell_type": "code",
96 | "execution_count": 5,
97 | "id": "03ace372-e479-4caa-bcf3-68043ebeb0a0",
98 | "metadata": {},
99 | "outputs": [
100 | {
101 | "name": "stdout",
102 | "output_type": "stream",
103 | "text": [
104 | "['gpt2', 'gpt2-medium']\n"
105 | ]
106 | },
107 | {
108 | "name": "stderr",
109 | "output_type": "stream",
110 | "text": [
111 | "Some weights of GPT2ForTokenClassification were not initialized from the model checkpoint at gpt2 and are newly initialized: ['h.4.attn.masked_bias', 'h.6.attn.masked_bias', 'classifier.bias', 'h.5.attn.masked_bias', 'h.2.attn.masked_bias', 'h.8.attn.masked_bias', 'h.11.attn.masked_bias', 'h.7.attn.masked_bias', 'h.0.attn.masked_bias', 'classifier.weight', 'h.3.attn.masked_bias', 'h.9.attn.masked_bias', 'h.10.attn.masked_bias', 'h.1.attn.masked_bias']\n",
112 | "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
113 | ]
114 | },
115 | {
116 | "name": "stdout",
117 | "output_type": "stream",
118 | "text": [
119 | "model.transformer.h.9.ln_1.weight\n",
120 | "model.transformer.h.9.attn.c_attn.weight\n",
121 | "model.transformer.h.9.attn.c_proj.weight\n",
122 | "model.transformer.h.9.ln_2.weight\n",
123 | "model.transformer.h.9.mlp.c_fc.weight\n",
124 | "model.transformer.h.9.mlp.c_proj.weight\n",
125 | "model.transformer.h.10.ln_1.weight\n",
126 | "model.transformer.h.10.attn.c_attn.weight\n",
127 | "model.transformer.h.10.attn.c_proj.weight\n",
128 | "model.transformer.h.10.ln_2.weight\n",
129 | "model.transformer.h.10.mlp.c_fc.weight\n",
130 | "model.transformer.h.10.mlp.c_proj.weight\n",
131 | "model.transformer.h.11.ln_1.weight\n",
132 | "model.transformer.h.11.attn.c_attn.weight\n",
133 | "model.transformer.h.11.attn.c_proj.weight\n",
134 | "model.transformer.h.11.ln_2.weight\n",
135 | "model.transformer.h.11.mlp.c_fc.weight\n",
136 | "model.transformer.h.11.mlp.c_proj.weight\n",
137 | "model.classifier.weight\n"
138 | ]
139 | },
140 | {
141 | "name": "stderr",
142 | "output_type": "stream",
143 | "text": [
144 | "Some weights of GPT2ForSequenceClassification were not initialized from the model checkpoint at gpt2-medium and are newly initialized: ['score.weight']\n",
145 | "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
146 | ]
147 | }
148 | ],
149 | "source": [
150 | "model_paths = ['gpt2', 'gpt2-medium'] \n",
151 | "\n",
152 | "print(model_paths)\n",
153 | "\n",
154 | "model1 = Gpt2AvgClassifier(model_paths[0], freeze=freeze) # AutoModelForSequenceClassification.from_pretrained(model_paths[0])\n",
155 | "model2 = AutoModelForSequenceClassification.from_pretrained(model_paths[1])\n",
156 | "# we can use input embedding as the embedding matrices are tied\n",
157 | "emb1 = model1.model.get_input_embeddings().weight.T.cpu().detach() \n",
158 | "emb2 = model2.get_input_embeddings().weight.T.cpu().detach() \n",
159 | "num_layers1, hidden_dim1 = (model1.model.config.n_layer, model1.model.config.n_embd)\n",
160 | "num_layers2, hidden_dim2 = (model2.config.n_layer, model2.config.n_embd)"
161 | ]
162 | },
163 | {
164 | "cell_type": "markdown",
165 | "id": "77f0f339-4a8e-4780-be94-40a3ffd9cef2",
166 | "metadata": {},
167 | "source": [
168 | "## Sentiment Analysis Finetuning"
169 | ]
170 | },
171 | {
172 | "cell_type": "code",
173 | "execution_count": 6,
174 | "id": "48bcbc0b-3ee5-4494-9e92-1a65504e7870",
175 | "metadata": {},
176 | "outputs": [],
177 | "source": [
178 | "model = model1"
179 | ]
180 | },
181 | {
182 | "cell_type": "markdown",
183 | "id": "fe41f0eb-e58f-45f8-9cf3-f4aff4e6c100",
184 | "metadata": {},
185 | "source": [
186 | "### Preparing Data"
187 | ]
188 | },
189 | {
190 | "cell_type": "code",
191 | "execution_count": 7,
192 | "id": "11a17d09-8034-4df8-b1a0-62ff7814598f",
193 | "metadata": {},
194 | "outputs": [],
195 | "source": [
196 | "def tokenize_imdb(examples):\n",
197 | " return tokenizer(examples[\"text\"], truncation=True)"
198 | ]
199 | },
200 | {
201 | "cell_type": "code",
202 | "execution_count": 8,
203 | "id": "46bcd56a-e0f0-4bd4-abb7-7f43dde6e29f",
204 | "metadata": {},
205 | "outputs": [
206 | {
207 | "name": "stderr",
208 | "output_type": "stream",
209 | "text": [
210 | "Reusing dataset imdb (/home/guydar/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a)\n"
211 | ]
212 | },
213 | {
214 | "data": {
215 | "application/vnd.jupyter.widget-view+json": {
216 | "model_id": "be9861acc9c94a6c9bc9c1062336e129",
217 | "version_major": 2,
218 | "version_minor": 0
219 | },
220 | "text/plain": [
221 | " 0%| | 0/3 [00:00, ?it/s]"
222 | ]
223 | },
224 | "metadata": {},
225 | "output_type": "display_data"
226 | },
227 | {
228 | "name": "stderr",
229 | "output_type": "stream",
230 | "text": [
231 | "Loading cached processed dataset at /home/guydar/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a/cache-3524c89eaed1ab3e.arrow\n",
232 | "Loading cached processed dataset at /home/guydar/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a/cache-bd746f6e0438ac54.arrow\n",
233 | "Loading cached processed dataset at /home/guydar/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a/cache-fcdd099119f0e220.arrow\n",
234 | "Loading cached shuffled indices for dataset at /home/guydar/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a/cache-0f38f07d8eec9b87.arrow\n"
235 | ]
236 | }
237 | ],
238 | "source": [
239 | "imdb = load_dataset('imdb')\n",
240 | "imdb = imdb.map(tokenize_imdb, batched=False)\n",
241 | "imdb_train, imdb_val = imdb['train'].shuffle().select(range(3000)), imdb['test'].shuffle().select(range(500))\n"
242 | ]
243 | },
244 | {
245 | "cell_type": "markdown",
246 | "id": "46ca5ca0-367c-4a78-90cf-fe550e4dd346",
247 | "metadata": {},
248 | "source": [
249 | "### Training"
250 | ]
251 | },
252 | {
253 | "cell_type": "code",
254 | "execution_count": 9,
255 | "id": "3460654a-3f6b-4469-87ab-123ab78abccf",
256 | "metadata": {},
257 | "outputs": [],
258 | "source": [
259 | "metric = load_metric('accuracy')\n",
260 | "def compute_metrics(eval_pred):\n",
261 | " logits, labels = eval_pred\n",
262 | " predictions = np.argmax(logits, axis=-1)\n",
263 | " return metric.compute(predictions=predictions, references=labels)"
264 | ]
265 | },
266 | {
267 | "cell_type": "code",
268 | "execution_count": 10,
269 | "id": "3afd2645-9df9-4185-9219-619b6623b8cf",
270 | "metadata": {},
271 | "outputs": [],
272 | "source": [
273 | "os.environ[\"WANDB_DISABLED\"] = \"true\""
274 | ]
275 | },
276 | {
277 | "cell_type": "code",
278 | "execution_count": 11,
279 | "id": "66a63fc7-d14c-484b-b4cf-e7223fd23997",
280 | "metadata": {},
281 | "outputs": [
282 | {
283 | "name": "stderr",
284 | "output_type": "stream",
285 | "text": [
286 | "Using the `WAND_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).\n"
287 | ]
288 | }
289 | ],
290 | "source": [
291 | "train_args = TrainingArguments(learning_rate=1e-5, report_to=None, output_dir='trainer_output', \n",
292 | " per_device_eval_batch_size=1, per_device_train_batch_size=1, \n",
293 | " save_steps=False, evaluation_strategy='epoch', num_train_epochs=1)"
294 | ]
295 | },
296 | {
297 | "cell_type": "code",
298 | "execution_count": 12,
299 | "id": "38f57c62-7d19-40b0-bcd0-2a869ae0d272",
300 | "metadata": {},
301 | "outputs": [],
302 | "source": [
303 | "train_args._n_gpu = 1"
304 | ]
305 | },
306 | {
307 | "cell_type": "code",
308 | "execution_count": 13,
309 | "id": "a9a30474-c98a-46ff-9925-851652170702",
310 | "metadata": {},
311 | "outputs": [],
312 | "source": [
313 | "old_model = deepcopy(model)"
314 | ]
315 | },
316 | {
317 | "cell_type": "code",
318 | "execution_count": 14,
319 | "id": "432e8261-d5a1-468e-a62e-c1a79a5324be",
320 | "metadata": {},
321 | "outputs": [
322 | {
323 | "name": "stderr",
324 | "output_type": "stream",
325 | "text": [
326 | "The following columns in the training set don't have a corresponding argument in `Gpt2AvgClassifier.forward` and have been ignored: attention_mask, text. If attention_mask, text are not expected by `Gpt2AvgClassifier.forward`, you can safely ignore this message.\n",
327 | "/mnt/netapp7/dar/miniconda/lib/python3.8/site-packages/transformers/optimization.py:306: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
328 | " warnings.warn(\n",
329 | "***** Running training *****\n",
330 | " Num examples = 3000\n",
331 | " Num Epochs = 1\n",
332 | " Instantaneous batch size per device = 1\n",
333 | " Total train batch size (w. parallel, distributed & accumulation) = 1\n",
334 | " Gradient Accumulation steps = 1\n",
335 | " Total optimization steps = 3000\n"
336 | ]
337 | },
338 | {
339 | "data": {
340 | "text/html": [
341 | "\n",
342 | " \n",
343 | " \n",
344 | "
\n",
345 | " [3000/3000 01:33, Epoch 1/1]\n",
346 | "
\n",
347 | " \n",
348 | " \n",
349 | " \n",
350 | " Epoch | \n",
351 | " Training Loss | \n",
352 | " Validation Loss | \n",
353 | " Accuracy | \n",
354 | "
\n",
355 | " \n",
356 | " \n",
357 | " \n",
358 | " 1 | \n",
359 | " 0.636800 | \n",
360 | " 0.950526 | \n",
361 | " 0.832000 | \n",
362 | "
\n",
363 | " \n",
364 | "
"
365 | ],
366 | "text/plain": [
367 | ""
368 | ]
369 | },
370 | "metadata": {},
371 | "output_type": "display_data"
372 | },
373 | {
374 | "name": "stderr",
375 | "output_type": "stream",
376 | "text": [
377 | "The following columns in the evaluation set don't have a corresponding argument in `Gpt2AvgClassifier.forward` and have been ignored: attention_mask, text. If attention_mask, text are not expected by `Gpt2AvgClassifier.forward`, you can safely ignore this message.\n",
378 | "***** Running Evaluation *****\n",
379 | " Num examples = 500\n",
380 | " Batch size = 1\n",
381 | "\n",
382 | "\n",
383 | "Training completed. Do not forget to share your model on huggingface.co/models =)\n",
384 | "\n",
385 | "\n"
386 | ]
387 | },
388 | {
389 | "data": {
390 | "text/plain": [
391 | "TrainOutput(global_step=3000, training_loss=0.8896415710449219, metrics={'train_runtime': 93.5903, 'train_samples_per_second': 32.055, 'train_steps_per_second': 32.055, 'total_flos': 0.0, 'train_loss': 0.8896415710449219, 'epoch': 1.0})"
392 | ]
393 | },
394 | "execution_count": 14,
395 | "metadata": {},
396 | "output_type": "execute_result"
397 | }
398 | ],
399 | "source": [
400 | "trainer = Trainer(model1, args=train_args, train_dataset=imdb_train, eval_dataset=imdb_val, \n",
401 | " compute_metrics=compute_metrics)\n",
402 | "trainer.train()"
403 | ]
404 | },
405 | {
406 | "cell_type": "markdown",
407 | "id": "6685ad43-e590-4e53-b776-cce678426ca8",
408 | "metadata": {},
409 | "source": [
410 | "### Visualize Finetuning Vectors"
411 | ]
412 | },
413 | {
414 | "cell_type": "code",
415 | "execution_count": 15,
416 | "id": "e7cc0792-d95a-4418-a58d-5999a2045af3",
417 | "metadata": {},
418 | "outputs": [],
419 | "source": [
420 | "diff_classifier = (model.model.classifier.weight.cpu() - old_model.model.classifier.weight.cpu()).detach()\n",
421 | "# diff_classifier = model.score.weight.detach().cpu() - old_model.score.weight.detach()\n",
422 | "# diff_classifier = model.classifier.weight.detach().cpu() - old_model.classifier.weight.detach()"
423 | ]
424 | },
425 | {
426 | "cell_type": "code",
427 | "execution_count": 16,
428 | "id": "1e949f1f-b149-4ee3-b521-d7f262aa563e",
429 | "metadata": {},
430 | "outputs": [],
431 | "source": [
432 | "neg_vector = diff_classifier[0, :]\n",
433 | "pos_vector = diff_classifier[1, :]"
434 | ]
435 | },
436 | {
437 | "cell_type": "code",
438 | "execution_count": 17,
439 | "id": "557b80ff-3285-4ba3-9397-0ac00d46ade9",
440 | "metadata": {},
441 | "outputs": [
442 | {
443 | "name": "stdout",
444 | "output_type": "stream",
445 | "text": [
446 | "POSITIVE NEGATIVE\n",
447 | "---------- ------------\n",
448 | "#iscover bullshit\n",
449 | "honoured shitty\n",
450 | "pioneers crap\n",
451 | "#knit crappy\n",
452 | "#izons incompetence\n",
453 | "#Vers incompetent\n",
454 | "#raits pointless\n",
455 | "pioneer retarded\n",
456 | "#elight worse\n",
457 | "enchant FUCK\n",
458 | "#Together idiots\n",
459 | "reunited useless\n",
460 | "powerfully fuck\n",
461 | "#joy worthless\n",
462 | "Together garbage\n",
463 | "pioneering inco\n",
464 | "passions #Fuck\n",
465 | "timeless lame\n",
466 | "lively shit\n",
467 | "#inguished stupid\n",
468 | "insepar pathetic\n",
469 | "#Join inept\n",
470 | "renowned #shit\n",
471 | "unmatched piss\n",
472 | "#Born asshole\n",
473 | "#ossom Worse\n",
474 | "welcomes poorly\n",
475 | "Selected awful\n",
476 | "#anqu stupidity\n",
477 | "#Discover ineffective\n"
478 | ]
479 | }
480 | ],
481 | "source": [
482 | "print(tabulate(\n",
483 | " [*zip(*[top_tokens(pos_vector @ emb1, k=30, only_ascii=True, tokenizer=tokenizer),\n",
484 | " top_tokens(neg_vector @ emb1, k=30, only_ascii=True, tokenizer=tokenizer)])],\n",
485 | " headers=['POSITIVE', 'NEGATIVE']))"
486 | ]
487 | },
488 | {
489 | "cell_type": "code",
490 | "execution_count": 18,
491 | "id": "b63e1491-c1d9-46f9-87e2-c5e8de97b44c",
492 | "metadata": {},
493 | "outputs": [],
494 | "source": [
495 | "i1 = 11 # this is the layer we visualize"
496 | ]
497 | },
498 | {
499 | "cell_type": "code",
500 | "execution_count": 19,
501 | "id": "83f6a687-7fa9-4e42-82f9-a8f72298d85b",
502 | "metadata": {},
503 | "outputs": [],
504 | "source": [
505 | "diff_K = (model.model.transformer.h[i1].mlp.c_fc.weight.cpu() - old_model.model.transformer.h[i1].mlp.c_fc.weight.cpu()).T\n",
506 | "diff_V = (model.model.transformer.h[i1].mlp.c_proj.weight.cpu() - old_model.model.transformer.h[i1].mlp.c_proj.weight.cpu())\n",
507 | "diff_WQ, diff_WK, diff_WV = ((model.model.transformer.h[i1].attn.c_attn.weight.cpu() - \n",
508 | " old_model.model.transformer.h[i1].attn.c_attn.weight.cpu()).T.chunk(3))\n",
509 | "diff_WO = (model.model.transformer.h[i1].attn.c_proj.weight.cpu() - old_model.model.transformer.h[i1].attn.c_proj.weight.cpu())"
510 | ]
511 | },
512 | {
513 | "cell_type": "code",
514 | "execution_count": 20,
515 | "id": "8977bab0-6c83-4ac3-a842-d6d7c93e3c96",
516 | "metadata": {},
517 | "outputs": [],
518 | "source": [
519 | "diff_param = diff_WV"
520 | ]
521 | },
522 | {
523 | "cell_type": "code",
524 | "execution_count": 21,
525 | "id": "2c76f6cb-8acd-4753-b0d9-50b59748cd74",
526 | "metadata": {},
527 | "outputs": [],
528 | "source": [
529 | "i2 = np.random.randint(diff_param.shape[0]) # index of vector in the parameter"
530 | ]
531 | },
532 | {
533 | "cell_type": "code",
534 | "execution_count": 22,
535 | "id": "e8840f5a-3b4e-4b6a-8cc1-06752ec66c0a",
536 | "metadata": {},
537 | "outputs": [
538 | {
539 | "name": "stdout",
540 | "output_type": "stream",
541 | "text": [
542 | "diff -diff\n",
543 | "------------ -------------\n",
544 | "incompetence unforgettable\n",
545 | "bullshit beautifully\n",
546 | "ineffective wonderfully\n",
547 | "worthless vividly\n",
548 | "bogus memorable\n",
549 | "incompetent thrilling\n",
550 | "useless delight\n",
551 | "retarded enjoyed\n",
552 | "retard timeless\n",
553 | "shitty superb\n",
554 | "worse wonderful\n",
555 | "idiots poignant\n",
556 | "#Fuck immensely\n",
557 | "Worse exhilar\n",
558 | "blame inspiring\n",
559 | "nonexistent delightful\n",
560 | "unus #love\n",
561 | "ineligible lively\n",
562 | "quotas vivid\n",
563 | "inco fascinating\n"
564 | ]
565 | }
566 | ],
567 | "source": [
568 | "print(tabulate(zip(*[top_tokens(diff_param[i2].detach() @ emb1, k=20, only_ascii=True, tokenizer=tokenizer),\n",
569 | " top_tokens(-diff_param[i2].detach() @ emb1, k=20, only_ascii=True, tokenizer=tokenizer)]), \n",
570 | " headers=[\"diff\", \"-diff\"]))"
571 | ]
572 | },
573 | {
574 | "cell_type": "markdown",
575 | "id": "cd0c489b-62dd-442c-b7ed-4a2cbdf66fe4",
576 | "metadata": {},
577 | "source": [
578 | "## Model Stitching"
579 | ]
580 | },
581 | {
582 | "cell_type": "code",
583 | "execution_count": 23,
584 | "id": "90c6d7c7-d95f-4acd-b4fe-08e4da1bac07",
585 | "metadata": {},
586 | "outputs": [],
587 | "source": [
588 | "def subtract_modules(mod1, mod2, subtract_ln=False, only_weight=False):\n",
589 | " mod_new = deepcopy(mod1)\n",
590 | " with torch.no_grad():\n",
591 | " for n, p in mod_new.named_parameters():\n",
592 | " if only_weight and not n.endswith('.weight'):\n",
593 | " continue\n",
594 | " submodule_name = n.rsplit('.', 1)[0] if '.' in n else ''\n",
595 | " is_ln = isinstance(mod_new.get_submodule(submodule_name), nn.LayerNorm)\n",
596 | " if (not is_ln) or subtract_ln:\n",
597 | " p.set_(p.data - mod2.get_parameter(n).data)\n",
598 | " return mod_new"
599 | ]
600 | },
601 | {
602 | "cell_type": "code",
603 | "execution_count": 24,
604 | "id": "c78ba6e8-bdd9-4fbf-97a9-d8f4b2b3e94f",
605 | "metadata": {},
606 | "outputs": [],
607 | "source": [
608 | "class StitchedTransformers(nn.Module):\n",
609 | " def __init__(self, old_model, model1, model2, kernel, num_keep_layers, num_transplanted_layers,\n",
610 | " subtract=True, **subtract_args):\n",
611 | " super().__init__()\n",
612 | " self.model2 = deepcopy(model2) \n",
613 | " self.model2.transformer.h = nn.ModuleList(self.model2.transformer.h[:num_keep_layers])\n",
614 | " self.register_buffer(\"stitching_kernel\", kernel) \n",
615 | " self.model1 = deepcopy(model1)\n",
616 | " offset = len(model1.model.transformer.h) - num_transplanted_layers\n",
617 | " self.model1.model.transformer.h = nn.ModuleList([\n",
618 | " subtract_modules(model1.model.transformer.h[offset + i], \n",
619 | " old_model.model.transformer.h[offset + i], \n",
620 | " **subtract_args) if subtract else model1.model.transformer.h[offset + i]\n",
621 | " for i in range(num_transplanted_layers)])\n",
622 | " self.model1.model.classifier = (\n",
623 | " subtract_modules(model1.model.classifier, old_model.model.classifier, **subtract_args) \n",
624 | " if subtract else model1.model.classifier\n",
625 | " )\n",
626 | " \n",
627 | " def forward(self, input_ids, labels):\n",
628 | " x = self.model2(input_ids, output_hidden_states=True).hidden_states[-1]\n",
629 | " x = x @ self.stitching_kernel\n",
630 | " res = self.model1(input_ids=None, inputs_embeds=x, labels=labels)\n",
631 | " res = {'loss': res['loss'], 'logits': res['logits']}\n",
632 | " return res"
633 | ]
634 | },
635 | {
636 | "cell_type": "code",
637 | "execution_count": 25,
638 | "id": "f57515ac-730f-4a8a-b4c9-2d65dcfb33d1",
639 | "metadata": {},
640 | "outputs": [],
641 | "source": [
642 | "extended = False\n",
643 | "kernel = emb_extended2 @ (emb_extended1).pinverse() if extended else emb2 @ (emb1).pinverse()\n",
644 | "# + .1 * torch.eye(1024, 768)"
645 | ]
646 | },
647 | {
648 | "cell_type": "code",
649 | "execution_count": 26,
650 | "id": "b724a15a-bf85-4abb-baa1-33c380c13d91",
651 | "metadata": {},
652 | "outputs": [],
653 | "source": [
654 | "subtract = False"
655 | ]
656 | },
657 | {
658 | "cell_type": "code",
659 | "execution_count": 27,
660 | "id": "f5228ca9-3026-4338-973a-f151ea6b5fc8",
661 | "metadata": {},
662 | "outputs": [],
663 | "source": [
664 | "num_transplanted_layers = 3\n",
665 | "num_keep_layers = 14"
666 | ]
667 | },
668 | {
669 | "cell_type": "markdown",
670 | "id": "a5415437-7726-43a3-9823-8af3e9a6ab81",
671 | "metadata": {},
672 | "source": [
673 | "### Evaluate"
674 | ]
675 | },
676 | {
677 | "cell_type": "code",
678 | "execution_count": 28,
679 | "id": "f0c80b60-6c1a-442b-9e64-0aefda4318ab",
680 | "metadata": {},
681 | "outputs": [],
682 | "source": [
683 | "stitched_model = StitchedTransformers(old_model.cuda(), model1, model2, kernel, \n",
684 | " num_keep_layers, num_transplanted_layers, subtract=subtract).cpu()"
685 | ]
686 | },
687 | {
688 | "cell_type": "code",
689 | "execution_count": 29,
690 | "id": "271fa53f-ff22-43e6-a471-945629f46082",
691 | "metadata": {},
692 | "outputs": [
693 | {
694 | "name": "stderr",
695 | "output_type": "stream",
696 | "text": [
697 | "The following columns in the evaluation set don't have a corresponding argument in `StitchedTransformers.forward` and have been ignored: attention_mask, text. If attention_mask, text are not expected by `StitchedTransformers.forward`, you can safely ignore this message.\n",
698 | "***** Running Evaluation *****\n",
699 | " Num examples = 500\n",
700 | " Batch size = 1\n"
701 | ]
702 | },
703 | {
704 | "data": {
705 | "text/html": [
706 | "\n",
707 | " \n",
708 | " \n",
709 | "
\n",
710 | " [500/500 00:13]\n",
711 | "
\n",
712 | " "
713 | ],
714 | "text/plain": [
715 | ""
716 | ]
717 | },
718 | "metadata": {},
719 | "output_type": "display_data"
720 | },
721 | {
722 | "data": {
723 | "text/plain": [
724 | "{'eval_loss': 10.434187889099121,\n",
725 | " 'eval_accuracy': 0.462,\n",
726 | " 'eval_runtime': 13.4261,\n",
727 | " 'eval_samples_per_second': 37.241,\n",
728 | " 'eval_steps_per_second': 37.241}"
729 | ]
730 | },
731 | "execution_count": 29,
732 | "metadata": {},
733 | "output_type": "execute_result"
734 | }
735 | ],
736 | "source": [
737 | "trainer_stitched = Trainer(stitched_model, args=train_args, train_dataset=imdb_train, eval_dataset=imdb_val, \n",
738 | " compute_metrics=compute_metrics)\n",
739 | "trainer_stitched.evaluate()"
740 | ]
741 | },
742 | {
743 | "cell_type": "markdown",
744 | "id": "4b9fe505-2e9c-4eea-9799-e50ff6a4b5a1",
745 | "metadata": {},
746 | "source": [
747 | "#### Plot All"
748 | ]
749 | },
750 | {
751 | "cell_type": "code",
752 | "execution_count": 39,
753 | "id": "febb237a-7c52-4c23-ab44-15909508bf04",
754 | "metadata": {
755 | "tags": []
756 | },
757 | "outputs": [
758 | {
759 | "name": "stderr",
760 | "output_type": "stream",
761 | "text": [
762 | "The following columns in the evaluation set don't have a corresponding argument in `StitchedTransformers.forward` and have been ignored: attention_mask, text. If attention_mask, text are not expected by `StitchedTransformers.forward`, you can safely ignore this message.\n",
763 | "***** Running Evaluation *****\n",
764 | " Num examples = 500\n",
765 | " Batch size = 1\n"
766 | ]
767 | },
768 | {
769 | "data": {
770 | "text/html": [
771 | "\n",
772 | " \n",
773 | " \n",
774 | "
\n",
775 | " [500/500 00:02]\n",
776 | "
\n",
777 | " "
778 | ],
779 | "text/plain": [
780 | ""
781 | ]
782 | },
783 | "metadata": {},
784 | "output_type": "display_data"
785 | },
786 | {
787 | "name": "stderr",
788 | "output_type": "stream",
789 | "text": [
790 | "The following columns in the evaluation set don't have a corresponding argument in `StitchedTransformers.forward` and have been ignored: attention_mask, text. If attention_mask, text are not expected by `StitchedTransformers.forward`, you can safely ignore this message.\n",
791 | "***** Running Evaluation *****\n",
792 | " Num examples = 500\n",
793 | " Batch size = 1\n"
794 | ]
795 | },
796 | {
797 | "data": {
798 | "text/html": [
799 | "\n",
800 | " \n",
801 | " \n",
802 | "
\n",
803 | " [500/500 00:03]\n",
804 | "
\n",
805 | " "
806 | ],
807 | "text/plain": [
808 | ""
809 | ]
810 | },
811 | "metadata": {},
812 | "output_type": "display_data"
813 | },
814 | {
815 | "name": "stderr",
816 | "output_type": "stream",
817 | "text": [
818 | "The following columns in the evaluation set don't have a corresponding argument in `StitchedTransformers.forward` and have been ignored: attention_mask, text. If attention_mask, text are not expected by `StitchedTransformers.forward`, you can safely ignore this message.\n",
819 | "***** Running Evaluation *****\n",
820 | " Num examples = 500\n",
821 | " Batch size = 1\n"
822 | ]
823 | },
824 | {
825 | "data": {
826 | "text/html": [
827 | "\n",
828 | " \n",
829 | " \n",
830 | "
\n",
831 | " [500/500 00:03]\n",
832 | "
\n",
833 | " "
834 | ],
835 | "text/plain": [
836 | ""
837 | ]
838 | },
839 | "metadata": {},
840 | "output_type": "display_data"
841 | },
842 | {
843 | "name": "stderr",
844 | "output_type": "stream",
845 | "text": [
846 | "The following columns in the evaluation set don't have a corresponding argument in `StitchedTransformers.forward` and have been ignored: attention_mask, text. If attention_mask, text are not expected by `StitchedTransformers.forward`, you can safely ignore this message.\n",
847 | "***** Running Evaluation *****\n",
848 | " Num examples = 500\n",
849 | " Batch size = 1\n"
850 | ]
851 | },
852 | {
853 | "data": {
854 | "text/html": [
855 | "\n",
856 | " \n",
857 | " \n",
858 | "
\n",
859 | " [500/500 00:04]\n",
860 | "
\n",
861 | " "
862 | ],
863 | "text/plain": [
864 | ""
865 | ]
866 | },
867 | "metadata": {},
868 | "output_type": "display_data"
869 | },
870 | {
871 | "name": "stderr",
872 | "output_type": "stream",
873 | "text": [
874 | "The following columns in the evaluation set don't have a corresponding argument in `StitchedTransformers.forward` and have been ignored: attention_mask, text. If attention_mask, text are not expected by `StitchedTransformers.forward`, you can safely ignore this message.\n",
875 | "***** Running Evaluation *****\n",
876 | " Num examples = 500\n",
877 | " Batch size = 1\n"
878 | ]
879 | },
880 | {
881 | "data": {
882 | "text/html": [
883 | "\n",
884 | " \n",
885 | " \n",
886 | "
\n",
887 | " [500/500 00:05]\n",
888 | "
\n",
889 | " "
890 | ],
891 | "text/plain": [
892 | ""
893 | ]
894 | },
895 | "metadata": {},
896 | "output_type": "display_data"
897 | },
898 | {
899 | "name": "stderr",
900 | "output_type": "stream",
901 | "text": [
902 | "The following columns in the evaluation set don't have a corresponding argument in `StitchedTransformers.forward` and have been ignored: attention_mask, text. If attention_mask, text are not expected by `StitchedTransformers.forward`, you can safely ignore this message.\n",
903 | "***** Running Evaluation *****\n",
904 | " Num examples = 500\n",
905 | " Batch size = 1\n"
906 | ]
907 | },
908 | {
909 | "data": {
910 | "text/html": [
911 | "\n",
912 | " \n",
913 | " \n",
914 | "
\n",
915 | " [500/500 00:06]\n",
916 | "
\n",
917 | " "
918 | ],
919 | "text/plain": [
920 | ""
921 | ]
922 | },
923 | "metadata": {},
924 | "output_type": "display_data"
925 | },
926 | {
927 | "name": "stderr",
928 | "output_type": "stream",
929 | "text": [
930 | "The following columns in the evaluation set don't have a corresponding argument in `StitchedTransformers.forward` and have been ignored: attention_mask, text. If attention_mask, text are not expected by `StitchedTransformers.forward`, you can safely ignore this message.\n",
931 | "***** Running Evaluation *****\n",
932 | " Num examples = 500\n",
933 | " Batch size = 1\n"
934 | ]
935 | },
936 | {
937 | "data": {
938 | "text/html": [
939 | "\n",
940 | " \n",
941 | " \n",
942 | "
\n",
943 | " [500/500 00:07]\n",
944 | "
\n",
945 | " "
946 | ],
947 | "text/plain": [
948 | ""
949 | ]
950 | },
951 | "metadata": {},
952 | "output_type": "display_data"
953 | },
954 | {
955 | "name": "stderr",
956 | "output_type": "stream",
957 | "text": [
958 | "The following columns in the evaluation set don't have a corresponding argument in `StitchedTransformers.forward` and have been ignored: attention_mask, text. If attention_mask, text are not expected by `StitchedTransformers.forward`, you can safely ignore this message.\n",
959 | "***** Running Evaluation *****\n",
960 | " Num examples = 500\n",
961 | " Batch size = 1\n"
962 | ]
963 | },
964 | {
965 | "data": {
966 | "text/html": [
967 | "\n",
968 | " \n",
969 | " \n",
970 | "
\n",
971 | " [500/500 00:08]\n",
972 | "
\n",
973 | " "
974 | ],
975 | "text/plain": [
976 | ""
977 | ]
978 | },
979 | "metadata": {},
980 | "output_type": "display_data"
981 | },
982 | {
983 | "name": "stderr",
984 | "output_type": "stream",
985 | "text": [
986 | "The following columns in the evaluation set don't have a corresponding argument in `StitchedTransformers.forward` and have been ignored: attention_mask, text. If attention_mask, text are not expected by `StitchedTransformers.forward`, you can safely ignore this message.\n",
987 | "***** Running Evaluation *****\n",
988 | " Num examples = 500\n",
989 | " Batch size = 1\n"
990 | ]
991 | },
992 | {
993 | "data": {
994 | "text/html": [
995 | "\n",
996 | " \n",
997 | " \n",
998 | "
\n",
999 | " [500/500 00:08]\n",
1000 | "
\n",
1001 | " "
1002 | ],
1003 | "text/plain": [
1004 | ""
1005 | ]
1006 | },
1007 | "metadata": {},
1008 | "output_type": "display_data"
1009 | },
1010 | {
1011 | "name": "stderr",
1012 | "output_type": "stream",
1013 | "text": [
1014 | "The following columns in the evaluation set don't have a corresponding argument in `StitchedTransformers.forward` and have been ignored: attention_mask, text. If attention_mask, text are not expected by `StitchedTransformers.forward`, you can safely ignore this message.\n",
1015 | "***** Running Evaluation *****\n",
1016 | " Num examples = 500\n",
1017 | " Batch size = 1\n"
1018 | ]
1019 | },
1020 | {
1021 | "data": {
1022 | "text/html": [
1023 | "\n",
1024 | " \n",
1025 | " \n",
1026 | "
\n",
1027 | " [500/500 00:09]\n",
1028 | "
\n",
1029 | " "
1030 | ],
1031 | "text/plain": [
1032 | ""
1033 | ]
1034 | },
1035 | "metadata": {},
1036 | "output_type": "display_data"
1037 | },
1038 | {
1039 | "name": "stderr",
1040 | "output_type": "stream",
1041 | "text": [
1042 | "The following columns in the evaluation set don't have a corresponding argument in `StitchedTransformers.forward` and have been ignored: attention_mask, text. If attention_mask, text are not expected by `StitchedTransformers.forward`, you can safely ignore this message.\n",
1043 | "***** Running Evaluation *****\n",
1044 | " Num examples = 500\n",
1045 | " Batch size = 1\n"
1046 | ]
1047 | },
1048 | {
1049 | "data": {
1050 | "text/html": [
1051 | "\n",
1052 | " \n",
1053 | " \n",
1054 | "
\n",
1055 | " [500/500 00:10]\n",
1056 | "
\n",
1057 | " "
1058 | ],
1059 | "text/plain": [
1060 | ""
1061 | ]
1062 | },
1063 | "metadata": {},
1064 | "output_type": "display_data"
1065 | },
1066 | {
1067 | "name": "stderr",
1068 | "output_type": "stream",
1069 | "text": [
1070 | "The following columns in the evaluation set don't have a corresponding argument in `StitchedTransformers.forward` and have been ignored: attention_mask, text. If attention_mask, text are not expected by `StitchedTransformers.forward`, you can safely ignore this message.\n",
1071 | "***** Running Evaluation *****\n",
1072 | " Num examples = 500\n",
1073 | " Batch size = 1\n"
1074 | ]
1075 | },
1076 | {
1077 | "data": {
1078 | "text/html": [
1079 | "\n",
1080 | " \n",
1081 | " \n",
1082 | "
\n",
1083 | " [500/500 00:11]\n",
1084 | "
\n",
1085 | " "
1086 | ],
1087 | "text/plain": [
1088 | ""
1089 | ]
1090 | },
1091 | "metadata": {},
1092 | "output_type": "display_data"
1093 | },
1094 | {
1095 | "name": "stderr",
1096 | "output_type": "stream",
1097 | "text": [
1098 | "The following columns in the evaluation set don't have a corresponding argument in `StitchedTransformers.forward` and have been ignored: attention_mask, text. If attention_mask, text are not expected by `StitchedTransformers.forward`, you can safely ignore this message.\n",
1099 | "***** Running Evaluation *****\n",
1100 | " Num examples = 500\n",
1101 | " Batch size = 1\n"
1102 | ]
1103 | },
1104 | {
1105 | "data": {
1106 | "text/html": [
1107 | "\n",
1108 | " \n",
1109 | " \n",
1110 | "
\n",
1111 | " [500/500 00:12]\n",
1112 | "
\n",
1113 | " "
1114 | ],
1115 | "text/plain": [
1116 | ""
1117 | ]
1118 | },
1119 | "metadata": {},
1120 | "output_type": "display_data"
1121 | },
1122 | {
1123 | "name": "stderr",
1124 | "output_type": "stream",
1125 | "text": [
1126 | "The following columns in the evaluation set don't have a corresponding argument in `StitchedTransformers.forward` and have been ignored: attention_mask, text. If attention_mask, text are not expected by `StitchedTransformers.forward`, you can safely ignore this message.\n",
1127 | "***** Running Evaluation *****\n",
1128 | " Num examples = 500\n",
1129 | " Batch size = 1\n"
1130 | ]
1131 | },
1132 | {
1133 | "data": {
1134 | "text/html": [
1135 | "\n",
1136 | " \n",
1137 | " \n",
1138 | "
\n",
1139 | " [500/500 00:13]\n",
1140 | "
\n",
1141 | " "
1142 | ],
1143 | "text/plain": [
1144 | ""
1145 | ]
1146 | },
1147 | "metadata": {},
1148 | "output_type": "display_data"
1149 | },
1150 | {
1151 | "name": "stderr",
1152 | "output_type": "stream",
1153 | "text": [
1154 | "The following columns in the evaluation set don't have a corresponding argument in `StitchedTransformers.forward` and have been ignored: attention_mask, text. If attention_mask, text are not expected by `StitchedTransformers.forward`, you can safely ignore this message.\n",
1155 | "***** Running Evaluation *****\n",
1156 | " Num examples = 500\n",
1157 | " Batch size = 1\n"
1158 | ]
1159 | },
1160 | {
1161 | "data": {
1162 | "text/html": [
1163 | "\n",
1164 | " \n",
1165 | " \n",
1166 | "
\n",
1167 | " [500/500 00:13]\n",
1168 | "
\n",
1169 | " "
1170 | ],
1171 | "text/plain": [
1172 | ""
1173 | ]
1174 | },
1175 | "metadata": {},
1176 | "output_type": "display_data"
1177 | },
1178 | {
1179 | "name": "stderr",
1180 | "output_type": "stream",
1181 | "text": [
1182 | "The following columns in the evaluation set don't have a corresponding argument in `StitchedTransformers.forward` and have been ignored: attention_mask, text. If attention_mask, text are not expected by `StitchedTransformers.forward`, you can safely ignore this message.\n",
1183 | "***** Running Evaluation *****\n",
1184 | " Num examples = 500\n",
1185 | " Batch size = 1\n"
1186 | ]
1187 | },
1188 | {
1189 | "data": {
1190 | "text/html": [
1191 | "\n",
1192 | " \n",
1193 | " \n",
1194 | "
\n",
1195 | " [500/500 00:14]\n",
1196 | "
\n",
1197 | " "
1198 | ],
1199 | "text/plain": [
1200 | ""
1201 | ]
1202 | },
1203 | "metadata": {},
1204 | "output_type": "display_data"
1205 | },
1206 | {
1207 | "name": "stderr",
1208 | "output_type": "stream",
1209 | "text": [
1210 | "The following columns in the evaluation set don't have a corresponding argument in `StitchedTransformers.forward` and have been ignored: attention_mask, text. If attention_mask, text are not expected by `StitchedTransformers.forward`, you can safely ignore this message.\n",
1211 | "***** Running Evaluation *****\n",
1212 | " Num examples = 500\n",
1213 | " Batch size = 1\n"
1214 | ]
1215 | },
1216 | {
1217 | "data": {
1218 | "text/html": [
1219 | "\n",
1220 | " \n",
1221 | " \n",
1222 | "
\n",
1223 | " [500/500 00:15]\n",
1224 | "
\n",
1225 | " "
1226 | ],
1227 | "text/plain": [
1228 | ""
1229 | ]
1230 | },
1231 | "metadata": {},
1232 | "output_type": "display_data"
1233 | },
1234 | {
1235 | "name": "stderr",
1236 | "output_type": "stream",
1237 | "text": [
1238 | "The following columns in the evaluation set don't have a corresponding argument in `StitchedTransformers.forward` and have been ignored: attention_mask, text. If attention_mask, text are not expected by `StitchedTransformers.forward`, you can safely ignore this message.\n",
1239 | "***** Running Evaluation *****\n",
1240 | " Num examples = 500\n",
1241 | " Batch size = 1\n"
1242 | ]
1243 | },
1244 | {
1245 | "data": {
1246 | "text/html": [
1247 | "\n",
1248 | " \n",
1249 | " \n",
1250 | "
\n",
1251 | " [500/500 00:16]\n",
1252 | "
\n",
1253 | " "
1254 | ],
1255 | "text/plain": [
1256 | ""
1257 | ]
1258 | },
1259 | "metadata": {},
1260 | "output_type": "display_data"
1261 | },
1262 | {
1263 | "name": "stderr",
1264 | "output_type": "stream",
1265 | "text": [
1266 | "The following columns in the evaluation set don't have a corresponding argument in `StitchedTransformers.forward` and have been ignored: attention_mask, text. If attention_mask, text are not expected by `StitchedTransformers.forward`, you can safely ignore this message.\n",
1267 | "***** Running Evaluation *****\n",
1268 | " Num examples = 500\n",
1269 | " Batch size = 1\n"
1270 | ]
1271 | },
1272 | {
1273 | "data": {
1274 | "text/html": [
1275 | "\n",
1276 | " \n",
1277 | " \n",
1278 | "
\n",
1279 | " [500/500 00:17]\n",
1280 | "
\n",
1281 | " "
1282 | ],
1283 | "text/plain": [
1284 | ""
1285 | ]
1286 | },
1287 | "metadata": {},
1288 | "output_type": "display_data"
1289 | },
1290 | {
1291 | "name": "stderr",
1292 | "output_type": "stream",
1293 | "text": [
1294 | "The following columns in the evaluation set don't have a corresponding argument in `StitchedTransformers.forward` and have been ignored: attention_mask, text. If attention_mask, text are not expected by `StitchedTransformers.forward`, you can safely ignore this message.\n",
1295 | "***** Running Evaluation *****\n",
1296 | " Num examples = 500\n",
1297 | " Batch size = 1\n"
1298 | ]
1299 | },
1300 | {
1301 | "data": {
1302 | "text/html": [
1303 | "\n",
1304 | " \n",
1305 | " \n",
1306 | "
\n",
1307 | " [500/500 00:17]\n",
1308 | "
\n",
1309 | " "
1310 | ],
1311 | "text/plain": [
1312 | ""
1313 | ]
1314 | },
1315 | "metadata": {},
1316 | "output_type": "display_data"
1317 | },
1318 | {
1319 | "name": "stderr",
1320 | "output_type": "stream",
1321 | "text": [
1322 | "The following columns in the evaluation set don't have a corresponding argument in `StitchedTransformers.forward` and have been ignored: attention_mask, text. If attention_mask, text are not expected by `StitchedTransformers.forward`, you can safely ignore this message.\n",
1323 | "***** Running Evaluation *****\n",
1324 | " Num examples = 500\n",
1325 | " Batch size = 1\n"
1326 | ]
1327 | },
1328 | {
1329 | "data": {
1330 | "text/html": [
1331 | "\n",
1332 | " \n",
1333 | " \n",
1334 | "
\n",
1335 | " [500/500 00:18]\n",
1336 | "
\n",
1337 | " "
1338 | ],
1339 | "text/plain": [
1340 | ""
1341 | ]
1342 | },
1343 | "metadata": {},
1344 | "output_type": "display_data"
1345 | },
1346 | {
1347 | "name": "stderr",
1348 | "output_type": "stream",
1349 | "text": [
1350 | "The following columns in the evaluation set don't have a corresponding argument in `StitchedTransformers.forward` and have been ignored: attention_mask, text. If attention_mask, text are not expected by `StitchedTransformers.forward`, you can safely ignore this message.\n",
1351 | "***** Running Evaluation *****\n",
1352 | " Num examples = 500\n",
1353 | " Batch size = 1\n"
1354 | ]
1355 | },
1356 | {
1357 | "data": {
1358 | "text/html": [
1359 | "\n",
1360 | " \n",
1361 | " \n",
1362 | "
\n",
1363 | " [500/500 00:19]\n",
1364 | "
\n",
1365 | " "
1366 | ],
1367 | "text/plain": [
1368 | ""
1369 | ]
1370 | },
1371 | "metadata": {},
1372 | "output_type": "display_data"
1373 | },
1374 | {
1375 | "name": "stderr",
1376 | "output_type": "stream",
1377 | "text": [
1378 | "The following columns in the evaluation set don't have a corresponding argument in `StitchedTransformers.forward` and have been ignored: attention_mask, text. If attention_mask, text are not expected by `StitchedTransformers.forward`, you can safely ignore this message.\n",
1379 | "***** Running Evaluation *****\n",
1380 | " Num examples = 500\n",
1381 | " Batch size = 1\n"
1382 | ]
1383 | },
1384 | {
1385 | "data": {
1386 | "text/html": [
1387 | "\n",
1388 | " \n",
1389 | " \n",
1390 | "
\n",
1391 | " [500/500 00:20]\n",
1392 | "
\n",
1393 | " "
1394 | ],
1395 | "text/plain": [
1396 | ""
1397 | ]
1398 | },
1399 | "metadata": {},
1400 | "output_type": "display_data"
1401 | },
1402 | {
1403 | "name": "stderr",
1404 | "output_type": "stream",
1405 | "text": [
1406 | "The following columns in the evaluation set don't have a corresponding argument in `StitchedTransformers.forward` and have been ignored: attention_mask, text. If attention_mask, text are not expected by `StitchedTransformers.forward`, you can safely ignore this message.\n",
1407 | "***** Running Evaluation *****\n",
1408 | " Num examples = 500\n",
1409 | " Batch size = 1\n"
1410 | ]
1411 | },
1412 | {
1413 | "data": {
1414 | "text/html": [
1415 | "\n",
1416 | " \n",
1417 | " \n",
1418 | "
\n",
1419 | " [500/500 00:21]\n",
1420 | "
\n",
1421 | " "
1422 | ],
1423 | "text/plain": [
1424 | ""
1425 | ]
1426 | },
1427 | "metadata": {},
1428 | "output_type": "display_data"
1429 | }
1430 | ],
1431 | "source": [
1432 | "accs = []\n",
1433 | "for num_keep_layers in range(model2.config.n_layer):\n",
1434 | " stitched_model = StitchedTransformers(old_model.cuda(), model1, model2, kernel, \n",
1435 | " num_keep_layers, num_transplanted_layers, subtract=subtract).cpu()\n",
1436 | " trainer_stitched = Trainer(stitched_model, args=train_args, train_dataset=imdb_train, eval_dataset=imdb_val, \n",
1437 | " compute_metrics=compute_metrics)\n",
1438 | "\n",
1439 | " accs.append(trainer_stitched.evaluate()['eval_accuracy'])"
1440 | ]
1441 | }
1442 | ],
1443 | "metadata": {
1444 | "kernelspec": {
1445 | "display_name": "Python 3",
1446 | "language": "python",
1447 | "name": "python3"
1448 | },
1449 | "language_info": {
1450 | "codemirror_mode": {
1451 | "name": "ipython",
1452 | "version": 3
1453 | },
1454 | "file_extension": ".py",
1455 | "mimetype": "text/x-python",
1456 | "name": "python",
1457 | "nbconvert_exporter": "python",
1458 | "pygments_lexer": "ipython3",
1459 | "version": "3.8.13"
1460 | }
1461 | },
1462 | "nbformat": 4,
1463 | "nbformat_minor": 5
1464 | }
1465 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | import torch
3 | from torch import nn
4 | import torch.nn.functional as F
5 | import numpy as np
6 | from datasets import load_dataset
7 | from copy import deepcopy
8 |
9 |
10 | def keep_k(x, k=100, absolute=True, dim=-1):
11 | shape = x.shape
12 | x_ = x
13 | if absolute:
14 | x_ = abs(x)
15 | values, indices = torch.topk(x_, k=k, dim=dim)
16 | res = torch.zeros_like(x)
17 | res.scatter_(dim, indices, x.gather(dim, indices))
18 | return res
19 |
20 |
21 | def load_imdb():
22 | return load_dataset('imdb')['test']['text']
23 |
24 |
25 | class TokenizerFromVocab:
26 | def __init__(self, vocab):
27 | self.vocab = vocab
28 |
29 | def convert_ids_to_tokens(self, arr):
30 | return [*map(vocab.__getitem__, arr.cpu().tolist())]
31 |
32 | def __len__(self):
33 | return len(self.vocab)
34 |
35 |
36 | def get_multiberts_tokenizer():
37 | vocab = dict(enumerate(open('multiberts/vocab.txt', 'r').read().split('\n')[:-1]))
38 | return TokenizerFromVocab(vocab)
39 |
40 |
41 | def convert_to_tokens(indices, tokenizer, strip=True, width=15):
42 | res = tokenizer.convert_ids_to_tokens(indices)
43 | if strip:
44 | res = list(map(lambda x: x[1:] if x[0] == 'Ġ' else "#" + x, res))
45 | if width:
46 | res = list(map(lambda x: x[:width] + (x[width:] and '...'), res))
47 | return res
48 |
49 |
50 | def top_tokens(v, tokenizer, k=100, only_english=False, only_ascii=False,
51 | exclude_brackets=False):
52 | v = deepcopy(v)
53 | ignored_indices = []
54 | if only_ascii:
55 | ignored_indices = [key for val, key in tokenizer.vocab.items() if not val.strip('Ġ').isascii()]
56 | if only_english:
57 | ignored_indices =[key for val, key in tokenizer.vocab.items()
58 | if not (val.strip('Ġ').isascii() and val.strip('Ġ[]').isalnum())]
59 | if exclude_brackets:
60 | ignored_indices = set(ignored_indices).intersection(
61 | {key for val, key in tokenizer.vocab.items() if not (val.isascii() and val.isalnum())})
62 | ignored_indices = list(ignored_indices)
63 | v[ignored_indices] = -np.inf
64 | values, indices = torch.topk(v, k=k)
65 | res = convert_to_tokens(indices, tokenizer)
66 | return res
67 |
68 |
69 | def top_matrix_tokens(mat, tokenizer, k=100, rel_thresh=None, thresh=None,
70 | sample_entries=10000, alphabetical=False, only_english=False,
71 | exclude_brackets=False):
72 | mat = deepcopy(mat)
73 | ignored_indices = []
74 | if only_english:
75 | ignored_indices = [key for val, key in tokenizer.vocab.items()
76 | if not (val.isascii() and val.strip('[]').isalnum())]
77 | if exclude_brackets:
78 | ignored_indices = set(ignored_indices).intersection(
79 | {key for val, key in tokenizer.vocab.items() if not (val.isascii() and val.isalnum())})
80 | ignored_indices = list(ignored_indices)
81 | mat[ignored_indices, :] = -np.inf
82 | mat[:, ignored_indices] = -np.inf
83 | cond = torch.ones_like(mat).bool()
84 | if rel_thresh:
85 | cond &= (mat > torch.max(mat) * rel_thresh)
86 | if thresh:
87 | cond &= (mat > thresh)
88 | entries = torch.nonzero(cond)
89 | if sample_entries:
90 | entries = entries[np.random.randint(len(torch.nonzero(cond)), size=sample_entries)]
91 | res_indices = sorted(entries, key=lambda x: x[0] if alphabetical else -mat[x[0], x[1]])
92 | res = [*map(convert_to_tokens, res_indices)]
93 | return res
94 |
--------------------------------------------------------------------------------