├── .gitignore ├── README.md ├── babi_loader.py ├── babi_main.py ├── fetch_data.sh └── pretrained_models ├── task10_acc1.0.pth ├── task11_acc1.0.pth ├── task12_acc1.0.pth ├── task13_acc1.0.pth ├── task14_acc0.9900000095367432.pth ├── task15_acc1.0.pth ├── task16_acc0.5169999986886978.pth ├── task17_acc0.8649999916553497.pth ├── task18_acc0.9790000081062317.pth ├── task19_acc0.997000002861023.pth ├── task1_acc1.0.pth ├── task20_acc1.0.pth ├── task2_acc0.9680000007152557.pth ├── task3_acc0.8919999957084656.pth ├── task4_acc1.0.pth ├── task5_acc0.9950000047683716.pth ├── task6_acc1.0.pth ├── task7_acc0.978000009059906.pth ├── task8_acc1.0.pth └── task9_acc1.0.pth /.gitignore: -------------------------------------------------------------------------------- 1 | dataset/ 2 | data/ 3 | models/ 4 | __pycache__/ 5 | log.txt -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Dynamic-memory-networks-plus-Pytorch 2 | 3 | [DMN+](https://arxiv.org/abs/1603.01417) implementation in Pytorch for question answering on the bAbI 10k dataset. 4 | 5 | ## Contents 6 | | file | description | 7 | | --- | --- | 8 | | `babi_loader.py` | declaration of bAbI Pytorch Dataset class | 9 | | `babi_main.py` | contains DMN+ model and training code | 10 | | `fetch_data.sh` | shell script to fetch bAbI tasks (from [DMNs in Theano](https://github.com/YerevaNN/Dynamic-memory-networks-in-Theano)) | 11 | 12 | ## Usage 13 | Install [Pytorch v0.1.12](http://pytorch.org/) and [Python 3.6.x](https://www.python.org/downloads/) (for [Literal String Interpolation](https://www.python.org/dev/peps/pep-0498/)) 14 | 15 | Run the included shell script to fetch the data 16 | 17 | chmod +x fetch_data.sh 18 | ./fetch_data.sh 19 | 20 | Run the main python code 21 | 22 | python babi_main.py 23 | 24 | ## Benchmarks 25 | 26 | Low accuracies compared to Xiong et al's are may due to different weight decay setting or the model's instability. 27 | 28 | > On some tasks, the accuracy was not stable across multiple 29 | runs. This was particularly problematic on QA3, QA17, 30 | and QA18. To solve this, we repeated training 10 times 31 | using random initializations and evaluated the model that 32 | achieved the lowest validation set loss. 33 | 34 | You can find pretrained models [here](https://github.com/dandelin/Dynamic-memory-networks-plus-Pytorch/tree/master/pretrained_models) 35 | 36 | | Task ID | This Repo | Xiong et al | 37 | | :---: | :---: | :---: | 38 | | 1 | 100% | 100% | 39 | | 2 | 96.8% | 99.7% | 40 | | 3 | 89.2% | 98.9% | 41 | | 4 | 100% | 100% | 42 | | 5 | 99.5% | 99.5% | 43 | | 6 | 100% | 100% | 44 | | 7 | 97.8% | 97.6% | 45 | | 8 | 100% | 100% | 46 | | 9 | 100% | 100% | 47 | | 10 | 100% | 100% | 48 | | 11 | 100% | 100% | 49 | | 12 | 100% | 100% | 50 | | 13 | 100% | 100% | 51 | | 14 | 99% | 99.8% | 52 | | 15 | 100% | 100% | 53 | | 16 | 51.6% | 54.7% | 54 | | 17 | 86.4% | 95.8% | 55 | | 18 | 97.9% | 97.9% | 56 | | 19 | 99.7% | 100% | 57 | | 20 | 100% | 100% | 58 | -------------------------------------------------------------------------------- /babi_loader.py: -------------------------------------------------------------------------------- 1 | from glob import glob 2 | from torch.utils.data.dataset import Dataset 3 | from torch.utils.data import DataLoader 4 | from torch.utils.data.dataloader import default_collate 5 | import re 6 | import numpy as np 7 | 8 | class adict(dict): 9 | def __init__(self, *av, **kav): 10 | dict.__init__(self, *av, **kav) 11 | self.__dict__ = self 12 | 13 | def pad_collate(batch): 14 | max_context_sen_len = float('-inf') 15 | max_context_len = float('-inf') 16 | max_question_len = float('-inf') 17 | for elem in batch: 18 | context, question, _ = elem 19 | max_context_len = max_context_len if max_context_len > len(context) else len(context) 20 | max_question_len = max_question_len if max_question_len > len(question) else len(question) 21 | for sen in context: 22 | max_context_sen_len = max_context_sen_len if max_context_sen_len > len(sen) else len(sen) 23 | max_context_len = min(max_context_len, 70) 24 | for i, elem in enumerate(batch): 25 | _context, question, answer = elem 26 | _context = _context[-max_context_len:] 27 | context = np.zeros((max_context_len, max_context_sen_len)) 28 | for j, sen in enumerate(_context): 29 | context[j] = np.pad(sen, (0, max_context_sen_len - len(sen)), 'constant', constant_values=0) 30 | question = np.pad(question, (0, max_question_len - len(question)), 'constant', constant_values=0) 31 | batch[i] = (context, question, answer) 32 | return default_collate(batch) 33 | 34 | class BabiDataset(Dataset): 35 | def __init__(self, task_id, mode='train'): 36 | self.vocab_path = 'dataset/babi{}_vocab.pkl'.format(task_id) 37 | self.mode = mode 38 | raw_train, raw_test = get_raw_babi(task_id) 39 | self.QA = adict() 40 | self.QA.VOCAB = {'': 0, '': 1} 41 | self.QA.IVOCAB = {0: '', 1: ''} 42 | self.train = self.get_indexed_qa(raw_train) 43 | self.valid = [self.train[i][int(-len(self.train[i])/10):] for i in range(3)] 44 | self.train = [self.train[i][:int(9 * len(self.train[i])/10)] for i in range(3)] 45 | self.test = self.get_indexed_qa(raw_test) 46 | 47 | def set_mode(self, mode): 48 | self.mode = mode 49 | 50 | def __len__(self): 51 | if self.mode == 'train': 52 | return len(self.train[0]) 53 | elif self.mode == 'valid': 54 | return len(self.valid[0]) 55 | elif self.mode == 'test': 56 | return len(self.test[0]) 57 | 58 | def __getitem__(self, index): 59 | if self.mode == 'train': 60 | contexts, questions, answers = self.train 61 | elif self.mode == 'valid': 62 | contexts, questions, answers = self.valid 63 | elif self.mode == 'test': 64 | contexts, questions, answers = self.test 65 | return contexts[index], questions[index], answers[index] 66 | 67 | def get_indexed_qa(self, raw_babi): 68 | unindexed = get_unindexed_qa(raw_babi) 69 | questions = [] 70 | contexts = [] 71 | answers = [] 72 | for qa in unindexed: 73 | context = [c.lower().split() + [''] for c in qa['C']] 74 | 75 | for con in context: 76 | for token in con: 77 | self.build_vocab(token) 78 | context = [[self.QA.VOCAB[token] for token in sentence] for sentence in context] 79 | question = qa['Q'].lower().split() + [''] 80 | 81 | for token in question: 82 | self.build_vocab(token) 83 | question = [self.QA.VOCAB[token] for token in question] 84 | 85 | self.build_vocab(qa['A'].lower()) 86 | answer = self.QA.VOCAB[qa['A'].lower()] 87 | 88 | 89 | contexts.append(context) 90 | questions.append(question) 91 | answers.append(answer) 92 | return (contexts, questions, answers) 93 | 94 | def build_vocab(self, token): 95 | if not token in self.QA.VOCAB: 96 | next_index = len(self.QA.VOCAB) 97 | self.QA.VOCAB[token] = next_index 98 | self.QA.IVOCAB[next_index] = token 99 | 100 | 101 | def get_raw_babi(taskid): 102 | paths = glob('data/en-10k/qa{}_*'.format(taskid)) 103 | for path in paths: 104 | if 'train' in path: 105 | with open(path, 'r') as fp: 106 | train = fp.read() 107 | elif 'test' in path: 108 | with open(path, 'r') as fp: 109 | test = fp.read() 110 | return train, test 111 | 112 | def build_vocab(raw_babi): 113 | lowered = raw_babi.lower() 114 | tokens = re.findall('[a-zA-Z]+', lowered) 115 | types = set(tokens) 116 | return types 117 | 118 | # adapted from https://github.com/YerevaNN/Dynamic-memory-networks-in-Theano/ 119 | def get_unindexed_qa(raw_babi): 120 | tasks = [] 121 | task = None 122 | babi = raw_babi.strip().split('\n') 123 | for i, line in enumerate(babi): 124 | id = int(line[0:line.find(' ')]) 125 | if id == 1: 126 | task = {"C": "", "Q": "", "A": "", "S": ""} 127 | counter = 0 128 | id_map = {} 129 | 130 | line = line.strip() 131 | line = line.replace('.', ' . ') 132 | line = line[line.find(' ')+1:] 133 | # if not a question 134 | if line.find('?') == -1: 135 | task["C"] += line + '' 136 | id_map[id] = counter 137 | counter += 1 138 | else: 139 | idx = line.find('?') 140 | tmp = line[idx+1:].split('\t') 141 | task["Q"] = line[:idx] 142 | task["A"] = tmp[1].strip() 143 | task["S"] = [] # Supporting facts 144 | for num in tmp[2].split(): 145 | task["S"].append(id_map[int(num.strip())]) 146 | tc = task.copy() 147 | tc['C'] = tc['C'].split('')[:-1] 148 | tasks.append(tc) 149 | return tasks 150 | 151 | if __name__ == '__main__': 152 | dset_train = BabiDataset(20, is_train=True) 153 | train_loader = DataLoader(dset_train, batch_size=2, shuffle=True, collate_fn=pad_collate) 154 | for batch_idx, data in enumerate(train_loader): 155 | contexts, questions, answers = data 156 | break 157 | -------------------------------------------------------------------------------- /babi_main.py: -------------------------------------------------------------------------------- 1 | from babi_loader import BabiDataset, pad_collate 2 | import os 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.nn.init as init 7 | from torch.autograd import Variable 8 | from torch.utils.data import DataLoader 9 | 10 | def position_encoding(embedded_sentence): 11 | ''' 12 | embedded_sentence.size() -> (#batch, #sentence, #token, #embedding) 13 | l.size() -> (#sentence, #embedding) 14 | output.size() -> (#batch, #sentence, #embedding) 15 | ''' 16 | _, _, slen, elen = embedded_sentence.size() 17 | 18 | l = [[(1 - s/(slen-1)) - (e/(elen-1)) * (1 - 2*s/(slen-1)) for e in range(elen)] for s in range(slen)] 19 | l = torch.FloatTensor(l) 20 | l = l.unsqueeze(0) # for #batch 21 | l = l.unsqueeze(1) # for #sen 22 | l = l.expand_as(embedded_sentence) 23 | weighted = embedded_sentence * Variable(l.cuda()) 24 | return torch.sum(weighted, dim=2).squeeze(2) # sum with tokens 25 | 26 | class AttentionGRUCell(nn.Module): 27 | def __init__(self, input_size, hidden_size): 28 | super(AttentionGRUCell, self).__init__() 29 | self.hidden_size = hidden_size 30 | self.Wr = nn.Linear(input_size, hidden_size) 31 | init.xavier_normal(self.Wr.state_dict()['weight']) 32 | self.Ur = nn.Linear(hidden_size, hidden_size) 33 | init.xavier_normal(self.Ur.state_dict()['weight']) 34 | self.W = nn.Linear(input_size, hidden_size) 35 | init.xavier_normal(self.W.state_dict()['weight']) 36 | self.U = nn.Linear(hidden_size, hidden_size) 37 | init.xavier_normal(self.U.state_dict()['weight']) 38 | 39 | def forward(self, fact, C, g): 40 | ''' 41 | fact.size() -> (#batch, #hidden = #embedding) 42 | c.size() -> (#hidden, ) -> (#batch, #hidden = #embedding) 43 | r.size() -> (#batch, #hidden = #embedding) 44 | h_tilda.size() -> (#batch, #hidden = #embedding) 45 | g.size() -> (#batch, ) 46 | ''' 47 | 48 | r = F.sigmoid(self.Wr(fact) + self.Ur(C)) 49 | h_tilda = F.tanh(self.W(fact) + r * self.U(C)) 50 | g = g.unsqueeze(1).expand_as(h_tilda) 51 | h = g * h_tilda + (1 - g) * C 52 | return h 53 | 54 | class AttentionGRU(nn.Module): 55 | def __init__(self, input_size, hidden_size): 56 | super(AttentionGRU, self).__init__() 57 | self.hidden_size = hidden_size 58 | self.AGRUCell = AttentionGRUCell(input_size, hidden_size) 59 | 60 | def forward(self, facts, G): 61 | ''' 62 | facts.size() -> (#batch, #sentence, #hidden = #embedding) 63 | fact.size() -> (#batch, #hidden = #embedding) 64 | G.size() -> (#batch, #sentence) 65 | g.size() -> (#batch, ) 66 | C.size() -> (#batch, #hidden) 67 | ''' 68 | batch_num, sen_num, embedding_size = facts.size() 69 | C = Variable(torch.zeros(self.hidden_size)).cuda() 70 | for sid in range(sen_num): 71 | fact = facts[:, sid, :] 72 | g = G[:, sid] 73 | if sid == 0: 74 | C = C.unsqueeze(0).expand_as(fact) 75 | C = self.AGRUCell(fact, C, g) 76 | return C 77 | 78 | class EpisodicMemory(nn.Module): 79 | def __init__(self, hidden_size): 80 | super(EpisodicMemory, self).__init__() 81 | self.AGRU = AttentionGRU(hidden_size, hidden_size) 82 | self.z1 = nn.Linear(4 * hidden_size, hidden_size) 83 | self.z2 = nn.Linear(hidden_size, 1) 84 | self.next_mem = nn.Linear(3 * hidden_size, hidden_size) 85 | init.xavier_normal(self.z1.state_dict()['weight']) 86 | init.xavier_normal(self.z2.state_dict()['weight']) 87 | init.xavier_normal(self.next_mem.state_dict()['weight']) 88 | 89 | def make_interaction(self, facts, questions, prevM): 90 | ''' 91 | facts.size() -> (#batch, #sentence, #hidden = #embedding) 92 | questions.size() -> (#batch, 1, #hidden) 93 | prevM.size() -> (#batch, #sentence = 1, #hidden = #embedding) 94 | z.size() -> (#batch, #sentence, 4 x #embedding) 95 | G.size() -> (#batch, #sentence) 96 | ''' 97 | batch_num, sen_num, embedding_size = facts.size() 98 | questions = questions.expand_as(facts) 99 | prevM = prevM.expand_as(facts) 100 | 101 | z = torch.cat([ 102 | facts * questions, 103 | facts * prevM, 104 | torch.abs(facts - questions), 105 | torch.abs(facts - prevM) 106 | ], dim=2) 107 | 108 | z = z.view(-1, 4 * embedding_size) 109 | 110 | G = F.tanh(self.z1(z)) 111 | G = self.z2(G) 112 | G = G.view(batch_num, -1) 113 | G = F.softmax(G) 114 | 115 | return G 116 | 117 | def forward(self, facts, questions, prevM): 118 | ''' 119 | facts.size() -> (#batch, #sentence, #hidden = #embedding) 120 | questions.size() -> (#batch, #sentence = 1, #hidden) 121 | prevM.size() -> (#batch, #sentence = 1, #hidden = #embedding) 122 | G.size() -> (#batch, #sentence) 123 | C.size() -> (#batch, #hidden) 124 | concat.size() -> (#batch, 3 x #embedding) 125 | ''' 126 | G = self.make_interaction(facts, questions, prevM) 127 | C = self.AGRU(facts, G) 128 | concat = torch.cat([prevM.squeeze(1), C, questions.squeeze(1)], dim=1) 129 | next_mem = F.relu(self.next_mem(concat)) 130 | next_mem = next_mem.unsqueeze(1) 131 | return next_mem 132 | 133 | 134 | class QuestionModule(nn.Module): 135 | def __init__(self, vocab_size, hidden_size): 136 | super(QuestionModule, self).__init__() 137 | self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True) 138 | 139 | def forward(self, questions, word_embedding): 140 | ''' 141 | questions.size() -> (#batch, #token) 142 | word_embedding() -> (#batch, #token, #embedding) 143 | gru() -> (1, #batch, #hidden) 144 | ''' 145 | questions = word_embedding(questions) 146 | _, questions = self.gru(questions) 147 | questions = questions.transpose(0, 1) 148 | return questions 149 | 150 | class InputModule(nn.Module): 151 | def __init__(self, vocab_size, hidden_size): 152 | super(InputModule, self).__init__() 153 | self.hidden_size = hidden_size 154 | self.gru = nn.GRU(hidden_size, hidden_size, bidirectional=True, batch_first=True) 155 | for name, param in self.gru.state_dict().items(): 156 | if 'weight' in name: init.xavier_normal(param) 157 | self.dropout = nn.Dropout(0.1) 158 | 159 | def forward(self, contexts, word_embedding): 160 | ''' 161 | contexts.size() -> (#batch, #sentence, #token) 162 | word_embedding() -> (#batch, #sentence x #token, #embedding) 163 | position_encoding() -> (#batch, #sentence, #embedding) 164 | facts.size() -> (#batch, #sentence, #hidden = #embedding) 165 | ''' 166 | batch_num, sen_num, token_num = contexts.size() 167 | 168 | contexts = contexts.view(batch_num, -1) 169 | contexts = word_embedding(contexts) 170 | 171 | contexts = contexts.view(batch_num, sen_num, token_num, -1) 172 | contexts = position_encoding(contexts) 173 | contexts = self.dropout(contexts) 174 | 175 | h0 = Variable(torch.zeros(2, batch_num, self.hidden_size).cuda()) 176 | facts, hdn = self.gru(contexts, h0) 177 | facts = facts[:, :, :hidden_size] + facts[:, :, hidden_size:] 178 | return facts 179 | 180 | class AnswerModule(nn.Module): 181 | def __init__(self, vocab_size, hidden_size): 182 | super(AnswerModule, self).__init__() 183 | self.z = nn.Linear(2 * hidden_size, vocab_size) 184 | init.xavier_normal(self.z.state_dict()['weight']) 185 | self.dropout = nn.Dropout(0.1) 186 | 187 | def forward(self, M, questions): 188 | M = self.dropout(M) 189 | concat = torch.cat([M, questions], dim=2).squeeze(1) 190 | z = self.z(concat) 191 | return z 192 | 193 | class DMNPlus(nn.Module): 194 | def __init__(self, hidden_size, vocab_size, num_hop=3, qa=None): 195 | super(DMNPlus, self).__init__() 196 | self.num_hop = num_hop 197 | self.qa = qa 198 | self.word_embedding = nn.Embedding(vocab_size, hidden_size, padding_idx=0, sparse=True).cuda() 199 | init.uniform(self.word_embedding.state_dict()['weight'], a=-(3**0.5), b=3**0.5) 200 | self.criterion = nn.CrossEntropyLoss(size_average=False) 201 | 202 | self.input_module = InputModule(vocab_size, hidden_size) 203 | self.question_module = QuestionModule(vocab_size, hidden_size) 204 | self.memory = EpisodicMemory(hidden_size) 205 | self.answer_module = AnswerModule(vocab_size, hidden_size) 206 | 207 | def forward(self, contexts, questions): 208 | ''' 209 | contexts.size() -> (#batch, #sentence, #token) -> (#batch, #sentence, #hidden = #embedding) 210 | questions.size() -> (#batch, #token) -> (#batch, 1, #hidden) 211 | ''' 212 | facts = self.input_module(contexts, self.word_embedding) 213 | questions = self.question_module(questions, self.word_embedding) 214 | M = questions 215 | for hop in range(self.num_hop): 216 | M = self.memory(facts, questions, M) 217 | preds = self.answer_module(M, questions) 218 | return preds 219 | 220 | def interpret_indexed_tensor(self, var): 221 | if len(var.size()) == 3: 222 | # var -> n x #sen x #token 223 | for n, sentences in enumerate(var): 224 | for i, sentence in enumerate(sentences): 225 | s = ' '.join([self.qa.IVOCAB[elem.data[0]] for elem in sentence]) 226 | print(f'{n}th of batch, {i}th sentence, {s}') 227 | elif len(var.size()) == 2: 228 | # var -> n x #token 229 | for n, sentence in enumerate(var): 230 | s = ' '.join([self.qa.IVOCAB[elem.data[0]] for elem in sentence]) 231 | print(f'{n}th of batch, {s}') 232 | elif len(var.size()) == 1: 233 | # var -> n (one token per batch) 234 | for n, token in enumerate(var): 235 | s = self.qa.IVOCAB[token.data[0]] 236 | print(f'{n}th of batch, {s}') 237 | 238 | def get_loss(self, contexts, questions, targets): 239 | output = self.forward(contexts, questions) 240 | loss = self.criterion(output, targets) 241 | reg_loss = 0 242 | for param in self.parameters(): 243 | reg_loss += 0.001 * torch.sum(param * param) 244 | preds = F.softmax(output) 245 | _, pred_ids = torch.max(preds, dim=1) 246 | corrects = (pred_ids.data == answers.data) 247 | acc = torch.mean(corrects.float()) 248 | return loss + reg_loss, acc 249 | 250 | if __name__ == '__main__': 251 | for run in range(10): 252 | for task_id in range(1, 21): 253 | dset = BabiDataset(task_id) 254 | vocab_size = len(dset.QA.VOCAB) 255 | hidden_size = 80 256 | 257 | model = DMNPlus(hidden_size, vocab_size, num_hop=3, qa=dset.QA) 258 | model.cuda() 259 | early_stopping_cnt = 0 260 | early_stopping_flag = False 261 | best_acc = 0 262 | optim = torch.optim.Adam(model.parameters()) 263 | 264 | 265 | for epoch in range(256): 266 | dset.set_mode('train') 267 | train_loader = DataLoader( 268 | dset, batch_size=100, shuffle=True, collate_fn=pad_collate 269 | ) 270 | 271 | model.train() 272 | if not early_stopping_flag: 273 | total_acc = 0 274 | cnt = 0 275 | for batch_idx, data in enumerate(train_loader): 276 | optim.zero_grad() 277 | contexts, questions, answers = data 278 | batch_size = contexts.size()[0] 279 | contexts = Variable(contexts.long().cuda()) 280 | questions = Variable(questions.long().cuda()) 281 | answers = Variable(answers.cuda()) 282 | 283 | loss, acc = model.get_loss(contexts, questions, answers) 284 | loss.backward() 285 | total_acc += acc * batch_size 286 | cnt += batch_size 287 | 288 | if batch_idx % 20 == 0: 289 | print(f'[Task {task_id}, Epoch {epoch}] [Training] loss : {loss.data[0]: {10}.{8}}, acc : {total_acc / cnt: {5}.{4}}, batch_idx : {batch_idx}') 290 | optim.step() 291 | 292 | dset.set_mode('valid') 293 | valid_loader = DataLoader( 294 | dset, batch_size=100, shuffle=False, collate_fn=pad_collate 295 | ) 296 | 297 | model.eval() 298 | total_acc = 0 299 | cnt = 0 300 | for batch_idx, data in enumerate(valid_loader): 301 | contexts, questions, answers = data 302 | batch_size = contexts.size()[0] 303 | contexts = Variable(contexts.long().cuda()) 304 | questions = Variable(questions.long().cuda()) 305 | answers = Variable(answers.cuda()) 306 | 307 | _, acc = model.get_loss(contexts, questions, answers) 308 | total_acc += acc * batch_size 309 | cnt += batch_size 310 | 311 | total_acc = total_acc / cnt 312 | if total_acc > best_acc: 313 | best_acc = total_acc 314 | best_state = model.state_dict() 315 | early_stopping_cnt = 0 316 | else: 317 | early_stopping_cnt += 1 318 | if early_stopping_cnt > 20: 319 | early_stopping_flag = True 320 | 321 | print(f'[Run {run}, Task {task_id}, Epoch {epoch}] [Validate] Accuracy : {total_acc: {5}.{4}}') 322 | with open('log.txt', 'a') as fp: 323 | fp.write(f'[Run {run}, Task {task_id}, Epoch {epoch}] [Validate] Accuracy : {total_acc: {5}.{4}}' + '\n') 324 | if total_acc == 1.0: 325 | break 326 | else: 327 | print(f'[Run {run}, Task {task_id}] Early Stopping at Epoch {epoch}, Valid Accuracy : {best_acc: {5}.{4}}') 328 | break 329 | 330 | dset.set_mode('test') 331 | test_loader = DataLoader( 332 | dset, batch_size=100, shuffle=False, collate_fn=pad_collate 333 | ) 334 | test_acc = 0 335 | cnt = 0 336 | 337 | for batch_idx, data in enumerate(test_loader): 338 | contexts, questions, answers = data 339 | batch_size = contexts.size()[0] 340 | contexts = Variable(contexts.long().cuda()) 341 | questions = Variable(questions.long().cuda()) 342 | answers = Variable(answers.cuda()) 343 | 344 | model.load_state_dict(best_state) 345 | _, acc = model.get_loss(contexts, questions, answers) 346 | test_acc += acc * batch_size 347 | cnt += batch_size 348 | print(f'[Run {run}, Task {task_id}, Epoch {epoch}] [Test] Accuracy : {test_acc / cnt: {5}.{4}}') 349 | os.makedirs('models', exist_ok=True) 350 | with open(f'models/task{task_id}_epoch{epoch}_run{run}_acc{test_acc/cnt}.pth', 'wb') as fp: 351 | torch.save(model.state_dict(), fp) 352 | with open('log.txt', 'a') as fp: 353 | fp.write(f'[Run {run}, Task {task_id}, Epoch {epoch}] [Test] Accuracy : {total_acc: {5}.{4}}' + '\n') -------------------------------------------------------------------------------- /fetch_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | url=http://www.thespermwhale.com/jaseweston/babi/tasks_1-20_v1-2.tar.gz 4 | fname=`basename $url` 5 | 6 | curl -SLO $url 7 | tar zxvf $fname 8 | mkdir -p data 9 | mv tasks_1-20_v1-2/* data/ 10 | rm -r tasks_1-20_v1-2 11 | rm tasks_1-20_v1-2.tar.gz -------------------------------------------------------------------------------- /pretrained_models/task10_acc1.0.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dandelin/Dynamic-memory-networks-plus-Pytorch/ad49955f907c03aade2f6c8ed13370ce7288d5a7/pretrained_models/task10_acc1.0.pth -------------------------------------------------------------------------------- /pretrained_models/task11_acc1.0.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dandelin/Dynamic-memory-networks-plus-Pytorch/ad49955f907c03aade2f6c8ed13370ce7288d5a7/pretrained_models/task11_acc1.0.pth -------------------------------------------------------------------------------- /pretrained_models/task12_acc1.0.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dandelin/Dynamic-memory-networks-plus-Pytorch/ad49955f907c03aade2f6c8ed13370ce7288d5a7/pretrained_models/task12_acc1.0.pth -------------------------------------------------------------------------------- /pretrained_models/task13_acc1.0.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dandelin/Dynamic-memory-networks-plus-Pytorch/ad49955f907c03aade2f6c8ed13370ce7288d5a7/pretrained_models/task13_acc1.0.pth -------------------------------------------------------------------------------- /pretrained_models/task14_acc0.9900000095367432.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dandelin/Dynamic-memory-networks-plus-Pytorch/ad49955f907c03aade2f6c8ed13370ce7288d5a7/pretrained_models/task14_acc0.9900000095367432.pth -------------------------------------------------------------------------------- /pretrained_models/task15_acc1.0.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dandelin/Dynamic-memory-networks-plus-Pytorch/ad49955f907c03aade2f6c8ed13370ce7288d5a7/pretrained_models/task15_acc1.0.pth -------------------------------------------------------------------------------- /pretrained_models/task16_acc0.5169999986886978.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dandelin/Dynamic-memory-networks-plus-Pytorch/ad49955f907c03aade2f6c8ed13370ce7288d5a7/pretrained_models/task16_acc0.5169999986886978.pth -------------------------------------------------------------------------------- /pretrained_models/task17_acc0.8649999916553497.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dandelin/Dynamic-memory-networks-plus-Pytorch/ad49955f907c03aade2f6c8ed13370ce7288d5a7/pretrained_models/task17_acc0.8649999916553497.pth -------------------------------------------------------------------------------- /pretrained_models/task18_acc0.9790000081062317.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dandelin/Dynamic-memory-networks-plus-Pytorch/ad49955f907c03aade2f6c8ed13370ce7288d5a7/pretrained_models/task18_acc0.9790000081062317.pth -------------------------------------------------------------------------------- /pretrained_models/task19_acc0.997000002861023.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dandelin/Dynamic-memory-networks-plus-Pytorch/ad49955f907c03aade2f6c8ed13370ce7288d5a7/pretrained_models/task19_acc0.997000002861023.pth -------------------------------------------------------------------------------- /pretrained_models/task1_acc1.0.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dandelin/Dynamic-memory-networks-plus-Pytorch/ad49955f907c03aade2f6c8ed13370ce7288d5a7/pretrained_models/task1_acc1.0.pth -------------------------------------------------------------------------------- /pretrained_models/task20_acc1.0.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dandelin/Dynamic-memory-networks-plus-Pytorch/ad49955f907c03aade2f6c8ed13370ce7288d5a7/pretrained_models/task20_acc1.0.pth -------------------------------------------------------------------------------- /pretrained_models/task2_acc0.9680000007152557.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dandelin/Dynamic-memory-networks-plus-Pytorch/ad49955f907c03aade2f6c8ed13370ce7288d5a7/pretrained_models/task2_acc0.9680000007152557.pth -------------------------------------------------------------------------------- /pretrained_models/task3_acc0.8919999957084656.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dandelin/Dynamic-memory-networks-plus-Pytorch/ad49955f907c03aade2f6c8ed13370ce7288d5a7/pretrained_models/task3_acc0.8919999957084656.pth -------------------------------------------------------------------------------- /pretrained_models/task4_acc1.0.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dandelin/Dynamic-memory-networks-plus-Pytorch/ad49955f907c03aade2f6c8ed13370ce7288d5a7/pretrained_models/task4_acc1.0.pth -------------------------------------------------------------------------------- /pretrained_models/task5_acc0.9950000047683716.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dandelin/Dynamic-memory-networks-plus-Pytorch/ad49955f907c03aade2f6c8ed13370ce7288d5a7/pretrained_models/task5_acc0.9950000047683716.pth -------------------------------------------------------------------------------- /pretrained_models/task6_acc1.0.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dandelin/Dynamic-memory-networks-plus-Pytorch/ad49955f907c03aade2f6c8ed13370ce7288d5a7/pretrained_models/task6_acc1.0.pth -------------------------------------------------------------------------------- /pretrained_models/task7_acc0.978000009059906.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dandelin/Dynamic-memory-networks-plus-Pytorch/ad49955f907c03aade2f6c8ed13370ce7288d5a7/pretrained_models/task7_acc0.978000009059906.pth -------------------------------------------------------------------------------- /pretrained_models/task8_acc1.0.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dandelin/Dynamic-memory-networks-plus-Pytorch/ad49955f907c03aade2f6c8ed13370ce7288d5a7/pretrained_models/task8_acc1.0.pth -------------------------------------------------------------------------------- /pretrained_models/task9_acc1.0.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dandelin/Dynamic-memory-networks-plus-Pytorch/ad49955f907c03aade2f6c8ed13370ce7288d5a7/pretrained_models/task9_acc1.0.pth --------------------------------------------------------------------------------