├── LICENSE
├── NMT_in_PyTorch.ipynb
├── README.md
└── nmt_pytorch.gif
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 elvis
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/NMT_in_PyTorch.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "NMT in PyTorch.ipynb",
7 | "version": "0.3.2",
8 | "provenance": [],
9 | "collapsed_sections": [],
10 | "include_colab_link": true
11 | },
12 | "kernelspec": {
13 | "name": "python3",
14 | "display_name": "Python 3"
15 | },
16 | "accelerator": "GPU"
17 | },
18 | "cells": [
19 | {
20 | "cell_type": "markdown",
21 | "metadata": {
22 | "id": "view-in-github",
23 | "colab_type": "text"
24 | },
25 | "source": [
26 | "
"
27 | ]
28 | },
29 | {
30 | "metadata": {
31 | "id": "gau9xEXMGY8s",
32 | "colab_type": "text"
33 | },
34 | "cell_type": "markdown",
35 | "source": [
36 | "# Neural Machine Translation with Attention Using PyTorch\n",
37 | "In this notebook we are going to perform machine translation using a deep learning based approach and attention mechanism. All the code is based on PyTorch and it was adopted from the tutorial provided on the official documentation of [TensorFlow](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb).\n",
38 | "\n",
39 | "Specifically, we are going to train a sequence to sequence model for Spanish to English translation. If you are not familiar with sequence to sequence models, I have provided some references at the end of this tutorial to familiarize yourself with the concept. Even if you are not familiar with seq2seq models, you can still proceed with the coding exercise. I will explain tiny details that are important as we proceed. \n",
40 | "\n",
41 | "The tutorial is very brief and I encourage you to also take a look at the official TensorFlow [notebook](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb) for more detailed explanations. The purpose of this tutorial is to understand how to convert certain code blocks into a deep learning framework like PyTorch. You will soon realize that the frameworks are very similar to some extent. The data preparation part is slightly different so I would emphasize that you spend more time analyzing this part of the code. \n",
42 | "\n",
43 | "If you have questions you can also reach out to me at ellfae@gmail.com or Twitter ([@omarsar0](https://twitter.com/omarsar0))."
44 | ]
45 | },
46 | {
47 | "metadata": {
48 | "id": "Z579-ISl9Zj6",
49 | "colab_type": "text"
50 | },
51 | "cell_type": "markdown",
52 | "source": [
53 | "## Import libraries"
54 | ]
55 | },
56 | {
57 | "metadata": {
58 | "id": "XRldb3db1Bg0",
59 | "colab_type": "code",
60 | "colab": {
61 | "base_uri": "https://localhost:8080/",
62 | "height": 377
63 | },
64 | "outputId": "504abe37-6f7e-4046-ac4e-cc036821f1d6"
65 | },
66 | "cell_type": "code",
67 | "source": [
68 | "!pip3 install http://download.pytorch.org/whl/cu80/torch-0.4.1-cp36-cp36m-linux_x86_64.whl \n",
69 | "!pip3 install torchvision"
70 | ],
71 | "execution_count": 1,
72 | "outputs": [
73 | {
74 | "output_type": "stream",
75 | "text": [
76 | "Collecting torch==0.4.1 from http://download.pytorch.org/whl/cu80/torch-0.4.1-cp36-cp36m-linux_x86_64.whl\n",
77 | "\u001b[?25l Downloading http://download.pytorch.org/whl/cu80/torch-0.4.1-cp36-cp36m-linux_x86_64.whl (483.0MB)\n",
78 | "\u001b[K 100% |████████████████████████████████| 483.0MB 51.1MB/s \n",
79 | "tcmalloc: large alloc 1073750016 bytes == 0x56070000 @ 0x7f1bbde7e2a4 0x594e17 0x626104 0x51190a 0x4f5277 0x510c78 0x5119bd 0x4f5277 0x4f3338 0x510fb0 0x5119bd 0x4f5277 0x4f3338 0x510fb0 0x5119bd 0x4f5277 0x4f3338 0x510fb0 0x5119bd 0x4f6070 0x510c78 0x5119bd 0x4f5277 0x4f3338 0x510fb0 0x5119bd 0x4f6070 0x4f3338 0x510fb0 0x5119bd 0x4f6070\n",
80 | "\u001b[?25hInstalling collected packages: torch\n",
81 | "Successfully installed torch-0.4.1\n",
82 | "Collecting torchvision\n",
83 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/ca/0d/f00b2885711e08bd71242ebe7b96561e6f6d01fdb4b9dcf4d37e2e13c5e1/torchvision-0.2.1-py2.py3-none-any.whl (54kB)\n",
84 | "\u001b[K 100% |████████████████████████████████| 61kB 3.7MB/s \n",
85 | "\u001b[?25hRequirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from torchvision) (1.14.6)\n",
86 | "Collecting pillow>=4.1.1 (from torchvision)\n",
87 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/62/94/5430ebaa83f91cc7a9f687ff5238e26164a779cca2ef9903232268b0a318/Pillow-5.3.0-cp36-cp36m-manylinux1_x86_64.whl (2.0MB)\n",
88 | "\u001b[K 100% |████████████████████████████████| 2.0MB 10.8MB/s \n",
89 | "\u001b[?25hRequirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from torchvision) (1.11.0)\n",
90 | "Requirement already satisfied: torch in /usr/local/lib/python3.6/dist-packages (from torchvision) (0.4.1)\n",
91 | "Installing collected packages: pillow, torchvision\n",
92 | " Found existing installation: Pillow 4.0.0\n",
93 | " Uninstalling Pillow-4.0.0:\n",
94 | " Successfully uninstalled Pillow-4.0.0\n",
95 | "Successfully installed pillow-5.3.0 torchvision-0.2.1\n"
96 | ],
97 | "name": "stdout"
98 | }
99 | ]
100 | },
101 | {
102 | "metadata": {
103 | "id": "qT20LFmb3jSW",
104 | "colab_type": "code",
105 | "colab": {
106 | "base_uri": "https://localhost:8080/",
107 | "height": 34
108 | },
109 | "outputId": "595c3efa-50dd-4f5a-9695-02ab51af4e61"
110 | },
111 | "cell_type": "code",
112 | "source": [
113 | "import torch\n",
114 | "import torch.functional as F\n",
115 | "import torch.nn as nn\n",
116 | "import torch.optim as optim\n",
117 | "from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence\n",
118 | "\n",
119 | "import pandas as pd\n",
120 | "from sklearn.model_selection import train_test_split\n",
121 | "import numpy as np\n",
122 | "import unicodedata\n",
123 | "import re\n",
124 | "import time\n",
125 | "\n",
126 | "print(torch.__version__)"
127 | ],
128 | "execution_count": 2,
129 | "outputs": [
130 | {
131 | "output_type": "stream",
132 | "text": [
133 | "0.4.1\n"
134 | ],
135 | "name": "stdout"
136 | }
137 | ]
138 | },
139 | {
140 | "metadata": {
141 | "id": "NAuXJjo9NuT8",
142 | "colab_type": "text"
143 | },
144 | "cell_type": "markdown",
145 | "source": [
146 | "## Import Data from Google Drive\n",
147 | "I stored the data on my Google Drive, but you can also obtain it from [here](http://www.manythings.org/anki/) as well. "
148 | ]
149 | },
150 | {
151 | "metadata": {
152 | "id": "63Ox1YURzVhF",
153 | "colab_type": "code",
154 | "colab": {
155 | "base_uri": "https://localhost:8080/",
156 | "height": 122
157 | },
158 | "outputId": "1fe23f41-d672-4ddf-9e4f-cabca9de294b"
159 | },
160 | "cell_type": "code",
161 | "source": [
162 | "from google.colab import drive\n",
163 | "drive.mount('/gdrive')"
164 | ],
165 | "execution_count": 3,
166 | "outputs": [
167 | {
168 | "output_type": "stream",
169 | "text": [
170 | "Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3Aietf%3Awg%3Aoauth%3A2.0%3Aoob&scope=email%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdocs.test%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive.photos.readonly%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fpeopleapi.readonly&response_type=code\n",
171 | "\n",
172 | "Enter your authorization code:\n",
173 | "··········\n",
174 | "Mounted at /gdrive\n"
175 | ],
176 | "name": "stdout"
177 | }
178 | ]
179 | },
180 | {
181 | "metadata": {
182 | "id": "JD8Qy0eC0ZtA",
183 | "colab_type": "code",
184 | "colab": {}
185 | },
186 | "cell_type": "code",
187 | "source": [
188 | "f = open('/gdrive/My Drive/DAIR RESOURCES/PyTorch/Neural Machine Translation with PyTorch/spa.txt', encoding='UTF-8').read().strip().split('\\n') "
189 | ],
190 | "execution_count": 0,
191 | "outputs": []
192 | },
193 | {
194 | "metadata": {
195 | "id": "0mVlB0W14b4G",
196 | "colab_type": "code",
197 | "colab": {}
198 | },
199 | "cell_type": "code",
200 | "source": [
201 | "lines = f"
202 | ],
203 | "execution_count": 0,
204 | "outputs": []
205 | },
206 | {
207 | "metadata": {
208 | "id": "JouLb6Eo4f28",
209 | "colab_type": "code",
210 | "colab": {}
211 | },
212 | "cell_type": "code",
213 | "source": [
214 | "# sample size (try with smaller sample size to reduce computation)\n",
215 | "num_examples = 30000 \n",
216 | "\n",
217 | "# creates lists containing each pair\n",
218 | "original_word_pairs = [[w for w in l.split('\\t')] for l in lines[:num_examples]]"
219 | ],
220 | "execution_count": 0,
221 | "outputs": []
222 | },
223 | {
224 | "metadata": {
225 | "id": "as2-5vGn4jUa",
226 | "colab_type": "code",
227 | "colab": {}
228 | },
229 | "cell_type": "code",
230 | "source": [
231 | "data = pd.DataFrame(original_word_pairs, columns=[\"eng\", \"es\"])"
232 | ],
233 | "execution_count": 0,
234 | "outputs": []
235 | },
236 | {
237 | "metadata": {
238 | "id": "913VSLih4lY3",
239 | "colab_type": "code",
240 | "colab": {
241 | "base_uri": "https://localhost:8080/",
242 | "height": 204
243 | },
244 | "outputId": "3e2dc267-8f7d-40a0-f4e4-42e4d6d7b09b"
245 | },
246 | "cell_type": "code",
247 | "source": [
248 | "data.head(5)"
249 | ],
250 | "execution_count": 8,
251 | "outputs": [
252 | {
253 | "output_type": "execute_result",
254 | "data": {
255 | "text/html": [
256 | "
\n",
257 | "\n",
270 | "
\n",
271 | " \n",
272 | " \n",
273 | " | \n",
274 | " eng | \n",
275 | " es | \n",
276 | "
\n",
277 | " \n",
278 | " \n",
279 | " \n",
280 | " 0 | \n",
281 | " Go. | \n",
282 | " Ve. | \n",
283 | "
\n",
284 | " \n",
285 | " 1 | \n",
286 | " Go. | \n",
287 | " Vete. | \n",
288 | "
\n",
289 | " \n",
290 | " 2 | \n",
291 | " Go. | \n",
292 | " Vaya. | \n",
293 | "
\n",
294 | " \n",
295 | " 3 | \n",
296 | " Go. | \n",
297 | " Váyase. | \n",
298 | "
\n",
299 | " \n",
300 | " 4 | \n",
301 | " Hi. | \n",
302 | " Hola. | \n",
303 | "
\n",
304 | " \n",
305 | "
\n",
306 | "
"
307 | ],
308 | "text/plain": [
309 | " eng es\n",
310 | "0 Go. Ve.\n",
311 | "1 Go. Vete.\n",
312 | "2 Go. Vaya.\n",
313 | "3 Go. Váyase.\n",
314 | "4 Hi. Hola."
315 | ]
316 | },
317 | "metadata": {
318 | "tags": []
319 | },
320 | "execution_count": 8
321 | }
322 | ]
323 | },
324 | {
325 | "metadata": {
326 | "id": "jCUSf31E4m6t",
327 | "colab_type": "code",
328 | "colab": {}
329 | },
330 | "cell_type": "code",
331 | "source": [
332 | "# Converts the unicode file to ascii\n",
333 | "def unicode_to_ascii(s):\n",
334 | " \"\"\"\n",
335 | " Normalizes latin chars with accent to their canonical decomposition\n",
336 | " \"\"\"\n",
337 | " return ''.join(c for c in unicodedata.normalize('NFD', s)\n",
338 | " if unicodedata.category(c) != 'Mn')\n",
339 | "\n",
340 | "def preprocess_sentence(w):\n",
341 | " w = unicode_to_ascii(w.lower().strip())\n",
342 | " \n",
343 | " # creating a space between a word and the punctuation following it\n",
344 | " # eg: \"he is a boy.\" => \"he is a boy .\" \n",
345 | " # Reference:- https://stackoverflow.com/questions/3645931/python-padding-punctuation-with-white-spaces-keeping-punctuation\n",
346 | " w = re.sub(r\"([?.!,¿])\", r\" \\1 \", w)\n",
347 | " w = re.sub(r'[\" \"]+', \" \", w)\n",
348 | " \n",
349 | " # replacing everything with space except (a-z, A-Z, \".\", \"?\", \"!\", \",\")\n",
350 | " w = re.sub(r\"[^a-zA-Z?.!,¿]+\", \" \", w)\n",
351 | " \n",
352 | " w = w.rstrip().strip()\n",
353 | " \n",
354 | " # adding a start and an end token to the sentence\n",
355 | " # so that the model know when to start and stop predicting.\n",
356 | " w = ' ' + w + ' '\n",
357 | " return w"
358 | ],
359 | "execution_count": 0,
360 | "outputs": []
361 | },
362 | {
363 | "metadata": {
364 | "id": "CN2pLaZkNqrv",
365 | "colab_type": "text"
366 | },
367 | "cell_type": "markdown",
368 | "source": [
369 | "## Data Exploration\n",
370 | "Let's explore the dataset a bit."
371 | ]
372 | },
373 | {
374 | "metadata": {
375 | "id": "QFLV4RCR4pXa",
376 | "colab_type": "code",
377 | "colab": {
378 | "base_uri": "https://localhost:8080/",
379 | "height": 359
380 | },
381 | "outputId": "31a441b5-9ab1-4b1b-93c3-2f95ca51684c"
382 | },
383 | "cell_type": "code",
384 | "source": [
385 | "# Now we do the preprocessing using pandas and lambdas\n",
386 | "data[\"eng\"] = data.eng.apply(lambda w: preprocess_sentence(w))\n",
387 | "data[\"es\"] = data.es.apply(lambda w: preprocess_sentence(w))\n",
388 | "data.sample(10)\n"
389 | ],
390 | "execution_count": 10,
391 | "outputs": [
392 | {
393 | "output_type": "execute_result",
394 | "data": {
395 | "text/html": [
396 | "\n",
397 | "\n",
410 | "
\n",
411 | " \n",
412 | " \n",
413 | " | \n",
414 | " eng | \n",
415 | " es | \n",
416 | "
\n",
417 | " \n",
418 | " \n",
419 | " \n",
420 | " 14479 | \n",
421 | " <start> the cat scared me . <end> | \n",
422 | " <start> el gato me espanto . <end> | \n",
423 | "
\n",
424 | " \n",
425 | " 20413 | \n",
426 | " <start> i wanted to see you . <end> | \n",
427 | " <start> queria verte . <end> | \n",
428 | "
\n",
429 | " \n",
430 | " 5675 | \n",
431 | " <start> don t help tom . <end> | \n",
432 | " <start> no ayudes a tom . <end> | \n",
433 | "
\n",
434 | " \n",
435 | " 13393 | \n",
436 | " <start> i have it at home . <end> | \n",
437 | " <start> lo tengo en casa . <end> | \n",
438 | "
\n",
439 | " \n",
440 | " 13404 | \n",
441 | " <start> i have to go home . <end> | \n",
442 | " <start> tengo que ir a casa . <end> | \n",
443 | "
\n",
444 | " \n",
445 | " 18717 | \n",
446 | " <start> where is your room ? <end> | \n",
447 | " <start> ¿ donde esta vuestra habitacion ? <end> | \n",
448 | "
\n",
449 | " \n",
450 | " 6015 | \n",
451 | " <start> i just want it . <end> | \n",
452 | " <start> lo quiero ya . <end> | \n",
453 | "
\n",
454 | " \n",
455 | " 25393 | \n",
456 | " <start> this is so hilarious . <end> | \n",
457 | " <start> esto es tan chistoso . <end> | \n",
458 | "
\n",
459 | " \n",
460 | " 25097 | \n",
461 | " <start> the baby is crawling . <end> | \n",
462 | " <start> el bebe esta gateando . <end> | \n",
463 | "
\n",
464 | " \n",
465 | " 16464 | \n",
466 | " <start> i have a glass eye . <end> | \n",
467 | " <start> tengo un ojo de cristal . <end> | \n",
468 | "
\n",
469 | " \n",
470 | "
\n",
471 | "
"
472 | ],
473 | "text/plain": [
474 | " eng \\\n",
475 | "14479 the cat scared me . \n",
476 | "20413 i wanted to see you . \n",
477 | "5675 don t help tom . \n",
478 | "13393 i have it at home . \n",
479 | "13404 i have to go home . \n",
480 | "18717 where is your room ? \n",
481 | "6015 i just want it . \n",
482 | "25393 this is so hilarious . \n",
483 | "25097 the baby is crawling . \n",
484 | "16464 i have a glass eye . \n",
485 | "\n",
486 | " es \n",
487 | "14479 el gato me espanto . \n",
488 | "20413 queria verte . \n",
489 | "5675 no ayudes a tom . \n",
490 | "13393 lo tengo en casa . \n",
491 | "13404 tengo que ir a casa . \n",
492 | "18717 ¿ donde esta vuestra habitacion ? \n",
493 | "6015 lo quiero ya . \n",
494 | "25393 esto es tan chistoso . \n",
495 | "25097 el bebe esta gateando . \n",
496 | "16464 tengo un ojo de cristal . "
497 | ]
498 | },
499 | "metadata": {
500 | "tags": []
501 | },
502 | "execution_count": 10
503 | }
504 | ]
505 | },
506 | {
507 | "metadata": {
508 | "id": "hqM7ZncM8V9B",
509 | "colab_type": "text"
510 | },
511 | "cell_type": "markdown",
512 | "source": [
513 | "## Building Vocabulary Index\n",
514 | "The class below is useful for creating the vocabular and index mappings which will be used to convert out inputs into indexed sequences. "
515 | ]
516 | },
517 | {
518 | "metadata": {
519 | "id": "2rXA7-N34sok",
520 | "colab_type": "code",
521 | "colab": {}
522 | },
523 | "cell_type": "code",
524 | "source": [
525 | "# This class creates a word -> index mapping (e.g,. \"dad\" -> 5) and vice-versa \n",
526 | "# (e.g., 5 -> \"dad\") for each language,\n",
527 | "class LanguageIndex():\n",
528 | " def __init__(self, lang):\n",
529 | " \"\"\" lang are the list of phrases from each language\"\"\"\n",
530 | " self.lang = lang\n",
531 | " self.word2idx = {}\n",
532 | " self.idx2word = {}\n",
533 | " self.vocab = set()\n",
534 | " \n",
535 | " self.create_index()\n",
536 | " \n",
537 | " def create_index(self):\n",
538 | " for phrase in self.lang:\n",
539 | " # update with individual tokens\n",
540 | " self.vocab.update(phrase.split(' '))\n",
541 | " \n",
542 | " # sort the vocab\n",
543 | " self.vocab = sorted(self.vocab)\n",
544 | "\n",
545 | " # add a padding token with index 0\n",
546 | " self.word2idx[''] = 0\n",
547 | " \n",
548 | " # word to index mapping\n",
549 | " for index, word in enumerate(self.vocab):\n",
550 | " self.word2idx[word] = index + 1 # +1 because of pad token\n",
551 | " \n",
552 | " # index to word mapping\n",
553 | " for word, index in self.word2idx.items():\n",
554 | " self.idx2word[index] = word "
555 | ],
556 | "execution_count": 0,
557 | "outputs": []
558 | },
559 | {
560 | "metadata": {
561 | "id": "Fesymsn34v7z",
562 | "colab_type": "code",
563 | "colab": {
564 | "base_uri": "https://localhost:8080/",
565 | "height": 187
566 | },
567 | "outputId": "43e7a6c8-8505-4dd5-a2d7-4a82efae0ab3"
568 | },
569 | "cell_type": "code",
570 | "source": [
571 | "# index language using the class above\n",
572 | "inp_lang = LanguageIndex(data[\"es\"].values.tolist())\n",
573 | "targ_lang = LanguageIndex(data[\"eng\"].values.tolist())\n",
574 | "# Vectorize the input and target languages\n",
575 | "input_tensor = [[inp_lang.word2idx[s] for s in es.split(' ')] for es in data[\"es\"].values.tolist()]\n",
576 | "target_tensor = [[targ_lang.word2idx[s] for s in eng.split(' ')] for eng in data[\"eng\"].values.tolist()]\n",
577 | "input_tensor[:10]"
578 | ],
579 | "execution_count": 12,
580 | "outputs": [
581 | {
582 | "output_type": "execute_result",
583 | "data": {
584 | "text/plain": [
585 | "[[5, 9090, 3, 4],\n",
586 | " [5, 9204, 3, 4],\n",
587 | " [5, 9082, 3, 4],\n",
588 | " [5, 9089, 3, 4],\n",
589 | " [5, 4702, 3, 4],\n",
590 | " [5, 2299, 1, 4],\n",
591 | " [5, 2304, 3, 4],\n",
592 | " [5, 9413, 7433, 6, 4],\n",
593 | " [5, 4270, 1, 4],\n",
594 | " [5, 4881, 1, 4]]"
595 | ]
596 | },
597 | "metadata": {
598 | "tags": []
599 | },
600 | "execution_count": 12
601 | }
602 | ]
603 | },
604 | {
605 | "metadata": {
606 | "id": "uPordlA-N4qR",
607 | "colab_type": "code",
608 | "colab": {
609 | "base_uri": "https://localhost:8080/",
610 | "height": 187
611 | },
612 | "outputId": "b15fd270-d92c-4e85-cc83-ff9fb478780c"
613 | },
614 | "cell_type": "code",
615 | "source": [
616 | "target_tensor[:10]"
617 | ],
618 | "execution_count": 13,
619 | "outputs": [
620 | {
621 | "output_type": "execute_result",
622 | "data": {
623 | "text/plain": [
624 | "[[5, 1857, 3, 4],\n",
625 | " [5, 1857, 3, 4],\n",
626 | " [5, 1857, 3, 4],\n",
627 | " [5, 1857, 3, 4],\n",
628 | " [5, 2058, 3, 4],\n",
629 | " [5, 3655, 1, 4],\n",
630 | " [5, 3655, 3, 4],\n",
631 | " [5, 4815, 6, 4],\n",
632 | " [5, 1636, 1, 4],\n",
633 | " [5, 1636, 1, 4]]"
634 | ]
635 | },
636 | "metadata": {
637 | "tags": []
638 | },
639 | "execution_count": 13
640 | }
641 | ]
642 | },
643 | {
644 | "metadata": {
645 | "id": "8cwX-0rt4zmN",
646 | "colab_type": "code",
647 | "colab": {}
648 | },
649 | "cell_type": "code",
650 | "source": [
651 | "def max_length(tensor):\n",
652 | " return max(len(t) for t in tensor)"
653 | ],
654 | "execution_count": 0,
655 | "outputs": []
656 | },
657 | {
658 | "metadata": {
659 | "id": "ycYy5gq641Uy",
660 | "colab_type": "code",
661 | "colab": {}
662 | },
663 | "cell_type": "code",
664 | "source": [
665 | "# calculate the max_length of input and output tensor\n",
666 | "max_length_inp, max_length_tar = max_length(input_tensor), max_length(target_tensor)"
667 | ],
668 | "execution_count": 0,
669 | "outputs": []
670 | },
671 | {
672 | "metadata": {
673 | "id": "q05E5IwH42_1",
674 | "colab_type": "code",
675 | "colab": {}
676 | },
677 | "cell_type": "code",
678 | "source": [
679 | "def pad_sequences(x, max_len):\n",
680 | " padded = np.zeros((max_len), dtype=np.int64)\n",
681 | " if len(x) > max_len: padded[:] = x[:max_len]\n",
682 | " else: padded[:len(x)] = x\n",
683 | " return padded"
684 | ],
685 | "execution_count": 0,
686 | "outputs": []
687 | },
688 | {
689 | "metadata": {
690 | "id": "66dJPqzV44jd",
691 | "colab_type": "code",
692 | "colab": {
693 | "base_uri": "https://localhost:8080/",
694 | "height": 34
695 | },
696 | "outputId": "8e5fe7eb-f39d-4365-a9fb-63fdaeff8f34"
697 | },
698 | "cell_type": "code",
699 | "source": [
700 | "# inplace padding\n",
701 | "input_tensor = [pad_sequences(x, max_length_inp) for x in input_tensor]\n",
702 | "target_tensor = [pad_sequences(x, max_length_tar) for x in target_tensor]\n",
703 | "len(target_tensor)"
704 | ],
705 | "execution_count": 17,
706 | "outputs": [
707 | {
708 | "output_type": "execute_result",
709 | "data": {
710 | "text/plain": [
711 | "30000"
712 | ]
713 | },
714 | "metadata": {
715 | "tags": []
716 | },
717 | "execution_count": 17
718 | }
719 | ]
720 | },
721 | {
722 | "metadata": {
723 | "id": "zvatfCWS46T-",
724 | "colab_type": "code",
725 | "colab": {
726 | "base_uri": "https://localhost:8080/",
727 | "height": 34
728 | },
729 | "outputId": "2eeff7ec-2c61-4e8e-d084-ed14251a3bca"
730 | },
731 | "cell_type": "code",
732 | "source": [
733 | "# Creating training and validation sets using an 80-20 split\n",
734 | "input_tensor_train, input_tensor_val, target_tensor_train, target_tensor_val = train_test_split(input_tensor, target_tensor, test_size=0.2)\n",
735 | "\n",
736 | "# Show length\n",
737 | "len(input_tensor_train), len(target_tensor_train), len(input_tensor_val), len(target_tensor_val)"
738 | ],
739 | "execution_count": 18,
740 | "outputs": [
741 | {
742 | "output_type": "execute_result",
743 | "data": {
744 | "text/plain": [
745 | "(24000, 24000, 6000, 6000)"
746 | ]
747 | },
748 | "metadata": {
749 | "tags": []
750 | },
751 | "execution_count": 18
752 | }
753 | ]
754 | },
755 | {
756 | "metadata": {
757 | "id": "HNFO3obpOsoB",
758 | "colab_type": "text"
759 | },
760 | "cell_type": "markdown",
761 | "source": [
762 | "## Load data into DataLoader for Batching\n",
763 | "This is just preparing the dataset so that it can be efficiently fed into the model through batches."
764 | ]
765 | },
766 | {
767 | "metadata": {
768 | "id": "-QRQKwxf479Q",
769 | "colab_type": "code",
770 | "colab": {}
771 | },
772 | "cell_type": "code",
773 | "source": [
774 | "from torch.utils.data import Dataset, DataLoader"
775 | ],
776 | "execution_count": 0,
777 | "outputs": []
778 | },
779 | {
780 | "metadata": {
781 | "id": "IDSxA4OM5Qlp",
782 | "colab_type": "code",
783 | "colab": {}
784 | },
785 | "cell_type": "code",
786 | "source": [
787 | "# conver the data to tensors and pass to the Dataloader \n",
788 | "# to create an batch iterator\n",
789 | "\n",
790 | "class MyData(Dataset):\n",
791 | " def __init__(self, X, y):\n",
792 | " self.data = X\n",
793 | " self.target = y\n",
794 | " # TODO: convert this into torch code is possible\n",
795 | " self.length = [ np.sum(1 - np.equal(x, 0)) for x in X]\n",
796 | " \n",
797 | " def __getitem__(self, index):\n",
798 | " x = self.data[index]\n",
799 | " y = self.target[index]\n",
800 | " x_len = self.length[index]\n",
801 | " return x,y,x_len\n",
802 | " \n",
803 | " def __len__(self):\n",
804 | " return len(self.data)"
805 | ],
806 | "execution_count": 0,
807 | "outputs": []
808 | },
809 | {
810 | "metadata": {
811 | "id": "D2WukeVF8NVn",
812 | "colab_type": "text"
813 | },
814 | "cell_type": "markdown",
815 | "source": [
816 | "## Parameters\n",
817 | "Let's define the hyperparameters and other things we need for training our NMT model."
818 | ]
819 | },
820 | {
821 | "metadata": {
822 | "id": "s3Be7lOZ5R-d",
823 | "colab_type": "code",
824 | "colab": {}
825 | },
826 | "cell_type": "code",
827 | "source": [
828 | "BUFFER_SIZE = len(input_tensor_train)\n",
829 | "BATCH_SIZE = 64\n",
830 | "N_BATCH = BUFFER_SIZE//BATCH_SIZE\n",
831 | "embedding_dim = 256\n",
832 | "units = 1024\n",
833 | "vocab_inp_size = len(inp_lang.word2idx)\n",
834 | "vocab_tar_size = len(targ_lang.word2idx)\n",
835 | "\n",
836 | "train_dataset = MyData(input_tensor_train, target_tensor_train)\n",
837 | "val_dataset = MyData(input_tensor_val, target_tensor_val)\n",
838 | "\n",
839 | "dataset = DataLoader(train_dataset, batch_size = BATCH_SIZE, \n",
840 | " drop_last=True,\n",
841 | " shuffle=True)"
842 | ],
843 | "execution_count": 0,
844 | "outputs": []
845 | },
846 | {
847 | "metadata": {
848 | "id": "blYXo7pv5TOu",
849 | "colab_type": "code",
850 | "colab": {}
851 | },
852 | "cell_type": "code",
853 | "source": [
854 | "class Encoder(nn.Module):\n",
855 | " def __init__(self, vocab_size, embedding_dim, enc_units, batch_sz):\n",
856 | " super(Encoder, self).__init__()\n",
857 | " self.batch_sz = batch_sz\n",
858 | " self.enc_units = enc_units\n",
859 | " self.vocab_size = vocab_size\n",
860 | " self.embedding_dim = embedding_dim\n",
861 | " self.embedding = nn.Embedding(self.vocab_size, self.embedding_dim)\n",
862 | " self.gru = nn.GRU(self.embedding_dim, self.enc_units)\n",
863 | " \n",
864 | " def forward(self, x, lens, device):\n",
865 | " # x: batch_size, max_length \n",
866 | " \n",
867 | " # x: batch_size, max_length, embedding_dim\n",
868 | " x = self.embedding(x) \n",
869 | " \n",
870 | " # x transformed = max_len X batch_size X embedding_dim\n",
871 | " # x = x.permute(1,0,2)\n",
872 | " x = pack_padded_sequence(x, lens) # unpad\n",
873 | " \n",
874 | " self.hidden = self.initialize_hidden_state(device)\n",
875 | " \n",
876 | " # output: max_length, batch_size, enc_units\n",
877 | " # self.hidden: 1, batch_size, enc_units\n",
878 | " output, self.hidden = self.gru(x, self.hidden) # gru returns hidden state of all timesteps as well as hidden state at last timestep\n",
879 | " \n",
880 | " # pad the sequence to the max length in the batch\n",
881 | " output, _ = pad_packed_sequence(output)\n",
882 | " \n",
883 | " return output, self.hidden\n",
884 | "\n",
885 | " def initialize_hidden_state(self, device):\n",
886 | " return torch.zeros((1, self.batch_sz, self.enc_units)).to(device)"
887 | ],
888 | "execution_count": 0,
889 | "outputs": []
890 | },
891 | {
892 | "metadata": {
893 | "id": "SrsQ7dTg5V__",
894 | "colab_type": "code",
895 | "colab": {}
896 | },
897 | "cell_type": "code",
898 | "source": [
899 | "### sort batch function to be able to use with pad_packed_sequence\n",
900 | "def sort_batch(X, y, lengths):\n",
901 | " lengths, indx = lengths.sort(dim=0, descending=True)\n",
902 | " X = X[indx]\n",
903 | " y = y[indx]\n",
904 | " return X.transpose(0,1), y, lengths # transpose (batch x seq) to (seq x batch)"
905 | ],
906 | "execution_count": 0,
907 | "outputs": []
908 | },
909 | {
910 | "metadata": {
911 | "id": "2X1h155CPQ1Y",
912 | "colab_type": "text"
913 | },
914 | "cell_type": "markdown",
915 | "source": [
916 | "## Testing the Encoder\n",
917 | "Before proceeding with training, we should always try to test out model behavior such as the size of outputs just to make that things are going as expected. In PyTorch this can be done easily since everything comes in eager execution by default."
918 | ]
919 | },
920 | {
921 | "metadata": {
922 | "id": "rbSLACY45Xz-",
923 | "colab_type": "code",
924 | "colab": {
925 | "base_uri": "https://localhost:8080/",
926 | "height": 34
927 | },
928 | "outputId": "d6461fcf-eba8-4843-f2d2-e5d77e788696"
929 | },
930 | "cell_type": "code",
931 | "source": [
932 | "### Testing Encoder part\n",
933 | "# TODO: put whether GPU is available or not\n",
934 | "# Device\n",
935 | "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
936 | "encoder = Encoder(vocab_inp_size, embedding_dim, units, BATCH_SIZE)\n",
937 | "\n",
938 | "encoder.to(device)\n",
939 | "# obtain one sample from the data iterator\n",
940 | "it = iter(dataset)\n",
941 | "x, y, x_len = next(it)\n",
942 | "\n",
943 | "# sort the batch first to be able to use with pac_pack_sequence\n",
944 | "xs, ys, lens = sort_batch(x, y, x_len)\n",
945 | "\n",
946 | "enc_output, enc_hidden = encoder(xs.to(device), lens, device)\n",
947 | "\n",
948 | "print(enc_output.size()) # max_length, batch_size, enc_units"
949 | ],
950 | "execution_count": 24,
951 | "outputs": [
952 | {
953 | "output_type": "stream",
954 | "text": [
955 | "torch.Size([12, 64, 1024])\n"
956 | ],
957 | "name": "stdout"
958 | }
959 | ]
960 | },
961 | {
962 | "metadata": {
963 | "id": "t4djvgil5bMQ",
964 | "colab_type": "code",
965 | "colab": {}
966 | },
967 | "cell_type": "code",
968 | "source": [
969 | "class Decoder(nn.Module):\n",
970 | " def __init__(self, vocab_size, embedding_dim, dec_units, enc_units, batch_sz):\n",
971 | " super(Decoder, self).__init__()\n",
972 | " self.batch_sz = batch_sz\n",
973 | " self.dec_units = dec_units\n",
974 | " self.enc_units = enc_units\n",
975 | " self.vocab_size = vocab_size\n",
976 | " self.embedding_dim = embedding_dim\n",
977 | " self.embedding = nn.Embedding(self.vocab_size, self.embedding_dim)\n",
978 | " self.gru = nn.GRU(self.embedding_dim + self.enc_units, \n",
979 | " self.dec_units,\n",
980 | " batch_first=True)\n",
981 | " self.fc = nn.Linear(self.enc_units, self.vocab_size)\n",
982 | " \n",
983 | " # used for attention\n",
984 | " self.W1 = nn.Linear(self.enc_units, self.dec_units)\n",
985 | " self.W2 = nn.Linear(self.enc_units, self.dec_units)\n",
986 | " self.V = nn.Linear(self.enc_units, 1)\n",
987 | " \n",
988 | " def forward(self, x, hidden, enc_output):\n",
989 | " # enc_output original: (max_length, batch_size, enc_units)\n",
990 | " # enc_output converted == (batch_size, max_length, hidden_size)\n",
991 | " enc_output = enc_output.permute(1,0,2)\n",
992 | " # hidden shape == (batch_size, hidden size)\n",
993 | " # hidden_with_time_axis shape == (batch_size, 1, hidden size)\n",
994 | " # we are doing this to perform addition to calculate the score\n",
995 | " \n",
996 | " # hidden shape == (batch_size, hidden size)\n",
997 | " # hidden_with_time_axis shape == (batch_size, 1, hidden size)\n",
998 | " hidden_with_time_axis = hidden.permute(1, 0, 2)\n",
999 | " \n",
1000 | " # score: (batch_size, max_length, hidden_size)\n",
1001 | " score = torch.tanh(self.W1(enc_output) + self.W2(hidden_with_time_axis))\n",
1002 | " \n",
1003 | " #score = torch.tanh(self.W2(hidden_with_time_axis) + self.W1(enc_output))\n",
1004 | " \n",
1005 | " # attention_weights shape == (batch_size, max_length, 1)\n",
1006 | " # we get 1 at the last axis because we are applying score to self.V\n",
1007 | " attention_weights = torch.softmax(self.V(score), dim=1)\n",
1008 | " \n",
1009 | " # context_vector shape after sum == (batch_size, hidden_size)\n",
1010 | " context_vector = attention_weights * enc_output\n",
1011 | " context_vector = torch.sum(context_vector, dim=1)\n",
1012 | " \n",
1013 | " # x shape after passing through embedding == (batch_size, 1, embedding_dim)\n",
1014 | " x = self.embedding(x)\n",
1015 | " \n",
1016 | " # x shape after concatenation == (batch_size, 1, embedding_dim + hidden_size)\n",
1017 | " #x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)\n",
1018 | " # ? Looks like attention vector in diagram of source\n",
1019 | " x = torch.cat((context_vector.unsqueeze(1), x), -1)\n",
1020 | " \n",
1021 | " # passing the concatenated vector to the GRU\n",
1022 | " # output: (batch_size, 1, hidden_size)\n",
1023 | " output, state = self.gru(x)\n",
1024 | " \n",
1025 | " \n",
1026 | " # output shape == (batch_size * 1, hidden_size)\n",
1027 | " output = output.view(-1, output.size(2))\n",
1028 | " \n",
1029 | " # output shape == (batch_size * 1, vocab)\n",
1030 | " x = self.fc(output)\n",
1031 | " \n",
1032 | " return x, state, attention_weights\n",
1033 | " \n",
1034 | " def initialize_hidden_state(self):\n",
1035 | " return torch.zeros((1, self.batch_sz, self.dec_units))"
1036 | ],
1037 | "execution_count": 0,
1038 | "outputs": []
1039 | },
1040 | {
1041 | "metadata": {
1042 | "id": "HsG5We7Sk_UR",
1043 | "colab_type": "text"
1044 | },
1045 | "cell_type": "markdown",
1046 | "source": [
1047 | "## Testing the Decoder\n",
1048 | "Similarily, try to test the decoder."
1049 | ]
1050 | },
1051 | {
1052 | "metadata": {
1053 | "id": "lmipPRVx5fqO",
1054 | "colab_type": "code",
1055 | "colab": {
1056 | "base_uri": "https://localhost:8080/",
1057 | "height": 170
1058 | },
1059 | "outputId": "d9c4cbf1-02b9-4be1-8f6f-340a42bd38ef"
1060 | },
1061 | "cell_type": "code",
1062 | "source": [
1063 | "# Device\n",
1064 | "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
1065 | "encoder = Encoder(vocab_inp_size, embedding_dim, units, BATCH_SIZE)\n",
1066 | "\n",
1067 | "encoder.to(device)\n",
1068 | "# obtain one sample from the data iterator\n",
1069 | "it = iter(dataset)\n",
1070 | "x, y, x_len = next(it)\n",
1071 | "\n",
1072 | "print(\"Input: \", x.shape)\n",
1073 | "print(\"Output: \", y.shape)\n",
1074 | "\n",
1075 | "# sort the batch first to be able to use with pac_pack_sequence\n",
1076 | "xs, ys, lens = sort_batch(x, y, x_len)\n",
1077 | "\n",
1078 | "enc_output, enc_hidden = encoder(xs.to(device), lens, device)\n",
1079 | "print(\"Encoder Output: \", enc_output.shape) # batch_size X max_length X enc_units\n",
1080 | "print(\"Encoder Hidden: \", enc_hidden.shape) # batch_size X enc_units (corresponds to the last state)\n",
1081 | "\n",
1082 | "decoder = Decoder(vocab_tar_size, embedding_dim, units, units, BATCH_SIZE)\n",
1083 | "decoder = decoder.to(device)\n",
1084 | "\n",
1085 | "#print(enc_hidden.squeeze(0).shape)\n",
1086 | "\n",
1087 | "dec_hidden = enc_hidden#.squeeze(0)\n",
1088 | "dec_input = torch.tensor([[targ_lang.word2idx['']]] * BATCH_SIZE)\n",
1089 | "print(\"Decoder Input: \", dec_input.shape)\n",
1090 | "print(\"--------\")\n",
1091 | "\n",
1092 | "for t in range(1, y.size(1)):\n",
1093 | " # enc_hidden: 1, batch_size, enc_units\n",
1094 | " # output: max_length, batch_size, enc_units\n",
1095 | " predictions, dec_hidden, _ = decoder(dec_input.to(device), \n",
1096 | " dec_hidden.to(device), \n",
1097 | " enc_output.to(device))\n",
1098 | " \n",
1099 | " print(\"Prediction: \", predictions.shape)\n",
1100 | " print(\"Decoder Hidden: \", dec_hidden.shape)\n",
1101 | " \n",
1102 | " #loss += loss_function(y[:, t].to(device), predictions.to(device))\n",
1103 | " \n",
1104 | " dec_input = y[:, t].unsqueeze(1)\n",
1105 | " print(dec_input.shape)\n",
1106 | " break"
1107 | ],
1108 | "execution_count": 26,
1109 | "outputs": [
1110 | {
1111 | "output_type": "stream",
1112 | "text": [
1113 | "Input: torch.Size([64, 16])\n",
1114 | "Output: torch.Size([64, 11])\n",
1115 | "Encoder Output: torch.Size([11, 64, 1024])\n",
1116 | "Encoder Hidden: torch.Size([1, 64, 1024])\n",
1117 | "Decoder Input: torch.Size([64, 1])\n",
1118 | "--------\n",
1119 | "Prediction: torch.Size([64, 4935])\n",
1120 | "Decoder Hidden: torch.Size([1, 64, 1024])\n",
1121 | "torch.Size([64, 1])\n"
1122 | ],
1123 | "name": "stdout"
1124 | }
1125 | ]
1126 | },
1127 | {
1128 | "metadata": {
1129 | "id": "QclyWIop5dRG",
1130 | "colab_type": "code",
1131 | "colab": {}
1132 | },
1133 | "cell_type": "code",
1134 | "source": [
1135 | "criterion = nn.CrossEntropyLoss()\n",
1136 | "\n",
1137 | "def loss_function(real, pred):\n",
1138 | " \"\"\" Only consider non-zero inputs in the loss; mask needed \"\"\"\n",
1139 | " #mask = 1 - np.equal(real, 0) # assign 0 to all above 0 and 1 to all 0s\n",
1140 | " #print(mask)\n",
1141 | " mask = real.ge(1).type(torch.cuda.FloatTensor)\n",
1142 | " \n",
1143 | " loss_ = criterion(pred, real) * mask \n",
1144 | " return torch.mean(loss_)"
1145 | ],
1146 | "execution_count": 0,
1147 | "outputs": []
1148 | },
1149 | {
1150 | "metadata": {
1151 | "id": "LjMMYJv85hVT",
1152 | "colab_type": "code",
1153 | "colab": {}
1154 | },
1155 | "cell_type": "code",
1156 | "source": [
1157 | "# Device\n",
1158 | "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
1159 | "\n",
1160 | "## TODO: Combine the encoder and decoder into one class\n",
1161 | "encoder = Encoder(vocab_inp_size, embedding_dim, units, BATCH_SIZE)\n",
1162 | "decoder = Decoder(vocab_tar_size, embedding_dim, units, units, BATCH_SIZE)\n",
1163 | "\n",
1164 | "encoder.to(device)\n",
1165 | "decoder.to(device)\n",
1166 | "\n",
1167 | "optimizer = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), \n",
1168 | " lr=0.001)"
1169 | ],
1170 | "execution_count": 0,
1171 | "outputs": []
1172 | },
1173 | {
1174 | "metadata": {
1175 | "id": "x6_WoDZM7reU",
1176 | "colab_type": "text"
1177 | },
1178 | "cell_type": "markdown",
1179 | "source": [
1180 | "## Training\n",
1181 | "Now we start the training. We are only using 10 epochs but you can expand this to keep trainining the model for a longer period of time. Note that in this case we are teacher forcing during training. Find a more detailed explanation in the official TensorFlow [implementation](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb) of this notebook provided by the TensorFlow team. "
1182 | ]
1183 | },
1184 | {
1185 | "metadata": {
1186 | "id": "KN8G-3YY8ADm",
1187 | "colab_type": "code",
1188 | "colab": {
1189 | "base_uri": "https://localhost:8080/",
1190 | "height": 1207
1191 | },
1192 | "outputId": "52c3acf4-10be-48c5-b305-4c6d24627517"
1193 | },
1194 | "cell_type": "code",
1195 | "source": [
1196 | "EPOCHS = 10\n",
1197 | "\n",
1198 | "for epoch in range(EPOCHS):\n",
1199 | " start = time.time()\n",
1200 | " \n",
1201 | " encoder.train()\n",
1202 | " decoder.train()\n",
1203 | " \n",
1204 | " total_loss = 0\n",
1205 | " \n",
1206 | " for (batch, (inp, targ, inp_len)) in enumerate(dataset):\n",
1207 | " loss = 0\n",
1208 | " \n",
1209 | " xs, ys, lens = sort_batch(inp, targ, inp_len)\n",
1210 | " enc_output, enc_hidden = encoder(xs.to(device), lens, device)\n",
1211 | " dec_hidden = enc_hidden\n",
1212 | " dec_input = torch.tensor([[targ_lang.word2idx['']]] * BATCH_SIZE)\n",
1213 | " \n",
1214 | " for t in range(1, ys.size(1)):\n",
1215 | " predictions, dec_hidden, _ = decoder(dec_input.to(device), \n",
1216 | " dec_hidden.to(device), \n",
1217 | " enc_output.to(device))\n",
1218 | " loss += loss_function(ys[:, t].to(device), predictions.to(device))\n",
1219 | " #loss += loss_\n",
1220 | " dec_input = ys[:, t].unsqueeze(1)\n",
1221 | " \n",
1222 | " \n",
1223 | " batch_loss = (loss / int(ys.size(1)))\n",
1224 | " total_loss += batch_loss\n",
1225 | " \n",
1226 | " optimizer.zero_grad()\n",
1227 | " \n",
1228 | " loss.backward()\n",
1229 | "\n",
1230 | " ### UPDATE MODEL PARAMETERS\n",
1231 | " optimizer.step()\n",
1232 | " \n",
1233 | " if batch % 100 == 0:\n",
1234 | " print('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1,\n",
1235 | " batch,\n",
1236 | " batch_loss.detach().item()))\n",
1237 | " \n",
1238 | " \n",
1239 | " ### TODO: Save checkpoint for model\n",
1240 | " print('Epoch {} Loss {:.4f}'.format(epoch + 1,\n",
1241 | " total_loss / N_BATCH))\n",
1242 | " print('Time taken for 1 epoch {} sec\\n'.format(time.time() - start))\n",
1243 | " \n",
1244 | " \n",
1245 | " \n",
1246 | " "
1247 | ],
1248 | "execution_count": 29,
1249 | "outputs": [
1250 | {
1251 | "output_type": "stream",
1252 | "text": [
1253 | "Epoch 1 Batch 0 Loss 4.6125\n",
1254 | "Epoch 1 Batch 100 Loss 1.6659\n",
1255 | "Epoch 1 Batch 200 Loss 1.1911\n",
1256 | "Epoch 1 Batch 300 Loss 1.1147\n",
1257 | "Epoch 1 Loss 1.4692\n",
1258 | "Time taken for 1 epoch 51.93851661682129 sec\n",
1259 | "\n",
1260 | "Epoch 2 Batch 0 Loss 0.8044\n",
1261 | "Epoch 2 Batch 100 Loss 0.6910\n",
1262 | "Epoch 2 Batch 200 Loss 0.7195\n",
1263 | "Epoch 2 Batch 300 Loss 0.6796\n",
1264 | "Epoch 2 Loss 0.7119\n",
1265 | "Time taken for 1 epoch 52.01424741744995 sec\n",
1266 | "\n",
1267 | "Epoch 3 Batch 0 Loss 0.3663\n",
1268 | "Epoch 3 Batch 100 Loss 0.3461\n",
1269 | "Epoch 3 Batch 200 Loss 0.4014\n",
1270 | "Epoch 3 Batch 300 Loss 0.3965\n",
1271 | "Epoch 3 Loss 0.3881\n",
1272 | "Time taken for 1 epoch 52.2163941860199 sec\n",
1273 | "\n",
1274 | "Epoch 4 Batch 0 Loss 0.2015\n",
1275 | "Epoch 4 Batch 100 Loss 0.2091\n",
1276 | "Epoch 4 Batch 200 Loss 0.2374\n",
1277 | "Epoch 4 Batch 300 Loss 0.2498\n",
1278 | "Epoch 4 Loss 0.2204\n",
1279 | "Time taken for 1 epoch 52.22896146774292 sec\n",
1280 | "\n",
1281 | "Epoch 5 Batch 0 Loss 0.1061\n",
1282 | "Epoch 5 Batch 100 Loss 0.1226\n",
1283 | "Epoch 5 Batch 200 Loss 0.1127\n",
1284 | "Epoch 5 Batch 300 Loss 0.1286\n",
1285 | "Epoch 5 Loss 0.1395\n",
1286 | "Time taken for 1 epoch 52.19645810127258 sec\n",
1287 | "\n",
1288 | "Epoch 6 Batch 0 Loss 0.1089\n",
1289 | "Epoch 6 Batch 100 Loss 0.1074\n",
1290 | "Epoch 6 Batch 200 Loss 0.1066\n",
1291 | "Epoch 6 Batch 300 Loss 0.1371\n",
1292 | "Epoch 6 Loss 0.1025\n",
1293 | "Time taken for 1 epoch 52.23947715759277 sec\n",
1294 | "\n",
1295 | "Epoch 7 Batch 0 Loss 0.0439\n",
1296 | "Epoch 7 Batch 100 Loss 0.0801\n",
1297 | "Epoch 7 Batch 200 Loss 0.0868\n",
1298 | "Epoch 7 Batch 300 Loss 0.0746\n",
1299 | "Epoch 7 Loss 0.0832\n",
1300 | "Time taken for 1 epoch 52.218220233917236 sec\n",
1301 | "\n",
1302 | "Epoch 8 Batch 0 Loss 0.0718\n",
1303 | "Epoch 8 Batch 100 Loss 0.0752\n",
1304 | "Epoch 8 Batch 200 Loss 0.0482\n",
1305 | "Epoch 8 Batch 300 Loss 0.1020\n",
1306 | "Epoch 8 Loss 0.0712\n",
1307 | "Time taken for 1 epoch 52.18943977355957 sec\n",
1308 | "\n",
1309 | "Epoch 9 Batch 0 Loss 0.0623\n",
1310 | "Epoch 9 Batch 100 Loss 0.0690\n",
1311 | "Epoch 9 Batch 200 Loss 0.0778\n",
1312 | "Epoch 9 Batch 300 Loss 0.0705\n",
1313 | "Epoch 9 Loss 0.0650\n",
1314 | "Time taken for 1 epoch 52.27435898780823 sec\n",
1315 | "\n",
1316 | "Epoch 10 Batch 0 Loss 0.0552\n",
1317 | "Epoch 10 Batch 100 Loss 0.0621\n",
1318 | "Epoch 10 Batch 200 Loss 0.0684\n",
1319 | "Epoch 10 Batch 300 Loss 0.0439\n",
1320 | "Epoch 10 Loss 0.0635\n",
1321 | "Time taken for 1 epoch 52.313509464263916 sec\n",
1322 | "\n"
1323 | ],
1324 | "name": "stdout"
1325 | }
1326 | ]
1327 | },
1328 | {
1329 | "metadata": {
1330 | "id": "_tF5jMP0-cmv",
1331 | "colab_type": "text"
1332 | },
1333 | "cell_type": "markdown",
1334 | "source": [
1335 | "## Final Words\n",
1336 | "Notice that we only trained the model and that's it. In fact, this notebook is in experimental phase, so there could also be some bugs or something I missed during the process of converting code or training. Please comment your concerns here or submit it as an issue in the GitHub version of this notebook. I will appreciate it!\n",
1337 | "\n",
1338 | "We didn't evaluate the model or analyzed it. To encourage you to practice what you have learned in the notebook, I will suggest that you try to convert the TensorFlow code used in the [original notebook](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb) and complete this notebook. I believe the code should be straightforward, the hard part was already done in this notebook. If you manage to complete it, please submit a PR on the GitHub version of this notebook. I will gladly accept your PR. Thanks for reading and hope this notebook was useful. Keep tuned for notebooks like this on my Twitter ([omarsar0](https://twitter.com/omarsar0)). "
1339 | ]
1340 | },
1341 | {
1342 | "metadata": {
1343 | "id": "cl4ZgMd-KyTU",
1344 | "colab_type": "text"
1345 | },
1346 | "cell_type": "markdown",
1347 | "source": [
1348 | "## References\n",
1349 | "\n",
1350 | "### Seq2Seq:\n",
1351 | " - Sutskever et al. (2014) - [Sequence to Sequence Learning with Neural Networks](Sequence to Sequence Learning with Neural Networks)\n",
1352 | " - [Sequence to sequence model: Introduction and concepts](https://towardsdatascience.com/sequence-to-sequence-model-introduction-and-concepts-44d9b41cd42d)\n",
1353 | " - [Blog on seq2seq](https://guillaumegenthial.github.io/sequence-to-sequence.html)\n",
1354 | " - [Bahdanau et al. (2016) NMT jointly learning to align and translate](https://arxiv.org/pdf/1409.0473.pdf)\n",
1355 | " - [Attention is all you need](https://arxiv.org/pdf/1706.03762.pdf)"
1356 | ]
1357 | }
1358 | ]
1359 | }
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Neural Machine Translation with Attention Using PyTorch
2 | In this notebook we are going to perform machine translation using a deep learning based approach and attention mechanism. All the code is based on PyTorch and it was adopted from the tutorial provided on the official documentation of [TensorFlow](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb).
3 |
4 | Specifically, we are going to train a sequence to sequence model for Spanish to English translation. If you are not familiar with sequence to sequence models, I have provided some references at the end of this tutorial to familiarize yourself with the concept. Even if you are not familiar with seq2seq models, you can still proceed with the coding exercise. I will explain tiny details that are important as we proceed.
5 |
6 | The tutorial is very brief and I encourage you to also take a look at the official TensorFlow [notebook](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb) for more detailed explanations. The purpose of this tutorial is to understand how to convert certain code blocks into a deep learning framework like PyTorch. You will soon realize that the frameworks are very similar to some extent. The data preparation part is slightly different so I would emphasize that you spend more time analyzing this part of the code.
7 |
8 | [Colab Notebook](https://colab.research.google.com/drive/1uFJBO1pgsiFwCGIJwZlhUzaJ2srDbtw-)
9 |
10 | If you have questions you can also reach out to me at ellfae@gmail.com or Twitter ([@omarsar0](https://twitter.com/omarsar0)). See a screenshot below of the notebook:
11 |
12 | 
13 |
--------------------------------------------------------------------------------
/nmt_pytorch.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/omarsar/pytorch_neural_machine_translation_attention/1678f0b42512509840da228b0357a20b7860f900/nmt_pytorch.gif
--------------------------------------------------------------------------------