├── README.md ├── data_utils.py ├── eval.py ├── model.py ├── test ├── __init__.py └── test_vocab.py ├── train.py ├── vis_attentions.ipynb └── vocab.py /README.md: -------------------------------------------------------------------------------- 1 | # End-to-End Goal-Oriented Dialog 2 | 3 | This repo contains a PyTorch implementation of the End-to-End Memory Network as described in the paper *[Learning end-to-end goal-oriented dialog.](https://arxiv.org/pdf/1605.07683.pdf)* Also there is a code for replicating the results on T1-T5 bAbI tasks and a jupyter notebook file for visualizing memory attentions of a learned model. 4 | 5 | 6 | ### Requirements 7 | 8 | ``` 9 | - python 3.6 10 | - pytorch 0.3.0 11 | ``` 12 | 13 | ### Running 14 | 15 | First you need to download [bAbI dialog dataset](https://fb-public.box.com/s/chnq60iivzv5uckpvj2n2vijlyepze6w). 16 | 17 | To run the training, use the following pattern: 18 | ``` 19 | python train.py /path/to/dataset/train_set_file.txt /path/to/dataset/dev_set_file.txt /path/to/dataset/candidates_file.txt 20 | ``` 21 | There are different command line arguments for adjusting model and training parameters. For complete list, run 22 | ``` 23 | python train.py -h 24 | ``` 25 | For evaluation, use: 26 | ``` 27 | python eval.py /path/to/saved/model/dir /path/to/dataset/test_set_file.txt 28 | ``` 29 | -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import re 3 | 4 | import torch 5 | 6 | from vocab import Vocab 7 | 8 | 9 | TOKEN_PATTERN = re.compile(r"[']*[^'.,?! ]+|[.,?!]") 10 | 11 | 12 | def tokenize(sent): 13 | 14 | if sent == '': 15 | return [sent] 16 | 17 | return re.findall(TOKEN_PATTERN, sent) 18 | 19 | 20 | def dialog_reader(path): 21 | """Reads dialogs that are given in Facebook bAbI dialog format.""" 22 | 23 | dialogs = [] 24 | 25 | with open(path) as f: 26 | for line in f: 27 | 28 | if line in ['\n', '\r\n']: 29 | 30 | yield dialogs 31 | 32 | dialogs = [] 33 | 34 | elif '\t' in line: 35 | 36 | match = re.search("^\d+ ([^\t]+)\t(.+$)", line) 37 | if not match: 38 | 39 | raise ValueError("Invalid dataset format.") 40 | 41 | if match[2] == '': 42 | raise ValueError("Invalid dataset format: Bot never keeps silence.") 43 | 44 | dialogs.append((match[1], match[2])) 45 | 46 | else: 47 | dialogs.append((line.split(' ', 1)[1][:-1], None)) 48 | 49 | if len(dialogs) > 0: 50 | yield dialogs 51 | 52 | 53 | def build_dialog_vocab(dialog_dataset_path, candidates_path, time_features=1000): 54 | """ 55 | Builds two vocabularies. One contains all dialog words along with some special tokens and the second contains 56 | candidate responses, where the word is a whole sentence. 57 | 58 | :param dialog_dataset_path: Path to the dialog dataset (must be in Facebook bAbI dialog format) 59 | :param candidates_path: Path to the file containing candidate responses. 60 | :param time_features: Number of time features to add to the dialog vocabulary. 61 | :return: tuple containing dialog and candidate response vocabularies, respectively. 62 | """ 63 | 64 | vocab = Vocab() 65 | 66 | # PAD token index must be zero so we add it first. 67 | vocab.add_special_token('') 68 | 69 | vocab.add_special_token('') 70 | 71 | vocab.add_special_token('') 72 | vocab.add_special_token('') 73 | 74 | # adding time features into the vocabulary 75 | for i in range(time_features): 76 | vocab.add_special_token('' % i) 77 | 78 | # adding user spoken words to the vocabulary 79 | for dialog in dialog_reader(dialog_dataset_path): 80 | 81 | for user_utter, _ in dialog: 82 | 83 | for word in tokenize(user_utter): 84 | vocab.add_word(word) 85 | 86 | candidate_vocab = Vocab() 87 | with open(candidates_path) as f: 88 | for line in f: 89 | 90 | sent = line[2:-1] 91 | 92 | candidate_vocab.add_word(sent) 93 | 94 | for word in tokenize(sent): 95 | vocab.add_word(word) 96 | 97 | vocab.make() 98 | candidate_vocab.make() 99 | 100 | return vocab, candidate_vocab 101 | 102 | 103 | def sent2vec(vocab, sent): 104 | """Returns vector representation of a sentence by substituting each word with its index from the vocabulary.""" 105 | 106 | vec = [] 107 | 108 | for word in tokenize(sent): 109 | 110 | idx = vocab.word_to_index(word) 111 | 112 | if idx == -1: 113 | idx = vocab.word_to_index('') 114 | 115 | vec.append(idx) 116 | 117 | return vec 118 | 119 | 120 | def vec2sent(vocab, vec): 121 | """Returns original sentence from its vector representation.""" 122 | 123 | return ' '.join(vocab.index_to_word(idx) for idx in vec) 124 | 125 | 126 | class DialogReader: 127 | """ 128 | Represents an iterator over the mini-batches of data samples for training/evaluation. 129 | 130 | Single data sample is a triple containing the memory (current dialog history), query (current user utterance) 131 | and the label (ground truth bot response), respectively. 132 | 133 | When dealing with mini-batches, data samples are sorted by memory length in advance, so that mini-batches are 134 | approximately same size for computation efficiency. 135 | """ 136 | 137 | def __init__(self, 138 | dialog_data_path, 139 | dialog_vocab, 140 | candidate_vocab, 141 | max_memory_size, 142 | batch_size, 143 | drop_last_batch=False, 144 | shuffle_data=False, 145 | eval_mode=False): 146 | """ 147 | :param dialog_data_path: Path to the dialog dataset. 148 | :param dialog_vocab: The dialog vocabulary (word level). 149 | :param candidate_vocab: The vocabulary of candidate responses (sent. level). 150 | :param max_memory_size: The maximum size of the dialog history. If exceeded, the earliest utterances are dropped. 151 | :param batch_size: The size of mini-batch. 152 | :param drop_last_batch: If the number of data samples isn't divisible by batch_size, the last smaller mini-batch is dropped. 153 | :param shuffle_data: Shuffle mini-batches before returning the iterator. 154 | :param eval_mode: If true, every mini-batch has size 1 (regardless batch_size) and comes with an unique dialog id, 155 | so that mini-batches from the same dialog have same ids. Useful when evaluating per dialog accuracy. 156 | """ 157 | 158 | self._dialog_data_path = dialog_data_path 159 | self._dialog_vocab = dialog_vocab 160 | self._candidate_vocab = candidate_vocab 161 | self._max_memory_size = max_memory_size 162 | self._batch_size = batch_size if not eval_mode else 1 163 | self._drop_last_batch = drop_last_batch 164 | self._shuffle_data = shuffle_data 165 | 166 | # In eval mode batch_size is automatically set to 1, dataset isn't sorted/shuffled, batch comes with dialog id. 167 | self._eval_mode = eval_mode 168 | 169 | self._load_data() 170 | 171 | if not eval_mode: 172 | self._dataset.sort(key=lambda x: len(x[0])) 173 | 174 | self._batches = [] 175 | 176 | batch = [] 177 | for sample in self._dataset: 178 | 179 | batch.append(sample) 180 | 181 | if len(batch) == self._batch_size: 182 | 183 | self._add_batch(batch) 184 | 185 | batch = [] 186 | 187 | if len(batch) > 0 and not self._drop_last_batch: 188 | self._add_batch(batch) 189 | 190 | def _add_batch(self, batch): 191 | 192 | if self._eval_mode: 193 | self._batches.append((batch[0][0], self._batch_to_tensor([batch[0][1]]))) 194 | else: 195 | self._batches.append(self._batch_to_tensor(batch)) 196 | 197 | def _load_data(self): 198 | 199 | # Vectorizing candidate responses. 200 | 201 | candidate_vec_max_len = 0 202 | candidate_vecs = [] 203 | for i in range(len(self._candidate_vocab)): 204 | 205 | sent = self._candidate_vocab.index_to_word(i) 206 | 207 | candidate_vec = [self._dialog_vocab.word_to_index(w) for w in tokenize(sent)] 208 | 209 | candidate_vec_max_len = max(candidate_vec_max_len, len(candidate_vec)) 210 | 211 | candidate_vecs.append(candidate_vec) 212 | 213 | # Creating tensor of (num_candidates, max_candidate_len) size to store all candidate responses. 214 | 215 | self._candidate_vecs = torch.LongTensor(len(candidate_vecs), candidate_vec_max_len).fill_(self._dialog_vocab.word_to_index('')) 216 | 217 | for i in range(len(candidate_vecs)): 218 | self._candidate_vecs[i,:len(candidate_vecs[i])] = torch.LongTensor(candidate_vecs[i]) 219 | 220 | # Building dialog dataset containing (current_meomry, query, label) triples. 221 | 222 | self._dataset = [] 223 | for dialog_i, dialog in enumerate(dialog_reader(self._dialog_data_path)): 224 | 225 | user_utters, bot_utters = zip(*dialog) 226 | 227 | i, tm = 0, 0 228 | memories = [] 229 | while i < len(dialog): 230 | 231 | if bot_utters[i]: 232 | 233 | query = sent2vec(self._dialog_vocab, user_utters[i]) 234 | 235 | label = self._candidate_vocab.word_to_index(bot_utters[i]) 236 | 237 | if self._eval_mode: 238 | self._dataset.append((dialog_i, (memories[:], query, label))) 239 | else: 240 | self._dataset.append((memories[:], query, label)) 241 | 242 | self._write_memory(memories, query, tm, 0) 243 | self._write_memory(memories, sent2vec(self._dialog_vocab, bot_utters[i]), tm + 1, 1) 244 | 245 | i, tm = i + 1, tm + 2 246 | 247 | # Handling 'displaying options' case. 248 | else: 249 | 250 | while not bot_utters[i]: 251 | 252 | self._write_memory(memories, sent2vec(self._dialog_vocab, user_utters[i]), tm, 0) 253 | 254 | i, tm = i + 1, tm + 1 255 | 256 | def _write_memory(self, memories, memory, time, speaker_id): 257 | 258 | memory = self._add_speaker_feature(memory, speaker_id) 259 | memory = self._add_time_feature(memory, time) 260 | 261 | if len(memories) == self._max_memory_size and self._max_memory_size > 0: 262 | del memories[0] 263 | 264 | if self._max_memory_size > 0: 265 | memories.append(memory) 266 | 267 | def _add_speaker_feature(self, vec, speaker_id): 268 | 269 | return [self._dialog_vocab.word_to_index(['', ''][speaker_id])] + vec 270 | 271 | def _add_time_feature(self, vec, time): 272 | 273 | return [self._dialog_vocab.word_to_index('' % time)] + vec 274 | 275 | def _batch_to_tensor(self, batch): 276 | 277 | pad = self._dialog_vocab.word_to_index('') 278 | 279 | memories, queries, labels = zip(*batch) 280 | 281 | batch_size = len(batch) 282 | max_mem_len = max(1, 1, *[len(m) for m in memories]) 283 | max_vec_len = max(1, 1, *[len(v) for m in memories for v in m]) 284 | max_query_len = max(len(q) for q in queries) 285 | 286 | mem_tensor = torch.LongTensor(batch_size, max_mem_len, max_vec_len).fill_(pad) 287 | 288 | query_tensor = torch.LongTensor(batch_size, max_query_len).fill_(pad) 289 | 290 | label_tensor = torch.stack([torch.LongTensor([label]) for label in labels]) 291 | 292 | for i in range(batch_size): 293 | for j in range(len(memories[i])): 294 | mem_tensor[i,j,:len(memories[i][j])] = torch.LongTensor(memories[i][j]) 295 | 296 | query_tensor[i,:len(queries[i])] = torch.LongTensor(queries[i]) 297 | 298 | return mem_tensor, query_tensor, label_tensor 299 | 300 | def __iter__(self): 301 | 302 | if not self._eval_mode and self._shuffle_data: 303 | random.shuffle(self._batches) 304 | 305 | return iter(self._batches) 306 | 307 | def __len__(self): 308 | 309 | return len(self._batches) 310 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import torch 5 | from torch.autograd import Variable 6 | 7 | from data_utils import DialogReader 8 | from model import MemN2N 9 | from vocab import Vocab 10 | 11 | 12 | def calc_accuracy_per_response(model, data_reader, use_cuda): 13 | """ 14 | Calculates per response accuracy, that is, the ratio of correct responses out of all responses. 15 | 16 | :param model: Trained model used for prediction. 17 | :param data_reader: DialogReader instance that provides iterator over samples. 18 | :param use_cuda: If True, calculations will be performed on GPU. 19 | """ 20 | 21 | n_correct = 0 22 | samples_total = 0 23 | 24 | for i_batch, sample_batched in enumerate(data_reader): 25 | 26 | sample, query, label = sample_batched 27 | 28 | if use_cuda: 29 | sample = sample.cuda() 30 | query = query.cuda() 31 | label = label.cuda() 32 | 33 | pred, _ = model(Variable(sample), Variable(query)) 34 | 35 | n_correct += torch.sum(torch.max(pred.data, 1)[1] == label.squeeze(1)) 36 | samples_total += pred.size()[0] 37 | 38 | return n_correct / samples_total 39 | 40 | 41 | def calc_accuracy_per_dialog(model, data_reader): 42 | """ 43 | Calculates per dialog accuracy, that is, the ratio of dialogs where every response is correct out of all dialogs. 44 | 45 | :param model: Trained model used for prediction. 46 | :param data_reader: DialogReader instance that provides iterator over samples. 47 | """ 48 | 49 | acc = dict() 50 | 51 | for dialog_i, sample in data_reader: 52 | 53 | memory, query, label = sample 54 | 55 | pred, _ = model(Variable(memory), Variable(query)) 56 | 57 | pred = (torch.max(pred.data, 1)[1] == label.squeeze(1))[0] 58 | 59 | if dialog_i in acc: 60 | acc[dialog_i][0] += pred 61 | acc[dialog_i][1] += 1 62 | else: 63 | acc[dialog_i] = [pred, 1] 64 | 65 | return sum(1 if i[0] == i[1] else 0 for i in acc.values()) / len(acc) 66 | 67 | 68 | if __name__ == "__main__": 69 | 70 | parser = argparse.ArgumentParser(description='Goal-Oriented Chatbot using End-to-End Memory Networks', formatter_class=argparse.ArgumentDefaultsHelpFormatter) 71 | 72 | parser.add_argument('model_dir', type=str, help='trained model path') 73 | parser.add_argument('test_path', type=str, help='test data path') 74 | 75 | parser.add_argument('--maxmemsize', type=int, metavar='N', default=100, help='memory capacity') 76 | 77 | args = parser.parse_args() 78 | 79 | # loading vocabularies and the trained model 80 | 81 | dialog_vocab = Vocab.load(os.path.join(args.model_dir, 'dialog_vocab')) 82 | candidates_vocab = Vocab.load(os.path.join(args.model_dir, 'candidates_vocab')) 83 | 84 | model = MemN2N.load(os.path.join(args.model_dir, 'model')) 85 | 86 | test_data_reader_per_resp = DialogReader(args.test_path, dialog_vocab, candidates_vocab, args.maxmemsize, 1, False, False, False) 87 | test_data_reader_per_dial = DialogReader(args.test_path, dialog_vocab, candidates_vocab, args.maxmemsize, 1, False, False, True) 88 | 89 | print("Per Response Accuracy: ", calc_accuracy_per_response(model, test_data_reader_per_resp, False)) 90 | print("Per Dialog Accuracy: ", calc_accuracy_per_dialog(model, test_data_reader_per_dial)) 91 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class MemN2N(nn.Module): 10 | """End-2-End Memory Network.""" 11 | 12 | def __init__(self, 13 | mem_cell_size, 14 | vocab_size, 15 | candidate_vecs, 16 | n_hops, 17 | init_std=0.1, 18 | nonlinearity=True): 19 | """ 20 | :param mem_cell_size: Size of the memory cell. 21 | :param vocab_size: Total number words in the vocabulary. 22 | :param candidate_vecs: Tensor containing vectors (vector of word indices) for each candidate response. 23 | :param n_hops: Number of iterative memory accesses. 24 | :param init_std: Initial std for weight initialization. 25 | :param nonlinearity: If true, performs softmax normalization of attention weights. 26 | """ 27 | 28 | super(MemN2N, self).__init__() 29 | 30 | self.mem_cell_size = mem_cell_size 31 | self.vocab_size = vocab_size 32 | self.candidate_vecs = candidate_vecs 33 | self.n_hops = n_hops 34 | self.init_std = init_std 35 | self.nonlinearity = nonlinearity 36 | 37 | self.query_emb = nn.Embedding(vocab_size, mem_cell_size, padding_idx=0) 38 | self.query_emb.weight.data.normal_(std=init_std) 39 | self.query_emb.weight.data[0] = 0 40 | 41 | self.out_transform = nn.Linear(mem_cell_size, mem_cell_size, bias=False) 42 | self.out_transform.weight.data.normal_(std=init_std) 43 | 44 | self.mem_emb = nn.ModuleList() 45 | for i in range(n_hops): 46 | 47 | mem_emb = nn.Embedding(vocab_size, mem_cell_size, padding_idx=0) 48 | mem_emb.weight.data.normal_(std=init_std) 49 | mem_emb.weight.data[0] = 0 50 | 51 | self.mem_emb.append(mem_emb) 52 | 53 | self.candidate_emb = nn.Embedding(vocab_size, mem_cell_size, padding_idx=0) 54 | self.candidate_emb.weight.data.normal_(std=init_std) 55 | self.candidate_emb.weight.data[0] = 0 56 | 57 | def forward(self, memory, query): 58 | """ 59 | :param memory: torch Variable, containing memory vectors. 60 | :param query: torch Variable, containing query vector. 61 | :return: Pair of log softmax predictions and attention layer activations. 62 | """ 63 | 64 | query_emb = torch.sum(self.query_emb(query), dim=1) 65 | 66 | u = [query_emb] 67 | attns = [] 68 | for hop in range(self.n_hops): 69 | 70 | mem_emb = self.embed_memory(memory, hop) 71 | 72 | attn_weights = torch.bmm(mem_emb, u[-1].unsqueeze(2)) 73 | 74 | if self.nonlinearity: 75 | attn_weights = F.softmax(attn_weights, 1) 76 | 77 | output_tmp = torch.bmm(attn_weights.permute(0, 2, 1), mem_emb).squeeze(1) 78 | output = self.out_transform(output_tmp) 79 | 80 | u.append(u[-1] + output) 81 | attns.append(attn_weights) 82 | 83 | candidate_emb = torch.sum(self.candidate_emb(self.candidate_vecs), dim=1) 84 | 85 | y_pred = candidate_emb @ u[-1].permute(1, 0) 86 | 87 | return F.log_softmax(y_pred, dim=0).permute(1, 0), torch.stack(attns, 3).squeeze(2) 88 | 89 | def embed_memory(self, memory, hop): 90 | 91 | emb = self.mem_emb[hop](memory.view(-1, memory.size()[2])) 92 | 93 | emb = torch.sum(emb.view(*memory.size(), -1), dim=2) 94 | 95 | return emb 96 | 97 | def save(self, path=None): 98 | """ 99 | Saves model state so that it can be restored later. 100 | 101 | :param path: Path to the save directory. 102 | """ 103 | 104 | if not path: 105 | path = os.path.join(os.getcwd(), 'model_' + str(time.time())) 106 | 107 | torch.save({ 108 | 'mem_cell_size': self.mem_cell_size, 109 | 'vocab_size': self.vocab_size, 110 | 'candidate_vecs': self.candidate_vecs, 111 | 'n_hops': self.n_hops, 112 | 'init_std': self.init_std, 113 | 'nonlinearity': self.nonlinearity, 114 | 'state_dict': self.state_dict() 115 | }, path) 116 | 117 | @staticmethod 118 | def load(path, load_weights=True): 119 | """ 120 | Static factory that builds the previously stored model. 121 | 122 | :param path: Path to the saved model. 123 | :param load_weights: If False, model weights (learnable parameters) aren't restored. 124 | """ 125 | 126 | model_params = torch.load(path, map_location=lambda storage, loc: storage.cpu()) 127 | 128 | model = MemN2N(model_params['mem_cell_size'], 129 | model_params['vocab_size'], 130 | model_params['candidate_vecs'], 131 | model_params['n_hops'], 132 | model_params['init_std'], 133 | model_params['nonlinearity']) 134 | 135 | if load_weights: 136 | model.load_state_dict(model_params['state_dict']) 137 | 138 | return model 139 | -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sandrobarna/pytorch_memn2n/db5e6e76c1978e0bf5436ab1e4ec8a1b22e53f8a/test/__init__.py -------------------------------------------------------------------------------- /test/test_vocab.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from vocab import Vocab 4 | 5 | 6 | class VocabTestCase(unittest.TestCase): 7 | 8 | def setUp(self): 9 | 10 | self._vocab = Vocab() 11 | 12 | self._spec_tokens = ['', ''] 13 | self._words = ['Hello', 'world', 'a', 'HELLO WORLD !!!', '2018', ':)'] 14 | 15 | for w in self._spec_tokens: 16 | self._vocab.add_special_token(w) 17 | 18 | for w in self._words: 19 | self._vocab.add_word(w) 20 | 21 | def test_len(self): 22 | 23 | self._vocab.make() 24 | 25 | self.assertEqual(len(self._vocab), len(self._spec_tokens) + len(self._words)) 26 | 27 | def test_word2index_index2word(self): 28 | 29 | self._vocab.make() 30 | 31 | for i, w in enumerate(self._spec_tokens + self._words): 32 | self.assertEqual(self._vocab.word_to_index(w), i) 33 | self.assertEqual(self._vocab.index_to_word(i), w) 34 | 35 | self.assertEqual(self._vocab.word_to_index('Unknown word'), -1) 36 | 37 | self.assertIsNone(self._vocab.index_to_word(-1)) 38 | self.assertIsNone(self._vocab.index_to_word(297196412)) 39 | 40 | def test_limited_vocab_size(self): 41 | 42 | for w in ['frequent word', 'frequent word']: 43 | self._vocab.add_word(w) 44 | 45 | n_words = 5 46 | 47 | self._vocab.make(n_words) 48 | 49 | self.assertEqual(len(self._vocab), n_words) 50 | 51 | for i, w in enumerate(self._spec_tokens + ['frequent word'] + self._words): 52 | self.assertEqual(self._vocab.word_to_index(w), i) 53 | self.assertEqual(self._vocab.index_to_word(i), w) 54 | 55 | if i == n_words - 1: 56 | break 57 | 58 | def test_spec_token_occurring_in_words(self): 59 | 60 | self._vocab.add_special_token('common word') 61 | self._vocab.add_word('common word') 62 | 63 | with self.assertRaises(ValueError): 64 | self._vocab.make() 65 | 66 | def test_spec_token_added_twice(self): 67 | 68 | self._vocab.add_special_token('token') 69 | self._vocab.add_special_token('token') 70 | 71 | with self.assertRaises(ValueError): 72 | self._vocab.make() -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch import optim 8 | from torch.autograd import Variable 9 | from torch.optim.lr_scheduler import StepLR 10 | 11 | from data_utils import DialogReader 12 | from data_utils import build_dialog_vocab 13 | from eval import calc_accuracy_per_response 14 | from model import MemN2N 15 | from vocab import Vocab 16 | 17 | 18 | def to_time(seconds): 19 | 20 | seconds = int(seconds) 21 | h = seconds / 3600 22 | m = seconds / 60 % 60 23 | seconds %= 60 24 | return '%dh %dm %ds' % (h, m, seconds) 25 | 26 | 27 | def train_step(model, optimizer, loss_criterion, max_grad_norm, memory, query, target, use_cuda): 28 | 29 | optimizer.zero_grad() 30 | 31 | memory = memory.cuda() if use_cuda else memory 32 | 33 | query = query.cuda() if use_cuda else query 34 | 35 | target = Variable(target).squeeze(1) 36 | 37 | pred, _ = model(Variable(memory), Variable(query)) 38 | 39 | loss = loss_criterion(pred.cpu(), target) 40 | 41 | loss.backward() 42 | 43 | torch.nn.utils.clip_grad_norm(model.parameters(), max_grad_norm) 44 | 45 | optimizer.step() 46 | 47 | return loss.data[0] 48 | 49 | 50 | def train(model, 51 | data_loader, 52 | dev_set_reader, 53 | n_epochs, 54 | lr, 55 | decay_factor, 56 | decay_every, 57 | max_grad_norm, 58 | print_interval, 59 | summary_interval, 60 | use_cuda): 61 | 62 | start_ts = time.time() 63 | 64 | optimizer = optim.SGD(model.parameters(), lr=lr) 65 | 66 | lr_scheduler = StepLR(optimizer, step_size=decay_every, gamma=decay_factor) 67 | 68 | loss_criterion = nn.NLLLoss() 69 | loss_criterion = loss_criterion.cuda() if use_cuda else loss_criterion 70 | 71 | loss_total_print = 0 72 | loss_total_summary = 0 73 | losses = [] 74 | 75 | i = 0 76 | for i_epoch in range(n_epochs): 77 | 78 | for i_batch, sample_batched in enumerate(data_loader): 79 | 80 | sample, query, label = sample_batched 81 | 82 | loss = train_step(model, optimizer, loss_criterion, max_grad_norm, sample, query, label, use_cuda) 83 | 84 | loss_total_print += loss 85 | loss_total_summary += loss 86 | losses.append(loss) 87 | 88 | if i % print_interval == 0: 89 | print("Epoch: %d, Iter: %d, Loss: %.5f" % (i_epoch + 1, i_batch + 1, loss_total_print / print_interval)) 90 | loss_total_print = 0 91 | 92 | if (i + 1) % summary_interval == 0: 93 | 94 | avg_loss = loss_total_summary / summary_interval 95 | loss_total_summary = 0 96 | 97 | dev_acc = calc_accuracy_per_response(model, dev_set_reader, use_cuda) 98 | 99 | print('\n---------- SUMARRY ----------') 100 | print("Time elapsed: %s, Train Loss: %.5f, Dev Accuracy: %.5f\n" % (to_time(time.time() - start_ts), avg_loss, dev_acc)) 101 | 102 | i += 1 103 | 104 | lr_scheduler.step() 105 | 106 | return losses 107 | 108 | 109 | if __name__ == "__main__": 110 | 111 | parser = argparse.ArgumentParser(description='Goal-Oriented Chatbot using End-to-End Memory Networks', formatter_class=argparse.ArgumentDefaultsHelpFormatter) 112 | 113 | parser.add_argument('train_path', type=str, help='training data path') 114 | parser.add_argument('dev_path', type=str, help='development data path') 115 | parser.add_argument('candidates_path', type=str, help='candidate responses data path') 116 | 117 | gr_model = parser.add_argument_group('Model Parameters') 118 | gr_model.add_argument('--edim', type=int, metavar='N', default=32, help='internal state dimension') 119 | gr_model.add_argument('--nhops', type=int, metavar='N', default=1, help='number of memory hops') 120 | gr_model.add_argument('--init_std', type=float, metavar='N', default=0.1, help='weight initialization std') 121 | 122 | gr_train = parser.add_argument_group('Training Parameters') 123 | gr_train.add_argument('--gpu', action="store_true", default=False, help='use GPU for training') 124 | gr_train.add_argument('--lr', type=float, metavar='N', default=0.01, help='initial learning rate') 125 | gr_train.add_argument('--decay_factor', type=float, default=0.5, help='learning rate decay factor') 126 | gr_train.add_argument('--decay_every', type=int, default=25, help='# of epochs learning rate is changed') 127 | gr_train.add_argument('--batchsize', type=int, metavar='N', default=32, help='minibatch size') 128 | gr_train.add_argument('--epochs', type=int, metavar='N', default=50, help='initial learning rate') 129 | gr_train.add_argument('--maxgradnorm', type=int, metavar='N', default=40, help='maximum gradient norm') 130 | gr_train.add_argument('--maxmemsize', type=int, metavar='N', default=100, help='memory capacity') 131 | gr_train.add_argument('--shuffle', action="store_true", default=True, help='shuffle batches before every epoch') 132 | gr_train.add_argument('--save_dir', type=str, default=None, help='path to save the model') 133 | 134 | args = parser.parse_args() 135 | 136 | # build data, initialize model and start training. 137 | 138 | dialog_vocab, candidates_vocab = build_dialog_vocab(args.train_path, args.candidates_path, 1000) 139 | 140 | trn_data_reader = DialogReader(args.train_path, dialog_vocab, candidates_vocab, args.maxmemsize, args.batchsize, False, args.shuffle, False) 141 | dev_data_reader = DialogReader(args.dev_path, dialog_vocab, candidates_vocab, args.maxmemsize, args.batchsize, False, False, False) 142 | 143 | candidate_vecs = Variable(trn_data_reader._candidate_vecs) 144 | candidate_vecs = candidate_vecs.cuda() if args.gpu else candidate_vecs 145 | 146 | model = MemN2N(args.edim, len(trn_data_reader._dialog_vocab), candidate_vecs, args.nhops, args.init_std) 147 | 148 | if args.gpu: 149 | model.cuda() 150 | 151 | train(model, trn_data_reader, dev_data_reader, args.epochs, args.lr, args.decay_factor, args.decay_every, args.maxgradnorm, 50, 500, args.gpu) 152 | 153 | # saving trained model and vocabularies. 154 | 155 | save_dir = args.save_dir 156 | if not save_dir: 157 | save_dir = os.getcwd() 158 | 159 | save_dir = os.path.join(save_dir, 'model_' + str(time.time())) 160 | 161 | if not os.path.exists(save_dir): 162 | os.makedirs(save_dir) 163 | else: 164 | raise ValueError("Model save path already exists") 165 | 166 | Vocab.save(dialog_vocab, os.path.join(save_dir, 'dialog_vocab')) 167 | Vocab.save(candidates_vocab, os.path.join(save_dir, 'candidates_vocab')) 168 | model.save(os.path.join(save_dir, 'model')) 169 | -------------------------------------------------------------------------------- /vocab.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from operator import itemgetter 3 | 4 | 5 | class Vocab: 6 | """Represents a vocabulary for storing words and their indices and provides mapping between them.""" 7 | 8 | def __init__(self): 9 | 10 | self._words = dict() 11 | self._spec_tokens = [] 12 | 13 | def add_special_token(self, token): 14 | 15 | self._spec_tokens.append(token) 16 | 17 | def add_word(self, word): 18 | 19 | self._words[word] = self._words.get(word, 0) + 1 20 | 21 | def make(self, n_words=-1): 22 | """ 23 | Builds mapping between words and their indices and vice versa. Should be called before using vocabulary. 24 | :param n_words: Number of top frequent words to be included in final vocabulary. 25 | If negative, all words are used. 26 | """ 27 | 28 | self._ensure_validity() 29 | 30 | vocab_size = len(self._spec_tokens) + len(self._words) 31 | 32 | self._vocab_size = min(n_words, vocab_size) if n_words > -1 else vocab_size 33 | 34 | # Spec tokens must preserve the order they have been added 35 | # Words are ordered by decreasing frequency (stable sort) 36 | self._idx2word = self._spec_tokens + [k for k, v in sorted(self._words.items(), key=itemgetter(1), reverse=True)[:self._vocab_size]] 37 | 38 | self._word2idx = dict((v, k) for k, v in enumerate(self._idx2word)) 39 | 40 | def word_to_index(self, word): 41 | 42 | return self._word2idx.get(word, -1) 43 | 44 | def index_to_word(self, index): 45 | 46 | return self._idx2word[index] if 0 <= index < self._vocab_size else None 47 | 48 | def _ensure_validity(self): 49 | 50 | unique_specials = set(self._spec_tokens) 51 | 52 | if len(unique_specials) != len(self._spec_tokens): 53 | raise ValueError("Single spec token was added more than once.") 54 | 55 | if len(unique_specials & self._words.keys()) != 0: 56 | raise ValueError("Spec tokens and words mustn't have common elements.") 57 | 58 | def __len__(self): 59 | 60 | return self._vocab_size 61 | 62 | @staticmethod 63 | def save(vocab, path): 64 | 65 | with open(path, 'wb') as f: 66 | pickle.dump(vocab, f) 67 | 68 | @staticmethod 69 | def load(path): 70 | 71 | with open(path, 'rb') as f: 72 | return pickle.load(f) --------------------------------------------------------------------------------