├── .gitignore ├── LICENSE ├── README.org ├── aoareader ├── AoAReader.py ├── Constants.py ├── Dataset.py ├── Dict.py └── __init__.py ├── preprocess.py ├── test.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | data/* -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2017 Kevin Ling 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of 6 | this software and associated documentation files (the "Software"), to deal in 7 | the Software without restriction, including without limitation the rights to 8 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 9 | the Software, and to permit persons to whom the Software is furnished to do so, 10 | subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 17 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 18 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 19 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 20 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 21 | -------------------------------------------------------------------------------- /README.org: -------------------------------------------------------------------------------- 1 | #+TITLE: Attention-over-Attention Model for Reading Comprehension 2 | 3 | This is an implementation of Attention-over-Attention Model with PyTorch. This 4 | model was proposed by Cui et al. ([[https://arxiv.org/pdf/1607.04423.pdf][paper]]). 5 | * Prerequisite 6 | + PyTorch with cuda 7 | + Python 3.6+ 8 | + NLTK (with punkt data) 9 | * Usage 10 | This implementation uses facebook's children's book test data. 11 | ** Preprocessing 12 | Make sure the data files (train.txt, dev.txt, test.txt) are present in the =data= directory. 13 | 14 | To preprocess the data: 15 | #+BEGIN_SRC bash 16 | python preprocess.py 17 | #+END_SRC 18 | This will generate the dictonary(=dict.pt=) from all words appeared in the dataset and 19 | vectorize all data (=train.txt.pt=, =dev.txt.pt=, =test.txt.pt=). 20 | ** Train the model 21 | Below is an example of training a model, set the parameters as you like. 22 | #+BEGIN_SRC bash 23 | python train.py -traindata data/train.txt.pt -validdata data/test.txt.pt -dict data/dict.pt \ 24 | -save_model model1 -gru_size 384 -embed_size 384 -batch_size 64 -dropout 0.1 \ 25 | -epochs 13 -learning_rate 0.001 -weigth_decay 0.0001 -gpu 1 -log_interval 50 26 | #+END_SRC 27 | After each epoch, a checkpoint will be saved, to resume a training process 28 | from checkpoint: 29 | #+BEGIN_SRC bash 30 | python train.py -train_from xxx_model_xxx_epoch_x.pt 31 | #+END_SRC 32 | ** Testing 33 | #+BEGIN_SRC bash 34 | python test.py -testdata data/test.txt.pt -dict data/dict.pt -out result.txt -model models/xx_checkpoint_epochxx.pt 35 | #+END_SRC 36 | * License 37 | [[file:LICENSE][MIT License]] 38 | -------------------------------------------------------------------------------- /aoareader/AoAReader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | from torch.nn.utils.rnn import pad_packed_sequence as unpack 5 | from torch.nn.utils.rnn import pack_padded_sequence as pack 6 | import torch.nn.functional as F 7 | import torch.nn.init as weigth_init 8 | from aoareader import Constants 9 | 10 | 11 | def sort_batch(data, seq_len): 12 | sorted_seq_len, sorted_idx = torch.sort(seq_len, dim=0, descending=True) 13 | sorted_data = data[sorted_idx.data] 14 | _, reverse_idx = torch.sort(sorted_idx, dim=0, descending=False) 15 | return sorted_data, sorted_seq_len.cuda(), reverse_idx.cuda() 16 | 17 | 18 | 19 | 20 | def softmax_mask(input, mask, axis=1, epsilon=1e-12): 21 | shift, _ = torch.max(input, axis, keepdim=True) 22 | shift = shift.expand_as(input).cuda() 23 | 24 | target_exp = torch.exp(input - shift) * mask 25 | 26 | normalize = torch.sum(target_exp, axis, keepdim=True).expand_as(target_exp) 27 | softm = target_exp / (normalize + epsilon) 28 | 29 | return softm.cuda() 30 | 31 | 32 | class AoAReader(nn.Module): 33 | 34 | def __init__(self, vocab_dict, dropout_rate, embed_dim, hidden_dim, bidirectional=True): 35 | super(AoAReader, self).__init__() 36 | self.vocab_dict = vocab_dict 37 | self.hidden_dim = hidden_dim 38 | self.embed_dim = embed_dim 39 | self.dropout_rate = dropout_rate 40 | 41 | self.embedding = nn.Embedding(vocab_dict.size(), 42 | self.embed_dim, 43 | padding_idx=Constants.PAD) 44 | self.embedding.weight.data.uniform_(-0.05, 0.05) 45 | 46 | input_size = self.embed_dim 47 | self.gru = nn.GRU(input_size, hidden_size=self.hidden_dim, dropout=dropout_rate, 48 | bidirectional=bidirectional, batch_first=True) 49 | 50 | # try independent gru 51 | #self.query_gru = nn.GRU(input_size, hidden_size=self.hidden_dim, dropout=dropout_rate, 52 | # bidirectional=bidirectional, batch_first=True) 53 | 54 | for weight in self.gru.parameters(): 55 | if len(weight.size()) > 1: 56 | weigth_init.orthogonal(weight.data) 57 | 58 | def forward(self, docs_input, docs_len, doc_mask, 59 | querys_input, querys_len, query_mask, 60 | candidates=None, answers=None): 61 | s_docs, s_docs_len, reverse_docs_idx = sort_batch(docs_input, docs_len) 62 | s_querys, s_querys_len, reverse_querys_idx = sort_batch(querys_input, querys_len) 63 | 64 | docs_embedding = pack(self.embedding(s_docs), list(s_docs_len.data), batch_first=True) 65 | querys_embedding = pack(self.embedding(s_querys), list(s_querys_len.data), batch_first=True) 66 | 67 | # encode 68 | docs_outputs, _ = self.gru(docs_embedding, None) 69 | querys_outputs, _ = self.gru(querys_embedding, None) 70 | 71 | # unpack 72 | docs_outputs, _ = unpack(docs_outputs, batch_first=True) 73 | querys_outputs, _ = unpack(querys_outputs, batch_first=True) 74 | 75 | docs_outputs = docs_outputs[reverse_docs_idx.data] 76 | querys_outputs = querys_outputs[reverse_querys_idx.data] 77 | 78 | 79 | # transpose query for pair-wise dot product 80 | dos = docs_outputs 81 | doc_mask = doc_mask.unsqueeze(2) 82 | qos = torch.transpose(querys_outputs, 1, 2) 83 | query_mask = query_mask.unsqueeze(2) 84 | 85 | # pair-wise matching score 86 | M = torch.bmm(dos, qos) 87 | M_mask = torch.bmm(doc_mask, query_mask.transpose(1, 2)) 88 | # query-document attention 89 | alpha = softmax_mask(M, M_mask, axis=1) 90 | beta = softmax_mask(M, M_mask, axis=2) 91 | 92 | sum_beta = torch.sum(beta, dim=1, keepdim=True) 93 | 94 | docs_len = docs_len.unsqueeze(1).unsqueeze(2).expand_as(sum_beta) 95 | average_beta = sum_beta / docs_len.float() 96 | 97 | 98 | # attended document-level attention 99 | s = torch.bmm(alpha, average_beta.transpose(1, 2)) 100 | # predict the most possible answer from given candidates 101 | pred_answers = None 102 | #pred_locs = None 103 | probs = None 104 | if candidates is not None: 105 | pred_answers = [] 106 | pred_locs = [] 107 | for i, cands in enumerate(candidates): 108 | pb = [] 109 | document = docs_input[i].squeeze() 110 | for j, candidate in enumerate(cands): 111 | pointer = document == candidate.expand_as(document) 112 | pb.append(torch.sum(torch.masked_select(s[i].squeeze(), pointer), keepdim=True)) 113 | pb = torch.cat(pb, dim=0).squeeze() 114 | _ , max_loc = torch.max(pb, 0) 115 | pred_answers.append(cands.index_select(0, max_loc)) 116 | pred_locs.append(max_loc) 117 | pred_answers = torch.cat(pred_answers, dim=0).squeeze() 118 | #pred_locs = torch.cat(pred_locs, dim=0).squeeze() 119 | 120 | if answers is not None: 121 | probs = [] 122 | for i, answer in enumerate(answers): 123 | document = docs_input[i].squeeze() 124 | pointer = document == answer.expand_as(document) 125 | this_prob = torch.sum(torch.masked_select(s[i].squeeze(), pointer)) 126 | probs.append(this_prob) 127 | probs = torch.cat(probs, 0).squeeze() 128 | 129 | return pred_answers, probs 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | -------------------------------------------------------------------------------- /aoareader/Constants.py: -------------------------------------------------------------------------------- 1 | PAD = 0 2 | UNK = 1 3 | 4 | PAD_WORD = "" 5 | UNK_WORD = "" -------------------------------------------------------------------------------- /aoareader/Dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import math 4 | import random 5 | 6 | import torch 7 | from torch.autograd import Variable 8 | 9 | import aoareader 10 | 11 | 12 | def create_mask(seq_lens): 13 | mask = torch.zeros(len(seq_lens), torch.max(seq_lens)) 14 | for i, seq_len in enumerate(seq_lens): 15 | mask[i][:seq_len] = 1 16 | 17 | return mask.float() 18 | 19 | class Dataset(object): 20 | 21 | def __init__(self, data: dict, batch_size, cuda, volatile=False): 22 | self.documents = data['documents'] 23 | self.querys = data['querys'] 24 | self.candidates = data['candidates'] 25 | self.answers = data.get('answers', None) 26 | 27 | # check if dimensions match 28 | assert len(self.documents) == len(self.querys) == len(self.candidates) 29 | 30 | if self.answers is not None: 31 | assert len(self.querys) == len(self.answers) 32 | 33 | self.cuda = cuda 34 | 35 | self.batch_size = batch_size 36 | self.numBatches = math.ceil(len(self.querys)/batch_size) 37 | self.volatile = volatile 38 | 39 | def _batchify(self, data, align_right=False, include_lengths=False): 40 | lengths = [x.size(0) for x in data] 41 | max_length = max(lengths) 42 | out = data[0].new(len(data), max_length).fill_(aoareader.Constants.PAD) 43 | for i in range(len(data)): 44 | data_length = data[i].size(0) 45 | offset = max_length - data_length if align_right else 0 46 | out[i].narrow(0, offset, data_length).copy_(data[i]) 47 | 48 | if include_lengths: 49 | return out, lengths 50 | else: 51 | return out 52 | 53 | def __getitem__(self, index): 54 | assert index < self.numBatches, "%d > %d" % (index, self.numBatches) 55 | 56 | documents, doc_lengths = self._batchify( 57 | self.documents[index*self.batch_size:(index+1)*self.batch_size], 58 | align_right=False, include_lengths=True) 59 | 60 | querys, q_lengths = self._batchify( 61 | self.querys[index*self.batch_size:(index+1)*self.batch_size], 62 | align_right=False, include_lengths=True) 63 | 64 | candidates = self._batchify( 65 | self.candidates[index*self.batch_size:(index+1)*self.batch_size], 66 | align_right=False, include_lengths=False) 67 | 68 | if self.answers is not None: 69 | answers = torch.LongTensor(self.answers[index*self.batch_size:(index+1)*self.batch_size]) 70 | else: 71 | answers = None 72 | 73 | def wrap(b: torch.LongTensor): 74 | if b is None: 75 | return b 76 | if len(b.size()) > 1: 77 | b = torch.stack(b, 0) 78 | b = b.contiguous() 79 | if self.cuda: 80 | b = b.cuda() 81 | b = Variable(b, volatile=self.volatile, requires_grad=False) 82 | return b 83 | 84 | doc_lengths = torch.LongTensor(doc_lengths) 85 | doc_mask = create_mask(doc_lengths) 86 | q_lengths = torch.LongTensor(q_lengths) 87 | q_mask = create_mask(q_lengths) 88 | 89 | return (wrap(documents), wrap(doc_lengths), wrap(doc_mask)), (wrap(querys), wrap(q_lengths), wrap(q_mask)), wrap(answers), wrap(candidates) 90 | 91 | def __len__(self): 92 | return self.numBatches 93 | 94 | def shuffle(self): 95 | data = list(zip(self.documents, self.querys, self.candidates, self.answers)) 96 | self.documents, self.querys, self.candidates, self.answers = zip(*[data[i] for i in torch.randperm(len(data))]) -------------------------------------------------------------------------------- /aoareader/Dict.py: -------------------------------------------------------------------------------- 1 | from aoareader import Constants 2 | import torch 3 | 4 | class Dict: 5 | 6 | def __init__(self, word2idx): 7 | self.word2idx = word2idx 8 | self.idx2word = {idx: word for word, idx in word2idx.items()} 9 | 10 | def getIdx(self, word): 11 | return self.word2idx.get(word, Constants.UNK) 12 | 13 | def getWord(self, idx): 14 | return self.idx2word.get(idx, Constants.UNK_WORD) 15 | 16 | def convert2idx(self, words): 17 | vec = [self.getIdx(word) for word in words] 18 | 19 | return torch.LongTensor(vec) 20 | 21 | def convert2word(self, idxs): 22 | vec = [self.getWord(idx) for idx in idxs] 23 | return vec 24 | 25 | def size(self): 26 | return len(self.idx2word) 27 | 28 | -------------------------------------------------------------------------------- /aoareader/__init__.py: -------------------------------------------------------------------------------- 1 | import aoareader.Constants 2 | 3 | from aoareader.Dataset import Dataset 4 | from aoareader.AoAReader import AoAReader -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | # From https://github.com/nschuc/alternating-reader-tf/blob/master/load_data.py, some modifications are made 2 | 3 | import json 4 | import os 5 | import numpy as np 6 | import torch 7 | from functools import reduce 8 | import itertools 9 | import time 10 | 11 | # parallel processing 12 | from joblib import Parallel, delayed 13 | 14 | import aoareader.Constants 15 | from aoareader.Dict import Dict as Vocabulary 16 | 17 | from nltk.tokenize import word_tokenize 18 | 19 | from sys import argv 20 | 21 | data_path = 'data/' 22 | data_filenames = { 23 | 'train': 'train.txt', 24 | 'valid': 'dev.txt', 25 | 'test': 'test.txt' 26 | } 27 | vocab_file = os.path.join(data_path, 'vocab.json') 28 | dict_file = os.path.join(data_path, 'dict.pt') 29 | 30 | def tokenize(sentence): 31 | return [s.strip().lower() for s in word_tokenize(sentence) if s.strip()] 32 | 33 | 34 | def parse_stories(lines, with_answer=True): 35 | stories = [] 36 | story = [] 37 | for line in lines: 38 | line = line.strip() 39 | if not line: 40 | story = [] 41 | else: 42 | _, line = line.split(' ', 1) 43 | if line: 44 | if '\t' in line: # query line 45 | answer = '' 46 | if with_answer: 47 | q, answer, _, candidates = line.split('\t') 48 | answer = answer.lower() 49 | else: 50 | q, _, candidates = line.split('\t') 51 | q = tokenize(q) 52 | 53 | # use the first 10 54 | candidates = [cand.lower() for cand in candidates.split('|')[:10]] 55 | stories.append((story, q, answer, candidates)) 56 | else: 57 | story.append(tokenize(line)) 58 | return stories 59 | 60 | 61 | def get_stories(story_lines, with_answer=True): 62 | stories = parse_stories(story_lines, with_answer=with_answer) 63 | flatten = lambda story: reduce(lambda x, y: x + y, story) 64 | stories = [(flatten(story), q, a, candidates) for story, q, a, candidates in stories] 65 | return stories 66 | 67 | 68 | def vectorize_stories(stories, vocab : Vocabulary): 69 | X = [] 70 | Q = [] 71 | C = [] 72 | A = [] 73 | 74 | for s, q, a, c in stories: 75 | x = vocab.convert2idx(s) 76 | xq = vocab.convert2idx(q) 77 | xc = vocab.convert2idx(c) 78 | X.append(x) 79 | Q.append(xq) 80 | C.append(xc) 81 | A.append(vocab.getIdx(a)) 82 | 83 | X = X 84 | Q = Q 85 | C = C 86 | A = torch.LongTensor(A) 87 | return X, Q, A, C 88 | 89 | 90 | def build_dict(stories): 91 | if os.path.isfile(vocab_file): 92 | with open(vocab_file, "r") as vf: 93 | word2idx = json.load(vf) 94 | else: 95 | 96 | vocab = sorted(set(itertools.chain(*(story + q + [answer] + candidates 97 | for story, q, answer, candidates in stories)))) 98 | vocab_size = len(vocab) + 2 # pad, unk 99 | print('Vocab size:', vocab_size) 100 | word2idx = dict((w, i + 2) for i,w in enumerate(vocab)) 101 | word2idx[aoareader.Constants.UNK_WORD] = 1 102 | word2idx[aoareader.Constants.PAD_WORD] = 0 103 | 104 | with open(vocab_file, "w") as vf: 105 | json.dump(word2idx, vf) 106 | 107 | return Vocabulary(word2idx) 108 | 109 | 110 | def main(): 111 | 112 | print('Preparing process dataset ...') 113 | train_filename = os.path.join(data_path, data_filenames['train']) 114 | valid_filename = os.path.join(data_path, data_filenames['valid']) 115 | test_filename = os.path.join(data_path, data_filenames['test']) 116 | 117 | 118 | with open(train_filename, 'r') as tf, open(valid_filename, 'r') as vf, open(test_filename, 'r') as tef: 119 | tlines = tf.readlines() 120 | vlines = vf.readlines() 121 | telines = tef.readlines() 122 | train_stories, valid_stories, test_stories = Parallel(n_jobs=2)(delayed(get_stories)(story_lines) 123 | for story_lines in [tlines, vlines, telines]) 124 | 125 | 126 | print('Preparing build dictionary ...') 127 | vocab_dict = build_dict(train_stories + valid_stories + test_stories) 128 | 129 | print('Preparing training, validation, testing ...') 130 | train = {} 131 | valid = {} 132 | test = {} 133 | 134 | train_data, valid_data, test_data = Parallel(n_jobs=2)(delayed(vectorize_stories)(stories, vocab_dict) 135 | for stories in [train_stories, valid_stories, test_stories]) 136 | train['documents'], train['querys'], train['answers'], train['candidates'] = train_data 137 | valid['documents'], valid['querys'], valid['answers'], valid['candidates'] = valid_data 138 | test['documents'], test['querys'], test['answers'], test['candidates'] = test_data 139 | 140 | 141 | print('Saving data to \'' + data_path + '\'...') 142 | torch.save(vocab_dict, dict_file) 143 | torch.save(train, train_filename + '.pt') 144 | torch.save(valid, valid_filename + '.pt') 145 | torch.save(test, test_filename + '.pt') 146 | 147 | if __name__ == '__main__': 148 | main() 149 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import aoareader as reader 2 | import torch 3 | import time 4 | import argparse 5 | import os 6 | 7 | from preprocess import get_stories, vectorize_stories 8 | 9 | 10 | parser = argparse.ArgumentParser(description="test.py") 11 | 12 | parser.add_argument('-testdata', default='data/test.txt.pt', 13 | help='Path to the test.txt.pt, test.txt.pt will be used if exists.') 14 | 15 | parser.add_argument('-dict', default="data/dict.pt", 16 | help='Path to the dictionary file, default value: data/dict.pt') 17 | 18 | parser.add_argument('-out', default='data/result.txt', 19 | help='output file name.') 20 | 21 | parser.add_argument('-model', required=True, help='path to the saved model.') 22 | 23 | 24 | testopt = parser.parse_args() 25 | print(testopt) 26 | 27 | 28 | def load_testdata(testfile, vocab_dict, with_answer=True): 29 | if os.path.exists(testfile + '.pt'): 30 | return torch.load(testfile + '.pt') 31 | else: 32 | testd = {} 33 | with open(testfile, 'r') as tf: 34 | tlines = tf.readlines() 35 | test_stories = get_stories(tlines, with_answer=with_answer) 36 | testd['documents'], testd['querys'], testd['answers'], testd['candidates'] = vectorize_stories(test_stories, vocab_dict) 37 | torch.save(testd, testfile + '.pt') 38 | return testd 39 | 40 | def evalulate(model, data, vocab_dict): 41 | 42 | def acc(answers, pred_answers): 43 | num_correct = (answers == pred_answers).sum().squeeze().data[0] 44 | return num_correct 45 | 46 | model.eval() 47 | answers = [] 48 | total_correct = 0 49 | total = 0 50 | for i in range(len(data)): 51 | (batch_docs, batch_docs_len, doc_mask), (batch_querys, batch_querys_len, query_mask), batch_answers , candidates = data[i] 52 | 53 | pred_answers, _ = model(batch_docs, batch_docs_len, doc_mask, 54 | batch_querys, batch_querys_len, query_mask, 55 | candidates=candidates, answers=batch_answers) 56 | 57 | answers.extend(pred_answers.data) 58 | num_correct = acc(batch_answers, pred_answers) 59 | 60 | total_in_minibatch = batch_answers.size(0) 61 | total_correct += num_correct 62 | total += total_in_minibatch 63 | del pred_answers 64 | 65 | print("Evaluating on test set:\nAccurary {:.2%}".format(total_correct / total)) 66 | return vocab_dict.convert2word(answers) 67 | 68 | def main(): 69 | print("Loading dict", testopt.dict) 70 | vocab_dict = torch.load(testopt.dict) 71 | 72 | print("Loading test data") 73 | test_data = torch.load(testopt.testdata) 74 | 75 | print("Loading model from ", testopt.model) 76 | ckp = torch.load(testopt.model) 77 | 78 | opt = ckp['opt'] 79 | model_state = ckp['model'] 80 | 81 | if opt.gpu: 82 | torch.cuda.set_device(opt.gpu) 83 | 84 | test_dataset = reader.Dataset(test_data, opt.batch_size, True, volatile=True) 85 | 86 | print(' * vocabulary size = %d' % 87 | (vocab_dict.size())) 88 | print(' * number of test samples. %d' % 89 | len(test_data['candidates'])) 90 | print(' * maximum batch size. %d' % opt.batch_size) 91 | 92 | print('Building model...') 93 | 94 | model = reader.AoAReader(vocab_dict, dropout_rate=opt.dropout, embed_dim=opt.embed_size, hidden_dim=opt.gru_size) 95 | # no way on CPU 96 | model.cuda() 97 | 98 | # load state 99 | model.load_state_dict(model_state) 100 | 101 | print('Evaluate on test data') 102 | answers = evalulate(model, test_dataset, vocab_dict) 103 | 104 | with open(testopt.out, 'w') as out: 105 | print('\n'.join(answers), file=out) 106 | 107 | if __name__ == '__main__': 108 | main() 109 | 110 | 111 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import aoareader as reader 2 | import torch 3 | import time 4 | import argparse 5 | 6 | # torch.backends.cudnn.enabled=True 7 | 8 | parser = argparse.ArgumentParser(description="train.py") 9 | 10 | # train options 11 | 12 | parser.add_argument('-traindata', default='data/train.txt.pt', 13 | help='Path to the *-train.pt file from preprocess.py, default value is \'data/train.txt.pt\'') 14 | parser.add_argument('-validdata', default='data/dev.txt.pt', 15 | help='Path to the *-dev.pt file from preprocess.py, default value is \'data/dev.txt.pt\'') 16 | parser.add_argument('-dict', default='data/dict.pt', 17 | help='Path to the dictionary file from preprocess.py, default value is \'data/dict.pt\'') 18 | parser.add_argument('-save_model', default='model', 19 | help="""Model filename (the model will be saved as 20 | _epochN_ACC.pt to 'models/' directory, where ACC is the 21 | validation accuracy""") 22 | parser.add_argument('-train_from', default='', type=str, 23 | help="""If training from a checkpoint then this is the 24 | path to the pre-trained model.""") 25 | 26 | # model parameters 27 | 28 | parser.add_argument('-gru_size', type=int, default=384, 29 | help='Size of GRU hidden states') 30 | parser.add_argument('-embed_size', type=int, default=384, 31 | help='Word embedding sizes') 32 | 33 | 34 | # optimization 35 | 36 | parser.add_argument('-batch_size', type=int, default=32, 37 | help='Maximum batch size') 38 | 39 | parser.add_argument('-dropout', type=float, default=0.1, 40 | help='Dropout probability; applied in bidirectional gru.') 41 | 42 | parser.add_argument('-start_epoch', type=int, default=1, 43 | help='The epoch from which to start') 44 | 45 | parser.add_argument('-epochs', type=int, default=13, 46 | help='Number of training epochs') 47 | 48 | parser.add_argument('-learning_rate', type=float, default=0.001, 49 | help="""Starting learning rate. Adam is 50 | used, this is the global learning rate.""") 51 | 52 | parser.add_argument('-weight_decay', type=float, default=0.0001, 53 | help="""weight decay (L2 penalty)""") 54 | 55 | # GPU 56 | 57 | parser.add_argument('-gpu', default=0, type=int, 58 | help="which gpu to use. (0, 1...)") 59 | 60 | # Log 61 | 62 | parser.add_argument('-log_interval', type=int, default=50, 63 | help="Print stats at this interval (minibatches).") 64 | 65 | 66 | opt = parser.parse_args() 67 | print(opt) 68 | 69 | if opt.gpu: 70 | torch.cuda.set_device(opt.gpu) 71 | 72 | 73 | def loss_func(answers, pred_answers, answer_probs): 74 | num_correct = (answers == pred_answers).sum().squeeze().data[0] 75 | loss = - torch.mean(torch.log(answer_probs), keepdim=True) 76 | return loss.cuda(), num_correct 77 | 78 | def eval(model, data): 79 | total_loss = 0 80 | total = 0 81 | total_correct = 0 82 | 83 | model.eval() 84 | for i in range(len(data)): 85 | (batch_docs, batch_docs_len, doc_mask), (batch_querys, batch_querys_len, query_mask), batch_answers, candidates = data[i] 86 | 87 | pred_answers, probs = model(batch_docs, batch_docs_len, doc_mask, 88 | batch_querys, batch_querys_len, query_mask, 89 | answers=batch_answers, candidates=candidates) 90 | 91 | loss, num_correct = loss_func(batch_answers, pred_answers, probs) 92 | 93 | total_in_minibatch = batch_answers.size(0) 94 | total_loss += loss.data[0] * total_in_minibatch 95 | total_correct += num_correct 96 | total += total_in_minibatch 97 | 98 | del loss, pred_answers, probs 99 | 100 | model.train() 101 | return total_loss / total, total_correct / total 102 | 103 | 104 | def trainModel(model, trainData, validData, optimizer: torch.optim.Adam): 105 | print(model) 106 | start_time = time.time() 107 | 108 | def trainEpoch(epoch): 109 | trainData.shuffle() 110 | 111 | total_loss, total, total_num_correct = 0, 0, 0 112 | report_loss, report_total, report_num_correct = 0, 0, 0 113 | for i in range(len(trainData)): 114 | (batch_docs, batch_docs_len, doc_mask), (batch_querys, batch_querys_len, query_mask), batch_answers, candidates = trainData[i] 115 | 116 | model.zero_grad() 117 | pred_answers, answer_probs = model(batch_docs, batch_docs_len, doc_mask, batch_querys, batch_querys_len, query_mask,answers=batch_answers, candidates=candidates) 118 | 119 | loss, num_correct = loss_func(batch_answers, pred_answers, answer_probs) 120 | 121 | loss.backward() 122 | for parameter in model.parameters(): 123 | parameter.grad.data.clamp_(-5.0, 5.0) 124 | # update the parameters 125 | optimizer.step() 126 | 127 | total_in_minibatch = batch_answers.size(0) 128 | 129 | report_loss += loss.data[0] * total_in_minibatch 130 | report_num_correct += num_correct 131 | report_total += total_in_minibatch 132 | 133 | total_loss += loss.data[0] * total_in_minibatch 134 | total_num_correct += num_correct 135 | total += total_in_minibatch 136 | if i % opt.log_interval == 0: 137 | print("Epoch %2d, %5d/%5d; avg loss: %.2f; acc: %6.2f; %6.0f s elapsed" % 138 | (epoch, i+1, len(trainData), 139 | report_loss / report_total, 140 | report_num_correct / report_total * 100, 141 | time.time()-start_time)) 142 | 143 | report_loss = report_total = report_num_correct = 0 144 | del loss, pred_answers, answer_probs 145 | 146 | return total_loss / total, total_num_correct / total 147 | 148 | for epoch in range(opt.start_epoch, opt.epochs + 1): 149 | print('') 150 | 151 | # (1) train for one epoch on the training set 152 | train_loss, train_acc = trainEpoch(epoch) 153 | print('Epoch %d:\t average loss: %.2f\t train accuracy: %g' % (epoch, train_loss, train_acc*100)) 154 | 155 | # (2) evaluate on the validation set 156 | valid_loss, valid_acc = eval(model, validData) 157 | print('=' * 20) 158 | print('Evaluating on validation set:') 159 | print('Validation loss: %.2f' % valid_loss) 160 | print('Validation accuracy: %g' % (valid_acc*100)) 161 | print('=' * 20) 162 | 163 | model_state_dict = model.state_dict() 164 | optimizer_state_dict = optimizer.state_dict() 165 | # (4) drop a checkpoint 166 | checkpoint = { 167 | 'model': model_state_dict, 168 | 'epoch': epoch, 169 | 'optimizer': optimizer_state_dict, 170 | 'opt': opt, 171 | } 172 | torch.save(checkpoint, 173 | 'models/%s_epoch%d_acc_%.2f.pt' % (opt.save_model, epoch, 100*valid_acc)) 174 | 175 | def main(): 176 | global opt 177 | train_from = opt.train_from 178 | if opt.train_from: 179 | train_from = True 180 | checkpoint = torch.load(opt.train_from) 181 | opt = checkpoint['opt'] 182 | 183 | 184 | print("Loading dictrionary from ", opt.dict) 185 | vocab_dict = torch.load(opt.dict) 186 | print("Loading train data from ", opt.traindata) 187 | train_data = torch.load(opt.traindata) 188 | print("Loading valid data from ", opt.validdata) 189 | valid_data = torch.load(opt.validdata) 190 | 191 | train_dataset = reader.Dataset(train_data, opt.batch_size, True) 192 | valid_dataset = reader.Dataset(valid_data, opt.batch_size, True, volatile=True) 193 | 194 | print(' * vocabulary size = %d' % 195 | (vocab_dict.size())) 196 | print(' * number of training samples. %d' % 197 | len(train_data['answers'])) 198 | print(' * maximum batch size. %d' % opt.batch_size) 199 | 200 | print('Building model...') 201 | 202 | model = reader.AoAReader(vocab_dict, dropout_rate=opt.dropout, embed_dim=opt.embed_size, hidden_dim=opt.gru_size) 203 | # no way on CPU 204 | model.cuda() 205 | 206 | if train_from: 207 | print('Loading model from checkpoint at %s' % opt.train_from) 208 | chk_model = checkpoint['model'] 209 | model.load_state_dict(chk_model) 210 | opt.start_epoch = checkpoint['epoch'] + 1 211 | 212 | optimizer = torch.optim.Adam(model.parameters(), lr=opt.learning_rate, weight_decay=opt.weight_decay) 213 | 214 | if train_from: 215 | optimizer.load_state_dict(checkpoint['optimizer']) 216 | 217 | nParams = sum([p.nelement() for p in model.parameters()]) 218 | print('* number of parameters: %d' % nParams) 219 | 220 | 221 | trainModel(model, train_dataset, valid_dataset, optimizer) 222 | 223 | if __name__ == '__main__': 224 | main() 225 | --------------------------------------------------------------------------------