├── .gitignore
├── LICENSE
├── README.md
├── data
└── cover.jpg
└── src
├── attention.py
├── attention_decoder.py
├── encoder.py
├── etl.py
├── eval.py
├── helpers.py
├── language.py
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea/*
2 | __pycache__/*
3 | data/*params*
4 | data/*.txt
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2017 Mohamed Eid
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # *Neural Machine Translation* implemented in PyTorch
2 |
3 |
4 |
5 | This is a PyTorch implementation of *[Effective Approaches to Attention-based Neural Machine Translation](https://arxiv.org/pdf/1508.04025.pdf)* using *[scheduled sampling](https://arxiv.org/pdf/1506.03099.pdf)* to improve the parameter estimation process.
6 | It uses tab-delimited bilingual sentence pairs acquired from [here](http://www.manythings.org/anki/) to train predictive language models.
7 |
8 | #### Implementation Architecture
9 |
10 | The model is trained end-to-end using *[stacked RNNs](https://cs224d.stanford.edu/reports/Lambert.pdf)* for sequence encoding and decoding.
11 | The decoder is additionally conditioned on a context vector for predicting the next constituent token in the sequence. This vector is computed using an *[attention mechanism](https://www.quora.com/What-is-Attention-Mechanism-in-Neural-Networks)* at each time step. Intuitively, the decoder is attempting to leverage information conglomerated by the encoder by deciding the relevancy of each encoding at each time step of the decoding process.
12 |
13 | ## Results
14 |
15 |
16 | Input Sequence (English) |
17 | Output Sequence (Spanish) |
18 |
19 | how are you doing |
20 | estas haciendo |
21 |
22 |
23 | i am going to the store |
24 | voy a la tienda |
25 |
26 | she is a scientist |
27 | ella es cientifico |
28 |
29 | he is an engineer |
30 | el es un ingeniero |
31 |
32 |
33 | i am going out to the city |
34 | voy al la de la ciudad |
35 |
36 |
37 | i am running out of ideas |
38 | me estoy quedando sin ideas |
39 |
40 |
41 |
42 | ## Prerequisites
43 |
44 | * [Python 3.5](https://www.python.org/downloads/release/python-350/)
45 | * [PyTorch](http://pytorch.org/)
46 | * [NumPy](http://www.numpy.org/)
47 |
48 | ## Usage
49 |
50 | To train a new language model invoke *train.py* with the desired language abbreviation you would like to translate english to. For instance, spanish can be translated to by specifying 'spa' as input. 'spa-eng.txt' in the data directory will be used. Other languages can be acquired from [here](http://www.manythings.org/anki/).
51 |
52 | ```
53 | python train.py langname
54 | ```
55 |
56 | To translate an input sequence in english into another language, invoke *eval.py* and specify the desired language and sentence. The program will exit if the language model parameters are not found in the data directory or if the language prefix is mistyped.
57 |
58 | ```
59 | python eval.py langname 'some english words'
60 | ```
61 |
62 | ## Files
63 |
64 | * [attention.py](attention.py)
65 |
66 | Attention nn module that is responsible for computing the alignment scores.
67 |
68 | * [attention_decoder.py](src/attention_decoder.py)
69 |
70 | Recurrent neural network that makes use of gated recurrent units to translate encoded inputs using attention.
71 |
72 | * [encoder.py](src/encoder.py)
73 |
74 | Recurrent neural network that encodes a given input sequence.
75 |
76 | * [etl.py](src/etl.py)
77 |
78 | Helper functions for data extraction, transformation, and loading.
79 |
80 | * [eval.py](src/eval.py)
81 |
82 | Script for evaluating the sequence-to-sequence model.
83 |
84 | * [helpers.py](src/helpers.py)
85 |
86 | General helper functions.
87 |
88 | * [language.py](src/language.py)
89 |
90 | Class that keeps record of some corpus. Attributes such as vocabulary counts and tokens are stored within instances of this class.
91 |
92 | * [train.py](src/train.py)
93 |
94 | Script for training a new sequence-to-sequence model.
95 |
--------------------------------------------------------------------------------
/data/cover.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lingyongyan/Neural-Machine-Translation/f10cf1edaf456c0868a1baf120ec6f9d44279e64/data/cover.jpg
--------------------------------------------------------------------------------
/src/attention.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.autograd import Variable
5 |
6 |
7 | class Attention(nn.Module):
8 | """Attention nn module that is responsible for computing the alignment scores."""
9 |
10 | def __init__(self, method, hidden_size):
11 | super(Attention, self).__init__()
12 | self.method = method
13 | self.hidden_size = hidden_size
14 |
15 | # Define layers
16 | if self.method == 'general':
17 | self.attention = nn.Linear(self.hidden_size, self.hidden_size)
18 | elif self.method == 'concat':
19 | self.attention = nn.Linear(self.hidden_size * 2, self.hidden_size)
20 | self.other = nn.Parameter(torch.FloatTensor(1, self.hidden_size))
21 |
22 | def forward(self, hidden, encoder_outputs):
23 | """Attend all encoder inputs conditioned on the previous hidden state of the decoder.
24 |
25 | After creating variables to store the attention energies, calculate their
26 | values for each encoder output and return the normalized values.
27 |
28 | Args:
29 | hidden: decoder hidden output used for condition
30 | encoder_outputs: list of encoder outputs
31 |
32 | Returns:
33 | Normalized (0..1) energy values, re-sized to 1 x 1 x seq_len
34 | """
35 |
36 | seq_len = len(encoder_outputs)
37 | energies = Variable(torch.zeros(seq_len)).cuda()
38 | for i in range(seq_len):
39 | energies[i] = self._score(hidden, encoder_outputs[i])
40 | return F.softmax(energies).unsqueeze(0).unsqueeze(0)
41 |
42 | def _score(self, hidden, encoder_output):
43 | """Calculate the relevance of a particular encoder output in respect to the decoder hidden."""
44 |
45 | if self.method == 'dot':
46 | energy = hidden.dot(encoder_output)
47 | elif self.method == 'general':
48 | energy = self.attention(encoder_output)
49 | energy = hidden.dot(energy)
50 | elif self.method == 'concat':
51 | energy = self.attention(torch.cat((hidden, encoder_output), 1))
52 | energy = self.other.dor(energy)
53 | return energy
54 |
--------------------------------------------------------------------------------
/src/attention_decoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from attention import Attention
5 |
6 |
7 | class AttentionDecoderRNN(nn.Module):
8 | """Recurrent neural network that makes use of gated recurrent units to translate encoded inputs using attention."""
9 |
10 | def __init__(self, attention_model, hidden_size, output_size, n_layers=1, dropout_p=.1):
11 | super(AttentionDecoderRNN, self).__init__()
12 | self.attention_model = attention_model
13 | self.hidden_size = hidden_size
14 | self.output_size = output_size
15 | self.n_layers = n_layers
16 | self.dropout_p = dropout_p
17 |
18 | # Define layers
19 | self.embedding = nn.Embedding(output_size, hidden_size)
20 | self.gru = nn.GRU(hidden_size * 2, hidden_size, n_layers, dropout=dropout_p)
21 | self.out = nn.Linear(hidden_size * 2, output_size)
22 |
23 | # Choose attention model
24 | if attention_model is not None:
25 | self.attention = Attention(attention_model, hidden_size)
26 |
27 | def forward(self, word_input, last_context, last_hidden, encoder_outputs):
28 | """Run forward propagation one step at a time.
29 |
30 | Get the embedding of the current input word (last output word) [s = 1 x batch_size x seq_len]
31 | then combine them with the previous context. Use this as input and run through the RNN. Next,
32 | calculate the attention from the current RNN state and all encoder outputs. The final output
33 | is the next word prediction using the RNN hidden state and context vector.
34 |
35 | Args:
36 | word_input: torch Variable representing the word input constituent
37 | last_context: torch Variable representing the previous context
38 | last_hidden: torch Variable representing the previous hidden state output
39 | encoder_outputs: torch Variable containing the encoder output values
40 |
41 | Return:
42 | output: torch Variable representing the predicted word constituent
43 | context: torch Variable representing the context value
44 | hidden: torch Variable representing the hidden state of the RNN
45 | attention_weights: torch Variable retrieved from the attention model
46 | """
47 |
48 | # Run through RNN
49 | word_embedded = self.embedding(word_input).view(1, 1, -1)
50 | rnn_input = torch.cat((word_embedded, last_context.unsqueeze(0)), 2)
51 | rnn_output, hidden = self.gru(rnn_input, last_hidden)
52 |
53 | # Calculate attention
54 | attention_weights = self.attention(rnn_output.squeeze(0), encoder_outputs)
55 | context = attention_weights.bmm(encoder_outputs.transpose(0, 1))
56 |
57 | # Predict output
58 | rnn_output = rnn_output.squeeze(0)
59 | context = context.squeeze(1)
60 | output = F.log_softmax(self.out(torch.cat((rnn_output, context), 1)))
61 | return output, context, hidden, attention_weights
62 |
--------------------------------------------------------------------------------
/src/encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.autograd import Variable
4 |
5 |
6 | class EncoderRNN(nn.Module):
7 | """Recurrent neural network that encodes a given input sequence."""
8 |
9 | def __init__(self, input_size, hidden_size, n_layers=1):
10 | super(EncoderRNN, self).__init__()
11 | self.input_size = input_size
12 | self.hidden_size = hidden_size
13 | self.n_layers = n_layers
14 |
15 | self.embedding = nn.Embedding(input_size, hidden_size)
16 | self.gru = nn.GRU(hidden_size, hidden_size, n_layers)
17 |
18 | def forward(self, word_inputs, hidden):
19 | seq_len = len(word_inputs)
20 | embedded = self.embedding(word_inputs).view(seq_len, 1, -1)
21 | output, hidden = self.gru(embedded, hidden)
22 | return output, hidden
23 |
24 | def init_hidden(self):
25 | hidden = Variable(torch.zeros(self.n_layers, 1, self.hidden_size))
26 | hidden = hidden.cuda()
27 | return hidden
28 |
--------------------------------------------------------------------------------
/src/etl.py:
--------------------------------------------------------------------------------
1 | import helpers
2 | import torch
3 | from language import Language
4 | from torch.autograd import Variable
5 |
6 | """
7 | Data Extraction
8 | """
9 |
10 | max_length = 20
11 |
12 |
13 | def filter_pair(p):
14 | is_good_length = len(p[0].split(' ')) < max_length and len(p[1].split(' ')) < max_length
15 | return is_good_length
16 |
17 |
18 | def filter_pairs(pairs):
19 | return [pair for pair in pairs if filter_pair(pair)]
20 |
21 |
22 | def prepare_data(lang_name):
23 |
24 | # Read and filter sentences
25 | input_lang, output_lang, pairs = read_languages(lang_name)
26 | pairs = filter_pairs(pairs)
27 |
28 | # Index words
29 | for pair in pairs:
30 | input_lang.index_words(pair[0])
31 | output_lang.index_words(pair[1])
32 |
33 | return input_lang, output_lang, pairs
34 |
35 |
36 | def read_languages(lang):
37 |
38 | # Read and parse the text file
39 | doc = open('../data/%s.txt' % lang).read()
40 | lines = doc.strip().split('\n')
41 |
42 | # Transform the data and initialize language instances
43 | pairs = [[helpers.normalize_string(s) for s in l.split('\t')] for l in lines]
44 | input_lang = Language('spa')
45 | output_lang = Language(lang)
46 |
47 | return input_lang, output_lang, pairs
48 |
49 |
50 | """
51 | Data Transformation
52 | """
53 |
54 |
55 | # Returns a list of indexes, one for each word in the sentence
56 | def indexes_from_sentence(lang, sentence):
57 | return [lang.word2index[word] for word in sentence.split(' ')]
58 |
59 |
60 | def variable_from_sentence(lang, sentence):
61 | indexes = indexes_from_sentence(lang, sentence)
62 | indexes.append(1)
63 | var = Variable(torch.LongTensor(indexes).view(-1, 1))
64 | var = var.cuda()
65 | return var
66 |
67 |
68 | def variables_from_pair(pair, input_lang, output_lang):
69 | input_variable = variable_from_sentence(input_lang, pair[0])
70 | target_variable = variable_from_sentence(output_lang, pair[1])
71 | return input_variable, target_variable
72 |
73 |
--------------------------------------------------------------------------------
/src/eval.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import etl
3 | import helpers
4 | import torch
5 | from attention_decoder import AttentionDecoderRNN
6 | from encoder import EncoderRNN
7 | from language import Language
8 | from torch.autograd import Variable
9 |
10 |
11 | # Parse argument for input sentence
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument('language')
14 | parser.add_argument('input')
15 | args = parser.parse_args()
16 | helpers.validate_language_params(args.language)
17 |
18 | input_lang, output_lang, pairs = etl.prepare_data(args.language)
19 | attn_model = 'general'
20 | hidden_size = 500
21 | n_layers = 2
22 | dropout_p = 0.05
23 |
24 | # Initialize models
25 | encoder = EncoderRNN(input_lang.n_words, hidden_size, n_layers)
26 | decoder = AttentionDecoderRNN(attn_model, hidden_size, output_lang.n_words, n_layers, dropout_p=dropout_p)
27 |
28 | # Load model parameters
29 | encoder.load_state_dict(torch.load('../data/encoder_params_{}'.format(args.language)))
30 | decoder.load_state_dict(torch.load('../data/decoder_params_{}'.format(args.language)))
31 | decoder.attention.load_state_dict(torch.load('../data/attention_params_{}'.format(args.language)))
32 |
33 | # Move models to GPU
34 | encoder.cuda()
35 | decoder.cuda()
36 |
37 |
38 | def evaluate(sentence, max_length=10):
39 | input_variable = etl.variable_from_sentence(input_lang, sentence)
40 | input_length = input_variable.size()[0]
41 |
42 | # Run through encoder
43 | encoder_hidden = encoder.init_hidden()
44 | encoder_outputs, encoder_hidden = encoder(input_variable, encoder_hidden)
45 |
46 | # Create starting vectors for decoder
47 | decoder_input = Variable(torch.LongTensor([[Language.sos_token]])) # SOS
48 | decoder_context = Variable(torch.zeros(1, decoder.hidden_size))
49 | decoder_input = decoder_input.cuda()
50 | decoder_context = decoder_context.cuda()
51 |
52 | decoder_hidden = encoder_hidden
53 |
54 | decoded_words = []
55 | decoder_attentions = torch.zeros(max_length, max_length)
56 |
57 | # Run through decoder
58 | for di in range(max_length):
59 | decoder_output, decoder_context, decoder_hidden, decoder_attention = decoder(decoder_input, decoder_context,
60 | decoder_hidden, encoder_outputs)
61 | decoder_attentions[di, :decoder_attention.size(2)] += decoder_attention.squeeze(0).squeeze(0).cpu().data
62 |
63 | # Choose top word from output
64 | topv, topi = decoder_output.data.topk(1)
65 | ni = topi[0][0]
66 | if ni == Language.eos_token:
67 | decoded_words.append('')
68 | break
69 | else:
70 | decoded_words.append(output_lang.index2word[ni])
71 |
72 | # Next input is chosen word
73 | decoder_input = Variable(torch.LongTensor([[ni]]))
74 | decoder_input = decoder_input.cuda()
75 |
76 | return decoded_words, decoder_attentions[:di + 1, :len(encoder_outputs)]
77 |
78 | sentence = helpers.normalize_string(args.input)
79 | output_words, decoder_attn = evaluate(sentence)
80 | output_sentence = ' '.join(output_words)
81 | print(output_sentence)
82 |
--------------------------------------------------------------------------------
/src/helpers.py:
--------------------------------------------------------------------------------
1 | import math
2 | import matplotlib.pyplot as plt
3 | import matplotlib.ticker as ticker
4 | import os
5 | import re
6 | import time
7 | import unicodedata
8 |
9 |
10 | def show_plot(points):
11 | plt.figure()
12 | fig, ax = plt.subplots()
13 | loc = ticker.MultipleLocator(base=0.2) # put ticks at regular intervals
14 | ax.yaxis.set_major_locator(loc)
15 | plt.plot(points)
16 |
17 |
18 | def as_minutes(s):
19 | m = math.floor(s / 60)
20 | s -= m * 60
21 | return '%dm %ds' % (m, s)
22 |
23 |
24 | def time_since(since, percent):
25 | now = time.time()
26 | s = now - since
27 | es = s / (percent)
28 | rs = es - s
29 | return '%s (- %s)' % (as_minutes(s), as_minutes(rs))
30 |
31 |
32 | # Lowercase, trim, and remove non-letter characters
33 | def normalize_string(s):
34 | s = unicode_to_ascii(s.lower().strip())
35 | s = re.sub(r"([.!?])", r" \1", s)
36 | s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
37 | return s
38 |
39 |
40 | # Turns a unicode string to plain ASCII (http://stackoverflow.com/a/518232/2809427)
41 | def unicode_to_ascii(s):
42 | chars = [c for c in unicodedata.normalize('NFD', s) if unicodedata.category(c) != 'Mn']
43 | char_list = ''.join(chars)
44 | return char_list
45 |
46 |
47 | def validate_language(l):
48 | p = '../data/{}.txt'.format(l)
49 | p = os.path.abspath(p)
50 | print(p)
51 |
52 | if not os.path.exists(p):
53 | url = 'http://www.manythings.org/anki/'
54 | print("{}.txt does not exist in the data directory. Please go to '{}' and download the data set.".format(l, url))
55 | exit(1)
56 |
57 |
58 | def validate_language_params(l):
59 | is_missing = (not os.path.exists('../data/attention_params_{}'.format(l))
60 | or not os.path.exists('../data/decoder_params_{}'.format(l))
61 | or not os.path.exists('../data/encoder_params_{}'.format(l)))
62 |
63 | if is_missing:
64 | print("Model params for language '{}' do not exist in the data directory. Please train a new model for this language.".format(l))
65 | exit(1)
66 |
--------------------------------------------------------------------------------
/src/language.py:
--------------------------------------------------------------------------------
1 | class Language:
2 | sos_token = 0
3 | eos_token = 1
4 |
5 | def __init__(self, name):
6 | self.name = name
7 | self.word2index = {}
8 | self.word2count = {}
9 | self.index2word = {0: '', 1: '', '': 2}
10 | self.n_words = len(self.index2word)
11 |
12 | def index_words(self, sentence):
13 | for word in sentence.split(' '):
14 | self.index_word(word)
15 |
16 | def index_word(self, word):
17 | if word not in self.word2index:
18 | self.word2index[word] = self.n_words
19 | self.word2count[word] = 1
20 | self.index2word[self.n_words] = word
21 | self.n_words += 1
22 | else:
23 | self.word2count[word] += 1
24 |
--------------------------------------------------------------------------------
/src/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import etl
3 | import helpers
4 | import random
5 | import time
6 | import torch
7 | import torch.nn as nn
8 | import torch.optim as optim
9 | from torch.autograd import Variable
10 | from attention_decoder import AttentionDecoderRNN
11 | from encoder import EncoderRNN
12 |
13 | # Parse argument for language to train
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument('language')
16 | args = parser.parse_args()
17 | helpers.validate_language(args.language)
18 |
19 | teacher_forcing_ratio = .5
20 | clip = 5.
21 |
22 |
23 | def train(input_var, target_var, encoder, decoder, encoder_opt, decoder_opt, criterion):
24 | # Initialize optimizers and loss
25 | encoder_opt.zero_grad()
26 | decoder_opt.zero_grad()
27 | loss = 0
28 |
29 | # Get input and target seq lengths
30 | target_length = target_var.size()[0]
31 |
32 | # Run through encoder
33 | encoder_hidden = encoder.init_hidden()
34 | encoder_outputs, encoder_hidden = encoder(input_var, encoder_hidden)
35 |
36 | # Prepare input and output variables
37 | decoder_input = Variable(torch.LongTensor([0]))
38 | decoder_input = decoder_input.cuda()
39 | decoder_context = Variable(torch.zeros(1, decoder.hidden_size))
40 | decoder_context = decoder_context.cuda()
41 | decoder_hidden = encoder_hidden
42 |
43 | # Scheduled sampling
44 | use_teacher_forcing = random.random() < teacher_forcing_ratio
45 | if use_teacher_forcing:
46 | # Feed target as the next input
47 | for di in range(target_length):
48 | decoder_output, decoder_context, decoder_hidden, decoder_attention = decoder(decoder_input,
49 | decoder_context,
50 | decoder_hidden,
51 | encoder_outputs)
52 | loss += criterion(decoder_output[0], target_var[di])
53 | decoder_input = target_var[di]
54 | else:
55 | # Use previous prediction as next input
56 | for di in range(target_length):
57 | decoder_output, decoder_context, decoder_hidden, decoder_attention = decoder(decoder_input,
58 | decoder_context,
59 | decoder_hidden,
60 | encoder_outputs)
61 | loss += criterion(decoder_output[0], target_var[di])
62 |
63 | topv, topi = decoder_output.data.topk(1)
64 | ni = topi[0][0]
65 |
66 | decoder_input = Variable(torch.LongTensor([[ni]]))
67 | decoder_input = decoder_input.cuda()
68 |
69 | if ni == 1:
70 | break
71 |
72 | # Backpropagation
73 | loss.backward()
74 | torch.nn.utils.clip_grad_norm(encoder.parameters(), clip)
75 | torch.nn.utils.clip_grad_norm(decoder.parameters(), clip)
76 | encoder_opt.step()
77 | decoder_opt.step()
78 |
79 | return loss.data[0] / target_length
80 |
81 | input_lang, output_lang, pairs = etl.prepare_data(args.language)
82 |
83 | attn_model = 'general'
84 | hidden_size = 500
85 | n_layers = 2
86 | dropout_p = 0.05
87 |
88 | # Initialize models
89 | encoder = EncoderRNN(input_lang.n_words, hidden_size, n_layers)
90 | decoder = AttentionDecoderRNN(attn_model, hidden_size, output_lang.n_words, n_layers, dropout_p=dropout_p)
91 |
92 | # Move models to GPU
93 | encoder.cuda()
94 | decoder.cuda()
95 |
96 | # Initialize optimizers and criterion
97 | learning_rate = 0.0001
98 | encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
99 | decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate)
100 | criterion = nn.NLLLoss()
101 |
102 | # Configuring training
103 | n_epochs = 100000
104 | plot_every = 200
105 | print_every = 1000
106 |
107 | # Keep track of time elapsed and running averages
108 | start = time.time()
109 | plot_losses = []
110 | print_loss_total = 0 # Reset every print_every
111 | plot_loss_total = 0 # Reset every plot_every
112 |
113 | # Begin training
114 | for epoch in range(1, n_epochs + 1):
115 |
116 | # Get training data for this cycle
117 | training_pair = etl.variables_from_pair(random.choice(pairs), input_lang, output_lang)
118 | input_variable = training_pair[0]
119 | target_variable = training_pair[1]
120 |
121 | # Run the train step
122 | loss = train(input_variable, target_variable, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion)
123 |
124 | # Keep track of loss
125 | print_loss_total += loss
126 | plot_loss_total += loss
127 |
128 | if epoch == 0:
129 | continue
130 |
131 | if epoch % print_every == 0:
132 | print_loss_avg = print_loss_total / print_every
133 | print_loss_total = 0
134 | time_since = helpers.time_since(start, epoch / n_epochs)
135 | print('%s (%d %d%%) %.4f' % (time_since, epoch, epoch / n_epochs * 100, print_loss_avg))
136 |
137 | if epoch % plot_every == 0:
138 | plot_loss_avg = plot_loss_total / plot_every
139 | plot_losses.append(plot_loss_avg)
140 | plot_loss_total = 0
141 |
142 |
143 | # Save our models
144 | torch.save(encoder.state_dict(), '../data/encoder_params_{}'.format(args.language))
145 | torch.save(decoder.state_dict(), '../data/decoder_params_{}'.format(args.language))
146 | torch.save(decoder.attention.state_dict(), '../data/attention_params_{}'.format(args.language))
147 |
148 | # Plot loss
149 | helpers.show_plot(plot_losses)
150 |
--------------------------------------------------------------------------------