├── gif-demo
├── icon.png
├── discord.gif
├── gameplay.png
└── huggingface.gif
├── LICENSE
├── discord_bot.js
├── README.md
├── .gitignore
├── discord_bot.py
└── model_train_upload_workflow.ipynb
/gif-demo/icon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shubham8550/twewy-discord-chatbot/main/gif-demo/icon.png
--------------------------------------------------------------------------------
/gif-demo/discord.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shubham8550/twewy-discord-chatbot/main/gif-demo/discord.gif
--------------------------------------------------------------------------------
/gif-demo/gameplay.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shubham8550/twewy-discord-chatbot/main/gif-demo/gameplay.png
--------------------------------------------------------------------------------
/gif-demo/huggingface.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shubham8550/twewy-discord-chatbot/main/gif-demo/huggingface.gif
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Lynn Zheng
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/discord_bot.js:
--------------------------------------------------------------------------------
1 | // discord.js import
2 | const Discord = require('discord.js');
3 | // node-fetch for making HTTP requests
4 | const fetch = require('node-fetch');
5 |
6 | // initialize client
7 | const client = new Discord.Client();
8 | // my model URL
9 | API_URL = 'https://api-inference.huggingface.co/models/r3dhummingbird/DialoGPT-medium-joshua';
10 |
11 | // log out some info
12 | client.on('ready', () => {
13 | console.log(`Logged in as ${client.user.tag}!`);
14 | });
15 |
16 | // when the bot receives a message
17 | // need async message because we are making HTTP requests
18 | client.on('message', async message => {
19 | // ignore messages from the bot itself
20 | if (message.author.bot) {
21 | return;
22 | }
23 | // form the payload
24 | const payload = {
25 | inputs: {
26 | text: message.content
27 | }
28 | };
29 | // form the request headers with Hugging Face API key
30 | const headers = {
31 | 'Authorization': 'Bearer ' + process.env.HUGGINGFACE_TOKEN
32 | };
33 |
34 | // set status to typing
35 | message.channel.startTyping();
36 | // query the server
37 | const response = await fetch(API_URL, {
38 | method: 'post',
39 | body: JSON.stringify(payload),
40 | headers: headers
41 | });
42 | const data = await response.json();
43 | let botResponse = '';
44 | if (data.hasOwnProperty('generated_text')) {
45 | botResponse = data.generated_text;
46 | } else if (data.hasOwnProperty('error')) { // error condition
47 | botResponse = data.error;
48 | }
49 | // stop typing
50 | message.channel.stopTyping();
51 | // send message to channel as a reply
52 | message.reply(botResponse);
53 | })
54 |
55 | client.login(process.env.DISCORD_TOKEN);
56 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Build a Discord AI Chatbot that Speaks like Your Favorite Character!
2 |
3 |
4 |

5 |
6 |
7 | This is a Discord AI Chatbot that uses the [Microsoft DialoGPT conversational model](https://huggingface.co/microsoft/DialoGPT-medium) fine-tuned on the game transcript of [The World Ends With You](https://en.wikipedia.org/wiki/The_World_Ends_with_You) (TWEWY). Read [my tutorial on freeCodeCamp](https://www.freecodecamp.org/news/discord-ai-chatbot/) or watch [my video tutorial on YouTube](https://youtu.be/UBwvFuTC1ZE). I've also made [a JavaScript version of the tutorial using Discord.js](https://youtu.be/XR6JFRLxe5A).
8 |
9 | I trained the model using the lines of my favorite quirky character, Joshua (left in the image below). He has about 700 lines in total in the entire game.
10 |
11 | 
12 |
13 | Here is a demo of the Discord bot in action.
14 |
15 | 
16 |
17 | You can also directly chat with the model hosted on [Hugging Face's Model Hub](https://huggingface.co/r3dhummingbird/DialoGPT-medium-joshua).
18 |
19 | 
20 |
21 | ## Structure of this Project
22 |
23 | - `model_train_upload_workflow.ipyb`: Notebook to be run in Google Colab to train and upload the model to Hugging Face's Model Hub
24 | - `discord_bot.py`: Script to be imported into a Repl.it Python Discord.py project
25 | - `discord_bot.js`: Script to be imported into a Repl.it JavaScript Discord.js project
26 |
27 | ## Resource Links
28 |
29 | - [15-min chat demo](https://youtu.be/-n6uWu8PZzo)
30 | - [My tutorial on freeCodeCamp](https://www.freecodecamp.org/news/discord-ai-chatbot/)
31 | - [My video tutorial on YouTube](https://youtu.be/UBwvFuTC1ZE)
32 | - [My JavaScript version of this tutorial on YouTube](https://youtu.be/XR6JFRLxe5A)
33 | - [My TWEWY dataset on Kaggle](https://www.kaggle.com/ruolinzheng/twewy-game-script)
34 | - [My Hugging Face Model](https://huggingface.co/r3dhummingbird/DialoGPT-medium-joshua)
35 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/discord_bot.py:
--------------------------------------------------------------------------------
1 | # the os module helps us access environment variables
2 | # i.e., our API keys
3 | import os
4 |
5 | # these modules are for querying the Hugging Face model
6 | import json
7 | import requests
8 |
9 | # the Discord Python API
10 | import discord
11 |
12 | # this is my Hugging Face profile link
13 | API_URL = 'https://api-inference.huggingface.co/models/r3dhummingbird/'
14 |
15 | class MyClient(discord.Client):
16 | def __init__(self, model_name):
17 | super().__init__()
18 | self.api_endpoint = API_URL + model_name
19 | # retrieve the secret API token from the system environment
20 | huggingface_token = os.environ['HUGGINGFACE_TOKEN']
21 | # format the header in our request to Hugging Face
22 | self.request_headers = {
23 | 'Authorization': 'Bearer {}'.format(huggingface_token)
24 | }
25 |
26 | def query(self, payload):
27 | """
28 | make request to the Hugging Face model API
29 | """
30 | data = json.dumps(payload)
31 | response = requests.request('POST',
32 | self.api_endpoint,
33 | headers=self.request_headers,
34 | data=data)
35 | ret = json.loads(response.content.decode('utf-8'))
36 | return ret
37 |
38 | async def on_ready(self):
39 | # print out information when the bot wakes up
40 | print('Logged in as')
41 | print(self.user.name)
42 | print(self.user.id)
43 | print('------')
44 | # send a request to the model without caring about the response
45 | # just so that the model wakes up and starts loading
46 | self.query({'inputs': {'text': 'Hello!'}})
47 |
48 | async def on_message(self, message):
49 | """
50 | this function is called whenever the bot sees a message in a channel
51 | """
52 | # ignore the message if it comes from the bot itself
53 | if message.author.id == self.user.id:
54 | return
55 |
56 | # form query payload with the content of the message
57 | payload = {'inputs': {'text': message.content}}
58 |
59 | # while the bot is waiting on a response from the model
60 | # set the its status as typing for user-friendliness
61 | async with message.channel.typing():
62 | response = self.query(payload)
63 | bot_response = response.get('generated_text', None)
64 |
65 | # we may get ill-formed response if the model hasn't fully loaded
66 | # or has timed out
67 | if not bot_response:
68 | if 'error' in response:
69 | bot_response = '`Error: {}`'.format(response['error'])
70 | else:
71 | bot_response = 'Hmm... something is not right.'
72 |
73 | # send the model's response to the Discord channel
74 | await message.channel.send(bot_response)
75 |
76 | def main():
77 | # DialoGPT-medium-joshua is my model name
78 | client = MyClient('DialoGPT-medium-joshua')
79 | client.run(os.environ['DISCORD_TOKEN'])
80 |
81 | if __name__ == '__main__':
82 | main()
--------------------------------------------------------------------------------
/model_train_upload_workflow.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "id": "VTze-VbeU1c0"
7 | },
8 | "source": [
9 | "# Fine-tune a DialoGPT model\n",
10 | "\n",
11 | "Adapted from the notebook in [this Medium post](https://towardsdatascience.com/make-your-own-rick-sanchez-bot-with-transformers-and-dialogpt-fine-tuning-f85e6d1f4e30?gi=e4a72d1510f0)."
12 | ]
13 | },
14 | {
15 | "cell_type": "markdown",
16 | "metadata": {
17 | "id": "Y17kuzFNUSrZ"
18 | },
19 | "source": [
20 | "## Setup"
21 | ]
22 | },
23 | {
24 | "cell_type": "code",
25 | "execution_count": null,
26 | "metadata": {
27 | "colab": {
28 | "base_uri": "https://localhost:8080/"
29 | },
30 | "id": "GBfltjGHT6KG",
31 | "outputId": "7822e15b-9c77-412a-a6ed-20100243db13"
32 | },
33 | "outputs": [],
34 | "source": [
35 | "from google.colab import drive\n",
36 | "drive.mount('/content/drive/')"
37 | ]
38 | },
39 | {
40 | "cell_type": "code",
41 | "execution_count": null,
42 | "metadata": {
43 | "id": "T8fgmjaqUErq"
44 | },
45 | "outputs": [],
46 | "source": [
47 | "!pip -q install transformers"
48 | ]
49 | },
50 | {
51 | "cell_type": "code",
52 | "execution_count": null,
53 | "metadata": {
54 | "id": "EtCreyG8UG1s"
55 | },
56 | "outputs": [],
57 | "source": [
58 | "import os\n",
59 | "os.chdir(\"/content/drive/My Drive/Colab Notebooks\")"
60 | ]
61 | },
62 | {
63 | "cell_type": "code",
64 | "execution_count": null,
65 | "metadata": {
66 | "id": "dnv5kT-mLsB-"
67 | },
68 | "outputs": [],
69 | "source": [
70 | "# all the imports\n",
71 | "\n",
72 | "import glob\n",
73 | "import logging\n",
74 | "import os\n",
75 | "import pickle\n",
76 | "import random\n",
77 | "import re\n",
78 | "import shutil\n",
79 | "from typing import Dict, List, Tuple\n",
80 | "\n",
81 | "import numpy as np\n",
82 | "import pandas as pd\n",
83 | "\n",
84 | "from sklearn.model_selection import train_test_split\n",
85 | "\n",
86 | "from torch.nn.utils.rnn import pad_sequence\n",
87 | "from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler\n",
88 | "from torch.utils.data.distributed import DistributedSampler\n",
89 | "from tqdm.notebook import tqdm, trange\n",
90 | "\n",
91 | "from pathlib import Path\n",
92 | "\n",
93 | "from transformers import (\n",
94 | " MODEL_WITH_LM_HEAD_MAPPING,\n",
95 | " WEIGHTS_NAME,\n",
96 | " AdamW,\n",
97 | " AutoConfig,\n",
98 | " PreTrainedModel,\n",
99 | " PreTrainedTokenizer,\n",
100 | " get_linear_schedule_with_warmup,\n",
101 | ")\n",
102 | "\n",
103 | "\n",
104 | "try:\n",
105 | " from torch.utils.tensorboard import SummaryWriter\n",
106 | "except ImportError:\n",
107 | " from tensorboardX import SummaryWriter"
108 | ]
109 | },
110 | {
111 | "cell_type": "markdown",
112 | "metadata": {
113 | "id": "BmrbGB8aUmBm"
114 | },
115 | "source": [
116 | "## Get Data from Kaggle"
117 | ]
118 | },
119 | {
120 | "cell_type": "code",
121 | "execution_count": null,
122 | "metadata": {
123 | "colab": {
124 | "base_uri": "https://localhost:8080/"
125 | },
126 | "id": "ftBYBoOoV_Er",
127 | "outputId": "07da0a13-6112-4c4e-cb49-51580c2d9e7a"
128 | },
129 | "outputs": [],
130 | "source": [
131 | "!mkdir ~/.kaggle\n",
132 | "!cp kaggle.json ~/.kaggle/kaggle.json"
133 | ]
134 | },
135 | {
136 | "cell_type": "code",
137 | "execution_count": null,
138 | "metadata": {
139 | "colab": {
140 | "base_uri": "https://localhost:8080/"
141 | },
142 | "id": "fbITTMcLVbI_",
143 | "outputId": "fb4c8bf1-ff2d-4952-a451-62cdd0655aea"
144 | },
145 | "outputs": [],
146 | "source": [
147 | "!kaggle datasets download ruolinzheng/twewy-game-script -f twewy-name-line-full.csv"
148 | ]
149 | },
150 | {
151 | "cell_type": "code",
152 | "execution_count": null,
153 | "metadata": {
154 | "id": "RXdJTSVwWGHj"
155 | },
156 | "outputs": [],
157 | "source": [
158 | "data = pd.read_csv('twewy-name-line-full.csv')"
159 | ]
160 | },
161 | {
162 | "cell_type": "code",
163 | "execution_count": null,
164 | "metadata": {
165 | "colab": {
166 | "base_uri": "https://localhost:8080/",
167 | "height": 238
168 | },
169 | "id": "h6kGx-9eG7qA",
170 | "outputId": "bd2efe43-1e50-4716-81a2-bf15a3dd03bd"
171 | },
172 | "outputs": [],
173 | "source": [
174 | "data.sample(6)"
175 | ]
176 | },
177 | {
178 | "cell_type": "code",
179 | "execution_count": null,
180 | "metadata": {
181 | "id": "PG8v6--qWUwj"
182 | },
183 | "outputs": [],
184 | "source": [
185 | "CHARACTER_NAME = 'Joshua'"
186 | ]
187 | },
188 | {
189 | "cell_type": "code",
190 | "execution_count": null,
191 | "metadata": {
192 | "id": "GZUcEMd2WLDT"
193 | },
194 | "outputs": [],
195 | "source": [
196 | "contexted = []\n",
197 | "\n",
198 | "# context window of size 7\n",
199 | "n = 7\n",
200 | "\n",
201 | "for i in data[data.name == CHARACTER_NAME].index:\n",
202 | " if i < n:\n",
203 | " continue\n",
204 | " row = []\n",
205 | " prev = i - 1 - n # we additionally substract 1, so row will contain current responce and 7 previous responces \n",
206 | " for j in range(i, prev, -1):\n",
207 | " row.append(data.line[j])\n",
208 | " contexted.append(row)\n",
209 | "\n",
210 | "columns = ['response', 'context'] \n",
211 | "columns = columns + ['context/' + str(i) for i in range(n - 1)]\n",
212 | "\n",
213 | "df = pd.DataFrame.from_records(contexted, columns=columns)"
214 | ]
215 | },
216 | {
217 | "cell_type": "code",
218 | "execution_count": null,
219 | "metadata": {
220 | "colab": {
221 | "base_uri": "https://localhost:8080/",
222 | "height": 446
223 | },
224 | "id": "4T5OlNZHUxij",
225 | "outputId": "895603a6-ca02-4301-c4b0-5bccbee8a3b8"
226 | },
227 | "outputs": [],
228 | "source": [
229 | "df.sample(6)"
230 | ]
231 | },
232 | {
233 | "cell_type": "code",
234 | "execution_count": null,
235 | "metadata": {
236 | "colab": {
237 | "base_uri": "https://localhost:8080/",
238 | "height": 380
239 | },
240 | "id": "NGy0MxMQVIAP",
241 | "outputId": "08b7f0eb-6a38-4b83-efdc-e53778d7547a"
242 | },
243 | "outputs": [],
244 | "source": [
245 | "trn_df, val_df = train_test_split(df, test_size=0.1)\n",
246 | "trn_df.head()"
247 | ]
248 | },
249 | {
250 | "cell_type": "code",
251 | "execution_count": null,
252 | "metadata": {
253 | "id": "aEeJQlAKWtiJ"
254 | },
255 | "outputs": [],
256 | "source": [
257 | "# create dataset suitable for our model\n",
258 | "def construct_conv(row, tokenizer, eos = True):\n",
259 | " flatten = lambda l: [item for sublist in l for item in sublist]\n",
260 | " conv = list(reversed([tokenizer.encode(x) + [tokenizer.eos_token_id] for x in row]))\n",
261 | " conv = flatten(conv)\n",
262 | " return conv\n",
263 | "\n",
264 | "class ConversationDataset(Dataset):\n",
265 | " def __init__(self, tokenizer: PreTrainedTokenizer, args, df, block_size=512):\n",
266 | "\n",
267 | " block_size = block_size - (tokenizer.model_max_length - tokenizer.max_len_single_sentence)\n",
268 | "\n",
269 | " directory = args.cache_dir\n",
270 | " cached_features_file = os.path.join(\n",
271 | " directory, args.model_type + \"_cached_lm_\" + str(block_size)\n",
272 | " )\n",
273 | "\n",
274 | " if os.path.exists(cached_features_file) and not args.overwrite_cache:\n",
275 | " logger.info(\"Loading features from cached file %s\", cached_features_file)\n",
276 | " with open(cached_features_file, \"rb\") as handle:\n",
277 | " self.examples = pickle.load(handle)\n",
278 | " else:\n",
279 | " logger.info(\"Creating features from dataset file at %s\", directory)\n",
280 | "\n",
281 | " self.examples = []\n",
282 | " for _, row in df.iterrows():\n",
283 | " conv = construct_conv(row, tokenizer)\n",
284 | " self.examples.append(conv)\n",
285 | "\n",
286 | " logger.info(\"Saving features into cached file %s\", cached_features_file)\n",
287 | " with open(cached_features_file, \"wb\") as handle:\n",
288 | " pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL)\n",
289 | "\n",
290 | " def __len__(self):\n",
291 | " return len(self.examples)\n",
292 | "\n",
293 | " def __getitem__(self, item):\n",
294 | " return torch.tensor(self.examples[item], dtype=torch.long)"
295 | ]
296 | },
297 | {
298 | "cell_type": "code",
299 | "execution_count": null,
300 | "metadata": {
301 | "id": "-3iHwoKlWyrs"
302 | },
303 | "outputs": [],
304 | "source": [
305 | "# Cacheing and storing of data/checkpoints\n",
306 | "\n",
307 | "def load_and_cache_examples(args, tokenizer, df_trn, df_val, evaluate=False):\n",
308 | " return ConversationDataset(tokenizer, args, df_val if evaluate else df_trn)\n",
309 | "\n",
310 | "\n",
311 | "def set_seed(args):\n",
312 | " random.seed(args.seed)\n",
313 | " np.random.seed(args.seed)\n",
314 | " torch.manual_seed(args.seed)\n",
315 | " if args.n_gpu > 0:\n",
316 | " torch.cuda.manual_seed_all(args.seed)\n",
317 | "\n",
318 | "\n",
319 | "def _sorted_checkpoints(args, checkpoint_prefix=\"checkpoint\", use_mtime=False) -> List[str]:\n",
320 | " ordering_and_checkpoint_path = []\n",
321 | "\n",
322 | " glob_checkpoints = glob.glob(os.path.join(args.output_dir, \"{}-*\".format(checkpoint_prefix)))\n",
323 | "\n",
324 | " for path in glob_checkpoints:\n",
325 | " if use_mtime:\n",
326 | " ordering_and_checkpoint_path.append((os.path.getmtime(path), path))\n",
327 | " else:\n",
328 | " regex_match = re.match(\".*{}-([0-9]+)\".format(checkpoint_prefix), path)\n",
329 | " if regex_match and regex_match.groups():\n",
330 | " ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))\n",
331 | "\n",
332 | " checkpoints_sorted = sorted(ordering_and_checkpoint_path)\n",
333 | " checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]\n",
334 | " return checkpoints_sorted\n",
335 | "\n",
336 | "\n",
337 | "def _rotate_checkpoints(args, checkpoint_prefix=\"checkpoint\", use_mtime=False) -> None:\n",
338 | " if not args.save_total_limit:\n",
339 | " return\n",
340 | " if args.save_total_limit <= 0:\n",
341 | " return\n",
342 | "\n",
343 | " # Check if we should delete older checkpoint(s)\n",
344 | " checkpoints_sorted = _sorted_checkpoints(args, checkpoint_prefix, use_mtime)\n",
345 | " if len(checkpoints_sorted) <= args.save_total_limit:\n",
346 | " return\n",
347 | "\n",
348 | " number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - args.save_total_limit)\n",
349 | " checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]\n",
350 | " for checkpoint in checkpoints_to_be_deleted:\n",
351 | " logger.info(\"Deleting older checkpoint [{}] due to args.save_total_limit\".format(checkpoint))\n",
352 | " shutil.rmtree(checkpoint)"
353 | ]
354 | },
355 | {
356 | "cell_type": "markdown",
357 | "metadata": {
358 | "id": "EEDdTJTqUwZJ"
359 | },
360 | "source": [
361 | "## Build Model"
362 | ]
363 | },
364 | {
365 | "cell_type": "code",
366 | "execution_count": null,
367 | "metadata": {
368 | "colab": {
369 | "base_uri": "https://localhost:8080/"
370 | },
371 | "id": "r2cE0fY5UHpz",
372 | "outputId": "e4f382cd-57d9-49b7-9da4-4b44fe57df5b"
373 | },
374 | "outputs": [],
375 | "source": [
376 | "from transformers import AutoModelWithLMHead, AutoModelForCausalLM, AutoTokenizer\n",
377 | "import torch\n",
378 | "\n",
379 | "tokenizer = AutoTokenizer.from_pretrained(\"microsoft/DialoGPT-small\")\n",
380 | "model = AutoModelWithLMHead.from_pretrained(\"microsoft/DialoGPT-small\")"
381 | ]
382 | },
383 | {
384 | "cell_type": "code",
385 | "execution_count": null,
386 | "metadata": {
387 | "id": "ra2vsRp-UMXo"
388 | },
389 | "outputs": [],
390 | "source": [
391 | "\"\"\"\n",
392 | "Fine-tuning the library models for language modeling on a text file (GPT, GPT-2, BERT, RoBERTa).\n",
393 | "GPT and GPT-2 are fine-tuned using a causal language modeling (CLM) loss while BERT and RoBERTa are fine-tuned\n",
394 | "using a masked language modeling (MLM) loss.\n",
395 | "\"\"\"\n",
396 | "\n",
397 | "# Configs\n",
398 | "logger = logging.getLogger(__name__)\n",
399 | "\n",
400 | "MODEL_CONFIG_CLASSES = list(MODEL_WITH_LM_HEAD_MAPPING.keys())\n",
401 | "MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)"
402 | ]
403 | },
404 | {
405 | "cell_type": "code",
406 | "execution_count": null,
407 | "metadata": {
408 | "id": "2OnASqJjUNJa"
409 | },
410 | "outputs": [],
411 | "source": [
412 | "# Args to allow for easy convertion of python script to notebook\n",
413 | "class Args():\n",
414 | " def __init__(self):\n",
415 | " self.output_dir = 'output-small'\n",
416 | " self.model_type = 'gpt2'\n",
417 | " self.model_name_or_path = 'microsoft/DialoGPT-small'\n",
418 | " self.config_name = 'microsoft/DialoGPT-small'\n",
419 | " self.tokenizer_name = 'microsoft/DialoGPT-small'\n",
420 | " self.cache_dir = 'cached'\n",
421 | " self.block_size = 512\n",
422 | " self.do_train = True\n",
423 | " self.do_eval = True\n",
424 | " self.evaluate_during_training = False\n",
425 | " self.per_gpu_train_batch_size = 4\n",
426 | " self.per_gpu_eval_batch_size = 4\n",
427 | " self.gradient_accumulation_steps = 1\n",
428 | " self.learning_rate = 5e-5\n",
429 | " self.weight_decay = 0.0\n",
430 | " self.adam_epsilon = 1e-8\n",
431 | " self.max_grad_norm = 1.0\n",
432 | " self.num_train_epochs = 4\n",
433 | " self.max_steps = -1\n",
434 | " self.warmup_steps = 0\n",
435 | " self.logging_steps = 1000\n",
436 | " self.save_steps = 3500\n",
437 | " self.save_total_limit = None\n",
438 | " self.eval_all_checkpoints = False\n",
439 | " self.no_cuda = False\n",
440 | " self.overwrite_output_dir = True\n",
441 | " self.overwrite_cache = True\n",
442 | " self.should_continue = False\n",
443 | " self.seed = 42\n",
444 | " self.local_rank = -1\n",
445 | " self.fp16 = False\n",
446 | " self.fp16_opt_level = 'O1'\n",
447 | "\n",
448 | "args = Args()"
449 | ]
450 | },
451 | {
452 | "cell_type": "markdown",
453 | "metadata": {
454 | "id": "9Q1dTFXxW9NE"
455 | },
456 | "source": [
457 | "## Train and Evaluate"
458 | ]
459 | },
460 | {
461 | "cell_type": "code",
462 | "execution_count": null,
463 | "metadata": {
464 | "id": "PaarIDZrW81h"
465 | },
466 | "outputs": [],
467 | "source": [
468 | "def train(args, train_dataset, model: PreTrainedModel, tokenizer: PreTrainedTokenizer) -> Tuple[int, float]:\n",
469 | " \"\"\" Train the model \"\"\"\n",
470 | " if args.local_rank in [-1, 0]:\n",
471 | " tb_writer = SummaryWriter()\n",
472 | "\n",
473 | " args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)\n",
474 | "\n",
475 | " def collate(examples: List[torch.Tensor]):\n",
476 | " if tokenizer._pad_token is None:\n",
477 | " return pad_sequence(examples, batch_first=True)\n",
478 | " return pad_sequence(examples, batch_first=True, padding_value=tokenizer.pad_token_id)\n",
479 | "\n",
480 | " train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)\n",
481 | " train_dataloader = DataLoader(\n",
482 | " train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, collate_fn=collate, drop_last = True\n",
483 | " )\n",
484 | "\n",
485 | " if args.max_steps > 0:\n",
486 | " t_total = args.max_steps\n",
487 | " args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1\n",
488 | " else:\n",
489 | " t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs\n",
490 | "\n",
491 | " model = model.module if hasattr(model, \"module\") else model # Take care of distributed/parallel training\n",
492 | " model.resize_token_embeddings(len(tokenizer))\n",
493 | " # add_special_tokens_(model, tokenizer)\n",
494 | "\n",
495 | "\n",
496 | " # Prepare optimizer and schedule (linear warmup and decay)\n",
497 | " no_decay = [\"bias\", \"LayerNorm.weight\"]\n",
498 | " optimizer_grouped_parameters = [\n",
499 | " {\n",
500 | " \"params\": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],\n",
501 | " \"weight_decay\": args.weight_decay,\n",
502 | " },\n",
503 | " {\"params\": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], \"weight_decay\": 0.0},\n",
504 | " ]\n",
505 | " optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)\n",
506 | " scheduler = get_linear_schedule_with_warmup(\n",
507 | " optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total\n",
508 | " )\n",
509 | "\n",
510 | " # Check if saved optimizer or scheduler states exist\n",
511 | " if (\n",
512 | " args.model_name_or_path\n",
513 | " and os.path.isfile(os.path.join(args.model_name_or_path, \"optimizer.pt\"))\n",
514 | " and os.path.isfile(os.path.join(args.model_name_or_path, \"scheduler.pt\"))\n",
515 | " ):\n",
516 | " # Load in optimizer and scheduler states\n",
517 | " optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, \"optimizer.pt\")))\n",
518 | " scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, \"scheduler.pt\")))\n",
519 | "\n",
520 | " if args.fp16:\n",
521 | " try:\n",
522 | " from apex import amp\n",
523 | " except ImportError:\n",
524 | " raise ImportError(\"Please install apex from https://www.github.com/nvidia/apex to use fp16 training.\")\n",
525 | " model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)\n",
526 | "\n",
527 | " # multi-gpu training (should be after apex fp16 initialization)\n",
528 | " if args.n_gpu > 1:\n",
529 | " model = torch.nn.DataParallel(model)\n",
530 | "\n",
531 | " # Distributed training (should be after apex fp16 initialization)\n",
532 | " if args.local_rank != -1:\n",
533 | " model = torch.nn.parallel.DistributedDataParallel(\n",
534 | " model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True\n",
535 | " )\n",
536 | "\n",
537 | " # Train!\n",
538 | " logger.info(\"***** Running training *****\")\n",
539 | " logger.info(\" Num examples = %d\", len(train_dataset))\n",
540 | " logger.info(\" Num Epochs = %d\", args.num_train_epochs)\n",
541 | " logger.info(\" Instantaneous batch size per GPU = %d\", args.per_gpu_train_batch_size)\n",
542 | " logger.info(\n",
543 | " \" Total train batch size (w. parallel, distributed & accumulation) = %d\",\n",
544 | " args.train_batch_size\n",
545 | " * args.gradient_accumulation_steps\n",
546 | " * (torch.distributed.get_world_size() if args.local_rank != -1 else 1),\n",
547 | " )\n",
548 | " logger.info(\" Gradient Accumulation steps = %d\", args.gradient_accumulation_steps)\n",
549 | " logger.info(\" Total optimization steps = %d\", t_total)\n",
550 | "\n",
551 | " global_step = 0\n",
552 | " epochs_trained = 0\n",
553 | " steps_trained_in_current_epoch = 0\n",
554 | " # Check if continuing training from a checkpoint\n",
555 | " if args.model_name_or_path and os.path.exists(args.model_name_or_path):\n",
556 | " try:\n",
557 | " # set global_step to gobal_step of last saved checkpoint from model path\n",
558 | " checkpoint_suffix = args.model_name_or_path.split(\"-\")[-1].split(\"/\")[0]\n",
559 | " global_step = int(checkpoint_suffix)\n",
560 | " epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)\n",
561 | " steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)\n",
562 | "\n",
563 | " logger.info(\" Continuing training from checkpoint, will skip to saved global_step\")\n",
564 | " logger.info(\" Continuing training from epoch %d\", epochs_trained)\n",
565 | " logger.info(\" Continuing training from global step %d\", global_step)\n",
566 | " logger.info(\" Will skip the first %d steps in the first epoch\", steps_trained_in_current_epoch)\n",
567 | " except ValueError:\n",
568 | " logger.info(\" Starting fine-tuning.\")\n",
569 | "\n",
570 | " tr_loss, logging_loss = 0.0, 0.0\n",
571 | "\n",
572 | " model.zero_grad()\n",
573 | " train_iterator = trange(\n",
574 | " epochs_trained, int(args.num_train_epochs), desc=\"Epoch\", disable=args.local_rank not in [-1, 0]\n",
575 | " )\n",
576 | " set_seed(args) # Added here for reproducibility\n",
577 | " for _ in train_iterator:\n",
578 | " epoch_iterator = tqdm(train_dataloader, desc=\"Iteration\", disable=args.local_rank not in [-1, 0])\n",
579 | " for step, batch in enumerate(epoch_iterator):\n",
580 | "\n",
581 | " # Skip past any already trained steps if resuming training\n",
582 | " if steps_trained_in_current_epoch > 0:\n",
583 | " steps_trained_in_current_epoch -= 1\n",
584 | " continue\n",
585 | "\n",
586 | " inputs, labels = (batch, batch)\n",
587 | " if inputs.shape[1] > 1024: continue\n",
588 | " inputs = inputs.to(args.device)\n",
589 | " labels = labels.to(args.device)\n",
590 | " model.train()\n",
591 | " outputs = model(inputs, labels=labels)\n",
592 | " loss = outputs[0] # model outputs are always tuple in transformers (see doc)\n",
593 | "\n",
594 | " if args.n_gpu > 1:\n",
595 | " loss = loss.mean() # mean() to average on multi-gpu parallel training\n",
596 | " if args.gradient_accumulation_steps > 1:\n",
597 | " loss = loss / args.gradient_accumulation_steps\n",
598 | "\n",
599 | " if args.fp16:\n",
600 | " with amp.scale_loss(loss, optimizer) as scaled_loss:\n",
601 | " scaled_loss.backward()\n",
602 | " else:\n",
603 | " loss.backward()\n",
604 | "\n",
605 | " tr_loss += loss.item()\n",
606 | " if (step + 1) % args.gradient_accumulation_steps == 0:\n",
607 | " if args.fp16:\n",
608 | " torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)\n",
609 | " else:\n",
610 | " torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)\n",
611 | " optimizer.step()\n",
612 | " scheduler.step() # Update learning rate schedule\n",
613 | " model.zero_grad()\n",
614 | " global_step += 1\n",
615 | "\n",
616 | " if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:\n",
617 | " # Log metrics\n",
618 | " if (\n",
619 | " args.local_rank == -1 and args.evaluate_during_training\n",
620 | " ): # Only evaluate when single GPU otherwise metrics may not average well\n",
621 | " results = evaluate(args, model, tokenizer)\n",
622 | " for key, value in results.items():\n",
623 | " tb_writer.add_scalar(\"eval_{}\".format(key), value, global_step)\n",
624 | " tb_writer.add_scalar(\"lr\", scheduler.get_lr()[0], global_step)\n",
625 | " tb_writer.add_scalar(\"loss\", (tr_loss - logging_loss) / args.logging_steps, global_step)\n",
626 | " logging_loss = tr_loss\n",
627 | "\n",
628 | " if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:\n",
629 | " checkpoint_prefix = \"checkpoint\"\n",
630 | " # Save model checkpoint\n",
631 | " output_dir = os.path.join(args.output_dir, \"{}-{}\".format(checkpoint_prefix, global_step))\n",
632 | " os.makedirs(output_dir, exist_ok=True)\n",
633 | " model_to_save = (\n",
634 | " model.module if hasattr(model, \"module\") else model\n",
635 | " ) # Take care of distributed/parallel training\n",
636 | " model_to_save.save_pretrained(output_dir)\n",
637 | " tokenizer.save_pretrained(output_dir)\n",
638 | "\n",
639 | " torch.save(args, os.path.join(output_dir, \"training_args.bin\"))\n",
640 | " logger.info(\"Saving model checkpoint to %s\", output_dir)\n",
641 | "\n",
642 | " _rotate_checkpoints(args, checkpoint_prefix)\n",
643 | "\n",
644 | " torch.save(optimizer.state_dict(), os.path.join(output_dir, \"optimizer.pt\"))\n",
645 | " torch.save(scheduler.state_dict(), os.path.join(output_dir, \"scheduler.pt\"))\n",
646 | " logger.info(\"Saving optimizer and scheduler states to %s\", output_dir)\n",
647 | "\n",
648 | " if args.max_steps > 0 and global_step > args.max_steps:\n",
649 | " epoch_iterator.close()\n",
650 | " break\n",
651 | " if args.max_steps > 0 and global_step > args.max_steps:\n",
652 | " train_iterator.close()\n",
653 | " break\n",
654 | "\n",
655 | " if args.local_rank in [-1, 0]:\n",
656 | " tb_writer.close()\n",
657 | "\n",
658 | " return global_step, tr_loss / global_step\n",
659 | "\n",
660 | "# Evaluation of some model\n",
661 | "\n",
662 | "def evaluate(args, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, df_trn, df_val, prefix=\"\") -> Dict:\n",
663 | " # Loop to handle MNLI double evaluation (matched, mis-matched)\n",
664 | " eval_output_dir = args.output_dir\n",
665 | "\n",
666 | " eval_dataset = load_and_cache_examples(args, tokenizer, df_trn, df_val, evaluate=True)\n",
667 | " os.makedirs(eval_output_dir, exist_ok=True)\n",
668 | " args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)\n",
669 | " # Note that DistributedSampler samples randomly\n",
670 | "\n",
671 | " def collate(examples: List[torch.Tensor]):\n",
672 | " if tokenizer._pad_token is None:\n",
673 | " return pad_sequence(examples, batch_first=True)\n",
674 | " return pad_sequence(examples, batch_first=True, padding_value=tokenizer.pad_token_id)\n",
675 | "\n",
676 | " eval_sampler = SequentialSampler(eval_dataset)\n",
677 | " eval_dataloader = DataLoader(\n",
678 | " eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size, collate_fn=collate, drop_last = True\n",
679 | " )\n",
680 | "\n",
681 | " # multi-gpu evaluate\n",
682 | " if args.n_gpu > 1:\n",
683 | " model = torch.nn.DataParallel(model)\n",
684 | "\n",
685 | " # Eval!\n",
686 | " logger.info(\"***** Running evaluation {} *****\".format(prefix))\n",
687 | " logger.info(\" Num examples = %d\", len(eval_dataset))\n",
688 | " logger.info(\" Batch size = %d\", args.eval_batch_size)\n",
689 | " eval_loss = 0.0\n",
690 | " nb_eval_steps = 0\n",
691 | " model.eval()\n",
692 | "\n",
693 | " for batch in tqdm(eval_dataloader, desc=\"Evaluating\"):\n",
694 | " inputs, labels = (batch, batch)\n",
695 | " inputs = inputs.to(args.device)\n",
696 | " labels = labels.to(args.device)\n",
697 | "\n",
698 | " with torch.no_grad():\n",
699 | " outputs = model(inputs, labels=labels)\n",
700 | " lm_loss = outputs[0]\n",
701 | " eval_loss += lm_loss.mean().item()\n",
702 | " nb_eval_steps += 1\n",
703 | "\n",
704 | " eval_loss = eval_loss / nb_eval_steps\n",
705 | " perplexity = torch.exp(torch.tensor(eval_loss))\n",
706 | "\n",
707 | " result = {\"perplexity\": perplexity}\n",
708 | "\n",
709 | " output_eval_file = os.path.join(eval_output_dir, prefix, \"eval_results.txt\")\n",
710 | " with open(output_eval_file, \"w\") as writer:\n",
711 | " logger.info(\"***** Eval results {} *****\".format(prefix))\n",
712 | " for key in sorted(result.keys()):\n",
713 | " logger.info(\" %s = %s\", key, str(result[key]))\n",
714 | " writer.write(\"%s = %s\\n\" % (key, str(result[key])))\n",
715 | "\n",
716 | " return result"
717 | ]
718 | },
719 | {
720 | "cell_type": "code",
721 | "execution_count": null,
722 | "metadata": {
723 | "id": "SCnGAJWbXD9C"
724 | },
725 | "outputs": [],
726 | "source": [
727 | "# Main runner\n",
728 | "\n",
729 | "def main(df_trn, df_val):\n",
730 | " args = Args()\n",
731 | " \n",
732 | " if args.should_continue:\n",
733 | " sorted_checkpoints = _sorted_checkpoints(args)\n",
734 | " if len(sorted_checkpoints) == 0:\n",
735 | " raise ValueError(\"Used --should_continue but no checkpoint was found in --output_dir.\")\n",
736 | " else:\n",
737 | " args.model_name_or_path = sorted_checkpoints[-1]\n",
738 | "\n",
739 | " if (\n",
740 | " os.path.exists(args.output_dir)\n",
741 | " and os.listdir(args.output_dir)\n",
742 | " and args.do_train\n",
743 | " and not args.overwrite_output_dir\n",
744 | " and not args.should_continue\n",
745 | " ):\n",
746 | " raise ValueError(\n",
747 | " \"Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.\".format(\n",
748 | " args.output_dir\n",
749 | " )\n",
750 | " )\n",
751 | "\n",
752 | " # Setup CUDA, GPU & distributed training\n",
753 | " device = torch.device(\"cuda\")\n",
754 | " args.n_gpu = torch.cuda.device_count()\n",
755 | " args.device = device\n",
756 | "\n",
757 | " # Setup logging\n",
758 | " logging.basicConfig(\n",
759 | " format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n",
760 | " datefmt=\"%m/%d/%Y %H:%M:%S\",\n",
761 | " level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,\n",
762 | " )\n",
763 | " logger.warning(\n",
764 | " \"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s\",\n",
765 | " args.local_rank,\n",
766 | " device,\n",
767 | " args.n_gpu,\n",
768 | " bool(args.local_rank != -1),\n",
769 | " args.fp16,\n",
770 | " )\n",
771 | "\n",
772 | " # Set seed\n",
773 | " set_seed(args)\n",
774 | "\n",
775 | " config = AutoConfig.from_pretrained(args.config_name, cache_dir=args.cache_dir)\n",
776 | " tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, cache_dir=args.cache_dir)\n",
777 | " model = AutoModelWithLMHead.from_pretrained(\n",
778 | " args.model_name_or_path,\n",
779 | " from_tf=False,\n",
780 | " config=config,\n",
781 | " cache_dir=args.cache_dir,\n",
782 | " )\n",
783 | " model.to(args.device)\n",
784 | " \n",
785 | " logger.info(\"Training/evaluation parameters %s\", args)\n",
786 | "\n",
787 | " # Training\n",
788 | " if args.do_train:\n",
789 | " train_dataset = load_and_cache_examples(args, tokenizer, df_trn, df_val, evaluate=False)\n",
790 | "\n",
791 | " global_step, tr_loss = train(args, train_dataset, model, tokenizer)\n",
792 | " logger.info(\" global_step = %s, average loss = %s\", global_step, tr_loss)\n",
793 | "\n",
794 | " # Saving best-practices: if you use save_pretrained for the model and tokenizer, you can reload them using from_pretrained()\n",
795 | " if args.do_train:\n",
796 | " # Create output directory if needed\n",
797 | " os.makedirs(args.output_dir, exist_ok=True)\n",
798 | "\n",
799 | " logger.info(\"Saving model checkpoint to %s\", args.output_dir)\n",
800 | " # Save a trained model, configuration and tokenizer using `save_pretrained()`.\n",
801 | " # They can then be reloaded using `from_pretrained()`\n",
802 | " model_to_save = (\n",
803 | " model.module if hasattr(model, \"module\") else model\n",
804 | " ) # Take care of distributed/parallel training\n",
805 | " model_to_save.save_pretrained(args.output_dir)\n",
806 | " tokenizer.save_pretrained(args.output_dir)\n",
807 | "\n",
808 | " # Good practice: save your training arguments together with the trained model\n",
809 | " torch.save(args, os.path.join(args.output_dir, \"training_args.bin\"))\n",
810 | "\n",
811 | " # Load a trained model and vocabulary that you have fine-tuned\n",
812 | " model = AutoModelWithLMHead.from_pretrained(args.output_dir)\n",
813 | " tokenizer = AutoTokenizer.from_pretrained(args.output_dir)\n",
814 | " model.to(args.device)\n",
815 | "\n",
816 | " # Evaluation\n",
817 | " results = {}\n",
818 | " if args.do_eval and args.local_rank in [-1, 0]:\n",
819 | " checkpoints = [args.output_dir]\n",
820 | " if args.eval_all_checkpoints:\n",
821 | " checkpoints = list(\n",
822 | " os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + \"/**/\" + WEIGHTS_NAME, recursive=True))\n",
823 | " )\n",
824 | " logging.getLogger(\"transformers.modeling_utils\").setLevel(logging.WARN) # Reduce logging\n",
825 | " logger.info(\"Evaluate the following checkpoints: %s\", checkpoints)\n",
826 | " for checkpoint in checkpoints:\n",
827 | " global_step = checkpoint.split(\"-\")[-1] if len(checkpoints) > 1 else \"\"\n",
828 | " prefix = checkpoint.split(\"/\")[-1] if checkpoint.find(\"checkpoint\") != -1 else \"\"\n",
829 | "\n",
830 | " model = AutoModelWithLMHead.from_pretrained(checkpoint)\n",
831 | " model.to(args.device)\n",
832 | " result = evaluate(args, model, tokenizer, df_trn, df_val, prefix=prefix)\n",
833 | " result = dict((k + \"_{}\".format(global_step), v) for k, v in result.items())\n",
834 | " results.update(result)\n",
835 | "\n",
836 | " return results"
837 | ]
838 | },
839 | {
840 | "cell_type": "markdown",
841 | "metadata": {
842 | "id": "7NWvkdR-XHeB"
843 | },
844 | "source": [
845 | "## Run the Main Function"
846 | ]
847 | },
848 | {
849 | "cell_type": "code",
850 | "execution_count": null,
851 | "metadata": {
852 | "colab": {
853 | "base_uri": "https://localhost:8080/",
854 | "height": 780,
855 | "referenced_widgets": [
856 | "1d7f4c82687540f1ad69eb54ac3c25b4",
857 | "e7b9f3fc77a24259a87ef0dc735dfecb",
858 | "f3bf54733c2d4d9daa1cc9a7746ccb14",
859 | "aa40eb6346b54e7dac98e0b068cd4927",
860 | "021b771a270f479aa3b9e2b5f17e3d97",
861 | "450b0e7fd7a347c7beb78b7d72f64385",
862 | "9391d7abf6ed4400903995f56d7a1260",
863 | "ea6b919964d24c2f9de1c64c9cefaf23",
864 | "2fa1fa2407384cb98d79a912de2d5b8f",
865 | "dc27e2caf1ea4a4ab9ae3708fb06952f",
866 | "e38fb98fd7b3413392dc39c93a107a35",
867 | "855ca0a6125a4d698416214a9425ad98",
868 | "4699416338ae40a5b6abf19e45089aec",
869 | "43fdb31d3f314624ba07a15718b0c8f3",
870 | "de252cd193114c40ad5f5e9622b7abc7",
871 | "5e48b617cc3f41c3945efc28fc5e0c75",
872 | "68a9dc52819c48fb97259f318f9b5c6a",
873 | "b4e00059cf3a49929978ed780aae8358",
874 | "0ff5f4e3506b493a98d72008a467f35f",
875 | "77b97fa3271b48ac9f93665a102b4fd1",
876 | "a937f1dfeee5432ba31b3016fd30e9e2",
877 | "3c6d446f491c48fcae03e0034bfaaae9",
878 | "a193bb3a0b5b4cbba587e2460075a445",
879 | "75f8aebc30304fe198b5a2898a53a92d",
880 | "8b8a7c771d234f6c9d758a1f07f75a90",
881 | "c6518c4a721745bf97ee682f2ebe4635",
882 | "29cffa2b4f234e12802344eb53838641",
883 | "96243b7b227f465f83a289481680b925",
884 | "8c016a54f0a24fcdacf369baa9d24f1e",
885 | "7fe5b457ca0f417f90a20d235e9cec07",
886 | "fdffb26b99c24c978580f1cf97359fea",
887 | "8e3f1740c82f47949eefc2eb53052eae",
888 | "9cccd43f6acc4e25b4876fd0ae7a2ad6",
889 | "175e94deab7f4d20b99b419bea33583b",
890 | "41f26f7210e540479814e5d68de13ddb",
891 | "cf5cd281fa3b453093e210650bf81e9e",
892 | "e1fbe239c2394cbf973ac5b95e1e1491",
893 | "810ac22adad344b7bf8b556ded990122",
894 | "8b3a41c1900b45ebb9c56601deca0e84",
895 | "002f56aac3d64b33a0e799c0baf1e6b9",
896 | "a0f2a9a279734aa5bf146f0a5b33c43b",
897 | "850b5411122e4d608511fe26818bea68",
898 | "0663fb4bd85f4d87a7d61910b995be14",
899 | "cb7f52610fcf49bda46a14b296ff5bb5",
900 | "0ca29b4a62e04d9c937189ea19b25de8",
901 | "f871b83632974e0088bae65e78efaf28",
902 | "4cacf7fc20754a7ca7fe08c8ec187a81",
903 | "8bcc625c0f284398bbd287fe45021b17"
904 | ]
905 | },
906 | "id": "e61zo2JtXGNX",
907 | "outputId": "22d4916e-7169-44b5-f9d8-79b9c43fab2e"
908 | },
909 | "outputs": [],
910 | "source": [
911 | "main(trn_df, val_df)"
912 | ]
913 | },
914 | {
915 | "cell_type": "markdown",
916 | "metadata": {
917 | "id": "YRpQ_n2zXQj-"
918 | },
919 | "source": [
920 | "## Load the Trained Model"
921 | ]
922 | },
923 | {
924 | "cell_type": "code",
925 | "execution_count": null,
926 | "metadata": {
927 | "colab": {
928 | "base_uri": "https://localhost:8080/"
929 | },
930 | "id": "HGw3qgfaXQHX",
931 | "outputId": "93e84cfd-9718-42e5-bd11-418112c91d71"
932 | },
933 | "outputs": [],
934 | "source": [
935 | "tokenizer = AutoTokenizer.from_pretrained('microsoft/DialoGPT-small')\n",
936 | "model = AutoModelWithLMHead.from_pretrained('output-small')"
937 | ]
938 | },
939 | {
940 | "cell_type": "code",
941 | "execution_count": null,
942 | "metadata": {
943 | "colab": {
944 | "base_uri": "https://localhost:8080/"
945 | },
946 | "id": "lAWsiAvNXbxd",
947 | "outputId": "0fd2541e-ee68-4976-b098-8483efe38d5e"
948 | },
949 | "outputs": [],
950 | "source": [
951 | "# Let's chat for 4 lines\n",
952 | "for step in range(4):\n",
953 | " # encode the new user input, add the eos_token and return a tensor in Pytorch\n",
954 | " new_user_input_ids = tokenizer.encode(input(\">> User:\") + tokenizer.eos_token, return_tensors='pt')\n",
955 | " # print(new_user_input_ids)\n",
956 | "\n",
957 | " # append the new user input tokens to the chat history\n",
958 | " bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) if step > 0 else new_user_input_ids\n",
959 | "\n",
960 | " # generated a response while limiting the total chat history to 1000 tokens, \n",
961 | " chat_history_ids = model.generate(\n",
962 | " bot_input_ids, max_length=200,\n",
963 | " pad_token_id=tokenizer.eos_token_id, \n",
964 | " no_repeat_ngram_size=3, \n",
965 | " do_sample=True, \n",
966 | " top_k=100, \n",
967 | " top_p=0.7,\n",
968 | " temperature=0.8\n",
969 | " )\n",
970 | " \n",
971 | " # pretty print last ouput tokens from bot\n",
972 | " print(\"JoshuaBot: {}\".format(tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)))"
973 | ]
974 | },
975 | {
976 | "cell_type": "markdown",
977 | "metadata": {
978 | "id": "ANSQlQezXqwn"
979 | },
980 | "source": [
981 | "## Push Model to Hugging Face"
982 | ]
983 | },
984 | {
985 | "cell_type": "code",
986 | "execution_count": null,
987 | "metadata": {
988 | "id": "VgnHRgHKXwDd"
989 | },
990 | "outputs": [],
991 | "source": [
992 | "!sudo apt-get install git-lfs"
993 | ]
994 | },
995 | {
996 | "cell_type": "code",
997 | "execution_count": null,
998 | "metadata": {
999 | "id": "uhqMtvfmXei8"
1000 | },
1001 | "outputs": [],
1002 | "source": [
1003 | "!git config --global user.email \"lynnzheng08@outlook.com\"\n",
1004 | "# Tip: using the same email as your huggingface.co account will link your commits to your profile\n",
1005 | "!git config --global user.name \"Lynn Zheng\""
1006 | ]
1007 | },
1008 | {
1009 | "cell_type": "code",
1010 | "execution_count": null,
1011 | "metadata": {
1012 | "id": "tfUsrKR7YLT1"
1013 | },
1014 | "outputs": [],
1015 | "source": [
1016 | "MY_MODEL_NAME = 'DialoGPT-small-joshua'\n",
1017 | "with open('HuggingFace-API-key.txt', 'rt') as f:\n",
1018 | " HUGGINGFACE_API_KEY = f.read().strip()"
1019 | ]
1020 | },
1021 | {
1022 | "cell_type": "code",
1023 | "execution_count": null,
1024 | "metadata": {
1025 | "colab": {
1026 | "base_uri": "https://localhost:8080/",
1027 | "height": 1000
1028 | },
1029 | "id": "_65nsiLcYNXI",
1030 | "outputId": "0dbf0cb1-957c-4adb-bf55-4222d2cc85bc"
1031 | },
1032 | "outputs": [],
1033 | "source": [
1034 | "model.push_to_hub(MY_MODEL_NAME, use_auth_token=HUGGINGFACE_API_KEY)\n",
1035 | "tokenizer.push_to_hub(MY_MODEL_NAME, use_auth_token=HUGGINGFACE_API_KEY)"
1036 | ]
1037 | },
1038 | {
1039 | "cell_type": "markdown",
1040 | "metadata": {
1041 | "id": "D_XfXTCrZKmO"
1042 | },
1043 | "source": [
1044 | "## All Done!"
1045 | ]
1046 | },
1047 | {
1048 | "cell_type": "code",
1049 | "execution_count": null,
1050 | "metadata": {
1051 | "id": "_tIwK7G8ZLrd"
1052 | },
1053 | "outputs": [],
1054 | "source": []
1055 | }
1056 | ],
1057 | "metadata": {
1058 | "accelerator": "GPU",
1059 | "colab": {
1060 | "collapsed_sections": [],
1061 | "name": "model_train_upload_workflow.ipynb",
1062 | "provenance": []
1063 | },
1064 | "kernelspec": {
1065 | "display_name": "Python 3",
1066 | "language": "python",
1067 | "name": "python3"
1068 | },
1069 | "language_info": {
1070 | "codemirror_mode": {
1071 | "name": "ipython",
1072 | "version": 3
1073 | },
1074 | "file_extension": ".py",
1075 | "mimetype": "text/x-python",
1076 | "name": "python",
1077 | "nbconvert_exporter": "python",
1078 | "pygments_lexer": "ipython3",
1079 | "version": "3.9.4"
1080 | },
1081 | "widgets": {
1082 | "application/vnd.jupyter.widget-state+json": {
1083 | "002f56aac3d64b33a0e799c0baf1e6b9": {
1084 | "model_module": "@jupyter-widgets/base",
1085 | "model_name": "LayoutModel",
1086 | "state": {
1087 | "_model_module": "@jupyter-widgets/base",
1088 | "_model_module_version": "1.2.0",
1089 | "_model_name": "LayoutModel",
1090 | "_view_count": null,
1091 | "_view_module": "@jupyter-widgets/base",
1092 | "_view_module_version": "1.2.0",
1093 | "_view_name": "LayoutView",
1094 | "align_content": null,
1095 | "align_items": null,
1096 | "align_self": null,
1097 | "border": null,
1098 | "bottom": null,
1099 | "display": null,
1100 | "flex": null,
1101 | "flex_flow": null,
1102 | "grid_area": null,
1103 | "grid_auto_columns": null,
1104 | "grid_auto_flow": null,
1105 | "grid_auto_rows": null,
1106 | "grid_column": null,
1107 | "grid_gap": null,
1108 | "grid_row": null,
1109 | "grid_template_areas": null,
1110 | "grid_template_columns": null,
1111 | "grid_template_rows": null,
1112 | "height": null,
1113 | "justify_content": null,
1114 | "justify_items": null,
1115 | "left": null,
1116 | "margin": null,
1117 | "max_height": null,
1118 | "max_width": null,
1119 | "min_height": null,
1120 | "min_width": null,
1121 | "object_fit": null,
1122 | "object_position": null,
1123 | "order": null,
1124 | "overflow": null,
1125 | "overflow_x": null,
1126 | "overflow_y": null,
1127 | "padding": null,
1128 | "right": null,
1129 | "top": null,
1130 | "visibility": null,
1131 | "width": null
1132 | }
1133 | },
1134 | "021b771a270f479aa3b9e2b5f17e3d97": {
1135 | "model_module": "@jupyter-widgets/controls",
1136 | "model_name": "ProgressStyleModel",
1137 | "state": {
1138 | "_model_module": "@jupyter-widgets/controls",
1139 | "_model_module_version": "1.5.0",
1140 | "_model_name": "ProgressStyleModel",
1141 | "_view_count": null,
1142 | "_view_module": "@jupyter-widgets/base",
1143 | "_view_module_version": "1.2.0",
1144 | "_view_name": "StyleView",
1145 | "bar_color": null,
1146 | "description_width": "initial"
1147 | }
1148 | },
1149 | "0663fb4bd85f4d87a7d61910b995be14": {
1150 | "model_module": "@jupyter-widgets/controls",
1151 | "model_name": "FloatProgressModel",
1152 | "state": {
1153 | "_dom_classes": [],
1154 | "_model_module": "@jupyter-widgets/controls",
1155 | "_model_module_version": "1.5.0",
1156 | "_model_name": "FloatProgressModel",
1157 | "_view_count": null,
1158 | "_view_module": "@jupyter-widgets/controls",
1159 | "_view_module_version": "1.5.0",
1160 | "_view_name": "ProgressView",
1161 | "bar_style": "success",
1162 | "description": "Evaluating: 100%",
1163 | "description_tooltip": null,
1164 | "layout": "IPY_MODEL_f871b83632974e0088bae65e78efaf28",
1165 | "max": 21,
1166 | "min": 0,
1167 | "orientation": "horizontal",
1168 | "style": "IPY_MODEL_0ca29b4a62e04d9c937189ea19b25de8",
1169 | "value": 21
1170 | }
1171 | },
1172 | "0ca29b4a62e04d9c937189ea19b25de8": {
1173 | "model_module": "@jupyter-widgets/controls",
1174 | "model_name": "ProgressStyleModel",
1175 | "state": {
1176 | "_model_module": "@jupyter-widgets/controls",
1177 | "_model_module_version": "1.5.0",
1178 | "_model_name": "ProgressStyleModel",
1179 | "_view_count": null,
1180 | "_view_module": "@jupyter-widgets/base",
1181 | "_view_module_version": "1.2.0",
1182 | "_view_name": "StyleView",
1183 | "bar_color": null,
1184 | "description_width": "initial"
1185 | }
1186 | },
1187 | "0ff5f4e3506b493a98d72008a467f35f": {
1188 | "model_module": "@jupyter-widgets/controls",
1189 | "model_name": "FloatProgressModel",
1190 | "state": {
1191 | "_dom_classes": [],
1192 | "_model_module": "@jupyter-widgets/controls",
1193 | "_model_module_version": "1.5.0",
1194 | "_model_name": "FloatProgressModel",
1195 | "_view_count": null,
1196 | "_view_module": "@jupyter-widgets/controls",
1197 | "_view_module_version": "1.5.0",
1198 | "_view_name": "ProgressView",
1199 | "bar_style": "success",
1200 | "description": "Iteration: 100%",
1201 | "description_tooltip": null,
1202 | "layout": "IPY_MODEL_3c6d446f491c48fcae03e0034bfaaae9",
1203 | "max": 195,
1204 | "min": 0,
1205 | "orientation": "horizontal",
1206 | "style": "IPY_MODEL_a937f1dfeee5432ba31b3016fd30e9e2",
1207 | "value": 195
1208 | }
1209 | },
1210 | "175e94deab7f4d20b99b419bea33583b": {
1211 | "model_module": "@jupyter-widgets/base",
1212 | "model_name": "LayoutModel",
1213 | "state": {
1214 | "_model_module": "@jupyter-widgets/base",
1215 | "_model_module_version": "1.2.0",
1216 | "_model_name": "LayoutModel",
1217 | "_view_count": null,
1218 | "_view_module": "@jupyter-widgets/base",
1219 | "_view_module_version": "1.2.0",
1220 | "_view_name": "LayoutView",
1221 | "align_content": null,
1222 | "align_items": null,
1223 | "align_self": null,
1224 | "border": null,
1225 | "bottom": null,
1226 | "display": null,
1227 | "flex": null,
1228 | "flex_flow": null,
1229 | "grid_area": null,
1230 | "grid_auto_columns": null,
1231 | "grid_auto_flow": null,
1232 | "grid_auto_rows": null,
1233 | "grid_column": null,
1234 | "grid_gap": null,
1235 | "grid_row": null,
1236 | "grid_template_areas": null,
1237 | "grid_template_columns": null,
1238 | "grid_template_rows": null,
1239 | "height": null,
1240 | "justify_content": null,
1241 | "justify_items": null,
1242 | "left": null,
1243 | "margin": null,
1244 | "max_height": null,
1245 | "max_width": null,
1246 | "min_height": null,
1247 | "min_width": null,
1248 | "object_fit": null,
1249 | "object_position": null,
1250 | "order": null,
1251 | "overflow": null,
1252 | "overflow_x": null,
1253 | "overflow_y": null,
1254 | "padding": null,
1255 | "right": null,
1256 | "top": null,
1257 | "visibility": null,
1258 | "width": null
1259 | }
1260 | },
1261 | "1d7f4c82687540f1ad69eb54ac3c25b4": {
1262 | "model_module": "@jupyter-widgets/controls",
1263 | "model_name": "HBoxModel",
1264 | "state": {
1265 | "_dom_classes": [],
1266 | "_model_module": "@jupyter-widgets/controls",
1267 | "_model_module_version": "1.5.0",
1268 | "_model_name": "HBoxModel",
1269 | "_view_count": null,
1270 | "_view_module": "@jupyter-widgets/controls",
1271 | "_view_module_version": "1.5.0",
1272 | "_view_name": "HBoxView",
1273 | "box_style": "",
1274 | "children": [
1275 | "IPY_MODEL_f3bf54733c2d4d9daa1cc9a7746ccb14",
1276 | "IPY_MODEL_aa40eb6346b54e7dac98e0b068cd4927"
1277 | ],
1278 | "layout": "IPY_MODEL_e7b9f3fc77a24259a87ef0dc735dfecb"
1279 | }
1280 | },
1281 | "29cffa2b4f234e12802344eb53838641": {
1282 | "model_module": "@jupyter-widgets/controls",
1283 | "model_name": "FloatProgressModel",
1284 | "state": {
1285 | "_dom_classes": [],
1286 | "_model_module": "@jupyter-widgets/controls",
1287 | "_model_module_version": "1.5.0",
1288 | "_model_name": "FloatProgressModel",
1289 | "_view_count": null,
1290 | "_view_module": "@jupyter-widgets/controls",
1291 | "_view_module_version": "1.5.0",
1292 | "_view_name": "ProgressView",
1293 | "bar_style": "success",
1294 | "description": "Iteration: 100%",
1295 | "description_tooltip": null,
1296 | "layout": "IPY_MODEL_7fe5b457ca0f417f90a20d235e9cec07",
1297 | "max": 195,
1298 | "min": 0,
1299 | "orientation": "horizontal",
1300 | "style": "IPY_MODEL_8c016a54f0a24fcdacf369baa9d24f1e",
1301 | "value": 195
1302 | }
1303 | },
1304 | "2fa1fa2407384cb98d79a912de2d5b8f": {
1305 | "model_module": "@jupyter-widgets/controls",
1306 | "model_name": "HBoxModel",
1307 | "state": {
1308 | "_dom_classes": [],
1309 | "_model_module": "@jupyter-widgets/controls",
1310 | "_model_module_version": "1.5.0",
1311 | "_model_name": "HBoxModel",
1312 | "_view_count": null,
1313 | "_view_module": "@jupyter-widgets/controls",
1314 | "_view_module_version": "1.5.0",
1315 | "_view_name": "HBoxView",
1316 | "box_style": "",
1317 | "children": [
1318 | "IPY_MODEL_e38fb98fd7b3413392dc39c93a107a35",
1319 | "IPY_MODEL_855ca0a6125a4d698416214a9425ad98"
1320 | ],
1321 | "layout": "IPY_MODEL_dc27e2caf1ea4a4ab9ae3708fb06952f"
1322 | }
1323 | },
1324 | "3c6d446f491c48fcae03e0034bfaaae9": {
1325 | "model_module": "@jupyter-widgets/base",
1326 | "model_name": "LayoutModel",
1327 | "state": {
1328 | "_model_module": "@jupyter-widgets/base",
1329 | "_model_module_version": "1.2.0",
1330 | "_model_name": "LayoutModel",
1331 | "_view_count": null,
1332 | "_view_module": "@jupyter-widgets/base",
1333 | "_view_module_version": "1.2.0",
1334 | "_view_name": "LayoutView",
1335 | "align_content": null,
1336 | "align_items": null,
1337 | "align_self": null,
1338 | "border": null,
1339 | "bottom": null,
1340 | "display": null,
1341 | "flex": null,
1342 | "flex_flow": null,
1343 | "grid_area": null,
1344 | "grid_auto_columns": null,
1345 | "grid_auto_flow": null,
1346 | "grid_auto_rows": null,
1347 | "grid_column": null,
1348 | "grid_gap": null,
1349 | "grid_row": null,
1350 | "grid_template_areas": null,
1351 | "grid_template_columns": null,
1352 | "grid_template_rows": null,
1353 | "height": null,
1354 | "justify_content": null,
1355 | "justify_items": null,
1356 | "left": null,
1357 | "margin": null,
1358 | "max_height": null,
1359 | "max_width": null,
1360 | "min_height": null,
1361 | "min_width": null,
1362 | "object_fit": null,
1363 | "object_position": null,
1364 | "order": null,
1365 | "overflow": null,
1366 | "overflow_x": null,
1367 | "overflow_y": null,
1368 | "padding": null,
1369 | "right": null,
1370 | "top": null,
1371 | "visibility": null,
1372 | "width": null
1373 | }
1374 | },
1375 | "41f26f7210e540479814e5d68de13ddb": {
1376 | "model_module": "@jupyter-widgets/controls",
1377 | "model_name": "FloatProgressModel",
1378 | "state": {
1379 | "_dom_classes": [],
1380 | "_model_module": "@jupyter-widgets/controls",
1381 | "_model_module_version": "1.5.0",
1382 | "_model_name": "FloatProgressModel",
1383 | "_view_count": null,
1384 | "_view_module": "@jupyter-widgets/controls",
1385 | "_view_module_version": "1.5.0",
1386 | "_view_name": "ProgressView",
1387 | "bar_style": "success",
1388 | "description": "Iteration: 100%",
1389 | "description_tooltip": null,
1390 | "layout": "IPY_MODEL_810ac22adad344b7bf8b556ded990122",
1391 | "max": 195,
1392 | "min": 0,
1393 | "orientation": "horizontal",
1394 | "style": "IPY_MODEL_e1fbe239c2394cbf973ac5b95e1e1491",
1395 | "value": 195
1396 | }
1397 | },
1398 | "43fdb31d3f314624ba07a15718b0c8f3": {
1399 | "model_module": "@jupyter-widgets/base",
1400 | "model_name": "LayoutModel",
1401 | "state": {
1402 | "_model_module": "@jupyter-widgets/base",
1403 | "_model_module_version": "1.2.0",
1404 | "_model_name": "LayoutModel",
1405 | "_view_count": null,
1406 | "_view_module": "@jupyter-widgets/base",
1407 | "_view_module_version": "1.2.0",
1408 | "_view_name": "LayoutView",
1409 | "align_content": null,
1410 | "align_items": null,
1411 | "align_self": null,
1412 | "border": null,
1413 | "bottom": null,
1414 | "display": null,
1415 | "flex": null,
1416 | "flex_flow": null,
1417 | "grid_area": null,
1418 | "grid_auto_columns": null,
1419 | "grid_auto_flow": null,
1420 | "grid_auto_rows": null,
1421 | "grid_column": null,
1422 | "grid_gap": null,
1423 | "grid_row": null,
1424 | "grid_template_areas": null,
1425 | "grid_template_columns": null,
1426 | "grid_template_rows": null,
1427 | "height": null,
1428 | "justify_content": null,
1429 | "justify_items": null,
1430 | "left": null,
1431 | "margin": null,
1432 | "max_height": null,
1433 | "max_width": null,
1434 | "min_height": null,
1435 | "min_width": null,
1436 | "object_fit": null,
1437 | "object_position": null,
1438 | "order": null,
1439 | "overflow": null,
1440 | "overflow_x": null,
1441 | "overflow_y": null,
1442 | "padding": null,
1443 | "right": null,
1444 | "top": null,
1445 | "visibility": null,
1446 | "width": null
1447 | }
1448 | },
1449 | "450b0e7fd7a347c7beb78b7d72f64385": {
1450 | "model_module": "@jupyter-widgets/base",
1451 | "model_name": "LayoutModel",
1452 | "state": {
1453 | "_model_module": "@jupyter-widgets/base",
1454 | "_model_module_version": "1.2.0",
1455 | "_model_name": "LayoutModel",
1456 | "_view_count": null,
1457 | "_view_module": "@jupyter-widgets/base",
1458 | "_view_module_version": "1.2.0",
1459 | "_view_name": "LayoutView",
1460 | "align_content": null,
1461 | "align_items": null,
1462 | "align_self": null,
1463 | "border": null,
1464 | "bottom": null,
1465 | "display": null,
1466 | "flex": null,
1467 | "flex_flow": null,
1468 | "grid_area": null,
1469 | "grid_auto_columns": null,
1470 | "grid_auto_flow": null,
1471 | "grid_auto_rows": null,
1472 | "grid_column": null,
1473 | "grid_gap": null,
1474 | "grid_row": null,
1475 | "grid_template_areas": null,
1476 | "grid_template_columns": null,
1477 | "grid_template_rows": null,
1478 | "height": null,
1479 | "justify_content": null,
1480 | "justify_items": null,
1481 | "left": null,
1482 | "margin": null,
1483 | "max_height": null,
1484 | "max_width": null,
1485 | "min_height": null,
1486 | "min_width": null,
1487 | "object_fit": null,
1488 | "object_position": null,
1489 | "order": null,
1490 | "overflow": null,
1491 | "overflow_x": null,
1492 | "overflow_y": null,
1493 | "padding": null,
1494 | "right": null,
1495 | "top": null,
1496 | "visibility": null,
1497 | "width": null
1498 | }
1499 | },
1500 | "4699416338ae40a5b6abf19e45089aec": {
1501 | "model_module": "@jupyter-widgets/controls",
1502 | "model_name": "ProgressStyleModel",
1503 | "state": {
1504 | "_model_module": "@jupyter-widgets/controls",
1505 | "_model_module_version": "1.5.0",
1506 | "_model_name": "ProgressStyleModel",
1507 | "_view_count": null,
1508 | "_view_module": "@jupyter-widgets/base",
1509 | "_view_module_version": "1.2.0",
1510 | "_view_name": "StyleView",
1511 | "bar_color": null,
1512 | "description_width": "initial"
1513 | }
1514 | },
1515 | "4cacf7fc20754a7ca7fe08c8ec187a81": {
1516 | "model_module": "@jupyter-widgets/controls",
1517 | "model_name": "DescriptionStyleModel",
1518 | "state": {
1519 | "_model_module": "@jupyter-widgets/controls",
1520 | "_model_module_version": "1.5.0",
1521 | "_model_name": "DescriptionStyleModel",
1522 | "_view_count": null,
1523 | "_view_module": "@jupyter-widgets/base",
1524 | "_view_module_version": "1.2.0",
1525 | "_view_name": "StyleView",
1526 | "description_width": ""
1527 | }
1528 | },
1529 | "5e48b617cc3f41c3945efc28fc5e0c75": {
1530 | "model_module": "@jupyter-widgets/base",
1531 | "model_name": "LayoutModel",
1532 | "state": {
1533 | "_model_module": "@jupyter-widgets/base",
1534 | "_model_module_version": "1.2.0",
1535 | "_model_name": "LayoutModel",
1536 | "_view_count": null,
1537 | "_view_module": "@jupyter-widgets/base",
1538 | "_view_module_version": "1.2.0",
1539 | "_view_name": "LayoutView",
1540 | "align_content": null,
1541 | "align_items": null,
1542 | "align_self": null,
1543 | "border": null,
1544 | "bottom": null,
1545 | "display": null,
1546 | "flex": null,
1547 | "flex_flow": null,
1548 | "grid_area": null,
1549 | "grid_auto_columns": null,
1550 | "grid_auto_flow": null,
1551 | "grid_auto_rows": null,
1552 | "grid_column": null,
1553 | "grid_gap": null,
1554 | "grid_row": null,
1555 | "grid_template_areas": null,
1556 | "grid_template_columns": null,
1557 | "grid_template_rows": null,
1558 | "height": null,
1559 | "justify_content": null,
1560 | "justify_items": null,
1561 | "left": null,
1562 | "margin": null,
1563 | "max_height": null,
1564 | "max_width": null,
1565 | "min_height": null,
1566 | "min_width": null,
1567 | "object_fit": null,
1568 | "object_position": null,
1569 | "order": null,
1570 | "overflow": null,
1571 | "overflow_x": null,
1572 | "overflow_y": null,
1573 | "padding": null,
1574 | "right": null,
1575 | "top": null,
1576 | "visibility": null,
1577 | "width": null
1578 | }
1579 | },
1580 | "68a9dc52819c48fb97259f318f9b5c6a": {
1581 | "model_module": "@jupyter-widgets/controls",
1582 | "model_name": "HBoxModel",
1583 | "state": {
1584 | "_dom_classes": [],
1585 | "_model_module": "@jupyter-widgets/controls",
1586 | "_model_module_version": "1.5.0",
1587 | "_model_name": "HBoxModel",
1588 | "_view_count": null,
1589 | "_view_module": "@jupyter-widgets/controls",
1590 | "_view_module_version": "1.5.0",
1591 | "_view_name": "HBoxView",
1592 | "box_style": "",
1593 | "children": [
1594 | "IPY_MODEL_0ff5f4e3506b493a98d72008a467f35f",
1595 | "IPY_MODEL_77b97fa3271b48ac9f93665a102b4fd1"
1596 | ],
1597 | "layout": "IPY_MODEL_b4e00059cf3a49929978ed780aae8358"
1598 | }
1599 | },
1600 | "75f8aebc30304fe198b5a2898a53a92d": {
1601 | "model_module": "@jupyter-widgets/base",
1602 | "model_name": "LayoutModel",
1603 | "state": {
1604 | "_model_module": "@jupyter-widgets/base",
1605 | "_model_module_version": "1.2.0",
1606 | "_model_name": "LayoutModel",
1607 | "_view_count": null,
1608 | "_view_module": "@jupyter-widgets/base",
1609 | "_view_module_version": "1.2.0",
1610 | "_view_name": "LayoutView",
1611 | "align_content": null,
1612 | "align_items": null,
1613 | "align_self": null,
1614 | "border": null,
1615 | "bottom": null,
1616 | "display": null,
1617 | "flex": null,
1618 | "flex_flow": null,
1619 | "grid_area": null,
1620 | "grid_auto_columns": null,
1621 | "grid_auto_flow": null,
1622 | "grid_auto_rows": null,
1623 | "grid_column": null,
1624 | "grid_gap": null,
1625 | "grid_row": null,
1626 | "grid_template_areas": null,
1627 | "grid_template_columns": null,
1628 | "grid_template_rows": null,
1629 | "height": null,
1630 | "justify_content": null,
1631 | "justify_items": null,
1632 | "left": null,
1633 | "margin": null,
1634 | "max_height": null,
1635 | "max_width": null,
1636 | "min_height": null,
1637 | "min_width": null,
1638 | "object_fit": null,
1639 | "object_position": null,
1640 | "order": null,
1641 | "overflow": null,
1642 | "overflow_x": null,
1643 | "overflow_y": null,
1644 | "padding": null,
1645 | "right": null,
1646 | "top": null,
1647 | "visibility": null,
1648 | "width": null
1649 | }
1650 | },
1651 | "77b97fa3271b48ac9f93665a102b4fd1": {
1652 | "model_module": "@jupyter-widgets/controls",
1653 | "model_name": "HTMLModel",
1654 | "state": {
1655 | "_dom_classes": [],
1656 | "_model_module": "@jupyter-widgets/controls",
1657 | "_model_module_version": "1.5.0",
1658 | "_model_name": "HTMLModel",
1659 | "_view_count": null,
1660 | "_view_module": "@jupyter-widgets/controls",
1661 | "_view_module_version": "1.5.0",
1662 | "_view_name": "HTMLView",
1663 | "description": "",
1664 | "description_tooltip": null,
1665 | "layout": "IPY_MODEL_75f8aebc30304fe198b5a2898a53a92d",
1666 | "placeholder": "",
1667 | "style": "IPY_MODEL_a193bb3a0b5b4cbba587e2460075a445",
1668 | "value": " 195/195 [00:35<00:00, 5.45it/s]"
1669 | }
1670 | },
1671 | "7fe5b457ca0f417f90a20d235e9cec07": {
1672 | "model_module": "@jupyter-widgets/base",
1673 | "model_name": "LayoutModel",
1674 | "state": {
1675 | "_model_module": "@jupyter-widgets/base",
1676 | "_model_module_version": "1.2.0",
1677 | "_model_name": "LayoutModel",
1678 | "_view_count": null,
1679 | "_view_module": "@jupyter-widgets/base",
1680 | "_view_module_version": "1.2.0",
1681 | "_view_name": "LayoutView",
1682 | "align_content": null,
1683 | "align_items": null,
1684 | "align_self": null,
1685 | "border": null,
1686 | "bottom": null,
1687 | "display": null,
1688 | "flex": null,
1689 | "flex_flow": null,
1690 | "grid_area": null,
1691 | "grid_auto_columns": null,
1692 | "grid_auto_flow": null,
1693 | "grid_auto_rows": null,
1694 | "grid_column": null,
1695 | "grid_gap": null,
1696 | "grid_row": null,
1697 | "grid_template_areas": null,
1698 | "grid_template_columns": null,
1699 | "grid_template_rows": null,
1700 | "height": null,
1701 | "justify_content": null,
1702 | "justify_items": null,
1703 | "left": null,
1704 | "margin": null,
1705 | "max_height": null,
1706 | "max_width": null,
1707 | "min_height": null,
1708 | "min_width": null,
1709 | "object_fit": null,
1710 | "object_position": null,
1711 | "order": null,
1712 | "overflow": null,
1713 | "overflow_x": null,
1714 | "overflow_y": null,
1715 | "padding": null,
1716 | "right": null,
1717 | "top": null,
1718 | "visibility": null,
1719 | "width": null
1720 | }
1721 | },
1722 | "810ac22adad344b7bf8b556ded990122": {
1723 | "model_module": "@jupyter-widgets/base",
1724 | "model_name": "LayoutModel",
1725 | "state": {
1726 | "_model_module": "@jupyter-widgets/base",
1727 | "_model_module_version": "1.2.0",
1728 | "_model_name": "LayoutModel",
1729 | "_view_count": null,
1730 | "_view_module": "@jupyter-widgets/base",
1731 | "_view_module_version": "1.2.0",
1732 | "_view_name": "LayoutView",
1733 | "align_content": null,
1734 | "align_items": null,
1735 | "align_self": null,
1736 | "border": null,
1737 | "bottom": null,
1738 | "display": null,
1739 | "flex": null,
1740 | "flex_flow": null,
1741 | "grid_area": null,
1742 | "grid_auto_columns": null,
1743 | "grid_auto_flow": null,
1744 | "grid_auto_rows": null,
1745 | "grid_column": null,
1746 | "grid_gap": null,
1747 | "grid_row": null,
1748 | "grid_template_areas": null,
1749 | "grid_template_columns": null,
1750 | "grid_template_rows": null,
1751 | "height": null,
1752 | "justify_content": null,
1753 | "justify_items": null,
1754 | "left": null,
1755 | "margin": null,
1756 | "max_height": null,
1757 | "max_width": null,
1758 | "min_height": null,
1759 | "min_width": null,
1760 | "object_fit": null,
1761 | "object_position": null,
1762 | "order": null,
1763 | "overflow": null,
1764 | "overflow_x": null,
1765 | "overflow_y": null,
1766 | "padding": null,
1767 | "right": null,
1768 | "top": null,
1769 | "visibility": null,
1770 | "width": null
1771 | }
1772 | },
1773 | "850b5411122e4d608511fe26818bea68": {
1774 | "model_module": "@jupyter-widgets/base",
1775 | "model_name": "LayoutModel",
1776 | "state": {
1777 | "_model_module": "@jupyter-widgets/base",
1778 | "_model_module_version": "1.2.0",
1779 | "_model_name": "LayoutModel",
1780 | "_view_count": null,
1781 | "_view_module": "@jupyter-widgets/base",
1782 | "_view_module_version": "1.2.0",
1783 | "_view_name": "LayoutView",
1784 | "align_content": null,
1785 | "align_items": null,
1786 | "align_self": null,
1787 | "border": null,
1788 | "bottom": null,
1789 | "display": null,
1790 | "flex": null,
1791 | "flex_flow": null,
1792 | "grid_area": null,
1793 | "grid_auto_columns": null,
1794 | "grid_auto_flow": null,
1795 | "grid_auto_rows": null,
1796 | "grid_column": null,
1797 | "grid_gap": null,
1798 | "grid_row": null,
1799 | "grid_template_areas": null,
1800 | "grid_template_columns": null,
1801 | "grid_template_rows": null,
1802 | "height": null,
1803 | "justify_content": null,
1804 | "justify_items": null,
1805 | "left": null,
1806 | "margin": null,
1807 | "max_height": null,
1808 | "max_width": null,
1809 | "min_height": null,
1810 | "min_width": null,
1811 | "object_fit": null,
1812 | "object_position": null,
1813 | "order": null,
1814 | "overflow": null,
1815 | "overflow_x": null,
1816 | "overflow_y": null,
1817 | "padding": null,
1818 | "right": null,
1819 | "top": null,
1820 | "visibility": null,
1821 | "width": null
1822 | }
1823 | },
1824 | "855ca0a6125a4d698416214a9425ad98": {
1825 | "model_module": "@jupyter-widgets/controls",
1826 | "model_name": "HTMLModel",
1827 | "state": {
1828 | "_dom_classes": [],
1829 | "_model_module": "@jupyter-widgets/controls",
1830 | "_model_module_version": "1.5.0",
1831 | "_model_name": "HTMLModel",
1832 | "_view_count": null,
1833 | "_view_module": "@jupyter-widgets/controls",
1834 | "_view_module_version": "1.5.0",
1835 | "_view_name": "HTMLView",
1836 | "description": "",
1837 | "description_tooltip": null,
1838 | "layout": "IPY_MODEL_5e48b617cc3f41c3945efc28fc5e0c75",
1839 | "placeholder": "",
1840 | "style": "IPY_MODEL_de252cd193114c40ad5f5e9622b7abc7",
1841 | "value": " 195/195 [00:44<00:00, 4.39it/s]"
1842 | }
1843 | },
1844 | "8b3a41c1900b45ebb9c56601deca0e84": {
1845 | "model_module": "@jupyter-widgets/controls",
1846 | "model_name": "DescriptionStyleModel",
1847 | "state": {
1848 | "_model_module": "@jupyter-widgets/controls",
1849 | "_model_module_version": "1.5.0",
1850 | "_model_name": "DescriptionStyleModel",
1851 | "_view_count": null,
1852 | "_view_module": "@jupyter-widgets/base",
1853 | "_view_module_version": "1.2.0",
1854 | "_view_name": "StyleView",
1855 | "description_width": ""
1856 | }
1857 | },
1858 | "8b8a7c771d234f6c9d758a1f07f75a90": {
1859 | "model_module": "@jupyter-widgets/controls",
1860 | "model_name": "HBoxModel",
1861 | "state": {
1862 | "_dom_classes": [],
1863 | "_model_module": "@jupyter-widgets/controls",
1864 | "_model_module_version": "1.5.0",
1865 | "_model_name": "HBoxModel",
1866 | "_view_count": null,
1867 | "_view_module": "@jupyter-widgets/controls",
1868 | "_view_module_version": "1.5.0",
1869 | "_view_name": "HBoxView",
1870 | "box_style": "",
1871 | "children": [
1872 | "IPY_MODEL_29cffa2b4f234e12802344eb53838641",
1873 | "IPY_MODEL_96243b7b227f465f83a289481680b925"
1874 | ],
1875 | "layout": "IPY_MODEL_c6518c4a721745bf97ee682f2ebe4635"
1876 | }
1877 | },
1878 | "8bcc625c0f284398bbd287fe45021b17": {
1879 | "model_module": "@jupyter-widgets/base",
1880 | "model_name": "LayoutModel",
1881 | "state": {
1882 | "_model_module": "@jupyter-widgets/base",
1883 | "_model_module_version": "1.2.0",
1884 | "_model_name": "LayoutModel",
1885 | "_view_count": null,
1886 | "_view_module": "@jupyter-widgets/base",
1887 | "_view_module_version": "1.2.0",
1888 | "_view_name": "LayoutView",
1889 | "align_content": null,
1890 | "align_items": null,
1891 | "align_self": null,
1892 | "border": null,
1893 | "bottom": null,
1894 | "display": null,
1895 | "flex": null,
1896 | "flex_flow": null,
1897 | "grid_area": null,
1898 | "grid_auto_columns": null,
1899 | "grid_auto_flow": null,
1900 | "grid_auto_rows": null,
1901 | "grid_column": null,
1902 | "grid_gap": null,
1903 | "grid_row": null,
1904 | "grid_template_areas": null,
1905 | "grid_template_columns": null,
1906 | "grid_template_rows": null,
1907 | "height": null,
1908 | "justify_content": null,
1909 | "justify_items": null,
1910 | "left": null,
1911 | "margin": null,
1912 | "max_height": null,
1913 | "max_width": null,
1914 | "min_height": null,
1915 | "min_width": null,
1916 | "object_fit": null,
1917 | "object_position": null,
1918 | "order": null,
1919 | "overflow": null,
1920 | "overflow_x": null,
1921 | "overflow_y": null,
1922 | "padding": null,
1923 | "right": null,
1924 | "top": null,
1925 | "visibility": null,
1926 | "width": null
1927 | }
1928 | },
1929 | "8c016a54f0a24fcdacf369baa9d24f1e": {
1930 | "model_module": "@jupyter-widgets/controls",
1931 | "model_name": "ProgressStyleModel",
1932 | "state": {
1933 | "_model_module": "@jupyter-widgets/controls",
1934 | "_model_module_version": "1.5.0",
1935 | "_model_name": "ProgressStyleModel",
1936 | "_view_count": null,
1937 | "_view_module": "@jupyter-widgets/base",
1938 | "_view_module_version": "1.2.0",
1939 | "_view_name": "StyleView",
1940 | "bar_color": null,
1941 | "description_width": "initial"
1942 | }
1943 | },
1944 | "8e3f1740c82f47949eefc2eb53052eae": {
1945 | "model_module": "@jupyter-widgets/base",
1946 | "model_name": "LayoutModel",
1947 | "state": {
1948 | "_model_module": "@jupyter-widgets/base",
1949 | "_model_module_version": "1.2.0",
1950 | "_model_name": "LayoutModel",
1951 | "_view_count": null,
1952 | "_view_module": "@jupyter-widgets/base",
1953 | "_view_module_version": "1.2.0",
1954 | "_view_name": "LayoutView",
1955 | "align_content": null,
1956 | "align_items": null,
1957 | "align_self": null,
1958 | "border": null,
1959 | "bottom": null,
1960 | "display": null,
1961 | "flex": null,
1962 | "flex_flow": null,
1963 | "grid_area": null,
1964 | "grid_auto_columns": null,
1965 | "grid_auto_flow": null,
1966 | "grid_auto_rows": null,
1967 | "grid_column": null,
1968 | "grid_gap": null,
1969 | "grid_row": null,
1970 | "grid_template_areas": null,
1971 | "grid_template_columns": null,
1972 | "grid_template_rows": null,
1973 | "height": null,
1974 | "justify_content": null,
1975 | "justify_items": null,
1976 | "left": null,
1977 | "margin": null,
1978 | "max_height": null,
1979 | "max_width": null,
1980 | "min_height": null,
1981 | "min_width": null,
1982 | "object_fit": null,
1983 | "object_position": null,
1984 | "order": null,
1985 | "overflow": null,
1986 | "overflow_x": null,
1987 | "overflow_y": null,
1988 | "padding": null,
1989 | "right": null,
1990 | "top": null,
1991 | "visibility": null,
1992 | "width": null
1993 | }
1994 | },
1995 | "9391d7abf6ed4400903995f56d7a1260": {
1996 | "model_module": "@jupyter-widgets/controls",
1997 | "model_name": "DescriptionStyleModel",
1998 | "state": {
1999 | "_model_module": "@jupyter-widgets/controls",
2000 | "_model_module_version": "1.5.0",
2001 | "_model_name": "DescriptionStyleModel",
2002 | "_view_count": null,
2003 | "_view_module": "@jupyter-widgets/base",
2004 | "_view_module_version": "1.2.0",
2005 | "_view_name": "StyleView",
2006 | "description_width": ""
2007 | }
2008 | },
2009 | "96243b7b227f465f83a289481680b925": {
2010 | "model_module": "@jupyter-widgets/controls",
2011 | "model_name": "HTMLModel",
2012 | "state": {
2013 | "_dom_classes": [],
2014 | "_model_module": "@jupyter-widgets/controls",
2015 | "_model_module_version": "1.5.0",
2016 | "_model_name": "HTMLModel",
2017 | "_view_count": null,
2018 | "_view_module": "@jupyter-widgets/controls",
2019 | "_view_module_version": "1.5.0",
2020 | "_view_name": "HTMLView",
2021 | "description": "",
2022 | "description_tooltip": null,
2023 | "layout": "IPY_MODEL_8e3f1740c82f47949eefc2eb53052eae",
2024 | "placeholder": "",
2025 | "style": "IPY_MODEL_fdffb26b99c24c978580f1cf97359fea",
2026 | "value": " 195/195 [01:17<00:00, 2.53it/s]"
2027 | }
2028 | },
2029 | "9cccd43f6acc4e25b4876fd0ae7a2ad6": {
2030 | "model_module": "@jupyter-widgets/controls",
2031 | "model_name": "HBoxModel",
2032 | "state": {
2033 | "_dom_classes": [],
2034 | "_model_module": "@jupyter-widgets/controls",
2035 | "_model_module_version": "1.5.0",
2036 | "_model_name": "HBoxModel",
2037 | "_view_count": null,
2038 | "_view_module": "@jupyter-widgets/controls",
2039 | "_view_module_version": "1.5.0",
2040 | "_view_name": "HBoxView",
2041 | "box_style": "",
2042 | "children": [
2043 | "IPY_MODEL_41f26f7210e540479814e5d68de13ddb",
2044 | "IPY_MODEL_cf5cd281fa3b453093e210650bf81e9e"
2045 | ],
2046 | "layout": "IPY_MODEL_175e94deab7f4d20b99b419bea33583b"
2047 | }
2048 | },
2049 | "a0f2a9a279734aa5bf146f0a5b33c43b": {
2050 | "model_module": "@jupyter-widgets/controls",
2051 | "model_name": "HBoxModel",
2052 | "state": {
2053 | "_dom_classes": [],
2054 | "_model_module": "@jupyter-widgets/controls",
2055 | "_model_module_version": "1.5.0",
2056 | "_model_name": "HBoxModel",
2057 | "_view_count": null,
2058 | "_view_module": "@jupyter-widgets/controls",
2059 | "_view_module_version": "1.5.0",
2060 | "_view_name": "HBoxView",
2061 | "box_style": "",
2062 | "children": [
2063 | "IPY_MODEL_0663fb4bd85f4d87a7d61910b995be14",
2064 | "IPY_MODEL_cb7f52610fcf49bda46a14b296ff5bb5"
2065 | ],
2066 | "layout": "IPY_MODEL_850b5411122e4d608511fe26818bea68"
2067 | }
2068 | },
2069 | "a193bb3a0b5b4cbba587e2460075a445": {
2070 | "model_module": "@jupyter-widgets/controls",
2071 | "model_name": "DescriptionStyleModel",
2072 | "state": {
2073 | "_model_module": "@jupyter-widgets/controls",
2074 | "_model_module_version": "1.5.0",
2075 | "_model_name": "DescriptionStyleModel",
2076 | "_view_count": null,
2077 | "_view_module": "@jupyter-widgets/base",
2078 | "_view_module_version": "1.2.0",
2079 | "_view_name": "StyleView",
2080 | "description_width": ""
2081 | }
2082 | },
2083 | "a937f1dfeee5432ba31b3016fd30e9e2": {
2084 | "model_module": "@jupyter-widgets/controls",
2085 | "model_name": "ProgressStyleModel",
2086 | "state": {
2087 | "_model_module": "@jupyter-widgets/controls",
2088 | "_model_module_version": "1.5.0",
2089 | "_model_name": "ProgressStyleModel",
2090 | "_view_count": null,
2091 | "_view_module": "@jupyter-widgets/base",
2092 | "_view_module_version": "1.2.0",
2093 | "_view_name": "StyleView",
2094 | "bar_color": null,
2095 | "description_width": "initial"
2096 | }
2097 | },
2098 | "aa40eb6346b54e7dac98e0b068cd4927": {
2099 | "model_module": "@jupyter-widgets/controls",
2100 | "model_name": "HTMLModel",
2101 | "state": {
2102 | "_dom_classes": [],
2103 | "_model_module": "@jupyter-widgets/controls",
2104 | "_model_module_version": "1.5.0",
2105 | "_model_name": "HTMLModel",
2106 | "_view_count": null,
2107 | "_view_module": "@jupyter-widgets/controls",
2108 | "_view_module_version": "1.5.0",
2109 | "_view_name": "HTMLView",
2110 | "description": "",
2111 | "description_tooltip": null,
2112 | "layout": "IPY_MODEL_ea6b919964d24c2f9de1c64c9cefaf23",
2113 | "placeholder": "",
2114 | "style": "IPY_MODEL_9391d7abf6ed4400903995f56d7a1260",
2115 | "value": " 4/4 [02:23<00:00, 36.00s/it]"
2116 | }
2117 | },
2118 | "b4e00059cf3a49929978ed780aae8358": {
2119 | "model_module": "@jupyter-widgets/base",
2120 | "model_name": "LayoutModel",
2121 | "state": {
2122 | "_model_module": "@jupyter-widgets/base",
2123 | "_model_module_version": "1.2.0",
2124 | "_model_name": "LayoutModel",
2125 | "_view_count": null,
2126 | "_view_module": "@jupyter-widgets/base",
2127 | "_view_module_version": "1.2.0",
2128 | "_view_name": "LayoutView",
2129 | "align_content": null,
2130 | "align_items": null,
2131 | "align_self": null,
2132 | "border": null,
2133 | "bottom": null,
2134 | "display": null,
2135 | "flex": null,
2136 | "flex_flow": null,
2137 | "grid_area": null,
2138 | "grid_auto_columns": null,
2139 | "grid_auto_flow": null,
2140 | "grid_auto_rows": null,
2141 | "grid_column": null,
2142 | "grid_gap": null,
2143 | "grid_row": null,
2144 | "grid_template_areas": null,
2145 | "grid_template_columns": null,
2146 | "grid_template_rows": null,
2147 | "height": null,
2148 | "justify_content": null,
2149 | "justify_items": null,
2150 | "left": null,
2151 | "margin": null,
2152 | "max_height": null,
2153 | "max_width": null,
2154 | "min_height": null,
2155 | "min_width": null,
2156 | "object_fit": null,
2157 | "object_position": null,
2158 | "order": null,
2159 | "overflow": null,
2160 | "overflow_x": null,
2161 | "overflow_y": null,
2162 | "padding": null,
2163 | "right": null,
2164 | "top": null,
2165 | "visibility": null,
2166 | "width": null
2167 | }
2168 | },
2169 | "c6518c4a721745bf97ee682f2ebe4635": {
2170 | "model_module": "@jupyter-widgets/base",
2171 | "model_name": "LayoutModel",
2172 | "state": {
2173 | "_model_module": "@jupyter-widgets/base",
2174 | "_model_module_version": "1.2.0",
2175 | "_model_name": "LayoutModel",
2176 | "_view_count": null,
2177 | "_view_module": "@jupyter-widgets/base",
2178 | "_view_module_version": "1.2.0",
2179 | "_view_name": "LayoutView",
2180 | "align_content": null,
2181 | "align_items": null,
2182 | "align_self": null,
2183 | "border": null,
2184 | "bottom": null,
2185 | "display": null,
2186 | "flex": null,
2187 | "flex_flow": null,
2188 | "grid_area": null,
2189 | "grid_auto_columns": null,
2190 | "grid_auto_flow": null,
2191 | "grid_auto_rows": null,
2192 | "grid_column": null,
2193 | "grid_gap": null,
2194 | "grid_row": null,
2195 | "grid_template_areas": null,
2196 | "grid_template_columns": null,
2197 | "grid_template_rows": null,
2198 | "height": null,
2199 | "justify_content": null,
2200 | "justify_items": null,
2201 | "left": null,
2202 | "margin": null,
2203 | "max_height": null,
2204 | "max_width": null,
2205 | "min_height": null,
2206 | "min_width": null,
2207 | "object_fit": null,
2208 | "object_position": null,
2209 | "order": null,
2210 | "overflow": null,
2211 | "overflow_x": null,
2212 | "overflow_y": null,
2213 | "padding": null,
2214 | "right": null,
2215 | "top": null,
2216 | "visibility": null,
2217 | "width": null
2218 | }
2219 | },
2220 | "cb7f52610fcf49bda46a14b296ff5bb5": {
2221 | "model_module": "@jupyter-widgets/controls",
2222 | "model_name": "HTMLModel",
2223 | "state": {
2224 | "_dom_classes": [],
2225 | "_model_module": "@jupyter-widgets/controls",
2226 | "_model_module_version": "1.5.0",
2227 | "_model_name": "HTMLModel",
2228 | "_view_count": null,
2229 | "_view_module": "@jupyter-widgets/controls",
2230 | "_view_module_version": "1.5.0",
2231 | "_view_name": "HTMLView",
2232 | "description": "",
2233 | "description_tooltip": null,
2234 | "layout": "IPY_MODEL_8bcc625c0f284398bbd287fe45021b17",
2235 | "placeholder": "",
2236 | "style": "IPY_MODEL_4cacf7fc20754a7ca7fe08c8ec187a81",
2237 | "value": " 21/21 [00:01<00:00, 10.78it/s]"
2238 | }
2239 | },
2240 | "cf5cd281fa3b453093e210650bf81e9e": {
2241 | "model_module": "@jupyter-widgets/controls",
2242 | "model_name": "HTMLModel",
2243 | "state": {
2244 | "_dom_classes": [],
2245 | "_model_module": "@jupyter-widgets/controls",
2246 | "_model_module_version": "1.5.0",
2247 | "_model_name": "HTMLModel",
2248 | "_view_count": null,
2249 | "_view_module": "@jupyter-widgets/controls",
2250 | "_view_module_version": "1.5.0",
2251 | "_view_name": "HTMLView",
2252 | "description": "",
2253 | "description_tooltip": null,
2254 | "layout": "IPY_MODEL_002f56aac3d64b33a0e799c0baf1e6b9",
2255 | "placeholder": "",
2256 | "style": "IPY_MODEL_8b3a41c1900b45ebb9c56601deca0e84",
2257 | "value": " 195/195 [00:40<00:00, 4.84it/s]"
2258 | }
2259 | },
2260 | "dc27e2caf1ea4a4ab9ae3708fb06952f": {
2261 | "model_module": "@jupyter-widgets/base",
2262 | "model_name": "LayoutModel",
2263 | "state": {
2264 | "_model_module": "@jupyter-widgets/base",
2265 | "_model_module_version": "1.2.0",
2266 | "_model_name": "LayoutModel",
2267 | "_view_count": null,
2268 | "_view_module": "@jupyter-widgets/base",
2269 | "_view_module_version": "1.2.0",
2270 | "_view_name": "LayoutView",
2271 | "align_content": null,
2272 | "align_items": null,
2273 | "align_self": null,
2274 | "border": null,
2275 | "bottom": null,
2276 | "display": null,
2277 | "flex": null,
2278 | "flex_flow": null,
2279 | "grid_area": null,
2280 | "grid_auto_columns": null,
2281 | "grid_auto_flow": null,
2282 | "grid_auto_rows": null,
2283 | "grid_column": null,
2284 | "grid_gap": null,
2285 | "grid_row": null,
2286 | "grid_template_areas": null,
2287 | "grid_template_columns": null,
2288 | "grid_template_rows": null,
2289 | "height": null,
2290 | "justify_content": null,
2291 | "justify_items": null,
2292 | "left": null,
2293 | "margin": null,
2294 | "max_height": null,
2295 | "max_width": null,
2296 | "min_height": null,
2297 | "min_width": null,
2298 | "object_fit": null,
2299 | "object_position": null,
2300 | "order": null,
2301 | "overflow": null,
2302 | "overflow_x": null,
2303 | "overflow_y": null,
2304 | "padding": null,
2305 | "right": null,
2306 | "top": null,
2307 | "visibility": null,
2308 | "width": null
2309 | }
2310 | },
2311 | "de252cd193114c40ad5f5e9622b7abc7": {
2312 | "model_module": "@jupyter-widgets/controls",
2313 | "model_name": "DescriptionStyleModel",
2314 | "state": {
2315 | "_model_module": "@jupyter-widgets/controls",
2316 | "_model_module_version": "1.5.0",
2317 | "_model_name": "DescriptionStyleModel",
2318 | "_view_count": null,
2319 | "_view_module": "@jupyter-widgets/base",
2320 | "_view_module_version": "1.2.0",
2321 | "_view_name": "StyleView",
2322 | "description_width": ""
2323 | }
2324 | },
2325 | "e1fbe239c2394cbf973ac5b95e1e1491": {
2326 | "model_module": "@jupyter-widgets/controls",
2327 | "model_name": "ProgressStyleModel",
2328 | "state": {
2329 | "_model_module": "@jupyter-widgets/controls",
2330 | "_model_module_version": "1.5.0",
2331 | "_model_name": "ProgressStyleModel",
2332 | "_view_count": null,
2333 | "_view_module": "@jupyter-widgets/base",
2334 | "_view_module_version": "1.2.0",
2335 | "_view_name": "StyleView",
2336 | "bar_color": null,
2337 | "description_width": "initial"
2338 | }
2339 | },
2340 | "e38fb98fd7b3413392dc39c93a107a35": {
2341 | "model_module": "@jupyter-widgets/controls",
2342 | "model_name": "FloatProgressModel",
2343 | "state": {
2344 | "_dom_classes": [],
2345 | "_model_module": "@jupyter-widgets/controls",
2346 | "_model_module_version": "1.5.0",
2347 | "_model_name": "FloatProgressModel",
2348 | "_view_count": null,
2349 | "_view_module": "@jupyter-widgets/controls",
2350 | "_view_module_version": "1.5.0",
2351 | "_view_name": "ProgressView",
2352 | "bar_style": "success",
2353 | "description": "Iteration: 100%",
2354 | "description_tooltip": null,
2355 | "layout": "IPY_MODEL_43fdb31d3f314624ba07a15718b0c8f3",
2356 | "max": 195,
2357 | "min": 0,
2358 | "orientation": "horizontal",
2359 | "style": "IPY_MODEL_4699416338ae40a5b6abf19e45089aec",
2360 | "value": 195
2361 | }
2362 | },
2363 | "e7b9f3fc77a24259a87ef0dc735dfecb": {
2364 | "model_module": "@jupyter-widgets/base",
2365 | "model_name": "LayoutModel",
2366 | "state": {
2367 | "_model_module": "@jupyter-widgets/base",
2368 | "_model_module_version": "1.2.0",
2369 | "_model_name": "LayoutModel",
2370 | "_view_count": null,
2371 | "_view_module": "@jupyter-widgets/base",
2372 | "_view_module_version": "1.2.0",
2373 | "_view_name": "LayoutView",
2374 | "align_content": null,
2375 | "align_items": null,
2376 | "align_self": null,
2377 | "border": null,
2378 | "bottom": null,
2379 | "display": null,
2380 | "flex": null,
2381 | "flex_flow": null,
2382 | "grid_area": null,
2383 | "grid_auto_columns": null,
2384 | "grid_auto_flow": null,
2385 | "grid_auto_rows": null,
2386 | "grid_column": null,
2387 | "grid_gap": null,
2388 | "grid_row": null,
2389 | "grid_template_areas": null,
2390 | "grid_template_columns": null,
2391 | "grid_template_rows": null,
2392 | "height": null,
2393 | "justify_content": null,
2394 | "justify_items": null,
2395 | "left": null,
2396 | "margin": null,
2397 | "max_height": null,
2398 | "max_width": null,
2399 | "min_height": null,
2400 | "min_width": null,
2401 | "object_fit": null,
2402 | "object_position": null,
2403 | "order": null,
2404 | "overflow": null,
2405 | "overflow_x": null,
2406 | "overflow_y": null,
2407 | "padding": null,
2408 | "right": null,
2409 | "top": null,
2410 | "visibility": null,
2411 | "width": null
2412 | }
2413 | },
2414 | "ea6b919964d24c2f9de1c64c9cefaf23": {
2415 | "model_module": "@jupyter-widgets/base",
2416 | "model_name": "LayoutModel",
2417 | "state": {
2418 | "_model_module": "@jupyter-widgets/base",
2419 | "_model_module_version": "1.2.0",
2420 | "_model_name": "LayoutModel",
2421 | "_view_count": null,
2422 | "_view_module": "@jupyter-widgets/base",
2423 | "_view_module_version": "1.2.0",
2424 | "_view_name": "LayoutView",
2425 | "align_content": null,
2426 | "align_items": null,
2427 | "align_self": null,
2428 | "border": null,
2429 | "bottom": null,
2430 | "display": null,
2431 | "flex": null,
2432 | "flex_flow": null,
2433 | "grid_area": null,
2434 | "grid_auto_columns": null,
2435 | "grid_auto_flow": null,
2436 | "grid_auto_rows": null,
2437 | "grid_column": null,
2438 | "grid_gap": null,
2439 | "grid_row": null,
2440 | "grid_template_areas": null,
2441 | "grid_template_columns": null,
2442 | "grid_template_rows": null,
2443 | "height": null,
2444 | "justify_content": null,
2445 | "justify_items": null,
2446 | "left": null,
2447 | "margin": null,
2448 | "max_height": null,
2449 | "max_width": null,
2450 | "min_height": null,
2451 | "min_width": null,
2452 | "object_fit": null,
2453 | "object_position": null,
2454 | "order": null,
2455 | "overflow": null,
2456 | "overflow_x": null,
2457 | "overflow_y": null,
2458 | "padding": null,
2459 | "right": null,
2460 | "top": null,
2461 | "visibility": null,
2462 | "width": null
2463 | }
2464 | },
2465 | "f3bf54733c2d4d9daa1cc9a7746ccb14": {
2466 | "model_module": "@jupyter-widgets/controls",
2467 | "model_name": "FloatProgressModel",
2468 | "state": {
2469 | "_dom_classes": [],
2470 | "_model_module": "@jupyter-widgets/controls",
2471 | "_model_module_version": "1.5.0",
2472 | "_model_name": "FloatProgressModel",
2473 | "_view_count": null,
2474 | "_view_module": "@jupyter-widgets/controls",
2475 | "_view_module_version": "1.5.0",
2476 | "_view_name": "ProgressView",
2477 | "bar_style": "success",
2478 | "description": "Epoch: 100%",
2479 | "description_tooltip": null,
2480 | "layout": "IPY_MODEL_450b0e7fd7a347c7beb78b7d72f64385",
2481 | "max": 4,
2482 | "min": 0,
2483 | "orientation": "horizontal",
2484 | "style": "IPY_MODEL_021b771a270f479aa3b9e2b5f17e3d97",
2485 | "value": 4
2486 | }
2487 | },
2488 | "f871b83632974e0088bae65e78efaf28": {
2489 | "model_module": "@jupyter-widgets/base",
2490 | "model_name": "LayoutModel",
2491 | "state": {
2492 | "_model_module": "@jupyter-widgets/base",
2493 | "_model_module_version": "1.2.0",
2494 | "_model_name": "LayoutModel",
2495 | "_view_count": null,
2496 | "_view_module": "@jupyter-widgets/base",
2497 | "_view_module_version": "1.2.0",
2498 | "_view_name": "LayoutView",
2499 | "align_content": null,
2500 | "align_items": null,
2501 | "align_self": null,
2502 | "border": null,
2503 | "bottom": null,
2504 | "display": null,
2505 | "flex": null,
2506 | "flex_flow": null,
2507 | "grid_area": null,
2508 | "grid_auto_columns": null,
2509 | "grid_auto_flow": null,
2510 | "grid_auto_rows": null,
2511 | "grid_column": null,
2512 | "grid_gap": null,
2513 | "grid_row": null,
2514 | "grid_template_areas": null,
2515 | "grid_template_columns": null,
2516 | "grid_template_rows": null,
2517 | "height": null,
2518 | "justify_content": null,
2519 | "justify_items": null,
2520 | "left": null,
2521 | "margin": null,
2522 | "max_height": null,
2523 | "max_width": null,
2524 | "min_height": null,
2525 | "min_width": null,
2526 | "object_fit": null,
2527 | "object_position": null,
2528 | "order": null,
2529 | "overflow": null,
2530 | "overflow_x": null,
2531 | "overflow_y": null,
2532 | "padding": null,
2533 | "right": null,
2534 | "top": null,
2535 | "visibility": null,
2536 | "width": null
2537 | }
2538 | },
2539 | "fdffb26b99c24c978580f1cf97359fea": {
2540 | "model_module": "@jupyter-widgets/controls",
2541 | "model_name": "DescriptionStyleModel",
2542 | "state": {
2543 | "_model_module": "@jupyter-widgets/controls",
2544 | "_model_module_version": "1.5.0",
2545 | "_model_name": "DescriptionStyleModel",
2546 | "_view_count": null,
2547 | "_view_module": "@jupyter-widgets/base",
2548 | "_view_module_version": "1.2.0",
2549 | "_view_name": "StyleView",
2550 | "description_width": ""
2551 | }
2552 | }
2553 | }
2554 | }
2555 | },
2556 | "nbformat": 4,
2557 | "nbformat_minor": 1
2558 | }
2559 |
--------------------------------------------------------------------------------