├── bAbI └── .gitkeep ├── model └── .gitkeep ├── .gitignore ├── memn2n ├── train.py ├── dataset.py ├── model.py ├── data_utils.py └── trainer.py └── README.md /bAbI/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /model/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.swp 2 | __pycache__ 3 | -------------------------------------------------------------------------------- /memn2n/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import trainer 3 | 4 | def parse_config(): 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument("--cuda", action="store_true") 7 | parser.add_argument("--dataset_dir", type=str, default="bAbI/tasks_1-20_v1-2/en/") 8 | parser.add_argument("--task", type=int, default=1) 9 | parser.add_argument("--max_hops", type=int, default=3) 10 | parser.add_argument("--batch_size", type=int, default=32) 11 | parser.add_argument("--max_epochs", type=int, default=100) 12 | parser.add_argument("--lr", type=float, default=0.01) 13 | parser.add_argument("--decay_interval", type=int, default=25) 14 | parser.add_argument("--decay_ratio", type=float, default=0.5) 15 | parser.add_argument("--max_clip", type=float, default=40.0) 16 | 17 | return parser.parse_args() 18 | 19 | 20 | def main(config): 21 | t = trainer.Trainer(config) 22 | t.fit() 23 | 24 | 25 | if __name__ == "__main__": 26 | config = parse_config() 27 | main(config) 28 | -------------------------------------------------------------------------------- /memn2n/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from itertools import chain 4 | import numpy as np 5 | import torch 6 | import torch.utils.data as data 7 | from data_utils import load_task, vectorize_data 8 | from six.moves import range, reduce 9 | 10 | class bAbIDataset(data.Dataset): 11 | def __init__(self, dataset_dir, task_id=1, memory_size=50, train=True): 12 | self.train = train 13 | self.task_id = task_id 14 | self.dataset_dir = dataset_dir 15 | 16 | train_data, test_data = load_task(self.dataset_dir, task_id) 17 | data = train_data + test_data 18 | 19 | self.vocab = set() 20 | for story, query, answer in data: 21 | self.vocab = self.vocab | set(list(chain.from_iterable(story))+query+answer) 22 | self.vocab = sorted(self.vocab) 23 | word_idx = dict((word, i+1) for i, word in enumerate(self.vocab)) 24 | 25 | self.max_story_size = max([len(story) for story, _, _ in data]) 26 | self.query_size = max([len(query) for _, query, _ in data]) 27 | self.sentence_size = max([len(row) for row in \ 28 | chain.from_iterable([story for story, _, _ in data])]) 29 | self.memory_size = min(memory_size, self.max_story_size) 30 | 31 | # Add time words/indexes 32 | for i in range(self.memory_size): 33 | word_idx["time{}".format(i+1)] = "time{}".format(i+1) 34 | 35 | self.num_vocab = len(word_idx) + 1 # +1 for nil word 36 | self.sentence_size = max(self.query_size, self.sentence_size) # for the position 37 | self.sentence_size += 1 # +1 for time words 38 | self.word_idx = word_idx 39 | 40 | self.mean_story_size = int(np.mean([ len(s) for s, _, _ in data ])) 41 | 42 | if train: 43 | story, query, answer = vectorize_data(train_data, self.word_idx, 44 | self.sentence_size, self.memory_size) 45 | else: 46 | story, query, answer = vectorize_data(test_data, self.word_idx, 47 | self.sentence_size, self.memory_size) 48 | 49 | self.data_story = torch.LongTensor(story) 50 | self.data_query = torch.LongTensor(query) 51 | self.data_answer = torch.LongTensor(np.argmax(answer, axis=1)) 52 | 53 | def __getitem__(self, idx): 54 | return self.data_story[idx], self.data_query[idx], self.data_answer[idx] 55 | 56 | def __len__(self): 57 | return len(self.data_story) 58 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MemN2N-pytorch 2 | PyTorch implementation of [End-To-End Memory Network](https://arxiv.org/abs/1503.08895). This code is heavily based on [memn2n](https://github.com/domluna/memn2n) by domluna. 3 | 4 | ## Dataset 5 | ```shell 6 | cd bAbI 7 | wget http://www.thespermwhale.com/jaseweston/babi/tasks_1-20_v1-2.tar.gz 8 | tar xzvf ./tasks_1-20_v1-2.tar.gz 9 | ``` 10 | 11 | ## Training 12 | ```shell 13 | python memn2n/train.py --task=3 --cuda 14 | ``` 15 | 16 | ## Results (single-task only) 17 | In all experiments, hyperparameters follow the settings in `memn2n/train.py` (e.g. lr=0.001). 18 | 19 | And since I suspect training is really unstable, I train the model 100 times in each task with fixed hyperparameters described in `memn2n/train.py`, then average top-5 results. 20 | 21 | Task | Training Acc. | Test Acc. | Pass 22 | ------|-----------------|-------------|-------- 23 | 1 | 1.00 | 1.00 | O 24 | 2 | 0.98 | 0.84 | 25 | 3 | 1.00 | 0.49 | 26 | 4 | 1.00 | 0.99 | O 27 | 5 | 1.00 | 0.94 | 28 | 6 | 1.00 | 0.93 | 29 | 7 | 0.96 | 0.95 | O 30 | 8 | 0.97 | 0.89 | 31 | 9 | 1.00 | 0.91 | 32 | 10 | 1.00 | 0.87 | 33 | 11 | 1.00 | 0.98 | O 34 | 12 | 1.00 | 1.00 | O 35 | 13 | 0.97 | 0.94 | 36 | 14 | 1.00 | 1.00 | O 37 | 15 | 1.00 | 1.00 | O 38 | 16 | 0.81 | 0.47 | 39 | 17 | 0.75 | 0.53 | 40 | 18 | 0.97 | 0.92 | 41 | 19 | 0.39 | 0.17 | 42 | 20 | 1.00 | 1.00 | O 43 | mean | 0.94 | 0.84 | 44 | 45 | ## Issues 46 | - It seems like model training heavily rely on weight initialization (or training is very unstable). For example, best performance of task 2 is ~90% however average performance over 100 experiments is ~40% with same model and same hyperparameters. 47 | - WHY? 48 | 49 | ## TODO 50 | - Multi-task learning 51 | -------------------------------------------------------------------------------- /memn2n/model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.init as init 5 | from torch.autograd import Variable 6 | 7 | def position_encoding(sentence_size, embedding_dim): 8 | encoding = np.ones((embedding_dim, sentence_size), dtype=np.float32) 9 | ls = sentence_size + 1 10 | le = embedding_dim + 1 11 | for i in range(1, le): 12 | for j in range(1, ls): 13 | encoding[i-1, j-1] = (i - (embedding_dim+1)/2) * (j - (sentence_size+1)/2) 14 | encoding = 1 + 4 * encoding / embedding_dim / sentence_size 15 | # Make position encoding of time words identity to avoid modifying them 16 | encoding[:, -1] = 1.0 17 | return np.transpose(encoding) 18 | 19 | class AttrProxy(object): 20 | """ 21 | Translates index lookups into attribute lookups. 22 | To implement some trick which able to use list of nn.Module in a nn.Module 23 | see https://discuss.pytorch.org/t/list-of-nn-module-in-a-nn-module/219/2 24 | """ 25 | def __init__(self, module, prefix): 26 | self.module = module 27 | self.prefix = prefix 28 | 29 | def __getitem__(self, i): 30 | return getattr(self.module, self.prefix + str(i)) 31 | 32 | 33 | class MemN2N(nn.Module): 34 | def __init__(self, settings): 35 | super(MemN2N, self).__init__() 36 | 37 | use_cuda = settings["use_cuda"] 38 | num_vocab = settings["num_vocab"] 39 | embedding_dim = settings["embedding_dim"] 40 | sentence_size = settings["sentence_size"] 41 | self.max_hops = settings["max_hops"] 42 | 43 | for hop in range(self.max_hops+1): 44 | C = nn.Embedding(num_vocab, embedding_dim, padding_idx=0) 45 | C.weight.data.normal_(0, 0.1) 46 | self.add_module("C_{}".format(hop), C) 47 | self.C = AttrProxy(self, "C_") 48 | 49 | self.softmax = nn.Softmax() 50 | self.encoding = Variable(torch.FloatTensor( 51 | position_encoding(sentence_size, embedding_dim)), requires_grad=False) 52 | 53 | if use_cuda: 54 | self.encoding = self.encoding.cuda() 55 | 56 | def forward(self, story, query): 57 | story_size = story.size() 58 | 59 | u = list() 60 | query_embed = self.C[0](query) 61 | # weired way to perform reduce_dot 62 | encoding = self.encoding.unsqueeze(0).expand_as(query_embed) 63 | u.append(torch.sum(query_embed*encoding, 1)) 64 | 65 | for hop in range(self.max_hops): 66 | embed_A = self.C[hop](story.view(story.size(0), -1)) 67 | embed_A = embed_A.view(story_size+(embed_A.size(-1),)) 68 | 69 | encoding = self.encoding.unsqueeze(0).unsqueeze(1).expand_as(embed_A) 70 | m_A = torch.sum(embed_A*encoding, 2) 71 | 72 | u_temp = u[-1].unsqueeze(1).expand_as(m_A) 73 | prob = self.softmax(torch.sum(m_A*u_temp, 2)) 74 | 75 | embed_C = self.C[hop+1](story.view(story.size(0), -1)) 76 | embed_C = embed_C.view(story_size+(embed_C.size(-1),)) 77 | m_C = torch.sum(embed_C*encoding, 2) 78 | 79 | prob = prob.unsqueeze(2).expand_as(m_C) 80 | o_k = torch.sum(m_C*prob, 1) 81 | 82 | u_k = u[-1] + o_k 83 | u.append(u_k) 84 | 85 | a_hat = u[-1]@self.C[self.max_hops].weight.transpose(0, 1) 86 | return a_hat, self.softmax(a_hat) 87 | -------------------------------------------------------------------------------- /memn2n/data_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Data util codes based on https://github.com/domluna/memn2n 3 | """ 4 | 5 | import os 6 | import re 7 | import numpy as np 8 | 9 | def load_task(data_dir, task_id, only_supporting=False): 10 | """ 11 | Load the nth task. There are 20 tasks in total. 12 | Returns a tuple containing the training and testing data for the task. 13 | """ 14 | assert task_id > 0 and task_id < 21 15 | 16 | files = os.listdir(data_dir) 17 | files = [os.path.join(data_dir, f) for f in files] 18 | s = "qa{}_".format(task_id) 19 | train_file = [f for f in files if s in f and 'train' in f][0] 20 | test_file = [f for f in files if s in f and 'test' in f][0] 21 | train_data = get_stories(train_file, only_supporting) 22 | test_data = get_stories(test_file, only_supporting) 23 | return train_data, test_data 24 | 25 | 26 | def tokenize(sent): 27 | """ 28 | Return the tokens of a sentence including punctuation. 29 | >>> tokenize('Bob dropped the apple. Where is the apple?') 30 | ['Bob', 'dropped', 'the', 'apple', '.', 'Where', 'is', 'the', 'apple', '?'] 31 | """ 32 | return [x.strip() for x in re.split("(\W+)?", sent) if x.strip()] 33 | 34 | 35 | def parse_stories(lines, only_supporting=False): 36 | """ 37 | Parse stories provided in the bAbI tasks format 38 | If only_supporting is true, only the sentences that support the answer are kept. 39 | """ 40 | data = [] 41 | story = [] 42 | for line in lines: 43 | line = str.lower(line) 44 | nid, line = line.split(" ", 1) 45 | nid = int(nid) 46 | if nid == 1: 47 | story = [] 48 | if "\t" in line: # question 49 | q, a, supporting = line.split("\t") 50 | q = tokenize(q) 51 | # a = tokenize(a) 52 | # answer is one vocab word even if it's actually multiple words 53 | a = [a] 54 | substory = None 55 | 56 | # remove question marks 57 | if q[-1] == "?": 58 | q = q[:-1] 59 | 60 | if only_supporting: 61 | # Only select the related substory 62 | supporting = map(int, supporting.split()) 63 | substory = [story[i - 1] for i in supporting] 64 | else: 65 | # Provide all the substories 66 | substory = [x for x in story if x] 67 | 68 | data.append((substory, q, a)) 69 | story.append("") 70 | else: # regular sentence 71 | # remove periods 72 | sent = tokenize(line) 73 | if sent[-1] == ".": 74 | sent = sent[:-1] 75 | story.append(sent) 76 | return data 77 | 78 | 79 | def get_stories(f, only_supporting=False): 80 | """ 81 | Given a file name, read the file, retrieve the stories, 82 | and then convert the sentences into a single story. 83 | If max_length is supplied, any stories longer than max_length 84 | tokens will be discarded. 85 | """ 86 | with open(f) as f: 87 | return parse_stories(f.readlines(), only_supporting=only_supporting) 88 | 89 | 90 | def vectorize_data(data, word_idx, sentence_size, memory_size): 91 | """ 92 | Vectorize stories and queries. 93 | If a sentence length < sentence_size, the sentence will be padded with 0's. 94 | If a story length < memory_size, the story will be padded with empty memories. 95 | Empty memories are 1-D arrays of length sentence_size filled with 0's. 96 | The answer array is returned as a one-hot encoding. 97 | """ 98 | S, Q, A = [], [], [] 99 | for story, query, answer in data: 100 | ss = [] 101 | for i, sentence in enumerate(story, 1): 102 | ls = max(0, sentence_size - len(sentence)) 103 | ss.append([word_idx[w] for w in sentence] + [0] * ls) 104 | 105 | # take only the most recent sentences that fit in memory 106 | ss = ss[::-1][:memory_size][::-1] 107 | 108 | # Make the last word of each sentence the time 'word' which 109 | # corresponds to vector of lookup table 110 | for i in range(len(ss)): 111 | ss[i][-1] = len(word_idx) - memory_size - i + len(ss) 112 | 113 | # pad to memory_size 114 | lm = max(0, memory_size - len(ss)) 115 | for _ in range(lm): 116 | ss.append([0] * sentence_size) 117 | 118 | lq = max(0, sentence_size - len(query)) 119 | q = [word_idx[w] for w in query] + [0] * lq 120 | 121 | y = np.zeros(len(word_idx) + 1) # 0 is reserved for nil word 122 | for a in answer: 123 | y[word_idx[a]] = 1 124 | 125 | S.append(ss); Q.append(q); A.append(y) 126 | return np.array(S), np.array(Q), np.array(A) 127 | -------------------------------------------------------------------------------- /memn2n/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | from torch.autograd import Variable 7 | from torch.utils.data import DataLoader 8 | from dataset import bAbIDataset 9 | from model import MemN2N 10 | 11 | class Trainer(): 12 | def __init__(self, config): 13 | self.train_data = bAbIDataset(config.dataset_dir, config.task) 14 | self.train_loader = DataLoader(self.train_data, 15 | batch_size=config.batch_size, 16 | num_workers=1, 17 | shuffle=True) 18 | 19 | self.test_data = bAbIDataset(config.dataset_dir, config.task, train=False) 20 | self.test_loader = DataLoader(self.test_data, 21 | batch_size=config.batch_size, 22 | num_workers=1, 23 | shuffle=False) 24 | 25 | settings = { 26 | "use_cuda": config.cuda, 27 | "num_vocab": self.train_data.num_vocab, 28 | "embedding_dim": 20, 29 | "sentence_size": self.train_data.sentence_size, 30 | "max_hops": config.max_hops 31 | } 32 | 33 | print("Longest sentence length", self.train_data.sentence_size) 34 | print("Longest story length", self.train_data.max_story_size) 35 | print("Average story length", self.train_data.mean_story_size) 36 | print("Number of vocab", self.train_data.num_vocab) 37 | 38 | self.mem_n2n = MemN2N(settings) 39 | self.ce_fn = nn.CrossEntropyLoss(size_average=False) 40 | self.opt = torch.optim.SGD(self.mem_n2n.parameters(), lr=config.lr) 41 | print(self.mem_n2n) 42 | 43 | if config.cuda: 44 | self.ce_fn = self.ce_fn.cuda() 45 | self.mem_n2n = self.mem_n2n.cuda() 46 | 47 | self.start_epoch = 0 48 | self.config = config 49 | 50 | def fit(self): 51 | config = self.config 52 | for epoch in range(self.start_epoch, config.max_epochs): 53 | loss = self._train_single_epoch(epoch) 54 | lr = self._decay_learning_rate(self.opt, epoch) 55 | 56 | if (epoch+1) % 10 == 0: 57 | train_acc = self.evaluate("train") 58 | test_acc = self.evaluate("test") 59 | print(epoch+1, loss, train_acc, test_acc) 60 | print(train_acc, test_acc) 61 | 62 | def load(self, directory): 63 | pass 64 | 65 | def evaluate(self, data="test"): 66 | correct = 0 67 | loader = self.train_loader if data == "train" else self.test_loader 68 | for step, (story, query, answer) in enumerate(loader): 69 | story = Variable(story) 70 | query = Variable(query) 71 | answer = Variable(answer) 72 | 73 | if self.config.cuda: 74 | story = story.cuda() 75 | query = query.cuda() 76 | answer = answer.cuda() 77 | 78 | pred_prob = self.mem_n2n(story, query)[1] 79 | pred = pred_prob.data.max(1)[1] # max func return (max, argmax) 80 | correct += pred.eq(answer.data).cpu().sum() 81 | 82 | acc = correct / len(loader.dataset) 83 | return acc 84 | 85 | def _train_single_epoch(self, epoch): 86 | config = self.config 87 | num_steps_per_epoch = len(self.train_loader) 88 | for step, (story, query, answer) in enumerate(self.train_loader): 89 | story = Variable(story) 90 | query = Variable(query) 91 | answer = Variable(answer) 92 | 93 | if config.cuda: 94 | story = story.cuda() 95 | query = query.cuda() 96 | answer = answer.cuda() 97 | 98 | self.opt.zero_grad() 99 | loss = self.ce_fn(self.mem_n2n(story, query)[0], answer) 100 | loss.backward() 101 | 102 | self._gradient_noise_and_clip(self.mem_n2n.parameters(), 103 | noise_stddev=1e-3, max_clip=config.max_clip) 104 | self.opt.step() 105 | 106 | return loss.data[0] 107 | 108 | def _gradient_noise_and_clip(self, parameters, 109 | noise_stddev=1e-3, max_clip=40.0): 110 | parameters = list(filter(lambda p: p.grad is not None, parameters)) 111 | nn.utils.clip_grad_norm(parameters, max_clip) 112 | 113 | for p in parameters: 114 | noise = torch.randn(p.size()) * noise_stddev 115 | if self.config.cuda: 116 | noise = noise.cuda() 117 | p.grad.data.add_(noise) 118 | 119 | def _decay_learning_rate(self, opt, epoch): 120 | decay_interval = self.config.decay_interval 121 | decay_ratio = self.config.decay_ratio 122 | 123 | decay_count = max(0, epoch // decay_interval) 124 | lr = self.config.lr * (decay_ratio ** decay_count) 125 | for param_group in opt.param_groups: 126 | param_group["lr"] = lr 127 | 128 | return lr 129 | --------------------------------------------------------------------------------