├── .gitignore ├── LICENSE ├── README.md ├── data └── cover.jpg └── src ├── attention.py ├── attention_decoder.py ├── encoder.py ├── etl.py ├── eval.py ├── helpers.py ├── language.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/* 2 | __pycache__/* 3 | data/*params* 4 | data/*.txt -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Mohamed Eid 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # *Neural Machine Translation* implemented in PyTorch 2 | 3 | 4 | 5 | This is a PyTorch implementation of *[Effective Approaches to Attention-based Neural Machine Translation](https://arxiv.org/pdf/1508.04025.pdf)* using *[scheduled sampling](https://arxiv.org/pdf/1506.03099.pdf)* to improve the parameter estimation process. 6 | It uses tab-delimited bilingual sentence pairs acquired from [here](http://www.manythings.org/anki/) to train predictive language models. 7 | 8 | #### Implementation Architecture 9 | 10 | The model is trained end-to-end using *[stacked RNNs](https://cs224d.stanford.edu/reports/Lambert.pdf)* for sequence encoding and decoding. 11 | The decoder is additionally conditioned on a context vector for predicting the next constituent token in the sequence. This vector is computed using an *[attention mechanism](https://www.quora.com/What-is-Attention-Mechanism-in-Neural-Networks)* at each time step. Intuitively, the decoder is attempting to leverage information conglomerated by the encoder by deciding the relevancy of each encoding at each time step of the decoding process. 12 | 13 | ## Results 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 |
Input Sequence (English)Output Sequence (Spanish)
how are you doingestas haciendo
i am going to the storevoy a la tienda
she is a scientistella es cientifico
he is an engineerel es un ingeniero
i am going out to the cityvoy al la de la ciudad
i am running out of ideasme estoy quedando sin ideas
41 | 42 | ## Prerequisites 43 | 44 | * [Python 3.5](https://www.python.org/downloads/release/python-350/) 45 | * [PyTorch](http://pytorch.org/) 46 | * [NumPy](http://www.numpy.org/) 47 | 48 | ## Usage 49 | 50 | To train a new language model invoke *train.py* with the desired language abbreviation you would like to translate english to. For instance, spanish can be translated to by specifying 'spa' as input. 'spa-eng.txt' in the data directory will be used. Other languages can be acquired from [here](http://www.manythings.org/anki/). 51 | 52 | ``` 53 | python train.py langname 54 | ``` 55 | 56 | To translate an input sequence in english into another language, invoke *eval.py* and specify the desired language and sentence. The program will exit if the language model parameters are not found in the data directory or if the language prefix is mistyped. 57 | 58 | ``` 59 | python eval.py langname 'some english words' 60 | ``` 61 | 62 | ## Files 63 | 64 | * [attention.py](attention.py) 65 | 66 | Attention nn module that is responsible for computing the alignment scores. 67 | 68 | * [attention_decoder.py](src/attention_decoder.py) 69 | 70 | Recurrent neural network that makes use of gated recurrent units to translate encoded inputs using attention. 71 | 72 | * [encoder.py](src/encoder.py) 73 | 74 | Recurrent neural network that encodes a given input sequence. 75 | 76 | * [etl.py](src/etl.py) 77 | 78 | Helper functions for data extraction, transformation, and loading. 79 | 80 | * [eval.py](src/eval.py) 81 | 82 | Script for evaluating the sequence-to-sequence model. 83 | 84 | * [helpers.py](src/helpers.py) 85 | 86 | General helper functions. 87 | 88 | * [language.py](src/language.py) 89 | 90 | Class that keeps record of some corpus. Attributes such as vocabulary counts and tokens are stored within instances of this class. 91 | 92 | * [train.py](src/train.py) 93 | 94 | Script for training a new sequence-to-sequence model. 95 | -------------------------------------------------------------------------------- /data/cover.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lingyongyan/Neural-Machine-Translation/f10cf1edaf456c0868a1baf120ec6f9d44279e64/data/cover.jpg -------------------------------------------------------------------------------- /src/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | 6 | 7 | class Attention(nn.Module): 8 | """Attention nn module that is responsible for computing the alignment scores.""" 9 | 10 | def __init__(self, method, hidden_size): 11 | super(Attention, self).__init__() 12 | self.method = method 13 | self.hidden_size = hidden_size 14 | 15 | # Define layers 16 | if self.method == 'general': 17 | self.attention = nn.Linear(self.hidden_size, self.hidden_size) 18 | elif self.method == 'concat': 19 | self.attention = nn.Linear(self.hidden_size * 2, self.hidden_size) 20 | self.other = nn.Parameter(torch.FloatTensor(1, self.hidden_size)) 21 | 22 | def forward(self, hidden, encoder_outputs): 23 | """Attend all encoder inputs conditioned on the previous hidden state of the decoder. 24 | 25 | After creating variables to store the attention energies, calculate their 26 | values for each encoder output and return the normalized values. 27 | 28 | Args: 29 | hidden: decoder hidden output used for condition 30 | encoder_outputs: list of encoder outputs 31 | 32 | Returns: 33 | Normalized (0..1) energy values, re-sized to 1 x 1 x seq_len 34 | """ 35 | 36 | seq_len = len(encoder_outputs) 37 | energies = Variable(torch.zeros(seq_len)).cuda() 38 | for i in range(seq_len): 39 | energies[i] = self._score(hidden, encoder_outputs[i]) 40 | return F.softmax(energies).unsqueeze(0).unsqueeze(0) 41 | 42 | def _score(self, hidden, encoder_output): 43 | """Calculate the relevance of a particular encoder output in respect to the decoder hidden.""" 44 | 45 | if self.method == 'dot': 46 | energy = hidden.dot(encoder_output) 47 | elif self.method == 'general': 48 | energy = self.attention(encoder_output) 49 | energy = hidden.dot(energy) 50 | elif self.method == 'concat': 51 | energy = self.attention(torch.cat((hidden, encoder_output), 1)) 52 | energy = self.other.dor(energy) 53 | return energy 54 | -------------------------------------------------------------------------------- /src/attention_decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from attention import Attention 5 | 6 | 7 | class AttentionDecoderRNN(nn.Module): 8 | """Recurrent neural network that makes use of gated recurrent units to translate encoded inputs using attention.""" 9 | 10 | def __init__(self, attention_model, hidden_size, output_size, n_layers=1, dropout_p=.1): 11 | super(AttentionDecoderRNN, self).__init__() 12 | self.attention_model = attention_model 13 | self.hidden_size = hidden_size 14 | self.output_size = output_size 15 | self.n_layers = n_layers 16 | self.dropout_p = dropout_p 17 | 18 | # Define layers 19 | self.embedding = nn.Embedding(output_size, hidden_size) 20 | self.gru = nn.GRU(hidden_size * 2, hidden_size, n_layers, dropout=dropout_p) 21 | self.out = nn.Linear(hidden_size * 2, output_size) 22 | 23 | # Choose attention model 24 | if attention_model is not None: 25 | self.attention = Attention(attention_model, hidden_size) 26 | 27 | def forward(self, word_input, last_context, last_hidden, encoder_outputs): 28 | """Run forward propagation one step at a time. 29 | 30 | Get the embedding of the current input word (last output word) [s = 1 x batch_size x seq_len] 31 | then combine them with the previous context. Use this as input and run through the RNN. Next, 32 | calculate the attention from the current RNN state and all encoder outputs. The final output 33 | is the next word prediction using the RNN hidden state and context vector. 34 | 35 | Args: 36 | word_input: torch Variable representing the word input constituent 37 | last_context: torch Variable representing the previous context 38 | last_hidden: torch Variable representing the previous hidden state output 39 | encoder_outputs: torch Variable containing the encoder output values 40 | 41 | Return: 42 | output: torch Variable representing the predicted word constituent 43 | context: torch Variable representing the context value 44 | hidden: torch Variable representing the hidden state of the RNN 45 | attention_weights: torch Variable retrieved from the attention model 46 | """ 47 | 48 | # Run through RNN 49 | word_embedded = self.embedding(word_input).view(1, 1, -1) 50 | rnn_input = torch.cat((word_embedded, last_context.unsqueeze(0)), 2) 51 | rnn_output, hidden = self.gru(rnn_input, last_hidden) 52 | 53 | # Calculate attention 54 | attention_weights = self.attention(rnn_output.squeeze(0), encoder_outputs) 55 | context = attention_weights.bmm(encoder_outputs.transpose(0, 1)) 56 | 57 | # Predict output 58 | rnn_output = rnn_output.squeeze(0) 59 | context = context.squeeze(1) 60 | output = F.log_softmax(self.out(torch.cat((rnn_output, context), 1))) 61 | return output, context, hidden, attention_weights 62 | -------------------------------------------------------------------------------- /src/encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | 5 | 6 | class EncoderRNN(nn.Module): 7 | """Recurrent neural network that encodes a given input sequence.""" 8 | 9 | def __init__(self, input_size, hidden_size, n_layers=1): 10 | super(EncoderRNN, self).__init__() 11 | self.input_size = input_size 12 | self.hidden_size = hidden_size 13 | self.n_layers = n_layers 14 | 15 | self.embedding = nn.Embedding(input_size, hidden_size) 16 | self.gru = nn.GRU(hidden_size, hidden_size, n_layers) 17 | 18 | def forward(self, word_inputs, hidden): 19 | seq_len = len(word_inputs) 20 | embedded = self.embedding(word_inputs).view(seq_len, 1, -1) 21 | output, hidden = self.gru(embedded, hidden) 22 | return output, hidden 23 | 24 | def init_hidden(self): 25 | hidden = Variable(torch.zeros(self.n_layers, 1, self.hidden_size)) 26 | hidden = hidden.cuda() 27 | return hidden 28 | -------------------------------------------------------------------------------- /src/etl.py: -------------------------------------------------------------------------------- 1 | import helpers 2 | import torch 3 | from language import Language 4 | from torch.autograd import Variable 5 | 6 | """ 7 | Data Extraction 8 | """ 9 | 10 | max_length = 20 11 | 12 | 13 | def filter_pair(p): 14 | is_good_length = len(p[0].split(' ')) < max_length and len(p[1].split(' ')) < max_length 15 | return is_good_length 16 | 17 | 18 | def filter_pairs(pairs): 19 | return [pair for pair in pairs if filter_pair(pair)] 20 | 21 | 22 | def prepare_data(lang_name): 23 | 24 | # Read and filter sentences 25 | input_lang, output_lang, pairs = read_languages(lang_name) 26 | pairs = filter_pairs(pairs) 27 | 28 | # Index words 29 | for pair in pairs: 30 | input_lang.index_words(pair[0]) 31 | output_lang.index_words(pair[1]) 32 | 33 | return input_lang, output_lang, pairs 34 | 35 | 36 | def read_languages(lang): 37 | 38 | # Read and parse the text file 39 | doc = open('../data/%s.txt' % lang).read() 40 | lines = doc.strip().split('\n') 41 | 42 | # Transform the data and initialize language instances 43 | pairs = [[helpers.normalize_string(s) for s in l.split('\t')] for l in lines] 44 | input_lang = Language('spa') 45 | output_lang = Language(lang) 46 | 47 | return input_lang, output_lang, pairs 48 | 49 | 50 | """ 51 | Data Transformation 52 | """ 53 | 54 | 55 | # Returns a list of indexes, one for each word in the sentence 56 | def indexes_from_sentence(lang, sentence): 57 | return [lang.word2index[word] for word in sentence.split(' ')] 58 | 59 | 60 | def variable_from_sentence(lang, sentence): 61 | indexes = indexes_from_sentence(lang, sentence) 62 | indexes.append(1) 63 | var = Variable(torch.LongTensor(indexes).view(-1, 1)) 64 | var = var.cuda() 65 | return var 66 | 67 | 68 | def variables_from_pair(pair, input_lang, output_lang): 69 | input_variable = variable_from_sentence(input_lang, pair[0]) 70 | target_variable = variable_from_sentence(output_lang, pair[1]) 71 | return input_variable, target_variable 72 | 73 | -------------------------------------------------------------------------------- /src/eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import etl 3 | import helpers 4 | import torch 5 | from attention_decoder import AttentionDecoderRNN 6 | from encoder import EncoderRNN 7 | from language import Language 8 | from torch.autograd import Variable 9 | 10 | 11 | # Parse argument for input sentence 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('language') 14 | parser.add_argument('input') 15 | args = parser.parse_args() 16 | helpers.validate_language_params(args.language) 17 | 18 | input_lang, output_lang, pairs = etl.prepare_data(args.language) 19 | attn_model = 'general' 20 | hidden_size = 500 21 | n_layers = 2 22 | dropout_p = 0.05 23 | 24 | # Initialize models 25 | encoder = EncoderRNN(input_lang.n_words, hidden_size, n_layers) 26 | decoder = AttentionDecoderRNN(attn_model, hidden_size, output_lang.n_words, n_layers, dropout_p=dropout_p) 27 | 28 | # Load model parameters 29 | encoder.load_state_dict(torch.load('../data/encoder_params_{}'.format(args.language))) 30 | decoder.load_state_dict(torch.load('../data/decoder_params_{}'.format(args.language))) 31 | decoder.attention.load_state_dict(torch.load('../data/attention_params_{}'.format(args.language))) 32 | 33 | # Move models to GPU 34 | encoder.cuda() 35 | decoder.cuda() 36 | 37 | 38 | def evaluate(sentence, max_length=10): 39 | input_variable = etl.variable_from_sentence(input_lang, sentence) 40 | input_length = input_variable.size()[0] 41 | 42 | # Run through encoder 43 | encoder_hidden = encoder.init_hidden() 44 | encoder_outputs, encoder_hidden = encoder(input_variable, encoder_hidden) 45 | 46 | # Create starting vectors for decoder 47 | decoder_input = Variable(torch.LongTensor([[Language.sos_token]])) # SOS 48 | decoder_context = Variable(torch.zeros(1, decoder.hidden_size)) 49 | decoder_input = decoder_input.cuda() 50 | decoder_context = decoder_context.cuda() 51 | 52 | decoder_hidden = encoder_hidden 53 | 54 | decoded_words = [] 55 | decoder_attentions = torch.zeros(max_length, max_length) 56 | 57 | # Run through decoder 58 | for di in range(max_length): 59 | decoder_output, decoder_context, decoder_hidden, decoder_attention = decoder(decoder_input, decoder_context, 60 | decoder_hidden, encoder_outputs) 61 | decoder_attentions[di, :decoder_attention.size(2)] += decoder_attention.squeeze(0).squeeze(0).cpu().data 62 | 63 | # Choose top word from output 64 | topv, topi = decoder_output.data.topk(1) 65 | ni = topi[0][0] 66 | if ni == Language.eos_token: 67 | decoded_words.append('') 68 | break 69 | else: 70 | decoded_words.append(output_lang.index2word[ni]) 71 | 72 | # Next input is chosen word 73 | decoder_input = Variable(torch.LongTensor([[ni]])) 74 | decoder_input = decoder_input.cuda() 75 | 76 | return decoded_words, decoder_attentions[:di + 1, :len(encoder_outputs)] 77 | 78 | sentence = helpers.normalize_string(args.input) 79 | output_words, decoder_attn = evaluate(sentence) 80 | output_sentence = ' '.join(output_words) 81 | print(output_sentence) 82 | -------------------------------------------------------------------------------- /src/helpers.py: -------------------------------------------------------------------------------- 1 | import math 2 | import matplotlib.pyplot as plt 3 | import matplotlib.ticker as ticker 4 | import os 5 | import re 6 | import time 7 | import unicodedata 8 | 9 | 10 | def show_plot(points): 11 | plt.figure() 12 | fig, ax = plt.subplots() 13 | loc = ticker.MultipleLocator(base=0.2) # put ticks at regular intervals 14 | ax.yaxis.set_major_locator(loc) 15 | plt.plot(points) 16 | 17 | 18 | def as_minutes(s): 19 | m = math.floor(s / 60) 20 | s -= m * 60 21 | return '%dm %ds' % (m, s) 22 | 23 | 24 | def time_since(since, percent): 25 | now = time.time() 26 | s = now - since 27 | es = s / (percent) 28 | rs = es - s 29 | return '%s (- %s)' % (as_minutes(s), as_minutes(rs)) 30 | 31 | 32 | # Lowercase, trim, and remove non-letter characters 33 | def normalize_string(s): 34 | s = unicode_to_ascii(s.lower().strip()) 35 | s = re.sub(r"([.!?])", r" \1", s) 36 | s = re.sub(r"[^a-zA-Z.!?]+", r" ", s) 37 | return s 38 | 39 | 40 | # Turns a unicode string to plain ASCII (http://stackoverflow.com/a/518232/2809427) 41 | def unicode_to_ascii(s): 42 | chars = [c for c in unicodedata.normalize('NFD', s) if unicodedata.category(c) != 'Mn'] 43 | char_list = ''.join(chars) 44 | return char_list 45 | 46 | 47 | def validate_language(l): 48 | p = '../data/{}.txt'.format(l) 49 | p = os.path.abspath(p) 50 | print(p) 51 | 52 | if not os.path.exists(p): 53 | url = 'http://www.manythings.org/anki/' 54 | print("{}.txt does not exist in the data directory. Please go to '{}' and download the data set.".format(l, url)) 55 | exit(1) 56 | 57 | 58 | def validate_language_params(l): 59 | is_missing = (not os.path.exists('../data/attention_params_{}'.format(l)) 60 | or not os.path.exists('../data/decoder_params_{}'.format(l)) 61 | or not os.path.exists('../data/encoder_params_{}'.format(l))) 62 | 63 | if is_missing: 64 | print("Model params for language '{}' do not exist in the data directory. Please train a new model for this language.".format(l)) 65 | exit(1) 66 | -------------------------------------------------------------------------------- /src/language.py: -------------------------------------------------------------------------------- 1 | class Language: 2 | sos_token = 0 3 | eos_token = 1 4 | 5 | def __init__(self, name): 6 | self.name = name 7 | self.word2index = {} 8 | self.word2count = {} 9 | self.index2word = {0: '', 1: '', '': 2} 10 | self.n_words = len(self.index2word) 11 | 12 | def index_words(self, sentence): 13 | for word in sentence.split(' '): 14 | self.index_word(word) 15 | 16 | def index_word(self, word): 17 | if word not in self.word2index: 18 | self.word2index[word] = self.n_words 19 | self.word2count[word] = 1 20 | self.index2word[self.n_words] = word 21 | self.n_words += 1 22 | else: 23 | self.word2count[word] += 1 24 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import etl 3 | import helpers 4 | import random 5 | import time 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | from torch.autograd import Variable 10 | from attention_decoder import AttentionDecoderRNN 11 | from encoder import EncoderRNN 12 | 13 | # Parse argument for language to train 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('language') 16 | args = parser.parse_args() 17 | helpers.validate_language(args.language) 18 | 19 | teacher_forcing_ratio = .5 20 | clip = 5. 21 | 22 | 23 | def train(input_var, target_var, encoder, decoder, encoder_opt, decoder_opt, criterion): 24 | # Initialize optimizers and loss 25 | encoder_opt.zero_grad() 26 | decoder_opt.zero_grad() 27 | loss = 0 28 | 29 | # Get input and target seq lengths 30 | target_length = target_var.size()[0] 31 | 32 | # Run through encoder 33 | encoder_hidden = encoder.init_hidden() 34 | encoder_outputs, encoder_hidden = encoder(input_var, encoder_hidden) 35 | 36 | # Prepare input and output variables 37 | decoder_input = Variable(torch.LongTensor([0])) 38 | decoder_input = decoder_input.cuda() 39 | decoder_context = Variable(torch.zeros(1, decoder.hidden_size)) 40 | decoder_context = decoder_context.cuda() 41 | decoder_hidden = encoder_hidden 42 | 43 | # Scheduled sampling 44 | use_teacher_forcing = random.random() < teacher_forcing_ratio 45 | if use_teacher_forcing: 46 | # Feed target as the next input 47 | for di in range(target_length): 48 | decoder_output, decoder_context, decoder_hidden, decoder_attention = decoder(decoder_input, 49 | decoder_context, 50 | decoder_hidden, 51 | encoder_outputs) 52 | loss += criterion(decoder_output[0], target_var[di]) 53 | decoder_input = target_var[di] 54 | else: 55 | # Use previous prediction as next input 56 | for di in range(target_length): 57 | decoder_output, decoder_context, decoder_hidden, decoder_attention = decoder(decoder_input, 58 | decoder_context, 59 | decoder_hidden, 60 | encoder_outputs) 61 | loss += criterion(decoder_output[0], target_var[di]) 62 | 63 | topv, topi = decoder_output.data.topk(1) 64 | ni = topi[0][0] 65 | 66 | decoder_input = Variable(torch.LongTensor([[ni]])) 67 | decoder_input = decoder_input.cuda() 68 | 69 | if ni == 1: 70 | break 71 | 72 | # Backpropagation 73 | loss.backward() 74 | torch.nn.utils.clip_grad_norm(encoder.parameters(), clip) 75 | torch.nn.utils.clip_grad_norm(decoder.parameters(), clip) 76 | encoder_opt.step() 77 | decoder_opt.step() 78 | 79 | return loss.data[0] / target_length 80 | 81 | input_lang, output_lang, pairs = etl.prepare_data(args.language) 82 | 83 | attn_model = 'general' 84 | hidden_size = 500 85 | n_layers = 2 86 | dropout_p = 0.05 87 | 88 | # Initialize models 89 | encoder = EncoderRNN(input_lang.n_words, hidden_size, n_layers) 90 | decoder = AttentionDecoderRNN(attn_model, hidden_size, output_lang.n_words, n_layers, dropout_p=dropout_p) 91 | 92 | # Move models to GPU 93 | encoder.cuda() 94 | decoder.cuda() 95 | 96 | # Initialize optimizers and criterion 97 | learning_rate = 0.0001 98 | encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate) 99 | decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate) 100 | criterion = nn.NLLLoss() 101 | 102 | # Configuring training 103 | n_epochs = 100000 104 | plot_every = 200 105 | print_every = 1000 106 | 107 | # Keep track of time elapsed and running averages 108 | start = time.time() 109 | plot_losses = [] 110 | print_loss_total = 0 # Reset every print_every 111 | plot_loss_total = 0 # Reset every plot_every 112 | 113 | # Begin training 114 | for epoch in range(1, n_epochs + 1): 115 | 116 | # Get training data for this cycle 117 | training_pair = etl.variables_from_pair(random.choice(pairs), input_lang, output_lang) 118 | input_variable = training_pair[0] 119 | target_variable = training_pair[1] 120 | 121 | # Run the train step 122 | loss = train(input_variable, target_variable, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion) 123 | 124 | # Keep track of loss 125 | print_loss_total += loss 126 | plot_loss_total += loss 127 | 128 | if epoch == 0: 129 | continue 130 | 131 | if epoch % print_every == 0: 132 | print_loss_avg = print_loss_total / print_every 133 | print_loss_total = 0 134 | time_since = helpers.time_since(start, epoch / n_epochs) 135 | print('%s (%d %d%%) %.4f' % (time_since, epoch, epoch / n_epochs * 100, print_loss_avg)) 136 | 137 | if epoch % plot_every == 0: 138 | plot_loss_avg = plot_loss_total / plot_every 139 | plot_losses.append(plot_loss_avg) 140 | plot_loss_total = 0 141 | 142 | 143 | # Save our models 144 | torch.save(encoder.state_dict(), '../data/encoder_params_{}'.format(args.language)) 145 | torch.save(decoder.state_dict(), '../data/decoder_params_{}'.format(args.language)) 146 | torch.save(decoder.attention.state_dict(), '../data/attention_params_{}'.format(args.language)) 147 | 148 | # Plot loss 149 | helpers.show_plot(plot_losses) 150 | --------------------------------------------------------------------------------