├── README.md ├── LICENSE ├── .gitignore ├── utils.py ├── train.py ├── model.py └── NLG_development.ipynb /README.md: -------------------------------------------------------------------------------- 1 | <<<<<<< HEAD 2 | # Natural Language Generation 3 | 4 | Here is my PyTorch implementation of "Unsupervised Natural Language Generation with Denoising Autoencoders" ([on arXiv here](https://arxiv.org/pdf/1804.07899.pdf)). 5 | 6 | 7 | Natural Language Generation with a denoising autoencoder 8 | ======= 9 | # NLG_Autoencoder 10 | 11 | An implementation of ["Unsupervised Natural Language Generation with Denoising Autoencoders"](https://arxiv.org/abs/1804.07899) by Markus Freitag and Scott Roy in PyTorch. This model trains a denoising autoencoder with a bidirectional LSTM network to generate natural sounding language. 12 | 13 | Still a work in progress, but it's able to train successfully. 14 | >>>>>>> e6322efede16f13937227bff98ebe5a18695706c 15 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Mat Leonard 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | import random 3 | 4 | import numpy as np 5 | 6 | punc_tokens = {'!': ' ', 7 | '.': ' ', 8 | '?': ' ', 9 | ',': ' ', 10 | '(': ' ', 11 | ')': ' ', 12 | '"': ' ', 13 | ';': ' ', 14 | '\n': ' ', 15 | '\t': ' ', 16 | '~': ' ', 17 | '-': ' ', 18 | '\'': ' ', 19 | ':': ' ' 20 | } 21 | 22 | 23 | def replace_punctuation(dataset): 24 | return [''.join([punc_tokens.get(char, char) for char in seq]) for seq in dataset] 25 | 26 | 27 | def extract_ngrams(sequence, n=2): 28 | """ Extract n-grams from a sequence """ 29 | ngrams = list(zip(*[sequence[ii:] for ii in range(n)])) 30 | 31 | return ngrams 32 | 33 | 34 | def corrupt(dataset, p_drop=0.6): 35 | """ Corrupt sequences in a dataset by randomly dropping words """ 36 | values, counts = np.unique(np.concatenate(dataset), return_counts=True) 37 | to_drop = set(values[counts > 100]) 38 | 39 | out_seq = [[each for each in seq if np.random.rand() > p_drop*int(each in to_drop)] for seq in dataset] 40 | 41 | return out_seq 42 | 43 | 44 | def shuffle(original_seq, corrupted): 45 | """ Shuffle elements in a corrupted sequence while keeping bigrams 46 | appearing in original sequence. 47 | """ 48 | 49 | if not corrupted: 50 | return corrupted 51 | 52 | # Need to swap words around now but keep bigrams 53 | # Get bigrams for original sequence 54 | seq_grams = extract_ngrams(original_seq) 55 | # Copy this 56 | cor = corrupted.copy() 57 | 58 | # Here I need to collect the tokens into n-grams that show up in the 59 | # original sequence. That way when I shuffle, 2-grams, 3-grams, etc 60 | # will stay together during the randomization. 61 | to_shuffle = [[cor.pop(0)]] 62 | while cor: 63 | if len(cor) == 1: 64 | to_shuffle.append([cor.pop()]) 65 | elif (to_shuffle[-1][-1], cor[0]) not in seq_grams: 66 | to_shuffle.append([cor.pop(0)]) 67 | else: 68 | to_shuffle[-1].append(cor.pop(0)) 69 | 70 | random.shuffle(to_shuffle) 71 | flattened = [elem for lst in to_shuffle for elem in lst] 72 | return flattened 73 | 74 | 75 | def get_tokens(dataset): 76 | # Tokenize our dataset 77 | corpus = " ".join(dataset) 78 | vocab_counter = Counter(corpus.split()) 79 | vocab = vocab_counter.keys() 80 | total_words = sum(vocab_counter.values()) 81 | 82 | vocab_freqs = {word: count/total_words for word, count in vocab_counter.items()} 83 | vocab_sorted = sorted(vocab, key=vocab_freqs.get, reverse=True) 84 | 85 | # Starting at 3 here to reserve special tokens 86 | vocab_to_int = dict(zip(vocab_sorted, range(3, len(vocab)+3))) 87 | 88 | vocab_to_int[""] = 0 # Start of sentence 89 | vocab_to_int[""] = 1 # End of sentence 90 | vocab_to_int[""] = 2 # Unknown word 91 | 92 | int_to_vocab = {val: key for key, val in vocab_to_int.items()} 93 | 94 | return vocab_to_int, int_to_vocab -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | 5 | 6 | def train(dataset, encoder, decoder, enc_opt, dec_opt, criterion, 7 | max_length=50, print_every=1000, plot_every=100, 8 | teacher_forcing=0.5, device=None): 9 | 10 | if device is None: 11 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 12 | 13 | steps = 0 14 | for input_tensor, target_tensor in dataloader(dataset): 15 | loss = 0 16 | print_loss_total = 0 # Reset every print_every 17 | 18 | steps += 1 19 | 20 | input_tensor = input_tensor.to(device) 21 | target_tensor = target_tensor.to(device) 22 | 23 | enc_opt.zero_grad() 24 | dec_opt.zero_grad() 25 | 26 | h, c = encoder.init_hidden(device=device) 27 | 28 | encoder_outputs = torch.zeros(max_length, 2*encoder.hidden_size).to(device) 29 | 30 | # Run input through encoder 31 | enc_outputs, enc_hidden = encoder.forward(input_tensor, (h, c)) 32 | 33 | # Prepare encoder_outputs for attention 34 | encoder_outputs[:enc_outputs.shape[0]] = enc_outputs.squeeze() 35 | 36 | # First decoder input is the token 37 | dec_input = torch.Tensor([[0]]).type(torch.LongTensor).to(device) 38 | dec_hidden = enc_hidden 39 | 40 | dec_outputs = [] 41 | for ii in range(target_tensor.shape[0]): 42 | # Pass in previous output and hidden state 43 | dec_out, dec_hidden, dec_attn = decoder.forward(dec_input, dec_hidden, encoder_outputs) 44 | _, out_token = dec_out.topk(1) 45 | 46 | # Curriculum learning, sometimes use the decoder output as the next input, 47 | # sometimes use the correct token from the target sequence 48 | if np.random.rand() < teacher_forcing: 49 | dec_input = target_tensor[ii].view(*out_token.shape) 50 | else: 51 | dec_input = out_token.detach().to(device) # detach from history as input 52 | 53 | dec_outputs.append(out_token) 54 | 55 | loss += criterion(dec_out, target_tensor[ii]) 56 | 57 | # If the input is the token (end of sentence)... 58 | if dec_input.item() == 1: 59 | break 60 | 61 | loss.backward() 62 | 63 | # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs. 64 | nn.utils.clip_grad_norm_(encoder.parameters(), 5) 65 | nn.utils.clip_grad_norm_(decoder.parameters(), 5) 66 | 67 | enc_opt.step() 68 | dec_opt.step() 69 | 70 | print_loss_total += loss 71 | plot_loss_total += loss 72 | 73 | if steps % print_every == 0: 74 | print_loss_avg = print_loss_total / print_every 75 | print_loss_total = 0 76 | print(f"Loss avg. = {print_loss_avg}") 77 | print([int_to_vocab[each.item()] for each in input_tensor]) 78 | print([int_to_vocab[each.item()] for each in dec_outputs]) -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | 6 | class Encoder(nn.Module): 7 | """ Sequence to sequence bidirectional LTSM encoder network """ 8 | def __init__(self, vocab_size, embedding_size=300, hidden_size=256, 9 | num_layers=2, drop_p=0.5): 10 | super().__init__() 11 | self.hidden_size = hidden_size 12 | self.num_layers = num_layers 13 | 14 | self.embedding = nn.Embedding(vocab_size, embedding_size) 15 | self.lstm = nn.LSTM(embedding_size, hidden_size, num_layers=num_layers, 16 | dropout=drop_p, bidirectional=True) 17 | 18 | def forward(self, input, hidden): 19 | embedded = self.embedding(input) 20 | output, hidden = self.lstm(embedded, hidden) 21 | return output, hidden 22 | 23 | def init_hidden(self, device='cpu'): 24 | """ Create two tensors with shape (num_layers * num_directions, batch, hidden_size) 25 | for the hidden state and cell state 26 | """ 27 | h_0, c_0 = torch.zeros(2, 2*self.num_layers, 1, self.hidden_size, device=device) 28 | 29 | return h_0, c_0 30 | 31 | 32 | 33 | class Decoder(nn.Module): 34 | """ Sequence to sequence bidirectional LSTM decoder network with attention 35 | Attention implementation from: 36 | http://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html 37 | """ 38 | 39 | def __init__(self, vocab_size, embedding_size=300, hidden_size=256, 40 | num_layers=2, drop_p=0.1, max_length=50): 41 | 42 | super().__init__() 43 | self.hidden_size = hidden_size 44 | self.num_layers = num_layers 45 | self.max_length = max_length 46 | 47 | self.embedding = nn.Embedding(vocab_size, embedding_size) 48 | self.attn = nn.Linear(self.hidden_size + embedding_size, self.max_length) 49 | self.attn_combine = nn.Linear(self.hidden_size * 2 + embedding_size, self.hidden_size) 50 | self.dropout = nn.Dropout(drop_p) 51 | self.lstm = nn.LSTM(hidden_size, hidden_size, num_layers=num_layers, 52 | dropout=drop_p, bidirectional=True) 53 | 54 | self.out = nn.Linear(2 * hidden_size, vocab_size) 55 | self.softmax = nn.LogSoftmax(dim=1) 56 | 57 | def forward(self, input, hidden, encoder_outputs): 58 | embedded = self.embedding(input) 59 | embedded = self.dropout(embedded) 60 | 61 | # Learns the attention vector (a probability distribution) here for weighting 62 | # encoder outputs based on the decoder input and encoder hidden vector 63 | attn_weights = F.softmax(self.attn(torch.cat((embedded[0], hidden[0][0]), 1)), dim=1) 64 | 65 | # Applies the attention vector (again, a probability distribution) to the encoder 66 | # outputs which weight the encoder_outputs 67 | attn_applied = torch.bmm(attn_weights.unsqueeze(0), 68 | encoder_outputs.unsqueeze(0)) 69 | 70 | # Now the decoder input is combined with the weighted encoder_outputs and 71 | # passed through a linear transformation as input to the LSTM layer 72 | output = torch.cat((embedded[0], attn_applied[0]), 1) 73 | output = self.attn_combine(output).unsqueeze(0) 74 | output = F.relu(output) 75 | 76 | output, hidden = self.lstm(output, hidden) 77 | output = self.out(output).view(1, -1) 78 | output = self.softmax(output) 79 | 80 | return output, hidden, attn_weights 81 | 82 | def init_hidden(self, device='cpu'): 83 | """ Create two tensors with shape (num_layers * num_directions, batch, hidden_size) 84 | for the hidden state and cell state 85 | """ 86 | h_0, c_0 = torch.zeros(2, 2*self.num_layers, 1, self.hidden_size, device=device) 87 | return h_0, c_0 -------------------------------------------------------------------------------- /NLG_development.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "Implementing this paper: [Unsupervised Natural Language Generation with Denoising Autoencoders](https://arxiv.org/pdf/1804.07899.pdf)\n", 8 | "\n", 9 | "Data from here: http://www.macs.hw.ac.uk/InteractionLab/E2E/#\n" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 1, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "from collections import Counter\n", 19 | "import random\n", 20 | "\n", 21 | "import numpy as np\n", 22 | "import pandas as pd\n", 23 | "import torch\n", 24 | "import torch.nn as nn\n", 25 | "import torch.nn.functional as F\n", 26 | "import torch.optim as optim\n", 27 | "\n", 28 | "import utils" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 2, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "# For E2E Dataset\n", 38 | "trainset = pd.read_csv('e2e-dataset/trainset.csv')\n", 39 | "trainset = trainset.assign(clean=utils.replace_punctuation(trainset['ref']))\n", 40 | "vocab_to_int, int_to_vocab = utils.get_tokens(trainset['clean'])\n", 41 | "as_tokens = trainset['clean'].apply(lambda x: [vocab_to_int[each] for each in x.split()])\n", 42 | "trainset = trainset.assign(tokenized=as_tokens)" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": 3, 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "def dataloader(dataset, p_drop=0.6, max_length=50):\n", 52 | " \n", 53 | " # Corrupt dataset by randomly dropping words\n", 54 | " corrupted = utils.corrupt(dataset)\n", 55 | " # Shuffle words in each sequence\n", 56 | " shuffled = [utils.shuffle(seq, cor_seq) for seq, cor_seq in zip(dataset, corrupted)]\n", 57 | "\n", 58 | " for shuffled_seq, original_seq in zip(shuffled, dataset):\n", 59 | " # need to make sure our input_tensors have at least one element\n", 60 | " if len(shuffled_seq) == 0:\n", 61 | " shuffled_seq = [original_seq[np.random.randint(0, len(original_seq))]]\n", 62 | " \n", 63 | " input_tensor = torch.Tensor(shuffled_seq).view(-1, 1).type(torch.LongTensor)\n", 64 | " \n", 65 | " # Append token to the end of original sequence\n", 66 | " target = original_seq.copy()\n", 67 | " target.append(1)\n", 68 | " target_tensor = torch.Tensor(target).view(-1, 1).type(torch.LongTensor)\n", 69 | " \n", 70 | " yield input_tensor, target_tensor" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": 4, 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "class Encoder(nn.Module):\n", 80 | " \n", 81 | " def __init__(self, vocab_size, embedding_size=300, hidden_size=256, num_layers=2, drop_p=0.5):\n", 82 | " super().__init__()\n", 83 | " self.hidden_size = hidden_size\n", 84 | " self.num_layers = num_layers\n", 85 | " \n", 86 | " self.embedding = nn.Embedding(vocab_size, embedding_size)\n", 87 | " self.lstm = nn.LSTM(embedding_size, hidden_size, num_layers=num_layers, \n", 88 | " dropout=drop_p, bidirectional=True)\n", 89 | " \n", 90 | " def forward(self, input, hidden):\n", 91 | " embedded = self.embedding(input)\n", 92 | " output, hidden = self.lstm(embedded, hidden)\n", 93 | " return output, hidden\n", 94 | " \n", 95 | " def init_hidden(self, device='cpu'):\n", 96 | " \"\"\" Create two tensors with shape (num_layers * num_directions, batch, hidden_size)\n", 97 | " for the hidden state and cell state\n", 98 | " \"\"\"\n", 99 | " h_0, c_0 = torch.zeros(2, 2*self.num_layers, 1, self.hidden_size, device=device)\n", 100 | " \n", 101 | " return h_0, c_0" 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": 5, 107 | "metadata": {}, 108 | "outputs": [], 109 | "source": [ 110 | "# Attention network from http://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html\n", 111 | "class Decoder(nn.Module):\n", 112 | " \n", 113 | " def __init__(self, vocab_size, embedding_size=300, hidden_size=256, \n", 114 | " num_layers=2, drop_p=0.1, max_length=50):\n", 115 | " \n", 116 | " super().__init__()\n", 117 | " self.hidden_size = hidden_size\n", 118 | " self.num_layers = num_layers\n", 119 | " self.max_length = max_length\n", 120 | "\n", 121 | " self.embedding = nn.Embedding(vocab_size, embedding_size)\n", 122 | " self.attn = nn.Linear(self.hidden_size + embedding_size, self.max_length)\n", 123 | " self.attn_combine = nn.Linear(self.hidden_size * 2 + embedding_size, self.hidden_size)\n", 124 | " self.dropout = nn.Dropout(drop_p)\n", 125 | " self.lstm = nn.LSTM(hidden_size, hidden_size, num_layers=num_layers, \n", 126 | " dropout=drop_p, bidirectional=True)\n", 127 | " \n", 128 | " self.out = nn.Linear(2 * hidden_size, vocab_size)\n", 129 | " self.log_softmax = nn.LogSoftmax(dim=1)\n", 130 | " \n", 131 | " def forward(self, input, hidden, encoder_outputs):\n", 132 | " embedded = self.embedding(input)\n", 133 | " embedded = self.dropout(embedded)\n", 134 | " \n", 135 | " # Learns the attention vector (a probability distribution) here for weighting\n", 136 | " # encoder outputs based on the decoder input and encoder hidden vector\n", 137 | " attn_weights = F.softmax(self.attn(torch.cat((embedded[0], hidden[0][0]), 1)), dim=1)\n", 138 | " \n", 139 | " # Applies the attention vector (again, a probability distribution) to the encoder\n", 140 | " # outputs which weight the encoder_outputs\n", 141 | " attn_applied = torch.bmm(attn_weights.unsqueeze(0),\n", 142 | " encoder_outputs.unsqueeze(0))\n", 143 | " \n", 144 | " # Now the decoder input is combined with the weighted encoder_outputs and\n", 145 | " # passed through a linear transformation as input to the LSTM layer\n", 146 | " output = torch.cat((embedded[0], attn_applied[0]), 1)\n", 147 | " output = self.attn_combine(output).unsqueeze(0)\n", 148 | " output = F.relu(output)\n", 149 | " \n", 150 | " output, hidden = self.lstm(output, hidden)\n", 151 | " output = self.out(output).view(1, -1)\n", 152 | " output = self.log_softmax(output)\n", 153 | " \n", 154 | " return output, hidden, attn_weights\n", 155 | " \n", 156 | " def init_hidden(self, device='cpu'):\n", 157 | " \"\"\" Create two tensors with shape (num_layers * num_directions, batch, hidden_size)\n", 158 | " for the hidden state and cell state\n", 159 | " \"\"\"\n", 160 | " h_0, c_0 = torch.zeros(2, 2*self.num_layers, 1, self.hidden_size, device=device)\n", 161 | " return h_0, c_0" 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": 22, 167 | "metadata": {}, 168 | "outputs": [], 169 | "source": [ 170 | "def train(dataset, encoder, decoder, enc_opt, dec_opt, criterion, \n", 171 | " max_length=50, print_every=1000, plot_every=100, \n", 172 | " teacher_forcing=0.5, device=None):\n", 173 | " \n", 174 | " if device is None:\n", 175 | " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", 176 | " \n", 177 | " steps = 0\n", 178 | " plot_losses = []\n", 179 | " for input_tensor, target_tensor in dataloader(dataset):\n", 180 | " loss = 0\n", 181 | " print_loss_total = 0 # Reset every print_every\n", 182 | " plot_loss_total = 0 # Reset every plot_every\n", 183 | " \n", 184 | " steps += 1\n", 185 | " \n", 186 | " input_tensor = input_tensor.to(device)\n", 187 | " target_tensor = target_tensor.to(device)\n", 188 | "\n", 189 | " enc_opt.zero_grad()\n", 190 | " dec_opt.zero_grad()\n", 191 | "\n", 192 | " h, c = encoder.init_hidden(device=device)\n", 193 | " encoder_outputs = torch.zeros(max_length, 2*encoder.hidden_size).to(device)\n", 194 | "\n", 195 | " # Run input through encoder\n", 196 | " enc_outputs, enc_hidden = encoder.forward(input_tensor, (h, c))\n", 197 | " \n", 198 | " # Prepare encoder_outputs for attention\n", 199 | " encoder_outputs[:min(enc_outputs.shape[0], max_length)] = enc_outputs[:max_length,0,:]\n", 200 | "\n", 201 | " # First decoder input is the token\n", 202 | " dec_input = torch.Tensor([[0]]).type(torch.LongTensor).to(device)\n", 203 | " dec_hidden = enc_hidden\n", 204 | "\n", 205 | " dec_outputs = []\n", 206 | " for ii in range(target_tensor.shape[0]):\n", 207 | " # Pass in previous output and hidden state\n", 208 | " dec_out, dec_hidden, dec_attn = decoder.forward(dec_input, dec_hidden, encoder_outputs)\n", 209 | " _, out_token = dec_out.topk(1)\n", 210 | " \n", 211 | " # Curriculum learning, sometimes use the decoder output as the next input,\n", 212 | " # sometimes use the correct token from the target sequence\n", 213 | " if np.random.rand() < teacher_forcing:\n", 214 | " dec_input = target_tensor[ii].view(*out_token.shape)\n", 215 | " else:\n", 216 | " dec_input = out_token.detach().to(device) # detach from history as input\n", 217 | " \n", 218 | " dec_outputs.append(out_token)\n", 219 | "\n", 220 | " loss += criterion(dec_out, target_tensor[ii])\n", 221 | " \n", 222 | " # If the input is the token (end of sentence)...\n", 223 | " if dec_input.item() == 1:\n", 224 | " break\n", 225 | "\n", 226 | " loss.backward()\n", 227 | " \n", 228 | " # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.\n", 229 | " nn.utils.clip_grad_norm_(encoder.parameters(), 5)\n", 230 | " nn.utils.clip_grad_norm_(decoder.parameters(), 5)\n", 231 | "\n", 232 | " enc_opt.step()\n", 233 | " dec_opt.step()\n", 234 | " \n", 235 | " print_loss_total += loss\n", 236 | " plot_loss_total += loss\n", 237 | "\n", 238 | " if steps % print_every == 0:\n", 239 | " print_loss_avg = print_loss_total / print_every\n", 240 | " print_loss_total = 0\n", 241 | " print(f\"Loss avg. = {print_loss_avg}\")\n", 242 | " print([int_to_vocab[each.item()] for each in input_tensor])\n", 243 | " print([int_to_vocab[each.item()] for each in dec_outputs])" 244 | ] 245 | }, 246 | { 247 | "cell_type": "code", 248 | "execution_count": 23, 249 | "metadata": {}, 250 | "outputs": [], 251 | "source": [ 252 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", 253 | "\n", 254 | "# max length for attention\n", 255 | "max_length = 50\n", 256 | "\n", 257 | "encoder = Encoder(len(vocab_to_int), hidden_size=512, drop_p=0.1).to(device)\n", 258 | "decoder = Decoder(len(vocab_to_int), hidden_size=512, drop_p=0.1, max_length=max_length).to(device)\n", 259 | "\n", 260 | "enc_opt = optim.Adam(encoder.parameters(), lr=0.001, amsgrad=True)\n", 261 | "dec_opt = optim.Adam(decoder.parameters(), lr=0.001, amsgrad=True)\n", 262 | "criterion = nn.NLLLoss()" 263 | ] 264 | }, 265 | { 266 | "cell_type": "code", 267 | "execution_count": 24, 268 | "metadata": { 269 | "scrolled": true 270 | }, 271 | "outputs": [ 272 | { 273 | "name": "stdout", 274 | "output_type": "stream", 275 | "text": [ 276 | "Starting epoch 1\n", 277 | "Loss avg. = 0.007070066407322884\n", 278 | "['coffee', 'city', 'centre', 'is', 'a', 'an']\n", 279 | "['The', 'Rice', 'Curry', 'is', 'a', 'family', 'restaurant', 'restaurant', 'shop', 'located', 'near', 'the', 'riverside', 'centre', 'near', 'near', 'has', 'a', 'average', 'customer', 'rating', '', '']\n", 280 | "Loss avg. = 0.03228044509887695\n", 281 | "['', 'The', 'customer', 'Rating', 'that', 'out', 'a', 'located', 'near', 'Plaza', 'Browns', 'Cambridge', 'is', 'is', 'of', '', 'The']\n", 282 | "['The', 'Cambridge', 'is', 'a', 'coffee', 'shop', '', 'in', 'the', '', 'near', '', 'Hotel', '', 'It', 'price', 'is', 'is', 'not', 'for', 'children', 'food', 'It', 'price', 'is', 'not', 'friendly', '', 'has', 'Rice', 'rating', 'is', 'low', 'low', 'out', 'of', 'of', '5', '', '']\n", 283 | "Loss avg. = 0.007794644217938185\n", 284 | "['in', 'one', 'star', 'Curry', 'is', 'a', 'family', 'near', 'the']\n", 285 | "['The', 'Waterman', 'Curry', 'is', 'a', 'family', 'friendly', 'restaurant', 'star', 'restaurant', 'located', 'located', 'the', 'river', 'Rouge', '', 'the', 'city', 'area', '', '']\n" 286 | ] 287 | }, 288 | { 289 | "ename": "KeyboardInterrupt", 290 | "evalue": "", 291 | "output_type": "error", 292 | "traceback": [ 293 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 294 | "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", 295 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 4\u001b[0m train(trainset['tokenized'], encoder, decoder, enc_opt, dec_opt, criterion, \n\u001b[1;32m 5\u001b[0m \u001b[0mteacher_forcing\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0.9\u001b[0m\u001b[0;34m/\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mprint_every\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m4200\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 6\u001b[0;31m max_length=max_length)\n\u001b[0m", 296 | "\u001b[0;32m\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(dataset, encoder, decoder, enc_opt, dec_opt, criterion, max_length, print_every, plot_every, teacher_forcing, device)\u001b[0m\n\u001b[1;32m 37\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mii\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtarget_tensor\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 38\u001b[0m \u001b[0;31m# Pass in previous output and hidden state\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 39\u001b[0;31m \u001b[0mdec_out\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdec_hidden\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdec_attn\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdecoder\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdec_input\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdec_hidden\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mencoder_outputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 40\u001b[0m \u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_token\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdec_out\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtopk\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 41\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 297 | "\u001b[0;32m\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input, hidden, encoder_outputs)\u001b[0m\n\u001b[1;32m 39\u001b[0m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrelu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 40\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 41\u001b[0;31m \u001b[0moutput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhidden\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlstm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhidden\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 42\u001b[0m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mview\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 43\u001b[0m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlog_softmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 298 | "\u001b[0;32m~/miniconda3/envs/pytorch/lib/python3.6/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 489\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 490\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 491\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 492\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 493\u001b[0m \u001b[0mhook_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 299 | "\u001b[0;32m~/miniconda3/envs/pytorch/lib/python3.6/site-packages/torch/nn/modules/rnn.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input, hx)\u001b[0m\n\u001b[1;32m 190\u001b[0m \u001b[0mflat_weight\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mflat_weight\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 191\u001b[0m )\n\u001b[0;32m--> 192\u001b[0;31m \u001b[0moutput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhidden\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mall_weights\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_sizes\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 193\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mis_packed\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 194\u001b[0m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mPackedSequence\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_sizes\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 300 | "\u001b[0;32m~/miniconda3/envs/pytorch/lib/python3.6/site-packages/torch/nn/_functions/rnn.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(input, *fargs, **fkwargs)\u001b[0m\n\u001b[1;32m 321\u001b[0m \u001b[0mfunc\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdecorator\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfunc\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 322\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 323\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mfargs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mfkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 324\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 325\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 301 | "\u001b[0;32m~/miniconda3/envs/pytorch/lib/python3.6/site-packages/torch/nn/_functions/rnn.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(input, weight, hx, batch_sizes)\u001b[0m\n\u001b[1;32m 285\u001b[0m \u001b[0mbatch_first\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdropout\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbool\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbidirectional\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 286\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch_sizes\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mvariable_length\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 287\u001b[0;31m dropout_ts)\n\u001b[0m\u001b[1;32m 288\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 289\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mcx\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 302 | "\u001b[0;31mKeyboardInterrupt\u001b[0m: " 303 | ] 304 | }, 305 | { 306 | "name": "stdout", 307 | "output_type": "stream", 308 | "text": [ 309 | "> \u001b[0;32m/home/mat/miniconda3/envs/pytorch/lib/python3.6/site-packages/torch/nn/_functions/rnn.py\u001b[0m(287)\u001b[0;36mforward\u001b[0;34m()\u001b[0m\n", 310 | "\u001b[0;32m 285 \u001b[0;31m \u001b[0mbatch_first\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdropout\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbool\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbidirectional\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 311 | "\u001b[0m\u001b[0;32m 286 \u001b[0;31m \u001b[0mlist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch_sizes\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mvariable_length\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 312 | "\u001b[0m\u001b[0;32m--> 287 \u001b[0;31m dropout_ts)\n", 313 | "\u001b[0m\u001b[0;32m 288 \u001b[0;31m\u001b[0;34m\u001b[0m\u001b[0m\n", 314 | "\u001b[0m\u001b[0;32m 289 \u001b[0;31m \u001b[0;32mif\u001b[0m \u001b[0mcx\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 315 | "\u001b[0m\n", 316 | "ipdb> q\n" 317 | ] 318 | } 319 | ], 320 | "source": [ 321 | "epochs = 10\n", 322 | "for e in range(1, epochs+1):\n", 323 | " print(f\"Starting epoch {e}\")\n", 324 | " train(trainset['tokenized'], encoder, decoder, enc_opt, dec_opt, criterion, \n", 325 | " teacher_forcing=0.9/e, device=device, print_every=4200,\n", 326 | " max_length=max_length)" 327 | ] 328 | }, 329 | { 330 | "cell_type": "code", 331 | "execution_count": 15, 332 | "metadata": {}, 333 | "outputs": [], 334 | "source": [ 335 | "checkpoint = {\"hidden_size\": 256,\n", 336 | " \"num_layers\": 512,\n", 337 | " \"encoder_sd\": encoder.state_dict(),\n", 338 | " \"decoder_sd\": decoder.state_dict(),\n", 339 | " \"epochs\": 5}\n", 340 | "\n", 341 | "torch.save(checkpoint, \"nlg_07052018.pth\"" 342 | ] 343 | }, 344 | { 345 | "cell_type": "code", 346 | "execution_count": null, 347 | "metadata": {}, 348 | "outputs": [], 349 | "source": [] 350 | } 351 | ], 352 | "metadata": { 353 | "kernelspec": { 354 | "display_name": "Python 3", 355 | "language": "python", 356 | "name": "python3" 357 | }, 358 | "language_info": { 359 | "codemirror_mode": { 360 | "name": "ipython", 361 | "version": 3 362 | }, 363 | "file_extension": ".py", 364 | "mimetype": "text/x-python", 365 | "name": "python", 366 | "nbconvert_exporter": "python", 367 | "pygments_lexer": "ipython3", 368 | "version": "3.6.5" 369 | } 370 | }, 371 | "nbformat": 4, 372 | "nbformat_minor": 2 373 | } 374 | --------------------------------------------------------------------------------