├── .gitignore ├── Beam_Search.ipynb ├── Colab_Train.ipynb ├── Inference.ipynb ├── Local_Train.ipynb ├── README.md ├── attention_visual.ipynb ├── conda.txt ├── config.py ├── dataset.py ├── model.py ├── requirements.txt ├── train.py ├── train_wb.py └── translate.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.pt 3 | *.json 4 | runs/ 5 | weights/ 6 | 7 | wandb/ -------------------------------------------------------------------------------- /Beam_Search.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from pathlib import Path\n", 10 | "import torch\n", 11 | "import torch.nn as nn\n", 12 | "from config import get_config, get_weights_file_path\n", 13 | "from train import get_model, get_ds, run_validation, causal_mask" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 2, 19 | "metadata": {}, 20 | "outputs": [ 21 | { 22 | "name": "stdout", 23 | "output_type": "stream", 24 | "text": [ 25 | "Using device: cuda\n", 26 | "Max length of source sentence: 309\n", 27 | "Max length of target sentence: 274\n" 28 | ] 29 | }, 30 | { 31 | "data": { 32 | "text/plain": [ 33 | "" 34 | ] 35 | }, 36 | "execution_count": 2, 37 | "metadata": {}, 38 | "output_type": "execute_result" 39 | } 40 | ], 41 | "source": [ 42 | "# Define the device\n", 43 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", 44 | "print(\"Using device:\", device)\n", 45 | "config = get_config()\n", 46 | "train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config)\n", 47 | "model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()).to(device)\n", 48 | "\n", 49 | "# Load the pretrained weights\n", 50 | "model_filename = get_weights_file_path(config, f\"19\")\n", 51 | "state = torch.load(model_filename)\n", 52 | "model.load_state_dict(state['model_state_dict'])" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 3, 58 | "metadata": {}, 59 | "outputs": [ 60 | { 61 | "name": "stdout", 62 | "output_type": "stream", 63 | "text": [ 64 | "--------------------------------------------------------------------------------\n", 65 | " SOURCE: Hence it is that for so long a time, and during so much fighting in the past twenty years, whenever there has been an army wholly Italian, it has always given a poor account of itself; the first witness to this is Il Taro, afterwards Allesandria, Capua, Genoa, Vaila, Bologna, Mestri.\n", 66 | " TARGET: Di qui nasce che, in tanto tempo, in tante guerre fatte ne' passati venti anni, quando elli è stato uno esercito tutto italiano, sempre ha fatto mala pruova. Di che è testimone prima el Taro, di poi Alessandria, Capua, Genova, Vailà, Bologna, Mestri.\n", 67 | " PREDICTED GREEDY: Di qui nasce che , in tanto , in tanto tempo , in tante guerre fatte ne ' passati\n", 68 | " PREDICTED BEAM: Di qui nasce che , in tanto tempo , in tante guerre fatte ne ' passati venti anni ,\n", 69 | "--------------------------------------------------------------------------------\n", 70 | " SOURCE: She went out.\n", 71 | " TARGET: Aprì lo sportello e venne fuori.\n", 72 | " PREDICTED GREEDY: Aprì lo sportello e venne fuori .\n", 73 | " PREDICTED BEAM: Aprì lo sportello e venne fuori . — Ecco , poi uscì e andò via . — Ecco ,\n", 74 | "--------------------------------------------------------------------------------\n" 75 | ] 76 | } 77 | ], 78 | "source": [ 79 | "def beam_search_decode(model, beam_size, source, source_mask, tokenizer_src, tokenizer_tgt, max_len, device):\n", 80 | " sos_idx = tokenizer_tgt.token_to_id('[SOS]')\n", 81 | " eos_idx = tokenizer_tgt.token_to_id('[EOS]')\n", 82 | "\n", 83 | " # Precompute the encoder output and reuse it for every step\n", 84 | " encoder_output = model.encode(source, source_mask)\n", 85 | " # Initialize the decoder input with the sos token\n", 86 | " decoder_initial_input = torch.empty(1, 1).fill_(sos_idx).type_as(source).to(device)\n", 87 | "\n", 88 | " # Create a candidate list\n", 89 | " candidates = [(decoder_initial_input, 1)]\n", 90 | "\n", 91 | " while True:\n", 92 | "\n", 93 | " # If a candidate has reached the maximum length, it means we have run the decoding for at least max_len iterations, so stop the search\n", 94 | " if any([cand.size(1) == max_len for cand, _ in candidates]):\n", 95 | " break\n", 96 | "\n", 97 | " # Create a new list of candidates\n", 98 | " new_candidates = []\n", 99 | "\n", 100 | " for candidate, score in candidates:\n", 101 | "\n", 102 | " # Do not expand candidates that have reached the eos token\n", 103 | " if candidate[0][-1].item() == eos_idx:\n", 104 | " continue\n", 105 | "\n", 106 | " # Build the candidate's mask\n", 107 | " candidate_mask = causal_mask(candidate.size(1)).type_as(source_mask).to(device)\n", 108 | " # calculate output\n", 109 | " out = model.decode(encoder_output, source_mask, candidate, candidate_mask)\n", 110 | " # get next token probabilities\n", 111 | " prob = model.project(out[:, -1])\n", 112 | " # get the top k candidates\n", 113 | " topk_prob, topk_idx = torch.topk(prob, beam_size, dim=1)\n", 114 | " for i in range(beam_size):\n", 115 | " # for each of the top k candidates, get the token and its probability\n", 116 | " token = topk_idx[0][i].unsqueeze(0).unsqueeze(0)\n", 117 | " token_prob = topk_prob[0][i].item()\n", 118 | " # create a new candidate by appending the token to the current candidate\n", 119 | " new_candidate = torch.cat([candidate, token], dim=1)\n", 120 | " # We sum the log probabilities because the probabilities are in log space\n", 121 | " new_candidates.append((new_candidate, score + token_prob))\n", 122 | "\n", 123 | " # Sort the new candidates by their score\n", 124 | " candidates = sorted(new_candidates, key=lambda x: x[1], reverse=True)\n", 125 | " # Keep only the top k candidates\n", 126 | " candidates = candidates[:beam_size]\n", 127 | "\n", 128 | " # If all the candidates have reached the eos token, stop\n", 129 | " if all([cand[0][-1].item() == eos_idx for cand, _ in candidates]):\n", 130 | " break\n", 131 | "\n", 132 | " # Return the best candidate\n", 133 | " return candidates[0][0].squeeze()\n", 134 | "\n", 135 | "def greedy_decode(model, source, source_mask, tokenizer_src, tokenizer_tgt, max_len, device):\n", 136 | " sos_idx = tokenizer_tgt.token_to_id('[SOS]')\n", 137 | " eos_idx = tokenizer_tgt.token_to_id('[EOS]')\n", 138 | "\n", 139 | " # Precompute the encoder output and reuse it for every step\n", 140 | " encoder_output = model.encode(source, source_mask)\n", 141 | " # Initialize the decoder input with the sos token\n", 142 | " decoder_input = torch.empty(1, 1).fill_(sos_idx).type_as(source).to(device)\n", 143 | " while True:\n", 144 | " if decoder_input.size(1) == max_len:\n", 145 | " break\n", 146 | "\n", 147 | " # build mask for target\n", 148 | " decoder_mask = causal_mask(decoder_input.size(1)).type_as(source_mask).to(device)\n", 149 | "\n", 150 | " # calculate output\n", 151 | " out = model.decode(encoder_output, source_mask, decoder_input, decoder_mask)\n", 152 | "\n", 153 | " # get next token\n", 154 | " prob = model.project(out[:, -1])\n", 155 | " _, next_word = torch.max(prob, dim=1)\n", 156 | " decoder_input = torch.cat(\n", 157 | " [decoder_input, torch.empty(1, 1).type_as(source).fill_(next_word.item()).to(device)], dim=1\n", 158 | " )\n", 159 | "\n", 160 | " if next_word == eos_idx:\n", 161 | " break\n", 162 | "\n", 163 | " return decoder_input.squeeze(0)\n", 164 | "\n", 165 | "def run_validation(model, validation_ds, tokenizer_src, tokenizer_tgt, max_len, device, print_msg, num_examples=2):\n", 166 | " model.eval()\n", 167 | " count = 0\n", 168 | "\n", 169 | " console_width = 80\n", 170 | "\n", 171 | " with torch.no_grad():\n", 172 | " for batch in validation_ds:\n", 173 | " count += 1\n", 174 | " encoder_input = batch[\"encoder_input\"].to(device) # (b, seq_len)\n", 175 | " encoder_mask = batch[\"encoder_mask\"].to(device) # (b, 1, 1, seq_len)\n", 176 | "\n", 177 | " # check that the batch size is 1\n", 178 | " assert encoder_input.size(\n", 179 | " 0) == 1, \"Batch size must be 1 for validation\"\n", 180 | "\n", 181 | " \n", 182 | " model_out_greedy = greedy_decode(model, encoder_input, encoder_mask, tokenizer_src, tokenizer_tgt, max_len, device)\n", 183 | " model_out_beam = beam_search_decode(model, 3, encoder_input, encoder_mask, tokenizer_src, tokenizer_tgt, max_len, device)\n", 184 | "\n", 185 | " source_text = batch[\"src_text\"][0]\n", 186 | " target_text = batch[\"tgt_text\"][0]\n", 187 | " model_out_text_beam = tokenizer_tgt.decode(model_out_beam.detach().cpu().numpy())\n", 188 | " model_out_text_greedy = tokenizer_tgt.decode(model_out_greedy.detach().cpu().numpy())\n", 189 | " \n", 190 | " # Print the source, target and model output\n", 191 | " print_msg('-'*console_width)\n", 192 | " print_msg(f\"{f'SOURCE: ':>20}{source_text}\")\n", 193 | " print_msg(f\"{f'TARGET: ':>20}{target_text}\")\n", 194 | " print_msg(f\"{f'PREDICTED GREEDY: ':>20}{model_out_text_greedy}\")\n", 195 | " print_msg(f\"{f'PREDICTED BEAM: ':>20}{model_out_text_beam}\")\n", 196 | "\n", 197 | " if count == num_examples:\n", 198 | " print_msg('-'*console_width)\n", 199 | " break\n", 200 | "\n", 201 | "run_validation(model, val_dataloader, tokenizer_src, tokenizer_tgt, 20, device, print_msg=print, num_examples=2)" 202 | ] 203 | }, 204 | { 205 | "cell_type": "code", 206 | "execution_count": null, 207 | "metadata": {}, 208 | "outputs": [], 209 | "source": [] 210 | } 211 | ], 212 | "metadata": { 213 | "kernelspec": { 214 | "display_name": "transformer", 215 | "language": "python", 216 | "name": "python3" 217 | }, 218 | "language_info": { 219 | "codemirror_mode": { 220 | "name": "ipython", 221 | "version": 3 222 | }, 223 | "file_extension": ".py", 224 | "mimetype": "text/x-python", 225 | "name": "python", 226 | "nbconvert_exporter": "python", 227 | "pygments_lexer": "ipython3", 228 | "version": "3.11.3" 229 | }, 230 | "orig_nbformat": 4 231 | }, 232 | "nbformat": 4, 233 | "nbformat_minor": 2 234 | } 235 | -------------------------------------------------------------------------------- /Inference.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from pathlib import Path\n", 10 | "import torch\n", 11 | "import torch.nn as nn\n", 12 | "from config import get_config, latest_weights_file_path\n", 13 | "from train import get_model, get_ds, run_validation\n", 14 | "from translate import translate" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "# Define the device\n", 24 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", 25 | "print(\"Using device:\", device)\n", 26 | "config = get_config()\n", 27 | "train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config)\n", 28 | "model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()).to(device)\n", 29 | "\n", 30 | "# Load the pretrained weights\n", 31 | "model_filename = latest_weights_file_path(config)\n", 32 | "state = torch.load(model_filename)\n", 33 | "model.load_state_dict(state['model_state_dict'])" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": null, 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "run_validation(model, val_dataloader, tokenizer_src, tokenizer_tgt, config['seq_len'], device, lambda msg: print(msg), 0, None, num_examples=10)" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": null, 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "t = translate(\"Why do I need to translate this?\")" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": null, 57 | "metadata": {}, 58 | "outputs": [], 59 | "source": [ 60 | "t = translate(34)" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": null, 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [] 69 | } 70 | ], 71 | "metadata": { 72 | "kernelspec": { 73 | "display_name": "transformer", 74 | "language": "python", 75 | "name": "python3" 76 | }, 77 | "language_info": { 78 | "codemirror_mode": { 79 | "name": "ipython", 80 | "version": 3 81 | }, 82 | "file_extension": ".py", 83 | "mimetype": "text/x-python", 84 | "name": "python", 85 | "nbconvert_exporter": "python", 86 | "pygments_lexer": "ipython3", 87 | "version": "3.9.0" 88 | }, 89 | "orig_nbformat": 4 90 | }, 91 | "nbformat": 4, 92 | "nbformat_minor": 2 93 | } 94 | -------------------------------------------------------------------------------- /Local_Train.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "colab": { 8 | "base_uri": "https://localhost:8080/", 9 | "height": 198, 10 | "referenced_widgets": [ 11 | "0ce327d5112b44dbb20e57752afc478a", 12 | "423a3059ad1a4e01bd01095cf1b41e14", 13 | "9cf2d2e2bfe24f2ab185165d79da8bdb", 14 | "996ac47b200c427088ee7644fe886896", 15 | "9b9addf13301466b9ef30b9d4b836a67", 16 | "ec2051bf0e9343d394e8a0ecb4fd5ec8", 17 | "56049bd375cd4512a0deaf69b7dae245", 18 | "140f33387db341398bc39e9c47703df4", 19 | "b3a8424c0b584a37ad2ede748085425c", 20 | "cb7d88a70af746f2ae31416b4b670c63", 21 | "4837276e5cf248449e287b1eeaef30ec", 22 | "3ab0f2022e654458875c2c091908e8c9", 23 | "f74bdeb79a224de8b1c85f4ca8657331", 24 | "4eb62038f89d4a8cb2c46e6a7cc70150", 25 | "9055fd09043642e0ae3d8a7a7c0ab31b", 26 | "4a2ead337d5c4ded9f28c93a70db1f08", 27 | "888a323362ae4daeac99915bcb3dcf10", 28 | "4d0e364e9f274e8ea7447e4e01c7f28f", 29 | "78a32764678a42f0a5a892f5275d88de", 30 | "aa17c3a834694a978046808fc5d29da1", 31 | "11e011e4acb24519bd41a054ddecbfb1", 32 | "5d1a9518abd44c18b122e575a7548ed2", 33 | "76e80fb236f5491597c992d1a809be33", 34 | "f7359467b0214c5385de8ee4334f7ba3", 35 | "a58ac736aa884eb9a27264cb04bb36ce", 36 | "6e6f7b7cccaa4f0cbfc9311db257bea1", 37 | "0656eee26364487f81580c3864e7a159", 38 | "05240e68c55a458286f43967e7f90889", 39 | "8cfa6df0ee654643bfdb4a3825e8fbbe", 40 | "96baa91869eb478eb492754b98169470", 41 | "bbda5260ca1c450386f9191e9f9dde97", 42 | "6fc5bec49f17469db39e0d4b535b94e9", 43 | "67822d28f8584e69abcb041b88377a9f", 44 | "aa082ade829247dc8ea0d75cc8a5b2a7", 45 | "83bc41f428b7492e9defdaa177f33a3e", 46 | "7f168d0ea11c4ea1a96202d3a36ec389", 47 | "ebb7ee3fd084466f9667771a99e6e3b2", 48 | "1e3c2a94251b4e75af0413a88b53bfe1", 49 | "a1188f80f78c49c7a822d71694e47074", 50 | "068552491889440e8a66e61b9f013786", 51 | "c88027eb3e1c4771ab57366070ecd553", 52 | "df75b255bfb04057b553830b59f0a153", 53 | "f0e5024d0d054c1eb8e01c4c8b027e79", 54 | "937ee45f4d634d189c6d95c886e97bca", 55 | "c2d14fa4280c48e0ae04859b73c80781", 56 | "d3104837d9734834b7c87e87289b08df", 57 | "02b02005adf241a4a0be8173ca3a4aee", 58 | "b317ba38f2b145f9b0b49f523547684f", 59 | "434340d109d1401d8868498a23b291cf", 60 | "2c95f5b81fc84ad698fe77b52cb84076", 61 | "ca588157678e4cc09c3fd760676efd39", 62 | "c020b38c6d2c436e8b742fd87d3b8b89", 63 | "3dc97a04373f484d9ccd1c46646d96cc", 64 | "4aed1fa58b7342eba35c2106ec934019", 65 | "60c72c47a8d84f0eab652822bed1ed09" 66 | ] 67 | }, 68 | "id": "gGDOaOoIwGc5", 69 | "outputId": "4180e60a-8985-4795-8e72-373deabc1ebc" 70 | }, 71 | "outputs": [], 72 | "source": [ 73 | "from config import get_config\n", 74 | "cfg = get_config()\n", 75 | "cfg['batch_size'] = 6\n", 76 | "cfg['preload'] = None\n", 77 | "cfg['num_epochs'] = 30\n", 78 | "\n", 79 | "from train import train_model\n", 80 | "\n", 81 | "train_model(cfg)" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": null, 87 | "metadata": {}, 88 | "outputs": [], 89 | "source": [] 90 | } 91 | ], 92 | "metadata": { 93 | "accelerator": "GPU", 94 | "colab": { 95 | "gpuType": "T4", 96 | "provenance": [] 97 | }, 98 | "gpuClass": "standard", 99 | "kernelspec": { 100 | "display_name": "Python 3", 101 | "name": "python3" 102 | }, 103 | "language_info": { 104 | "codemirror_mode": { 105 | "name": "ipython", 106 | "version": 3 107 | }, 108 | "file_extension": ".py", 109 | "mimetype": "text/x-python", 110 | "name": "python", 111 | "nbconvert_exporter": "python", 112 | "pygments_lexer": "ipython3", 113 | "version": "3.10.6" 114 | }, 115 | "widgets": { 116 | "application/vnd.jupyter.widget-state+json": { 117 | "02b02005adf241a4a0be8173ca3a4aee": { 118 | "model_module": "@jupyter-widgets/controls", 119 | "model_module_version": "1.5.0", 120 | "model_name": "FloatProgressModel", 121 | "state": { 122 | "_dom_classes": [], 123 | "_model_module": "@jupyter-widgets/controls", 124 | "_model_module_version": "1.5.0", 125 | "_model_name": "FloatProgressModel", 126 | "_view_count": null, 127 | "_view_module": "@jupyter-widgets/controls", 128 | "_view_module_version": "1.5.0", 129 | "_view_name": "ProgressView", 130 | "bar_style": "", 131 | "description": "", 132 | "description_tooltip": null, 133 | "layout": "IPY_MODEL_c020b38c6d2c436e8b742fd87d3b8b89", 134 | "max": 32332, 135 | "min": 0, 136 | "orientation": "horizontal", 137 | "style": "IPY_MODEL_3dc97a04373f484d9ccd1c46646d96cc", 138 | "value": 32332 139 | } 140 | }, 141 | "05240e68c55a458286f43967e7f90889": { 142 | "model_module": "@jupyter-widgets/base", 143 | "model_module_version": "1.2.0", 144 | "model_name": "LayoutModel", 145 | "state": { 146 | "_model_module": "@jupyter-widgets/base", 147 | "_model_module_version": "1.2.0", 148 | "_model_name": "LayoutModel", 149 | "_view_count": null, 150 | "_view_module": "@jupyter-widgets/base", 151 | "_view_module_version": "1.2.0", 152 | "_view_name": "LayoutView", 153 | "align_content": null, 154 | "align_items": null, 155 | "align_self": null, 156 | "border": null, 157 | "bottom": null, 158 | "display": null, 159 | "flex": null, 160 | "flex_flow": null, 161 | "grid_area": null, 162 | "grid_auto_columns": null, 163 | "grid_auto_flow": null, 164 | "grid_auto_rows": null, 165 | "grid_column": null, 166 | "grid_gap": null, 167 | "grid_row": null, 168 | "grid_template_areas": null, 169 | "grid_template_columns": null, 170 | "grid_template_rows": null, 171 | "height": null, 172 | "justify_content": null, 173 | "justify_items": null, 174 | "left": null, 175 | "margin": null, 176 | "max_height": null, 177 | "max_width": null, 178 | "min_height": null, 179 | "min_width": null, 180 | "object_fit": null, 181 | "object_position": null, 182 | "order": null, 183 | "overflow": null, 184 | "overflow_x": null, 185 | "overflow_y": null, 186 | "padding": null, 187 | "right": null, 188 | "top": null, 189 | "visibility": null, 190 | "width": null 191 | } 192 | }, 193 | "0656eee26364487f81580c3864e7a159": { 194 | "model_module": "@jupyter-widgets/base", 195 | "model_module_version": "1.2.0", 196 | "model_name": "LayoutModel", 197 | "state": { 198 | "_model_module": "@jupyter-widgets/base", 199 | "_model_module_version": "1.2.0", 200 | "_model_name": "LayoutModel", 201 | "_view_count": null, 202 | "_view_module": "@jupyter-widgets/base", 203 | "_view_module_version": "1.2.0", 204 | "_view_name": "LayoutView", 205 | "align_content": null, 206 | "align_items": null, 207 | "align_self": null, 208 | "border": null, 209 | "bottom": null, 210 | "display": null, 211 | "flex": null, 212 | "flex_flow": null, 213 | "grid_area": null, 214 | "grid_auto_columns": null, 215 | "grid_auto_flow": null, 216 | "grid_auto_rows": null, 217 | "grid_column": null, 218 | "grid_gap": null, 219 | "grid_row": null, 220 | "grid_template_areas": null, 221 | "grid_template_columns": null, 222 | "grid_template_rows": null, 223 | "height": null, 224 | "justify_content": null, 225 | "justify_items": null, 226 | "left": null, 227 | "margin": null, 228 | "max_height": null, 229 | "max_width": null, 230 | "min_height": null, 231 | "min_width": null, 232 | "object_fit": null, 233 | "object_position": null, 234 | "order": null, 235 | "overflow": null, 236 | "overflow_x": null, 237 | "overflow_y": null, 238 | "padding": null, 239 | "right": null, 240 | "top": null, 241 | "visibility": null, 242 | "width": null 243 | } 244 | }, 245 | "068552491889440e8a66e61b9f013786": { 246 | "model_module": "@jupyter-widgets/controls", 247 | "model_module_version": "1.5.0", 248 | "model_name": "DescriptionStyleModel", 249 | "state": { 250 | "_model_module": "@jupyter-widgets/controls", 251 | "_model_module_version": "1.5.0", 252 | "_model_name": "DescriptionStyleModel", 253 | "_view_count": null, 254 | "_view_module": "@jupyter-widgets/base", 255 | "_view_module_version": "1.2.0", 256 | "_view_name": "StyleView", 257 | "description_width": "" 258 | } 259 | }, 260 | "0ce327d5112b44dbb20e57752afc478a": { 261 | "model_module": "@jupyter-widgets/controls", 262 | "model_module_version": "1.5.0", 263 | "model_name": "HBoxModel", 264 | "state": { 265 | "_dom_classes": [], 266 | "_model_module": "@jupyter-widgets/controls", 267 | "_model_module_version": "1.5.0", 268 | "_model_name": "HBoxModel", 269 | "_view_count": null, 270 | "_view_module": "@jupyter-widgets/controls", 271 | "_view_module_version": "1.5.0", 272 | "_view_name": "HBoxView", 273 | "box_style": "", 274 | "children": [ 275 | "IPY_MODEL_423a3059ad1a4e01bd01095cf1b41e14", 276 | "IPY_MODEL_9cf2d2e2bfe24f2ab185165d79da8bdb", 277 | "IPY_MODEL_996ac47b200c427088ee7644fe886896" 278 | ], 279 | "layout": "IPY_MODEL_9b9addf13301466b9ef30b9d4b836a67" 280 | } 281 | }, 282 | "11e011e4acb24519bd41a054ddecbfb1": { 283 | "model_module": "@jupyter-widgets/base", 284 | "model_module_version": "1.2.0", 285 | "model_name": "LayoutModel", 286 | "state": { 287 | "_model_module": "@jupyter-widgets/base", 288 | "_model_module_version": "1.2.0", 289 | "_model_name": "LayoutModel", 290 | "_view_count": null, 291 | "_view_module": "@jupyter-widgets/base", 292 | "_view_module_version": "1.2.0", 293 | "_view_name": "LayoutView", 294 | "align_content": null, 295 | "align_items": null, 296 | "align_self": null, 297 | "border": null, 298 | "bottom": null, 299 | "display": null, 300 | "flex": null, 301 | "flex_flow": null, 302 | "grid_area": null, 303 | "grid_auto_columns": null, 304 | "grid_auto_flow": null, 305 | "grid_auto_rows": null, 306 | "grid_column": null, 307 | "grid_gap": null, 308 | "grid_row": null, 309 | "grid_template_areas": null, 310 | "grid_template_columns": null, 311 | "grid_template_rows": null, 312 | "height": null, 313 | "justify_content": null, 314 | "justify_items": null, 315 | "left": null, 316 | "margin": null, 317 | "max_height": null, 318 | "max_width": null, 319 | "min_height": null, 320 | "min_width": null, 321 | "object_fit": null, 322 | "object_position": null, 323 | "order": null, 324 | "overflow": null, 325 | "overflow_x": null, 326 | "overflow_y": null, 327 | "padding": null, 328 | "right": null, 329 | "top": null, 330 | "visibility": null, 331 | "width": null 332 | } 333 | }, 334 | "140f33387db341398bc39e9c47703df4": { 335 | "model_module": "@jupyter-widgets/base", 336 | "model_module_version": "1.2.0", 337 | "model_name": "LayoutModel", 338 | "state": { 339 | "_model_module": "@jupyter-widgets/base", 340 | "_model_module_version": "1.2.0", 341 | "_model_name": "LayoutModel", 342 | "_view_count": null, 343 | "_view_module": "@jupyter-widgets/base", 344 | "_view_module_version": "1.2.0", 345 | "_view_name": "LayoutView", 346 | "align_content": null, 347 | "align_items": null, 348 | "align_self": null, 349 | "border": null, 350 | "bottom": null, 351 | "display": null, 352 | "flex": null, 353 | "flex_flow": null, 354 | "grid_area": null, 355 | "grid_auto_columns": null, 356 | "grid_auto_flow": null, 357 | "grid_auto_rows": null, 358 | "grid_column": null, 359 | "grid_gap": null, 360 | "grid_row": null, 361 | "grid_template_areas": null, 362 | "grid_template_columns": null, 363 | "grid_template_rows": null, 364 | "height": null, 365 | "justify_content": null, 366 | "justify_items": null, 367 | "left": null, 368 | "margin": null, 369 | "max_height": null, 370 | "max_width": null, 371 | "min_height": null, 372 | "min_width": null, 373 | "object_fit": null, 374 | "object_position": null, 375 | "order": null, 376 | "overflow": null, 377 | "overflow_x": null, 378 | "overflow_y": null, 379 | "padding": null, 380 | "right": null, 381 | "top": null, 382 | "visibility": null, 383 | "width": null 384 | } 385 | }, 386 | "1e3c2a94251b4e75af0413a88b53bfe1": { 387 | "model_module": "@jupyter-widgets/base", 388 | "model_module_version": "1.2.0", 389 | "model_name": "LayoutModel", 390 | "state": { 391 | "_model_module": "@jupyter-widgets/base", 392 | "_model_module_version": "1.2.0", 393 | "_model_name": "LayoutModel", 394 | "_view_count": null, 395 | "_view_module": "@jupyter-widgets/base", 396 | "_view_module_version": "1.2.0", 397 | "_view_name": "LayoutView", 398 | "align_content": null, 399 | "align_items": null, 400 | "align_self": null, 401 | "border": null, 402 | "bottom": null, 403 | "display": null, 404 | "flex": null, 405 | "flex_flow": null, 406 | "grid_area": null, 407 | "grid_auto_columns": null, 408 | "grid_auto_flow": null, 409 | "grid_auto_rows": null, 410 | "grid_column": null, 411 | "grid_gap": null, 412 | "grid_row": null, 413 | "grid_template_areas": null, 414 | "grid_template_columns": null, 415 | "grid_template_rows": null, 416 | "height": null, 417 | "justify_content": null, 418 | "justify_items": null, 419 | "left": null, 420 | "margin": null, 421 | "max_height": null, 422 | "max_width": null, 423 | "min_height": null, 424 | "min_width": null, 425 | "object_fit": null, 426 | "object_position": null, 427 | "order": null, 428 | "overflow": null, 429 | "overflow_x": null, 430 | "overflow_y": null, 431 | "padding": null, 432 | "right": null, 433 | "top": null, 434 | "visibility": null, 435 | "width": null 436 | } 437 | }, 438 | "2c95f5b81fc84ad698fe77b52cb84076": { 439 | "model_module": "@jupyter-widgets/base", 440 | "model_module_version": "1.2.0", 441 | "model_name": "LayoutModel", 442 | "state": { 443 | "_model_module": "@jupyter-widgets/base", 444 | "_model_module_version": "1.2.0", 445 | "_model_name": "LayoutModel", 446 | "_view_count": null, 447 | "_view_module": "@jupyter-widgets/base", 448 | "_view_module_version": "1.2.0", 449 | "_view_name": "LayoutView", 450 | "align_content": null, 451 | "align_items": null, 452 | "align_self": null, 453 | "border": null, 454 | "bottom": null, 455 | "display": null, 456 | "flex": null, 457 | "flex_flow": null, 458 | "grid_area": null, 459 | "grid_auto_columns": null, 460 | "grid_auto_flow": null, 461 | "grid_auto_rows": null, 462 | "grid_column": null, 463 | "grid_gap": null, 464 | "grid_row": null, 465 | "grid_template_areas": null, 466 | "grid_template_columns": null, 467 | "grid_template_rows": null, 468 | "height": null, 469 | "justify_content": null, 470 | "justify_items": null, 471 | "left": null, 472 | "margin": null, 473 | "max_height": null, 474 | "max_width": null, 475 | "min_height": null, 476 | "min_width": null, 477 | "object_fit": null, 478 | "object_position": null, 479 | "order": null, 480 | "overflow": null, 481 | "overflow_x": null, 482 | "overflow_y": null, 483 | "padding": null, 484 | "right": null, 485 | "top": null, 486 | "visibility": null, 487 | "width": null 488 | } 489 | }, 490 | "3ab0f2022e654458875c2c091908e8c9": { 491 | "model_module": "@jupyter-widgets/controls", 492 | "model_module_version": "1.5.0", 493 | "model_name": "HBoxModel", 494 | "state": { 495 | "_dom_classes": [], 496 | "_model_module": "@jupyter-widgets/controls", 497 | "_model_module_version": "1.5.0", 498 | "_model_name": "HBoxModel", 499 | "_view_count": null, 500 | "_view_module": "@jupyter-widgets/controls", 501 | "_view_module_version": "1.5.0", 502 | "_view_name": "HBoxView", 503 | "box_style": "", 504 | "children": [ 505 | "IPY_MODEL_f74bdeb79a224de8b1c85f4ca8657331", 506 | "IPY_MODEL_4eb62038f89d4a8cb2c46e6a7cc70150", 507 | "IPY_MODEL_9055fd09043642e0ae3d8a7a7c0ab31b" 508 | ], 509 | "layout": "IPY_MODEL_4a2ead337d5c4ded9f28c93a70db1f08" 510 | } 511 | }, 512 | "3dc97a04373f484d9ccd1c46646d96cc": { 513 | "model_module": "@jupyter-widgets/controls", 514 | "model_module_version": "1.5.0", 515 | "model_name": "ProgressStyleModel", 516 | "state": { 517 | "_model_module": "@jupyter-widgets/controls", 518 | "_model_module_version": "1.5.0", 519 | "_model_name": "ProgressStyleModel", 520 | "_view_count": null, 521 | "_view_module": "@jupyter-widgets/base", 522 | "_view_module_version": "1.2.0", 523 | "_view_name": "StyleView", 524 | "bar_color": null, 525 | "description_width": "" 526 | } 527 | }, 528 | "423a3059ad1a4e01bd01095cf1b41e14": { 529 | "model_module": "@jupyter-widgets/controls", 530 | "model_module_version": "1.5.0", 531 | "model_name": "HTMLModel", 532 | "state": { 533 | "_dom_classes": [], 534 | "_model_module": "@jupyter-widgets/controls", 535 | "_model_module_version": "1.5.0", 536 | "_model_name": "HTMLModel", 537 | "_view_count": null, 538 | "_view_module": "@jupyter-widgets/controls", 539 | "_view_module_version": "1.5.0", 540 | "_view_name": "HTMLView", 541 | "description": "", 542 | "description_tooltip": null, 543 | "layout": "IPY_MODEL_ec2051bf0e9343d394e8a0ecb4fd5ec8", 544 | "placeholder": "​", 545 | "style": "IPY_MODEL_56049bd375cd4512a0deaf69b7dae245", 546 | "value": "Downloading builder script: 100%" 547 | } 548 | }, 549 | "434340d109d1401d8868498a23b291cf": { 550 | "model_module": "@jupyter-widgets/base", 551 | "model_module_version": "1.2.0", 552 | "model_name": "LayoutModel", 553 | "state": { 554 | "_model_module": "@jupyter-widgets/base", 555 | "_model_module_version": "1.2.0", 556 | "_model_name": "LayoutModel", 557 | "_view_count": null, 558 | "_view_module": "@jupyter-widgets/base", 559 | "_view_module_version": "1.2.0", 560 | "_view_name": "LayoutView", 561 | "align_content": null, 562 | "align_items": null, 563 | "align_self": null, 564 | "border": null, 565 | "bottom": null, 566 | "display": null, 567 | "flex": null, 568 | "flex_flow": null, 569 | "grid_area": null, 570 | "grid_auto_columns": null, 571 | "grid_auto_flow": null, 572 | "grid_auto_rows": null, 573 | "grid_column": null, 574 | "grid_gap": null, 575 | "grid_row": null, 576 | "grid_template_areas": null, 577 | "grid_template_columns": null, 578 | "grid_template_rows": null, 579 | "height": null, 580 | "justify_content": null, 581 | "justify_items": null, 582 | "left": null, 583 | "margin": null, 584 | "max_height": null, 585 | "max_width": null, 586 | "min_height": null, 587 | "min_width": null, 588 | "object_fit": null, 589 | "object_position": null, 590 | "order": null, 591 | "overflow": null, 592 | "overflow_x": null, 593 | "overflow_y": null, 594 | "padding": null, 595 | "right": null, 596 | "top": null, 597 | "visibility": "hidden", 598 | "width": null 599 | } 600 | }, 601 | "4837276e5cf248449e287b1eeaef30ec": { 602 | "model_module": "@jupyter-widgets/controls", 603 | "model_module_version": "1.5.0", 604 | "model_name": "DescriptionStyleModel", 605 | "state": { 606 | "_model_module": "@jupyter-widgets/controls", 607 | "_model_module_version": "1.5.0", 608 | "_model_name": "DescriptionStyleModel", 609 | "_view_count": null, 610 | "_view_module": "@jupyter-widgets/base", 611 | "_view_module_version": "1.2.0", 612 | "_view_name": "StyleView", 613 | "description_width": "" 614 | } 615 | }, 616 | "4a2ead337d5c4ded9f28c93a70db1f08": { 617 | "model_module": "@jupyter-widgets/base", 618 | "model_module_version": "1.2.0", 619 | "model_name": "LayoutModel", 620 | "state": { 621 | "_model_module": "@jupyter-widgets/base", 622 | "_model_module_version": "1.2.0", 623 | "_model_name": "LayoutModel", 624 | "_view_count": null, 625 | "_view_module": "@jupyter-widgets/base", 626 | "_view_module_version": "1.2.0", 627 | "_view_name": "LayoutView", 628 | "align_content": null, 629 | "align_items": null, 630 | "align_self": null, 631 | "border": null, 632 | "bottom": null, 633 | "display": null, 634 | "flex": null, 635 | "flex_flow": null, 636 | "grid_area": null, 637 | "grid_auto_columns": null, 638 | "grid_auto_flow": null, 639 | "grid_auto_rows": null, 640 | "grid_column": null, 641 | "grid_gap": null, 642 | "grid_row": null, 643 | "grid_template_areas": null, 644 | "grid_template_columns": null, 645 | "grid_template_rows": null, 646 | "height": null, 647 | "justify_content": null, 648 | "justify_items": null, 649 | "left": null, 650 | "margin": null, 651 | "max_height": null, 652 | "max_width": null, 653 | "min_height": null, 654 | "min_width": null, 655 | "object_fit": null, 656 | "object_position": null, 657 | "order": null, 658 | "overflow": null, 659 | "overflow_x": null, 660 | "overflow_y": null, 661 | "padding": null, 662 | "right": null, 663 | "top": null, 664 | "visibility": null, 665 | "width": null 666 | } 667 | }, 668 | "4aed1fa58b7342eba35c2106ec934019": { 669 | "model_module": "@jupyter-widgets/base", 670 | "model_module_version": "1.2.0", 671 | "model_name": "LayoutModel", 672 | "state": { 673 | "_model_module": "@jupyter-widgets/base", 674 | "_model_module_version": "1.2.0", 675 | "_model_name": "LayoutModel", 676 | "_view_count": null, 677 | "_view_module": "@jupyter-widgets/base", 678 | "_view_module_version": "1.2.0", 679 | "_view_name": "LayoutView", 680 | "align_content": null, 681 | "align_items": null, 682 | "align_self": null, 683 | "border": null, 684 | "bottom": null, 685 | "display": null, 686 | "flex": null, 687 | "flex_flow": null, 688 | "grid_area": null, 689 | "grid_auto_columns": null, 690 | "grid_auto_flow": null, 691 | "grid_auto_rows": null, 692 | "grid_column": null, 693 | "grid_gap": null, 694 | "grid_row": null, 695 | "grid_template_areas": null, 696 | "grid_template_columns": null, 697 | "grid_template_rows": null, 698 | "height": null, 699 | "justify_content": null, 700 | "justify_items": null, 701 | "left": null, 702 | "margin": null, 703 | "max_height": null, 704 | "max_width": null, 705 | "min_height": null, 706 | "min_width": null, 707 | "object_fit": null, 708 | "object_position": null, 709 | "order": null, 710 | "overflow": null, 711 | "overflow_x": null, 712 | "overflow_y": null, 713 | "padding": null, 714 | "right": null, 715 | "top": null, 716 | "visibility": null, 717 | "width": null 718 | } 719 | }, 720 | "4d0e364e9f274e8ea7447e4e01c7f28f": { 721 | "model_module": "@jupyter-widgets/controls", 722 | "model_module_version": "1.5.0", 723 | "model_name": "DescriptionStyleModel", 724 | "state": { 725 | "_model_module": "@jupyter-widgets/controls", 726 | "_model_module_version": "1.5.0", 727 | "_model_name": "DescriptionStyleModel", 728 | "_view_count": null, 729 | "_view_module": "@jupyter-widgets/base", 730 | "_view_module_version": "1.2.0", 731 | "_view_name": "StyleView", 732 | "description_width": "" 733 | } 734 | }, 735 | "4eb62038f89d4a8cb2c46e6a7cc70150": { 736 | "model_module": "@jupyter-widgets/controls", 737 | "model_module_version": "1.5.0", 738 | "model_name": "FloatProgressModel", 739 | "state": { 740 | "_dom_classes": [], 741 | "_model_module": "@jupyter-widgets/controls", 742 | "_model_module_version": "1.5.0", 743 | "_model_name": "FloatProgressModel", 744 | "_view_count": null, 745 | "_view_module": "@jupyter-widgets/controls", 746 | "_view_module_version": "1.5.0", 747 | "_view_name": "ProgressView", 748 | "bar_style": "success", 749 | "description": "", 750 | "description_tooltip": null, 751 | "layout": "IPY_MODEL_78a32764678a42f0a5a892f5275d88de", 752 | "max": 161154, 753 | "min": 0, 754 | "orientation": "horizontal", 755 | "style": "IPY_MODEL_aa17c3a834694a978046808fc5d29da1", 756 | "value": 161154 757 | } 758 | }, 759 | "56049bd375cd4512a0deaf69b7dae245": { 760 | "model_module": "@jupyter-widgets/controls", 761 | "model_module_version": "1.5.0", 762 | "model_name": "DescriptionStyleModel", 763 | "state": { 764 | "_model_module": "@jupyter-widgets/controls", 765 | "_model_module_version": "1.5.0", 766 | "_model_name": "DescriptionStyleModel", 767 | "_view_count": null, 768 | "_view_module": "@jupyter-widgets/base", 769 | "_view_module_version": "1.2.0", 770 | "_view_name": "StyleView", 771 | "description_width": "" 772 | } 773 | }, 774 | "5d1a9518abd44c18b122e575a7548ed2": { 775 | "model_module": "@jupyter-widgets/controls", 776 | "model_module_version": "1.5.0", 777 | "model_name": "DescriptionStyleModel", 778 | "state": { 779 | "_model_module": "@jupyter-widgets/controls", 780 | "_model_module_version": "1.5.0", 781 | "_model_name": "DescriptionStyleModel", 782 | "_view_count": null, 783 | "_view_module": "@jupyter-widgets/base", 784 | "_view_module_version": "1.2.0", 785 | "_view_name": "StyleView", 786 | "description_width": "" 787 | } 788 | }, 789 | "60c72c47a8d84f0eab652822bed1ed09": { 790 | "model_module": "@jupyter-widgets/controls", 791 | "model_module_version": "1.5.0", 792 | "model_name": "DescriptionStyleModel", 793 | "state": { 794 | "_model_module": "@jupyter-widgets/controls", 795 | "_model_module_version": "1.5.0", 796 | "_model_name": "DescriptionStyleModel", 797 | "_view_count": null, 798 | "_view_module": "@jupyter-widgets/base", 799 | "_view_module_version": "1.2.0", 800 | "_view_name": "StyleView", 801 | "description_width": "" 802 | } 803 | }, 804 | "67822d28f8584e69abcb041b88377a9f": { 805 | "model_module": "@jupyter-widgets/controls", 806 | "model_module_version": "1.5.0", 807 | "model_name": "DescriptionStyleModel", 808 | "state": { 809 | "_model_module": "@jupyter-widgets/controls", 810 | "_model_module_version": "1.5.0", 811 | "_model_name": "DescriptionStyleModel", 812 | "_view_count": null, 813 | "_view_module": "@jupyter-widgets/base", 814 | "_view_module_version": "1.2.0", 815 | "_view_name": "StyleView", 816 | "description_width": "" 817 | } 818 | }, 819 | "6e6f7b7cccaa4f0cbfc9311db257bea1": { 820 | "model_module": "@jupyter-widgets/controls", 821 | "model_module_version": "1.5.0", 822 | "model_name": "HTMLModel", 823 | "state": { 824 | "_dom_classes": [], 825 | "_model_module": "@jupyter-widgets/controls", 826 | "_model_module_version": "1.5.0", 827 | "_model_name": "HTMLModel", 828 | "_view_count": null, 829 | "_view_module": "@jupyter-widgets/controls", 830 | "_view_module_version": "1.5.0", 831 | "_view_name": "HTMLView", 832 | "description": "", 833 | "description_tooltip": null, 834 | "layout": "IPY_MODEL_6fc5bec49f17469db39e0d4b535b94e9", 835 | "placeholder": "​", 836 | "style": "IPY_MODEL_67822d28f8584e69abcb041b88377a9f", 837 | "value": " 20.5k/20.5k [00:00<00:00, 1.34MB/s]" 838 | } 839 | }, 840 | "6fc5bec49f17469db39e0d4b535b94e9": { 841 | "model_module": "@jupyter-widgets/base", 842 | "model_module_version": "1.2.0", 843 | "model_name": "LayoutModel", 844 | "state": { 845 | "_model_module": "@jupyter-widgets/base", 846 | "_model_module_version": "1.2.0", 847 | "_model_name": "LayoutModel", 848 | "_view_count": null, 849 | "_view_module": "@jupyter-widgets/base", 850 | "_view_module_version": "1.2.0", 851 | "_view_name": "LayoutView", 852 | "align_content": null, 853 | "align_items": null, 854 | "align_self": null, 855 | "border": null, 856 | "bottom": null, 857 | "display": null, 858 | "flex": null, 859 | "flex_flow": null, 860 | "grid_area": null, 861 | "grid_auto_columns": null, 862 | "grid_auto_flow": null, 863 | "grid_auto_rows": null, 864 | "grid_column": null, 865 | "grid_gap": null, 866 | "grid_row": null, 867 | "grid_template_areas": null, 868 | "grid_template_columns": null, 869 | "grid_template_rows": null, 870 | "height": null, 871 | "justify_content": null, 872 | "justify_items": null, 873 | "left": null, 874 | "margin": null, 875 | "max_height": null, 876 | "max_width": null, 877 | "min_height": null, 878 | "min_width": null, 879 | "object_fit": null, 880 | "object_position": null, 881 | "order": null, 882 | "overflow": null, 883 | "overflow_x": null, 884 | "overflow_y": null, 885 | "padding": null, 886 | "right": null, 887 | "top": null, 888 | "visibility": null, 889 | "width": null 890 | } 891 | }, 892 | "76e80fb236f5491597c992d1a809be33": { 893 | "model_module": "@jupyter-widgets/controls", 894 | "model_module_version": "1.5.0", 895 | "model_name": "HBoxModel", 896 | "state": { 897 | "_dom_classes": [], 898 | "_model_module": "@jupyter-widgets/controls", 899 | "_model_module_version": "1.5.0", 900 | "_model_name": "HBoxModel", 901 | "_view_count": null, 902 | "_view_module": "@jupyter-widgets/controls", 903 | "_view_module_version": "1.5.0", 904 | "_view_name": "HBoxView", 905 | "box_style": "", 906 | "children": [ 907 | "IPY_MODEL_f7359467b0214c5385de8ee4334f7ba3", 908 | "IPY_MODEL_a58ac736aa884eb9a27264cb04bb36ce", 909 | "IPY_MODEL_6e6f7b7cccaa4f0cbfc9311db257bea1" 910 | ], 911 | "layout": "IPY_MODEL_0656eee26364487f81580c3864e7a159" 912 | } 913 | }, 914 | "78a32764678a42f0a5a892f5275d88de": { 915 | "model_module": "@jupyter-widgets/base", 916 | "model_module_version": "1.2.0", 917 | "model_name": "LayoutModel", 918 | "state": { 919 | "_model_module": "@jupyter-widgets/base", 920 | "_model_module_version": "1.2.0", 921 | "_model_name": "LayoutModel", 922 | "_view_count": null, 923 | "_view_module": "@jupyter-widgets/base", 924 | "_view_module_version": "1.2.0", 925 | "_view_name": "LayoutView", 926 | "align_content": null, 927 | "align_items": null, 928 | "align_self": null, 929 | "border": null, 930 | "bottom": null, 931 | "display": null, 932 | "flex": null, 933 | "flex_flow": null, 934 | "grid_area": null, 935 | "grid_auto_columns": null, 936 | "grid_auto_flow": null, 937 | "grid_auto_rows": null, 938 | "grid_column": null, 939 | "grid_gap": null, 940 | "grid_row": null, 941 | "grid_template_areas": null, 942 | "grid_template_columns": null, 943 | "grid_template_rows": null, 944 | "height": null, 945 | "justify_content": null, 946 | "justify_items": null, 947 | "left": null, 948 | "margin": null, 949 | "max_height": null, 950 | "max_width": null, 951 | "min_height": null, 952 | "min_width": null, 953 | "object_fit": null, 954 | "object_position": null, 955 | "order": null, 956 | "overflow": null, 957 | "overflow_x": null, 958 | "overflow_y": null, 959 | "padding": null, 960 | "right": null, 961 | "top": null, 962 | "visibility": null, 963 | "width": null 964 | } 965 | }, 966 | "7f168d0ea11c4ea1a96202d3a36ec389": { 967 | "model_module": "@jupyter-widgets/controls", 968 | "model_module_version": "1.5.0", 969 | "model_name": "FloatProgressModel", 970 | "state": { 971 | "_dom_classes": [], 972 | "_model_module": "@jupyter-widgets/controls", 973 | "_model_module_version": "1.5.0", 974 | "_model_name": "FloatProgressModel", 975 | "_view_count": null, 976 | "_view_module": "@jupyter-widgets/controls", 977 | "_view_module_version": "1.5.0", 978 | "_view_name": "ProgressView", 979 | "bar_style": "success", 980 | "description": "", 981 | "description_tooltip": null, 982 | "layout": "IPY_MODEL_c88027eb3e1c4771ab57366070ecd553", 983 | "max": 3295251, 984 | "min": 0, 985 | "orientation": "horizontal", 986 | "style": "IPY_MODEL_df75b255bfb04057b553830b59f0a153", 987 | "value": 3295251 988 | } 989 | }, 990 | "83bc41f428b7492e9defdaa177f33a3e": { 991 | "model_module": "@jupyter-widgets/controls", 992 | "model_module_version": "1.5.0", 993 | "model_name": "HTMLModel", 994 | "state": { 995 | "_dom_classes": [], 996 | "_model_module": "@jupyter-widgets/controls", 997 | "_model_module_version": "1.5.0", 998 | "_model_name": "HTMLModel", 999 | "_view_count": null, 1000 | "_view_module": "@jupyter-widgets/controls", 1001 | "_view_module_version": "1.5.0", 1002 | "_view_name": "HTMLView", 1003 | "description": "", 1004 | "description_tooltip": null, 1005 | "layout": "IPY_MODEL_a1188f80f78c49c7a822d71694e47074", 1006 | "placeholder": "​", 1007 | "style": "IPY_MODEL_068552491889440e8a66e61b9f013786", 1008 | "value": "Downloading data: 100%" 1009 | } 1010 | }, 1011 | "888a323362ae4daeac99915bcb3dcf10": { 1012 | "model_module": "@jupyter-widgets/base", 1013 | "model_module_version": "1.2.0", 1014 | "model_name": "LayoutModel", 1015 | "state": { 1016 | "_model_module": "@jupyter-widgets/base", 1017 | "_model_module_version": "1.2.0", 1018 | "_model_name": "LayoutModel", 1019 | "_view_count": null, 1020 | "_view_module": "@jupyter-widgets/base", 1021 | "_view_module_version": "1.2.0", 1022 | "_view_name": "LayoutView", 1023 | "align_content": null, 1024 | "align_items": null, 1025 | "align_self": null, 1026 | "border": null, 1027 | "bottom": null, 1028 | "display": null, 1029 | "flex": null, 1030 | "flex_flow": null, 1031 | "grid_area": null, 1032 | "grid_auto_columns": null, 1033 | "grid_auto_flow": null, 1034 | "grid_auto_rows": null, 1035 | "grid_column": null, 1036 | "grid_gap": null, 1037 | "grid_row": null, 1038 | "grid_template_areas": null, 1039 | "grid_template_columns": null, 1040 | "grid_template_rows": null, 1041 | "height": null, 1042 | "justify_content": null, 1043 | "justify_items": null, 1044 | "left": null, 1045 | "margin": null, 1046 | "max_height": null, 1047 | "max_width": null, 1048 | "min_height": null, 1049 | "min_width": null, 1050 | "object_fit": null, 1051 | "object_position": null, 1052 | "order": null, 1053 | "overflow": null, 1054 | "overflow_x": null, 1055 | "overflow_y": null, 1056 | "padding": null, 1057 | "right": null, 1058 | "top": null, 1059 | "visibility": null, 1060 | "width": null 1061 | } 1062 | }, 1063 | "8cfa6df0ee654643bfdb4a3825e8fbbe": { 1064 | "model_module": "@jupyter-widgets/controls", 1065 | "model_module_version": "1.5.0", 1066 | "model_name": "DescriptionStyleModel", 1067 | "state": { 1068 | "_model_module": "@jupyter-widgets/controls", 1069 | "_model_module_version": "1.5.0", 1070 | "_model_name": "DescriptionStyleModel", 1071 | "_view_count": null, 1072 | "_view_module": "@jupyter-widgets/base", 1073 | "_view_module_version": "1.2.0", 1074 | "_view_name": "StyleView", 1075 | "description_width": "" 1076 | } 1077 | }, 1078 | "9055fd09043642e0ae3d8a7a7c0ab31b": { 1079 | "model_module": "@jupyter-widgets/controls", 1080 | "model_module_version": "1.5.0", 1081 | "model_name": "HTMLModel", 1082 | "state": { 1083 | "_dom_classes": [], 1084 | "_model_module": "@jupyter-widgets/controls", 1085 | "_model_module_version": "1.5.0", 1086 | "_model_name": "HTMLModel", 1087 | "_view_count": null, 1088 | "_view_module": "@jupyter-widgets/controls", 1089 | "_view_module_version": "1.5.0", 1090 | "_view_name": "HTMLView", 1091 | "description": "", 1092 | "description_tooltip": null, 1093 | "layout": "IPY_MODEL_11e011e4acb24519bd41a054ddecbfb1", 1094 | "placeholder": "​", 1095 | "style": "IPY_MODEL_5d1a9518abd44c18b122e575a7548ed2", 1096 | "value": " 161k/161k [00:00<00:00, 865kB/s]" 1097 | } 1098 | }, 1099 | "937ee45f4d634d189c6d95c886e97bca": { 1100 | "model_module": "@jupyter-widgets/controls", 1101 | "model_module_version": "1.5.0", 1102 | "model_name": "DescriptionStyleModel", 1103 | "state": { 1104 | "_model_module": "@jupyter-widgets/controls", 1105 | "_model_module_version": "1.5.0", 1106 | "_model_name": "DescriptionStyleModel", 1107 | "_view_count": null, 1108 | "_view_module": "@jupyter-widgets/base", 1109 | "_view_module_version": "1.2.0", 1110 | "_view_name": "StyleView", 1111 | "description_width": "" 1112 | } 1113 | }, 1114 | "96baa91869eb478eb492754b98169470": { 1115 | "model_module": "@jupyter-widgets/base", 1116 | "model_module_version": "1.2.0", 1117 | "model_name": "LayoutModel", 1118 | "state": { 1119 | "_model_module": "@jupyter-widgets/base", 1120 | "_model_module_version": "1.2.0", 1121 | "_model_name": "LayoutModel", 1122 | "_view_count": null, 1123 | "_view_module": "@jupyter-widgets/base", 1124 | "_view_module_version": "1.2.0", 1125 | "_view_name": "LayoutView", 1126 | "align_content": null, 1127 | "align_items": null, 1128 | "align_self": null, 1129 | "border": null, 1130 | "bottom": null, 1131 | "display": null, 1132 | "flex": null, 1133 | "flex_flow": null, 1134 | "grid_area": null, 1135 | "grid_auto_columns": null, 1136 | "grid_auto_flow": null, 1137 | "grid_auto_rows": null, 1138 | "grid_column": null, 1139 | "grid_gap": null, 1140 | "grid_row": null, 1141 | "grid_template_areas": null, 1142 | "grid_template_columns": null, 1143 | "grid_template_rows": null, 1144 | "height": null, 1145 | "justify_content": null, 1146 | "justify_items": null, 1147 | "left": null, 1148 | "margin": null, 1149 | "max_height": null, 1150 | "max_width": null, 1151 | "min_height": null, 1152 | "min_width": null, 1153 | "object_fit": null, 1154 | "object_position": null, 1155 | "order": null, 1156 | "overflow": null, 1157 | "overflow_x": null, 1158 | "overflow_y": null, 1159 | "padding": null, 1160 | "right": null, 1161 | "top": null, 1162 | "visibility": null, 1163 | "width": null 1164 | } 1165 | }, 1166 | "996ac47b200c427088ee7644fe886896": { 1167 | "model_module": "@jupyter-widgets/controls", 1168 | "model_module_version": "1.5.0", 1169 | "model_name": "HTMLModel", 1170 | "state": { 1171 | "_dom_classes": [], 1172 | "_model_module": "@jupyter-widgets/controls", 1173 | "_model_module_version": "1.5.0", 1174 | "_model_name": "HTMLModel", 1175 | "_view_count": null, 1176 | "_view_module": "@jupyter-widgets/controls", 1177 | "_view_module_version": "1.5.0", 1178 | "_view_name": "HTMLView", 1179 | "description": "", 1180 | "description_tooltip": null, 1181 | "layout": "IPY_MODEL_cb7d88a70af746f2ae31416b4b670c63", 1182 | "placeholder": "​", 1183 | "style": "IPY_MODEL_4837276e5cf248449e287b1eeaef30ec", 1184 | "value": " 6.08k/6.08k [00:00<00:00, 279kB/s]" 1185 | } 1186 | }, 1187 | "9b9addf13301466b9ef30b9d4b836a67": { 1188 | "model_module": "@jupyter-widgets/base", 1189 | "model_module_version": "1.2.0", 1190 | "model_name": "LayoutModel", 1191 | "state": { 1192 | "_model_module": "@jupyter-widgets/base", 1193 | "_model_module_version": "1.2.0", 1194 | "_model_name": "LayoutModel", 1195 | "_view_count": null, 1196 | "_view_module": "@jupyter-widgets/base", 1197 | "_view_module_version": "1.2.0", 1198 | "_view_name": "LayoutView", 1199 | "align_content": null, 1200 | "align_items": null, 1201 | "align_self": null, 1202 | "border": null, 1203 | "bottom": null, 1204 | "display": null, 1205 | "flex": null, 1206 | "flex_flow": null, 1207 | "grid_area": null, 1208 | "grid_auto_columns": null, 1209 | "grid_auto_flow": null, 1210 | "grid_auto_rows": null, 1211 | "grid_column": null, 1212 | "grid_gap": null, 1213 | "grid_row": null, 1214 | "grid_template_areas": null, 1215 | "grid_template_columns": null, 1216 | "grid_template_rows": null, 1217 | "height": null, 1218 | "justify_content": null, 1219 | "justify_items": null, 1220 | "left": null, 1221 | "margin": null, 1222 | "max_height": null, 1223 | "max_width": null, 1224 | "min_height": null, 1225 | "min_width": null, 1226 | "object_fit": null, 1227 | "object_position": null, 1228 | "order": null, 1229 | "overflow": null, 1230 | "overflow_x": null, 1231 | "overflow_y": null, 1232 | "padding": null, 1233 | "right": null, 1234 | "top": null, 1235 | "visibility": null, 1236 | "width": null 1237 | } 1238 | }, 1239 | "9cf2d2e2bfe24f2ab185165d79da8bdb": { 1240 | "model_module": "@jupyter-widgets/controls", 1241 | "model_module_version": "1.5.0", 1242 | "model_name": "FloatProgressModel", 1243 | "state": { 1244 | "_dom_classes": [], 1245 | "_model_module": "@jupyter-widgets/controls", 1246 | "_model_module_version": "1.5.0", 1247 | "_model_name": "FloatProgressModel", 1248 | "_view_count": null, 1249 | "_view_module": "@jupyter-widgets/controls", 1250 | "_view_module_version": "1.5.0", 1251 | "_view_name": "ProgressView", 1252 | "bar_style": "success", 1253 | "description": "", 1254 | "description_tooltip": null, 1255 | "layout": "IPY_MODEL_140f33387db341398bc39e9c47703df4", 1256 | "max": 6081, 1257 | "min": 0, 1258 | "orientation": "horizontal", 1259 | "style": "IPY_MODEL_b3a8424c0b584a37ad2ede748085425c", 1260 | "value": 6081 1261 | } 1262 | }, 1263 | "a1188f80f78c49c7a822d71694e47074": { 1264 | "model_module": "@jupyter-widgets/base", 1265 | "model_module_version": "1.2.0", 1266 | "model_name": "LayoutModel", 1267 | "state": { 1268 | "_model_module": "@jupyter-widgets/base", 1269 | "_model_module_version": "1.2.0", 1270 | "_model_name": "LayoutModel", 1271 | "_view_count": null, 1272 | "_view_module": "@jupyter-widgets/base", 1273 | "_view_module_version": "1.2.0", 1274 | "_view_name": "LayoutView", 1275 | "align_content": null, 1276 | "align_items": null, 1277 | "align_self": null, 1278 | "border": null, 1279 | "bottom": null, 1280 | "display": null, 1281 | "flex": null, 1282 | "flex_flow": null, 1283 | "grid_area": null, 1284 | "grid_auto_columns": null, 1285 | "grid_auto_flow": null, 1286 | "grid_auto_rows": null, 1287 | "grid_column": null, 1288 | "grid_gap": null, 1289 | "grid_row": null, 1290 | "grid_template_areas": null, 1291 | "grid_template_columns": null, 1292 | "grid_template_rows": null, 1293 | "height": null, 1294 | "justify_content": null, 1295 | "justify_items": null, 1296 | "left": null, 1297 | "margin": null, 1298 | "max_height": null, 1299 | "max_width": null, 1300 | "min_height": null, 1301 | "min_width": null, 1302 | "object_fit": null, 1303 | "object_position": null, 1304 | "order": null, 1305 | "overflow": null, 1306 | "overflow_x": null, 1307 | "overflow_y": null, 1308 | "padding": null, 1309 | "right": null, 1310 | "top": null, 1311 | "visibility": null, 1312 | "width": null 1313 | } 1314 | }, 1315 | "a58ac736aa884eb9a27264cb04bb36ce": { 1316 | "model_module": "@jupyter-widgets/controls", 1317 | "model_module_version": "1.5.0", 1318 | "model_name": "FloatProgressModel", 1319 | "state": { 1320 | "_dom_classes": [], 1321 | "_model_module": "@jupyter-widgets/controls", 1322 | "_model_module_version": "1.5.0", 1323 | "_model_name": "FloatProgressModel", 1324 | "_view_count": null, 1325 | "_view_module": "@jupyter-widgets/controls", 1326 | "_view_module_version": "1.5.0", 1327 | "_view_name": "ProgressView", 1328 | "bar_style": "success", 1329 | "description": "", 1330 | "description_tooltip": null, 1331 | "layout": "IPY_MODEL_96baa91869eb478eb492754b98169470", 1332 | "max": 20464, 1333 | "min": 0, 1334 | "orientation": "horizontal", 1335 | "style": "IPY_MODEL_bbda5260ca1c450386f9191e9f9dde97", 1336 | "value": 20464 1337 | } 1338 | }, 1339 | "aa082ade829247dc8ea0d75cc8a5b2a7": { 1340 | "model_module": "@jupyter-widgets/controls", 1341 | "model_module_version": "1.5.0", 1342 | "model_name": "HBoxModel", 1343 | "state": { 1344 | "_dom_classes": [], 1345 | "_model_module": "@jupyter-widgets/controls", 1346 | "_model_module_version": "1.5.0", 1347 | "_model_name": "HBoxModel", 1348 | "_view_count": null, 1349 | "_view_module": "@jupyter-widgets/controls", 1350 | "_view_module_version": "1.5.0", 1351 | "_view_name": "HBoxView", 1352 | "box_style": "", 1353 | "children": [ 1354 | "IPY_MODEL_83bc41f428b7492e9defdaa177f33a3e", 1355 | "IPY_MODEL_7f168d0ea11c4ea1a96202d3a36ec389", 1356 | "IPY_MODEL_ebb7ee3fd084466f9667771a99e6e3b2" 1357 | ], 1358 | "layout": "IPY_MODEL_1e3c2a94251b4e75af0413a88b53bfe1" 1359 | } 1360 | }, 1361 | "aa17c3a834694a978046808fc5d29da1": { 1362 | "model_module": "@jupyter-widgets/controls", 1363 | "model_module_version": "1.5.0", 1364 | "model_name": "ProgressStyleModel", 1365 | "state": { 1366 | "_model_module": "@jupyter-widgets/controls", 1367 | "_model_module_version": "1.5.0", 1368 | "_model_name": "ProgressStyleModel", 1369 | "_view_count": null, 1370 | "_view_module": "@jupyter-widgets/base", 1371 | "_view_module_version": "1.2.0", 1372 | "_view_name": "StyleView", 1373 | "bar_color": null, 1374 | "description_width": "" 1375 | } 1376 | }, 1377 | "b317ba38f2b145f9b0b49f523547684f": { 1378 | "model_module": "@jupyter-widgets/controls", 1379 | "model_module_version": "1.5.0", 1380 | "model_name": "HTMLModel", 1381 | "state": { 1382 | "_dom_classes": [], 1383 | "_model_module": "@jupyter-widgets/controls", 1384 | "_model_module_version": "1.5.0", 1385 | "_model_name": "HTMLModel", 1386 | "_view_count": null, 1387 | "_view_module": "@jupyter-widgets/controls", 1388 | "_view_module_version": "1.5.0", 1389 | "_view_name": "HTMLView", 1390 | "description": "", 1391 | "description_tooltip": null, 1392 | "layout": "IPY_MODEL_4aed1fa58b7342eba35c2106ec934019", 1393 | "placeholder": "​", 1394 | "style": "IPY_MODEL_60c72c47a8d84f0eab652822bed1ed09", 1395 | "value": " 32332/32332 [00:01<00:00, 27628.23 examples/s]" 1396 | } 1397 | }, 1398 | "b3a8424c0b584a37ad2ede748085425c": { 1399 | "model_module": "@jupyter-widgets/controls", 1400 | "model_module_version": "1.5.0", 1401 | "model_name": "ProgressStyleModel", 1402 | "state": { 1403 | "_model_module": "@jupyter-widgets/controls", 1404 | "_model_module_version": "1.5.0", 1405 | "_model_name": "ProgressStyleModel", 1406 | "_view_count": null, 1407 | "_view_module": "@jupyter-widgets/base", 1408 | "_view_module_version": "1.2.0", 1409 | "_view_name": "StyleView", 1410 | "bar_color": null, 1411 | "description_width": "" 1412 | } 1413 | }, 1414 | "bbda5260ca1c450386f9191e9f9dde97": { 1415 | "model_module": "@jupyter-widgets/controls", 1416 | "model_module_version": "1.5.0", 1417 | "model_name": "ProgressStyleModel", 1418 | "state": { 1419 | "_model_module": "@jupyter-widgets/controls", 1420 | "_model_module_version": "1.5.0", 1421 | "_model_name": "ProgressStyleModel", 1422 | "_view_count": null, 1423 | "_view_module": "@jupyter-widgets/base", 1424 | "_view_module_version": "1.2.0", 1425 | "_view_name": "StyleView", 1426 | "bar_color": null, 1427 | "description_width": "" 1428 | } 1429 | }, 1430 | "c020b38c6d2c436e8b742fd87d3b8b89": { 1431 | "model_module": "@jupyter-widgets/base", 1432 | "model_module_version": "1.2.0", 1433 | "model_name": "LayoutModel", 1434 | "state": { 1435 | "_model_module": "@jupyter-widgets/base", 1436 | "_model_module_version": "1.2.0", 1437 | "_model_name": "LayoutModel", 1438 | "_view_count": null, 1439 | "_view_module": "@jupyter-widgets/base", 1440 | "_view_module_version": "1.2.0", 1441 | "_view_name": "LayoutView", 1442 | "align_content": null, 1443 | "align_items": null, 1444 | "align_self": null, 1445 | "border": null, 1446 | "bottom": null, 1447 | "display": null, 1448 | "flex": null, 1449 | "flex_flow": null, 1450 | "grid_area": null, 1451 | "grid_auto_columns": null, 1452 | "grid_auto_flow": null, 1453 | "grid_auto_rows": null, 1454 | "grid_column": null, 1455 | "grid_gap": null, 1456 | "grid_row": null, 1457 | "grid_template_areas": null, 1458 | "grid_template_columns": null, 1459 | "grid_template_rows": null, 1460 | "height": null, 1461 | "justify_content": null, 1462 | "justify_items": null, 1463 | "left": null, 1464 | "margin": null, 1465 | "max_height": null, 1466 | "max_width": null, 1467 | "min_height": null, 1468 | "min_width": null, 1469 | "object_fit": null, 1470 | "object_position": null, 1471 | "order": null, 1472 | "overflow": null, 1473 | "overflow_x": null, 1474 | "overflow_y": null, 1475 | "padding": null, 1476 | "right": null, 1477 | "top": null, 1478 | "visibility": null, 1479 | "width": null 1480 | } 1481 | }, 1482 | "c2d14fa4280c48e0ae04859b73c80781": { 1483 | "model_module": "@jupyter-widgets/controls", 1484 | "model_module_version": "1.5.0", 1485 | "model_name": "HBoxModel", 1486 | "state": { 1487 | "_dom_classes": [], 1488 | "_model_module": "@jupyter-widgets/controls", 1489 | "_model_module_version": "1.5.0", 1490 | "_model_name": "HBoxModel", 1491 | "_view_count": null, 1492 | "_view_module": "@jupyter-widgets/controls", 1493 | "_view_module_version": "1.5.0", 1494 | "_view_name": "HBoxView", 1495 | "box_style": "", 1496 | "children": [ 1497 | "IPY_MODEL_d3104837d9734834b7c87e87289b08df", 1498 | "IPY_MODEL_02b02005adf241a4a0be8173ca3a4aee", 1499 | "IPY_MODEL_b317ba38f2b145f9b0b49f523547684f" 1500 | ], 1501 | "layout": "IPY_MODEL_434340d109d1401d8868498a23b291cf" 1502 | } 1503 | }, 1504 | "c88027eb3e1c4771ab57366070ecd553": { 1505 | "model_module": "@jupyter-widgets/base", 1506 | "model_module_version": "1.2.0", 1507 | "model_name": "LayoutModel", 1508 | "state": { 1509 | "_model_module": "@jupyter-widgets/base", 1510 | "_model_module_version": "1.2.0", 1511 | "_model_name": "LayoutModel", 1512 | "_view_count": null, 1513 | "_view_module": "@jupyter-widgets/base", 1514 | "_view_module_version": "1.2.0", 1515 | "_view_name": "LayoutView", 1516 | "align_content": null, 1517 | "align_items": null, 1518 | "align_self": null, 1519 | "border": null, 1520 | "bottom": null, 1521 | "display": null, 1522 | "flex": null, 1523 | "flex_flow": null, 1524 | "grid_area": null, 1525 | "grid_auto_columns": null, 1526 | "grid_auto_flow": null, 1527 | "grid_auto_rows": null, 1528 | "grid_column": null, 1529 | "grid_gap": null, 1530 | "grid_row": null, 1531 | "grid_template_areas": null, 1532 | "grid_template_columns": null, 1533 | "grid_template_rows": null, 1534 | "height": null, 1535 | "justify_content": null, 1536 | "justify_items": null, 1537 | "left": null, 1538 | "margin": null, 1539 | "max_height": null, 1540 | "max_width": null, 1541 | "min_height": null, 1542 | "min_width": null, 1543 | "object_fit": null, 1544 | "object_position": null, 1545 | "order": null, 1546 | "overflow": null, 1547 | "overflow_x": null, 1548 | "overflow_y": null, 1549 | "padding": null, 1550 | "right": null, 1551 | "top": null, 1552 | "visibility": null, 1553 | "width": null 1554 | } 1555 | }, 1556 | "ca588157678e4cc09c3fd760676efd39": { 1557 | "model_module": "@jupyter-widgets/controls", 1558 | "model_module_version": "1.5.0", 1559 | "model_name": "DescriptionStyleModel", 1560 | "state": { 1561 | "_model_module": "@jupyter-widgets/controls", 1562 | "_model_module_version": "1.5.0", 1563 | "_model_name": "DescriptionStyleModel", 1564 | "_view_count": null, 1565 | "_view_module": "@jupyter-widgets/base", 1566 | "_view_module_version": "1.2.0", 1567 | "_view_name": "StyleView", 1568 | "description_width": "" 1569 | } 1570 | }, 1571 | "cb7d88a70af746f2ae31416b4b670c63": { 1572 | "model_module": "@jupyter-widgets/base", 1573 | "model_module_version": "1.2.0", 1574 | "model_name": "LayoutModel", 1575 | "state": { 1576 | "_model_module": "@jupyter-widgets/base", 1577 | "_model_module_version": "1.2.0", 1578 | "_model_name": "LayoutModel", 1579 | "_view_count": null, 1580 | "_view_module": "@jupyter-widgets/base", 1581 | "_view_module_version": "1.2.0", 1582 | "_view_name": "LayoutView", 1583 | "align_content": null, 1584 | "align_items": null, 1585 | "align_self": null, 1586 | "border": null, 1587 | "bottom": null, 1588 | "display": null, 1589 | "flex": null, 1590 | "flex_flow": null, 1591 | "grid_area": null, 1592 | "grid_auto_columns": null, 1593 | "grid_auto_flow": null, 1594 | "grid_auto_rows": null, 1595 | "grid_column": null, 1596 | "grid_gap": null, 1597 | "grid_row": null, 1598 | "grid_template_areas": null, 1599 | "grid_template_columns": null, 1600 | "grid_template_rows": null, 1601 | "height": null, 1602 | "justify_content": null, 1603 | "justify_items": null, 1604 | "left": null, 1605 | "margin": null, 1606 | "max_height": null, 1607 | "max_width": null, 1608 | "min_height": null, 1609 | "min_width": null, 1610 | "object_fit": null, 1611 | "object_position": null, 1612 | "order": null, 1613 | "overflow": null, 1614 | "overflow_x": null, 1615 | "overflow_y": null, 1616 | "padding": null, 1617 | "right": null, 1618 | "top": null, 1619 | "visibility": null, 1620 | "width": null 1621 | } 1622 | }, 1623 | "d3104837d9734834b7c87e87289b08df": { 1624 | "model_module": "@jupyter-widgets/controls", 1625 | "model_module_version": "1.5.0", 1626 | "model_name": "HTMLModel", 1627 | "state": { 1628 | "_dom_classes": [], 1629 | "_model_module": "@jupyter-widgets/controls", 1630 | "_model_module_version": "1.5.0", 1631 | "_model_name": "HTMLModel", 1632 | "_view_count": null, 1633 | "_view_module": "@jupyter-widgets/controls", 1634 | "_view_module_version": "1.5.0", 1635 | "_view_name": "HTMLView", 1636 | "description": "", 1637 | "description_tooltip": null, 1638 | "layout": "IPY_MODEL_2c95f5b81fc84ad698fe77b52cb84076", 1639 | "placeholder": "​", 1640 | "style": "IPY_MODEL_ca588157678e4cc09c3fd760676efd39", 1641 | "value": "Generating train split: 100%" 1642 | } 1643 | }, 1644 | "df75b255bfb04057b553830b59f0a153": { 1645 | "model_module": "@jupyter-widgets/controls", 1646 | "model_module_version": "1.5.0", 1647 | "model_name": "ProgressStyleModel", 1648 | "state": { 1649 | "_model_module": "@jupyter-widgets/controls", 1650 | "_model_module_version": "1.5.0", 1651 | "_model_name": "ProgressStyleModel", 1652 | "_view_count": null, 1653 | "_view_module": "@jupyter-widgets/base", 1654 | "_view_module_version": "1.2.0", 1655 | "_view_name": "StyleView", 1656 | "bar_color": null, 1657 | "description_width": "" 1658 | } 1659 | }, 1660 | "ebb7ee3fd084466f9667771a99e6e3b2": { 1661 | "model_module": "@jupyter-widgets/controls", 1662 | "model_module_version": "1.5.0", 1663 | "model_name": "HTMLModel", 1664 | "state": { 1665 | "_dom_classes": [], 1666 | "_model_module": "@jupyter-widgets/controls", 1667 | "_model_module_version": "1.5.0", 1668 | "_model_name": "HTMLModel", 1669 | "_view_count": null, 1670 | "_view_module": "@jupyter-widgets/controls", 1671 | "_view_module_version": "1.5.0", 1672 | "_view_name": "HTMLView", 1673 | "description": "", 1674 | "description_tooltip": null, 1675 | "layout": "IPY_MODEL_f0e5024d0d054c1eb8e01c4c8b027e79", 1676 | "placeholder": "​", 1677 | "style": "IPY_MODEL_937ee45f4d634d189c6d95c886e97bca", 1678 | "value": " 3.30M/3.30M [00:01<00:00, 2.77MB/s]" 1679 | } 1680 | }, 1681 | "ec2051bf0e9343d394e8a0ecb4fd5ec8": { 1682 | "model_module": "@jupyter-widgets/base", 1683 | "model_module_version": "1.2.0", 1684 | "model_name": "LayoutModel", 1685 | "state": { 1686 | "_model_module": "@jupyter-widgets/base", 1687 | "_model_module_version": "1.2.0", 1688 | "_model_name": "LayoutModel", 1689 | "_view_count": null, 1690 | "_view_module": "@jupyter-widgets/base", 1691 | "_view_module_version": "1.2.0", 1692 | "_view_name": "LayoutView", 1693 | "align_content": null, 1694 | "align_items": null, 1695 | "align_self": null, 1696 | "border": null, 1697 | "bottom": null, 1698 | "display": null, 1699 | "flex": null, 1700 | "flex_flow": null, 1701 | "grid_area": null, 1702 | "grid_auto_columns": null, 1703 | "grid_auto_flow": null, 1704 | "grid_auto_rows": null, 1705 | "grid_column": null, 1706 | "grid_gap": null, 1707 | "grid_row": null, 1708 | "grid_template_areas": null, 1709 | "grid_template_columns": null, 1710 | "grid_template_rows": null, 1711 | "height": null, 1712 | "justify_content": null, 1713 | "justify_items": null, 1714 | "left": null, 1715 | "margin": null, 1716 | "max_height": null, 1717 | "max_width": null, 1718 | "min_height": null, 1719 | "min_width": null, 1720 | "object_fit": null, 1721 | "object_position": null, 1722 | "order": null, 1723 | "overflow": null, 1724 | "overflow_x": null, 1725 | "overflow_y": null, 1726 | "padding": null, 1727 | "right": null, 1728 | "top": null, 1729 | "visibility": null, 1730 | "width": null 1731 | } 1732 | }, 1733 | "f0e5024d0d054c1eb8e01c4c8b027e79": { 1734 | "model_module": "@jupyter-widgets/base", 1735 | "model_module_version": "1.2.0", 1736 | "model_name": "LayoutModel", 1737 | "state": { 1738 | "_model_module": "@jupyter-widgets/base", 1739 | "_model_module_version": "1.2.0", 1740 | "_model_name": "LayoutModel", 1741 | "_view_count": null, 1742 | "_view_module": "@jupyter-widgets/base", 1743 | "_view_module_version": "1.2.0", 1744 | "_view_name": "LayoutView", 1745 | "align_content": null, 1746 | "align_items": null, 1747 | "align_self": null, 1748 | "border": null, 1749 | "bottom": null, 1750 | "display": null, 1751 | "flex": null, 1752 | "flex_flow": null, 1753 | "grid_area": null, 1754 | "grid_auto_columns": null, 1755 | "grid_auto_flow": null, 1756 | "grid_auto_rows": null, 1757 | "grid_column": null, 1758 | "grid_gap": null, 1759 | "grid_row": null, 1760 | "grid_template_areas": null, 1761 | "grid_template_columns": null, 1762 | "grid_template_rows": null, 1763 | "height": null, 1764 | "justify_content": null, 1765 | "justify_items": null, 1766 | "left": null, 1767 | "margin": null, 1768 | "max_height": null, 1769 | "max_width": null, 1770 | "min_height": null, 1771 | "min_width": null, 1772 | "object_fit": null, 1773 | "object_position": null, 1774 | "order": null, 1775 | "overflow": null, 1776 | "overflow_x": null, 1777 | "overflow_y": null, 1778 | "padding": null, 1779 | "right": null, 1780 | "top": null, 1781 | "visibility": null, 1782 | "width": null 1783 | } 1784 | }, 1785 | "f7359467b0214c5385de8ee4334f7ba3": { 1786 | "model_module": "@jupyter-widgets/controls", 1787 | "model_module_version": "1.5.0", 1788 | "model_name": "HTMLModel", 1789 | "state": { 1790 | "_dom_classes": [], 1791 | "_model_module": "@jupyter-widgets/controls", 1792 | "_model_module_version": "1.5.0", 1793 | "_model_name": "HTMLModel", 1794 | "_view_count": null, 1795 | "_view_module": "@jupyter-widgets/controls", 1796 | "_view_module_version": "1.5.0", 1797 | "_view_name": "HTMLView", 1798 | "description": "", 1799 | "description_tooltip": null, 1800 | "layout": "IPY_MODEL_05240e68c55a458286f43967e7f90889", 1801 | "placeholder": "​", 1802 | "style": "IPY_MODEL_8cfa6df0ee654643bfdb4a3825e8fbbe", 1803 | "value": "Downloading readme: 100%" 1804 | } 1805 | }, 1806 | "f74bdeb79a224de8b1c85f4ca8657331": { 1807 | "model_module": "@jupyter-widgets/controls", 1808 | "model_module_version": "1.5.0", 1809 | "model_name": "HTMLModel", 1810 | "state": { 1811 | "_dom_classes": [], 1812 | "_model_module": "@jupyter-widgets/controls", 1813 | "_model_module_version": "1.5.0", 1814 | "_model_name": "HTMLModel", 1815 | "_view_count": null, 1816 | "_view_module": "@jupyter-widgets/controls", 1817 | "_view_module_version": "1.5.0", 1818 | "_view_name": "HTMLView", 1819 | "description": "", 1820 | "description_tooltip": null, 1821 | "layout": "IPY_MODEL_888a323362ae4daeac99915bcb3dcf10", 1822 | "placeholder": "​", 1823 | "style": "IPY_MODEL_4d0e364e9f274e8ea7447e4e01c7f28f", 1824 | "value": "Downloading metadata: 100%" 1825 | } 1826 | } 1827 | } 1828 | } 1829 | }, 1830 | "nbformat": 4, 1831 | "nbformat_minor": 0 1832 | } 1833 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-transformer 2 | Attention is all you need implementation 3 | 4 | YouTube video with full step-by-step implementation: https://www.youtube.com/watch?v=ISNdQcPhsts 5 | 6 | -------------------------------------------------------------------------------- /attention_visual.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "import torch.nn as nn\n", 11 | "from model import Transformer\n", 12 | "from config import get_config, get_weights_file_path\n", 13 | "from train import get_model, get_ds, greedy_decode\n", 14 | "import altair as alt\n", 15 | "import pandas as pd\n", 16 | "import numpy as np\n", 17 | "import warnings\n", 18 | "warnings.filterwarnings(\"ignore\")" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": null, 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "# Define the device\n", 28 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", 29 | "print(\"Using device:\", device)" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": null, 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "config = get_config()\n", 39 | "train_dataloader, val_dataloader, vocab_src, vocab_tgt = get_ds(config)\n", 40 | "model = get_model(config, vocab_src.get_vocab_size(), vocab_tgt.get_vocab_size()).to(device)\n", 41 | "\n", 42 | "# Load the pretrained weights\n", 43 | "model_filename = get_weights_file_path(config, f\"29\")\n", 44 | "state = torch.load(model_filename)\n", 45 | "model.load_state_dict(state['model_state_dict'])" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": null, 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "def load_next_batch():\n", 55 | " # Load a sample batch from the validation set\n", 56 | " batch = next(iter(val_dataloader))\n", 57 | " encoder_input = batch[\"encoder_input\"].to(device)\n", 58 | " encoder_mask = batch[\"encoder_mask\"].to(device)\n", 59 | " decoder_input = batch[\"decoder_input\"].to(device)\n", 60 | " decoder_mask = batch[\"decoder_mask\"].to(device)\n", 61 | "\n", 62 | " encoder_input_tokens = [vocab_src.id_to_token(idx) for idx in encoder_input[0].cpu().numpy()]\n", 63 | " decoder_input_tokens = [vocab_tgt.id_to_token(idx) for idx in decoder_input[0].cpu().numpy()]\n", 64 | "\n", 65 | " # check that the batch size is 1\n", 66 | " assert encoder_input.size(\n", 67 | " 0) == 1, \"Batch size must be 1 for validation\"\n", 68 | "\n", 69 | " model_out = greedy_decode(\n", 70 | " model, encoder_input, encoder_mask, vocab_src, vocab_tgt, config['seq_len'], device)\n", 71 | " \n", 72 | " return batch, encoder_input_tokens, decoder_input_tokens" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": null, 78 | "metadata": {}, 79 | "outputs": [], 80 | "source": [ 81 | "def mtx2df(m, max_row, max_col, row_tokens, col_tokens):\n", 82 | " return pd.DataFrame(\n", 83 | " [\n", 84 | " (\n", 85 | " r,\n", 86 | " c,\n", 87 | " float(m[r, c]),\n", 88 | " \"%.3d %s\" % (r, row_tokens[r] if len(row_tokens) > r else \"\"),\n", 89 | " \"%.3d %s\" % (c, col_tokens[c] if len(col_tokens) > c else \"\"),\n", 90 | " )\n", 91 | " for r in range(m.shape[0])\n", 92 | " for c in range(m.shape[1])\n", 93 | " if r < max_row and c < max_col\n", 94 | " ],\n", 95 | " columns=[\"row\", \"column\", \"value\", \"row_token\", \"col_token\"],\n", 96 | " )\n", 97 | "\n", 98 | "def get_attn_map(attn_type: str, layer: int, head: int):\n", 99 | " if attn_type == \"encoder\":\n", 100 | " attn = model.encoder.layers[layer].self_attention_block.attention_scores\n", 101 | " elif attn_type == \"decoder\":\n", 102 | " attn = model.decoder.layers[layer].self_attention_block.attention_scores\n", 103 | " elif attn_type == \"encoder-decoder\":\n", 104 | " attn = model.decoder.layers[layer].cross_attention_block.attention_scores\n", 105 | " return attn[0, head].data\n", 106 | "\n", 107 | "def attn_map(attn_type, layer, head, row_tokens, col_tokens, max_sentence_len):\n", 108 | " df = mtx2df(\n", 109 | " get_attn_map(attn_type, layer, head),\n", 110 | " max_sentence_len,\n", 111 | " max_sentence_len,\n", 112 | " row_tokens,\n", 113 | " col_tokens,\n", 114 | " )\n", 115 | " return (\n", 116 | " alt.Chart(data=df)\n", 117 | " .mark_rect()\n", 118 | " .encode(\n", 119 | " x=alt.X(\"col_token\", axis=alt.Axis(title=\"\")),\n", 120 | " y=alt.Y(\"row_token\", axis=alt.Axis(title=\"\")),\n", 121 | " color=\"value\",\n", 122 | " tooltip=[\"row\", \"column\", \"value\", \"row_token\", \"col_token\"],\n", 123 | " )\n", 124 | " #.title(f\"Layer {layer} Head {head}\")\n", 125 | " .properties(height=400, width=400, title=f\"Layer {layer} Head {head}\")\n", 126 | " .interactive()\n", 127 | " )\n", 128 | "\n", 129 | "def get_all_attention_maps(attn_type: str, layers: list[int], heads: list[int], row_tokens: list, col_tokens, max_sentence_len: int):\n", 130 | " charts = []\n", 131 | " for layer in layers:\n", 132 | " rowCharts = []\n", 133 | " for head in heads:\n", 134 | " rowCharts.append(attn_map(attn_type, layer, head, row_tokens, col_tokens, max_sentence_len))\n", 135 | " charts.append(alt.hconcat(*rowCharts))\n", 136 | " return alt.vconcat(*charts)" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": null, 142 | "metadata": {}, 143 | "outputs": [], 144 | "source": [ 145 | "batch, encoder_input_tokens, decoder_input_tokens = load_next_batch()\n", 146 | "print(f'Source: {batch[\"src_text\"][0]}')\n", 147 | "print(f'Target: {batch[\"tgt_text\"][0]}')\n", 148 | "sentence_len = encoder_input_tokens.index(\"[PAD]\")" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": null, 154 | "metadata": {}, 155 | "outputs": [], 156 | "source": [ 157 | "layers = [0, 1, 2]\n", 158 | "heads = [0, 1, 2, 3, 4, 5, 6, 7]\n", 159 | "\n", 160 | "# Encoder Self-Attention\n", 161 | "get_all_attention_maps(\"encoder\", layers, heads, encoder_input_tokens, encoder_input_tokens, min(20, sentence_len))\n" 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": null, 167 | "metadata": {}, 168 | "outputs": [], 169 | "source": [ 170 | "# Encoder Self-Attention\n", 171 | "get_all_attention_maps(\"decoder\", layers, heads, decoder_input_tokens, decoder_input_tokens, min(20, sentence_len))" 172 | ] 173 | }, 174 | { 175 | "cell_type": "code", 176 | "execution_count": null, 177 | "metadata": {}, 178 | "outputs": [], 179 | "source": [ 180 | "# Encoder Self-Attention\n", 181 | "get_all_attention_maps(\"encoder-decoder\", layers, heads, encoder_input_tokens, decoder_input_tokens, min(20, sentence_len))" 182 | ] 183 | } 184 | ], 185 | "metadata": { 186 | "kernelspec": { 187 | "display_name": "transformer", 188 | "language": "python", 189 | "name": "python3" 190 | }, 191 | "language_info": { 192 | "codemirror_mode": { 193 | "name": "ipython", 194 | "version": 3 195 | }, 196 | "file_extension": ".py", 197 | "mimetype": "text/x-python", 198 | "name": "python", 199 | "nbconvert_exporter": "python", 200 | "pygments_lexer": "ipython3", 201 | "version": "3.10.6" 202 | }, 203 | "orig_nbformat": 4 204 | }, 205 | "nbformat": 4, 206 | "nbformat_minor": 2 207 | } 208 | -------------------------------------------------------------------------------- /conda.txt: -------------------------------------------------------------------------------- 1 | # This file may be used to create an environment using: 2 | # $ conda create --name --file 3 | # platform: linux-64 4 | @EXPLICIT 5 | https://repo.anaconda.com/pkgs/main/linux-64/_libgcc_mutex-0.1-main.conda 6 | https://repo.anaconda.com/pkgs/main/linux-64/ca-certificates-2023.08.22-h06a4308_0.conda 7 | https://repo.anaconda.com/pkgs/main/linux-64/ld_impl_linux-64-2.38-h1181459_1.conda 8 | https://repo.anaconda.com/pkgs/main/linux-64/libstdcxx-ng-11.2.0-h1234567_1.conda 9 | https://repo.anaconda.com/pkgs/main/noarch/tzdata-2023c-h04d1e81_0.conda 10 | https://repo.anaconda.com/pkgs/main/linux-64/libgomp-11.2.0-h1234567_1.conda 11 | https://repo.anaconda.com/pkgs/main/linux-64/_openmp_mutex-5.1-1_gnu.conda 12 | https://repo.anaconda.com/pkgs/main/linux-64/libgcc-ng-11.2.0-h1234567_1.conda 13 | https://repo.anaconda.com/pkgs/main/linux-64/libffi-3.4.4-h6a678d5_0.conda 14 | https://repo.anaconda.com/pkgs/main/linux-64/ncurses-6.4-h6a678d5_0.conda 15 | https://repo.anaconda.com/pkgs/main/linux-64/openssl-3.0.12-h7f8727e_0.conda 16 | https://repo.anaconda.com/pkgs/main/linux-64/xz-5.4.5-h5eee18b_0.conda 17 | https://repo.anaconda.com/pkgs/main/linux-64/zlib-1.2.13-h5eee18b_0.conda 18 | https://repo.anaconda.com/pkgs/main/linux-64/readline-8.2-h5eee18b_0.conda 19 | https://repo.anaconda.com/pkgs/main/linux-64/tk-8.6.12-h1ccaba5_0.conda 20 | https://repo.anaconda.com/pkgs/main/linux-64/sqlite-3.41.2-h5eee18b_0.conda 21 | https://repo.anaconda.com/pkgs/main/linux-64/python-3.9.18-h955ad1f_0.conda 22 | https://repo.anaconda.com/pkgs/main/linux-64/setuptools-68.0.0-py39h06a4308_0.conda 23 | https://repo.anaconda.com/pkgs/main/linux-64/wheel-0.41.2-py39h06a4308_0.conda 24 | https://repo.anaconda.com/pkgs/main/linux-64/pip-23.3.1-py39h06a4308_0.conda 25 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | def get_config(): 4 | return { 5 | "batch_size": 8, 6 | "num_epochs": 20, 7 | "lr": 10**-4, 8 | "seq_len": 350, 9 | "d_model": 512, 10 | "datasource": 'opus_books', 11 | "lang_src": "en", 12 | "lang_tgt": "it", 13 | "model_folder": "weights", 14 | "model_basename": "tmodel_", 15 | "preload": "latest", 16 | "tokenizer_file": "tokenizer_{0}.json", 17 | "experiment_name": "runs/tmodel" 18 | } 19 | 20 | def get_weights_file_path(config, epoch: str): 21 | model_folder = f"{config['datasource']}_{config['model_folder']}" 22 | model_filename = f"{config['model_basename']}{epoch}.pt" 23 | return str(Path('.') / model_folder / model_filename) 24 | 25 | # Find the latest weights file in the weights folder 26 | def latest_weights_file_path(config): 27 | model_folder = f"{config['datasource']}_{config['model_folder']}" 28 | model_filename = f"{config['model_basename']}*" 29 | weights_files = list(Path(model_folder).glob(model_filename)) 30 | if len(weights_files) == 0: 31 | return None 32 | weights_files.sort() 33 | return str(weights_files[-1]) 34 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.utils.data import Dataset 4 | 5 | class BilingualDataset(Dataset): 6 | 7 | def __init__(self, ds, tokenizer_src, tokenizer_tgt, src_lang, tgt_lang, seq_len): 8 | super().__init__() 9 | self.seq_len = seq_len 10 | 11 | self.ds = ds 12 | self.tokenizer_src = tokenizer_src 13 | self.tokenizer_tgt = tokenizer_tgt 14 | self.src_lang = src_lang 15 | self.tgt_lang = tgt_lang 16 | 17 | self.sos_token = torch.tensor([tokenizer_tgt.token_to_id("[SOS]")], dtype=torch.int64) 18 | self.eos_token = torch.tensor([tokenizer_tgt.token_to_id("[EOS]")], dtype=torch.int64) 19 | self.pad_token = torch.tensor([tokenizer_tgt.token_to_id("[PAD]")], dtype=torch.int64) 20 | 21 | def __len__(self): 22 | return len(self.ds) 23 | 24 | def __getitem__(self, idx): 25 | src_target_pair = self.ds[idx] 26 | src_text = src_target_pair['translation'][self.src_lang] 27 | tgt_text = src_target_pair['translation'][self.tgt_lang] 28 | 29 | # Transform the text into tokens 30 | enc_input_tokens = self.tokenizer_src.encode(src_text).ids 31 | dec_input_tokens = self.tokenizer_tgt.encode(tgt_text).ids 32 | 33 | # Add sos, eos and padding to each sentence 34 | enc_num_padding_tokens = self.seq_len - len(enc_input_tokens) - 2 # We will add and 35 | # We will only add , and only on the label 36 | dec_num_padding_tokens = self.seq_len - len(dec_input_tokens) - 1 37 | 38 | # Make sure the number of padding tokens is not negative. If it is, the sentence is too long 39 | if enc_num_padding_tokens < 0 or dec_num_padding_tokens < 0: 40 | raise ValueError("Sentence is too long") 41 | 42 | # Add and token 43 | encoder_input = torch.cat( 44 | [ 45 | self.sos_token, 46 | torch.tensor(enc_input_tokens, dtype=torch.int64), 47 | self.eos_token, 48 | torch.tensor([self.pad_token] * enc_num_padding_tokens, dtype=torch.int64), 49 | ], 50 | dim=0, 51 | ) 52 | 53 | # Add only token 54 | decoder_input = torch.cat( 55 | [ 56 | self.sos_token, 57 | torch.tensor(dec_input_tokens, dtype=torch.int64), 58 | torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64), 59 | ], 60 | dim=0, 61 | ) 62 | 63 | # Add only token 64 | label = torch.cat( 65 | [ 66 | torch.tensor(dec_input_tokens, dtype=torch.int64), 67 | self.eos_token, 68 | torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64), 69 | ], 70 | dim=0, 71 | ) 72 | 73 | # Double check the size of the tensors to make sure they are all seq_len long 74 | assert encoder_input.size(0) == self.seq_len 75 | assert decoder_input.size(0) == self.seq_len 76 | assert label.size(0) == self.seq_len 77 | 78 | return { 79 | "encoder_input": encoder_input, # (seq_len) 80 | "decoder_input": decoder_input, # (seq_len) 81 | "encoder_mask": (encoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).int(), # (1, 1, seq_len) 82 | "decoder_mask": (decoder_input != self.pad_token).unsqueeze(0).int() & causal_mask(decoder_input.size(0)), # (1, seq_len) & (1, seq_len, seq_len), 83 | "label": label, # (seq_len) 84 | "src_text": src_text, 85 | "tgt_text": tgt_text, 86 | } 87 | 88 | def causal_mask(size): 89 | mask = torch.triu(torch.ones((1, size, size)), diagonal=1).type(torch.int) 90 | return mask == 0 -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | 5 | class LayerNormalization(nn.Module): 6 | 7 | def __init__(self, features: int, eps:float=10**-6) -> None: 8 | super().__init__() 9 | self.eps = eps 10 | self.alpha = nn.Parameter(torch.ones(features)) # alpha is a learnable parameter 11 | self.bias = nn.Parameter(torch.zeros(features)) # bias is a learnable parameter 12 | 13 | def forward(self, x): 14 | # x: (batch, seq_len, hidden_size) 15 | # Keep the dimension for broadcasting 16 | mean = x.mean(dim = -1, keepdim = True) # (batch, seq_len, 1) 17 | # Keep the dimension for broadcasting 18 | std = x.std(dim = -1, keepdim = True) # (batch, seq_len, 1) 19 | # eps is to prevent dividing by zero or when std is very small 20 | return self.alpha * (x - mean) / (std + self.eps) + self.bias 21 | 22 | class FeedForwardBlock(nn.Module): 23 | 24 | def __init__(self, d_model: int, d_ff: int, dropout: float) -> None: 25 | super().__init__() 26 | self.linear_1 = nn.Linear(d_model, d_ff) # w1 and b1 27 | self.dropout = nn.Dropout(dropout) 28 | self.linear_2 = nn.Linear(d_ff, d_model) # w2 and b2 29 | 30 | def forward(self, x): 31 | # (batch, seq_len, d_model) --> (batch, seq_len, d_ff) --> (batch, seq_len, d_model) 32 | return self.linear_2(self.dropout(torch.relu(self.linear_1(x)))) 33 | 34 | class InputEmbeddings(nn.Module): 35 | 36 | def __init__(self, d_model: int, vocab_size: int) -> None: 37 | super().__init__() 38 | self.d_model = d_model 39 | self.vocab_size = vocab_size 40 | self.embedding = nn.Embedding(vocab_size, d_model) 41 | 42 | def forward(self, x): 43 | # (batch, seq_len) --> (batch, seq_len, d_model) 44 | # Multiply by sqrt(d_model) to scale the embeddings according to the paper 45 | return self.embedding(x) * math.sqrt(self.d_model) 46 | 47 | class PositionalEncoding(nn.Module): 48 | 49 | def __init__(self, d_model: int, seq_len: int, dropout: float) -> None: 50 | super().__init__() 51 | self.d_model = d_model 52 | self.seq_len = seq_len 53 | self.dropout = nn.Dropout(dropout) 54 | # Create a matrix of shape (seq_len, d_model) 55 | pe = torch.zeros(seq_len, d_model) 56 | # Create a vector of shape (seq_len) 57 | position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1) # (seq_len, 1) 58 | # Create a vector of shape (d_model) 59 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) # (d_model / 2) 60 | # Apply sine to even indices 61 | pe[:, 0::2] = torch.sin(position * div_term) # sin(position * (10000 ** (2i / d_model)) 62 | # Apply cosine to odd indices 63 | pe[:, 1::2] = torch.cos(position * div_term) # cos(position * (10000 ** (2i / d_model)) 64 | # Add a batch dimension to the positional encoding 65 | pe = pe.unsqueeze(0) # (1, seq_len, d_model) 66 | # Register the positional encoding as a buffer 67 | self.register_buffer('pe', pe) 68 | 69 | def forward(self, x): 70 | x = x + (self.pe[:, :x.shape[1], :]).requires_grad_(False) # (batch, seq_len, d_model) 71 | return self.dropout(x) 72 | 73 | class ResidualConnection(nn.Module): 74 | 75 | def __init__(self, features: int, dropout: float) -> None: 76 | super().__init__() 77 | self.dropout = nn.Dropout(dropout) 78 | self.norm = LayerNormalization(features) 79 | 80 | def forward(self, x, sublayer): 81 | return x + self.dropout(sublayer(self.norm(x))) 82 | 83 | class MultiHeadAttentionBlock(nn.Module): 84 | 85 | def __init__(self, d_model: int, h: int, dropout: float) -> None: 86 | super().__init__() 87 | self.d_model = d_model # Embedding vector size 88 | self.h = h # Number of heads 89 | # Make sure d_model is divisible by h 90 | assert d_model % h == 0, "d_model is not divisible by h" 91 | 92 | self.d_k = d_model // h # Dimension of vector seen by each head 93 | self.w_q = nn.Linear(d_model, d_model, bias=False) # Wq 94 | self.w_k = nn.Linear(d_model, d_model, bias=False) # Wk 95 | self.w_v = nn.Linear(d_model, d_model, bias=False) # Wv 96 | self.w_o = nn.Linear(d_model, d_model, bias=False) # Wo 97 | self.dropout = nn.Dropout(dropout) 98 | 99 | @staticmethod 100 | def attention(query, key, value, mask, dropout: nn.Dropout): 101 | d_k = query.shape[-1] 102 | # Just apply the formula from the paper 103 | # (batch, h, seq_len, d_k) --> (batch, h, seq_len, seq_len) 104 | attention_scores = (query @ key.transpose(-2, -1)) / math.sqrt(d_k) 105 | if mask is not None: 106 | # Write a very low value (indicating -inf) to the positions where mask == 0 107 | attention_scores.masked_fill_(mask == 0, -1e9) 108 | attention_scores = attention_scores.softmax(dim=-1) # (batch, h, seq_len, seq_len) # Apply softmax 109 | if dropout is not None: 110 | attention_scores = dropout(attention_scores) 111 | # (batch, h, seq_len, seq_len) --> (batch, h, seq_len, d_k) 112 | # return attention scores which can be used for visualization 113 | return (attention_scores @ value), attention_scores 114 | 115 | def forward(self, q, k, v, mask): 116 | query = self.w_q(q) # (batch, seq_len, d_model) --> (batch, seq_len, d_model) 117 | key = self.w_k(k) # (batch, seq_len, d_model) --> (batch, seq_len, d_model) 118 | value = self.w_v(v) # (batch, seq_len, d_model) --> (batch, seq_len, d_model) 119 | 120 | # (batch, seq_len, d_model) --> (batch, seq_len, h, d_k) --> (batch, h, seq_len, d_k) 121 | query = query.view(query.shape[0], query.shape[1], self.h, self.d_k).transpose(1, 2) 122 | key = key.view(key.shape[0], key.shape[1], self.h, self.d_k).transpose(1, 2) 123 | value = value.view(value.shape[0], value.shape[1], self.h, self.d_k).transpose(1, 2) 124 | 125 | # Calculate attention 126 | x, self.attention_scores = MultiHeadAttentionBlock.attention(query, key, value, mask, self.dropout) 127 | 128 | # Combine all the heads together 129 | # (batch, h, seq_len, d_k) --> (batch, seq_len, h, d_k) --> (batch, seq_len, d_model) 130 | x = x.transpose(1, 2).contiguous().view(x.shape[0], -1, self.h * self.d_k) 131 | 132 | # Multiply by Wo 133 | # (batch, seq_len, d_model) --> (batch, seq_len, d_model) 134 | return self.w_o(x) 135 | 136 | class EncoderBlock(nn.Module): 137 | 138 | def __init__(self, features: int, self_attention_block: MultiHeadAttentionBlock, feed_forward_block: FeedForwardBlock, dropout: float) -> None: 139 | super().__init__() 140 | self.self_attention_block = self_attention_block 141 | self.feed_forward_block = feed_forward_block 142 | self.residual_connections = nn.ModuleList([ResidualConnection(features, dropout) for _ in range(2)]) 143 | 144 | def forward(self, x, src_mask): 145 | x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, src_mask)) 146 | x = self.residual_connections[1](x, self.feed_forward_block) 147 | return x 148 | 149 | class Encoder(nn.Module): 150 | 151 | def __init__(self, features: int, layers: nn.ModuleList) -> None: 152 | super().__init__() 153 | self.layers = layers 154 | self.norm = LayerNormalization(features) 155 | 156 | def forward(self, x, mask): 157 | for layer in self.layers: 158 | x = layer(x, mask) 159 | return self.norm(x) 160 | 161 | class DecoderBlock(nn.Module): 162 | 163 | def __init__(self, features: int, self_attention_block: MultiHeadAttentionBlock, cross_attention_block: MultiHeadAttentionBlock, feed_forward_block: FeedForwardBlock, dropout: float) -> None: 164 | super().__init__() 165 | self.self_attention_block = self_attention_block 166 | self.cross_attention_block = cross_attention_block 167 | self.feed_forward_block = feed_forward_block 168 | self.residual_connections = nn.ModuleList([ResidualConnection(features, dropout) for _ in range(3)]) 169 | 170 | def forward(self, x, encoder_output, src_mask, tgt_mask): 171 | x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, tgt_mask)) 172 | x = self.residual_connections[1](x, lambda x: self.cross_attention_block(x, encoder_output, encoder_output, src_mask)) 173 | x = self.residual_connections[2](x, self.feed_forward_block) 174 | return x 175 | 176 | class Decoder(nn.Module): 177 | 178 | def __init__(self, features: int, layers: nn.ModuleList) -> None: 179 | super().__init__() 180 | self.layers = layers 181 | self.norm = LayerNormalization(features) 182 | 183 | def forward(self, x, encoder_output, src_mask, tgt_mask): 184 | for layer in self.layers: 185 | x = layer(x, encoder_output, src_mask, tgt_mask) 186 | return self.norm(x) 187 | 188 | class ProjectionLayer(nn.Module): 189 | 190 | def __init__(self, d_model, vocab_size) -> None: 191 | super().__init__() 192 | self.proj = nn.Linear(d_model, vocab_size) 193 | 194 | def forward(self, x) -> None: 195 | # (batch, seq_len, d_model) --> (batch, seq_len, vocab_size) 196 | return self.proj(x) 197 | 198 | class Transformer(nn.Module): 199 | 200 | def __init__(self, encoder: Encoder, decoder: Decoder, src_embed: InputEmbeddings, tgt_embed: InputEmbeddings, src_pos: PositionalEncoding, tgt_pos: PositionalEncoding, projection_layer: ProjectionLayer) -> None: 201 | super().__init__() 202 | self.encoder = encoder 203 | self.decoder = decoder 204 | self.src_embed = src_embed 205 | self.tgt_embed = tgt_embed 206 | self.src_pos = src_pos 207 | self.tgt_pos = tgt_pos 208 | self.projection_layer = projection_layer 209 | 210 | def encode(self, src, src_mask): 211 | # (batch, seq_len, d_model) 212 | src = self.src_embed(src) 213 | src = self.src_pos(src) 214 | return self.encoder(src, src_mask) 215 | 216 | def decode(self, encoder_output: torch.Tensor, src_mask: torch.Tensor, tgt: torch.Tensor, tgt_mask: torch.Tensor): 217 | # (batch, seq_len, d_model) 218 | tgt = self.tgt_embed(tgt) 219 | tgt = self.tgt_pos(tgt) 220 | return self.decoder(tgt, encoder_output, src_mask, tgt_mask) 221 | 222 | def project(self, x): 223 | # (batch, seq_len, vocab_size) 224 | return self.projection_layer(x) 225 | 226 | def build_transformer(src_vocab_size: int, tgt_vocab_size: int, src_seq_len: int, tgt_seq_len: int, d_model: int=512, N: int=6, h: int=8, dropout: float=0.1, d_ff: int=2048) -> Transformer: 227 | # Create the embedding layers 228 | src_embed = InputEmbeddings(d_model, src_vocab_size) 229 | tgt_embed = InputEmbeddings(d_model, tgt_vocab_size) 230 | 231 | # Create the positional encoding layers 232 | src_pos = PositionalEncoding(d_model, src_seq_len, dropout) 233 | tgt_pos = PositionalEncoding(d_model, tgt_seq_len, dropout) 234 | 235 | # Create the encoder blocks 236 | encoder_blocks = [] 237 | for _ in range(N): 238 | encoder_self_attention_block = MultiHeadAttentionBlock(d_model, h, dropout) 239 | feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout) 240 | encoder_block = EncoderBlock(d_model, encoder_self_attention_block, feed_forward_block, dropout) 241 | encoder_blocks.append(encoder_block) 242 | 243 | # Create the decoder blocks 244 | decoder_blocks = [] 245 | for _ in range(N): 246 | decoder_self_attention_block = MultiHeadAttentionBlock(d_model, h, dropout) 247 | decoder_cross_attention_block = MultiHeadAttentionBlock(d_model, h, dropout) 248 | feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout) 249 | decoder_block = DecoderBlock(d_model, decoder_self_attention_block, decoder_cross_attention_block, feed_forward_block, dropout) 250 | decoder_blocks.append(decoder_block) 251 | 252 | # Create the encoder and decoder 253 | encoder = Encoder(d_model, nn.ModuleList(encoder_blocks)) 254 | decoder = Decoder(d_model, nn.ModuleList(decoder_blocks)) 255 | 256 | # Create the projection layer 257 | projection_layer = ProjectionLayer(d_model, tgt_vocab_size) 258 | 259 | # Create the transformer 260 | transformer = Transformer(encoder, decoder, src_embed, tgt_embed, src_pos, tgt_pos, projection_layer) 261 | 262 | # Initialize the parameters 263 | for p in transformer.parameters(): 264 | if p.dim() > 1: 265 | nn.init.xavier_uniform_(p) 266 | 267 | return transformer -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ## Use python 3.9 2 | 3 | torch==2.0.1 4 | torchvision==0.15.2 5 | torchaudio==2.0.2 6 | torchtext==0.15.2 7 | datasets==2.15.0 8 | tokenizers==0.13.3 9 | torchmetrics==1.0.3 10 | tensorboard==2.13.0 11 | altair==5.1.1 12 | wandb==0.15.9 -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from model import build_transformer 2 | from dataset import BilingualDataset, causal_mask 3 | from config import get_config, get_weights_file_path, latest_weights_file_path 4 | 5 | import torchtext.datasets as datasets 6 | import torch 7 | import torch.nn as nn 8 | from torch.utils.data import Dataset, DataLoader, random_split 9 | from torch.optim.lr_scheduler import LambdaLR 10 | 11 | import warnings 12 | from tqdm import tqdm 13 | import os 14 | from pathlib import Path 15 | 16 | # Huggingface datasets and tokenizers 17 | from datasets import load_dataset 18 | from tokenizers import Tokenizer 19 | from tokenizers.models import WordLevel 20 | from tokenizers.trainers import WordLevelTrainer 21 | from tokenizers.pre_tokenizers import Whitespace 22 | 23 | import torchmetrics 24 | from torch.utils.tensorboard import SummaryWriter 25 | 26 | def greedy_decode(model, source, source_mask, tokenizer_src, tokenizer_tgt, max_len, device): 27 | sos_idx = tokenizer_tgt.token_to_id('[SOS]') 28 | eos_idx = tokenizer_tgt.token_to_id('[EOS]') 29 | 30 | # Precompute the encoder output and reuse it for every step 31 | encoder_output = model.encode(source, source_mask) 32 | # Initialize the decoder input with the sos token 33 | decoder_input = torch.empty(1, 1).fill_(sos_idx).type_as(source).to(device) 34 | while True: 35 | if decoder_input.size(1) == max_len: 36 | break 37 | 38 | # build mask for target 39 | decoder_mask = causal_mask(decoder_input.size(1)).type_as(source_mask).to(device) 40 | 41 | # calculate output 42 | out = model.decode(encoder_output, source_mask, decoder_input, decoder_mask) 43 | 44 | # get next token 45 | prob = model.project(out[:, -1]) 46 | _, next_word = torch.max(prob, dim=1) 47 | decoder_input = torch.cat( 48 | [decoder_input, torch.empty(1, 1).type_as(source).fill_(next_word.item()).to(device)], dim=1 49 | ) 50 | 51 | if next_word == eos_idx: 52 | break 53 | 54 | return decoder_input.squeeze(0) 55 | 56 | 57 | def run_validation(model, validation_ds, tokenizer_src, tokenizer_tgt, max_len, device, print_msg, global_step, writer, num_examples=2): 58 | model.eval() 59 | count = 0 60 | 61 | source_texts = [] 62 | expected = [] 63 | predicted = [] 64 | 65 | try: 66 | # get the console window width 67 | with os.popen('stty size', 'r') as console: 68 | _, console_width = console.read().split() 69 | console_width = int(console_width) 70 | except: 71 | # If we can't get the console width, use 80 as default 72 | console_width = 80 73 | 74 | with torch.no_grad(): 75 | for batch in validation_ds: 76 | count += 1 77 | encoder_input = batch["encoder_input"].to(device) # (b, seq_len) 78 | encoder_mask = batch["encoder_mask"].to(device) # (b, 1, 1, seq_len) 79 | 80 | # check that the batch size is 1 81 | assert encoder_input.size( 82 | 0) == 1, "Batch size must be 1 for validation" 83 | 84 | model_out = greedy_decode(model, encoder_input, encoder_mask, tokenizer_src, tokenizer_tgt, max_len, device) 85 | 86 | source_text = batch["src_text"][0] 87 | target_text = batch["tgt_text"][0] 88 | model_out_text = tokenizer_tgt.decode(model_out.detach().cpu().numpy()) 89 | 90 | source_texts.append(source_text) 91 | expected.append(target_text) 92 | predicted.append(model_out_text) 93 | 94 | # Print the source, target and model output 95 | print_msg('-'*console_width) 96 | print_msg(f"{f'SOURCE: ':>12}{source_text}") 97 | print_msg(f"{f'TARGET: ':>12}{target_text}") 98 | print_msg(f"{f'PREDICTED: ':>12}{model_out_text}") 99 | 100 | if count == num_examples: 101 | print_msg('-'*console_width) 102 | break 103 | 104 | if writer: 105 | # Evaluate the character error rate 106 | # Compute the char error rate 107 | metric = torchmetrics.CharErrorRate() 108 | cer = metric(predicted, expected) 109 | writer.add_scalar('validation cer', cer, global_step) 110 | writer.flush() 111 | 112 | # Compute the word error rate 113 | metric = torchmetrics.WordErrorRate() 114 | wer = metric(predicted, expected) 115 | writer.add_scalar('validation wer', wer, global_step) 116 | writer.flush() 117 | 118 | # Compute the BLEU metric 119 | metric = torchmetrics.BLEUScore() 120 | bleu = metric(predicted, expected) 121 | writer.add_scalar('validation BLEU', bleu, global_step) 122 | writer.flush() 123 | 124 | def get_all_sentences(ds, lang): 125 | for item in ds: 126 | yield item['translation'][lang] 127 | 128 | def get_or_build_tokenizer(config, ds, lang): 129 | tokenizer_path = Path(config['tokenizer_file'].format(lang)) 130 | if not Path.exists(tokenizer_path): 131 | # Most code taken from: https://huggingface.co/docs/tokenizers/quicktour 132 | tokenizer = Tokenizer(WordLevel(unk_token="[UNK]")) 133 | tokenizer.pre_tokenizer = Whitespace() 134 | trainer = WordLevelTrainer(special_tokens=["[UNK]", "[PAD]", "[SOS]", "[EOS]"], min_frequency=2) 135 | tokenizer.train_from_iterator(get_all_sentences(ds, lang), trainer=trainer) 136 | tokenizer.save(str(tokenizer_path)) 137 | else: 138 | tokenizer = Tokenizer.from_file(str(tokenizer_path)) 139 | return tokenizer 140 | 141 | def get_ds(config): 142 | # It only has the train split, so we divide it overselves 143 | ds_raw = load_dataset(f"{config['datasource']}", f"{config['lang_src']}-{config['lang_tgt']}", split='train') 144 | 145 | # Build tokenizers 146 | tokenizer_src = get_or_build_tokenizer(config, ds_raw, config['lang_src']) 147 | tokenizer_tgt = get_or_build_tokenizer(config, ds_raw, config['lang_tgt']) 148 | 149 | # Keep 90% for training, 10% for validation 150 | train_ds_size = int(0.9 * len(ds_raw)) 151 | val_ds_size = len(ds_raw) - train_ds_size 152 | train_ds_raw, val_ds_raw = random_split(ds_raw, [train_ds_size, val_ds_size]) 153 | 154 | train_ds = BilingualDataset(train_ds_raw, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len']) 155 | val_ds = BilingualDataset(val_ds_raw, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len']) 156 | 157 | # Find the maximum length of each sentence in the source and target sentence 158 | max_len_src = 0 159 | max_len_tgt = 0 160 | 161 | for item in ds_raw: 162 | src_ids = tokenizer_src.encode(item['translation'][config['lang_src']]).ids 163 | tgt_ids = tokenizer_tgt.encode(item['translation'][config['lang_tgt']]).ids 164 | max_len_src = max(max_len_src, len(src_ids)) 165 | max_len_tgt = max(max_len_tgt, len(tgt_ids)) 166 | 167 | print(f'Max length of source sentence: {max_len_src}') 168 | print(f'Max length of target sentence: {max_len_tgt}') 169 | 170 | 171 | train_dataloader = DataLoader(train_ds, batch_size=config['batch_size'], shuffle=True) 172 | val_dataloader = DataLoader(val_ds, batch_size=1, shuffle=True) 173 | 174 | return train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt 175 | 176 | def get_model(config, vocab_src_len, vocab_tgt_len): 177 | model = build_transformer(vocab_src_len, vocab_tgt_len, config["seq_len"], config['seq_len'], d_model=config['d_model']) 178 | return model 179 | 180 | def train_model(config): 181 | # Define the device 182 | device = "cuda" if torch.cuda.is_available() else "mps" if torch.has_mps or torch.backends.mps.is_available() else "cpu" 183 | print("Using device:", device) 184 | if (device == 'cuda'): 185 | print(f"Device name: {torch.cuda.get_device_name(device.index)}") 186 | print(f"Device memory: {torch.cuda.get_device_properties(device.index).total_memory / 1024 ** 3} GB") 187 | elif (device == 'mps'): 188 | print(f"Device name: ") 189 | else: 190 | print("NOTE: If you have a GPU, consider using it for training.") 191 | print(" On a Windows machine with NVidia GPU, check this video: https://www.youtube.com/watch?v=GMSjDTU8Zlc") 192 | print(" On a Mac machine, run: pip3 install --pre torch torchvision torchaudio torchtext --index-url https://download.pytorch.org/whl/nightly/cpu") 193 | device = torch.device(device) 194 | 195 | # Make sure the weights folder exists 196 | Path(f"{config['datasource']}_{config['model_folder']}").mkdir(parents=True, exist_ok=True) 197 | 198 | train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config) 199 | model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()).to(device) 200 | # Tensorboard 201 | writer = SummaryWriter(config['experiment_name']) 202 | 203 | optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'], eps=1e-9) 204 | 205 | # If the user specified a model to preload before training, load it 206 | initial_epoch = 0 207 | global_step = 0 208 | preload = config['preload'] 209 | model_filename = latest_weights_file_path(config) if preload == 'latest' else get_weights_file_path(config, preload) if preload else None 210 | if model_filename: 211 | print(f'Preloading model {model_filename}') 212 | state = torch.load(model_filename) 213 | model.load_state_dict(state['model_state_dict']) 214 | initial_epoch = state['epoch'] + 1 215 | optimizer.load_state_dict(state['optimizer_state_dict']) 216 | global_step = state['global_step'] 217 | else: 218 | print('No model to preload, starting from scratch') 219 | 220 | loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer_src.token_to_id('[PAD]'), label_smoothing=0.1).to(device) 221 | 222 | for epoch in range(initial_epoch, config['num_epochs']): 223 | torch.cuda.empty_cache() 224 | model.train() 225 | batch_iterator = tqdm(train_dataloader, desc=f"Processing Epoch {epoch:02d}") 226 | for batch in batch_iterator: 227 | 228 | encoder_input = batch['encoder_input'].to(device) # (b, seq_len) 229 | decoder_input = batch['decoder_input'].to(device) # (B, seq_len) 230 | encoder_mask = batch['encoder_mask'].to(device) # (B, 1, 1, seq_len) 231 | decoder_mask = batch['decoder_mask'].to(device) # (B, 1, seq_len, seq_len) 232 | 233 | # Run the tensors through the encoder, decoder and the projection layer 234 | encoder_output = model.encode(encoder_input, encoder_mask) # (B, seq_len, d_model) 235 | decoder_output = model.decode(encoder_output, encoder_mask, decoder_input, decoder_mask) # (B, seq_len, d_model) 236 | proj_output = model.project(decoder_output) # (B, seq_len, vocab_size) 237 | 238 | # Compare the output with the label 239 | label = batch['label'].to(device) # (B, seq_len) 240 | 241 | # Compute the loss using a simple cross entropy 242 | loss = loss_fn(proj_output.view(-1, tokenizer_tgt.get_vocab_size()), label.view(-1)) 243 | batch_iterator.set_postfix({"loss": f"{loss.item():6.3f}"}) 244 | 245 | # Log the loss 246 | writer.add_scalar('train loss', loss.item(), global_step) 247 | writer.flush() 248 | 249 | # Backpropagate the loss 250 | loss.backward() 251 | 252 | # Update the weights 253 | optimizer.step() 254 | optimizer.zero_grad(set_to_none=True) 255 | 256 | global_step += 1 257 | 258 | # Run validation at the end of every epoch 259 | run_validation(model, val_dataloader, tokenizer_src, tokenizer_tgt, config['seq_len'], device, lambda msg: batch_iterator.write(msg), global_step, writer) 260 | 261 | # Save the model at the end of every epoch 262 | model_filename = get_weights_file_path(config, f"{epoch:02d}") 263 | torch.save({ 264 | 'epoch': epoch, 265 | 'model_state_dict': model.state_dict(), 266 | 'optimizer_state_dict': optimizer.state_dict(), 267 | 'global_step': global_step 268 | }, model_filename) 269 | 270 | 271 | if __name__ == '__main__': 272 | warnings.filterwarnings("ignore") 273 | config = get_config() 274 | train_model(config) 275 | -------------------------------------------------------------------------------- /train_wb.py: -------------------------------------------------------------------------------- 1 | from model import build_transformer 2 | from dataset import BilingualDataset, causal_mask 3 | from config import get_config, get_weights_file_path 4 | 5 | import torchtext.datasets as datasets 6 | import torch 7 | import torch.nn as nn 8 | from torch.utils.data import Dataset, DataLoader, random_split 9 | from torch.optim.lr_scheduler import LambdaLR 10 | 11 | import warnings 12 | from tqdm import tqdm 13 | import os 14 | from pathlib import Path 15 | 16 | # Huggingface datasets and tokenizers 17 | from datasets import load_dataset 18 | from tokenizers import Tokenizer 19 | from tokenizers.models import WordLevel 20 | from tokenizers.trainers import WordLevelTrainer 21 | from tokenizers.pre_tokenizers import Whitespace 22 | 23 | import wandb 24 | 25 | import torchmetrics 26 | 27 | def greedy_decode(model, source, source_mask, tokenizer_src, tokenizer_tgt, max_len, device): 28 | sos_idx = tokenizer_tgt.token_to_id('[SOS]') 29 | eos_idx = tokenizer_tgt.token_to_id('[EOS]') 30 | 31 | # Precompute the encoder output and reuse it for every step 32 | encoder_output = model.encode(source, source_mask) 33 | # Initialize the decoder input with the sos token 34 | decoder_input = torch.empty(1, 1).fill_(sos_idx).type_as(source).to(device) 35 | while True: 36 | if decoder_input.size(1) == max_len: 37 | break 38 | 39 | # build mask for target 40 | decoder_mask = causal_mask(decoder_input.size(1)).type_as(source_mask).to(device) 41 | 42 | # calculate output 43 | out = model.decode(encoder_output, source_mask, decoder_input, decoder_mask) 44 | 45 | # get next token 46 | prob = model.project(out[:, -1]) 47 | _, next_word = torch.max(prob, dim=1) 48 | decoder_input = torch.cat( 49 | [decoder_input, torch.empty(1, 1).type_as(source).fill_(next_word.item()).to(device)], dim=1 50 | ) 51 | 52 | if next_word == eos_idx: 53 | break 54 | 55 | return decoder_input.squeeze(0) 56 | 57 | 58 | def run_validation(model, validation_ds, tokenizer_src, tokenizer_tgt, max_len, device, print_msg, global_step, num_examples=2): 59 | model.eval() 60 | count = 0 61 | 62 | source_texts = [] 63 | expected = [] 64 | predicted = [] 65 | 66 | try: 67 | # get the console window width 68 | with os.popen('stty size', 'r') as console: 69 | _, console_width = console.read().split() 70 | console_width = int(console_width) 71 | except: 72 | # If we can't get the console width, use 80 as default 73 | console_width = 80 74 | 75 | with torch.no_grad(): 76 | for batch in validation_ds: 77 | count += 1 78 | encoder_input = batch["encoder_input"].to(device) # (b, seq_len) 79 | encoder_mask = batch["encoder_mask"].to(device) # (b, 1, 1, seq_len) 80 | 81 | # check that the batch size is 1 82 | assert encoder_input.size( 83 | 0) == 1, "Batch size must be 1 for validation" 84 | 85 | model_out = greedy_decode(model, encoder_input, encoder_mask, tokenizer_src, tokenizer_tgt, max_len, device) 86 | 87 | source_text = batch["src_text"][0] 88 | target_text = batch["tgt_text"][0] 89 | model_out_text = tokenizer_tgt.decode(model_out.detach().cpu().numpy()) 90 | 91 | source_texts.append(source_text) 92 | expected.append(target_text) 93 | predicted.append(model_out_text) 94 | 95 | # Print the source, target and model output 96 | print_msg('-'*console_width) 97 | print_msg(f"{f'SOURCE: ':>12}{source_text}") 98 | print_msg(f"{f'TARGET: ':>12}{target_text}") 99 | print_msg(f"{f'PREDICTED: ':>12}{model_out_text}") 100 | 101 | if count == num_examples: 102 | print_msg('-'*console_width) 103 | break 104 | 105 | 106 | # Evaluate the character error rate 107 | # Compute the char error rate 108 | metric = torchmetrics.CharErrorRate() 109 | cer = metric(predicted, expected) 110 | wandb.log({'validation/cer': cer, 'global_step': global_step}) 111 | 112 | # Compute the word error rate 113 | metric = torchmetrics.WordErrorRate() 114 | wer = metric(predicted, expected) 115 | wandb.log({'validation/wer': wer, 'global_step': global_step}) 116 | 117 | # Compute the BLEU metric 118 | metric = torchmetrics.BLEUScore() 119 | bleu = metric(predicted, expected) 120 | wandb.log({'validation/BLEU': bleu, 'global_step': global_step}) 121 | 122 | def get_all_sentences(ds, lang): 123 | for item in ds: 124 | yield item['translation'][lang] 125 | 126 | def get_or_build_tokenizer(config, ds, lang): 127 | tokenizer_path = Path(config['tokenizer_file'].format(lang)) 128 | if not Path.exists(tokenizer_path): 129 | # Most code taken from: https://huggingface.co/docs/tokenizers/quicktour 130 | tokenizer = Tokenizer(WordLevel(unk_token="[UNK]")) 131 | tokenizer.pre_tokenizer = Whitespace() 132 | trainer = WordLevelTrainer(special_tokens=["[UNK]", "[PAD]", "[SOS]", "[EOS]"], min_frequency=2) 133 | tokenizer.train_from_iterator(get_all_sentences(ds, lang), trainer=trainer) 134 | tokenizer.save(str(tokenizer_path)) 135 | else: 136 | tokenizer = Tokenizer.from_file(str(tokenizer_path)) 137 | return tokenizer 138 | 139 | def get_ds(config): 140 | # It only has the train split, so we divide it overselves 141 | ds_raw = load_dataset('opus_books', f"{config['lang_src']}-{config['lang_tgt']}", split='train') 142 | 143 | # Build tokenizers 144 | tokenizer_src = get_or_build_tokenizer(config, ds_raw, config['lang_src']) 145 | tokenizer_tgt = get_or_build_tokenizer(config, ds_raw, config['lang_tgt']) 146 | 147 | # Keep 90% for training, 10% for validation 148 | train_ds_size = int(0.9 * len(ds_raw)) 149 | val_ds_size = len(ds_raw) - train_ds_size 150 | train_ds_raw, val_ds_raw = random_split(ds_raw, [train_ds_size, val_ds_size]) 151 | 152 | train_ds = BilingualDataset(train_ds_raw, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len']) 153 | val_ds = BilingualDataset(val_ds_raw, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len']) 154 | 155 | # Find the maximum length of each sentence in the source and target sentence 156 | max_len_src = 0 157 | max_len_tgt = 0 158 | 159 | for item in ds_raw: 160 | src_ids = tokenizer_src.encode(item['translation'][config['lang_src']]).ids 161 | tgt_ids = tokenizer_tgt.encode(item['translation'][config['lang_tgt']]).ids 162 | max_len_src = max(max_len_src, len(src_ids)) 163 | max_len_tgt = max(max_len_tgt, len(tgt_ids)) 164 | 165 | print(f'Max length of source sentence: {max_len_src}') 166 | print(f'Max length of target sentence: {max_len_tgt}') 167 | 168 | 169 | train_dataloader = DataLoader(train_ds, batch_size=config['batch_size'], shuffle=True) 170 | val_dataloader = DataLoader(val_ds, batch_size=1, shuffle=True) 171 | 172 | return train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt 173 | 174 | def get_model(config, vocab_src_len, vocab_tgt_len): 175 | model = build_transformer(vocab_src_len, vocab_tgt_len, config["seq_len"], config['seq_len'], d_model=config['d_model']) 176 | return model 177 | 178 | def train_model(config): 179 | # Define the device 180 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 181 | print("Using device:", device) 182 | 183 | # Make sure the weights folder exists 184 | Path(config['model_folder']).mkdir(parents=True, exist_ok=True) 185 | 186 | train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config) 187 | model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()).to(device) 188 | 189 | optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'], eps=1e-9) 190 | 191 | # If the user specified a model to preload before training, load it 192 | initial_epoch = 0 193 | global_step = 0 194 | if config['preload']: 195 | model_filename = get_weights_file_path(config, config['preload']) 196 | print(f'Preloading model {model_filename}') 197 | state = torch.load(model_filename) 198 | model.load_state_dict(state['model_state_dict']) 199 | initial_epoch = state['epoch'] + 1 200 | optimizer.load_state_dict(state['optimizer_state_dict']) 201 | global_step = state['global_step'] 202 | del state 203 | 204 | loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer_src.token_to_id('[PAD]'), label_smoothing=0.1).to(device) 205 | 206 | # define our custom x axis metric 207 | wandb.define_metric("global_step") 208 | # define which metrics will be plotted against it 209 | wandb.define_metric("validation/*", step_metric="global_step") 210 | wandb.define_metric("train/*", step_metric="global_step") 211 | 212 | for epoch in range(initial_epoch, config['num_epochs']): 213 | torch.cuda.empty_cache() 214 | model.train() 215 | batch_iterator = tqdm(train_dataloader, desc=f"Processing Epoch {epoch:02d}") 216 | for batch in batch_iterator: 217 | 218 | encoder_input = batch['encoder_input'].to(device) # (b, seq_len) 219 | decoder_input = batch['decoder_input'].to(device) # (B, seq_len) 220 | encoder_mask = batch['encoder_mask'].to(device) # (B, 1, 1, seq_len) 221 | decoder_mask = batch['decoder_mask'].to(device) # (B, 1, seq_len, seq_len) 222 | 223 | # Run the tensors through the encoder, decoder and the projection layer 224 | encoder_output = model.encode(encoder_input, encoder_mask) # (B, seq_len, d_model) 225 | decoder_output = model.decode(encoder_output, encoder_mask, decoder_input, decoder_mask) # (B, seq_len, d_model) 226 | proj_output = model.project(decoder_output) # (B, seq_len, vocab_size) 227 | 228 | # Compare the output with the label 229 | label = batch['label'].to(device) # (B, seq_len) 230 | 231 | # Compute the loss using a simple cross entropy 232 | loss = loss_fn(proj_output.view(-1, tokenizer_tgt.get_vocab_size()), label.view(-1)) 233 | batch_iterator.set_postfix({"loss": f"{loss.item():6.3f}"}) 234 | 235 | # Log the loss 236 | wandb.log({'train/loss': loss.item(), 'global_step': global_step}) 237 | 238 | # Backpropagate the loss 239 | loss.backward() 240 | 241 | # Update the weights 242 | optimizer.step() 243 | optimizer.zero_grad(set_to_none=True) 244 | 245 | global_step += 1 246 | 247 | # Run validation at the end of every epoch 248 | run_validation(model, val_dataloader, tokenizer_src, tokenizer_tgt, config['seq_len'], device, lambda msg: batch_iterator.write(msg), global_step) 249 | 250 | # Save the model at the end of every epoch 251 | model_filename = get_weights_file_path(config, f"{epoch:02d}") 252 | torch.save({ 253 | 'epoch': epoch, 254 | 'model_state_dict': model.state_dict(), 255 | 'optimizer_state_dict': optimizer.state_dict(), 256 | 'global_step': global_step 257 | }, model_filename) 258 | 259 | 260 | if __name__ == '__main__': 261 | warnings.filterwarnings("ignore") 262 | config = get_config() 263 | config['num_epochs'] = 30 264 | config['preload'] = None 265 | 266 | wandb.init( 267 | # set the wandb project where this run will be logged 268 | project="pytorch-transformer", 269 | 270 | # track hyperparameters and run metadata 271 | config=config 272 | ) 273 | 274 | train_model(config) 275 | -------------------------------------------------------------------------------- /translate.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from config import get_config, latest_weights_file_path 3 | from model import build_transformer 4 | from tokenizers import Tokenizer 5 | from datasets import load_dataset 6 | from dataset import BilingualDataset 7 | import torch 8 | import sys 9 | 10 | def translate(sentence: str): 11 | # Define the device, tokenizers, and model 12 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 13 | print("Using device:", device) 14 | config = get_config() 15 | tokenizer_src = Tokenizer.from_file(str(Path(config['tokenizer_file'].format(config['lang_src'])))) 16 | tokenizer_tgt = Tokenizer.from_file(str(Path(config['tokenizer_file'].format(config['lang_tgt'])))) 17 | model = build_transformer(tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size(), config["seq_len"], config['seq_len'], d_model=config['d_model']).to(device) 18 | 19 | # Load the pretrained weights 20 | model_filename = latest_weights_file_path(config) 21 | state = torch.load(model_filename) 22 | model.load_state_dict(state['model_state_dict']) 23 | 24 | # if the sentence is a number use it as an index to the test set 25 | label = "" 26 | if type(sentence) == int or sentence.isdigit(): 27 | id = int(sentence) 28 | ds = load_dataset(f"{config['datasource']}", f"{config['lang_src']}-{config['lang_tgt']}", split='all') 29 | ds = BilingualDataset(ds, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len']) 30 | sentence = ds[id]['src_text'] 31 | label = ds[id]["tgt_text"] 32 | seq_len = config['seq_len'] 33 | 34 | # translate the sentence 35 | model.eval() 36 | with torch.no_grad(): 37 | # Precompute the encoder output and reuse it for every generation step 38 | source = tokenizer_src.encode(sentence) 39 | source = torch.cat([ 40 | torch.tensor([tokenizer_src.token_to_id('[SOS]')], dtype=torch.int64), 41 | torch.tensor(source.ids, dtype=torch.int64), 42 | torch.tensor([tokenizer_src.token_to_id('[EOS]')], dtype=torch.int64), 43 | torch.tensor([tokenizer_src.token_to_id('[PAD]')] * (seq_len - len(source.ids) - 2), dtype=torch.int64) 44 | ], dim=0).to(device) 45 | source_mask = (source != tokenizer_src.token_to_id('[PAD]')).unsqueeze(0).unsqueeze(0).int().to(device) 46 | encoder_output = model.encode(source, source_mask) 47 | 48 | # Initialize the decoder input with the sos token 49 | decoder_input = torch.empty(1, 1).fill_(tokenizer_tgt.token_to_id('[SOS]')).type_as(source).to(device) 50 | 51 | # Print the source sentence and target start prompt 52 | if label != "": print(f"{f'ID: ':>12}{id}") 53 | print(f"{f'SOURCE: ':>12}{sentence}") 54 | if label != "": print(f"{f'TARGET: ':>12}{label}") 55 | print(f"{f'PREDICTED: ':>12}", end='') 56 | 57 | # Generate the translation word by word 58 | while decoder_input.size(1) < seq_len: 59 | # build mask for target and calculate output 60 | decoder_mask = torch.triu(torch.ones((1, decoder_input.size(1), decoder_input.size(1))), diagonal=1).type(torch.int).type_as(source_mask).to(device) 61 | out = model.decode(encoder_output, source_mask, decoder_input, decoder_mask) 62 | 63 | # project next token 64 | prob = model.project(out[:, -1]) 65 | _, next_word = torch.max(prob, dim=1) 66 | decoder_input = torch.cat([decoder_input, torch.empty(1, 1).type_as(source).fill_(next_word.item()).to(device)], dim=1) 67 | 68 | # print the translated word 69 | print(f"{tokenizer_tgt.decode([next_word.item()])}", end=' ') 70 | 71 | # break if we predict the end of sentence token 72 | if next_word == tokenizer_tgt.token_to_id('[EOS]'): 73 | break 74 | 75 | # convert ids to tokens 76 | return tokenizer_tgt.decode(decoder_input[0].tolist()) 77 | 78 | #read sentence from argument 79 | translate(sys.argv[1] if len(sys.argv) > 1 else "I am not a very good a student.") --------------------------------------------------------------------------------