├── .gitignore ├── 1 - Sequence to Sequence Learning with Neural Networks.ipynb ├── 2 - Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation.ipynb ├── 3 - Neural Machine Translation by Jointly Learning to Align and Translate.ipynb ├── 4 - Packed Padded Sequences, Masking, Inference and BLEU.ipynb ├── 5 - Convolutional Sequence to Sequence Learning.ipynb ├── 6 - Attention is All You Need.ipynb ├── LICENSE ├── README.md └── assets ├── convseq2seq0.png ├── convseq2seq0.xml ├── convseq2seq1.png ├── convseq2seq1.xml ├── convseq2seq2.png ├── convseq2seq2.xml ├── convseq2seq3.png ├── convseq2seq3.xml ├── convseq2seq4.png ├── convseq2seq4.xml ├── convseq2seq5.png ├── convseq2seq5.xml ├── seq2seq1.png ├── seq2seq1.xml ├── seq2seq10.png ├── seq2seq10.xml ├── seq2seq2.png ├── seq2seq2.xml ├── seq2seq3.png ├── seq2seq3.xml ├── seq2seq4.png ├── seq2seq4.xml ├── seq2seq5.png ├── seq2seq5.xml ├── seq2seq6.png ├── seq2seq6.xml ├── seq2seq7.png ├── seq2seq7.xml ├── seq2seq8.png ├── seq2seq8.xml ├── seq2seq9.png ├── seq2seq9.xml ├── transformer-attention.png ├── transformer-decoder.png ├── transformer-encoder.png └── transformer1.png /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | #data 107 | *.pt 108 | .data/ 109 | data/ 110 | models/ 111 | .vector_cache/ 112 | 113 | .vscode/ 114 | -------------------------------------------------------------------------------- /1 - Sequence to Sequence Learning with Neural Networks.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# 1 - Sequence to Sequence Learning with Neural Networks\n", 8 | "\n", 9 | "In this series we'll be building a machine learning model to go from once sequence to another, using PyTorch and torchtext. This will be done on German to English translations, but the models can be applied to any problem that involves going from one sequence to another, such as summarization, i.e. going from a sequence to a shorter sequence in the same language.\n", 10 | "\n", 11 | "In this first notebook, we'll start simple to understand the general concepts by implementing the model from the [Sequence to Sequence Learning with Neural Networks](https://arxiv.org/abs/1409.3215) paper. \n", 12 | "\n", 13 | "## Introduction\n", 14 | "\n", 15 | "The most common sequence-to-sequence (seq2seq) models are *encoder-decoder* models, which commonly use a *recurrent neural network* (RNN) to *encode* the source (input) sentence into a single vector. In this notebook, we'll refer to this single vector as a *context vector*. We can think of the context vector as being an abstract representation of the entire input sentence. This vector is then *decoded* by a second RNN which learns to output the target (output) sentence by generating it one word at a time.\n", 16 | "\n", 17 | "![](assets/seq2seq1.png)\n", 18 | "\n", 19 | "The above image shows an example translation. The input/source sentence, \"guten morgen\", is passed through the embedding layer (yellow) and then input into the encoder (green). We also append a *start of sequence* (``) and *end of sequence* (``) token to the start and end of sentence, respectively. At each time-step, the input to the encoder RNN is both the embedding, $e$, of the current word, $e(x_t)$, as well as the hidden state from the previous time-step, $h_{t-1}$, and the encoder RNN outputs a new hidden state $h_t$. We can think of the hidden state as a vector representation of the sentence so far. The RNN can be represented as a function of both of $e(x_t)$ and $h_{t-1}$:\n", 20 | "\n", 21 | "$$h_t = \\text{EncoderRNN}(e(x_t), h_{t-1})$$\n", 22 | "\n", 23 | "We're using the term RNN generally here, it could be any recurrent architecture, such as an *LSTM* (Long Short-Term Memory) or a *GRU* (Gated Recurrent Unit). \n", 24 | "\n", 25 | "Here, we have $X = \\{x_1, x_2, ..., x_T\\}$, where $x_1 = \\text{}, x_2 = \\text{guten}$, etc. The initial hidden state, $h_0$, is usually either initialized to zeros or a learned parameter.\n", 26 | "\n", 27 | "Once the final word, $x_T$, has been passed into the RNN via the embedding layer, we use the final hidden state, $h_T$, as the context vector, i.e. $h_T = z$. This is a vector representation of the entire source sentence.\n", 28 | "\n", 29 | "Now we have our context vector, $z$, we can start decoding it to get the output/target sentence, \"good morning\". Again, we append start and end of sequence tokens to the target sentence. At each time-step, the input to the decoder RNN (blue) is the embedding, $d$, of current word, $d(y_t)$, as well as the hidden state from the previous time-step, $s_{t-1}$, where the initial decoder hidden state, $s_0$, is the context vector, $s_0 = z = h_T$, i.e. the initial decoder hidden state is the final encoder hidden state. Thus, similar to the encoder, we can represent the decoder as:\n", 30 | "\n", 31 | "$$s_t = \\text{DecoderRNN}(d(y_t), s_{t-1})$$\n", 32 | "\n", 33 | "Although the input/source embedding layer, $e$, and the output/target embedding layer, $d$, are both shown in yellow in the diagram they are two different embedding layers with their own parameters.\n", 34 | "\n", 35 | "In the decoder, we need to go from the hidden state to an actual word, therefore at each time-step we use $s_t$ to predict (by passing it through a `Linear` layer, shown in purple) what we think is the next word in the sequence, $\\hat{y}_t$. \n", 36 | "\n", 37 | "$$\\hat{y}_t = f(s_t)$$\n", 38 | "\n", 39 | "The words in the decoder are always generated one after another, with one per time-step. We always use `` for the first input to the decoder, $y_1$, but for subsequent inputs, $y_{t>1}$, we will sometimes use the actual, ground truth next word in the sequence, $y_t$ and sometimes use the word predicted by our decoder, $\\hat{y}_{t-1}$. This is called *teacher forcing*, see a bit more info about it [here](https://machinelearningmastery.com/teacher-forcing-for-recurrent-neural-networks/). \n", 40 | "\n", 41 | "When training/testing our model, we always know how many words are in our target sentence, so we stop generating words once we hit that many. During inference it is common to keep generating words until the model outputs an `` token or after a certain amount of words have been generated.\n", 42 | "\n", 43 | "Once we have our predicted target sentence, $\\hat{Y} = \\{ \\hat{y}_1, \\hat{y}_2, ..., \\hat{y}_T \\}$, we compare it against our actual target sentence, $Y = \\{ y_1, y_2, ..., y_T \\}$, to calculate our loss. We then use this loss to update all of the parameters in our model.\n", 44 | "\n", 45 | "## Preparing Data\n", 46 | "\n", 47 | "We'll be coding up the models in PyTorch and using torchtext to help us do all of the pre-processing required. We'll also be using spaCy to assist in the tokenization of the data." 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": 1, 53 | "metadata": {}, 54 | "outputs": [], 55 | "source": [ 56 | "import torch\n", 57 | "import torch.nn as nn\n", 58 | "import torch.optim as optim\n", 59 | "\n", 60 | "from torchtext.legacy.datasets import Multi30k\n", 61 | "from torchtext.legacy.data import Field, BucketIterator\n", 62 | "\n", 63 | "import spacy\n", 64 | "import numpy as np\n", 65 | "\n", 66 | "import random\n", 67 | "import math\n", 68 | "import time" 69 | ] 70 | }, 71 | { 72 | "cell_type": "markdown", 73 | "metadata": {}, 74 | "source": [ 75 | "We'll set the random seeds for deterministic results." 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": 2, 81 | "metadata": {}, 82 | "outputs": [], 83 | "source": [ 84 | "SEED = 1234\n", 85 | "\n", 86 | "random.seed(SEED)\n", 87 | "np.random.seed(SEED)\n", 88 | "torch.manual_seed(SEED)\n", 89 | "torch.cuda.manual_seed(SEED)\n", 90 | "torch.backends.cudnn.deterministic = True" 91 | ] 92 | }, 93 | { 94 | "cell_type": "markdown", 95 | "metadata": {}, 96 | "source": [ 97 | "Next, we'll create the tokenizers. A tokenizer is used to turn a string containing a sentence into a list of individual tokens that make up that string, e.g. \"good morning!\" becomes [\"good\", \"morning\", \"!\"]. We'll start talking about the sentences being a sequence of tokens from now, instead of saying they're a sequence of words. What's the difference? Well, \"good\" and \"morning\" are both words and tokens, but \"!\" is a token, not a word. \n", 98 | "\n", 99 | "spaCy has model for each language (\"de_core_news_sm\" for German and \"en_core_web_sm\" for English) which need to be loaded so we can access the tokenizer of each model. \n", 100 | "\n", 101 | "**Note**: the models must first be downloaded using the following on the command line: \n", 102 | "```\n", 103 | "python -m spacy download en_core_web_sm\n", 104 | "python -m spacy download de_core_news_sm\n", 105 | "```\n", 106 | "\n", 107 | "We load the models as such:" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": 3, 113 | "metadata": {}, 114 | "outputs": [], 115 | "source": [ 116 | "spacy_de = spacy.load('de_core_news_sm')\n", 117 | "spacy_en = spacy.load('en_core_web_sm')" 118 | ] 119 | }, 120 | { 121 | "cell_type": "markdown", 122 | "metadata": {}, 123 | "source": [ 124 | "Next, we create the tokenizer functions. These can be passed to torchtext and will take in the sentence as a string and return the sentence as a list of tokens.\n", 125 | "\n", 126 | "In the paper we are implementing, they find it beneficial to reverse the order of the input which they believe \"introduces many short term dependencies in the data that make the optimization problem much easier\". We copy this by reversing the German sentence after it has been transformed into a list of tokens." 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": 4, 132 | "metadata": {}, 133 | "outputs": [], 134 | "source": [ 135 | "def tokenize_de(text):\n", 136 | " \"\"\"\n", 137 | " Tokenizes German text from a string into a list of strings (tokens) and reverses it\n", 138 | " \"\"\"\n", 139 | " return [tok.text for tok in spacy_de.tokenizer(text)][::-1]\n", 140 | "\n", 141 | "def tokenize_en(text):\n", 142 | " \"\"\"\n", 143 | " Tokenizes English text from a string into a list of strings (tokens)\n", 144 | " \"\"\"\n", 145 | " return [tok.text for tok in spacy_en.tokenizer(text)]" 146 | ] 147 | }, 148 | { 149 | "cell_type": "markdown", 150 | "metadata": {}, 151 | "source": [ 152 | "torchtext's `Field`s handle how data should be processed. All of the possible arguments are detailed [here](https://github.com/pytorch/text/blob/master/torchtext/data/field.py#L61). \n", 153 | "\n", 154 | "We set the `tokenize` argument to the correct tokenization function for each, with German being the `SRC` (source) field and English being the `TRG` (target) field. The field also appends the \"start of sequence\" and \"end of sequence\" tokens via the `init_token` and `eos_token` arguments, and converts all words to lowercase." 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": 5, 160 | "metadata": {}, 161 | "outputs": [ 162 | { 163 | "name": "stderr", 164 | "output_type": "stream", 165 | "text": [ 166 | "/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/field.py:150: UserWarning: Field class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n", 167 | " warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n" 168 | ] 169 | } 170 | ], 171 | "source": [ 172 | "SRC = Field(tokenize = tokenize_de, \n", 173 | " init_token = '', \n", 174 | " eos_token = '', \n", 175 | " lower = True)\n", 176 | "\n", 177 | "TRG = Field(tokenize = tokenize_en, \n", 178 | " init_token = '', \n", 179 | " eos_token = '', \n", 180 | " lower = True)" 181 | ] 182 | }, 183 | { 184 | "cell_type": "markdown", 185 | "metadata": {}, 186 | "source": [ 187 | "Next, we download and load the train, validation and test data. \n", 188 | "\n", 189 | "The dataset we'll be using is the [Multi30k dataset](https://github.com/multi30k/dataset). This is a dataset with ~30,000 parallel English, German and French sentences, each with ~12 words per sentence. \n", 190 | "\n", 191 | "`exts` specifies which languages to use as the source and target (source goes first) and `fields` specifies which field to use for the source and target." 192 | ] 193 | }, 194 | { 195 | "cell_type": "code", 196 | "execution_count": 6, 197 | "metadata": {}, 198 | "outputs": [ 199 | { 200 | "name": "stderr", 201 | "output_type": "stream", 202 | "text": [ 203 | "/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/example.py:78: UserWarning: Example class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n", 204 | " warnings.warn('Example class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.', UserWarning)\n" 205 | ] 206 | } 207 | ], 208 | "source": [ 209 | "train_data, valid_data, test_data = Multi30k.splits(exts = ('.de', '.en'), \n", 210 | " fields = (SRC, TRG))" 211 | ] 212 | }, 213 | { 214 | "cell_type": "markdown", 215 | "metadata": {}, 216 | "source": [ 217 | "We can double check that we've loaded the right number of examples:" 218 | ] 219 | }, 220 | { 221 | "cell_type": "code", 222 | "execution_count": 7, 223 | "metadata": {}, 224 | "outputs": [ 225 | { 226 | "name": "stdout", 227 | "output_type": "stream", 228 | "text": [ 229 | "Number of training examples: 29000\n", 230 | "Number of validation examples: 1014\n", 231 | "Number of testing examples: 1000\n" 232 | ] 233 | } 234 | ], 235 | "source": [ 236 | "print(f\"Number of training examples: {len(train_data.examples)}\")\n", 237 | "print(f\"Number of validation examples: {len(valid_data.examples)}\")\n", 238 | "print(f\"Number of testing examples: {len(test_data.examples)}\")" 239 | ] 240 | }, 241 | { 242 | "cell_type": "markdown", 243 | "metadata": {}, 244 | "source": [ 245 | "We can also print out an example, making sure the source sentence is reversed:" 246 | ] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "execution_count": 8, 251 | "metadata": {}, 252 | "outputs": [ 253 | { 254 | "name": "stdout", 255 | "output_type": "stream", 256 | "text": [ 257 | "{'src': ['.', 'büsche', 'vieler', 'nähe', 'der', 'in', 'freien', 'im', 'sind', 'männer', 'weiße', 'junge', 'zwei'], 'trg': ['two', 'young', ',', 'white', 'males', 'are', 'outside', 'near', 'many', 'bushes', '.']}\n" 258 | ] 259 | } 260 | ], 261 | "source": [ 262 | "print(vars(train_data.examples[0]))" 263 | ] 264 | }, 265 | { 266 | "cell_type": "markdown", 267 | "metadata": {}, 268 | "source": [ 269 | "The period is at the beginning of the German (src) sentence, so it looks like the sentence has been correctly reversed.\n", 270 | "\n", 271 | "Next, we'll build the *vocabulary* for the source and target languages. The vocabulary is used to associate each unique token with an index (an integer). The vocabularies of the source and target languages are distinct.\n", 272 | "\n", 273 | "Using the `min_freq` argument, we only allow tokens that appear at least 2 times to appear in our vocabulary. Tokens that appear only once are converted into an `` (unknown) token.\n", 274 | "\n", 275 | "It is important to note that our vocabulary should only be built from the training set and not the validation/test set. This prevents \"information leakage\" into our model, giving us artifically inflated validation/test scores." 276 | ] 277 | }, 278 | { 279 | "cell_type": "code", 280 | "execution_count": 9, 281 | "metadata": {}, 282 | "outputs": [], 283 | "source": [ 284 | "SRC.build_vocab(train_data, min_freq = 2)\n", 285 | "TRG.build_vocab(train_data, min_freq = 2)" 286 | ] 287 | }, 288 | { 289 | "cell_type": "code", 290 | "execution_count": 10, 291 | "metadata": {}, 292 | "outputs": [ 293 | { 294 | "name": "stdout", 295 | "output_type": "stream", 296 | "text": [ 297 | "Unique tokens in source (de) vocabulary: 7853\n", 298 | "Unique tokens in target (en) vocabulary: 5893\n" 299 | ] 300 | } 301 | ], 302 | "source": [ 303 | "print(f\"Unique tokens in source (de) vocabulary: {len(SRC.vocab)}\")\n", 304 | "print(f\"Unique tokens in target (en) vocabulary: {len(TRG.vocab)}\")" 305 | ] 306 | }, 307 | { 308 | "cell_type": "markdown", 309 | "metadata": {}, 310 | "source": [ 311 | "The final step of preparing the data is to create the iterators. These can be iterated on to return a batch of data which will have a `src` attribute (the PyTorch tensors containing a batch of numericalized source sentences) and a `trg` attribute (the PyTorch tensors containing a batch of numericalized target sentences). Numericalized is just a fancy way of saying they have been converted from a sequence of readable tokens to a sequence of corresponding indexes, using the vocabulary. \n", 312 | "\n", 313 | "We also need to define a `torch.device`. This is used to tell torchText to put the tensors on the GPU or not. We use the `torch.cuda.is_available()` function, which will return `True` if a GPU is detected on our computer. We pass this `device` to the iterator.\n", 314 | "\n", 315 | "When we get a batch of examples using an iterator we need to make sure that all of the source sentences are padded to the same length, the same with the target sentences. Luckily, torchText iterators handle this for us! \n", 316 | "\n", 317 | "We use a `BucketIterator` instead of the standard `Iterator` as it creates batches in such a way that it minimizes the amount of padding in both the source and target sentences. " 318 | ] 319 | }, 320 | { 321 | "cell_type": "code", 322 | "execution_count": 11, 323 | "metadata": {}, 324 | "outputs": [], 325 | "source": [ 326 | "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')" 327 | ] 328 | }, 329 | { 330 | "cell_type": "code", 331 | "execution_count": 12, 332 | "metadata": {}, 333 | "outputs": [ 334 | { 335 | "name": "stderr", 336 | "output_type": "stream", 337 | "text": [ 338 | "/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/iterator.py:48: UserWarning: BucketIterator class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n", 339 | " warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n" 340 | ] 341 | } 342 | ], 343 | "source": [ 344 | "BATCH_SIZE = 128\n", 345 | "\n", 346 | "train_iterator, valid_iterator, test_iterator = BucketIterator.splits(\n", 347 | " (train_data, valid_data, test_data), \n", 348 | " batch_size = BATCH_SIZE, \n", 349 | " device = device)" 350 | ] 351 | }, 352 | { 353 | "cell_type": "markdown", 354 | "metadata": {}, 355 | "source": [ 356 | "## Building the Seq2Seq Model\n", 357 | "\n", 358 | "We'll be building our model in three parts. The encoder, the decoder and a seq2seq model that encapsulates the encoder and decoder and will provide a way to interface with each.\n", 359 | "\n", 360 | "### Encoder\n", 361 | "\n", 362 | "First, the encoder, a 2 layer LSTM. The paper we are implementing uses a 4-layer LSTM, but in the interest of training time we cut this down to 2-layers. The concept of multi-layer RNNs is easy to expand from 2 to 4 layers. \n", 363 | "\n", 364 | "For a multi-layer RNN, the input sentence, $X$, after being embedded goes into the first (bottom) layer of the RNN and hidden states, $H=\\{h_1, h_2, ..., h_T\\}$, output by this layer are used as inputs to the RNN in the layer above. Thus, representing each layer with a superscript, the hidden states in the first layer are given by:\n", 365 | "\n", 366 | "$$h_t^1 = \\text{EncoderRNN}^1(e(x_t), h_{t-1}^1)$$\n", 367 | "\n", 368 | "The hidden states in the second layer are given by:\n", 369 | "\n", 370 | "$$h_t^2 = \\text{EncoderRNN}^2(h_t^1, h_{t-1}^2)$$\n", 371 | "\n", 372 | "Using a multi-layer RNN also means we'll also need an initial hidden state as input per layer, $h_0^l$, and we will also output a context vector per layer, $z^l$.\n", 373 | "\n", 374 | "Without going into too much detail about LSTMs (see [this](https://colah.github.io/posts/2015-08-Understanding-LSTMs/) blog post to learn more about them), all we need to know is that they're a type of RNN which instead of just taking in a hidden state and returning a new hidden state per time-step, also take in and return a *cell state*, $c_t$, per time-step.\n", 375 | "\n", 376 | "$$\\begin{align*}\n", 377 | "h_t &= \\text{RNN}(e(x_t), h_{t-1})\\\\\n", 378 | "(h_t, c_t) &= \\text{LSTM}(e(x_t), h_{t-1}, c_{t-1})\n", 379 | "\\end{align*}$$\n", 380 | "\n", 381 | "We can just think of $c_t$ as another type of hidden state. Similar to $h_0^l$, $c_0^l$ will be initialized to a tensor of all zeros. Also, our context vector will now be both the final hidden state and the final cell state, i.e. $z^l = (h_T^l, c_T^l)$.\n", 382 | "\n", 383 | "Extending our multi-layer equations to LSTMs, we get:\n", 384 | "\n", 385 | "$$\\begin{align*}\n", 386 | "(h_t^1, c_t^1) &= \\text{EncoderLSTM}^1(e(x_t), (h_{t-1}^1, c_{t-1}^1))\\\\\n", 387 | "(h_t^2, c_t^2) &= \\text{EncoderLSTM}^2(h_t^1, (h_{t-1}^2, c_{t-1}^2))\n", 388 | "\\end{align*}$$\n", 389 | "\n", 390 | "Note how only our hidden state from the first layer is passed as input to the second layer, and not the cell state.\n", 391 | "\n", 392 | "So our encoder looks something like this: \n", 393 | "\n", 394 | "![](assets/seq2seq2.png)\n", 395 | "\n", 396 | "We create this in code by making an `Encoder` module, which requires we inherit from `torch.nn.Module` and use the `super().__init__()` as some boilerplate code. The encoder takes the following arguments:\n", 397 | "- `input_dim` is the size/dimensionality of the one-hot vectors that will be input to the encoder. This is equal to the input (source) vocabulary size.\n", 398 | "- `emb_dim` is the dimensionality of the embedding layer. This layer converts the one-hot vectors into dense vectors with `emb_dim` dimensions. \n", 399 | "- `hid_dim` is the dimensionality of the hidden and cell states.\n", 400 | "- `n_layers` is the number of layers in the RNN.\n", 401 | "- `dropout` is the amount of dropout to use. This is a regularization parameter to prevent overfitting. Check out [this](https://www.coursera.org/lecture/deep-neural-network/understanding-dropout-YaGbR) for more details about dropout.\n", 402 | "\n", 403 | "We aren't going to discuss the embedding layer in detail during these tutorials. All we need to know is that there is a step before the words - technically, the indexes of the words - are passed into the RNN, where the words are transformed into vectors. To read more about word embeddings, check these articles: [1](https://monkeylearn.com/blog/word-embeddings-transform-text-numbers/), [2](http://p.migdal.pl/2017/01/06/king-man-woman-queen-why.html), [3](http://mccormickml.com/2016/04/19/word2vec-tutorial-the-skip-gram-model/), [4](http://mccormickml.com/2017/01/11/word2vec-tutorial-part-2-negative-sampling/). \n", 404 | "\n", 405 | "The embedding layer is created using `nn.Embedding`, the LSTM with `nn.LSTM` and a dropout layer with `nn.Dropout`. Check the PyTorch [documentation](https://pytorch.org/docs/stable/nn.html) for more about these.\n", 406 | "\n", 407 | "One thing to note is that the `dropout` argument to the LSTM is how much dropout to apply between the layers of a multi-layer RNN, i.e. between the hidden states output from layer $l$ and those same hidden states being used for the input of layer $l+1$.\n", 408 | "\n", 409 | "In the `forward` method, we pass in the source sentence, $X$, which is converted into dense vectors using the `embedding` layer, and then dropout is applied. These embeddings are then passed into the RNN. As we pass a whole sequence to the RNN, it will automatically do the recurrent calculation of the hidden states over the whole sequence for us! Notice that we do not pass an initial hidden or cell state to the RNN. This is because, as noted in the [documentation](https://pytorch.org/docs/stable/nn.html#torch.nn.LSTM), that if no hidden/cell state is passed to the RNN, it will automatically create an initial hidden/cell state as a tensor of all zeros. \n", 410 | "\n", 411 | "The RNN returns: `outputs` (the top-layer hidden state for each time-step), `hidden` (the final hidden state for each layer, $h_T$, stacked on top of each other) and `cell` (the final cell state for each layer, $c_T$, stacked on top of each other).\n", 412 | "\n", 413 | "As we only need the final hidden and cell states (to make our context vector), `forward` only returns `hidden` and `cell`. \n", 414 | "\n", 415 | "The sizes of each of the tensors is left as comments in the code. In this implementation `n_directions` will always be 1, however note that bidirectional RNNs (covered in tutorial 3) will have `n_directions` as 2." 416 | ] 417 | }, 418 | { 419 | "cell_type": "code", 420 | "execution_count": 13, 421 | "metadata": {}, 422 | "outputs": [], 423 | "source": [ 424 | "class Encoder(nn.Module):\n", 425 | " def __init__(self, input_dim, emb_dim, hid_dim, n_layers, dropout):\n", 426 | " super().__init__()\n", 427 | " \n", 428 | " self.hid_dim = hid_dim\n", 429 | " self.n_layers = n_layers\n", 430 | " \n", 431 | " self.embedding = nn.Embedding(input_dim, emb_dim)\n", 432 | " \n", 433 | " self.rnn = nn.LSTM(emb_dim, hid_dim, n_layers, dropout = dropout)\n", 434 | " \n", 435 | " self.dropout = nn.Dropout(dropout)\n", 436 | " \n", 437 | " def forward(self, src):\n", 438 | " \n", 439 | " #src = [src len, batch size]\n", 440 | " \n", 441 | " embedded = self.dropout(self.embedding(src))\n", 442 | " \n", 443 | " #embedded = [src len, batch size, emb dim]\n", 444 | " \n", 445 | " outputs, (hidden, cell) = self.rnn(embedded)\n", 446 | " \n", 447 | " #outputs = [src len, batch size, hid dim * n directions]\n", 448 | " #hidden = [n layers * n directions, batch size, hid dim]\n", 449 | " #cell = [n layers * n directions, batch size, hid dim]\n", 450 | " \n", 451 | " #outputs are always from the top hidden layer\n", 452 | " \n", 453 | " return hidden, cell" 454 | ] 455 | }, 456 | { 457 | "cell_type": "markdown", 458 | "metadata": {}, 459 | "source": [ 460 | "### Decoder\n", 461 | "\n", 462 | "Next, we'll build our decoder, which will also be a 2-layer (4 in the paper) LSTM.\n", 463 | "\n", 464 | "![](assets/seq2seq3.png)\n", 465 | "\n", 466 | "The `Decoder` class does a single step of decoding, i.e. it ouputs single token per time-step. The first layer will receive a hidden and cell state from the previous time-step, $(s_{t-1}^1, c_{t-1}^1)$, and feeds it through the LSTM with the current embedded token, $y_t$, to produce a new hidden and cell state, $(s_t^1, c_t^1)$. The subsequent layers will use the hidden state from the layer below, $s_t^{l-1}$, and the previous hidden and cell states from their layer, $(s_{t-1}^l, c_{t-1}^l)$. This provides equations very similar to those in the encoder.\n", 467 | "\n", 468 | "$$\\begin{align*}\n", 469 | "(s_t^1, c_t^1) = \\text{DecoderLSTM}^1(d(y_t), (s_{t-1}^1, c_{t-1}^1))\\\\\n", 470 | "(s_t^2, c_t^2) = \\text{DecoderLSTM}^2(s_t^1, (s_{t-1}^2, c_{t-1}^2))\n", 471 | "\\end{align*}$$\n", 472 | "\n", 473 | "Remember that the initial hidden and cell states to our decoder are our context vectors, which are the final hidden and cell states of our encoder from the same layer, i.e. $(s_0^l,c_0^l)=z^l=(h_T^l,c_T^l)$.\n", 474 | "\n", 475 | "We then pass the hidden state from the top layer of the RNN, $s_t^L$, through a linear layer, $f$, to make a prediction of what the next token in the target (output) sequence should be, $\\hat{y}_{t+1}$. \n", 476 | "\n", 477 | "$$\\hat{y}_{t+1} = f(s_t^L)$$\n", 478 | "\n", 479 | "The arguments and initialization are similar to the `Encoder` class, except we now have an `output_dim` which is the size of the vocabulary for the output/target. There is also the addition of the `Linear` layer, used to make the predictions from the top layer hidden state.\n", 480 | "\n", 481 | "Within the `forward` method, we accept a batch of input tokens, previous hidden states and previous cell states. As we are only decoding one token at a time, the input tokens will always have a sequence length of 1. We `unsqueeze` the input tokens to add a sentence length dimension of 1. Then, similar to the encoder, we pass through an embedding layer and apply dropout. This batch of embedded tokens is then passed into the RNN with the previous hidden and cell states. This produces an `output` (hidden state from the top layer of the RNN), a new `hidden` state (one for each layer, stacked on top of each other) and a new `cell` state (also one per layer, stacked on top of each other). We then pass the `output` (after getting rid of the sentence length dimension) through the linear layer to receive our `prediction`. We then return the `prediction`, the new `hidden` state and the new `cell` state.\n", 482 | "\n", 483 | "**Note**: as we always have a sequence length of 1, we could use `nn.LSTMCell`, instead of `nn.LSTM`, as it is designed to handle a batch of inputs that aren't necessarily in a sequence. `nn.LSTMCell` is just a single cell and `nn.LSTM` is a wrapper around potentially multiple cells. Using the `nn.LSTMCell` in this case would mean we don't have to `unsqueeze` to add a fake sequence length dimension, but we would need one `nn.LSTMCell` per layer in the decoder and to ensure each `nn.LSTMCell` receives the correct initial hidden state from the encoder. All of this makes the code less concise - hence the decision to stick with the regular `nn.LSTM`." 484 | ] 485 | }, 486 | { 487 | "cell_type": "code", 488 | "execution_count": 14, 489 | "metadata": {}, 490 | "outputs": [], 491 | "source": [ 492 | "class Decoder(nn.Module):\n", 493 | " def __init__(self, output_dim, emb_dim, hid_dim, n_layers, dropout):\n", 494 | " super().__init__()\n", 495 | " \n", 496 | " self.output_dim = output_dim\n", 497 | " self.hid_dim = hid_dim\n", 498 | " self.n_layers = n_layers\n", 499 | " \n", 500 | " self.embedding = nn.Embedding(output_dim, emb_dim)\n", 501 | " \n", 502 | " self.rnn = nn.LSTM(emb_dim, hid_dim, n_layers, dropout = dropout)\n", 503 | " \n", 504 | " self.fc_out = nn.Linear(hid_dim, output_dim)\n", 505 | " \n", 506 | " self.dropout = nn.Dropout(dropout)\n", 507 | " \n", 508 | " def forward(self, input, hidden, cell):\n", 509 | " \n", 510 | " #input = [batch size]\n", 511 | " #hidden = [n layers * n directions, batch size, hid dim]\n", 512 | " #cell = [n layers * n directions, batch size, hid dim]\n", 513 | " \n", 514 | " #n directions in the decoder will both always be 1, therefore:\n", 515 | " #hidden = [n layers, batch size, hid dim]\n", 516 | " #context = [n layers, batch size, hid dim]\n", 517 | " \n", 518 | " input = input.unsqueeze(0)\n", 519 | " \n", 520 | " #input = [1, batch size]\n", 521 | " \n", 522 | " embedded = self.dropout(self.embedding(input))\n", 523 | " \n", 524 | " #embedded = [1, batch size, emb dim]\n", 525 | " \n", 526 | " output, (hidden, cell) = self.rnn(embedded, (hidden, cell))\n", 527 | " \n", 528 | " #output = [seq len, batch size, hid dim * n directions]\n", 529 | " #hidden = [n layers * n directions, batch size, hid dim]\n", 530 | " #cell = [n layers * n directions, batch size, hid dim]\n", 531 | " \n", 532 | " #seq len and n directions will always be 1 in the decoder, therefore:\n", 533 | " #output = [1, batch size, hid dim]\n", 534 | " #hidden = [n layers, batch size, hid dim]\n", 535 | " #cell = [n layers, batch size, hid dim]\n", 536 | " \n", 537 | " prediction = self.fc_out(output.squeeze(0))\n", 538 | " \n", 539 | " #prediction = [batch size, output dim]\n", 540 | " \n", 541 | " return prediction, hidden, cell" 542 | ] 543 | }, 544 | { 545 | "cell_type": "markdown", 546 | "metadata": {}, 547 | "source": [ 548 | "### Seq2Seq\n", 549 | "\n", 550 | "For the final part of the implemenetation, we'll implement the seq2seq model. This will handle: \n", 551 | "- receiving the input/source sentence\n", 552 | "- using the encoder to produce the context vectors \n", 553 | "- using the decoder to produce the predicted output/target sentence\n", 554 | "\n", 555 | "Our full model will look like this:\n", 556 | "\n", 557 | "![](assets/seq2seq4.png)\n", 558 | "\n", 559 | "The `Seq2Seq` model takes in an `Encoder`, `Decoder`, and a `device` (used to place tensors on the GPU, if it exists).\n", 560 | "\n", 561 | "For this implementation, we have to ensure that the number of layers and the hidden (and cell) dimensions are equal in the `Encoder` and `Decoder`. This is not always the case, we do not necessarily need the same number of layers or the same hidden dimension sizes in a sequence-to-sequence model. However, if we did something like having a different number of layers then we would need to make decisions about how this is handled. For example, if our encoder has 2 layers and our decoder only has 1, how is this handled? Do we average the two context vectors output by the decoder? Do we pass both through a linear layer? Do we only use the context vector from the highest layer? Etc.\n", 562 | "\n", 563 | "Our `forward` method takes the source sentence, target sentence and a teacher-forcing ratio. The teacher forcing ratio is used when training our model. When decoding, at each time-step we will predict what the next token in the target sequence will be from the previous tokens decoded, $\\hat{y}_{t+1}=f(s_t^L)$. With probability equal to the teaching forcing ratio (`teacher_forcing_ratio`) we will use the actual ground-truth next token in the sequence as the input to the decoder during the next time-step. However, with probability `1 - teacher_forcing_ratio`, we will use the token that the model predicted as the next input to the model, even if it doesn't match the actual next token in the sequence. \n", 564 | "\n", 565 | "The first thing we do in the `forward` method is to create an `outputs` tensor that will store all of our predictions, $\\hat{Y}$.\n", 566 | "\n", 567 | "We then feed the input/source sentence, `src`, into the encoder and receive out final hidden and cell states.\n", 568 | "\n", 569 | "The first input to the decoder is the start of sequence (``) token. As our `trg` tensor already has the `` token appended (all the way back when we defined the `init_token` in our `TRG` field) we get our $y_1$ by slicing into it. We know how long our target sentences should be (`max_len`), so we loop that many times. The last token input into the decoder is the one **before** the `` token - the `` token is never input into the decoder. \n", 570 | "\n", 571 | "During each iteration of the loop, we:\n", 572 | "- pass the input, previous hidden and previous cell states ($y_t, s_{t-1}, c_{t-1}$) into the decoder\n", 573 | "- receive a prediction, next hidden state and next cell state ($\\hat{y}_{t+1}, s_{t}, c_{t}$) from the decoder\n", 574 | "- place our prediction, $\\hat{y}_{t+1}$/`output` in our tensor of predictions, $\\hat{Y}$/`outputs`\n", 575 | "- decide if we are going to \"teacher force\" or not\n", 576 | " - if we do, the next `input` is the ground-truth next token in the sequence, $y_{t+1}$/`trg[t]`\n", 577 | " - if we don't, the next `input` is the predicted next token in the sequence, $\\hat{y}_{t+1}$/`top1`, which we get by doing an `argmax` over the output tensor\n", 578 | " \n", 579 | "Once we've made all of our predictions, we return our tensor full of predictions, $\\hat{Y}$/`outputs`.\n", 580 | "\n", 581 | "**Note**: our decoder loop starts at 1, not 0. This means the 0th element of our `outputs` tensor remains all zeros. So our `trg` and `outputs` look something like:\n", 582 | "\n", 583 | "$$\\begin{align*}\n", 584 | "\\text{trg} = [, &y_1, y_2, y_3, ]\\\\\n", 585 | "\\text{outputs} = [0, &\\hat{y}_1, \\hat{y}_2, \\hat{y}_3, ]\n", 586 | "\\end{align*}$$\n", 587 | "\n", 588 | "Later on when we calculate the loss, we cut off the first element of each tensor to get:\n", 589 | "\n", 590 | "$$\\begin{align*}\n", 591 | "\\text{trg} = [&y_1, y_2, y_3, ]\\\\\n", 592 | "\\text{outputs} = [&\\hat{y}_1, \\hat{y}_2, \\hat{y}_3, ]\n", 593 | "\\end{align*}$$" 594 | ] 595 | }, 596 | { 597 | "cell_type": "code", 598 | "execution_count": 15, 599 | "metadata": {}, 600 | "outputs": [], 601 | "source": [ 602 | "class Seq2Seq(nn.Module):\n", 603 | " def __init__(self, encoder, decoder, device):\n", 604 | " super().__init__()\n", 605 | " \n", 606 | " self.encoder = encoder\n", 607 | " self.decoder = decoder\n", 608 | " self.device = device\n", 609 | " \n", 610 | " assert encoder.hid_dim == decoder.hid_dim, \\\n", 611 | " \"Hidden dimensions of encoder and decoder must be equal!\"\n", 612 | " assert encoder.n_layers == decoder.n_layers, \\\n", 613 | " \"Encoder and decoder must have equal number of layers!\"\n", 614 | " \n", 615 | " def forward(self, src, trg, teacher_forcing_ratio = 0.5):\n", 616 | " \n", 617 | " #src = [src len, batch size]\n", 618 | " #trg = [trg len, batch size]\n", 619 | " #teacher_forcing_ratio is probability to use teacher forcing\n", 620 | " #e.g. if teacher_forcing_ratio is 0.75 we use ground-truth inputs 75% of the time\n", 621 | " \n", 622 | " batch_size = trg.shape[1]\n", 623 | " trg_len = trg.shape[0]\n", 624 | " trg_vocab_size = self.decoder.output_dim\n", 625 | " \n", 626 | " #tensor to store decoder outputs\n", 627 | " outputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(self.device)\n", 628 | " \n", 629 | " #last hidden state of the encoder is used as the initial hidden state of the decoder\n", 630 | " hidden, cell = self.encoder(src)\n", 631 | " \n", 632 | " #first input to the decoder is the tokens\n", 633 | " input = trg[0,:]\n", 634 | " \n", 635 | " for t in range(1, trg_len):\n", 636 | " \n", 637 | " #insert input token embedding, previous hidden and previous cell states\n", 638 | " #receive output tensor (predictions) and new hidden and cell states\n", 639 | " output, hidden, cell = self.decoder(input, hidden, cell)\n", 640 | " \n", 641 | " #place predictions in a tensor holding predictions for each token\n", 642 | " outputs[t] = output\n", 643 | " \n", 644 | " #decide if we are going to use teacher forcing or not\n", 645 | " teacher_force = random.random() < teacher_forcing_ratio\n", 646 | " \n", 647 | " #get the highest predicted token from our predictions\n", 648 | " top1 = output.argmax(1) \n", 649 | " \n", 650 | " #if teacher forcing, use actual next token as next input\n", 651 | " #if not, use predicted token\n", 652 | " input = trg[t] if teacher_force else top1\n", 653 | " \n", 654 | " return outputs" 655 | ] 656 | }, 657 | { 658 | "cell_type": "markdown", 659 | "metadata": {}, 660 | "source": [ 661 | "# Training the Seq2Seq Model\n", 662 | "\n", 663 | "Now we have our model implemented, we can begin training it. \n", 664 | "\n", 665 | "First, we'll initialize our model. As mentioned before, the input and output dimensions are defined by the size of the vocabulary. The embedding dimesions and dropout for the encoder and decoder can be different, but the number of layers and the size of the hidden/cell states must be the same. \n", 666 | "\n", 667 | "We then define the encoder, decoder and then our Seq2Seq model, which we place on the `device`." 668 | ] 669 | }, 670 | { 671 | "cell_type": "code", 672 | "execution_count": 16, 673 | "metadata": {}, 674 | "outputs": [], 675 | "source": [ 676 | "INPUT_DIM = len(SRC.vocab)\n", 677 | "OUTPUT_DIM = len(TRG.vocab)\n", 678 | "ENC_EMB_DIM = 256\n", 679 | "DEC_EMB_DIM = 256\n", 680 | "HID_DIM = 512\n", 681 | "N_LAYERS = 2\n", 682 | "ENC_DROPOUT = 0.5\n", 683 | "DEC_DROPOUT = 0.5\n", 684 | "\n", 685 | "enc = Encoder(INPUT_DIM, ENC_EMB_DIM, HID_DIM, N_LAYERS, ENC_DROPOUT)\n", 686 | "dec = Decoder(OUTPUT_DIM, DEC_EMB_DIM, HID_DIM, N_LAYERS, DEC_DROPOUT)\n", 687 | "\n", 688 | "model = Seq2Seq(enc, dec, device).to(device)" 689 | ] 690 | }, 691 | { 692 | "cell_type": "markdown", 693 | "metadata": {}, 694 | "source": [ 695 | "Next up is initializing the weights of our model. In the paper they state they initialize all weights from a uniform distribution between -0.08 and +0.08, i.e. $\\mathcal{U}(-0.08, 0.08)$.\n", 696 | "\n", 697 | "We initialize weights in PyTorch by creating a function which we `apply` to our model. When using `apply`, the `init_weights` function will be called on every module and sub-module within our model. For each module we loop through all of the parameters and sample them from a uniform distribution with `nn.init.uniform_`." 698 | ] 699 | }, 700 | { 701 | "cell_type": "code", 702 | "execution_count": 17, 703 | "metadata": {}, 704 | "outputs": [ 705 | { 706 | "data": { 707 | "text/plain": [ 708 | "Seq2Seq(\n", 709 | " (encoder): Encoder(\n", 710 | " (embedding): Embedding(7853, 256)\n", 711 | " (rnn): LSTM(256, 512, num_layers=2, dropout=0.5)\n", 712 | " (dropout): Dropout(p=0.5, inplace=False)\n", 713 | " )\n", 714 | " (decoder): Decoder(\n", 715 | " (embedding): Embedding(5893, 256)\n", 716 | " (rnn): LSTM(256, 512, num_layers=2, dropout=0.5)\n", 717 | " (fc_out): Linear(in_features=512, out_features=5893, bias=True)\n", 718 | " (dropout): Dropout(p=0.5, inplace=False)\n", 719 | " )\n", 720 | ")" 721 | ] 722 | }, 723 | "execution_count": 17, 724 | "metadata": {}, 725 | "output_type": "execute_result" 726 | } 727 | ], 728 | "source": [ 729 | "def init_weights(m):\n", 730 | " for name, param in m.named_parameters():\n", 731 | " nn.init.uniform_(param.data, -0.08, 0.08)\n", 732 | " \n", 733 | "model.apply(init_weights)" 734 | ] 735 | }, 736 | { 737 | "cell_type": "markdown", 738 | "metadata": {}, 739 | "source": [ 740 | "We also define a function that will calculate the number of trainable parameters in the model." 741 | ] 742 | }, 743 | { 744 | "cell_type": "code", 745 | "execution_count": 18, 746 | "metadata": {}, 747 | "outputs": [ 748 | { 749 | "name": "stdout", 750 | "output_type": "stream", 751 | "text": [ 752 | "The model has 13,898,501 trainable parameters\n" 753 | ] 754 | } 755 | ], 756 | "source": [ 757 | "def count_parameters(model):\n", 758 | " return sum(p.numel() for p in model.parameters() if p.requires_grad)\n", 759 | "\n", 760 | "print(f'The model has {count_parameters(model):,} trainable parameters')" 761 | ] 762 | }, 763 | { 764 | "cell_type": "markdown", 765 | "metadata": {}, 766 | "source": [ 767 | "We define our optimizer, which we use to update our parameters in the training loop. Check out [this](http://ruder.io/optimizing-gradient-descent/) post for information about different optimizers. Here, we'll use Adam." 768 | ] 769 | }, 770 | { 771 | "cell_type": "code", 772 | "execution_count": 19, 773 | "metadata": {}, 774 | "outputs": [], 775 | "source": [ 776 | "optimizer = optim.Adam(model.parameters())" 777 | ] 778 | }, 779 | { 780 | "cell_type": "markdown", 781 | "metadata": {}, 782 | "source": [ 783 | "Next, we define our loss function. The `CrossEntropyLoss` function calculates both the log softmax as well as the negative log-likelihood of our predictions. \n", 784 | "\n", 785 | "Our loss function calculates the average loss per token, however by passing the index of the `` token as the `ignore_index` argument we ignore the loss whenever the target token is a padding token. " 786 | ] 787 | }, 788 | { 789 | "cell_type": "code", 790 | "execution_count": 20, 791 | "metadata": {}, 792 | "outputs": [], 793 | "source": [ 794 | "TRG_PAD_IDX = TRG.vocab.stoi[TRG.pad_token]\n", 795 | "\n", 796 | "criterion = nn.CrossEntropyLoss(ignore_index = TRG_PAD_IDX)" 797 | ] 798 | }, 799 | { 800 | "cell_type": "markdown", 801 | "metadata": {}, 802 | "source": [ 803 | "Next, we'll define our training loop. \n", 804 | "\n", 805 | "First, we'll set the model into \"training mode\" with `model.train()`. This will turn on dropout (and batch normalization, which we aren't using) and then iterate through our data iterator.\n", 806 | "\n", 807 | "As stated before, our decoder loop starts at 1, not 0. This means the 0th element of our `outputs` tensor remains all zeros. So our `trg` and `outputs` look something like:\n", 808 | "\n", 809 | "$$\\begin{align*}\n", 810 | "\\text{trg} = [, &y_1, y_2, y_3, ]\\\\\n", 811 | "\\text{outputs} = [0, &\\hat{y}_1, \\hat{y}_2, \\hat{y}_3, ]\n", 812 | "\\end{align*}$$\n", 813 | "\n", 814 | "Here, when we calculate the loss, we cut off the first element of each tensor to get:\n", 815 | "\n", 816 | "$$\\begin{align*}\n", 817 | "\\text{trg} = [&y_1, y_2, y_3, ]\\\\\n", 818 | "\\text{outputs} = [&\\hat{y}_1, \\hat{y}_2, \\hat{y}_3, ]\n", 819 | "\\end{align*}$$\n", 820 | "\n", 821 | "At each iteration:\n", 822 | "- get the source and target sentences from the batch, $X$ and $Y$\n", 823 | "- zero the gradients calculated from the last batch\n", 824 | "- feed the source and target into the model to get the output, $\\hat{Y}$\n", 825 | "- as the loss function only works on 2d inputs with 1d targets we need to flatten each of them with `.view`\n", 826 | " - we slice off the first column of the output and target tensors as mentioned above\n", 827 | "- calculate the gradients with `loss.backward()`\n", 828 | "- clip the gradients to prevent them from exploding (a common issue in RNNs)\n", 829 | "- update the parameters of our model by doing an optimizer step\n", 830 | "- sum the loss value to a running total\n", 831 | "\n", 832 | "Finally, we return the loss that is averaged over all batches." 833 | ] 834 | }, 835 | { 836 | "cell_type": "code", 837 | "execution_count": 21, 838 | "metadata": {}, 839 | "outputs": [], 840 | "source": [ 841 | "def train(model, iterator, optimizer, criterion, clip):\n", 842 | " \n", 843 | " model.train()\n", 844 | " \n", 845 | " epoch_loss = 0\n", 846 | " \n", 847 | " for i, batch in enumerate(iterator):\n", 848 | " \n", 849 | " src = batch.src\n", 850 | " trg = batch.trg\n", 851 | " \n", 852 | " optimizer.zero_grad()\n", 853 | " \n", 854 | " output = model(src, trg)\n", 855 | " \n", 856 | " #trg = [trg len, batch size]\n", 857 | " #output = [trg len, batch size, output dim]\n", 858 | " \n", 859 | " output_dim = output.shape[-1]\n", 860 | " \n", 861 | " output = output[1:].view(-1, output_dim)\n", 862 | " trg = trg[1:].view(-1)\n", 863 | " \n", 864 | " #trg = [(trg len - 1) * batch size]\n", 865 | " #output = [(trg len - 1) * batch size, output dim]\n", 866 | " \n", 867 | " loss = criterion(output, trg)\n", 868 | " \n", 869 | " loss.backward()\n", 870 | " \n", 871 | " torch.nn.utils.clip_grad_norm_(model.parameters(), clip)\n", 872 | " \n", 873 | " optimizer.step()\n", 874 | " \n", 875 | " epoch_loss += loss.item()\n", 876 | " \n", 877 | " return epoch_loss / len(iterator)" 878 | ] 879 | }, 880 | { 881 | "cell_type": "markdown", 882 | "metadata": {}, 883 | "source": [ 884 | "Our evaluation loop is similar to our training loop, however as we aren't updating any parameters we don't need to pass an optimizer or a clip value.\n", 885 | "\n", 886 | "We must remember to set the model to evaluation mode with `model.eval()`. This will turn off dropout (and batch normalization, if used).\n", 887 | "\n", 888 | "We use the `with torch.no_grad()` block to ensure no gradients are calculated within the block. This reduces memory consumption and speeds things up. \n", 889 | "\n", 890 | "The iteration loop is similar (without the parameter updates), however we must ensure we turn teacher forcing off for evaluation. This will cause the model to only use it's own predictions to make further predictions within a sentence, which mirrors how it would be used in deployment." 891 | ] 892 | }, 893 | { 894 | "cell_type": "code", 895 | "execution_count": 22, 896 | "metadata": {}, 897 | "outputs": [], 898 | "source": [ 899 | "def evaluate(model, iterator, criterion):\n", 900 | " \n", 901 | " model.eval()\n", 902 | " \n", 903 | " epoch_loss = 0\n", 904 | " \n", 905 | " with torch.no_grad():\n", 906 | " \n", 907 | " for i, batch in enumerate(iterator):\n", 908 | "\n", 909 | " src = batch.src\n", 910 | " trg = batch.trg\n", 911 | "\n", 912 | " output = model(src, trg, 0) #turn off teacher forcing\n", 913 | "\n", 914 | " #trg = [trg len, batch size]\n", 915 | " #output = [trg len, batch size, output dim]\n", 916 | "\n", 917 | " output_dim = output.shape[-1]\n", 918 | " \n", 919 | " output = output[1:].view(-1, output_dim)\n", 920 | " trg = trg[1:].view(-1)\n", 921 | "\n", 922 | " #trg = [(trg len - 1) * batch size]\n", 923 | " #output = [(trg len - 1) * batch size, output dim]\n", 924 | "\n", 925 | " loss = criterion(output, trg)\n", 926 | " \n", 927 | " epoch_loss += loss.item()\n", 928 | " \n", 929 | " return epoch_loss / len(iterator)" 930 | ] 931 | }, 932 | { 933 | "cell_type": "markdown", 934 | "metadata": {}, 935 | "source": [ 936 | "Next, we'll create a function that we'll use to tell us how long an epoch takes." 937 | ] 938 | }, 939 | { 940 | "cell_type": "code", 941 | "execution_count": 23, 942 | "metadata": {}, 943 | "outputs": [], 944 | "source": [ 945 | "def epoch_time(start_time, end_time):\n", 946 | " elapsed_time = end_time - start_time\n", 947 | " elapsed_mins = int(elapsed_time / 60)\n", 948 | " elapsed_secs = int(elapsed_time - (elapsed_mins * 60))\n", 949 | " return elapsed_mins, elapsed_secs" 950 | ] 951 | }, 952 | { 953 | "cell_type": "markdown", 954 | "metadata": {}, 955 | "source": [ 956 | "We can finally start training our model!\n", 957 | "\n", 958 | "At each epoch, we'll be checking if our model has achieved the best validation loss so far. If it has, we'll update our best validation loss and save the parameters of our model (called `state_dict` in PyTorch). Then, when we come to test our model, we'll use the saved parameters used to achieve the best validation loss. \n", 959 | "\n", 960 | "We'll be printing out both the loss and the perplexity at each epoch. It is easier to see a change in perplexity than a change in loss as the numbers are much bigger." 961 | ] 962 | }, 963 | { 964 | "cell_type": "code", 965 | "execution_count": 24, 966 | "metadata": {}, 967 | "outputs": [ 968 | { 969 | "name": "stderr", 970 | "output_type": "stream", 971 | "text": [ 972 | "/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/batch.py:23: UserWarning: Batch class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n", 973 | " warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n" 974 | ] 975 | }, 976 | { 977 | "name": "stdout", 978 | "output_type": "stream", 979 | "text": [ 980 | "Epoch: 01 | Time: 0m 26s\n", 981 | "\tTrain Loss: 5.052 | Train PPL: 156.386\n", 982 | "\t Val. Loss: 4.916 | Val. PPL: 136.446\n", 983 | "Epoch: 02 | Time: 0m 26s\n", 984 | "\tTrain Loss: 4.483 | Train PPL: 88.521\n", 985 | "\t Val. Loss: 4.789 | Val. PPL: 120.154\n", 986 | "Epoch: 03 | Time: 0m 25s\n", 987 | "\tTrain Loss: 4.195 | Train PPL: 66.363\n", 988 | "\t Val. Loss: 4.552 | Val. PPL: 94.854\n", 989 | "Epoch: 04 | Time: 0m 25s\n", 990 | "\tTrain Loss: 3.963 | Train PPL: 52.625\n", 991 | "\t Val. Loss: 4.485 | Val. PPL: 88.672\n", 992 | "Epoch: 05 | Time: 0m 25s\n", 993 | "\tTrain Loss: 3.783 | Train PPL: 43.955\n", 994 | "\t Val. Loss: 4.375 | Val. PPL: 79.466\n", 995 | "Epoch: 06 | Time: 0m 25s\n", 996 | "\tTrain Loss: 3.636 | Train PPL: 37.957\n", 997 | "\t Val. Loss: 4.234 | Val. PPL: 69.011\n", 998 | "Epoch: 07 | Time: 0m 26s\n", 999 | "\tTrain Loss: 3.506 | Train PPL: 33.329\n", 1000 | "\t Val. Loss: 4.077 | Val. PPL: 58.948\n", 1001 | "Epoch: 08 | Time: 0m 27s\n", 1002 | "\tTrain Loss: 3.370 | Train PPL: 29.090\n", 1003 | "\t Val. Loss: 4.018 | Val. PPL: 55.581\n", 1004 | "Epoch: 09 | Time: 0m 26s\n", 1005 | "\tTrain Loss: 3.241 | Train PPL: 25.569\n", 1006 | "\t Val. Loss: 3.934 | Val. PPL: 51.113\n", 1007 | "Epoch: 10 | Time: 0m 26s\n", 1008 | "\tTrain Loss: 3.157 | Train PPL: 23.492\n", 1009 | "\t Val. Loss: 3.927 | Val. PPL: 50.743\n" 1010 | ] 1011 | } 1012 | ], 1013 | "source": [ 1014 | "N_EPOCHS = 10\n", 1015 | "CLIP = 1\n", 1016 | "\n", 1017 | "best_valid_loss = float('inf')\n", 1018 | "\n", 1019 | "for epoch in range(N_EPOCHS):\n", 1020 | " \n", 1021 | " start_time = time.time()\n", 1022 | " \n", 1023 | " train_loss = train(model, train_iterator, optimizer, criterion, CLIP)\n", 1024 | " valid_loss = evaluate(model, valid_iterator, criterion)\n", 1025 | " \n", 1026 | " end_time = time.time()\n", 1027 | " \n", 1028 | " epoch_mins, epoch_secs = epoch_time(start_time, end_time)\n", 1029 | " \n", 1030 | " if valid_loss < best_valid_loss:\n", 1031 | " best_valid_loss = valid_loss\n", 1032 | " torch.save(model.state_dict(), 'tut1-model.pt')\n", 1033 | " \n", 1034 | " print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')\n", 1035 | " print(f'\\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')\n", 1036 | " print(f'\\t Val. Loss: {valid_loss:.3f} | Val. PPL: {math.exp(valid_loss):7.3f}')" 1037 | ] 1038 | }, 1039 | { 1040 | "cell_type": "markdown", 1041 | "metadata": {}, 1042 | "source": [ 1043 | "We'll load the parameters (`state_dict`) that gave our model the best validation loss and run it the model on the test set." 1044 | ] 1045 | }, 1046 | { 1047 | "cell_type": "code", 1048 | "execution_count": 25, 1049 | "metadata": {}, 1050 | "outputs": [ 1051 | { 1052 | "name": "stdout", 1053 | "output_type": "stream", 1054 | "text": [ 1055 | "| Test Loss: 3.951 | Test PPL: 52.001 |\n" 1056 | ] 1057 | } 1058 | ], 1059 | "source": [ 1060 | "model.load_state_dict(torch.load('tut1-model.pt'))\n", 1061 | "\n", 1062 | "test_loss = evaluate(model, test_iterator, criterion)\n", 1063 | "\n", 1064 | "print(f'| Test Loss: {test_loss:.3f} | Test PPL: {math.exp(test_loss):7.3f} |')" 1065 | ] 1066 | }, 1067 | { 1068 | "cell_type": "markdown", 1069 | "metadata": {}, 1070 | "source": [ 1071 | "In the following notebook we'll implement a model that achieves improved test perplexity, but only uses a single layer in the encoder and the decoder." 1072 | ] 1073 | } 1074 | ], 1075 | "metadata": { 1076 | "kernelspec": { 1077 | "display_name": "Python 3", 1078 | "language": "python", 1079 | "name": "python3" 1080 | }, 1081 | "language_info": { 1082 | "codemirror_mode": { 1083 | "name": "ipython", 1084 | "version": 3 1085 | }, 1086 | "file_extension": ".py", 1087 | "mimetype": "text/x-python", 1088 | "name": "python", 1089 | "nbconvert_exporter": "python", 1090 | "pygments_lexer": "ipython3", 1091 | "version": "3.8.5" 1092 | } 1093 | }, 1094 | "nbformat": 4, 1095 | "nbformat_minor": 2 1096 | } -------------------------------------------------------------------------------- /2 - Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# 2 - Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation\n", 8 | "\n", 9 | "In this second notebook on sequence-to-sequence models using PyTorch and TorchText, we'll be implementing the model from [Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation](https://arxiv.org/abs/1406.1078). This model will achieve improved test perplexity whilst only using a single layer RNN in both the encoder and the decoder.\n", 10 | "\n", 11 | "## Introduction\n", 12 | "\n", 13 | "Let's remind ourselves of the general encoder-decoder model.\n", 14 | "\n", 15 | "![](assets/seq2seq1.png)\n", 16 | "\n", 17 | "We use our encoder (green) over the embedded source sequence (yellow) to create a context vector (red). We then use that context vector with the decoder (blue) and a linear layer (purple) to generate the target sentence.\n", 18 | "\n", 19 | "In the previous model, we used an multi-layered LSTM as the encoder and decoder.\n", 20 | "\n", 21 | "![](assets/seq2seq4.png)\n", 22 | "\n", 23 | "One downside of the previous model is that the decoder is trying to cram lots of information into the hidden states. Whilst decoding, the hidden state will need to contain information about the whole of the source sequence, as well as all of the tokens have been decoded so far. By alleviating some of this information compression, we can create a better model!\n", 24 | "\n", 25 | "We'll also be using a GRU (Gated Recurrent Unit) instead of an LSTM (Long Short-Term Memory). Why? Mainly because that's what they did in the paper (this paper also introduced GRUs) and also because we used LSTMs last time. To understand how GRUs (and LSTMs) differ from standard RNNS, check out [this](https://colah.github.io/posts/2015-08-Understanding-LSTMs/) link. Is a GRU better than an LSTM? [Research](https://arxiv.org/abs/1412.3555) has shown they're pretty much the same, and both are better than standard RNNs. \n", 26 | "\n", 27 | "## Preparing Data\n", 28 | "\n", 29 | "All of the data preparation will be (almost) the same as last time, so we'll very briefly detail what each code block does. See the previous notebook for a recap.\n", 30 | "\n", 31 | "We'll import PyTorch, TorchText, spaCy and a few standard modules." 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": 1, 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "import torch\n", 41 | "import torch.nn as nn\n", 42 | "import torch.optim as optim\n", 43 | "\n", 44 | "from torchtext.legacy.datasets import Multi30k\n", 45 | "from torchtext.legacy.data import Field, BucketIterator\n", 46 | "\n", 47 | "import spacy\n", 48 | "import numpy as np\n", 49 | "\n", 50 | "import random\n", 51 | "import math\n", 52 | "import time" 53 | ] 54 | }, 55 | { 56 | "cell_type": "markdown", 57 | "metadata": {}, 58 | "source": [ 59 | "Then set a random seed for deterministic results/reproducability." 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": 2, 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "SEED = 1234\n", 69 | "\n", 70 | "random.seed(SEED)\n", 71 | "np.random.seed(SEED)\n", 72 | "torch.manual_seed(SEED)\n", 73 | "torch.cuda.manual_seed(SEED)\n", 74 | "torch.backends.cudnn.deterministic = True" 75 | ] 76 | }, 77 | { 78 | "cell_type": "markdown", 79 | "metadata": {}, 80 | "source": [ 81 | "Instantiate our German and English spaCy models." 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": 3, 87 | "metadata": {}, 88 | "outputs": [], 89 | "source": [ 90 | "spacy_de = spacy.load('de_core_news_sm')\n", 91 | "spacy_en = spacy.load('en_core_web_sm')" 92 | ] 93 | }, 94 | { 95 | "cell_type": "markdown", 96 | "metadata": {}, 97 | "source": [ 98 | "Previously we reversed the source (German) sentence, however in the paper we are implementing they don't do this, so neither will we." 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": 4, 104 | "metadata": {}, 105 | "outputs": [], 106 | "source": [ 107 | "def tokenize_de(text):\n", 108 | " \"\"\"\n", 109 | " Tokenizes German text from a string into a list of strings\n", 110 | " \"\"\"\n", 111 | " return [tok.text for tok in spacy_de.tokenizer(text)]\n", 112 | "\n", 113 | "def tokenize_en(text):\n", 114 | " \"\"\"\n", 115 | " Tokenizes English text from a string into a list of strings\n", 116 | " \"\"\"\n", 117 | " return [tok.text for tok in spacy_en.tokenizer(text)]" 118 | ] 119 | }, 120 | { 121 | "cell_type": "markdown", 122 | "metadata": {}, 123 | "source": [ 124 | "Create our fields to process our data. This will append the \"start of sentence\" and \"end of sentence\" tokens as well as converting all words to lowercase." 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": 5, 130 | "metadata": {}, 131 | "outputs": [ 132 | { 133 | "name": "stderr", 134 | "output_type": "stream", 135 | "text": [ 136 | "/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/field.py:150: UserWarning: Field class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n", 137 | " warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n" 138 | ] 139 | } 140 | ], 141 | "source": [ 142 | "SRC = Field(tokenize=tokenize_de, \n", 143 | " init_token='', \n", 144 | " eos_token='', \n", 145 | " lower=True)\n", 146 | "\n", 147 | "TRG = Field(tokenize = tokenize_en, \n", 148 | " init_token='', \n", 149 | " eos_token='', \n", 150 | " lower=True)" 151 | ] 152 | }, 153 | { 154 | "cell_type": "markdown", 155 | "metadata": {}, 156 | "source": [ 157 | "Load our data." 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "execution_count": 6, 163 | "metadata": {}, 164 | "outputs": [ 165 | { 166 | "name": "stderr", 167 | "output_type": "stream", 168 | "text": [ 169 | "/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/example.py:78: UserWarning: Example class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n", 170 | " warnings.warn('Example class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.', UserWarning)\n" 171 | ] 172 | } 173 | ], 174 | "source": [ 175 | "train_data, valid_data, test_data = Multi30k.splits(exts = ('.de', '.en'), \n", 176 | " fields = (SRC, TRG))" 177 | ] 178 | }, 179 | { 180 | "cell_type": "markdown", 181 | "metadata": {}, 182 | "source": [ 183 | "We'll also print out an example just to double check they're not reversed." 184 | ] 185 | }, 186 | { 187 | "cell_type": "code", 188 | "execution_count": 7, 189 | "metadata": {}, 190 | "outputs": [ 191 | { 192 | "name": "stdout", 193 | "output_type": "stream", 194 | "text": [ 195 | "{'src': ['zwei', 'junge', 'weiße', 'männer', 'sind', 'im', 'freien', 'in', 'der', 'nähe', 'vieler', 'büsche', '.'], 'trg': ['two', 'young', ',', 'white', 'males', 'are', 'outside', 'near', 'many', 'bushes', '.']}\n" 196 | ] 197 | } 198 | ], 199 | "source": [ 200 | "print(vars(train_data.examples[0]))" 201 | ] 202 | }, 203 | { 204 | "cell_type": "markdown", 205 | "metadata": {}, 206 | "source": [ 207 | "Then create our vocabulary, converting all tokens appearing less than twice into `` tokens." 208 | ] 209 | }, 210 | { 211 | "cell_type": "code", 212 | "execution_count": 8, 213 | "metadata": {}, 214 | "outputs": [], 215 | "source": [ 216 | "SRC.build_vocab(train_data, min_freq = 2)\n", 217 | "TRG.build_vocab(train_data, min_freq = 2)" 218 | ] 219 | }, 220 | { 221 | "cell_type": "markdown", 222 | "metadata": {}, 223 | "source": [ 224 | "Finally, define the `device` and create our iterators." 225 | ] 226 | }, 227 | { 228 | "cell_type": "code", 229 | "execution_count": 9, 230 | "metadata": {}, 231 | "outputs": [], 232 | "source": [ 233 | "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')" 234 | ] 235 | }, 236 | { 237 | "cell_type": "code", 238 | "execution_count": 10, 239 | "metadata": {}, 240 | "outputs": [ 241 | { 242 | "name": "stderr", 243 | "output_type": "stream", 244 | "text": [ 245 | "/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/iterator.py:48: UserWarning: BucketIterator class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n", 246 | " warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n" 247 | ] 248 | } 249 | ], 250 | "source": [ 251 | "BATCH_SIZE = 128\n", 252 | "\n", 253 | "train_iterator, valid_iterator, test_iterator = BucketIterator.splits(\n", 254 | " (train_data, valid_data, test_data), \n", 255 | " batch_size = BATCH_SIZE, \n", 256 | " device = device)" 257 | ] 258 | }, 259 | { 260 | "cell_type": "markdown", 261 | "metadata": {}, 262 | "source": [ 263 | "## Building the Seq2Seq Model\n", 264 | "\n", 265 | "### Encoder\n", 266 | "\n", 267 | "The encoder is similar to the previous one, with the multi-layer LSTM swapped for a single-layer GRU. We also don't pass the dropout as an argument to the GRU as that dropout is used between each layer of a multi-layered RNN. As we only have a single layer, PyTorch will display a warning if we try and use pass a dropout value to it.\n", 268 | "\n", 269 | "Another thing to note about the GRU is that it only requires and returns a hidden state, there is no cell state like in the LSTM.\n", 270 | "\n", 271 | "$$\\begin{align*}\n", 272 | "h_t &= \\text{GRU}(e(x_t), h_{t-1})\\\\\n", 273 | "(h_t, c_t) &= \\text{LSTM}(e(x_t), h_{t-1}, c_{t-1})\\\\\n", 274 | "h_t &= \\text{RNN}(e(x_t), h_{t-1})\n", 275 | "\\end{align*}$$\n", 276 | "\n", 277 | "From the equations above, it looks like the RNN and the GRU are identical. Inside the GRU, however, is a number of *gating mechanisms* that control the information flow in to and out of the hidden state (similar to an LSTM). Again, for more info, check out [this](https://colah.github.io/posts/2015-08-Understanding-LSTMs/) excellent post. \n", 278 | "\n", 279 | "The rest of the encoder should be very familar from the last tutorial, it takes in a sequence, $X = \\{x_1, x_2, ... , x_T\\}$, passes it through the embedding layer, recurrently calculates hidden states, $H = \\{h_1, h_2, ..., h_T\\}$, and returns a context vector (the final hidden state), $z=h_T$.\n", 280 | "\n", 281 | "$$h_t = \\text{EncoderGRU}(e(x_t), h_{t-1})$$\n", 282 | "\n", 283 | "This is identical to the encoder of the general seq2seq model, with all the \"magic\" happening inside the GRU (green).\n", 284 | "\n", 285 | "![](assets/seq2seq5.png)" 286 | ] 287 | }, 288 | { 289 | "cell_type": "code", 290 | "execution_count": 11, 291 | "metadata": {}, 292 | "outputs": [], 293 | "source": [ 294 | "class Encoder(nn.Module):\n", 295 | " def __init__(self, input_dim, emb_dim, hid_dim, dropout):\n", 296 | " super().__init__()\n", 297 | "\n", 298 | " self.hid_dim = hid_dim\n", 299 | " \n", 300 | " self.embedding = nn.Embedding(input_dim, emb_dim) #no dropout as only one layer!\n", 301 | " \n", 302 | " self.rnn = nn.GRU(emb_dim, hid_dim)\n", 303 | " \n", 304 | " self.dropout = nn.Dropout(dropout)\n", 305 | " \n", 306 | " def forward(self, src):\n", 307 | " \n", 308 | " #src = [src len, batch size]\n", 309 | " \n", 310 | " embedded = self.dropout(self.embedding(src))\n", 311 | " \n", 312 | " #embedded = [src len, batch size, emb dim]\n", 313 | " \n", 314 | " outputs, hidden = self.rnn(embedded) #no cell state!\n", 315 | " \n", 316 | " #outputs = [src len, batch size, hid dim * n directions]\n", 317 | " #hidden = [n layers * n directions, batch size, hid dim]\n", 318 | " \n", 319 | " #outputs are always from the top hidden layer\n", 320 | " \n", 321 | " return hidden" 322 | ] 323 | }, 324 | { 325 | "cell_type": "markdown", 326 | "metadata": {}, 327 | "source": [ 328 | "## Decoder\n", 329 | "\n", 330 | "The decoder is where the implementation differs significantly from the previous model and we alleviate some of the information compression.\n", 331 | "\n", 332 | "Instead of the GRU in the decoder taking just the embedded target token, $d(y_t)$ and the previous hidden state $s_{t-1}$ as inputs, it also takes the context vector $z$. \n", 333 | "\n", 334 | "$$s_t = \\text{DecoderGRU}(d(y_t), s_{t-1}, z)$$\n", 335 | "\n", 336 | "Note how this context vector, $z$, does not have a $t$ subscript, meaning we re-use the same context vector returned by the encoder for every time-step in the decoder. \n", 337 | "\n", 338 | "Before, we predicted the next token, $\\hat{y}_{t+1}$, with the linear layer, $f$, only using the top-layer decoder hidden state at that time-step, $s_t$, as $\\hat{y}_{t+1}=f(s_t^L)$. Now, we also pass the embedding of current token, $d(y_t)$ and the context vector, $z$ to the linear layer.\n", 339 | "\n", 340 | "$$\\hat{y}_{t+1} = f(d(y_t), s_t, z)$$\n", 341 | "\n", 342 | "Thus, our decoder now looks something like this:\n", 343 | "\n", 344 | "![](assets/seq2seq6.png)\n", 345 | "\n", 346 | "Note, the initial hidden state, $s_0$, is still the context vector, $z$, so when generating the first token we are actually inputting two identical context vectors into the GRU.\n", 347 | "\n", 348 | "How do these two changes reduce the information compression? Well, hypothetically the decoder hidden states, $s_t$, no longer need to contain information about the source sequence as it is always available as an input. Thus, it only needs to contain information about what tokens it has generated so far. The addition of $y_t$ to the linear layer also means this layer can directly see what the token is, without having to get this information from the hidden state. \n", 349 | "\n", 350 | "However, this hypothesis is just a hypothesis, it is impossible to determine how the model actually uses the information provided to it (don't listen to anyone that says differently). Nevertheless, it is a solid intuition and the results seem to indicate that this modifications are a good idea!\n", 351 | "\n", 352 | "Within the implementation, we will pass $d(y_t)$ and $z$ to the GRU by concatenating them together, so the input dimensions to the GRU are now `emb_dim + hid_dim` (as context vector will be of size `hid_dim`). The linear layer will take $d(y_t), s_t$ and $z$ also by concatenating them together, hence the input dimensions are now `emb_dim + hid_dim*2`. We also don't pass a value of dropout to the GRU as it only uses a single layer.\n", 353 | "\n", 354 | "`forward` now takes a `context` argument. Inside of `forward`, we concatenate $y_t$ and $z$ as `emb_con` before feeding to the GRU, and we concatenate $d(y_t)$, $s_t$ and $z$ together as `output` before feeding it through the linear layer to receive our predictions, $\\hat{y}_{t+1}$." 355 | ] 356 | }, 357 | { 358 | "cell_type": "code", 359 | "execution_count": 12, 360 | "metadata": {}, 361 | "outputs": [], 362 | "source": [ 363 | "class Decoder(nn.Module):\n", 364 | " def __init__(self, output_dim, emb_dim, hid_dim, dropout):\n", 365 | " super().__init__()\n", 366 | "\n", 367 | " self.hid_dim = hid_dim\n", 368 | " self.output_dim = output_dim\n", 369 | " \n", 370 | " self.embedding = nn.Embedding(output_dim, emb_dim)\n", 371 | " \n", 372 | " self.rnn = nn.GRU(emb_dim + hid_dim, hid_dim)\n", 373 | " \n", 374 | " self.fc_out = nn.Linear(emb_dim + hid_dim * 2, output_dim)\n", 375 | " \n", 376 | " self.dropout = nn.Dropout(dropout)\n", 377 | " \n", 378 | " def forward(self, input, hidden, context):\n", 379 | " \n", 380 | " #input = [batch size]\n", 381 | " #hidden = [n layers * n directions, batch size, hid dim]\n", 382 | " #context = [n layers * n directions, batch size, hid dim]\n", 383 | " \n", 384 | " #n layers and n directions in the decoder will both always be 1, therefore:\n", 385 | " #hidden = [1, batch size, hid dim]\n", 386 | " #context = [1, batch size, hid dim]\n", 387 | " \n", 388 | " input = input.unsqueeze(0)\n", 389 | " \n", 390 | " #input = [1, batch size]\n", 391 | " \n", 392 | " embedded = self.dropout(self.embedding(input))\n", 393 | " \n", 394 | " #embedded = [1, batch size, emb dim]\n", 395 | " \n", 396 | " emb_con = torch.cat((embedded, context), dim = 2)\n", 397 | " \n", 398 | " #emb_con = [1, batch size, emb dim + hid dim]\n", 399 | " \n", 400 | " output, hidden = self.rnn(emb_con, hidden)\n", 401 | " \n", 402 | " #output = [seq len, batch size, hid dim * n directions]\n", 403 | " #hidden = [n layers * n directions, batch size, hid dim]\n", 404 | " \n", 405 | " #seq len, n layers and n directions will always be 1 in the decoder, therefore:\n", 406 | " #output = [1, batch size, hid dim]\n", 407 | " #hidden = [1, batch size, hid dim]\n", 408 | " \n", 409 | " output = torch.cat((embedded.squeeze(0), hidden.squeeze(0), context.squeeze(0)), \n", 410 | " dim = 1)\n", 411 | " \n", 412 | " #output = [batch size, emb dim + hid dim * 2]\n", 413 | " \n", 414 | " prediction = self.fc_out(output)\n", 415 | " \n", 416 | " #prediction = [batch size, output dim]\n", 417 | " \n", 418 | " return prediction, hidden" 419 | ] 420 | }, 421 | { 422 | "cell_type": "markdown", 423 | "metadata": {}, 424 | "source": [ 425 | "## Seq2Seq Model\n", 426 | "\n", 427 | "Putting the encoder and decoder together, we get:\n", 428 | "\n", 429 | "![](assets/seq2seq7.png)\n", 430 | "\n", 431 | "Again, in this implementation we need to ensure the hidden dimensions in both the encoder and the decoder are the same.\n", 432 | "\n", 433 | "Briefly going over all of the steps:\n", 434 | "- the `outputs` tensor is created to hold all predictions, $\\hat{Y}$\n", 435 | "- the source sequence, $X$, is fed into the encoder to receive a `context` vector\n", 436 | "- the initial decoder hidden state is set to be the `context` vector, $s_0 = z = h_T$\n", 437 | "- we use a batch of `` tokens as the first `input`, $y_1$\n", 438 | "- we then decode within a loop:\n", 439 | " - inserting the input token $y_t$, previous hidden state, $s_{t-1}$, and the context vector, $z$, into the decoder\n", 440 | " - receiving a prediction, $\\hat{y}_{t+1}$, and a new hidden state, $s_t$\n", 441 | " - we then decide if we are going to teacher force or not, setting the next input as appropriate (either the ground truth next token in the target sequence or the highest predicted next token)" 442 | ] 443 | }, 444 | { 445 | "cell_type": "code", 446 | "execution_count": 13, 447 | "metadata": {}, 448 | "outputs": [], 449 | "source": [ 450 | "class Seq2Seq(nn.Module):\n", 451 | " def __init__(self, encoder, decoder, device):\n", 452 | " super().__init__()\n", 453 | " \n", 454 | " self.encoder = encoder\n", 455 | " self.decoder = decoder\n", 456 | " self.device = device\n", 457 | " \n", 458 | " assert encoder.hid_dim == decoder.hid_dim, \\\n", 459 | " \"Hidden dimensions of encoder and decoder must be equal!\"\n", 460 | " \n", 461 | " def forward(self, src, trg, teacher_forcing_ratio = 0.5):\n", 462 | " \n", 463 | " #src = [src len, batch size]\n", 464 | " #trg = [trg len, batch size]\n", 465 | " #teacher_forcing_ratio is probability to use teacher forcing\n", 466 | " #e.g. if teacher_forcing_ratio is 0.75 we use ground-truth inputs 75% of the time\n", 467 | " \n", 468 | " batch_size = trg.shape[1]\n", 469 | " trg_len = trg.shape[0]\n", 470 | " trg_vocab_size = self.decoder.output_dim\n", 471 | " \n", 472 | " #tensor to store decoder outputs\n", 473 | " outputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(self.device)\n", 474 | " \n", 475 | " #last hidden state of the encoder is the context\n", 476 | " context = self.encoder(src)\n", 477 | " \n", 478 | " #context also used as the initial hidden state of the decoder\n", 479 | " hidden = context\n", 480 | " \n", 481 | " #first input to the decoder is the tokens\n", 482 | " input = trg[0,:]\n", 483 | " \n", 484 | " for t in range(1, trg_len):\n", 485 | " \n", 486 | " #insert input token embedding, previous hidden state and the context state\n", 487 | " #receive output tensor (predictions) and new hidden state\n", 488 | " output, hidden = self.decoder(input, hidden, context)\n", 489 | " \n", 490 | " #place predictions in a tensor holding predictions for each token\n", 491 | " outputs[t] = output\n", 492 | " \n", 493 | " #decide if we are going to use teacher forcing or not\n", 494 | " teacher_force = random.random() < teacher_forcing_ratio\n", 495 | " \n", 496 | " #get the highest predicted token from our predictions\n", 497 | " top1 = output.argmax(1) \n", 498 | " \n", 499 | " #if teacher forcing, use actual next token as next input\n", 500 | " #if not, use predicted token\n", 501 | " input = trg[t] if teacher_force else top1\n", 502 | "\n", 503 | " return outputs" 504 | ] 505 | }, 506 | { 507 | "cell_type": "markdown", 508 | "metadata": {}, 509 | "source": [ 510 | "# Training the Seq2Seq Model\n", 511 | "\n", 512 | "The rest of this tutorial is very similar to the previous one. \n", 513 | "\n", 514 | "We initialise our encoder, decoder and seq2seq model (placing it on the GPU if we have one). As before, the embedding dimensions and the amount of dropout used can be different between the encoder and the decoder, but the hidden dimensions must remain the same." 515 | ] 516 | }, 517 | { 518 | "cell_type": "code", 519 | "execution_count": 14, 520 | "metadata": {}, 521 | "outputs": [], 522 | "source": [ 523 | "INPUT_DIM = len(SRC.vocab)\n", 524 | "OUTPUT_DIM = len(TRG.vocab)\n", 525 | "ENC_EMB_DIM = 256\n", 526 | "DEC_EMB_DIM = 256\n", 527 | "HID_DIM = 512\n", 528 | "ENC_DROPOUT = 0.5\n", 529 | "DEC_DROPOUT = 0.5\n", 530 | "\n", 531 | "enc = Encoder(INPUT_DIM, ENC_EMB_DIM, HID_DIM, ENC_DROPOUT)\n", 532 | "dec = Decoder(OUTPUT_DIM, DEC_EMB_DIM, HID_DIM, DEC_DROPOUT)\n", 533 | "\n", 534 | "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", 535 | "\n", 536 | "model = Seq2Seq(enc, dec, device).to(device)" 537 | ] 538 | }, 539 | { 540 | "cell_type": "markdown", 541 | "metadata": {}, 542 | "source": [ 543 | "Next, we initialize our parameters. The paper states the parameters are initialized from a normal distribution with a mean of 0 and a standard deviation of 0.01, i.e. $\\mathcal{N}(0, 0.01)$. \n", 544 | "\n", 545 | "It also states we should initialize the recurrent parameters to a special initialization, however to keep things simple we'll also initialize them to $\\mathcal{N}(0, 0.01)$." 546 | ] 547 | }, 548 | { 549 | "cell_type": "code", 550 | "execution_count": 15, 551 | "metadata": {}, 552 | "outputs": [ 553 | { 554 | "data": { 555 | "text/plain": [ 556 | "Seq2Seq(\n", 557 | " (encoder): Encoder(\n", 558 | " (embedding): Embedding(7853, 256)\n", 559 | " (rnn): GRU(256, 512)\n", 560 | " (dropout): Dropout(p=0.5, inplace=False)\n", 561 | " )\n", 562 | " (decoder): Decoder(\n", 563 | " (embedding): Embedding(5893, 256)\n", 564 | " (rnn): GRU(768, 512)\n", 565 | " (fc_out): Linear(in_features=1280, out_features=5893, bias=True)\n", 566 | " (dropout): Dropout(p=0.5, inplace=False)\n", 567 | " )\n", 568 | ")" 569 | ] 570 | }, 571 | "execution_count": 15, 572 | "metadata": {}, 573 | "output_type": "execute_result" 574 | } 575 | ], 576 | "source": [ 577 | "def init_weights(m):\n", 578 | " for name, param in m.named_parameters():\n", 579 | " nn.init.normal_(param.data, mean=0, std=0.01)\n", 580 | " \n", 581 | "model.apply(init_weights)" 582 | ] 583 | }, 584 | { 585 | "cell_type": "markdown", 586 | "metadata": {}, 587 | "source": [ 588 | "We print out the number of parameters.\n", 589 | "\n", 590 | "Even though we only have a single layer RNN for our encoder and decoder we actually have **more** parameters than the last model. This is due to the increased size of the inputs to the GRU and the linear layer. However, it is not a significant amount of parameters and causes a minimal amount of increase in training time (~3 seconds per epoch extra)." 591 | ] 592 | }, 593 | { 594 | "cell_type": "code", 595 | "execution_count": 16, 596 | "metadata": {}, 597 | "outputs": [ 598 | { 599 | "name": "stdout", 600 | "output_type": "stream", 601 | "text": [ 602 | "The model has 14,219,781 trainable parameters\n" 603 | ] 604 | } 605 | ], 606 | "source": [ 607 | "def count_parameters(model):\n", 608 | " return sum(p.numel() for p in model.parameters() if p.requires_grad)\n", 609 | "\n", 610 | "print(f'The model has {count_parameters(model):,} trainable parameters')" 611 | ] 612 | }, 613 | { 614 | "cell_type": "markdown", 615 | "metadata": {}, 616 | "source": [ 617 | "We initiaize our optimizer." 618 | ] 619 | }, 620 | { 621 | "cell_type": "code", 622 | "execution_count": 17, 623 | "metadata": {}, 624 | "outputs": [], 625 | "source": [ 626 | "optimizer = optim.Adam(model.parameters())" 627 | ] 628 | }, 629 | { 630 | "cell_type": "markdown", 631 | "metadata": {}, 632 | "source": [ 633 | "We also initialize the loss function, making sure to ignore the loss on `` tokens." 634 | ] 635 | }, 636 | { 637 | "cell_type": "code", 638 | "execution_count": 18, 639 | "metadata": {}, 640 | "outputs": [], 641 | "source": [ 642 | "TRG_PAD_IDX = TRG.vocab.stoi[TRG.pad_token]\n", 643 | "\n", 644 | "criterion = nn.CrossEntropyLoss(ignore_index = TRG_PAD_IDX)" 645 | ] 646 | }, 647 | { 648 | "cell_type": "markdown", 649 | "metadata": {}, 650 | "source": [ 651 | "We then create the training loop..." 652 | ] 653 | }, 654 | { 655 | "cell_type": "code", 656 | "execution_count": 19, 657 | "metadata": {}, 658 | "outputs": [], 659 | "source": [ 660 | "def train(model, iterator, optimizer, criterion, clip):\n", 661 | " \n", 662 | " model.train()\n", 663 | " \n", 664 | " epoch_loss = 0\n", 665 | " \n", 666 | " for i, batch in enumerate(iterator):\n", 667 | " \n", 668 | " src = batch.src\n", 669 | " trg = batch.trg\n", 670 | " \n", 671 | " optimizer.zero_grad()\n", 672 | " \n", 673 | " output = model(src, trg)\n", 674 | " \n", 675 | " #trg = [trg len, batch size]\n", 676 | " #output = [trg len, batch size, output dim]\n", 677 | " \n", 678 | " output_dim = output.shape[-1]\n", 679 | " \n", 680 | " output = output[1:].view(-1, output_dim)\n", 681 | " trg = trg[1:].view(-1)\n", 682 | " \n", 683 | " #trg = [(trg len - 1) * batch size]\n", 684 | " #output = [(trg len - 1) * batch size, output dim]\n", 685 | " \n", 686 | " loss = criterion(output, trg)\n", 687 | " \n", 688 | " loss.backward()\n", 689 | " \n", 690 | " torch.nn.utils.clip_grad_norm_(model.parameters(), clip)\n", 691 | " \n", 692 | " optimizer.step()\n", 693 | " \n", 694 | " epoch_loss += loss.item()\n", 695 | " \n", 696 | " return epoch_loss / len(iterator)" 697 | ] 698 | }, 699 | { 700 | "cell_type": "markdown", 701 | "metadata": {}, 702 | "source": [ 703 | "...and the evaluation loop, remembering to set the model to `eval` mode and turn off teaching forcing." 704 | ] 705 | }, 706 | { 707 | "cell_type": "code", 708 | "execution_count": 20, 709 | "metadata": {}, 710 | "outputs": [], 711 | "source": [ 712 | "def evaluate(model, iterator, criterion):\n", 713 | " \n", 714 | " model.eval()\n", 715 | " \n", 716 | " epoch_loss = 0\n", 717 | " \n", 718 | " with torch.no_grad():\n", 719 | " \n", 720 | " for i, batch in enumerate(iterator):\n", 721 | "\n", 722 | " src = batch.src\n", 723 | " trg = batch.trg\n", 724 | "\n", 725 | " output = model(src, trg, 0) #turn off teacher forcing\n", 726 | "\n", 727 | " #trg = [trg len, batch size]\n", 728 | " #output = [trg len, batch size, output dim]\n", 729 | "\n", 730 | " output_dim = output.shape[-1]\n", 731 | " \n", 732 | " output = output[1:].view(-1, output_dim)\n", 733 | " trg = trg[1:].view(-1)\n", 734 | "\n", 735 | " #trg = [(trg len - 1) * batch size]\n", 736 | " #output = [(trg len - 1) * batch size, output dim]\n", 737 | "\n", 738 | " loss = criterion(output, trg)\n", 739 | "\n", 740 | " epoch_loss += loss.item()\n", 741 | " \n", 742 | " return epoch_loss / len(iterator)" 743 | ] 744 | }, 745 | { 746 | "cell_type": "markdown", 747 | "metadata": {}, 748 | "source": [ 749 | "We'll also define the function that calculates how long an epoch takes." 750 | ] 751 | }, 752 | { 753 | "cell_type": "code", 754 | "execution_count": 21, 755 | "metadata": {}, 756 | "outputs": [], 757 | "source": [ 758 | "def epoch_time(start_time, end_time):\n", 759 | " elapsed_time = end_time - start_time\n", 760 | " elapsed_mins = int(elapsed_time / 60)\n", 761 | " elapsed_secs = int(elapsed_time - (elapsed_mins * 60))\n", 762 | " return elapsed_mins, elapsed_secs" 763 | ] 764 | }, 765 | { 766 | "cell_type": "markdown", 767 | "metadata": {}, 768 | "source": [ 769 | "Then, we train our model, saving the parameters that give us the best validation loss." 770 | ] 771 | }, 772 | { 773 | "cell_type": "code", 774 | "execution_count": 22, 775 | "metadata": {}, 776 | "outputs": [ 777 | { 778 | "name": "stderr", 779 | "output_type": "stream", 780 | "text": [ 781 | "/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/batch.py:23: UserWarning: Batch class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n", 782 | " warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n" 783 | ] 784 | }, 785 | { 786 | "name": "stdout", 787 | "output_type": "stream", 788 | "text": [ 789 | "Epoch: 01 | Time: 0m 28s\n", 790 | "\tTrain Loss: 5.072 | Train PPL: 159.429\n", 791 | "\t Val. Loss: 5.041 | Val. PPL: 154.646\n", 792 | "Epoch: 02 | Time: 0m 28s\n", 793 | "\tTrain Loss: 4.366 | Train PPL: 78.718\n", 794 | "\t Val. Loss: 5.114 | Val. PPL: 166.280\n", 795 | "Epoch: 03 | Time: 0m 28s\n", 796 | "\tTrain Loss: 4.011 | Train PPL: 55.202\n", 797 | "\t Val. Loss: 4.613 | Val. PPL: 100.795\n", 798 | "Epoch: 04 | Time: 0m 28s\n", 799 | "\tTrain Loss: 3.612 | Train PPL: 37.050\n", 800 | "\t Val. Loss: 4.195 | Val. PPL: 66.323\n", 801 | "Epoch: 05 | Time: 0m 27s\n", 802 | "\tTrain Loss: 3.252 | Train PPL: 25.848\n", 803 | "\t Val. Loss: 3.981 | Val. PPL: 53.584\n", 804 | "Epoch: 06 | Time: 0m 28s\n", 805 | "\tTrain Loss: 2.953 | Train PPL: 19.173\n", 806 | "\t Val. Loss: 3.798 | Val. PPL: 44.601\n", 807 | "Epoch: 07 | Time: 0m 27s\n", 808 | "\tTrain Loss: 2.701 | Train PPL: 14.892\n", 809 | "\t Val. Loss: 3.653 | Val. PPL: 38.593\n", 810 | "Epoch: 08 | Time: 0m 28s\n", 811 | "\tTrain Loss: 2.463 | Train PPL: 11.735\n", 812 | "\t Val. Loss: 3.599 | Val. PPL: 36.558\n", 813 | "Epoch: 09 | Time: 0m 28s\n", 814 | "\tTrain Loss: 2.247 | Train PPL: 9.456\n", 815 | "\t Val. Loss: 3.563 | Val. PPL: 35.269\n", 816 | "Epoch: 10 | Time: 0m 28s\n", 817 | "\tTrain Loss: 2.090 | Train PPL: 8.086\n", 818 | "\t Val. Loss: 3.639 | Val. PPL: 38.051\n" 819 | ] 820 | } 821 | ], 822 | "source": [ 823 | "N_EPOCHS = 10\n", 824 | "CLIP = 1\n", 825 | "\n", 826 | "best_valid_loss = float('inf')\n", 827 | "\n", 828 | "for epoch in range(N_EPOCHS):\n", 829 | " \n", 830 | " start_time = time.time()\n", 831 | " \n", 832 | " train_loss = train(model, train_iterator, optimizer, criterion, CLIP)\n", 833 | " valid_loss = evaluate(model, valid_iterator, criterion)\n", 834 | " \n", 835 | " end_time = time.time()\n", 836 | " \n", 837 | " epoch_mins, epoch_secs = epoch_time(start_time, end_time)\n", 838 | " \n", 839 | " if valid_loss < best_valid_loss:\n", 840 | " best_valid_loss = valid_loss\n", 841 | " torch.save(model.state_dict(), 'tut2-model.pt')\n", 842 | " \n", 843 | " print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')\n", 844 | " print(f'\\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')\n", 845 | " print(f'\\t Val. Loss: {valid_loss:.3f} | Val. PPL: {math.exp(valid_loss):7.3f}')" 846 | ] 847 | }, 848 | { 849 | "cell_type": "markdown", 850 | "metadata": {}, 851 | "source": [ 852 | "Finally, we test the model on the test set using these \"best\" parameters." 853 | ] 854 | }, 855 | { 856 | "cell_type": "code", 857 | "execution_count": 23, 858 | "metadata": {}, 859 | "outputs": [ 860 | { 861 | "name": "stdout", 862 | "output_type": "stream", 863 | "text": [ 864 | "| Test Loss: 3.546 | Test PPL: 34.662 |\n" 865 | ] 866 | } 867 | ], 868 | "source": [ 869 | "model.load_state_dict(torch.load('tut2-model.pt'))\n", 870 | "\n", 871 | "test_loss = evaluate(model, test_iterator, criterion)\n", 872 | "\n", 873 | "print(f'| Test Loss: {test_loss:.3f} | Test PPL: {math.exp(test_loss):7.3f} |')" 874 | ] 875 | }, 876 | { 877 | "cell_type": "markdown", 878 | "metadata": {}, 879 | "source": [ 880 | "Just looking at the test loss, we get better performance than the previous model. This is a pretty good sign that this model architecture is doing something right! Relieving the information compression seems like the way forard, and in the next tutorial we'll expand on this even further with *attention*." 881 | ] 882 | } 883 | ], 884 | "metadata": { 885 | "kernelspec": { 886 | "display_name": "Python 3", 887 | "language": "python", 888 | "name": "python3" 889 | }, 890 | "language_info": { 891 | "codemirror_mode": { 892 | "name": "ipython", 893 | "version": 3 894 | }, 895 | "file_extension": ".py", 896 | "mimetype": "text/x-python", 897 | "name": "python", 898 | "nbconvert_exporter": "python", 899 | "pygments_lexer": "ipython3", 900 | "version": "3.8.5" 901 | } 902 | }, 903 | "nbformat": 4, 904 | "nbformat_minor": 2 905 | } -------------------------------------------------------------------------------- /3 - Neural Machine Translation by Jointly Learning to Align and Translate.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# 3 - Neural Machine Translation by Jointly Learning to Align and Translate\n", 8 | "\n", 9 | "In this third notebook on sequence-to-sequence models using PyTorch and TorchText, we'll be implementing the model from [Neural Machine Translation by Jointly Learning to Align and Translate](https://arxiv.org/abs/1409.0473). This model achives our best perplexity yet, ~27 compared to ~34 for the previous model.\n", 10 | "\n", 11 | "## Introduction\n", 12 | "\n", 13 | "As a reminder, here is the general encoder-decoder model:\n", 14 | "\n", 15 | "![](assets/seq2seq1.png)\n", 16 | "\n", 17 | "In the previous model, our architecture was set-up in a way to reduce \"information compression\" by explicitly passing the context vector, $z$, to the decoder at every time-step and by passing both the context vector and embedded input word, $d(y_t)$, along with the hidden state, $s_t$, to the linear layer, $f$, to make a prediction.\n", 18 | "\n", 19 | "![](assets/seq2seq7.png)\n", 20 | "\n", 21 | "Even though we have reduced some of this compression, our context vector still needs to contain all of the information about the source sentence. The model implemented in this notebook avoids this compression by allowing the decoder to look at the entire source sentence (via its hidden states) at each decoding step! How does it do this? It uses *attention*. \n", 22 | "\n", 23 | "Attention works by first, calculating an attention vector, $a$, that is the length of the source sentence. The attention vector has the property that each element is between 0 and 1, and the entire vector sums to 1. We then calculate a weighted sum of our source sentence hidden states, $H$, to get a weighted source vector, $w$. \n", 24 | "\n", 25 | "$$w = \\sum_{i}a_ih_i$$\n", 26 | "\n", 27 | "We calculate a new weighted source vector every time-step when decoding, using it as input to our decoder RNN as well as the linear layer to make a prediction. We'll explain how to do all of this during the tutorial.\n", 28 | "\n", 29 | "## Preparing Data\n", 30 | "\n", 31 | "Again, the preparation is similar to last time.\n", 32 | "\n", 33 | "First we import all the required modules." 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 1, 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "import torch\n", 43 | "import torch.nn as nn\n", 44 | "import torch.optim as optim\n", 45 | "import torch.nn.functional as F\n", 46 | "\n", 47 | "from torchtext.legacy.datasets import Multi30k\n", 48 | "from torchtext.legacy.data import Field, BucketIterator\n", 49 | "\n", 50 | "import spacy\n", 51 | "import numpy as np\n", 52 | "\n", 53 | "import random\n", 54 | "import math\n", 55 | "import time" 56 | ] 57 | }, 58 | { 59 | "cell_type": "markdown", 60 | "metadata": {}, 61 | "source": [ 62 | "Set the random seeds for reproducability." 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": 2, 68 | "metadata": {}, 69 | "outputs": [], 70 | "source": [ 71 | "SEED = 1234\n", 72 | "\n", 73 | "random.seed(SEED)\n", 74 | "np.random.seed(SEED)\n", 75 | "torch.manual_seed(SEED)\n", 76 | "torch.cuda.manual_seed(SEED)\n", 77 | "torch.backends.cudnn.deterministic = True" 78 | ] 79 | }, 80 | { 81 | "cell_type": "markdown", 82 | "metadata": {}, 83 | "source": [ 84 | "Load the German and English spaCy models." 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": 3, 90 | "metadata": {}, 91 | "outputs": [], 92 | "source": [ 93 | "spacy_de = spacy.load('de_core_news_sm')\n", 94 | "spacy_en = spacy.load('en_core_web_sm')" 95 | ] 96 | }, 97 | { 98 | "cell_type": "markdown", 99 | "metadata": {}, 100 | "source": [ 101 | "We create the tokenizers." 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": 4, 107 | "metadata": {}, 108 | "outputs": [], 109 | "source": [ 110 | "def tokenize_de(text):\n", 111 | " \"\"\"\n", 112 | " Tokenizes German text from a string into a list of strings\n", 113 | " \"\"\"\n", 114 | " return [tok.text for tok in spacy_de.tokenizer(text)]\n", 115 | "\n", 116 | "def tokenize_en(text):\n", 117 | " \"\"\"\n", 118 | " Tokenizes English text from a string into a list of strings\n", 119 | " \"\"\"\n", 120 | " return [tok.text for tok in spacy_en.tokenizer(text)]" 121 | ] 122 | }, 123 | { 124 | "cell_type": "markdown", 125 | "metadata": {}, 126 | "source": [ 127 | "The fields remain the same as before." 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": 5, 133 | "metadata": {}, 134 | "outputs": [ 135 | { 136 | "name": "stderr", 137 | "output_type": "stream", 138 | "text": [ 139 | "/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/field.py:150: UserWarning: Field class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n", 140 | " warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n" 141 | ] 142 | } 143 | ], 144 | "source": [ 145 | "SRC = Field(tokenize = tokenize_de, \n", 146 | " init_token = '', \n", 147 | " eos_token = '', \n", 148 | " lower = True)\n", 149 | "\n", 150 | "TRG = Field(tokenize = tokenize_en, \n", 151 | " init_token = '', \n", 152 | " eos_token = '', \n", 153 | " lower = True)" 154 | ] 155 | }, 156 | { 157 | "cell_type": "markdown", 158 | "metadata": {}, 159 | "source": [ 160 | "Load the data." 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": 6, 166 | "metadata": {}, 167 | "outputs": [ 168 | { 169 | "name": "stderr", 170 | "output_type": "stream", 171 | "text": [ 172 | "/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/example.py:78: UserWarning: Example class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n", 173 | " warnings.warn('Example class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.', UserWarning)\n" 174 | ] 175 | } 176 | ], 177 | "source": [ 178 | "train_data, valid_data, test_data = Multi30k.splits(exts = ('.de', '.en'), \n", 179 | " fields = (SRC, TRG))" 180 | ] 181 | }, 182 | { 183 | "cell_type": "markdown", 184 | "metadata": {}, 185 | "source": [ 186 | "Build the vocabulary." 187 | ] 188 | }, 189 | { 190 | "cell_type": "code", 191 | "execution_count": 7, 192 | "metadata": {}, 193 | "outputs": [], 194 | "source": [ 195 | "SRC.build_vocab(train_data, min_freq = 2)\n", 196 | "TRG.build_vocab(train_data, min_freq = 2)" 197 | ] 198 | }, 199 | { 200 | "cell_type": "markdown", 201 | "metadata": {}, 202 | "source": [ 203 | "Define the device." 204 | ] 205 | }, 206 | { 207 | "cell_type": "code", 208 | "execution_count": 8, 209 | "metadata": {}, 210 | "outputs": [], 211 | "source": [ 212 | "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')" 213 | ] 214 | }, 215 | { 216 | "cell_type": "markdown", 217 | "metadata": {}, 218 | "source": [ 219 | "Create the iterators." 220 | ] 221 | }, 222 | { 223 | "cell_type": "code", 224 | "execution_count": 9, 225 | "metadata": {}, 226 | "outputs": [ 227 | { 228 | "name": "stderr", 229 | "output_type": "stream", 230 | "text": [ 231 | "/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/iterator.py:48: UserWarning: BucketIterator class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n", 232 | " warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n" 233 | ] 234 | } 235 | ], 236 | "source": [ 237 | "BATCH_SIZE = 128\n", 238 | "\n", 239 | "train_iterator, valid_iterator, test_iterator = BucketIterator.splits(\n", 240 | " (train_data, valid_data, test_data), \n", 241 | " batch_size = BATCH_SIZE,\n", 242 | " device = device)" 243 | ] 244 | }, 245 | { 246 | "cell_type": "markdown", 247 | "metadata": {}, 248 | "source": [ 249 | "## Building the Seq2Seq Model\n", 250 | "\n", 251 | "### Encoder\n", 252 | "\n", 253 | "First, we'll build the encoder. Similar to the previous model, we only use a single layer GRU, however we now use a *bidirectional RNN*. With a bidirectional RNN, we have two RNNs in each layer. A *forward RNN* going over the embedded sentence from left to right (shown below in green), and a *backward RNN* going over the embedded sentence from right to left (teal). All we need to do in code is set `bidirectional = True` and then pass the embedded sentence to the RNN as before. \n", 254 | "\n", 255 | "![](assets/seq2seq8.png)\n", 256 | "\n", 257 | "We now have:\n", 258 | "\n", 259 | "$$\\begin{align*}\n", 260 | "h_t^\\rightarrow &= \\text{EncoderGRU}^\\rightarrow(e(x_t^\\rightarrow),h_{t-1}^\\rightarrow)\\\\\n", 261 | "h_t^\\leftarrow &= \\text{EncoderGRU}^\\leftarrow(e(x_t^\\leftarrow),h_{t-1}^\\leftarrow)\n", 262 | "\\end{align*}$$\n", 263 | "\n", 264 | "Where $x_0^\\rightarrow = \\text{}, x_1^\\rightarrow = \\text{guten}$ and $x_0^\\leftarrow = \\text{}, x_1^\\leftarrow = \\text{morgen}$.\n", 265 | "\n", 266 | "As before, we only pass an input (`embedded`) to the RNN, which tells PyTorch to initialize both the forward and backward initial hidden states ($h_0^\\rightarrow$ and $h_0^\\leftarrow$, respectively) to a tensor of all zeros. We'll also get two context vectors, one from the forward RNN after it has seen the final word in the sentence, $z^\\rightarrow=h_T^\\rightarrow$, and one from the backward RNN after it has seen the first word in the sentence, $z^\\leftarrow=h_T^\\leftarrow$.\n", 267 | "\n", 268 | "The RNN returns `outputs` and `hidden`. \n", 269 | "\n", 270 | "`outputs` is of size **[src len, batch size, hid dim * num directions]** where the first `hid_dim` elements in the third axis are the hidden states from the top layer forward RNN, and the last `hid_dim` elements are hidden states from the top layer backward RNN. We can think of the third axis as being the forward and backward hidden states concatenated together other, i.e. $h_1 = [h_1^\\rightarrow; h_{T}^\\leftarrow]$, $h_2 = [h_2^\\rightarrow; h_{T-1}^\\leftarrow]$ and we can denote all encoder hidden states (forward and backwards concatenated together) as $H=\\{ h_1, h_2, ..., h_T\\}$.\n", 271 | "\n", 272 | "`hidden` is of size **[n layers * num directions, batch size, hid dim]**, where **[-2, :, :]** gives the top layer forward RNN hidden state after the final time-step (i.e. after it has seen the last word in the sentence) and **[-1, :, :]** gives the top layer backward RNN hidden state after the final time-step (i.e. after it has seen the first word in the sentence).\n", 273 | "\n", 274 | "As the decoder is not bidirectional, it only needs a single context vector, $z$, to use as its initial hidden state, $s_0$, and we currently have two, a forward and a backward one ($z^\\rightarrow=h_T^\\rightarrow$ and $z^\\leftarrow=h_T^\\leftarrow$, respectively). We solve this by concatenating the two context vectors together, passing them through a linear layer, $g$, and applying the $\\tanh$ activation function. \n", 275 | "\n", 276 | "$$z=\\tanh(g(h_T^\\rightarrow, h_T^\\leftarrow)) = \\tanh(g(z^\\rightarrow, z^\\leftarrow)) = s_0$$\n", 277 | "\n", 278 | "**Note**: this is actually a deviation from the paper. Instead, they feed only the first backward RNN hidden state through a linear layer to get the context vector/decoder initial hidden state. This doesn't seem to make sense to me, so we have changed it.\n", 279 | "\n", 280 | "As we want our model to look back over the whole of the source sentence we return `outputs`, the stacked forward and backward hidden states for every token in the source sentence. We also return `hidden`, which acts as our initial hidden state in the decoder." 281 | ] 282 | }, 283 | { 284 | "cell_type": "code", 285 | "execution_count": 10, 286 | "metadata": {}, 287 | "outputs": [], 288 | "source": [ 289 | "class Encoder(nn.Module):\n", 290 | " def __init__(self, input_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout):\n", 291 | " super().__init__()\n", 292 | " \n", 293 | " self.embedding = nn.Embedding(input_dim, emb_dim)\n", 294 | " \n", 295 | " self.rnn = nn.GRU(emb_dim, enc_hid_dim, bidirectional = True)\n", 296 | " \n", 297 | " self.fc = nn.Linear(enc_hid_dim * 2, dec_hid_dim)\n", 298 | " \n", 299 | " self.dropout = nn.Dropout(dropout)\n", 300 | " \n", 301 | " def forward(self, src):\n", 302 | " \n", 303 | " #src = [src len, batch size]\n", 304 | " \n", 305 | " embedded = self.dropout(self.embedding(src))\n", 306 | " \n", 307 | " #embedded = [src len, batch size, emb dim]\n", 308 | " \n", 309 | " outputs, hidden = self.rnn(embedded)\n", 310 | " \n", 311 | " #outputs = [src len, batch size, hid dim * num directions]\n", 312 | " #hidden = [n layers * num directions, batch size, hid dim]\n", 313 | " \n", 314 | " #hidden is stacked [forward_1, backward_1, forward_2, backward_2, ...]\n", 315 | " #outputs are always from the last layer\n", 316 | " \n", 317 | " #hidden [-2, :, : ] is the last of the forwards RNN \n", 318 | " #hidden [-1, :, : ] is the last of the backwards RNN\n", 319 | " \n", 320 | " #initial decoder hidden is final hidden state of the forwards and backwards \n", 321 | " # encoder RNNs fed through a linear layer\n", 322 | " hidden = torch.tanh(self.fc(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim = 1)))\n", 323 | " \n", 324 | " #outputs = [src len, batch size, enc hid dim * 2]\n", 325 | " #hidden = [batch size, dec hid dim]\n", 326 | " \n", 327 | " return outputs, hidden" 328 | ] 329 | }, 330 | { 331 | "cell_type": "markdown", 332 | "metadata": {}, 333 | "source": [ 334 | "### Attention\n", 335 | "\n", 336 | "Next up is the attention layer. This will take in the previous hidden state of the decoder, $s_{t-1}$, and all of the stacked forward and backward hidden states from the encoder, $H$. The layer will output an attention vector, $a_t$, that is the length of the source sentence, each element is between 0 and 1 and the entire vector sums to 1.\n", 337 | "\n", 338 | "Intuitively, this layer takes what we have decoded so far, $s_{t-1}$, and all of what we have encoded, $H$, to produce a vector, $a_t$, that represents which words in the source sentence we should pay the most attention to in order to correctly predict the next word to decode, $\\hat{y}_{t+1}$. \n", 339 | "\n", 340 | "First, we calculate the *energy* between the previous decoder hidden state and the encoder hidden states. As our encoder hidden states are a sequence of $T$ tensors, and our previous decoder hidden state is a single tensor, the first thing we do is `repeat` the previous decoder hidden state $T$ times. We then calculate the energy, $E_t$, between them by concatenating them together and passing them through a linear layer (`attn`) and a $\\tanh$ activation function. \n", 341 | "\n", 342 | "$$E_t = \\tanh(\\text{attn}(s_{t-1}, H))$$ \n", 343 | "\n", 344 | "This can be thought of as calculating how well each encoder hidden state \"matches\" the previous decoder hidden state.\n", 345 | "\n", 346 | "We currently have a **[dec hid dim, src len]** tensor for each example in the batch. We want this to be **[src len]** for each example in the batch as the attention should be over the length of the source sentence. This is achieved by multiplying the `energy` by a **[1, dec hid dim]** tensor, $v$.\n", 347 | "\n", 348 | "$$\\hat{a}_t = v E_t$$\n", 349 | "\n", 350 | "We can think of $v$ as the weights for a weighted sum of the energy across all encoder hidden states. These weights tell us how much we should attend to each token in the source sequence. The parameters of $v$ are initialized randomly, but learned with the rest of the model via backpropagation. Note how $v$ is not dependent on time, and the same $v$ is used for each time-step of the decoding. We implement $v$ as a linear layer without a bias.\n", 351 | "\n", 352 | "Finally, we ensure the attention vector fits the constraints of having all elements between 0 and 1 and the vector summing to 1 by passing it through a $\\text{softmax}$ layer.\n", 353 | "\n", 354 | "$$a_t = \\text{softmax}(\\hat{a_t})$$\n", 355 | "\n", 356 | "This gives us the attention over the source sentence!\n", 357 | "\n", 358 | "Graphically, this looks something like below. This is for calculating the very first attention vector, where $s_{t-1} = s_0 = z$. The green/teal blocks represent the hidden states from both the forward and backward RNNs, and the attention computation is all done within the pink block.\n", 359 | "\n", 360 | "![](assets/seq2seq9.png)" 361 | ] 362 | }, 363 | { 364 | "cell_type": "code", 365 | "execution_count": 11, 366 | "metadata": {}, 367 | "outputs": [], 368 | "source": [ 369 | "class Attention(nn.Module):\n", 370 | " def __init__(self, enc_hid_dim, dec_hid_dim):\n", 371 | " super().__init__()\n", 372 | " \n", 373 | " self.attn = nn.Linear((enc_hid_dim * 2) + dec_hid_dim, dec_hid_dim)\n", 374 | " self.v = nn.Linear(dec_hid_dim, 1, bias = False)\n", 375 | " \n", 376 | " def forward(self, hidden, encoder_outputs):\n", 377 | " \n", 378 | " #hidden = [batch size, dec hid dim]\n", 379 | " #encoder_outputs = [src len, batch size, enc hid dim * 2]\n", 380 | " \n", 381 | " batch_size = encoder_outputs.shape[1]\n", 382 | " src_len = encoder_outputs.shape[0]\n", 383 | " \n", 384 | " #repeat decoder hidden state src_len times\n", 385 | " hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)\n", 386 | " \n", 387 | " encoder_outputs = encoder_outputs.permute(1, 0, 2)\n", 388 | " \n", 389 | " #hidden = [batch size, src len, dec hid dim]\n", 390 | " #encoder_outputs = [batch size, src len, enc hid dim * 2]\n", 391 | " \n", 392 | " energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim = 2))) \n", 393 | " \n", 394 | " #energy = [batch size, src len, dec hid dim]\n", 395 | "\n", 396 | " attention = self.v(energy).squeeze(2)\n", 397 | " \n", 398 | " #attention= [batch size, src len]\n", 399 | " \n", 400 | " return F.softmax(attention, dim=1)" 401 | ] 402 | }, 403 | { 404 | "cell_type": "markdown", 405 | "metadata": {}, 406 | "source": [ 407 | "### Decoder\n", 408 | "\n", 409 | "Next up is the decoder. \n", 410 | "\n", 411 | "The decoder contains the attention layer, `attention`, which takes the previous hidden state, $s_{t-1}$, all of the encoder hidden states, $H$, and returns the attention vector, $a_t$.\n", 412 | "\n", 413 | "We then use this attention vector to create a weighted source vector, $w_t$, denoted by `weighted`, which is a weighted sum of the encoder hidden states, $H$, using $a_t$ as the weights.\n", 414 | "\n", 415 | "$$w_t = a_t H$$\n", 416 | "\n", 417 | "The embedded input word, $d(y_t)$, the weighted source vector, $w_t$, and the previous decoder hidden state, $s_{t-1}$, are then all passed into the decoder RNN, with $d(y_t)$ and $w_t$ being concatenated together.\n", 418 | "\n", 419 | "$$s_t = \\text{DecoderGRU}(d(y_t), w_t, s_{t-1})$$\n", 420 | "\n", 421 | "We then pass $d(y_t)$, $w_t$ and $s_t$ through the linear layer, $f$, to make a prediction of the next word in the target sentence, $\\hat{y}_{t+1}$. This is done by concatenating them all together.\n", 422 | "\n", 423 | "$$\\hat{y}_{t+1} = f(d(y_t), w_t, s_t)$$\n", 424 | "\n", 425 | "The image below shows decoding the first word in an example translation.\n", 426 | "\n", 427 | "![](assets/seq2seq10.png)\n", 428 | "\n", 429 | "The green/teal blocks show the forward/backward encoder RNNs which output $H$, the red block shows the context vector, $z = h_T = \\tanh(g(h^\\rightarrow_T,h^\\leftarrow_T)) = \\tanh(g(z^\\rightarrow, z^\\leftarrow)) = s_0$, the blue block shows the decoder RNN which outputs $s_t$, the purple block shows the linear layer, $f$, which outputs $\\hat{y}_{t+1}$ and the orange block shows the calculation of the weighted sum over $H$ by $a_t$ and outputs $w_t$. Not shown is the calculation of $a_t$." 430 | ] 431 | }, 432 | { 433 | "cell_type": "code", 434 | "execution_count": 12, 435 | "metadata": {}, 436 | "outputs": [], 437 | "source": [ 438 | "class Decoder(nn.Module):\n", 439 | " def __init__(self, output_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout, attention):\n", 440 | " super().__init__()\n", 441 | "\n", 442 | " self.output_dim = output_dim\n", 443 | " self.attention = attention\n", 444 | " \n", 445 | " self.embedding = nn.Embedding(output_dim, emb_dim)\n", 446 | " \n", 447 | " self.rnn = nn.GRU((enc_hid_dim * 2) + emb_dim, dec_hid_dim)\n", 448 | " \n", 449 | " self.fc_out = nn.Linear((enc_hid_dim * 2) + dec_hid_dim + emb_dim, output_dim)\n", 450 | " \n", 451 | " self.dropout = nn.Dropout(dropout)\n", 452 | " \n", 453 | " def forward(self, input, hidden, encoder_outputs):\n", 454 | " \n", 455 | " #input = [batch size]\n", 456 | " #hidden = [batch size, dec hid dim]\n", 457 | " #encoder_outputs = [src len, batch size, enc hid dim * 2]\n", 458 | " \n", 459 | " input = input.unsqueeze(0)\n", 460 | " \n", 461 | " #input = [1, batch size]\n", 462 | " \n", 463 | " embedded = self.dropout(self.embedding(input))\n", 464 | " \n", 465 | " #embedded = [1, batch size, emb dim]\n", 466 | " \n", 467 | " a = self.attention(hidden, encoder_outputs)\n", 468 | " \n", 469 | " #a = [batch size, src len]\n", 470 | " \n", 471 | " a = a.unsqueeze(1)\n", 472 | " \n", 473 | " #a = [batch size, 1, src len]\n", 474 | " \n", 475 | " encoder_outputs = encoder_outputs.permute(1, 0, 2)\n", 476 | " \n", 477 | " #encoder_outputs = [batch size, src len, enc hid dim * 2]\n", 478 | " \n", 479 | " weighted = torch.bmm(a, encoder_outputs)\n", 480 | " \n", 481 | " #weighted = [batch size, 1, enc hid dim * 2]\n", 482 | " \n", 483 | " weighted = weighted.permute(1, 0, 2)\n", 484 | " \n", 485 | " #weighted = [1, batch size, enc hid dim * 2]\n", 486 | " \n", 487 | " rnn_input = torch.cat((embedded, weighted), dim = 2)\n", 488 | " \n", 489 | " #rnn_input = [1, batch size, (enc hid dim * 2) + emb dim]\n", 490 | " \n", 491 | " output, hidden = self.rnn(rnn_input, hidden.unsqueeze(0))\n", 492 | " \n", 493 | " #output = [seq len, batch size, dec hid dim * n directions]\n", 494 | " #hidden = [n layers * n directions, batch size, dec hid dim]\n", 495 | " \n", 496 | " #seq len, n layers and n directions will always be 1 in this decoder, therefore:\n", 497 | " #output = [1, batch size, dec hid dim]\n", 498 | " #hidden = [1, batch size, dec hid dim]\n", 499 | " #this also means that output == hidden\n", 500 | " assert (output == hidden).all()\n", 501 | " \n", 502 | " embedded = embedded.squeeze(0)\n", 503 | " output = output.squeeze(0)\n", 504 | " weighted = weighted.squeeze(0)\n", 505 | " \n", 506 | " prediction = self.fc_out(torch.cat((output, weighted, embedded), dim = 1))\n", 507 | " \n", 508 | " #prediction = [batch size, output dim]\n", 509 | " \n", 510 | " return prediction, hidden.squeeze(0)" 511 | ] 512 | }, 513 | { 514 | "cell_type": "markdown", 515 | "metadata": {}, 516 | "source": [ 517 | "### Seq2Seq\n", 518 | "\n", 519 | "This is the first model where we don't have to have the encoder RNN and decoder RNN have the same hidden dimensions, however the encoder has to be bidirectional. This requirement can be removed by changing all occurences of `enc_dim * 2` to `enc_dim * 2 if encoder_is_bidirectional else enc_dim`. \n", 520 | "\n", 521 | "This seq2seq encapsulator is similar to the last two. The only difference is that the `encoder` returns both the final hidden state (which is the final hidden state from both the forward and backward encoder RNNs passed through a linear layer) to be used as the initial hidden state for the decoder, as well as every hidden state (which are the forward and backward hidden states stacked on top of each other). We also need to ensure that `hidden` and `encoder_outputs` are passed to the decoder. \n", 522 | "\n", 523 | "Briefly going over all of the steps:\n", 524 | "- the `outputs` tensor is created to hold all predictions, $\\hat{Y}$\n", 525 | "- the source sequence, $X$, is fed into the encoder to receive $z$ and $H$\n", 526 | "- the initial decoder hidden state is set to be the `context` vector, $s_0 = z = h_T$\n", 527 | "- we use a batch of `` tokens as the first `input`, $y_1$\n", 528 | "- we then decode within a loop:\n", 529 | " - inserting the input token $y_t$, previous hidden state, $s_{t-1}$, and all encoder outputs, $H$, into the decoder\n", 530 | " - receiving a prediction, $\\hat{y}_{t+1}$, and a new hidden state, $s_t$\n", 531 | " - we then decide if we are going to teacher force or not, setting the next input as appropriate" 532 | ] 533 | }, 534 | { 535 | "cell_type": "code", 536 | "execution_count": 13, 537 | "metadata": {}, 538 | "outputs": [], 539 | "source": [ 540 | "class Seq2Seq(nn.Module):\n", 541 | " def __init__(self, encoder, decoder, device):\n", 542 | " super().__init__()\n", 543 | " \n", 544 | " self.encoder = encoder\n", 545 | " self.decoder = decoder\n", 546 | " self.device = device\n", 547 | " \n", 548 | " def forward(self, src, trg, teacher_forcing_ratio = 0.5):\n", 549 | " \n", 550 | " #src = [src len, batch size]\n", 551 | " #trg = [trg len, batch size]\n", 552 | " #teacher_forcing_ratio is probability to use teacher forcing\n", 553 | " #e.g. if teacher_forcing_ratio is 0.75 we use teacher forcing 75% of the time\n", 554 | " \n", 555 | " batch_size = src.shape[1]\n", 556 | " trg_len = trg.shape[0]\n", 557 | " trg_vocab_size = self.decoder.output_dim\n", 558 | " \n", 559 | " #tensor to store decoder outputs\n", 560 | " outputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(self.device)\n", 561 | " \n", 562 | " #encoder_outputs is all hidden states of the input sequence, back and forwards\n", 563 | " #hidden is the final forward and backward hidden states, passed through a linear layer\n", 564 | " encoder_outputs, hidden = self.encoder(src)\n", 565 | " \n", 566 | " #first input to the decoder is the tokens\n", 567 | " input = trg[0,:]\n", 568 | " \n", 569 | " for t in range(1, trg_len):\n", 570 | " \n", 571 | " #insert input token embedding, previous hidden state and all encoder hidden states\n", 572 | " #receive output tensor (predictions) and new hidden state\n", 573 | " output, hidden = self.decoder(input, hidden, encoder_outputs)\n", 574 | " \n", 575 | " #place predictions in a tensor holding predictions for each token\n", 576 | " outputs[t] = output\n", 577 | " \n", 578 | " #decide if we are going to use teacher forcing or not\n", 579 | " teacher_force = random.random() < teacher_forcing_ratio\n", 580 | " \n", 581 | " #get the highest predicted token from our predictions\n", 582 | " top1 = output.argmax(1) \n", 583 | " \n", 584 | " #if teacher forcing, use actual next token as next input\n", 585 | " #if not, use predicted token\n", 586 | " input = trg[t] if teacher_force else top1\n", 587 | "\n", 588 | " return outputs" 589 | ] 590 | }, 591 | { 592 | "cell_type": "markdown", 593 | "metadata": {}, 594 | "source": [ 595 | "## Training the Seq2Seq Model\n", 596 | "\n", 597 | "The rest of this tutorial is very similar to the previous one.\n", 598 | "\n", 599 | "We initialise our parameters, encoder, decoder and seq2seq model (placing it on the GPU if we have one). " 600 | ] 601 | }, 602 | { 603 | "cell_type": "code", 604 | "execution_count": 14, 605 | "metadata": {}, 606 | "outputs": [], 607 | "source": [ 608 | "INPUT_DIM = len(SRC.vocab)\n", 609 | "OUTPUT_DIM = len(TRG.vocab)\n", 610 | "ENC_EMB_DIM = 256\n", 611 | "DEC_EMB_DIM = 256\n", 612 | "ENC_HID_DIM = 512\n", 613 | "DEC_HID_DIM = 512\n", 614 | "ENC_DROPOUT = 0.5\n", 615 | "DEC_DROPOUT = 0.5\n", 616 | "\n", 617 | "attn = Attention(ENC_HID_DIM, DEC_HID_DIM)\n", 618 | "enc = Encoder(INPUT_DIM, ENC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, ENC_DROPOUT)\n", 619 | "dec = Decoder(OUTPUT_DIM, DEC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, DEC_DROPOUT, attn)\n", 620 | "\n", 621 | "model = Seq2Seq(enc, dec, device).to(device)" 622 | ] 623 | }, 624 | { 625 | "cell_type": "markdown", 626 | "metadata": {}, 627 | "source": [ 628 | "We use a simplified version of the weight initialization scheme used in the paper. Here, we will initialize all biases to zero and all weights from $\\mathcal{N}(0, 0.01)$." 629 | ] 630 | }, 631 | { 632 | "cell_type": "code", 633 | "execution_count": 15, 634 | "metadata": {}, 635 | "outputs": [ 636 | { 637 | "data": { 638 | "text/plain": [ 639 | "Seq2Seq(\n", 640 | " (encoder): Encoder(\n", 641 | " (embedding): Embedding(7853, 256)\n", 642 | " (rnn): GRU(256, 512, bidirectional=True)\n", 643 | " (fc): Linear(in_features=1024, out_features=512, bias=True)\n", 644 | " (dropout): Dropout(p=0.5, inplace=False)\n", 645 | " )\n", 646 | " (decoder): Decoder(\n", 647 | " (attention): Attention(\n", 648 | " (attn): Linear(in_features=1536, out_features=512, bias=True)\n", 649 | " (v): Linear(in_features=512, out_features=1, bias=False)\n", 650 | " )\n", 651 | " (embedding): Embedding(5893, 256)\n", 652 | " (rnn): GRU(1280, 512)\n", 653 | " (fc_out): Linear(in_features=1792, out_features=5893, bias=True)\n", 654 | " (dropout): Dropout(p=0.5, inplace=False)\n", 655 | " )\n", 656 | ")" 657 | ] 658 | }, 659 | "execution_count": 15, 660 | "metadata": {}, 661 | "output_type": "execute_result" 662 | } 663 | ], 664 | "source": [ 665 | "def init_weights(m):\n", 666 | " for name, param in m.named_parameters():\n", 667 | " if 'weight' in name:\n", 668 | " nn.init.normal_(param.data, mean=0, std=0.01)\n", 669 | " else:\n", 670 | " nn.init.constant_(param.data, 0)\n", 671 | " \n", 672 | "model.apply(init_weights)" 673 | ] 674 | }, 675 | { 676 | "cell_type": "markdown", 677 | "metadata": {}, 678 | "source": [ 679 | "Calculate the number of parameters. We get an increase of almost 50% in the amount of parameters from the last model. " 680 | ] 681 | }, 682 | { 683 | "cell_type": "code", 684 | "execution_count": 16, 685 | "metadata": {}, 686 | "outputs": [ 687 | { 688 | "name": "stdout", 689 | "output_type": "stream", 690 | "text": [ 691 | "The model has 20,518,405 trainable parameters\n" 692 | ] 693 | } 694 | ], 695 | "source": [ 696 | "def count_parameters(model):\n", 697 | " return sum(p.numel() for p in model.parameters() if p.requires_grad)\n", 698 | "\n", 699 | "print(f'The model has {count_parameters(model):,} trainable parameters')" 700 | ] 701 | }, 702 | { 703 | "cell_type": "markdown", 704 | "metadata": {}, 705 | "source": [ 706 | "We create an optimizer." 707 | ] 708 | }, 709 | { 710 | "cell_type": "code", 711 | "execution_count": 17, 712 | "metadata": {}, 713 | "outputs": [], 714 | "source": [ 715 | "optimizer = optim.Adam(model.parameters())" 716 | ] 717 | }, 718 | { 719 | "cell_type": "markdown", 720 | "metadata": {}, 721 | "source": [ 722 | "We initialize the loss function." 723 | ] 724 | }, 725 | { 726 | "cell_type": "code", 727 | "execution_count": 18, 728 | "metadata": {}, 729 | "outputs": [], 730 | "source": [ 731 | "TRG_PAD_IDX = TRG.vocab.stoi[TRG.pad_token]\n", 732 | "\n", 733 | "criterion = nn.CrossEntropyLoss(ignore_index = TRG_PAD_IDX)" 734 | ] 735 | }, 736 | { 737 | "cell_type": "markdown", 738 | "metadata": {}, 739 | "source": [ 740 | "We then create the training loop..." 741 | ] 742 | }, 743 | { 744 | "cell_type": "code", 745 | "execution_count": 19, 746 | "metadata": {}, 747 | "outputs": [], 748 | "source": [ 749 | "def train(model, iterator, optimizer, criterion, clip):\n", 750 | " \n", 751 | " model.train()\n", 752 | " \n", 753 | " epoch_loss = 0\n", 754 | " \n", 755 | " for i, batch in enumerate(iterator):\n", 756 | " \n", 757 | " src = batch.src\n", 758 | " trg = batch.trg\n", 759 | " \n", 760 | " optimizer.zero_grad()\n", 761 | " \n", 762 | " output = model(src, trg)\n", 763 | " \n", 764 | " #trg = [trg len, batch size]\n", 765 | " #output = [trg len, batch size, output dim]\n", 766 | " \n", 767 | " output_dim = output.shape[-1]\n", 768 | " \n", 769 | " output = output[1:].view(-1, output_dim)\n", 770 | " trg = trg[1:].view(-1)\n", 771 | " \n", 772 | " #trg = [(trg len - 1) * batch size]\n", 773 | " #output = [(trg len - 1) * batch size, output dim]\n", 774 | " \n", 775 | " loss = criterion(output, trg)\n", 776 | " \n", 777 | " loss.backward()\n", 778 | " \n", 779 | " torch.nn.utils.clip_grad_norm_(model.parameters(), clip)\n", 780 | " \n", 781 | " optimizer.step()\n", 782 | " \n", 783 | " epoch_loss += loss.item()\n", 784 | " \n", 785 | " return epoch_loss / len(iterator)" 786 | ] 787 | }, 788 | { 789 | "cell_type": "markdown", 790 | "metadata": {}, 791 | "source": [ 792 | "...and the evaluation loop, remembering to set the model to `eval` mode and turn off teaching forcing." 793 | ] 794 | }, 795 | { 796 | "cell_type": "code", 797 | "execution_count": 20, 798 | "metadata": {}, 799 | "outputs": [], 800 | "source": [ 801 | "def evaluate(model, iterator, criterion):\n", 802 | " \n", 803 | " model.eval()\n", 804 | " \n", 805 | " epoch_loss = 0\n", 806 | " \n", 807 | " with torch.no_grad():\n", 808 | " \n", 809 | " for i, batch in enumerate(iterator):\n", 810 | "\n", 811 | " src = batch.src\n", 812 | " trg = batch.trg\n", 813 | "\n", 814 | " output = model(src, trg, 0) #turn off teacher forcing\n", 815 | "\n", 816 | " #trg = [trg len, batch size]\n", 817 | " #output = [trg len, batch size, output dim]\n", 818 | "\n", 819 | " output_dim = output.shape[-1]\n", 820 | " \n", 821 | " output = output[1:].view(-1, output_dim)\n", 822 | " trg = trg[1:].view(-1)\n", 823 | "\n", 824 | " #trg = [(trg len - 1) * batch size]\n", 825 | " #output = [(trg len - 1) * batch size, output dim]\n", 826 | "\n", 827 | " loss = criterion(output, trg)\n", 828 | "\n", 829 | " epoch_loss += loss.item()\n", 830 | " \n", 831 | " return epoch_loss / len(iterator)" 832 | ] 833 | }, 834 | { 835 | "cell_type": "markdown", 836 | "metadata": {}, 837 | "source": [ 838 | "Finally, define a timing function." 839 | ] 840 | }, 841 | { 842 | "cell_type": "code", 843 | "execution_count": 21, 844 | "metadata": {}, 845 | "outputs": [], 846 | "source": [ 847 | "def epoch_time(start_time, end_time):\n", 848 | " elapsed_time = end_time - start_time\n", 849 | " elapsed_mins = int(elapsed_time / 60)\n", 850 | " elapsed_secs = int(elapsed_time - (elapsed_mins * 60))\n", 851 | " return elapsed_mins, elapsed_secs" 852 | ] 853 | }, 854 | { 855 | "cell_type": "markdown", 856 | "metadata": {}, 857 | "source": [ 858 | "Then, we train our model, saving the parameters that give us the best validation loss." 859 | ] 860 | }, 861 | { 862 | "cell_type": "code", 863 | "execution_count": 22, 864 | "metadata": {}, 865 | "outputs": [ 866 | { 867 | "name": "stderr", 868 | "output_type": "stream", 869 | "text": [ 870 | "/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/batch.py:23: UserWarning: Batch class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n", 871 | " warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n" 872 | ] 873 | }, 874 | { 875 | "name": "stdout", 876 | "output_type": "stream", 877 | "text": [ 878 | "Epoch: 01 | Time: 0m 55s\n", 879 | "\tTrain Loss: 5.018 | Train PPL: 151.167\n", 880 | "\t Val. Loss: 4.869 | Val. PPL: 130.233\n", 881 | "Epoch: 02 | Time: 0m 55s\n", 882 | "\tTrain Loss: 4.143 | Train PPL: 63.018\n", 883 | "\t Val. Loss: 4.677 | Val. PPL: 107.422\n", 884 | "Epoch: 03 | Time: 0m 55s\n", 885 | "\tTrain Loss: 3.490 | Train PPL: 32.780\n", 886 | "\t Val. Loss: 3.803 | Val. PPL: 44.853\n", 887 | "Epoch: 04 | Time: 0m 55s\n", 888 | "\tTrain Loss: 2.913 | Train PPL: 18.421\n", 889 | "\t Val. Loss: 3.438 | Val. PPL: 31.138\n", 890 | "Epoch: 05 | Time: 0m 55s\n", 891 | "\tTrain Loss: 2.505 | Train PPL: 12.238\n", 892 | "\t Val. Loss: 3.300 | Val. PPL: 27.115\n", 893 | "Epoch: 06 | Time: 0m 56s\n", 894 | "\tTrain Loss: 2.207 | Train PPL: 9.088\n", 895 | "\t Val. Loss: 3.267 | Val. PPL: 26.227\n", 896 | "Epoch: 07 | Time: 0m 55s\n", 897 | "\tTrain Loss: 1.960 | Train PPL: 7.103\n", 898 | "\t Val. Loss: 3.220 | Val. PPL: 25.033\n", 899 | "Epoch: 08 | Time: 0m 55s\n", 900 | "\tTrain Loss: 1.745 | Train PPL: 5.725\n", 901 | "\t Val. Loss: 3.234 | Val. PPL: 25.376\n", 902 | "Epoch: 09 | Time: 0m 55s\n", 903 | "\tTrain Loss: 1.570 | Train PPL: 4.806\n", 904 | "\t Val. Loss: 3.249 | Val. PPL: 25.760\n", 905 | "Epoch: 10 | Time: 0m 55s\n", 906 | "\tTrain Loss: 1.461 | Train PPL: 4.311\n", 907 | "\t Val. Loss: 3.362 | Val. PPL: 28.854\n" 908 | ] 909 | } 910 | ], 911 | "source": [ 912 | "N_EPOCHS = 10\n", 913 | "CLIP = 1\n", 914 | "\n", 915 | "best_valid_loss = float('inf')\n", 916 | "\n", 917 | "for epoch in range(N_EPOCHS):\n", 918 | " \n", 919 | " start_time = time.time()\n", 920 | " \n", 921 | " train_loss = train(model, train_iterator, optimizer, criterion, CLIP)\n", 922 | " valid_loss = evaluate(model, valid_iterator, criterion)\n", 923 | " \n", 924 | " end_time = time.time()\n", 925 | " \n", 926 | " epoch_mins, epoch_secs = epoch_time(start_time, end_time)\n", 927 | " \n", 928 | " if valid_loss < best_valid_loss:\n", 929 | " best_valid_loss = valid_loss\n", 930 | " torch.save(model.state_dict(), 'tut3-model.pt')\n", 931 | " \n", 932 | " print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')\n", 933 | " print(f'\\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')\n", 934 | " print(f'\\t Val. Loss: {valid_loss:.3f} | Val. PPL: {math.exp(valid_loss):7.3f}')" 935 | ] 936 | }, 937 | { 938 | "cell_type": "markdown", 939 | "metadata": {}, 940 | "source": [ 941 | "Finally, we test the model on the test set using these \"best\" parameters." 942 | ] 943 | }, 944 | { 945 | "cell_type": "code", 946 | "execution_count": 23, 947 | "metadata": {}, 948 | "outputs": [ 949 | { 950 | "name": "stdout", 951 | "output_type": "stream", 952 | "text": [ 953 | "| Test Loss: 3.179 | Test PPL: 24.027 |\n" 954 | ] 955 | } 956 | ], 957 | "source": [ 958 | "model.load_state_dict(torch.load('tut3-model.pt'))\n", 959 | "\n", 960 | "test_loss = evaluate(model, test_iterator, criterion)\n", 961 | "\n", 962 | "print(f'| Test Loss: {test_loss:.3f} | Test PPL: {math.exp(test_loss):7.3f} |')" 963 | ] 964 | }, 965 | { 966 | "cell_type": "markdown", 967 | "metadata": {}, 968 | "source": [ 969 | "We've improved on the previous model, but this came at the cost of doubling the training time.\n", 970 | "\n", 971 | "In the next notebook, we'll be using the same architecture but using a few tricks that are applicable to all RNN architectures - packed padded sequences and masking. We'll also implement code which will allow us to look at what words in the input the RNN is paying attention to when decoding the output." 972 | ] 973 | } 974 | ], 975 | "metadata": { 976 | "kernelspec": { 977 | "display_name": "Python 3", 978 | "language": "python", 979 | "name": "python3" 980 | }, 981 | "language_info": { 982 | "codemirror_mode": { 983 | "name": "ipython", 984 | "version": 3 985 | }, 986 | "file_extension": ".py", 987 | "mimetype": "text/x-python", 988 | "name": "python", 989 | "nbconvert_exporter": "python", 990 | "pygments_lexer": "ipython3", 991 | "version": "3.8.5" 992 | } 993 | }, 994 | "nbformat": 4, 995 | "nbformat_minor": 2 996 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Ben Trevett 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch Seq2Seq 2 | 3 | ## Note: This repo only works with torchtext 0.9 or above which requires PyTorch 1.8 or above. If you are using torchtext 0.8 then please use [this](https://github.com/bentrevett/pytorch-seq2seq/tree/torchtext08) branch 4 | 5 | This repo contains tutorials covering understanding and implementing sequence-to-sequence (seq2seq) models using [PyTorch](https://github.com/pytorch/pytorch) 1.8, [torchtext](https://github.com/pytorch/text) 0.9 and [spaCy](https://spacy.io/) 3.0, using Python 3.8. 6 | 7 | **If you find any mistakes or disagree with any of the explanations, please do not hesitate to [submit an issue](https://github.com/bentrevett/pytorch-seq2seq/issues/new). I welcome any feedback, positive or negative!** 8 | 9 | ## Getting Started 10 | 11 | To install PyTorch, see installation instructions on the [PyTorch website](pytorch.org). 12 | 13 | To install torchtext: 14 | 15 | ``` bash 16 | pip install torchtext 17 | ``` 18 | 19 | We'll also make use of spaCy to tokenize our data. To install spaCy, follow the instructions [here](https://spacy.io/usage/) making sure to install both the English and German models with: 20 | 21 | ``` bash 22 | python -m spacy download en_core_web_sm 23 | python -m spacy download de_core_news_sm 24 | ``` 25 | 26 | ## Tutorials 27 | 28 | * 1 - [Sequence to Sequence Learning with Neural Networks](https://github.com/bentrevett/pytorch-seq2seq/blob/master/1%20-%20Sequence%20to%20Sequence%20Learning%20with%20Neural%20Networks.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/bentrevett/pytorch-seq2seq/blob/master/1%20-%20Sequence%20to%20Sequence%20Learning%20with%20Neural%20Networks.ipynb) 29 | 30 | This first tutorial covers the workflow of a PyTorch with torchtext seq2seq project. We'll cover the basics of seq2seq networks using encoder-decoder models, how to implement these models in PyTorch, and how to use torchtext to do all of the heavy lifting with regards to text processing. The model itself will be based off an implementation of [Sequence to Sequence Learning with Neural Networks](https://arxiv.org/abs/1409.3215), which uses multi-layer LSTMs. 31 | 32 | * 2 - [Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation](https://github.com/bentrevett/pytorch-seq2seq/blob/master/2%20-%20Learning%20Phrase%20Representations%20using%20RNN%20Encoder-Decoder%20for%20Statistical%20Machine%20Translation.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/bentrevett/pytorch-seq2seq/blob/master/2%20-%20Learning%20Phrase%20Representations%20using%20RNN%20Encoder-Decoder%20for%20Statistical%20Machine%20Translation.ipynb) 33 | 34 | Now we have the basic workflow covered, this tutorial will focus on improving our results. Building on our knowledge of PyTorch and torchtext gained from the previous tutorial, we'll cover a second second model, which helps with the information compression problem faced by encoder-decoder models. This model will be based off an implementation of [Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation](https://arxiv.org/abs/1406.1078), which uses GRUs. 35 | 36 | * 3 - [Neural Machine Translation by Jointly Learning to Align and Translate](https://github.com/bentrevett/pytorch-seq2seq/blob/master/3%20-%20Neural%20Machine%20Translation%20by%20Jointly%20Learning%20to%20Align%20and%20Translate.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/bentrevett/pytorch-seq2seq/blob/master/3%20-%20Neural%20Machine%20Translation%20by%20Jointly%20Learning%20to%20Align%20and%20Translate.ipynb) 37 | 38 | Next, we learn about attention by implementing [Neural Machine Translation by Jointly Learning to Align and Translate](https://arxiv.org/abs/1409.0473). This further allievates the information compression problem by allowing the decoder to "look back" at the input sentence by creating context vectors that are weighted sums of the encoder hidden states. The weights for this weighted sum are calculated via an attention mechanism, where the decoder learns to pay attention to the most relevant words in the input sentence. 39 | 40 | * 4 - [Packed Padded Sequences, Masking, Inference and BLEU](https://github.com/bentrevett/pytorch-seq2seq/blob/master/4%20-%20Packed%20Padded%20Sequences%2C%20Masking%2C%20Inference%20and%20BLEU.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/bentrevett/pytorch-seq2seq/blob/master/4%20-%20Packed%20Padded%20Sequences%2C%20Masking%2C%20Inference%20and%20BLEU.ipynb) 41 | 42 | In this notebook, we will improve the previous model architecture by adding *packed padded sequences* and *masking*. These are two methods commonly used in NLP. Packed padded sequences allow us to only process the non-padded elements of our input sentence with our RNN. Masking is used to force the model to ignore certain elements we do not want it to look at, such as attention over padded elements. Together, these give us a small performance boost. We also cover a very basic way of using the model for inference, allowing us to get translations for any sentence we want to give to the model and how we can view the attention values over the source sequence for those translations. Finally, we show how to calculate the BLEU metric from our translations. 43 | 44 | * 5 - [Convolutional Sequence to Sequence Learning](https://github.com/bentrevett/pytorch-seq2seq/blob/master/5%20-%20Convolutional%20Sequence%20to%20Sequence%20Learning.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/bentrevett/pytorch-seq2seq/blob/master/5%20-%20Convolutional%20Sequence%20to%20Sequence%20Learning.ipynb) 45 | 46 | We finally move away from RNN based models and implement a fully convolutional model. One of the downsides of RNNs is that they are sequential. That is, before a word is processed by the RNN, all previous words must also be processed. Convolutional models can be fully parallelized, which allow them to be trained much quicker. We will be implementing the [Convolutional Sequence to Sequence](https://arxiv.org/abs/1705.03122) model, which uses multiple convolutional layers in both the encoder and decoder, with an attention mechanism between them. 47 | 48 | * 6 - [Attention Is All You Need](https://github.com/bentrevett/pytorch-seq2seq/blob/master/6%20-%20Attention%20is%20All%20You%20Need.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/bentrevett/pytorch-seq2seq/blob/master/6%20-%20Attention%20is%20All%20You%20Need.ipynb) 49 | 50 | Continuing with the non-RNN based models, we implement the Transformer model from [Attention Is All You Need](https://arxiv.org/abs/1706.03762). This model is based soley on attention mechanisms and introduces Multi-Head Attention. The encoder and decoder are made of multiple layers, with each layer consisting of Multi-Head Attention and Positionwise Feedforward sublayers. This model is currently used in many state-of-the-art sequence-to-sequence and transfer learning tasks. 51 | 52 | ## References 53 | 54 | Here are some things I looked at while making these tutorials. Some of it may be out of date. 55 | 56 | - https://github.com/spro/practical-pytorch 57 | - https://github.com/keon/seq2seq 58 | - https://github.com/pengshuang/CNN-Seq2Seq 59 | - https://github.com/pytorch/fairseq 60 | - https://github.com/jadore801120/attention-is-all-you-need-pytorch 61 | - http://nlp.seas.harvard.edu/2018/04/03/attention.html 62 | - https://www.analyticsvidhya.com/blog/2019/06/understanding-transformers-nlp-state-of-the-art-models/ 63 | -------------------------------------------------------------------------------- /assets/convseq2seq0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StatQuest/pytorch-seq2seq/49df8404d938a6edbf729876405558cc2c2b3013/assets/convseq2seq0.png -------------------------------------------------------------------------------- /assets/convseq2seq0.xml: -------------------------------------------------------------------------------- 1 | 7V1bk5u4Ev418+gpxJ3HTSbZ3dpsKrU5OZvkZQsbzZiNDT4Yz4zz6w+YyxipAdkWkhich4mNMeC+99et1o3xdv38a+Jvln/GAV7d6FrwfGPc3ei67tpa9l9+ZF8ccfTywEMSBsUh9HLgc/gTlwer03ZhgLeNE9M4XqXhpnlwEUcRXqSNY36SxE/N0+7jVfOuG/8BUwc+L/wVffTvMEiXxVFXd16O/4bDh2V1Z2R7xSdrvzq5/CXbpR/ET0eHjHc3xtskjtPi1fr5LV7lxKvoUnzvfcun9YMlOEpZvmB9C57jSPv56cP+a6I97L2/v36ZIWQV13n0V7vyJ5ePm+4rGiTxLgpwfhl0Y7x5WoYp/rzxF/mnTxnXs2PLdL0qP6Yfq3zSR5yk+PnoUPmYv+J4jdNkn51SfmpoJclKmako+PTCAMcrjy2PiG9XJ/ol0x/qS7/QJXtRkuYkMpnqkWnmEnRyaTp5OkAny+ZAp9j9/s/+LsEf9qu/8Jf40/b7n+9nlkVRBQeZPpVv4yRdxg9x5K/evRx980I3LXv3cs6HON6U1PoXp+m+NA7+Lo2btMTPYfo1//qtVb77Vl4sf333fPxmX72Jst979KX87bfqevmbl68d3lXfu4+j9L2/Dlf5gd/w6hGn4cLvYug23iUL3EG0yqz5yQNOO84z3eLEnKKd8pHglZ+Gj00Lxp3XlbwyaYTWrxH34Wr1Nl7FyeG7BkaBhZ3s+DZN4h/46BPPdgzf5mRqzKYKWYCtMQEVMofSIGQ5Z6jQYpc81panSfVFvA4X5et/d+tNdaGHA/0PRwq9ss9UvjMVqXaH+ZtC9j/hJMyoiJNcSMLoAVK5t5k2hTjJCPQRP5Ufls+GjIvU0GBUQ0NTSg0NSg3Tp5gSoUwd0ibfErwNf/rzwwk5lTdxGKWHh7Pe3Fh32ZGc19uStPnbVfgQZa8XGWUzBhhvci3LbN/ql/KDNJeYN9uCd/85iM/MPIV/Jh+dNu2mTtsam07ziB5gr2iPxSu2erdztfwCr+hKc4vlVz/lGnHkKtymWJlkFFUaseJbhMjUj3GBZ3CHdbj39/f6YgE53MCe2xYvh+sq53DP0c6rw73c4erMHtdQyuPqtMtFlAhNyeFaqN/hQnnoYA63sszqO1yF0tAqru1VR0tXSh2r5x53IkoGrQr4xXOwnKtfvNwvGjqrX7TVUkSdUsQNjjcZkyftHG2r1zk6Qp0jHawo6hwVykYN1liVv3OEs1GLwP5NV2w2atBBMFevKyYbJamogNc1r15Xjte1WL2up5bXpQu4tCZOyeHahmrZKM0PRR2uStmozepw1QKHqucedzZqqYbSGjT4fY+jg5eYtK1zVUsujLHYOpWSC0+arYOTi7qRqkouHMHJhTesERWTXJBUlG5ETbplh1bXKZlPx1ItVDwn+5t6qGiy4qVVbq1IqFg997hDRUc1CMVkQKYEKNXpgj60V3dIf8TYwNJ7odpsVhcqVHaw8MAeTZexQkGmyYpv8beSLVJEtN6K7qcyT1r5oGqQSVJRvvmV0+woLBzp70e0BOmP5d26xsu/FiSAv03+9MbUnkNsLn9/a38097///O/TfiYV40THBrk2z72BK9KbsnKraX3ycnhXFyrKg4G/XdbllwssNGvHq+uIkTATaRlJjv7ZDREzDDaD/UuS+Puj08r8seu2oGN4kdjiiufK7/v73R+b8MePH4/7BP3nV+0f44+grmeqH1OAvB9Y4hz3QgG7zKVw73xus/oUcnG6T39I/CDMvn702Tt0Z71zjj67CxO8SMM4v9UT3qZ8QgGDaCEzTXGhAKxTo0EzOOpU1QDSr1OeTJ2qHvOqU506ZZKLPqTr1GhyX546xdp95WowNwXpFPeeqFepUzbZuiZdp0azPo+nTpmsOiUVJLdOGrwwWZ0iIVmROgVjtHQfDF7hdfYLn8Itzvm3W1OM5Fh7XOH7/FJnVR6PpeK46mhX78snbtc+dsZ5ehO3MS2acQgBnBusBunQkWGayXuU3ysK8t+b0fsgwLq9ymk8TxpctP+3y0faHCg1KzjzS3YCsjfPLx9mrx7y//F6joNA6TYesdLQj+YiS6RvRFOMNx1WXMQTlcPBvJli3HIxFAWDjobj3bqWg0y9+Os19NAgre3AtSlHHcTLXWA46pm7Vu6sBoGu6veyUgI0RejKYe1DRJqoPBtmzjkjkl4Hc/jXeDyl7B73xk+17R4JL8q3e1Msg1VMYLB7orAQmDnuZJnD3e7ZSCW7V0ngVOweCQHLt3s0vRVVrcs6B/moJGsrFNJ0qfbSGwtTuTOHf7eZ0bSXjlx7Sa9kedX2koT3RdpLmAFQ82wBBwfhI4gH54DurCRnDgjXFL3Jx7LTWPEBJCbg4uoe240fMYPO0IVWYYT9HLde+ft8jVFx2YwUxZWbd8sOH35U8+gofidez7OPg3Cd/Z1lB/11LvjlPbRlxuPy05MIcC3bkEA9pJAIWjpG9kxy00iPLpR+PDzvYW+Ex9ub/Gkq8Ziv4sWP7bG4XFlKVuJ0D2DpUJU4MH6ZIrzvsWOQUvvnPGWQqsDCbmBCEYirzw2b19xeAqHXGYdqD5axTREDRhozUoWkdsLVDzoV9SCBXOnqMUWosIRnmdRDalNb/aBTUQ8S75OuHqNBhhSA+0oUj0mtROF93Q86FbUiYSGRatXCAagFfkR4yVRwoQPy0/njLr5FjTZF8+3mQBcYhDqCqK4g1EUgFKT+g4FQoFudYrNoCUQ03SO4Rt6Q6xyVQSzm/sINDMg56oZpWsEwiEWtCbJizim2FNaAaa9yWFKVAymDV4hRDhKvkK4cU+w7K1EIBuVw5CqHMmiFGOUg0QrpykGTewrKAaAOoHJIrQPVjzkV5SAxB5HK0cIBB+DAWbk4Oxxwed5/YqtGvEupLPmaDnemw5BkDrZ4EhbN6gnUt9wKzfNFzKP2kGJ7cddPPu6Rvq5qg8tRRcbRzfRlEHZBM/1dclCo6O2rgW1zb4d0W8pPn3dV26gI6aOBqBQarIx01snKA7irFl0nJwEK1/VXMVuZJKMCfvCVD1dmUDaHVdl07nOYL+QdnabS5nZS/s+R5/9gDHo0PaUqdAUx7xePio2hpSke/w3iX+UIOTJBkD6WcTzgCVe1ulRZWkJCS6X1mwjYK16WTgpZwEkql/wF79NULtahp6iYCSvPZykz9lRMJyupH9IbxBFF7imoB7DpMlhyLcajydMO7pssq11zJbVDfs3VOmd90WKXPNb75DT5s4jX4aJ8/e9uvaku9HDg1OFIoTS2WAhju/SD+Kl8UwZF1e4/n4sM+gYAO7qyZ+MyAMTyQI3nHjAatnnraaZjo+Kv25BA4dWD6mc3i/xFjbzoV4+3ZNF8wtiKYfZjK1ApfLjagn1OwHm1GBwsho3EWAxTU8piVD/7yGLQIjgpm+C1bBEoAG+Nvm0/Br99NXbR5+9vfGv5JVjp49kzkG2Dv6H1y0LerXOE4RgNbpoW8anRvH7L/o/0XYiqNCs0dOouf+R9iF3+mJ+L166AoIRK7Uofn4TahggJtS0xEkrep09C255rUAmV2rExPgl1SBkcREIdR4yEkvfpk9C25xpUQqWuCR2fhLqudqvXZRrXakbRuYTqHCTU9cRIKHkfs8zjT30uXhLavi63ueHRU0xJ7aRSBbLziQ1v5JEptC8ObXBog+NNRvgpM8m0rT4mOSKZRPc33eNoofTeXgK4ZLlqcYnGUafdhe0YvaZuKFCko7TWAXTjK9Dd0UQPsM8eiH3Wt+A5jrSfnz7svybaw977++uX8dSN5TT1gjQrccDjqjN4nrzVXl2PPe61XjNb4mIvmMtnaNC1TnRzSp0IpDvQGgVbOLW0kHaXP59wOGmfOLOIARZQpWeo/A1k0mjmgslZVwbSDOjFEuQVYWhmZhEzgwZcVdZFkHGvKaOoKN3fnrPo5epvL/e3wB7TsL81lfK39M7EdMozKWdrWv3OdigEAeTQq59tPUAKCo0phL2trpQ28p9bKCUJJQNW6U4RnYPjXL3i5V4RmvoGn+iqpYg0GrTG0XaxxNG03SNL1+FQBRCYU6MZNapQMopY0SH+/rElGzXJGSeu2GwUrNOMLx0lySjf856DFV09LwfPywo48e+EukxgaFyI1sRpuVyJg8VgFr36wSpDpKSsAJGhFkCEaISIq2MUlJIaquG01W72jf6qxTKdeoKh2/0JxlBLHWH/OJrRGgolGNDMX0HWriXBIJfQmo7YBGPoWcKCEgySjNLtaEXGI7rSiMC0LChSLF4EGKKoBVUoXtSZIRpHqXhRHxhIERQv6qoBKSwzeJWcmT64Z9dJl8TYyNJ/pdpy9qzH4hYjyJkHPPJQExghLMhQtokR0YMrurMKGE08xlCTJKN8E3zOlKwRjevuh4qFaZCm3SL75Z+ly7XLo+nBUcguQwOl4ZZvUYGCRa5eHm5MbidF5M85VHpy9YzcsFHk6GqYcddlYGfov86q/9y3TLqM2TSgdNVSqNWAXD8gXUuvMNMZWsoKM5nS9l/pfO6rlnZrKTn6Q76WjmbxmEpayopxmGqBwYZUQOuUaVHcmc02foof5+1hcvNTJ0S9TAYmwK9W0JX8AnJ7vkDiQsQXLh4q1cmHq7/p9jckhC7f39D9S3iF19lPfAq3OGfgbk1xcoLbbc8cYrKRCcxeQQhg3WB1Y4NukUoziY/ye0VB/oMzgh9EuNxEfZ402FhtyZ6TalawJt+xHdmbZ3C/dhwESo+3EisOBFgGoe+Dbb8ObyA2murYZWDrqZu/dIKqvXGDZQkKEEGemqPJ2bnzZgB8m0BkBtwGrlPmFIiRhOwCRwHV0reB0199cs1TDQ1WEylq8zjYRNJKpChPufOGv4kk18cJN5H0cOLXbSLJKoF8Eyl1xL4sdbJYTZ0j1dSNps7KnTf8TR3ZqS/c1DE0rL4qU0eWWuSbutEkVypEg6xAe7GFgzQTOZrVaNx5M0CboS7ZRKpTVBBjIsnqgEgTCXMAWt1cgMlB+AiiyTkcPCvpmcPJNUmz1xDSfICYCbC5usd240fMkDV0oVUYYT9HvVf+Pl9WVlw2I0Vx5ebdssOHH9U8Oorfidfz7OMgXGd/Z+S8fG2Z8bj89CQCXKs+FMwPaWS9P/WxShpDqWRVUzhSyY+HB85EMI4eb2/yp6nkY76KFz+2x/Jy5SlVydM9gKdDVfLgSHQ0E0M4Ri0Wa9Jtm4Iiys7HlB+EBBZ2AxMKQlx9btjctkggUHudcZz6cHmanHVFkrWDdT6VLapA2fmYk9EOErCVrx2jWXzFUzuY2/5E1aY6H3My2kFifNK1w5giHmSxzlKzRZUzOh9zMtpBwjsitQPmAN1BOCrYYyrwzgHA6fxxF9+iBo2i+XZzoAuMJR0hTVcs6TIsCVJ/sViSrcwqscDH7j0IsNsLF8/vB8pta3LLssC2OuiCGA6Q+ZN8DqiTwYrhABmjy+eAOlmSGA6QcaB8DkhF2NRZh9g7Mf94vH6rJPBbVF7ZJtmLFvWWvRDb1yCSVr5v0SLZjSVk0aItFToTIfYN6W7Vgc79GBo7N1wg86zjjqqYVLbMW/qJMk9+oVfmyfUoQmS+Gus8dZlvlWTZcmeS05365I78Qq/ckU3+nXIHfJ1cwmxrt8bxP6dJkpbhcwzP1XPhgbvPHHq5Vo5RBDs//+oijqJygfgV28h7KgjeGXRICy6HHWx1tDOahQwKDVJxWHsyXLV2YXG4wyhSpmoj1XYPcBjAESWnavdLuqhZrYgIC1nHb3NzYzS6cjuky1J/Awik2oZhzmganBSan+t40lxVm6a3RO7CNB2q5nJ0goLmmpNklO4EXTkpszpRocs6LKWSQEWiQpfO4OikYFKuz1PM87lSwsmRJ2kua/nA5V4+uIzZUtfbqlNlEsD5Ss/lg5lgRNSBZTbO74UyyVhBCITuKtM1o/asSzLzFTnrEm6Vf/XeBtT6M20N6wIUV+qC/2oIpvo85c6b4TNY0Qv+XWV6sQQt+CdNpPSZKOPZyoWnOrGulyjCKmnx89TWS5DaIX29RBVWK8ABMZ2KJAekdyp6o7FPCgEGns5o35DGvQnmMm7rvPVNzmbJigHaHsOUzXFWdb2haz2SirgeBDmUK7MOy61wvCWXak0Z59abE14gnFtoM5I3msHTClV4PVYEYgC3BSfIkjeu9rinxzIKvMp5QznDlxQKD5nTX7XqSR6dD9NmdkpuT7W+Jm802KxC2lg7Mwa3p9bOwDV6fC3wimC+Ghsanrgy8MR1gTK2Mqx5IB/lU7O+q0w599VjgaCGn2tZWLfzQZrULc+M0TRD8+cO94QVebeu5SBTL/96DXNqZB97munYqPjril1gV4ukAoZ22HKvhOouTHEkddDweRHascHUbh2rx2Y2IrH8zSc/zXgeHY7omn5hfAYskIvd7//s7xL8Yb/6C3+JP22///l+5qoRnpG7R1cRVVvARZ5vIL4RF2zypZQfpJt81rWWZagvL61TZ0bVsH0HqrQZIERv5auUEdY5GOF+Db9Et4D5wCqbaZ3sYzb78mLCrHtCEuOpDOpSpdsFaeckQ4td8lgrX5Mfx/OreoddCaxqNIYMdY4jYsbUjQstCPelbLAmzyzPutU8K3NhxV/UUGzRjQb1727tNNheOw0aE0X1/qKL0FaDeueaq80QbTOQoDaEmWUpZTOq331kM2gZnJRVsOz+BiShpViEuI16pzefixZxcCBoZirmYZRZjusQ8BbJICf81NZa2pysOsgdWDKix6tcdMiFReBuQOpRz60RJBcQOs9HLqrr5FQEL0TsWmDqm2f6Ku8q6aqFqrjedSOCHu9kaoR3cgBhcwBhG2wjAqSDDc4KCdsdvgrbOcJmk5FQtZ+iCFkDqw1VtPpaqw0gzomqGksvIooMUdttdgKy8jHG+3tstzTfON5cu7QmXPVskntCuQB+YIsEHU3uY8TU5gAF4UAssF0+LMjeJnHuWF4S8Ixeyz8z35Kf8X8= -------------------------------------------------------------------------------- /assets/convseq2seq1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StatQuest/pytorch-seq2seq/49df8404d938a6edbf729876405558cc2c2b3013/assets/convseq2seq1.png -------------------------------------------------------------------------------- /assets/convseq2seq1.xml: -------------------------------------------------------------------------------- 1 | 7V1Lk6M4Ev41dawKhMTr2F3VPXOYnujY3o3dmcsGBtlmBoMX43JV//qVMGCQsMEYJFGmD9VGiJcyv0zlQ6kH+Lx5+yVxt+tvsY/DB13z3x7gy4OuA93WyH+05f3YYkL72LBKAj/vdGr4EfzEeWN+3Wof+HhX65jGcZgG23qjF0cR9tJam5sk8aHebRmH9adu3RXmGn54bsi3/jvw0/Wx1datU/uvOFitiycD0zme2bhF5/xLdmvXjw+VJvjlAT4ncZwef23ennFIB68Yl+N1X8+cLV8swVHa5YLY/vO/7y8J/u09/Af+V/x99+e3r4+Gkb9c+l58MfbJAOSHcZKu41UcueGXU+vnJN5HPqa31cjRqc9vcbwljYA0/oXT9D2nprtPY9K0Tjdhfha/Bel/6OVPRn70R34z+vvlrXrwXhxEafJeuYge/lHcjx6cLsuOiuuWcZR+dTdBSBt+xeErTgPPJSf48cuHdBfvEw9fGLSCD91khdML/VDO6nREK0/IyfMLjjeYvCnpkODQTYPXOsu5Oeeuyn4n4pIfOX2voHV+31c33OdP4mhfp+xhHaT4x9bNBuNA8F2n4jIIw+c4jJPsWoiBb2CLtO/SJP4bV844pgVd89KQv+IkxW8Xx6g4WwiUXJ4glB8fTugsmtYVYBZtw4+qYfWAkLdPXrNxBhyevHgTePnvv/abbXGjVTb+WcsRV2ZP8PUEUim/6MGR97/jJCCjiBPKJEG0aoLcM0FTgBMyQL/jQ34yfzcAb4Ih7AhDqCkFQ8jB8OcBBxwPETykdcIleBf8dBdZBzrM2ziI0uztjM8PxgtpocTe5WNLD8NgFZHfHhlaQgH4mcKMCL/wU34ipSzzeXck3j8z/nlE1xAQDQNqqNVBbejdQK2PBWrDnIpaPKve+sL8BrVoS9OL+aXfKSJObFXyR85W0Gb4JZdix6sYlilf4wbVYI+rcZfLpe55TRrXNxemMZDGZUdRAY3bB52zxr1d4+qdVS5USuXqvM4FHAvdlcI12hWuLlLhmhpHD0UVrkJ2aDGxbYWjoSsFx+K9p22JspNWBfRiH2fOrBdv14tQ76oXTbWAqHNA3OBo561xdNfqEWlGq3q0hKpHfrqiqHpUyB6FXWerw6vHZnsUIkZjaGLtUchPgwfVu2LsUW4U5etdNOtdOXrX6Kp3HbX0rsEBkUfiXSlcSzV7lKeHogpXJXvU7Kpw1XIPFe89bXsUqeanhbz7e4m9dXrnxoUB240LQ6isg1ORdSoZF440WddsXCCzDn/oCDYunHGFqBjjgh1F6UIU8Vk7PFzvSnw6qk0V+1h/9z5VRF09poVtrchUsXjvaU8VDdVcKKiDZ0oAqK5n9LG1usHqo44pLK03Mlh5eITsaNMDczKJxgpNMlFX/9bwUvIMFzHZt6IzqhDvOJvgJJMdRfniV066o7DpSHtGoiEIP47zZMPTvzOeAEEy2ZoN/x4yuWuWqyUqyxUwMpkVE2PL5MGzXM8Rm7NRr5feq8T1A3J55dwX8GJ8sSrnXoIEe2kQ00cd8C4dRujrTLoQNCULfVvKFHvidmuRKtCOfrUijsV7zyC9CFLIrhiQDtLJqGiVQNo18aeYh6sC0sHzcz4kSBFSDaSzB7gHSFFXkA7uarst2wtIFcmgSuuS8oKo7bu7dZmzJoD0AOhD0z679FOSuO+VDnl07KyJZbISx2ZKLbT0B9rl/qxDiOlPfhzfeFC7raDCrG0uahvWcS5d25h8thIO8YZ84SHYYUq//YYj5IAR4hAv6a16xYerXFGNDZvFcf7GDcVNrq8UUXeuQYunGwANhBstUGzxNlhK2D2iz4p8+rlkuDP+1c2QDvEiqRHR/N+e1orJBurxSJhPpAMwt2+nk+TXiv6PNwvs+zQjd0RWuC1ZQCgztHvcQVPS1XhzRnuCc0bNcK6ZR2RHZYb4eFNJq6tL1raUmkpa6nhObQ83a+CFbSDjose8vwu0LA4mS5U6019WdQNqBvdQNk9u2fU6HDlHjk8U33kvKGN9mPJRNplAwwgoG97F2IwyNnFVNMqK77wXlLFOSPkom0ykYAyUiUrHsyWjbPAajmqjjHW+SEeZ3ZSAdrTW/eC10Vyn9vZjPpzUXi9H9IGWo+VN+cyGZ6z54hnEoo46+wSabhQGEXapWyF032me/vG2ZCiOd64/jTRnH1VvncR34s2CnPaDDfn7SBrdDWX8/BnaOvCLs1cNwOxUA+14BE2rL+BogLSUkYhLbJ7xYlvOQhtIIrL1b0ubukIB0xY68eDjCL9n75sV5X59eqD8UOBzEcbe37sqXmdMMY7qEmPSHNXOnKt/fQqxI8idwopg0HdFCOGfJ2DpCEDLgUBjcpGJXHnSTNsxDSPrhOoPGTkz2ZlM/V2FktiLqLEIj566nNOnGvvMOYLsZ2QqzDlTjMjJ5xxdDOcYtsKco0yEwTew7aMmG8TWF9AcqiI4E8crLT5pFdg0ZbzPYijAxngUoIAynkkxFGD9/wpQQL8vCrC+YQUo0JS/PyGn6b04hzP378WPu/kRpcs5Wuy22bg0e6IrfurZE32d16wd/WI90UCb3WZ9DAhZfrOOxZev9JsRnpNoiQBtdpz14EGghONMNuvMnrM+rKOE50w268yusz6so4TrTDbrqOM7c7G9bIzfm56NF8thrFbWdya79FOp++6FAKzrTD4B1PGcCSEA6zmTTwB1HGdCCMDVlJROAKn1E9RZUd+610t1Y5izjDBceZSxJrfXLr83GJHdtvye7d+2/J5dOyFk+T0AUk11EVxfY+6zELi4kVBty6H+LN+5aN9Y7qhrWV4/U2j1rFPqcgWJ1q2xBbG8VA+pOix/lpElcx1iaxS2cB3bv43r2KWgF7mOv5qtxMFmPp6xVVtfQ/S28IDPWafhGn/v0mu9OIryOiYsWO4wzMPIKYOfqzbWbRgtNxqApkX7vYKf+lGnNcY/zwY/F41P6HqLvIXm5RMpVemhxft0u08r4caFvBDkeGVFBmBIh40jwQaWbLKfRmTJJh/a5FhyswiimSl7MqXWLiZNoTxZ3HkOCVxhpoDuwYNC6ihSyqZ882lvjGOqtv0XKIZxcjvjdGB2QTvjmex2G6Ln+zqfF/ikrqoSsIebqdp2v0CfTAlfhbbCAHrX/YlGUFfNWLfOeDDEYf1D7FDEDqMCevCDb1HUAWxWV7Dpg+9mdCPteH8XL27vSf9ZQDn9N2dw9YBkodU66D9dMUg6MumtTshdDPHNUSY/10Z3rDPxk3NhErZ/W3SHnTSIiSkWZJCfvaN0GXvWCJZexh7Aey6+O4JC6AZK0RULyw9VAJ9CShayQJNeshDAyVQGVSk7HwraUpkAVNnCFuUgyAevmDXlLHjlrymHk1lUqhR4x5lxXwleyUtrIL9RkCzwCslr5+I80rE7r6jsMzPunKuL1NrOpXzzacfALRZG8nHUYY3gRGPgxaCNrahsNqAjOi6Gmtw0eTGULJMLxzu2Oso9Rwis+lSiKUIgNu8VTd9dIyFCjnR5yuyMJDizmEqcJBjcISQjQs4Oo3wtieS4eRSaOaLO+woXPKjKzBHxfhbe7r8n/WfDOrzkR8jR7IjpA8nOSSto8IWXN9L7wy/PHT1C3p34xjipwNdGyO0rV92y/dsi5OykQUyEHCm0RY7CEXLWRJYfIUeTScoaRXUIKgzJglJ4hBx1cG99pAg5CzT5EfLCk6I+0FQKshmCSh/aCm9UVA6CfPCKiZCz4JUfITekBAMmD15Bi+8ug1dyhNxoKtgvB7xCIuRcFEg6dvvMcL198lqap3XyVAuktVZTE+jtrJWxuljwqrOjDd42uzZElV3X0JOjIcsEx792nQFFBx+Mprl2NQy5m8OQ1U0W7HY3rNgwpNlnqj5LjAEkhilouq8jpSSGyU/weRa8J5mgw/bUhIFCM+QwiWkxoBNFyZRr/S32Me3xfw== -------------------------------------------------------------------------------- /assets/convseq2seq2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StatQuest/pytorch-seq2seq/49df8404d938a6edbf729876405558cc2c2b3013/assets/convseq2seq2.png -------------------------------------------------------------------------------- /assets/convseq2seq2.xml: -------------------------------------------------------------------------------- 1 | 7V1bl5s2EP41+7h7kLj6cTdpkrZpT85Je5o89WAb27QYXIz3kl9fYYNtJIxlY5gBtC9rhIwF880nzUXDnf5u+foxdleL36KpF9xRbfp6p7+/o1TX9BH7l7a87VqoYRm7lnnsT3dt5NDw1f/hZY1a1rrxp9660DGJoiDxV8XGSRSG3iQptLlxHL0Uu82ioPirK3fuCQ1fJ24gtv7lT5PFrtWh9qH9k+fPF/kvEyu746Wbd87uZL1wp9HLUZP+053+Lo6iZPdp+frOC9Knlz+X3fc+nDi7H1jshYnMFx614Ocn/fdv/4W/fHr9/PYrGf8I7qmzu8yzG2yyO85Gm7zlj+DZixOfPZHP7tgLvkRrP/GjkJ0aR0kSLe/0p7zDY+DP0xNJtGKti2QZsAPCPrJbX6UXW77OU5g8jN21P3lwJ5vE+zuJfTecp7/1tIWN9mCyj3G0Cafe9PD93aPT2MEkWvqT7PM6iaN/vXdREMXbserWxPHGM3Zm5gfBUfvU9ZwZ+9ZTtHInfpKi0UwvkN09uwHv9eRzJXtpMZx70dJL4jfWJfuCaWUCziBumtklXg54oVrWZ3GEFTNrczOIzveXPkiRfcgEeYlQbSXUmkI1HHRCtZRQ6wpVQydUUwm1plB1ik6ouhJqTaEyVcUmVKqEWlOoBN1CydQFGXpTtvzPDqM4WUTzKHSDnw6tR088fS6HPp+jVJpbMfzjJclbZsu4myQqCtl79ZNvR5+/74WZHr1/za68PXjLD0J2v98OHdPD7/k10oPD17ZH+fem7nqxR8fuZtM7rJYgeyDRJp54FU/uhKRjL3AT/7l4/TKxZV/9Evnslw9qb44eHP3wZxbwwqbvB+vozzaK10/ceO4l2SUPyHiMY/ftqNsq7bA+PQidWyVY2X18kOxvnOlPrcr+7MNuxAdY7x/w9Ugn59mriOuXhZ94XxkHpGdfGBkVMTyLwuSDu/SD9CY+ecGzlxIbO+Fm1DZhkPJikV9mzsSbTMoYaeyYRlPzh6WJVGOUMI3RGNMYimmuYxraCNPoOgamcS5kGucyptENAKaRWCf1iWl48wOeaUzFNNcxjd4I0xgEAdMY1mVMw/c/xzQGBWAaCTO7T0zDe6/gmcZSTHMd0xjNMM0IAdOY/Mr7DHPw/c8xjVltbTXDNMbAmIZfl4Izja2Y5jqmMRthGtNGwDQWv/I+wxx8/7NMU21tNcM0EvGgPjENHzuHZxpHMc11TGM1wjQWBo+wfaFHmO9/jmksCI+wRDpBn5jGwuYRFhOvvjy+xyuCeexOffb1/FwYhSn1FWKMx8HHAs3cPpZoY3O7jbolzjZlxcd9wWVlEEEwLU7yh4n9+/G8ftNJ/obzOsnmhd1cWtGRjmquAGrJNB8m/Hy2z6A4nXNxA6USMmQotFJJCKAFLbsl8nVJ5OtauahaQj4aP3g7yBcS/sCRLyGAjiHfkEV+3TygeshH45dtB/lC/jI48iUE0DHkm7LIp6DIR+MnbAn5/HYMcORLCKBjyLdkkV83T6Ie8tH4rdpBvrC7DBz5EgLoGPJtWeTXjdvXQ77Err4+IZ/32MIjX0IAiJB/81gMn7hOHfJgFi+y0zUh4HL+Uga7lHb448Z2Io5zs+T4kUR6PCLJymRNjyQ5jdig03k+TgSkZnrO1CgjNYeOdctqxm1nGNCkJlEAABH0b05qfI58DVITLgVMarRTkpVJ0NakSQ10pZaPcyikxntk4UlNYlZBBP2bkxqfjl+D1IRLAZNa73ztOpEmtbrJnPVIDc0W0nZIjXe2g5OaKTGrIIL+zUmNz/yvQWrCpYBJrXdhFJ1Kk1rdvNF6pCaxRu4VqfE57eCk1i3Py+1JzbkdqfGXAia13kXIdNl8IGLboKSGJyGoFVLjQ2TwpNYtz8vNSY3fz1CD1IRLAZNa74KfumyqF7EdUFLDk+vVCqnx0U9wUiNUYqmmKuBdNFORfKYAq4BHciEqqd7MqMIgVVWC9Nb+PwxSlZgDlVQrpcqHqjBIVdWAritVPqsCg1RVufa6UuX3jWKQqipEc6WdSWhd8/FEJmFlgQjClYegnBuioZLBNFN22ZLBZ7oL2ZKF7s2UhyAyb4ZBs529gTmEaiLbtGxwjxTZXLtzt+6O3BN6W1k1uC2y4eM858hG6F/dndeDNsgGT8YcTNVgeLLJ03sU2VxONrQRsqkuHNwS2fCFgM+xh9i/ujuvB62QDR0W2fCeLARko967cjXZQNQObols+FrA59hD7F/dvdLqaops0MQtYWoHIyAb9UKEq8mmmaKe1eWDWyIbvhzwOfYQ+1eTTaXV1RTZoCmOAFM+GAHZKAfx1WTTjIO4uoJwS2TDVwQ+xx5i/2pugnAQ6wNzEPMZWQjIpmNFZ5HXEIYXqCG6QVELFLKKMAJpgbpMu1ZGOK8Pez7R2IQtsyRTyLZPdZaE5BkKrli92+RN8kqZEugH3RBJZEp69gn9QkIgPPp7t3GOGLK1JYkJusmEyBQ37BP6hSRnePR3q8ieFPodWfRbsIXkZWqB9Qr9/MYNePR3q3CRVEFt6aJdFgVFv0x9lT6hX9iMBo7+jhWDkEK/tNVrwVq9MnvW+4R+3puLAP39s3pre3LKwx1CPvyDdlwSYB+Ea6sqgInHZoYpYguuOv0zmWu7gcpVR8juhlYdPAY3TKlUcNXpn71d24dUrjpCrjK06uCx1mEKcoKrTv+M9doOqBPJcHw6C7DqWHhMfZiyj9CqY/XP0q/tvTqhOnweKbTq4PETwBQXBFed/rkJaru+JLMioVVnYG4CoYQduOqgeWsmTEmDe5KbMHAiQGPvw2z0xiACNO8RhNn+ikEEaEx3mE2BGEQgYbz3SQT8OhaBCGw0VjjMBhIMIiiz5qyA/e7T1H8uiML6bxOl7ez+k/vsgT6yHtkz3Z9nn+bZ/+111is3zNsC9y3tmlbt3p1koz4+f9S8/fliaxMjmkThcxRstoX3LhsTB9R0EEU0xt7a/+GOtx1SBGV7xFhv8+nOfJ/CkllV652BRY5QGniz9FLlxf/W6caKcP7H1ky7N6pQn57I7Ddi5cfZiMkN8Mwtbcrez0vKyvM1BmealxFSOzKkXnst+z4LSmBz0m00L7SYzWa0fKKYWmPLbCi+rlta/v4EqJmCwtaH6pxqyW73oAT0VTH7gQ5FtXh7HIVqgThwO6taprRqwe4lsfF4G1tRLd7PgkK1LKVaF6iW7DYtSkawqoXHi9yOajkIVQu00k3nVMuWVS0KuwfMxhMdaEW1eL8oCtVylGpdoFqy2ytPQqEt1cIT9WlFtXh/NwLVIiOJmEPHsmfskST+s/wHOPzjCbmBuPHu83pdYOh3+pexnK+MJdAP68R2JJinT+gXMl/g0S8x/3YN/UQa/bB+5nygQ0G/kHQEj/7+bfRyqDT6YV3B+UAHg34+3wse/b17cTtxZMP3Wa4bHPoHFr4XUu3g0Q/6CoSuOZQc6VqNNqyv1hlY8F5IoARXrFHZompI+ZMfP/+ZPtgJQ7urkiivwDQpvo/AyItKHydR0nYxXWYmDwnTKeamGzd9UpMoDL2JAvYVwL4vAlu3SoBtthpW00ZqFSS9CqKatHkBuQbaDxPBGghkT6flgAfVKFFp95colnRuMKxi4TEuQHbqolAslXR/iWLJZgbrsIqFJy8YZP81CsUC8f52VrFk84JBSxjvhzkYxXIQKhbom707p1iyWcGgiVv7YQ5FsfgADgrFMpRiXaBYsjnBoDlh+2EORbH4AA4KxRrYqoF3IFGHwMtgYBMM72tAIYOBcRFvlqKQAZ4MeRALBoMMctUcigz4xS4KGeDJFoZ5tTwGGdASGQwpDWDlTlXQ/1JCLwb9LU0M+lvtorgsPNwKirVok6w2aTnq9Bao5S5TWYXj9aqIIFiQ5wNOn5EfbserIH8R5IlA3QLkid4u5ssit61ACAmoVVHHGu6BIoMT+KKOtCxeOiQ4qxzb2xYqBU+xZYuggUNapdjWwnXF/oeWc2opzAYIPEhWVmJtK5HCW4kwWx7okZW46zWOS9CjQHXFwvWeaCXkSEpgRZuDVZk7vx3ng7Llb4Ahq8yWLzN+roAQO4yjVKT7cx9jd7X4LZp6aY//AQ== -------------------------------------------------------------------------------- /assets/convseq2seq3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StatQuest/pytorch-seq2seq/49df8404d938a6edbf729876405558cc2c2b3013/assets/convseq2seq3.png -------------------------------------------------------------------------------- /assets/convseq2seq3.xml: -------------------------------------------------------------------------------- 1 | 7V1bk5s4Fv41/dguhLg+pruTyVRNUqnNzs4kL1sYZJsMBi/G7XZ+/YqLMEiyjTEIEZOHjhE3oXO+c3Q+HUkP8Hn99lvsbFafIg8FD6rivT3AlwdVNQ0L/00LDnkBtGBesIx9Ly8Cx4Kv/k9UFCpF6c730LZ2YRJFQeJv6oVuFIbITWplThxH+/pliyiov3XjLBFT8NV1Arb0L99LVnmppZrH8o/IX67Im4Fh52fWDrm4+JLtyvGifaUIvn+Az3EUJfmv9dszCtK2I+2S3/fhxNmyYjEKkyY3RNb3/x5eYvTHIfgX+jP6sv3+6cOjrheVSw7ki5GHG6A4jOJkFS2j0AneH0uf4mgXeih9rIKPjtf8EUUbXAhw4Q+UJIdCms4uiXDRKlkHxVn05id/p7fP9OLoW/Gw9PfLW/XgQA7CJD5UbkoPv5HnpQfH27Ijct8iCpMPztoP0oKPKHhFie86+ATbfkWTbqNd7KIzjUb00ImXKDlznVboftqilTcU4vkNRWuEa4oviFHgJP5rXeWcQnOX5XVH4eIfhXyvkHXx3Fcn2BVvYmRfl+x+5Sfo68bJGmOP4V2X4sIPgucoiOLsXoiApyMTl2+TOPoHVc7Yhgkd41yTv6I4QW9n24ictQo4FfZEJ/DaH9GpFUWrCjBJWfetqpstIOTu4tesnQGDJzda+27x+8duvSEPWmbtn5XkuDJagq8lkEr7lR7kuv8FxT5uRRSnSuKHSx7knjGafBTjBvqM9sXJom4A3gRD2BCGUJEKhpCBYbKPGBXCcEjqcovR1v/pzLML0lbeRH6YZJXTnx70F1ySynpbNG16GPjLEP92cctiAcCnFGXY9gXvihNJqjFP21x2/87U51G7Rn5aN5iGSh3ThtIM02pfmNaNsXjFk96tLcpv8IrWYG6xuPVLioijWpX6UaiVZlD6Uhix/C5KZcpq3OAZrH4d7mKxUF2X53A9Y27oHTlcuhUlcLht0Dk53NsdrtrY40KpPK7KulzAqNBdOVz9ssNVRTpcYpnld7gShaGkX3sRjroqFRxJvccdiNKdVgn8YhsuZ/KLt/tFqDb1i4ZcQFQZIG5QtMFCvmfnqCn6RedoCnWObGdFUucoUTQKm/ZVu3eO/GgUalQ0aomNRiHbCe7U64qJRulWlMDrapPXHcbr6k29ri2X19UZILJIvCuHa8oWjbLykNThyhSNGk0drlzkEKn3uKNRTTaWFrLk9wKFmZe4Z1unq7IFF3Astk6m4MIezNbxg4tyaIsEF6bg4MLu14iKCS7oVhzciGpsyg4L17syn7ZsXcU20d+9dxW1pnwpia0l6SqSeo+7q6jLRqFoDZgpAaC6XtH79uo67Y8aJrBcfFBpNsmDcsj21j0wRpNlLFEnU2vKb3VvJU9oEZV6KzqfSmOJsxF2MulWHN78DpPsKKw7cjkfUReEH9ueWfD47wQTIMgmm1Pg38ImN81xNUXluALKJtORTnc2+cuTprz5SFv9/mx81g6///zP/vA4KFMOqipUKtTF8AeodW2ZKcoljcmOyuGuotBztqtyEE+ATlmmIJ3SFdwklX9GTcUgbKZi7+LYOVQuK1iIc6/ldi+OGps/sds+Rec52qeMFcOxXN/7WMaO5+PbK+fegxf9vVk59+LHyE38KH3VHm2TbjotKpXspmkDd1qsQQ3PSHkXkupy2XvJNWJO6j2B9CxIIT3fZXCQjqaLKRNImyaukThSFpB2nl/2S4JUo9MABwfpNILRAqRaU5B2ThXfBlJtAmkTjo4eTh8apAabpIQCtMZfuPe3KJXfbs0IssOB4QAt0ke1GhauakV1SNggx0WNOQuaXL86RJ1T03RWbgBwBNfb+LDJdl0TrO5h+q7QSz8XN3emv6oRpE08j2tCNP63S9eHyRrqMRfMO3wBMDZvx5P41zL9H63nyPOkTrESqgyXiXag94RiLlk2mnkc7R2jMN6TinVgf7zn2e+UwJFaLuI70rmlpwawFwIIGgN7RHs0YOqh+9k9P8NHGT3bRjjKOk8rlBtlNIMzPMpGQ7P2gLLuCRY+yui0U9EoI995LyijKZjhUTYanrQPlIlKprMGRlnnyy/KjTKaQxkcZRYvfSwPuj3/lRt1p2HzY9GcadhdtuhDupIsG5FnoTgVlJN34MA4bBza8x4U+CFyUnYgcA5pln3+WNwU+ZPrb8PF2UfVS0fxnWg9x6c9f43/PuJCZ50qfvEOZeV75OxVDTBxY+AyHgFv7gSd7tEdIE1pLOICGSfIaNOeKx1ZRHrpWtViJVBaTTEdD3Y44HNW32w97dfZQ6oPBJ/zIHL/2VbxOmGK4ptVm4MpoXyzPWXaX58AbAuiU2gTrLadz4H1ZwZMVQPQtCFQqExibFdmimHZhq5nF2n1l/ScV2yPZu1ciVLQgSKIa4CaxJrTZiH1SXMExc+aIbHmWJPmtNAcVYzm6JbEmiPNCIOnI8vTeDGIpc6h0dVq3tQ4ntpwqf3+1k9TpGGfxUiAHuORQALSMJNiJEDz/xJIQL0vCdDcsAQS4GUvj4g0vRdyOKN/z37cza8oKedwvt1k7cJnois89cREX8eaXUa/WCYaKBNt1iaAGIg3Aw2XTr6WN1OGjESAMhFnLXQQyEGcDaw6E3PWRnXkYM4GVp2JOmujOnJQZ6JUhz+bYTRG56SgL+9pKKyLg3smiq3jMLv4a9YkbdozW9FMA+R/SeaEqI0MFWlY0rnjWh7k8RPYQ2u61w0/QbOkZTQigJ/gY200fcMbsCayK1fFGpAJakAaOlwM1Gg6fHiojaZDdAvUBHZ95YWaNOMeYqBGj3sMDzX7HqAmMFSQF2rSDHCJgRo9wCUSaickwEu2bjXw03zs6fZBpisnB0S75IEekpnGXs6NvfAUs7dFEfiaSWogvxeQaAEi0Jyrgp2HVTfKu/Nu1xC7KBiy7RUDSDOObhuFBsouaBslg16bXfAK9+WHVmQ469NrSb/hjyHb3pBAHc16eRKtmw7UpptZ9OCu+Fg36SWNhWP9l9jOgm5GCfzgL76fRQOwmU3Bpna+9cWNsmOjVNbc3pP/M4F0/m80/LhMkCRerYH/65wvvE3epObDU3dSLxlLhw6DLxkL4D2vkNcDjJp1I0UvK1R+qAT4FLKuEA20wdcVAnA0y3fJlEIHBe1aiAEq7ezTshGGB6+YiV80eIef+AVHM/NDKvAaUoB30PzX8Yxb3dKNEihoabMXILv7wkBWWkz2Am2lh89e0NuQDu4ufi33Z6zLx43Wvlv8/rFbb8iDlpmkspIcgIZYNnC7crxoXxwUak12nfyak1E8q36OiIK3GQBdVAK8otUxXtNA4QNx5LPr6TJ5tkk+zTja0uknd0xT0ksb8mhKXlZJfzSl0cY7TxajA4thCMotVjWpLAb57IrFYFXwnmyCemqDcwFDF+G37Wfv499wF379/uToqz+9QB3PXtXNNpbufZqMbs9MSzeBpmZ/Cc1FCO4Tpy9EX+xrzHYE6rXbS9PvobaXblyvrraj5qrooEzi+FRUM4WoqGaLUVH6PZdU9FS9elXRQdOfxqeiuk3pYC8aatCLu/ekofR7LmnoqXr1qqGDMrrj01BTVWYqUUD8t96PTjVU7UBDTShGQ+n3aEUkf229etVQMvNwIOr4ChUtw9IyDs3vUssg9URc2ky3q9Ekt6WaZrIRSrhjsFytffT41gX7yFxvC9C+zrfdoDhqBUFk8DhqBVmKZXUTXVr0QrYC8234mB52OKgDTJuCMG2PG9NkU8ymmIZAAKY7X6CFwjTwvIXCwzRQTGijfjANBY478THNpmeg0I28jDg7Tk5N98rBDTvtjFOVpWbPLHj8p9WtNW9jdp5oeyMDATv7gi/a9dwPJ+FeI1wIOcLl7WvVhXC5qRnsLLpkH/UpOPnJeHqaTjPL2puEWPRtULTBDX/PQoKKfklIpkghsYHCAtvIdETxnqWkqXJJie363feUYd28aOr6Gnbkr3LEy/GtppKgKZXkzIxvjvg66kvgwzhK+3THQAzHQatPuBOYXvF/ -------------------------------------------------------------------------------- /assets/convseq2seq4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StatQuest/pytorch-seq2seq/49df8404d938a6edbf729876405558cc2c2b3013/assets/convseq2seq4.png -------------------------------------------------------------------------------- /assets/convseq2seq4.xml: -------------------------------------------------------------------------------- 1 | 7V1bl5s4Ev41/Zg+iLsf+5b0zHTnZDazu8m+zMFGttnByINx3379SLawQZINbkCSG/ohASGwqPq+QlUqSRfWzeLlSxos548ohPGFaYQvF9bthWmCEXDwf6TkdVvimKNtwSyNQlppX/A9eoO00KCl6yiEq1LFDKE4i5blwglKEjjJSmVBmqLncrUpisu/ugxmkCv4PglivvS/UZjNt6W+6e3L72E0m+e/DFz6fosgr0zfZDUPQvRcKLLuLqybFKFse7R4uYExEV4ul+19nw9c3TUshUlW54YrI/7l2vr64+/k1/uXh9ffwPgt/pS/x1MQr+kb09Zmr7kInmCaRVgiD8EYxt/QKsoilOBLY5RlaHFhXecVruJoRi5kaIlL59kixicAH+JXX5KHLV5mBCWX42AVTS6DyTqDf2ZpFCQz8lvXG9QYlw4+TNE6CWG4v38rOgOfTNAimtDjVZaiv+ANilG6aavlTnw4nuIr0yiOC+VhAP0pvusaLYNJlBEwOuQBvBCpXMk7wZdCERXqF4gWMEtfcRV61fapginCXYueP+/xYhq0bF7AikPLAgrR2e7Rey3iA6rIU5TqDkptqlRDO6U6g1IbKtUytVOqNSi1oVIxVXVTqjkotaFSgaubUh2b0yEMcU+RnqI0m6MZSoL4bl9akDiRy77OAyLa3Kjh/zDLXmm3N1hnqKxk+BJlP+jt5PjnTpnk7PalcOn2NT9J8Pv+2Fckpz/z55GT/W2bs/y+MFjNd+g4qMMVWqcTeAz923pZkM5gdqQe8KjtI2I8iokUxkEWPZU75CIF01u/oQg3em/1rdGlb+3/nBKy8If+0i38eXb5+dv3oI/cY+gqTYPXQrUlqbA60gimP+EZTGe9or5N3/tQfdM9Wh8fbFu8J8BOwA0MXbWdKzPgeR5l8Du2FuTqMzZbZbRPUZJ9DhZRTF7iHsZPkJhAfCGgRnCC0QhT3hJN/QmcTES2a+w7dmfdB583SrbAJtkt2KRHcH//mzW+TZa///7H7G7669V/prvew2CTKmySVdsm2Z3YpJOthX+itfBZ9lfUt4/W78Za1OjqfiRrwXmQEq2FUP52z+TPhmVUy7+G//6R5O+wXXjV8q8RFPtI8ndZv1i1/Eec/L9d3eqrgpJ/XOqHSHB3VevKBpxiZPYsd73Jn8XOZKs9y4Naq+xMgroebj7S1V5nspFO83arN4C7wNHhUFMLpOICg4ZqUtVQgASWNUB+XT8qj9bpgnxtXA85yOcCFcqRX0MBeiPfrot8oBfytXH65CCfc7qVI7+GAvRGvlMX+aZeyNfG3ZaEfDbcoRz5NRSgN/LdushvfSCrGfK1CXTIQT4XaJKIfOGwDBic50qfuEgqoRBbH4hpltLBB69UccqBfmiLOOWbY8t1u/GdHVM1p5R8PM6DU7m3XckpRytO5c3uC6dYr1w9p9SmD+jNqXyuQRWnXL04laOyJ5xi/X31nFKbJqg3p8yanPL04pTZM075unHKGTh1EJyCIRmhEH29OKXPiIwUTrExCpmcOpDpKLJqbox/9zqMnkqacP9ekzmK1/hds09Unle4BhXp7jo+mtH/N89ZLYMkL4uDV1LVAPlF3Ori9ULx5ufLpV20aIKSJxSvN/MuTmsTg1PSiDIYU7iK3oLxpgLBLk33xLWd6wvnlqASm5vV1vKAAkhjOCWPEs/9WJG0k2T2x8Z+fbKPgZ5coIYNuPk5bbFg3ujp4YHqiBsQzc7oDM5m7kwO34gjMy2qM1aAXu7MruXqvxPT6dQUZ4yF7th1Ooq7WS5BidIvhZn7lAO1jkwYqEEtvbyaXcv7Qi02/KYFtYao9hGA1s08MIFezs2u5X2hFhuF04Ja7kCtwwCtm9pgAs1ymD1tkhskUcvXkFreQK3DAPXqUivvj2hDrRpLQX0karExOQ2oBUYiHfQpKvfl4d9EsBOM92AIzb1nKll5eQtRoBmYcjHt9xzTBHPhOiCSoqsUDsA+HdifysC2XMAD25HaDwJnE3N+f3/GNLpZEUOwSg+z5IXJ6K2lhXjYhXI8u2KpDLZdp9aXsbTGTksa9N2kTG5n4+Sey3/m5DpF55P03cQYtJ5BKiaRDeQYA3YdnCpyc+06tb4cY6BP+FHJulwaGIOzies3MQatjxUfINFIjjGwWRhVkZtt16n15RgDfQKmSpbd0sAYnE2afRNj0ProtphEjifHGDgsjCrIzbXr1PpyjIE+IV4la8BpYAzOZn5AE2PQ+ni8mESuI8kYsDCqIDfXrlPryzEGorjsBzYG7HiPBsagZ30zNmpj+kD1iNsuLa8vOmCdZS100DNLxPooWuhAmxUO1HQNddBBTs2+6ID9ImuhA1Oggz6NVi+DcBibPtWYlMemRcv3unJRLJo3IwXFBlpnyzVxbsgrmG6wILpKxqtlGUFqQZ43mMgoSjbtHSB/EuQBZ7o5yANLLuZFPRg5mB8g9B4ImWWrKfLHhbNI2ShPe/vpcAAaVqg/4Lor3/tI2QdOky8YJvZkHQcZJMJNQvLvcolRpusXN8gyXHtIfGyc+GhbojR1IOpful2RDwwzQA7Kpu6E4A72P2ykU30yCIPQC8RfPtt1DLubD5pty5z9IWaVP7DqkGzqzgXuYAe/RjrVJxVPCqvY0QUdWDUaWHVINnWnAe9mtGvCKn3GTaWwih0v0oBVOc8HVh1cvrIGq5SttnSs3b1hFbtprgasOptJJPJZ5ddmlbKFlo61uy+sYsd0dWCV2pzLIqeUbpPe+USNA7uW7x6x5TiXjlk5E8M2mLa0NeODzQis2GadrV+1zTr3Hi3ndYrRrna15N6gnUsnfS/a2akGXaHdYSc7VqCdrV+Fdu49pKBd7TJfvUE7u5v4+9HO9ro7QrvLzuarQDtbvxLt7HvIQDsw1YaI3wN3UAS7cek5p+CdnHwjQ4dpsikhGTtcxxY6gERIBB1bY/PXyO0AeU35K+wfACoLPK8jArFL1ldNYmE7R7RdHRMid1DUu1fVi5bN0iCM8O2Fa3fg1rnzCtduo5QuEIRbB1fZMfDW98pY0Dimcq8MCDR1ZrbMPDtbpmy3kJq2zJFjyyz/RFvmSLFlrS92X8bW2IAWdIXYgr7h+92YGtFCXZINzfnFfxhDc3Z2RjBcr5WdcevZmfao3XoeAENtEIZToTqB4Vkj2A21NYjtApMfC4bJBIWbdM19aiLZ7QfLtpCqyEi/h4mFnjm69K39n1222x5vt4FIt52lYwOTH5EU63YxjpJBu6do13ZqZ5F2qF9+xgZcjGEYYhEOKqxSoZNPrykocNSR/v76V/blwXr7+r/Pj9/9tzR5vLr5qmr/4EYjz0xPo7J/I3xxk+/dHFuJr2l3pumqhj7tfbTlvRyTSYHK2TPqksS7OMy7aHyDdRxtviNf4TND5f1IOEdVAV4Ostfax1hyRTg8Y7v6oooZq8YTUs/Y2rt0bZfLk81YdqBYCmP5hP8lREus8l6T1ndqkNaTSlo1A/rqSVt7k6LtsnayScuOd0shLe/tTrFL1HF/WXvW2q5+rFWTmKCetXXzNOn6c9JZ6ytgLZ8Detlrvjpena6xaJua7vjKB2T7wdd8IY8ahLXbIKxoRw4WDBaj5Y72/pDC/Vy+5ZUKtpP0t7Pk0WpX0Pmkd+1Ng2nV+ZSLNvp5h2nApykiEeO9wtNgOX9EISQ1/gE= -------------------------------------------------------------------------------- /assets/convseq2seq5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StatQuest/pytorch-seq2seq/49df8404d938a6edbf729876405558cc2c2b3013/assets/convseq2seq5.png -------------------------------------------------------------------------------- /assets/convseq2seq5.xml: -------------------------------------------------------------------------------- 1 | 7Vzbcqs2FP0aPyaDuZk8+pJL2yQ9U3em6XnpyCCDcmTEAfn69ZWwsMFgjJsY1OAnoy2BpbX2kvaWsDvacLZ6DEHgvRAH4o6qOKuONuqoaldRLPbBLeutRdfUrcENkSMa7Q1jtIHJncI6Rw6MMg0pIZiiIGu0ie9Dm2ZsIAzJMttsSnD2WwPgwpxhbAOct/6FHOptrZba29ufIHK95Ju75t22ZgaSxmIkkQccskyZtPuONgwJodur2WoIMQcvwWV738OR2l3HQujTKjf0FfzLQHt9++n/+rR6Xv/WnWzwTTKOBcBzMWLRW7pOIFjAkCKGyDOYQPyNRIgi4rOqCaGUzDraIGnQx8jlFZQEzOrRGWaFLrtkQw/4w2Yrl3vJ7QREyL4F9pzCf2iIgO/y7xrEXqPcGuwyJHPfgc7+/i10CivYZIZscR3RkPyAQ4JJGPdVM20LTqasZoowTtkdAK0pu2tAAmAjyp3R4A/Igyhw5WOCq5RJgPoIyQzScM2aiFrdEgQLDzc1UV7u/UVVhM1L+YohbEC4qLt79J5FdiGIPIdU80rqR0lVpCPVuJL6QVI1VTpStSupHySVSVU2UtUrqR8ktWtKR+ppTvf4chSWHqJwzJDhtUtGUZa+KfHpA5ghzEf4BPECcrpZBRCE2ww9GOZRn1o2tO0iniaWoV9sqrTyBOgF+OuXwr/CRPmV8M/FH03jr7cM/8Ogvmn8K0R/Xwl/43ABaBr/uxz+3/ojeSnIrO4OiLzdul/DYt00V11pFutdiHQ8qPoEAnIhsNI0AdKs1vUQkIuWGidAmuW6HgJy4VLjBEizXtdEwGG81DgBFfZAvxIBuYCpcQJ6RauwiSkPSNAiw4T5c85PQQZsrPRG4NlnLQSku3p25YrP+DlRAPzEhsGaN1W6SSXrdbo+ZY6/Pmu9RI9s4i8Insc7O+f16cBPeSeyzhjCCG3AJG7AfTcgyKcxg8agY4y4V84pibanad2Uk2I4pUd3lyIeGvrun7wwYhlfidPzCnFYx5QmyqLHBSdT54c0p725W7T/8xnuTL7fj931OwiWb57zvjAW8Pe7m4LpAzouTMZMQuoRl/gA3++tg+wEs2/zTGKIOTHvkNK1QJJzluUZrhB96yQbeLz0t3gYvx6t0oW1KBxFPyLz0IYlklUVsWpRELqQlqCRBHgcglI6Q4gBRYvsYernb84p0iy39aTHhxF/z6xvti+Wh9YGeZh5eZRtFkujDmlioWY2r5tXh9EGdfQqqkOXTB0VXoH5Suo4TNWaV4fZBnVYFdVhSKYOq13qOMyj61THkbcJ2p5HB8C55s/nTvJGxouLjmTMer04f4BWkxcryA/m9OpB56aYWQ8qmgcLd2DUS4UJVXZ0a4gbKsUAP/6gj8/a5vX7w8vY2oT+S3/4mmSFp7PHJF+rPwYo63cKd7okl5TILgL4TyIZMh5QLP5XuDwQin6ZBNMyqgUJl9NGhRxGam3oeW0UD7Sx5LGs2ynYA0gCBmar1WFll44idfRqVUeFHEZqdVTed0z2MiSRR37FnkKfe2er9aGbsumjKDb/P+mj6t7KbjdDEn3k56XbVivD6J2Oq9RalVHhXFFqZSQ/o60gjcY2Hks7nt0vADPuqNtsn0Q7w8WTdumVo2qn1xSjVuVI8wpoMyfydb4EXXyqojcxU13sVKU0Gzw5uzX2ukppt9uijcO3c5s/ccwvLtdfc4has76JjBX3f38R16X+RES7/xc= -------------------------------------------------------------------------------- /assets/seq2seq1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StatQuest/pytorch-seq2seq/49df8404d938a6edbf729876405558cc2c2b3013/assets/seq2seq1.png -------------------------------------------------------------------------------- /assets/seq2seq1.xml: -------------------------------------------------------------------------------- 1 | 7Z1Ne6M2EIB/jY/NYxAIfNzNfvTQbfs0h9099cFGtmkwymIcO/n1FbYwIMkOMWLAiXwJiA8bjd6Z0cyIjNDtavc1DR6W32hI4pE9Dncj9Glk2xZyXfYnb3k6tHgeb1ikUchPKhvuomfCG8e8dROFZF07MaM0zqKHeuOMJgmZZbW2IE3ptn7anMb1b30IFkRquJsFsdz6PQqz5aHVt3HZ/juJFsvimy08ORyZBrP7RUo3Cf++kY3m+8/h8Coo7sUfdL0MQrqtNKHPI3SbUpodtla7WxLnfVt02+G6LyeOHn93SpKsyQWIX7HOnopnJyHrCr6b0IT9+bh/IJJfMmZ7M7qKZnx7ma1itmmxTbKLsh+V7Z/5KTduvpdk6dMPfsV+pzx27ID82H8ky574aAg2GWVNNM2WdEGTIP6D0gd++zlNsi/BKorz4XVLN2lEUvYQf5ItP8jvYWG2L3dJ8czswhl/aJsPsiBdEH4WOjTl3VG5jHfjV0JXhD0IOyElcZBFj/WRE/ABuDieVwqBbXA5qGXCf8tjEG9IMYoEGdUFsl1GGbl7CPYPs2VE1gVT6+JFHKzXCkGus5Tek1sa03T/Fci3pwjn/TeP4rjSHrrED53XCiGIo0XCdmZMCOy0M1J5JGlGdmd7vDiKOURcydgO39+WyBZNywqtRVsbGSFZJu+SGyRzU5wGDw4y4DQC50hAD+BY1stCupSkgTDh6h7//NK/acS+tVR/fl2KCAviOUDJrxIkdPwZjYTmSjJbbDKSSIJj4y8TRFKjgwuzCgZvkoZ5Ppoj5pB94AdWURjmX6Nktj46GmGrR+yXGyvHasacrYO58ZtnDsMwZ2Mw5rAksxVl9zbQtTB0oNBNtHuI3CssfMTh0OfD0OeOBb9l3Bl9DtIuvav077lkq/79RLewmxLlv2zFANx7PPPJdK5w7wPiz2dDcO8dvz/3/hjXegdaTzsIJ7SeKM4OtZ5jtF5FsrVoYG9qb2LUXiO15/YYDrQcSUjrkY3j3DFfb6Zsc5FvoqKN3a/SLMnTuPQnpYwFl97y4Fx6Bxv1uO/LIiVX0Y+O9qhXY5ejQUTRhH1zE2b36BfaLwvpykNQ+hOGan8QOWAxKGVGBQerHJi9HSN0fWwwpqwlj5DRKaR/nnadpgzLpszuy5I57yhkCBWxt+BihoqQvXFEVIqvz/gUBGKl3ivV4M+BOStAIXsRvy6dFVV8uOqsrI2zoo1ZUF9Ff1TyKplFQBMMMc3WIbPopQmGYVZjIggUWv2xsuuEFqgaTMwSdQmtohyM0tCQeXmuApRMH4DModAHlKPFYu6pu1lm8UgmH/jKTBHkNNPRv6zmKq2foz2W1oy/Dq1f8Uj1wswkShbGAF4OJ6QBdDztcF5l7LtI2VZj39rDQo2FIjuVz33YtanvOu5Ytmtzf0Zmg7BrjhjRBrRr7jteLlpLEQFlL5zxDap8HN+ti96d3ODKx3PqX3CgXcsKIHkJyVJR4DQ2BU4X432UKZwldGVfRiVVy0i1rdIWojOQVWvFuxje6nwfep1l07n9yzcSLbJGdS2nolVg2wbstmCL6zpBydY/kXlXZIurOS8mW7pRh2TLWW4V2abSvDXZ4uJRULLfeCUYdI3sxWRLN+qObNxsiuUYstuSLRbeQpKN5YoH1UohM+VqHyfrcc6Fm60HM/53aymLmXRIKXtvfGZde8EKVB2LEN0W7a3GF6zIGQdJnAAJhwn2UKAo1yZW6BLvtbKEeJ+A34wwHfkGD6KKbCCEFYoLulKsO8I82aM1hDUxYpCEySIZ7jrMcxUvJ8VQzdt5QPNPT7BiluhuNJ1/ioUtlljYom/+6b3xaGJN1wKNA1F8HeraYayCH76uFSuPIHWtvpeKXYOuBfJnsCeoSHShrhUdo6a69kOaBk+V0x7yE9bNHTCLj61yFB3ueLEq0Lde6gqGGdTEFJ+Q2uuHmegbdGjS5WSRWanRMooIGF6SK/1NpXH7ACGcAIuhY16fAvImMDjBko/z7/+sfv37c/frr2Cysbzn7f1v/vna2NJUfi5bhX5WmlOF2W22PIRtf9pVd56KnTM2d79TXrbfK65TGtoX51rqzhrLlvxQVNvCcLcT3zCCVSGeYlcxgZrP57a2+nMJMoUQm5YyHpOpABMopeAm52dVhrvaKLcV3PWKnW2wa4SdWGjYP3eyoAx3J4e54h8loV65G8Y/Sxo+d2IZYP/cGT/zFdw5CnsHtabx7C8y4L2ySq9/8MYGvObgnVpc3Bt3wyj3GT53YsSzd+7888XwhrtaZyleaDvplbthvBV1+NyJger+uTv/oijDXa2zPJk71C94ngGvEXhiIqFD8Nhu+b/dDwld1s3LbzQk+Rn/Aw== -------------------------------------------------------------------------------- /assets/seq2seq10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StatQuest/pytorch-seq2seq/49df8404d938a6edbf729876405558cc2c2b3013/assets/seq2seq10.png -------------------------------------------------------------------------------- /assets/seq2seq10.xml: -------------------------------------------------------------------------------- 1 | 7V1Lc6M4EP41Oc4UIMD4OOO8DjtbW5WqnclpSwYZM8GWC8uP5NevMMIGScbYgOSM8SWokQDp09fdajXkDoxm26cELqY/cIDiO8sItnfg/s6yTOA49E8qec8kA8vNBGESBazSQfASfSAmNJh0FQVoWapIMI5JtCgLfTyfI5+UZDBJ8KZcbYLj8l0XMESC4MWHsSj9GQVkmkm9vBep/BlF4TS/s+kOszNj6L+FCV7N2f3uLDDZ/bLTM5hfi3V0OYUB3hRE4OEOjBKMSXY0245QnI5tPmxZu8cjZ/fPnaA5qdMAsBZL8p73HQV0KFhxjuf0z/ddh1DaxKAlH88inx1PySymhyY9RNuI/Cocv6ZVvjppaU6S91+sxa5wOLcfgPTcb0TIO5sNcEUwFeGETHGI5zD+C+MFu/wEz8kjnEVxOr1GeJVEKKGd+Btt2El2DdOlZXFI8j7Thj7rtMUmGUxCxGqBTJQOR6EZG8YnhGeIdoRWSFAMSbQuzxzIJmC4r3cAgR4wHOSYsGdZw3iF8lnEYVQGZDONCHpZwF1nNpSRZWBKQxzGcLmUALkkCX5DIxzjZHcL4Flj4KbjN4niuCAPHOQF9rkgwDgK57TgUxBoNfocCQwiWihc+bvxAB7cwrn7KKHkjvB8NxOTlDrH0VyjhKBtJVL5WZeRjykny2blzYHquWhaYHkua4ItELG8Sb4BkW95NfWEAz3hOiXcnjkaCGeap8G9lIFXwiWnbd6wpv/giN71oDa9MorA5eDJyMxacQjtH6MWaI6AWbgiaC4AR+cf4SApsYqBWSQUEwn0SGdzRB3Ab+zELAqC9DZSrpdnRy26twP75UbONutxzmqDc8YfzzlXDecsVxnnXAGzGabX7knXwNCpJJ3t9p5lN26knJq2YXwFhZ9TdnGc4Ve38BvYXfHWquHg9N5rA1IDS5/3WicW8MktaftxFjlfga3MlEoXlC6cpUSLU9uJ8HIvCMl+pHojewkfVRpZYAhI3aaRZc5GKX7TNo/rgnL/8nv7+t/jEm9mw/Xzv69Pjh9+GVS7Q4dxejhIOeAkYDGAcrhOgHVd+MiHqfXgQW3TJi44er+lTb/F0+i2DFtXk1XMO2jN1yvzbTw1vo1pqPNtvBO+zbL3bVrjrErXxrnhneCSG6MotGeXwwe2d14AIVMzrQTbxWjtlDI4I/NqTA93NDZyGb1eQdyzuwa795gqpLMYOZChavaoNtXZfDaBSp3tVOrsOoucs92u7tV0lZE4ipfecK8JLgr3ntxjFbzxI2r/W5LA90K1RVphWeEwcvcxbYObd9kVL7Yp4uJOpn2sXvs01T4mn1uhUv0MevXzR6kfwZZ1pH74+7SufsRFqkz9gF79NFU/Fr85plL9VEeZevXz2dQPv0XXlfrh79O2+nHrrajtXv00VT/A0Kd+5PspzV2iWoE0fTFtebcVbeLbXgdJN+eqD5v3ktx21UflCBd0yocw1c7bL+O3wsaeYzsGp0nSF5g8H/m+BnbbPLsVbmFJQfDAJ2a3v0rWu8dIWwZwOd0XeLwnyN3hLcyQYDAcG9zemyVsvjVQI4zPp7etFW2sOfwM9Lip1ZZKcSo9kpP1gXmi/rCyfjcqa3iJLSzN0guoUzkzS7O+E5KdTyUFbPHsttnSzEcSX+/YtGzK9qN7DA/dpsy0NJuyYfV2bxvkFFfP1+a16rIiRj0r0prvKMvZaJVvru+h8UTkWwCRN9HCN43ZT3IdXJ052kE+opgVdW30G+qhX4fZUdJ+moaYiyrFclAB5kCDMyPQ3Nj9eEOQPvH9Nn/ktPDOCrXmxgT8nH48rckz/EGIvYrfXkaPX1xbm29zkQPv8DEBS0VMYCjR630uXkcmRGUynnyqi1a7dRNiVpiQK/TgPEkC35XluVc+d4G58Eq5qXulpDBn7oj1rg5jsCEvWcRWF0PHrbAqknVjYk/6abwv3rGblnezu1XS0B0AKIkSITNw0OAKTJynOyhxiYWrGZS4mnWRLpf2JN9Mfl3UXo65HG1xm3z56VOUdVPYHGi3l2K0KcQ46PGS46UOLulCe9A8LiXVlRKdWi9+dTSYUKXBd4VDs12pHITglOmFUQlDVOHZ+0IKFhKVT6T5LdrAHbuOdP9rYp2/33L0Ldrjb8oKBJQAfJST/IfNFO7NyEFtnth4Q5yUvcOnlZLX8UGez01J/vN1+jnZPLHvhjhpiZwEWjlp9ZxszEn+63b6Odn8Gy83xMmjn1zWRsrr+O7y5yYl/zUs/aRs/k7yDZGy7jZ3lsKjjamth8ZvkKl85KdDptLi4V+IZOHaw/9pAQ//Aw== -------------------------------------------------------------------------------- /assets/seq2seq2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StatQuest/pytorch-seq2seq/49df8404d938a6edbf729876405558cc2c2b3013/assets/seq2seq2.png -------------------------------------------------------------------------------- /assets/seq2seq2.xml: -------------------------------------------------------------------------------- 1 | 7Vxtj5s4EP41kXofdgU2EPi4zW5bqdequr3q2k8nAg6hJThHnE3SX38GTMC8JGxqDLshWmnxYAzM48cz45lkAmer/fvIXi8/YRcFE6C4+wm8nwCgTg2F/oslh1SiK2oq8CLfZZ1ywaP/CzEhu87b+i7acB0JxgHx17zQwWGIHMLJ7CjCO77bAgf8Xde2hyqCR8cOqtJ/fJcsU6kJjFz+AfneMruzaljpmbnt/PQivA3Z/SYALpJPenplZ2OxO2yWtot3qSh5d/gwgbMIY5IerfYzFMS6zdSWKuhdw9njc0coJG0ugNlzkEP27silqmDNEIf039vkhVB8iUJbDl75DjteklVAD1V6iPY++VY4/h53udXjVkiiwzd2RdLIzx0VEJ/7gQg5sNlgbwmmIhyRJfZwaAd/Yrxmwy9wSN7ZKz+Ip9cMbyMfRfQlPqMdO8nGUA3arqoke2d6ocNeGrBJZkceYr1gKorVUbiMqfE9witEX4R2iFBgE/+Jnzk2m4DesV8OAj1gONRjwp7lyQ62KJtFJYx4QHZLn6DHtZ28zI4ykgeGU7EX2JtNDZAbEuGfaIYDHCW3gCaYQyPW38IPgoLc1ZHpas8FwQ58L6QNh4JAu51A5QlFBO1PapydBQpbLdgiAzTW3uWUzUTLAlsz2e9gBKuYXCVvYJU3WTf5xIEjcdoRx+yPOKp6HqRLmTQQTuii5z+79Av26V1zFAGPIjRK8KSkZFeVEDo+RivQ9Apm3pagsAIcnX+kBAnHDgZmkRhMVJnm8Wz2qUN2x06sfNeNb1PLWX52tKKtGNgvN1aa2o5zQATnlFfPOUMO56AijXNGBbMVpmOPpPsNQyeTdJoxeohFd7DoImrCTWTr2KqF+zH6iMWVrQcfsU0E/MLtlfjdhQaDZUozWLXhl2GvYsIEsYVCeHMUeOSoqdGUXcJHmaYMKqMpS3RpVE0Z6MuSaZZwUBgQGSzDU3zXq6WqlYye0t1qWfXvR0ekbuFTQY+OiAyO5Qtfvg5+H5i3YvbDvy69FfOMt7IZvRVxpJUaeFd3KymQCaau/0QPGZoa/fs10R/U9DDrQ4fnuslfk+emrulKdU1emA5ynCGsyVrZTkpck/Urzlhzjqckn8hUbmHho5k88rp1axQ+U40fP7UiQnIQ+uv2d2Wni9r6tucHKrNaIObTEXOR6YqLMa8M1CHmrzyulb3jdznmTVUC4jE3xw2mVA81dt7srQ4te5pxi+KZuX3VkOcOm1VMrpM6NZVoptYbdcZStMsy9DKpY8GROkWaFKljgd6oo43UuSgjKNXqiA/GXyZ1pjVWp7f6Z1N8ydGgQqc6zctOSajlLWpxKYnslcaV75nZBakrn/m6OSZ7G7Itn84PVM4VCtyeGLekhG5DXox5ZaDuMLfEb0ldF+blwO5izJsixA4wBzUGOEsTgyxNLN8iv7yEsNQw/opjES7ukOQRl/lodsZHYpHwx9uPaBp9+vuv+8/BcnGn31jN5R2x2riJYPy3xdmJm02i0DvaQTXW+/wkXxHyZvmvkrKdPp7ipI0/KiUi6a0aakSGURgkgNY3sIT1tEprVe+ojqce/bqySeHoq0X01RF9BqreO/p1EbJQ9NUi99Vr5r6hc+BbvWNvdY09KGIPrhl7VRsY+GoW+HWHPiyiD68ZfVBa9weAfl0lglD0tSL62jWjD5X+uK+E1pMym97N72fkoxZ9hf7Dl5z7DQFeHl495NKSXmpDsJpQrd0XNujx/b7YOEwKgWHDVzmSRn5Z0jpwsJUCt7MpkFptZeFRtVRYQgLq5BP1nNZwjbmh16Q1FosFELaJUqFZDYit0xoQyNtEaWDe6Vr7kXncPG/6dmhvxBvGVwyHT7xy/d4AiFdFaiRe4zyvKcKAvRJvGIn8F0C88q9T9E+80zWBI/G4eW7WWDxZxU4nn2hk3jOLBztkHm3mv2ecZoHyH42GD/8D -------------------------------------------------------------------------------- /assets/seq2seq3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StatQuest/pytorch-seq2seq/49df8404d938a6edbf729876405558cc2c2b3013/assets/seq2seq3.png -------------------------------------------------------------------------------- /assets/seq2seq3.xml: -------------------------------------------------------------------------------- 1 | 7VxNc6M4EP01rto9JAWIz2MmycwedmenNofJ7GUKg4zZAHJAju359StAMghkhzhY4Ao5xKgRSOqn1+puyZ6B23j7JXVXy7+QD6OZpvjbGbibaZoKDIN85JJdKbE0sxQEaejTSpXgIfwFqVCh0nXow4yriBGKcLjihR5KEuhhTuamKdrw1RYo4ltduQFsCR48N2pLv4c+XpZSm40il/8Bw2DJWlZNp7wzd72nIEXrhLY308Ci+Ctvxy57F20hW7o+2pSiYuzgfgZuU4RweRVvb2GU65aprVTQ5wN39/1OYYI7PUD7neEdGzv0iSpoMUEJ+fhUDAjmjyik5KE49Oj1EscRuVTJJdyG+DEXXxu09INW2g8yL/wHMd5RxN01RkSEUrxEAUrc6E+EVvR1C5Tgz24cRvkUukXrNIQp6ehXuKE36TtUk5Tbw2b6JQ96dGB2KcrHV6tD9fIFohjidEcqpDBycfjCTwWXzqhgX2//6DcUklY1hc1+nU5iOvc1ReFfgd00gJg+VWFDLmrdqEQFYmL0dNA7emoduxLJhPTnkT5RFKp740K2VGwpcvoGm4PpCCa0Ly9utIbMBjQw4gHZLEMMH1ZuMZgNsac8MJyKg8jNMgGQGU7RE7xFEUqLJoDp2XC+yPUZRlFN7rvQXnhvBcGNwiAhBY+AQKodQeUFphhuj2qc3dUaNGG02VQGl4mWNVurK+/HiDX9Eaxe70QQWz2tCecZrZ4+Wb0asnWrBwYze85k9jqZvT0vBjB7oH/itM1eRZyKRz9EoA5HHACG8QR14Fw7jm1Yqq4V//VzmUg2Qo6OphvnJItwjgTK9oIA7xVXmxlkPuMGJTm20dlSJxoVtWiTsyMkwdUNvRGHvp83I7QB/PTrZAb6mSGnuy667VwbVicaa33Q2JxoXOjBGMa1kUhjo0XjACF/4urp661srtoSuDoWPkoKNUDThTpfqAEmt7Yjzezh3FpdbWHyIddDXRuGf/LWQzbCGh1jlCZhEkxL4ulclbwk6lbvdL3I9I1Onbt6/qb3LYrOoLQ9zZlGeKz/mhn39HKIlW9uG7qhtFe+he1BbxQrX3O3R+LCZxkfx7k0JQV7rZzN2XxLU0Q5+QxzTAu4ZpthUPUNaI2BYc10iy2RYTJSLSNhmKXIYVgrnXI2hrERTQx7Y5JEJsPakJzKsPPz6FgQeBCGun9nqXI4BhqrmNp03A+Q7NVYTzUbLyotSC9s7T8uGK+tlTUPmqH6+WwtMzqTrX1b9C3T1vZ3XOwSbK0kfwaoDRMJTrS1Tceoq629SVN3V6u2yitk3R0wFvJUs6h848mmoL9jBhcwzWQFpuAAam+eZs0I95xLevsw5LR/+b7gV+22XPSRqLXam19Ttv3dsZVEAJlrcfAUEJxOAfXnyMkEtv+01EVuodiC9djufT3ujIqMsx4XsA8tRGUwUMYRl47/BEfT0dj7mxICU0fG9zEugTlamznDHelnvZmo80ZPQCZ17Ok0Bbfsc4uONhh1ZHyt9hIMmgCV4ezZODbcL8CeKQO6Ah0WnfFmEzuRwhH5x71vCHVWuMhBNov0gB++kEuaGGCHwVR2GKysQ17PVZPPp4s7IiaTTk9xAH5uPX35Mn/8+/k5UhX87UoVZPYYnrk+OAjN5zViN66yQlM3pIJqrrbVTX6a/Jb9VMtjg6R/ilcWfm/Nm7KpAxNnHImnHsBn+wMsS2S1wVeNM+WJxOiL04J9oq/V0dc+MvqqPj74Dxv8nuAHdfjBR4Zf0waEXwsf/8meV98/fQ3/BXeREV3drK+c42eGKxfovpI29CJ0kwTuVLcghlzfbeuF3ez18KYoVI8VpR0HW8ORetUpEytLkI+xZTlmR3s0cAzjm3PTEByzWSwWWm8+V4tmAhA7pzN1RZ7TJUbu+J7NRDxOWUBAPGNQ4ol+UWEiXocvXA9OvOMZ0ol4nLJ0AfG0QYmnT8TrRrzm1zrPRzxSrH6isjwXVv0OKLj/Hw== -------------------------------------------------------------------------------- /assets/seq2seq4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StatQuest/pytorch-seq2seq/49df8404d938a6edbf729876405558cc2c2b3013/assets/seq2seq4.png -------------------------------------------------------------------------------- /assets/seq2seq4.xml: -------------------------------------------------------------------------------- 1 | 7V3dc5s4EP9rPHP3kAxCgPFjmn7czF3uOtOZa/vUIUa2uWIrwSRx+tefMMJ8SBAlFkI2ykMLQgak3d/uane1TOD1evcpCe5WNzhE8cS2wt0Evp/Ytg18QP7LWp7zlqnt5Q3LJArzJlA2fIl+Idpo0daHKETbWscU4ziN7uqNc7zZoHlaawuSBD/Vuy1wXH/qXbBETMOXeRCzrV+jMF3lrX4xiqz9DxQtV8WTgTfLr9wG85/LBD9s6PMmNlzs//LL66C4F33CdhWE+Clv2o8dfpjA6wTjND9a765RnM1tMW35BH1suXp47wRtUqEf0Pfeps/F2FFIpoKebvCG/PduPyCU/cQiZ3O8jub0eJWuY3IIyCHaRem3rPnSpWffaafDILOT/1CaPlOKBw8pJk04SVd4iTdB/BfGd/R2C7xJPwbrKM5Y6Bo/JBFKyIv+jZ7oRXoP4JFzdtjF/JIfzunA/LwpG1+lD52XTwivUZo8kw4JioM0eqyzQkA5annod/jpZxyRp9oW5X7gUCamvG9bVv0WaZAsUUp/VdKGHFReo2zaU4xPPQdKpx6o0i6n5Ia8zzf6i/1JeU0vyuYTmzfNZBO7RqYOmtB3eQziB1TIgAaN6gR5WkUp+nIX7AfzRORpnTC1KV7GwXbLIeQ2TfBPdI1jnOwfAb25j24X2XxGcVxpDwPkL+avJUIQR8sNOZkTIpBuHVR5REmKdp0zXly1GzApYPNUCtyiaVWRtY51PI2KR49B6kkHAl/q2U1y9ij1HCP1KpStSj04mNibGbEnJPYOuBhA7EH5wGHFXgmcEkffeUQdDjgQDmMJOo516di+OwX0X6cvEVmMsAZHL1hnIIvTjBJ4e2hYpoeJq3AG4ee0Acka2ii3VIFGmxjYZOiIyOLqil5YR2GYPYYrA+rsJyQG5HDI200XZyaGYVsGhj2D4f08uMPYNQox7DIYXmIcGqC+XdkqBaqvAKi6gFHRIgM2jaf+FhnQGLSCGPOHM2gdwNBklMrQsYfBnzplWIywAsc1TjbRZmn04duxqlIfOlPpWD1Jr41Dzbqq20Z6ZEKYKKyNObEJiJ1fE/cDPRxC7d36ruNarNpb+HM010LtNYM8CrXe1B2PZekpWuYxrpreDEuPBzn1CJt5Uxh4LMIQCF001QFhTS+LrxBhKpwsmiBsaqlBGONI6Q1hxYgMwl7pHlGJMJYkb0VY/zjqWgG2kqFq302BGozBhhYDTcO9BWQvLvSA17hRLkGkoFX+ukBfWauKD5rr9P5kbSF0jKx93dJbpayVlyV2CrJWkT0DQUNEwjfK2qZhJCprr5IkeK50u8s6bMUNsGLJU3JRfsc3iwJ52QUnwGaqFqawhWqvZrPmCrdPlc7mQJrI5XGLXyCmLmQ4aqds5Mu42o9eWykkYGFatCb/IJP8I8+QU0lY+W6pkwyh+Bx97EvXx8JUUZHocQJBaC5VBiOKHutS/dM3mobGwd5UsDAtos+jR47NIme4TP7ibQx0XmkJqISOb7Ipamq/pnTswaCjYjftKQg0DlWGk2d6BNxPQJ5ZA5oCAkpHX2+iEChmPPtYekBIeMJ5BrK3dw+E0SM5pI6BIhkMFMlgeR9y+1o39Xg6uRQxlXD6uV7CH7u5s3q8/fbP/X0MrPTzBehh79BJGg0tszNl8dnSczBd1v3mA2s4376FHicIG7rIDx0dEHnBqDgwu3SHRqV8/9E5oZKzy53f0xnMA9j95gaVL6GS8UFpgEqHF0Q5L4O0ZeCKcpQuQHMzSn9ZSt0jrYapH1K0YehsomHiGlXllqIWsp7/UrJl4IrKWlwAdhugYuSy5S3WmDzSQPcYtasBdE1lJ6UGLh/fTb+FcnTr4bA9AcO5SSkd7GZvrNp3qkj7sgVpFMOT520y6WWytHJzI+rwStn4pzpmx6XzLuA1li4gjqTriEo/d9Ouf5Ht1UFtTw+KWpXUdgX8Wcao2hOrmdajgVXlmkJHXbNjD4Rj5bZXMVJT17MfZ+aA9ZFaCD5ij8jLJpUyBZ7pAK/y59Y1BJxdgsrVYiek/J11LdNw5gWBlIeixO2zl+9VMR+U8cOZl69RHuA4ih/aU06U8cOZl9hQ7XI7hh3aUx2UcYPx53TNjnC+kTdY7m73m5sV/mtzFmw4/Aq/WFkaVHZhTQCV0pMgjqSrHhtRTwCVzAfBdEAlSyyDyh6xJpSQYDcr1/TtdfN4X9MxGH45eKkFhEfsQROAsCeqWAvzUhvFeuZOL3HaKQ97gJny8KUnkFZkRPCEE8PQQgYbB5RcB7U4AgWc3VC5C8o7829jKXdQH8UP7d4QZfxw5qlIqh3Ux7BD+ypcFTcUNaXHaLELOLhUWX3NyrO2q54TbI7VlxegyGa0xiPe/QMuLlxs93N9RToA725XXqzXtfht9cPKv3NE3s+a5ye/M4Uu8ke1VLrQI51GhuHowAa9HZ7heMi5UJP/MuV+yFg2B4AqBwDDAYeKJTpwgNM3B4CqDACjlgEQ1jObiGWmAQfw9oNJ5QC7ygH2qDnAdnXkAJ73RyoHwCoHwFFzAJhaGnIAd9OZTA5wqhzgjJoDZjqKAF52i0QG2Boz4MAAjfTmg0AYjvr8TxrIpL4xAcoPxehH/vZilZLIb/R/+ZmoAcnvfdpZf85ubpzN1L76/O7+X7B2L2bd3uHSG/ehbG3MC9djx/Hsie2TI8fvd9WT59r0NzxzL3/jmDtom08l2bFc2F8yTce4Bg/kht6t53ICuYvFwpZWPZYBHYcVxAO5jsUCsa8wLpd0h0DRuIAI+wEis09INRD1SGo7BSAyX1MdHIjdEbIzBaLTExCbG7RUA5Hn4DZAFClYNzwQWVKNAIhuT6bpwDjUo+zcCeCwWaJqeBh2f+FaIxh21VjZn5Q/25/1AV9OyrevKpe7840M8l5yzemnALt3Vxjk1fic84UR/1hlehzy9PhyiP7I03AN2F3kwyCvxuc+B3nHeliPQ55vkCeGPHVuUHKa4Cw+VC4ryDSvbnCIsh7/Aw== -------------------------------------------------------------------------------- /assets/seq2seq5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StatQuest/pytorch-seq2seq/49df8404d938a6edbf729876405558cc2c2b3013/assets/seq2seq5.png -------------------------------------------------------------------------------- /assets/seq2seq5.xml: -------------------------------------------------------------------------------- 1 | 7Ztdk5MwFIZ/TS/dAQJZeql11Rl1R2cv1MsUUsAF0oF02/rrDSV8hMQutjRUl17skkMCJW+ec06SMgOLZPc+Q+vwM/FxPLMMfzcDb2eWNQc2+1sY9qXBMczSEGSRX5pahofoF+ZGg1s3kY9zoSIlJKbRWjR6JE2xRwUbyjKyFautSCzedY0CLBkePBTL1m+RT8PS6lqwsX/AURBWdzbhvDyzRN5jkJFNyu83s8Dq8ClPJ6i6Fn/QPEQ+2bZM4G4GFhkhtDxKdgscF11bdVvZ7t0fztbfO8Mp7dMA8BY53VfPjn3WFbyYkpT9e3N4IFw0MVjJI0nk8eOQJjE7NNkh3kX0e+v4R1HlxilKKc3233mLQ6E5V3dAce4npnTPRwPaUMJMJKMhCUiK4k+ErPnlVySl71ASxcXwWpBNFuGMPcQ93vKT/BomZGW5S6pnZg09/tAWH2QoCzCvBUpT0R2tZrwb32OSYPYgrEKGY0SjJ3HkID4Ag7peIwI74DqoNeHf5QnFG1yNoo5GoiDbMKL4YY0OD7NlQIrCCF0cxCjPFULmNCOPeEFikh1uAVxrCWDRf6sojlt238Gub/+tCCiOgpQVPCYCq3ZElSecUbw72uPVWcgh4k7Gsnl52yBbmcIWrZXtHI2ArMmL5AbI3FTV9IMDJnB6gVMTMAI4pvm8SKeSdCVMOEOPf970C4nYXRv354oqAtiRp4SSt+ooVH+NXqI5kmbBhuJUEo6NP9qRRKCDi9kGg5ukYV6M5oglZK/5iSTy/eI2SmbF0dEL22FkPz1Y2WY/5qwhmDP+e+agHuYsqI05KGmWEHbtCbozAp1O6Gw4ZYjtdLCdItqDh8jec6se6ceUIxaezRovR+wzA/7H49XwqwvqgAVsbQFLOf2CKCmAiYsIhUleGwJa99QUyk7hUWcoA4ak1MsMZVAOZdZYkcyeDy4KF6KS5fo6/uJTaqMT9IzLeUs5v58SEZXjc0fMQ3Qg1vi9xg3+uLJkxR0Hv0smK+4zyUo+JSuDMat12i2vVf4aw7EuXcd2DNmxrlwPe941OFa7G+w0elbnBW87C9mjpsTGNm5A62O7jii9M7+Brc+tLd6gDAaD7CTIS9Eh87Oly90s2eHB2RqVjV2vZZ58cA+8a031OV1HXrBRqWpOqg4cWc1bjSI7GrLhEf2y7v3avpPL5y/UjcgDumt5lqoC25rAPhfs7v6wVrJvJ7KH3BU+mWzpQhckW54Aq8gGE9nnkt3dhNZK9n++SKx7++xksqULXY5s2G+KZU9kn0t2d09OJ9n3+dfN/Tz3fmLr0XiAH1+Duf0K2kdxb0C8a6ydjlbCqoC639o2O367axf2VeHIqveh0DQ7lKp2SuyfXcBWdxZXQV6P0bB7d/Qbjbwn5MMldBR7QqvVyhps6VKiTCFi31lwzaGGlUu1cMd/6zVxJ3SWah10VOyu4zdh149dd446PnfHJ64Td0JnWTJ3YFTu5FXjibs+M8jxuXMn7vpz98d35EYD7zpelLt+8LoTvAuCx4rNW9/lrL95sx7c/QY= -------------------------------------------------------------------------------- /assets/seq2seq6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StatQuest/pytorch-seq2seq/49df8404d938a6edbf729876405558cc2c2b3013/assets/seq2seq6.png -------------------------------------------------------------------------------- /assets/seq2seq6.xml: -------------------------------------------------------------------------------- 1 | 7Vzfk5s2EP5r/JiOQULgx/TukrbTdDpzkybpGwYZk8PIxfLZzl9fYYQBgTnZFoK7kIcLWn5K3+6n3dXKE3C32n9M3PXyE/FxNDGn/n4C7iemOQOQ/U0Fh0xgTY1MECShn4lKgsfwB+bCKZduQx9vKhdSQiIarqtCj8Qx9mhF5iYJ2VUvW5Co+ta1G+Ca4NFzo7r0S+jTZSZ1TFTIf8NhsMzfbKBZdmbuek9BQrYxf9/EBIvjv+z0ys2fxTu6Wbo+2ZVE4GEC7hJCaHa02t/hKB3afNiy+z6cOXv67gTHVOoG/t0besj7jn02FLwZk5j99+uxQzi9ZcpaHlmFHj9e0lXEDg12iPch/ZqKf7F46xu/6NTJtPEdU3rgiLtbSpiIJHRJAhK70Z+ErPnjFiSmH9xVGKUqdEe2SYgT9qF/4R0/yZ9hINaud5uPxIbd6PGOOZko7V/pGj4uHzFZYZoc2AUJjlwaPldVweUaFZyuO936NwnZW83pPtd1ji3XfTNv54+gbhJgyu8qsGEHpc8oREfEmtGDQDl6Rhm7DMmYfc9XfsexUZwbFrLZwHIKUg12BaYWTPi3PLvRFuccIGBUBWS3DCl+XLvHzuwYnVaBqQxxELmbTQOQG5qQJ3xHIpIcXwGQ5+D5Ih3PMIpKct/FzsK7FAQ3CoOYNTwGArusBZVnnFC8bx1xfhY6gplA3t4VhJuLliWuhdPbMTqx70/AesoN4QzriXB2yHpwZL0SsmXWA73R3mykPSnas1B/tGfAGkibiYkimg7lds4Og/QQ5DL2vJK4hifrMRVAq+DBDbEMBRfVBjYdv5C53+/5iVXo++lrGrWkqkdSiqLG+ORRRrCKsmHLoWwqQBmop8f65FbQY8GW35pMtz96BEDTzCf4+wB1NvM5UtGat02ej8gaMsjaLdDa7dj67mZ5elEnQKffdb/PPyxtHHhDSgU+zf+ZhfAzfff79PP9v398cXb/Pb0zTFCfOJHVibK8TxL3ULpgnSrBpkWXLGF2qIbv7CB74rUKlNtEZZpG7iql1WwaIJuTYGT9G0MaaGhkfTSy/nEcujHkF+OdDll/ZshAeyPrvwlih5qwnwnenWjAZ7C/dDZAoOPZwKrNBgEh/kj514dzWinf0UD5Q6F1TWksJIbn3aWxwJgyuS6Y1pkygVJz79t3q6DZj/116VbVzW0Mpi/3ubgPUQ6mbaMTZbnUfbINgTeAWvcpt4kSga9IEodxMHpQ17O7Tg8K2srZ/VWuJMEGI1ZeLSG9gC6Vv76Ml3uOdgXXauF42POanLG5Y0FLmpGhplwH1BTvQiHezVdNVBG2Y42aNSzNsmxNmiW+x1KsWVLZ19s0q+RKmn27kh2r3QCdScsUNAgp1iApV+R6DXp7SqK9osnpiJuE2RUYijVLKk04apY+zaolGTvSLEsIrYDi9QNYXz/4UdM1DSnME8rn9KL3FCYUa8I1pjBRfc2/qerLGKu+bkbZaY5fdKQykFxtnzmifCvK4pKfTpRtqQj2bSz5dVSS9WL9nsi56lYcUH2+7GO6nCEbuA1uFDZ8C9uDmC4FInX0zZa2jjqqgVhYTly6I6nuLCzv0WhhF05iOi2sDsm1Fta9HbUtzZ+FQUP6qr7YKcxispnUF4NTQ1yBzxhEibWqX34bLtdq0gMRvg65NiedkWsvW+HWybXqNoi/Bq7V5M8gcTUJXMm1omMky7UXJwLF90C1iUBb3ZazV6BmugJTdAa1y9VM9A06nNLrP38wlpTfmEXUmF6qlySPFW23Jwj1AZirztn9fXjc36dwV7c+YE+bVsdV3JciLcPoaeN3ZwUCYk38VK33dhqwcR13yOu4Z+rv26s7Cit9KKQC1zZacoPFy1Wpnd1k0OZ3HxvFbcdWdXOCYMwKt/47N5LDjfg17cjXb3o+miOriZQXC1OZ6dXsrAFFabc4L2bp0fTa05ej6TVuzq79SF9/plcvmRhNTyagGYDpta/Sjqb34h490LPtDaMUYvi2J8acHdoeaxY/PJzFEcWPO4OH/wE= -------------------------------------------------------------------------------- /assets/seq2seq7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StatQuest/pytorch-seq2seq/49df8404d938a6edbf729876405558cc2c2b3013/assets/seq2seq7.png -------------------------------------------------------------------------------- /assets/seq2seq7.xml: -------------------------------------------------------------------------------- 1 | 7V1bk5s2FP41fmwGcfdjurm0M0mnM9uZJo8YZEyCkYvxrje/vmAEBknGsi0JvKt9CcjcpO9856YjZWY9rPef82Cz+ooimM5MI9rPrA8z0wSW45T/VC0vdYtnunVDnCcRvujY8Jj8grjRwK27JILb3oUFQmmRbPqNIcoyGBa9tiDP0XP/siVK+2/dBDGkGh7DIKVb/02iYlW3+k0vqvY/YBKvmjcDd17/sgjCn3GOdhl+38y0loe/+ud10DwLd3S7CiL03GmyPs6shxyhoj5a7x9gWo1tM2z1fZ9O/Np+dw6zgusG/N3b4qXpO4zKocCnGcrKf34/dAhWtxjlWYjWSYiPV8U6LQ9BeQj3SfGtan7n4LPv+KK2k9XJD1gULxjxYFegsgnlxQrFKAvSLwht8OOWKCs+BeskrUToAe3yBOblh/4Fn/GP+BnALc/pbuOR2JY3hrhjft1U9a9zDR6XzxCtYZG/lBfkMA2K5KkvCgGWqLi9rr31b5SUbzUNLP2OgbHFsm82580jiiCPYYHvOmJTHnQ+49h0QIyNnm0JRw90sauRzMrv+YbvOJwcf5sWsvXA1k1z0WD3YBrABH/LU5DuYKMDCIz6gDyvkgI+boJDZ55LfdoHpjfEcRpstwwgt0WOfsIHlKL88ArLDX24WFbjmaRppz0KoL8MLwUhSJM4K0/CEoTysgFUnmBewP3giONfbZ+giY3Pn48Kt2ladXStbdyOUat934DWE06EE1qPhFOi1rO11usg29V61mhqb67VHpfac9zx1B6wKZC2M9NNi2ood4vyMK4OraatfF6nmcKz7HFBgNbDAxOxCwVuoga2Gr+kdL/f4x/WSRRVr2FKSV+OuARFDPn4UXbtPsrA40PZFICyJV490sbtqB6P2vI7i7rjqUfLUmT5CH/fcqVZPp8rWgt3+dMBWcCDrDcArTeMbRRsV+2LpABdfdeHffNh1ckLPuESgf2Pfx428L8vtr3Pk+9/zjMvWvwGTI82nK4jRVje53nw0rlgUwnBdkCWHMI69MP38qB+4rUC1HCiZ6bdYF2p1doMoG3boLX+jSGNDRRqfVdr/cM4yCHy2XhHotafAx5ob9T6r0Kxq8rwzQnvjiTwCewvtQauJdkaOJQ1iBGKtMq/PpxTqvJ9BSp/KmpdURrLJcNzeWksS6dMrgumVaZMbC7b+/rdKtsch38y3SqabjqYvtznYmShPSBFWC51nzxA6A1LrPvUcKKjwNcoz5Is1h7U9dpdpQdle8K1+13OJNkOTWLhsRT3BDpX/voyvTxytEu4Vks/hGHIcsYWvmM73BrZVpTrsBXFuzYR7zazJqIUtu9oyZqWZDmeIski3+MIliyu7OttktVxJc2xXUnJYjdBZ9IxCQlyBUsQlytyvQTdt5Aor17yJekhwpJaYECK6LvNM19Zc0nITDNXblGLo6IspCRxdIjYyxqaYBhTHG16puIXJaAKkqWtaJwSptGTpTZZfa4wWerS1QWs+jKg68tuRtlnR0oqkiYuXxWhqVG+FWVyclElyh5XrPw6JhclFX+drRQkda64uQ2XtpdjmMu561kBw/eCIHKgNwlzSShSX5219FRUbE2EYY3iUh3HyWNY0yPNsAuNmEqG0ZBcyzD5PBoqAjgJg4JEGT2tSlgx3pzt2SgXkHP94uJHT/xE33R17UhJCom6tlE6WtdeNpeuUteKW4p+D7pWkT/jkvNW1pW6lnSMeHXtxRlF8j222GkST9zitjsQM1WBqXsCtcvFjPQNJJp0eqMFXbx+YxZRYXqJLn7WtXO3JwjVAdiIzsmVhFCvJBS4flwdsHPxaam7rImcM1YJz8fbXsNj0E19COSbC8tlhEBlAORH9hRCIDDi/hpz8Quy7pM6jDUBwBDuy3LDMo2FVtPnjjniQitgsByKuwn1uHgBDEWpOaBucXrbp24ctitgpt29642WyjUwwDDfAPEU7QVkqlu+2PapH0LHmnm3mDy1zNMbGfZ9w57DCIRzlh+YaZTPTN9jtMgSYKUeo/sGDJenxnBZZCW4TMPFymXo1KEgi0ZSUq1F0xkQbLwM2qLNhXOZH5c3tM1/O/bSo2x1G/23fdIeyRn1N+Ku8gDozYLwQJjjMFCm19L0SW+dKsNrGXHnVGByla7pta5qFsVIW3pNpuaMia51bRMferHrlBe7AvCGs3f9RJ0if9s23lmdP7vZzqcB35m/czt/TamoDIbS+cEVYwmsoZfAXs3wFlSFjgCgE4osWPX6ddHunsqFzQC88rVAqosLeNMe5x9EmmWRGpuuFmdR29TUvpXaZC2DWm6/8jyn6vqFq7lNPUgitxvJOcNt/d+a3cxtslpCKbdNFdnVV8xtcor3am5TD5LJbTrryuK2rbl9K7fJeWOV3GZvLW8N52WPVPx4bCVGmklXBq35Zl5O7p8/NCdzODnedjhr7mMS/9qN+PFoMVIzCmaYh79p5FnLyF24DistvlyawhKZFNMYOPLGwy0XFeQxT3BvOLmpudeXc5vmXl3tMR716F36NPV44tUJcG946z3Nvb6gM2qH6xVf43FvGuXD0+ceGU9OgHvDq5819/qC7jJ8zrp4fjzycVSFa/LN6IBvAuQzNPkuIB9jgwJ/XOpNY5OC6VOP3HhnAtQbTq5q6jH/o+mezzku9ej5TU09ni2TJkA9GilNvdOCzthbxBqZe9PYX2T63CN3tZLIvfI0R6joTjyVw7z6iiJYXfE/ -------------------------------------------------------------------------------- /assets/seq2seq8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StatQuest/pytorch-seq2seq/49df8404d938a6edbf729876405558cc2c2b3013/assets/seq2seq8.png -------------------------------------------------------------------------------- /assets/seq2seq8.xml: -------------------------------------------------------------------------------- 1 | 7V1bc5s4FP41fkwHEGB4TJ206cxuJ7PZmU2fOsTImBYjF8uxnV+/XAQ2SICwucVRHhIQ4qbvnO9cdEQmYLbafw2s9fJvZENvokj2fgLuJooiT3Up/BO1HJIWTZKTBidwbdLp2PDkvkHSSM5ztq4NN7mOGCEPu+t84xz5PpzjXJsVBGiX77ZAXv6ua8uBVMPT3PLo1v9cGy+TVkPRj+0P0HWW6Z1l3UyOvFjz306Atj6530QBi/gnObyy0muRO2yWlo12SVP87uB+AmYBQjjZWu1n0IvGNh22ZIC+lBzNnjuAPuY5YePq4Nuvx2d/c/jnz5+7B+3pYNwoBnnPDT6kIwLtcIDILgrwEjnIt7z7Y+vn+K1hdF0p3FvilRduyuEm3Lv4OWr+pJG9H6TTL4jxgaBvbTEKm47X/guhNbmCbW2W8ZWjHfoV07FE22AOK99rmvSMXubkVDI0XyFaQRwcwg4B9CzsvualwSJC5WT9slMfkRs+jCKlCpDKPxF/WZPyl8BW4EBMzjrCE26cPMaxKQatAYBABp0AeApfAqYfPuhz2i3aOR7LgcaLdAvgJmNbNTqS0rYY5ABsilb25K+Wt4UpbxTgy4OzW7oYPq2teER2IQfnActYJerreNZmQ7bnaOXOyfYGB+g3nCEPBfEtgAQNyTDCIwvX807aXyQIoB61Ix9/sVauF4n1LATEhUH4lN/hjhwkAMtRZ8tzHT/cmYdIht0qoH2FAYb7SiDSo6CgV1OyvzuSdKp6yxN+1qXOoDOvkylNqR+mVAZnSrUfpuybDVMAa9mwA6N4mUqlTy7YsIYNldGxoSnXQ9dUu05BKgOVV58aQcZHpUPoVwmVynlxAHIB54Q0KCq9DQLrcNJtHXXYNLiPVAgxapzhQv9wI3mCVnldMYEQxcGsOuhJFKn71Ihi0dvoSRSvNJo2eyI2MLiPqPXjI44pmjZ5o+kOSOVCJ0RE03z+Ixif/2gIoz0Yzao9GW3qPjVGu0j/vRhtkE5YXJnR7iD3yUZtcJt9OXOMMa+T4ceR5R5XXid7cmGX6+zyyMwykC6fTSqxwpnHe+IA/3hvFro3TpV7stDUfWos9CAGWu4mqm6H4KsiuXj3bp/eLN47kL134SaanwzJ1BSQ/NZz2CuS9ElSp7pspr+5JJS+iyQVr9uv/yCxosjwZdS3iXY/0WYeXOC40mQy/TyZ3iXHKIkMGR8XjFbOHvnIhwVTRJoowxLZD3duebfkwMq17VimWVYyL+dchrIRc7Zg5W60gpkzaTOnMsyc0p2ZY0WfuhcBaLuv4aaD41ePkF7+VItyQCQgOSF8gtw5Qi545UItiIUyuFiYTcQCCLHoRCzk0clF6iRRFmL5UykRAgE399z2+OBmzW0ncMsC7ktD3gHhBtWZPgIBXzhLooh0mzG1k/n8x2PdB7qlCJ3mmxiJJdB2SHHZ9IBiVCLVe062ybD2Xf3RYcJ1HHk8Q3kBuk7n8WwNGrY6hjyebBRiWAapdZXIAzQmH5LUAE1qabfRsNr0HbJa65aBr8CnQ1bjqBcTrBapz4CsxvTBW6K5kRCWcXv/ePv478Pzd9l/2QP4MLv9dqMzfLPW/QpeEDQKA2eLoU8BIcKdcs/AzOuQqvcX7jDTFh9Bh1RahwaLb3QKgxUKH0soURNDNKASqdWTj+241yPRJv55cNnoxytU9copSK16CrLjmcNMud+VP99+UMRX/NdlmoLDWRQefWQFB/ToeZJJV+mN6DSpDpeYYAa/urWKFCCet4RokzWIqcum+jWgo5LWEn70PKDC0reh1K3EbxjZpwQajWvfpY9d+g10cCb8BhavDek2VOvKObRGK8qwxckX+BaMTz+0r6fcUDGr6k58i43wLc7XwQFdC7YVq3E4zrRizXyOVi2cNipt0ljZ29JiRCkpTAoi4BnFyqIm8UzlSy2e2p+2aaxAuRR4mQJeQN4O38p5N7THOjWNZUhzRcc03ALYc2cT+kRWZ5F6bpWBQLa99MsISonL1hZlFluUEp+/fkQpwD34uiK5wbqiN2G4OxKLaSFb0aPvViIWjdYVKUIwuhGMYi1On+ahMn0zcP7R1l90jZF/XCwWynzeFKiS/CMFEwNM3vwjYGh0V/lHNnAtLiQZNPVYv0K9kPioXbJeIeidTxWoRcB516DLcj7YU7ubc6is+hRE0HChxfBM0PDjo2OsauBR8VyelFfrlU60vrj+SeX8NkrthTIz05fWq0LrubR+fPafroy/eq1nlFVUSHXrWl9cH3S21hcv1LvWj6PoYPxaXyxWHF7rOT43e21az1pWWSHWnVcRn632xQu1p/bh7vG/eCXdj/8qDdz/Dw== -------------------------------------------------------------------------------- /assets/seq2seq9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StatQuest/pytorch-seq2seq/49df8404d938a6edbf729876405558cc2c2b3013/assets/seq2seq9.png -------------------------------------------------------------------------------- /assets/seq2seq9.xml: -------------------------------------------------------------------------------- 1 | 7VzLcqM4FP0aL5MCBApZJn6kF9OpqcpUprOakkHGdAC5hRzb/fUjQJiXbOOAIemQjdHVC+noHF2uICMw9rcPFK2W34mNvZGm2NsRmIw0TVW1W/4TWXaJBSpGYnCoa4tCmeHJ/Y2FURHWtWvjsFCQEeIxd1U0WiQIsMUKNkQp2RSLLYhX7HWFHFwxPFnIq1r/dW22TKymBjP7N+w6y7RnFYoBz5H16lCyDkR/Iw0s4r8k20dpW2Kg4RLZZJMzgekIjCkhLLnyt2PsRXObTltSb3Ygd3/fFAesTgUgaoRsl44d23wqRDIgAf+5jweEoyoKT1nEdy1xvWS+xy9Vfom3LvuRu36JilwbUSpgdPdD1IgTWd5+AqK8n5ixnVgNaM0INxHKlsQhAfL+ImQlml+QgM2Q73rR8hqTNXUx5YN4xBuRKdpQIU9XpyQdM69oiUFrYpEh6mBRCiSmaDpy1cQ0PmDiYz4QXoBiDzH3rbhykFiAzr5cBgK/EDjIMRH38oa8NU5XUQmjIiCbpcvw0wrFg9lwRhaBKUyx46EwlAAZMkpe8Zh4hMZdAFObAxjN38L1vJzdNrBp6+eCgDzXCXjC4iDwYvw+KLJdnsi1fK9MwRTm8iYu5eR2SRCvRBpR5zCab5gyvD2KVJoLBfmEOGm6SG8yqqemZY7lqa0JtqCK5ZfkG6jyLS3WPeHAQLiLEm7PnB4Ip6qnwX0vAz8Il4y2eSOq/k1c3msmm2YRRQBL8CRkFrVKCO1voxZoRgUzZ81wUAGOrz9WgqTAKgFmnlDCVKFHtJpd7gDeiQzfte2oGynXi6ujFt3bgf39m5yu1uOc1gbnlD+ec7AbzmmwM87BCmY+4W0PpGuw0XVJOh0OnuVl3Eg5NXVFuQa5P6Po4hi31zD3d6NfirdaDQdn8F4bkBpo/XmvdWIBn3wnbT/OIucr0DvbSqUPlBD5EdG8aO/EJNwbHLafqWGTfQ8fu9xkgVJB6mtussLZKMRv2uZxXVAmTz+3L//NQrLxb9++Pb88GJZzdXNcKLN5mmbWEnASsARAKVwnwPpY+Minqbegm+SBY/Bb2vRbzB7dltvWZfIY8zLVfPlgvo3ZjW+jKt35NuYJ3yYcfJvWONula2N84ZPgghvTUWhPL4YPdPO8AEIiM60E26vR2iVncELm9ZxfxjRWUhtvL2ce2F2D3XtMO6RzNXIgQ1UdUG2q2eW3CbrUbOOoZtd5yDnb7bq8TB/bJA7i1W+4VwXvCveePGOteOMHZP+OUrTLFVtFBcIjDmOpH1VXSusuafHde0r14U6mPtqgPk3VRy2/W9Gl/NwM8vNHyU9lL7uQ/JT7aV1+qg+pMvkBg/w0lR+tfDjWpfwcjzIN8vPZ5Kd8RHcp+Sn307b8wHpP1PogP03lByj9yY/0PAU216RagbRWpKc/OTG7kBNdb9JLewE2+UKRxtFjOYgwKywh+GtN0oyrMAb9jhdQ4WqbZWbyIYLvwTyMflBOZZKWa8sMxbwzNI8LRCtT6CovbdyPjEnUOF96oaB0Tm88vGAStWHRsrwPucq4gfNPvEavTh6xtaASevlpF1RVAnSpEumpa2tnn6XDy9lsPJ7NztXjytnobAZh3EwbGJSUWtXqKXUbx5FyDPRLKPV5x5T9nUzKp0TyVc4He2vg6H3n6PS7GZ8qVJibhm4oVaYtTAtbVg+eT5lPXR7ve4/Pz5sJ1EL1hn4PXsNf6uPtfkNrwCfp2pdwpB7v+PVkm0/sajAyTmTV4lRaT7rmTzJNPlkHDx474NDRO+r5dRwbzqEheR1nsVho5/Ps4Os4h1+5qTBQAnDd43vQ4R4nB7V5gPYLcVKTcLJXSmoDJRtTsvwdXP+cPP6dzsDJwmRJPFLQKyc/xrfin5uT5c/k+udk8/covhAndck+2dXz39E7GkjZgJTlz2ouSEqezP7DURJQzf6NFJj+Dw== -------------------------------------------------------------------------------- /assets/transformer-attention.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StatQuest/pytorch-seq2seq/49df8404d938a6edbf729876405558cc2c2b3013/assets/transformer-attention.png -------------------------------------------------------------------------------- /assets/transformer-decoder.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StatQuest/pytorch-seq2seq/49df8404d938a6edbf729876405558cc2c2b3013/assets/transformer-decoder.png -------------------------------------------------------------------------------- /assets/transformer-encoder.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StatQuest/pytorch-seq2seq/49df8404d938a6edbf729876405558cc2c2b3013/assets/transformer-encoder.png -------------------------------------------------------------------------------- /assets/transformer1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StatQuest/pytorch-seq2seq/49df8404d938a6edbf729876405558cc2c2b3013/assets/transformer1.png --------------------------------------------------------------------------------