├── CNN_MNIST.ipynb
├── Clustering.ipynb
├── GAN.ipynb
├── GAN_losses.ipynb
├── GloVe.ipynb
├── MLP_(scikit_learn).ipynb
├── MLP_(scikit_learn)_solution.ipynb
├── MLP_MNIST.ipynb
├── MLP_autoencoder.ipynb
├── MLP_regression.ipynb
├── README.md
├── SSD.ipynb
├── SVM (scikit-learn).ipynb
├── SVM_(scikit_learn)_solution.ipynb
├── dataset.ipynb
├── kNN_IRIS.ipynb
├── kNN_IRIS_solution.ipynb
├── linear_classifier.ipynb
├── linear_classifier_solution.ipynb
├── linear_regression.ipynb
├── linear_regression_solution.ipynb
├── python_core.ipynb
├── python_tutorial.ipynb
├── python_tutorial_solution.ipynb
├── sequence_prediction (LSTM per seq).ipynb
├── sequence_prediction (LSTM per time).ipynb
├── sequence_prediction (RNN).ipynb
├── skip-gram.ipynb
├── transfer_learning_(ants_and_bees).ipynb
└── tree_and_ensemble.ipynb
/CNN_MNIST.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "CNN_MNIST.ipynb",
7 | "version": "0.3.2",
8 | "provenance": [],
9 | "collapsed_sections": [],
10 | "include_colab_link": true
11 | },
12 | "kernelspec": {
13 | "name": "python3",
14 | "display_name": "Python 3"
15 | },
16 | "accelerator": "GPU"
17 | },
18 | "cells": [
19 | {
20 | "cell_type": "markdown",
21 | "metadata": {
22 | "id": "view-in-github",
23 | "colab_type": "text"
24 | },
25 | "source": [
26 | "
"
27 | ]
28 | },
29 | {
30 | "metadata": {
31 | "id": "W7zsPERtr0LC",
32 | "colab_type": "text"
33 | },
34 | "cell_type": "markdown",
35 | "source": [
36 | "# Convolutional Neural Nets (CNN)"
37 | ]
38 | },
39 | {
40 | "metadata": {
41 | "id": "9iL5KH63r0LE",
42 | "colab_type": "code",
43 | "colab": {}
44 | },
45 | "cell_type": "code",
46 | "source": [
47 | "# If necessary, uncommand and run the following line to install pytorch\n",
48 | "#!pip install torch torchvision"
49 | ],
50 | "execution_count": 0,
51 | "outputs": []
52 | },
53 | {
54 | "metadata": {
55 | "id": "dEBe8KzEr0LJ",
56 | "colab_type": "text"
57 | },
58 | "cell_type": "markdown",
59 | "source": [
60 | "과거에는 Google Colab를 pytorch를 기본제공하지 않았으므로 Google Colab에서 pytorch를 사용하려면 먼저 pytorch를 설치해야 함.\n",
61 | "\n",
62 | "현재는 불필요함."
63 | ]
64 | },
65 | {
66 | "metadata": {
67 | "id": "YbUmP-VGVvcP",
68 | "colab_type": "code",
69 | "colab": {}
70 | },
71 | "cell_type": "code",
72 | "source": [
73 | "import numpy as np\n",
74 | "import datetime"
75 | ],
76 | "execution_count": 0,
77 | "outputs": []
78 | },
79 | {
80 | "metadata": {
81 | "id": "QoPiqjikr0LT",
82 | "colab_type": "text"
83 | },
84 | "cell_type": "markdown",
85 | "source": [
86 | "**numpy**는 다차원 배열 및 벡터/ 행렬 기본 연산\n",
87 | "python으로 data science를 할 때 가장 기본이 되는 라이브러리 중 하나.\n",
88 | "\n",
89 | "**datetime** 학습/실행 시간 측정을 위한 package"
90 | ]
91 | },
92 | {
93 | "metadata": {
94 | "id": "kjpK8daar0LK",
95 | "colab_type": "code",
96 | "outputId": "75c43b76-4d24-4ade-d0a6-2e0f5555410b",
97 | "colab": {
98 | "base_uri": "https://localhost:8080/",
99 | "height": 35
100 | }
101 | },
102 | "cell_type": "code",
103 | "source": [
104 | "import torch\n",
105 | "import torchvision \n",
106 | "import torch.nn as nn\n",
107 | "\n",
108 | "print(torch.__version__)"
109 | ],
110 | "execution_count": 4,
111 | "outputs": [
112 | {
113 | "output_type": "stream",
114 | "text": [
115 | "1.0.1.post2\n"
116 | ],
117 | "name": "stdout"
118 | }
119 | ]
120 | },
121 | {
122 | "metadata": {
123 | "id": "LhNkRvYkr0LN",
124 | "colab_type": "text"
125 | },
126 | "cell_type": "markdown",
127 | "source": [
128 | "**torch**: pytorch package\n",
129 | "\n",
130 | "**torch.nn**: 신경망 모델에 Class들을 포함\n",
131 | "\n",
132 | "**torchvision**은 computer vision에 많이 사용되는 dataset, model, transform을 포함 (https://pytorch.org/docs/stable/torchvision/index.html)"
133 | ]
134 | },
135 | {
136 | "metadata": {
137 | "id": "ntWD5II587og",
138 | "colab_type": "code",
139 | "colab": {}
140 | },
141 | "cell_type": "code",
142 | "source": [
143 | "from torch.utils.data import DataLoader"
144 | ],
145 | "execution_count": 0,
146 | "outputs": []
147 | },
148 | {
149 | "metadata": {
150 | "id": "YgsGVrkrAEPD",
151 | "colab_type": "text"
152 | },
153 | "cell_type": "markdown",
154 | "source": [
155 | "Data Loader: 데이터 로드를 위한 패키지 (Dataset + Sampler + Iterator)\n",
156 | "> * Dataset is an abstract class representing a dataset \n",
157 | "> * Sampler provides a way to iterate over indices of dataset elements\n",
158 | "\n",
159 | "\n",
160 | "See https://pytorch.org/docs/stable/data.html"
161 | ]
162 | },
163 | {
164 | "metadata": {
165 | "id": "_tUvuvvPr0LO",
166 | "colab_type": "code",
167 | "colab": {}
168 | },
169 | "cell_type": "code",
170 | "source": [
171 | "from torchvision import datasets\n",
172 | "from torchvision import transforms"
173 | ],
174 | "execution_count": 0,
175 | "outputs": []
176 | },
177 | {
178 | "metadata": {
179 | "id": "Odn-STAQAmKq",
180 | "colab_type": "text"
181 | },
182 | "cell_type": "markdown",
183 | "source": [
184 | "**dataset**: MNIST, fashion MNIST, COCO, LSUN, CIFAR, etc.\n",
185 | "\n",
186 | "**transforms**: algorithms for preprocessing or data augmentation\n",
187 | "\n",
188 | "See https://pytorch.org/docs/stable/torchvision/index.html to know datasets and transforms in torchvision"
189 | ]
190 | },
191 | {
192 | "metadata": {
193 | "id": "mtKXCay2kN5Z",
194 | "colab_type": "text"
195 | },
196 | "cell_type": "markdown",
197 | "source": [
198 | "# Using (deep) neural networks with python\n",
199 | "\n",
200 | "\n",
201 | "1. Define a network model\n",
202 | "\n",
203 | "2. Prepare data\n",
204 | "\n",
205 | "3. Train the model\n",
206 | "\n",
207 | "4. Evalute the model"
208 | ]
209 | },
210 | {
211 | "metadata": {
212 | "id": "eCc7O8M5r0LT",
213 | "colab_type": "code",
214 | "colab": {}
215 | },
216 | "cell_type": "code",
217 | "source": [
218 | "#import matplotlib\n",
219 | "\n",
220 | "%matplotlib inline\n",
221 | "from matplotlib.pyplot import imshow, imsave"
222 | ],
223 | "execution_count": 0,
224 | "outputs": []
225 | },
226 | {
227 | "metadata": {
228 | "id": "qgaevoCUr0LW",
229 | "colab_type": "text"
230 | },
231 | "cell_type": "markdown",
232 | "source": [
233 | "**matplotlib**: python visualization library"
234 | ]
235 | },
236 | {
237 | "metadata": {
238 | "id": "hYk9uq_Br0LW",
239 | "colab_type": "code",
240 | "outputId": "c4a75a2f-b82b-4ad4-bc0a-99fa2a486b37",
241 | "colab": {
242 | "base_uri": "https://localhost:8080/",
243 | "height": 35
244 | }
245 | },
246 | "cell_type": "code",
247 | "source": [
248 | "MODEL_NAME = 'DNN'\n",
249 | "DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
250 | "print(\"MODEL_NAME = {}, DEVICE = {}\".format(MODEL_NAME, DEVICE))"
251 | ],
252 | "execution_count": 8,
253 | "outputs": [
254 | {
255 | "output_type": "stream",
256 | "text": [
257 | "MODEL_NAME = DNN, DEVICE = cuda\n"
258 | ],
259 | "name": "stdout"
260 | }
261 | ]
262 | },
263 | {
264 | "metadata": {
265 | "id": "AhUqWS9-r0La",
266 | "colab_type": "text"
267 | },
268 | "cell_type": "markdown",
269 | "source": [
270 | "GPU가 있다면 GPU를 통해 학습을 가속화하고, 없으면 CPU로 학습하기 위해 device를 정해준다.\n",
271 | "\n",
272 | "**torch.cuda.is_avaliable()**은 GPU가 사용가능한지를 판단하는 함수"
273 | ]
274 | },
275 | {
276 | "metadata": {
277 | "id": "8ljFCGEIXfCy",
278 | "colab_type": "text"
279 | },
280 | "cell_type": "markdown",
281 | "source": [
282 | "## Defining a Neural Network model using pytorch\n",
283 | "\n",
284 | "1. Define a neural net model\n",
285 | "\n",
286 | "> * Define a model class inheriting **nn.module**\n",
287 | "\n",
288 | ">> nn.module is the base class of all layers/operators\n",
289 | "\n",
290 | ">* Define **__init__** function (constructor)\n",
291 | "\n",
292 | ">> Create layers and operators\n",
293 | "\n",
294 | ">* Define **forward** function (forward propagation)\n",
295 | "\n",
296 | ">> Define how to compute the output from the input\n",
297 | "\n",
298 | "> Example\n",
299 | "\n",
300 | "~~~~\n",
301 | " class Model(nn.Module):\n",
302 | " def __init__(self):\n",
303 | " super(Model, self).__init__()\n",
304 | " self.conv1 = nn.Conv2d(1, 20, 5)\n",
305 | " self.conv2 = nn.Conv2d(20, 20, 5) \n",
306 | " \n",
307 | " def forward(self, x):\n",
308 | " x = F.relu(self.conv1(x))\n",
309 | " return F.relu(self.conv2(x))\n",
310 | "~~~~\n",
311 | "\n",
312 | "> Note! You don't need to backpropagation procedure, because pytorch provides **autograd**"
313 | ]
314 | },
315 | {
316 | "metadata": {
317 | "id": "sbpvoKXzr0Lb",
318 | "colab_type": "code",
319 | "colab": {}
320 | },
321 | "cell_type": "code",
322 | "source": [
323 | "class HelloCNN(nn.Module):\n",
324 | " \"\"\"\n",
325 | " Simple CNN Clssifier\n",
326 | " \"\"\"\n",
327 | " def __init__(self, num_classes=10):\n",
328 | " super(HelloCNN, self).__init__()\n",
329 | " \n",
330 | " self.conv = nn.Sequential(\n",
331 | " # (N, 1, 28, 28)\n",
332 | " nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),\n",
333 | " nn.ReLU(),\n",
334 | " nn.MaxPool2d(2, 2),\n",
335 | " # (N, 32, 14, 14)\n",
336 | " nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),\n",
337 | " nn.ReLU(),\n",
338 | " nn.MaxPool2d(2, 2),\n",
339 | " # (N, 64, 7, 7)\n",
340 | " )\n",
341 | " self.fc = nn.Sequential(\n",
342 | " nn.Linear(7*7*64, 512),\n",
343 | " nn.Dropout(p=0.5),\n",
344 | " nn.Linear(512, num_classes),\n",
345 | " )\n",
346 | " \n",
347 | " def forward(self, x):\n",
348 | " y_ = self.conv(x) # (N, 64, 7, 7)\n",
349 | " y_ = y_.view(y_.size(0), -1) # (N, 64*7*7)\n",
350 | " y_ = self.fc(y_)\n",
351 | " return y_"
352 | ],
353 | "execution_count": 0,
354 | "outputs": []
355 | },
356 | {
357 | "metadata": {
358 | "id": "BfM4wnwdr0Le",
359 | "colab_type": "text"
360 | },
361 | "cell_type": "markdown",
362 | "source": [
363 | "**nn.Sequential()**: a sequential container.\n",
364 | "\n",
365 | "* Example of using Sequential\n",
366 | "\n",
367 | "~~~~\n",
368 | " model = nn.Sequential(\n",
369 | " nn.Conv2d(1,20,5),\n",
370 | " nn.ReLU(),\n",
371 | " nn.Conv2d(20,64,5),\n",
372 | " nn.ReLU()\n",
373 | " )\n",
374 | "~~~~\n",
375 | "* Example of using Sequential with OrderedDict\n",
376 | "\n",
377 | "~~~~\n",
378 | " model = nn.Sequential(OrderedDict([\n",
379 | " ('conv1', nn.Conv2d(1,20,5)),\n",
380 | " ('relu1', nn.ReLU()),\n",
381 | " ('conv2', nn.Conv2d(20,64,5)),\n",
382 | " ('relu2', nn.ReLU())\n",
383 | " ]))\n",
384 | "~~~~\n",
385 | "\n",
386 | "**nn.ModuleList**: a list-like container class \n",
387 | "\n",
388 | "* Example of using ModuleList\n",
389 | "\n",
390 | "~~~~\n",
391 | " class MyModule(nn.Module):\n",
392 | " def __init__(self):\n",
393 | " super(MyModule, self).__init__()\n",
394 | " self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])\n",
395 | " \n",
396 | " def forward(self, x):\n",
397 | " # ModuleList can act as an iterable, or be indexed using ints\n",
398 | " for i, l in enumerate(self.linears):\n",
399 | " x = self.linears[i // 2](x) + l(x)\n",
400 | " return x\n",
401 | " \n",
402 | "~~~~"
403 | ]
404 | },
405 | {
406 | "metadata": {
407 | "id": "fyP-DP1Sr0Lh",
408 | "colab_type": "code",
409 | "colab": {}
410 | },
411 | "cell_type": "code",
412 | "source": [
413 | "model = HelloCNN().to(DEVICE)"
414 | ],
415 | "execution_count": 0,
416 | "outputs": []
417 | },
418 | {
419 | "metadata": {
420 | "id": "cyyU-712c1Np",
421 | "colab_type": "text"
422 | },
423 | "cell_type": "markdown",
424 | "source": [
425 | "Moves and/or casts the parameters and buffers. (CPU or GPU)"
426 | ]
427 | },
428 | {
429 | "metadata": {
430 | "id": "Z9-QvNvbksQj",
431 | "colab_type": "text"
432 | },
433 | "cell_type": "markdown",
434 | "source": [
435 | "## Loading and preprocessing of data\n",
436 | "\n"
437 | ]
438 | },
439 | {
440 | "metadata": {
441 | "id": "Glhmh_OHlBMY",
442 | "colab_type": "text"
443 | },
444 | "cell_type": "markdown",
445 | "source": [
446 | "Transform of input data"
447 | ]
448 | },
449 | {
450 | "metadata": {
451 | "id": "oq2EYzdBr0Lk",
452 | "colab_type": "code",
453 | "colab": {}
454 | },
455 | "cell_type": "code",
456 | "source": [
457 | "transform = transforms.Compose(\n",
458 | " [transforms.ToTensor(), # image to tensor\n",
459 | " transforms.Normalize(mean=(0.1307,), std=(0.3081,)) # normalize to \"(x-mean)/std\"\n",
460 | " ])"
461 | ],
462 | "execution_count": 0,
463 | "outputs": []
464 | },
465 | {
466 | "metadata": {
467 | "id": "Vbb9TUIOz3tb",
468 | "colab_type": "text"
469 | },
470 | "cell_type": "markdown",
471 | "source": [
472 | ""
473 | ]
474 | },
475 | {
476 | "metadata": {
477 | "id": "2hE6ydS3r0Lm",
478 | "colab_type": "text"
479 | },
480 | "cell_type": "markdown",
481 | "source": [
482 | "**transforms**: torchvision에서 제공하는 transform 함수들이 있는 패키지.\n",
483 | "\n",
484 | "**ToTensor**: numpy array를 torch tensor로 변환.\n",
485 | "\n",
486 | "**Normalize**: 정규화 함수 output[channel] = (input[channel] - mean[channel]) / std[channel]"
487 | ]
488 | },
489 | {
490 | "metadata": {
491 | "id": "buibF906r0Lm",
492 | "colab_type": "code",
493 | "outputId": "7179a4de-3cd8-4d6d-ff84-214d02ddc810",
494 | "colab": {
495 | "base_uri": "https://localhost:8080/",
496 | "height": 289
497 | }
498 | },
499 | "cell_type": "code",
500 | "source": [
501 | "mnist_train = datasets.MNIST(root='../data/', train=True, transform=transform, download=True)\n",
502 | "mnist_test = datasets.MNIST(root='../data/', train=False, transform=transform, download=True)"
503 | ],
504 | "execution_count": 12,
505 | "outputs": [
506 | {
507 | "output_type": "stream",
508 | "text": [
509 | "\r0it [00:00, ?it/s]"
510 | ],
511 | "name": "stderr"
512 | },
513 | {
514 | "output_type": "stream",
515 | "text": [
516 | "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../data/MNIST/raw/train-images-idx3-ubyte.gz\n"
517 | ],
518 | "name": "stdout"
519 | },
520 | {
521 | "output_type": "stream",
522 | "text": [
523 | "9920512it [00:01, 9240944.23it/s] \n"
524 | ],
525 | "name": "stderr"
526 | },
527 | {
528 | "output_type": "stream",
529 | "text": [
530 | "Extracting ../data/MNIST/raw/train-images-idx3-ubyte.gz\n"
531 | ],
532 | "name": "stdout"
533 | },
534 | {
535 | "output_type": "stream",
536 | "text": [
537 | " 0%| | 0/28881 [00:00, ?it/s]"
538 | ],
539 | "name": "stderr"
540 | },
541 | {
542 | "output_type": "stream",
543 | "text": [
544 | "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../data/MNIST/raw/train-labels-idx1-ubyte.gz\n"
545 | ],
546 | "name": "stdout"
547 | },
548 | {
549 | "output_type": "stream",
550 | "text": [
551 | "32768it [00:00, 136930.80it/s] \n",
552 | " 0%| | 0/1648877 [00:00, ?it/s]"
553 | ],
554 | "name": "stderr"
555 | },
556 | {
557 | "output_type": "stream",
558 | "text": [
559 | "Extracting ../data/MNIST/raw/train-labels-idx1-ubyte.gz\n",
560 | "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw/t10k-images-idx3-ubyte.gz\n"
561 | ],
562 | "name": "stdout"
563 | },
564 | {
565 | "output_type": "stream",
566 | "text": [
567 | "1654784it [00:00, 2248538.48it/s] \n",
568 | "0it [00:00, ?it/s]"
569 | ],
570 | "name": "stderr"
571 | },
572 | {
573 | "output_type": "stream",
574 | "text": [
575 | "Extracting ../data/MNIST/raw/t10k-images-idx3-ubyte.gz\n",
576 | "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz\n"
577 | ],
578 | "name": "stdout"
579 | },
580 | {
581 | "output_type": "stream",
582 | "text": [
583 | "8192it [00:00, 52056.11it/s] \n"
584 | ],
585 | "name": "stderr"
586 | },
587 | {
588 | "output_type": "stream",
589 | "text": [
590 | "Extracting ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz\n",
591 | "Processing...\n",
592 | "Done!\n"
593 | ],
594 | "name": "stdout"
595 | }
596 | ]
597 | },
598 | {
599 | "metadata": {
600 | "id": "lYL_NEavr0Lo",
601 | "colab_type": "text"
602 | },
603 | "cell_type": "markdown",
604 | "source": [
605 | "**datasets**에는 여러 데이터들에 대해 다운로드하고 처리하는 클래스가 내장되어 있음. [참고](https://pytorch.org/docs/stable/torchvision/datasets.html)\n",
606 | "\n",
607 | "root 폴더에 없을 시에 download하고, 앞서 정의한 transform에 따라 전처리 된 데이터를 return함."
608 | ]
609 | },
610 | {
611 | "metadata": {
612 | "id": "wGbeeCkCr0Lp",
613 | "colab_type": "code",
614 | "colab": {}
615 | },
616 | "cell_type": "code",
617 | "source": [
618 | "batch_size = 64"
619 | ],
620 | "execution_count": 0,
621 | "outputs": []
622 | },
623 | {
624 | "metadata": {
625 | "id": "6_n0AFCmr0Lr",
626 | "colab_type": "code",
627 | "colab": {}
628 | },
629 | "cell_type": "code",
630 | "source": [
631 | "train_loader = DataLoader(dataset=mnist_train, batch_size=batch_size, shuffle=True, drop_last=True)\n",
632 | "test_loader = DataLoader(dataset=mnist_test, batch_size=100, shuffle=False, drop_last=False)"
633 | ],
634 | "execution_count": 0,
635 | "outputs": []
636 | },
637 | {
638 | "metadata": {
639 | "id": "GR9GEof0r0Ls",
640 | "colab_type": "text"
641 | },
642 | "cell_type": "markdown",
643 | "source": [
644 | "**DataLoader**는 pytorch에서 학습 시에 데이터를 배치 사이즈만큼씩 효율적으로 불러오도록 돕는 클래스. 잘 사용할수록 GPU의 사용률이 올라간다.\n",
645 | "\n",
646 | "**shuffle**: every epochs 마다 데이터의 순서를 랜덤하게 섞는다.\n",
647 | "\n",
648 | "**drop_last**: 데이터의 개수가 배치 사이즈로 나눠떨어지지 않는 경우, 마지막 배치를 버린다. 주로 학습시에만 사용."
649 | ]
650 | },
651 | {
652 | "metadata": {
653 | "id": "YH-t5coUmuOM",
654 | "colab_type": "text"
655 | },
656 | "cell_type": "markdown",
657 | "source": [
658 | "## Training neural network model\n",
659 | "\n",
660 | "\n",
661 | "Training procedure\n",
662 | "~~~~\n",
663 | "for epoch in range(max_epoch):\n",
664 | " for input, target in dataset: # retrieve input data and target labels\n",
665 | " optimizer.zero_grad() # reset gradient\n",
666 | " output = model(input) # forward propagation\n",
667 | " loss = loss_fn(output, target) # get loss value\n",
668 | " loss.backward() # back-propagation (compute gradient) optimizer.step() # update parameters with gradient\n",
669 | "~~~~"
670 | ]
671 | },
672 | {
673 | "metadata": {
674 | "id": "eeNx2EZKFoBP",
675 | "colab_type": "code",
676 | "colab": {}
677 | },
678 | "cell_type": "code",
679 | "source": [
680 | "# utility function to measure time\n",
681 | "\n",
682 | "import time\n",
683 | "import math\n",
684 | "\n",
685 | "def timeSince(since):\n",
686 | " now = time.time()\n",
687 | " s = now - since\n",
688 | " m = math.floor(s / 60)\n",
689 | " s -= m * 60\n",
690 | " return '%dm %ds' % (m, s)"
691 | ],
692 | "execution_count": 0,
693 | "outputs": []
694 | },
695 | {
696 | "metadata": {
697 | "id": "pEt4xYD4r0Lt",
698 | "colab_type": "code",
699 | "colab": {}
700 | },
701 | "cell_type": "code",
702 | "source": [
703 | "# set loss function and optimizer\n",
704 | "\n",
705 | "criterion = nn.CrossEntropyLoss()\n",
706 | "optim = torch.optim.Adam(model.parameters(), lr=0.001)"
707 | ],
708 | "execution_count": 0,
709 | "outputs": []
710 | },
711 | {
712 | "metadata": {
713 | "id": "waSS2hrKr0Lv",
714 | "colab_type": "text"
715 | },
716 | "cell_type": "markdown",
717 | "source": [
718 | "**nn.CrossEntropyLoss**: Cross entropy를 계산하는 Loss. softmax가 내부적으로 수행된다.\n",
719 | "\n",
720 | "**optim.Adam**: optim에는 여러 optimizer가 있고, Adam Optimizer는 대표적으로 많이 사용된다."
721 | ]
722 | },
723 | {
724 | "metadata": {
725 | "id": "K4lF0ODwr0Lz",
726 | "colab_type": "text"
727 | },
728 | "cell_type": "markdown",
729 | "source": [
730 | "### Training procedure\n",
731 | "\n",
732 | "첫번째 for문: 원하는 epoch만큼 반복\n",
733 | "\n",
734 | "두번째 for문: training datset에서 배치 사이즈 만큼씩 모두 샘플링 될 때까지 반복.\n",
735 | "\n",
736 | "**Line 2**: MNIST dataset은 DataLoader를 통해 image와 label을 return.\n",
737 | "\n",
738 | "**Line 4**: 각각 Device에 올린다 (GPU or CPU)\n",
739 | "\n",
740 | "**Line 5**: 모델에 이미지를 넣고 forward propagation 한다.\n",
741 | "\n",
742 | "**Line 7**: 결과값 y_hat과 실제 정답 y에 대한 loss를 계산한다.\n",
743 | "\n",
744 | "**zero_grad (Line 9)**: 모델의 gradient를 0으로 초기화한다.\n",
745 | "\n",
746 | "**backward (Line 10)**: loss를 계산하는 것까지 연결되어있는 graph를 따라 gradient를 계산한다.\n",
747 | "\n",
748 | "**step (Line 11)**: 계산된 gradient를 모두 parameter에 적용한다.\n",
749 | "\n",
750 | "**eval (Line 17)**: 모델을 evaluation mode로 바꿔준다 (dropout 조정, Batch normalization 조정 등)\n",
751 | "\n",
752 | "**torch.no_grad (Line 19)**: gradient를 계산하기 위해 추적하는 수고를 하지 않음\n",
753 | "\n",
754 | "**torch.max (Line 24)**: max value와 indices(즉, argmax)를 return.\n",
755 | "\n",
756 | "**train (Line 29)**: evaluation mode였던 모델을 train mode로 전환"
757 | ]
758 | },
759 | {
760 | "metadata": {
761 | "id": "zKCZDgDTNuY5",
762 | "colab_type": "code",
763 | "colab": {}
764 | },
765 | "cell_type": "code",
766 | "source": [
767 | "# reset loss history\n",
768 | "all_losses = []"
769 | ],
770 | "execution_count": 0,
771 | "outputs": []
772 | },
773 | {
774 | "metadata": {
775 | "scrolled": true,
776 | "id": "-KAxJESir0L0",
777 | "colab_type": "code",
778 | "outputId": "1c5bc7f7-5882-458f-f6ce-421638a6418a",
779 | "colab": {
780 | "base_uri": "https://localhost:8080/",
781 | "height": 467
782 | }
783 | },
784 | "cell_type": "code",
785 | "source": [
786 | "max_epoch = 5 # maximum number of epochs\n",
787 | "step = 0 # initialize step counter variable\n",
788 | "\n",
789 | "plot_every = 200\n",
790 | "total_loss = 0 # Reset every plot_every iters\n",
791 | "\n",
792 | "start = time.time()\n",
793 | "\n",
794 | "for epoch in range(max_epoch):\n",
795 | " for idx, (images, labels) in enumerate(train_loader):\n",
796 | " # Training Discriminator\n",
797 | " x, y = images.to(DEVICE), labels.to(DEVICE) # (N, 1, 28, 28), (N, )\n",
798 | " \n",
799 | " y_hat = model(x) # (N, 10) # forward propagation\n",
800 | " \n",
801 | " loss = criterion(y_hat, y) # computing loss\n",
802 | " total_loss += loss.item()\n",
803 | " \n",
804 | " optim.zero_grad() # reset gradient\n",
805 | " loss.backward() # back-propagation (compute gradient)\n",
806 | " optim.step() # update parameters with gradient\n",
807 | " \n",
808 | " # periodically print loss\n",
809 | " if step % 500 == 0:\n",
810 | " print('Epoch({}): {}/{}, Step: {}, Loss: {}'.format(timeSince(start), epoch, max_epoch, step, loss.item()))\n",
811 | " \n",
812 | " if (step + 1) % plot_every == 0:\n",
813 | " all_losses.append(total_loss / plot_every)\n",
814 | " total_loss = 0\n",
815 | " \n",
816 | " # periodically evalute model on test data\n",
817 | " if step % 1000 == 0:\n",
818 | " model.eval()\n",
819 | " acc = 0.\n",
820 | " with torch.no_grad(): # disable autograd\n",
821 | " for idx, (images, labels) in enumerate(test_loader):\n",
822 | " x, y = images.to(DEVICE), labels.to(DEVICE) # (N, 1, 28, 28), (N, )\n",
823 | " y_hat = model(x) # (N, 10)\n",
824 | " loss = criterion(y_hat, y)\n",
825 | " _, indices = torch.max(y_hat, dim=-1) # find maxmum along the last axis (argmax of each row)\n",
826 | " # ex) max_value, max_idx = torch.max(input, dim)\n",
827 | " acc += torch.sum(indices == y).item() # count correctly classified samples\n",
828 | " # torch.sum() returns Tensor. Tensor.item() converts it to a value\n",
829 | " print('*'*20, 'Test', '*'*20)\n",
830 | " print('Step: {}, Loss: {}, Accuracy: {} %'.format(step, loss.item(), acc/len(mnist_test)*100))\n",
831 | " print('*'*46)\n",
832 | " model.train() # turn to train mode (enable autograd)\n",
833 | " step += 1"
834 | ],
835 | "execution_count": 18,
836 | "outputs": [
837 | {
838 | "output_type": "stream",
839 | "text": [
840 | "Epoch(0m 0s): 0/5, Step: 0, Loss: 2.2575552463531494\n",
841 | "******************** Test ********************\n",
842 | "Step: 0, Loss: 2.9159839153289795, Accuracy: 9.58 %\n",
843 | "**********************************************\n",
844 | "Epoch(0m 7s): 0/5, Step: 500, Loss: 0.0489889457821846\n",
845 | "Epoch(0m 12s): 1/5, Step: 1000, Loss: 0.0709199458360672\n",
846 | "******************** Test ********************\n",
847 | "Step: 1000, Loss: 0.011706847697496414, Accuracy: 98.59 %\n",
848 | "**********************************************\n",
849 | "Epoch(0m 19s): 1/5, Step: 1500, Loss: 0.08090385049581528\n",
850 | "Epoch(0m 25s): 2/5, Step: 2000, Loss: 0.006270997226238251\n",
851 | "******************** Test ********************\n",
852 | "Step: 2000, Loss: 0.019123367965221405, Accuracy: 98.81 %\n",
853 | "**********************************************\n",
854 | "Epoch(0m 31s): 2/5, Step: 2500, Loss: 0.013506345450878143\n",
855 | "Epoch(0m 36s): 3/5, Step: 3000, Loss: 0.010808899998664856\n",
856 | "******************** Test ********************\n",
857 | "Step: 3000, Loss: 0.016738882288336754, Accuracy: 98.97 %\n",
858 | "**********************************************\n",
859 | "Epoch(0m 42s): 3/5, Step: 3500, Loss: 0.26307404041290283\n",
860 | "Epoch(0m 48s): 4/5, Step: 4000, Loss: 0.1541719138622284\n",
861 | "******************** Test ********************\n",
862 | "Step: 4000, Loss: 0.0011183833703398705, Accuracy: 99.11 %\n",
863 | "**********************************************\n",
864 | "Epoch(0m 55s): 4/5, Step: 4500, Loss: 0.003879919648170471\n"
865 | ],
866 | "name": "stdout"
867 | }
868 | ]
869 | },
870 | {
871 | "metadata": {
872 | "id": "HGfPwCZoN6LV",
873 | "colab_type": "code",
874 | "outputId": "ce67884e-cb84-4a1a-bb1f-82b0afff7996",
875 | "colab": {
876 | "base_uri": "https://localhost:8080/",
877 | "height": 287
878 | }
879 | },
880 | "cell_type": "code",
881 | "source": [
882 | "import matplotlib.pyplot as plt\n",
883 | "import matplotlib.ticker as ticker\n",
884 | "\n",
885 | "plt.figure()\n",
886 | "plt.plot(all_losses)"
887 | ],
888 | "execution_count": 19,
889 | "outputs": [
890 | {
891 | "output_type": "execute_result",
892 | "data": {
893 | "text/plain": [
894 | "[]"
895 | ]
896 | },
897 | "metadata": {
898 | "tags": []
899 | },
900 | "execution_count": 19
901 | },
902 | {
903 | "output_type": "display_data",
904 | "data": {
905 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAD8CAYAAACb4nSYAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAH9ZJREFUeJzt3Xt4XPV95/H3VyNpxpJmbN1tLMuS\nL1xsnNggywQSoI0hkG1woGlj0m7cNl1KNmxD026XblpI6dMmJZs2fVrShTRubg3eJJuL93kIhNA0\n3IJtAQbbgEG+W2DrZluyZF3nu3/MsTyWZTS2JY885/N6nnk058w546+G4XN++p3z+x1zd0REJBzy\nsl2AiIicPwp9EZEQUeiLiISIQl9EJEQU+iIiIaLQFxEJEYW+iEiIKPRFREJEoS8iEiL52S5gtIqK\nCq+rq8t2GSIiF5QXXnih3d0rx9tuyoV+XV0dTU1N2S5DROSCYmZ7MtlO3TsiIiGi0BcRCRGFvohI\niCj0RURCRKEvIhIiCn0RkRBR6IuIhEjOhH5X3yBf/tkbvLzvcLZLERGZsnIm9N3hyz97k027O7Nd\niojIlJUzoZ+I5RMryONgV1+2SxERmbJyJvTNjOpEjINd/dkuRURkysqZ0AeojsfU0hcReQcZhb6Z\n3WRm282s2czuGeP1O81si5ltNrNnzGxR2mt/Fuy33cw+MJHFj1aZiNLWrZa+iMjpjBv6ZhYBHgRu\nBhYBt6eHeuA77r7E3ZcCDwB/F+y7CFgNLAZuAr4SvN+kUEtfROSdZdLSbwSa3X2nuw8A64BV6Ru4\ne1faYjHgwfNVwDp373f3XUBz8H6TojoRpWdgmKP9Q5P1T4iIXNAyCf3ZwL605f3BupOY2afMbAep\nlv4fnuG+d5hZk5k1tbW1ZVr7KaoTMQC19kVETmPCTuS6+4PuPh/4H8Cfn+G+D7t7g7s3VFaOe+OX\n06qKRwFo1RU8IiJjyiT0W4A5acs1wbrTWQd8+Cz3PSdVQUu/tVstfRGRsWQS+puAhWZWb2aFpE7M\nrk/fwMwWpi3+J+DN4Pl6YLWZRc2sHlgIbDz3ssdWnUi19NW9IyIytnHvkevuQ2Z2F/A4EAHWuvs2\nM7sfaHL39cBdZrYSGAQOAWuCfbeZ2XeBV4Eh4FPuPjxJvwsl0XyKCiMaoCUichoZ3Rjd3R8FHh21\n7t60559+h33/Gvjrsy3wTJgZVfGoWvoiIqeRUyNyIdWv36oBWiIiY8q50K9OxGhVS19EZEy5F/rx\nKAe7+nH38TcWEQmZ3Av9RIxjg8N0a1SuiMgpci70qxIaoCUicjq5F/rxYICW+vVFRE6Rc6E/MkBL\no3JFRE6Rc6FfNTLpmrp3RERGy7nQL4nmU1wYUZ++iMgYci70IXUFj7p3REROlZOhX5WI6kSuiMgY\ncjL0qxMx9emLiIwhJ0O/Kh6ltbtPo3JFREbJydCvTsToG0zS1adRuSIi6XIy9EfuoKV+fRGRk+Rk\n6FfHj99BS/36IiLpcjL0da9cEZGx5Wboq6UvIjKmnAz94mg+8Wi+bpsoIjJKToY+BAO01L0jInKS\nnA19DdASETlVzob+8QFaIiJyQs6G/vGWvkblioickLOhX5WIMTCU5MixwWyXIiIyZeRs6I/cQUv9\n+iIiIzIKfTO7ycy2m1mzmd0zxuufMbNXzewVM3vSzOamvTZsZpuDx/qJLP6djNwrV/36IiIj8sfb\nwMwiwIPADcB+YJOZrXf3V9M2ewlocPdeM/sk8ADw0eC1Y+6+dILrHpda+iIip8qkpd8INLv7Tncf\nANYBq9I3cPefu3tvsPg8UDOxZZ654y19DdASETkhk9CfDexLW94frDudTwA/SVuOmVmTmT1vZh8e\nawczuyPYpqmtrS2DksY3rTBCIpavmTZFRNKM271zJszst4EG4Lq01XPdvcXM5gH/bmZb3H1H+n7u\n/jDwMEBDQ8OEXWNZlYjR2q3uHRGR4zJp6bcAc9KWa4J1JzGzlcBngVvcfSRp3b0l+LkT+A9g2TnU\ne0aqE1F174iIpMkk9DcBC82s3swKgdXASVfhmNky4CFSgd+atr7UzKLB8wrgGiD9BPCkqo5rKgYR\nkXTjdu+4+5CZ3QU8DkSAte6+zczuB5rcfT3wRaAE+J6ZAex191uAy4CHzCxJ6gDzhVFX/UyqqkSM\ntu7UqNygLhGRUMuoT9/dHwUeHbXu3rTnK0+z33PAknMp8FxUxaMMDCc53DtIaXFhtsoQEZkycnZE\nLqTm3wE4qAFaIiJAzoe+BmiJiKTL8dDXAC0RkXQ5HfqVwb1y23StvogIkOOhHyuIMH1agVr6IiKB\nnA590AAtEZF0IQh9DdASETku50O/Kh7TpGsiIoHcD/1ElLaj/SSTuleuiEjOh351PMrgsHOodyDb\npYiIZF3uh/7Itfrq1xcRyfnQr9JUDCIiI3I/9I8P0FJLX0QkBKE/Mv+OWvoiIjkf+tH8CKVFBere\nEREhBKEPGqAlInJcKEJfN0gXEUkJR+jHoxqVKyJCSEK/OhGltVujckVEQhL6MYaTTkePRuWKSLiF\nIvSr4qkBWq26gkdEQi4coR9cq9+qK3hEJORCEfq6V66ISEooQr+y5PioXLX0RSTcMgp9M7vJzLab\nWbOZ3TPG658xs1fN7BUze9LM5qa9tsbM3gweayay+EwV5udRXlyoPn0RCb1xQ9/MIsCDwM3AIuB2\nM1s0arOXgAZ3fxfwfeCBYN8y4D5gBdAI3GdmpRNXfuYq41G19EUk9DJp6TcCze6+090HgHXAqvQN\n3P3n7t4bLD4P1ATPPwA84e6d7n4IeAK4aWJKPzPViZha+iISepmE/mxgX9ry/mDd6XwC+MlZ7jtp\nqhNRncgVkdDLn8g3M7PfBhqA685wvzuAOwBqa2snsqQR1YkYbd39DCedSJ5Nyr8hIjLVZdLSbwHm\npC3XBOtOYmYrgc8Ct7h7/5ns6+4Pu3uDuzdUVlZmWvsZqYpHSTp09KhfX0TCK5PQ3wQsNLN6MysE\nVgPr0zcws2XAQ6QCvzXtpceBG82sNDiBe2Ow7rw7fttEDdASkTAbN/TdfQi4i1RYvwZ81923mdn9\nZnZLsNkXgRLge2a22czWB/t2An9F6sCxCbg/WHfeaYCWiEiGffru/ijw6Kh196Y9X/kO+64F1p5t\ngROlOqEBWiIioRiRC1BREsVMk66JSLiFJvQLIqlRuWrpi0iYhSb0ITXFsu6gJSJhFqrQr05EOaju\nHREJsZCFfkyXbIpIqIUq9KviUdqP9jM0nMx2KSIiWRGu0E/EglG5uleuiIRTqEJfA7REJOxCFvq6\nV66IhFuoQr8qHrT0dQWPiIRUqEK/oqQQM03FICLhFarQz4/kUVES1QAtEQmtUIU+pPr1W7vV0heR\ncApf6MdjunpHREIrdKFflYiqT19EQit8oR+P0dHTz6BG5YpICIUu9KsTMdyh/aha+yISPiEMfQ3Q\nEpHwCl3ojwzQ0slcEQmh0IX+yL1yddmmiIRQ6EK/vCRKnqEBWiISSqEL/UieURmPqk9fREIpdKEP\nqX59TbomImEUytCv1gAtEQmpUIZ+VSKmPn0RCaWMQt/MbjKz7WbWbGb3jPH6tWb2opkNmdlHRr02\nbGabg8f6iSr8XFTHY3T0DGhUroiETv54G5hZBHgQuAHYD2wys/Xu/mraZnuB3wH+ZIy3OObuSyeg\n1glTFVy22dbdz0UzpmW5GhGR8yeTln4j0OzuO919AFgHrErfwN13u/srwAXRdB65Vl9dPCISMpmE\n/mxgX9ry/mBdpmJm1mRmz5vZh8+ouklyYlSuTuaKSLiM270zAea6e4uZzQP+3cy2uPuO9A3M7A7g\nDoDa2tpJL6g6kQr9Nl22KSIhk0lLvwWYk7ZcE6zLiLu3BD93Av8BLBtjm4fdvcHdGyorKzN967NW\nXlxIJM/U0heR0Mkk9DcBC82s3swKgdVARlfhmFmpmUWD5xXANcCr77zX5MvLMypLourTF5HQGTf0\n3X0IuAt4HHgN+K67bzOz+83sFgAzW25m+4HfAB4ys23B7pcBTWb2MvBz4AujrvrJmupEVJOuiUjo\nZNSn7+6PAo+OWndv2vNNpLp9Ru/3HLDkHGucFFWJGPs6e7NdhojIeRXKEbmQaum3qqUvIiET2tCv\nisfo7Bmgf2g426WIiJw3oQ396rRRuSIiYRHa0K9KaICWiIRPaEO/Oq4BWiISPqEN/aqR+XfU0heR\n8Aht6JcVFZKfZxqgJSKhEtrQz8szquK6g5aIhEtoQx+CO2ipT19EQiTUoV+diNKqlr6IhEioQ78q\nHuOgWvoiEiKhDv3qRJTDvYP0DWpUroiEQ6hDv2rkZirq4hGRcAh16B+/g5ZO5opIWIQ69KviGqAl\nIuES6tCvHpl/Ry19EQmHUId+aVEBBRHTvPoiEhqhDn0zS122qZa+iIREqEMfUhOvaYCWiIRF6EO/\nWi19EQkRhb7ulSsiIRL60K9KxDhyTKNyRSQcFPrBtfrq1xeRMAh96M+aPg2AZ3e0Z7kSEZHJF/rQ\nb6wvo7GujHt/vJWn3mjLdjkiIpMqo9A3s5vMbLuZNZvZPWO8fq2ZvWhmQ2b2kVGvrTGzN4PHmokq\nfKIU5ufx1TUNLKiKc+e3X2DzvsPZLklEZNKMG/pmFgEeBG4GFgG3m9miUZvtBX4H+M6ofcuA+4AV\nQCNwn5mVnnvZE2v6tAK+8XvLqSiJ8rv/upHm1u5slyQiMikyaek3As3uvtPdB4B1wKr0Ddx9t7u/\nAiRH7fsB4Al373T3Q8ATwE0TUPeEq4rH+NYnGonk5fHxr23krcPHsl2SiMiEyyT0ZwP70pb3B+sy\ncS77nndzy4v5+u8up7tviI+v3cihnoFslyQiMqGmxIlcM7vDzJrMrKmtLbsnUy+fPZ2vrmlgb2cv\nv/v1TfQODGW1HhGRiZRJ6LcAc9KWa4J1mchoX3d/2N0b3L2hsrIyw7eePFfNK+cfb1/GK/sPc+e3\nX2RgaHSvlYjIhSmT0N8ELDSzejMrBFYD6zN8/8eBG82sNDiBe2Owbsr7wOKZfP62JTz1Rht/8r2X\nSSY92yWJiJyz/PE2cPchM7uLVFhHgLXuvs3M7gea3H29mS0HfgiUAh8ys79098Xu3mlmf0XqwAFw\nv7t3TtLvMuE+uryWjp4BHnhsO2XFhdz3oUWYWbbLEhE5a+OGPoC7Pwo8OmrdvWnPN5Hquhlr37XA\n2nOoMas+ed18Oo4O8LVndlFRUshdv7ow2yWJiJy1jEI/zMyMz37wMg71DPC/fvoGZcVRPraiNttl\niYicFYV+BvLyjL/9yLs41DvAn/9oC6VFBdy8ZFa2yxIROWNT4pLNC0FBJI+v/NaVLKst5dPrNvNc\nsyZoE5ELj0L/DEwrjPC1NQ3UVRTxX77ZxJb9R7JdkojIGVHon6EZRYV88/dWMKOokI+v3cD/e/kt\n3HU5p4hcGBT6Z2Hm9Bjf/v0V1JQW8d8eeYmPr93I7vaebJclIjIuhf5Zqq8o5kefuoa/vGUxL+09\nzI1ffop/+Nmb9A/ptosiMnUp9M9BJM9Yc3UdT/7xddy4qJq//9kb3Pzlp3lWJ3lFZIpS6E+A6kSM\nf/rYFXzz9xoZdue3/mUDn173Eq3dfdkuTUTkJAr9CXTtxZU8fve1/OH7F/KTLQd4/5d+wbee38Ow\n5u0RkSlCoT/BYgURPnPDxTx29/t4V810/uJHW7ntK8+ytUWXd4pI9in0J8m8yhK+/YkV/MPqpbQc\nPsYt//QMn1u/je6+wWyXJiIhptCfRGbGqqWzefIz1/OxFbV845e7ef+XfsEPX9qvOfpFJCtsqg0s\namho8KampmyXMSk27zvMZ3+4hW1vdVFeXMhHGmq4fXktdRXF2S5NRC5wZvaCuzeMu51C//waTjpP\nvdnGIxv28uTrrQwnnWsWlLN6eS03Lq4mmh/JdokicgFS6F8ADnb18b2mfTyycR8th49RVlzIR66s\nYfXyOcyrLMl2eSJyAVHoX0CGk84zze08smEvT7x2kOGk85555axunMNNl89U619ExqXQv0C1dvXx\nvRf2s27TXvZ1HqO0qIBfv6KG1Y21LKhS619ExqbQv8Alk86zO9p5ZONefrrtIENJ570LKrh75UIa\n6sqyXZ6ITDEK/RzS1t3P917Yx9pndtF+dID3Lazg7pUXc+Xc0myXJiJThEI/Bx0bGObbz+/hf/9i\nBx09A1x7cSV/tHIhy2oV/iJhp9DPYb0DQ3zrl3t46KmddPYMcP0lldy98mKWzpmR7dJEJEsU+iHQ\n0z/EN3+5h4ef2sGh3kF+9dIq7l65kHfVKPxFwkahHyJH+4f4xnO7+erTOzncO8jKy6q4e+XFXD57\n+hm9j7tztH+I1u5+OnsGqIpHqSktIpJnk1S5iEwUhX4IdfcNBuG/iyPHBrlhUTV3r1zIZTMTdPYO\n0NrVT2t3H63d/bQFj9buPlq7+mk72k9rVz/HBk++81dhJI+6iiLmV5akHlXFzK8sYV5lCSXR/Cz9\npiIy2oSGvpndBPwDEAH+xd2/MOr1KPBN4EqgA/iou+82szrgNWB7sOnz7n7nO/1bCv1z19U3yNef\n3c2/PL2Trr4hInk25pz+8Vg+lfEoVfEoVfHYieeJKKVFhbR297Oj7Sg7WnvY2XaUPZ29J71PdSJ6\n4mBQWcy8yhIWVJUwa3oMM/11IHI+ZRr64zbVzCwCPAjcAOwHNpnZend/NW2zTwCH3H2Bma0G/hb4\naPDaDndfesa/gZy1RKyAP3z/QtZcXce6jXvp7hs6KdArS1IBP63wzEb6Dgwl2dvZmzoQBAeDHW1H\n+dHmFrr7hka2mzU9RmN9GSvqy2msL2N+ZbEOAiJTRCZ/nzcCze6+E8DM1gGrgPTQXwV8Lnj+feCf\nTP+XZ930aQX8wXXzJ+z9CvPzWFBVcsrIYHen/egAO9qOsv1ANxt3d/Jscwc/3vwWABUlhScdBC6p\njpOn8wQiWZFJ6M8G9qUt7wdWnG4bdx8ysyNAefBavZm9BHQBf+7uT59byTLVmBmV8SiV8ShXzStn\nzdV1uDu72nvYuKuTDbs62bCzg0e3HABSB6PldWWsqC+jsb6MxRclyI+cuLVDMul09w3R2TtAZ08/\nnT2DHOoZoKNngEO9A3T2nHgURIyr5pVzzYIKltXO0DxFIuOY7DNxbwO17t5hZlcCPzKzxe7elb6R\nmd0B3AFQW1s7ySXJ+WBmzAtO+K5uTP033X+olw07O4MDQQc/e+0gAMWFES6eGaenfygV8L0Dp72v\ncGF+HuXFhZQFj66+IR78eTP/+O/NxAryaKwv55r5qYPAolkJ/UUhMkomod8CzElbrgnWjbXNfjPL\nB6YDHZ46S9wP4O4vmNkO4GLgpDO17v4w8DCkTuSexe8hF4Ca0iJqrizi16+sAVJTS2/Y1cnGXR00\ntx6lKl7MlXMLKS06EeqlxYWUF59YV1QYOeX8wJFjg2zY2cFzOzp4prmdz//kdQBKiwq4en4FVy8o\n55r5FcwtL9K5BQm9ca/eCUL8DeD9pMJ9E/Axd9+Wts2ngCXufmdwIvc2d/9NM6sEOt192MzmAU8H\n23We7t/T1Ttyrg529fFsczvPNnfwbHM7B7r6AJg9YxrXLAi6guaUctGM2EndSnJ+dBzt583Woyys\nKqG8JJrtcnLGhF29E/TR3wU8TuqSzbXuvs3M7gea3H098DXgW2bWDHQCq4PdrwXuN7NBIAnc+U6B\nLzIRqhMxbruihtuuqMHd2dneExwE2nls6wG+27QfgEiecdGMGLVlRdSWFVFTWjTyvLasiBlFBfrL\n4Bz0DQ7T3HqU197uYvuBbrYf7Oa1t7tpP9oPQCKWz+duWcyty2brcz6PNDhLQmU46WxtOcL2A93s\n7ewdeew/1Ev70YGTti2J5jOnrIjasmnUlhUxp6yIRbMSLKmZrhPGaZJJZ/+hY7x+oIvXD3Sz/UA3\nrx/oYld7D8dPzUTz81hYXcKlMxNcOjNObVkRDz+1k6Y9h3j/pVX8zW1LqE7EsvuLXOA0IlfkDPX0\nD7HvUC97O44fCI6NHBT2dfbSP5QEUgH27jkzWF5XyvK6Mq6cW0o8VpDl6s+/F/Yc4qFf7ODZ5nZ6\nBk6M5K4tK+KSmXEumxnnkpkJLp0Vp668+JTpPIaTztef280XH3+dwkge931oMbddoVb/2VLoi0yg\nZNJp7e5n877DNO3uZNPuTra+1cVw0skzuHRmgsb6MpbXlbG8vpSqeG62Wt2dp95s5ys/b2bDrk6m\nTyvgQ++exeKLpnPJzDgXV8fPeHqOXe09/On3X2bT7kP86qVV/M2tS5g5PTc/v8mk0BeZZD39Q7y0\n9zCbgoPAS3sPj8xdNLe8iOV1ZTTWlXHF3FJKovn0Dw0zMJSkf+RxYvnkn6n1SYd310znirmlxAqy\n2500nHR+svVt/vk/drDtrS5mJmL8/vvqub2xluIJmIMpGbT6H3j8dQqCVv+vq9V/RhT6IufZ4HCS\nrS1HaNp9iI27O2na3cmh3sFzft9ofh7L68q4ZkEF1ywoZ/FF08/bzKf9Q8P88MUWHnpqJ7vae6iv\nKObO6+bx4WWzJ+W8xu72Hv70+6+wcXcnv3JJJZ+/7V1q9WdIoS+SZcmks7P9KC/uPczQsBPNz6Mw\nPy/tZ2RkOTpquTA/j6FhZ9PuTp7dkbry6I2DR4HUiOb3zCsfufy0vmLi5zbq6R/ikY17+erTOznY\n1c/lsxP81+sX8IHFMyf9gJNMOt/85W7+9rHt5EeMe39tER+5skat/nEo9EVyTGt3H7/c0cEzb7bz\n3I4OWg4fA1IT3B3/K+Ca+RVUncNVMId6BvjX53bzjed2c+TYIO+ZV84nr5/P+xZWnPfQ3dPRw3//\n/its3NXJ9ZdU8vnbljBr+rTzWsOFRKEvksPcnT0dvTzT3M5zO1IHgcNBV9KcsmkkYgVMK4gQG3nk\njSxPK4wQy88jVhghlh8sF+SxZX8Xj2zcy7HBYW5YVM0nr5/PFVm+//LoVv9f/NoifmOKt/pbu/t4\n/e1uCiJ5FOYbhZHUX3AFEaMw+CuuMHLiZyTPJuT3UeiLhEgy6bz6dhfPNrezpeUIxwaG6RsaTv0c\nTNI3OEzf4DDHBlPLo2+WA6nBaquWXsQnr5vPwup4Fn6L09vTkerr37Crk/fMK+cvVy3m4ilUY+/A\nED/ddpAfvNTCM2+2cZqpo8ZkxshBYOmcGXzrE6Pns8z0fRT6InIa7k7/0PGDQeogEI/lUzGFp0VI\nJp1HNu3lgce209M/xO9cXcenVy7M2hiJ4aTz/M4OfvBiC49tfZuegWFmz5jGrctm896FFbinTu4P\nDCUZGP0zeD446rVZ02Pcce3ZTYeu0BeRnNTZM8AXH3+ddZv2UVkS5X9+8DJWLb3ovHX5bD/QzQ9e\n2s+PX3qLA119xKP5fHDJLG69YjaNdWVZm9lVoS8iOe3lfYe598dbeXn/ERrry7h/1WIunZmYlH+r\ntbuP9Zvf4gcvtvDq211E8ozrL67k1itms/Ky6qyPowCFvoiEQDLp/J+mfTzw2Ot09Q3x8ffM5Y9u\nuJjEBHT5HO4d4BdvtPGDF1t4Ouinf1fNdG5dNpsPvfuiKdcVptAXkdA43DvAFx/fznc27qW8OMqf\n3XzpGc/j03G0f+ROb8/v7GD7wW7cU1Nyf3jZRdy6rOaUW4VOJQp9EQmdLfuP8Bc/3srmfYdpmFvK\n/asuZ9FFY3f5tHX3s2FXBxt2pu7kdnzwW6wgjyvnlrKivpyr55dzRW3pBXEHNoW+iIRSMul8/4X9\nfOGx1zncO8B/vmoun7nxEvoGh3l+Z8fIPZt3tPUAUFQY4cq5pVw1r5yr5pWxZPYMCvMvvJvrKPRF\nJNSO9A7ypSe28+3n95AfyWMgmBo7Hs2noa6UFfPKWVFfxuWzp1OQA3dQm7A7Z4mIXIimFxVw/6rL\n+ejyOfzbhr3UlxezYl4Zi2YlQn2bTIW+iOS0xRdN529uXZLtMqaM8B7uRERCSKEvIhIiCn0RkRBR\n6IuIhIhCX0QkRBT6IiIhotAXEQkRhb6ISIhMuWkYzKwN2HMOb1EBtE9QOblCn8mp9JmcSp/JqS6k\nz2Suu1eOt9GUC/1zZWZNmcw/ESb6TE6lz+RU+kxOlYufibp3RERCRKEvIhIiuRj6D2e7gClIn8mp\n9JmcSp/JqXLuM8m5Pn0RETm9XGzpi4jIaeRM6JvZTWa23cyazeyebNczFZjZbjPbYmabzSy0tyMz\ns7Vm1mpmW9PWlZnZE2b2ZvCzNJs1nm+n+Uw+Z2Ytwfdls5l9MJs1nm9mNsfMfm5mr5rZNjP7dLA+\np74rORH6ZhYBHgRuBhYBt5vZouxWNWX8irsvzbXLzs7Q14GbRq27B3jS3RcCTwbLYfJ1Tv1MAP4+\n+L4sdfdHz3NN2TYE/LG7LwKuAj4V5EhOfVdyIvSBRqDZ3Xe6+wCwDliV5ZpkinD3p4DOUatXAd8I\nnn8D+PB5LSrLTvOZhJq7v+3uLwbPu4HXgNnk2HclV0J/NrAvbXl/sC7sHPipmb1gZndku5gpptrd\n3w6eHwCqs1nMFHKXmb0SdP9c0N0Y58LM6oBlwAZy7LuSK6EvY3uvu19BqtvrU2Z2bbYLmoo8dQmb\nLmODfwbmA0uBt4EvZbec7DCzEuD/Ane7e1f6a7nwXcmV0G8B5qQt1wTrQs3dW4KfrcAPSXWDScpB\nM5sFEPxszXI9WefuB9192N2TwFcJ4ffFzApIBf6/ufsPgtU59V3JldDfBCw0s3ozKwRWA+uzXFNW\nmVmxmcWPPwduBLa+816hsh5YEzxfA/w4i7VMCceDLXArIfu+mJkBXwNec/e/S3spp74rOTM4K7i8\n7MtABFjr7n+d5ZKyyszmkWrdA+QD3wnrZ2JmjwDXk5ox8SBwH/Aj4LtALalZXX/T3UNzYvM0n8n1\npLp2HNgN/EFaX3bOM7P3Ak8DW4BksPp/kurXz5nvSs6EvoiIjC9XundERCQDCn0RkRBR6IuIhIhC\nX0QkRBT6IiIhotAXEQkRhb6ISIgo9EVEQuT/A6q5iSeartITAAAAAElFTkSuQmCC\n",
906 | "text/plain": [
907 | ""
908 | ]
909 | },
910 | "metadata": {
911 | "tags": []
912 | }
913 | }
914 | ]
915 | },
916 | {
917 | "metadata": {
918 | "id": "ta_UwyT1r0L6",
919 | "colab_type": "text"
920 | },
921 | "cell_type": "markdown",
922 | "source": [
923 | "## Test and Visualize"
924 | ]
925 | },
926 | {
927 | "metadata": {
928 | "id": "6AwZsqeJr0L6",
929 | "colab_type": "code",
930 | "outputId": "ad8f32f8-5934-4d01-fe13-4292c4c89137",
931 | "colab": {
932 | "base_uri": "https://localhost:8080/",
933 | "height": 71
934 | }
935 | },
936 | "cell_type": "code",
937 | "source": [
938 | "# Test\n",
939 | "model.eval()\n",
940 | "acc = 0.\n",
941 | "with torch.no_grad():\n",
942 | " for idx, (images, labels) in enumerate(test_loader):\n",
943 | " x, y = images.to(DEVICE), labels.to(DEVICE) # (N, 1, 28, 28), (N, )\n",
944 | " y_hat = model(x) # (N, 10)\n",
945 | " loss = criterion(y_hat, y)\n",
946 | " _, indices = torch.max(y_hat, dim=-1)\n",
947 | " acc += torch.sum(indices == y).item()\n",
948 | "print('*'*20, 'Test', '*'*20)\n",
949 | "print('Step: {}, Loss: {}, Accuracy: {} %'.format(step, loss.item(), acc/len(mnist_test)*100))\n",
950 | "print('*'*46)"
951 | ],
952 | "execution_count": 20,
953 | "outputs": [
954 | {
955 | "output_type": "stream",
956 | "text": [
957 | "******************** Test ********************\n",
958 | "Step: 4685, Loss: 0.0009241390507668257, Accuracy: 99.02 %\n",
959 | "**********************************************\n"
960 | ],
961 | "name": "stdout"
962 | }
963 | ]
964 | },
965 | {
966 | "metadata": {
967 | "id": "nFoQkM7Mr0L8",
968 | "colab_type": "code",
969 | "outputId": "3641821b-6e4e-42a2-866e-b99dcc420672",
970 | "colab": {
971 | "base_uri": "https://localhost:8080/",
972 | "height": 35
973 | }
974 | },
975 | "cell_type": "code",
976 | "source": [
977 | "idx = 7777 # 0 to 9999\n",
978 | "img, y = mnist_test[idx]\n",
979 | "img.shape, y"
980 | ],
981 | "execution_count": 21,
982 | "outputs": [
983 | {
984 | "output_type": "execute_result",
985 | "data": {
986 | "text/plain": [
987 | "(torch.Size([1, 28, 28]), 5)"
988 | ]
989 | },
990 | "metadata": {
991 | "tags": []
992 | },
993 | "execution_count": 21
994 | }
995 | ]
996 | },
997 | {
998 | "metadata": {
999 | "id": "qzrcr3vcr0L-",
1000 | "colab_type": "code",
1001 | "outputId": "d837ce23-ffbb-45ba-e687-613d5b05e63e",
1002 | "colab": {
1003 | "base_uri": "https://localhost:8080/",
1004 | "height": 287
1005 | }
1006 | },
1007 | "cell_type": "code",
1008 | "source": [
1009 | "imshow(img[0], cmap='gray')"
1010 | ],
1011 | "execution_count": 22,
1012 | "outputs": [
1013 | {
1014 | "output_type": "execute_result",
1015 | "data": {
1016 | "text/plain": [
1017 | ""
1018 | ]
1019 | },
1020 | "metadata": {
1021 | "tags": []
1022 | },
1023 | "execution_count": 22
1024 | },
1025 | {
1026 | "output_type": "display_data",
1027 | "data": {
1028 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAADeRJREFUeJzt3X+MVPW5x/HPg4BGSiLcphtCzUUJ\nNkFjabOBm0huWq8CmibQP9RiYriRdFFRS6xGo3+IuWk0N7RV+YNkjUQwlELiD7BKW0rI5dY0DatR\n/NUWS2jKBhcVA6IJuPjcP/Zws8DOd4aZ82t93q9kszPnmTPnyWQ/e87M95z5mrsLQDxjqm4AQDUI\nPxAU4QeCIvxAUIQfCIrwA0ERfiAowg8ERfiBoMaWuTEz43RCoGDubq08rqM9v5ktMLO/mtn7ZvZA\nJ88FoFzW7rn9ZnaepL9JulbSAUm7JS1293cT67DnBwpWxp5/tqT33X2fu5+Q9GtJCzt4PgAl6iT8\nUyX9c9j9A9my05hZj5n1mVlfB9sCkLPCP/Bz915JvRKH/UCddLLn75d08bD738yWARgFOgn/bkkz\nzOwSMxsv6UeStubTFoCitX3Y7+6DZnanpN9JOk/SWnd/J7fOABSq7aG+tjbGe36gcKWc5ANg9CL8\nQFCEHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiAowg8ERfiBoAg/EBThB4Ii\n/EBQhB8IivADQRF+IKhSp+hGMcaNG9ewdsEFFxS67WuuuSZZnz9/fsPasmXL8m7nNGvXrm1YGxgY\nSK67e/fuZP2VV15J1o8fP56s1wF7fiAowg8ERfiBoAg/EBThB4Ii/EBQhB8IqqNZes1sv6RPJZ2U\nNOju3U0ezyy9bZg7d26yvnLlyoa1q6++uqNtm6UnfC1zluc6efTRR5P1hx56qKROztbqLL15nOTz\nfXf/KIfnAVAiDvuBoDoNv0v6vZm9ZmY9eTQEoBydHvbPdfd+M/uGpO1m9hd33zX8Adk/Bf4xADXT\n0Z7f3fuz34ckvSBp9giP6XX37mYfBgIoV9vhN7MJZjbx1G1J8yS9nVdjAIrVyWF/l6QXsqGgsZJ+\n5e6/zaUrAIVrO/zuvk/St3PsJazp06cn62vWrEnWL7/88jzbgaQjR44k6y+99FJJnRSHoT4gKMIP\nBEX4gaAIPxAU4QeCIvxAUB1d0nvOG+OS3rbccMMNyfqmTZsK23adL+k9evRosr5v3762n/uuu+5K\n1l999dW2n7torV7Sy54fCIrwA0ERfiAowg8ERfiBoAg/EBThB4Jiiu4amDx5crK+dOnSkjrJ3969\nexvW3nzzzeS627dvT9Y//PDDZP3FF19M1qNjzw8ERfiBoAg/EBThB4Ii/EBQhB8IivADQTHOXwNz\n5sxJ1ufNm1dSJ2c7efJksn7vvfcm6xs3bmxYGxgYaKsn5IM9PxAU4QeCIvxAUIQfCIrwA0ERfiAo\nwg8E1XSc38zWSvqBpEPufkW2bLKkTZKmSdov6UZ3/6S4Nr/a6ny9/s0335ysb968uaROkLdW9vzP\nSFpwxrIHJO1w9xmSdmT3AYwiTcPv7rskHT5j8UJJ67Lb6yQtyrkvAAVr9z1/l7sfzG5/IKkrp34A\nlKTjc/vd3VNz8JlZj6SeTrcDIF/t7vkHzGyKJGW/DzV6oLv3unu3u3e3uS0ABWg3/FslLcluL5G0\nJZ92AJSlafjNbKOkP0n6lpkdMLOlkh6TdK2Z7ZV0TXYfwChiZc6vnvpsILLFixcn6xs2bCipk7N9\n8kn69I3Dh88cCDrdyy+/3LC2c+fO5LpbtnBA2Q53t1Yexxl+QFCEHwiK8ANBEX4gKMIPBEX4gaAY\n6quBmTNnJuurVq1K1hcsOPOiy/yYpUeNOvn7GRwcTNY//vjjZP3ZZ59N1lNDidu2bUuuO5ox1Acg\nifADQRF+ICjCDwRF+IGgCD8QFOEHgmKcfxQYOzb9bWsrVqxoWHv44YeT606YMCFZL3Kcv2ip6cUf\nf/zx5LqPPPJIsn7s2LG2eioD4/wAkgg/EBThB4Ii/EBQhB8IivADQRF+ICjG+YO76qqrkvV58+Yl\n6/fcc0+yPmZM4/3LhRdemFy3SqtXr07W77777pI6OXeM8wNIIvxAUIQfCIrwA0ERfiAowg8ERfiB\noJqO85vZWkk/kHTI3a/Ilq2U9GNJH2YPe9DdX2m6Mcb5w7nooosa1m655Zbkus3q3d3dbfXUij17\n9iTrc+bMSdaPHz+eZzvnJM9x/mckjTQrxC/dfVb20zT4AOqlafjdfZekwyX0AqBEnbznv9PM9pjZ\nWjOblFtHAErRbvjXSJouaZakg5J+3uiBZtZjZn1m1tfmtgAUoK3wu/uAu5909y8lPSVpduKxve7e\n7e7FfToD4Jy1FX4zmzLs7g8lvZ1POwDKkv5OaElmtlHS9yR93cwOSHpY0vfMbJYkl7Rf0rICewRQ\nAK7nR211dXUl67t27UrWZ8yYkWc7p5k4cWKy/tlnnxW27Wa4nh9AEuEHgiL8QFCEHwiK8ANBEX4g\nqKbj/CjeZZddlqzX+Suu58+fn6xPnTq1Ya3Z1183Gy4rcjht27ZtyXqVl+zmhT0/EBThB4Ii/EBQ\nhB8IivADQRF+ICjCDwTFOH9m/Pjxyfp1113XsHbbbbd1tO3Zsxt+EZIkadKk6r4i0Sx9dWizS8L7\n+/sb1jZv3pxc94477kjWZ82alax3YufOncn64OBgYdsuC3t+ICjCDwRF+IGgCD8QFOEHgiL8QFCE\nHwgqzFd3jx2bPqXhiSeeSNZvv/32PNsZNTod56+r1atXJ+v33Xdfsn7ixIk828kVX90NIInwA0ER\nfiAowg8ERfiBoAg/EBThB4Jqej2/mV0sab2kLkkuqdfdnzCzyZI2SZomab+kG939k+Ja7czy5cuT\n9ajj+F9lTz75ZMPa/fffn1y3zuP4eWllzz8o6afuPlPSv0labmYzJT0gaYe7z5C0I7sPYJRoGn53\nP+jur2e3P5X0nqSpkhZKWpc9bJ2kRUU1CSB/5/Se38ymSfqOpD9L6nL3g1npAw29LQAwSrT8HX5m\n9jVJz0la4e5Hh5/z7e7e6Lx9M+uR1NNpowDy1dKe38zGaSj4G9z9+WzxgJlNyepTJB0aaV1373X3\nbnfvzqNhAPloGn4b2sU/Lek9d//FsNJWSUuy20skbcm/PQBFaeWw/ypJt0h6y8zeyJY9KOkxSZvN\nbKmkf0i6sZgW8zEwMFB1C8hZs8tyU8N5X4UptjvVNPzu/kdJja4P/o982wFQFs7wA4Ii/EBQhB8I\nivADQRF+ICjCDwQV5qu7x4xJ/5/r7e1N1m+99dY82xk1Ov3q7iNHjjSsrV+/Prnupk2bkvW+vr5k\nPcJluSPhq7sBJBF+ICjCDwRF+IGgCD8QFOEHgiL8QFBhxvmbOf/885P1RYsafz/plVdemVz3pptu\nStYvvfTSZL1Iq1atSta/+OKLZP3YsWPJemrq888//zy5LtrDOD+AJMIPBEX4gaAIPxAU4QeCIvxA\nUIQfCIpxfuArhnF+AEmEHwiK8ANBEX4gKMIPBEX4gaAIPxBU0/Cb2cVmttPM3jWzd8zsJ9nylWbW\nb2ZvZD/XF98ugLw0PcnHzKZImuLur5vZREmvSVok6UZJx9w9/W0Qpz8XJ/kABWv1JJ+xLTzRQUkH\ns9ufmtl7kqZ21h6Aqp3Te34zmybpO5L+nC2608z2mNlaM5vUYJ0eM+szs/TcSgBK1fK5/Wb2NUn/\nI+ln7v68mXVJ+kiSS/ovDb01SE5ox2E/ULxWD/tbCr+ZjZP0G0m/c/dfjFCfJuk37n5Fk+ch/EDB\ncruwx4amaX1a0nvDg599EHjKDyW9fa5NAqhOK5/2z5X0v5LekvRltvhBSYslzdLQYf9+ScuyDwdT\nz8WeHyhYrof9eSH8QPG4nh9AEuEHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUIQf\nCIrwA0ERfiCopl/gmbOPJP1j2P2vZ8vqqK691bUvid7alWdv/9rqA0u9nv+sjZv1uXt3ZQ0k1LW3\nuvYl0Vu7quqNw34gKMIPBFV1+Hsr3n5KXXura18SvbWrkt4qfc8PoDpV7/kBVKSS8JvZAjP7q5m9\nb2YPVNFDI2a238zeymYernSKsWwatENm9vawZZPNbLuZ7c1+jzhNWkW91WLm5sTM0pW+dnWb8br0\nw34zO0/S3yRdK+mApN2SFrv7u6U20oCZ7ZfU7e6Vjwmb2b9LOiZp/anZkMzsvyUddvfHsn+ck9z9\n/pr0tlLnOHNzQb01mln6P1Xha5fnjNd5qGLPP1vS++6+z91PSPq1pIUV9FF77r5L0uEzFi+UtC67\nvU5Dfzyla9BbLbj7QXd/Pbv9qaRTM0tX+tol+qpEFeGfKumfw+4fUL2m/HZJvzez18ysp+pmRtA1\nbGakDyR1VdnMCJrO3FymM2aWrs1r186M13njA7+zzXX370q6TtLy7PC2lnzoPVudhmvWSJquoWnc\nDkr6eZXNZDNLPydphbsfHV6r8rUboa9KXrcqwt8v6eJh97+ZLasFd+/Pfh+S9IKG3qbUycCpSVKz\n34cq7uf/ufuAu5909y8lPaUKX7tsZunnJG1w9+ezxZW/diP1VdXrVkX4d0uaYWaXmNl4ST+StLWC\nPs5iZhOyD2JkZhMkzVP9Zh/eKmlJdnuJpC0V9nKauszc3GhmaVX82tVuxmt3L/1H0vUa+sT/75Ie\nqqKHBn1dKunN7OedqnuTtFFDh4FfaOizkaWS/kXSDkl7Jf1B0uQa9fashmZz3qOhoE2pqLe5Gjqk\n3yPpjezn+qpfu0RflbxunOEHBMUHfkBQhB8IivADQRF+ICjCDwRF+IGgCD8QFOEHgvo/y76ebXSU\nlzQAAAAASUVORK5CYII=\n",
1029 | "text/plain": [
1030 | ""
1031 | ]
1032 | },
1033 | "metadata": {
1034 | "tags": []
1035 | }
1036 | }
1037 | ]
1038 | },
1039 | {
1040 | "metadata": {
1041 | "id": "E1qgio9Rr0MC",
1042 | "colab_type": "code",
1043 | "outputId": "cc9202b1-d9b4-4ca6-9dc9-bf3e7167cefd",
1044 | "colab": {
1045 | "base_uri": "https://localhost:8080/",
1046 | "height": 35
1047 | }
1048 | },
1049 | "cell_type": "code",
1050 | "source": [
1051 | "sample = img.unsqueeze(dim=0).to(DEVICE)\n",
1052 | "out = model(sample).cpu()\n",
1053 | "_, idx = out.max(dim=-1)\n",
1054 | "print(\"prediction = \", idx.item())"
1055 | ],
1056 | "execution_count": 28,
1057 | "outputs": [
1058 | {
1059 | "output_type": "stream",
1060 | "text": [
1061 | "prediction = 5\n"
1062 | ],
1063 | "name": "stdout"
1064 | }
1065 | ]
1066 | },
1067 | {
1068 | "metadata": {
1069 | "id": "8cR6hl9hr0MF",
1070 | "colab_type": "code",
1071 | "colab": {}
1072 | },
1073 | "cell_type": "code",
1074 | "source": [
1075 | "# Saving params.\n",
1076 | "torch.save(model.state_dict(), 'model.pkl')"
1077 | ],
1078 | "execution_count": 0,
1079 | "outputs": []
1080 | }
1081 | ]
1082 | }
--------------------------------------------------------------------------------
/GAN.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "kernelspec": {
6 | "name": "python3",
7 | "display_name": "Python 3"
8 | },
9 | "language_info": {
10 | "codemirror_mode": {
11 | "name": "ipython",
12 | "version": 3
13 | },
14 | "file_extension": ".py",
15 | "mimetype": "text/x-python",
16 | "name": "python",
17 | "nbconvert_exporter": "python",
18 | "pygments_lexer": "ipython3",
19 | "version": "3.6.7"
20 | },
21 | "colab": {
22 | "name": "GAN.ipynb",
23 | "provenance": [],
24 | "include_colab_link": true
25 | },
26 | "accelerator": "GPU"
27 | },
28 | "cells": [
29 | {
30 | "cell_type": "markdown",
31 | "metadata": {
32 | "id": "view-in-github",
33 | "colab_type": "text"
34 | },
35 | "source": [
36 | "
"
37 | ]
38 | },
39 | {
40 | "cell_type": "markdown",
41 | "metadata": {
42 | "id": "c7cFO8VuKr16",
43 | "colab_type": "text"
44 | },
45 | "source": [
46 | "# Implementation of Vanilla GANs model\n",
47 | "Reference: https://arxiv.org/pdf/1406.2661.pdf"
48 | ]
49 | },
50 | {
51 | "cell_type": "code",
52 | "metadata": {
53 | "id": "i2TQr5qBKr18",
54 | "colab_type": "code",
55 | "colab": {}
56 | },
57 | "source": [
58 | "# Run the comment below only when using Google Colab\n",
59 | "# !pip install torch torchvision"
60 | ],
61 | "execution_count": 0,
62 | "outputs": []
63 | },
64 | {
65 | "cell_type": "code",
66 | "metadata": {
67 | "id": "1_C-4VbXKr1-",
68 | "colab_type": "code",
69 | "colab": {}
70 | },
71 | "source": [
72 | "import torch\n",
73 | "import torchvision\n",
74 | "import torch.nn as nn"
75 | ],
76 | "execution_count": 0,
77 | "outputs": []
78 | },
79 | {
80 | "cell_type": "code",
81 | "metadata": {
82 | "id": "4so3lTmqKr2A",
83 | "colab_type": "code",
84 | "colab": {}
85 | },
86 | "source": [
87 | "from torch.utils.data import DataLoader\n",
88 | "from torchvision import datasets\n",
89 | "from torchvision import transforms"
90 | ],
91 | "execution_count": 0,
92 | "outputs": []
93 | },
94 | {
95 | "cell_type": "code",
96 | "metadata": {
97 | "id": "DIQZBMekKr2C",
98 | "colab_type": "code",
99 | "colab": {}
100 | },
101 | "source": [
102 | "import numpy as np\n",
103 | "import datetime\n",
104 | "import os, sys"
105 | ],
106 | "execution_count": 0,
107 | "outputs": []
108 | },
109 | {
110 | "cell_type": "code",
111 | "metadata": {
112 | "id": "5UJYgSALKr2E",
113 | "colab_type": "code",
114 | "colab": {}
115 | },
116 | "source": [
117 | "from matplotlib.pyplot import imshow, imsave\n",
118 | "%matplotlib inline"
119 | ],
120 | "execution_count": 0,
121 | "outputs": []
122 | },
123 | {
124 | "cell_type": "code",
125 | "metadata": {
126 | "id": "OxrxlkiiKr2G",
127 | "colab_type": "code",
128 | "outputId": "2886f0bb-fd62-4b33-bb7a-c62ba8f09bcc",
129 | "colab": {
130 | "base_uri": "https://localhost:8080/",
131 | "height": 35
132 | }
133 | },
134 | "source": [
135 | "MODEL_NAME = 'GAN'\n",
136 | "DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
137 | "print(\"DEVICE =\",DEVICE)"
138 | ],
139 | "execution_count": 6,
140 | "outputs": [
141 | {
142 | "output_type": "stream",
143 | "text": [
144 | "DEVICE = cuda\n"
145 | ],
146 | "name": "stdout"
147 | }
148 | ]
149 | },
150 | {
151 | "cell_type": "code",
152 | "metadata": {
153 | "id": "eHLsegE6Kr2I",
154 | "colab_type": "code",
155 | "colab": {}
156 | },
157 | "source": [
158 | "# G: generator\n",
159 | "def get_sample_image(G, n_noise):\n",
160 | " \"\"\"\n",
161 | " synthesize samples from random noise\n",
162 | " \"\"\"\n",
163 | " z = torch.randn(100, n_noise).to(DEVICE) # generate 100 random noise vectors\n",
164 | " y_hat = G(z).view(100, 28, 28) # (100, 28, 28)\n",
165 | " result = y_hat.cpu().data.numpy()\n",
166 | " img = np.zeros([280, 280]) # 10x10 grid to tile 100 images on \n",
167 | " for j in range(10):\n",
168 | " img[j*28:(j+1)*28] = np.concatenate([x for x in result[j*10:(j+1)*10]], axis=-1)\n",
169 | " return img"
170 | ],
171 | "execution_count": 0,
172 | "outputs": []
173 | },
174 | {
175 | "cell_type": "markdown",
176 | "metadata": {
177 | "id": "wtTRuccjKr2J",
178 | "colab_type": "text"
179 | },
180 | "source": [
181 | "**get_sample_image**: GAN 모델 학습 후에 Generator가 이미지를 잘 만드는지 확인하기 위한 함수\n",
182 | "\n",
183 | "**Line 5**: Generator의 input으로 사용될 noise를 배치 사이즈만큼 sampling 한다.\n",
184 | "\n",
185 | "**Line 6**: Generator의 output을 이미지 형태로 reshape한다.\n",
186 | "\n",
187 | "**Line 7**: gpu에 있는 데이터를 cpu로 가져오고 graph와 관계없이 데이터 자체에 대해서 numpy로 변환한다.\n",
188 | "\n",
189 | "**Line 8**: 시각화를 위해 저장할 배열 선언\n",
190 | "\n",
191 | "**Line 9-10**: 만들어낸 이미지 100장에 대해서 시각화하기 위해 (8)에서 선언한 배열에 반복적으로 부분 저장"
192 | ]
193 | },
194 | {
195 | "cell_type": "markdown",
196 | "metadata": {
197 | "id": "Y_m-0GbDKr2K",
198 | "colab_type": "text"
199 | },
200 | "source": [
201 | ""
202 | ]
203 | },
204 | {
205 | "cell_type": "code",
206 | "metadata": {
207 | "id": "lxwPYqTFKr2K",
208 | "colab_type": "code",
209 | "colab": {}
210 | },
211 | "source": [
212 | "class Discriminator(nn.Module):\n",
213 | " \"\"\"\n",
214 | " Simple Discriminator w/ MLP\n",
215 | " \"\"\"\n",
216 | " def __init__(self, input_size=784, num_classes=1):\n",
217 | " super(Discriminator, self).__init__()\n",
218 | " self.layer = nn.Sequential(\n",
219 | " nn.Linear(input_size, 512),\n",
220 | " nn.LeakyReLU(0.2),\n",
221 | " nn.Linear(512, 256),\n",
222 | " nn.LeakyReLU(0.2),\n",
223 | " nn.Linear(256, num_classes),\n",
224 | " nn.Sigmoid(),\n",
225 | " )\n",
226 | " \n",
227 | " def forward(self, x):\n",
228 | " y_ = x.view(x.size(0), -1)\n",
229 | " y_ = self.layer(y_)\n",
230 | " return y_"
231 | ],
232 | "execution_count": 0,
233 | "outputs": []
234 | },
235 | {
236 | "cell_type": "markdown",
237 | "metadata": {
238 | "id": "ZRJxzVrYKr2M",
239 | "colab_type": "text"
240 | },
241 | "source": [
242 | "**Discriminator**: GAN의 Discriminator model architecture 정의. Fully-connected layer 3개로 구성. 마지막은 확률로 나타내기 위해 sigmoid 사용"
243 | ]
244 | },
245 | {
246 | "cell_type": "code",
247 | "metadata": {
248 | "id": "6Ab_cCLvKr2M",
249 | "colab_type": "code",
250 | "colab": {}
251 | },
252 | "source": [
253 | "class Generator(nn.Module):\n",
254 | " \"\"\"\n",
255 | " Simple Generator w/ MLP\n",
256 | " \"\"\"\n",
257 | " def __init__(self, input_size=100, num_classes=784):\n",
258 | " super(Generator, self).__init__()\n",
259 | " self.layer = nn.Sequential(\n",
260 | " nn.Linear(input_size, 128),\n",
261 | " nn.LeakyReLU(0.2),\n",
262 | " nn.Linear(128, 256),\n",
263 | " nn.BatchNorm1d(256),\n",
264 | " nn.LeakyReLU(0.2),\n",
265 | " nn.Linear(256, 512),\n",
266 | " nn.BatchNorm1d(512),\n",
267 | " nn.LeakyReLU(0.2),\n",
268 | " nn.Linear(512, 1024),\n",
269 | " nn.BatchNorm1d(1024),\n",
270 | " nn.LeakyReLU(0.2),\n",
271 | " nn.Linear(1024, num_classes),\n",
272 | " nn.Tanh()\n",
273 | " )\n",
274 | " \n",
275 | " def forward(self, x):\n",
276 | " y_ = self.layer(x)\n",
277 | " y_ = y_.view(x.size(0), 1, 28, 28)\n",
278 | " return y_"
279 | ],
280 | "execution_count": 0,
281 | "outputs": []
282 | },
283 | {
284 | "cell_type": "markdown",
285 | "metadata": {
286 | "id": "-PnT3Dp0Kr2O",
287 | "colab_type": "text"
288 | },
289 | "source": [
290 | "**Generator**: 일반적으로 Generator는 Discriminator보다 학습하기 어려우므로 더 깊게 FC layer 5개로 구성."
291 | ]
292 | },
293 | {
294 | "cell_type": "code",
295 | "metadata": {
296 | "id": "MUwWx_QeKr2O",
297 | "colab_type": "code",
298 | "colab": {}
299 | },
300 | "source": [
301 | "# dimension of noise vector\n",
302 | "n_noise = 100"
303 | ],
304 | "execution_count": 0,
305 | "outputs": []
306 | },
307 | {
308 | "cell_type": "code",
309 | "metadata": {
310 | "id": "XNd1_mBUKr2R",
311 | "colab_type": "code",
312 | "colab": {}
313 | },
314 | "source": [
315 | "D = Discriminator().to(DEVICE)\n",
316 | "G = Generator(n_noise).to(DEVICE)"
317 | ],
318 | "execution_count": 0,
319 | "outputs": []
320 | },
321 | {
322 | "cell_type": "markdown",
323 | "metadata": {
324 | "id": "9fUu_RjDKr2U",
325 | "colab_type": "text"
326 | },
327 | "source": [
328 | "각각 모델을 메모리에 올리는 작업"
329 | ]
330 | },
331 | {
332 | "cell_type": "code",
333 | "metadata": {
334 | "id": "GG_o76a7Kr2V",
335 | "colab_type": "code",
336 | "colab": {}
337 | },
338 | "source": [
339 | "transform = transforms.Compose([transforms.ToTensor(),\n",
340 | " #transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))\n",
341 | " transforms.Normalize(mean=(0.1307,), std=(0.3081,))\n",
342 | " ]\n",
343 | ")"
344 | ],
345 | "execution_count": 0,
346 | "outputs": []
347 | },
348 | {
349 | "cell_type": "markdown",
350 | "metadata": {
351 | "id": "AW26U8N2Kr2a",
352 | "colab_type": "text"
353 | },
354 | "source": [
355 | "**transforms** torchvision에서 제공하는 transform 함수들이 있는 패키지.\n",
356 | "\n",
357 | "**ToTensor**는 numpy array를 torch tensor로 변환.\n",
358 | "\n",
359 | "**Normalize**는 다음과 같이 계산함. input[channel] = (input[channel] - mean[channel]) / std[channel]"
360 | ]
361 | },
362 | {
363 | "cell_type": "code",
364 | "metadata": {
365 | "id": "71g2OJRzKr2b",
366 | "colab_type": "code",
367 | "colab": {
368 | "base_uri": "https://localhost:8080/",
369 | "height": 289
370 | },
371 | "outputId": "296de745-f217-4bbc-f6e1-09ba2fbfb07b"
372 | },
373 | "source": [
374 | "mnist = datasets.MNIST(root='../data/', train=True, transform=transform, download=True)"
375 | ],
376 | "execution_count": 13,
377 | "outputs": [
378 | {
379 | "output_type": "stream",
380 | "text": [
381 | "\r0it [00:00, ?it/s]"
382 | ],
383 | "name": "stderr"
384 | },
385 | {
386 | "output_type": "stream",
387 | "text": [
388 | "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../data/MNIST/raw/train-images-idx3-ubyte.gz\n"
389 | ],
390 | "name": "stdout"
391 | },
392 | {
393 | "output_type": "stream",
394 | "text": [
395 | "9920512it [00:01, 9271069.17it/s] \n"
396 | ],
397 | "name": "stderr"
398 | },
399 | {
400 | "output_type": "stream",
401 | "text": [
402 | "Extracting ../data/MNIST/raw/train-images-idx3-ubyte.gz to ../data/MNIST/raw\n"
403 | ],
404 | "name": "stdout"
405 | },
406 | {
407 | "output_type": "stream",
408 | "text": [
409 | " 0%| | 0/28881 [00:00, ?it/s]"
410 | ],
411 | "name": "stderr"
412 | },
413 | {
414 | "output_type": "stream",
415 | "text": [
416 | "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../data/MNIST/raw/train-labels-idx1-ubyte.gz\n"
417 | ],
418 | "name": "stdout"
419 | },
420 | {
421 | "output_type": "stream",
422 | "text": [
423 | "32768it [00:00, 143204.96it/s] \n",
424 | " 0%| | 0/1648877 [00:00, ?it/s]"
425 | ],
426 | "name": "stderr"
427 | },
428 | {
429 | "output_type": "stream",
430 | "text": [
431 | "Extracting ../data/MNIST/raw/train-labels-idx1-ubyte.gz to ../data/MNIST/raw\n",
432 | "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw/t10k-images-idx3-ubyte.gz\n"
433 | ],
434 | "name": "stdout"
435 | },
436 | {
437 | "output_type": "stream",
438 | "text": [
439 | "1654784it [00:00, 2251986.41it/s] \n",
440 | "8192it [00:00, 54345.52it/s] \n"
441 | ],
442 | "name": "stderr"
443 | },
444 | {
445 | "output_type": "stream",
446 | "text": [
447 | "Extracting ../data/MNIST/raw/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw\n",
448 | "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz\n",
449 | "Extracting ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw\n",
450 | "Processing...\n",
451 | "Done!\n"
452 | ],
453 | "name": "stdout"
454 | }
455 | ]
456 | },
457 | {
458 | "cell_type": "markdown",
459 | "metadata": {
460 | "id": "-UTLB2vbKr2e",
461 | "colab_type": "text"
462 | },
463 | "source": [
464 | "GAN에서는 noise sample로부터 새로운 이미지를 만들어내는 작업이기 때문에 따로 test set을 불러올 필요가 없음."
465 | ]
466 | },
467 | {
468 | "cell_type": "code",
469 | "metadata": {
470 | "id": "dyq1l4XnKr2f",
471 | "colab_type": "code",
472 | "colab": {}
473 | },
474 | "source": [
475 | "batch_size = 64"
476 | ],
477 | "execution_count": 0,
478 | "outputs": []
479 | },
480 | {
481 | "cell_type": "code",
482 | "metadata": {
483 | "id": "5wTnXfUKKr2h",
484 | "colab_type": "code",
485 | "colab": {}
486 | },
487 | "source": [
488 | "data_loader = DataLoader(dataset=mnist, batch_size=batch_size, shuffle=True, drop_last=True)"
489 | ],
490 | "execution_count": 0,
491 | "outputs": []
492 | },
493 | {
494 | "cell_type": "markdown",
495 | "metadata": {
496 | "id": "JKw0DIefKr2k",
497 | "colab_type": "text"
498 | },
499 | "source": [
500 | "**DataLoader**는 pytorch에서 학습 시에 데이터를 배치 사이즈만큼씩 효율적으로 불러오도록 돕는 클래스. 잘 사용할수록 GPU의 사용률이 올라간다.\n",
501 | "\n",
502 | "**shuffle**: every epochs 마다 데이터의 순서를 랜덤하게 섞는다.\n",
503 | "\n",
504 | "**drop_last**: 데이터의 개수가 배치 사이즈로 나눠떨어지지 않는 경우, 마지막 배치를 버린다. 주로 학습시에만 사용."
505 | ]
506 | },
507 | {
508 | "cell_type": "code",
509 | "metadata": {
510 | "id": "r94AbEHEKr2l",
511 | "colab_type": "code",
512 | "colab": {}
513 | },
514 | "source": [
515 | "criterion = nn.BCELoss()\n",
516 | "D_opt = torch.optim.Adam(D.parameters(), lr=0.0002, betas=(0.5, 0.999))\n",
517 | "G_opt = torch.optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999))"
518 | ],
519 | "execution_count": 0,
520 | "outputs": []
521 | },
522 | {
523 | "cell_type": "markdown",
524 | "metadata": {
525 | "id": "EamQvm0gKr2n",
526 | "colab_type": "text"
527 | },
528 | "source": [
529 | "**GAN objective**\n",
530 | "\n",
531 | "$$\\min_G \\max_D V(D,G) = \\mathbb{E}_{x\\sim p_{data}~(x)}[log D(x)] + \\mathbb{E}_{z\\sim p_x(z)}[log(1-D(G(z)))]$$\n",
532 | "\n",
533 | "**nn.BCELoss**: Binary Cross Entropy\n",
534 | "\n",
535 | "$$-{[y\\log(\\hat{y}) + (1 - y)\\log(1 - \\hat{y})]}$$\n",
536 | "\n",
537 | "**ADAM betas**: exponential decay rates for the moment estimates. (default: (0.9, 0.999))\n",
538 | "\n",
539 | "$$\\beta_1, \\beta_2 \\in [0,1)$$"
540 | ]
541 | },
542 | {
543 | "cell_type": "code",
544 | "metadata": {
545 | "id": "1LC42zZeKr2n",
546 | "colab_type": "code",
547 | "colab": {}
548 | },
549 | "source": [
550 | "max_epoch = 50 # need more than 10 epochs for training generator\n",
551 | "step = 0\n",
552 | "n_critic = 1 # for training more k steps about Discriminator"
553 | ],
554 | "execution_count": 0,
555 | "outputs": []
556 | },
557 | {
558 | "cell_type": "code",
559 | "metadata": {
560 | "id": "NbKnzrnKKr2q",
561 | "colab_type": "code",
562 | "colab": {}
563 | },
564 | "source": [
565 | "D_labels = torch.ones(batch_size, 1).to(DEVICE) # Discriminator Label to real\n",
566 | "D_fakes = torch.zeros(batch_size, 1).to(DEVICE) # Discriminator Label to fake"
567 | ],
568 | "execution_count": 0,
569 | "outputs": []
570 | },
571 | {
572 | "cell_type": "markdown",
573 | "metadata": {
574 | "id": "ccpkV-j_Kr2t",
575 | "colab_type": "text"
576 | },
577 | "source": [
578 | "Discriminator를 학습할 때는 **D(x)**는 1이 나오도록 **D(G(z))**에 대해서는 0이 나오도록,\n",
579 | "\n",
580 | "Generator를 학습할 때는 **D(G(z))**에 대해 1이 나오도록 학습하기 위해 Discriminator의 label 준비"
581 | ]
582 | },
583 | {
584 | "cell_type": "code",
585 | "metadata": {
586 | "id": "hJPlkSL4Kr2u",
587 | "colab_type": "code",
588 | "colab": {}
589 | },
590 | "source": [
591 | "if not os.path.exists('samples'):\n",
592 | " os.makedirs('samples')"
593 | ],
594 | "execution_count": 0,
595 | "outputs": []
596 | },
597 | {
598 | "cell_type": "markdown",
599 | "metadata": {
600 | "id": "qgwXALRMKr2v",
601 | "colab_type": "text"
602 | },
603 | "source": [
604 | "### Training Code\n",
605 | "\n"
606 | ]
607 | },
608 | {
609 | "cell_type": "markdown",
610 | "metadata": {
611 | "id": "qZzBqIytKr2w",
612 | "colab_type": "text"
613 | },
614 | "source": [
615 | "**line 2**: GAN에서는 MNIST의 class가 필요없으므로 image만 취함.\n",
616 | "\n",
617 | "**line 4-6**: Real sample **x**에 대해 D를 forward하고 loss 계산\n",
618 | "\n",
619 | "**line 8-10**: Fake sample **G(z)**에 대해 D를 forward하고 loss 계산\n",
620 | "\n",
621 | "**line 11**: 위 2개의 loss를 합산\n",
622 | "\n",
623 | "**line 13-15**: gradient 초기화 --> backward하면서 계산 --> parameter 업데이트\n",
624 | "\n",
625 | "**line 17**: GAN에서 gradient를 계산하는 방법은 Discriminator에 의존하므로 D를 잘 학습시키기 위해 G보다 여러번 학습시키는 테크닉. W-GAN에서 사용\n",
626 | "\n",
627 | "**line 19-25**: Generator를 위해 (8-10)과 반대로 loss를 계산하는 부분. non saturating loss **-log(D(G(z)))**를 사용하는 이유는 아래 언급함.\n",
628 | "\n",
629 | "**line 21**: 이론적으로 **log(1-D(G(z)))**를 minimize하는 것이 맞으나, 학습 초기에 G가 이미지를 잘 만들지 못해 gradient가 작은 문제로 saturate 될 수 있으므로, 그 대안으로 제시된 것이 **log(D(G(z)))**를 maximize하는 것이다.\n",
630 | "\n",
631 | "**line 30-34**: 1000 step마다 Generator가 학습이 잘 되고 있는지 샘플 이미지 만들어서 저장"
632 | ]
633 | },
634 | {
635 | "cell_type": "code",
636 | "metadata": {
637 | "id": "u-w-CzhlsswE",
638 | "colab_type": "code",
639 | "colab": {}
640 | },
641 | "source": [
642 | "# run this cell to initialize G and D\n",
643 | "D = Discriminator().to(DEVICE)\n",
644 | "G = Generator(n_noise).to(DEVICE)"
645 | ],
646 | "execution_count": 0,
647 | "outputs": []
648 | },
649 | {
650 | "cell_type": "code",
651 | "metadata": {
652 | "scrolled": true,
653 | "id": "LTgCYt8WKr2w",
654 | "colab_type": "code",
655 | "outputId": "937655c8-c3f1-4b1a-dad0-9b39135e661d",
656 | "colab": {
657 | "base_uri": "https://localhost:8080/",
658 | "height": 989
659 | }
660 | },
661 | "source": [
662 | "G_loss_list = []\n",
663 | "D_loss_list = []\n",
664 | "\n",
665 | "for epoch in range(max_epoch):\n",
666 | " for idx, (images, _) in enumerate(data_loader):\n",
667 | " # Training Discriminator\n",
668 | " x = images.to(DEVICE)\n",
669 | " x_outputs = D(x)\n",
670 | " D_x_loss = criterion(x_outputs, D_labels)\n",
671 | "\n",
672 | " z = torch.randn(batch_size, n_noise).to(DEVICE)\n",
673 | " z_outputs = D(G(z))\n",
674 | " D_z_loss = criterion(z_outputs, D_fakes)\n",
675 | " D_loss = D_x_loss + D_z_loss\n",
676 | "\n",
677 | " \n",
678 | " D.zero_grad()\n",
679 | " D_loss.backward()\n",
680 | " D_opt.step()\n",
681 | "\n",
682 | " if step % n_critic == 0:\n",
683 | " # Training Generator\n",
684 | " z = torch.randn(batch_size, n_noise).to(DEVICE)\n",
685 | " z_outputs = D(G(z))\n",
686 | " G_loss = criterion(z_outputs, D_labels)\n",
687 | "\n",
688 | " G.zero_grad()\n",
689 | " G_loss.backward()\n",
690 | " G_opt.step()\n",
691 | " \n",
692 | " if step % 500 == 0:\n",
693 | " print('Epoch: {}/{}, Step: {}, D Loss: {}, G Loss: {}'.format(epoch, max_epoch, step, D_loss.item(), G_loss.item()))\n",
694 | " G_loss_list.append(G_loss.item())\n",
695 | " D_loss_list.append(D_loss.item())\n",
696 | " \n",
697 | " if step % 1000 == 0:\n",
698 | " G.eval()\n",
699 | " img = get_sample_image(G, n_noise)\n",
700 | " imsave('samples/{}_step{}.jpg'.format(MODEL_NAME, str(step).zfill(3)), img, cmap='gray')\n",
701 | " G.train()\n",
702 | " step += 1"
703 | ],
704 | "execution_count": 0,
705 | "outputs": [
706 | {
707 | "output_type": "stream",
708 | "text": [
709 | "Epoch: 0/50, Step: 11000, D Loss: 0.017072804272174835, G Loss: 6.205540657043457\n",
710 | "Epoch: 0/50, Step: 11500, D Loss: 0.03962019830942154, G Loss: 8.96841049194336\n",
711 | "Epoch: 1/50, Step: 12000, D Loss: 0.009823893196880817, G Loss: 8.245841979980469\n",
712 | "Epoch: 1/50, Step: 12500, D Loss: 0.005063357762992382, G Loss: 7.6561479568481445\n",
713 | "Epoch: 2/50, Step: 13000, D Loss: 0.002445463789626956, G Loss: 7.686934947967529\n",
714 | "Epoch: 2/50, Step: 13500, D Loss: 0.07487145066261292, G Loss: 5.739070892333984\n",
715 | "Epoch: 3/50, Step: 14000, D Loss: 0.0020278235897421837, G Loss: 8.868522644042969\n",
716 | "Epoch: 3/50, Step: 14500, D Loss: 0.009442816488444805, G Loss: 7.217374801635742\n",
717 | "Epoch: 4/50, Step: 15000, D Loss: 0.005332832224667072, G Loss: 10.235108375549316\n",
718 | "Epoch: 4/50, Step: 15500, D Loss: 0.014287389814853668, G Loss: 6.958813667297363\n",
719 | "Epoch: 5/50, Step: 16000, D Loss: 0.004800278227776289, G Loss: 6.8694047927856445\n",
720 | "Epoch: 5/50, Step: 16500, D Loss: 0.1734900027513504, G Loss: 5.07503604888916\n",
721 | "Epoch: 6/50, Step: 17000, D Loss: 0.029605243355035782, G Loss: 5.75632381439209\n",
722 | "Epoch: 6/50, Step: 17500, D Loss: 0.01813247799873352, G Loss: 8.339418411254883\n",
723 | "Epoch: 7/50, Step: 18000, D Loss: 0.004233901854604483, G Loss: 7.612527847290039\n",
724 | "Epoch: 8/50, Step: 18500, D Loss: 0.01574200764298439, G Loss: 10.41611385345459\n",
725 | "Epoch: 8/50, Step: 19000, D Loss: 0.008210530504584312, G Loss: 7.209020614624023\n",
726 | "Epoch: 9/50, Step: 19500, D Loss: 0.0008720408077351749, G Loss: 8.900108337402344\n",
727 | "Epoch: 9/50, Step: 20000, D Loss: 0.009653720073401928, G Loss: 8.031501770019531\n",
728 | "Epoch: 10/50, Step: 20500, D Loss: 0.014342648908495903, G Loss: 5.6937713623046875\n",
729 | "Epoch: 10/50, Step: 21000, D Loss: 0.08091247826814651, G Loss: 7.744855880737305\n",
730 | "Epoch: 11/50, Step: 21500, D Loss: 0.0034302675630897284, G Loss: 8.516478538513184\n",
731 | "Epoch: 11/50, Step: 22000, D Loss: 0.002264691051095724, G Loss: 7.806923866271973\n",
732 | "Epoch: 12/50, Step: 22500, D Loss: 0.009318589232861996, G Loss: 9.593717575073242\n",
733 | "Epoch: 12/50, Step: 23000, D Loss: 0.030135272070765495, G Loss: 7.173668384552002\n",
734 | "Epoch: 13/50, Step: 23500, D Loss: 0.0029895356856286526, G Loss: 8.691886901855469\n",
735 | "Epoch: 13/50, Step: 24000, D Loss: 0.007704799063503742, G Loss: 7.48193359375\n",
736 | "Epoch: 14/50, Step: 24500, D Loss: 0.003973798826336861, G Loss: 7.132073879241943\n",
737 | "Epoch: 14/50, Step: 25000, D Loss: 0.003595698159188032, G Loss: 6.773451805114746\n",
738 | "Epoch: 15/50, Step: 25500, D Loss: 0.0008625903865322471, G Loss: 9.139907836914062\n",
739 | "Epoch: 16/50, Step: 26000, D Loss: 0.002710268134251237, G Loss: 8.221826553344727\n",
740 | "Epoch: 16/50, Step: 26500, D Loss: 0.02643599361181259, G Loss: 21.747425079345703\n",
741 | "Epoch: 17/50, Step: 27000, D Loss: 0.005619286559522152, G Loss: 7.38516902923584\n",
742 | "Epoch: 17/50, Step: 27500, D Loss: 0.0014947939198464155, G Loss: 8.017923355102539\n",
743 | "Epoch: 18/50, Step: 28000, D Loss: 0.018260205164551735, G Loss: 9.57470703125\n",
744 | "Epoch: 18/50, Step: 28500, D Loss: 0.0028417676221579313, G Loss: 7.531101226806641\n",
745 | "Epoch: 19/50, Step: 29000, D Loss: 0.017717955633997917, G Loss: 11.651924133300781\n",
746 | "Epoch: 19/50, Step: 29500, D Loss: 0.004795431159436703, G Loss: 8.255350112915039\n",
747 | "Epoch: 20/50, Step: 30000, D Loss: 0.0006288132863119245, G Loss: 10.273591995239258\n",
748 | "Epoch: 20/50, Step: 30500, D Loss: 0.0030029569752514362, G Loss: 8.368170738220215\n",
749 | "Epoch: 21/50, Step: 31000, D Loss: 0.0011793775483965874, G Loss: 7.983173370361328\n",
750 | "Epoch: 21/50, Step: 31500, D Loss: 0.006279204972088337, G Loss: 6.6801252365112305\n",
751 | "Epoch: 22/50, Step: 32000, D Loss: 0.01721380278468132, G Loss: 8.927534103393555\n",
752 | "Epoch: 22/50, Step: 32500, D Loss: 0.00010531823500059545, G Loss: 11.091259002685547\n",
753 | "Epoch: 23/50, Step: 33000, D Loss: 0.0005428884760476649, G Loss: 9.457132339477539\n",
754 | "Epoch: 24/50, Step: 33500, D Loss: 0.026661192998290062, G Loss: 42.57115173339844\n",
755 | "Epoch: 24/50, Step: 34000, D Loss: 0.0012832947541028261, G Loss: 7.962080001831055\n",
756 | "Epoch: 25/50, Step: 34500, D Loss: 0.002940555103123188, G Loss: 8.231060028076172\n",
757 | "Epoch: 25/50, Step: 35000, D Loss: 0.0009441556758247316, G Loss: 9.213018417358398\n",
758 | "Epoch: 26/50, Step: 35500, D Loss: 0.0306414645165205, G Loss: 6.855942249298096\n",
759 | "Epoch: 26/50, Step: 36000, D Loss: 0.0021525942720472813, G Loss: 12.038948059082031\n",
760 | "Epoch: 27/50, Step: 36500, D Loss: 0.000524238683283329, G Loss: 8.796098709106445\n",
761 | "Epoch: 27/50, Step: 37000, D Loss: 0.014627888798713684, G Loss: 9.935175895690918\n",
762 | "Epoch: 28/50, Step: 37500, D Loss: 0.021666239947080612, G Loss: 7.52801513671875\n"
763 | ],
764 | "name": "stdout"
765 | }
766 | ]
767 | },
768 | {
769 | "cell_type": "code",
770 | "metadata": {
771 | "id": "qdRvG-oPsIyM",
772 | "colab_type": "code",
773 | "colab": {}
774 | },
775 | "source": [
776 | "import matplotlib.pyplot as plt\n",
777 | "plt.ion()\n",
778 | "\n",
779 | "fig = plt.figure()\n",
780 | "plt.plot(G_loss_list, label='generator')\n",
781 | "plt.plot(D_loss_list, label='discriminator')\n",
782 | "plt.xlabel('x 500 steps')\n",
783 | "plt.ylabel('Loss')\n",
784 | "plt.legend()\n",
785 | "plt.show()"
786 | ],
787 | "execution_count": 0,
788 | "outputs": []
789 | },
790 | {
791 | "cell_type": "markdown",
792 | "metadata": {
793 | "id": "TA1DWc2LKr2y",
794 | "colab_type": "text"
795 | },
796 | "source": [
797 | "## Visualize Sample"
798 | ]
799 | },
800 | {
801 | "cell_type": "code",
802 | "metadata": {
803 | "id": "5Z1c8k4XKr2z",
804 | "colab_type": "code",
805 | "colab": {}
806 | },
807 | "source": [
808 | "# generation to image\n",
809 | "G.eval()\n",
810 | "imshow(get_sample_image(G, n_noise), cmap='gray')"
811 | ],
812 | "execution_count": 0,
813 | "outputs": []
814 | },
815 | {
816 | "cell_type": "code",
817 | "metadata": {
818 | "id": "syY8Vn-_Kr21",
819 | "colab_type": "code",
820 | "colab": {}
821 | },
822 | "source": [
823 | "# Saving params.\n",
824 | "torch.save(D.state_dict(), 'D.pkl')\n",
825 | "torch.save(G.state_dict(), 'G.pkl')"
826 | ],
827 | "execution_count": 0,
828 | "outputs": []
829 | },
830 | {
831 | "cell_type": "code",
832 | "metadata": {
833 | "id": "R6zgG3o_Kr22",
834 | "colab_type": "code",
835 | "colab": {}
836 | },
837 | "source": [
838 | ""
839 | ],
840 | "execution_count": 0,
841 | "outputs": []
842 | }
843 | ]
844 | }
845 |
--------------------------------------------------------------------------------
/GAN_losses.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "GAN_losses.ipynb",
7 | "version": "0.3.2",
8 | "provenance": [],
9 | "collapsed_sections": [],
10 | "include_colab_link": true
11 | },
12 | "kernelspec": {
13 | "name": "python3",
14 | "display_name": "Python 3"
15 | }
16 | },
17 | "cells": [
18 | {
19 | "cell_type": "markdown",
20 | "metadata": {
21 | "id": "view-in-github",
22 | "colab_type": "text"
23 | },
24 | "source": [
25 | "
"
26 | ]
27 | },
28 | {
29 | "cell_type": "code",
30 | "metadata": {
31 | "id": "7V5WK9jaCvKL",
32 | "colab_type": "code",
33 | "colab": {}
34 | },
35 | "source": [
36 | "import numpy as np\n",
37 | "import matplotlib.pyplot as plt"
38 | ],
39 | "execution_count": 0,
40 | "outputs": []
41 | },
42 | {
43 | "cell_type": "markdown",
44 | "metadata": {
45 | "id": "g5efuFGPJRg-",
46 | "colab_type": "text"
47 | },
48 | "source": [
49 | "# Comparison of original GAN loss, alternative GAN loss, and WGAN loss\n",
50 | "\n",
51 | "* $org\\_loss: V(G,D)=E_{x \\sim P_{data}}[log D(x)]+E_{z \\sim P_z}[log (1-D(G(z)))]$\n",
52 | "\n",
53 | "* $alternative\\_loss: V(G,D)=E_{x \\sim P_{data}}[log D(x)]-E_{z \\sim P_z}[log D(G(z))]$\n",
54 | "\n",
55 | "* $wgan\\_loss: V(G,D)=E_{x \\sim P_{data}}[D(x)]-E_{z \\sim P_z}[D(G(z))]$"
56 | ]
57 | },
58 | {
59 | "cell_type": "code",
60 | "metadata": {
61 | "id": "Plgjq4IuDtLO",
62 | "colab_type": "code",
63 | "colab": {}
64 | },
65 | "source": [
66 | "# Consider only the loss for the generator\n",
67 | "\n",
68 | "x = np.linspace(0.01, 0.99, 100) # x corresponds D(G(z))\n",
69 | "ones = np.ones(x.shape)\n",
70 | "org_loss = np.log(ones-x) # original GAN\n",
71 | "\n",
72 | "alternative_loss = -np.log(x) # alternative GAN loss\n",
73 | "\n",
74 | "wgan_loss = -x # WGAN loss"
75 | ],
76 | "execution_count": 0,
77 | "outputs": []
78 | },
79 | {
80 | "cell_type": "code",
81 | "metadata": {
82 | "id": "A_woz9uwDnqO",
83 | "colab_type": "code",
84 | "colab": {
85 | "base_uri": "https://localhost:8080/",
86 | "height": 287
87 | },
88 | "outputId": "fb2e71fc-e916-4f98-8b8e-f3ffd5167e21"
89 | },
90 | "source": [
91 | "plt.plot(x, org_loss)\n",
92 | "plt.plot(x, alternative_loss)\n",
93 | "plt.plot(x, wgan_loss)\n",
94 | "plt.legend([\"orginal gan loss\", \"alternative gan loss\", \"wgan loss\"])"
95 | ],
96 | "execution_count": 3,
97 | "outputs": [
98 | {
99 | "output_type": "execute_result",
100 | "data": {
101 | "text/plain": [
102 | ""
103 | ]
104 | },
105 | "metadata": {
106 | "tags": []
107 | },
108 | "execution_count": 3
109 | },
110 | {
111 | "output_type": "display_data",
112 | "data": {
113 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXYAAAD8CAYAAABjAo9vAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzt3Xl8VNX9//HXmckkk2Qmkz1kAQLI\nngQIEOCLAYWKC26gaF1q64JVf9b2a8sXv9ZavqX9/mpd66+tVmvdW21VXAotQgVFRREQRBZZA4SE\n7Jlkkky2Ob8/7iQEyJ5JJpl8no/Hfczkzp17z03gfc8999xzldYaIYQQgcPk7wIIIYTwLQl2IYQI\nMBLsQggRYCTYhRAiwEiwCyFEgJFgF0KIACPBLoQQAUaCXQghAowEuxBCBJggf2w0NjZWp6am+mPT\nQggxYG3btq1Yax3X0XJ+CfbU1FS2bt3qj00LIcSApZQ62pnlpClGCCECjAS7EEIEGAl2IYQIMH5p\nYxdCGOrr68nNzcXtdvu7KKIfsVqtpKSkYLFYuvV9CXYh/Cg3Nxe73U5qaipKKX8XR/QDWmtKSkrI\nzc1lxIgR3VqHNMUI4Udut5uYmBgJddFMKUVMTEyPzuIk2IXwMwl1caae/psYWMF+YB1seszfpRBC\niH5tYAX74Y2w8dfQUOfvkgghvJ5++mleeumlbn03JyeHtLQ0H5eo/2zPXwbWxdPkTGishcI9kDTZ\n36URYlDQWqO1xmRqvR54xx139HGJREcGVo09KdN4zdvu33IIEUAee+wx0tLSSEtL44knngCMmu3Y\nsWO56aabSEtL4/jx4zz33HOMGTOGrKwsli5dyt133w3AihUreOSRRwA477zzWL58OVlZWYwZM4ZN\nmzY1ry87O5vMzEwyMzP59NNP2y2Tx+PhrrvuYty4cVxwwQVccsklvPHGGwD84he/YPr06aSlpXH7\n7bejtW53221xu93cfPPNpKenM2XKFDZs2ADA7t27ycrKYvLkyWRkZHDgwAGqqqpYuHAhkyZNIi0t\njddff72bv+2+MbBq7FGpEBoNJ7bBtFv8XRohfOp/3tvNnrwKn65zQlIEP79sYpufb9u2jeeff57P\nP/8crTUzZsxg7ty5REVFceDAAV588UVmzpxJXl4eK1euZPv27djtdubNm8ekSZNaXWdDQwNbtmxh\nzZo1/M///A/r168nPj6edevWYbVaOXDgANddd12740W99dZb5OTksGfPHgoLCxk/fjy33GL8n7/7\n7rt58MEHAfjOd77DP/7xDy677LI2t92W3//+9yil2LVrF/v27WPBggXs37+fp59+mh/+8IfccMMN\n1NXV0djYyJo1a0hKSmL16tUAOJ3O9n/xfjawauxKGc0xJ770d0mECAgff/wxixYtIjw8HJvNxuLF\ni5trusOHD2fmzJkAbNmyhblz5xIdHY3FYmHJkiVtrnPx4sUATJ06lZycHMC4EWvp0qWkp6ezZMkS\n9uzZ02G5lixZgslkYsiQIZx//vnNn23YsIEZM2aQnp7OBx98wO7du9vddnvbuPHGGwEYN24cw4cP\nZ//+/cyaNYv//d//5aGHHuLo0aOEhoaSnp7OunXrWL58OZs2bcLhcLS7bn8bWDV2MJpjDj0CdVUQ\nHO7v0gjhM+3VrP0hPLx7/79CQkIAMJvNNDQ0APD444+TkJDAzp078Xg8WK3Wbq3b7XZz1113sXXr\nVoYOHcqKFStO6+/d2ra76vrrr2fGjBmsXr2aSy65hD/+8Y/MmzeP7du3s2bNGh544AHmz5/ffNbQ\nHw2sGjsYNXbtgfyv/F0SIQa87Oxs3n77baqrq6mqqmLVqlVkZ2eftdz06dP58MMPKSsro6GhgTff\nfLNL23E6nSQmJmIymXj55ZdpbGxsd/nZs2fz5ptv4vF4KCgoYOPGjQDNIR4bG4vL5Wpud++O7Oxs\nXn31VQD279/PsWPHGDt2LIcPH2bkyJHcc889XHHFFXz11Vfk5eURFhbGjTfeyLJly9i+vX9f5xuY\nNXYw2tmHz/JvWYQY4DIzM/ne975HVlYWALfddhtTpkw5qxkjOTmZ+++/n6ysLKKjoxk3blyXmiPu\nuusurrrqKl566SUuuuiiDs8GrrrqKv79738zYcIEhg4dSmZmJg6Hg8jISJYuXUpaWhpDhgxh+vTp\nXd7nlmW68847SU9PJygoiBdeeIGQkBD+9re/8fLLL2OxWBgyZAj3338/X3zxBcuWLcNkMmGxWHjq\nqae6vd2+oJquKPeladOm6R49aOOxiTBsBlz9Z98VSgg/2Lt3L+PHj/d3MTrF5XJhs9loaGhg0aJF\n3HLLLSxatKjXt1dSUkJWVhaffPIJQ4YM6bXt9Tet/dtQSm3TWk/r6LsDr8YOkDwFTvTvUyEhAs2K\nFStYv349brebBQsWcOWVV/bq9i699FLKy8upq6vjZz/72aAK9Z7yWbArpczAVuCE1vpSX623VclT\nYe97UF0KYdG9uikhhKGpr3pfaWpXF13ny4unPwT2+nB9bZMblYQQok0+CXalVAqwEPiTL9bXoabh\nBKQ/uxBCnMVXNfYngP8CPG0toJS6XSm1VSm1taioqGdbszogZrTU2IUQohU9Dnal1KVAodZ6W3vL\naa2f0VpP01pPi4uL6+lmjXb2E9vAD716hBCiP/NFjX02cLlSKgd4DZinlHrFB+tt34hscBVI7xgh\neklqairFxcWUl5fzhz/8oU+2uXHjxtMGCOvJkMB9oeUAaP1Jj4Nda/3fWusUrXUq8G3gA631jT0u\nWUfGXwbmENj1t17flBCDWXeCXWuNx9Nmy2ybzgz2O+64g5tuuqnL6xnsBt6QAk2sDhizAL5+Cxq7\nNyaEEAKuvPJKpk6dysSJE3nmmWfO+vy+++7j0KFDTJ48mWXLlgHw8MMPM336dDIyMvj5z38OtD7U\nr81m46c//SmTJk1i5syZFBQUAPDee+8xY8YMpkyZwre+9S0KCgrIycnh6aef5vHHH2fy5Mls2rSp\nuUa8b9++5rtjm7aVnp4OGCNUzp07l6lTp3LhhReSn59/1j4cOnSImTNnkp6ezgMPPIDNZgOMm6Dm\nz59PZmYm6enpvPPOO83rHz9+PEuXLmXixIksWLCAmpqadn+PO3bsYObMmWRkZLBo0SLKysoAePLJ\nJ5kwYQIZGRl8+9vfBuDDDz9k8uTJTJ48mSlTplBZWdn5P1hnNA2i35fT1KlTtU/sfkfrn0dofWC9\nb9YnRB/bs2fPqR/WLNf6z5f4dlqzvMMylJSUaK21rq6u1hMnTtTFxcVaa62HDx+ui4qK9JEjR/TE\niRObl1+7dq1eunSp9ng8urGxUS9cuFB/+OGH+siRI1oppTdv3ty8LKDfffddrbXWy5Yt0ytXrtRa\na11aWqo9Ho/WWutnn31W33vvvVprrX/+85/rhx9+uPn7LX+eNGmSPnz4sNZa61//+td65cqVuq6u\nTs+aNUsXFhZqrbV+7bXX9M0333zWPi5cuFD/5S9/0Vpr/dRTT+nw8HCttdb19fXa6XRqrbUuKirS\no0aN0h6PRx85ckSbzWb95Zdfaq21XrJkiX755ZfPWm/L8qWnp+uNGzdqrbX+2c9+pn/4wx9qrbVO\nTEzUbrdba611WVmZ1lrrSy+9VH/88cdaa60rKyt1fX39Wes+7d/Gqd/nVt2JjB24NXaA0QsgxAG7\n/u7vkggxYD355JPNNerjx49z4MCBdpd///33ef/995kyZQqZmZns27ev+Tsth/oFCA4O5tJLjfsV\nWw6lm5uby4UXXkh6ejoPP/zwaUPvtuWaa65pfsDF66+/zrXXXss333zD119/zQUXXMDkyZP55S9/\nSW5u7lnf3bx5c/NQw9dff33zfK01999/PxkZGXzrW9/ixIkTzWcVI0aMYPLkyWeVvTVOp5Py8nLm\nzp0LwHe/+10++ugjADIyMrjhhht45ZVXCAoy7gmdPXs29957L08++STl5eXN831lYA4p0MRihQmX\nw+5VsPAxCA7zd4mE6L6Lf93nm9y4cSPr169n8+bNhIWFcd555502DG5rtNb893//N9///vdPm5+T\nk3PW4F4WiwWlFHD6ULo/+MEPuPfee7n88svZuHEjK1as6LCs1157LUuWLGHx4sUopRg9ejS7du1i\n4sSJbN68uQt7fcqrr75KUVER27Ztw2KxkJqa2rz/TUMAN5W9o6aYtqxevZqPPvqI9957j1/96lfs\n2rWL++67j4ULF7JmzRpmz57N2rVrGTduXLfW35qBXWMHyLgG6lyw/5/+LokQA47T6SQqKoqwsDD2\n7dvHZ599dtYydrv9tDbgCy+8kD//+c+4XC4ATpw4QWFhYZe3m5ycDMCLL77Y5rZaGjVqFGazmZUr\nV3LttdcCMHbsWIqKipqDvb6+vtXa/8yZM5uHGn7ttddOK0d8fDwWi4UNGzZw9OjRLu1HE4fDQVRU\nVPNDSl5++WXmzp2Lx+Ph+PHjnH/++Tz00EM4nU5cLheHDh0iPT2d5cuXM336dPbt29et7bZl4Af7\n8NlgT4SvpDlGiK666KKLaGhoYPz48dx3332nNaM0iYmJYfbs2aSlpbFs2TIWLFjA9ddfz6xZs0hP\nT+fqq6/u8sW/FStWsGTJEqZOnUpsbGzz/Msuu4xVq1Y1Xzw907XXXssrr7zCNddcAxhNPW+88QbL\nly9n0qRJTJ48udXnqT7xxBM89thjZGRkcPDgweYhh2+44Qa2bt1Keno6L730Uo9qzS+++CLLli0j\nIyODHTt28OCDD9LY2MiNN97Y/FzVe+65h8jISJ544gnS0tLIyMjAYrFw8cUXd3u7rRmYw/aeae1P\n4fOn4UdfQ0Si79YrRC8bSMP2DmTV1dWEhoailOK1117jr3/9a3MPmP6qJ8P2DvwaO8D024w7UD/5\nrb9LIoToh7Zt28bkyZPJyMjgD3/4A48++qi/i9SrBvbF0ybRI2DydbDteTj3R2CXcZuFEKdkZ2ez\nc+dOfxejzwRGjR0g+yfQWA8fP+HvkgghhF8FTrBHj4BJ3lp75Ul/l0YIIfwmcIIdYM6PpdYuhBj0\nAivYo0eeqrWXH/N3aYQQwi8CK9gBzrsPlBn+ca+M1S5EP9Y0LLDwvcAL9sihMP9BOLhOxpARQgxK\ngRfsAFlLIWU6/HM5VEmNQIi2PPzwwzz55JMA/Od//ifz5s0D4IMPPuCGG24A4LnnnmPMmDFkZWWx\ndOlS7r77bqD1oXfBuKv0lltu4bzzzmPkyJHN62/PY489RlpaGmlpaTzxhHGNrKqqioULFzJp0iTS\n0tKaBwC77777mofB/clPfuLbX0iACIx+7GcymeHy/wdPZ8O/7oOr+uYZ20L0xENbHmJfqW/HDBkX\nPY7lWcvb/Dw7O5tHH32Ue+65h61bt1JbW0t9fT2bNm1izpw55OXlsXLlSrZv347dbmfevHlMmjQJ\ngHPPPZfPPvsMpRR/+tOf+M1vftN848++ffvYsGEDlZWVjB07ljvvvBOLxdJqGbZt28bzzz/P559/\njtaaGTNmMHfuXA4fPkxSUhKrV68GjHFdSkpKWLVqFfv27UMpRXl5uU9/X4EiMGvsAPHjIfvHRnPM\n12/6uzRC9EtTp05l27ZtVFRUEBISwqxZs9i6dSubNm0iOzubLVu2MHfuXKKjo7FYLM1D30L7Q+8u\nXLiQkJAQYmNjiY+Pb67Nt+bjjz9m0aJFhIeHY7PZWLx4MZs2bSI9PZ1169axfPlyNm3ahMPhwOFw\nYLVaufXWW3nrrbcIC5MRXVsTmDX2Jtk/hsMb4J0fQPwEI+yF6Kfaq1n3FovFwogRI3jhhRf4j//4\nDzIyMtiwYQMHDx5k/Pjx7N+/v83vtjf07plD3jYN19sVY8aMYfv27axZs4YHHniA+fPn8+CDD7Jl\nyxb+/e9/88Ybb/C73/2ODz74oMvrDnSBW2MHCAqGJS9CcDi8dgO4nf4ukRD9TnZ2No888ghz5swh\nOzubp59+milTpqCUYvr06Xz44YeUlZXR0NDQPPQttD30bne2//bbb1NdXU1VVRWrVq0iOzubvLw8\nwsLCuPHGG1m2bBnbt2/H5XLhdDq55JJLePzxxwfVMAFdEdg1djBGe7zmRXjxMlh1B1z7KpgC+3gm\nRFdkZ2fzq1/9ilmzZhEeHo7VaiU7OxuA5ORk7r//frKysoiOjmbcuHHNQ942Db0bFRXFvHnzOHLk\nSLe2n5mZyfe+973mZ5redtttTJkyhbVr17Js2TJMJhMWi4WnnnqKyspKrrjiCtxuN1prHnvsMd/8\nEgJMYAzb2xmfPQ3/Wg6z7oYFvwTvU12E8KeBMGyvy+XCZrPR0NDAokWLuOWWW1i0aJG/ixXwejJs\nb+DX2JvM+D6UHobNv4PQKJgj3aSE6IwVK1awfv163G43CxYs4Morr/R3kUQHBk+wKwUX/Rrc5fDB\nSgiNNMZxF0K065FHHvF3EUQXDZ5gB6Nt/Yrfg7sCVv8EzMGQeZO/SyUGOa118wOfhQDj30RPDL6r\niGYLLHkeRs2Dd38An/4/f5dIDGJWq5WSkpIe/0cWgUNrTUlJCVartdvrGFw19iaWULjuNXhrKbz/\nANSUwbyfyQVV0edSUlLIzc2lqKjI30UR/YjVaiUlJaXb3x+cwQ5GH/er/wz/cMCmR6EiHy59HCzd\nP0oK0VVNNwgJ4UuDN9jBGFPmst9CRBJs/L9QcsDo525P8HfJhBCi2wZfG/uZlDLGcF/yIhTshmfP\nh9xt/i6VEEJ0mwR7k4lXwi1rjYd0/HmBcVHV4/F3qYQQossk2FtKzIDvfwhjLjIuqv71WhnPXQgx\n4EiwnyksGq59BS55BA5/CL+fAbvf9nephBCi0yTYW6OU8RSm2zcaj9r7+3fhbzeBS7qkCSH6Pwn2\n9iRMgFvXG89Q/eaf8Ltp8MVz4Gn0d8mEEKJNPQ52pdRQpdQGpdQepdRupdQPfVGwfsMcZDyw446P\nYUg6rL4X/jQfTkjPGSFE/+SLGnsD8GOt9QRgJvB/lFITfLDe/iVuLHz3PVj8J6jIg2fnwZtLofy4\nv0smhBCn6XGwa63ztdbbve8rgb1Ack/X2y8pBRlL4O6tRi1+77tG88y6n0N1qb9LJ4QQgI/b2JVS\nqcAU4HNfrrffsUYY7e53b4Xxl8Mnv4XfToKNDxkjRwohhB/5LNiVUjbgTeBHWuuz0k0pdbtSaqtS\namvADHgUORSuehbu/ARGzIGN/wtPpMPGX0sNXgjhNz55NJ5SygL8A1irte7wIYR+eTReXzixHT56\nBL5ZDcE2mHYLzLzTGItGCCF6qLOPxvNFrxgFPAfs7UyoB7TkTLjuL3Dnp8bdq5t/Z9Tg3/o+nNzl\n79IJIQaJHtfYlVLnApuAXUDT4Cr3a63XtPWdgK2xn6ksx3iI9vaXoL4Khs82bnwad6nxwA8hhOiC\nztbYfdIU01WDJtib1JTBl6/Almeh/CjYE2HKdyDzOxA5zN+lE0IMEBLs/ZGnEQ6uNwL+4Hpj3jnz\nYcqNMPYSCArxb/mEEP1aZ4N9cD9oo6+ZzDDmQmMqP2bU4re/DH//HoRGQfoSyPi20VYvj+kTQnST\n1Nj9zdMIhzfCjr/A3vegsRaiR0HGNZB2NcSe4+8SCiH6CWmKGYjcTiPcv3odjmwCtDE+zcTFxoNA\nokf6u4RCCD+SYB/oKvKMceB3vwW5XxjzEtJh/GXGFD9emmuEGGQk2ANJ+TGjJr/nXTj+OaAhKtXo\nNjn2Yhg6Q7pPCjEISLAHqsqTxtjw36wx2uYb68DqgHO+BaMXwKj5YIvzdymFEL1Agn0wqK00wn3/\nv2D/+1BVaMxPmmIE/DnzIWW61OaFCBAS7IONxwMnd8KB9XBwndEurz0QbIfUc2HkeTByLsSNk7Z5\nIQYo6cc+2JhMRk09aQrMXQY15ZCzCQ7+21ur/6exnC3BCPrUcyF1DsSMkqAXIsBIsAeq0MhTPWgA\nyo7CkQ/hyEdGV8qv3zTm2xKMMWxSZ8OwWRA33jhICCEGLAn2wSJqOETdBJk3gdZQchByPoajn0DO\nJ0a3SgBrpNHLZtgM4zUpE4LD/Ft2IUSXSLAPRkpB7GhjmnazEfRlOXDsMzj2qfF6YK2xrCkIEtJg\naBakZEHKVIgaIc03QvRjcvFUtK66FI5vMfrN534BJ7ZBfbXxWVgMJE81pqRMY2yb8Fj/lleIQUAu\nnoqeCYuGsRcZE0BjAxTugRNbIXeb8XpgHeCtGDiGQdJkY0qcbFzEDYv2W/GFGMwk2EXnmIMgMcOY\npt1izKuthPydxiMB8740pr3vnvqOYygkToIhGcaYN4kZEJEszThC9DIJdtF9IfZTXSeb1JRB/leQ\nvwPydhiPBNy3muaavTXSCPkh6UbbfcJEo2+9xeqXXRAiEEmwC98KjTJuhBo599S8WhcU7IaCXUbQ\n538FW5+Hhhrjc2U2+tPHTzDCPn68MUWlGmPYCyG6RIJd9L4Qm9F9ctiMU/M8jVB6xAj7gt1QsMeo\n5e95+9QyQVaIHWOEfNw47zRWAl+IDkiwC/8wmY2HiMSeAxMXnZpf64Kib4wLtYV7oWiv0d/+q9dP\nLWMO8XbXHGOEfdP7mFFgCe37fRGin5FgF/1LiM3oK58y9fT5bicU7YeifcZUvN/ogrl7Fc3t9yiI\nHOoN+dHGQSPmHOO9PVHuqBWDhgS7GBisDhg63ZhaqquG0kNG0BcfMKaSA3B0M9RXnVouKNSo0UeP\n9L62eG9LkJ46IqBIsIuBLTjsVC+blrSGyvxTQV9y2DgAFO4xxrL3NJxa1hIO0SOMKWqEEfhN7x0p\n0p4vBhwJdhGYlIKIJGNq2UMHjJutnMe8Yd80HYLCfbB/rfHwkiYmi9G8E5V6aooc7n0/3OgFJEQ/\nI8EuBh9zkLdW3srDwT2NxvNmy44Y4+eUel/LciBvldFPv6UQB0QNM8I+ctjZk9XRBzskxOkk2IVo\nyWQ2auiRQ2HEnLM/dzuNIZDLcqD8qPG+/KjR5HPog1Pj6TQJcRjrcgw99epIMULfkQLh8XJRV/ic\nBLsQXWF1nBpa4UxaQ3WJEfbOY1B+3HgQudP7evQTqK04/TsmCziSjcCPSDbC3pEMEU2vycY25eKu\n6AIJdiF8RSljlMvw2LO7azZxO8GZa4S+87jx3nkcnCeM/vqV+aAbT/+OJdwb8klG4DddO4hIhohE\n4zU0SsJfNJNgF6IvWR3GlDCx9c8bG8B10gj6ilyjvb/5fb7R3OM6aTzPtqUgK9iHgD3JCHu7d4pI\nNObZhxg/y5g8g4IEuxD9iTnI2xyTAsxofZnGBnAVGKHfFPiVecbPlSeN0TYr86HBffZ3Q6PANuRU\n0NsTjFdbgjGv6VXu4B3QJNiFGGjMQd52+WRgeuvLaA3u8hahn2/U9CtbTMUHjHkt+/Q3CXGALf5U\n2NsSjIOALcGYb/POD42Si7/9kAS7EIFIKSN0Q6MgYULby3k8UFNq1PArC06Fv6vAOxUaQze4Cs7u\n8QPGyJzhcd6wjzfCvunn8HiwxXlf4yE0Wg4CfUSCXYjBzGQ6dcH3zLt3z1RbaYR/VaER9C3fuwqN\nqXCv8eqpP/v7yuzdVtwZU2yLg0EshHl/loeod5sEuxCic0LsxhR7TvvLaW3cyFVVZIR8VSG4iozX\nqqJT70sPQ1Xx6WP6tGQJOz3ow2ON5+02z2t6jTFeg8OlZ5CXT4JdKXUR8FvADPxJa/1rX6xXCDEA\nKWU87zYs2hg/vyN1VUbgV5V4X4ugutgIfVeh8b4y33hIS3Xx6UM+tBRkNYK/Ofy9gR8W4y1PTIsp\n2mgaCgr27b73Ez0OdqWUGfg9cAGQC3yhlHpXa72np+sWQgwCweHGFJXa8bJaG01C1cXGgaDpAFBd\n0mKedyo9DNVlUOtse30hEadCvmXgh0Ub1yfOmhc9IJqIfFFjzwIOaq0PAyilXgOuACTYhRC+pRRY\nI4yptbF+WtNQZ1wgbgr8qmLvz03zvK9VRVD8jXEwqKtse31BVu+F6aawjzTeN12sbjootJwXGtWn\n9xD4ItiTgeMtfs6lzQ64PfPMF//kq4JviLJGEhsWSYItigRbDCkR0aRExBAaHNIbmxVCDGRBwd5+\n+0M6/52GOuM6QXXJqYNAy9eaMqgpN34uPnjqs9YuGjeXI9QI+EVPnz3iqI/12cVTpdTtwO0Aw4YN\n69Y63jnwT47Vb2h7AU8wSocRRDghyobVbCM8KAJ7cASOEAfRoQ5iQ6OID49iiD2KlIhYkuzR2ENs\nKLnoIoRoEhTsvXkrofPf0dq4XlBT1iL8y4zAd5ef+tnWhXV2ky+C/QQwtMXPKd55p9FaPwM8AzBt\n2jR95uedseqaxzjpKuOEs5R8VykFrlKKq8spqSmnzO2ksq4CV0MF1Q2V1HpclNXnUVy/H11bg6pq\n5SaM5sKZTh0QTDZCzTbCg+zYgyOIDHEQHRpJfHgUCbZoEm3RJEfEEGl1EBESgcVk6c6uCCECjVLG\nox1DbMZInn7ki2D/AhitlBqBEejfBq73wXrPEhwUxLDIOIZFxnXpe40eTXFVJcedJeRVlHLSVUph\nVRklNeWUu52U15bjqq+kpsGFu76S6rpSCjkO5mow1aJU28chkw4hSNmweg8I9uAIIoIdRFsdxIZ5\nDwb2KGLDonCEOIjwnj2EBYXJWYIQolf0ONi11g1KqbuBtRjdHf+std7d45L5kNmkSLBHkGCPAEZ0\n6jtaa6rrGimuqiHPWU5uRTGFVWUUuEqbzxAq6ipw1VVQ3ViJy+OiTFeBuQhlrkGZqlGmxjbXrzAT\nrMIJNduxWSJwhEQ0XzuIC4skJiyq+SDQ8oAQERxBkEluPxBCtM0nCaG1XgOs8cW6+gulFOEhQYSH\n2Bkebef01qbWaa2pqGmgpKqW0qpa8isqyK8s46SrjKLqUkprnJTXOqmodVLVUEm1pxKXqqbYVIMy\n56LM+42DgrmVwZtaCDGFEW6x4wiOIDo0kujQyLPC3xHiwBHsOO3AEBoUKmcJQgwCUvXzIaUUjjAL\njjALI+NsQAztnSForXHVNlBaVUdJVR3FlbWUVtVR5Kohr6KUoupySmrKKK1xUlHnpKq+EszV1Jlq\ncJlrKDRXg7kQk/ko5qBqtKlczBw/AAAVLklEQVQGVNtnCUGmoNNCPyIkojn8I0IizjogNP1sD7bL\nWYIQA4j8b/UjpRR2qwW71cLwmPAOl2/0aO9BoJbiyjqKXbUUu2opchk/F7ncFLkqKaouo7y2Aq2q\nUOZqlLkGTDUocw21lhoqQ2opsLhR5jK0qqaeKup1KwM8tWCz2JrD/rQDQhvNRXKWIIT/SLAPIGaT\nIs4eQpw9BDrokuvxaMqq604P/cpaCiuMA0FRZS2F5cars6YeaDQOAOYalKkGs6WGiLB6bGG1hFrr\nsODG3FhDfV0NhXUujnvyqWl0UVlXQYNuu8eRxWRpN/zbakKyB9sxm8y+/QUKMUhIsAcok0kRYwsh\nxtbxQaC2odEIem/wF1a6m18LKowDQF6Fm5Kqs8fosJghLkIRE9FIpK0ee1g94aF1hATXYrbUYDLX\n0EAVVQ2VVNRWUFBdwP6y/VTUVVDV1uBPXnaL3Tg7aOf6QcsmpKZXq9kqZwliUJNgF4QEmUmJCiMl\nqv0xMOoaPBS7aimocHunWk5WuClwuimodJNXaMxz1Z5dg48OD2ZIhJUhDiujHVYSE6zEOyw4wuoI\nC63DGlKH2+PCWWv0NqqoraC8tpyKugqctU6cdU5OVp1s/qy9s4RgU3C71w+aPmu5jCPEgc1ik7ME\nERAk2EWnBQeZSIoMJSmy/cemuWobOOl0G1OFm5POGvKdxsEgr9zNl8fKKKs++9brqDALSZGRJDoS\nSY60khQZSpp3e8mRocTbQzCZlNEVtaEaZ62zOfwraitw1jlPe206KOS58thXt4+K2gqqG9q+lqBQ\n2IJtHV4/aO0agzVIniUq+g8JduFztpAgzom3cU68rc1l3PWNnHS6yXPWkF/uJt9ZQ57TTX55DcdL\nq/n8SAmV7tNr5RazYojDSnJkKMmRYSRHWkmJspMcFc/YqFASHaEEB7X/hJ76xnoj8OucxtlBbYv3\nrRwg8qrymg8SjbrtHkch5pBOXz9oecZgD7ZjUvJUIeFbEuzCL6wWM6mx4aTGtt0bqMJdT365m7zy\nGnLLa8jzTifKavj0UDEFFW48LW4KVgqGRFhJiQr1Ni2FMjQqjJRo4zXRYcVithATGkNMaEyXyqu1\nxlXvOtU05G0eanlm0LIJ6YTrBHtK9lBRV0FNQ02b61Uo7MH21sO/vQNDiIMQswx6J1qntO7WsC09\nMm3aNL1169Y+364ILPWNHk463eSW1ZBbVu19PfU+31lzWvAHmRTJUaEMiw5jaHQYw1pOMWFEWHtn\n3J+6xrpWw7/lQaLpLOLMA4VHe9pcr9Vs7dT1g9Pmea8lyFnCwKSU2qa1ntbRclJjFwOWxWxiqDek\njZvBTlff6CG/3M2x0mqOl1VzvLTa+76Gf319ktIzevlEhVkYFhNOakwYw894jQ4P7nZPm2BzMLGh\nscSGxnbpex7tMc4S2rl+UF5b3vzzcddxKkoqOjxLMCmTcZbQ4jrBmQeEMy8sNy0XbA7MJw4FGqmx\ni0Gr0l3PsdJqjpVUc7S0mqMl1RwrrSKnuJo8Zw0t/2vYrUGMiA1nRGw4qTHhze9HxoVj76Wafk/U\nNtY2X0NoeQbQsgmp5XWGijrjLKKyrhJN25kQGhTa4bWE1i4u2ywyNLYvdLbGLsEuRCtqGxo5XlpD\nTnEVOSVVHC2pJqekiiPFVZwoPz30Y20hjIwLZ1RcOKPibM1TclQoZtPACjOP9lBZV3naAaHlWULL\nXkhn9kiqbaxtc71mZW6+ltB0RtDeAaLlWYPF3P8OnP4iTTFC9EBIkLnNnj3u+kaOl1ZzuNgI+sNF\nLg4XVbF2dwGlVaceJhYcZGJkrDfsvesaHW9jZFw4IUH9s7+8SZmaQ7arahtr275+cEavo1J3KTnO\nHCrqKjp1ltDRheTWPgu3hA/aswSpsQvhQ2VVdRzyBv2hIhcHC10cLHJxvLS6+UKuSUFqTDjnxNsY\nk2BndIKNsUPsjIy1ddhdMxA1ehqbryW0PCs488Bw1plDbQV1nrPvhm4SpIJO9Thq7QzhjNFPW15f\n6K8P0JEauxB+EBUezLTwaKalRp82313fyJHiKg4UujhYUMn+Ahf7Cyv5975CGr2JH2RSjIwLZ0yC\nnfGJEYxNsDMu0U5yZGAPpGY2mbt9luBucHcY/k2flbnLyHHm4Kxz4qpztXuWEBYU1unrB/3xATpS\nYxfCj2obGjlUWMWBwkr2F1TyzclK9p2sJLfsVK8WuzWI8UMiGJ9oZ0JSBBMSHYxOsGG19M/mnIGg\n6SyhZfi3df2gZVNSeW059e08sDpIBZ3eBbWVpqPzh55Pki2pW+WWGrsQA0BIkNkI66SI0+ZXuuvZ\nX1DJ3vxK9p2sYG9+JW9sy6Vqs3H3a5BJcU68jQmJEUxMdpCWZLzaQuS/dGd09yxBa4270d1mr6Iz\nm5CKqos4VH6IitoKKusrARjhGNHtYO8sqbELMUB4PJpjpdXsya9gT14Fu/Oc7M6roLDS6I2iFIyI\nDScj2UF6SiQZKQ4mJkUQFixh3x80eBqorKskzBLW7buGpbujEINEYYWb3XkV7DrhNKZcJycrjMcr\nmhSMSbAzKSWSSUMjmTw0kjEJNoLMg+8ibSCQphghBon4CCvxEVbOHxffPK+wws2uE052Hi9nR66T\ntXtO8vpWoytmWLCZjBQHmcOimDo8isxhUUSFyx2lgUSCXYgAFB9hZX6ElfnjEwCjbfhoSTU7jpfz\n5bEyth8r548fHW7ukTMyLpxpw6OYlhpNVmo0w2P6R+8O0T3SFCPEIFVT18jO3HK2HytjW04Z246V\nUe4dJz/OHkLWiGhmjohmxsgYRsfLkAD9gTTFCCHaFRpsZubIGGaONAZQ83g0h4pcbMkpZcsRY1r9\nVT4AMeHBxrKjYpg9KoYRsYP3rs6BQGrsQohWaa05XlrDZ0dK+OxQCZsPl5DvNC7KJjqszD4nluzR\nscw+J5ZYm4wN3xekV4wQwqea2uk/OVTMJweL+fRQSXPTzcSkCOaMieO8MXFkDo/CIr1ueoUEuxCi\nVzV6NLvznGw6UMyH+4vYfrSMBo/Gbg1izug45o2L57yxccRIbd5nJNiFEH2qwl3PpweL2fhNER/s\nK6SwshalIHNYFBdMSGDBhARGxrX9HFzRMQl2IYTfeDyaPfkVrN9bwPq9BXx9ogKAc+JtXJw2hIvS\nhjAhMUIuwHaRBLsQot84UV7Dut0n+dfuk2w5UopHQ2pMGJdmJHHppETGJtgl5DtBgl0I0S8Vu2pZ\nt6eANbvy+eRgMR4No+NtXDklmcsnJXmfYStaI8EuhOj3il21/PPrk7y3I48tOaUAZKVGc/XUFC7J\nSJTRKs8gwS6EGFCOl1bzzo4TvLX9BIeLqwgLNnNxWiLXzxhK5rAoaapBgl0IMUBprdl+rJw3th3n\nvZ35uGobGJNg47qsYVw1NYUIa/98bF1fkGAXQgx4VbUN/OOrPP6y5Tg7j5cTFmxmcWYyN81KZUyC\n3d/F63N9EuxKqYeBy4A64BBws9a6vKPvSbALIbrq6xNOXvw0h3d25lHX4GHOmDiWZo/g3HNiB00z\nTV8F+wLgA611g1LqIQCt9fKOvifBLoTortKqOv665RgvfppDYWUt44bYufO8UVyakYTZFNgB3+dN\nMUqpRcDVWusbOlpWgl0I0VO1DY28uyOPZz46zIFCF6kxYdx53igWZ6YE7Fg1nQ12X+79LcA/fbg+\nIYRoU0iQmSXThrL2R3N4+sap2K0Wlr+5i/mPfsib23KbHyIyGHVYY1dKrQeGtPLRT7XW73iX+Skw\nDVis21ihUup24HaAYcOGTT169GhPyi2EEKfRWrPxmyIeef8bdudVMCounP+6aBwLJiQETBt8nzXF\nKKW+B3wfmK+1ru7Md6QpRgjRWzwezdrdJ3nk/W84VFRF1ohoHlg4noyUSH8Xrcf6pClGKXUR8F/A\n5Z0NdSGE6E0mk+Li9ETW/mgOK69M41Chi8t/9wnL/r6TEletv4vXJ3raK+YgEAKUeGd9prW+o6Pv\nSY1dCNFXKt31/O6Dgzz38RHCQ4JYduFYrssaNiB70MgNSkII0cKBgkoeePtrPj9SytThUfzm6gxG\nDbDx4f3RK0YIIfqt0Ql2Xrt9Jo8umcTBQheX/HYTz350OCB7z0iwCyEGDaUUV01NYd1/ziF7dBy/\nWrOX6575jLzyGn8Xzack2IUQg058hJVnb5rKY9dMYneek0ue3MS6PQX+LpbPSLALIQYlpRSLM1N4\n7wfnkhwZytKXtrLyH3toaPT4u2g9JsEuhBjURsbZeOuu/+CmWcN57uMj3PzCFzir6/1drB6RYBdC\nDHohQWZ+cUUav16czmeHS1j0h084VOTyd7G6TYJdCCG8vp01jFdvm4mzpp6rnvqUHcc7HIW8X5Jg\nF0KIFrJGRLPqrtlEWC1c/+xnfHKw2N9F6jIJdiGEOMOwmDDeuGMWQ6PCuPn5L1i7+6S/i9QlEuxC\nCNGK+Agrr39/JhOSIvg/r25n4zeF/i5Sp0mwCyFEGyLDgnnp1izGJNi545VtbM0p9XeROkWCXQgh\n2hFhtfDSrVkkOUK5+YUv2JNX4e8idUiCXQghOhBrC+Hl22ZgCwniu89voaDC7e8itUuCXQghOiE5\nMpQXbs7C5W7grle3U9fQf+9QlWAXQohOGjvEzm+uzmDb0TJ+tXqPv4vTJgl2IYTogssmJXHbuSN4\ncfNR3tqe6+/itEqCXQghuui+i8cxc2Q096/axbGS/vdUUAl2IYTooiCzicevnUyQycT9q3bhjyfR\ntUeCXQghuiHREcp9F4/j44PF/H1b/2qSkWAXQohuuj5rGFmp0fzyH3sorOw/XSAl2IUQoptMJsX/\nvSodd4OHFe/u9ndxmkmwCyFED4yKs3HPvHNYs+skX/STIQck2IUQooduPXcksbZgfrv+gL+LAkiw\nCyFEj4UGm7lj7ig+PljcL2rtEuxCCOEDN8wY3m9q7RLsQgjhA6HBZr4/p3/U2iXYhRDCR26YOaxf\n1Nol2IUQwkfCgoO4fc5IPj5Y7Ndx2yXYhRDCh66eOpQgk+KdHSf8VgYJdiGE8KHo8GDmjonj3Z15\neDz+GUNGgl0IIXzs8slJ5DvdfH7EPxdRJdiFEMLHLpiQQFiw2W/NMRLsQgjhY2HBQVw4cQhrduVT\n29DY59uXYBdCiF5wxeQkKtwNbPymqM+37ZNgV0r9WCmllVKxvlifEEIMdOeeE0usLdgvzTE9Dnal\n1FBgAXCs58URQojAEGQ2cWlGEuv3FlLhru/Tbfuixv448F9A/3o2lBBC+NlFaUOoa/DwRR/3julR\nsCulrgBOaK13+qg8QggRMDJSHJgU7Mx19ul2gzpaQCm1HhjSykc/Be7HaIbpkFLqduB2gGHDhnWh\niEIIMTCFBQcxJsHOV7nlfbrdDoNda/2t1uYrpdKBEcBOpRRACrBdKZWltT7ZynqeAZ4BmDZtmjTb\nCCEGhYwUB+v2FKC1xpuVva7bTTFa611a63itdarWOhXIBTJbC3UhhBisMlIiKauuJ7esps+2Kf3Y\nhRCiF01KiQRgZx82x/gs2L0192JfrU8IIQLB2CF2goNMfNWHF1Clxi6EEL0oOMjEhMQIdhwfgDV2\nIYQQrZuU4uDrE04a+2gYXwl2IYToZRkpkVTXNXKoyNUn25NgF0KIXjZpqHEBta+aYyTYhRCil42M\nDcceEtRnNypJsAshRC8zmRRpyY4+6xkjwS6EEH0gY6iDvfkVffLgDQl2IYToA5NTIqlv1OzNr+z1\nbUmwCyFEH5g8LJILJiRg7oPxYjocBEwIIUTPJTpCefamaX2yLamxCyFEgJFgF0KIACPBLoQQAUaC\nXQghAowEuxBCBBgJdiGECDAS7EIIEWAk2IUQIsAorftm4PfTNqpUEXC0C1+JBQbjY/dkvweXwbrf\nMHj3vav7PVxrHdfRQn4J9q5SSm3VWvfNLVv9iOz34DJY9xsG77731n5LU4wQQgQYCXYhhAgwAyXY\nn/F3AfxE9ntwGaz7DYN333tlvwdEG7sQQojOGyg1diGEEJ3Ur4JdKXWRUuobpdRBpdR9rXweopR6\n3fv550qp1L4vpe91Yr/vVUrtUUp9pZT6t1JquD/K6Wsd7XeL5a5SSmmlVED0mujMfiulrvH+zXcr\npf7S12XsDZ34dz5MKbVBKfWl99/6Jf4op68ppf6slCpUSn3dxudKKfWk9/fylVIqs8cb1Vr3iwkw\nA4eAkUAwsBOYcMYydwFPe99/G3jd3+Xuo/0+Hwjzvr9zsOy3dzk78BHwGTDN3+Xuo7/3aOBLIMr7\nc7y/y91H+/0McKf3/QQgx9/l9tG+zwEyga/b+PwS4J+AAmYCn/d0m/2pxp4FHNRaH9Za1wGvAVec\nscwVwIve928A85Xqg+dM9a4O91trvUFrXe398TMgpY/L2Bs68/cGWAk8BLj7snC9qDP7vRT4vda6\nDEBrXdjHZewNndlvDUR43zuAvD4sX6/RWn8ElLazyBXAS9rwGRCplErsyTb7U7AnA8db/Jzrndfq\nMlrrBsAJxPRJ6XpPZ/a7pVsxju4DXYf77T0lHaq1Xt2XBetlnfl7jwHGKKU+UUp9ppS6qM9K13s6\ns98rgBuVUrnAGuAHfVM0v+tqBnRInnk6gCilbgSmAXP9XZbeppQyAY8B3/NzUfwhCKM55jyMs7OP\nlFLpWutyv5aq910HvKC1flQpNQt4WSmVprX2+LtgA01/qrGfAIa2+DnFO6/VZZRSQRinayV9Urre\n05n9Rin1LeCnwOVa69o+Kltv6mi/7UAasFEplYPR9vhuAFxA7czfOxd4V2tdr7U+AuzHCPqBrDP7\nfSvwNwCt9WbAijGWSqDrVAZ0RX8K9i+A0UqpEUqpYIyLo++escy7wHe9768GPtDeqw8DWIf7rZSa\nAvwRI9QDob0VOthvrbVTax2rtU7VWqdiXFu4XGu91T/F9ZnO/Dt/G6O2jlIqFqNp5nBfFrIXdGa/\njwHzAZRS4zGCvahPS+kf7wI3eXvHzAScWuv8Hq3R31eMW7k6vB/j6vlPvfN+gfEfGow/9N+Bg8AW\nYKS/y9xH+70eKAB2eKd3/V3mvtjvM5bdSAD0iunk31thNEPtAXYB3/Z3mftovycAn2D0mNkBLPB3\nmX20338F8oF6jLOxW4E7gDta/L1/7/297PLFv3O581QIIQJMf2qKEUII4QMS7EIIEWAk2IUQIsBI\nsAshRICRYBdCiAAjwS6EEAFGgl0IIQKMBLsQQgSY/w/UXnupHkJuagAAAABJRU5ErkJggg==\n",
114 | "text/plain": [
115 | ""
116 | ]
117 | },
118 | "metadata": {
119 | "tags": []
120 | }
121 | }
122 | ]
123 | },
124 | {
125 | "cell_type": "code",
126 | "metadata": {
127 | "id": "ExqkmRRCD4n1",
128 | "colab_type": "code",
129 | "colab": {}
130 | },
131 | "source": [
132 | ""
133 | ],
134 | "execution_count": 0,
135 | "outputs": []
136 | }
137 | ]
138 | }
--------------------------------------------------------------------------------
/GloVe.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "GloVe.ipynb",
7 | "provenance": [],
8 | "collapsed_sections": [],
9 | "include_colab_link": true
10 | },
11 | "kernelspec": {
12 | "name": "python3",
13 | "display_name": "Python 3"
14 | }
15 | },
16 | "cells": [
17 | {
18 | "cell_type": "markdown",
19 | "metadata": {
20 | "id": "view-in-github",
21 | "colab_type": "text"
22 | },
23 | "source": [
24 | "
"
25 | ]
26 | },
27 | {
28 | "cell_type": "code",
29 | "metadata": {
30 | "id": "vw42L3S_bSqn",
31 | "colab_type": "code",
32 | "colab": {}
33 | },
34 | "source": [
35 | "import torch\n",
36 | "import torch.nn as nn\n",
37 | "from torch.autograd import Variable\n",
38 | "import torch.optim as optim\n",
39 | "import torch.nn.functional as F\n",
40 | "import nltk\n",
41 | "import random\n",
42 | "import numpy as np\n",
43 | "from collections import Counter\n",
44 | "\n",
45 | "#lamda expression, definition of lamda, parameter = l\n",
46 | "flatten = lambda l: [item for sublist in l for item in sublist]\n",
47 | "random.seed(1024)"
48 | ],
49 | "execution_count": 0,
50 | "outputs": []
51 | },
52 | {
53 | "cell_type": "code",
54 | "metadata": {
55 | "id": "vQcgtkefgVt1",
56 | "colab_type": "code",
57 | "outputId": "c55be4cc-e22a-49a2-edca-3ce11b2f0745",
58 | "colab": {
59 | "base_uri": "https://localhost:8080/",
60 | "height": 111
61 | }
62 | },
63 | "source": [
64 | "nltk.download('gutenberg')\n",
65 | "nltk.download('punkt')"
66 | ],
67 | "execution_count": 0,
68 | "outputs": [
69 | {
70 | "output_type": "stream",
71 | "text": [
72 | "[nltk_data] Downloading package gutenberg to /root/nltk_data...\n",
73 | "[nltk_data] Package gutenberg is already up-to-date!\n",
74 | "[nltk_data] Downloading package punkt to /root/nltk_data...\n",
75 | "[nltk_data] Package punkt is already up-to-date!\n"
76 | ],
77 | "name": "stdout"
78 | },
79 | {
80 | "output_type": "execute_result",
81 | "data": {
82 | "text/plain": [
83 | "True"
84 | ]
85 | },
86 | "metadata": {
87 | "tags": []
88 | },
89 | "execution_count": 10
90 | }
91 | ]
92 | },
93 | {
94 | "cell_type": "code",
95 | "metadata": {
96 | "id": "9qUF3R4acHUs",
97 | "colab_type": "code",
98 | "colab": {}
99 | },
100 | "source": [
101 | "FloatTensor = torch.FloatTensor\n",
102 | "LongTensor = torch.LongTensor\n",
103 | "ByteTensor = torch.ByteTensor"
104 | ],
105 | "execution_count": 0,
106 | "outputs": []
107 | },
108 | {
109 | "cell_type": "code",
110 | "metadata": {
111 | "id": "DYrZeL0ycTFc",
112 | "colab_type": "code",
113 | "colab": {}
114 | },
115 | "source": [
116 | "def getBatch(batch_size, train_data):\n",
117 | " random.shuffle(train_data)\n",
118 | " sindex = 0\n",
119 | " eindex = batch_size\n",
120 | " while eindex < len(train_data):\n",
121 | " batch = train_data[sindex:eindex]\n",
122 | " temp = eindex\n",
123 | " eindex = eindex + batch_size\n",
124 | " sindex = temp\n",
125 | " yield batch\n",
126 | " \n",
127 | " if eindex >= len(train_data):\n",
128 | " batch = train_data[sindex:]\n",
129 | " yield batch"
130 | ],
131 | "execution_count": 0,
132 | "outputs": []
133 | },
134 | {
135 | "cell_type": "code",
136 | "metadata": {
137 | "id": "5bwILUq1cUAt",
138 | "colab_type": "code",
139 | "colab": {}
140 | },
141 | "source": [
142 | "def prepare_sequence(seq, word2index):\n",
143 | " idxs = list(map(lambda w: word2index[w] if word2index.get(w) is not None else word2index[\"\"], seq))\n",
144 | " return Variable(LongTensor(idxs))\n",
145 | "\n",
146 | "def prepare_word(word, word2index):\n",
147 | " return Variable(LongTensor([word2index[word]]) if word2index.get(word) is not None else LongTensor([word2index[\"\"]]))"
148 | ],
149 | "execution_count": 0,
150 | "outputs": []
151 | },
152 | {
153 | "cell_type": "markdown",
154 | "metadata": {
155 | "id": "ofQkouNLcWkR",
156 | "colab_type": "text"
157 | },
158 | "source": [
159 | "# **Data load and Preprocessing**"
160 | ]
161 | },
162 | {
163 | "cell_type": "code",
164 | "metadata": {
165 | "id": "H_ZeSgWhcY0G",
166 | "colab_type": "code",
167 | "colab": {}
168 | },
169 | "source": [
170 | "corpus = list(nltk.corpus.gutenberg.sents('melville-moby_dick.txt'))[:500]\n",
171 | "corpus = [[word.lower() for word in sent] for sent in corpus]"
172 | ],
173 | "execution_count": 0,
174 | "outputs": []
175 | },
176 | {
177 | "cell_type": "markdown",
178 | "metadata": {
179 | "id": "EU-M4zFIcqX5",
180 | "colab_type": "text"
181 | },
182 | "source": [
183 | "**Build vocab**"
184 | ]
185 | },
186 | {
187 | "cell_type": "code",
188 | "metadata": {
189 | "id": "C7v0OHeHcwHb",
190 | "colab_type": "code",
191 | "colab": {}
192 | },
193 | "source": [
194 | "vocab = list(set(flatten(corpus)))"
195 | ],
196 | "execution_count": 0,
197 | "outputs": []
198 | },
199 | {
200 | "cell_type": "code",
201 | "metadata": {
202 | "id": "gWgAjQNAczKz",
203 | "colab_type": "code",
204 | "colab": {}
205 | },
206 | "source": [
207 | "word2index = {}\n",
208 | "for vo in vocab:\n",
209 | " if word2index.get(vo) is None:\n",
210 | " word2index[vo] = len(word2index)\n",
211 | " \n",
212 | "index2word={v:k for k, v in word2index.items()}"
213 | ],
214 | "execution_count": 0,
215 | "outputs": []
216 | },
217 | {
218 | "cell_type": "code",
219 | "metadata": {
220 | "id": "l1XT0wM3iggR",
221 | "colab_type": "code",
222 | "colab": {}
223 | },
224 | "source": [
225 | "WINDOW_SIZE = 5\n",
226 | "windows = flatten([list(nltk.ngrams([''] * WINDOW_SIZE + c + [''] * WINDOW_SIZE, WINDOW_SIZE * 2 + 1)) for c in corpus])\n",
227 | "\n",
228 | "window_data = []\n",
229 | "\n",
230 | "for window in windows:\n",
231 | " for i in range(WINDOW_SIZE * 2 + 1):\n",
232 | " if i == WINDOW_SIZE or window[i] == '': \n",
233 | " continue \n",
234 | " window_data.append((window[WINDOW_SIZE], window[i])) \n"
235 | ],
236 | "execution_count": 0,
237 | "outputs": []
238 | },
239 | {
240 | "cell_type": "markdown",
241 | "metadata": {
242 | "id": "K3OKmleWc9Mv",
243 | "colab_type": "text"
244 | },
245 | "source": [
246 | "\n",
247 | "\n",
248 | "**Weighting Function**\n",
249 | "\n",
250 | "Function to prevent X-ij from splashing above a certain value\n",
251 | "\n",
252 | "\n",
253 | ""
254 | ]
255 | },
256 | {
257 | "cell_type": "code",
258 | "metadata": {
259 | "id": "OIR9J6RVdOCB",
260 | "colab_type": "code",
261 | "colab": {}
262 | },
263 | "source": [
264 | "def weighting(w_i, w_j):\n",
265 | " try:\n",
266 | " x_ij = X_ik[(w_i, w_j)]\n",
267 | " except:\n",
268 | " x_ij = 1\n",
269 | " \n",
270 | " x_max = 100 #100 # fixed in paper\n",
271 | " alpha = 0.75\n",
272 | " \n",
273 | " if x_ij < x_max:\n",
274 | " result = (x_ij/x_max)**alpha\n",
275 | " else:\n",
276 | " result = 1\n",
277 | " \n",
278 | " return result"
279 | ],
280 | "execution_count": 0,
281 | "outputs": []
282 | },
283 | {
284 | "cell_type": "markdown",
285 | "metadata": {
286 | "id": "u7UpGTpqdQWq",
287 | "colab_type": "text"
288 | },
289 | "source": [
290 | "# **Build Co-occurence Matrix X**\n",
291 | "\n",
292 | "Because of model complexity, It is important to determine whether a tighter bound can be placed on the number of nonzero elements of X."
293 | ]
294 | },
295 | {
296 | "cell_type": "code",
297 | "metadata": {
298 | "id": "pTSJO-ordgD0",
299 | "colab_type": "code",
300 | "colab": {}
301 | },
302 | "source": [
303 | "X_i = Counter(flatten(corpus)) # X_i, dictionary"
304 | ],
305 | "execution_count": 0,
306 | "outputs": []
307 | },
308 | {
309 | "cell_type": "code",
310 | "metadata": {
311 | "id": "8_9ZaOmXdhb2",
312 | "colab_type": "code",
313 | "colab": {}
314 | },
315 | "source": [
316 | "X_ik_window_5 = Counter(window_data) # Co-occurece in window size 5, dictionary"
317 | ],
318 | "execution_count": 0,
319 | "outputs": []
320 | },
321 | {
322 | "cell_type": "code",
323 | "metadata": {
324 | "id": "NOUivMbQdite",
325 | "colab_type": "code",
326 | "colab": {}
327 | },
328 | "source": [
329 | "X_ik = {}\n",
330 | "weighting_dic = {}"
331 | ],
332 | "execution_count": 0,
333 | "outputs": []
334 | },
335 | {
336 | "cell_type": "code",
337 | "metadata": {
338 | "id": "_xcTe97Jdw2Q",
339 | "colab_type": "code",
340 | "colab": {}
341 | },
342 | "source": [
343 | "from itertools import combinations_with_replacement\n",
344 | "# combinations_with_replacement('ABCD', 2)\n",
345 | "# AA AB AC AD BB BC BD CC CD DD"
346 | ],
347 | "execution_count": 0,
348 | "outputs": []
349 | },
350 | {
351 | "cell_type": "code",
352 | "metadata": {
353 | "id": "iy4xGsEldxiZ",
354 | "colab_type": "code",
355 | "colab": {}
356 | },
357 | "source": [
358 | "for bigram in combinations_with_replacement(vocab, 2): # bigram : tuple\n",
359 | " if X_ik_window_5.get(bigram) is not None: # nonzero elements\n",
360 | " co_occer = X_ik_window_5[bigram]\n",
361 | " X_ik[bigram] = co_occer + 1 # log(Xik) -> log(Xik+1) to prevent divergence\n",
362 | " X_ik[(bigram[1],bigram[0])] = co_occer+1 # to satisfy X_ik = X_ki\n",
363 | " \n",
364 | " else:\n",
365 | " pass\n",
366 | " \n",
367 | " weighting_dic[bigram] = weighting(bigram[0], bigram[1])\n",
368 | " weighting_dic[(bigram[1], bigram[0])] = weighting(bigram[1], bigram[0])"
369 | ],
370 | "execution_count": 0,
371 | "outputs": []
372 | },
373 | {
374 | "cell_type": "code",
375 | "metadata": {
376 | "id": "THamZfl_dz2Z",
377 | "colab_type": "code",
378 | "outputId": "e3b0b1ed-c424-49d4-eb9d-47b4ba906ad8",
379 | "colab": {
380 | "base_uri": "https://localhost:8080/",
381 | "height": 55
382 | }
383 | },
384 | "source": [
385 | "test = random.choice(window_data)\n",
386 | "print(test)\n",
387 | "try:\n",
388 | " print(X_ik[(test[0], test[1])] == X_ik[(test[1], test[0])]) #check X_ik = X_ki\n",
389 | "except:\n",
390 | " 1"
391 | ],
392 | "execution_count": 0,
393 | "outputs": [
394 | {
395 | "output_type": "stream",
396 | "text": [
397 | "('sacred', 'any')\n",
398 | "True\n"
399 | ],
400 | "name": "stdout"
401 | }
402 | ]
403 | },
404 | {
405 | "cell_type": "markdown",
406 | "metadata": {
407 | "id": "v7DnHLQOwocg",
408 | "colab_type": "text"
409 | },
410 | "source": [
411 | "# **Prepare train data**"
412 | ]
413 | },
414 | {
415 | "cell_type": "code",
416 | "metadata": {
417 | "id": "I-yG8b_iwtOt",
418 | "colab_type": "code",
419 | "outputId": "84549bc7-61f8-4037-9424-b11255096606",
420 | "colab": {
421 | "base_uri": "https://localhost:8080/",
422 | "height": 36
423 | }
424 | },
425 | "source": [
426 | "u_p = [] # center vec\n",
427 | "v_p = [] # context vec\n",
428 | "co_p = [] # log(x_ij)\n",
429 | "weight_p = [] # f(x_ij)\n",
430 | "\n",
431 | "for pair in window_data: \n",
432 | " u_p.append(prepare_word(pair[0], word2index).view(1, -1))\n",
433 | " v_p.append(prepare_word(pair[1], word2index).view(1, -1))\n",
434 | " \n",
435 | " try:\n",
436 | " cooc = X_ik[pair]\n",
437 | " except:\n",
438 | " cooc = 1\n",
439 | "\n",
440 | " co_p.append(torch.log(Variable(FloatTensor([cooc]))).view(1, -1))\n",
441 | " weight_p.append(Variable(FloatTensor([weighting_dic[pair]])).view(1, -1))\n",
442 | "\n",
443 | " \n",
444 | "train_data = list(zip(u_p, v_p, co_p, weight_p))\n",
445 | "del u_p\n",
446 | "del v_p\n",
447 | "del co_p\n",
448 | "del weight_p\n",
449 | "print(train_data[0]) # tuple (center vec i, context vec j log(x_ij), weight f(w_ij))"
450 | ],
451 | "execution_count": 0,
452 | "outputs": [
453 | {
454 | "output_type": "stream",
455 | "text": [
456 | "(tensor([[1394]]), tensor([[134]]), tensor([[0.6931]]), tensor([[0.0532]]))\n"
457 | ],
458 | "name": "stdout"
459 | }
460 | ]
461 | },
462 | {
463 | "cell_type": "markdown",
464 | "metadata": {
465 | "id": "CXcS4npmz1YT",
466 | "colab_type": "text"
467 | },
468 | "source": [
469 | "# **Modeling**\n",
470 | "\n",
471 | ""
472 | ]
473 | },
474 | {
475 | "cell_type": "code",
476 | "metadata": {
477 | "id": "PiuUgUNR0CCa",
478 | "colab_type": "code",
479 | "colab": {}
480 | },
481 | "source": [
482 | "class GloVe(nn.Module):\n",
483 | " \n",
484 | " def __init__(self, vocab_size,projection_dim):\n",
485 | " super(GloVe,self).__init__()\n",
486 | " self.embedding_v = nn.Embedding(vocab_size, projection_dim) # center embedding\n",
487 | " self.embedding_u = nn.Embedding(vocab_size, projection_dim) # out embedding\n",
488 | " \n",
489 | " self.v_bias = nn.Embedding(vocab_size, 1)\n",
490 | " self.u_bias = nn.Embedding(vocab_size, 1)\n",
491 | " \n",
492 | " initrange = (2.0 / (vocab_size + projection_dim))**0.5 # Xavier init\n",
493 | " self.embedding_v.weight.data.uniform_(-initrange, initrange) # init\n",
494 | " self.embedding_u.weight.data.uniform_(-initrange, initrange) # init\n",
495 | " self.v_bias.weight.data.uniform_(-initrange, initrange) # init\n",
496 | " self.u_bias.weight.data.uniform_(-initrange, initrange) # init\n",
497 | " \n",
498 | " def forward(self, center_words, target_words, coocs, weights):\n",
499 | " center_embeds = self.embedding_v(center_words) # B x 1 x D\n",
500 | " target_embeds = self.embedding_u(target_words) # B x 1 x D\n",
501 | " \n",
502 | " center_bias = self.v_bias(center_words).squeeze(1)\n",
503 | " target_bias = self.u_bias(target_words).squeeze(1)\n",
504 | " \n",
505 | " inner_product = target_embeds.bmm(center_embeds.transpose(1, 2)).squeeze(2) # Bx1\n",
506 | " \n",
507 | " loss = weights*torch.pow(inner_product +center_bias + target_bias - coocs, 2)\n",
508 | " \n",
509 | " return torch.sum(loss)\n",
510 | " \n",
511 | " def prediction(self, inputs):\n",
512 | " v_embeds = self.embedding_v(inputs) # B x 1 x D\n",
513 | " u_embeds = self.embedding_u(inputs) # B x 1 x D\n",
514 | " \n",
515 | " return v_embeds+u_embeds # final embed\n"
516 | ],
517 | "execution_count": 0,
518 | "outputs": []
519 | },
520 | {
521 | "cell_type": "markdown",
522 | "metadata": {
523 | "id": "fD-RgBgO1iCf",
524 | "colab_type": "text"
525 | },
526 | "source": [
527 | "# **Train**"
528 | ]
529 | },
530 | {
531 | "cell_type": "code",
532 | "metadata": {
533 | "id": "oi4Vn1y81kRh",
534 | "colab_type": "code",
535 | "colab": {}
536 | },
537 | "source": [
538 | "EMBEDDING_SIZE = 50\n",
539 | "BATCH_SIZE = 256\n",
540 | "EPOCH = 50"
541 | ],
542 | "execution_count": 0,
543 | "outputs": []
544 | },
545 | {
546 | "cell_type": "code",
547 | "metadata": {
548 | "id": "8q8MpqZX2DQv",
549 | "colab_type": "code",
550 | "colab": {}
551 | },
552 | "source": [
553 | "losses = []\n",
554 | "model = GloVe(len(word2index), EMBEDDING_SIZE)\n",
555 | "optimizer = optim.Adam(model.parameters(), lr=0.001)"
556 | ],
557 | "execution_count": 0,
558 | "outputs": []
559 | },
560 | {
561 | "cell_type": "code",
562 | "metadata": {
563 | "id": "DtHG8eJA-bUr",
564 | "colab_type": "code",
565 | "outputId": "075e33c5-b708-4b55-e315-ea41e909ef1d",
566 | "colab": {
567 | "base_uri": "https://localhost:8080/",
568 | "height": 111
569 | }
570 | },
571 | "source": [
572 | "for epoch in range(EPOCH):\n",
573 | " for i,batch in enumerate(getBatch(BATCH_SIZE, train_data)):\n",
574 | " \n",
575 | " inputs, targets, coocs, weights = zip(*batch)\n",
576 | " \n",
577 | " inputs = torch.cat(inputs) # B x 1\n",
578 | " targets = torch.cat(targets) # B x 1\n",
579 | " coocs = torch.cat(coocs)\n",
580 | " weights = torch.cat(weights)\n",
581 | " model.zero_grad()\n",
582 | "\n",
583 | " loss = model(inputs, targets, coocs, weights)\n",
584 | " \n",
585 | " loss.backward()\n",
586 | " optimizer.step()\n",
587 | "\n",
588 | " #print(loss.data.tolist())\n",
589 | "\n",
590 | " losses.append(loss.data.tolist())\n",
591 | " #losses.append(loss.data.tolist()[0])\n",
592 | " if epoch % 10 == 0:\n",
593 | " print(\"Epoch : %d, mean_loss : %.02f\" % (epoch, np.mean(losses)))\n",
594 | " losses = []"
595 | ],
596 | "execution_count": 0,
597 | "outputs": [
598 | {
599 | "output_type": "stream",
600 | "text": [
601 | "Epoch : 0, mean_loss : 228.11\n",
602 | "Epoch : 10, mean_loss : 2.21\n",
603 | "Epoch : 20, mean_loss : 0.50\n",
604 | "Epoch : 30, mean_loss : 0.12\n",
605 | "Epoch : 40, mean_loss : 0.04\n"
606 | ],
607 | "name": "stdout"
608 | }
609 | ]
610 | },
611 | {
612 | "cell_type": "markdown",
613 | "metadata": {
614 | "id": "EOdn3U4A_I6s",
615 | "colab_type": "text"
616 | },
617 | "source": [
618 | "# **Test**"
619 | ]
620 | },
621 | {
622 | "cell_type": "code",
623 | "metadata": {
624 | "id": "9GqPL2he_LMM",
625 | "colab_type": "code",
626 | "colab": {}
627 | },
628 | "source": [
629 | "def word_similarity(target, vocab):\n",
630 | " target_V = model.prediction(prepare_word(target, word2index))\n",
631 | " similarities = []\n",
632 | " for i in range(len(vocab)):\n",
633 | " if vocab[i] == target: \n",
634 | " continue\n",
635 | " \n",
636 | " vector = model.prediction(prepare_word(list(vocab)[i], word2index))\n",
637 | " \n",
638 | " cosine_sim = F.cosine_similarity(target_V, vector).data.tolist()[0] \n",
639 | " similarities.append([vocab[i], cosine_sim])\n",
640 | " return sorted(similarities, key=lambda x: x[1], reverse=True)[:10]"
641 | ],
642 | "execution_count": 0,
643 | "outputs": []
644 | },
645 | {
646 | "cell_type": "code",
647 | "metadata": {
648 | "id": "pmjCpY_N_q2a",
649 | "colab_type": "code",
650 | "outputId": "ae5cf1e4-332d-41c0-e49b-c08ddd69c067",
651 | "colab": {
652 | "base_uri": "https://localhost:8080/",
653 | "height": 36
654 | }
655 | },
656 | "source": [
657 | "test = random.choice(list(vocab))\n",
658 | "test"
659 | ],
660 | "execution_count": 0,
661 | "outputs": [
662 | {
663 | "output_type": "execute_result",
664 | "data": {
665 | "text/plain": [
666 | "'since'"
667 | ]
668 | },
669 | "metadata": {
670 | "tags": []
671 | },
672 | "execution_count": 31
673 | }
674 | ]
675 | },
676 | {
677 | "cell_type": "code",
678 | "metadata": {
679 | "id": "dx7xm6tE_sez",
680 | "colab_type": "code",
681 | "outputId": "d2d0f135-e632-46b4-f080-a35008e9293b",
682 | "colab": {
683 | "base_uri": "https://localhost:8080/",
684 | "height": 204
685 | }
686 | },
687 | "source": [
688 | "word_similarity(test, vocab)"
689 | ],
690 | "execution_count": 0,
691 | "outputs": [
692 | {
693 | "output_type": "execute_result",
694 | "data": {
695 | "text/plain": [
696 | "[['learned', 0.677182674407959],\n",
697 | " ['hosmannus', 0.6526056528091431],\n",
698 | " ['doubt', 0.6380100846290588],\n",
699 | " ['justly', 0.5834043622016907],\n",
700 | " ['work', 0.5606750249862671],\n",
701 | " ['lazarus', 0.5067567825317383],\n",
702 | " ['hazy', 0.5046504735946655],\n",
703 | " ['measure', 0.47750324010849],\n",
704 | " ['insular', 0.4746555685997009],\n",
705 | " ['head', 0.4561583995819092]]"
706 | ]
707 | },
708 | "metadata": {
709 | "tags": []
710 | },
711 | "execution_count": 32
712 | }
713 | ]
714 | }
715 | ]
716 | }
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Example codes for Machine Learning Education (https://github.com/callee2006/MachineLearning)
2 |
3 | * colab에서 open, run 방법
4 |
5 | 1. colab 열기 (colab.research.google.com)
6 | 1. Google account로 로그인
7 | 1. 파일 -> 노트열기 -> GITHUB
8 |
9 | GitHub URL에 "https://github.com/callee2006/MachineLearning"
10 |
11 | 1. 노트 선택 및 실행
12 |
13 | (git 설치 후 "git clone https://github.com/callee2006/MachineLearning" 하면 다운로드 받으실 수 있지만, scikit-learn, matplotlib, graphviz, pytorch, mglearn 등이 설치되어 있어야 실행 가능합니다.)
14 |
15 |
16 | 예제코드 목록
17 |
18 | * python quickstart examples
19 | - python_core.ipynb (minimum tutorial)
20 | - python_tutorial.ipynb (quick tutorial)
21 |
22 | * accessing datasets in python
23 | - dataset.ipynb
24 |
25 | * k-nearest neighbor
26 | - kNN_IRIS.ipynb
27 |
28 | * linear and regularized regression (including Ridge, Lasso, Elastic-Net)
29 | - linear_regression.ipynb
30 |
31 | * Decision Tree, Random Forests, XGBoost
32 | - tree_and_ensemble.ipynb
33 |
34 | * Support Vector Machines
35 | - SVM (scikit-learn).ipynb
36 |
37 | * MLP examples - scikit-learn
38 | - MLP (scikit-learn).ipynb
39 |
40 | * MLP examples - pytorch
41 | - MLP_MNIST.ipynb # digit image classification
42 | - MLP_regression.ipynb # regression using MLP
43 | - MLP_autoencoder.ipynb # autoencoder
44 |
45 | * CNN example - pytorch
46 | - CNN_MNIST.ipynb # CPU and GPU
47 |
48 | * RNN (sequence prediction) - pytorcyh
49 | - sequence_prediction (RNN).ipynb
50 | - sequence_prediction (LSTM per time).ipynb
51 | - sequence_prediction (LSTM per seq).ipynb # CPU and GPU
52 |
53 |
--------------------------------------------------------------------------------
/python_core.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "python_core.ipynb",
7 | "version": "0.3.2",
8 | "provenance": [],
9 | "collapsed_sections": [],
10 | "include_colab_link": true
11 | },
12 | "kernelspec": {
13 | "name": "python3",
14 | "display_name": "Python 3"
15 | }
16 | },
17 | "cells": [
18 | {
19 | "cell_type": "markdown",
20 | "metadata": {
21 | "id": "view-in-github",
22 | "colab_type": "text"
23 | },
24 | "source": [
25 | "
"
26 | ]
27 | },
28 | {
29 | "cell_type": "code",
30 | "metadata": {
31 | "id": "kfqNWrFEr2hd",
32 | "colab_type": "code",
33 | "colab": {}
34 | },
35 | "source": [
36 | "import numpy as np"
37 | ],
38 | "execution_count": 0,
39 | "outputs": []
40 | },
41 | {
42 | "cell_type": "markdown",
43 | "metadata": {
44 | "id": "p3GZN4NGslWM",
45 | "colab_type": "text"
46 | },
47 | "source": [
48 | "# int and float types"
49 | ]
50 | },
51 | {
52 | "cell_type": "code",
53 | "metadata": {
54 | "id": "xt4GL6DlofE7",
55 | "colab_type": "code",
56 | "outputId": "77261b4b-d9f6-48b9-9bf0-fd802c6b24f7",
57 | "colab": {
58 | "base_uri": "https://localhost:8080/",
59 | "height": 35
60 | }
61 | },
62 | "source": [
63 | "a = 5\n",
64 | "type(a)"
65 | ],
66 | "execution_count": 0,
67 | "outputs": [
68 | {
69 | "output_type": "execute_result",
70 | "data": {
71 | "text/plain": [
72 | "int"
73 | ]
74 | },
75 | "metadata": {
76 | "tags": []
77 | },
78 | "execution_count": 1
79 | }
80 | ]
81 | },
82 | {
83 | "cell_type": "code",
84 | "metadata": {
85 | "id": "wk0NXRfbohxr",
86 | "colab_type": "code",
87 | "outputId": "83f52d0f-f936-4ca4-a1a9-aa4228909374",
88 | "colab": {
89 | "base_uri": "https://localhost:8080/",
90 | "height": 35
91 | }
92 | },
93 | "source": [
94 | "a = 2.0\n",
95 | "type(a)"
96 | ],
97 | "execution_count": 1,
98 | "outputs": [
99 | {
100 | "output_type": "execute_result",
101 | "data": {
102 | "text/plain": [
103 | "float"
104 | ]
105 | },
106 | "metadata": {
107 | "tags": []
108 | },
109 | "execution_count": 1
110 | }
111 | ]
112 | },
113 | {
114 | "cell_type": "markdown",
115 | "metadata": {
116 | "id": "i7ljhEufoWHW",
117 | "colab_type": "text"
118 | },
119 | "source": [
120 | "# input and output"
121 | ]
122 | },
123 | {
124 | "cell_type": "code",
125 | "metadata": {
126 | "id": "7j7vBjB-n5jl",
127 | "colab_type": "code",
128 | "colab": {
129 | "base_uri": "https://localhost:8080/",
130 | "height": 107
131 | },
132 | "outputId": "243aed46-699c-46c9-9f91-d55ac5ca9147"
133 | },
134 | "source": [
135 | "print(\"Input a: \")\n",
136 | "a = int(input())\n",
137 | "\n",
138 | "print(\"Input b: \")\n",
139 | "b = int(input())\n",
140 | "\n",
141 | "c = a + b\n",
142 | "print(\"{} + {} = {}\".format(a, b, c))"
143 | ],
144 | "execution_count": 3,
145 | "outputs": [
146 | {
147 | "output_type": "stream",
148 | "text": [
149 | "Input a: \n",
150 | "10\n",
151 | "Input b: \n",
152 | "20\n",
153 | "10 + 20 = 30\n"
154 | ],
155 | "name": "stdout"
156 | }
157 | ]
158 | },
159 | {
160 | "cell_type": "markdown",
161 | "metadata": {
162 | "id": "Xt2TKs8Zsq2j",
163 | "colab_type": "text"
164 | },
165 | "source": [
166 | "# for - Loop"
167 | ]
168 | },
169 | {
170 | "cell_type": "code",
171 | "metadata": {
172 | "id": "J-IH-6sBooYn",
173 | "colab_type": "code",
174 | "outputId": "19b85c9f-26d9-4152-837f-d64852b68192",
175 | "colab": {
176 | "base_uri": "https://localhost:8080/",
177 | "height": 107
178 | }
179 | },
180 | "source": [
181 | "# repeating 5 times\n",
182 | "for i in range(5):\n",
183 | " print(\"i = \", i)"
184 | ],
185 | "execution_count": 0,
186 | "outputs": [
187 | {
188 | "output_type": "stream",
189 | "text": [
190 | "i = 0\n",
191 | "i = 1\n",
192 | "i = 2\n",
193 | "i = 3\n",
194 | "i = 4\n"
195 | ],
196 | "name": "stdout"
197 | }
198 | ]
199 | },
200 | {
201 | "cell_type": "code",
202 | "metadata": {
203 | "id": "omN7_F-JrKf5",
204 | "colab_type": "code",
205 | "outputId": "40e9c4b0-28f8-447b-8293-391b6f0f305e",
206 | "colab": {
207 | "base_uri": "https://localhost:8080/",
208 | "height": 161
209 | }
210 | },
211 | "source": [
212 | "# repeating from 20 to 200, stop = 10\n",
213 | "for i in range(20, 100, 10): # start = 20, end = 100, step = 10\n",
214 | " print(\"i = \", i)"
215 | ],
216 | "execution_count": 0,
217 | "outputs": [
218 | {
219 | "output_type": "stream",
220 | "text": [
221 | "i = 20\n",
222 | "i = 30\n",
223 | "i = 40\n",
224 | "i = 50\n",
225 | "i = 60\n",
226 | "i = 70\n",
227 | "i = 80\n",
228 | "i = 90\n"
229 | ],
230 | "name": "stdout"
231 | }
232 | ]
233 | },
234 | {
235 | "cell_type": "markdown",
236 | "metadata": {
237 | "id": "mpYCCGcjvXgI",
238 | "colab_type": "text"
239 | },
240 | "source": [
241 | "# Lists\n",
242 | "\n",
243 | "* Container of heterogeneous objects\n",
244 | " \n",
245 | "* Flexible (e.g. a list can contain another list)\n",
246 | "\n",
247 | "* Designed for efficient insertion / deletion\n",
248 | "\n"
249 | ]
250 | },
251 | {
252 | "cell_type": "code",
253 | "metadata": {
254 | "id": "wrQLfW1Cs8S1",
255 | "colab_type": "code",
256 | "outputId": "272fbf00-fa1d-4600-b7cf-3af19c853615",
257 | "colab": {
258 | "base_uri": "https://localhost:8080/",
259 | "height": 35
260 | }
261 | },
262 | "source": [
263 | "# creating a list\n",
264 | "a = [5,10,15,20,25,30,35,40]\n",
265 | "print(a)"
266 | ],
267 | "execution_count": 0,
268 | "outputs": [
269 | {
270 | "output_type": "stream",
271 | "text": [
272 | "[5, 10, 15, 20, 25, 30, 35, 40]\n"
273 | ],
274 | "name": "stdout"
275 | }
276 | ]
277 | },
278 | {
279 | "cell_type": "code",
280 | "metadata": {
281 | "id": "mNMcB6p29M1n",
282 | "colab_type": "code",
283 | "outputId": "06d746d6-b6af-43c8-e2f1-b013e56b7958",
284 | "colab": {
285 | "base_uri": "https://localhost:8080/",
286 | "height": 35
287 | }
288 | },
289 | "source": [
290 | "\n",
291 | "b = [3.14, 100, 'hello', a]\n",
292 | "print(b)"
293 | ],
294 | "execution_count": 0,
295 | "outputs": [
296 | {
297 | "output_type": "stream",
298 | "text": [
299 | "[3.14, 100, 'hello', [5, 10, 15, 20, 25, 30, 35, 40]]\n"
300 | ],
301 | "name": "stdout"
302 | }
303 | ]
304 | },
305 | {
306 | "cell_type": "code",
307 | "metadata": {
308 | "id": "HHrAvT7QvdZq",
309 | "colab_type": "code",
310 | "outputId": "aedb8d94-dcfc-4896-f613-295c392be4ac",
311 | "colab": {
312 | "base_uri": "https://localhost:8080/",
313 | "height": 35
314 | }
315 | },
316 | "source": [
317 | "# element access\n",
318 | "print(\"a[2] = \", a[2])"
319 | ],
320 | "execution_count": 0,
321 | "outputs": [
322 | {
323 | "output_type": "stream",
324 | "text": [
325 | "a[2] = 15\n"
326 | ],
327 | "name": "stdout"
328 | }
329 | ]
330 | },
331 | {
332 | "cell_type": "code",
333 | "metadata": {
334 | "id": "gMEWuYgBo0dj",
335 | "colab_type": "code",
336 | "outputId": "82d8c941-60c7-4de3-d037-d4a92e3fada8",
337 | "colab": {
338 | "base_uri": "https://localhost:8080/",
339 | "height": 35
340 | }
341 | },
342 | "source": [
343 | "# slicing from 0 to 3 (exclusively)\n",
344 | "print(\"a[0:3] = \", a[0:3])"
345 | ],
346 | "execution_count": 0,
347 | "outputs": [
348 | {
349 | "output_type": "stream",
350 | "text": [
351 | "a[0:3] = [5, 10, 15]\n"
352 | ],
353 | "name": "stdout"
354 | }
355 | ]
356 | },
357 | {
358 | "cell_type": "code",
359 | "metadata": {
360 | "id": "W_HAWol8vV_s",
361 | "colab_type": "code",
362 | "outputId": "7a20b17a-56b8-4516-c953-608952ec5eec",
363 | "colab": {
364 | "base_uri": "https://localhost:8080/",
365 | "height": 35
366 | }
367 | },
368 | "source": [
369 | "# slicing from 5 to the end of the list\n",
370 | "print(\"a[5:] = \", a[5:])"
371 | ],
372 | "execution_count": 0,
373 | "outputs": [
374 | {
375 | "output_type": "stream",
376 | "text": [
377 | "a[5:] = [30, 35, 40]\n"
378 | ],
379 | "name": "stdout"
380 | }
381 | ]
382 | },
383 | {
384 | "cell_type": "markdown",
385 | "metadata": {
386 | "id": "PnN7xe2a4nOx",
387 | "colab_type": "text"
388 | },
389 | "source": [
390 | "# Tuples\n",
391 | "\n",
392 | "* Container of heterogeneous objects \n",
393 | "\n",
394 | "* Immutable"
395 | ]
396 | },
397 | {
398 | "cell_type": "code",
399 | "metadata": {
400 | "id": "v4UVKjqL4mYQ",
401 | "colab_type": "code",
402 | "outputId": "48229023-ab0f-4198-cb4a-e4576e0abce4",
403 | "colab": {
404 | "base_uri": "https://localhost:8080/",
405 | "height": 35
406 | }
407 | },
408 | "source": [
409 | "# creating a tuple\n",
410 | "t = (5,'program', 1+3j)\n",
411 | "print(t)"
412 | ],
413 | "execution_count": 0,
414 | "outputs": [
415 | {
416 | "output_type": "stream",
417 | "text": [
418 | "(5, 'program', (1+3j))\n"
419 | ],
420 | "name": "stdout"
421 | }
422 | ]
423 | },
424 | {
425 | "cell_type": "code",
426 | "metadata": {
427 | "id": "qwi_ybso4pZL",
428 | "colab_type": "code",
429 | "outputId": "d08cb184-845d-4a44-d581-352073852943",
430 | "colab": {
431 | "base_uri": "https://localhost:8080/",
432 | "height": 53
433 | }
434 | },
435 | "source": [
436 | "# accessing elements\n",
437 | "\n",
438 | "print(\"t[1] = \", t[1])\n",
439 | "# t[1] = 'program'\n",
440 | "\n",
441 | "print(\"t[0:3] = \", t[0:3])\n",
442 | "# t[0:3] = (5, 'program', (1+3j))"
443 | ],
444 | "execution_count": 0,
445 | "outputs": [
446 | {
447 | "output_type": "stream",
448 | "text": [
449 | "t[1] = program\n",
450 | "t[0:3] = (5, 'program', (1+3j))\n"
451 | ],
452 | "name": "stdout"
453 | }
454 | ]
455 | },
456 | {
457 | "cell_type": "code",
458 | "metadata": {
459 | "id": "z6qrV3U28NG7",
460 | "colab_type": "code",
461 | "outputId": "2a2dbdcb-0dfb-4f25-95c0-25d89624d2b7",
462 | "colab": {
463 | "base_uri": "https://localhost:8080/",
464 | "height": 172
465 | }
466 | },
467 | "source": [
468 | "# Tuples are immutable\n",
469 | "# The following code generates error\n",
470 | "t[0] = 10"
471 | ],
472 | "execution_count": 0,
473 | "outputs": [
474 | {
475 | "output_type": "error",
476 | "ename": "TypeError",
477 | "evalue": "ignored",
478 | "traceback": [
479 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
480 | "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
481 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mt\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m10\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
482 | "\u001b[0;31mTypeError\u001b[0m: 'tuple' object does not support item assignment"
483 | ]
484 | }
485 | ]
486 | },
487 | {
488 | "cell_type": "markdown",
489 | "metadata": {
490 | "id": "HCwwrT8Nsef3",
491 | "colab_type": "text"
492 | },
493 | "source": [
494 | "# Function"
495 | ]
496 | },
497 | {
498 | "cell_type": "code",
499 | "metadata": {
500 | "id": "42vld-6KpSqX",
501 | "colab_type": "code",
502 | "colab": {}
503 | },
504 | "source": [
505 | "def is_even(x):\n",
506 | " if x % 2 == 0:\n",
507 | " return True\n",
508 | " else:\n",
509 | " return False"
510 | ],
511 | "execution_count": 0,
512 | "outputs": []
513 | },
514 | {
515 | "cell_type": "code",
516 | "metadata": {
517 | "id": "_0Bee6LPsTVl",
518 | "colab_type": "code",
519 | "outputId": "f9bb2d44-3503-4cd6-9540-d5e399941038",
520 | "colab": {
521 | "base_uri": "https://localhost:8080/",
522 | "height": 53
523 | }
524 | },
525 | "source": [
526 | "print(\"is_even(20) = {}\".format(is_even(20)))\n",
527 | "print(\"is_even(35) = {}\".format(is_even(35)))"
528 | ],
529 | "execution_count": 0,
530 | "outputs": [
531 | {
532 | "output_type": "stream",
533 | "text": [
534 | "is_even(20) = True\n",
535 | "is_even(35) = False\n"
536 | ],
537 | "name": "stdout"
538 | }
539 | ]
540 | },
541 | {
542 | "cell_type": "markdown",
543 | "metadata": {
544 | "id": "Ev2rm5PHvnqh",
545 | "colab_type": "text"
546 | },
547 | "source": [
548 | "# Class\n",
549 | "\n"
550 | ]
551 | },
552 | {
553 | "cell_type": "code",
554 | "metadata": {
555 | "id": "OuEe_PMfschP",
556 | "colab_type": "code",
557 | "colab": {}
558 | },
559 | "source": [
560 | "# defining a class\n",
561 | "class Hello:\n",
562 | " def __init__(self, message):\n",
563 | " self.mesg = message \n",
564 | " \n",
565 | " def display(self):\n",
566 | " print('=' * 20, end = '')\n",
567 | " print(self.mesg, end = '')\n",
568 | " print('=' * 20, end = '')\n",
569 | " \n",
570 | " def get_message(self):\n",
571 | " return self.mesg"
572 | ],
573 | "execution_count": 0,
574 | "outputs": []
575 | },
576 | {
577 | "cell_type": "code",
578 | "metadata": {
579 | "id": "w1FAgPbnwL8m",
580 | "colab_type": "code",
581 | "colab": {}
582 | },
583 | "source": [
584 | "# creating a class object\n",
585 | "hello = Hello('nice to see you!')"
586 | ],
587 | "execution_count": 0,
588 | "outputs": []
589 | },
590 | {
591 | "cell_type": "code",
592 | "metadata": {
593 | "id": "OINhVC65wTLc",
594 | "colab_type": "code",
595 | "outputId": "1fdda1bb-cc6b-4160-e28f-02cf9d65ebc3",
596 | "colab": {
597 | "base_uri": "https://localhost:8080/",
598 | "height": 35
599 | }
600 | },
601 | "source": [
602 | "# calling methods\n",
603 | "hello.display()\n",
604 | "print(\"message = \", hello.get_message())"
605 | ],
606 | "execution_count": 0,
607 | "outputs": [
608 | {
609 | "output_type": "stream",
610 | "text": [
611 | "====================nice to see you!====================message = nice to see you!\n"
612 | ],
613 | "name": "stdout"
614 | }
615 | ]
616 | },
617 | {
618 | "cell_type": "code",
619 | "metadata": {
620 | "id": "SiyKRh1fxMEk",
621 | "colab_type": "code",
622 | "colab": {}
623 | },
624 | "source": [
625 | ""
626 | ],
627 | "execution_count": 0,
628 | "outputs": []
629 | }
630 | ]
631 | }
--------------------------------------------------------------------------------
/skip-gram.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "Skip-gram 최종.ipynb",
7 | "provenance": [],
8 | "collapsed_sections": [],
9 | "include_colab_link": true
10 | },
11 | "kernelspec": {
12 | "name": "python3",
13 | "display_name": "Python 3"
14 | }
15 | },
16 | "cells": [
17 | {
18 | "cell_type": "markdown",
19 | "metadata": {
20 | "id": "view-in-github",
21 | "colab_type": "text"
22 | },
23 | "source": [
24 | "
"
25 | ]
26 | },
27 | {
28 | "cell_type": "markdown",
29 | "metadata": {
30 | "id": "zjwDBfJERa07",
31 | "colab_type": "text"
32 | },
33 | "source": [
34 | "# **Skip-gram with negative smapling**"
35 | ]
36 | },
37 | {
38 | "cell_type": "code",
39 | "metadata": {
40 | "id": "du0GK2oNRItO",
41 | "colab_type": "code",
42 | "colab": {}
43 | },
44 | "source": [
45 | "import torch\n",
46 | "import torch.nn as nn\n",
47 | "from torch.autograd import Variable\n",
48 | "import torch.optim as optim\n",
49 | "import torch.nn.functional as F\n",
50 | "import nltk\n",
51 | "import random\n",
52 | "import numpy as np\n",
53 | "from collections import Counter\n",
54 | "flatten = lambda l: [item for sublist in l for item in sublist]\n",
55 | "random.seed(1024)"
56 | ],
57 | "execution_count": 0,
58 | "outputs": []
59 | },
60 | {
61 | "cell_type": "markdown",
62 | "metadata": {
63 | "id": "b17jQohbvXhM",
64 | "colab_type": "text"
65 | },
66 | "source": [
67 | "## NLTK (Natural Language Toolkit)\n",
68 | "\n",
69 | "https://www.nltk.org/\n",
70 | "\n",
71 | "NLTK is a leading platform for building Python programs to work with human language data. It provides easy-to-use interfaces to over 50 corpora and lexical resources such as WordNet, along with a suite of text processing libraries for classification, tokenization, stemming, tagging, parsing, and semantic reasoning, wrappers for industrial-strength NLP libraries, and an active discussion forum."
72 | ]
73 | },
74 | {
75 | "cell_type": "code",
76 | "metadata": {
77 | "id": "NZrLs03Yx3OJ",
78 | "colab_type": "code",
79 | "outputId": "61c1b22b-3f39-4137-97dd-d3d797884b8c",
80 | "colab": {
81 | "base_uri": "https://localhost:8080/",
82 | "height": 107
83 | }
84 | },
85 | "source": [
86 | "nltk.download('gutenberg')\n",
87 | "nltk.download('punkt')"
88 | ],
89 | "execution_count": 2,
90 | "outputs": [
91 | {
92 | "output_type": "stream",
93 | "text": [
94 | "[nltk_data] Downloading package gutenberg to /root/nltk_data...\n",
95 | "[nltk_data] Unzipping corpora/gutenberg.zip.\n",
96 | "[nltk_data] Downloading package punkt to /root/nltk_data...\n",
97 | "[nltk_data] Unzipping tokenizers/punkt.zip.\n"
98 | ],
99 | "name": "stdout"
100 | },
101 | {
102 | "output_type": "execute_result",
103 | "data": {
104 | "text/plain": [
105 | "True"
106 | ]
107 | },
108 | "metadata": {
109 | "tags": []
110 | },
111 | "execution_count": 2
112 | }
113 | ]
114 | },
115 | {
116 | "cell_type": "code",
117 | "metadata": {
118 | "id": "L92lZOsbRR01",
119 | "colab_type": "code",
120 | "colab": {}
121 | },
122 | "source": [
123 | "FloatTensor = torch.FloatTensor\n",
124 | "LongTensor = torch.LongTensor\n",
125 | "ByteTensor = torch.ByteTensor"
126 | ],
127 | "execution_count": 0,
128 | "outputs": []
129 | },
130 | {
131 | "cell_type": "code",
132 | "metadata": {
133 | "id": "irN4kKwcRkqE",
134 | "colab_type": "code",
135 | "colab": {}
136 | },
137 | "source": [
138 | "def getBatch(batch_size, train_data):\n",
139 | " random.shuffle(train_data)\n",
140 | " sindex = 0\n",
141 | " eindex = batch_size\n",
142 | " while eindex < len(train_data):\n",
143 | " batch = train_data[sindex: eindex]\n",
144 | " temp = eindex\n",
145 | " eindex = eindex + batch_size\n",
146 | " sindex = temp\n",
147 | " yield batch\n",
148 | " \n",
149 | " if eindex >= len(train_data):\n",
150 | " batch = train_data[sindex:]\n",
151 | " yield batch"
152 | ],
153 | "execution_count": 0,
154 | "outputs": []
155 | },
156 | {
157 | "cell_type": "code",
158 | "metadata": {
159 | "id": "9KO8YBrTRnMt",
160 | "colab_type": "code",
161 | "colab": {}
162 | },
163 | "source": [
164 | "# Return the Tensor with index information corresponding to seq\n",
165 | "def prepare_sequence(seq, word2index):\n",
166 | " idxs = list(map(lambda w: word2index[w] if word2index.get(w) is not None else word2index[\"\"], seq)) # (함수정의, parameter)\n",
167 | " return Variable(LongTensor(idxs))\n",
168 | "\n",
169 | "#Return the Tensor with index information corresponding to the word\n",
170 | "def prepare_word(word, word2index):\n",
171 | " return Variable(LongTensor([word2index[word]]) if word2index.get(word) is not None else LongTensor([word2index[\"\"]]))\n"
172 | ],
173 | "execution_count": 0,
174 | "outputs": []
175 | },
176 | {
177 | "cell_type": "markdown",
178 | "metadata": {
179 | "id": "g67k1BuKRpoX",
180 | "colab_type": "text"
181 | },
182 | "source": [
183 | "# **Data load and Preprocessing**"
184 | ]
185 | },
186 | {
187 | "cell_type": "code",
188 | "metadata": {
189 | "id": "JNFiZcQSRunn",
190 | "colab_type": "code",
191 | "colab": {}
192 | },
193 | "source": [
194 | "corpus = list(nltk.corpus.gutenberg.sents('melville-moby_dick.txt'))[:500]\n",
195 | "corpus = [[word.lower() for word in sent] for sent in corpus]"
196 | ],
197 | "execution_count": 0,
198 | "outputs": []
199 | },
200 | {
201 | "cell_type": "markdown",
202 | "metadata": {
203 | "id": "fChJK81CTMnO",
204 | "colab_type": "text"
205 | },
206 | "source": [
207 | "**Exclude sparse words**"
208 | ]
209 | },
210 | {
211 | "cell_type": "code",
212 | "metadata": {
213 | "id": "L3d718UOTQG_",
214 | "colab_type": "code",
215 | "colab": {}
216 | },
217 | "source": [
218 | "word_count = Counter(flatten(corpus))"
219 | ],
220 | "execution_count": 0,
221 | "outputs": []
222 | },
223 | {
224 | "cell_type": "code",
225 | "metadata": {
226 | "id": "UeFxpKBDxd3l",
227 | "colab_type": "code",
228 | "colab": {}
229 | },
230 | "source": [
231 | "MIN_COUNT = 3\n",
232 | "exclude = []"
233 | ],
234 | "execution_count": 0,
235 | "outputs": []
236 | },
237 | {
238 | "cell_type": "code",
239 | "metadata": {
240 | "id": "8C-58KIRxf0G",
241 | "colab_type": "code",
242 | "colab": {}
243 | },
244 | "source": [
245 | "for w, c in word_count.items():\n",
246 | " if c < MIN_COUNT:\n",
247 | " exclude.append(w)\n"
248 | ],
249 | "execution_count": 0,
250 | "outputs": []
251 | },
252 | {
253 | "cell_type": "markdown",
254 | "metadata": {
255 | "id": "wdqakO7-yCwT",
256 | "colab_type": "text"
257 | },
258 | "source": [
259 | "Prepare train data"
260 | ]
261 | },
262 | {
263 | "cell_type": "code",
264 | "metadata": {
265 | "id": "HELs6M7VyEiF",
266 | "colab_type": "code",
267 | "colab": {}
268 | },
269 | "source": [
270 | "vocab = list(set(flatten(corpus)) - set(exclude))"
271 | ],
272 | "execution_count": 0,
273 | "outputs": []
274 | },
275 | {
276 | "cell_type": "code",
277 | "metadata": {
278 | "id": "mrWqRzyjyLbW",
279 | "colab_type": "code",
280 | "colab": {}
281 | },
282 | "source": [
283 | "word2index = {}\n",
284 | "for vo in vocab:\n",
285 | " if word2index.get(vo) is None:\n",
286 | " word2index[vo] = len(word2index)\n",
287 | " \n",
288 | "index2word = {v:k for k, v in word2index.items()}"
289 | ],
290 | "execution_count": 0,
291 | "outputs": []
292 | },
293 | {
294 | "cell_type": "code",
295 | "metadata": {
296 | "id": "RRoEerE6yOBE",
297 | "colab_type": "code",
298 | "colab": {}
299 | },
300 | "source": [
301 | "WINDOW_SIZE = 5\n",
302 | "windows = flatten([list(nltk.ngrams([''] * WINDOW_SIZE + c + [''] * WINDOW_SIZE, WINDOW_SIZE * 2 + 1)) for c in corpus])\n",
303 | "\n",
304 | "train_data = []\n",
305 | "\n",
306 | "for window in windows:\n",
307 | " for i in range(WINDOW_SIZE * 2 + 1):\n",
308 | " if window[i] in exclude or window[WINDOW_SIZE] in exclude: \n",
309 | " continue # min_count\n",
310 | " if i == WINDOW_SIZE or window[i] == '': \n",
311 | " continue\n",
312 | " train_data.append((window[WINDOW_SIZE], window[i]))\n",
313 | "\n",
314 | "X_p = []\n",
315 | "y_p = []\n",
316 | "\n",
317 | "for tr in train_data:\n",
318 | " X_p.append(prepare_word(tr[0], word2index).view(1, -1))\n",
319 | " y_p.append(prepare_word(tr[1], word2index).view(1, -1))\n",
320 | "\n",
321 | "#change to tensor with index\n",
322 | "train_data = list(zip(X_p, y_p))\n",
323 | "\n"
324 | ],
325 | "execution_count": 0,
326 | "outputs": []
327 | },
328 | {
329 | "cell_type": "markdown",
330 | "metadata": {
331 | "id": "vYLCe2AT6gOV",
332 | "colab_type": "text"
333 | },
334 | "source": [
335 | "# **Build Unigram Distribution ** 0.75**\n"
336 | ]
337 | },
338 | {
339 | "cell_type": "code",
340 | "metadata": {
341 | "id": "tj2_OxyH7UZ7",
342 | "colab_type": "code",
343 | "colab": {}
344 | },
345 | "source": [
346 | "Z = 0.001\n",
347 | "word_count = Counter(flatten(corpus))\n",
348 | "num_total_words = sum([c for w, c in word_count.items() if w not in exclude])"
349 | ],
350 | "execution_count": 0,
351 | "outputs": []
352 | },
353 | {
354 | "cell_type": "code",
355 | "metadata": {
356 | "id": "BGHIylEF7hmp",
357 | "colab_type": "code",
358 | "outputId": "36399f72-38dc-4520-be73-8379fdb5b9fe",
359 | "colab": {
360 | "base_uri": "https://localhost:8080/",
361 | "height": 55
362 | }
363 | },
364 | "source": [
365 | "unigram_table = []\n",
366 | "\n",
367 | "for vo in vocab:\n",
368 | " unigram_table.extend([vo] * int(((word_count[vo]/num_total_words)**0.75)/Z))\n",
369 | "\n",
370 | "print(unigram_table)"
371 | ],
372 | "execution_count": 14,
373 | "outputs": [
374 | {
375 | "output_type": "stream",
376 | "text": [
377 | "['that', 'that', 'that', 'that', 'that', 'that', 'that', 'that', 'that', 'that', 'that', 'that', 'that', 'that', 'that', 'that', 'that', 'that', 'that', 'that', 'that', 'that', 'that', 'that', 'that', 'that', 'that', 'that', 'that', 'that', 'that', 'that', 'that', 'that', 'that', 'that', 'that', 'that', 'that', 'that', 'or', 'or', 'or', 'or', 'or', 'or', 'or', 'or', 'or', 'or', 'or', 'or', 'or', 'or', 'or', 'or', 'or', 'world', 'world', 'world', 'world', 'world', 'world', 'dives', 'dives', 'picture', 'picture', 'picture', 'picture', 'one', 'one', 'one', 'one', 'one', 'one', 'one', 'one', 'one', 'one', 'one', 'one', 'one', 'one', 'one', 'one', 'one', 'one', 'when', 'when', 'when', 'when', 'when', 'when', 'when', 'when', 'came', 'came', 'came', 'came', 'came', 'them', 'them', 'them', 'them', 'them', 'them', 'them', 'them', 'them', 'among', 'among', 'among', 'among', 'marvellous', 'marvellous', 'captain', 'captain', 'captain', 'london', 'london', 'yet', 'yet', 'yet', 'yet', 'yet', 'yet', 'english', 'english', '-', '-', '-', '-', '-', '-', '-', '-', '-', '-', '-', '-', '-', '-', '-', '-', '-', '-', '-', '-', '-', '-', '-', '-', '-', '-', '-', '-', '-', '-', '-', '-', '-', '-', '-', '-', '-', '-', 'air', 'air', 'air', 'air', 'air', 'whom', 'whom', 'whom', 'of', 'of', 'of', 'of', 'of', 'of', 'of', 'of', 'of', 'of', 'of', 'of', 'of', 'of', 'of', 'of', 'of', 'of', 'of', 'of', 'of', 'of', 'of', 'of', 'of', 'of', 'of', 'of', 'of', 'of', 'of', 'of', 'of', 'of', 'of', 'of', 'of', 'of', 'of', 'of', 'of', 'of', 'of', 'of', 'of', 'of', 'of', 'of', 'of', 'of', 'of', 'of', 'of', 'of', 'of', 'of', 'of', 'of', 'of', 'of', 'of', 'of', 'of', 'of', 'of', 'of', 'of', 'of', 'of', 'of', 'of', 'of', 'of', 'of', 'of', 'of', 'of', 'of', 'of', 'of', 'of', 'of', 'of', 'of', 'of', 'of', 'perhaps', 'perhaps', 'better', 'better', 'better', 'kind', 'kind', 'summer', 'summer', 'having', 'having', 'having', 'against', 'against', 'against', 'brought', 'brought', 'vessel', 'vessel', 'vessel', 'entering', 'entering', ';--', ';--', 'blue', 'blue', 'thou', 'thou', 'thou', 'thou', 'my', 'my', 'my', 'my', 'my', 'my', 'my', 'my', 'my', 'my', 'my', 'my', 'my', 'my', 'my', 'however', 'however', 'however', 'however', 'order', 'order', 'order', 'order', 'order', 'only', 'only', 'only', 'only', 'only', 'mouth', 'mouth', 'mouth', 'mouth', 'mouth', 'how', 'how', 'how', 'how', 'how', 'were', 'were', 'were', 'were', 'were', 'were', 'were', 'were', 'were', 'were', 'were', 'were', 'were', 'chace', 'chace', 'seemed', 'seemed', 'seemed', 'entry', 'entry', 'entry', 'last', 'last', 'last', 'last', 'from', 'from', 'from', 'from', 'from', 'from', 'from', 'from', 'from', 'from', 'from', 'from', 'from', 'from', 'from', 'from', 'from', 'from', 'from', 'from', 'history', 'history', 'history', 'two', 'two', 'two', 'two', 'two', 'two', 'two', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', ',', 'set', 'set', 'set', 'set', 'armed', 'armed', 'window', 'window', 'passage', 'passage', 'passage', 'purse', 'purse', 'here', 'here', 'here', 'here', 'here', 'here', 'here', 'here', 'here', 'american', 'american', 'are', 'are', 'are', 'are', 'are', 'are', 'are', 'are', 'are', 'are', 'are', 'are', 'image', 'image', 'be', 'be', 'be', 'be', 'be', 'be', 'be', 'be', 'be', 'be', 'be', 'be', 'be', 'be', 'be', 'be', 'said', 'said', 'said', 'said', 'said', 'once', 'once', 'once', 'once', 'once', 'ay', 'ay', 'ay', 'though', 'though', 'though', 'though', 'though', 'though', 'though', 'though', 'ribs', 'ribs', 'thing', 'thing', 'thing', 'thing', 'find', 'find', 'find', 'find', 'boats', 'boats', 'boats', 'ever', 'ever', 'ever', 'ever', 'ever', 'ever', 'ever', 'ever', 'same', 'same', 'same', 'same', 'same', 'extracts', 'extracts', 'extracts', 'else', 'else', 'else', '.--', '.--', '.--', '.--', 'particular', 'particular', 'particular', 'again', 'again', 'sail', 'sail', 'sail', 'sail', 'mast', 'mast', 'mast', 'mast', 'mast', 'not', 'not', 'not', 'not', 'not', 'not', 'not', 'not', 'not', 'not', 'not', 'not', 'not', 'not', 'not', 'told', 'told', 'told', 'told', 'besides', 'besides', 'besides', 'besides', 'wind', 'wind', 'wind', 'north', 'north', 'say', 'say', 'say', 'say', 'say', 'maketh', 'maketh', 'maketh', 'years', 'years', 'years', 'years', 'years', 'their', 'their', 'their', 'their', 'their', 'their', 'their', 'their', 'their', 'their', 'their', 'their', 'their', 'whenever', 'whenever', 'whenever', 'whenever', 'as', 'as', 'as', 'as', 'as', 'as', 'as', 'as', 'as', 'as', 'as', 'as', 'as', 'as', 'as', 'as', 'as', 'as', 'as', 'as', 'as', 'as', 'as', 'as', 'as', 'hear', 'hear', 'found', 'found', 'found', 'found', 'account', 'account', 'account', 'must', 'must', 'must', 'must', 'must', 'does', 'does', 'whale', 'whale', 'whale', 'whale', 'whale', 'whale', 'whale', 'whale', 'whale', 'whale', 'whale', 'whale', 'whale', 'whale', 'whale', 'whale', 'whale', 'whale', 'whale', 'whale', 'whale', 'whale', 'whale', 'whale', 'whale', 'whale', 'whale', 'whale', 'whale', 'whale', 'whale', 'whale', 'whale', 'euroclydon', 'euroclydon', 'euroclydon', 'euroclydon', 'tell', 'tell', 'tell', 'but', 'but', 'but', 'but', 'but', 'but', 'but', 'but', 'but', 'but', 'but', 'but', 'but', 'but', 'but', 'but', 'but', 'but', 'but', 'but', 'but', 'bright', 'bright', 'bright', 'country', 'country', 'country', 'men', 'men', 'men', 'men', 'almost', 'almost', 'almost', 'almost', 'almost', 'spouter', 'spouter', 'spouter', 'spouter', 'than', 'than', 'than', 'than', 'than', 'than', 'than', 'than', 'than', 'soul', 'soul', 'supplied', 'supplied', 'tears', 'tears', 'supper', 'supper', 'yourself', 'yourself', 'voyage', 'voyage', 'voyage', 'voyage', 'voyage', 'voyage', 'voyage', 'voyage', 'in', 'in', 'in', 'in', 'in', 'in', 'in', 'in', 'in', 'in', 'in', 'in', 'in', 'in', 'in', 'in', 'in', 'in', 'in', 'in', 'in', 'in', 'in', 'in', 'in', 'in', 'in', 'in', 'in', 'in', 'in', 'in', 'in', 'in', 'in', 'in', 'in', 'in', 'in', 'in', 'in', 'in', 'in', 'in', 'in', 'in', 'in', 'in', 'in', 'in', 'in', 'in', 'in', 'in', 'in', 'in', 'in', 'in', 'in', 'in', 'in', 'stranded', 'stranded', 'less', 'less', 'who', 'who', 'who', 'who', 'who', 'who', 'looked', 'looked', 'looked', '\"', '\"', '\"', '\"', '\"', '\"', '\"', '\"', '\"', '\"', '\"', '\"', '\"', '\"', '\"', '\"', '\"', '\"', '\"', '\"', '\"', '\"', '\"', '\"', '\"', '\"', '\"', '\"', '\"', '\"', '\"', '\"', '\"', '\"', '\"', '\"', '\"', '\"', '\"', '\"', '\"', '\"', '\"', 'money', 'money', 'money', 'money', 'green', 'green', 'after', 'after', 'grow', 'grow', 'grow', 'make', 'make', 'make', 'make', 'sperma', 'sperma', 'like', 'like', 'like', 'like', 'like', 'like', 'like', 'like', 'like', 'like', 'like', 'half', 'half', 'half', 'penny', 'penny', 'oil', 'oil', 'oil', 'over', 'over', 'over', 'over', 'over', 'over', 'over', 'over', 'deep', 'deep', 'deep', 'deep', 'hands', 'hands', 'hill', 'hill', 'shore', 'shore', 'shore', 'goes', 'goes', 'think', 'think', 'is', 'is', 'is', 'is', 'is', 'is', 'is', 'is', 'is', 'is', 'is', 'is', 'is', 'is', 'is', 'is', 'is', 'is', 'is', 'is', 'is', 'is', 'is', 'is', 'is', 'is', 'is', 'is', 'is', 'is', 'is', 'is', 'is', 'they', 'they', 'they', 'they', 'they', 'they', 'they', 'they', 'they', 'they', 'they', 'they', 'they', 'they', 'they', 'they', 'ye', 'ye', 'ye', 'ye', 'ye', 'down', 'down', 'down', 'down', 'down', 'down', 'down', 'where', 'where', 'where', 'where', 'where', 'where', 'where', 'between', 'between', 'heads', 'heads', 'artist', 'artist', 'thinks', 'thinks', 'swallow', 'swallow', 'life', 'life', 'life', 'fifty', 'fifty', 'commodore', 'commodore', 'fixed', 'fixed', 'did', 'did', 'did', 'did', 'did', 'did', 'did', 'form', 'form', 'huge', 'huge', '--\"', '--\"', '--\"', '--\"', 'would', 'would', 'would', 'would', 'would', 'would', 'would', 'would', 'would', 'would', 'own', 'own', 'own', 'own', 'own', 'hand', 'hand', 'hand', 'hand', 'water', 'water', 'water', 'water', 'water', 'water', 'water', 'water', 'water', 'up', 'up', 'up', 'up', 'up', 'up', 'up', 'up', 'up', 'up', 'up', 'up', 'up', 'up', 'glass', 'glass', 'glass', 'behind', 'behind', 'killed', 'killed', 'killed', 'killed', '?', '?', '?', '?', '?', '?', '?', '?', '?', '?', '?', '?', '?', '?', 'it', 'it', 'it', 'it', 'it', 'it', 'it', 'it', 'it', 'it', 'it', 'it', 'it', 'it', 'it', 'it', 'it', 'it', 'it', 'it', 'it', 'it', 'it', 'it', 'it', 'it', 'it', 'it', 'it', 'it', 'it', 'it', 'it', 'it', 'it', 'it', 'going', 'going', 'going', 'going', 'way', 'way', 'way', 'way', 'way', 'way', 'way', 'way', 'old', 'old', 'old', 'old', 'old', 'old', 'old', 'old', 'old', 'old', 'old', 'old', 'old', 'old', 'altogether', 'altogether', ',\"', ',\"', ',\"', ',\"', ',\"', ',\"', 'itself', 'itself', 'because', 'because', 'because', 'because', 'arched', 'arched', 'which', 'which', 'which', 'which', 'which', 'which', 'which', 'which', 'which', 'which', 'which', 'which', 'which', 'which', 'which', 'which', 'enter', 'enter', 'can', 'can', 'can', 'can', 'can', 'can', 'can', 'monster', 'monster', 'monster', 'being', 'being', 'being', 'being', 'being', 'being', 'being', 'this', 'this', 'this', 'this', 'this', 'this', 'this', 'this', 'this', 'this', 'this', 'this', 'this', 'this', 'this', 'this', 'this', 'this', 'this', 'this', 'this', 'this', 'whales', 'whales', 'whales', 'whales', 'whales', 'whales', 'whales', 'whales', 'whales', 'whales', 'whales', 'whales', 'whales', 'streets', 'streets', 'streets', 'streets', 'nigh', 'nigh', 'sleep', 'sleep', 'sleep', 'animal', 'animal', 'animal', 'animal', 'sailor', 'sailor', 'sailor', 'sailor', 'light', 'light', 'light', 'vast', 'vast', 'vast', 'vast', 'vast', 'right', 'right', 'right', 'right', 'right', 'right', 'such', 'such', 'such', 'such', 'such', 'such', 'such', 'such', 'such', 'lay', 'lay', 'themselves', 'themselves', 'land', 'land', 'land', 'land', 'land', 'land', 'ceti', 'ceti', 'immense', 'immense', 'man', 'man', 'man', 'man', 'man', 'man', 'towards', 'towards', 'towards', 'towards', 'rather', 'rather', 'rather', 'rather', 'quantity', 'quantity', 'then', 'then', 'then', 'then', 'then', 'craft', 'craft', 'craft', 'craft', 'ocean', 'ocean', 'ocean', 'ocean', 'ocean', 'ocean', 'ocean', 'should', 'should', 'should', 'should', 'should', 'should', 'also', 'also', 'what', 'what', 'what', 'what', 'what', 'what', 'what', 'what', 'what', 'what', 'what', 'what', 'what', 'she', 'she', 'she', 'she', 'dim', 'dim', 'dim', '...', '...', '...', '...', '...', '...', '...', 'high', 'high', 'high', 'high', 'well', 'well', 'well', 'well', 'well', 'known', 'known', 'known', 'was', 'was', 'was', 'was', 'was', 'was', 'was', 'was', 'was', 'was', 'was', 'was', 'was', 'was', 'was', 'was', 'was', 'was', 'was', 'was', 'was', 'upon', 'upon', 'upon', 'upon', 'upon', 'upon', 'upon', 'upon', 'upon', 'upon', 'upon', 'till', 'till', 'till', 'mighty', 'mighty', 'poor', 'poor', 'poor', 'poor', 'poor', 'globe', 'globe', 'globe', 'globe', 'city', 'city', 'city', '(', '(', '(', '(', '(', '(', '(', '(', '(', '(', 'hearts', 'hearts', '),', '),', '),', '),', 'him', 'him', 'him', 'him', 'him', 'him', 'him', 'him', 'him', 'shall', 'shall', 'shall', 'shall', 'young', 'young', 'for', 'for', 'for', 'for', 'for', 'for', 'for', 'for', 'for', 'for', 'for', 'for', 'for', 'for', 'for', 'for', 'for', 'for', 'for', 'for', 'for', 'for', 'for', 'for', 'for', 'winds', 'winds', 'within', 'within', 'within', 'within', 'within', 'sir', 'sir', 'sir', 'sir', 'sir', 'sir', 'stern', 'stern', 'sign', 'sign', 'sign', 'sign', 'bag', 'bag', 'things', 'things', 'things', 'things', 'heart', 'heart', 'body', 'body', 'body', 'beneath', 'beneath', 'your', 'your', 'your', 'your', 'your', 'your', 'your', 'your', 'your', 'no', 'no', 'no', 'no', 'no', 'no', 'no', 'no', 'no', 'no', 'sperm', 'sperm', 'sperm', 'sperm', 'sperm', 'leviathan', 'leviathan', 'leviathan', 'leviathan', 'leviathan', 'leviathan', 'leviathan', 'leviathan', 'leviathan', 'me', 'me', 'me', 'me', 'me', 'me', 'me', 'me', 'me', 'me', 'me', 'me', 'me', 'me', 'me', 'me', 'me', 'voyages', 'voyages', 'voyages', 'harpooneer', 'harpooneer', 'harpooneer', 'other', 'other', 'other', 'other', 'other', 'other', 'other', 'other', 'harpoons', 'harpoons', 'house', 'house', 'house', 'house', 'mind', 'mind', 'floating', 'floating', 'get', 'get', 'get', 'get', 'get', 't', 't', 't', 't', 't', 't', 'letter', 'letter', 'had', 'had', 'had', 'had', 'had', 'had', 'had', 'had', 'had', 'had', 'had', 'stone', 'stone', 'fish', 'fish', 'fish', 'fish', 'fish', 'door', 'door', 'door', 'door', 'thought', 'thought', 'thought', 'thought', 'thought', 'thought', 'either', 'either', 'either', 'either', 'to', 'to', 'to', 'to', 'to', 'to', 'to', 'to', 'to', 'to', 'to', 'to', 'to', 'to', 'to', 'to', 'to', 'to', 'to', 'to', 'to', 'to', 'to', 'to', 'to', 'to', 'to', 'to', 'to', 'to', 'to', 'to', 'to', 'to', 'to', 'to', 'to', 'to', 'to', 'to', 'to', 'to', 'to', 'to', 'to', 'to', 'to', 'to', 'to', 'to', 'to', 'to', 'to', 'to', 'to', 'to', 'to', 'to', 'to', 'few', 'few', 'true', 'true', 'true', 'true', 'gathered', 'gathered', 'around', 'around', 'around', 'inn', 'inn', 'inn', 'whaling', 'whaling', 'whaling', 'whaling', 'whaling', 'whaling', 'whaling', 'whaling', 'forty', 'forty', 'new', 'new', 'new', 'new', 'new', 'new', 'its', 'its', 'its', 'its', 'its', 'its', 'near', 'near', 'near', 'near', 'each', 'each', 'some', 'some', 'some', 'some', 'some', 'some', 'some', 'some', 'some', 'some', 'some', 'some', 'some', 'some', 'let', 'let', 'let', 'let', 'let', 'day', 'day', 'aloft', 'aloft', 'doubtless', 'doubtless', ';', ';', ';', ';', ';', ';', ';', ';', ';', ';', ';', ';', ';', ';', ';', ';', ';', ';', ';', ';', ';', ';', ';', ';', ';', ';', ';', ';', ';', ';', 'ibid', 'ibid', 'black', 'black', 'black', 'black', 'about', 'about', 'about', 'about', 'about', 'about', 'about', 'about', 'about', 'moving', 'moving', 'he', 'he', 'he', 'he', 'he', 'he', 'he', 'he', 'he', 'he', 'he', 'he', 'he', 'he', 'he', 'he', 'he', 'he', 'an', 'an', 'an', 'an', 'an', 'an', 'an', 'an', 'an', 'an', '?\"', '?\"', '?\"', '?\"', 'gone', 'gone', 'gone', 'lead', 'lead', 'passenger', 'passenger', 'passenger', 'tempestuous', 'tempestuous', 'myself', 'myself', 'myself', 'myself', 'myself', 'sleeps', 'sleeps', 'making', 'making', 'red', 'red', 'others', 'others', 'ship', 'ship', 'ship', 'ship', 'ship', 'ship', 'ship', 'ship', 'without', 'without', 'without', 'without', 'without', 'been', 'been', 'been', 'been', 'been', 'been', 'been', 's', 's', 's', 's', 's', 's', 's', 's', 's', 's', 's', 's', 's', 's', 's', 's', 's', 's', 's', 's', 's', 's', 's', 's', 's', 's', 's', 's', 's', 's', 's', 's', 's', 's', 'might', 'might', 'might', 'might', 'might', 'little', 'little', 'little', 'little', 'little', 'little', 'little', 'these', 'these', 'these', 'these', 'these', 'these', 'these', 'these', 'these', 'these', 'lazarus', 'lazarus', 'lazarus', 'stream', 'stream', 'stream', ':', ':', ':', ':', 'matter', 'matter', 'matter', 'open', 'open', 'open', 'open', 'open', 'look', 'look', 'look', 'look', 'never', 'never', 'never', 'never', 'never', 'never', 'wild', 'wild', 'three', 'three', 'three', 'three', 'blows', 'blows', 'blows', 'too', 'too', 'too', 'too', 'too', 'too', 'through', 'through', 'through', 'through', 'through', 'through', 'boat', 'boat', 'boat', 'boat', 'large', 'large', 'large', 'large', 'room', 'room', 'by', 'by', 'by', 'by', 'by', 'by', 'by', 'by', 'by', 'by', 'by', 'by', 'by', 'by', 'by', 'by', 'by', 'by', 'by', 'by', 'by', 'miles', 'miles', 'miles', 'miles', 'teeth', 'teeth', 'teeth', 'long', 'long', 'long', 'long', 'long', 'long', 'ago', 'ago', 'ago', 'ago', 'feet', 'feet', 'feet', 'feet', 'feet', 'many', 'many', 'many', 'many', 'low', 'low', 'low', 'low', 'thomas', 'thomas', 'all', 'all', 'all', 'all', 'all', 'all', 'all', 'all', 'all', 'all', 'all', 'all', 'all', 'all', 'all', 'all', 'all', 'all', 'all', 'more', 'more', 'more', 'more', 'more', 'more', 'more', 'more', 'board', 'board', 'himself', 'himself', 'himself', 'himself', 'put', 'put', 'put', 'at', 'at', 'at', 'at', 'at', 'at', 'at', 'at', 'at', 'at', 'at', 'at', 'at', 'at', 'at', 'at', 'chapter', 'chapter', 'chapter', 'destroyed', 'destroyed', 'bones', 'bones', 'night', 'night', 'night', 'night', 'night', 'night', 'has', 'has', 'has', 'has', 'still', 'still', 'still', 'still', 'street', 'street', 'street', 'jolly', 'jolly', 'there', 'there', 'there', 'there', 'there', 'there', 'there', 'there', 'there', 'there', 'there', 'there', 'there', 'there', 'there', 'there', 'there', 'there', 'there', 'there', 'see', 'see', 'see', 'see', 'see', 'see', 'see', 'eyes', 'eyes', 'death', 'death', 'death', 'death', 'persons', 'persons', 'stop', 'stop', '--', '--', '--', '--', '--', '--', '--', '--', '--', '--', '--', '--', '--', '--', '--', '--', '--', '--', '--', '--', '--', '--', '--', '--', '--', '--', '--', '--', '--', '--', '--', '--', '--', '--', '--', '--', '--', '--', '--', '--', '--', '--', '--', '--', '--', '--', 'full', 'full', 'full', 'passengers', 'passengers', 'further', 'further', 'much', 'much', 'much', 'much', 'much', 'much', 'whose', 'whose', 'portentous', 'portentous', 'requires', 'requires', 'wide', 'wide', 'wide', 'wide', 'ishmael', 'ishmael', 'ishmael', 'ishmael', 'another', 'another', 'length', 'length', 'off', 'off', 'off', 'off', 'off', 'off', 'battle', 'battle', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'why', 'why', 'why', 'why', 'why', 'parts', 'parts', 'will', 'will', 'will', 'will', 'will', 'will', 'will', 'will', 'will', 'works', 'works', 'run', 'run', 'made', 'made', 'made', 'made', 'webster', 'webster', 'pains', 'pains', 'seas', 'seas', 'seas', 'seas', 'those', 'those', 'those', 'those', 'those', 'those', 'while', 'while', 'while', 'we', 'we', 'we', 'we', 'we', 'we', 'we', 'we', 'we', 'we', 'ice', 'ice', 'give', 'give', 'give', 'give', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', 'fishes', 'fishes', 'on', 'on', 'on', 'on', 'on', 'on', 'on', 'on', 'on', 'on', 'on', 'on', 'on', 'on', 'on', 'on', 'on', 'on', 'on', 'on', 'on', 'on', 'on', 'i', 'i', 'i', 'i', 'i', 'i', 'i', 'i', 'i', 'i', 'i', 'i', 'i', 'i', 'i', 'i', 'i', 'i', 'i', 'i', 'i', 'i', 'i', 'i', 'i', 'i', 'i', 'i', 'i', 'i', 'i', 'i', 'i', 'i', 'i', 'i', ')', ')', ')', ')', ')', ')', ')', 'her', 'her', 'her', 'her', 'jaws', 'jaws', 'jaws', 'jaws', 'bed', 'bed', 'pale', 'pale', 'compare', 'compare', 'd', 'd', 'd', 'd', 'd', 'place', 'place', 'place', 'place', 'place', 'place', \"'\", \"'\", \"'\", \"'\", \"'\", \"'\", \"'\", \"'\", \"'\", \"'\", \"'\", \"'\", \"'\", \"'\", \"'\", \"'\", \"'\", \"'\", \"'\", \"'\", \"'\", \"'\", \"'\", \"'\", \"'\", \"'\", \"'\", \"'\", \"'\", \"'\", \"'\", \"'\", \"'\", \"'\", \"'\", \"'\", \"'\", \"'\", 'days', 'days', 'out', 'out', 'out', 'out', 'out', 'out', 'out', 'out', 'out', 'out', 'out', 'whether', 'whether', 'whether', 'whether', 'before', 'before', 'before', 'before', 'before', 'before', 'before', 'created', 'created', 'saw', 'saw', 'saw', 'saw', 'saw', 'saw', 'saw', 'saw', 'sometimes', 'sometimes', 'ha', 'ha', 'beast', 'beast', 'our', 'our', 'our', 'our', 'four', 'four', 'time', 'time', 'time', 'time', 'time', 'time', 'time', 'even', 'even', 'even', 'even', 'nuee', 'nuee', 'nuee', 'cape', 'cape', 'cape', '!', '!', '!', '!', '!', '!', '!', '!', '!', '!', '!', '!', '!', '!', '!', 'into', 'into', 'into', 'into', 'into', 'into', 'into', 'into', 'into', 'into', 'part', 'part', 'part', 'part', 'part', 'nothing', 'nothing', 'nothing', '!\"', '!\"', '!\"', '!\"', 'round', 'round', 'round', 'round', 'round', 'round', 'paying', 'paying', 'stood', 'stood', 'stood', 'something', 'something', 'something', 'something', 'something', 'head', 'head', 'head', 'head', 'head', 'head', 'head', 'head', 'coffin', 'coffin', 'coffin', 'stand', 'stand', 'stand', 'stand', 'stand', 'every', 'every', 'every', 'every', 'every', 'every', 'every', 'broiled', 'broiled', 'bedford', 'bedford', 'bedford', 'any', 'any', 'any', 'any', 'any', 'any', 'strong', 'strong', 'strong', 'come', 'come', 'come', 'come', 'come', 'could', 'could', 'could', 'could', 'could', 'called', 'called', 'called', 'called', 'jonah', 'jonah', 'sort', 'sort', 'sort', 'sort', 'sort', 'sort', 'cook', 'cook', 'cook', 'cannot', 'cannot', 'jaw', 'jaw', 'god', 'god', 'god', 'purpose', 'purpose', 'especially', 'especially', 'don', 'don', 'glasses', 'glasses', 'glasses', 'us', 'us', 'us', 'us', 'us', 'grand', 'grand', 'grand', 'grand', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'if', 'if', 'if', 'if', 'if', 'if', 'if', 'if', 'if', 'if', 'if', 'spermacetti', 'spermacetti', 'e', 'e', 'now', 'now', 'now', 'now', 'now', 'now', 'now', 'now', 'now', 'now', 'idea', 'idea', 'idea', 'most', 'most', 'most', 'most', 'most', 'most', 'most', 'most', 'am', 'am', 'am', 'am', 'very', 'very', 'very', 'very', 'very', 'you', 'you', 'you', 'you', 'you', 'you', 'you', 'you', 'you', 'you', 'you', 'you', 'you', 'you', 'you', 'you', 'you', 'you', 'you', 'you', 'you', 'you', 'you', 'you', 'island', 'island', 'northern', 'northern', 'away', 'away', 'thousand', 'thousand', 'dictionary', 'dictionary', 'go', 'go', 'go', 'go', 'go', 'go', 'go', 'go', 'go', 'go', 'so', 'so', 'so', 'so', 'so', 'so', 'so', 'so', 'so', 'so', 'so', 'so', 'so', 'so', 'so', 'name', 'name', 'narrative', 'narrative', 'narrative', 'late', 'late', 'late', 'monsters', 'monsters', 'take', 'take', 'take', 'take', 'take', 'take', 'take', 'take', 'sword', 'sword', 'sword', 'town', 'town', 'his', 'his', 'his', 'his', 'his', 'his', 'his', 'his', 'his', 'his', 'his', 'his', 'his', 'his', 'his', 'his', 'his', 'his', 'his', 'his', 'his', 'his', 'his', 'his', 'under', 'under', 'under', '?--', '?--', '?--', '?--', 'first', 'first', 'first', 'first', 'first', 'first', 'first', 'first', 'may', 'may', 'may', 'may', 'may', 'may', 'may', ',--', ',--', ',--', 'king', 'king', 'king', 'king', 'king', 'seen', 'seen', 'seen', 'enough', 'enough', 'enough', 'enough', 'ships', 'ships', 'ships', 'ships', 'have', 'have', 'have', 'have', 'have', 'have', 'have', 'have', 'have', 'have', 'have', 'have', 'have', 'have', 'sea', 'sea', 'sea', 'sea', 'sea', 'sea', 'sea', 'sea', 'sea', 'sea', 'sea', 'sea', 'sea', 'sea', 'sea', 'sea', '.\"', '.\"', '.\"', '.\"', '.\"', '.\"', '.\"', '.\"', '.\"', '.\"', '.\"', '.\"', '.\"', '.\"', '.\"', '.\"', '.\"', '.\"', '.\"', '.\"', '.\"', '.\"', '.\"', '.\"', '.\"', '.\"', '.\"', '.\"', '.\"', '.\"', '.\"', '.\"', '.\"', '.\"', '.\"', 'monstrous', 'monstrous', 'monstrous', 'view', 'view', 'view', 'lord', 'lord', 'frost', 'frost', 'frost', 'with', 'with', 'with', 'with', 'with', 'with', 'with', 'with', 'with', 'with', 'with', 'with', 'with', 'with', 'with', 'with', 'with', 'with', 'with', 'with', 'with', 'with', 'with', 'with', 'with', 'with', 'with', 'with', 'with', 'great', 'great', 'great', 'great', 'great', 'great', 'great', 'great', 'great', 'great', 'great', 'great', 'sub', 'sub', 'sub', 'sub', 'sub', 'ten', 'ten', 'ten', 'point', 'point', 'and', 'and', 'and', 'and', 'and', 'and', 'and', 'and', 'and', 'and', 'and', 'and', 'and', 'and', 'and', 'and', 'and', 'and', 'and', 'and', 'and', 'and', 'and', 'and', 'and', 'and', 'and', 'and', 'and', 'and', 'and', 'and', 'and', 'and', 'and', 'and', 'and', 'and', 'and', 'and', 'and', 'and', 'and', 'and', 'and', 'and', 'and', 'and', 'and', 'and', 'and', 'and', 'and', 'and', 'and', 'and', 'and', 'and', 'and', 'and', 'and', 'and', 'and', 'and', 'and', 'and', 'and', 'and', 'and', 'and', 'and', 'and', 'and', 'and', 'and', 'and', 'and', 'south', 'south', 'south', 'far', 'far', 'far', 'far', 'nantucket', 'nantucket', 'nantucket', 'nantucket', 'nantucket', 'nantucket', 'nantucket', 'nantucket', 'nantucket', 'earth', 'earth', 'do', 'do', 'do', 'do', 'do', 'do', 'pacific', 'pacific', 'looking', 'looking', 'looking', 'side', 'side', 'side', 'side', 'sing', 'sing', 'tail', 'tail', 'tail', 'royal', 'royal']\n"
378 | ],
379 | "name": "stdout"
380 | }
381 | ]
382 | },
383 | {
384 | "cell_type": "code",
385 | "metadata": {
386 | "id": "IVtb0mic7okz",
387 | "colab_type": "code",
388 | "outputId": "353c3b20-4074-4ea1-b5e6-fe95a86b641f",
389 | "colab": {
390 | "base_uri": "https://localhost:8080/",
391 | "height": 35
392 | }
393 | },
394 | "source": [
395 | "print(len(vocab), len(unigram_table))"
396 | ],
397 | "execution_count": 15,
398 | "outputs": [
399 | {
400 | "output_type": "stream",
401 | "text": [
402 | "478 3500\n"
403 | ],
404 | "name": "stdout"
405 | }
406 | ]
407 | },
408 | {
409 | "cell_type": "markdown",
410 | "metadata": {
411 | "id": "vC7ha5lm_UYc",
412 | "colab_type": "text"
413 | },
414 | "source": [
415 | "# **Negative Sampling**"
416 | ]
417 | },
418 | {
419 | "cell_type": "code",
420 | "metadata": {
421 | "id": "zPmASeWB_W_J",
422 | "colab_type": "code",
423 | "colab": {}
424 | },
425 | "source": [
426 | "def negative_sampling(targets, unigram_table, k):\n",
427 | " batch_size = targets.size(0)\n",
428 | " neg_samples = []\n",
429 | " for i in range(batch_size):\n",
430 | " nsample = []\n",
431 | " target_index = targets[i].data.tolist()[0]\n",
432 | " while len(nsample) < k: # k = num of sampling\n",
433 | " neg = random.choice(unigram_table)\n",
434 | " if word2index[neg] == target_index: \n",
435 | " continue\n",
436 | " nsample.append(neg)\n",
437 | " neg_samples.append(prepare_sequence(nsample, word2index).view(1, -1))\n",
438 | " \n",
439 | " return torch.cat(neg_samples) #concatenates"
440 | ],
441 | "execution_count": 0,
442 | "outputs": []
443 | },
444 | {
445 | "cell_type": "markdown",
446 | "metadata": {
447 | "id": "ZMAYtc28A2KJ",
448 | "colab_type": "text"
449 | },
450 | "source": [
451 | "# **Modeling**\n",
452 | "\n",
453 | "\n",
454 | "\n",
455 | "\n",
456 | "\n",
457 | "\n",
458 | "\n",
459 | ""
460 | ]
461 | },
462 | {
463 | "cell_type": "code",
464 | "metadata": {
465 | "id": "rw_x7YXRA_nQ",
466 | "colab_type": "code",
467 | "colab": {}
468 | },
469 | "source": [
470 | "class SkipgramNegSampling(nn.Module):\n",
471 | " \n",
472 | " def __init__(self, vocab_size, projection_dim):\n",
473 | " super(SkipgramNegSampling, self).__init__()\n",
474 | " self.embedding_v = nn.Embedding(vocab_size, projection_dim) # center embedding\n",
475 | " self.embedding_u = nn.Embedding(vocab_size, projection_dim) # out embedding\n",
476 | " self.logsigmoid = nn.LogSigmoid()\n",
477 | " \n",
478 | " initrange = (2.0 / (vocab_size + projection_dim))**0.5 # Xavier init\n",
479 | " self.embedding_v.weight.data.uniform_(-initrange, initrange) # init\n",
480 | " self.embedding_u.weight.data.uniform_(-0.0, 0.0) # init\n",
481 | " \n",
482 | " def forward(self, center_words, target_words, negative_words):\n",
483 | " center_embeds = self.embedding_v(center_words) # B x 1 x D\n",
484 | " target_embeds = self.embedding_u(target_words) # B x 1 x D\n",
485 | " \n",
486 | " neg_embeds = -self.embedding_u(negative_words) # B x K x D k = num of negative sampling\n",
487 | " \n",
488 | " positive_score = target_embeds.bmm(center_embeds.transpose(1, 2)).squeeze(2) # Bx1\n",
489 | " negative_score = torch.sum(neg_embeds.bmm(center_embeds.transpose(1, 2)).squeeze(2), 1).view(negs.size(0), -1) # BxK -> Bx1\n",
490 | " \n",
491 | " loss = self.logsigmoid(positive_score) + self.logsigmoid(negative_score)\n",
492 | " \n",
493 | " return -torch.mean(loss)\n",
494 | " \n",
495 | " def prediction(self, inputs):\n",
496 | " embeds = self.embedding_v(inputs)\n",
497 | " \n",
498 | " return embeds"
499 | ],
500 | "execution_count": 0,
501 | "outputs": []
502 | },
503 | {
504 | "cell_type": "markdown",
505 | "metadata": {
506 | "id": "bg_xFERTB6BU",
507 | "colab_type": "text"
508 | },
509 | "source": [
510 | "# **Train**"
511 | ]
512 | },
513 | {
514 | "cell_type": "code",
515 | "metadata": {
516 | "id": "uIKrQT4JB8GN",
517 | "colab_type": "code",
518 | "colab": {}
519 | },
520 | "source": [
521 | "EMBEDDING_SIZE = 30 \n",
522 | "BATCH_SIZE = 256\n",
523 | "EPOCH = 100\n",
524 | "NEG = 10 # Num of Negative Sampling"
525 | ],
526 | "execution_count": 0,
527 | "outputs": []
528 | },
529 | {
530 | "cell_type": "code",
531 | "metadata": {
532 | "id": "OmWjxt90B_Re",
533 | "colab_type": "code",
534 | "colab": {}
535 | },
536 | "source": [
537 | "losses = []\n",
538 | "model = SkipgramNegSampling(len(word2index), EMBEDDING_SIZE)\n",
539 | "optimizer = optim.Adam(model.parameters(), lr=0.001)"
540 | ],
541 | "execution_count": 0,
542 | "outputs": []
543 | },
544 | {
545 | "cell_type": "code",
546 | "metadata": {
547 | "id": "jyyMZD2WCEDO",
548 | "colab_type": "code",
549 | "outputId": "67183f9b-5444-4105-f6ec-fec456b52790",
550 | "colab": {
551 | "base_uri": "https://localhost:8080/",
552 | "height": 89
553 | }
554 | },
555 | "source": [
556 | "for epoch in range(EPOCH):\n",
557 | " for i,batch in enumerate(getBatch(BATCH_SIZE, train_data)):\n",
558 | " \n",
559 | " inputs, targets = zip(*batch)\n",
560 | " \n",
561 | " inputs = torch.cat(inputs) # B x 1\n",
562 | " targets = torch.cat(targets) # B x 1\n",
563 | " negs = negative_sampling(targets, unigram_table, NEG)\n",
564 | " model.zero_grad()\n",
565 | "\n",
566 | " loss = model(inputs, targets, negs)\n",
567 | " \n",
568 | " loss.backward()\n",
569 | " optimizer.step()\n",
570 | " \n",
571 | " losses.append(loss.data.tolist())\n",
572 | " if epoch % 10 == 0:\n",
573 | " print(\"Epoch : %d, mean_loss : %.02f\" % (epoch, np.mean(losses)))\n",
574 | " losses = []"
575 | ],
576 | "execution_count": 0,
577 | "outputs": [
578 | {
579 | "output_type": "stream",
580 | "text": [
581 | "Epoch : 0, mean_loss : 1.06\n",
582 | "Epoch : 10, mean_loss : 0.86\n",
583 | "Epoch : 20, mean_loss : 0.80\n",
584 | "Epoch : 30, mean_loss : 0.74\n"
585 | ],
586 | "name": "stdout"
587 | }
588 | ]
589 | },
590 | {
591 | "cell_type": "markdown",
592 | "metadata": {
593 | "id": "al_-FzyrCKui",
594 | "colab_type": "text"
595 | },
596 | "source": [
597 | "# **Test**"
598 | ]
599 | },
600 | {
601 | "cell_type": "code",
602 | "metadata": {
603 | "id": "qJeIrSb_CMxw",
604 | "colab_type": "code",
605 | "colab": {}
606 | },
607 | "source": [
608 | "def word_similarity(target, vocab):\n",
609 | " target_V = model.prediction(prepare_word(target, word2index))\n",
610 | " similarities = []\n",
611 | " for i in range(len(vocab)):\n",
612 | " if vocab[i] == target: \n",
613 | " continue\n",
614 | " \n",
615 | " vector = model.prediction(prepare_word(list(vocab)[i], word2index))\n",
616 | " \n",
617 | " cosine_sim = F.cosine_similarity(target_V, vector).data.tolist()[0]\n",
618 | " similarities.append([vocab[i], cosine_sim])\n",
619 | " return sorted(similarities, key=lambda x: x[1], reverse=True)[:10]"
620 | ],
621 | "execution_count": 0,
622 | "outputs": []
623 | },
624 | {
625 | "cell_type": "code",
626 | "metadata": {
627 | "id": "kOhDlBrjCO6Q",
628 | "colab_type": "code",
629 | "colab": {}
630 | },
631 | "source": [
632 | "test = random.choice(list(vocab))\n",
633 | "test"
634 | ],
635 | "execution_count": 0,
636 | "outputs": []
637 | },
638 | {
639 | "cell_type": "code",
640 | "metadata": {
641 | "id": "GT30HLR6CQXz",
642 | "colab_type": "code",
643 | "colab": {}
644 | },
645 | "source": [
646 | "word_similarity(test, vocab)"
647 | ],
648 | "execution_count": 0,
649 | "outputs": []
650 | }
651 | ]
652 | }
--------------------------------------------------------------------------------