├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── exploration.ipynb ├── interpretation-experiments.ipynb ├── load_multiberts.sh ├── parameter-alignment.ipynb ├── sentiment-analysis.ipynb └── utils.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/* 2 | multiberts/* 3 | .ipynb_checkpoints/* 4 | trainer_output/* 5 | artifacts/* 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Analyzing Transformers in Embedding Space! 2 | **code still requires some refactoring and documentation** 3 | 4 | This code encompasses all the experiments presented in the [paper](https://arxiv.org/abs/2209.02535). 5 | 6 | 7 | ## Setup (Linux) 8 | First create a directory named `artifacts` here. Run in shell: 9 | ``` 10 | mkdir artifacts 11 | ``` 12 | 13 | 14 | To use the notebook `parameter-alignment.ipynb` you must also download models from multiBERTs. Run in shell: 15 | ``` 16 | ./load_multiberts.sh 17 | ``` 18 | 19 | ## Cite Us 20 | If you want to cite us: 21 | ``` 22 | @misc{transformers_in_embedding_space, 23 | doi = {10.48550/ARXIV.2209.02535}, 24 | url = {https://arxiv.org/abs/2209.02535}, 25 | author = {Dar, Guy and Geva, Mor and Gupta, Ankit and Berant, Jonathan}, 26 | title = {Analyzing Transformers in Embedding Space}, 27 | publisher = {arXiv}, 28 | year = {2022}, 29 | copyright = {Creative Commons Attribution 4.0 International} 30 | } 31 | ``` -------------------------------------------------------------------------------- /exploration.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "11d5afc8-8486-45f2-b2b1-5b1d7aa7135e", 6 | "metadata": {}, 7 | "source": [ 8 | "## Init" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "id": "f5a18fab-8c9c-47b5-815e-5aea6f853959", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import torch\n", 19 | "from torch.nn import functional as F\n", 20 | "from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel\n", 21 | "from tabulate import tabulate\n", 22 | "from tqdm import tqdm, trange\n", 23 | "from copy import deepcopy\n", 24 | "import numpy as np\n", 25 | "from collections import Counter" 26 | ] 27 | }, 28 | { 29 | "cell_type": "markdown", 30 | "id": "24ddef26-399f-4798-a1e1-a10b59896077", 31 | "metadata": {}, 32 | "source": [ 33 | "## Helper Functions" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 2, 39 | "id": "0486a4d3-6094-4699-a2df-13c841910e6d", 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "ALNUM_CHARSET = set('abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789')\n", 44 | "\n", 45 | "def convert_to_tokens(indices, tokenizer, extended=False, extra_values_pos=None, strip=True):\n", 46 | " if extended:\n", 47 | " res = [tokenizer.convert_ids_to_tokens([idx])[0] if idx < len(tokenizer) else \n", 48 | " (f\"[pos{idx-len(tokenizer)}]\" if idx < extra_values_pos else f\"[val{idx-extra_values_pos}]\") \n", 49 | " for idx in indices]\n", 50 | " else:\n", 51 | " res = tokenizer.convert_ids_to_tokens(indices)\n", 52 | " if strip:\n", 53 | " res = list(map(lambda x: x[1:] if x[0] == 'Ġ' else \"#\" + x, res))\n", 54 | " return res\n", 55 | "\n", 56 | "\n", 57 | "def top_tokens(v, k=100, tokenizer=None, only_alnum=False, only_ascii=True, with_values=False, \n", 58 | " exclude_brackets=False, extended=True, extra_values=None, only_from_list=None):\n", 59 | " if tokenizer is None:\n", 60 | " tokenizer = my_tokenizer\n", 61 | " v = deepcopy(v)\n", 62 | " ignored_indices = []\n", 63 | " if only_ascii:\n", 64 | " ignored_indices.extend([key for val, key in tokenizer.vocab.items() if not val.strip('Ġ▁').isascii()])\n", 65 | " if only_alnum: \n", 66 | " ignored_indices.extend([key for val, key in tokenizer.vocab.items() if not (set(val.strip('Ġ▁[] ')) <= ALNUM_CHARSET)])\n", 67 | " if only_from_list:\n", 68 | " ignored_indices.extend([key for val, key in tokenizer.vocab.items() if val.strip('Ġ▁ ').lower() not in only_from_list])\n", 69 | " if exclude_brackets:\n", 70 | " ignored_indices = set(ignored_indices).intersection(\n", 71 | " {key for val, key in tokenizer.vocab.items() if not (val.isascii() and val.isalnum())})\n", 72 | " ignored_indices = list(ignored_indices)\n", 73 | " \n", 74 | " ignored_indices = list(set(ignored_indices))\n", 75 | " v[ignored_indices] = -np.inf\n", 76 | " extra_values_pos = len(v)\n", 77 | " if extra_values is not None:\n", 78 | " v = torch.cat([v, extra_values])\n", 79 | " values, indices = torch.topk(v, k=k)\n", 80 | " res = convert_to_tokens(indices, tokenizer, extended=extended, extra_values_pos=extra_values_pos)\n", 81 | " if with_values:\n", 82 | " res = list(zip(res, values.cpu().numpy()))\n", 83 | " return res" 84 | ] 85 | }, 86 | { 87 | "cell_type": "markdown", 88 | "id": "5e03a094-4602-47bf-89ea-0656fb4c1d5d", 89 | "metadata": {}, 90 | "source": [ 91 | "## Extract Weights" 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": 3, 97 | "id": "a6e82d2c-92a5-4892-afc3-b40439d6b971", 98 | "metadata": {}, 99 | "outputs": [], 100 | "source": [ 101 | "model = AutoModelForCausalLM.from_pretrained(\"gpt2-medium\")\n", 102 | "tokenizer = my_tokenizer = AutoTokenizer.from_pretrained(\"gpt2-medium\")\n", 103 | "emb = model.get_output_embeddings().weight.data.T.detach()\n", 104 | "\n", 105 | "num_layers = model.config.n_layer\n", 106 | "num_heads = model.config.n_head\n", 107 | "hidden_dim = model.config.n_embd\n", 108 | "head_size = hidden_dim // num_heads\n", 109 | "\n", 110 | "K = torch.cat([model.get_parameter(f\"transformer.h.{j}.mlp.c_fc.weight\").T\n", 111 | " for j in range(num_layers)]).detach()\n", 112 | "V = torch.cat([model.get_parameter(f\"transformer.h.{j}.mlp.c_proj.weight\")\n", 113 | " for j in range(num_layers)]).detach()\n", 114 | "\n", 115 | "W_Q, W_K, W_V = torch.cat([model.get_parameter(f\"transformer.h.{j}.attn.c_attn.weight\") \n", 116 | " for j in range(num_layers)]).detach().chunk(3, dim=-1)\n", 117 | "W_O = torch.cat([model.get_parameter(f\"transformer.h.{j}.attn.c_proj.weight\") \n", 118 | " for j in range(num_layers)]).detach()\n" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": 4, 124 | "id": "104fd920-f247-4827-be5a-c958dfadbe17", 125 | "metadata": {}, 126 | "outputs": [], 127 | "source": [ 128 | "K_heads = K.reshape(num_layers, -1, hidden_dim)\n", 129 | "V_heads = V.reshape(num_layers, -1, hidden_dim)\n", 130 | "d_int = K_heads.shape[1]\n", 131 | "\n", 132 | "W_Q_heads = W_Q.reshape(num_layers, hidden_dim, num_heads, head_size).permute(0, 2, 1, 3)\n", 133 | "W_K_heads = W_K.reshape(num_layers, hidden_dim, num_heads, head_size).permute(0, 2, 1, 3)\n", 134 | "W_V_heads = W_V.reshape(num_layers, hidden_dim, num_heads, head_size).permute(0, 2, 1, 3)\n", 135 | "W_O_heads = W_O.reshape(num_layers, num_heads, head_size, hidden_dim)" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": 5, 141 | "id": "cec5b99d-9e6c-413f-a43d-02fd39485640", 142 | "metadata": {}, 143 | "outputs": [], 144 | "source": [ 145 | "emb_inv = emb.T" 146 | ] 147 | }, 148 | { 149 | "cell_type": "markdown", 150 | "id": "e58b9e38-8ff4-41f6-8c47-8ae23f4293e6", 151 | "metadata": {}, 152 | "source": [ 153 | "## Interpretation" 154 | ] 155 | }, 156 | { 157 | "cell_type": "markdown", 158 | "id": "e81b3452-6449-4397-bda7-3209c573c53a", 159 | "metadata": {}, 160 | "source": [ 161 | "#### Alternative I: No Token List" 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": 6, 167 | "id": "6fbe8ce6-abf5-427b-b374-a43f88b06097", 168 | "metadata": {}, 169 | "outputs": [], 170 | "source": [ 171 | "tokens_list = set()" 172 | ] 173 | }, 174 | { 175 | "cell_type": "markdown", 176 | "id": "ada3bfc9-3d15-43ae-a21c-da0747b16ba4", 177 | "metadata": {}, 178 | "source": [ 179 | "#### Alternative II: Can Load Token List from IMDB" 180 | ] 181 | }, 182 | { 183 | "cell_type": "code", 184 | "execution_count": 14, 185 | "id": "2bb14a77-5ac2-449b-b771-97be619cb79a", 186 | "metadata": {}, 187 | "outputs": [], 188 | "source": [ 189 | "from datasets import load_dataset" 190 | ] 191 | }, 192 | { 193 | "cell_type": "code", 194 | "execution_count": null, 195 | "id": "6fe48f8f-eb62-4bee-a45c-27ff12220825", 196 | "metadata": {}, 197 | "outputs": [], 198 | "source": [ 199 | "imdb = load_dataset('imdb')['train']['text']" 200 | ] 201 | }, 202 | { 203 | "cell_type": "code", 204 | "execution_count": 17, 205 | "id": "ae8eb81a-8525-40ef-8376-a902980c18d0", 206 | "metadata": {}, 207 | "outputs": [], 208 | "source": [ 209 | "max_tokens_num = None" 210 | ] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "execution_count": 18, 215 | "id": "ea6fd97b-1dbf-4e61-bcdb-f4f2d2c23ef9", 216 | "metadata": {}, 217 | "outputs": [ 218 | { 219 | "name": "stderr", 220 | "output_type": "stream", 221 | "text": [ 222 | " 0%| | 0/25000 [00:00 1024). Running this sequence through the model will result in indexing errors\n", 223 | "100%|██████████| 25000/25000 [00:53<00:00, 467.46it/s]\n" 224 | ] 225 | } 226 | ], 227 | "source": [ 228 | "if max_tokens_num is None:\n", 229 | " tokens_list = set()\n", 230 | " for txt in tqdm(imdb):\n", 231 | " tokens_list = tokens_list.union(set(tokenizer.tokenize(txt)))\n", 232 | "else:\n", 233 | " tokens_list = Counter()\n", 234 | " for txt in tqdm(imdb):\n", 235 | " tokens_list.update(set(tokenizer.tokenize(txt)))\n", 236 | " tokens_list = map(lambda x: x[0], tokens_list.most_common(max_tokens_num))\n", 237 | " " 238 | ] 239 | }, 240 | { 241 | "cell_type": "code", 242 | "execution_count": 19, 243 | "id": "f2944cdd-13d9-41ea-a2d8-a643ab49b583", 244 | "metadata": {}, 245 | "outputs": [], 246 | "source": [ 247 | "tokens_list = set([*map(lambda x: x.strip('Ġ▁').lower(), tokens_list)])" 248 | ] 249 | }, 250 | { 251 | "cell_type": "markdown", 252 | "id": "cd71b325-da96-47b2-99a8-6b909668b642", 253 | "metadata": {}, 254 | "source": [ 255 | "### FF Keys & Values" 256 | ] 257 | }, 258 | { 259 | "cell_type": "code", 260 | "execution_count": 7, 261 | "id": "b891ff3d-18ac-4daf-bb1b-272cc9bb00a5", 262 | "metadata": {}, 263 | "outputs": [ 264 | { 265 | "name": "stdout", 266 | "output_type": "stream", 267 | "text": [ 268 | "23 907\n", 269 | "K V\n", 270 | "---------- ----------\n", 271 | "hands hand\n", 272 | "hand #Hand\n", 273 | "#hands Hand\n", 274 | "#hand #hand\n", 275 | "fingers hands\n", 276 | "#feet Hands\n", 277 | "fingertips fist\n", 278 | "claws #hands\n", 279 | "paw finger\n", 280 | "paws handed\n", 281 | "metab thumb\n", 282 | "palms fingers\n", 283 | "fingert foot\n", 284 | "#Hand #handed\n", 285 | "fists paw\n", 286 | "wrists handing\n", 287 | "levers #finger\n", 288 | "thumbs #hander\n", 289 | "tentacles fingertips\n", 290 | "feet claw\n", 291 | "limb fingert\n", 292 | "slider #Foot\n", 293 | "#handed Stick\n", 294 | "#dimension arm\n", 295 | "jaws #Accessory\n", 296 | "skelet #fing\n", 297 | "lapt Foot\n", 298 | "ankles index\n", 299 | "weap toe\n", 300 | "foot #auntlet\n" 301 | ] 302 | } 303 | ], 304 | "source": [ 305 | "i1, i2 = 23, 907\n", 306 | "# i1, i2 = np.random.randint(num_layers), np.random.randint(d_int)\n", 307 | "\n", 308 | "print(i1, i2)\n", 309 | "print(tabulate([*zip(\n", 310 | " top_tokens((K_heads[i1, i2]) @ emb, k=30, only_from_list=tokens_list, only_alnum=False),\n", 311 | " top_tokens((V_heads[i1, i2]) @ emb, k=30, only_from_list=tokens_list, only_alnum=False),\n", 312 | " # top_tokens((-K_heads[i1, i2]) @ emb, k=200, only_from_list=tokens_list),\n", 313 | " # top_tokens((-V_heads[i1, i2]) @ emb, k=200, only_from_list=tokens_list),\n", 314 | ")], headers=['K', 'V', '-K', '-V']))" 315 | ] 316 | }, 317 | { 318 | "cell_type": "markdown", 319 | "id": "5ec96387-7cd5-4348-9444-6e1a10987da3", 320 | "metadata": {}, 321 | "source": [ 322 | "### Attention Weights Interpretation" 323 | ] 324 | }, 325 | { 326 | "cell_type": "code", 327 | "execution_count": 227, 328 | "id": "ca7323f0-7cf9-4020-bbbd-191de7fe25b0", 329 | "metadata": {}, 330 | "outputs": [], 331 | "source": [ 332 | "def approx_topk(mat, min_k=500, max_k=250_000, th0=10, max_iters=10, verbose=False):\n", 333 | " _get_actual_k = lambda th, th_max: torch.nonzero((mat > th) & (mat < th_max)).shape[0]\n", 334 | " th_max = np.inf\n", 335 | " left, right = 0, th0 \n", 336 | " while True:\n", 337 | " actual_k = _get_actual_k(right, th_max)\n", 338 | " if verbose:\n", 339 | " print(f\"one more iteration. {actual_k}\")\n", 340 | " if actual_k <= max_k:\n", 341 | " break\n", 342 | " left, right = right, right * 2\n", 343 | " if min_k <= actual_k <= max_k:\n", 344 | " th = right\n", 345 | " else:\n", 346 | " for _ in range(max_iters):\n", 347 | " mid = (left + right) / 2\n", 348 | " actual_k = _get_actual_k(mid, th_max)\n", 349 | " if verbose:\n", 350 | " print(f\"one more iteration. {actual_k}\")\n", 351 | " if min_k <= actual_k <= max_k:\n", 352 | " break\n", 353 | " if actual_k > max_k:\n", 354 | " left = mid\n", 355 | " else:\n", 356 | " right = mid\n", 357 | " th = mid\n", 358 | " return torch.nonzero((mat > th) & (mat < th_max)).tolist()\n", 359 | "\n", 360 | "def get_top_entries(tmp, all_high_pos, only_ascii=False, only_alnum=False, exclude_same=False, exclude_fuzzy=False, tokens_list=None):\n", 361 | " remaining_pos = all_high_pos\n", 362 | " if only_ascii:\n", 363 | " remaining_pos = [*filter(\n", 364 | " lambda x: (tokenizer.decode(x[0]).strip('Ġ▁').isascii() and tokenizer.decode(x[1]).strip('Ġ▁').isascii()), \n", 365 | " remaining_pos)]\n", 366 | " if only_alnum:\n", 367 | " remaining_pos = [*filter(\n", 368 | " lambda x: (tokenizer.decode(x[0]).strip('Ġ▁ ').isalnum() and tokenizer.decode(x[1]).strip('Ġ▁ ').isalnum()), \n", 369 | " remaining_pos)]\n", 370 | " if exclude_same:\n", 371 | " remaining_pos = [*filter(\n", 372 | " lambda x: tokenizer.decode(x[0]).lower().strip() != tokenizer.decode(x[1]).lower().strip(), \n", 373 | " remaining_pos)]\n", 374 | " if exclude_fuzzy:\n", 375 | " remaining_pos = [*filter(\n", 376 | " lambda x: not _fuzzy_eq(tokenizer.decode(x[0]).lower().strip(), tokenizer.decode(x[1]).lower().strip()), \n", 377 | " remaining_pos)]\n", 378 | " if tokens_list:\n", 379 | " remaining_pos = [*filter(\n", 380 | " lambda x: ((tokenizer.decode(x[0]).strip('Ġ▁').lower().strip() in tokens_list) and \n", 381 | " (tokenizer.decode(x[1]).strip('Ġ▁').lower().strip() in tokens_list)), \n", 382 | " remaining_pos)]\n", 383 | "\n", 384 | " pos_val = tmp[[*zip(*remaining_pos)]]\n", 385 | " good_cells = [*map(lambda x: (tokenizer.decode(x[0]), tokenizer.decode(x[1])), remaining_pos)]\n", 386 | " good_tokens = list(map(lambda x: Counter(x).most_common(), zip(*good_cells)))\n", 387 | " remaining_pos_best = np.array(remaining_pos)[torch.argsort(pos_val if reverse_list else -pos_val)[:50]]\n", 388 | " good_cells_best = [*map(lambda x: (tokenizer.decode(x[0]), tokenizer.decode(x[1])), remaining_pos_best)]\n", 389 | " # good_cells[:100]\n", 390 | " # list(zip(good_tokens[0], good_tokens[1]))\n", 391 | " return good_cells_best" 392 | ] 393 | }, 394 | { 395 | "cell_type": "markdown", 396 | "id": "2bbd4e4a-a69b-4d58-a371-426b9073ff81", 397 | "metadata": { 398 | "tags": [] 399 | }, 400 | "source": [ 401 | "#### $W_{VO}$ Interpretation" 402 | ] 403 | }, 404 | { 405 | "cell_type": "markdown", 406 | "id": "79156807-54af-4fe8-b170-29dc0856d686", 407 | "metadata": {}, 408 | "source": [ 409 | "Choose **layer** and **head** here:" 410 | ] 411 | }, 412 | { 413 | "cell_type": "code", 414 | "execution_count": 253, 415 | "id": "f0d00edd-fdca-40d0-807f-e4bf3d7611e9", 416 | "metadata": {}, 417 | "outputs": [ 418 | { 419 | "data": { 420 | "text/plain": [ 421 | "(24, 9)" 422 | ] 423 | }, 424 | "execution_count": 253, 425 | "metadata": {}, 426 | "output_type": "execute_result" 427 | } 428 | ], 429 | "source": [ 430 | "i1, i2 = np.random.randint(num_layers), np.random.randint(num_heads)\n", 431 | "i1, i2 = 24, 9\n", 432 | "i1, i2" 433 | ] 434 | }, 435 | { 436 | "cell_type": "code", 437 | "execution_count": 254, 438 | "id": "2743e277-74f7-41a6-a479-5bd65840d2d8", 439 | "metadata": {}, 440 | "outputs": [], 441 | "source": [ 442 | "W_V_tmp, W_O_tmp = W_V_heads[i1, i2, :], W_O_heads[i1, i2]\n", 443 | "tmp = (emb_inv @ (W_V_tmp @ W_O_tmp) @ emb)" 444 | ] 445 | }, 446 | { 447 | "cell_type": "code", 448 | "execution_count": 255, 449 | "id": "0ede26ee-3b82-4446-8b6d-2300e7bdcc16", 450 | "metadata": {}, 451 | "outputs": [ 452 | { 453 | "name": "stdout", 454 | "output_type": "stream", 455 | "text": [ 456 | "one more iteration. 11496\n" 457 | ] 458 | } 459 | ], 460 | "source": [ 461 | "all_high_pos = approx_topk(tmp, th0=1, verbose=True) # torch.nonzero((tmp > th) & (tmp < th_max)).tolist()" 462 | ] 463 | }, 464 | { 465 | "cell_type": "code", 466 | "execution_count": 256, 467 | "id": "3207aed9-f99e-4109-9160-2d62f31e8b76", 468 | "metadata": {}, 469 | "outputs": [], 470 | "source": [ 471 | "exclude_same = False\n", 472 | "reverse_list = False\n", 473 | "only_ascii = True\n", 474 | "only_alnum = False" 475 | ] 476 | }, 477 | { 478 | "cell_type": "code", 479 | "execution_count": 257, 480 | "id": "0fab412a-3f2e-4e0c-b11e-c462b17b6191", 481 | "metadata": {}, 482 | "outputs": [ 483 | { 484 | "data": { 485 | "text/plain": [ 486 | "[(' interviewer', ' interviewer'),\n", 487 | " (' lectures', ' lectures'),\n", 488 | " (' lecture', ' lecture'),\n", 489 | " ('Interview', ' interview'),\n", 490 | " (' interview', ' interview'),\n", 491 | " (' interviewer', ' interview'),\n", 492 | " (' interviewing', ' interviewing'),\n", 493 | " (' magazine', ' magazine'),\n", 494 | " (' Reviews', ' Reviews'),\n", 495 | " (' reviewer', ' reviewer'),\n", 496 | " (' reviewers', ' reviewers'),\n", 497 | " (' lecture', ' lectures'),\n", 498 | " (' testers', ' testers'),\n", 499 | " (' editors', ' editors'),\n", 500 | " (' interview', ' interviewer'),\n", 501 | " ('Interview', ' Interview'),\n", 502 | " ('Interview', ' interviewer'),\n", 503 | " ('Interview', 'Interview'),\n", 504 | " (' lectures', ' lecture'),\n", 505 | " (' interviewer', ' interviewing'),\n", 506 | " (' journal', ' journal'),\n", 507 | " (' interviewing', ' interviewer'),\n", 508 | " (' blogs', ' blogs'),\n", 509 | " (' editorial', ' editorial'),\n", 510 | " (' tests', ' tests'),\n", 511 | " (' presentations', ' presentations'),\n", 512 | " (' Editorial', ' Editorial'),\n", 513 | " (' Interview', ' interview'),\n", 514 | " (' reviewers', ' reviewer'),\n", 515 | " ('Interview', ' interviews'),\n", 516 | " (' interviewing', ' interview'),\n", 517 | " (' Interview', ' interviewer'),\n", 518 | " (' interview', ' interviews'),\n", 519 | " (' Interview', ' Interview'),\n", 520 | " ('Interview', ' interviewing'),\n", 521 | " (' interviewer', 'Interview'),\n", 522 | " (' testifying', ' testifying'),\n", 523 | " (' reviewer', ' reviewers'),\n", 524 | " (' blogging', ' blogging'),\n", 525 | " (' broadcast', ' broadcast'),\n", 526 | " (' interviewer', ' Interview'),\n", 527 | " (' magazines', ' magazine'),\n", 528 | " (' Editorial', ' editorial'),\n", 529 | " (' interviews', ' interview'),\n", 530 | " (' interview', ' interviewing'),\n", 531 | " (' interview', ' Interview'),\n", 532 | " (' interviews', ' interviews'),\n", 533 | " ('tests', ' tests'),\n", 534 | " (' interviewing', ' interviews'),\n", 535 | " (' interview', 'Interview')]" 536 | ] 537 | }, 538 | "execution_count": 257, 539 | "metadata": {}, 540 | "output_type": "execute_result" 541 | } 542 | ], 543 | "source": [ 544 | "get_top_entries(tmp, all_high_pos, only_ascii=only_ascii, only_alnum=only_alnum, \n", 545 | " exclude_same=exclude_same, tokens_list=None)" 546 | ] 547 | }, 548 | { 549 | "cell_type": "markdown", 550 | "id": "49638ac2-455c-4441-aa67-4fde61c9ea83", 551 | "metadata": { 552 | "tags": [] 553 | }, 554 | "source": [ 555 | "#### $W_{QK}$ Interpretation" 556 | ] 557 | }, 558 | { 559 | "cell_type": "markdown", 560 | "id": "8fbfc882-3bf7-46e5-8e6a-f6a245ceb7bf", 561 | "metadata": {}, 562 | "source": [ 563 | "Choose **layer** and **head** here:" 564 | ] 565 | }, 566 | { 567 | "cell_type": "code", 568 | "execution_count": 330, 569 | "id": "c43e9c99-f9bf-4eae-b215-289a7f630ccd", 570 | "metadata": {}, 571 | "outputs": [ 572 | { 573 | "data": { 574 | "text/plain": [ 575 | "(20, 13)" 576 | ] 577 | }, 578 | "execution_count": 330, 579 | "metadata": {}, 580 | "output_type": "execute_result" 581 | } 582 | ], 583 | "source": [ 584 | "# i1, i2 = np.random.randint(num_layers), np.random.randint(num_heads)\n", 585 | "i1, i2 = 20, 13\n", 586 | "i1, i2" 587 | ] 588 | }, 589 | { 590 | "cell_type": "code", 591 | "execution_count": 331, 592 | "id": "e74cb0c2-c39f-42ce-a87c-91125207d537", 593 | "metadata": {}, 594 | "outputs": [], 595 | "source": [ 596 | "W_Q_tmp, W_K_tmp = W_Q_heads[i1, i2, :], W_K_heads[i1, i2, :]\n", 597 | "tmp2 = (emb_inv @ (W_Q_tmp @ W_K_tmp.T) @ emb_inv.T)" 598 | ] 599 | }, 600 | { 601 | "cell_type": "code", 602 | "execution_count": 332, 603 | "id": "26a461fb-bcbc-477a-be76-20f5b52f329d", 604 | "metadata": {}, 605 | "outputs": [ 606 | { 607 | "name": "stdout", 608 | "output_type": "stream", 609 | "text": [ 610 | "one more iteration. 265\n", 611 | "one more iteration. 103159\n" 612 | ] 613 | } 614 | ], 615 | "source": [ 616 | "all_high_pos = approx_topk(tmp2, th0=1, verbose=True) # torch.nonzero((tmp2 > th2) & (tmp2 < th_max2)).tolist()" 617 | ] 618 | }, 619 | { 620 | "cell_type": "code", 621 | "execution_count": 335, 622 | "id": "16706487-23da-4e39-9799-641528d2e6ed", 623 | "metadata": {}, 624 | "outputs": [], 625 | "source": [ 626 | "exclude_same = False\n", 627 | "reverse_list = False\n", 628 | "only_ascii = True\n", 629 | "only_alnum = True" 630 | ] 631 | }, 632 | { 633 | "cell_type": "code", 634 | "execution_count": 336, 635 | "id": "4e402a1a-8481-4481-9bad-a10f70542b63", 636 | "metadata": {}, 637 | "outputs": [ 638 | { 639 | "data": { 640 | "text/plain": [ 641 | "[(' outdoors', ' outdoors'),\n", 642 | " (' outdoor', ' outdoors'),\n", 643 | " (' Gre', 'burg'),\n", 644 | " (' healing', ' healing'),\n", 645 | " (' indoor', ' outdoors'),\n", 646 | " (' Hemp', 'burg'),\n", 647 | " (' Ticket', ' Ticket'),\n", 648 | " (' accommodations', ' accommodations'),\n", 649 | " ('eco', 'aco'),\n", 650 | " ('prem', 'otti'),\n", 651 | " (' Candy', 'cott'),\n", 652 | " (' decorative', ' ornament'),\n", 653 | " ('yan', 'ava'),\n", 654 | " (' deadlines', ' schedule'),\n", 655 | " (' Lor', 'ian'),\n", 656 | " (' architectural', ' ornament'),\n", 657 | " (' Ratings', ' Ratings'),\n", 658 | " (' Bod', 'za'),\n", 659 | " (' exotic', ' exotic'),\n", 660 | " (' food', ' baths'),\n", 661 | " (' Marketplace', ' Marketplace'),\n", 662 | " (' heal', ' healing'),\n", 663 | " (' Ex', 'ilus'),\n", 664 | " (' indoors', ' outdoors'),\n", 665 | " (' therm', ' therm'),\n", 666 | " (' bleach', ' coated'),\n", 667 | " (' Sod', 'opol'),\n", 668 | " (' District', ' Metropolitan'),\n", 669 | " (' Anonymous', ' Rebell'),\n", 670 | " (' Corn', 'burg'),\n", 671 | " (' indoor', ' indoors'),\n", 672 | " (' R', 'vale'),\n", 673 | " ('rom', 'otti'),\n", 674 | " (' ratings', ' Ratings'),\n", 675 | " (' attendance', ' attendance'),\n", 676 | " (' destinations', ' destinations'),\n", 677 | " (' VIDEOS', ' VIDEOS'),\n", 678 | " ('yan', 'opol'),\n", 679 | " (' Suffolk', 'ville'),\n", 680 | " (' retali', ' against'),\n", 681 | " ('mos', 'oli'),\n", 682 | " (' pacing', ' pacing'),\n", 683 | " (' Spectrum', ' QC'),\n", 684 | " (' Il', 'ian'),\n", 685 | " (' archived', ' archived'),\n", 686 | " (' Pledge', ' Pledge'),\n", 687 | " ('alg', 'otti'),\n", 688 | " (' Freedom', 'USA'),\n", 689 | " ('anto', 'ero'),\n", 690 | " (' decorative', ' decoration')]" 691 | ] 692 | }, 693 | "execution_count": 336, 694 | "metadata": {}, 695 | "output_type": "execute_result" 696 | } 697 | ], 698 | "source": [ 699 | "get_top_entries(tmp2, all_high_pos, only_ascii=only_ascii, only_alnum=only_alnum, exclude_same=exclude_same, \n", 700 | " tokens_list=tokens_list)" 701 | ] 702 | }, 703 | { 704 | "cell_type": "markdown", 705 | "id": "d5e6b973-a8c8-4775-aa69-fa1bac08c450", 706 | "metadata": {}, 707 | "source": [ 708 | "## Plots" 709 | ] 710 | }, 711 | { 712 | "cell_type": "markdown", 713 | "id": "bc9d2084-4e47-40b9-b864-62bf0b1d6e9b", 714 | "metadata": {}, 715 | "source": [ 716 | "*We thank Ohad Rubin for the idea of providing plots for better visualizations!*" 717 | ] 718 | }, 719 | { 720 | "cell_type": "code", 721 | "execution_count": 20, 722 | "id": "c30e8596-b8a1-4b53-935c-24318aca7fdf", 723 | "metadata": {}, 724 | "outputs": [], 725 | "source": [ 726 | "i1, i2 = 6, 2152" 727 | ] 728 | }, 729 | { 730 | "cell_type": "code", 731 | "execution_count": 21, 732 | "id": "99271972-75c1-4cc2-b0be-c5f0895558fb", 733 | "metadata": {}, 734 | "outputs": [], 735 | "source": [ 736 | "from sklearn.manifold import TSNE\n", 737 | "import pandas as pd\n", 738 | "import plotly.express as px" 739 | ] 740 | }, 741 | { 742 | "cell_type": "code", 743 | "execution_count": 22, 744 | "id": "3fb14d53-dc5d-47aa-8609-3633843e1d53", 745 | "metadata": {}, 746 | "outputs": [], 747 | "source": [ 748 | "def _calc_df(vector, k, coef, normalized, tokenizer):\n", 749 | " mat = emb\n", 750 | " if normalized:\n", 751 | " mat = F.normalize(mat, dim=-1)\n", 752 | " dot = vector @ mat\n", 753 | " sol = torch.topk(dot * coef, k=k).indices # np.argsort(dot * coef)[-k:]\n", 754 | " pattern = mat[:, sol].T\n", 755 | " scores = coef * dot[sol]\n", 756 | " # labels = tokenizer.batch_decode(sol)\n", 757 | " labels = convert_to_tokens(sol, tokenizer=tokenizer)\n", 758 | " X_embedded = TSNE(n_components=3,\n", 759 | " learning_rate=10,\n", 760 | " init='pca',\n", 761 | " perplexity=3).fit_transform(pattern)\n", 762 | "\n", 763 | " df = pd.DataFrame(dict(x=X_embedded.T[0], y=X_embedded.T[1], z=X_embedded.T[2], label=labels, score=scores))\n", 764 | " return df\n", 765 | "\n", 766 | "\n", 767 | "def plot_embedding_space(vector, is_3d=False, add_text=False, k=100, coef=1, normalized=False, tokenizer=None):\n", 768 | " df = _calc_df(vector, k=k, coef=coef, normalized=normalized, tokenizer=tokenizer)\n", 769 | " kwargs = {}\n", 770 | " scatter_fn = px.scatter\n", 771 | " if add_text:\n", 772 | " kwargs.update({'text': 'label'})\n", 773 | " if is_3d:\n", 774 | " scatter_fn = px.scatter_3d\n", 775 | " kwargs.update({'z': 'z'})\n", 776 | " fig = scatter_fn(\n", 777 | " data_frame=df, \n", 778 | " x='x', \n", 779 | " y='y',\n", 780 | " custom_data=[\"label\", \"score\"],\n", 781 | " color=\"score\", size_max=1, **kwargs)\n", 782 | "\n", 783 | " fig.update_traces(\n", 784 | " hovertemplate=\"
\".join([\n", 785 | " \"ColX: %{x}\",\n", 786 | " \"ColY: %{y}\",\n", 787 | " \"label: %{customdata[0]}\",\n", 788 | " \"score: %{customdata[1]}\"\n", 789 | " ])\n", 790 | " )\n", 791 | " \n", 792 | " if add_text:\n", 793 | " fig.update_traces(textposition='middle right')\n", 794 | " fig.show()" 795 | ] 796 | }, 797 | { 798 | "cell_type": "code", 799 | "execution_count": null, 800 | "id": "80575a88-3633-4a10-9e0e-9c1b41c65c46", 801 | "metadata": {}, 802 | "outputs": [], 803 | "source": [ 804 | "plot_embedding_space(K_heads[i1][i2], tokenizer=tokenizer, normalized=False)" 805 | ] 806 | } 807 | ], 808 | "metadata": { 809 | "kernelspec": { 810 | "display_name": "Python 3", 811 | "language": "python", 812 | "name": "python3" 813 | }, 814 | "language_info": { 815 | "codemirror_mode": { 816 | "name": "ipython", 817 | "version": 3 818 | }, 819 | "file_extension": ".py", 820 | "mimetype": "text/x-python", 821 | "name": "python", 822 | "nbconvert_exporter": "python", 823 | "pygments_lexer": "ipython3", 824 | "version": "3.8.13" 825 | }, 826 | "toc-showcode": true, 827 | "toc-showmarkdowntxt": true 828 | }, 829 | "nbformat": 4, 830 | "nbformat_minor": 5 831 | } 832 | -------------------------------------------------------------------------------- /load_multiberts.sh: -------------------------------------------------------------------------------- 1 | mkdir multiberts 2 | mkdir multiberts/models 3 | 4 | wget -O multiberts/vocab.txt https://storage.googleapis.com/multiberts/public/vocab.txt 5 | wget -O multiberts/bert_config.json https://storage.googleapis.com/multiberts/public/bert_config.json 6 | for ckpt in {0..2} ; do 7 | wget "https://storage.googleapis.com/multiberts/public/models/seed_${ckpt}.zip" -O "multiberts/seed_${ckpt}.zip" 8 | unzip -o "multiberts/seed_${ckpt}.zip" -d "multiberts/models" 9 | rm "multiberts/seed_${ckpt}.zip" 10 | export BERT_BASE_DIR="multiberts/models/seed_${ckpt}" 11 | transformers-cli convert --model_type bert \ 12 | --tf_checkpoint $BERT_BASE_DIR/bert.ckpt \ 13 | --config multiberts/bert_config.json \ 14 | --pytorch_dump_output $BERT_BASE_DIR/pytorch_model.bin 15 | 16 | cp multiberts/vocab.txt $BERT_BASE_DIR/vocab.txt 17 | cp multiberts/bert_config.json $BERT_BASE_DIR/config.json 18 | done -------------------------------------------------------------------------------- /parameter-alignment.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "7ff967e6-9737-410f-9ade-52af976bdbc9", 6 | "metadata": {}, 7 | "source": [ 8 | "## Init" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "id": "abadd5e8-88d7-4b2c-b7fa-70f0bc074dfe", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "from utils import get_multiberts_tokenizer\n", 19 | "import tensorflow as tf\n", 20 | "import tensorflow.keras.layers as tf_layers\n", 21 | "from tqdm.auto import tqdm\n", 22 | "import numpy as np\n", 23 | "import torch\n", 24 | "from torch import nn\n", 25 | "import torch.nn.functional as F\n", 26 | "import matplotlib.pyplot as plt\n", 27 | "import seaborn as sns\n", 28 | "from scipy.optimize import linear_sum_assignment" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 2, 34 | "id": "db9f88ea-9d42-4fca-8f27-d8976bb334b5", 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "num_layers = 12" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": 3, 44 | "id": "d1645359-e748-47a7-a741-919e65178fc2", 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "def corr_coef(X, Y):\n", 49 | " muX, muY = X.mean(-1, keepdims=True), Y.mean(-1, keepdims=True) \n", 50 | " X, Y = map(lambda A, mu: F.normalize(A - mu, dim=-1), [X, Y], [muX, muY])\n", 51 | " return X @ Y.T" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": 4, 57 | "id": "5a05ee97-e055-44e9-9f0c-f432d3415496", 58 | "metadata": { 59 | "tags": [] 60 | }, 61 | "outputs": [], 62 | "source": [ 63 | "def kernelized_corr_coef(x, y, K1, K2, K12, emb_mean1, emb_mean2, n):\n", 64 | " mu1, mu2 = x @ emb_mean1, y @ emb_mean2\n", 65 | " mu12 = torch.outer(mu1, mu2)\n", 66 | " x_norm, y_norm = map(lambda x, K, mu: (1/n * (x @ K) * x).sum(dim=-1) - mu**2, [x, y], [K1, K2], [mu1, mu2])\n", 67 | " return (1 / n * x @ K12 @ y.T - mu12) / torch.sqrt(torch.outer(x_norm, y_norm))" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": 5, 73 | "id": "e7c24c6e-a6af-43c5-907b-39536976b67f", 74 | "metadata": {}, 75 | "outputs": [], 76 | "source": [ 77 | "tokenizer = get_multiberts_tokenizer()" 78 | ] 79 | }, 80 | { 81 | "cell_type": "markdown", 82 | "id": "bf557219-c61c-43a8-a77c-a259e4090e40", 83 | "metadata": {}, 84 | "source": [ 85 | "## Extract Parameters" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": 6, 91 | "id": "d79efb45-b296-4983-8f80-9c49bfbc5aed", 92 | "metadata": {}, 93 | "outputs": [ 94 | { 95 | "name": "stdout", 96 | "output_type": "stream", 97 | "text": [ 98 | "['multiberts/models/seed_4/bert.ckpt', 'multiberts/models/seed_5/bert.ckpt']\n" 99 | ] 100 | }, 101 | { 102 | "data": { 103 | "application/vnd.jupyter.widget-view+json": { 104 | "model_id": "3b871ac73af4424f8e2834345bdc8b40", 105 | "version_major": 2, 106 | "version_minor": 0 107 | }, 108 | "text/plain": [ 109 | " 0%| | 0/2 [00:00" 257 | ] 258 | }, 259 | "execution_count": 14, 260 | "metadata": {}, 261 | "output_type": "execute_result" 262 | }, 263 | { 264 | "data": { 265 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWoAAAD4CAYAAADFAawfAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAAffklEQVR4nO3de5BfZZ3n8fcn3UlIAgQJlwGCJEyCysVhIAbHBcTJwATHMroFZZBRxsqYtRxW3S1L2bKWAdbdNdYMrBasTjQwCIug4KVHA0FA0N3BmKCJEq5NQEm4JzFXQtL9++4f54T50enu8+vu8yTnd/rzok7l9Ll8+0nT/e0nz3nO81VEYGZm1TVmfzfAzMwG50RtZlZxTtRmZhXnRG1mVnFO1GZmFdeZ+hM89673JJtWMnnO4UniNjZsTRJ39/M7ksTt3ZYkLAAvPnVQkrg7d6X51tvRMzZJXIDNiX5cdipNf+m2cWm+jwFe7NmeJO7P19+rkcbY/cralnPO2MOOH/Hn2xeSJ2ozs32q0bu/W1A6J2ozq5do7O8WlM6J2szqpeFEbWZWaVHDHrVnfZhZvfT2tL4VkDRX0uOSuiVd1s/58ZJuy88vlzSt6dzbJT0oaY2k30o6ID9+fx5zVb4dUdQO96jNrF5KepgoqQO4DjgXWAeskNQVEY80XbYA2BQRMyTNBxYBH5LUCdwMfCQiVkuaAuxuuu/iiFjZalvcozazeolG69vgZgPdEbE2InYBtwLz+lwzD7gx378dmCNJwHnAbyJiNUBEbIiIYf8GKexRS3pr3phj8kPrga6IeHS4n9TMLJkhPEyUtBBY2HRocUQszvePAZ5tOrcOOKNPiNeviYgeSZuBKcAJQEhaBhwO3BoRX2667wZJvcAdwBejYBnTQXvUkj5P9ltEwC/zTcC3+xuvabpvoaSVklbe/OJzg30KM7NSRTSGsMXiiJjVtC0u/gwt6QTOBC7O//ygpDn5uYsj4hTgrHz7SCvBBrMAOCkimsdWkHQ1sAb4Un835X/ZxZD2zUQzs72UNz1vPXBs08dT82P9XbMuH5eeDGwg633/LCJeAZC0FDgNuDci1gNExFZJt5ANsXxrsIYUjVE3gKP7OX5Ufs7MrFp6d7e+DW4FMFPSdEnjgPlAV59ruoBL8v0LgPvyYYxlwCmSJuYJ/N3AI5I6JR0GIGks8D7g4aKGFPWoPwPcK+lJ/m2s5s3ADODSouBmZvtcSfOo8zHnS8mSbgdwfUSskXQVsDIiuoAlwE2SuoGNZMmciNiUjzysAAJYGhE/ljQJWJYn6Q7gHuAbRW0ZNFFHxF2STiDrmjc/TFwxkieYZmbJlPhmYkQsBZb2OXZ50/5O4MIB7r2ZbIpe87HtwOlDbUfhrI/IXvP5xVADm5ntFzV8M9EvvJhZvXitDzOzaotG4UPCtuNEbWb14h710I3pTDeNesyfn5cm7u5dSeJ2/r47SdzYnqZyDMDUXz6RJO7mR9OsXrBj67gkcQEOTlSV5sBJryWJ++3t45PEheytt8ryGLWZWcW5wouZWcW5R21mVnEeozYzq7gWCgK0GydqM6sX96jNzKqtjqtbOFGbWb24R21mVnGe9WFmVnE17FEP+/UwSR8b5Nzrpbhuet6luMxsH+rtaX1rEyN5j/fKgU401yH7yFH9FYgxM0ukvCrklTHo0Iek3wx0Cjiy/OaYmY1QDYc+isaojwT+EtjU57iAf03SIjOzkRiFifpHwIERsarvCUn3p2iQmdmItNGQRquKaiYuGOTch8tvjpnZCLXRQ8JWeXqemdXLKBz6MDNrL6Nt6MPMrO24R10tHce9PUlcHXxYkrgx45U0cXdsThIXYNyEpUniHnp4mrJkB6/fmiQuQGNHmrJynUeMTRJ38k/TxAV4dUy6kmcjVsNEnaZwnZnZ/hLR+lZA0lxJj0vqlnRZP+fHS7otP79c0rSmc2+X9KCkNZJ+K+mA/Pjp+cfdkr4qqbAEpRO1mdVLT0/r2yAkdQDXAecDJwIXSTqxz2ULgE0RMQO4BliU39sJ3Ax8IiJOAs4Bduf3fA34ODAz3+YW/ZWcqM2sXsp7hXw20B0RayNiF3ArMK/PNfOAG/P924E5eQ/5POA3EbEaICI2RESvpKOAgyPiFxERwLeADxQ1xInazOql0Wh5a15ALt8WNkU6Bni26eN1+TH6uyYieoDNwBTgBCAkLZP0K0mfa7p+XUHMvbT1w0Qzs720MPb8b5fGYmBxglZ0AmcC7wB2APdKeogskQ+Ze9RmVi9D6FEXWA8c2/Tx1PxYv9fk49KTgQ1kPeWfRcQrEbEDWAqcll8/tSDmXpyozaxeykvUK4CZkqZLGgfMB7r6XNMFXJLvXwDcl489LwNOkTQxT+DvBh6JiOeBLZLemY9lfxT4YVFDCoc+JL2VbAxleURsazo+NyLuKrrfzGxfit5yittGRI+kS8mSbgdwfUSskXQVsDIiuoAlwE2SuoGNZMmciNgk6WqyZB/A0oj4cR76k8A/AxOAO/NtUEXrUX8K+DvgUWCJpE9HxJ7s/z+AfhN1PiC/EODLM2bi4gFmts+U+MJLRCwlG7ZoPnZ50/5O4MIB7r2ZbIpe3+MrgZOH0o6iHvXHgdMjYls+kft2SdMi4itka1L3q3mA/oWzz0nzOpeZWX9G4VofY/YMd0TEM5LOIUvWxzFIojYz228a9esbFj1MfFHSqXs+yJP2+4DDgFMStsvMbHjKe5hYGUU96o8Cb3jPMp/U/VFJ/5SsVWZmw1XSw8QqKarwsm6Qc/+v/OaYmY1QG/WUW+U3E82sXmo4Ru1EbWb1MgpnfZiZtRf3qIdu55Z0VSYaLz2dJO6Yjvb6/RW9u4svGiaNSTMLc8xBE5LE7Th4Z5K4mTRfZ43rSBK3M+EKER2q7uoT4TFqM7OKG22zPszM2o6HPszMKs5DH2ZmFecetZlZxXl6nplZxblHbWZWbdEzCmd9SJoNRESskHQiMBd4LF9Q28ysWmrYox501rqkvwe+CnxN0v8ErgUmAZdJ+sIg971egv2WDQOu62RmVr5otL61iaIe9QXAqcB44AVgakRskfQPwHLgv/d3U3OFl2dOPbd+v97MrLpq2KMuStQ9EdEL7JD0VERsAYiIVyW1z68jMxs1YhQm6l2SJkbEDuD0PQclTQacqM2sekbhw8SzI+I1gIg3DOiMBS5J1iozs+EabT3qPUm6n+OvAK8kaZGZ2UiMtkRtZtZuIpyozcyqrYY96uqu/m1mNhyNaH0rIGmupMcldUu6rJ/z4yXdlp9fLmlafnyapFclrcq3rzfdc38ec8+5I4rakbxH3ehNUyEEIDa/nCbu4ccliZts+cWUE/fHjksSVn9U+L05LGlqpWQ0fmvC6OU7NOGP9zalq9w0UtFTzs+DpA7gOuBcYB2wQlJXRDzSdNkCYFNEzJA0H1gEfCg/91REnDpA+IsjYmWrbXGP2szqpTGEbXCzge6IWBsRu4BbgXl9rpkH3Jjv3w7MkVR679SJ2sxqJRrR8ta83EW+LWwKdQzwbNPH6/Jj9HdNRPQAm4Ep+bnpkn4t6QFJZ/W574Z82OO/tpLY/TDRzOplCA8Tm5e7KNnzwJsjYoOk04EfSDopf7v74ohYL+kg4A7gI8C3BgvmHrWZ1Ut5Qx/rgWObPp6aH+v3GkmdwGRgQ0S8FhEbACLiIeAp4IT84/X5n1uBW8iGWAblRG1mtTKUoY8CK4CZkqZLGgfMB7r6XNPFv72lfQFwX0SEpMPzh5FIOh6YCayV1CnpsPz4WOB9wMNFDfHQh5nVSvSUM486InokXQosI5tQdH1ErJF0FbAyIrqAJcBNkrqBjWTJHOBs4CpJu8n67p+IiI2SJgHL8iTdAdwDfKOoLU7UZlYvJc5WzQukLO1z7PKm/Z3Ahf3cdwfZ+HPf49tpWuCuVUMe+pA06KC3mdn+VMO6AYP3qCX1HY8R8B5JhwBExPsHuG8hsBDgi0e/jYsOnTrylpqZtaKNEnCrioY+pgKPAN8EgixRzwL+cbCbmqe8rD3lvPq9eG9mldVOPeVWFQ19zAIeAr4AbI6I+4FXI+KBiHggdePMzIYqelrf2kXRetQN4BpJ383/fLHoHjOz/amOPeqWkm5ErAMulPRXwJa0TTIzG75Rm6j3iIgfAz9O1BYzs5GLdCt27i8exjCzWhn1PWozs6qLhnvUZmaVlrJYyf6SPFE/8dKhyWIfcc+9SeLGE48UXzQcYxNVxdiarvLIrlXPpAlc0noMfb32Yrpp+9s3pKl209OTZm20Z+LVJHEB1vVUd06Bhz7MzCrOQx9mZhUXNXwX2onazGrFPWozs4rzw0Qzs4pzj9rMrOLCbyaamVXbqJ+eJ+lMsoq5D0fE3WmaZGY2fI0a9qgHnWkv6ZdN+x8HrgUOAv5e0mWD3LdQ0kpJK5e++lRpjTUzKxKhlrd2UfRKVPOrdAuBcyPiSuA84OKBboqIxRExKyJmvXfCH5fQTDOz1jR61fLWLoqGPsZIehNZQldEvAxZJV1JbVQfwcxGi9E462MyWSkuASHpqIh4XtKB+TEzs0qp4xh1USmuaQOcagAfLL01ZmYj1E5jz60a1rJdEbEjIp4uuzFmZiMV0fpWRNJcSY9L6u5vAoWk8ZJuy88vlzQtPz5N0quSVuXb15vuOV3Sb/N7viqp8DdLmvUVzcz2k0ao5W0wkjqA64DzgROBiySd2OeyBcCmiJgBXAMsajr3VEScmm+faDr+NeDjwMx8m1v0d3KiNrNaaTTU8lZgNtAdEWsjYhdwKzCvzzXzgBvz/duBOYP1kCUdBRwcEb+IiAC+BXygqCFO1GZWK0PpUTe/85FvC5tCHQM82/TxuvwY/V0TET3AZmBKfm66pF9LekDSWU3XryuIuZfkr5BvHdORLHZj47YkcXVImriMSfOQI7Ykai/Q+4feJHFTvea7a3u6b+ldr6WJvXt3mp+RXZHm/10Wu7qzc4fyMDEiFgOLEzTjeeDNEbFB0unADySdNNxgXuvDzGqlxOl564Fjmz6emh/r75p1kjrJpjRvyIc1XgOIiIckPQWckF8/tSDmXjz0YWa1EkPYCqwAZkqaLmkcMB/o6nNNF3BJvn8BcF9EhKTD84eRSDqe7KHh2oh4Htgi6Z35WPZHgR8WNcQ9ajOrld5GOf3PiOiRdCmwDOgAro+INZKuAlZGRBewBLhJUjewkSyZA5wNXCVpN9l7J5+IiI35uU8C/wxMAO7Mt0E5UZtZrZT5+CMilgJL+xy7vGl/J3BhP/fdAdwxQMyVwMlDaYcTtZnVStRwdQsnajOrlYarkJuZVVujhj3qosIBZ0g6ON+fIOlKSf8iaZGkyfumiWZmrQvU8tYuih6PXg/syPe/QjZHcFF+7IaBbmp+2+eeHd2lNNTMrBW9qOWtXRQWDshfiwSYFRGn5fv/V9KqgW5qftvnu0ddXMMRIzOrqhrWti3sUT8s6WP5/mpJswAknQDsTtoyM7NhaAxhaxdFifpvgXfnrz+eCDwoaS3wjfycmVml1HGMuqjCy2bgb/IHitPz69dFxIv7onFmZkNVw5KJrU3Pi4gtwOrEbTEzG7E6Ts/zPGozq5V0i7vuP07UZlYrjeIShG3HidrMaqWO84GTJ+rtiaqaADS2pakyMWbb9iRxNX5ckrjx6q4kcQF6dhRfMxydByaKOz7dpCuNSZMCXktU4WVcuuJKdCph8BFqp2l3rXKP2sxqZdTO+jAzaxft9Gp4q5yozaxW3KM2M6s4j1GbmVWcZ32YmVWchz7MzCqujkMfRRVePiXp2H3VGDOzkepV61u7KFrm9L8ByyX9XNInJR3eStDmCi/3b39y5K00M2vRaFyPei0wlSxhnw48IukuSZdIOmigmyJicUTMiohZ50yaWWJzzcwGNxoTdUREIyLujogFwNHA/wbmkiVxM7NKiSFs7aIoUb9hFCcidkdEV0RcBByXrllmZsPTUOtbEUlzJT0uqVvSZf2cHy/ptvz8cknT+px/s6Rtkj7bdOwZSb+VtErSylb+TkWzPj400ImISLRcj5nZ8JU1pCGpA7gOOBdYB6yQ1BURjzRdtgDYFBEzJM0HFvHGvHk1cGc/4d8TEa+02pZBe9QR8USrgczMqqB3CFuB2UB3RKyNiF3ArcC8PtfMA27M928H5kjZgtiSPgA8DawZyd8Hioc+zMzaylCGPppnqOXbwqZQxwDPNn28Lj9Gf9dERA+wGZgi6UDg88CV/TQxgLslPdTn8w3IL7yYWa0MZegjIhYDixM04wrgmojYpr0rzpwZEeslHQH8RNJjEfGzwYI5UZtZrZQ4m2M90PzC39T8WH/XrJPUCUwGNgBnABdI+jJwCNCQtDMiro2I9QAR8ZKk75MNsezfRP1Cws/Q84c0MyE7d6epHJNqOlDKCi+7d6Sp5LE70aPoHVvSVNEB2L0rzdeiI1HlmD800n1f7Oh9LVnskWqU95O2ApgpaTpZQp4PfLjPNV3AJcCDwAXAfRERwFl7LpB0BbAtIq6VNAkYExFb8/3zgKuKGuIetZnVSllVyCOiR9KlwDKgA7g+ItZIugpYGRFdwBLgJkndwEayZD6YI4Hv58MhncAtEXFXUVucqM2sVsr8d3ZELAWW9jl2edP+TuDCghhXNO2vBf5kqO1wojazWvEyp2ZmFVfiGHVlOFGbWa3UL007UZtZzbTTqnitcqI2s1rprWGfetBELWkc2XST5yLiHkkfBt4FPAosjojd+6CNZmYtG4096hvyayZKugQ4EPgeMIfsbZpL+rspf399IcAHD53N7ANdPMDM9o3R+DDxlIh4e/5q5Hrg6IjolXQzsHqgm5rfn//ScX9dv6+amVVWHRNOUaIekw9/TAImkr3HvhEYD4xN3DYzsyEbjUMfS4DHyF6f/ALwXUlrgXeSrc1qZlYpo+5hYkRcI+m2fP85Sd8C/gL4RkT8cl800MxsKEbjGDUR8VzT/h/IqhiYmVVS/dK051GbWc2Myh61mVk7GY0PE83M2kq4Rz10h5e1inc/xs+clCTumGOPShKXMWlqCWvCAUniAkx4bl2SuD1b06xFOe61dN9wEWnaPHFSmkos4zanq3YzYUy62CM16mZ9mJm1Gw99mJlVXCPcozYzq7T6pWknajOrGU/PMzOrOM/6MDOruB4najOzanOP2sys4uo4Pa/wDQxJx0v6rKSvSLpa0ickHbwvGmdmNlQR0fJWRNJcSY9L6pZ0WT/nx0u6LT+/XNK0PuffLGmbpM+2GrM/gyZqSZ8Cvg4cALyDrGDAscAvJJ0zyH0LJa2UtPKB7U+20g4zs1I0iJa3wUjqAK4DzgdOBC6SdGKfyxYAmyJiBnANsKjP+auBO4cYcy9FPeqPA+dHxBfJ1qE+KSK+AMzNG9WviFgcEbMiYta7J7leopntO71Ey1uB2UB3RKyNiF1kxVLm9blmHnBjvn87MEeSACR9AHgaWDPEmHtpZfGJPePY48mK2xIRv8eluMysgobSo27+13++LWwKdQzwbNPH6/Jj9HdNRPQAm4Epkg4EPg9cOdD1g8TcS9HDxG8CKyQtB84i79ZLOpysdqKZWaW0MvbcdO3rhbhLdgVwTURsyzvYI1JUiusrku4B3gb8Y0Q8lh9/GTh7xJ/dzKxkJc76WE/2TG6Pqfmx/q5ZJ6mTrAD4BuAM4AJJXwYOARqSdgIPtRBzL62U4lrDG8dYzMwqq8R51CuAmZKmkyXT+cCH+1zTBVwCPAhcANwXWZf+rD0XSLoC2BYR1+bJvCjmXjyP2sxqpay1PiKiR9KlwDKgA7g+ItZIugpYGRFdwBLgJkndZMPB84cTs6gtTtRmViu9Ud7gR0QsBZb2OXZ50/5O4MKCGFcUxSziRG1mteJXyIdhdefuZLH/7M40pZGmPPxYkrg9r6UpxbVza7qZks+8/EdJ4m4ck+Zbb/uYNN8TADvT/O9jSk+axPK7WJskLsDW3TuSxR4pFw4wM6u4+qVpJ2ozqxkXDjAzqzgnajOziitz1kdVOFGbWa141oeZWcUNZa2PdlG0HvVkSV+S9JikjZI2SHo0P3bIPmqjmVnLylqPukqKZoZ+B9gEnBMRh0bEFOA9+bHvpG6cmdlQlVnhpSqKEvW0iFgUES/sORARL0TEIuC4gW5qXuP14a1PldVWM7NCvTRa3tpFUaL+naTPSTpyzwFJR0r6PG9c/PoNmiu8nHzQH5fVVjOzQo2Ilrd2UZSoPwRMAR7Ix6g3AvcDh1KwEImZ2f4QQ/ivXRQVDthEVk7m833PSfoYcEOidpmZDUs79ZRbNZJlZvrWAjMz2+9GXY9a0m8GOgUcOcA5M7P9po496qIXXo4E/pJsOl4zAf+apEVmZiMwGl8h/xFwYESs6ntC0v0pGmRmNhLtNKTRqqKHiQsGOVdYkNHMbF+LUdijHrEHdz2fLPakcVOTxJ389JuSxN2hNL/pt9GbJC7ACwe8liTupsaWJHF7E/amdjbSVCuaMnZikriPvvD7JHGh2ovzt9Or4a3yokxmVivt9Gp4q5yozaxW3KM2M6u43obHqM3MKm3UzfowM2s3dRyjHskr5GZmlVNm4QBJcyU9Lqlb0mX9nB8v6bb8/HJJ0/LjsyWtyrfVkj7YdM8zkn6bn1vZyt9p2D1qSXdGxPnDvd/MLIWyetSSOoDrgHOBdcAKSV0R8UjTZQuATRExQ9J8YBHZqqMPA7MiokfSUcBqSf8SET35fe+JiFdabUvRWh+nDXQKOLXVT2Jmtq+U+DBxNtAdEWsBJN0KzAOaE/U84Ip8/3bgWkmKiB1N1xzACKeeF/WoVwAPkCXmvg4Z6CZJC4GFAMdNnsHhE48abvvMzIZkKNPzmnNVbnFELM73j+GNBVLWAWf0CfH6NXnveTPZGv6vSDoDuJ6sGtZHmnrTAdwtKYB/avp8AypK1I8C/yEinuznLzhohRdgMcA7jj67fiP7ZlZZQxn6aM5VCdqxHDhJ0tuAG/Ph4p3AmRGxXtIRwE8kPRYRPxssVtHDxCsGueY/DrXhZmaplViKaz1wbNPHU/Nj/V4jqROYDGxoviAiHgW2ASfnH6/P/3wJ+D7ZEMugBk3UEXF7RDw+wOk0C2KYmY1AiYUDVgAzJU2XNA6YD3T1uaYLuCTfvwC4LyIiv6cTQNJxwFuBZyRNknRQfnwScB7Zg8dBjWQe9ZW4FJeZVUxZhQPyMedLgWVAB3B9RKyRdBWwMiK6gCXATZK6gY1kyRzgTOAySbuBBvDJiHhF0vHA9yVBln9viYi7itriCi9mViuNEpc5jYilwNI+xy5v2t9JP4W+I+Im4KZ+jq8F/mSo7XCFFzOrlTq+megKL2ZWK6MuUbvCi5m1m/qlabLfPlXZgIXtFrvd4rZjm/218NditG9VW5RpYfEllYvdbnFTxm63uCljt1vclLFTtnlUqFqiNjOzPpyozcwqrmqJOsk794ljt1vclLHbLW7K2O0WN2XslG0eFZQP9puZWUVVrUdtZmZ9OFGbmVVcZRJ1UW2yEcS9XtJLkgpXqBpi3GMl/VTSI5LWSPp0SXEPkPTLvM7aGklXlhG3KX6HpF9L+lHJcYdcB67FuIdIul3SY5IelfRnJcR8S1M9u1WStkj6TAnNRdJ/yv+/PSzp25IOKCNuHvvTedw1I2lvfz8Tkg6V9BNJT+Z/Dmt1zAFiX5i3uSFp1nDbPart74nc+Rh5B/AUcDwwDlgNnFhS7LOB04CHS27zUcBp+f5BwBNltJlsHZUD8/2xwHLgnSW2+z8DtwA/Kvnr8QxwWILvjRuBv833xwGHlBy/A3gBOK6EWMcATwMT8o+/A/xNSe08mWw5zIlkbxTfA8wYZqy9fiaALwOX5fuXAYtKjP024C3A/WR1BEv9HhkNW1V61K/XJouIXcCe2mQjFlnlhI1lxOoT9/mI+FW+v5WsGs4xJcSNiNiWfzg230p54itpKvBXwDfLiJeapMlkP/hLACJiV0T8oeRPMwd4KiJ+V1K8TmBCvhbxROC5kuK+DVgeETsiK+n0APDvhxNogJ+JeWS/FMn//EBZsSPi0Rh4XXtrQVUSdX+1yUac9PaVvET8n5L1fsuI1yFpFfAS8JPISvqU4X8BnyNbH7dse+rAPZTXoSvDdOBl4IZ8uOab+WLrZZoPfLuMQJFV7vgH4PfA88DmiLi7jNhkvemzJE2RNBF4L2+sPjJSR0bE8/n+C3gZ40qpSqJuW5IOBO4APhMRW8qIGRG9EXEqWemf2ZJOHmlMSe8DXoqIh0YaawBnRsRpwPnA30k6u4SYnWT/jP5aRPwpsJ3sn+WlyKt2vB/4bknx3kTWM50OHA1MkvTXZcSOrJzTIuBu4C5gFdBbRux+PldQ07WN2lVVEnUrtckqR9JYsiT9fyLie2XHz/+Z/1Ngbgnh/h3wfknPkA0t/bmkm0uICwyvDlwL1gHrmv5FcTtZ4i7L+cCvIuLFkuL9BfB0RLwcEbuB7wHvKik2EbEkIk6PiLPJ1oh/oqzYwIuSjgLI/3ypxNg2QlVJ1K3UJqsUZbV0lgCPRsTVJcY9XNIh+f4E4FzgsZHGjYj/EhFTI2Ia2df3vogopbc33DpwRSLiBeBZSW/JD80BHhlp3CYXUdKwR+73wDslTcy/P+aQPbsoRV61GklvJhufvqWs2Lyx9t8lwA9LjG0jtb+fZu7ZyMbcniCb/fGFEuN+m2y8cDdZD21BSXHPJPvn4W/I/hm6CnhvCXHfDvw6j/swcHmCr/U5lDjrg2y2zup8W1Py/79TgZX51+MHwJtKijuJrFr05JK/tleS/WJ9mKwU0/gSY/+c7BfVamDOCOLs9TMBTAHuBZ4km1FyaImxP5jvvwa8CCwr82s+Gja/Qm5mVnFVGfowM7MBOFGbmVWcE7WZWcU5UZuZVZwTtZlZxTlRm5lVnBO1mVnF/X8ayvby+EAPwAAAAABJRU5ErkJggg==\n", 266 | "text/plain": [ 267 | "
" 268 | ] 269 | }, 270 | "metadata": { 271 | "needs_background": "light" 272 | }, 273 | "output_type": "display_data" 274 | } 275 | ], 276 | "source": [ 277 | "sns.heatmap(S_agg)" 278 | ] 279 | }, 280 | { 281 | "cell_type": "markdown", 282 | "id": "38de0e8e-6b86-4297-b3df-762427912835", 283 | "metadata": {}, 284 | "source": [ 285 | "**Hungarian Algorithm Matching**" 286 | ] 287 | }, 288 | { 289 | "cell_type": "code", 290 | "execution_count": 17, 291 | "id": "98b41fc7-81e6-4ca0-9fda-850d61d440b7", 292 | "metadata": {}, 293 | "outputs": [], 294 | "source": [ 295 | "_, edges = linear_sum_assignment(-S_agg)" 296 | ] 297 | }, 298 | { 299 | "cell_type": "code", 300 | "execution_count": 18, 301 | "id": "34db4842-b3eb-49a4-9de0-753a50fa05a9", 302 | "metadata": {}, 303 | "outputs": [ 304 | { 305 | "data": { 306 | "text/plain": [ 307 | "" 308 | ] 309 | }, 310 | "execution_count": 18, 311 | "metadata": {}, 312 | "output_type": "execute_result" 313 | }, 314 | { 315 | "data": { 316 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAV0AAAD8CAYAAADUv3dIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAAXJUlEQVR4nO3de9RddX3n8fcnCUFuDSqWgSQqrdHqaBeXDFovmA7QRtsFbUcHZHVEFzVds4qXOjM1HWew4ppZ0lE7dpW2RsCKHWGAcWrGpoJa0LYWTKzAJOEWI0K4KygjOEKe85k/9o5zeJpz9jnPs88+ex8+L9de2WdfvueXB/PNL7/927+vbBMREc1YMu0GREQ8nSTpRkQ0KEk3IqJBSboREQ1K0o2IaFCSbkREg5J0IyIGkHSJpAclbR9wXpL+UNIuSTdLOr4qZpJuRMRgfwasH3L+dcCactsA/ElVwCTdiIgBbH8FeHjIJacDl7pwPXC4pKOGxVxWZwP358nv7J7YK28HHf2aSYWOiCnY+8Q9WmyMcXLO8uf89G9S9FD32WR70xhftxK4u+/znvLYfYNumHjSjYhoVG9u5EvLBDtOkl20JN2ImC3uNflt9wCr+z6vKo8NlDHdiJgtvd7o2+JtBt5czmJ4BfB92wOHFiA93YiYMa6xpyvpMmAdcISkPcD7gAOK7/GfAluA1wO7gMeBt1bFTNKNiNkyt7e2ULbfVHHewG+NEzNJNyJmyxgP0qYhSTciZkuzD9LGVpl0Jf0MxQTgleWhe4DNtm+ZZMMiIhakngdkEzN09oKk9wCXAwK+Vm4CLpO0cch9GyRtk7Ttoksvq7O9ERFD2b2Rt2nQsBppkm4H/qntJ+cdXw7ssL2m6gvyRlpEjKqON9J+dMdXR845B6555aK/b1xVwws94Gjg2/OOH1Wei4hol7knq6+Zoqqk+y7gS5Lu4P+/X/xc4AXAuRNsV0TEwnT5QZrtz0t6IXAiT32QttV2u+dlRMTTU8sfpFXOXnAx2nx9A22JiFi8Lvd0IyI6p+s93YiILnGv2w/SIiK65ene053kXNof3vs3E4mb+b8RHZYx3YiIBmXBm4iIBqWnGxHRoKf7mG5ERKNqXMR8EpJ0I2K2pKcbEdGctq9QkKQbEbMlPd2IiAZl9kJERINa3tMdWq5nGEkD67v3l+vp9R5b6FdERIxvbu/o2xQsOOkC7x90wvYm22ttr12y5JBFfEVExJjcG32bgqHDC5JuHnQKOLL+5kRELFLLhxeqxnSPBH4ReGTecQFfnUiLIiIWo+NJ93PAobZvnH9C0nWTaFBExKJ0efaC7XOGnDur/uZERCxSXgOOiGhQx4cXIiK6pcvDCxERnZOe7uRMqqxOygBFdFiSbkREg+xpt2CoJN2ImC17M3shIqI5LX+Qtpi1FyIi2qfXG32rIGm9pNsk7ZK0cT/nnyvpWknfkHSzpNdXxUzSjYjZYo++DSFpKXAh8DrgJcCbJL1k3mX/AbjC9nHAmcAfVzUvwwsRMVvqm71wIrDL9m4ASZcDpwM7+64x8BPl/grg3qqgSboRMVvGSLqSNgAb+g5tsr2p3F8J3N13bg/w8nkhfg+4RtLbgUOAU6q+szLpSvqZ8stvsP2DvuPrbX++6v6IiCZ5bvTClGWC3VR54WBvAv7M9ocl/RzwKUkvtQc/zRs6pivpHcBngbcD2yWd3nf6Pw+5L5UjImI66nuQdg+wuu/zqvJYv3OAKwBs/z3wDOCIYUGrHqS9DTjB9q8A64D/KOmd5TkNuimVIyJiauqrHLEVWCPpGEnLKR6UbZ53zV3AyQCSXkyRdB8aFrRqeGHJviEF23dKWgdcJel5DEm6ERFT06vnjTTbeyWdC1wNLAUusb1D0vnANtubgX8DfFzSb1M8VHuLPXxaRFXSfUDSsfsWMbf9A0m/DFwCvGxxv6WIiAmoce0F21uALfOOnde3vxN41Tgxq5Lum4GnvFNney/wZkkfG+eLIiIaMcaDtGmoqhyxZ8i5v6u/ORERi5RVxiIiGlTTmO6kJOlGxGxp+YI3SboRMVvS0+2eVKSI6C5nTDciokFdnr0QEdE5GV6IiGhQhhciIhqUnm5ERIMyZSwiokHp6UZENMd7Oz57QdKJgG1vLYuyrQduLVffiYholy73dCW9j6IS5jJJX6CoD3QtsFHScbb/04D7flx3SEtXkIXMI6IxHR/TfQNwLHAgcD+wyvajkj4E3ADsN+n21x1atnxlu//aiYjZ0uWeLrDX9hzwuKRv2n4UwPYPJbX7r5OIeFpyx5PuE5IOtv04cMK+g5JWAEm6EdE+HX+QdpLtHwHMKyl8AHD2xFoVEbFQXe7p7ku4+zn+HeA7E2lRRMRidDnpRkR0TUUx3qlL0o2I2ZKebkREg5J0Y5+uVaSAVKWI7vHedk+sStKNiNnS7pybpBsRs6XrL0dERHRLkm5ERIMyvBAR0ZwML0RENMh7k3QjIprT8uGFJePeIOnSSTQkIqIO7o2+TUNV5YjN8w8BPy/pcADbpw24L5UjImI6Wt7TrRpeWAXsBC4CTJF01wIfHnZTKkdExLS0vFpP5fDCWuDrwHuB79u+Dvih7S/b/vKkGxcRMS7vHX2rImm9pNsk7ZK0ccA1/1LSTkk7JH26KmbVero94A8kXVn++kDVPRER01RXT1fSUuBC4FRgD7BV0mbbO/uuWQP8LvAq249I+smquCMlUNt7gDdK+iXg0YX8BiIimlDj8MKJwC7buwEkXQ6cTjHkus/bgAttPwJg+8GqoGPNXrD9l7b//Tj3REQ0yhp5k7RB0ra+bUNfpJXA3X2f95TH+r0QeKGkv5N0vaT1Vc3LUEFEzJRxerr9D/0XaBmwBlhHMfHgK5JeZvt7w26IiJgZ7qmuUPcAq/s+ryqP9dsD3GD7SeBbkm6nSMJbBwUd++WIiIg2681p5K3CVmCNpGMkLQfOBOa/u/AXFL1cJB1BMdywe1jQ9HRnwCSrO0yqKkUqUsSk1PUgzfZeSecCVwNLgUts75B0PrDN9uby3C9I2gnMAf/O9neHxU3SjYiZUuPwAra3AFvmHTuvb9/Au8ttJEm6ETFTWl6BPUk3ImZLnT3dSUjSjYiZMsIDsqlK0o2ImZKebkREg+wk3YiIxrR9acexkq6kV1MsArHd9jWTaVJExML1Wt7THfpGmqSv9e2/Dfgj4DDgfYPWliyv/fEiEr3eY7U1NiKiiq2Rt2mo6uke0Le/ATjV9kOSPgRcD3xwfzelckRETEvXZy8skfRMih6xbD8EYPsxSSOsux4R0ayuz15YQVGuR4AlHWX7PkmHlsciIlql7WO6VeV6nj/gVA/41dpbExGxSDM5Zcz248C3am5LRMSiZe2FiIgGdXp4ISKia3odf5AWEdEp6elGp02qwkMqUsSkzOSDtIiItkpPNyKiQS2fvJCkGxGzZa7X7iLnSboRMVNavrJjkm5EzBa3fIWCJN2ImCm9lg/qJulGxEzptbynW7WI+csl/US5f5Ck90v6X5IukLSimSZGRIzOaORtGqoe810CPF7uf5RiqccLymOfGHRTKkdExLTMoZG3aahcxNz2vsXK19o+vtz/W0k3DroplSMiYlraPnuhqqe7XdJby/2bJK0FkPRC4MmJtiwiYgF6Y2zTUJV0fwN4raRvAi8B/l7SbuDj5bmIiFZp+5huVeWI7wNvKR+mHVNev8f2A000LiJiXC1f2XG0KWO2HwVumnBbIiIWre1TxjJPNyJmyty0G1AhSTciZkpP6elGRDSm7XNUk3RjKrpWkQJSlaIruj5PNyKiU3oafasiab2k2yTtkrRxyHX/QpL3vcswTHq6ETFT6nq9V9JS4ELgVGAPsFXSZts75113GPBO4IZR4qanGxEzpcae7onALtu7bT8BXA6cvp/rPkCxJs3/HaV9SboRMVPGeQ24f3GuctvQF2olcHff5z3lsR+TdDyw2vZfjtq+DC9ExEwZZ/ZC/+Jc45K0BPgI8JZx7kvSjYiZUuNrwPcAq/s+ryqP7XMY8FLgOhVzg/8JsFnSaba3DQqapBsRM6XGKWNbgTWSjqFItmcCZ+07Wa5Nc8S+z5KuA/7tsIQL1ZUj3iFp9bBrIiLaZE6jb8OUa4mfC1wN3AJcYXuHpPMlnbbQ9lX1dD8AbCyXdrwMuNL2Q1VBy8HoDQBauoIlSw5ZaPsiIsZS58sRtrcAW+YdO2/AtetGiVk1e2E3xTjGB4ATgJ2SPi/p7HJu2qCGbrK91vbaJNyIaFLXFzG37Z7ta2yfAxwN/DGwniIhR0S0isfYpqFqeOEpox62nwQ2UzyhO3hirYqIWKCuL2J+xqATth8fdC4iYlravuBNVbme25tqSEREHbKIeUREg7o+vBAR0SmdHl6IiOiaVI6IaNAkqztMqipFKlLUq9fytJukGxEzJQ/SIiIalDHdiIgGZfZCRESDMqYbEdGgdqfcJN2ImDEZ042IaNBcy/u6Q5OupOUUJSrutf1FSWcBr6RYRX1TuepYRERrdL2n+4nymoMlnQ0cCnwGOJmiJvzZ+7splSMiYlq6/iDtZbZ/VtIyisJsR9uek/TnwE2Dbuova7xs+cp2/wQiYqa0PeFUJd0l5RDDIcDBwArgYeBA4IAJty0iYmxdH164GLgVWAq8F7hS0m7gFcDlE25bRMTYOv0gzfYfSPrv5f69ki4FTgE+bvtrTTQwImIcXR/Txfa9ffvfA66aZIMiIhaj3Sk383QjYsZ0vqcbEdElXX+QFhHRKU5PN2I2TKrCQypS1KvTsxciIromwwsREQ3qOT3diIjGtDvlJulGxIzJlLGIiAZl9kJERIP2JulGRDSn7T3dJdNuQEREnXpjbFUkrZd0m6Rdkjbu5/y7Je2UdLOkL0l6XlXMyp6upJ8Cfg1YDcwBtwOftv3oCG2OiGiUa5oyJmkpcCFwKrAH2Cpps+2dfZd9A1hr+3FJ/xr4feCMYXGH9nQlvQP4U+AZwD+jWLx8NXC9pHVD7tsgaZukbb3eY1W/t4iI2vTwyFuFE4FdtnfbfoJiDfHT+y+wfa3tx8uP1wOrqoJW9XTfBhxbluj5CLDF9jpJHwM+Cxy3v5tSricipmWc14D76zmWNpX5C2AlcHffuT3Ay4eEOwf4q6rvHOVB2jKKYYUDKQpTYvsuSSnXExGtM8483f4O4mJI+nVgLfDaqmurku5FFOMYNwCvAS4ov+A5FLXSIiJapa4xXYpivKv7Pq8qjz2FpFMoypm91vaPqoJWlev5qKQvAi8GPmz71vL4Q8BJo7c9IqIZNS54sxVYI+kYimR7JnBW/wWSjgM+Bqy3/eAoQUcp17MD2DF2cyMipqCuebq290o6F7iaojjvJbZ3SDof2GZ7M/BfKIZdr5QEcJft04bFzcsRETFT6lx7wfYWYMu8Y+f17Z8ybswk3YiYKXNu94q6SboRMVPa/hpwkm7ElHWtDBC0uxRQFjGPiGhQu1Nukm5EzJgsYh4R0aAk3YiIBmX2QkREgzJ7ISKiQTWuvTARVevprpD0QUm3SnpY0ncl3VIeO7yhNkZEjKzG9XQnoqpczxXAI8A628+y/Wzg58tjV0y6cRER47I98jYNVUn3+bYvsH3/vgO277d9ATCwFlAqR0TEtMzRG3mbhqqk+21JvyPpyH0HJB0p6T08dUX1p7C9yfZa22uXLDmkrrZGRFTq2SNv01CVdM8Ang18uRzTfRi4DngW8MYJty0iYmwe43/TULWI+SPAe8rtKSS9FfjEhNoVEbEgbV97oaqnO8z7a2tFRERNOt3TlXTzoFPAkQPORURMTdt7ulUvRxwJ/CLFFLF+Ar46kRZFRCxC118D/hxwqO0b55+QdN0kGhQRsRidfg3Y9jlDzp016FxExLS44z3diOioSVZ3mGRVisXK0o4REQ1q+4I3SboRMVPS042IaNBcL2O6ERGN6fTshYiIrsmYbkREg9o+prvgtRck/VWdDYmIqEPbFzGvWnvh+EGngGNrb01ExCJ1/UHaVuDLFEl2vsMH3SRpA7ABQEtXkIXMI6IpbR9eqEq6twC/afuO+SckDa0cAWwCWLZ8Zbt/AhExU7r+IO33GDzu+/Z6mxIRsXidXtrR9lVDTj+z5rZERCxa2+fppnJERMyUthemTOWIiJgpvY4v7ZjKERHRKXU+SJO0HvgosBS4yPYH550/ELgUOAH4LnCG7TuHxUzliIiYKXUlXUlLgQuBU4E9wFZJm23v7LvsHOAR2y+QdCZwAXDGsLhDx3Rtn2P7bwecS+WIiGgdj7FVOBHYZXu37SeAy4HT511zOvDJcv8q4GRJ+3uvoa+BY7wyN+kN2NC12F2L28U252eRn8Ukf8/Atr5tQ9+5N1AMKez7/K+AP5p3/3ZgVd/nbwJHDPvOxcxemIQNHYzdtbiTjN21uJOM3bW4k4w9yTYviu1Nttf2bZsm/Z1tS7oREW1xD7C67/Oq8th+r5G0DFhB8UBtoCTdiIj92wqskXSMpOXAmcDmeddsBs4u998A/LXLcYZB2rae7iS79pOK3bW4k4zdtbiTjN21uJOMPfF/sk+C7b2SzgWuppgydontHZLOB7bZ3gxcDHxK0i7gYYrEPJQqknJERNQowwsREQ1K0o2IaFBrkq6k9ZJuk7RL0sYa414i6UFJ2+uKWcZdLelaSTsl7ZD0zpriPkPS1yTdVMatdWEhSUslfUPS52qOe6ek/y3pRknbaox7uKSrJN0q6RZJP1dDzBeV7dy3PSrpXTU0F0m/Xf532y7pMknPqCNuGfudZdwdi2nv/v5MSHqWpC9IuqP8dUGrCA6I/cayzT1Jaxfa7pkx7cnJ5ZjyUopJxT8FLAduAl5SU+yTgOOB7TW3+Sjg+HL/MOD2OtpMsa7FoeX+AcANwCtqbPe7gU8Dn6v553EnFZPCFxj3k8BvlPvLgcNrjr8UuB94Xg2xVgLfAg4qP18BvKWmdr6UYiL+wRQPwL8IvGCBsf7Rnwng94GN5f5G4IIaY78YeBFwHbC27v+PdG1rS093lNftFsT2VyieKtbK9n22/6Hc/z8UVTZW1hDXtn9Qfjyg3Gp52ilpFfBLwEV1xJs0SSso/hBfDGD7Cdvfq/lrTga+afvbNcVbBhxUztk8GLi3prgvBm6w/bjtvRRltH5tIYEG/Jnof531k8Cv1BXb9i22b1tIvFnUlqS7Eugv/7OHGhJYUyQ9HziOoldaR7ylkm4EHgS+YLuWuMB/BX4HmMTadwaukfT1skZeHY4BHgI+UQ6JXCSp7oJ7ZwKX1RHI9j3Ah4C7gPuA79u+po7YFL3c10h6tqSDgdfz1In7i3Wk7fvK/fvJ0q0T05ak21mSDgX+B/Au24/WEdP2nO1jKd6AOVHSSxcbU9IvAw/a/vpiYw3watvHA68DfkvSSTXEXEbxT9U/sX0c8BjFP31rUU54Pw24sqZ4z6ToMR4DHA0cIunX64ht+xaKFayuAT4P3AjM1RF7P9814nowsRBtSbqjvG7XOpIOoEi4/832Z+qOX/5T+lpgfQ3hXgWcJulOiuGbfy7pz2uIC/y4l4ftB4H/STFktFh7gD19Pf2rKJJwXV4H/IPtB2qKdwrwLdsP2X4S+AzwyppiY/ti2yfYPolijevb64oNPCDpKIDy1wdrjB192pJ0R3ndrlXK5dsuBm6x/ZEa4z5H0uHl/kEUa3neuti4tn/X9irbz6f4+f617Vp6YZIOkXTYvn3gFyj+Obwotu8H7pb0ovLQycDOIbeM603UNLRQugt4haSDy/9/nEwx1l8LST9Z/vpcivHcT9cVm6e+zno28NkaY0e/aT/J27dRjFHdTjGL4b01xr2MYnztSYqe0zk1xX01xT/Bbqb4p96NwOtriPuzwDfKuNuB8ybws15HjbMXKGad3FRuO2r+73csxZJ7NwN/ATyzpriHUCxMsqLmn+37Kf6S3A58Cjiwxth/Q/GXzk3AyYuI84/+TADPBr4E3EExM+JZNcb+1XL/R8ADwNV1/sy7tuU14IiIBrVleCEi4mkhSTciokFJuhERDUrSjYhoUJJuRESDknQjIhqUpBsR0aD/B4cTrFpeIVllAAAAAElFTkSuQmCC\n", 317 | "text/plain": [ 318 | "
" 319 | ] 320 | }, 321 | "metadata": { 322 | "needs_background": "light" 323 | }, 324 | "output_type": "display_data" 325 | } 326 | ], 327 | "source": [ 328 | "tmp_plot = torch.zeros(num_layers, num_layers)\n", 329 | "tmp_plot[torch.arange(num_layers), edges] = 1\n", 330 | "sns.heatmap(tmp_plot)" 331 | ] 332 | }, 333 | { 334 | "cell_type": "markdown", 335 | "id": "57f1a490-e007-4a8c-a6f9-c6dee9862674", 336 | "metadata": {}, 337 | "source": [ 338 | "### Plot All" 339 | ] 340 | }, 341 | { 342 | "cell_type": "markdown", 343 | "id": "0dd124d1-d7d1-4738-9d4e-46dd08448c67", 344 | "metadata": {}, 345 | "source": [ 346 | "**Projected Vectors in Embedding Space**" 347 | ] 348 | }, 349 | { 350 | "cell_type": "code", 351 | "execution_count": 19, 352 | "id": "f65b94d9-8225-4d37-b2e5-51b0e8685c83", 353 | "metadata": {}, 354 | "outputs": [ 355 | { 356 | "data": { 357 | "application/vnd.jupyter.widget-view+json": { 358 | "model_id": "6d135a9724e64019a4be61be4e988ce3", 359 | "version_major": 2, 360 | "version_minor": 0 361 | }, 362 | "text/plain": [ 363 | " 0%| | 0/6 [00:00" 374 | ] 375 | }, 376 | "metadata": {}, 377 | "output_type": "display_data" 378 | } 379 | ], 380 | "source": [ 381 | "plt.figure(figsize=(7.5, 4.25))\n", 382 | "for i, (title, param1, param2) in enumerate(tqdm([(\"$K$\", K1, K2), (\"$V$\", V1, V2), (\"$W_K$\", WK1, WK2), \n", 383 | " (\"$W_Q$\", WQ1, WQ2), (\"$W_V$\", WV1, WV2), (\"$W_O$\", WO1, WO2)])):\n", 384 | " S = kernelized_corr_coef(param1.to(device), param2.to(device), \n", 385 | " kernel11.to(device), kernel22.to(device), kernel12.to(device), \n", 386 | " emb_mean1.to(device), emb_mean2.to(device), n=len(tokenizer))\n", 387 | " layer_size = param1.shape[0] // num_layers\n", 388 | " S_agg = S.view(num_layers, layer_size, num_layers, layer_size).abs().mean([-1, -3]).cpu().numpy()\n", 389 | " _, edges = linear_sum_assignment(-S_agg)\n", 390 | " tmp_plot = torch.zeros(num_layers, num_layers)\n", 391 | " tmp_plot[torch.arange(num_layers), edges] = 1\n", 392 | " plt.subplot(2, 3, i+1)\n", 393 | " plt.title(title)\n", 394 | " sns.heatmap(tmp_plot, cbar=False)\n", 395 | " plt.xticks([])\n", 396 | " plt.yticks([])\n", 397 | "plt.savefig(\"artifacts/all_diagonals.pdf\") " 398 | ] 399 | }, 400 | { 401 | "cell_type": "markdown", 402 | "id": "58319b85-c83b-4512-8579-7e8e48f68e55", 403 | "metadata": {}, 404 | "source": [ 405 | "**Raw Vectors in Feature Space**" 406 | ] 407 | }, 408 | { 409 | "cell_type": "code", 410 | "execution_count": 20, 411 | "id": "60027716-58b8-4170-bfdf-7369dd099623", 412 | "metadata": {}, 413 | "outputs": [ 414 | { 415 | "data": { 416 | "application/vnd.jupyter.widget-view+json": { 417 | "model_id": "afe47250216a41a485d392af4e71dd73", 418 | "version_major": 2, 419 | "version_minor": 0 420 | }, 421 | "text/plain": [ 422 | " 0%| | 0/6 [00:00" 433 | ] 434 | }, 435 | "metadata": {}, 436 | "output_type": "display_data" 437 | } 438 | ], 439 | "source": [ 440 | "plt.figure(figsize=(7.5, 4.25))\n", 441 | "for i, (title, param1, param2) in enumerate(tqdm([(\"$K$\", K1, K2), (\"$V$\", V1, V2), (\"$W_K$\", WK1, WK2), \n", 442 | " (\"$W_Q$\", WQ1, WQ2), (\"$W_V$\", WV1, WV2), (\"$W_O$\", WO1, WO2)])):\n", 443 | " S = corr_coef(param1.to(device), param2.to(device))\n", 444 | " layer_size = param1.shape[0] // num_layers\n", 445 | " S_agg = S.view(num_layers, layer_size, num_layers, layer_size).abs().mean([-1, -3]).cpu().numpy()\n", 446 | " _, edges = linear_sum_assignment(-S_agg)\n", 447 | " tmp_plot = torch.zeros(num_layers, num_layers)\n", 448 | " tmp_plot[torch.arange(num_layers), edges] = 1\n", 449 | " plt.subplot(2, 3, i+1)\n", 450 | " plt.title(title)\n", 451 | " sns.heatmap(tmp_plot, cbar=False)\n", 452 | " plt.xticks([])\n", 453 | " plt.yticks([])\n", 454 | "plt.savefig(\"artifacts/all_diagonals_raw.pdf\") " 455 | ] 456 | }, 457 | { 458 | "cell_type": "code", 459 | "execution_count": null, 460 | "id": "ddb9ac26-ca92-4d3b-aeb2-28b86bb8cf84", 461 | "metadata": {}, 462 | "outputs": [], 463 | "source": [] 464 | } 465 | ], 466 | "metadata": { 467 | "kernelspec": { 468 | "display_name": "Python 3", 469 | "language": "python", 470 | "name": "python3" 471 | }, 472 | "language_info": { 473 | "codemirror_mode": { 474 | "name": "ipython", 475 | "version": 3 476 | }, 477 | "file_extension": ".py", 478 | "mimetype": "text/x-python", 479 | "name": "python", 480 | "nbconvert_exporter": "python", 481 | "pygments_lexer": "ipython3", 482 | "version": "3.8.13" 483 | } 484 | }, 485 | "nbformat": 4, 486 | "nbformat_minor": 5 487 | } 488 | -------------------------------------------------------------------------------- /sentiment-analysis.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "4c4c816e-f4ca-43ba-8000-c93942620147", 6 | "metadata": {}, 7 | "source": [ 8 | "## Init" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "id": "0dd22ca7-dbdd-4855-83ac-405efbf1e408", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import torch\n", 19 | "from torch import nn\n", 20 | "import torch.nn.functional as F\n", 21 | "from copy import deepcopy\n", 22 | "from transformers import (AutoModelForMaskedLM, AutoModelForCausalLM, AutoTokenizer, AutoModelForTokenClassification,\n", 23 | " AutoModelForSequenceClassification, TrainingArguments, Trainer)\n", 24 | "from tqdm.auto import tqdm\n", 25 | "import numpy as np\n", 26 | "import matplotlib.pyplot as plt\n", 27 | "import seaborn as sns\n", 28 | "import json\n", 29 | "from tensorflow.keras.models import load_model\n", 30 | "from datasets import load_dataset, load_metric\n", 31 | "import os\n", 32 | "from utils import top_tokens\n", 33 | "from tabulate import tabulate" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 2, 39 | "id": "07fdd79f-7d70-4548-8b9c-909cc86d81f3", 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "tokenizer = AutoTokenizer.from_pretrained('gpt2') # ('bert-base-uncased') # get_multiberts_tokenizer()" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 3, 49 | "id": "f14b9d71-1578-402e-bb0e-e343f0cbd78a", 50 | "metadata": {}, 51 | "outputs": [], 52 | "source": [ 53 | "class Gpt2AvgClassifier(nn.Module):\n", 54 | " def __init__(self, name, freeze=None, num_labels=2):\n", 55 | " super().__init__()\n", 56 | " self.model = AutoModelForTokenClassification.from_pretrained(name, num_labels=num_labels)\n", 57 | " self.model.transformer.ln_f = nn.Identity(self.model.config.n_ctx)\n", 58 | " if freeze is not None:\n", 59 | " for n, p in self.named_parameters():\n", 60 | " p.requires_grad = False\n", 61 | " if len(n.split('.transformer.h.')) == 2 and n.endswith('.weight'):\n", 62 | " if int(n.split('.transformer.h.')[1].split('.')[0]) >= freeze:\n", 63 | " p.requires_grad = True\n", 64 | " print(n)\n", 65 | " if n.endswith('.classifier.weight'):\n", 66 | " p.requires_grad = True\n", 67 | " print(n)\n", 68 | " \n", 69 | " def forward(self, input_ids, labels, inputs_embeds=None):\n", 70 | " res = self.model(input_ids=input_ids, inputs_embeds=inputs_embeds)\n", 71 | " res.logits = res.logits.mean(dim=-2)\n", 72 | " res['loss'] = F.cross_entropy(res.logits.view(-1, res.logits.shape[-1]), labels.view(-1))\n", 73 | " return res" 74 | ] 75 | }, 76 | { 77 | "cell_type": "markdown", 78 | "id": "8c4d2819-e575-4f98-83af-1cd016b88bce", 79 | "metadata": {}, 80 | "source": [ 81 | "### Initialize Models" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": 4, 87 | "id": "601e4a50-cf00-4245-a62e-421113db425a", 88 | "metadata": {}, 89 | "outputs": [], 90 | "source": [ 91 | "freeze = 9 # number of layers to freeze" 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": 5, 97 | "id": "03ace372-e479-4caa-bcf3-68043ebeb0a0", 98 | "metadata": {}, 99 | "outputs": [ 100 | { 101 | "name": "stdout", 102 | "output_type": "stream", 103 | "text": [ 104 | "['gpt2', 'gpt2-medium']\n" 105 | ] 106 | }, 107 | { 108 | "name": "stderr", 109 | "output_type": "stream", 110 | "text": [ 111 | "Some weights of GPT2ForTokenClassification were not initialized from the model checkpoint at gpt2 and are newly initialized: ['h.4.attn.masked_bias', 'h.6.attn.masked_bias', 'classifier.bias', 'h.5.attn.masked_bias', 'h.2.attn.masked_bias', 'h.8.attn.masked_bias', 'h.11.attn.masked_bias', 'h.7.attn.masked_bias', 'h.0.attn.masked_bias', 'classifier.weight', 'h.3.attn.masked_bias', 'h.9.attn.masked_bias', 'h.10.attn.masked_bias', 'h.1.attn.masked_bias']\n", 112 | "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" 113 | ] 114 | }, 115 | { 116 | "name": "stdout", 117 | "output_type": "stream", 118 | "text": [ 119 | "model.transformer.h.9.ln_1.weight\n", 120 | "model.transformer.h.9.attn.c_attn.weight\n", 121 | "model.transformer.h.9.attn.c_proj.weight\n", 122 | "model.transformer.h.9.ln_2.weight\n", 123 | "model.transformer.h.9.mlp.c_fc.weight\n", 124 | "model.transformer.h.9.mlp.c_proj.weight\n", 125 | "model.transformer.h.10.ln_1.weight\n", 126 | "model.transformer.h.10.attn.c_attn.weight\n", 127 | "model.transformer.h.10.attn.c_proj.weight\n", 128 | "model.transformer.h.10.ln_2.weight\n", 129 | "model.transformer.h.10.mlp.c_fc.weight\n", 130 | "model.transformer.h.10.mlp.c_proj.weight\n", 131 | "model.transformer.h.11.ln_1.weight\n", 132 | "model.transformer.h.11.attn.c_attn.weight\n", 133 | "model.transformer.h.11.attn.c_proj.weight\n", 134 | "model.transformer.h.11.ln_2.weight\n", 135 | "model.transformer.h.11.mlp.c_fc.weight\n", 136 | "model.transformer.h.11.mlp.c_proj.weight\n", 137 | "model.classifier.weight\n" 138 | ] 139 | }, 140 | { 141 | "name": "stderr", 142 | "output_type": "stream", 143 | "text": [ 144 | "Some weights of GPT2ForSequenceClassification were not initialized from the model checkpoint at gpt2-medium and are newly initialized: ['score.weight']\n", 145 | "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" 146 | ] 147 | } 148 | ], 149 | "source": [ 150 | "model_paths = ['gpt2', 'gpt2-medium'] \n", 151 | "\n", 152 | "print(model_paths)\n", 153 | "\n", 154 | "model1 = Gpt2AvgClassifier(model_paths[0], freeze=freeze) # AutoModelForSequenceClassification.from_pretrained(model_paths[0])\n", 155 | "model2 = AutoModelForSequenceClassification.from_pretrained(model_paths[1])\n", 156 | "# we can use input embedding as the embedding matrices are tied\n", 157 | "emb1 = model1.model.get_input_embeddings().weight.T.cpu().detach() \n", 158 | "emb2 = model2.get_input_embeddings().weight.T.cpu().detach() \n", 159 | "num_layers1, hidden_dim1 = (model1.model.config.n_layer, model1.model.config.n_embd)\n", 160 | "num_layers2, hidden_dim2 = (model2.config.n_layer, model2.config.n_embd)" 161 | ] 162 | }, 163 | { 164 | "cell_type": "markdown", 165 | "id": "77f0f339-4a8e-4780-be94-40a3ffd9cef2", 166 | "metadata": {}, 167 | "source": [ 168 | "## Sentiment Analysis Finetuning" 169 | ] 170 | }, 171 | { 172 | "cell_type": "code", 173 | "execution_count": 6, 174 | "id": "48bcbc0b-3ee5-4494-9e92-1a65504e7870", 175 | "metadata": {}, 176 | "outputs": [], 177 | "source": [ 178 | "model = model1" 179 | ] 180 | }, 181 | { 182 | "cell_type": "markdown", 183 | "id": "fe41f0eb-e58f-45f8-9cf3-f4aff4e6c100", 184 | "metadata": {}, 185 | "source": [ 186 | "### Preparing Data" 187 | ] 188 | }, 189 | { 190 | "cell_type": "code", 191 | "execution_count": 7, 192 | "id": "11a17d09-8034-4df8-b1a0-62ff7814598f", 193 | "metadata": {}, 194 | "outputs": [], 195 | "source": [ 196 | "def tokenize_imdb(examples):\n", 197 | " return tokenizer(examples[\"text\"], truncation=True)" 198 | ] 199 | }, 200 | { 201 | "cell_type": "code", 202 | "execution_count": 8, 203 | "id": "46bcd56a-e0f0-4bd4-abb7-7f43dde6e29f", 204 | "metadata": {}, 205 | "outputs": [ 206 | { 207 | "name": "stderr", 208 | "output_type": "stream", 209 | "text": [ 210 | "Reusing dataset imdb (/home/guydar/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a)\n" 211 | ] 212 | }, 213 | { 214 | "data": { 215 | "application/vnd.jupyter.widget-view+json": { 216 | "model_id": "be9861acc9c94a6c9bc9c1062336e129", 217 | "version_major": 2, 218 | "version_minor": 0 219 | }, 220 | "text/plain": [ 221 | " 0%| | 0/3 [00:00\n", 343 | " \n", 344 | " \n", 345 | " [3000/3000 01:33, Epoch 1/1]\n", 346 | " \n", 347 | " \n", 348 | " \n", 349 | " \n", 350 | " \n", 351 | " \n", 352 | " \n", 353 | " \n", 354 | " \n", 355 | " \n", 356 | " \n", 357 | " \n", 358 | " \n", 359 | " \n", 360 | " \n", 361 | " \n", 362 | " \n", 363 | " \n", 364 | "
EpochTraining LossValidation LossAccuracy
10.6368000.9505260.832000

" 365 | ], 366 | "text/plain": [ 367 | "" 368 | ] 369 | }, 370 | "metadata": {}, 371 | "output_type": "display_data" 372 | }, 373 | { 374 | "name": "stderr", 375 | "output_type": "stream", 376 | "text": [ 377 | "The following columns in the evaluation set don't have a corresponding argument in `Gpt2AvgClassifier.forward` and have been ignored: attention_mask, text. If attention_mask, text are not expected by `Gpt2AvgClassifier.forward`, you can safely ignore this message.\n", 378 | "***** Running Evaluation *****\n", 379 | " Num examples = 500\n", 380 | " Batch size = 1\n", 381 | "\n", 382 | "\n", 383 | "Training completed. Do not forget to share your model on huggingface.co/models =)\n", 384 | "\n", 385 | "\n" 386 | ] 387 | }, 388 | { 389 | "data": { 390 | "text/plain": [ 391 | "TrainOutput(global_step=3000, training_loss=0.8896415710449219, metrics={'train_runtime': 93.5903, 'train_samples_per_second': 32.055, 'train_steps_per_second': 32.055, 'total_flos': 0.0, 'train_loss': 0.8896415710449219, 'epoch': 1.0})" 392 | ] 393 | }, 394 | "execution_count": 14, 395 | "metadata": {}, 396 | "output_type": "execute_result" 397 | } 398 | ], 399 | "source": [ 400 | "trainer = Trainer(model1, args=train_args, train_dataset=imdb_train, eval_dataset=imdb_val, \n", 401 | " compute_metrics=compute_metrics)\n", 402 | "trainer.train()" 403 | ] 404 | }, 405 | { 406 | "cell_type": "markdown", 407 | "id": "6685ad43-e590-4e53-b776-cce678426ca8", 408 | "metadata": {}, 409 | "source": [ 410 | "### Visualize Finetuning Vectors" 411 | ] 412 | }, 413 | { 414 | "cell_type": "code", 415 | "execution_count": 15, 416 | "id": "e7cc0792-d95a-4418-a58d-5999a2045af3", 417 | "metadata": {}, 418 | "outputs": [], 419 | "source": [ 420 | "diff_classifier = (model.model.classifier.weight.cpu() - old_model.model.classifier.weight.cpu()).detach()\n", 421 | "# diff_classifier = model.score.weight.detach().cpu() - old_model.score.weight.detach()\n", 422 | "# diff_classifier = model.classifier.weight.detach().cpu() - old_model.classifier.weight.detach()" 423 | ] 424 | }, 425 | { 426 | "cell_type": "code", 427 | "execution_count": 16, 428 | "id": "1e949f1f-b149-4ee3-b521-d7f262aa563e", 429 | "metadata": {}, 430 | "outputs": [], 431 | "source": [ 432 | "neg_vector = diff_classifier[0, :]\n", 433 | "pos_vector = diff_classifier[1, :]" 434 | ] 435 | }, 436 | { 437 | "cell_type": "code", 438 | "execution_count": 17, 439 | "id": "557b80ff-3285-4ba3-9397-0ac00d46ade9", 440 | "metadata": {}, 441 | "outputs": [ 442 | { 443 | "name": "stdout", 444 | "output_type": "stream", 445 | "text": [ 446 | "POSITIVE NEGATIVE\n", 447 | "---------- ------------\n", 448 | "#iscover bullshit\n", 449 | "honoured shitty\n", 450 | "pioneers crap\n", 451 | "#knit crappy\n", 452 | "#izons incompetence\n", 453 | "#Vers incompetent\n", 454 | "#raits pointless\n", 455 | "pioneer retarded\n", 456 | "#elight worse\n", 457 | "enchant FUCK\n", 458 | "#Together idiots\n", 459 | "reunited useless\n", 460 | "powerfully fuck\n", 461 | "#joy worthless\n", 462 | "Together garbage\n", 463 | "pioneering inco\n", 464 | "passions #Fuck\n", 465 | "timeless lame\n", 466 | "lively shit\n", 467 | "#inguished stupid\n", 468 | "insepar pathetic\n", 469 | "#Join inept\n", 470 | "renowned #shit\n", 471 | "unmatched piss\n", 472 | "#Born asshole\n", 473 | "#ossom Worse\n", 474 | "welcomes poorly\n", 475 | "Selected awful\n", 476 | "#anqu stupidity\n", 477 | "#Discover ineffective\n" 478 | ] 479 | } 480 | ], 481 | "source": [ 482 | "print(tabulate(\n", 483 | " [*zip(*[top_tokens(pos_vector @ emb1, k=30, only_ascii=True, tokenizer=tokenizer),\n", 484 | " top_tokens(neg_vector @ emb1, k=30, only_ascii=True, tokenizer=tokenizer)])],\n", 485 | " headers=['POSITIVE', 'NEGATIVE']))" 486 | ] 487 | }, 488 | { 489 | "cell_type": "code", 490 | "execution_count": 18, 491 | "id": "b63e1491-c1d9-46f9-87e2-c5e8de97b44c", 492 | "metadata": {}, 493 | "outputs": [], 494 | "source": [ 495 | "i1 = 11 # this is the layer we visualize" 496 | ] 497 | }, 498 | { 499 | "cell_type": "code", 500 | "execution_count": 19, 501 | "id": "83f6a687-7fa9-4e42-82f9-a8f72298d85b", 502 | "metadata": {}, 503 | "outputs": [], 504 | "source": [ 505 | "diff_K = (model.model.transformer.h[i1].mlp.c_fc.weight.cpu() - old_model.model.transformer.h[i1].mlp.c_fc.weight.cpu()).T\n", 506 | "diff_V = (model.model.transformer.h[i1].mlp.c_proj.weight.cpu() - old_model.model.transformer.h[i1].mlp.c_proj.weight.cpu())\n", 507 | "diff_WQ, diff_WK, diff_WV = ((model.model.transformer.h[i1].attn.c_attn.weight.cpu() - \n", 508 | " old_model.model.transformer.h[i1].attn.c_attn.weight.cpu()).T.chunk(3))\n", 509 | "diff_WO = (model.model.transformer.h[i1].attn.c_proj.weight.cpu() - old_model.model.transformer.h[i1].attn.c_proj.weight.cpu())" 510 | ] 511 | }, 512 | { 513 | "cell_type": "code", 514 | "execution_count": 20, 515 | "id": "8977bab0-6c83-4ac3-a842-d6d7c93e3c96", 516 | "metadata": {}, 517 | "outputs": [], 518 | "source": [ 519 | "diff_param = diff_WV" 520 | ] 521 | }, 522 | { 523 | "cell_type": "code", 524 | "execution_count": 21, 525 | "id": "2c76f6cb-8acd-4753-b0d9-50b59748cd74", 526 | "metadata": {}, 527 | "outputs": [], 528 | "source": [ 529 | "i2 = np.random.randint(diff_param.shape[0]) # index of vector in the parameter" 530 | ] 531 | }, 532 | { 533 | "cell_type": "code", 534 | "execution_count": 22, 535 | "id": "e8840f5a-3b4e-4b6a-8cc1-06752ec66c0a", 536 | "metadata": {}, 537 | "outputs": [ 538 | { 539 | "name": "stdout", 540 | "output_type": "stream", 541 | "text": [ 542 | "diff -diff\n", 543 | "------------ -------------\n", 544 | "incompetence unforgettable\n", 545 | "bullshit beautifully\n", 546 | "ineffective wonderfully\n", 547 | "worthless vividly\n", 548 | "bogus memorable\n", 549 | "incompetent thrilling\n", 550 | "useless delight\n", 551 | "retarded enjoyed\n", 552 | "retard timeless\n", 553 | "shitty superb\n", 554 | "worse wonderful\n", 555 | "idiots poignant\n", 556 | "#Fuck immensely\n", 557 | "Worse exhilar\n", 558 | "blame inspiring\n", 559 | "nonexistent delightful\n", 560 | "unus #love\n", 561 | "ineligible lively\n", 562 | "quotas vivid\n", 563 | "inco fascinating\n" 564 | ] 565 | } 566 | ], 567 | "source": [ 568 | "print(tabulate(zip(*[top_tokens(diff_param[i2].detach() @ emb1, k=20, only_ascii=True, tokenizer=tokenizer),\n", 569 | " top_tokens(-diff_param[i2].detach() @ emb1, k=20, only_ascii=True, tokenizer=tokenizer)]), \n", 570 | " headers=[\"diff\", \"-diff\"]))" 571 | ] 572 | }, 573 | { 574 | "cell_type": "markdown", 575 | "id": "cd0c489b-62dd-442c-b7ed-4a2cbdf66fe4", 576 | "metadata": {}, 577 | "source": [ 578 | "## Model Stitching" 579 | ] 580 | }, 581 | { 582 | "cell_type": "code", 583 | "execution_count": 23, 584 | "id": "90c6d7c7-d95f-4acd-b4fe-08e4da1bac07", 585 | "metadata": {}, 586 | "outputs": [], 587 | "source": [ 588 | "def subtract_modules(mod1, mod2, subtract_ln=False, only_weight=False):\n", 589 | " mod_new = deepcopy(mod1)\n", 590 | " with torch.no_grad():\n", 591 | " for n, p in mod_new.named_parameters():\n", 592 | " if only_weight and not n.endswith('.weight'):\n", 593 | " continue\n", 594 | " submodule_name = n.rsplit('.', 1)[0] if '.' in n else ''\n", 595 | " is_ln = isinstance(mod_new.get_submodule(submodule_name), nn.LayerNorm)\n", 596 | " if (not is_ln) or subtract_ln:\n", 597 | " p.set_(p.data - mod2.get_parameter(n).data)\n", 598 | " return mod_new" 599 | ] 600 | }, 601 | { 602 | "cell_type": "code", 603 | "execution_count": 24, 604 | "id": "c78ba6e8-bdd9-4fbf-97a9-d8f4b2b3e94f", 605 | "metadata": {}, 606 | "outputs": [], 607 | "source": [ 608 | "class StitchedTransformers(nn.Module):\n", 609 | " def __init__(self, old_model, model1, model2, kernel, num_keep_layers, num_transplanted_layers,\n", 610 | " subtract=True, **subtract_args):\n", 611 | " super().__init__()\n", 612 | " self.model2 = deepcopy(model2) \n", 613 | " self.model2.transformer.h = nn.ModuleList(self.model2.transformer.h[:num_keep_layers])\n", 614 | " self.register_buffer(\"stitching_kernel\", kernel) \n", 615 | " self.model1 = deepcopy(model1)\n", 616 | " offset = len(model1.model.transformer.h) - num_transplanted_layers\n", 617 | " self.model1.model.transformer.h = nn.ModuleList([\n", 618 | " subtract_modules(model1.model.transformer.h[offset + i], \n", 619 | " old_model.model.transformer.h[offset + i], \n", 620 | " **subtract_args) if subtract else model1.model.transformer.h[offset + i]\n", 621 | " for i in range(num_transplanted_layers)])\n", 622 | " self.model1.model.classifier = (\n", 623 | " subtract_modules(model1.model.classifier, old_model.model.classifier, **subtract_args) \n", 624 | " if subtract else model1.model.classifier\n", 625 | " )\n", 626 | " \n", 627 | " def forward(self, input_ids, labels):\n", 628 | " x = self.model2(input_ids, output_hidden_states=True).hidden_states[-1]\n", 629 | " x = x @ self.stitching_kernel\n", 630 | " res = self.model1(input_ids=None, inputs_embeds=x, labels=labels)\n", 631 | " res = {'loss': res['loss'], 'logits': res['logits']}\n", 632 | " return res" 633 | ] 634 | }, 635 | { 636 | "cell_type": "code", 637 | "execution_count": 25, 638 | "id": "f57515ac-730f-4a8a-b4c9-2d65dcfb33d1", 639 | "metadata": {}, 640 | "outputs": [], 641 | "source": [ 642 | "extended = False\n", 643 | "kernel = emb_extended2 @ (emb_extended1).pinverse() if extended else emb2 @ (emb1).pinverse()\n", 644 | "# + .1 * torch.eye(1024, 768)" 645 | ] 646 | }, 647 | { 648 | "cell_type": "code", 649 | "execution_count": 26, 650 | "id": "b724a15a-bf85-4abb-baa1-33c380c13d91", 651 | "metadata": {}, 652 | "outputs": [], 653 | "source": [ 654 | "subtract = False" 655 | ] 656 | }, 657 | { 658 | "cell_type": "code", 659 | "execution_count": 27, 660 | "id": "f5228ca9-3026-4338-973a-f151ea6b5fc8", 661 | "metadata": {}, 662 | "outputs": [], 663 | "source": [ 664 | "num_transplanted_layers = 3\n", 665 | "num_keep_layers = 14" 666 | ] 667 | }, 668 | { 669 | "cell_type": "markdown", 670 | "id": "a5415437-7726-43a3-9823-8af3e9a6ab81", 671 | "metadata": {}, 672 | "source": [ 673 | "### Evaluate" 674 | ] 675 | }, 676 | { 677 | "cell_type": "code", 678 | "execution_count": 28, 679 | "id": "f0c80b60-6c1a-442b-9e64-0aefda4318ab", 680 | "metadata": {}, 681 | "outputs": [], 682 | "source": [ 683 | "stitched_model = StitchedTransformers(old_model.cuda(), model1, model2, kernel, \n", 684 | " num_keep_layers, num_transplanted_layers, subtract=subtract).cpu()" 685 | ] 686 | }, 687 | { 688 | "cell_type": "code", 689 | "execution_count": 29, 690 | "id": "271fa53f-ff22-43e6-a471-945629f46082", 691 | "metadata": {}, 692 | "outputs": [ 693 | { 694 | "name": "stderr", 695 | "output_type": "stream", 696 | "text": [ 697 | "The following columns in the evaluation set don't have a corresponding argument in `StitchedTransformers.forward` and have been ignored: attention_mask, text. If attention_mask, text are not expected by `StitchedTransformers.forward`, you can safely ignore this message.\n", 698 | "***** Running Evaluation *****\n", 699 | " Num examples = 500\n", 700 | " Batch size = 1\n" 701 | ] 702 | }, 703 | { 704 | "data": { 705 | "text/html": [ 706 | "\n", 707 | "

\n", 708 | " \n", 709 | " \n", 710 | " [500/500 00:13]\n", 711 | "
\n", 712 | " " 713 | ], 714 | "text/plain": [ 715 | "" 716 | ] 717 | }, 718 | "metadata": {}, 719 | "output_type": "display_data" 720 | }, 721 | { 722 | "data": { 723 | "text/plain": [ 724 | "{'eval_loss': 10.434187889099121,\n", 725 | " 'eval_accuracy': 0.462,\n", 726 | " 'eval_runtime': 13.4261,\n", 727 | " 'eval_samples_per_second': 37.241,\n", 728 | " 'eval_steps_per_second': 37.241}" 729 | ] 730 | }, 731 | "execution_count": 29, 732 | "metadata": {}, 733 | "output_type": "execute_result" 734 | } 735 | ], 736 | "source": [ 737 | "trainer_stitched = Trainer(stitched_model, args=train_args, train_dataset=imdb_train, eval_dataset=imdb_val, \n", 738 | " compute_metrics=compute_metrics)\n", 739 | "trainer_stitched.evaluate()" 740 | ] 741 | }, 742 | { 743 | "cell_type": "markdown", 744 | "id": "4b9fe505-2e9c-4eea-9799-e50ff6a4b5a1", 745 | "metadata": {}, 746 | "source": [ 747 | "#### Plot All" 748 | ] 749 | }, 750 | { 751 | "cell_type": "code", 752 | "execution_count": 39, 753 | "id": "febb237a-7c52-4c23-ab44-15909508bf04", 754 | "metadata": { 755 | "tags": [] 756 | }, 757 | "outputs": [ 758 | { 759 | "name": "stderr", 760 | "output_type": "stream", 761 | "text": [ 762 | "The following columns in the evaluation set don't have a corresponding argument in `StitchedTransformers.forward` and have been ignored: attention_mask, text. If attention_mask, text are not expected by `StitchedTransformers.forward`, you can safely ignore this message.\n", 763 | "***** Running Evaluation *****\n", 764 | " Num examples = 500\n", 765 | " Batch size = 1\n" 766 | ] 767 | }, 768 | { 769 | "data": { 770 | "text/html": [ 771 | "\n", 772 | "
\n", 773 | " \n", 774 | " \n", 775 | " [500/500 00:02]\n", 776 | "
\n", 777 | " " 778 | ], 779 | "text/plain": [ 780 | "" 781 | ] 782 | }, 783 | "metadata": {}, 784 | "output_type": "display_data" 785 | }, 786 | { 787 | "name": "stderr", 788 | "output_type": "stream", 789 | "text": [ 790 | "The following columns in the evaluation set don't have a corresponding argument in `StitchedTransformers.forward` and have been ignored: attention_mask, text. If attention_mask, text are not expected by `StitchedTransformers.forward`, you can safely ignore this message.\n", 791 | "***** Running Evaluation *****\n", 792 | " Num examples = 500\n", 793 | " Batch size = 1\n" 794 | ] 795 | }, 796 | { 797 | "data": { 798 | "text/html": [ 799 | "\n", 800 | "
\n", 801 | " \n", 802 | " \n", 803 | " [500/500 00:03]\n", 804 | "
\n", 805 | " " 806 | ], 807 | "text/plain": [ 808 | "" 809 | ] 810 | }, 811 | "metadata": {}, 812 | "output_type": "display_data" 813 | }, 814 | { 815 | "name": "stderr", 816 | "output_type": "stream", 817 | "text": [ 818 | "The following columns in the evaluation set don't have a corresponding argument in `StitchedTransformers.forward` and have been ignored: attention_mask, text. If attention_mask, text are not expected by `StitchedTransformers.forward`, you can safely ignore this message.\n", 819 | "***** Running Evaluation *****\n", 820 | " Num examples = 500\n", 821 | " Batch size = 1\n" 822 | ] 823 | }, 824 | { 825 | "data": { 826 | "text/html": [ 827 | "\n", 828 | "
\n", 829 | " \n", 830 | " \n", 831 | " [500/500 00:03]\n", 832 | "
\n", 833 | " " 834 | ], 835 | "text/plain": [ 836 | "" 837 | ] 838 | }, 839 | "metadata": {}, 840 | "output_type": "display_data" 841 | }, 842 | { 843 | "name": "stderr", 844 | "output_type": "stream", 845 | "text": [ 846 | "The following columns in the evaluation set don't have a corresponding argument in `StitchedTransformers.forward` and have been ignored: attention_mask, text. If attention_mask, text are not expected by `StitchedTransformers.forward`, you can safely ignore this message.\n", 847 | "***** Running Evaluation *****\n", 848 | " Num examples = 500\n", 849 | " Batch size = 1\n" 850 | ] 851 | }, 852 | { 853 | "data": { 854 | "text/html": [ 855 | "\n", 856 | "
\n", 857 | " \n", 858 | " \n", 859 | " [500/500 00:04]\n", 860 | "
\n", 861 | " " 862 | ], 863 | "text/plain": [ 864 | "" 865 | ] 866 | }, 867 | "metadata": {}, 868 | "output_type": "display_data" 869 | }, 870 | { 871 | "name": "stderr", 872 | "output_type": "stream", 873 | "text": [ 874 | "The following columns in the evaluation set don't have a corresponding argument in `StitchedTransformers.forward` and have been ignored: attention_mask, text. If attention_mask, text are not expected by `StitchedTransformers.forward`, you can safely ignore this message.\n", 875 | "***** Running Evaluation *****\n", 876 | " Num examples = 500\n", 877 | " Batch size = 1\n" 878 | ] 879 | }, 880 | { 881 | "data": { 882 | "text/html": [ 883 | "\n", 884 | "
\n", 885 | " \n", 886 | " \n", 887 | " [500/500 00:05]\n", 888 | "
\n", 889 | " " 890 | ], 891 | "text/plain": [ 892 | "" 893 | ] 894 | }, 895 | "metadata": {}, 896 | "output_type": "display_data" 897 | }, 898 | { 899 | "name": "stderr", 900 | "output_type": "stream", 901 | "text": [ 902 | "The following columns in the evaluation set don't have a corresponding argument in `StitchedTransformers.forward` and have been ignored: attention_mask, text. If attention_mask, text are not expected by `StitchedTransformers.forward`, you can safely ignore this message.\n", 903 | "***** Running Evaluation *****\n", 904 | " Num examples = 500\n", 905 | " Batch size = 1\n" 906 | ] 907 | }, 908 | { 909 | "data": { 910 | "text/html": [ 911 | "\n", 912 | "
\n", 913 | " \n", 914 | " \n", 915 | " [500/500 00:06]\n", 916 | "
\n", 917 | " " 918 | ], 919 | "text/plain": [ 920 | "" 921 | ] 922 | }, 923 | "metadata": {}, 924 | "output_type": "display_data" 925 | }, 926 | { 927 | "name": "stderr", 928 | "output_type": "stream", 929 | "text": [ 930 | "The following columns in the evaluation set don't have a corresponding argument in `StitchedTransformers.forward` and have been ignored: attention_mask, text. If attention_mask, text are not expected by `StitchedTransformers.forward`, you can safely ignore this message.\n", 931 | "***** Running Evaluation *****\n", 932 | " Num examples = 500\n", 933 | " Batch size = 1\n" 934 | ] 935 | }, 936 | { 937 | "data": { 938 | "text/html": [ 939 | "\n", 940 | "
\n", 941 | " \n", 942 | " \n", 943 | " [500/500 00:07]\n", 944 | "
\n", 945 | " " 946 | ], 947 | "text/plain": [ 948 | "" 949 | ] 950 | }, 951 | "metadata": {}, 952 | "output_type": "display_data" 953 | }, 954 | { 955 | "name": "stderr", 956 | "output_type": "stream", 957 | "text": [ 958 | "The following columns in the evaluation set don't have a corresponding argument in `StitchedTransformers.forward` and have been ignored: attention_mask, text. If attention_mask, text are not expected by `StitchedTransformers.forward`, you can safely ignore this message.\n", 959 | "***** Running Evaluation *****\n", 960 | " Num examples = 500\n", 961 | " Batch size = 1\n" 962 | ] 963 | }, 964 | { 965 | "data": { 966 | "text/html": [ 967 | "\n", 968 | "
\n", 969 | " \n", 970 | " \n", 971 | " [500/500 00:08]\n", 972 | "
\n", 973 | " " 974 | ], 975 | "text/plain": [ 976 | "" 977 | ] 978 | }, 979 | "metadata": {}, 980 | "output_type": "display_data" 981 | }, 982 | { 983 | "name": "stderr", 984 | "output_type": "stream", 985 | "text": [ 986 | "The following columns in the evaluation set don't have a corresponding argument in `StitchedTransformers.forward` and have been ignored: attention_mask, text. If attention_mask, text are not expected by `StitchedTransformers.forward`, you can safely ignore this message.\n", 987 | "***** Running Evaluation *****\n", 988 | " Num examples = 500\n", 989 | " Batch size = 1\n" 990 | ] 991 | }, 992 | { 993 | "data": { 994 | "text/html": [ 995 | "\n", 996 | "
\n", 997 | " \n", 998 | " \n", 999 | " [500/500 00:08]\n", 1000 | "
\n", 1001 | " " 1002 | ], 1003 | "text/plain": [ 1004 | "" 1005 | ] 1006 | }, 1007 | "metadata": {}, 1008 | "output_type": "display_data" 1009 | }, 1010 | { 1011 | "name": "stderr", 1012 | "output_type": "stream", 1013 | "text": [ 1014 | "The following columns in the evaluation set don't have a corresponding argument in `StitchedTransformers.forward` and have been ignored: attention_mask, text. If attention_mask, text are not expected by `StitchedTransformers.forward`, you can safely ignore this message.\n", 1015 | "***** Running Evaluation *****\n", 1016 | " Num examples = 500\n", 1017 | " Batch size = 1\n" 1018 | ] 1019 | }, 1020 | { 1021 | "data": { 1022 | "text/html": [ 1023 | "\n", 1024 | "
\n", 1025 | " \n", 1026 | " \n", 1027 | " [500/500 00:09]\n", 1028 | "
\n", 1029 | " " 1030 | ], 1031 | "text/plain": [ 1032 | "" 1033 | ] 1034 | }, 1035 | "metadata": {}, 1036 | "output_type": "display_data" 1037 | }, 1038 | { 1039 | "name": "stderr", 1040 | "output_type": "stream", 1041 | "text": [ 1042 | "The following columns in the evaluation set don't have a corresponding argument in `StitchedTransformers.forward` and have been ignored: attention_mask, text. If attention_mask, text are not expected by `StitchedTransformers.forward`, you can safely ignore this message.\n", 1043 | "***** Running Evaluation *****\n", 1044 | " Num examples = 500\n", 1045 | " Batch size = 1\n" 1046 | ] 1047 | }, 1048 | { 1049 | "data": { 1050 | "text/html": [ 1051 | "\n", 1052 | "
\n", 1053 | " \n", 1054 | " \n", 1055 | " [500/500 00:10]\n", 1056 | "
\n", 1057 | " " 1058 | ], 1059 | "text/plain": [ 1060 | "" 1061 | ] 1062 | }, 1063 | "metadata": {}, 1064 | "output_type": "display_data" 1065 | }, 1066 | { 1067 | "name": "stderr", 1068 | "output_type": "stream", 1069 | "text": [ 1070 | "The following columns in the evaluation set don't have a corresponding argument in `StitchedTransformers.forward` and have been ignored: attention_mask, text. If attention_mask, text are not expected by `StitchedTransformers.forward`, you can safely ignore this message.\n", 1071 | "***** Running Evaluation *****\n", 1072 | " Num examples = 500\n", 1073 | " Batch size = 1\n" 1074 | ] 1075 | }, 1076 | { 1077 | "data": { 1078 | "text/html": [ 1079 | "\n", 1080 | "
\n", 1081 | " \n", 1082 | " \n", 1083 | " [500/500 00:11]\n", 1084 | "
\n", 1085 | " " 1086 | ], 1087 | "text/plain": [ 1088 | "" 1089 | ] 1090 | }, 1091 | "metadata": {}, 1092 | "output_type": "display_data" 1093 | }, 1094 | { 1095 | "name": "stderr", 1096 | "output_type": "stream", 1097 | "text": [ 1098 | "The following columns in the evaluation set don't have a corresponding argument in `StitchedTransformers.forward` and have been ignored: attention_mask, text. If attention_mask, text are not expected by `StitchedTransformers.forward`, you can safely ignore this message.\n", 1099 | "***** Running Evaluation *****\n", 1100 | " Num examples = 500\n", 1101 | " Batch size = 1\n" 1102 | ] 1103 | }, 1104 | { 1105 | "data": { 1106 | "text/html": [ 1107 | "\n", 1108 | "
\n", 1109 | " \n", 1110 | " \n", 1111 | " [500/500 00:12]\n", 1112 | "
\n", 1113 | " " 1114 | ], 1115 | "text/plain": [ 1116 | "" 1117 | ] 1118 | }, 1119 | "metadata": {}, 1120 | "output_type": "display_data" 1121 | }, 1122 | { 1123 | "name": "stderr", 1124 | "output_type": "stream", 1125 | "text": [ 1126 | "The following columns in the evaluation set don't have a corresponding argument in `StitchedTransformers.forward` and have been ignored: attention_mask, text. If attention_mask, text are not expected by `StitchedTransformers.forward`, you can safely ignore this message.\n", 1127 | "***** Running Evaluation *****\n", 1128 | " Num examples = 500\n", 1129 | " Batch size = 1\n" 1130 | ] 1131 | }, 1132 | { 1133 | "data": { 1134 | "text/html": [ 1135 | "\n", 1136 | "
\n", 1137 | " \n", 1138 | " \n", 1139 | " [500/500 00:13]\n", 1140 | "
\n", 1141 | " " 1142 | ], 1143 | "text/plain": [ 1144 | "" 1145 | ] 1146 | }, 1147 | "metadata": {}, 1148 | "output_type": "display_data" 1149 | }, 1150 | { 1151 | "name": "stderr", 1152 | "output_type": "stream", 1153 | "text": [ 1154 | "The following columns in the evaluation set don't have a corresponding argument in `StitchedTransformers.forward` and have been ignored: attention_mask, text. If attention_mask, text are not expected by `StitchedTransformers.forward`, you can safely ignore this message.\n", 1155 | "***** Running Evaluation *****\n", 1156 | " Num examples = 500\n", 1157 | " Batch size = 1\n" 1158 | ] 1159 | }, 1160 | { 1161 | "data": { 1162 | "text/html": [ 1163 | "\n", 1164 | "
\n", 1165 | " \n", 1166 | " \n", 1167 | " [500/500 00:13]\n", 1168 | "
\n", 1169 | " " 1170 | ], 1171 | "text/plain": [ 1172 | "" 1173 | ] 1174 | }, 1175 | "metadata": {}, 1176 | "output_type": "display_data" 1177 | }, 1178 | { 1179 | "name": "stderr", 1180 | "output_type": "stream", 1181 | "text": [ 1182 | "The following columns in the evaluation set don't have a corresponding argument in `StitchedTransformers.forward` and have been ignored: attention_mask, text. If attention_mask, text are not expected by `StitchedTransformers.forward`, you can safely ignore this message.\n", 1183 | "***** Running Evaluation *****\n", 1184 | " Num examples = 500\n", 1185 | " Batch size = 1\n" 1186 | ] 1187 | }, 1188 | { 1189 | "data": { 1190 | "text/html": [ 1191 | "\n", 1192 | "
\n", 1193 | " \n", 1194 | " \n", 1195 | " [500/500 00:14]\n", 1196 | "
\n", 1197 | " " 1198 | ], 1199 | "text/plain": [ 1200 | "" 1201 | ] 1202 | }, 1203 | "metadata": {}, 1204 | "output_type": "display_data" 1205 | }, 1206 | { 1207 | "name": "stderr", 1208 | "output_type": "stream", 1209 | "text": [ 1210 | "The following columns in the evaluation set don't have a corresponding argument in `StitchedTransformers.forward` and have been ignored: attention_mask, text. If attention_mask, text are not expected by `StitchedTransformers.forward`, you can safely ignore this message.\n", 1211 | "***** Running Evaluation *****\n", 1212 | " Num examples = 500\n", 1213 | " Batch size = 1\n" 1214 | ] 1215 | }, 1216 | { 1217 | "data": { 1218 | "text/html": [ 1219 | "\n", 1220 | "
\n", 1221 | " \n", 1222 | " \n", 1223 | " [500/500 00:15]\n", 1224 | "
\n", 1225 | " " 1226 | ], 1227 | "text/plain": [ 1228 | "" 1229 | ] 1230 | }, 1231 | "metadata": {}, 1232 | "output_type": "display_data" 1233 | }, 1234 | { 1235 | "name": "stderr", 1236 | "output_type": "stream", 1237 | "text": [ 1238 | "The following columns in the evaluation set don't have a corresponding argument in `StitchedTransformers.forward` and have been ignored: attention_mask, text. If attention_mask, text are not expected by `StitchedTransformers.forward`, you can safely ignore this message.\n", 1239 | "***** Running Evaluation *****\n", 1240 | " Num examples = 500\n", 1241 | " Batch size = 1\n" 1242 | ] 1243 | }, 1244 | { 1245 | "data": { 1246 | "text/html": [ 1247 | "\n", 1248 | "
\n", 1249 | " \n", 1250 | " \n", 1251 | " [500/500 00:16]\n", 1252 | "
\n", 1253 | " " 1254 | ], 1255 | "text/plain": [ 1256 | "" 1257 | ] 1258 | }, 1259 | "metadata": {}, 1260 | "output_type": "display_data" 1261 | }, 1262 | { 1263 | "name": "stderr", 1264 | "output_type": "stream", 1265 | "text": [ 1266 | "The following columns in the evaluation set don't have a corresponding argument in `StitchedTransformers.forward` and have been ignored: attention_mask, text. If attention_mask, text are not expected by `StitchedTransformers.forward`, you can safely ignore this message.\n", 1267 | "***** Running Evaluation *****\n", 1268 | " Num examples = 500\n", 1269 | " Batch size = 1\n" 1270 | ] 1271 | }, 1272 | { 1273 | "data": { 1274 | "text/html": [ 1275 | "\n", 1276 | "
\n", 1277 | " \n", 1278 | " \n", 1279 | " [500/500 00:17]\n", 1280 | "
\n", 1281 | " " 1282 | ], 1283 | "text/plain": [ 1284 | "" 1285 | ] 1286 | }, 1287 | "metadata": {}, 1288 | "output_type": "display_data" 1289 | }, 1290 | { 1291 | "name": "stderr", 1292 | "output_type": "stream", 1293 | "text": [ 1294 | "The following columns in the evaluation set don't have a corresponding argument in `StitchedTransformers.forward` and have been ignored: attention_mask, text. If attention_mask, text are not expected by `StitchedTransformers.forward`, you can safely ignore this message.\n", 1295 | "***** Running Evaluation *****\n", 1296 | " Num examples = 500\n", 1297 | " Batch size = 1\n" 1298 | ] 1299 | }, 1300 | { 1301 | "data": { 1302 | "text/html": [ 1303 | "\n", 1304 | "
\n", 1305 | " \n", 1306 | " \n", 1307 | " [500/500 00:17]\n", 1308 | "
\n", 1309 | " " 1310 | ], 1311 | "text/plain": [ 1312 | "" 1313 | ] 1314 | }, 1315 | "metadata": {}, 1316 | "output_type": "display_data" 1317 | }, 1318 | { 1319 | "name": "stderr", 1320 | "output_type": "stream", 1321 | "text": [ 1322 | "The following columns in the evaluation set don't have a corresponding argument in `StitchedTransformers.forward` and have been ignored: attention_mask, text. If attention_mask, text are not expected by `StitchedTransformers.forward`, you can safely ignore this message.\n", 1323 | "***** Running Evaluation *****\n", 1324 | " Num examples = 500\n", 1325 | " Batch size = 1\n" 1326 | ] 1327 | }, 1328 | { 1329 | "data": { 1330 | "text/html": [ 1331 | "\n", 1332 | "
\n", 1333 | " \n", 1334 | " \n", 1335 | " [500/500 00:18]\n", 1336 | "
\n", 1337 | " " 1338 | ], 1339 | "text/plain": [ 1340 | "" 1341 | ] 1342 | }, 1343 | "metadata": {}, 1344 | "output_type": "display_data" 1345 | }, 1346 | { 1347 | "name": "stderr", 1348 | "output_type": "stream", 1349 | "text": [ 1350 | "The following columns in the evaluation set don't have a corresponding argument in `StitchedTransformers.forward` and have been ignored: attention_mask, text. If attention_mask, text are not expected by `StitchedTransformers.forward`, you can safely ignore this message.\n", 1351 | "***** Running Evaluation *****\n", 1352 | " Num examples = 500\n", 1353 | " Batch size = 1\n" 1354 | ] 1355 | }, 1356 | { 1357 | "data": { 1358 | "text/html": [ 1359 | "\n", 1360 | "
\n", 1361 | " \n", 1362 | " \n", 1363 | " [500/500 00:19]\n", 1364 | "
\n", 1365 | " " 1366 | ], 1367 | "text/plain": [ 1368 | "" 1369 | ] 1370 | }, 1371 | "metadata": {}, 1372 | "output_type": "display_data" 1373 | }, 1374 | { 1375 | "name": "stderr", 1376 | "output_type": "stream", 1377 | "text": [ 1378 | "The following columns in the evaluation set don't have a corresponding argument in `StitchedTransformers.forward` and have been ignored: attention_mask, text. If attention_mask, text are not expected by `StitchedTransformers.forward`, you can safely ignore this message.\n", 1379 | "***** Running Evaluation *****\n", 1380 | " Num examples = 500\n", 1381 | " Batch size = 1\n" 1382 | ] 1383 | }, 1384 | { 1385 | "data": { 1386 | "text/html": [ 1387 | "\n", 1388 | "
\n", 1389 | " \n", 1390 | " \n", 1391 | " [500/500 00:20]\n", 1392 | "
\n", 1393 | " " 1394 | ], 1395 | "text/plain": [ 1396 | "" 1397 | ] 1398 | }, 1399 | "metadata": {}, 1400 | "output_type": "display_data" 1401 | }, 1402 | { 1403 | "name": "stderr", 1404 | "output_type": "stream", 1405 | "text": [ 1406 | "The following columns in the evaluation set don't have a corresponding argument in `StitchedTransformers.forward` and have been ignored: attention_mask, text. If attention_mask, text are not expected by `StitchedTransformers.forward`, you can safely ignore this message.\n", 1407 | "***** Running Evaluation *****\n", 1408 | " Num examples = 500\n", 1409 | " Batch size = 1\n" 1410 | ] 1411 | }, 1412 | { 1413 | "data": { 1414 | "text/html": [ 1415 | "\n", 1416 | "
\n", 1417 | " \n", 1418 | " \n", 1419 | " [500/500 00:21]\n", 1420 | "
\n", 1421 | " " 1422 | ], 1423 | "text/plain": [ 1424 | "" 1425 | ] 1426 | }, 1427 | "metadata": {}, 1428 | "output_type": "display_data" 1429 | } 1430 | ], 1431 | "source": [ 1432 | "accs = []\n", 1433 | "for num_keep_layers in range(model2.config.n_layer):\n", 1434 | " stitched_model = StitchedTransformers(old_model.cuda(), model1, model2, kernel, \n", 1435 | " num_keep_layers, num_transplanted_layers, subtract=subtract).cpu()\n", 1436 | " trainer_stitched = Trainer(stitched_model, args=train_args, train_dataset=imdb_train, eval_dataset=imdb_val, \n", 1437 | " compute_metrics=compute_metrics)\n", 1438 | "\n", 1439 | " accs.append(trainer_stitched.evaluate()['eval_accuracy'])" 1440 | ] 1441 | } 1442 | ], 1443 | "metadata": { 1444 | "kernelspec": { 1445 | "display_name": "Python 3", 1446 | "language": "python", 1447 | "name": "python3" 1448 | }, 1449 | "language_info": { 1450 | "codemirror_mode": { 1451 | "name": "ipython", 1452 | "version": 3 1453 | }, 1454 | "file_extension": ".py", 1455 | "mimetype": "text/x-python", 1456 | "name": "python", 1457 | "nbconvert_exporter": "python", 1458 | "pygments_lexer": "ipython3", 1459 | "version": "3.8.13" 1460 | } 1461 | }, 1462 | "nbformat": 4, 1463 | "nbformat_minor": 5 1464 | } 1465 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | from torch import nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | from datasets import load_dataset 7 | from copy import deepcopy 8 | 9 | 10 | def keep_k(x, k=100, absolute=True, dim=-1): 11 | shape = x.shape 12 | x_ = x 13 | if absolute: 14 | x_ = abs(x) 15 | values, indices = torch.topk(x_, k=k, dim=dim) 16 | res = torch.zeros_like(x) 17 | res.scatter_(dim, indices, x.gather(dim, indices)) 18 | return res 19 | 20 | 21 | def load_imdb(): 22 | return load_dataset('imdb')['test']['text'] 23 | 24 | 25 | class TokenizerFromVocab: 26 | def __init__(self, vocab): 27 | self.vocab = vocab 28 | 29 | def convert_ids_to_tokens(self, arr): 30 | return [*map(vocab.__getitem__, arr.cpu().tolist())] 31 | 32 | def __len__(self): 33 | return len(self.vocab) 34 | 35 | 36 | def get_multiberts_tokenizer(): 37 | vocab = dict(enumerate(open('multiberts/vocab.txt', 'r').read().split('\n')[:-1])) 38 | return TokenizerFromVocab(vocab) 39 | 40 | 41 | def convert_to_tokens(indices, tokenizer, strip=True, width=15): 42 | res = tokenizer.convert_ids_to_tokens(indices) 43 | if strip: 44 | res = list(map(lambda x: x[1:] if x[0] == 'Ġ' else "#" + x, res)) 45 | if width: 46 | res = list(map(lambda x: x[:width] + (x[width:] and '...'), res)) 47 | return res 48 | 49 | 50 | def top_tokens(v, tokenizer, k=100, only_english=False, only_ascii=False, 51 | exclude_brackets=False): 52 | v = deepcopy(v) 53 | ignored_indices = [] 54 | if only_ascii: 55 | ignored_indices = [key for val, key in tokenizer.vocab.items() if not val.strip('Ġ').isascii()] 56 | if only_english: 57 | ignored_indices =[key for val, key in tokenizer.vocab.items() 58 | if not (val.strip('Ġ').isascii() and val.strip('Ġ[]').isalnum())] 59 | if exclude_brackets: 60 | ignored_indices = set(ignored_indices).intersection( 61 | {key for val, key in tokenizer.vocab.items() if not (val.isascii() and val.isalnum())}) 62 | ignored_indices = list(ignored_indices) 63 | v[ignored_indices] = -np.inf 64 | values, indices = torch.topk(v, k=k) 65 | res = convert_to_tokens(indices, tokenizer) 66 | return res 67 | 68 | 69 | def top_matrix_tokens(mat, tokenizer, k=100, rel_thresh=None, thresh=None, 70 | sample_entries=10000, alphabetical=False, only_english=False, 71 | exclude_brackets=False): 72 | mat = deepcopy(mat) 73 | ignored_indices = [] 74 | if only_english: 75 | ignored_indices = [key for val, key in tokenizer.vocab.items() 76 | if not (val.isascii() and val.strip('[]').isalnum())] 77 | if exclude_brackets: 78 | ignored_indices = set(ignored_indices).intersection( 79 | {key for val, key in tokenizer.vocab.items() if not (val.isascii() and val.isalnum())}) 80 | ignored_indices = list(ignored_indices) 81 | mat[ignored_indices, :] = -np.inf 82 | mat[:, ignored_indices] = -np.inf 83 | cond = torch.ones_like(mat).bool() 84 | if rel_thresh: 85 | cond &= (mat > torch.max(mat) * rel_thresh) 86 | if thresh: 87 | cond &= (mat > thresh) 88 | entries = torch.nonzero(cond) 89 | if sample_entries: 90 | entries = entries[np.random.randint(len(torch.nonzero(cond)), size=sample_entries)] 91 | res_indices = sorted(entries, key=lambda x: x[0] if alphabetical else -mat[x[0], x[1]]) 92 | res = [*map(convert_to_tokens, res_indices)] 93 | return res 94 | --------------------------------------------------------------------------------