├── README.md ├── mt5_soft_prompt_tuning_large.ipynb └── mt5_soft_prompt_tuning.ipynb /README.md: -------------------------------------------------------------------------------- 1 | # mt5-soft-prompt-tuning 2 | 3 | 下面链接同repo里面的ipynb 4 | 5 | [Colab mt5-base](https://colab.research.google.com/drive/1PRx6tABbx2BwI38pSfOlKuXruAdhfTLj?usp=sharing) 6 | 7 | [Colab mt5-large](https://colab.research.google.com/drive/15lWNnfs4FdC379IJh-SVKG_sFl55GpaO?usp=sharing) 8 | 9 | Code copy and change from: [Repo: soft-prompt-tuning](https://github.com/kipgparker/soft-prompt-tuning) 10 | 11 | [Paper: The Power of Scale for Parameter-Efficient Prompt Tuning](https://arxiv.org/pdf/2104.08691.pdf) 12 | 13 | [Paper: mT5: A massively multilingual pre-trained text-to-text transformer](https://arxiv.org/pdf/2010.11934) 14 | 15 | [Repo: mT5: Multilingual T5](https://github.com/google-research/multilingual-t5) 16 | -------------------------------------------------------------------------------- /mt5_soft_prompt_tuning_large.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "accelerator": "GPU", 6 | "colab": { 7 | "name": "mt5-soft-prompt-tuning-large.ipynb", 8 | "provenance": [], 9 | "collapsed_sections": [], 10 | "machine_shape": "hm" 11 | }, 12 | "kernelspec": { 13 | "display_name": "Python 3", 14 | "name": "python3" 15 | }, 16 | "language_info": { 17 | "name": "python" 18 | } 19 | }, 20 | "cells": [ 21 | { 22 | "cell_type": "code", 23 | "metadata": { 24 | "colab": { 25 | "base_uri": "https://localhost:8080/" 26 | }, 27 | "id": "u8xBtDGlxvz4", 28 | "outputId": "93c696ef-c394-4cbb-a928-0de74c29551c" 29 | }, 30 | "source": [ 31 | "!nvidia-smi" 32 | ], 33 | "execution_count": null, 34 | "outputs": [ 35 | { 36 | "output_type": "stream", 37 | "name": "stdout", 38 | "text": [ 39 | "Sat Sep 11 09:17:34 2021 \n", 40 | "+-----------------------------------------------------------------------------+\n", 41 | "| NVIDIA-SMI 470.63.01 Driver Version: 460.32.03 CUDA Version: 11.2 |\n", 42 | "|-------------------------------+----------------------+----------------------+\n", 43 | "| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n", 44 | "| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n", 45 | "| | | MIG M. |\n", 46 | "|===============================+======================+======================|\n", 47 | "| 0 Tesla V100-SXM2... Off | 00000000:00:04.0 Off | 0 |\n", 48 | "| N/A 37C P0 50W / 300W | 0MiB / 16160MiB | 0% Default |\n", 49 | "| | | N/A |\n", 50 | "+-------------------------------+----------------------+----------------------+\n", 51 | " \n", 52 | "+-----------------------------------------------------------------------------+\n", 53 | "| Processes: |\n", 54 | "| GPU GI CI PID Type Process name GPU Memory |\n", 55 | "| ID ID Usage |\n", 56 | "|=============================================================================|\n", 57 | "| No running processes found |\n", 58 | "+-----------------------------------------------------------------------------+\n" 59 | ] 60 | } 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "metadata": { 66 | "colab": { 67 | "base_uri": "https://localhost:8080/" 68 | }, 69 | "id": "rbpX3_9PxzVE", 70 | "outputId": "e945ce69-cc50-4e8c-fe75-d453e7a704b8" 71 | }, 72 | "source": [ 73 | "!pip install transformers SentencePiece torch tqdm" 74 | ], 75 | "execution_count": null, 76 | "outputs": [ 77 | { 78 | "output_type": "stream", 79 | "name": "stdout", 80 | "text": [ 81 | "Requirement already satisfied: transformers in /usr/local/lib/python3.7/dist-packages (4.10.2)\n", 82 | "Requirement already satisfied: SentencePiece in /usr/local/lib/python3.7/dist-packages (0.1.96)\n", 83 | "Requirement already satisfied: torch in /usr/local/lib/python3.7/dist-packages (1.9.0+cu102)\n", 84 | "Requirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (4.62.0)\n", 85 | "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.7/dist-packages (from transformers) (5.4.1)\n", 86 | "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.7/dist-packages (from transformers) (2019.12.20)\n", 87 | "Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from transformers) (2.23.0)\n", 88 | "Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from transformers) (4.6.4)\n", 89 | "Requirement already satisfied: huggingface-hub>=0.0.12 in /usr/local/lib/python3.7/dist-packages (from transformers) (0.0.16)\n", 90 | "Requirement already satisfied: sacremoses in /usr/local/lib/python3.7/dist-packages (from transformers) (0.0.45)\n", 91 | "Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from transformers) (3.0.12)\n", 92 | "Requirement already satisfied: packaging in /usr/local/lib/python3.7/dist-packages (from transformers) (21.0)\n", 93 | "Requirement already satisfied: tokenizers<0.11,>=0.10.1 in /usr/local/lib/python3.7/dist-packages (from transformers) (0.10.3)\n", 94 | "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.7/dist-packages (from transformers) (1.19.5)\n", 95 | "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from huggingface-hub>=0.0.12->transformers) (3.7.4.3)\n", 96 | "Requirement already satisfied: pyparsing>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging->transformers) (2.4.7)\n", 97 | "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata->transformers) (3.5.0)\n", 98 | "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (1.24.3)\n", 99 | "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (2.10)\n", 100 | "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (3.0.4)\n", 101 | "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (2021.5.30)\n", 102 | "Requirement already satisfied: click in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers) (7.1.2)\n", 103 | "Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers) (1.15.0)\n", 104 | "Requirement already satisfied: joblib in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers) (1.0.1)\n" 105 | ] 106 | } 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "metadata": { 112 | "id": "sqiBMWGHx3Yx" 113 | }, 114 | "source": [ 115 | "import math\n", 116 | "\n", 117 | "from tqdm import tqdm\n", 118 | "import numpy as np\n", 119 | "from transformers import MT5ForConditionalGeneration, T5Tokenizer\n", 120 | "import torch\n", 121 | "import torch.nn as nn\n", 122 | "from sklearn.metrics import accuracy_score" 123 | ], 124 | "execution_count": null, 125 | "outputs": [] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "metadata": { 130 | "id": "w9dorhOHyydQ" 131 | }, 132 | "source": [ 133 | "class SoftEmbedding(nn.Module):\n", 134 | " def __init__(self, \n", 135 | " wte: nn.Embedding,\n", 136 | " n_tokens: int = 10, \n", 137 | " random_range: float = 0.5,\n", 138 | " initialize_from_vocab: bool = True):\n", 139 | " \"\"\"appends learned embedding to \n", 140 | " Args:\n", 141 | " wte (nn.Embedding): original transformer word embedding\n", 142 | " n_tokens (int, optional): number of tokens for task. Defaults to 10.\n", 143 | " random_range (float, optional): range to init embedding (if not initialize from vocab). Defaults to 0.5.\n", 144 | " initialize_from_vocab (bool, optional): initalizes from default vocab. Defaults to True.\n", 145 | " \"\"\"\n", 146 | " super(SoftEmbedding, self).__init__()\n", 147 | " self.wte = wte\n", 148 | " self.n_tokens = n_tokens\n", 149 | " self.learned_embedding = nn.parameter.Parameter(self.initialize_embedding(wte,\n", 150 | " n_tokens, \n", 151 | " random_range, \n", 152 | " initialize_from_vocab))\n", 153 | " \n", 154 | " def initialize_embedding(self, \n", 155 | " wte: nn.Embedding,\n", 156 | " n_tokens: int = 10, \n", 157 | " random_range: float = 0.5, \n", 158 | " initialize_from_vocab: bool = True):\n", 159 | " \"\"\"initializes learned embedding\n", 160 | " Args:\n", 161 | " same as __init__\n", 162 | " Returns:\n", 163 | " torch.float: initialized using original schemes\n", 164 | " \"\"\"\n", 165 | " if initialize_from_vocab:\n", 166 | " return self.wte.weight[:n_tokens].clone().detach()\n", 167 | " return torch.FloatTensor(n_tokens, wte.weight.size(1)).uniform_(-random_range, random_range)\n", 168 | " \n", 169 | " def forward(self, tokens):\n", 170 | " \"\"\"run forward pass\n", 171 | " Args:\n", 172 | " tokens (torch.long): input tokens before encoding\n", 173 | " Returns:\n", 174 | " torch.float: encoding of text concatenated with learned task specifc embedding\n", 175 | " \"\"\"\n", 176 | " input_embedding = self.wte(tokens[:, self.n_tokens:])\n", 177 | " learned_embedding = self.learned_embedding.repeat(input_embedding.size(0), 1, 1)\n", 178 | " return torch.cat([learned_embedding, input_embedding], 1)" 179 | ], 180 | "execution_count": null, 181 | "outputs": [] 182 | }, 183 | { 184 | "cell_type": "code", 185 | "metadata": { 186 | "colab": { 187 | "base_uri": "https://localhost:8080/" 188 | }, 189 | "id": "We6vNt5ukkBJ", 190 | "outputId": "fee4408f-e1c1-43e3-9c0b-fb29a7e04e85" 191 | }, 192 | "source": [ 193 | "!pip install zh-dataset-inews" 194 | ], 195 | "execution_count": null, 196 | "outputs": [ 197 | { 198 | "output_type": "stream", 199 | "name": "stdout", 200 | "text": [ 201 | "Requirement already satisfied: zh-dataset-inews in /usr/local/lib/python3.7/dist-packages (0.0.2)\n" 202 | ] 203 | } 204 | ] 205 | }, 206 | { 207 | "cell_type": "code", 208 | "metadata": { 209 | "id": "biQD2_aB6v8s" 210 | }, 211 | "source": [ 212 | "from zh_dataset_inews import title_train, label_train, title_dev, label_dev, title_test, label_test" 213 | ], 214 | "execution_count": null, 215 | "outputs": [] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "metadata": { 220 | "id": "LfF_Z4qu0z3a" 221 | }, 222 | "source": [ 223 | "def generate_data(batch_size, n_tokens, title_data, label_data):\n", 224 | "\n", 225 | " labels = [\n", 226 | " torch.tensor([[3]]), # \\x00\n", 227 | " torch.tensor([[4]]), # \\x01\n", 228 | " torch.tensor([[5]]), # \\x02\n", 229 | " ]\n", 230 | "\n", 231 | " def yield_data(x_batch, y_batch, l_batch):\n", 232 | " x = torch.nn.utils.rnn.pad_sequence(x_batch, batch_first=True)\n", 233 | " y = torch.cat(y_batch, dim=0)\n", 234 | " m = (x > 0).to(torch.float32)\n", 235 | " decoder_input_ids = torch.full((x.size(0), n_tokens), 1)\n", 236 | " if torch.cuda.is_available():\n", 237 | " x = x.cuda()\n", 238 | " y = y.cuda()\n", 239 | " m = m.cuda()\n", 240 | " decoder_input_ids = decoder_input_ids.cuda()\n", 241 | " return x, y, m, decoder_input_ids, l_batch\n", 242 | "\n", 243 | " x_batch, y_batch, l_batch = [], [], []\n", 244 | " for x, y in zip(title_data, label_data):\n", 245 | " context = x\n", 246 | " inputs = tokenizer(context, return_tensors=\"pt\")\n", 247 | " inputs['input_ids'] = torch.cat([torch.full((1, n_tokens), 1), inputs['input_ids']], 1)\n", 248 | " l_batch.append(y)\n", 249 | " y = labels[y]\n", 250 | " y = torch.cat([torch.full((1, n_tokens - 1), -100), y], 1)\n", 251 | " x_batch.append(inputs['input_ids'][0])\n", 252 | " y_batch.append(y)\n", 253 | " if len(x_batch) >= batch_size:\n", 254 | " yield yield_data(x_batch, y_batch, l_batch)\n", 255 | " x_batch, y_batch, l_batch = [], [], []\n", 256 | "\n", 257 | " if len(x_batch) > 0:\n", 258 | " yield yield_data(x_batch, y_batch, l_batch)\n", 259 | " x_batch, y_batch, l_batch = [], [], []" 260 | ], 261 | "execution_count": null, 262 | "outputs": [] 263 | }, 264 | { 265 | "cell_type": "code", 266 | "metadata": { 267 | "id": "5NKTnNidXnzS" 268 | }, 269 | "source": [ 270 | "model = MT5ForConditionalGeneration.from_pretrained(\"google/mt5-large\")\n", 271 | "tokenizer = T5Tokenizer.from_pretrained(\"google/mt5-large\")\n", 272 | "n_tokens = 100\n", 273 | "s_wte = SoftEmbedding(model.get_input_embeddings(), \n", 274 | " n_tokens=n_tokens, \n", 275 | " initialize_from_vocab=True)\n", 276 | "model.set_input_embeddings(s_wte)\n", 277 | "if torch.cuda.is_available():\n", 278 | " model = model.cuda()" 279 | ], 280 | "execution_count": null, 281 | "outputs": [] 282 | }, 283 | { 284 | "cell_type": "code", 285 | "metadata": { 286 | "id": "eaKnkA8M4Am4" 287 | }, 288 | "source": [ 289 | "parameters = list(model.parameters())\n", 290 | "for x in parameters[1:]:\n", 291 | " x.requires_grad = False" 292 | ], 293 | "execution_count": null, 294 | "outputs": [] 295 | }, 296 | { 297 | "cell_type": "code", 298 | "metadata": { 299 | "colab": { 300 | "base_uri": "https://localhost:8080/" 301 | }, 302 | "id": "YR1cPDym4LYi", 303 | "outputId": "e9d9ef6a-4450-48d5-e06d-d88b8721da79" 304 | }, 305 | "source": [ 306 | "parameters[0]" 307 | ], 308 | "execution_count": null, 309 | "outputs": [ 310 | { 311 | "output_type": "execute_result", 312 | "data": { 313 | "text/plain": [ 314 | "Parameter containing:\n", 315 | "tensor([[ -1.0312, -4.2500, 7.0000, ..., 6.0938, -8.0625, -9.5000],\n", 316 | " [ -7.7500, -12.1250, -2.3438, ..., -7.8438, 9.1875, 4.4375],\n", 317 | " [ 0.9805, 1.0781, -0.3867, ..., -1.0156, -0.4785, 0.8008],\n", 318 | " ...,\n", 319 | " [ -1.4922, 0.1895, -0.2041, ..., 0.6250, 0.0131, -1.8828],\n", 320 | " [ 0.8789, 0.1108, 1.1953, ..., 0.8281, 1.4844, 0.3418],\n", 321 | " [ 0.1436, -0.3867, -0.7734, ..., 0.5078, -0.0157, 0.1060]],\n", 322 | " device='cuda:0', requires_grad=True)" 323 | ] 324 | }, 325 | "metadata": {}, 326 | "execution_count": 10 327 | } 328 | ] 329 | }, 330 | { 331 | "cell_type": "code", 332 | "metadata": { 333 | "colab": { 334 | "base_uri": "https://localhost:8080/" 335 | }, 336 | "id": "b20PaXQA4L4n", 337 | "outputId": "3b5eebd4-e309-4076-e4e7-4f23423ab05e" 338 | }, 339 | "source": [ 340 | "parameters[2]" 341 | ], 342 | "execution_count": null, 343 | "outputs": [ 344 | { 345 | "output_type": "execute_result", 346 | "data": { 347 | "text/plain": [ 348 | "Parameter containing:\n", 349 | "tensor([[ 0.0099, 0.0084, 0.0172, ..., 0.0220, 0.0435, -0.0337],\n", 350 | " [ 0.0112, -0.0181, -0.0107, ..., 0.0227, 0.0190, 0.0033],\n", 351 | " [ 0.0061, 0.0430, 0.0625, ..., -0.0334, -0.0130, 0.0205],\n", 352 | " ...,\n", 353 | " [ 0.0034, 0.0228, 0.0003, ..., 0.0113, -0.0045, -0.0222],\n", 354 | " [ 0.0297, -0.0042, -0.0393, ..., 0.0037, -0.0145, -0.0023],\n", 355 | " [ 0.0053, -0.0029, 0.0157, ..., -0.0125, 0.0068, 0.0106]],\n", 356 | " device='cuda:0')" 357 | ] 358 | }, 359 | "metadata": {}, 360 | "execution_count": 11 361 | } 362 | ] 363 | }, 364 | { 365 | "cell_type": "code", 366 | "metadata": { 367 | "id": "c07hoCo44MP-" 368 | }, 369 | "source": [ 370 | "" 371 | ], 372 | "execution_count": null, 373 | "outputs": [] 374 | }, 375 | { 376 | "cell_type": "code", 377 | "metadata": { 378 | "id": "uXKtiZOPkKwZ" 379 | }, 380 | "source": [ 381 | "for x, y, m, dii, true_labels in generate_data(2, n_tokens, title_train, label_train):\n", 382 | " assert dii.shape == y.shape\n", 383 | " outputs = model(input_ids=x, labels=y, attention_mask=m, decoder_input_ids=dii)\n", 384 | " assert outputs['logits'].shape[:2] == y.shape\n", 385 | " pred_labels = outputs['logits'][:, -1, 3:6].argmax(-1).detach().cpu().numpy().tolist()\n", 386 | " break" 387 | ], 388 | "execution_count": null, 389 | "outputs": [] 390 | }, 391 | { 392 | "cell_type": "code", 393 | "metadata": { 394 | "colab": { 395 | "base_uri": "https://localhost:8080/" 396 | }, 397 | "id": "RdZ3dTfK3-ob", 398 | "outputId": "60d1d5a9-cf2f-4317-8f03-d140edbf6b8f" 399 | }, 400 | "source": [ 401 | "batch_size = 2\n", 402 | "n_epoch = 50\n", 403 | "total_batch = math.ceil(len(title_train) / batch_size)\n", 404 | "dev_total_batch = math.ceil(len(title_dev) / batch_size)\n", 405 | "use_ce_loss = False\n", 406 | "ce_loss = nn.CrossEntropyLoss()\n", 407 | "optimizer = torch.optim.Adam(s_wte.parameters(), lr=0.5)\n", 408 | "\n", 409 | "for epoch in range(n_epoch):\n", 410 | " print('epoch', epoch)\n", 411 | "\n", 412 | " all_true_labels = []\n", 413 | " all_pred_labels = []\n", 414 | " losses = []\n", 415 | " pbar = tqdm(enumerate(generate_data(batch_size, n_tokens, title_train, label_train)), total=total_batch)\n", 416 | " for i, (x, y, m, dii, true_labels) in pbar:\n", 417 | " all_true_labels += true_labels\n", 418 | " \n", 419 | " optimizer.zero_grad()\n", 420 | " outputs = model(input_ids=x, labels=y, attention_mask=m, decoder_input_ids=dii)\n", 421 | " pred_labels = outputs['logits'][:, -1, 3:6].argmax(-1).detach().cpu().numpy().tolist()\n", 422 | " all_pred_labels += pred_labels\n", 423 | "\n", 424 | " if use_ce_loss:\n", 425 | " logits = outputs['logits'][:, -1, 3:6]\n", 426 | " true_labels_tensor = torch.tensor(true_labels, dtype=torch.long).cuda()\n", 427 | " loss = ce_loss(logits, true_labels_tensor)\n", 428 | " else:\n", 429 | " loss = outputs.loss\n", 430 | " loss.backward()\n", 431 | " optimizer.step()\n", 432 | " loss_value = float(loss.detach().cpu().numpy().tolist()) / batch_size\n", 433 | " losses.append(loss_value)\n", 434 | "\n", 435 | " acc = accuracy_score(all_true_labels, all_pred_labels)\n", 436 | " pbar.set_description(f'train: loss={np.mean(losses):.4f}, acc={acc:.4f}')\n", 437 | "\n", 438 | " all_true_labels = []\n", 439 | " all_pred_labels = []\n", 440 | " losses = []\n", 441 | " with torch.no_grad():\n", 442 | " pbar = tqdm(enumerate(generate_data(batch_size, n_tokens, title_dev, label_dev)), total=dev_total_batch)\n", 443 | " for i, (x, y, m, dii, true_labels) in pbar:\n", 444 | " all_true_labels += true_labels\n", 445 | " outputs = model(input_ids=x, labels=y, attention_mask=m, decoder_input_ids=dii)\n", 446 | " loss = outputs.loss\n", 447 | " loss_value = float(loss.detach().cpu().numpy().tolist()) / batch_size\n", 448 | " losses.append(loss_value)\n", 449 | " pred_labels = outputs['logits'][:, -1, 3:6].argmax(-1).detach().cpu().numpy().tolist()\n", 450 | " all_pred_labels += pred_labels\n", 451 | " acc = accuracy_score(all_true_labels, all_pred_labels)\n", 452 | " pbar.set_description(f'dev: loss={np.mean(losses):.4f}, acc={acc:.4f}')" 453 | ], 454 | "execution_count": 13, 455 | "outputs": [ 456 | { 457 | "metadata": { 458 | "tags": null 459 | }, 460 | "name": "stdout", 461 | "output_type": "stream", 462 | "text": [ 463 | "epoch 0\n" 464 | ] 465 | }, 466 | { 467 | "metadata": { 468 | "tags": null 469 | }, 470 | "name": "stderr", 471 | "output_type": "stream", 472 | "text": [ 473 | "train: loss=27.0702, acc=0.1079: 100%|██████████| 2678/2678 [06:10<00:00, 7.23it/s]\n", 474 | "dev: loss=24.9345, acc=0.0861: 100%|██████████| 500/500 [00:32<00:00, 15.35it/s]\n" 475 | ] 476 | }, 477 | { 478 | "metadata": { 479 | "tags": null 480 | }, 481 | "name": "stdout", 482 | "output_type": "stream", 483 | "text": [ 484 | "epoch 1\n" 485 | ] 486 | }, 487 | { 488 | "metadata": { 489 | "tags": null 490 | }, 491 | "name": "stderr", 492 | "output_type": "stream", 493 | "text": [ 494 | "train: loss=18.2214, acc=0.1617: 100%|██████████| 2678/2678 [06:08<00:00, 7.26it/s]\n", 495 | "dev: loss=16.0809, acc=0.4344: 100%|██████████| 500/500 [00:32<00:00, 15.47it/s]\n" 496 | ] 497 | }, 498 | { 499 | "metadata": { 500 | "tags": null 501 | }, 502 | "name": "stdout", 503 | "output_type": "stream", 504 | "text": [ 505 | "epoch 2\n" 506 | ] 507 | }, 508 | { 509 | "metadata": { 510 | "tags": null 511 | }, 512 | "name": "stderr", 513 | "output_type": "stream", 514 | "text": [ 515 | "train: loss=24.5954, acc=0.2510: 100%|██████████| 2678/2678 [06:08<00:00, 7.26it/s]\n", 516 | "dev: loss=31.2844, acc=0.0861: 100%|██████████| 500/500 [00:32<00:00, 15.45it/s]\n" 517 | ] 518 | }, 519 | { 520 | "metadata": { 521 | "tags": null 522 | }, 523 | "name": "stdout", 524 | "output_type": "stream", 525 | "text": [ 526 | "epoch 3\n" 527 | ] 528 | }, 529 | { 530 | "metadata": { 531 | "tags": null 532 | }, 533 | "name": "stderr", 534 | "output_type": "stream", 535 | "text": [ 536 | "train: loss=27.7627, acc=0.1079: 100%|██████████| 2678/2678 [06:07<00:00, 7.29it/s]\n", 537 | "dev: loss=17.8570, acc=0.0861: 100%|██████████| 500/500 [00:32<00:00, 15.51it/s]\n" 538 | ] 539 | }, 540 | { 541 | "metadata": { 542 | "tags": null 543 | }, 544 | "name": "stdout", 545 | "output_type": "stream", 546 | "text": [ 547 | "epoch 4\n" 548 | ] 549 | }, 550 | { 551 | "metadata": { 552 | "tags": null 553 | }, 554 | "name": "stderr", 555 | "output_type": "stream", 556 | "text": [ 557 | "train: loss=5.3752, acc=0.4288: 100%|██████████| 2678/2678 [06:08<00:00, 7.27it/s]\n", 558 | "dev: loss=2.5420, acc=0.4935: 100%|██████████| 500/500 [00:32<00:00, 15.43it/s]\n" 559 | ] 560 | }, 561 | { 562 | "metadata": { 563 | "tags": null 564 | }, 565 | "name": "stdout", 566 | "output_type": "stream", 567 | "text": [ 568 | "epoch 5\n" 569 | ] 570 | }, 571 | { 572 | "metadata": { 573 | "tags": null 574 | }, 575 | "name": "stderr", 576 | "output_type": "stream", 577 | "text": [ 578 | "train: loss=1.6852, acc=0.4426: 100%|██████████| 2678/2678 [06:07<00:00, 7.29it/s]\n", 579 | "dev: loss=1.8112, acc=0.4925: 100%|██████████| 500/500 [00:32<00:00, 15.38it/s]\n" 580 | ] 581 | }, 582 | { 583 | "metadata": { 584 | "tags": null 585 | }, 586 | "name": "stdout", 587 | "output_type": "stream", 588 | "text": [ 589 | "epoch 6\n" 590 | ] 591 | }, 592 | { 593 | "metadata": { 594 | "tags": null 595 | }, 596 | "name": "stderr", 597 | "output_type": "stream", 598 | "text": [ 599 | "train: loss=6.8852, acc=0.3473: 100%|██████████| 2678/2678 [06:07<00:00, 7.29it/s]\n", 600 | "dev: loss=15.4861, acc=0.0861: 100%|██████████| 500/500 [00:32<00:00, 15.24it/s]\n" 601 | ] 602 | }, 603 | { 604 | "metadata": { 605 | "tags": null 606 | }, 607 | "name": "stdout", 608 | "output_type": "stream", 609 | "text": [ 610 | "epoch 7\n" 611 | ] 612 | }, 613 | { 614 | "metadata": { 615 | "tags": null 616 | }, 617 | "name": "stderr", 618 | "output_type": "stream", 619 | "text": [ 620 | "train: loss=5.7757, acc=0.4325: 100%|██████████| 2678/2678 [06:07<00:00, 7.29it/s]\n", 621 | "dev: loss=2.9378, acc=0.4895: 100%|██████████| 500/500 [00:32<00:00, 15.45it/s]\n" 622 | ] 623 | }, 624 | { 625 | "metadata": { 626 | "tags": null 627 | }, 628 | "name": "stdout", 629 | "output_type": "stream", 630 | "text": [ 631 | "epoch 8\n" 632 | ] 633 | }, 634 | { 635 | "metadata": { 636 | "tags": null 637 | }, 638 | "name": "stderr", 639 | "output_type": "stream", 640 | "text": [ 641 | "train: loss=3.1350, acc=0.4405: 100%|██████████| 2678/2678 [06:08<00:00, 7.27it/s]\n", 642 | "dev: loss=1.9451, acc=0.4925: 100%|██████████| 500/500 [00:32<00:00, 15.25it/s]\n" 643 | ] 644 | }, 645 | { 646 | "metadata": { 647 | "tags": null 648 | }, 649 | "name": "stdout", 650 | "output_type": "stream", 651 | "text": [ 652 | "epoch 9\n" 653 | ] 654 | }, 655 | { 656 | "metadata": { 657 | "tags": null 658 | }, 659 | "name": "stderr", 660 | "output_type": "stream", 661 | "text": [ 662 | "train: loss=1.4360, acc=0.4459: 100%|██████████| 2678/2678 [06:07<00:00, 7.28it/s]\n", 663 | "dev: loss=1.5743, acc=0.4925: 100%|██████████| 500/500 [00:32<00:00, 15.34it/s]\n" 664 | ] 665 | }, 666 | { 667 | "metadata": { 668 | "tags": null 669 | }, 670 | "name": "stdout", 671 | "output_type": "stream", 672 | "text": [ 673 | "epoch 10\n" 674 | ] 675 | }, 676 | { 677 | "metadata": { 678 | "tags": null 679 | }, 680 | "name": "stderr", 681 | "output_type": "stream", 682 | "text": [ 683 | "train: loss=0.7307, acc=0.4702: 100%|██████████| 2678/2678 [06:07<00:00, 7.28it/s]\n", 684 | "dev: loss=2.3358, acc=0.4925: 100%|██████████| 500/500 [00:32<00:00, 15.38it/s]\n" 685 | ] 686 | }, 687 | { 688 | "metadata": { 689 | "tags": null 690 | }, 691 | "name": "stdout", 692 | "output_type": "stream", 693 | "text": [ 694 | "epoch 11\n" 695 | ] 696 | }, 697 | { 698 | "metadata": { 699 | "tags": null 700 | }, 701 | "name": "stderr", 702 | "output_type": "stream", 703 | "text": [ 704 | "train: loss=0.7700, acc=0.4480: 100%|██████████| 2678/2678 [06:05<00:00, 7.32it/s]\n", 705 | "dev: loss=1.3577, acc=0.4925: 100%|██████████| 500/500 [00:32<00:00, 15.50it/s]\n" 706 | ] 707 | }, 708 | { 709 | "metadata": { 710 | "tags": null 711 | }, 712 | "name": "stdout", 713 | "output_type": "stream", 714 | "text": [ 715 | "epoch 12\n" 716 | ] 717 | }, 718 | { 719 | "metadata": { 720 | "tags": null 721 | }, 722 | "name": "stderr", 723 | "output_type": "stream", 724 | "text": [ 725 | "train: loss=0.6409, acc=0.4652: 100%|██████████| 2678/2678 [06:06<00:00, 7.31it/s]\n", 726 | "dev: loss=1.1933, acc=0.4925: 100%|██████████| 500/500 [00:32<00:00, 15.43it/s]\n" 727 | ] 728 | }, 729 | { 730 | "metadata": { 731 | "tags": null 732 | }, 733 | "name": "stdout", 734 | "output_type": "stream", 735 | "text": [ 736 | "epoch 13\n" 737 | ] 738 | }, 739 | { 740 | "metadata": { 741 | "tags": null 742 | }, 743 | "name": "stderr", 744 | "output_type": "stream", 745 | "text": [ 746 | "train: loss=0.5962, acc=0.5100: 100%|██████████| 2678/2678 [06:05<00:00, 7.32it/s]\n", 747 | "dev: loss=1.0476, acc=0.4925: 100%|██████████| 500/500 [00:31<00:00, 15.64it/s]\n" 748 | ] 749 | }, 750 | { 751 | "metadata": { 752 | "tags": null 753 | }, 754 | "name": "stdout", 755 | "output_type": "stream", 756 | "text": [ 757 | "epoch 14\n" 758 | ] 759 | }, 760 | { 761 | "metadata": { 762 | "tags": null 763 | }, 764 | "name": "stderr", 765 | "output_type": "stream", 766 | "text": [ 767 | "train: loss=0.5385, acc=0.5968: 100%|██████████| 2678/2678 [06:05<00:00, 7.32it/s]\n", 768 | "dev: loss=1.1322, acc=0.4925: 100%|██████████| 500/500 [00:31<00:00, 15.69it/s]\n" 769 | ] 770 | }, 771 | { 772 | "metadata": { 773 | "tags": null 774 | }, 775 | "name": "stdout", 776 | "output_type": "stream", 777 | "text": [ 778 | "epoch 15\n" 779 | ] 780 | }, 781 | { 782 | "metadata": { 783 | "tags": null 784 | }, 785 | "name": "stderr", 786 | "output_type": "stream", 787 | "text": [ 788 | "train: loss=0.4824, acc=0.6390: 100%|██████████| 2678/2678 [06:02<00:00, 7.39it/s]\n", 789 | "dev: loss=0.8160, acc=0.4925: 100%|██████████| 500/500 [00:31<00:00, 15.64it/s]\n" 790 | ] 791 | }, 792 | { 793 | "metadata": { 794 | "tags": null 795 | }, 796 | "name": "stdout", 797 | "output_type": "stream", 798 | "text": [ 799 | "epoch 16\n" 800 | ] 801 | }, 802 | { 803 | "metadata": { 804 | "tags": null 805 | }, 806 | "name": "stderr", 807 | "output_type": "stream", 808 | "text": [ 809 | "train: loss=0.4478, acc=0.6753: 100%|██████████| 2678/2678 [06:03<00:00, 7.38it/s]\n", 810 | "dev: loss=0.6584, acc=0.5726: 100%|██████████| 500/500 [00:32<00:00, 15.60it/s]\n" 811 | ] 812 | }, 813 | { 814 | "metadata": { 815 | "tags": null 816 | }, 817 | "name": "stdout", 818 | "output_type": "stream", 819 | "text": [ 820 | "epoch 17\n" 821 | ] 822 | }, 823 | { 824 | "metadata": { 825 | "tags": null 826 | }, 827 | "name": "stderr", 828 | "output_type": "stream", 829 | "text": [ 830 | "train: loss=0.4170, acc=0.6930: 100%|██████████| 2678/2678 [06:02<00:00, 7.39it/s]\n", 831 | "dev: loss=0.6231, acc=0.6336: 100%|██████████| 500/500 [00:32<00:00, 15.52it/s]\n" 832 | ] 833 | }, 834 | { 835 | "metadata": { 836 | "tags": null 837 | }, 838 | "name": "stdout", 839 | "output_type": "stream", 840 | "text": [ 841 | "epoch 18\n" 842 | ] 843 | }, 844 | { 845 | "metadata": { 846 | "tags": null 847 | }, 848 | "name": "stderr", 849 | "output_type": "stream", 850 | "text": [ 851 | "train: loss=0.3995, acc=0.7029: 100%|██████████| 2678/2678 [06:01<00:00, 7.41it/s]\n", 852 | "dev: loss=0.5515, acc=0.6667: 100%|██████████| 500/500 [00:32<00:00, 15.56it/s]\n" 853 | ] 854 | }, 855 | { 856 | "metadata": { 857 | "tags": null 858 | }, 859 | "name": "stdout", 860 | "output_type": "stream", 861 | "text": [ 862 | "epoch 19\n" 863 | ] 864 | }, 865 | { 866 | "metadata": { 867 | "tags": null 868 | }, 869 | "name": "stderr", 870 | "output_type": "stream", 871 | "text": [ 872 | "train: loss=0.3821, acc=0.7107: 100%|██████████| 2678/2678 [06:03<00:00, 7.38it/s]\n", 873 | "dev: loss=0.5967, acc=0.6466: 100%|██████████| 500/500 [00:31<00:00, 15.63it/s]\n" 874 | ] 875 | }, 876 | { 877 | "metadata": { 878 | "tags": null 879 | }, 880 | "name": "stdout", 881 | "output_type": "stream", 882 | "text": [ 883 | "epoch 20\n" 884 | ] 885 | }, 886 | { 887 | "metadata": { 888 | "tags": null 889 | }, 890 | "name": "stderr", 891 | "output_type": "stream", 892 | "text": [ 893 | "train: loss=0.3647, acc=0.7303: 100%|██████████| 2678/2678 [06:01<00:00, 7.40it/s]\n", 894 | "dev: loss=0.4908, acc=0.6557: 100%|██████████| 500/500 [00:32<00:00, 15.57it/s]\n" 895 | ] 896 | }, 897 | { 898 | "metadata": { 899 | "tags": null 900 | }, 901 | "name": "stdout", 902 | "output_type": "stream", 903 | "text": [ 904 | "epoch 21\n" 905 | ] 906 | }, 907 | { 908 | "metadata": { 909 | "tags": null 910 | }, 911 | "name": "stderr", 912 | "output_type": "stream", 913 | "text": [ 914 | "train: loss=0.3547, acc=0.7315: 100%|██████████| 2678/2678 [06:01<00:00, 7.42it/s]\n", 915 | "dev: loss=0.4713, acc=0.7067: 100%|██████████| 500/500 [00:31<00:00, 15.65it/s]\n" 916 | ] 917 | }, 918 | { 919 | "metadata": { 920 | "tags": null 921 | }, 922 | "name": "stdout", 923 | "output_type": "stream", 924 | "text": [ 925 | "epoch 22\n" 926 | ] 927 | }, 928 | { 929 | "output_type": "stream", 930 | "name": "stderr", 931 | "text": [ 932 | "train: loss=0.3406, acc=0.7415: 100%|██████████| 2678/2678 [06:01<00:00, 7.40it/s]\n", 933 | "dev: loss=0.4078, acc=0.7417: 100%|██████████| 500/500 [00:31<00:00, 15.64it/s]\n" 934 | ] 935 | }, 936 | { 937 | "output_type": "stream", 938 | "name": "stdout", 939 | "text": [ 940 | "epoch 23\n" 941 | ] 942 | }, 943 | { 944 | "output_type": "stream", 945 | "name": "stderr", 946 | "text": [ 947 | "train: loss=0.3255, acc=0.7565: 100%|██████████| 2678/2678 [06:02<00:00, 7.39it/s]\n", 948 | "dev: loss=0.3752, acc=0.7337: 100%|██████████| 500/500 [00:32<00:00, 15.61it/s]\n" 949 | ] 950 | }, 951 | { 952 | "output_type": "stream", 953 | "name": "stdout", 954 | "text": [ 955 | "epoch 24\n" 956 | ] 957 | }, 958 | { 959 | "output_type": "stream", 960 | "name": "stderr", 961 | "text": [ 962 | "train: loss=0.3098, acc=0.7621: 100%|██████████| 2678/2678 [06:02<00:00, 7.38it/s]\n", 963 | "dev: loss=0.3692, acc=0.7407: 100%|██████████| 500/500 [00:32<00:00, 15.57it/s]\n" 964 | ] 965 | }, 966 | { 967 | "output_type": "stream", 968 | "name": "stdout", 969 | "text": [ 970 | "epoch 25\n" 971 | ] 972 | }, 973 | { 974 | "output_type": "stream", 975 | "name": "stderr", 976 | "text": [ 977 | "train: loss=0.2986, acc=0.7811: 100%|██████████| 2678/2678 [06:02<00:00, 7.40it/s]\n", 978 | "dev: loss=0.3479, acc=0.7628: 100%|██████████| 500/500 [00:32<00:00, 15.50it/s]\n" 979 | ] 980 | }, 981 | { 982 | "output_type": "stream", 983 | "name": "stdout", 984 | "text": [ 985 | "epoch 26\n" 986 | ] 987 | }, 988 | { 989 | "output_type": "stream", 990 | "name": "stderr", 991 | "text": [ 992 | "train: loss=0.2924, acc=0.7824: 100%|██████████| 2678/2678 [06:02<00:00, 7.39it/s]\n", 993 | "dev: loss=0.3496, acc=0.7548: 100%|██████████| 500/500 [00:32<00:00, 15.62it/s]\n" 994 | ] 995 | }, 996 | { 997 | "output_type": "stream", 998 | "name": "stdout", 999 | "text": [ 1000 | "epoch 27\n" 1001 | ] 1002 | }, 1003 | { 1004 | "output_type": "stream", 1005 | "name": "stderr", 1006 | "text": [ 1007 | "train: loss=0.2865, acc=0.7826: 100%|██████████| 2678/2678 [06:00<00:00, 7.42it/s]\n", 1008 | "dev: loss=0.3458, acc=0.7608: 100%|██████████| 500/500 [00:31<00:00, 15.69it/s]\n" 1009 | ] 1010 | }, 1011 | { 1012 | "output_type": "stream", 1013 | "name": "stdout", 1014 | "text": [ 1015 | "epoch 28\n" 1016 | ] 1017 | }, 1018 | { 1019 | "output_type": "stream", 1020 | "name": "stderr", 1021 | "text": [ 1022 | "train: loss=0.2746, acc=0.7950: 100%|██████████| 2678/2678 [06:03<00:00, 7.37it/s]\n", 1023 | "dev: loss=0.3276, acc=0.7648: 100%|██████████| 500/500 [00:31<00:00, 15.63it/s]\n" 1024 | ] 1025 | }, 1026 | { 1027 | "output_type": "stream", 1028 | "name": "stdout", 1029 | "text": [ 1030 | "epoch 29\n" 1031 | ] 1032 | }, 1033 | { 1034 | "output_type": "stream", 1035 | "name": "stderr", 1036 | "text": [ 1037 | "train: loss=0.2642, acc=0.8011: 100%|██████████| 2678/2678 [06:06<00:00, 7.31it/s]\n", 1038 | "dev: loss=0.3346, acc=0.7578: 100%|██████████| 500/500 [00:32<00:00, 15.58it/s]\n" 1039 | ] 1040 | }, 1041 | { 1042 | "output_type": "stream", 1043 | "name": "stdout", 1044 | "text": [ 1045 | "epoch 30\n" 1046 | ] 1047 | }, 1048 | { 1049 | "output_type": "stream", 1050 | "name": "stderr", 1051 | "text": [ 1052 | "train: loss=0.2633, acc=0.8021: 100%|██████████| 2678/2678 [06:05<00:00, 7.33it/s]\n", 1053 | "dev: loss=0.3168, acc=0.7477: 100%|██████████| 500/500 [00:32<00:00, 15.35it/s]\n" 1054 | ] 1055 | }, 1056 | { 1057 | "output_type": "stream", 1058 | "name": "stdout", 1059 | "text": [ 1060 | "epoch 31\n" 1061 | ] 1062 | }, 1063 | { 1064 | "output_type": "stream", 1065 | "name": "stderr", 1066 | "text": [ 1067 | "train: loss=0.2450, acc=0.8121: 100%|██████████| 2678/2678 [06:05<00:00, 7.32it/s]\n", 1068 | "dev: loss=0.3279, acc=0.7628: 100%|██████████| 500/500 [00:32<00:00, 15.46it/s]\n" 1069 | ] 1070 | }, 1071 | { 1072 | "output_type": "stream", 1073 | "name": "stdout", 1074 | "text": [ 1075 | "epoch 32\n" 1076 | ] 1077 | }, 1078 | { 1079 | "output_type": "stream", 1080 | "name": "stderr", 1081 | "text": [ 1082 | "train: loss=0.2345, acc=0.8239: 100%|██████████| 2678/2678 [06:04<00:00, 7.34it/s]\n", 1083 | "dev: loss=0.3625, acc=0.7538: 100%|██████████| 500/500 [00:32<00:00, 15.55it/s]\n" 1084 | ] 1085 | }, 1086 | { 1087 | "output_type": "stream", 1088 | "name": "stdout", 1089 | "text": [ 1090 | "epoch 33\n" 1091 | ] 1092 | }, 1093 | { 1094 | "output_type": "stream", 1095 | "name": "stderr", 1096 | "text": [ 1097 | "train: loss=0.2345, acc=0.8243: 100%|██████████| 2678/2678 [06:03<00:00, 7.36it/s]\n", 1098 | "dev: loss=0.3248, acc=0.7618: 100%|██████████| 500/500 [00:32<00:00, 15.56it/s]\n" 1099 | ] 1100 | }, 1101 | { 1102 | "output_type": "stream", 1103 | "name": "stdout", 1104 | "text": [ 1105 | "epoch 34\n" 1106 | ] 1107 | }, 1108 | { 1109 | "output_type": "stream", 1110 | "name": "stderr", 1111 | "text": [ 1112 | "train: loss=0.2715, acc=0.7938: 100%|██████████| 2678/2678 [06:05<00:00, 7.32it/s]\n", 1113 | "dev: loss=0.3177, acc=0.7487: 100%|██████████| 500/500 [00:32<00:00, 15.54it/s]\n" 1114 | ] 1115 | }, 1116 | { 1117 | "output_type": "stream", 1118 | "name": "stdout", 1119 | "text": [ 1120 | "epoch 35\n" 1121 | ] 1122 | }, 1123 | { 1124 | "output_type": "stream", 1125 | "name": "stderr", 1126 | "text": [ 1127 | "train: loss=0.2448, acc=0.8136: 100%|██████████| 2678/2678 [06:05<00:00, 7.34it/s]\n", 1128 | "dev: loss=0.3142, acc=0.7638: 100%|██████████| 500/500 [00:32<00:00, 15.44it/s]\n" 1129 | ] 1130 | }, 1131 | { 1132 | "output_type": "stream", 1133 | "name": "stdout", 1134 | "text": [ 1135 | "epoch 36\n" 1136 | ] 1137 | }, 1138 | { 1139 | "output_type": "stream", 1140 | "name": "stderr", 1141 | "text": [ 1142 | "train: loss=0.2281, acc=0.8288: 100%|██████████| 2678/2678 [06:06<00:00, 7.31it/s]\n", 1143 | "dev: loss=0.3680, acc=0.7518: 100%|██████████| 500/500 [00:32<00:00, 15.35it/s]\n" 1144 | ] 1145 | }, 1146 | { 1147 | "output_type": "stream", 1148 | "name": "stdout", 1149 | "text": [ 1150 | "epoch 37\n" 1151 | ] 1152 | }, 1153 | { 1154 | "output_type": "stream", 1155 | "name": "stderr", 1156 | "text": [ 1157 | "train: loss=0.2265, acc=0.8317: 100%|██████████| 2678/2678 [06:06<00:00, 7.31it/s]\n", 1158 | "dev: loss=0.3473, acc=0.7588: 100%|██████████| 500/500 [00:32<00:00, 15.46it/s]\n" 1159 | ] 1160 | }, 1161 | { 1162 | "output_type": "stream", 1163 | "name": "stdout", 1164 | "text": [ 1165 | "epoch 38\n" 1166 | ] 1167 | }, 1168 | { 1169 | "output_type": "stream", 1170 | "name": "stderr", 1171 | "text": [ 1172 | "train: loss=0.2330, acc=0.8263: 100%|██████████| 2678/2678 [06:04<00:00, 7.34it/s]\n", 1173 | "dev: loss=0.3554, acc=0.7497: 100%|██████████| 500/500 [00:32<00:00, 15.37it/s]\n" 1174 | ] 1175 | }, 1176 | { 1177 | "output_type": "stream", 1178 | "name": "stdout", 1179 | "text": [ 1180 | "epoch 39\n" 1181 | ] 1182 | }, 1183 | { 1184 | "output_type": "stream", 1185 | "name": "stderr", 1186 | "text": [ 1187 | "train: loss=0.2277, acc=0.8261: 100%|██████████| 2678/2678 [06:05<00:00, 7.33it/s]\n", 1188 | "dev: loss=0.3790, acc=0.7407: 100%|██████████| 500/500 [00:32<00:00, 15.42it/s]\n" 1189 | ] 1190 | }, 1191 | { 1192 | "output_type": "stream", 1193 | "name": "stdout", 1194 | "text": [ 1195 | "epoch 40\n" 1196 | ] 1197 | }, 1198 | { 1199 | "output_type": "stream", 1200 | "name": "stderr", 1201 | "text": [ 1202 | "train: loss=0.2221, acc=0.8304: 100%|██████████| 2678/2678 [06:05<00:00, 7.32it/s]\n", 1203 | "dev: loss=0.3345, acc=0.7407: 100%|██████████| 500/500 [00:32<00:00, 15.40it/s]\n" 1204 | ] 1205 | }, 1206 | { 1207 | "output_type": "stream", 1208 | "name": "stdout", 1209 | "text": [ 1210 | "epoch 41\n" 1211 | ] 1212 | }, 1213 | { 1214 | "output_type": "stream", 1215 | "name": "stderr", 1216 | "text": [ 1217 | "train: loss=0.2417, acc=0.8144: 100%|██████████| 2678/2678 [06:06<00:00, 7.31it/s]\n", 1218 | "dev: loss=0.3752, acc=0.7347: 100%|██████████| 500/500 [00:32<00:00, 15.50it/s]\n" 1219 | ] 1220 | }, 1221 | { 1222 | "output_type": "stream", 1223 | "name": "stdout", 1224 | "text": [ 1225 | "epoch 42\n" 1226 | ] 1227 | }, 1228 | { 1229 | "output_type": "stream", 1230 | "name": "stderr", 1231 | "text": [ 1232 | "train: loss=0.2384, acc=0.8174: 100%|██████████| 2678/2678 [06:06<00:00, 7.31it/s]\n", 1233 | "dev: loss=0.3535, acc=0.7588: 100%|██████████| 500/500 [00:32<00:00, 15.43it/s]\n" 1234 | ] 1235 | }, 1236 | { 1237 | "output_type": "stream", 1238 | "name": "stdout", 1239 | "text": [ 1240 | "epoch 43\n" 1241 | ] 1242 | }, 1243 | { 1244 | "output_type": "stream", 1245 | "name": "stderr", 1246 | "text": [ 1247 | "train: loss=0.2198, acc=0.8303: 100%|██████████| 2678/2678 [06:06<00:00, 7.32it/s]\n", 1248 | "dev: loss=0.3527, acc=0.7467: 100%|██████████| 500/500 [00:32<00:00, 15.48it/s]\n" 1249 | ] 1250 | }, 1251 | { 1252 | "output_type": "stream", 1253 | "name": "stdout", 1254 | "text": [ 1255 | "epoch 44\n" 1256 | ] 1257 | }, 1258 | { 1259 | "output_type": "stream", 1260 | "name": "stderr", 1261 | "text": [ 1262 | "train: loss=0.2125, acc=0.8400: 100%|██████████| 2678/2678 [06:05<00:00, 7.33it/s]\n", 1263 | "dev: loss=0.3540, acc=0.7447: 100%|██████████| 500/500 [00:32<00:00, 15.47it/s]\n" 1264 | ] 1265 | }, 1266 | { 1267 | "output_type": "stream", 1268 | "name": "stdout", 1269 | "text": [ 1270 | "epoch 45\n" 1271 | ] 1272 | }, 1273 | { 1274 | "output_type": "stream", 1275 | "name": "stderr", 1276 | "text": [ 1277 | "train: loss=0.2063, acc=0.8476: 100%|██████████| 2678/2678 [06:04<00:00, 7.34it/s]\n", 1278 | "dev: loss=0.3499, acc=0.7518: 100%|██████████| 500/500 [00:32<00:00, 15.53it/s]\n" 1279 | ] 1280 | }, 1281 | { 1282 | "output_type": "stream", 1283 | "name": "stdout", 1284 | "text": [ 1285 | "epoch 46\n" 1286 | ] 1287 | }, 1288 | { 1289 | "output_type": "stream", 1290 | "name": "stderr", 1291 | "text": [ 1292 | "train: loss=0.2001, acc=0.8514: 100%|██████████| 2678/2678 [06:05<00:00, 7.32it/s]\n", 1293 | "dev: loss=0.3517, acc=0.7467: 100%|██████████| 500/500 [00:32<00:00, 15.48it/s]\n" 1294 | ] 1295 | }, 1296 | { 1297 | "output_type": "stream", 1298 | "name": "stdout", 1299 | "text": [ 1300 | "epoch 47\n" 1301 | ] 1302 | }, 1303 | { 1304 | "output_type": "stream", 1305 | "name": "stderr", 1306 | "text": [ 1307 | "train: loss=0.1885, acc=0.8646: 100%|██████████| 2678/2678 [06:05<00:00, 7.33it/s]\n", 1308 | "dev: loss=0.3732, acc=0.7477: 100%|██████████| 500/500 [00:32<00:00, 15.37it/s]\n" 1309 | ] 1310 | }, 1311 | { 1312 | "output_type": "stream", 1313 | "name": "stdout", 1314 | "text": [ 1315 | "epoch 48\n" 1316 | ] 1317 | }, 1318 | { 1319 | "output_type": "stream", 1320 | "name": "stderr", 1321 | "text": [ 1322 | "train: loss=0.1811, acc=0.8650: 100%|██████████| 2678/2678 [06:05<00:00, 7.32it/s]\n", 1323 | "dev: loss=0.4366, acc=0.7227: 100%|██████████| 500/500 [00:32<00:00, 15.33it/s]\n" 1324 | ] 1325 | }, 1326 | { 1327 | "output_type": "stream", 1328 | "name": "stdout", 1329 | "text": [ 1330 | "epoch 49\n" 1331 | ] 1332 | }, 1333 | { 1334 | "output_type": "stream", 1335 | "name": "stderr", 1336 | "text": [ 1337 | "train: loss=0.1723, acc=0.8726: 100%|██████████| 2678/2678 [06:04<00:00, 7.34it/s]\n", 1338 | "dev: loss=0.4849, acc=0.7387: 100%|██████████| 500/500 [00:32<00:00, 15.41it/s]\n" 1339 | ] 1340 | } 1341 | ] 1342 | }, 1343 | { 1344 | "cell_type": "code", 1345 | "metadata": { 1346 | "id": "Zrxp8Im04beQ" 1347 | }, 1348 | "source": [ 1349 | "parameters2 = list(model.parameters())" 1350 | ], 1351 | "execution_count": 14, 1352 | "outputs": [] 1353 | }, 1354 | { 1355 | "cell_type": "code", 1356 | "metadata": { 1357 | "id": "rRBcvIbY4s-R", 1358 | "colab": { 1359 | "base_uri": "https://localhost:8080/" 1360 | }, 1361 | "outputId": "baf9152a-f81c-4192-9966-b453dd07833c" 1362 | }, 1363 | "source": [ 1364 | "parameters2[0]" 1365 | ], 1366 | "execution_count": 15, 1367 | "outputs": [ 1368 | { 1369 | "output_type": "execute_result", 1370 | "data": { 1371 | "text/plain": [ 1372 | "Parameter containing:\n", 1373 | "tensor([[ 39.9078, -138.6385, 217.1636, ..., 29.5207, 145.8943,\n", 1374 | " 144.5315],\n", 1375 | " [ 515.6390, -31.6162, 51.5134, ..., -53.0618, 245.4292,\n", 1376 | " -69.3007],\n", 1377 | " [ 236.2896, 43.0374, -19.2581, ..., -127.3152, 130.7397,\n", 1378 | " 31.1689],\n", 1379 | " ...,\n", 1380 | " [ 88.1486, 49.1501, 125.5696, ..., 113.4881, 96.0846,\n", 1381 | " 368.0652],\n", 1382 | " [ 100.6963, -102.7619, -35.8637, ..., -144.5385, -25.3403,\n", 1383 | " 173.1718],\n", 1384 | " [-164.1508, -81.5056, 152.1980, ..., -178.5098, 6.0514,\n", 1385 | " -129.9609]], device='cuda:0', requires_grad=True)" 1386 | ] 1387 | }, 1388 | "metadata": {}, 1389 | "execution_count": 15 1390 | } 1391 | ] 1392 | }, 1393 | { 1394 | "cell_type": "code", 1395 | "metadata": { 1396 | "id": "Mp_0DmKr4tZA", 1397 | "colab": { 1398 | "base_uri": "https://localhost:8080/" 1399 | }, 1400 | "outputId": "a48ec7e3-794a-49ea-d069-8da58e3ab218" 1401 | }, 1402 | "source": [ 1403 | "parameters2[2]" 1404 | ], 1405 | "execution_count": 16, 1406 | "outputs": [ 1407 | { 1408 | "output_type": "execute_result", 1409 | "data": { 1410 | "text/plain": [ 1411 | "Parameter containing:\n", 1412 | "tensor([[ 0.0099, 0.0084, 0.0172, ..., 0.0220, 0.0435, -0.0337],\n", 1413 | " [ 0.0112, -0.0181, -0.0107, ..., 0.0227, 0.0190, 0.0033],\n", 1414 | " [ 0.0061, 0.0430, 0.0625, ..., -0.0334, -0.0130, 0.0205],\n", 1415 | " ...,\n", 1416 | " [ 0.0034, 0.0228, 0.0003, ..., 0.0113, -0.0045, -0.0222],\n", 1417 | " [ 0.0297, -0.0042, -0.0393, ..., 0.0037, -0.0145, -0.0023],\n", 1418 | " [ 0.0053, -0.0029, 0.0157, ..., -0.0125, 0.0068, 0.0106]],\n", 1419 | " device='cuda:0')" 1420 | ] 1421 | }, 1422 | "metadata": {}, 1423 | "execution_count": 16 1424 | } 1425 | ] 1426 | }, 1427 | { 1428 | "cell_type": "code", 1429 | "metadata": { 1430 | "id": "Oqdi4A7tFnxM" 1431 | }, 1432 | "source": [ 1433 | "def predict(text):\n", 1434 | " inputs = tokenizer(text, return_tensors='pt')\n", 1435 | " inputs['input_ids'] = torch.cat([torch.full((1, n_tokens), 1), inputs['input_ids']], 1)\n", 1436 | "\n", 1437 | " decoder_input_ids = torch.full((1, n_tokens), 1)\n", 1438 | " with torch.no_grad():\n", 1439 | " outputs = model(input_ids=inputs['input_ids'].cuda(), decoder_input_ids=decoder_input_ids.cuda())\n", 1440 | " logits = outputs['logits'][:, -1, 3:6]\n", 1441 | " pred = logits.argmax(-1).detach().cpu().numpy()[0]\n", 1442 | " # print(logits)\n", 1443 | " return pred" 1444 | ], 1445 | "execution_count": 17, 1446 | "outputs": [] 1447 | }, 1448 | { 1449 | "cell_type": "code", 1450 | "metadata": { 1451 | "id": "ohlT29oGueUD", 1452 | "colab": { 1453 | "base_uri": "https://localhost:8080/" 1454 | }, 1455 | "outputId": "ef705633-ccb3-4a66-d732-3e272834f6f2" 1456 | }, 1457 | "source": [ 1458 | "train_rets = []\n", 1459 | "for i in tqdm(range(len(title_train))):\n", 1460 | " pred = predict(title_train[i])\n", 1461 | " train_rets.append((label_train[i], pred, title_train[i]))" 1462 | ], 1463 | "execution_count": 18, 1464 | "outputs": [ 1465 | { 1466 | "output_type": "stream", 1467 | "name": "stderr", 1468 | "text": [ 1469 | "100%|██████████| 5355/5355 [04:39<00:00, 19.19it/s]\n" 1470 | ] 1471 | } 1472 | ] 1473 | }, 1474 | { 1475 | "cell_type": "code", 1476 | "metadata": { 1477 | "id": "NE84AeC9BC8U" 1478 | }, 1479 | "source": [ 1480 | "rets = []\n", 1481 | "for i in tqdm(range(len(title_test))):\n", 1482 | " pred = predict(title_test[i])\n", 1483 | " rets.append((label_test[i], pred, title_test[i]))" 1484 | ], 1485 | "execution_count": null, 1486 | "outputs": [] 1487 | }, 1488 | { 1489 | "cell_type": "code", 1490 | "metadata": { 1491 | "id": "DZOnkUtKuiwa", 1492 | "colab": { 1493 | "base_uri": "https://localhost:8080/" 1494 | }, 1495 | "outputId": "83bda13e-9dc1-41c3-ff0e-6b696dc65ccc" 1496 | }, 1497 | "source": [ 1498 | "print(\n", 1499 | " accuracy_score(\n", 1500 | " [x[0] for x in train_rets],\n", 1501 | " [x[1] for x in train_rets],\n", 1502 | " )\n", 1503 | ")" 1504 | ], 1505 | "execution_count": 23, 1506 | "outputs": [ 1507 | { 1508 | "output_type": "stream", 1509 | "name": "stdout", 1510 | "text": [ 1511 | "0.861624649859944\n" 1512 | ] 1513 | } 1514 | ] 1515 | }, 1516 | { 1517 | "cell_type": "code", 1518 | "metadata": { 1519 | "id": "74CWYzXXuW-W", 1520 | "colab": { 1521 | "base_uri": "https://localhost:8080/" 1522 | }, 1523 | "outputId": "c1a7e002-279d-4602-9ebf-19df96135ea1" 1524 | }, 1525 | "source": [ 1526 | "print(\n", 1527 | " accuracy_score(\n", 1528 | " [x[0] for x in rets],\n", 1529 | " [x[1] for x in rets],\n", 1530 | " )\n", 1531 | ")" 1532 | ], 1533 | "execution_count": 24, 1534 | "outputs": [ 1535 | { 1536 | "output_type": "stream", 1537 | "name": "stdout", 1538 | "text": [ 1539 | "0.7447447447447447\n" 1540 | ] 1541 | } 1542 | ] 1543 | }, 1544 | { 1545 | "cell_type": "code", 1546 | "metadata": { 1547 | "id": "D8vIKsvAuCpX", 1548 | "colab": { 1549 | "base_uri": "https://localhost:8080/" 1550 | }, 1551 | "outputId": "73f9bd61-f349-41af-b08b-96c20a5ea304" 1552 | }, 1553 | "source": [ 1554 | "print(\n", 1555 | " accuracy_score(\n", 1556 | " [x[0] for x in rets],\n", 1557 | " [0] * len(rets),\n", 1558 | " ),\n", 1559 | " accuracy_score(\n", 1560 | " [x[0] for x in rets],\n", 1561 | " [1] * len(rets),\n", 1562 | " ),\n", 1563 | " accuracy_score(\n", 1564 | " [x[0] for x in rets],\n", 1565 | " [2] * len(rets),\n", 1566 | " )\n", 1567 | ")" 1568 | ], 1569 | "execution_count": 25, 1570 | "outputs": [ 1571 | { 1572 | "output_type": "stream", 1573 | "name": "stdout", 1574 | "text": [ 1575 | "0.0990990990990991 0.4944944944944945 0.4064064064064064\n" 1576 | ] 1577 | } 1578 | ] 1579 | }, 1580 | { 1581 | "cell_type": "code", 1582 | "metadata": { 1583 | "id": "1O1uwg8irIdc" 1584 | }, 1585 | "source": [ 1586 | "" 1587 | ], 1588 | "execution_count": 22, 1589 | "outputs": [] 1590 | } 1591 | ] 1592 | } -------------------------------------------------------------------------------- /mt5_soft_prompt_tuning.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "accelerator": "GPU", 6 | "colab": { 7 | "name": "mt5-soft-prompt-tuning.ipynb", 8 | "provenance": [], 9 | "collapsed_sections": [], 10 | "machine_shape": "hm" 11 | }, 12 | "kernelspec": { 13 | "display_name": "Python 3", 14 | "name": "python3" 15 | }, 16 | "language_info": { 17 | "name": "python" 18 | }, 19 | "widgets": { 20 | "application/vnd.jupyter.widget-state+json": { 21 | "88c88a80fe4244dc8519ed9a9d4a7fad": { 22 | "model_module": "@jupyter-widgets/controls", 23 | "model_name": "HBoxModel", 24 | "model_module_version": "1.5.0", 25 | "state": { 26 | "_view_name": "HBoxView", 27 | "_dom_classes": [], 28 | "_model_name": "HBoxModel", 29 | "_view_module": "@jupyter-widgets/controls", 30 | "_model_module_version": "1.5.0", 31 | "_view_count": null, 32 | "_view_module_version": "1.5.0", 33 | "box_style": "", 34 | "layout": "IPY_MODEL_012192e864f6471e9336264b77f9d047", 35 | "_model_module": "@jupyter-widgets/controls", 36 | "children": [ 37 | "IPY_MODEL_b71062f17e064547b91634af273e6d40", 38 | "IPY_MODEL_8ce567f99b614fdc99ef7053ce718b23", 39 | "IPY_MODEL_05fab04f6d6f46f6bccc8cbfa22a88a2" 40 | ] 41 | } 42 | }, 43 | "012192e864f6471e9336264b77f9d047": { 44 | "model_module": "@jupyter-widgets/base", 45 | "model_name": "LayoutModel", 46 | "model_module_version": "1.2.0", 47 | "state": { 48 | "_view_name": "LayoutView", 49 | "grid_template_rows": null, 50 | "right": null, 51 | "justify_content": null, 52 | "_view_module": "@jupyter-widgets/base", 53 | "overflow": null, 54 | "_model_module_version": "1.2.0", 55 | "_view_count": null, 56 | "flex_flow": null, 57 | "width": null, 58 | "min_width": null, 59 | "border": null, 60 | "align_items": null, 61 | "bottom": null, 62 | "_model_module": "@jupyter-widgets/base", 63 | "top": null, 64 | "grid_column": null, 65 | "overflow_y": null, 66 | "overflow_x": null, 67 | "grid_auto_flow": null, 68 | "grid_area": null, 69 | "grid_template_columns": null, 70 | "flex": null, 71 | "_model_name": "LayoutModel", 72 | "justify_items": null, 73 | "grid_row": null, 74 | "max_height": null, 75 | "align_content": null, 76 | "visibility": null, 77 | "align_self": null, 78 | "height": null, 79 | "min_height": null, 80 | "padding": null, 81 | "grid_auto_rows": null, 82 | "grid_gap": null, 83 | "max_width": null, 84 | "order": null, 85 | "_view_module_version": "1.2.0", 86 | "grid_template_areas": null, 87 | "object_position": null, 88 | "object_fit": null, 89 | "grid_auto_columns": null, 90 | "margin": null, 91 | "display": null, 92 | "left": null 93 | } 94 | }, 95 | "b71062f17e064547b91634af273e6d40": { 96 | "model_module": "@jupyter-widgets/controls", 97 | "model_name": "HTMLModel", 98 | "model_module_version": "1.5.0", 99 | "state": { 100 | "_view_name": "HTMLView", 101 | "style": "IPY_MODEL_caa18d30298b4ad7be2e9ceda2f395b8", 102 | "_dom_classes": [], 103 | "description": "", 104 | "_model_name": "HTMLModel", 105 | "placeholder": "​", 106 | "_view_module": "@jupyter-widgets/controls", 107 | "_model_module_version": "1.5.0", 108 | "value": "Downloading: 100%", 109 | "_view_count": null, 110 | "_view_module_version": "1.5.0", 111 | "description_tooltip": null, 112 | "_model_module": "@jupyter-widgets/controls", 113 | "layout": "IPY_MODEL_0f7b84c5523b48fabd4f79d8b8d80fa9" 114 | } 115 | }, 116 | "8ce567f99b614fdc99ef7053ce718b23": { 117 | "model_module": "@jupyter-widgets/controls", 118 | "model_name": "FloatProgressModel", 119 | "model_module_version": "1.5.0", 120 | "state": { 121 | "_view_name": "ProgressView", 122 | "style": "IPY_MODEL_f982a7f0aad440fbb5c4497c3fd1d15e", 123 | "_dom_classes": [], 124 | "description": "", 125 | "_model_name": "FloatProgressModel", 126 | "bar_style": "success", 127 | "max": 702, 128 | "_view_module": "@jupyter-widgets/controls", 129 | "_model_module_version": "1.5.0", 130 | "value": 702, 131 | "_view_count": null, 132 | "_view_module_version": "1.5.0", 133 | "orientation": "horizontal", 134 | "min": 0, 135 | "description_tooltip": null, 136 | "_model_module": "@jupyter-widgets/controls", 137 | "layout": "IPY_MODEL_1ebe5477942940abaf722f6963f9b0e8" 138 | } 139 | }, 140 | "05fab04f6d6f46f6bccc8cbfa22a88a2": { 141 | "model_module": "@jupyter-widgets/controls", 142 | "model_name": "HTMLModel", 143 | "model_module_version": "1.5.0", 144 | "state": { 145 | "_view_name": "HTMLView", 146 | "style": "IPY_MODEL_60d985a36d434ee996f616952a7c7612", 147 | "_dom_classes": [], 148 | "description": "", 149 | "_model_name": "HTMLModel", 150 | "placeholder": "​", 151 | "_view_module": "@jupyter-widgets/controls", 152 | "_model_module_version": "1.5.0", 153 | "value": " 702/702 [00:00<00:00, 21.8kB/s]", 154 | "_view_count": null, 155 | "_view_module_version": "1.5.0", 156 | "description_tooltip": null, 157 | "_model_module": "@jupyter-widgets/controls", 158 | "layout": "IPY_MODEL_bde84bfa5da2408ca9eda7c9b39291ad" 159 | } 160 | }, 161 | "caa18d30298b4ad7be2e9ceda2f395b8": { 162 | "model_module": "@jupyter-widgets/controls", 163 | "model_name": "DescriptionStyleModel", 164 | "model_module_version": "1.5.0", 165 | "state": { 166 | "_view_name": "StyleView", 167 | "_model_name": "DescriptionStyleModel", 168 | "description_width": "", 169 | "_view_module": "@jupyter-widgets/base", 170 | "_model_module_version": "1.5.0", 171 | "_view_count": null, 172 | "_view_module_version": "1.2.0", 173 | "_model_module": "@jupyter-widgets/controls" 174 | } 175 | }, 176 | "0f7b84c5523b48fabd4f79d8b8d80fa9": { 177 | "model_module": "@jupyter-widgets/base", 178 | "model_name": "LayoutModel", 179 | "model_module_version": "1.2.0", 180 | "state": { 181 | "_view_name": "LayoutView", 182 | "grid_template_rows": null, 183 | "right": null, 184 | "justify_content": null, 185 | "_view_module": "@jupyter-widgets/base", 186 | "overflow": null, 187 | "_model_module_version": "1.2.0", 188 | "_view_count": null, 189 | "flex_flow": null, 190 | "width": null, 191 | "min_width": null, 192 | "border": null, 193 | "align_items": null, 194 | "bottom": null, 195 | "_model_module": "@jupyter-widgets/base", 196 | "top": null, 197 | "grid_column": null, 198 | "overflow_y": null, 199 | "overflow_x": null, 200 | "grid_auto_flow": null, 201 | "grid_area": null, 202 | "grid_template_columns": null, 203 | "flex": null, 204 | "_model_name": "LayoutModel", 205 | "justify_items": null, 206 | "grid_row": null, 207 | "max_height": null, 208 | "align_content": null, 209 | "visibility": null, 210 | "align_self": null, 211 | "height": null, 212 | "min_height": null, 213 | "padding": null, 214 | "grid_auto_rows": null, 215 | "grid_gap": null, 216 | "max_width": null, 217 | "order": null, 218 | "_view_module_version": "1.2.0", 219 | "grid_template_areas": null, 220 | "object_position": null, 221 | "object_fit": null, 222 | "grid_auto_columns": null, 223 | "margin": null, 224 | "display": null, 225 | "left": null 226 | } 227 | }, 228 | "f982a7f0aad440fbb5c4497c3fd1d15e": { 229 | "model_module": "@jupyter-widgets/controls", 230 | "model_name": "ProgressStyleModel", 231 | "model_module_version": "1.5.0", 232 | "state": { 233 | "_view_name": "StyleView", 234 | "_model_name": "ProgressStyleModel", 235 | "description_width": "", 236 | "_view_module": "@jupyter-widgets/base", 237 | "_model_module_version": "1.5.0", 238 | "_view_count": null, 239 | "_view_module_version": "1.2.0", 240 | "bar_color": null, 241 | "_model_module": "@jupyter-widgets/controls" 242 | } 243 | }, 244 | "1ebe5477942940abaf722f6963f9b0e8": { 245 | "model_module": "@jupyter-widgets/base", 246 | "model_name": "LayoutModel", 247 | "model_module_version": "1.2.0", 248 | "state": { 249 | "_view_name": "LayoutView", 250 | "grid_template_rows": null, 251 | "right": null, 252 | "justify_content": null, 253 | "_view_module": "@jupyter-widgets/base", 254 | "overflow": null, 255 | "_model_module_version": "1.2.0", 256 | "_view_count": null, 257 | "flex_flow": null, 258 | "width": null, 259 | "min_width": null, 260 | "border": null, 261 | "align_items": null, 262 | "bottom": null, 263 | "_model_module": "@jupyter-widgets/base", 264 | "top": null, 265 | "grid_column": null, 266 | "overflow_y": null, 267 | "overflow_x": null, 268 | "grid_auto_flow": null, 269 | "grid_area": null, 270 | "grid_template_columns": null, 271 | "flex": null, 272 | "_model_name": "LayoutModel", 273 | "justify_items": null, 274 | "grid_row": null, 275 | "max_height": null, 276 | "align_content": null, 277 | "visibility": null, 278 | "align_self": null, 279 | "height": null, 280 | "min_height": null, 281 | "padding": null, 282 | "grid_auto_rows": null, 283 | "grid_gap": null, 284 | "max_width": null, 285 | "order": null, 286 | "_view_module_version": "1.2.0", 287 | "grid_template_areas": null, 288 | "object_position": null, 289 | "object_fit": null, 290 | "grid_auto_columns": null, 291 | "margin": null, 292 | "display": null, 293 | "left": null 294 | } 295 | }, 296 | "60d985a36d434ee996f616952a7c7612": { 297 | "model_module": "@jupyter-widgets/controls", 298 | "model_name": "DescriptionStyleModel", 299 | "model_module_version": "1.5.0", 300 | "state": { 301 | "_view_name": "StyleView", 302 | "_model_name": "DescriptionStyleModel", 303 | "description_width": "", 304 | "_view_module": "@jupyter-widgets/base", 305 | "_model_module_version": "1.5.0", 306 | "_view_count": null, 307 | "_view_module_version": "1.2.0", 308 | "_model_module": "@jupyter-widgets/controls" 309 | } 310 | }, 311 | "bde84bfa5da2408ca9eda7c9b39291ad": { 312 | "model_module": "@jupyter-widgets/base", 313 | "model_name": "LayoutModel", 314 | "model_module_version": "1.2.0", 315 | "state": { 316 | "_view_name": "LayoutView", 317 | "grid_template_rows": null, 318 | "right": null, 319 | "justify_content": null, 320 | "_view_module": "@jupyter-widgets/base", 321 | "overflow": null, 322 | "_model_module_version": "1.2.0", 323 | "_view_count": null, 324 | "flex_flow": null, 325 | "width": null, 326 | "min_width": null, 327 | "border": null, 328 | "align_items": null, 329 | "bottom": null, 330 | "_model_module": "@jupyter-widgets/base", 331 | "top": null, 332 | "grid_column": null, 333 | "overflow_y": null, 334 | "overflow_x": null, 335 | "grid_auto_flow": null, 336 | "grid_area": null, 337 | "grid_template_columns": null, 338 | "flex": null, 339 | "_model_name": "LayoutModel", 340 | "justify_items": null, 341 | "grid_row": null, 342 | "max_height": null, 343 | "align_content": null, 344 | "visibility": null, 345 | "align_self": null, 346 | "height": null, 347 | "min_height": null, 348 | "padding": null, 349 | "grid_auto_rows": null, 350 | "grid_gap": null, 351 | "max_width": null, 352 | "order": null, 353 | "_view_module_version": "1.2.0", 354 | "grid_template_areas": null, 355 | "object_position": null, 356 | "object_fit": null, 357 | "grid_auto_columns": null, 358 | "margin": null, 359 | "display": null, 360 | "left": null 361 | } 362 | }, 363 | "ac7114b8298b45bda89d9e822d0f0ffd": { 364 | "model_module": "@jupyter-widgets/controls", 365 | "model_name": "HBoxModel", 366 | "model_module_version": "1.5.0", 367 | "state": { 368 | "_view_name": "HBoxView", 369 | "_dom_classes": [], 370 | "_model_name": "HBoxModel", 371 | "_view_module": "@jupyter-widgets/controls", 372 | "_model_module_version": "1.5.0", 373 | "_view_count": null, 374 | "_view_module_version": "1.5.0", 375 | "box_style": "", 376 | "layout": "IPY_MODEL_e6ae18a1e8144dada46879db37e8cdef", 377 | "_model_module": "@jupyter-widgets/controls", 378 | "children": [ 379 | "IPY_MODEL_702c7b9308b646c1bdd0ca2387776c25", 380 | "IPY_MODEL_686aca97abba439e81e430f1c33224dc", 381 | "IPY_MODEL_61ee6d1f21bf41c89488dbcfb99f4d5d" 382 | ] 383 | } 384 | }, 385 | "e6ae18a1e8144dada46879db37e8cdef": { 386 | "model_module": "@jupyter-widgets/base", 387 | "model_name": "LayoutModel", 388 | "model_module_version": "1.2.0", 389 | "state": { 390 | "_view_name": "LayoutView", 391 | "grid_template_rows": null, 392 | "right": null, 393 | "justify_content": null, 394 | "_view_module": "@jupyter-widgets/base", 395 | "overflow": null, 396 | "_model_module_version": "1.2.0", 397 | "_view_count": null, 398 | "flex_flow": null, 399 | "width": null, 400 | "min_width": null, 401 | "border": null, 402 | "align_items": null, 403 | "bottom": null, 404 | "_model_module": "@jupyter-widgets/base", 405 | "top": null, 406 | "grid_column": null, 407 | "overflow_y": null, 408 | "overflow_x": null, 409 | "grid_auto_flow": null, 410 | "grid_area": null, 411 | "grid_template_columns": null, 412 | "flex": null, 413 | "_model_name": "LayoutModel", 414 | "justify_items": null, 415 | "grid_row": null, 416 | "max_height": null, 417 | "align_content": null, 418 | "visibility": null, 419 | "align_self": null, 420 | "height": null, 421 | "min_height": null, 422 | "padding": null, 423 | "grid_auto_rows": null, 424 | "grid_gap": null, 425 | "max_width": null, 426 | "order": null, 427 | "_view_module_version": "1.2.0", 428 | "grid_template_areas": null, 429 | "object_position": null, 430 | "object_fit": null, 431 | "grid_auto_columns": null, 432 | "margin": null, 433 | "display": null, 434 | "left": null 435 | } 436 | }, 437 | "702c7b9308b646c1bdd0ca2387776c25": { 438 | "model_module": "@jupyter-widgets/controls", 439 | "model_name": "HTMLModel", 440 | "model_module_version": "1.5.0", 441 | "state": { 442 | "_view_name": "HTMLView", 443 | "style": "IPY_MODEL_abd066dddd434d63ae737bb56c67de79", 444 | "_dom_classes": [], 445 | "description": "", 446 | "_model_name": "HTMLModel", 447 | "placeholder": "​", 448 | "_view_module": "@jupyter-widgets/controls", 449 | "_model_module_version": "1.5.0", 450 | "value": "Downloading: 100%", 451 | "_view_count": null, 452 | "_view_module_version": "1.5.0", 453 | "description_tooltip": null, 454 | "_model_module": "@jupyter-widgets/controls", 455 | "layout": "IPY_MODEL_9269f66db7d24a6da268fa98c33d6f02" 456 | } 457 | }, 458 | "686aca97abba439e81e430f1c33224dc": { 459 | "model_module": "@jupyter-widgets/controls", 460 | "model_name": "FloatProgressModel", 461 | "model_module_version": "1.5.0", 462 | "state": { 463 | "_view_name": "ProgressView", 464 | "style": "IPY_MODEL_81d5b9f593dd4f36894e32a5d0360d1d", 465 | "_dom_classes": [], 466 | "description": "", 467 | "_model_name": "FloatProgressModel", 468 | "bar_style": "success", 469 | "max": 2329735129, 470 | "_view_module": "@jupyter-widgets/controls", 471 | "_model_module_version": "1.5.0", 472 | "value": 2329735129, 473 | "_view_count": null, 474 | "_view_module_version": "1.5.0", 475 | "orientation": "horizontal", 476 | "min": 0, 477 | "description_tooltip": null, 478 | "_model_module": "@jupyter-widgets/controls", 479 | "layout": "IPY_MODEL_f15396c431f34835bbe2eeb65dfd16a4" 480 | } 481 | }, 482 | "61ee6d1f21bf41c89488dbcfb99f4d5d": { 483 | "model_module": "@jupyter-widgets/controls", 484 | "model_name": "HTMLModel", 485 | "model_module_version": "1.5.0", 486 | "state": { 487 | "_view_name": "HTMLView", 488 | "style": "IPY_MODEL_0ce264acdaa8432b9715bdc42344b0aa", 489 | "_dom_classes": [], 490 | "description": "", 491 | "_model_name": "HTMLModel", 492 | "placeholder": "​", 493 | "_view_module": "@jupyter-widgets/controls", 494 | "_model_module_version": "1.5.0", 495 | "value": " 2.33G/2.33G [00:57<00:00, 29.6MB/s]", 496 | "_view_count": null, 497 | "_view_module_version": "1.5.0", 498 | "description_tooltip": null, 499 | "_model_module": "@jupyter-widgets/controls", 500 | "layout": "IPY_MODEL_5a46f6fdc4784422b652653b05d1c691" 501 | } 502 | }, 503 | "abd066dddd434d63ae737bb56c67de79": { 504 | "model_module": "@jupyter-widgets/controls", 505 | "model_name": "DescriptionStyleModel", 506 | "model_module_version": "1.5.0", 507 | "state": { 508 | "_view_name": "StyleView", 509 | "_model_name": "DescriptionStyleModel", 510 | "description_width": "", 511 | "_view_module": "@jupyter-widgets/base", 512 | "_model_module_version": "1.5.0", 513 | "_view_count": null, 514 | "_view_module_version": "1.2.0", 515 | "_model_module": "@jupyter-widgets/controls" 516 | } 517 | }, 518 | "9269f66db7d24a6da268fa98c33d6f02": { 519 | "model_module": "@jupyter-widgets/base", 520 | "model_name": "LayoutModel", 521 | "model_module_version": "1.2.0", 522 | "state": { 523 | "_view_name": "LayoutView", 524 | "grid_template_rows": null, 525 | "right": null, 526 | "justify_content": null, 527 | "_view_module": "@jupyter-widgets/base", 528 | "overflow": null, 529 | "_model_module_version": "1.2.0", 530 | "_view_count": null, 531 | "flex_flow": null, 532 | "width": null, 533 | "min_width": null, 534 | "border": null, 535 | "align_items": null, 536 | "bottom": null, 537 | "_model_module": "@jupyter-widgets/base", 538 | "top": null, 539 | "grid_column": null, 540 | "overflow_y": null, 541 | "overflow_x": null, 542 | "grid_auto_flow": null, 543 | "grid_area": null, 544 | "grid_template_columns": null, 545 | "flex": null, 546 | "_model_name": "LayoutModel", 547 | "justify_items": null, 548 | "grid_row": null, 549 | "max_height": null, 550 | "align_content": null, 551 | "visibility": null, 552 | "align_self": null, 553 | "height": null, 554 | "min_height": null, 555 | "padding": null, 556 | "grid_auto_rows": null, 557 | "grid_gap": null, 558 | "max_width": null, 559 | "order": null, 560 | "_view_module_version": "1.2.0", 561 | "grid_template_areas": null, 562 | "object_position": null, 563 | "object_fit": null, 564 | "grid_auto_columns": null, 565 | "margin": null, 566 | "display": null, 567 | "left": null 568 | } 569 | }, 570 | "81d5b9f593dd4f36894e32a5d0360d1d": { 571 | "model_module": "@jupyter-widgets/controls", 572 | "model_name": "ProgressStyleModel", 573 | "model_module_version": "1.5.0", 574 | "state": { 575 | "_view_name": "StyleView", 576 | "_model_name": "ProgressStyleModel", 577 | "description_width": "", 578 | "_view_module": "@jupyter-widgets/base", 579 | "_model_module_version": "1.5.0", 580 | "_view_count": null, 581 | "_view_module_version": "1.2.0", 582 | "bar_color": null, 583 | "_model_module": "@jupyter-widgets/controls" 584 | } 585 | }, 586 | "f15396c431f34835bbe2eeb65dfd16a4": { 587 | "model_module": "@jupyter-widgets/base", 588 | "model_name": "LayoutModel", 589 | "model_module_version": "1.2.0", 590 | "state": { 591 | "_view_name": "LayoutView", 592 | "grid_template_rows": null, 593 | "right": null, 594 | "justify_content": null, 595 | "_view_module": "@jupyter-widgets/base", 596 | "overflow": null, 597 | "_model_module_version": "1.2.0", 598 | "_view_count": null, 599 | "flex_flow": null, 600 | "width": null, 601 | "min_width": null, 602 | "border": null, 603 | "align_items": null, 604 | "bottom": null, 605 | "_model_module": "@jupyter-widgets/base", 606 | "top": null, 607 | "grid_column": null, 608 | "overflow_y": null, 609 | "overflow_x": null, 610 | "grid_auto_flow": null, 611 | "grid_area": null, 612 | "grid_template_columns": null, 613 | "flex": null, 614 | "_model_name": "LayoutModel", 615 | "justify_items": null, 616 | "grid_row": null, 617 | "max_height": null, 618 | "align_content": null, 619 | "visibility": null, 620 | "align_self": null, 621 | "height": null, 622 | "min_height": null, 623 | "padding": null, 624 | "grid_auto_rows": null, 625 | "grid_gap": null, 626 | "max_width": null, 627 | "order": null, 628 | "_view_module_version": "1.2.0", 629 | "grid_template_areas": null, 630 | "object_position": null, 631 | "object_fit": null, 632 | "grid_auto_columns": null, 633 | "margin": null, 634 | "display": null, 635 | "left": null 636 | } 637 | }, 638 | "0ce264acdaa8432b9715bdc42344b0aa": { 639 | "model_module": "@jupyter-widgets/controls", 640 | "model_name": "DescriptionStyleModel", 641 | "model_module_version": "1.5.0", 642 | "state": { 643 | "_view_name": "StyleView", 644 | "_model_name": "DescriptionStyleModel", 645 | "description_width": "", 646 | "_view_module": "@jupyter-widgets/base", 647 | "_model_module_version": "1.5.0", 648 | "_view_count": null, 649 | "_view_module_version": "1.2.0", 650 | "_model_module": "@jupyter-widgets/controls" 651 | } 652 | }, 653 | "5a46f6fdc4784422b652653b05d1c691": { 654 | "model_module": "@jupyter-widgets/base", 655 | "model_name": "LayoutModel", 656 | "model_module_version": "1.2.0", 657 | "state": { 658 | "_view_name": "LayoutView", 659 | "grid_template_rows": null, 660 | "right": null, 661 | "justify_content": null, 662 | "_view_module": "@jupyter-widgets/base", 663 | "overflow": null, 664 | "_model_module_version": "1.2.0", 665 | "_view_count": null, 666 | "flex_flow": null, 667 | "width": null, 668 | "min_width": null, 669 | "border": null, 670 | "align_items": null, 671 | "bottom": null, 672 | "_model_module": "@jupyter-widgets/base", 673 | "top": null, 674 | "grid_column": null, 675 | "overflow_y": null, 676 | "overflow_x": null, 677 | "grid_auto_flow": null, 678 | "grid_area": null, 679 | "grid_template_columns": null, 680 | "flex": null, 681 | "_model_name": "LayoutModel", 682 | "justify_items": null, 683 | "grid_row": null, 684 | "max_height": null, 685 | "align_content": null, 686 | "visibility": null, 687 | "align_self": null, 688 | "height": null, 689 | "min_height": null, 690 | "padding": null, 691 | "grid_auto_rows": null, 692 | "grid_gap": null, 693 | "max_width": null, 694 | "order": null, 695 | "_view_module_version": "1.2.0", 696 | "grid_template_areas": null, 697 | "object_position": null, 698 | "object_fit": null, 699 | "grid_auto_columns": null, 700 | "margin": null, 701 | "display": null, 702 | "left": null 703 | } 704 | }, 705 | "268502a8a7204b5cbbae12de9c63ed1c": { 706 | "model_module": "@jupyter-widgets/controls", 707 | "model_name": "HBoxModel", 708 | "model_module_version": "1.5.0", 709 | "state": { 710 | "_view_name": "HBoxView", 711 | "_dom_classes": [], 712 | "_model_name": "HBoxModel", 713 | "_view_module": "@jupyter-widgets/controls", 714 | "_model_module_version": "1.5.0", 715 | "_view_count": null, 716 | "_view_module_version": "1.5.0", 717 | "box_style": "", 718 | "layout": "IPY_MODEL_71118489f6ff49a989d91f5325359b77", 719 | "_model_module": "@jupyter-widgets/controls", 720 | "children": [ 721 | "IPY_MODEL_542cfc0efdeb436da7d4c789ae3a722d", 722 | "IPY_MODEL_f77f512b44bf481cb8846404350b9ab2", 723 | "IPY_MODEL_2d4e4125f1ed417bb2cfe878468b2966" 724 | ] 725 | } 726 | }, 727 | "71118489f6ff49a989d91f5325359b77": { 728 | "model_module": "@jupyter-widgets/base", 729 | "model_name": "LayoutModel", 730 | "model_module_version": "1.2.0", 731 | "state": { 732 | "_view_name": "LayoutView", 733 | "grid_template_rows": null, 734 | "right": null, 735 | "justify_content": null, 736 | "_view_module": "@jupyter-widgets/base", 737 | "overflow": null, 738 | "_model_module_version": "1.2.0", 739 | "_view_count": null, 740 | "flex_flow": null, 741 | "width": null, 742 | "min_width": null, 743 | "border": null, 744 | "align_items": null, 745 | "bottom": null, 746 | "_model_module": "@jupyter-widgets/base", 747 | "top": null, 748 | "grid_column": null, 749 | "overflow_y": null, 750 | "overflow_x": null, 751 | "grid_auto_flow": null, 752 | "grid_area": null, 753 | "grid_template_columns": null, 754 | "flex": null, 755 | "_model_name": "LayoutModel", 756 | "justify_items": null, 757 | "grid_row": null, 758 | "max_height": null, 759 | "align_content": null, 760 | "visibility": null, 761 | "align_self": null, 762 | "height": null, 763 | "min_height": null, 764 | "padding": null, 765 | "grid_auto_rows": null, 766 | "grid_gap": null, 767 | "max_width": null, 768 | "order": null, 769 | "_view_module_version": "1.2.0", 770 | "grid_template_areas": null, 771 | "object_position": null, 772 | "object_fit": null, 773 | "grid_auto_columns": null, 774 | "margin": null, 775 | "display": null, 776 | "left": null 777 | } 778 | }, 779 | "542cfc0efdeb436da7d4c789ae3a722d": { 780 | "model_module": "@jupyter-widgets/controls", 781 | "model_name": "HTMLModel", 782 | "model_module_version": "1.5.0", 783 | "state": { 784 | "_view_name": "HTMLView", 785 | "style": "IPY_MODEL_354e1cdb54874df192f9f4f190035efa", 786 | "_dom_classes": [], 787 | "description": "", 788 | "_model_name": "HTMLModel", 789 | "placeholder": "​", 790 | "_view_module": "@jupyter-widgets/controls", 791 | "_model_module_version": "1.5.0", 792 | "value": "Downloading: 100%", 793 | "_view_count": null, 794 | "_view_module_version": "1.5.0", 795 | "description_tooltip": null, 796 | "_model_module": "@jupyter-widgets/controls", 797 | "layout": "IPY_MODEL_3e2b5c8ec31149838dcacfcc78f2087a" 798 | } 799 | }, 800 | "f77f512b44bf481cb8846404350b9ab2": { 801 | "model_module": "@jupyter-widgets/controls", 802 | "model_name": "FloatProgressModel", 803 | "model_module_version": "1.5.0", 804 | "state": { 805 | "_view_name": "ProgressView", 806 | "style": "IPY_MODEL_562e3a57010c4ca5af045c3c33500a28", 807 | "_dom_classes": [], 808 | "description": "", 809 | "_model_name": "FloatProgressModel", 810 | "bar_style": "success", 811 | "max": 4309802, 812 | "_view_module": "@jupyter-widgets/controls", 813 | "_model_module_version": "1.5.0", 814 | "value": 4309802, 815 | "_view_count": null, 816 | "_view_module_version": "1.5.0", 817 | "orientation": "horizontal", 818 | "min": 0, 819 | "description_tooltip": null, 820 | "_model_module": "@jupyter-widgets/controls", 821 | "layout": "IPY_MODEL_53b33d44e6354bb6af52775f1ed6b9dd" 822 | } 823 | }, 824 | "2d4e4125f1ed417bb2cfe878468b2966": { 825 | "model_module": "@jupyter-widgets/controls", 826 | "model_name": "HTMLModel", 827 | "model_module_version": "1.5.0", 828 | "state": { 829 | "_view_name": "HTMLView", 830 | "style": "IPY_MODEL_e26e613b94f94d539d027c3ceb496903", 831 | "_dom_classes": [], 832 | "description": "", 833 | "_model_name": "HTMLModel", 834 | "placeholder": "​", 835 | "_view_module": "@jupyter-widgets/controls", 836 | "_model_module_version": "1.5.0", 837 | "value": " 4.31M/4.31M [00:00<00:00, 6.60MB/s]", 838 | "_view_count": null, 839 | "_view_module_version": "1.5.0", 840 | "description_tooltip": null, 841 | "_model_module": "@jupyter-widgets/controls", 842 | "layout": "IPY_MODEL_65b77629789f45b6ab5105f893c7b1cc" 843 | } 844 | }, 845 | "354e1cdb54874df192f9f4f190035efa": { 846 | "model_module": "@jupyter-widgets/controls", 847 | "model_name": "DescriptionStyleModel", 848 | "model_module_version": "1.5.0", 849 | "state": { 850 | "_view_name": "StyleView", 851 | "_model_name": "DescriptionStyleModel", 852 | "description_width": "", 853 | "_view_module": "@jupyter-widgets/base", 854 | "_model_module_version": "1.5.0", 855 | "_view_count": null, 856 | "_view_module_version": "1.2.0", 857 | "_model_module": "@jupyter-widgets/controls" 858 | } 859 | }, 860 | "3e2b5c8ec31149838dcacfcc78f2087a": { 861 | "model_module": "@jupyter-widgets/base", 862 | "model_name": "LayoutModel", 863 | "model_module_version": "1.2.0", 864 | "state": { 865 | "_view_name": "LayoutView", 866 | "grid_template_rows": null, 867 | "right": null, 868 | "justify_content": null, 869 | "_view_module": "@jupyter-widgets/base", 870 | "overflow": null, 871 | "_model_module_version": "1.2.0", 872 | "_view_count": null, 873 | "flex_flow": null, 874 | "width": null, 875 | "min_width": null, 876 | "border": null, 877 | "align_items": null, 878 | "bottom": null, 879 | "_model_module": "@jupyter-widgets/base", 880 | "top": null, 881 | "grid_column": null, 882 | "overflow_y": null, 883 | "overflow_x": null, 884 | "grid_auto_flow": null, 885 | "grid_area": null, 886 | "grid_template_columns": null, 887 | "flex": null, 888 | "_model_name": "LayoutModel", 889 | "justify_items": null, 890 | "grid_row": null, 891 | "max_height": null, 892 | "align_content": null, 893 | "visibility": null, 894 | "align_self": null, 895 | "height": null, 896 | "min_height": null, 897 | "padding": null, 898 | "grid_auto_rows": null, 899 | "grid_gap": null, 900 | "max_width": null, 901 | "order": null, 902 | "_view_module_version": "1.2.0", 903 | "grid_template_areas": null, 904 | "object_position": null, 905 | "object_fit": null, 906 | "grid_auto_columns": null, 907 | "margin": null, 908 | "display": null, 909 | "left": null 910 | } 911 | }, 912 | "562e3a57010c4ca5af045c3c33500a28": { 913 | "model_module": "@jupyter-widgets/controls", 914 | "model_name": "ProgressStyleModel", 915 | "model_module_version": "1.5.0", 916 | "state": { 917 | "_view_name": "StyleView", 918 | "_model_name": "ProgressStyleModel", 919 | "description_width": "", 920 | "_view_module": "@jupyter-widgets/base", 921 | "_model_module_version": "1.5.0", 922 | "_view_count": null, 923 | "_view_module_version": "1.2.0", 924 | "bar_color": null, 925 | "_model_module": "@jupyter-widgets/controls" 926 | } 927 | }, 928 | "53b33d44e6354bb6af52775f1ed6b9dd": { 929 | "model_module": "@jupyter-widgets/base", 930 | "model_name": "LayoutModel", 931 | "model_module_version": "1.2.0", 932 | "state": { 933 | "_view_name": "LayoutView", 934 | "grid_template_rows": null, 935 | "right": null, 936 | "justify_content": null, 937 | "_view_module": "@jupyter-widgets/base", 938 | "overflow": null, 939 | "_model_module_version": "1.2.0", 940 | "_view_count": null, 941 | "flex_flow": null, 942 | "width": null, 943 | "min_width": null, 944 | "border": null, 945 | "align_items": null, 946 | "bottom": null, 947 | "_model_module": "@jupyter-widgets/base", 948 | "top": null, 949 | "grid_column": null, 950 | "overflow_y": null, 951 | "overflow_x": null, 952 | "grid_auto_flow": null, 953 | "grid_area": null, 954 | "grid_template_columns": null, 955 | "flex": null, 956 | "_model_name": "LayoutModel", 957 | "justify_items": null, 958 | "grid_row": null, 959 | "max_height": null, 960 | "align_content": null, 961 | "visibility": null, 962 | "align_self": null, 963 | "height": null, 964 | "min_height": null, 965 | "padding": null, 966 | "grid_auto_rows": null, 967 | "grid_gap": null, 968 | "max_width": null, 969 | "order": null, 970 | "_view_module_version": "1.2.0", 971 | "grid_template_areas": null, 972 | "object_position": null, 973 | "object_fit": null, 974 | "grid_auto_columns": null, 975 | "margin": null, 976 | "display": null, 977 | "left": null 978 | } 979 | }, 980 | "e26e613b94f94d539d027c3ceb496903": { 981 | "model_module": "@jupyter-widgets/controls", 982 | "model_name": "DescriptionStyleModel", 983 | "model_module_version": "1.5.0", 984 | "state": { 985 | "_view_name": "StyleView", 986 | "_model_name": "DescriptionStyleModel", 987 | "description_width": "", 988 | "_view_module": "@jupyter-widgets/base", 989 | "_model_module_version": "1.5.0", 990 | "_view_count": null, 991 | "_view_module_version": "1.2.0", 992 | "_model_module": "@jupyter-widgets/controls" 993 | } 994 | }, 995 | "65b77629789f45b6ab5105f893c7b1cc": { 996 | "model_module": "@jupyter-widgets/base", 997 | "model_name": "LayoutModel", 998 | "model_module_version": "1.2.0", 999 | "state": { 1000 | "_view_name": "LayoutView", 1001 | "grid_template_rows": null, 1002 | "right": null, 1003 | "justify_content": null, 1004 | "_view_module": "@jupyter-widgets/base", 1005 | "overflow": null, 1006 | "_model_module_version": "1.2.0", 1007 | "_view_count": null, 1008 | "flex_flow": null, 1009 | "width": null, 1010 | "min_width": null, 1011 | "border": null, 1012 | "align_items": null, 1013 | "bottom": null, 1014 | "_model_module": "@jupyter-widgets/base", 1015 | "top": null, 1016 | "grid_column": null, 1017 | "overflow_y": null, 1018 | "overflow_x": null, 1019 | "grid_auto_flow": null, 1020 | "grid_area": null, 1021 | "grid_template_columns": null, 1022 | "flex": null, 1023 | "_model_name": "LayoutModel", 1024 | "justify_items": null, 1025 | "grid_row": null, 1026 | "max_height": null, 1027 | "align_content": null, 1028 | "visibility": null, 1029 | "align_self": null, 1030 | "height": null, 1031 | "min_height": null, 1032 | "padding": null, 1033 | "grid_auto_rows": null, 1034 | "grid_gap": null, 1035 | "max_width": null, 1036 | "order": null, 1037 | "_view_module_version": "1.2.0", 1038 | "grid_template_areas": null, 1039 | "object_position": null, 1040 | "object_fit": null, 1041 | "grid_auto_columns": null, 1042 | "margin": null, 1043 | "display": null, 1044 | "left": null 1045 | } 1046 | }, 1047 | "c95d9ed3d1ed43bebf121392906414fa": { 1048 | "model_module": "@jupyter-widgets/controls", 1049 | "model_name": "HBoxModel", 1050 | "model_module_version": "1.5.0", 1051 | "state": { 1052 | "_view_name": "HBoxView", 1053 | "_dom_classes": [], 1054 | "_model_name": "HBoxModel", 1055 | "_view_module": "@jupyter-widgets/controls", 1056 | "_model_module_version": "1.5.0", 1057 | "_view_count": null, 1058 | "_view_module_version": "1.5.0", 1059 | "box_style": "", 1060 | "layout": "IPY_MODEL_f258c1cc75474d1392a9368c51550c45", 1061 | "_model_module": "@jupyter-widgets/controls", 1062 | "children": [ 1063 | "IPY_MODEL_2c9e861a76054c738afefb8620500773", 1064 | "IPY_MODEL_e6af025bf3744cb2a3017199a8dc185c", 1065 | "IPY_MODEL_d3371fb6455440d48684cfe9d4829eff" 1066 | ] 1067 | } 1068 | }, 1069 | "f258c1cc75474d1392a9368c51550c45": { 1070 | "model_module": "@jupyter-widgets/base", 1071 | "model_name": "LayoutModel", 1072 | "model_module_version": "1.2.0", 1073 | "state": { 1074 | "_view_name": "LayoutView", 1075 | "grid_template_rows": null, 1076 | "right": null, 1077 | "justify_content": null, 1078 | "_view_module": "@jupyter-widgets/base", 1079 | "overflow": null, 1080 | "_model_module_version": "1.2.0", 1081 | "_view_count": null, 1082 | "flex_flow": null, 1083 | "width": null, 1084 | "min_width": null, 1085 | "border": null, 1086 | "align_items": null, 1087 | "bottom": null, 1088 | "_model_module": "@jupyter-widgets/base", 1089 | "top": null, 1090 | "grid_column": null, 1091 | "overflow_y": null, 1092 | "overflow_x": null, 1093 | "grid_auto_flow": null, 1094 | "grid_area": null, 1095 | "grid_template_columns": null, 1096 | "flex": null, 1097 | "_model_name": "LayoutModel", 1098 | "justify_items": null, 1099 | "grid_row": null, 1100 | "max_height": null, 1101 | "align_content": null, 1102 | "visibility": null, 1103 | "align_self": null, 1104 | "height": null, 1105 | "min_height": null, 1106 | "padding": null, 1107 | "grid_auto_rows": null, 1108 | "grid_gap": null, 1109 | "max_width": null, 1110 | "order": null, 1111 | "_view_module_version": "1.2.0", 1112 | "grid_template_areas": null, 1113 | "object_position": null, 1114 | "object_fit": null, 1115 | "grid_auto_columns": null, 1116 | "margin": null, 1117 | "display": null, 1118 | "left": null 1119 | } 1120 | }, 1121 | "2c9e861a76054c738afefb8620500773": { 1122 | "model_module": "@jupyter-widgets/controls", 1123 | "model_name": "HTMLModel", 1124 | "model_module_version": "1.5.0", 1125 | "state": { 1126 | "_view_name": "HTMLView", 1127 | "style": "IPY_MODEL_33dab107c44d407689a375e420f8710d", 1128 | "_dom_classes": [], 1129 | "description": "", 1130 | "_model_name": "HTMLModel", 1131 | "placeholder": "​", 1132 | "_view_module": "@jupyter-widgets/controls", 1133 | "_model_module_version": "1.5.0", 1134 | "value": "Downloading: 100%", 1135 | "_view_count": null, 1136 | "_view_module_version": "1.5.0", 1137 | "description_tooltip": null, 1138 | "_model_module": "@jupyter-widgets/controls", 1139 | "layout": "IPY_MODEL_3649756bf03e40698c7dc32941daada0" 1140 | } 1141 | }, 1142 | "e6af025bf3744cb2a3017199a8dc185c": { 1143 | "model_module": "@jupyter-widgets/controls", 1144 | "model_name": "FloatProgressModel", 1145 | "model_module_version": "1.5.0", 1146 | "state": { 1147 | "_view_name": "ProgressView", 1148 | "style": "IPY_MODEL_1a6214d3c8984950a0a86fb5fcdf0e6d", 1149 | "_dom_classes": [], 1150 | "description": "", 1151 | "_model_name": "FloatProgressModel", 1152 | "bar_style": "success", 1153 | "max": 65, 1154 | "_view_module": "@jupyter-widgets/controls", 1155 | "_model_module_version": "1.5.0", 1156 | "value": 65, 1157 | "_view_count": null, 1158 | "_view_module_version": "1.5.0", 1159 | "orientation": "horizontal", 1160 | "min": 0, 1161 | "description_tooltip": null, 1162 | "_model_module": "@jupyter-widgets/controls", 1163 | "layout": "IPY_MODEL_e2a23791a7a5484fb3e3b08931555a3f" 1164 | } 1165 | }, 1166 | "d3371fb6455440d48684cfe9d4829eff": { 1167 | "model_module": "@jupyter-widgets/controls", 1168 | "model_name": "HTMLModel", 1169 | "model_module_version": "1.5.0", 1170 | "state": { 1171 | "_view_name": "HTMLView", 1172 | "style": "IPY_MODEL_5f7b4216db414a2a8c024b814049d1c1", 1173 | "_dom_classes": [], 1174 | "description": "", 1175 | "_model_name": "HTMLModel", 1176 | "placeholder": "​", 1177 | "_view_module": "@jupyter-widgets/controls", 1178 | "_model_module_version": "1.5.0", 1179 | "value": " 65.0/65.0 [00:00<00:00, 2.42kB/s]", 1180 | "_view_count": null, 1181 | "_view_module_version": "1.5.0", 1182 | "description_tooltip": null, 1183 | "_model_module": "@jupyter-widgets/controls", 1184 | "layout": "IPY_MODEL_9e6dc5881f1a446d915fa00ef1f778ce" 1185 | } 1186 | }, 1187 | "33dab107c44d407689a375e420f8710d": { 1188 | "model_module": "@jupyter-widgets/controls", 1189 | "model_name": "DescriptionStyleModel", 1190 | "model_module_version": "1.5.0", 1191 | "state": { 1192 | "_view_name": "StyleView", 1193 | "_model_name": "DescriptionStyleModel", 1194 | "description_width": "", 1195 | "_view_module": "@jupyter-widgets/base", 1196 | "_model_module_version": "1.5.0", 1197 | "_view_count": null, 1198 | "_view_module_version": "1.2.0", 1199 | "_model_module": "@jupyter-widgets/controls" 1200 | } 1201 | }, 1202 | "3649756bf03e40698c7dc32941daada0": { 1203 | "model_module": "@jupyter-widgets/base", 1204 | "model_name": "LayoutModel", 1205 | "model_module_version": "1.2.0", 1206 | "state": { 1207 | "_view_name": "LayoutView", 1208 | "grid_template_rows": null, 1209 | "right": null, 1210 | "justify_content": null, 1211 | "_view_module": "@jupyter-widgets/base", 1212 | "overflow": null, 1213 | "_model_module_version": "1.2.0", 1214 | "_view_count": null, 1215 | "flex_flow": null, 1216 | "width": null, 1217 | "min_width": null, 1218 | "border": null, 1219 | "align_items": null, 1220 | "bottom": null, 1221 | "_model_module": "@jupyter-widgets/base", 1222 | "top": null, 1223 | "grid_column": null, 1224 | "overflow_y": null, 1225 | "overflow_x": null, 1226 | "grid_auto_flow": null, 1227 | "grid_area": null, 1228 | "grid_template_columns": null, 1229 | "flex": null, 1230 | "_model_name": "LayoutModel", 1231 | "justify_items": null, 1232 | "grid_row": null, 1233 | "max_height": null, 1234 | "align_content": null, 1235 | "visibility": null, 1236 | "align_self": null, 1237 | "height": null, 1238 | "min_height": null, 1239 | "padding": null, 1240 | "grid_auto_rows": null, 1241 | "grid_gap": null, 1242 | "max_width": null, 1243 | "order": null, 1244 | "_view_module_version": "1.2.0", 1245 | "grid_template_areas": null, 1246 | "object_position": null, 1247 | "object_fit": null, 1248 | "grid_auto_columns": null, 1249 | "margin": null, 1250 | "display": null, 1251 | "left": null 1252 | } 1253 | }, 1254 | "1a6214d3c8984950a0a86fb5fcdf0e6d": { 1255 | "model_module": "@jupyter-widgets/controls", 1256 | "model_name": "ProgressStyleModel", 1257 | "model_module_version": "1.5.0", 1258 | "state": { 1259 | "_view_name": "StyleView", 1260 | "_model_name": "ProgressStyleModel", 1261 | "description_width": "", 1262 | "_view_module": "@jupyter-widgets/base", 1263 | "_model_module_version": "1.5.0", 1264 | "_view_count": null, 1265 | "_view_module_version": "1.2.0", 1266 | "bar_color": null, 1267 | "_model_module": "@jupyter-widgets/controls" 1268 | } 1269 | }, 1270 | "e2a23791a7a5484fb3e3b08931555a3f": { 1271 | "model_module": "@jupyter-widgets/base", 1272 | "model_name": "LayoutModel", 1273 | "model_module_version": "1.2.0", 1274 | "state": { 1275 | "_view_name": "LayoutView", 1276 | "grid_template_rows": null, 1277 | "right": null, 1278 | "justify_content": null, 1279 | "_view_module": "@jupyter-widgets/base", 1280 | "overflow": null, 1281 | "_model_module_version": "1.2.0", 1282 | "_view_count": null, 1283 | "flex_flow": null, 1284 | "width": null, 1285 | "min_width": null, 1286 | "border": null, 1287 | "align_items": null, 1288 | "bottom": null, 1289 | "_model_module": "@jupyter-widgets/base", 1290 | "top": null, 1291 | "grid_column": null, 1292 | "overflow_y": null, 1293 | "overflow_x": null, 1294 | "grid_auto_flow": null, 1295 | "grid_area": null, 1296 | "grid_template_columns": null, 1297 | "flex": null, 1298 | "_model_name": "LayoutModel", 1299 | "justify_items": null, 1300 | "grid_row": null, 1301 | "max_height": null, 1302 | "align_content": null, 1303 | "visibility": null, 1304 | "align_self": null, 1305 | "height": null, 1306 | "min_height": null, 1307 | "padding": null, 1308 | "grid_auto_rows": null, 1309 | "grid_gap": null, 1310 | "max_width": null, 1311 | "order": null, 1312 | "_view_module_version": "1.2.0", 1313 | "grid_template_areas": null, 1314 | "object_position": null, 1315 | "object_fit": null, 1316 | "grid_auto_columns": null, 1317 | "margin": null, 1318 | "display": null, 1319 | "left": null 1320 | } 1321 | }, 1322 | "5f7b4216db414a2a8c024b814049d1c1": { 1323 | "model_module": "@jupyter-widgets/controls", 1324 | "model_name": "DescriptionStyleModel", 1325 | "model_module_version": "1.5.0", 1326 | "state": { 1327 | "_view_name": "StyleView", 1328 | "_model_name": "DescriptionStyleModel", 1329 | "description_width": "", 1330 | "_view_module": "@jupyter-widgets/base", 1331 | "_model_module_version": "1.5.0", 1332 | "_view_count": null, 1333 | "_view_module_version": "1.2.0", 1334 | "_model_module": "@jupyter-widgets/controls" 1335 | } 1336 | }, 1337 | "9e6dc5881f1a446d915fa00ef1f778ce": { 1338 | "model_module": "@jupyter-widgets/base", 1339 | "model_name": "LayoutModel", 1340 | "model_module_version": "1.2.0", 1341 | "state": { 1342 | "_view_name": "LayoutView", 1343 | "grid_template_rows": null, 1344 | "right": null, 1345 | "justify_content": null, 1346 | "_view_module": "@jupyter-widgets/base", 1347 | "overflow": null, 1348 | "_model_module_version": "1.2.0", 1349 | "_view_count": null, 1350 | "flex_flow": null, 1351 | "width": null, 1352 | "min_width": null, 1353 | "border": null, 1354 | "align_items": null, 1355 | "bottom": null, 1356 | "_model_module": "@jupyter-widgets/base", 1357 | "top": null, 1358 | "grid_column": null, 1359 | "overflow_y": null, 1360 | "overflow_x": null, 1361 | "grid_auto_flow": null, 1362 | "grid_area": null, 1363 | "grid_template_columns": null, 1364 | "flex": null, 1365 | "_model_name": "LayoutModel", 1366 | "justify_items": null, 1367 | "grid_row": null, 1368 | "max_height": null, 1369 | "align_content": null, 1370 | "visibility": null, 1371 | "align_self": null, 1372 | "height": null, 1373 | "min_height": null, 1374 | "padding": null, 1375 | "grid_auto_rows": null, 1376 | "grid_gap": null, 1377 | "max_width": null, 1378 | "order": null, 1379 | "_view_module_version": "1.2.0", 1380 | "grid_template_areas": null, 1381 | "object_position": null, 1382 | "object_fit": null, 1383 | "grid_auto_columns": null, 1384 | "margin": null, 1385 | "display": null, 1386 | "left": null 1387 | } 1388 | }, 1389 | "f4cfe43af5c2497aaefef587324c861e": { 1390 | "model_module": "@jupyter-widgets/controls", 1391 | "model_name": "HBoxModel", 1392 | "model_module_version": "1.5.0", 1393 | "state": { 1394 | "_view_name": "HBoxView", 1395 | "_dom_classes": [], 1396 | "_model_name": "HBoxModel", 1397 | "_view_module": "@jupyter-widgets/controls", 1398 | "_model_module_version": "1.5.0", 1399 | "_view_count": null, 1400 | "_view_module_version": "1.5.0", 1401 | "box_style": "", 1402 | "layout": "IPY_MODEL_b67d7a6803da4ece97c3ab69aa7547e8", 1403 | "_model_module": "@jupyter-widgets/controls", 1404 | "children": [ 1405 | "IPY_MODEL_2237a78e60eb437d8562d9b74bd98360", 1406 | "IPY_MODEL_f5c1da4f42d54e5ba9d4f7782a7554b0", 1407 | "IPY_MODEL_41517d0e2fd74aa08b04a5b9a4ac665d" 1408 | ] 1409 | } 1410 | }, 1411 | "b67d7a6803da4ece97c3ab69aa7547e8": { 1412 | "model_module": "@jupyter-widgets/base", 1413 | "model_name": "LayoutModel", 1414 | "model_module_version": "1.2.0", 1415 | "state": { 1416 | "_view_name": "LayoutView", 1417 | "grid_template_rows": null, 1418 | "right": null, 1419 | "justify_content": null, 1420 | "_view_module": "@jupyter-widgets/base", 1421 | "overflow": null, 1422 | "_model_module_version": "1.2.0", 1423 | "_view_count": null, 1424 | "flex_flow": null, 1425 | "width": null, 1426 | "min_width": null, 1427 | "border": null, 1428 | "align_items": null, 1429 | "bottom": null, 1430 | "_model_module": "@jupyter-widgets/base", 1431 | "top": null, 1432 | "grid_column": null, 1433 | "overflow_y": null, 1434 | "overflow_x": null, 1435 | "grid_auto_flow": null, 1436 | "grid_area": null, 1437 | "grid_template_columns": null, 1438 | "flex": null, 1439 | "_model_name": "LayoutModel", 1440 | "justify_items": null, 1441 | "grid_row": null, 1442 | "max_height": null, 1443 | "align_content": null, 1444 | "visibility": null, 1445 | "align_self": null, 1446 | "height": null, 1447 | "min_height": null, 1448 | "padding": null, 1449 | "grid_auto_rows": null, 1450 | "grid_gap": null, 1451 | "max_width": null, 1452 | "order": null, 1453 | "_view_module_version": "1.2.0", 1454 | "grid_template_areas": null, 1455 | "object_position": null, 1456 | "object_fit": null, 1457 | "grid_auto_columns": null, 1458 | "margin": null, 1459 | "display": null, 1460 | "left": null 1461 | } 1462 | }, 1463 | "2237a78e60eb437d8562d9b74bd98360": { 1464 | "model_module": "@jupyter-widgets/controls", 1465 | "model_name": "HTMLModel", 1466 | "model_module_version": "1.5.0", 1467 | "state": { 1468 | "_view_name": "HTMLView", 1469 | "style": "IPY_MODEL_d2b1ed49e9d946c289b9db7fadf94992", 1470 | "_dom_classes": [], 1471 | "description": "", 1472 | "_model_name": "HTMLModel", 1473 | "placeholder": "​", 1474 | "_view_module": "@jupyter-widgets/controls", 1475 | "_model_module_version": "1.5.0", 1476 | "value": "Downloading: 100%", 1477 | "_view_count": null, 1478 | "_view_module_version": "1.5.0", 1479 | "description_tooltip": null, 1480 | "_model_module": "@jupyter-widgets/controls", 1481 | "layout": "IPY_MODEL_72d55badd64a441eb02e9d547281c82e" 1482 | } 1483 | }, 1484 | "f5c1da4f42d54e5ba9d4f7782a7554b0": { 1485 | "model_module": "@jupyter-widgets/controls", 1486 | "model_name": "FloatProgressModel", 1487 | "model_module_version": "1.5.0", 1488 | "state": { 1489 | "_view_name": "ProgressView", 1490 | "style": "IPY_MODEL_3fd60e410f434de0bcb061a7fca394ba", 1491 | "_dom_classes": [], 1492 | "description": "", 1493 | "_model_name": "FloatProgressModel", 1494 | "bar_style": "success", 1495 | "max": 376, 1496 | "_view_module": "@jupyter-widgets/controls", 1497 | "_model_module_version": "1.5.0", 1498 | "value": 376, 1499 | "_view_count": null, 1500 | "_view_module_version": "1.5.0", 1501 | "orientation": "horizontal", 1502 | "min": 0, 1503 | "description_tooltip": null, 1504 | "_model_module": "@jupyter-widgets/controls", 1505 | "layout": "IPY_MODEL_7ebab85482134e1db1a62d584dd1f1e6" 1506 | } 1507 | }, 1508 | "41517d0e2fd74aa08b04a5b9a4ac665d": { 1509 | "model_module": "@jupyter-widgets/controls", 1510 | "model_name": "HTMLModel", 1511 | "model_module_version": "1.5.0", 1512 | "state": { 1513 | "_view_name": "HTMLView", 1514 | "style": "IPY_MODEL_62345a490f2c48129bc3e29b42638439", 1515 | "_dom_classes": [], 1516 | "description": "", 1517 | "_model_name": "HTMLModel", 1518 | "placeholder": "​", 1519 | "_view_module": "@jupyter-widgets/controls", 1520 | "_model_module_version": "1.5.0", 1521 | "value": " 376/376 [00:00<00:00, 12.6kB/s]", 1522 | "_view_count": null, 1523 | "_view_module_version": "1.5.0", 1524 | "description_tooltip": null, 1525 | "_model_module": "@jupyter-widgets/controls", 1526 | "layout": "IPY_MODEL_7af5f84e9e21429d993d06062b50209a" 1527 | } 1528 | }, 1529 | "d2b1ed49e9d946c289b9db7fadf94992": { 1530 | "model_module": "@jupyter-widgets/controls", 1531 | "model_name": "DescriptionStyleModel", 1532 | "model_module_version": "1.5.0", 1533 | "state": { 1534 | "_view_name": "StyleView", 1535 | "_model_name": "DescriptionStyleModel", 1536 | "description_width": "", 1537 | "_view_module": "@jupyter-widgets/base", 1538 | "_model_module_version": "1.5.0", 1539 | "_view_count": null, 1540 | "_view_module_version": "1.2.0", 1541 | "_model_module": "@jupyter-widgets/controls" 1542 | } 1543 | }, 1544 | "72d55badd64a441eb02e9d547281c82e": { 1545 | "model_module": "@jupyter-widgets/base", 1546 | "model_name": "LayoutModel", 1547 | "model_module_version": "1.2.0", 1548 | "state": { 1549 | "_view_name": "LayoutView", 1550 | "grid_template_rows": null, 1551 | "right": null, 1552 | "justify_content": null, 1553 | "_view_module": "@jupyter-widgets/base", 1554 | "overflow": null, 1555 | "_model_module_version": "1.2.0", 1556 | "_view_count": null, 1557 | "flex_flow": null, 1558 | "width": null, 1559 | "min_width": null, 1560 | "border": null, 1561 | "align_items": null, 1562 | "bottom": null, 1563 | "_model_module": "@jupyter-widgets/base", 1564 | "top": null, 1565 | "grid_column": null, 1566 | "overflow_y": null, 1567 | "overflow_x": null, 1568 | "grid_auto_flow": null, 1569 | "grid_area": null, 1570 | "grid_template_columns": null, 1571 | "flex": null, 1572 | "_model_name": "LayoutModel", 1573 | "justify_items": null, 1574 | "grid_row": null, 1575 | "max_height": null, 1576 | "align_content": null, 1577 | "visibility": null, 1578 | "align_self": null, 1579 | "height": null, 1580 | "min_height": null, 1581 | "padding": null, 1582 | "grid_auto_rows": null, 1583 | "grid_gap": null, 1584 | "max_width": null, 1585 | "order": null, 1586 | "_view_module_version": "1.2.0", 1587 | "grid_template_areas": null, 1588 | "object_position": null, 1589 | "object_fit": null, 1590 | "grid_auto_columns": null, 1591 | "margin": null, 1592 | "display": null, 1593 | "left": null 1594 | } 1595 | }, 1596 | "3fd60e410f434de0bcb061a7fca394ba": { 1597 | "model_module": "@jupyter-widgets/controls", 1598 | "model_name": "ProgressStyleModel", 1599 | "model_module_version": "1.5.0", 1600 | "state": { 1601 | "_view_name": "StyleView", 1602 | "_model_name": "ProgressStyleModel", 1603 | "description_width": "", 1604 | "_view_module": "@jupyter-widgets/base", 1605 | "_model_module_version": "1.5.0", 1606 | "_view_count": null, 1607 | "_view_module_version": "1.2.0", 1608 | "bar_color": null, 1609 | "_model_module": "@jupyter-widgets/controls" 1610 | } 1611 | }, 1612 | "7ebab85482134e1db1a62d584dd1f1e6": { 1613 | "model_module": "@jupyter-widgets/base", 1614 | "model_name": "LayoutModel", 1615 | "model_module_version": "1.2.0", 1616 | "state": { 1617 | "_view_name": "LayoutView", 1618 | "grid_template_rows": null, 1619 | "right": null, 1620 | "justify_content": null, 1621 | "_view_module": "@jupyter-widgets/base", 1622 | "overflow": null, 1623 | "_model_module_version": "1.2.0", 1624 | "_view_count": null, 1625 | "flex_flow": null, 1626 | "width": null, 1627 | "min_width": null, 1628 | "border": null, 1629 | "align_items": null, 1630 | "bottom": null, 1631 | "_model_module": "@jupyter-widgets/base", 1632 | "top": null, 1633 | "grid_column": null, 1634 | "overflow_y": null, 1635 | "overflow_x": null, 1636 | "grid_auto_flow": null, 1637 | "grid_area": null, 1638 | "grid_template_columns": null, 1639 | "flex": null, 1640 | "_model_name": "LayoutModel", 1641 | "justify_items": null, 1642 | "grid_row": null, 1643 | "max_height": null, 1644 | "align_content": null, 1645 | "visibility": null, 1646 | "align_self": null, 1647 | "height": null, 1648 | "min_height": null, 1649 | "padding": null, 1650 | "grid_auto_rows": null, 1651 | "grid_gap": null, 1652 | "max_width": null, 1653 | "order": null, 1654 | "_view_module_version": "1.2.0", 1655 | "grid_template_areas": null, 1656 | "object_position": null, 1657 | "object_fit": null, 1658 | "grid_auto_columns": null, 1659 | "margin": null, 1660 | "display": null, 1661 | "left": null 1662 | } 1663 | }, 1664 | "62345a490f2c48129bc3e29b42638439": { 1665 | "model_module": "@jupyter-widgets/controls", 1666 | "model_name": "DescriptionStyleModel", 1667 | "model_module_version": "1.5.0", 1668 | "state": { 1669 | "_view_name": "StyleView", 1670 | "_model_name": "DescriptionStyleModel", 1671 | "description_width": "", 1672 | "_view_module": "@jupyter-widgets/base", 1673 | "_model_module_version": "1.5.0", 1674 | "_view_count": null, 1675 | "_view_module_version": "1.2.0", 1676 | "_model_module": "@jupyter-widgets/controls" 1677 | } 1678 | }, 1679 | "7af5f84e9e21429d993d06062b50209a": { 1680 | "model_module": "@jupyter-widgets/base", 1681 | "model_name": "LayoutModel", 1682 | "model_module_version": "1.2.0", 1683 | "state": { 1684 | "_view_name": "LayoutView", 1685 | "grid_template_rows": null, 1686 | "right": null, 1687 | "justify_content": null, 1688 | "_view_module": "@jupyter-widgets/base", 1689 | "overflow": null, 1690 | "_model_module_version": "1.2.0", 1691 | "_view_count": null, 1692 | "flex_flow": null, 1693 | "width": null, 1694 | "min_width": null, 1695 | "border": null, 1696 | "align_items": null, 1697 | "bottom": null, 1698 | "_model_module": "@jupyter-widgets/base", 1699 | "top": null, 1700 | "grid_column": null, 1701 | "overflow_y": null, 1702 | "overflow_x": null, 1703 | "grid_auto_flow": null, 1704 | "grid_area": null, 1705 | "grid_template_columns": null, 1706 | "flex": null, 1707 | "_model_name": "LayoutModel", 1708 | "justify_items": null, 1709 | "grid_row": null, 1710 | "max_height": null, 1711 | "align_content": null, 1712 | "visibility": null, 1713 | "align_self": null, 1714 | "height": null, 1715 | "min_height": null, 1716 | "padding": null, 1717 | "grid_auto_rows": null, 1718 | "grid_gap": null, 1719 | "max_width": null, 1720 | "order": null, 1721 | "_view_module_version": "1.2.0", 1722 | "grid_template_areas": null, 1723 | "object_position": null, 1724 | "object_fit": null, 1725 | "grid_auto_columns": null, 1726 | "margin": null, 1727 | "display": null, 1728 | "left": null 1729 | } 1730 | } 1731 | } 1732 | } 1733 | }, 1734 | "cells": [ 1735 | { 1736 | "cell_type": "code", 1737 | "metadata": { 1738 | "colab": { 1739 | "base_uri": "https://localhost:8080/" 1740 | }, 1741 | "id": "u8xBtDGlxvz4", 1742 | "outputId": "a0837408-171a-4437-ff03-3d43b82c0730" 1743 | }, 1744 | "source": [ 1745 | "!nvidia-smi" 1746 | ], 1747 | "execution_count": 1, 1748 | "outputs": [ 1749 | { 1750 | "output_type": "stream", 1751 | "name": "stdout", 1752 | "text": [ 1753 | "Sat Sep 11 03:45:27 2021 \n", 1754 | "+-----------------------------------------------------------------------------+\n", 1755 | "| NVIDIA-SMI 470.63.01 Driver Version: 460.32.03 CUDA Version: 11.2 |\n", 1756 | "|-------------------------------+----------------------+----------------------+\n", 1757 | "| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n", 1758 | "| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n", 1759 | "| | | MIG M. |\n", 1760 | "|===============================+======================+======================|\n", 1761 | "| 0 Tesla P100-PCIE... Off | 00000000:00:04.0 Off | 0 |\n", 1762 | "| N/A 40C P0 27W / 250W | 0MiB / 16280MiB | 0% Default |\n", 1763 | "| | | N/A |\n", 1764 | "+-------------------------------+----------------------+----------------------+\n", 1765 | " \n", 1766 | "+-----------------------------------------------------------------------------+\n", 1767 | "| Processes: |\n", 1768 | "| GPU GI CI PID Type Process name GPU Memory |\n", 1769 | "| ID ID Usage |\n", 1770 | "|=============================================================================|\n", 1771 | "| No running processes found |\n", 1772 | "+-----------------------------------------------------------------------------+\n" 1773 | ] 1774 | } 1775 | ] 1776 | }, 1777 | { 1778 | "cell_type": "code", 1779 | "metadata": { 1780 | "colab": { 1781 | "base_uri": "https://localhost:8080/" 1782 | }, 1783 | "id": "rbpX3_9PxzVE", 1784 | "outputId": "60193622-45ea-4e77-c7a1-e452637533f2" 1785 | }, 1786 | "source": [ 1787 | "!pip install transformers SentencePiece torch tqdm" 1788 | ], 1789 | "execution_count": 2, 1790 | "outputs": [ 1791 | { 1792 | "output_type": "stream", 1793 | "name": "stdout", 1794 | "text": [ 1795 | "Collecting transformers\n", 1796 | " Downloading transformers-4.10.2-py3-none-any.whl (2.8 MB)\n", 1797 | "\u001b[K |████████████████████████████████| 2.8 MB 5.0 MB/s \n", 1798 | "\u001b[?25hCollecting SentencePiece\n", 1799 | " Downloading sentencepiece-0.1.96-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB)\n", 1800 | "\u001b[K |████████████████████████████████| 1.2 MB 34.6 MB/s \n", 1801 | "\u001b[?25hRequirement already satisfied: torch in /usr/local/lib/python3.7/dist-packages (1.9.0+cu102)\n", 1802 | "Requirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (4.62.0)\n", 1803 | "Requirement already satisfied: packaging in /usr/local/lib/python3.7/dist-packages (from transformers) (21.0)\n", 1804 | "Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from transformers) (4.6.4)\n", 1805 | "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.7/dist-packages (from transformers) (2019.12.20)\n", 1806 | "Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from transformers) (3.0.12)\n", 1807 | "Collecting huggingface-hub>=0.0.12\n", 1808 | " Downloading huggingface_hub-0.0.16-py3-none-any.whl (50 kB)\n", 1809 | "\u001b[K |████████████████████████████████| 50 kB 4.0 MB/s \n", 1810 | "\u001b[?25hCollecting tokenizers<0.11,>=0.10.1\n", 1811 | " Downloading tokenizers-0.10.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (3.3 MB)\n", 1812 | "\u001b[K |████████████████████████████████| 3.3 MB 53.8 MB/s \n", 1813 | "\u001b[?25hRequirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.7/dist-packages (from transformers) (1.19.5)\n", 1814 | "Collecting sacremoses\n", 1815 | " Downloading sacremoses-0.0.45-py3-none-any.whl (895 kB)\n", 1816 | "\u001b[K |████████████████████████████████| 895 kB 61.5 MB/s \n", 1817 | "\u001b[?25hRequirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from transformers) (2.23.0)\n", 1818 | "Collecting pyyaml>=5.1\n", 1819 | " Downloading PyYAML-5.4.1-cp37-cp37m-manylinux1_x86_64.whl (636 kB)\n", 1820 | "\u001b[K |████████████████████████████████| 636 kB 65.1 MB/s \n", 1821 | "\u001b[?25hRequirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from huggingface-hub>=0.0.12->transformers) (3.7.4.3)\n", 1822 | "Requirement already satisfied: pyparsing>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging->transformers) (2.4.7)\n", 1823 | "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata->transformers) (3.5.0)\n", 1824 | "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (2.10)\n", 1825 | "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (1.24.3)\n", 1826 | "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (2021.5.30)\n", 1827 | "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (3.0.4)\n", 1828 | "Requirement already satisfied: click in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers) (7.1.2)\n", 1829 | "Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers) (1.15.0)\n", 1830 | "Requirement already satisfied: joblib in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers) (1.0.1)\n", 1831 | "Installing collected packages: tokenizers, sacremoses, pyyaml, huggingface-hub, transformers, SentencePiece\n", 1832 | " Attempting uninstall: pyyaml\n", 1833 | " Found existing installation: PyYAML 3.13\n", 1834 | " Uninstalling PyYAML-3.13:\n", 1835 | " Successfully uninstalled PyYAML-3.13\n", 1836 | "Successfully installed SentencePiece-0.1.96 huggingface-hub-0.0.16 pyyaml-5.4.1 sacremoses-0.0.45 tokenizers-0.10.3 transformers-4.10.2\n" 1837 | ] 1838 | } 1839 | ] 1840 | }, 1841 | { 1842 | "cell_type": "code", 1843 | "metadata": { 1844 | "id": "sqiBMWGHx3Yx" 1845 | }, 1846 | "source": [ 1847 | "import math\n", 1848 | "\n", 1849 | "from tqdm import tqdm\n", 1850 | "import numpy as np\n", 1851 | "from transformers import MT5ForConditionalGeneration, T5Tokenizer\n", 1852 | "import torch\n", 1853 | "import torch.nn as nn\n", 1854 | "from sklearn.metrics import accuracy_score" 1855 | ], 1856 | "execution_count": 3, 1857 | "outputs": [] 1858 | }, 1859 | { 1860 | "cell_type": "code", 1861 | "metadata": { 1862 | "id": "w9dorhOHyydQ" 1863 | }, 1864 | "source": [ 1865 | "class SoftEmbedding(nn.Module):\n", 1866 | " def __init__(self, \n", 1867 | " wte: nn.Embedding,\n", 1868 | " n_tokens: int = 10, \n", 1869 | " random_range: float = 0.5,\n", 1870 | " initialize_from_vocab: bool = True):\n", 1871 | " \"\"\"appends learned embedding to \n", 1872 | " Args:\n", 1873 | " wte (nn.Embedding): original transformer word embedding\n", 1874 | " n_tokens (int, optional): number of tokens for task. Defaults to 10.\n", 1875 | " random_range (float, optional): range to init embedding (if not initialize from vocab). Defaults to 0.5.\n", 1876 | " initialize_from_vocab (bool, optional): initalizes from default vocab. Defaults to True.\n", 1877 | " \"\"\"\n", 1878 | " super(SoftEmbedding, self).__init__()\n", 1879 | " self.wte = wte\n", 1880 | " self.n_tokens = n_tokens\n", 1881 | " self.learned_embedding = nn.parameter.Parameter(self.initialize_embedding(wte,\n", 1882 | " n_tokens, \n", 1883 | " random_range, \n", 1884 | " initialize_from_vocab))\n", 1885 | " \n", 1886 | " def initialize_embedding(self, \n", 1887 | " wte: nn.Embedding,\n", 1888 | " n_tokens: int = 10, \n", 1889 | " random_range: float = 0.5, \n", 1890 | " initialize_from_vocab: bool = True):\n", 1891 | " \"\"\"initializes learned embedding\n", 1892 | " Args:\n", 1893 | " same as __init__\n", 1894 | " Returns:\n", 1895 | " torch.float: initialized using original schemes\n", 1896 | " \"\"\"\n", 1897 | " if initialize_from_vocab:\n", 1898 | " return self.wte.weight[:n_tokens].clone().detach()\n", 1899 | " return torch.FloatTensor(n_tokens, wte.weight.size(1)).uniform_(-random_range, random_range)\n", 1900 | " \n", 1901 | " def forward(self, tokens):\n", 1902 | " \"\"\"run forward pass\n", 1903 | " Args:\n", 1904 | " tokens (torch.long): input tokens before encoding\n", 1905 | " Returns:\n", 1906 | " torch.float: encoding of text concatenated with learned task specifc embedding\n", 1907 | " \"\"\"\n", 1908 | " input_embedding = self.wte(tokens[:, self.n_tokens:])\n", 1909 | " learned_embedding = self.learned_embedding.repeat(input_embedding.size(0), 1, 1)\n", 1910 | " return torch.cat([learned_embedding, input_embedding], 1)" 1911 | ], 1912 | "execution_count": 4, 1913 | "outputs": [] 1914 | }, 1915 | { 1916 | "cell_type": "code", 1917 | "metadata": { 1918 | "colab": { 1919 | "base_uri": "https://localhost:8080/" 1920 | }, 1921 | "id": "We6vNt5ukkBJ", 1922 | "outputId": "929c1d31-dd37-4cd1-d4d5-3a9667eea948" 1923 | }, 1924 | "source": [ 1925 | "!pip install zh-dataset-inews" 1926 | ], 1927 | "execution_count": 5, 1928 | "outputs": [ 1929 | { 1930 | "output_type": "stream", 1931 | "name": "stdout", 1932 | "text": [ 1933 | "Collecting zh-dataset-inews\n", 1934 | " Downloading zh_dataset_inews-0.0.2-py3-none-any.whl (11.3 MB)\n", 1935 | "\u001b[K |████████████████████████████████| 11.3 MB 5.2 MB/s \n", 1936 | "\u001b[?25hInstalling collected packages: zh-dataset-inews\n", 1937 | "Successfully installed zh-dataset-inews-0.0.2\n" 1938 | ] 1939 | } 1940 | ] 1941 | }, 1942 | { 1943 | "cell_type": "code", 1944 | "metadata": { 1945 | "id": "biQD2_aB6v8s" 1946 | }, 1947 | "source": [ 1948 | "from zh_dataset_inews import title_train, label_train, title_dev, label_dev, title_test, label_test" 1949 | ], 1950 | "execution_count": 6, 1951 | "outputs": [] 1952 | }, 1953 | { 1954 | "cell_type": "code", 1955 | "metadata": { 1956 | "id": "LfF_Z4qu0z3a" 1957 | }, 1958 | "source": [ 1959 | "def generate_data(batch_size, n_tokens, title_data, label_data):\n", 1960 | "\n", 1961 | " labels = [\n", 1962 | " torch.tensor([[3]]), # \\x00\n", 1963 | " torch.tensor([[4]]), # \\x01\n", 1964 | " torch.tensor([[5]]), # \\x02\n", 1965 | " ]\n", 1966 | "\n", 1967 | " def yield_data(x_batch, y_batch, l_batch):\n", 1968 | " x = torch.nn.utils.rnn.pad_sequence(x_batch, batch_first=True)\n", 1969 | " y = torch.cat(y_batch, dim=0)\n", 1970 | " m = (x > 0).to(torch.float32)\n", 1971 | " decoder_input_ids = torch.full((x.size(0), n_tokens), 1)\n", 1972 | " if torch.cuda.is_available():\n", 1973 | " x = x.cuda()\n", 1974 | " y = y.cuda()\n", 1975 | " m = m.cuda()\n", 1976 | " decoder_input_ids = decoder_input_ids.cuda()\n", 1977 | " return x, y, m, decoder_input_ids, l_batch\n", 1978 | "\n", 1979 | " x_batch, y_batch, l_batch = [], [], []\n", 1980 | " for x, y in zip(title_data, label_data):\n", 1981 | " context = x\n", 1982 | " inputs = tokenizer(context, return_tensors=\"pt\")\n", 1983 | " inputs['input_ids'] = torch.cat([torch.full((1, n_tokens), 1), inputs['input_ids']], 1)\n", 1984 | " l_batch.append(y)\n", 1985 | " y = labels[y]\n", 1986 | " y = torch.cat([torch.full((1, n_tokens - 1), -100), y], 1)\n", 1987 | " x_batch.append(inputs['input_ids'][0])\n", 1988 | " y_batch.append(y)\n", 1989 | " if len(x_batch) >= batch_size:\n", 1990 | " yield yield_data(x_batch, y_batch, l_batch)\n", 1991 | " x_batch, y_batch, l_batch = [], [], []\n", 1992 | "\n", 1993 | " if len(x_batch) > 0:\n", 1994 | " yield yield_data(x_batch, y_batch, l_batch)\n", 1995 | " x_batch, y_batch, l_batch = [], [], []" 1996 | ], 1997 | "execution_count": 7, 1998 | "outputs": [] 1999 | }, 2000 | { 2001 | "cell_type": "code", 2002 | "metadata": { 2003 | "id": "5NKTnNidXnzS", 2004 | "colab": { 2005 | "base_uri": "https://localhost:8080/", 2006 | "height": 177, 2007 | "referenced_widgets": [ 2008 | "88c88a80fe4244dc8519ed9a9d4a7fad", 2009 | "012192e864f6471e9336264b77f9d047", 2010 | "b71062f17e064547b91634af273e6d40", 2011 | "8ce567f99b614fdc99ef7053ce718b23", 2012 | "05fab04f6d6f46f6bccc8cbfa22a88a2", 2013 | "caa18d30298b4ad7be2e9ceda2f395b8", 2014 | "0f7b84c5523b48fabd4f79d8b8d80fa9", 2015 | "f982a7f0aad440fbb5c4497c3fd1d15e", 2016 | "1ebe5477942940abaf722f6963f9b0e8", 2017 | "60d985a36d434ee996f616952a7c7612", 2018 | "bde84bfa5da2408ca9eda7c9b39291ad", 2019 | "ac7114b8298b45bda89d9e822d0f0ffd", 2020 | "e6ae18a1e8144dada46879db37e8cdef", 2021 | "702c7b9308b646c1bdd0ca2387776c25", 2022 | "686aca97abba439e81e430f1c33224dc", 2023 | "61ee6d1f21bf41c89488dbcfb99f4d5d", 2024 | "abd066dddd434d63ae737bb56c67de79", 2025 | "9269f66db7d24a6da268fa98c33d6f02", 2026 | "81d5b9f593dd4f36894e32a5d0360d1d", 2027 | "f15396c431f34835bbe2eeb65dfd16a4", 2028 | "0ce264acdaa8432b9715bdc42344b0aa", 2029 | "5a46f6fdc4784422b652653b05d1c691", 2030 | "268502a8a7204b5cbbae12de9c63ed1c", 2031 | "71118489f6ff49a989d91f5325359b77", 2032 | "542cfc0efdeb436da7d4c789ae3a722d", 2033 | "f77f512b44bf481cb8846404350b9ab2", 2034 | "2d4e4125f1ed417bb2cfe878468b2966", 2035 | "354e1cdb54874df192f9f4f190035efa", 2036 | "3e2b5c8ec31149838dcacfcc78f2087a", 2037 | "562e3a57010c4ca5af045c3c33500a28", 2038 | "53b33d44e6354bb6af52775f1ed6b9dd", 2039 | "e26e613b94f94d539d027c3ceb496903", 2040 | "65b77629789f45b6ab5105f893c7b1cc", 2041 | "c95d9ed3d1ed43bebf121392906414fa", 2042 | "f258c1cc75474d1392a9368c51550c45", 2043 | "2c9e861a76054c738afefb8620500773", 2044 | "e6af025bf3744cb2a3017199a8dc185c", 2045 | "d3371fb6455440d48684cfe9d4829eff", 2046 | "33dab107c44d407689a375e420f8710d", 2047 | "3649756bf03e40698c7dc32941daada0", 2048 | "1a6214d3c8984950a0a86fb5fcdf0e6d", 2049 | "e2a23791a7a5484fb3e3b08931555a3f", 2050 | "5f7b4216db414a2a8c024b814049d1c1", 2051 | "9e6dc5881f1a446d915fa00ef1f778ce", 2052 | "f4cfe43af5c2497aaefef587324c861e", 2053 | "b67d7a6803da4ece97c3ab69aa7547e8", 2054 | "2237a78e60eb437d8562d9b74bd98360", 2055 | "f5c1da4f42d54e5ba9d4f7782a7554b0", 2056 | "41517d0e2fd74aa08b04a5b9a4ac665d", 2057 | "d2b1ed49e9d946c289b9db7fadf94992", 2058 | "72d55badd64a441eb02e9d547281c82e", 2059 | "3fd60e410f434de0bcb061a7fca394ba", 2060 | "7ebab85482134e1db1a62d584dd1f1e6", 2061 | "62345a490f2c48129bc3e29b42638439", 2062 | "7af5f84e9e21429d993d06062b50209a" 2063 | ] 2064 | }, 2065 | "outputId": "a522bba1-2370-4cae-88e1-a0583a84f046" 2066 | }, 2067 | "source": [ 2068 | "model = MT5ForConditionalGeneration.from_pretrained(\"google/mt5-base\")\n", 2069 | "tokenizer = T5Tokenizer.from_pretrained(\"google/mt5-base\")\n", 2070 | "n_tokens = 100\n", 2071 | "s_wte = SoftEmbedding(model.get_input_embeddings(), \n", 2072 | " n_tokens=n_tokens, \n", 2073 | " initialize_from_vocab=True)\n", 2074 | "model.set_input_embeddings(s_wte)\n", 2075 | "if torch.cuda.is_available():\n", 2076 | " model = model.cuda()" 2077 | ], 2078 | "execution_count": 8, 2079 | "outputs": [ 2080 | { 2081 | "output_type": "display_data", 2082 | "data": { 2083 | "application/vnd.jupyter.widget-view+json": { 2084 | "model_id": "88c88a80fe4244dc8519ed9a9d4a7fad", 2085 | "version_minor": 0, 2086 | "version_major": 2 2087 | }, 2088 | "text/plain": [ 2089 | "Downloading: 0%| | 0.00/702 [00:00