├── .gitignore ├── README.md ├── attn.py ├── data.py ├── evaluate.py ├── model.py ├── service.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.swp 2 | *.pyc 3 | *.pt 4 | data/ 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch Seq2Seq Intent Parsing 2 | 3 | Reframing intent parsing as a human - machine translation task. Work in progress successor to [torch-seq2seq-intent-parsing](https://github.com/spro/torch-seq2seq-intent-parsing) 4 | 5 | ## The command language 6 | 7 | This is a simple command language developed for the "home assistant" [Maia](https://github.com/withmaia) living in my apartment. She's designed as a collection of microservices with services for lights (Hue), switches (WeMo), and info such as weather and market prices. 8 | 9 | A command consists of a "service", a "method", and some number of arguments. 10 | 11 | ``` 12 | lights setState office_light on 13 | switches getState teapot 14 | weather getWeather "San Francisco" 15 | price getPrice TSLA 16 | ``` 17 | 18 | These can be represented with variable placeholders: 19 | 20 | ``` 21 | lights setState $device $state 22 | switches getState $device 23 | weather getWeather $location 24 | price getPrice $symbol 25 | ``` 26 | 27 | We can imagine a bunch of human sentences that would map to a single command: 28 | 29 | ``` 30 | "Turn the office light on." 31 | "Please turn on the light in the office." 32 | "Maia could you set the office light on, thank you." 33 | ``` 34 | 35 | Which could similarly be represented with placeholders. 36 | 37 | ## TODO: Specific vs. freeform variables 38 | 39 | A shortcoming of the approach so far is that the model has to learn translations of specific values, for example mapping all of the device names to their equivalent `device_name`. If we added a "basement light" the model would have no `basement_light` in the output vocabulary unless it was re-trained. 40 | 41 | The bigger the potential input space, the more obvious the problem - consider the `getWeather` command, where the model would need to be trained with every possible location we might ask about. Worse yet, consider a `playMusic` command that could take any song or artist name... 42 | 43 | 44 | This can be solved with a technique which I have [implemented in Torch here](https://github.com/spro/torch-seq2seq-intent-parsing). The training pairs have "variable placeholders" in the output translation, which the model generates during an intial pass. Then the network fills in the values of these placeholders with an additional pass over the input. 45 | 46 | ![](https://camo.githubusercontent.com/4125995f183d3158103b46eeb5ffdea4eef0ef52/68747470733a2f2f692e696d6775722e636f6d2f56316c747668492e706e67) 47 | -------------------------------------------------------------------------------- /attn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | 6 | USE_CUDA = False 7 | 8 | class Attn(nn.Module): 9 | def __init__(self, method, hidden_size): 10 | super(Attn, self).__init__() 11 | 12 | self.method = method 13 | self.hidden_size = hidden_size 14 | 15 | if self.method == 'general': 16 | self.attn = nn.Linear(self.hidden_size, hidden_size) 17 | 18 | elif self.method == 'concat': 19 | self.attn = nn.Linear(self.hidden_size * 2, hidden_size) 20 | self.other = nn.Parameter(torch.FloatTensor(1, hidden_size)) 21 | 22 | def forward(self, hidden, encoder_outputs): 23 | seq_len = len(encoder_outputs) 24 | 25 | # Create variable to store attention energies 26 | attn_energies = Variable(torch.zeros(seq_len)) # B x 1 x S 27 | if USE_CUDA: attn_energies = attn_energies.cuda() 28 | 29 | # Calculate energies for each encoder output 30 | for i in range(seq_len): 31 | attn_energies[i] = self.score(hidden, encoder_outputs[i]) 32 | 33 | # Normalize energies to weights in range 0 to 1, resize to 1 x 1 x seq_len 34 | return F.softmax(attn_energies).unsqueeze(0).unsqueeze(0) 35 | 36 | def score(self, hidden, encoder_output): 37 | 38 | if self.method == 'dot': 39 | energy = hidden.dot(encoder_output) 40 | return energy 41 | 42 | elif self.method == 'general': 43 | energy = self.attn(encoder_output) 44 | energy = hidden.dot(energy) 45 | return energy 46 | 47 | elif self.method == 'concat': 48 | energy = self.attn(torch.cat((hidden, encoder_output), 1)) 49 | energy = self.other.dot(energy) 50 | return energy 51 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import random 4 | import re 5 | 6 | # ## Generating training data 7 | 8 | templates = [ 9 | ("lights setState $light_name $light_state", [ 10 | "~turn $light_state $light_name", 11 | "~turn $light_name $light_state", 12 | ]), 13 | ("lights setState $light_name $light_amount", [ 14 | "~turn $light_name $light_amount", 15 | ]), 16 | ("lights setStates $group_name $light_state", [ 17 | "~turn $light_state $group_name", 18 | "~turn $group_name $light_state", 19 | ]), 20 | ("lights setStates $group_name $light_amount", [ 21 | "~turn $group_name $light_amount", 22 | ]), 23 | ("lights getState $light_name", [ 24 | "is $light_name on", 25 | ]), 26 | ("lights getStates $group_name", [ 27 | "are $group_name on", 28 | ]), 29 | ("music setVolume $volume", [ 30 | "~turn the music $volume", 31 | "~turn it $volume", 32 | ]), 33 | ("time getTime", [ 34 | "time", 35 | "what time is it", 36 | "~whatis the time", 37 | ]), 38 | ("price getPrice $asset", [ 39 | "$asset", 40 | "how much is $asset", 41 | "~price of $asset", 42 | "~whatis the ~price of $asset", 43 | ]), 44 | ("weather getWeather $location", [ 45 | "tell me the weather in $location", 46 | "~whatis it like in $location", 47 | ]), 48 | ("greeting", [ 49 | "hi", "hello", "how are you", "what's up", "hey maia", 50 | ]), 51 | ("thanks", [ 52 | "thanks", "thank you", "thank you so much", "thx", "you're great", 53 | ]), 54 | ] 55 | 56 | variables = { 57 | "$light_name": [ 58 | ("office_light", ["the office light", "the light in the office"]), 59 | ("kitchen_light", ["the kitchen light", "the light in the kitchen"]), 60 | ("living_room_light", ["the living room light", "the light in the living room", "the light in the den"]), 61 | ("outside_light", ["the outside light", "the outdoor light", "the light outside", "the porch light"]), 62 | ], 63 | "$group_name": [ 64 | ("all_lights", ["all the lights", "everything"]), 65 | ("bedroom_lights", ["my lights", "the bedroom lights", "the lights", "the lights in my room"]), 66 | ], 67 | "$light_state": [ 68 | ("on", ["on"]), 69 | ("off", ["off", "out"]), 70 | ("blue", ["blue"]), 71 | ("green", ["green"]), 72 | ("red", ["red"]), 73 | ("purple", ["purple"]), 74 | ("orange", ["orange"]), 75 | ("white", ["white", "normal"]), 76 | ], 77 | "$light_amount": [ 78 | ("low", ["low", "dim", "dark"]), 79 | ("high", ["high", "bright"]), 80 | ("down", ["down", "lower", "darker"]), 81 | ("up", ["up", "brighter"]), 82 | ], 83 | "$volume": [ 84 | ("high", ["high", "loud"]), 85 | ("low", ["low", "quiet", "down"]), 86 | ("up", ["up", "louder"]), 87 | ("down", ["down", "quieter"]), 88 | ], 89 | "$location": [ 90 | ("san_francisco", ["sf", "san francisco", "the city"]), 91 | ("new_hampshire", ["nh", "new hampshire", "the northeast"]), 92 | ], 93 | "$asset": [ 94 | ("btc", ["btc", "bitcoin", "bitcoins"]), 95 | ("eth", ["eth", "ethereum", "etherium", "ether"]), 96 | ("usd", ["usd", "dollars", "us dollars", "the fed"]), 97 | ("pesos", ["pesos", "mexican dollars"]), 98 | ], 99 | } 100 | 101 | synonyms = { 102 | "~turn": ["turn", "set", "make", "put", "change"], 103 | "~whatis": ["what is", "what's", "whats", "tell me", "tell us", "tell me about", "what about", "how about", "show me"], 104 | "~price": ["price", "value", "exchange rate", "dollar amount"], 105 | } 106 | 107 | prefixes = ["please", "pls", "hey maia", "hi", "could you", "would you", "hey", "yo", "ey", "oy", "excuse me please"] 108 | suffixes = ["thanks", "thank you", "please", "plz", "pls", "plox", "ok", "thank you so much"] 109 | 110 | # Choose a random pair of templates (output, input) 111 | def choose_templates(): 112 | output_template, input_templates = random.choice(templates) 113 | input_template = random.choice(input_templates) 114 | return output_template, input_template 115 | 116 | input_template, output_template = choose_templates() 117 | print('input template =', input_template) 118 | print('output template =', output_template) 119 | 120 | # We'll assume that all the variables in the input template are used in the output template. 121 | 122 | # Choose variable values to fill a template with (output, input) 123 | def choose_variables(template): 124 | variable_names = [word for word in template.split(' ') if word[0] == '$'] 125 | input_variables = {} 126 | output_variables = {} 127 | for variable_name in variable_names: 128 | variable = random.choice(variables[variable_name]) 129 | output_variables[variable_name] = variable[0] 130 | input_variables[variable_name] = random.choice(variable[1]) 131 | return output_variables, input_variables 132 | 133 | input_variables, output_variables = choose_variables(input_template) 134 | print('input variables =', input_variables) 135 | print('output variables =', output_variables) 136 | 137 | def fill_template(template, template_variables): 138 | filled = [] 139 | for word in template.split(' '): 140 | # Choose variable 141 | if word[0] == '$': 142 | filled.append(template_variables[word]) 143 | # Choose synonym 144 | elif word[0] == '~': 145 | filled.append(random.choice(synonyms[word])) 146 | # Regular word 147 | else: 148 | filled.append(word) 149 | return ' '.join(filled) 150 | 151 | PREFIX_PROB = 0.3 152 | SUFFIX_PROB = 0.3 153 | 154 | def add_fixes(sentence): 155 | if random.random() < PREFIX_PROB: 156 | sentence = random.choice(prefixes) + ' ' + sentence 157 | if random.random() < SUFFIX_PROB: 158 | sentence += ' ' + random.choice(suffixes) 159 | return sentence 160 | 161 | def random_training_pair(): 162 | output_template, input_template = choose_templates() 163 | output_variables, input_variables = choose_variables(input_template) 164 | 165 | output_string = fill_template(output_template, output_variables) 166 | input_string = fill_template(input_template, input_variables) 167 | input_string = add_fixes(input_string) 168 | 169 | return input_string, output_string 170 | 171 | for i in range(10): 172 | print('\n', random_training_pair()) 173 | 174 | # ## Keeping track of the output vocabulary 175 | 176 | SOS_token = 0 177 | EOS_token = 1 178 | 179 | def tokenize_sentence(s): 180 | s = re.sub('[^\w\s]', '', s) 181 | s = re.sub('\s+', ' ', s) 182 | return s.strip().split(' ') 183 | 184 | class DictionaryLang(): 185 | def __init__(self): 186 | self.word2index = {} 187 | self.word2count = {} 188 | self.index2word = {0: "SOS", 1: "EOS"} 189 | self.size = 2 # Count SOS and EOS 190 | 191 | def __str__(self): 192 | return "%s(size = %d)" % (self.__class__.__name__, self.size) 193 | 194 | def add_word(self, word): 195 | if word not in self.word2index: 196 | self.word2index[word] = self.size 197 | self.word2count[word] = 1 198 | self.index2word[self.size] = word 199 | self.size += 1 200 | else: 201 | self.word2count[word] += 1 202 | 203 | def get_word(self, word): 204 | return self.word2index[word] 205 | 206 | def indexes_from_sentence(self, sentence): 207 | return [self.get_word(word) for word in tokenize_sentence(sentence)] 208 | 209 | def variable_from_sentence(self, sentence): 210 | indexes = self.indexes_from_sentence(sentence) 211 | indexes.append(EOS_token) 212 | return Variable(torch.LongTensor(indexes).view(-1, 1)) 213 | 214 | # First turn the generated data into Lang for input (english) and output (command) languages 215 | 216 | output_lang = DictionaryLang() 217 | 218 | def add_words(lang, template): 219 | for word in template.split(' '): 220 | if word[0] != '$': 221 | lang.add_word(word) 222 | 223 | # Add words from templates 224 | 225 | for output_template, input_templates in templates: 226 | add_words(output_lang, output_template) 227 | 228 | # Add values of variables 229 | 230 | for variable_name in variables: 231 | for output_variable, input_variables in variables[variable_name]: 232 | add_words(output_lang, output_variable) 233 | 234 | print("output lang = %s" % output_lang) 235 | 236 | # ## Using word vectors for the input vocabulary 237 | 238 | from torchtext.vocab import load_word_vectors 239 | 240 | class GloVeLang: 241 | def __init__(self, size): 242 | self.size = size 243 | glove_dict, glove_arr, glove_size = load_word_vectors('data/', 'glove.twitter.27B', size) 244 | self.glove_dict = glove_dict 245 | self.glove_arr = glove_arr 246 | 247 | def __str__(self): 248 | return "%s(size = %d)" % (self.__class__.__name__, self.size) 249 | 250 | def vector_from_word(self, word): 251 | if word in self.glove_dict: 252 | return self.glove_arr[self.glove_dict[word]] 253 | else: 254 | return torch.zeros(self.size) 255 | 256 | def variable_from_sentence(self, sentence): 257 | words = tokenize_sentence(sentence.lower()) 258 | tensor = torch.zeros(len(words), 1, self.size) 259 | for wi in range(len(words)): 260 | word = words[wi] 261 | tensor[wi][0] = self.vector_from_word(word) 262 | return Variable(tensor) 263 | 264 | input_lang = GloVeLang(100) 265 | print("input lang = %s" % input_lang) 266 | input_lang.variable_from_sentence('turn on the light').size() 267 | 268 | # Now we can use these Lang objects to create tensors from sentences: 269 | 270 | def variables_from_pair(pair): 271 | input_variable = input_lang.variable_from_sentence(pair[0]) 272 | target_variable = output_lang.variable_from_sentence(pair[1]) 273 | return (input_variable, target_variable) 274 | 275 | def generate_training_pairs(n_iters): 276 | pairs = [] 277 | for i in range(n_iters): 278 | pairs.append(variables_from_pair(random_training_pair())) 279 | return pairs 280 | 281 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | from data import * 2 | from model import * 3 | 4 | MIN_PROB = -0.1 5 | 6 | # # Evaluating the trained model 7 | 8 | def evaluate(encoder, decoder, sentence, max_length=MAX_LENGTH): 9 | encoder.train(False) 10 | decoder.train(False) 11 | 12 | input_variable = input_lang.variable_from_sentence(sentence) 13 | input_length = input_variable.size()[0] 14 | 15 | encoder_outputs, encoder_hidden = encoder(input_variable) 16 | 17 | decoder_input = Variable(torch.LongTensor([[SOS_token]])) # SOS 18 | decoder_hidden = encoder_hidden 19 | 20 | decoded_words = [] 21 | seq_length = input_variable.size(0) 22 | decoder_attentions = torch.zeros(max_length, seq_length) 23 | 24 | total_prob = 0 25 | 26 | for di in range(max_length): 27 | decoder_output, decoder_hidden, decoder_attention = decoder(decoder_input, decoder_hidden, encoder_outputs) 28 | decoder_attentions[di] = decoder_attention.data[-1] 29 | topv, topi = decoder_output.data.topk(1) 30 | ni = topi[0][0] 31 | if ni == EOS_token: 32 | break 33 | else: 34 | total_prob += topv[0][0] 35 | decoded_words.append(output_lang.index2word[ni]) 36 | decoder_input = Variable(torch.LongTensor([[ni]])) 37 | 38 | encoder.train(True) 39 | decoder.train(True) 40 | 41 | return decoded_words, total_prob, decoder_attentions[:di+1] 42 | 43 | test_sentences = [ 44 | 'um can you turn on the office light', 45 | 'hey maia please turn off all the lights thanks', 46 | 'how are you today', 47 | 'thank you', 48 | 'please make the music loud', 49 | 'whats the weather in minnesota', 50 | 'whats the weather in sf', 51 | 'are you on', 52 | 'is my light on' 53 | ] 54 | 55 | def evaluate_tests(encoder, decoder, ): 56 | for test_sentence in test_sentences: 57 | command, prob, attn = evaluate(encoder, decoder, test_sentence) 58 | command_str = ' '.join(command) 59 | if prob < MIN_PROB: 60 | command_str += ' (???)' 61 | print(test_sentence, '\n %.4f : %s' % (prob, command_str)) 62 | 63 | if __name__ == '__main__': 64 | import sys 65 | input = sys.argv[1] 66 | print('input', input) 67 | 68 | encoder = torch.load('seq2seq-encoder.pt') 69 | decoder = torch.load('seq2seq-decoder.pt') 70 | 71 | command, prob, attn = evaluate(encoder, decoder, input) 72 | if prob > -0.05: 73 | print(prob, command) 74 | else: 75 | print(prob, "UNKNOWN") 76 | 77 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | 6 | MAX_LENGTH = 20 7 | 8 | from attn import * 9 | 10 | # # Defining the models 11 | 12 | class EncoderRNN(nn.Module): 13 | def __init__(self, input_size, hidden_size, n_layers=1): 14 | super(EncoderRNN, self).__init__() 15 | 16 | self.input_size = input_size 17 | self.hidden_size = hidden_size 18 | self.n_layers = n_layers 19 | 20 | self.embedding = nn.Linear(input_size, hidden_size) 21 | self.gru = nn.GRU(hidden_size, hidden_size, n_layers) 22 | 23 | def forward(self, input): 24 | seq_len = input.size(0) 25 | batch_size = input.size(1) 26 | embedded = self.embedding(input.view(seq_len * batch_size, -1)) # Process seq x batch at once 27 | output = embedded.view(seq_len, batch_size, -1) # Resize back to seq x batch for RNN 28 | output, hidden = self.gru(output) 29 | return output, hidden 30 | 31 | # ## Decoder with Attention 32 | class AttnDecoderRNN(nn.Module): 33 | def __init__(self, attn_model, hidden_size, output_size, n_layers=1, dropout_p=0.1): 34 | super(AttnDecoderRNN, self).__init__() 35 | 36 | # Keep parameters for reference 37 | self.attn_model = attn_model 38 | self.hidden_size = hidden_size 39 | self.output_size = output_size 40 | self.n_layers = n_layers 41 | self.dropout_p = dropout_p 42 | 43 | # Define layers 44 | self.embedding = nn.Embedding(output_size, hidden_size) 45 | self.gru = nn.GRU(hidden_size, hidden_size, n_layers, dropout=dropout_p) 46 | self.out = nn.Linear(hidden_size * 2, output_size) 47 | 48 | # Choose attention model 49 | if attn_model != 'none': 50 | self.attn = Attn(attn_model, hidden_size) 51 | 52 | def forward(self, word_input, last_hidden, encoder_outputs): 53 | # Note: we run this one step at a time 54 | 55 | # Get the embedding of the current input word (last output word) 56 | word_embedded = self.embedding(word_input).view(1, 1, -1) # S=1 x B x N 57 | 58 | # Combine embedded input word and last context, run through RNN 59 | rnn_output, hidden = self.gru(word_embedded, last_hidden) 60 | 61 | # Calculate attention from current RNN state and all encoder outputs; apply to encoder outputs 62 | attn_weights = self.attn(rnn_output.squeeze(0), encoder_outputs) 63 | context = attn_weights.bmm(encoder_outputs.transpose(0, 1)) # B x 1 x N 64 | 65 | # Final output layer (next word prediction) using the RNN hidden state and context vector 66 | rnn_output = rnn_output.squeeze(0) # S=1 x B x N -> B x N 67 | context = context.squeeze(1) # B x S=1 x N -> B x N 68 | output = F.log_softmax(self.out(torch.cat((rnn_output, context), 1))) 69 | 70 | # Return final output, hidden state, and attention weights (for visualization) 71 | return output, hidden, attn_weights 72 | 73 | if __name__ == '__main__': 74 | print("Testing models...") 75 | n_layers = 2 76 | input_size = 10 77 | hidden_size = 50 78 | output_size = 10 79 | encoder = EncoderRNN(input_size, hidden_size, n_layers=n_layers) 80 | decoder = AttnDecoderRNN('dot', hidden_size, output_size, n_layers=n_layers) 81 | 82 | # Test encoder 83 | inp = Variable(torch.rand(5, 1, input_size)) 84 | encoder_outputs, encoder_hidden = encoder(inp) 85 | print('encoder_outputs', encoder_outputs.size()) 86 | print('encoder_hidden', encoder_hidden.size()) 87 | 88 | # Test encoder 89 | decoder_input = Variable(torch.LongTensor([[0]])) # SOS 90 | decoder_hidden = encoder_hidden 91 | decoder_output, decoder_hidden, decoder_attention = decoder(decoder_input, decoder_hidden, encoder_outputs) 92 | -------------------------------------------------------------------------------- /service.py: -------------------------------------------------------------------------------- 1 | import somata 2 | from evaluate import * 3 | 4 | print("Loading encoder...") 5 | encoder = torch.load('seq2seq-encoder.pt') 6 | print("Loading decoder...") 7 | decoder = torch.load('seq2seq-decoder.pt') 8 | print("Loaded models.") 9 | 10 | def parse(body, cb): 11 | print('[parse]', body) 12 | parsed, prob, attn = evaluate(encoder, decoder, body) 13 | print(parsed, prob) 14 | cb({'parsed': parsed, 'prob': prob}) 15 | 16 | service = somata.Service('maia:parser', {'parse': parse}, {'bind_port': 7181}) 17 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | from torch import optim 5 | import torch.nn.functional as F 6 | 7 | import math 8 | import time 9 | import os 10 | import argparse 11 | 12 | import sconce 13 | 14 | # Parse command line arguments 15 | argparser = argparse.ArgumentParser() 16 | argparser.add_argument('--n_epochs', type=int, default=200) 17 | argparser.add_argument('--n_iters', type=int, default=200) 18 | argparser.add_argument('--hidden_size', type=int, default=100) 19 | argparser.add_argument('--n_layers', type=int, default=2) 20 | argparser.add_argument('--dropout_p', type=float, default=0.1) 21 | argparser.add_argument('--learning_rate', type=float, default=0.001) 22 | args = argparser.parse_args() 23 | 24 | job = sconce.Job('seq2seq-intent-parsing', vars(args)) 25 | job.log_every = args.n_iters * 10 26 | 27 | from data import * 28 | from model import * 29 | from evaluate import * 30 | 31 | # # Training 32 | 33 | def train(input_variable, target_variable): 34 | encoder_optimizer.zero_grad() 35 | decoder_optimizer.zero_grad() 36 | 37 | input_length = input_variable.size()[0] 38 | target_length = target_variable.size()[0] 39 | 40 | encoder_outputs, encoder_hidden = encoder(input_variable) 41 | 42 | decoder_input = Variable(torch.LongTensor([[SOS_token]])) 43 | decoder_hidden = encoder_hidden 44 | 45 | loss = 0 46 | 47 | for di in range(target_length): 48 | decoder_output, decoder_hidden, decoder_attention = decoder(decoder_input, decoder_hidden, encoder_outputs) 49 | loss += criterion(decoder_output[0], target_variable[di]) 50 | decoder_input = target_variable[di] # Teacher forcing 51 | 52 | loss.backward() 53 | 54 | encoder_optimizer.step() 55 | decoder_optimizer.step() 56 | 57 | return loss.data[0] / target_length 58 | 59 | def save_model(model, filename): 60 | torch.save(model, filename) 61 | print('Saved %s as %s' % (model.__class__.__name__, filename)) 62 | 63 | def save(): 64 | save_model(encoder, 'seq2seq-encoder.pt') 65 | save_model(decoder, 'seq2seq-decoder.pt') 66 | 67 | encoder = EncoderRNN(input_lang.size, args.hidden_size) 68 | decoder = AttnDecoderRNN('dot', args.hidden_size, output_lang.size, args.n_layers, dropout_p=args.dropout_p) 69 | 70 | encoder_optimizer = optim.Adam(encoder.parameters(), lr=args.learning_rate) 71 | decoder_optimizer = optim.Adam(decoder.parameters(), lr=args.learning_rate) 72 | criterion = nn.NLLLoss() 73 | 74 | try: 75 | print("Training for %d epochs..." % args.n_epochs) 76 | 77 | for epoch in range(args.n_epochs): 78 | training_pairs = generate_training_pairs(args.n_iters) 79 | 80 | for i in range(args.n_iters): 81 | input_variable = training_pairs[i][0] 82 | target_variable = training_pairs[i][1] 83 | loss = train(input_variable, target_variable) 84 | 85 | job.record((args.n_iters * epoch) + i, loss) 86 | 87 | evaluate_tests(encoder, decoder) 88 | 89 | print("Saving...") 90 | save() 91 | 92 | except KeyboardInterrupt: 93 | print("Saving before quit...") 94 | save() 95 | 96 | --------------------------------------------------------------------------------