├── gif-demo ├── icon.png ├── discord.gif ├── gameplay.png └── huggingface.gif ├── LICENSE ├── discord_bot.js ├── README.md ├── .gitignore ├── discord_bot.py └── model_train_upload_workflow.ipynb /gif-demo/icon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shubham8550/twewy-discord-chatbot/main/gif-demo/icon.png -------------------------------------------------------------------------------- /gif-demo/discord.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shubham8550/twewy-discord-chatbot/main/gif-demo/discord.gif -------------------------------------------------------------------------------- /gif-demo/gameplay.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shubham8550/twewy-discord-chatbot/main/gif-demo/gameplay.png -------------------------------------------------------------------------------- /gif-demo/huggingface.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shubham8550/twewy-discord-chatbot/main/gif-demo/huggingface.gif -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Lynn Zheng 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /discord_bot.js: -------------------------------------------------------------------------------- 1 | // discord.js import 2 | const Discord = require('discord.js'); 3 | // node-fetch for making HTTP requests 4 | const fetch = require('node-fetch'); 5 | 6 | // initialize client 7 | const client = new Discord.Client(); 8 | // my model URL 9 | API_URL = 'https://api-inference.huggingface.co/models/r3dhummingbird/DialoGPT-medium-joshua'; 10 | 11 | // log out some info 12 | client.on('ready', () => { 13 | console.log(`Logged in as ${client.user.tag}!`); 14 | }); 15 | 16 | // when the bot receives a message 17 | // need async message because we are making HTTP requests 18 | client.on('message', async message => { 19 | // ignore messages from the bot itself 20 | if (message.author.bot) { 21 | return; 22 | } 23 | // form the payload 24 | const payload = { 25 | inputs: { 26 | text: message.content 27 | } 28 | }; 29 | // form the request headers with Hugging Face API key 30 | const headers = { 31 | 'Authorization': 'Bearer ' + process.env.HUGGINGFACE_TOKEN 32 | }; 33 | 34 | // set status to typing 35 | message.channel.startTyping(); 36 | // query the server 37 | const response = await fetch(API_URL, { 38 | method: 'post', 39 | body: JSON.stringify(payload), 40 | headers: headers 41 | }); 42 | const data = await response.json(); 43 | let botResponse = ''; 44 | if (data.hasOwnProperty('generated_text')) { 45 | botResponse = data.generated_text; 46 | } else if (data.hasOwnProperty('error')) { // error condition 47 | botResponse = data.error; 48 | } 49 | // stop typing 50 | message.channel.stopTyping(); 51 | // send message to channel as a reply 52 | message.reply(botResponse); 53 | }) 54 | 55 | client.login(process.env.DISCORD_TOKEN); 56 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Build a Discord AI Chatbot that Speaks like Your Favorite Character! 2 | 3 |
4 | 5 |
6 | 7 | This is a Discord AI Chatbot that uses the [Microsoft DialoGPT conversational model](https://huggingface.co/microsoft/DialoGPT-medium) fine-tuned on the game transcript of [The World Ends With You](https://en.wikipedia.org/wiki/The_World_Ends_with_You) (TWEWY). Read [my tutorial on freeCodeCamp](https://www.freecodecamp.org/news/discord-ai-chatbot/) or watch [my video tutorial on YouTube](https://youtu.be/UBwvFuTC1ZE). I've also made [a JavaScript version of the tutorial using Discord.js](https://youtu.be/XR6JFRLxe5A). 8 | 9 | I trained the model using the lines of my favorite quirky character, Joshua (left in the image below). He has about 700 lines in total in the entire game. 10 | 11 |
12 | 13 | Here is a demo of the Discord bot in action. 14 | 15 |
16 | 17 | You can also directly chat with the model hosted on [Hugging Face's Model Hub](https://huggingface.co/r3dhummingbird/DialoGPT-medium-joshua). 18 | 19 |
20 | 21 | ## Structure of this Project 22 | 23 | - `model_train_upload_workflow.ipyb`: Notebook to be run in Google Colab to train and upload the model to Hugging Face's Model Hub 24 | - `discord_bot.py`: Script to be imported into a Repl.it Python Discord.py project 25 | - `discord_bot.js`: Script to be imported into a Repl.it JavaScript Discord.js project 26 | 27 | ## Resource Links 28 | 29 | - [15-min chat demo](https://youtu.be/-n6uWu8PZzo) 30 | - [My tutorial on freeCodeCamp](https://www.freecodecamp.org/news/discord-ai-chatbot/) 31 | - [My video tutorial on YouTube](https://youtu.be/UBwvFuTC1ZE) 32 | - [My JavaScript version of this tutorial on YouTube](https://youtu.be/XR6JFRLxe5A) 33 | - [My TWEWY dataset on Kaggle](https://www.kaggle.com/ruolinzheng/twewy-game-script) 34 | - [My Hugging Face Model](https://huggingface.co/r3dhummingbird/DialoGPT-medium-joshua) 35 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /discord_bot.py: -------------------------------------------------------------------------------- 1 | # the os module helps us access environment variables 2 | # i.e., our API keys 3 | import os 4 | 5 | # these modules are for querying the Hugging Face model 6 | import json 7 | import requests 8 | 9 | # the Discord Python API 10 | import discord 11 | 12 | # this is my Hugging Face profile link 13 | API_URL = 'https://api-inference.huggingface.co/models/r3dhummingbird/' 14 | 15 | class MyClient(discord.Client): 16 | def __init__(self, model_name): 17 | super().__init__() 18 | self.api_endpoint = API_URL + model_name 19 | # retrieve the secret API token from the system environment 20 | huggingface_token = os.environ['HUGGINGFACE_TOKEN'] 21 | # format the header in our request to Hugging Face 22 | self.request_headers = { 23 | 'Authorization': 'Bearer {}'.format(huggingface_token) 24 | } 25 | 26 | def query(self, payload): 27 | """ 28 | make request to the Hugging Face model API 29 | """ 30 | data = json.dumps(payload) 31 | response = requests.request('POST', 32 | self.api_endpoint, 33 | headers=self.request_headers, 34 | data=data) 35 | ret = json.loads(response.content.decode('utf-8')) 36 | return ret 37 | 38 | async def on_ready(self): 39 | # print out information when the bot wakes up 40 | print('Logged in as') 41 | print(self.user.name) 42 | print(self.user.id) 43 | print('------') 44 | # send a request to the model without caring about the response 45 | # just so that the model wakes up and starts loading 46 | self.query({'inputs': {'text': 'Hello!'}}) 47 | 48 | async def on_message(self, message): 49 | """ 50 | this function is called whenever the bot sees a message in a channel 51 | """ 52 | # ignore the message if it comes from the bot itself 53 | if message.author.id == self.user.id: 54 | return 55 | 56 | # form query payload with the content of the message 57 | payload = {'inputs': {'text': message.content}} 58 | 59 | # while the bot is waiting on a response from the model 60 | # set the its status as typing for user-friendliness 61 | async with message.channel.typing(): 62 | response = self.query(payload) 63 | bot_response = response.get('generated_text', None) 64 | 65 | # we may get ill-formed response if the model hasn't fully loaded 66 | # or has timed out 67 | if not bot_response: 68 | if 'error' in response: 69 | bot_response = '`Error: {}`'.format(response['error']) 70 | else: 71 | bot_response = 'Hmm... something is not right.' 72 | 73 | # send the model's response to the Discord channel 74 | await message.channel.send(bot_response) 75 | 76 | def main(): 77 | # DialoGPT-medium-joshua is my model name 78 | client = MyClient('DialoGPT-medium-joshua') 79 | client.run(os.environ['DISCORD_TOKEN']) 80 | 81 | if __name__ == '__main__': 82 | main() -------------------------------------------------------------------------------- /model_train_upload_workflow.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "VTze-VbeU1c0" 7 | }, 8 | "source": [ 9 | "# Fine-tune a DialoGPT model\n", 10 | "\n", 11 | "Adapted from the notebook in [this Medium post](https://towardsdatascience.com/make-your-own-rick-sanchez-bot-with-transformers-and-dialogpt-fine-tuning-f85e6d1f4e30?gi=e4a72d1510f0)." 12 | ] 13 | }, 14 | { 15 | "cell_type": "markdown", 16 | "metadata": { 17 | "id": "Y17kuzFNUSrZ" 18 | }, 19 | "source": [ 20 | "## Setup" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": null, 26 | "metadata": { 27 | "colab": { 28 | "base_uri": "https://localhost:8080/" 29 | }, 30 | "id": "GBfltjGHT6KG", 31 | "outputId": "7822e15b-9c77-412a-a6ed-20100243db13" 32 | }, 33 | "outputs": [], 34 | "source": [ 35 | "from google.colab import drive\n", 36 | "drive.mount('/content/drive/')" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": null, 42 | "metadata": { 43 | "id": "T8fgmjaqUErq" 44 | }, 45 | "outputs": [], 46 | "source": [ 47 | "!pip -q install transformers" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": null, 53 | "metadata": { 54 | "id": "EtCreyG8UG1s" 55 | }, 56 | "outputs": [], 57 | "source": [ 58 | "import os\n", 59 | "os.chdir(\"/content/drive/My Drive/Colab Notebooks\")" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": null, 65 | "metadata": { 66 | "id": "dnv5kT-mLsB-" 67 | }, 68 | "outputs": [], 69 | "source": [ 70 | "# all the imports\n", 71 | "\n", 72 | "import glob\n", 73 | "import logging\n", 74 | "import os\n", 75 | "import pickle\n", 76 | "import random\n", 77 | "import re\n", 78 | "import shutil\n", 79 | "from typing import Dict, List, Tuple\n", 80 | "\n", 81 | "import numpy as np\n", 82 | "import pandas as pd\n", 83 | "\n", 84 | "from sklearn.model_selection import train_test_split\n", 85 | "\n", 86 | "from torch.nn.utils.rnn import pad_sequence\n", 87 | "from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler\n", 88 | "from torch.utils.data.distributed import DistributedSampler\n", 89 | "from tqdm.notebook import tqdm, trange\n", 90 | "\n", 91 | "from pathlib import Path\n", 92 | "\n", 93 | "from transformers import (\n", 94 | " MODEL_WITH_LM_HEAD_MAPPING,\n", 95 | " WEIGHTS_NAME,\n", 96 | " AdamW,\n", 97 | " AutoConfig,\n", 98 | " PreTrainedModel,\n", 99 | " PreTrainedTokenizer,\n", 100 | " get_linear_schedule_with_warmup,\n", 101 | ")\n", 102 | "\n", 103 | "\n", 104 | "try:\n", 105 | " from torch.utils.tensorboard import SummaryWriter\n", 106 | "except ImportError:\n", 107 | " from tensorboardX import SummaryWriter" 108 | ] 109 | }, 110 | { 111 | "cell_type": "markdown", 112 | "metadata": { 113 | "id": "BmrbGB8aUmBm" 114 | }, 115 | "source": [ 116 | "## Get Data from Kaggle" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": null, 122 | "metadata": { 123 | "colab": { 124 | "base_uri": "https://localhost:8080/" 125 | }, 126 | "id": "ftBYBoOoV_Er", 127 | "outputId": "07da0a13-6112-4c4e-cb49-51580c2d9e7a" 128 | }, 129 | "outputs": [], 130 | "source": [ 131 | "!mkdir ~/.kaggle\n", 132 | "!cp kaggle.json ~/.kaggle/kaggle.json" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": null, 138 | "metadata": { 139 | "colab": { 140 | "base_uri": "https://localhost:8080/" 141 | }, 142 | "id": "fbITTMcLVbI_", 143 | "outputId": "fb4c8bf1-ff2d-4952-a451-62cdd0655aea" 144 | }, 145 | "outputs": [], 146 | "source": [ 147 | "!kaggle datasets download ruolinzheng/twewy-game-script -f twewy-name-line-full.csv" 148 | ] 149 | }, 150 | { 151 | "cell_type": "code", 152 | "execution_count": null, 153 | "metadata": { 154 | "id": "RXdJTSVwWGHj" 155 | }, 156 | "outputs": [], 157 | "source": [ 158 | "data = pd.read_csv('twewy-name-line-full.csv')" 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": null, 164 | "metadata": { 165 | "colab": { 166 | "base_uri": "https://localhost:8080/", 167 | "height": 238 168 | }, 169 | "id": "h6kGx-9eG7qA", 170 | "outputId": "bd2efe43-1e50-4716-81a2-bf15a3dd03bd" 171 | }, 172 | "outputs": [], 173 | "source": [ 174 | "data.sample(6)" 175 | ] 176 | }, 177 | { 178 | "cell_type": "code", 179 | "execution_count": null, 180 | "metadata": { 181 | "id": "PG8v6--qWUwj" 182 | }, 183 | "outputs": [], 184 | "source": [ 185 | "CHARACTER_NAME = 'Joshua'" 186 | ] 187 | }, 188 | { 189 | "cell_type": "code", 190 | "execution_count": null, 191 | "metadata": { 192 | "id": "GZUcEMd2WLDT" 193 | }, 194 | "outputs": [], 195 | "source": [ 196 | "contexted = []\n", 197 | "\n", 198 | "# context window of size 7\n", 199 | "n = 7\n", 200 | "\n", 201 | "for i in data[data.name == CHARACTER_NAME].index:\n", 202 | " if i < n:\n", 203 | " continue\n", 204 | " row = []\n", 205 | " prev = i - 1 - n # we additionally substract 1, so row will contain current responce and 7 previous responces \n", 206 | " for j in range(i, prev, -1):\n", 207 | " row.append(data.line[j])\n", 208 | " contexted.append(row)\n", 209 | "\n", 210 | "columns = ['response', 'context'] \n", 211 | "columns = columns + ['context/' + str(i) for i in range(n - 1)]\n", 212 | "\n", 213 | "df = pd.DataFrame.from_records(contexted, columns=columns)" 214 | ] 215 | }, 216 | { 217 | "cell_type": "code", 218 | "execution_count": null, 219 | "metadata": { 220 | "colab": { 221 | "base_uri": "https://localhost:8080/", 222 | "height": 446 223 | }, 224 | "id": "4T5OlNZHUxij", 225 | "outputId": "895603a6-ca02-4301-c4b0-5bccbee8a3b8" 226 | }, 227 | "outputs": [], 228 | "source": [ 229 | "df.sample(6)" 230 | ] 231 | }, 232 | { 233 | "cell_type": "code", 234 | "execution_count": null, 235 | "metadata": { 236 | "colab": { 237 | "base_uri": "https://localhost:8080/", 238 | "height": 380 239 | }, 240 | "id": "NGy0MxMQVIAP", 241 | "outputId": "08b7f0eb-6a38-4b83-efdc-e53778d7547a" 242 | }, 243 | "outputs": [], 244 | "source": [ 245 | "trn_df, val_df = train_test_split(df, test_size=0.1)\n", 246 | "trn_df.head()" 247 | ] 248 | }, 249 | { 250 | "cell_type": "code", 251 | "execution_count": null, 252 | "metadata": { 253 | "id": "aEeJQlAKWtiJ" 254 | }, 255 | "outputs": [], 256 | "source": [ 257 | "# create dataset suitable for our model\n", 258 | "def construct_conv(row, tokenizer, eos = True):\n", 259 | " flatten = lambda l: [item for sublist in l for item in sublist]\n", 260 | " conv = list(reversed([tokenizer.encode(x) + [tokenizer.eos_token_id] for x in row]))\n", 261 | " conv = flatten(conv)\n", 262 | " return conv\n", 263 | "\n", 264 | "class ConversationDataset(Dataset):\n", 265 | " def __init__(self, tokenizer: PreTrainedTokenizer, args, df, block_size=512):\n", 266 | "\n", 267 | " block_size = block_size - (tokenizer.model_max_length - tokenizer.max_len_single_sentence)\n", 268 | "\n", 269 | " directory = args.cache_dir\n", 270 | " cached_features_file = os.path.join(\n", 271 | " directory, args.model_type + \"_cached_lm_\" + str(block_size)\n", 272 | " )\n", 273 | "\n", 274 | " if os.path.exists(cached_features_file) and not args.overwrite_cache:\n", 275 | " logger.info(\"Loading features from cached file %s\", cached_features_file)\n", 276 | " with open(cached_features_file, \"rb\") as handle:\n", 277 | " self.examples = pickle.load(handle)\n", 278 | " else:\n", 279 | " logger.info(\"Creating features from dataset file at %s\", directory)\n", 280 | "\n", 281 | " self.examples = []\n", 282 | " for _, row in df.iterrows():\n", 283 | " conv = construct_conv(row, tokenizer)\n", 284 | " self.examples.append(conv)\n", 285 | "\n", 286 | " logger.info(\"Saving features into cached file %s\", cached_features_file)\n", 287 | " with open(cached_features_file, \"wb\") as handle:\n", 288 | " pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL)\n", 289 | "\n", 290 | " def __len__(self):\n", 291 | " return len(self.examples)\n", 292 | "\n", 293 | " def __getitem__(self, item):\n", 294 | " return torch.tensor(self.examples[item], dtype=torch.long)" 295 | ] 296 | }, 297 | { 298 | "cell_type": "code", 299 | "execution_count": null, 300 | "metadata": { 301 | "id": "-3iHwoKlWyrs" 302 | }, 303 | "outputs": [], 304 | "source": [ 305 | "# Cacheing and storing of data/checkpoints\n", 306 | "\n", 307 | "def load_and_cache_examples(args, tokenizer, df_trn, df_val, evaluate=False):\n", 308 | " return ConversationDataset(tokenizer, args, df_val if evaluate else df_trn)\n", 309 | "\n", 310 | "\n", 311 | "def set_seed(args):\n", 312 | " random.seed(args.seed)\n", 313 | " np.random.seed(args.seed)\n", 314 | " torch.manual_seed(args.seed)\n", 315 | " if args.n_gpu > 0:\n", 316 | " torch.cuda.manual_seed_all(args.seed)\n", 317 | "\n", 318 | "\n", 319 | "def _sorted_checkpoints(args, checkpoint_prefix=\"checkpoint\", use_mtime=False) -> List[str]:\n", 320 | " ordering_and_checkpoint_path = []\n", 321 | "\n", 322 | " glob_checkpoints = glob.glob(os.path.join(args.output_dir, \"{}-*\".format(checkpoint_prefix)))\n", 323 | "\n", 324 | " for path in glob_checkpoints:\n", 325 | " if use_mtime:\n", 326 | " ordering_and_checkpoint_path.append((os.path.getmtime(path), path))\n", 327 | " else:\n", 328 | " regex_match = re.match(\".*{}-([0-9]+)\".format(checkpoint_prefix), path)\n", 329 | " if regex_match and regex_match.groups():\n", 330 | " ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))\n", 331 | "\n", 332 | " checkpoints_sorted = sorted(ordering_and_checkpoint_path)\n", 333 | " checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]\n", 334 | " return checkpoints_sorted\n", 335 | "\n", 336 | "\n", 337 | "def _rotate_checkpoints(args, checkpoint_prefix=\"checkpoint\", use_mtime=False) -> None:\n", 338 | " if not args.save_total_limit:\n", 339 | " return\n", 340 | " if args.save_total_limit <= 0:\n", 341 | " return\n", 342 | "\n", 343 | " # Check if we should delete older checkpoint(s)\n", 344 | " checkpoints_sorted = _sorted_checkpoints(args, checkpoint_prefix, use_mtime)\n", 345 | " if len(checkpoints_sorted) <= args.save_total_limit:\n", 346 | " return\n", 347 | "\n", 348 | " number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - args.save_total_limit)\n", 349 | " checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]\n", 350 | " for checkpoint in checkpoints_to_be_deleted:\n", 351 | " logger.info(\"Deleting older checkpoint [{}] due to args.save_total_limit\".format(checkpoint))\n", 352 | " shutil.rmtree(checkpoint)" 353 | ] 354 | }, 355 | { 356 | "cell_type": "markdown", 357 | "metadata": { 358 | "id": "EEDdTJTqUwZJ" 359 | }, 360 | "source": [ 361 | "## Build Model" 362 | ] 363 | }, 364 | { 365 | "cell_type": "code", 366 | "execution_count": null, 367 | "metadata": { 368 | "colab": { 369 | "base_uri": "https://localhost:8080/" 370 | }, 371 | "id": "r2cE0fY5UHpz", 372 | "outputId": "e4f382cd-57d9-49b7-9da4-4b44fe57df5b" 373 | }, 374 | "outputs": [], 375 | "source": [ 376 | "from transformers import AutoModelWithLMHead, AutoModelForCausalLM, AutoTokenizer\n", 377 | "import torch\n", 378 | "\n", 379 | "tokenizer = AutoTokenizer.from_pretrained(\"microsoft/DialoGPT-small\")\n", 380 | "model = AutoModelWithLMHead.from_pretrained(\"microsoft/DialoGPT-small\")" 381 | ] 382 | }, 383 | { 384 | "cell_type": "code", 385 | "execution_count": null, 386 | "metadata": { 387 | "id": "ra2vsRp-UMXo" 388 | }, 389 | "outputs": [], 390 | "source": [ 391 | "\"\"\"\n", 392 | "Fine-tuning the library models for language modeling on a text file (GPT, GPT-2, BERT, RoBERTa).\n", 393 | "GPT and GPT-2 are fine-tuned using a causal language modeling (CLM) loss while BERT and RoBERTa are fine-tuned\n", 394 | "using a masked language modeling (MLM) loss.\n", 395 | "\"\"\"\n", 396 | "\n", 397 | "# Configs\n", 398 | "logger = logging.getLogger(__name__)\n", 399 | "\n", 400 | "MODEL_CONFIG_CLASSES = list(MODEL_WITH_LM_HEAD_MAPPING.keys())\n", 401 | "MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)" 402 | ] 403 | }, 404 | { 405 | "cell_type": "code", 406 | "execution_count": null, 407 | "metadata": { 408 | "id": "2OnASqJjUNJa" 409 | }, 410 | "outputs": [], 411 | "source": [ 412 | "# Args to allow for easy convertion of python script to notebook\n", 413 | "class Args():\n", 414 | " def __init__(self):\n", 415 | " self.output_dir = 'output-small'\n", 416 | " self.model_type = 'gpt2'\n", 417 | " self.model_name_or_path = 'microsoft/DialoGPT-small'\n", 418 | " self.config_name = 'microsoft/DialoGPT-small'\n", 419 | " self.tokenizer_name = 'microsoft/DialoGPT-small'\n", 420 | " self.cache_dir = 'cached'\n", 421 | " self.block_size = 512\n", 422 | " self.do_train = True\n", 423 | " self.do_eval = True\n", 424 | " self.evaluate_during_training = False\n", 425 | " self.per_gpu_train_batch_size = 4\n", 426 | " self.per_gpu_eval_batch_size = 4\n", 427 | " self.gradient_accumulation_steps = 1\n", 428 | " self.learning_rate = 5e-5\n", 429 | " self.weight_decay = 0.0\n", 430 | " self.adam_epsilon = 1e-8\n", 431 | " self.max_grad_norm = 1.0\n", 432 | " self.num_train_epochs = 4\n", 433 | " self.max_steps = -1\n", 434 | " self.warmup_steps = 0\n", 435 | " self.logging_steps = 1000\n", 436 | " self.save_steps = 3500\n", 437 | " self.save_total_limit = None\n", 438 | " self.eval_all_checkpoints = False\n", 439 | " self.no_cuda = False\n", 440 | " self.overwrite_output_dir = True\n", 441 | " self.overwrite_cache = True\n", 442 | " self.should_continue = False\n", 443 | " self.seed = 42\n", 444 | " self.local_rank = -1\n", 445 | " self.fp16 = False\n", 446 | " self.fp16_opt_level = 'O1'\n", 447 | "\n", 448 | "args = Args()" 449 | ] 450 | }, 451 | { 452 | "cell_type": "markdown", 453 | "metadata": { 454 | "id": "9Q1dTFXxW9NE" 455 | }, 456 | "source": [ 457 | "## Train and Evaluate" 458 | ] 459 | }, 460 | { 461 | "cell_type": "code", 462 | "execution_count": null, 463 | "metadata": { 464 | "id": "PaarIDZrW81h" 465 | }, 466 | "outputs": [], 467 | "source": [ 468 | "def train(args, train_dataset, model: PreTrainedModel, tokenizer: PreTrainedTokenizer) -> Tuple[int, float]:\n", 469 | " \"\"\" Train the model \"\"\"\n", 470 | " if args.local_rank in [-1, 0]:\n", 471 | " tb_writer = SummaryWriter()\n", 472 | "\n", 473 | " args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)\n", 474 | "\n", 475 | " def collate(examples: List[torch.Tensor]):\n", 476 | " if tokenizer._pad_token is None:\n", 477 | " return pad_sequence(examples, batch_first=True)\n", 478 | " return pad_sequence(examples, batch_first=True, padding_value=tokenizer.pad_token_id)\n", 479 | "\n", 480 | " train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)\n", 481 | " train_dataloader = DataLoader(\n", 482 | " train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, collate_fn=collate, drop_last = True\n", 483 | " )\n", 484 | "\n", 485 | " if args.max_steps > 0:\n", 486 | " t_total = args.max_steps\n", 487 | " args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1\n", 488 | " else:\n", 489 | " t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs\n", 490 | "\n", 491 | " model = model.module if hasattr(model, \"module\") else model # Take care of distributed/parallel training\n", 492 | " model.resize_token_embeddings(len(tokenizer))\n", 493 | " # add_special_tokens_(model, tokenizer)\n", 494 | "\n", 495 | "\n", 496 | " # Prepare optimizer and schedule (linear warmup and decay)\n", 497 | " no_decay = [\"bias\", \"LayerNorm.weight\"]\n", 498 | " optimizer_grouped_parameters = [\n", 499 | " {\n", 500 | " \"params\": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],\n", 501 | " \"weight_decay\": args.weight_decay,\n", 502 | " },\n", 503 | " {\"params\": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], \"weight_decay\": 0.0},\n", 504 | " ]\n", 505 | " optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)\n", 506 | " scheduler = get_linear_schedule_with_warmup(\n", 507 | " optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total\n", 508 | " )\n", 509 | "\n", 510 | " # Check if saved optimizer or scheduler states exist\n", 511 | " if (\n", 512 | " args.model_name_or_path\n", 513 | " and os.path.isfile(os.path.join(args.model_name_or_path, \"optimizer.pt\"))\n", 514 | " and os.path.isfile(os.path.join(args.model_name_or_path, \"scheduler.pt\"))\n", 515 | " ):\n", 516 | " # Load in optimizer and scheduler states\n", 517 | " optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, \"optimizer.pt\")))\n", 518 | " scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, \"scheduler.pt\")))\n", 519 | "\n", 520 | " if args.fp16:\n", 521 | " try:\n", 522 | " from apex import amp\n", 523 | " except ImportError:\n", 524 | " raise ImportError(\"Please install apex from https://www.github.com/nvidia/apex to use fp16 training.\")\n", 525 | " model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)\n", 526 | "\n", 527 | " # multi-gpu training (should be after apex fp16 initialization)\n", 528 | " if args.n_gpu > 1:\n", 529 | " model = torch.nn.DataParallel(model)\n", 530 | "\n", 531 | " # Distributed training (should be after apex fp16 initialization)\n", 532 | " if args.local_rank != -1:\n", 533 | " model = torch.nn.parallel.DistributedDataParallel(\n", 534 | " model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True\n", 535 | " )\n", 536 | "\n", 537 | " # Train!\n", 538 | " logger.info(\"***** Running training *****\")\n", 539 | " logger.info(\" Num examples = %d\", len(train_dataset))\n", 540 | " logger.info(\" Num Epochs = %d\", args.num_train_epochs)\n", 541 | " logger.info(\" Instantaneous batch size per GPU = %d\", args.per_gpu_train_batch_size)\n", 542 | " logger.info(\n", 543 | " \" Total train batch size (w. parallel, distributed & accumulation) = %d\",\n", 544 | " args.train_batch_size\n", 545 | " * args.gradient_accumulation_steps\n", 546 | " * (torch.distributed.get_world_size() if args.local_rank != -1 else 1),\n", 547 | " )\n", 548 | " logger.info(\" Gradient Accumulation steps = %d\", args.gradient_accumulation_steps)\n", 549 | " logger.info(\" Total optimization steps = %d\", t_total)\n", 550 | "\n", 551 | " global_step = 0\n", 552 | " epochs_trained = 0\n", 553 | " steps_trained_in_current_epoch = 0\n", 554 | " # Check if continuing training from a checkpoint\n", 555 | " if args.model_name_or_path and os.path.exists(args.model_name_or_path):\n", 556 | " try:\n", 557 | " # set global_step to gobal_step of last saved checkpoint from model path\n", 558 | " checkpoint_suffix = args.model_name_or_path.split(\"-\")[-1].split(\"/\")[0]\n", 559 | " global_step = int(checkpoint_suffix)\n", 560 | " epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)\n", 561 | " steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)\n", 562 | "\n", 563 | " logger.info(\" Continuing training from checkpoint, will skip to saved global_step\")\n", 564 | " logger.info(\" Continuing training from epoch %d\", epochs_trained)\n", 565 | " logger.info(\" Continuing training from global step %d\", global_step)\n", 566 | " logger.info(\" Will skip the first %d steps in the first epoch\", steps_trained_in_current_epoch)\n", 567 | " except ValueError:\n", 568 | " logger.info(\" Starting fine-tuning.\")\n", 569 | "\n", 570 | " tr_loss, logging_loss = 0.0, 0.0\n", 571 | "\n", 572 | " model.zero_grad()\n", 573 | " train_iterator = trange(\n", 574 | " epochs_trained, int(args.num_train_epochs), desc=\"Epoch\", disable=args.local_rank not in [-1, 0]\n", 575 | " )\n", 576 | " set_seed(args) # Added here for reproducibility\n", 577 | " for _ in train_iterator:\n", 578 | " epoch_iterator = tqdm(train_dataloader, desc=\"Iteration\", disable=args.local_rank not in [-1, 0])\n", 579 | " for step, batch in enumerate(epoch_iterator):\n", 580 | "\n", 581 | " # Skip past any already trained steps if resuming training\n", 582 | " if steps_trained_in_current_epoch > 0:\n", 583 | " steps_trained_in_current_epoch -= 1\n", 584 | " continue\n", 585 | "\n", 586 | " inputs, labels = (batch, batch)\n", 587 | " if inputs.shape[1] > 1024: continue\n", 588 | " inputs = inputs.to(args.device)\n", 589 | " labels = labels.to(args.device)\n", 590 | " model.train()\n", 591 | " outputs = model(inputs, labels=labels)\n", 592 | " loss = outputs[0] # model outputs are always tuple in transformers (see doc)\n", 593 | "\n", 594 | " if args.n_gpu > 1:\n", 595 | " loss = loss.mean() # mean() to average on multi-gpu parallel training\n", 596 | " if args.gradient_accumulation_steps > 1:\n", 597 | " loss = loss / args.gradient_accumulation_steps\n", 598 | "\n", 599 | " if args.fp16:\n", 600 | " with amp.scale_loss(loss, optimizer) as scaled_loss:\n", 601 | " scaled_loss.backward()\n", 602 | " else:\n", 603 | " loss.backward()\n", 604 | "\n", 605 | " tr_loss += loss.item()\n", 606 | " if (step + 1) % args.gradient_accumulation_steps == 0:\n", 607 | " if args.fp16:\n", 608 | " torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)\n", 609 | " else:\n", 610 | " torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)\n", 611 | " optimizer.step()\n", 612 | " scheduler.step() # Update learning rate schedule\n", 613 | " model.zero_grad()\n", 614 | " global_step += 1\n", 615 | "\n", 616 | " if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:\n", 617 | " # Log metrics\n", 618 | " if (\n", 619 | " args.local_rank == -1 and args.evaluate_during_training\n", 620 | " ): # Only evaluate when single GPU otherwise metrics may not average well\n", 621 | " results = evaluate(args, model, tokenizer)\n", 622 | " for key, value in results.items():\n", 623 | " tb_writer.add_scalar(\"eval_{}\".format(key), value, global_step)\n", 624 | " tb_writer.add_scalar(\"lr\", scheduler.get_lr()[0], global_step)\n", 625 | " tb_writer.add_scalar(\"loss\", (tr_loss - logging_loss) / args.logging_steps, global_step)\n", 626 | " logging_loss = tr_loss\n", 627 | "\n", 628 | " if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:\n", 629 | " checkpoint_prefix = \"checkpoint\"\n", 630 | " # Save model checkpoint\n", 631 | " output_dir = os.path.join(args.output_dir, \"{}-{}\".format(checkpoint_prefix, global_step))\n", 632 | " os.makedirs(output_dir, exist_ok=True)\n", 633 | " model_to_save = (\n", 634 | " model.module if hasattr(model, \"module\") else model\n", 635 | " ) # Take care of distributed/parallel training\n", 636 | " model_to_save.save_pretrained(output_dir)\n", 637 | " tokenizer.save_pretrained(output_dir)\n", 638 | "\n", 639 | " torch.save(args, os.path.join(output_dir, \"training_args.bin\"))\n", 640 | " logger.info(\"Saving model checkpoint to %s\", output_dir)\n", 641 | "\n", 642 | " _rotate_checkpoints(args, checkpoint_prefix)\n", 643 | "\n", 644 | " torch.save(optimizer.state_dict(), os.path.join(output_dir, \"optimizer.pt\"))\n", 645 | " torch.save(scheduler.state_dict(), os.path.join(output_dir, \"scheduler.pt\"))\n", 646 | " logger.info(\"Saving optimizer and scheduler states to %s\", output_dir)\n", 647 | "\n", 648 | " if args.max_steps > 0 and global_step > args.max_steps:\n", 649 | " epoch_iterator.close()\n", 650 | " break\n", 651 | " if args.max_steps > 0 and global_step > args.max_steps:\n", 652 | " train_iterator.close()\n", 653 | " break\n", 654 | "\n", 655 | " if args.local_rank in [-1, 0]:\n", 656 | " tb_writer.close()\n", 657 | "\n", 658 | " return global_step, tr_loss / global_step\n", 659 | "\n", 660 | "# Evaluation of some model\n", 661 | "\n", 662 | "def evaluate(args, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, df_trn, df_val, prefix=\"\") -> Dict:\n", 663 | " # Loop to handle MNLI double evaluation (matched, mis-matched)\n", 664 | " eval_output_dir = args.output_dir\n", 665 | "\n", 666 | " eval_dataset = load_and_cache_examples(args, tokenizer, df_trn, df_val, evaluate=True)\n", 667 | " os.makedirs(eval_output_dir, exist_ok=True)\n", 668 | " args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)\n", 669 | " # Note that DistributedSampler samples randomly\n", 670 | "\n", 671 | " def collate(examples: List[torch.Tensor]):\n", 672 | " if tokenizer._pad_token is None:\n", 673 | " return pad_sequence(examples, batch_first=True)\n", 674 | " return pad_sequence(examples, batch_first=True, padding_value=tokenizer.pad_token_id)\n", 675 | "\n", 676 | " eval_sampler = SequentialSampler(eval_dataset)\n", 677 | " eval_dataloader = DataLoader(\n", 678 | " eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size, collate_fn=collate, drop_last = True\n", 679 | " )\n", 680 | "\n", 681 | " # multi-gpu evaluate\n", 682 | " if args.n_gpu > 1:\n", 683 | " model = torch.nn.DataParallel(model)\n", 684 | "\n", 685 | " # Eval!\n", 686 | " logger.info(\"***** Running evaluation {} *****\".format(prefix))\n", 687 | " logger.info(\" Num examples = %d\", len(eval_dataset))\n", 688 | " logger.info(\" Batch size = %d\", args.eval_batch_size)\n", 689 | " eval_loss = 0.0\n", 690 | " nb_eval_steps = 0\n", 691 | " model.eval()\n", 692 | "\n", 693 | " for batch in tqdm(eval_dataloader, desc=\"Evaluating\"):\n", 694 | " inputs, labels = (batch, batch)\n", 695 | " inputs = inputs.to(args.device)\n", 696 | " labels = labels.to(args.device)\n", 697 | "\n", 698 | " with torch.no_grad():\n", 699 | " outputs = model(inputs, labels=labels)\n", 700 | " lm_loss = outputs[0]\n", 701 | " eval_loss += lm_loss.mean().item()\n", 702 | " nb_eval_steps += 1\n", 703 | "\n", 704 | " eval_loss = eval_loss / nb_eval_steps\n", 705 | " perplexity = torch.exp(torch.tensor(eval_loss))\n", 706 | "\n", 707 | " result = {\"perplexity\": perplexity}\n", 708 | "\n", 709 | " output_eval_file = os.path.join(eval_output_dir, prefix, \"eval_results.txt\")\n", 710 | " with open(output_eval_file, \"w\") as writer:\n", 711 | " logger.info(\"***** Eval results {} *****\".format(prefix))\n", 712 | " for key in sorted(result.keys()):\n", 713 | " logger.info(\" %s = %s\", key, str(result[key]))\n", 714 | " writer.write(\"%s = %s\\n\" % (key, str(result[key])))\n", 715 | "\n", 716 | " return result" 717 | ] 718 | }, 719 | { 720 | "cell_type": "code", 721 | "execution_count": null, 722 | "metadata": { 723 | "id": "SCnGAJWbXD9C" 724 | }, 725 | "outputs": [], 726 | "source": [ 727 | "# Main runner\n", 728 | "\n", 729 | "def main(df_trn, df_val):\n", 730 | " args = Args()\n", 731 | " \n", 732 | " if args.should_continue:\n", 733 | " sorted_checkpoints = _sorted_checkpoints(args)\n", 734 | " if len(sorted_checkpoints) == 0:\n", 735 | " raise ValueError(\"Used --should_continue but no checkpoint was found in --output_dir.\")\n", 736 | " else:\n", 737 | " args.model_name_or_path = sorted_checkpoints[-1]\n", 738 | "\n", 739 | " if (\n", 740 | " os.path.exists(args.output_dir)\n", 741 | " and os.listdir(args.output_dir)\n", 742 | " and args.do_train\n", 743 | " and not args.overwrite_output_dir\n", 744 | " and not args.should_continue\n", 745 | " ):\n", 746 | " raise ValueError(\n", 747 | " \"Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.\".format(\n", 748 | " args.output_dir\n", 749 | " )\n", 750 | " )\n", 751 | "\n", 752 | " # Setup CUDA, GPU & distributed training\n", 753 | " device = torch.device(\"cuda\")\n", 754 | " args.n_gpu = torch.cuda.device_count()\n", 755 | " args.device = device\n", 756 | "\n", 757 | " # Setup logging\n", 758 | " logging.basicConfig(\n", 759 | " format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n", 760 | " datefmt=\"%m/%d/%Y %H:%M:%S\",\n", 761 | " level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,\n", 762 | " )\n", 763 | " logger.warning(\n", 764 | " \"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s\",\n", 765 | " args.local_rank,\n", 766 | " device,\n", 767 | " args.n_gpu,\n", 768 | " bool(args.local_rank != -1),\n", 769 | " args.fp16,\n", 770 | " )\n", 771 | "\n", 772 | " # Set seed\n", 773 | " set_seed(args)\n", 774 | "\n", 775 | " config = AutoConfig.from_pretrained(args.config_name, cache_dir=args.cache_dir)\n", 776 | " tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, cache_dir=args.cache_dir)\n", 777 | " model = AutoModelWithLMHead.from_pretrained(\n", 778 | " args.model_name_or_path,\n", 779 | " from_tf=False,\n", 780 | " config=config,\n", 781 | " cache_dir=args.cache_dir,\n", 782 | " )\n", 783 | " model.to(args.device)\n", 784 | " \n", 785 | " logger.info(\"Training/evaluation parameters %s\", args)\n", 786 | "\n", 787 | " # Training\n", 788 | " if args.do_train:\n", 789 | " train_dataset = load_and_cache_examples(args, tokenizer, df_trn, df_val, evaluate=False)\n", 790 | "\n", 791 | " global_step, tr_loss = train(args, train_dataset, model, tokenizer)\n", 792 | " logger.info(\" global_step = %s, average loss = %s\", global_step, tr_loss)\n", 793 | "\n", 794 | " # Saving best-practices: if you use save_pretrained for the model and tokenizer, you can reload them using from_pretrained()\n", 795 | " if args.do_train:\n", 796 | " # Create output directory if needed\n", 797 | " os.makedirs(args.output_dir, exist_ok=True)\n", 798 | "\n", 799 | " logger.info(\"Saving model checkpoint to %s\", args.output_dir)\n", 800 | " # Save a trained model, configuration and tokenizer using `save_pretrained()`.\n", 801 | " # They can then be reloaded using `from_pretrained()`\n", 802 | " model_to_save = (\n", 803 | " model.module if hasattr(model, \"module\") else model\n", 804 | " ) # Take care of distributed/parallel training\n", 805 | " model_to_save.save_pretrained(args.output_dir)\n", 806 | " tokenizer.save_pretrained(args.output_dir)\n", 807 | "\n", 808 | " # Good practice: save your training arguments together with the trained model\n", 809 | " torch.save(args, os.path.join(args.output_dir, \"training_args.bin\"))\n", 810 | "\n", 811 | " # Load a trained model and vocabulary that you have fine-tuned\n", 812 | " model = AutoModelWithLMHead.from_pretrained(args.output_dir)\n", 813 | " tokenizer = AutoTokenizer.from_pretrained(args.output_dir)\n", 814 | " model.to(args.device)\n", 815 | "\n", 816 | " # Evaluation\n", 817 | " results = {}\n", 818 | " if args.do_eval and args.local_rank in [-1, 0]:\n", 819 | " checkpoints = [args.output_dir]\n", 820 | " if args.eval_all_checkpoints:\n", 821 | " checkpoints = list(\n", 822 | " os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + \"/**/\" + WEIGHTS_NAME, recursive=True))\n", 823 | " )\n", 824 | " logging.getLogger(\"transformers.modeling_utils\").setLevel(logging.WARN) # Reduce logging\n", 825 | " logger.info(\"Evaluate the following checkpoints: %s\", checkpoints)\n", 826 | " for checkpoint in checkpoints:\n", 827 | " global_step = checkpoint.split(\"-\")[-1] if len(checkpoints) > 1 else \"\"\n", 828 | " prefix = checkpoint.split(\"/\")[-1] if checkpoint.find(\"checkpoint\") != -1 else \"\"\n", 829 | "\n", 830 | " model = AutoModelWithLMHead.from_pretrained(checkpoint)\n", 831 | " model.to(args.device)\n", 832 | " result = evaluate(args, model, tokenizer, df_trn, df_val, prefix=prefix)\n", 833 | " result = dict((k + \"_{}\".format(global_step), v) for k, v in result.items())\n", 834 | " results.update(result)\n", 835 | "\n", 836 | " return results" 837 | ] 838 | }, 839 | { 840 | "cell_type": "markdown", 841 | "metadata": { 842 | "id": "7NWvkdR-XHeB" 843 | }, 844 | "source": [ 845 | "## Run the Main Function" 846 | ] 847 | }, 848 | { 849 | "cell_type": "code", 850 | "execution_count": null, 851 | "metadata": { 852 | "colab": { 853 | "base_uri": "https://localhost:8080/", 854 | "height": 780, 855 | "referenced_widgets": [ 856 | "1d7f4c82687540f1ad69eb54ac3c25b4", 857 | "e7b9f3fc77a24259a87ef0dc735dfecb", 858 | "f3bf54733c2d4d9daa1cc9a7746ccb14", 859 | "aa40eb6346b54e7dac98e0b068cd4927", 860 | "021b771a270f479aa3b9e2b5f17e3d97", 861 | "450b0e7fd7a347c7beb78b7d72f64385", 862 | "9391d7abf6ed4400903995f56d7a1260", 863 | "ea6b919964d24c2f9de1c64c9cefaf23", 864 | "2fa1fa2407384cb98d79a912de2d5b8f", 865 | "dc27e2caf1ea4a4ab9ae3708fb06952f", 866 | "e38fb98fd7b3413392dc39c93a107a35", 867 | "855ca0a6125a4d698416214a9425ad98", 868 | "4699416338ae40a5b6abf19e45089aec", 869 | "43fdb31d3f314624ba07a15718b0c8f3", 870 | "de252cd193114c40ad5f5e9622b7abc7", 871 | "5e48b617cc3f41c3945efc28fc5e0c75", 872 | "68a9dc52819c48fb97259f318f9b5c6a", 873 | "b4e00059cf3a49929978ed780aae8358", 874 | "0ff5f4e3506b493a98d72008a467f35f", 875 | "77b97fa3271b48ac9f93665a102b4fd1", 876 | "a937f1dfeee5432ba31b3016fd30e9e2", 877 | "3c6d446f491c48fcae03e0034bfaaae9", 878 | "a193bb3a0b5b4cbba587e2460075a445", 879 | "75f8aebc30304fe198b5a2898a53a92d", 880 | "8b8a7c771d234f6c9d758a1f07f75a90", 881 | "c6518c4a721745bf97ee682f2ebe4635", 882 | "29cffa2b4f234e12802344eb53838641", 883 | "96243b7b227f465f83a289481680b925", 884 | "8c016a54f0a24fcdacf369baa9d24f1e", 885 | "7fe5b457ca0f417f90a20d235e9cec07", 886 | "fdffb26b99c24c978580f1cf97359fea", 887 | "8e3f1740c82f47949eefc2eb53052eae", 888 | "9cccd43f6acc4e25b4876fd0ae7a2ad6", 889 | "175e94deab7f4d20b99b419bea33583b", 890 | "41f26f7210e540479814e5d68de13ddb", 891 | "cf5cd281fa3b453093e210650bf81e9e", 892 | "e1fbe239c2394cbf973ac5b95e1e1491", 893 | "810ac22adad344b7bf8b556ded990122", 894 | "8b3a41c1900b45ebb9c56601deca0e84", 895 | "002f56aac3d64b33a0e799c0baf1e6b9", 896 | "a0f2a9a279734aa5bf146f0a5b33c43b", 897 | "850b5411122e4d608511fe26818bea68", 898 | "0663fb4bd85f4d87a7d61910b995be14", 899 | "cb7f52610fcf49bda46a14b296ff5bb5", 900 | "0ca29b4a62e04d9c937189ea19b25de8", 901 | "f871b83632974e0088bae65e78efaf28", 902 | "4cacf7fc20754a7ca7fe08c8ec187a81", 903 | "8bcc625c0f284398bbd287fe45021b17" 904 | ] 905 | }, 906 | "id": "e61zo2JtXGNX", 907 | "outputId": "22d4916e-7169-44b5-f9d8-79b9c43fab2e" 908 | }, 909 | "outputs": [], 910 | "source": [ 911 | "main(trn_df, val_df)" 912 | ] 913 | }, 914 | { 915 | "cell_type": "markdown", 916 | "metadata": { 917 | "id": "YRpQ_n2zXQj-" 918 | }, 919 | "source": [ 920 | "## Load the Trained Model" 921 | ] 922 | }, 923 | { 924 | "cell_type": "code", 925 | "execution_count": null, 926 | "metadata": { 927 | "colab": { 928 | "base_uri": "https://localhost:8080/" 929 | }, 930 | "id": "HGw3qgfaXQHX", 931 | "outputId": "93e84cfd-9718-42e5-bd11-418112c91d71" 932 | }, 933 | "outputs": [], 934 | "source": [ 935 | "tokenizer = AutoTokenizer.from_pretrained('microsoft/DialoGPT-small')\n", 936 | "model = AutoModelWithLMHead.from_pretrained('output-small')" 937 | ] 938 | }, 939 | { 940 | "cell_type": "code", 941 | "execution_count": null, 942 | "metadata": { 943 | "colab": { 944 | "base_uri": "https://localhost:8080/" 945 | }, 946 | "id": "lAWsiAvNXbxd", 947 | "outputId": "0fd2541e-ee68-4976-b098-8483efe38d5e" 948 | }, 949 | "outputs": [], 950 | "source": [ 951 | "# Let's chat for 4 lines\n", 952 | "for step in range(4):\n", 953 | " # encode the new user input, add the eos_token and return a tensor in Pytorch\n", 954 | " new_user_input_ids = tokenizer.encode(input(\">> User:\") + tokenizer.eos_token, return_tensors='pt')\n", 955 | " # print(new_user_input_ids)\n", 956 | "\n", 957 | " # append the new user input tokens to the chat history\n", 958 | " bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) if step > 0 else new_user_input_ids\n", 959 | "\n", 960 | " # generated a response while limiting the total chat history to 1000 tokens, \n", 961 | " chat_history_ids = model.generate(\n", 962 | " bot_input_ids, max_length=200,\n", 963 | " pad_token_id=tokenizer.eos_token_id, \n", 964 | " no_repeat_ngram_size=3, \n", 965 | " do_sample=True, \n", 966 | " top_k=100, \n", 967 | " top_p=0.7,\n", 968 | " temperature=0.8\n", 969 | " )\n", 970 | " \n", 971 | " # pretty print last ouput tokens from bot\n", 972 | " print(\"JoshuaBot: {}\".format(tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)))" 973 | ] 974 | }, 975 | { 976 | "cell_type": "markdown", 977 | "metadata": { 978 | "id": "ANSQlQezXqwn" 979 | }, 980 | "source": [ 981 | "## Push Model to Hugging Face" 982 | ] 983 | }, 984 | { 985 | "cell_type": "code", 986 | "execution_count": null, 987 | "metadata": { 988 | "id": "VgnHRgHKXwDd" 989 | }, 990 | "outputs": [], 991 | "source": [ 992 | "!sudo apt-get install git-lfs" 993 | ] 994 | }, 995 | { 996 | "cell_type": "code", 997 | "execution_count": null, 998 | "metadata": { 999 | "id": "uhqMtvfmXei8" 1000 | }, 1001 | "outputs": [], 1002 | "source": [ 1003 | "!git config --global user.email \"lynnzheng08@outlook.com\"\n", 1004 | "# Tip: using the same email as your huggingface.co account will link your commits to your profile\n", 1005 | "!git config --global user.name \"Lynn Zheng\"" 1006 | ] 1007 | }, 1008 | { 1009 | "cell_type": "code", 1010 | "execution_count": null, 1011 | "metadata": { 1012 | "id": "tfUsrKR7YLT1" 1013 | }, 1014 | "outputs": [], 1015 | "source": [ 1016 | "MY_MODEL_NAME = 'DialoGPT-small-joshua'\n", 1017 | "with open('HuggingFace-API-key.txt', 'rt') as f:\n", 1018 | " HUGGINGFACE_API_KEY = f.read().strip()" 1019 | ] 1020 | }, 1021 | { 1022 | "cell_type": "code", 1023 | "execution_count": null, 1024 | "metadata": { 1025 | "colab": { 1026 | "base_uri": "https://localhost:8080/", 1027 | "height": 1000 1028 | }, 1029 | "id": "_65nsiLcYNXI", 1030 | "outputId": "0dbf0cb1-957c-4adb-bf55-4222d2cc85bc" 1031 | }, 1032 | "outputs": [], 1033 | "source": [ 1034 | "model.push_to_hub(MY_MODEL_NAME, use_auth_token=HUGGINGFACE_API_KEY)\n", 1035 | "tokenizer.push_to_hub(MY_MODEL_NAME, use_auth_token=HUGGINGFACE_API_KEY)" 1036 | ] 1037 | }, 1038 | { 1039 | "cell_type": "markdown", 1040 | "metadata": { 1041 | "id": "D_XfXTCrZKmO" 1042 | }, 1043 | "source": [ 1044 | "## All Done!" 1045 | ] 1046 | }, 1047 | { 1048 | "cell_type": "code", 1049 | "execution_count": null, 1050 | "metadata": { 1051 | "id": "_tIwK7G8ZLrd" 1052 | }, 1053 | "outputs": [], 1054 | "source": [] 1055 | } 1056 | ], 1057 | "metadata": { 1058 | "accelerator": "GPU", 1059 | "colab": { 1060 | "collapsed_sections": [], 1061 | "name": "model_train_upload_workflow.ipynb", 1062 | "provenance": [] 1063 | }, 1064 | "kernelspec": { 1065 | "display_name": "Python 3", 1066 | "language": "python", 1067 | "name": "python3" 1068 | }, 1069 | "language_info": { 1070 | "codemirror_mode": { 1071 | "name": "ipython", 1072 | "version": 3 1073 | }, 1074 | "file_extension": ".py", 1075 | "mimetype": "text/x-python", 1076 | "name": "python", 1077 | "nbconvert_exporter": "python", 1078 | "pygments_lexer": "ipython3", 1079 | "version": "3.9.4" 1080 | }, 1081 | "widgets": { 1082 | "application/vnd.jupyter.widget-state+json": { 1083 | "002f56aac3d64b33a0e799c0baf1e6b9": { 1084 | "model_module": "@jupyter-widgets/base", 1085 | "model_name": "LayoutModel", 1086 | "state": { 1087 | "_model_module": "@jupyter-widgets/base", 1088 | "_model_module_version": "1.2.0", 1089 | "_model_name": "LayoutModel", 1090 | "_view_count": null, 1091 | "_view_module": "@jupyter-widgets/base", 1092 | "_view_module_version": "1.2.0", 1093 | "_view_name": "LayoutView", 1094 | "align_content": null, 1095 | "align_items": null, 1096 | "align_self": null, 1097 | "border": null, 1098 | "bottom": null, 1099 | "display": null, 1100 | "flex": null, 1101 | "flex_flow": null, 1102 | "grid_area": null, 1103 | "grid_auto_columns": null, 1104 | "grid_auto_flow": null, 1105 | "grid_auto_rows": null, 1106 | "grid_column": null, 1107 | "grid_gap": null, 1108 | "grid_row": null, 1109 | "grid_template_areas": null, 1110 | "grid_template_columns": null, 1111 | "grid_template_rows": null, 1112 | "height": null, 1113 | "justify_content": null, 1114 | "justify_items": null, 1115 | "left": null, 1116 | "margin": null, 1117 | "max_height": null, 1118 | "max_width": null, 1119 | "min_height": null, 1120 | "min_width": null, 1121 | "object_fit": null, 1122 | "object_position": null, 1123 | "order": null, 1124 | "overflow": null, 1125 | "overflow_x": null, 1126 | "overflow_y": null, 1127 | "padding": null, 1128 | "right": null, 1129 | "top": null, 1130 | "visibility": null, 1131 | "width": null 1132 | } 1133 | }, 1134 | "021b771a270f479aa3b9e2b5f17e3d97": { 1135 | "model_module": "@jupyter-widgets/controls", 1136 | "model_name": "ProgressStyleModel", 1137 | "state": { 1138 | "_model_module": "@jupyter-widgets/controls", 1139 | "_model_module_version": "1.5.0", 1140 | "_model_name": "ProgressStyleModel", 1141 | "_view_count": null, 1142 | "_view_module": "@jupyter-widgets/base", 1143 | "_view_module_version": "1.2.0", 1144 | "_view_name": "StyleView", 1145 | "bar_color": null, 1146 | "description_width": "initial" 1147 | } 1148 | }, 1149 | "0663fb4bd85f4d87a7d61910b995be14": { 1150 | "model_module": "@jupyter-widgets/controls", 1151 | "model_name": "FloatProgressModel", 1152 | "state": { 1153 | "_dom_classes": [], 1154 | "_model_module": "@jupyter-widgets/controls", 1155 | "_model_module_version": "1.5.0", 1156 | "_model_name": "FloatProgressModel", 1157 | "_view_count": null, 1158 | "_view_module": "@jupyter-widgets/controls", 1159 | "_view_module_version": "1.5.0", 1160 | "_view_name": "ProgressView", 1161 | "bar_style": "success", 1162 | "description": "Evaluating: 100%", 1163 | "description_tooltip": null, 1164 | "layout": "IPY_MODEL_f871b83632974e0088bae65e78efaf28", 1165 | "max": 21, 1166 | "min": 0, 1167 | "orientation": "horizontal", 1168 | "style": "IPY_MODEL_0ca29b4a62e04d9c937189ea19b25de8", 1169 | "value": 21 1170 | } 1171 | }, 1172 | "0ca29b4a62e04d9c937189ea19b25de8": { 1173 | "model_module": "@jupyter-widgets/controls", 1174 | "model_name": "ProgressStyleModel", 1175 | "state": { 1176 | "_model_module": "@jupyter-widgets/controls", 1177 | "_model_module_version": "1.5.0", 1178 | "_model_name": "ProgressStyleModel", 1179 | "_view_count": null, 1180 | "_view_module": "@jupyter-widgets/base", 1181 | "_view_module_version": "1.2.0", 1182 | "_view_name": "StyleView", 1183 | "bar_color": null, 1184 | "description_width": "initial" 1185 | } 1186 | }, 1187 | "0ff5f4e3506b493a98d72008a467f35f": { 1188 | "model_module": "@jupyter-widgets/controls", 1189 | "model_name": "FloatProgressModel", 1190 | "state": { 1191 | "_dom_classes": [], 1192 | "_model_module": "@jupyter-widgets/controls", 1193 | "_model_module_version": "1.5.0", 1194 | "_model_name": "FloatProgressModel", 1195 | "_view_count": null, 1196 | "_view_module": "@jupyter-widgets/controls", 1197 | "_view_module_version": "1.5.0", 1198 | "_view_name": "ProgressView", 1199 | "bar_style": "success", 1200 | "description": "Iteration: 100%", 1201 | "description_tooltip": null, 1202 | "layout": "IPY_MODEL_3c6d446f491c48fcae03e0034bfaaae9", 1203 | "max": 195, 1204 | "min": 0, 1205 | "orientation": "horizontal", 1206 | "style": "IPY_MODEL_a937f1dfeee5432ba31b3016fd30e9e2", 1207 | "value": 195 1208 | } 1209 | }, 1210 | "175e94deab7f4d20b99b419bea33583b": { 1211 | "model_module": "@jupyter-widgets/base", 1212 | "model_name": "LayoutModel", 1213 | "state": { 1214 | "_model_module": "@jupyter-widgets/base", 1215 | "_model_module_version": "1.2.0", 1216 | "_model_name": "LayoutModel", 1217 | "_view_count": null, 1218 | "_view_module": "@jupyter-widgets/base", 1219 | "_view_module_version": "1.2.0", 1220 | "_view_name": "LayoutView", 1221 | "align_content": null, 1222 | "align_items": null, 1223 | "align_self": null, 1224 | "border": null, 1225 | "bottom": null, 1226 | "display": null, 1227 | "flex": null, 1228 | "flex_flow": null, 1229 | "grid_area": null, 1230 | "grid_auto_columns": null, 1231 | "grid_auto_flow": null, 1232 | "grid_auto_rows": null, 1233 | "grid_column": null, 1234 | "grid_gap": null, 1235 | "grid_row": null, 1236 | "grid_template_areas": null, 1237 | "grid_template_columns": null, 1238 | "grid_template_rows": null, 1239 | "height": null, 1240 | "justify_content": null, 1241 | "justify_items": null, 1242 | "left": null, 1243 | "margin": null, 1244 | "max_height": null, 1245 | "max_width": null, 1246 | "min_height": null, 1247 | "min_width": null, 1248 | "object_fit": null, 1249 | "object_position": null, 1250 | "order": null, 1251 | "overflow": null, 1252 | "overflow_x": null, 1253 | "overflow_y": null, 1254 | "padding": null, 1255 | "right": null, 1256 | "top": null, 1257 | "visibility": null, 1258 | "width": null 1259 | } 1260 | }, 1261 | "1d7f4c82687540f1ad69eb54ac3c25b4": { 1262 | "model_module": "@jupyter-widgets/controls", 1263 | "model_name": "HBoxModel", 1264 | "state": { 1265 | "_dom_classes": [], 1266 | "_model_module": "@jupyter-widgets/controls", 1267 | "_model_module_version": "1.5.0", 1268 | "_model_name": "HBoxModel", 1269 | "_view_count": null, 1270 | "_view_module": "@jupyter-widgets/controls", 1271 | "_view_module_version": "1.5.0", 1272 | "_view_name": "HBoxView", 1273 | "box_style": "", 1274 | "children": [ 1275 | "IPY_MODEL_f3bf54733c2d4d9daa1cc9a7746ccb14", 1276 | "IPY_MODEL_aa40eb6346b54e7dac98e0b068cd4927" 1277 | ], 1278 | "layout": "IPY_MODEL_e7b9f3fc77a24259a87ef0dc735dfecb" 1279 | } 1280 | }, 1281 | "29cffa2b4f234e12802344eb53838641": { 1282 | "model_module": "@jupyter-widgets/controls", 1283 | "model_name": "FloatProgressModel", 1284 | "state": { 1285 | "_dom_classes": [], 1286 | "_model_module": "@jupyter-widgets/controls", 1287 | "_model_module_version": "1.5.0", 1288 | "_model_name": "FloatProgressModel", 1289 | "_view_count": null, 1290 | "_view_module": "@jupyter-widgets/controls", 1291 | "_view_module_version": "1.5.0", 1292 | "_view_name": "ProgressView", 1293 | "bar_style": "success", 1294 | "description": "Iteration: 100%", 1295 | "description_tooltip": null, 1296 | "layout": "IPY_MODEL_7fe5b457ca0f417f90a20d235e9cec07", 1297 | "max": 195, 1298 | "min": 0, 1299 | "orientation": "horizontal", 1300 | "style": "IPY_MODEL_8c016a54f0a24fcdacf369baa9d24f1e", 1301 | "value": 195 1302 | } 1303 | }, 1304 | "2fa1fa2407384cb98d79a912de2d5b8f": { 1305 | "model_module": "@jupyter-widgets/controls", 1306 | "model_name": "HBoxModel", 1307 | "state": { 1308 | "_dom_classes": [], 1309 | "_model_module": "@jupyter-widgets/controls", 1310 | "_model_module_version": "1.5.0", 1311 | "_model_name": "HBoxModel", 1312 | "_view_count": null, 1313 | "_view_module": "@jupyter-widgets/controls", 1314 | "_view_module_version": "1.5.0", 1315 | "_view_name": "HBoxView", 1316 | "box_style": "", 1317 | "children": [ 1318 | "IPY_MODEL_e38fb98fd7b3413392dc39c93a107a35", 1319 | "IPY_MODEL_855ca0a6125a4d698416214a9425ad98" 1320 | ], 1321 | "layout": "IPY_MODEL_dc27e2caf1ea4a4ab9ae3708fb06952f" 1322 | } 1323 | }, 1324 | "3c6d446f491c48fcae03e0034bfaaae9": { 1325 | "model_module": "@jupyter-widgets/base", 1326 | "model_name": "LayoutModel", 1327 | "state": { 1328 | "_model_module": "@jupyter-widgets/base", 1329 | "_model_module_version": "1.2.0", 1330 | "_model_name": "LayoutModel", 1331 | "_view_count": null, 1332 | "_view_module": "@jupyter-widgets/base", 1333 | "_view_module_version": "1.2.0", 1334 | "_view_name": "LayoutView", 1335 | "align_content": null, 1336 | "align_items": null, 1337 | "align_self": null, 1338 | "border": null, 1339 | "bottom": null, 1340 | "display": null, 1341 | "flex": null, 1342 | "flex_flow": null, 1343 | "grid_area": null, 1344 | "grid_auto_columns": null, 1345 | "grid_auto_flow": null, 1346 | "grid_auto_rows": null, 1347 | "grid_column": null, 1348 | "grid_gap": null, 1349 | "grid_row": null, 1350 | "grid_template_areas": null, 1351 | "grid_template_columns": null, 1352 | "grid_template_rows": null, 1353 | "height": null, 1354 | "justify_content": null, 1355 | "justify_items": null, 1356 | "left": null, 1357 | "margin": null, 1358 | "max_height": null, 1359 | "max_width": null, 1360 | "min_height": null, 1361 | "min_width": null, 1362 | "object_fit": null, 1363 | "object_position": null, 1364 | "order": null, 1365 | "overflow": null, 1366 | "overflow_x": null, 1367 | "overflow_y": null, 1368 | "padding": null, 1369 | "right": null, 1370 | "top": null, 1371 | "visibility": null, 1372 | "width": null 1373 | } 1374 | }, 1375 | "41f26f7210e540479814e5d68de13ddb": { 1376 | "model_module": "@jupyter-widgets/controls", 1377 | "model_name": "FloatProgressModel", 1378 | "state": { 1379 | "_dom_classes": [], 1380 | "_model_module": "@jupyter-widgets/controls", 1381 | "_model_module_version": "1.5.0", 1382 | "_model_name": "FloatProgressModel", 1383 | "_view_count": null, 1384 | "_view_module": "@jupyter-widgets/controls", 1385 | "_view_module_version": "1.5.0", 1386 | "_view_name": "ProgressView", 1387 | "bar_style": "success", 1388 | "description": "Iteration: 100%", 1389 | "description_tooltip": null, 1390 | "layout": "IPY_MODEL_810ac22adad344b7bf8b556ded990122", 1391 | "max": 195, 1392 | "min": 0, 1393 | "orientation": "horizontal", 1394 | "style": "IPY_MODEL_e1fbe239c2394cbf973ac5b95e1e1491", 1395 | "value": 195 1396 | } 1397 | }, 1398 | "43fdb31d3f314624ba07a15718b0c8f3": { 1399 | "model_module": "@jupyter-widgets/base", 1400 | "model_name": "LayoutModel", 1401 | "state": { 1402 | "_model_module": "@jupyter-widgets/base", 1403 | "_model_module_version": "1.2.0", 1404 | "_model_name": "LayoutModel", 1405 | "_view_count": null, 1406 | "_view_module": "@jupyter-widgets/base", 1407 | "_view_module_version": "1.2.0", 1408 | "_view_name": "LayoutView", 1409 | "align_content": null, 1410 | "align_items": null, 1411 | "align_self": null, 1412 | "border": null, 1413 | "bottom": null, 1414 | "display": null, 1415 | "flex": null, 1416 | "flex_flow": null, 1417 | "grid_area": null, 1418 | "grid_auto_columns": null, 1419 | "grid_auto_flow": null, 1420 | "grid_auto_rows": null, 1421 | "grid_column": null, 1422 | "grid_gap": null, 1423 | "grid_row": null, 1424 | "grid_template_areas": null, 1425 | "grid_template_columns": null, 1426 | "grid_template_rows": null, 1427 | "height": null, 1428 | "justify_content": null, 1429 | "justify_items": null, 1430 | "left": null, 1431 | "margin": null, 1432 | "max_height": null, 1433 | "max_width": null, 1434 | "min_height": null, 1435 | "min_width": null, 1436 | "object_fit": null, 1437 | "object_position": null, 1438 | "order": null, 1439 | "overflow": null, 1440 | "overflow_x": null, 1441 | "overflow_y": null, 1442 | "padding": null, 1443 | "right": null, 1444 | "top": null, 1445 | "visibility": null, 1446 | "width": null 1447 | } 1448 | }, 1449 | "450b0e7fd7a347c7beb78b7d72f64385": { 1450 | "model_module": "@jupyter-widgets/base", 1451 | "model_name": "LayoutModel", 1452 | "state": { 1453 | "_model_module": "@jupyter-widgets/base", 1454 | "_model_module_version": "1.2.0", 1455 | "_model_name": "LayoutModel", 1456 | "_view_count": null, 1457 | "_view_module": "@jupyter-widgets/base", 1458 | "_view_module_version": "1.2.0", 1459 | "_view_name": "LayoutView", 1460 | "align_content": null, 1461 | "align_items": null, 1462 | "align_self": null, 1463 | "border": null, 1464 | "bottom": null, 1465 | "display": null, 1466 | "flex": null, 1467 | "flex_flow": null, 1468 | "grid_area": null, 1469 | "grid_auto_columns": null, 1470 | "grid_auto_flow": null, 1471 | "grid_auto_rows": null, 1472 | "grid_column": null, 1473 | "grid_gap": null, 1474 | "grid_row": null, 1475 | "grid_template_areas": null, 1476 | "grid_template_columns": null, 1477 | "grid_template_rows": null, 1478 | "height": null, 1479 | "justify_content": null, 1480 | "justify_items": null, 1481 | "left": null, 1482 | "margin": null, 1483 | "max_height": null, 1484 | "max_width": null, 1485 | "min_height": null, 1486 | "min_width": null, 1487 | "object_fit": null, 1488 | "object_position": null, 1489 | "order": null, 1490 | "overflow": null, 1491 | "overflow_x": null, 1492 | "overflow_y": null, 1493 | "padding": null, 1494 | "right": null, 1495 | "top": null, 1496 | "visibility": null, 1497 | "width": null 1498 | } 1499 | }, 1500 | "4699416338ae40a5b6abf19e45089aec": { 1501 | "model_module": "@jupyter-widgets/controls", 1502 | "model_name": "ProgressStyleModel", 1503 | "state": { 1504 | "_model_module": "@jupyter-widgets/controls", 1505 | "_model_module_version": "1.5.0", 1506 | "_model_name": "ProgressStyleModel", 1507 | "_view_count": null, 1508 | "_view_module": "@jupyter-widgets/base", 1509 | "_view_module_version": "1.2.0", 1510 | "_view_name": "StyleView", 1511 | "bar_color": null, 1512 | "description_width": "initial" 1513 | } 1514 | }, 1515 | "4cacf7fc20754a7ca7fe08c8ec187a81": { 1516 | "model_module": "@jupyter-widgets/controls", 1517 | "model_name": "DescriptionStyleModel", 1518 | "state": { 1519 | "_model_module": "@jupyter-widgets/controls", 1520 | "_model_module_version": "1.5.0", 1521 | "_model_name": "DescriptionStyleModel", 1522 | "_view_count": null, 1523 | "_view_module": "@jupyter-widgets/base", 1524 | "_view_module_version": "1.2.0", 1525 | "_view_name": "StyleView", 1526 | "description_width": "" 1527 | } 1528 | }, 1529 | "5e48b617cc3f41c3945efc28fc5e0c75": { 1530 | "model_module": "@jupyter-widgets/base", 1531 | "model_name": "LayoutModel", 1532 | "state": { 1533 | "_model_module": "@jupyter-widgets/base", 1534 | "_model_module_version": "1.2.0", 1535 | "_model_name": "LayoutModel", 1536 | "_view_count": null, 1537 | "_view_module": "@jupyter-widgets/base", 1538 | "_view_module_version": "1.2.0", 1539 | "_view_name": "LayoutView", 1540 | "align_content": null, 1541 | "align_items": null, 1542 | "align_self": null, 1543 | "border": null, 1544 | "bottom": null, 1545 | "display": null, 1546 | "flex": null, 1547 | "flex_flow": null, 1548 | "grid_area": null, 1549 | "grid_auto_columns": null, 1550 | "grid_auto_flow": null, 1551 | "grid_auto_rows": null, 1552 | "grid_column": null, 1553 | "grid_gap": null, 1554 | "grid_row": null, 1555 | "grid_template_areas": null, 1556 | "grid_template_columns": null, 1557 | "grid_template_rows": null, 1558 | "height": null, 1559 | "justify_content": null, 1560 | "justify_items": null, 1561 | "left": null, 1562 | "margin": null, 1563 | "max_height": null, 1564 | "max_width": null, 1565 | "min_height": null, 1566 | "min_width": null, 1567 | "object_fit": null, 1568 | "object_position": null, 1569 | "order": null, 1570 | "overflow": null, 1571 | "overflow_x": null, 1572 | "overflow_y": null, 1573 | "padding": null, 1574 | "right": null, 1575 | "top": null, 1576 | "visibility": null, 1577 | "width": null 1578 | } 1579 | }, 1580 | "68a9dc52819c48fb97259f318f9b5c6a": { 1581 | "model_module": "@jupyter-widgets/controls", 1582 | "model_name": "HBoxModel", 1583 | "state": { 1584 | "_dom_classes": [], 1585 | "_model_module": "@jupyter-widgets/controls", 1586 | "_model_module_version": "1.5.0", 1587 | "_model_name": "HBoxModel", 1588 | "_view_count": null, 1589 | "_view_module": "@jupyter-widgets/controls", 1590 | "_view_module_version": "1.5.0", 1591 | "_view_name": "HBoxView", 1592 | "box_style": "", 1593 | "children": [ 1594 | "IPY_MODEL_0ff5f4e3506b493a98d72008a467f35f", 1595 | "IPY_MODEL_77b97fa3271b48ac9f93665a102b4fd1" 1596 | ], 1597 | "layout": "IPY_MODEL_b4e00059cf3a49929978ed780aae8358" 1598 | } 1599 | }, 1600 | "75f8aebc30304fe198b5a2898a53a92d": { 1601 | "model_module": "@jupyter-widgets/base", 1602 | "model_name": "LayoutModel", 1603 | "state": { 1604 | "_model_module": "@jupyter-widgets/base", 1605 | "_model_module_version": "1.2.0", 1606 | "_model_name": "LayoutModel", 1607 | "_view_count": null, 1608 | "_view_module": "@jupyter-widgets/base", 1609 | "_view_module_version": "1.2.0", 1610 | "_view_name": "LayoutView", 1611 | "align_content": null, 1612 | "align_items": null, 1613 | "align_self": null, 1614 | "border": null, 1615 | "bottom": null, 1616 | "display": null, 1617 | "flex": null, 1618 | "flex_flow": null, 1619 | "grid_area": null, 1620 | "grid_auto_columns": null, 1621 | "grid_auto_flow": null, 1622 | "grid_auto_rows": null, 1623 | "grid_column": null, 1624 | "grid_gap": null, 1625 | "grid_row": null, 1626 | "grid_template_areas": null, 1627 | "grid_template_columns": null, 1628 | "grid_template_rows": null, 1629 | "height": null, 1630 | "justify_content": null, 1631 | "justify_items": null, 1632 | "left": null, 1633 | "margin": null, 1634 | "max_height": null, 1635 | "max_width": null, 1636 | "min_height": null, 1637 | "min_width": null, 1638 | "object_fit": null, 1639 | "object_position": null, 1640 | "order": null, 1641 | "overflow": null, 1642 | "overflow_x": null, 1643 | "overflow_y": null, 1644 | "padding": null, 1645 | "right": null, 1646 | "top": null, 1647 | "visibility": null, 1648 | "width": null 1649 | } 1650 | }, 1651 | "77b97fa3271b48ac9f93665a102b4fd1": { 1652 | "model_module": "@jupyter-widgets/controls", 1653 | "model_name": "HTMLModel", 1654 | "state": { 1655 | "_dom_classes": [], 1656 | "_model_module": "@jupyter-widgets/controls", 1657 | "_model_module_version": "1.5.0", 1658 | "_model_name": "HTMLModel", 1659 | "_view_count": null, 1660 | "_view_module": "@jupyter-widgets/controls", 1661 | "_view_module_version": "1.5.0", 1662 | "_view_name": "HTMLView", 1663 | "description": "", 1664 | "description_tooltip": null, 1665 | "layout": "IPY_MODEL_75f8aebc30304fe198b5a2898a53a92d", 1666 | "placeholder": "​", 1667 | "style": "IPY_MODEL_a193bb3a0b5b4cbba587e2460075a445", 1668 | "value": " 195/195 [00:35<00:00, 5.45it/s]" 1669 | } 1670 | }, 1671 | "7fe5b457ca0f417f90a20d235e9cec07": { 1672 | "model_module": "@jupyter-widgets/base", 1673 | "model_name": "LayoutModel", 1674 | "state": { 1675 | "_model_module": "@jupyter-widgets/base", 1676 | "_model_module_version": "1.2.0", 1677 | "_model_name": "LayoutModel", 1678 | "_view_count": null, 1679 | "_view_module": "@jupyter-widgets/base", 1680 | "_view_module_version": "1.2.0", 1681 | "_view_name": "LayoutView", 1682 | "align_content": null, 1683 | "align_items": null, 1684 | "align_self": null, 1685 | "border": null, 1686 | "bottom": null, 1687 | "display": null, 1688 | "flex": null, 1689 | "flex_flow": null, 1690 | "grid_area": null, 1691 | "grid_auto_columns": null, 1692 | "grid_auto_flow": null, 1693 | "grid_auto_rows": null, 1694 | "grid_column": null, 1695 | "grid_gap": null, 1696 | "grid_row": null, 1697 | "grid_template_areas": null, 1698 | "grid_template_columns": null, 1699 | "grid_template_rows": null, 1700 | "height": null, 1701 | "justify_content": null, 1702 | "justify_items": null, 1703 | "left": null, 1704 | "margin": null, 1705 | "max_height": null, 1706 | "max_width": null, 1707 | "min_height": null, 1708 | "min_width": null, 1709 | "object_fit": null, 1710 | "object_position": null, 1711 | "order": null, 1712 | "overflow": null, 1713 | "overflow_x": null, 1714 | "overflow_y": null, 1715 | "padding": null, 1716 | "right": null, 1717 | "top": null, 1718 | "visibility": null, 1719 | "width": null 1720 | } 1721 | }, 1722 | "810ac22adad344b7bf8b556ded990122": { 1723 | "model_module": "@jupyter-widgets/base", 1724 | "model_name": "LayoutModel", 1725 | "state": { 1726 | "_model_module": "@jupyter-widgets/base", 1727 | "_model_module_version": "1.2.0", 1728 | "_model_name": "LayoutModel", 1729 | "_view_count": null, 1730 | "_view_module": "@jupyter-widgets/base", 1731 | "_view_module_version": "1.2.0", 1732 | "_view_name": "LayoutView", 1733 | "align_content": null, 1734 | "align_items": null, 1735 | "align_self": null, 1736 | "border": null, 1737 | "bottom": null, 1738 | "display": null, 1739 | "flex": null, 1740 | "flex_flow": null, 1741 | "grid_area": null, 1742 | "grid_auto_columns": null, 1743 | "grid_auto_flow": null, 1744 | "grid_auto_rows": null, 1745 | "grid_column": null, 1746 | "grid_gap": null, 1747 | "grid_row": null, 1748 | "grid_template_areas": null, 1749 | "grid_template_columns": null, 1750 | "grid_template_rows": null, 1751 | "height": null, 1752 | "justify_content": null, 1753 | "justify_items": null, 1754 | "left": null, 1755 | "margin": null, 1756 | "max_height": null, 1757 | "max_width": null, 1758 | "min_height": null, 1759 | "min_width": null, 1760 | "object_fit": null, 1761 | "object_position": null, 1762 | "order": null, 1763 | "overflow": null, 1764 | "overflow_x": null, 1765 | "overflow_y": null, 1766 | "padding": null, 1767 | "right": null, 1768 | "top": null, 1769 | "visibility": null, 1770 | "width": null 1771 | } 1772 | }, 1773 | "850b5411122e4d608511fe26818bea68": { 1774 | "model_module": "@jupyter-widgets/base", 1775 | "model_name": "LayoutModel", 1776 | "state": { 1777 | "_model_module": "@jupyter-widgets/base", 1778 | "_model_module_version": "1.2.0", 1779 | "_model_name": "LayoutModel", 1780 | "_view_count": null, 1781 | "_view_module": "@jupyter-widgets/base", 1782 | "_view_module_version": "1.2.0", 1783 | "_view_name": "LayoutView", 1784 | "align_content": null, 1785 | "align_items": null, 1786 | "align_self": null, 1787 | "border": null, 1788 | "bottom": null, 1789 | "display": null, 1790 | "flex": null, 1791 | "flex_flow": null, 1792 | "grid_area": null, 1793 | "grid_auto_columns": null, 1794 | "grid_auto_flow": null, 1795 | "grid_auto_rows": null, 1796 | "grid_column": null, 1797 | "grid_gap": null, 1798 | "grid_row": null, 1799 | "grid_template_areas": null, 1800 | "grid_template_columns": null, 1801 | "grid_template_rows": null, 1802 | "height": null, 1803 | "justify_content": null, 1804 | "justify_items": null, 1805 | "left": null, 1806 | "margin": null, 1807 | "max_height": null, 1808 | "max_width": null, 1809 | "min_height": null, 1810 | "min_width": null, 1811 | "object_fit": null, 1812 | "object_position": null, 1813 | "order": null, 1814 | "overflow": null, 1815 | "overflow_x": null, 1816 | "overflow_y": null, 1817 | "padding": null, 1818 | "right": null, 1819 | "top": null, 1820 | "visibility": null, 1821 | "width": null 1822 | } 1823 | }, 1824 | "855ca0a6125a4d698416214a9425ad98": { 1825 | "model_module": "@jupyter-widgets/controls", 1826 | "model_name": "HTMLModel", 1827 | "state": { 1828 | "_dom_classes": [], 1829 | "_model_module": "@jupyter-widgets/controls", 1830 | "_model_module_version": "1.5.0", 1831 | "_model_name": "HTMLModel", 1832 | "_view_count": null, 1833 | "_view_module": "@jupyter-widgets/controls", 1834 | "_view_module_version": "1.5.0", 1835 | "_view_name": "HTMLView", 1836 | "description": "", 1837 | "description_tooltip": null, 1838 | "layout": "IPY_MODEL_5e48b617cc3f41c3945efc28fc5e0c75", 1839 | "placeholder": "​", 1840 | "style": "IPY_MODEL_de252cd193114c40ad5f5e9622b7abc7", 1841 | "value": " 195/195 [00:44<00:00, 4.39it/s]" 1842 | } 1843 | }, 1844 | "8b3a41c1900b45ebb9c56601deca0e84": { 1845 | "model_module": "@jupyter-widgets/controls", 1846 | "model_name": "DescriptionStyleModel", 1847 | "state": { 1848 | "_model_module": "@jupyter-widgets/controls", 1849 | "_model_module_version": "1.5.0", 1850 | "_model_name": "DescriptionStyleModel", 1851 | "_view_count": null, 1852 | "_view_module": "@jupyter-widgets/base", 1853 | "_view_module_version": "1.2.0", 1854 | "_view_name": "StyleView", 1855 | "description_width": "" 1856 | } 1857 | }, 1858 | "8b8a7c771d234f6c9d758a1f07f75a90": { 1859 | "model_module": "@jupyter-widgets/controls", 1860 | "model_name": "HBoxModel", 1861 | "state": { 1862 | "_dom_classes": [], 1863 | "_model_module": "@jupyter-widgets/controls", 1864 | "_model_module_version": "1.5.0", 1865 | "_model_name": "HBoxModel", 1866 | "_view_count": null, 1867 | "_view_module": "@jupyter-widgets/controls", 1868 | "_view_module_version": "1.5.0", 1869 | "_view_name": "HBoxView", 1870 | "box_style": "", 1871 | "children": [ 1872 | "IPY_MODEL_29cffa2b4f234e12802344eb53838641", 1873 | "IPY_MODEL_96243b7b227f465f83a289481680b925" 1874 | ], 1875 | "layout": "IPY_MODEL_c6518c4a721745bf97ee682f2ebe4635" 1876 | } 1877 | }, 1878 | "8bcc625c0f284398bbd287fe45021b17": { 1879 | "model_module": "@jupyter-widgets/base", 1880 | "model_name": "LayoutModel", 1881 | "state": { 1882 | "_model_module": "@jupyter-widgets/base", 1883 | "_model_module_version": "1.2.0", 1884 | "_model_name": "LayoutModel", 1885 | "_view_count": null, 1886 | "_view_module": "@jupyter-widgets/base", 1887 | "_view_module_version": "1.2.0", 1888 | "_view_name": "LayoutView", 1889 | "align_content": null, 1890 | "align_items": null, 1891 | "align_self": null, 1892 | "border": null, 1893 | "bottom": null, 1894 | "display": null, 1895 | "flex": null, 1896 | "flex_flow": null, 1897 | "grid_area": null, 1898 | "grid_auto_columns": null, 1899 | "grid_auto_flow": null, 1900 | "grid_auto_rows": null, 1901 | "grid_column": null, 1902 | "grid_gap": null, 1903 | "grid_row": null, 1904 | "grid_template_areas": null, 1905 | "grid_template_columns": null, 1906 | "grid_template_rows": null, 1907 | "height": null, 1908 | "justify_content": null, 1909 | "justify_items": null, 1910 | "left": null, 1911 | "margin": null, 1912 | "max_height": null, 1913 | "max_width": null, 1914 | "min_height": null, 1915 | "min_width": null, 1916 | "object_fit": null, 1917 | "object_position": null, 1918 | "order": null, 1919 | "overflow": null, 1920 | "overflow_x": null, 1921 | "overflow_y": null, 1922 | "padding": null, 1923 | "right": null, 1924 | "top": null, 1925 | "visibility": null, 1926 | "width": null 1927 | } 1928 | }, 1929 | "8c016a54f0a24fcdacf369baa9d24f1e": { 1930 | "model_module": "@jupyter-widgets/controls", 1931 | "model_name": "ProgressStyleModel", 1932 | "state": { 1933 | "_model_module": "@jupyter-widgets/controls", 1934 | "_model_module_version": "1.5.0", 1935 | "_model_name": "ProgressStyleModel", 1936 | "_view_count": null, 1937 | "_view_module": "@jupyter-widgets/base", 1938 | "_view_module_version": "1.2.0", 1939 | "_view_name": "StyleView", 1940 | "bar_color": null, 1941 | "description_width": "initial" 1942 | } 1943 | }, 1944 | "8e3f1740c82f47949eefc2eb53052eae": { 1945 | "model_module": "@jupyter-widgets/base", 1946 | "model_name": "LayoutModel", 1947 | "state": { 1948 | "_model_module": "@jupyter-widgets/base", 1949 | "_model_module_version": "1.2.0", 1950 | "_model_name": "LayoutModel", 1951 | "_view_count": null, 1952 | "_view_module": "@jupyter-widgets/base", 1953 | "_view_module_version": "1.2.0", 1954 | "_view_name": "LayoutView", 1955 | "align_content": null, 1956 | "align_items": null, 1957 | "align_self": null, 1958 | "border": null, 1959 | "bottom": null, 1960 | "display": null, 1961 | "flex": null, 1962 | "flex_flow": null, 1963 | "grid_area": null, 1964 | "grid_auto_columns": null, 1965 | "grid_auto_flow": null, 1966 | "grid_auto_rows": null, 1967 | "grid_column": null, 1968 | "grid_gap": null, 1969 | "grid_row": null, 1970 | "grid_template_areas": null, 1971 | "grid_template_columns": null, 1972 | "grid_template_rows": null, 1973 | "height": null, 1974 | "justify_content": null, 1975 | "justify_items": null, 1976 | "left": null, 1977 | "margin": null, 1978 | "max_height": null, 1979 | "max_width": null, 1980 | "min_height": null, 1981 | "min_width": null, 1982 | "object_fit": null, 1983 | "object_position": null, 1984 | "order": null, 1985 | "overflow": null, 1986 | "overflow_x": null, 1987 | "overflow_y": null, 1988 | "padding": null, 1989 | "right": null, 1990 | "top": null, 1991 | "visibility": null, 1992 | "width": null 1993 | } 1994 | }, 1995 | "9391d7abf6ed4400903995f56d7a1260": { 1996 | "model_module": "@jupyter-widgets/controls", 1997 | "model_name": "DescriptionStyleModel", 1998 | "state": { 1999 | "_model_module": "@jupyter-widgets/controls", 2000 | "_model_module_version": "1.5.0", 2001 | "_model_name": "DescriptionStyleModel", 2002 | "_view_count": null, 2003 | "_view_module": "@jupyter-widgets/base", 2004 | "_view_module_version": "1.2.0", 2005 | "_view_name": "StyleView", 2006 | "description_width": "" 2007 | } 2008 | }, 2009 | "96243b7b227f465f83a289481680b925": { 2010 | "model_module": "@jupyter-widgets/controls", 2011 | "model_name": "HTMLModel", 2012 | "state": { 2013 | "_dom_classes": [], 2014 | "_model_module": "@jupyter-widgets/controls", 2015 | "_model_module_version": "1.5.0", 2016 | "_model_name": "HTMLModel", 2017 | "_view_count": null, 2018 | "_view_module": "@jupyter-widgets/controls", 2019 | "_view_module_version": "1.5.0", 2020 | "_view_name": "HTMLView", 2021 | "description": "", 2022 | "description_tooltip": null, 2023 | "layout": "IPY_MODEL_8e3f1740c82f47949eefc2eb53052eae", 2024 | "placeholder": "​", 2025 | "style": "IPY_MODEL_fdffb26b99c24c978580f1cf97359fea", 2026 | "value": " 195/195 [01:17<00:00, 2.53it/s]" 2027 | } 2028 | }, 2029 | "9cccd43f6acc4e25b4876fd0ae7a2ad6": { 2030 | "model_module": "@jupyter-widgets/controls", 2031 | "model_name": "HBoxModel", 2032 | "state": { 2033 | "_dom_classes": [], 2034 | "_model_module": "@jupyter-widgets/controls", 2035 | "_model_module_version": "1.5.0", 2036 | "_model_name": "HBoxModel", 2037 | "_view_count": null, 2038 | "_view_module": "@jupyter-widgets/controls", 2039 | "_view_module_version": "1.5.0", 2040 | "_view_name": "HBoxView", 2041 | "box_style": "", 2042 | "children": [ 2043 | "IPY_MODEL_41f26f7210e540479814e5d68de13ddb", 2044 | "IPY_MODEL_cf5cd281fa3b453093e210650bf81e9e" 2045 | ], 2046 | "layout": "IPY_MODEL_175e94deab7f4d20b99b419bea33583b" 2047 | } 2048 | }, 2049 | "a0f2a9a279734aa5bf146f0a5b33c43b": { 2050 | "model_module": "@jupyter-widgets/controls", 2051 | "model_name": "HBoxModel", 2052 | "state": { 2053 | "_dom_classes": [], 2054 | "_model_module": "@jupyter-widgets/controls", 2055 | "_model_module_version": "1.5.0", 2056 | "_model_name": "HBoxModel", 2057 | "_view_count": null, 2058 | "_view_module": "@jupyter-widgets/controls", 2059 | "_view_module_version": "1.5.0", 2060 | "_view_name": "HBoxView", 2061 | "box_style": "", 2062 | "children": [ 2063 | "IPY_MODEL_0663fb4bd85f4d87a7d61910b995be14", 2064 | "IPY_MODEL_cb7f52610fcf49bda46a14b296ff5bb5" 2065 | ], 2066 | "layout": "IPY_MODEL_850b5411122e4d608511fe26818bea68" 2067 | } 2068 | }, 2069 | "a193bb3a0b5b4cbba587e2460075a445": { 2070 | "model_module": "@jupyter-widgets/controls", 2071 | "model_name": "DescriptionStyleModel", 2072 | "state": { 2073 | "_model_module": "@jupyter-widgets/controls", 2074 | "_model_module_version": "1.5.0", 2075 | "_model_name": "DescriptionStyleModel", 2076 | "_view_count": null, 2077 | "_view_module": "@jupyter-widgets/base", 2078 | "_view_module_version": "1.2.0", 2079 | "_view_name": "StyleView", 2080 | "description_width": "" 2081 | } 2082 | }, 2083 | "a937f1dfeee5432ba31b3016fd30e9e2": { 2084 | "model_module": "@jupyter-widgets/controls", 2085 | "model_name": "ProgressStyleModel", 2086 | "state": { 2087 | "_model_module": "@jupyter-widgets/controls", 2088 | "_model_module_version": "1.5.0", 2089 | "_model_name": "ProgressStyleModel", 2090 | "_view_count": null, 2091 | "_view_module": "@jupyter-widgets/base", 2092 | "_view_module_version": "1.2.0", 2093 | "_view_name": "StyleView", 2094 | "bar_color": null, 2095 | "description_width": "initial" 2096 | } 2097 | }, 2098 | "aa40eb6346b54e7dac98e0b068cd4927": { 2099 | "model_module": "@jupyter-widgets/controls", 2100 | "model_name": "HTMLModel", 2101 | "state": { 2102 | "_dom_classes": [], 2103 | "_model_module": "@jupyter-widgets/controls", 2104 | "_model_module_version": "1.5.0", 2105 | "_model_name": "HTMLModel", 2106 | "_view_count": null, 2107 | "_view_module": "@jupyter-widgets/controls", 2108 | "_view_module_version": "1.5.0", 2109 | "_view_name": "HTMLView", 2110 | "description": "", 2111 | "description_tooltip": null, 2112 | "layout": "IPY_MODEL_ea6b919964d24c2f9de1c64c9cefaf23", 2113 | "placeholder": "​", 2114 | "style": "IPY_MODEL_9391d7abf6ed4400903995f56d7a1260", 2115 | "value": " 4/4 [02:23<00:00, 36.00s/it]" 2116 | } 2117 | }, 2118 | "b4e00059cf3a49929978ed780aae8358": { 2119 | "model_module": "@jupyter-widgets/base", 2120 | "model_name": "LayoutModel", 2121 | "state": { 2122 | "_model_module": "@jupyter-widgets/base", 2123 | "_model_module_version": "1.2.0", 2124 | "_model_name": "LayoutModel", 2125 | "_view_count": null, 2126 | "_view_module": "@jupyter-widgets/base", 2127 | "_view_module_version": "1.2.0", 2128 | "_view_name": "LayoutView", 2129 | "align_content": null, 2130 | "align_items": null, 2131 | "align_self": null, 2132 | "border": null, 2133 | "bottom": null, 2134 | "display": null, 2135 | "flex": null, 2136 | "flex_flow": null, 2137 | "grid_area": null, 2138 | "grid_auto_columns": null, 2139 | "grid_auto_flow": null, 2140 | "grid_auto_rows": null, 2141 | "grid_column": null, 2142 | "grid_gap": null, 2143 | "grid_row": null, 2144 | "grid_template_areas": null, 2145 | "grid_template_columns": null, 2146 | "grid_template_rows": null, 2147 | "height": null, 2148 | "justify_content": null, 2149 | "justify_items": null, 2150 | "left": null, 2151 | "margin": null, 2152 | "max_height": null, 2153 | "max_width": null, 2154 | "min_height": null, 2155 | "min_width": null, 2156 | "object_fit": null, 2157 | "object_position": null, 2158 | "order": null, 2159 | "overflow": null, 2160 | "overflow_x": null, 2161 | "overflow_y": null, 2162 | "padding": null, 2163 | "right": null, 2164 | "top": null, 2165 | "visibility": null, 2166 | "width": null 2167 | } 2168 | }, 2169 | "c6518c4a721745bf97ee682f2ebe4635": { 2170 | "model_module": "@jupyter-widgets/base", 2171 | "model_name": "LayoutModel", 2172 | "state": { 2173 | "_model_module": "@jupyter-widgets/base", 2174 | "_model_module_version": "1.2.0", 2175 | "_model_name": "LayoutModel", 2176 | "_view_count": null, 2177 | "_view_module": "@jupyter-widgets/base", 2178 | "_view_module_version": "1.2.0", 2179 | "_view_name": "LayoutView", 2180 | "align_content": null, 2181 | "align_items": null, 2182 | "align_self": null, 2183 | "border": null, 2184 | "bottom": null, 2185 | "display": null, 2186 | "flex": null, 2187 | "flex_flow": null, 2188 | "grid_area": null, 2189 | "grid_auto_columns": null, 2190 | "grid_auto_flow": null, 2191 | "grid_auto_rows": null, 2192 | "grid_column": null, 2193 | "grid_gap": null, 2194 | "grid_row": null, 2195 | "grid_template_areas": null, 2196 | "grid_template_columns": null, 2197 | "grid_template_rows": null, 2198 | "height": null, 2199 | "justify_content": null, 2200 | "justify_items": null, 2201 | "left": null, 2202 | "margin": null, 2203 | "max_height": null, 2204 | "max_width": null, 2205 | "min_height": null, 2206 | "min_width": null, 2207 | "object_fit": null, 2208 | "object_position": null, 2209 | "order": null, 2210 | "overflow": null, 2211 | "overflow_x": null, 2212 | "overflow_y": null, 2213 | "padding": null, 2214 | "right": null, 2215 | "top": null, 2216 | "visibility": null, 2217 | "width": null 2218 | } 2219 | }, 2220 | "cb7f52610fcf49bda46a14b296ff5bb5": { 2221 | "model_module": "@jupyter-widgets/controls", 2222 | "model_name": "HTMLModel", 2223 | "state": { 2224 | "_dom_classes": [], 2225 | "_model_module": "@jupyter-widgets/controls", 2226 | "_model_module_version": "1.5.0", 2227 | "_model_name": "HTMLModel", 2228 | "_view_count": null, 2229 | "_view_module": "@jupyter-widgets/controls", 2230 | "_view_module_version": "1.5.0", 2231 | "_view_name": "HTMLView", 2232 | "description": "", 2233 | "description_tooltip": null, 2234 | "layout": "IPY_MODEL_8bcc625c0f284398bbd287fe45021b17", 2235 | "placeholder": "​", 2236 | "style": "IPY_MODEL_4cacf7fc20754a7ca7fe08c8ec187a81", 2237 | "value": " 21/21 [00:01<00:00, 10.78it/s]" 2238 | } 2239 | }, 2240 | "cf5cd281fa3b453093e210650bf81e9e": { 2241 | "model_module": "@jupyter-widgets/controls", 2242 | "model_name": "HTMLModel", 2243 | "state": { 2244 | "_dom_classes": [], 2245 | "_model_module": "@jupyter-widgets/controls", 2246 | "_model_module_version": "1.5.0", 2247 | "_model_name": "HTMLModel", 2248 | "_view_count": null, 2249 | "_view_module": "@jupyter-widgets/controls", 2250 | "_view_module_version": "1.5.0", 2251 | "_view_name": "HTMLView", 2252 | "description": "", 2253 | "description_tooltip": null, 2254 | "layout": "IPY_MODEL_002f56aac3d64b33a0e799c0baf1e6b9", 2255 | "placeholder": "​", 2256 | "style": "IPY_MODEL_8b3a41c1900b45ebb9c56601deca0e84", 2257 | "value": " 195/195 [00:40<00:00, 4.84it/s]" 2258 | } 2259 | }, 2260 | "dc27e2caf1ea4a4ab9ae3708fb06952f": { 2261 | "model_module": "@jupyter-widgets/base", 2262 | "model_name": "LayoutModel", 2263 | "state": { 2264 | "_model_module": "@jupyter-widgets/base", 2265 | "_model_module_version": "1.2.0", 2266 | "_model_name": "LayoutModel", 2267 | "_view_count": null, 2268 | "_view_module": "@jupyter-widgets/base", 2269 | "_view_module_version": "1.2.0", 2270 | "_view_name": "LayoutView", 2271 | "align_content": null, 2272 | "align_items": null, 2273 | "align_self": null, 2274 | "border": null, 2275 | "bottom": null, 2276 | "display": null, 2277 | "flex": null, 2278 | "flex_flow": null, 2279 | "grid_area": null, 2280 | "grid_auto_columns": null, 2281 | "grid_auto_flow": null, 2282 | "grid_auto_rows": null, 2283 | "grid_column": null, 2284 | "grid_gap": null, 2285 | "grid_row": null, 2286 | "grid_template_areas": null, 2287 | "grid_template_columns": null, 2288 | "grid_template_rows": null, 2289 | "height": null, 2290 | "justify_content": null, 2291 | "justify_items": null, 2292 | "left": null, 2293 | "margin": null, 2294 | "max_height": null, 2295 | "max_width": null, 2296 | "min_height": null, 2297 | "min_width": null, 2298 | "object_fit": null, 2299 | "object_position": null, 2300 | "order": null, 2301 | "overflow": null, 2302 | "overflow_x": null, 2303 | "overflow_y": null, 2304 | "padding": null, 2305 | "right": null, 2306 | "top": null, 2307 | "visibility": null, 2308 | "width": null 2309 | } 2310 | }, 2311 | "de252cd193114c40ad5f5e9622b7abc7": { 2312 | "model_module": "@jupyter-widgets/controls", 2313 | "model_name": "DescriptionStyleModel", 2314 | "state": { 2315 | "_model_module": "@jupyter-widgets/controls", 2316 | "_model_module_version": "1.5.0", 2317 | "_model_name": "DescriptionStyleModel", 2318 | "_view_count": null, 2319 | "_view_module": "@jupyter-widgets/base", 2320 | "_view_module_version": "1.2.0", 2321 | "_view_name": "StyleView", 2322 | "description_width": "" 2323 | } 2324 | }, 2325 | "e1fbe239c2394cbf973ac5b95e1e1491": { 2326 | "model_module": "@jupyter-widgets/controls", 2327 | "model_name": "ProgressStyleModel", 2328 | "state": { 2329 | "_model_module": "@jupyter-widgets/controls", 2330 | "_model_module_version": "1.5.0", 2331 | "_model_name": "ProgressStyleModel", 2332 | "_view_count": null, 2333 | "_view_module": "@jupyter-widgets/base", 2334 | "_view_module_version": "1.2.0", 2335 | "_view_name": "StyleView", 2336 | "bar_color": null, 2337 | "description_width": "initial" 2338 | } 2339 | }, 2340 | "e38fb98fd7b3413392dc39c93a107a35": { 2341 | "model_module": "@jupyter-widgets/controls", 2342 | "model_name": "FloatProgressModel", 2343 | "state": { 2344 | "_dom_classes": [], 2345 | "_model_module": "@jupyter-widgets/controls", 2346 | "_model_module_version": "1.5.0", 2347 | "_model_name": "FloatProgressModel", 2348 | "_view_count": null, 2349 | "_view_module": "@jupyter-widgets/controls", 2350 | "_view_module_version": "1.5.0", 2351 | "_view_name": "ProgressView", 2352 | "bar_style": "success", 2353 | "description": "Iteration: 100%", 2354 | "description_tooltip": null, 2355 | "layout": "IPY_MODEL_43fdb31d3f314624ba07a15718b0c8f3", 2356 | "max": 195, 2357 | "min": 0, 2358 | "orientation": "horizontal", 2359 | "style": "IPY_MODEL_4699416338ae40a5b6abf19e45089aec", 2360 | "value": 195 2361 | } 2362 | }, 2363 | "e7b9f3fc77a24259a87ef0dc735dfecb": { 2364 | "model_module": "@jupyter-widgets/base", 2365 | "model_name": "LayoutModel", 2366 | "state": { 2367 | "_model_module": "@jupyter-widgets/base", 2368 | "_model_module_version": "1.2.0", 2369 | "_model_name": "LayoutModel", 2370 | "_view_count": null, 2371 | "_view_module": "@jupyter-widgets/base", 2372 | "_view_module_version": "1.2.0", 2373 | "_view_name": "LayoutView", 2374 | "align_content": null, 2375 | "align_items": null, 2376 | "align_self": null, 2377 | "border": null, 2378 | "bottom": null, 2379 | "display": null, 2380 | "flex": null, 2381 | "flex_flow": null, 2382 | "grid_area": null, 2383 | "grid_auto_columns": null, 2384 | "grid_auto_flow": null, 2385 | "grid_auto_rows": null, 2386 | "grid_column": null, 2387 | "grid_gap": null, 2388 | "grid_row": null, 2389 | "grid_template_areas": null, 2390 | "grid_template_columns": null, 2391 | "grid_template_rows": null, 2392 | "height": null, 2393 | "justify_content": null, 2394 | "justify_items": null, 2395 | "left": null, 2396 | "margin": null, 2397 | "max_height": null, 2398 | "max_width": null, 2399 | "min_height": null, 2400 | "min_width": null, 2401 | "object_fit": null, 2402 | "object_position": null, 2403 | "order": null, 2404 | "overflow": null, 2405 | "overflow_x": null, 2406 | "overflow_y": null, 2407 | "padding": null, 2408 | "right": null, 2409 | "top": null, 2410 | "visibility": null, 2411 | "width": null 2412 | } 2413 | }, 2414 | "ea6b919964d24c2f9de1c64c9cefaf23": { 2415 | "model_module": "@jupyter-widgets/base", 2416 | "model_name": "LayoutModel", 2417 | "state": { 2418 | "_model_module": "@jupyter-widgets/base", 2419 | "_model_module_version": "1.2.0", 2420 | "_model_name": "LayoutModel", 2421 | "_view_count": null, 2422 | "_view_module": "@jupyter-widgets/base", 2423 | "_view_module_version": "1.2.0", 2424 | "_view_name": "LayoutView", 2425 | "align_content": null, 2426 | "align_items": null, 2427 | "align_self": null, 2428 | "border": null, 2429 | "bottom": null, 2430 | "display": null, 2431 | "flex": null, 2432 | "flex_flow": null, 2433 | "grid_area": null, 2434 | "grid_auto_columns": null, 2435 | "grid_auto_flow": null, 2436 | "grid_auto_rows": null, 2437 | "grid_column": null, 2438 | "grid_gap": null, 2439 | "grid_row": null, 2440 | "grid_template_areas": null, 2441 | "grid_template_columns": null, 2442 | "grid_template_rows": null, 2443 | "height": null, 2444 | "justify_content": null, 2445 | "justify_items": null, 2446 | "left": null, 2447 | "margin": null, 2448 | "max_height": null, 2449 | "max_width": null, 2450 | "min_height": null, 2451 | "min_width": null, 2452 | "object_fit": null, 2453 | "object_position": null, 2454 | "order": null, 2455 | "overflow": null, 2456 | "overflow_x": null, 2457 | "overflow_y": null, 2458 | "padding": null, 2459 | "right": null, 2460 | "top": null, 2461 | "visibility": null, 2462 | "width": null 2463 | } 2464 | }, 2465 | "f3bf54733c2d4d9daa1cc9a7746ccb14": { 2466 | "model_module": "@jupyter-widgets/controls", 2467 | "model_name": "FloatProgressModel", 2468 | "state": { 2469 | "_dom_classes": [], 2470 | "_model_module": "@jupyter-widgets/controls", 2471 | "_model_module_version": "1.5.0", 2472 | "_model_name": "FloatProgressModel", 2473 | "_view_count": null, 2474 | "_view_module": "@jupyter-widgets/controls", 2475 | "_view_module_version": "1.5.0", 2476 | "_view_name": "ProgressView", 2477 | "bar_style": "success", 2478 | "description": "Epoch: 100%", 2479 | "description_tooltip": null, 2480 | "layout": "IPY_MODEL_450b0e7fd7a347c7beb78b7d72f64385", 2481 | "max": 4, 2482 | "min": 0, 2483 | "orientation": "horizontal", 2484 | "style": "IPY_MODEL_021b771a270f479aa3b9e2b5f17e3d97", 2485 | "value": 4 2486 | } 2487 | }, 2488 | "f871b83632974e0088bae65e78efaf28": { 2489 | "model_module": "@jupyter-widgets/base", 2490 | "model_name": "LayoutModel", 2491 | "state": { 2492 | "_model_module": "@jupyter-widgets/base", 2493 | "_model_module_version": "1.2.0", 2494 | "_model_name": "LayoutModel", 2495 | "_view_count": null, 2496 | "_view_module": "@jupyter-widgets/base", 2497 | "_view_module_version": "1.2.0", 2498 | "_view_name": "LayoutView", 2499 | "align_content": null, 2500 | "align_items": null, 2501 | "align_self": null, 2502 | "border": null, 2503 | "bottom": null, 2504 | "display": null, 2505 | "flex": null, 2506 | "flex_flow": null, 2507 | "grid_area": null, 2508 | "grid_auto_columns": null, 2509 | "grid_auto_flow": null, 2510 | "grid_auto_rows": null, 2511 | "grid_column": null, 2512 | "grid_gap": null, 2513 | "grid_row": null, 2514 | "grid_template_areas": null, 2515 | "grid_template_columns": null, 2516 | "grid_template_rows": null, 2517 | "height": null, 2518 | "justify_content": null, 2519 | "justify_items": null, 2520 | "left": null, 2521 | "margin": null, 2522 | "max_height": null, 2523 | "max_width": null, 2524 | "min_height": null, 2525 | "min_width": null, 2526 | "object_fit": null, 2527 | "object_position": null, 2528 | "order": null, 2529 | "overflow": null, 2530 | "overflow_x": null, 2531 | "overflow_y": null, 2532 | "padding": null, 2533 | "right": null, 2534 | "top": null, 2535 | "visibility": null, 2536 | "width": null 2537 | } 2538 | }, 2539 | "fdffb26b99c24c978580f1cf97359fea": { 2540 | "model_module": "@jupyter-widgets/controls", 2541 | "model_name": "DescriptionStyleModel", 2542 | "state": { 2543 | "_model_module": "@jupyter-widgets/controls", 2544 | "_model_module_version": "1.5.0", 2545 | "_model_name": "DescriptionStyleModel", 2546 | "_view_count": null, 2547 | "_view_module": "@jupyter-widgets/base", 2548 | "_view_module_version": "1.2.0", 2549 | "_view_name": "StyleView", 2550 | "description_width": "" 2551 | } 2552 | } 2553 | } 2554 | } 2555 | }, 2556 | "nbformat": 4, 2557 | "nbformat_minor": 1 2558 | } 2559 | --------------------------------------------------------------------------------