├── .gitignore ├── README.md ├── fetch_data.sh ├── grid_search.sh ├── main.py └── results └── .keep /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints/ 2 | data/ 3 | *.txt 4 | *.gz 5 | *.pyc 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## This repository is deprecated in favor of End-To-End Memory Networks: [npow/MemN2N](https://github.com/npow/MemN2N) 2 | 3 | # Description 4 | This is an implementation of [Memory Networks](http://arxiv.org/abs/1410.3916) in Theano. The dataset can be found [here](http://fb.ai/babi). 5 | 6 | Below are the results obtained from the current implementation on [the set of toy tasks](http://arxiv.org/abs/1502.05698). Performance is better than LSTMs in most of the tasks, but it is still not quite at the level as the MemNN results reported in the original paper. 7 | 8 | | Task#| N-gram Classifier | LSTM | MemNN (Weston 2014) | This Repo | 9 | |------|-------------------|------|---------------------|-----------| 10 | | 1 | 36 | 50 | 100 | 76 | 11 | | 2 | 2 | 20 | 100 | 70 | 12 | | 3 | 7 | 20 | 20 | 18 | 13 | | 4 | 50 | 61 | 71 | 67 | 14 | | 5 | 20 | 70 | 83 | 81 | 15 | | 6 | 49 | 48 | 47 | 45 | 16 | | 7 | 52 | 49 | 68 | N/A | 17 | | 8 | 40 | 45 | 77 | N/A | 18 | | 9 | 62 | 64 | 65 | 67 | 19 | | 10 | 45 | 54 | 59 | 54 | 20 | | 11 | 29 | 72 | 100 | 54 | 21 | | 12 | 9 | 74 | 100 | 63 | 22 | | 13 | 26 | 94 | 100 | 46 | 23 | | 14 | 19 | 27 | 99 | 100 | 24 | | 15 | 20 | 21 | 74 | 100 | 25 | | 16 | 43 | 23 | 27 | 40 | 26 | | 17 | 46 | 51 | 54 | 56 | 27 | | 18 | 52 | 52 | 57 | 53 | 28 | | 19 | 0 | 8 | 0 | N/A | 29 | | 20 | 76 | 91 | 100 | 100 | 30 | | Mean | 34 | 49 | 75 | 64.1 | 31 | 32 | 33 | -------------------------------------------------------------------------------- /fetch_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | url=http://www.thespermwhale.com/jaseweston/babi/tasks_1-20_v1-2.tar.gz 4 | fname=`basename $url` 5 | 6 | wget $url 7 | tar zxvf $fname 8 | mv tasks_1-20_v1-2 data 9 | -------------------------------------------------------------------------------- /grid_search.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export THEANO_FLAGS=device=cpu,floatX=float32 3 | 4 | task=$1 5 | if [[ -z $1 ]]; then 6 | task=3 7 | fi 8 | 9 | for embedding_size in 10 50 100 200 300 400 500 1000; do 10 | for lr in 0.1 0.01 0.001; do 11 | for gamma in 10 1 0.1 0.01 0.001; do 12 | echo "RUNNING task: $task, gamma: $gamma, embedding_size: $embedding_size, lr: $lr" 13 | time python -u main.py --task $task --embedding_size $embedding_size --gamma $gamma > results/q${task}_gamma${gamma}_lr${lr}_d${embedding_size}.txt & 14 | sleep 1 15 | done 16 | wait 17 | done 18 | done 19 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import argparse 3 | import glob 4 | import numpy as np 5 | import sys 6 | from collections import OrderedDict 7 | from sklearn import metrics 8 | from sklearn.feature_extraction.text import * 9 | from sklearn.preprocessing import * 10 | from theano.ifelse import ifelse 11 | import theano 12 | import theano.tensor as T 13 | 14 | def zeros(shape, dtype=np.float32): 15 | return np.zeros(shape, dtype) 16 | 17 | # TODO: convert this to a theano function 18 | def O_t(xs, L, s): 19 | t = 0 20 | for i in xrange(len(L)-1): # last element is the answer, so we can skip it 21 | if s(xs, i, t, L) > 0: 22 | t = i 23 | return t 24 | 25 | def sgd(cost, params, learning_rate): 26 | grads = T.grad(cost, params) 27 | updates = OrderedDict() 28 | 29 | for param, grad in zip(params, grads): 30 | updates[param] = param - learning_rate * grad 31 | 32 | return updates 33 | 34 | class Model: 35 | def __init__(self, train_file, test_file, D=50, gamma=1, lr=0.001): 36 | self.train_lines, self.test_lines = self.get_lines(train_file), self.get_lines(test_file) 37 | lines = np.concatenate([self.train_lines, self.test_lines], axis=0) 38 | 39 | self.vectorizer = CountVectorizer(lowercase=False) 40 | self.vectorizer.fit([x['text'] + ' ' + x['answer'] if 'answer' in x else x['text'] for x in lines]) 41 | 42 | L = self.vectorizer.transform([x['text'] for x in lines]).toarray().astype(np.float32) 43 | self.L_train, self.L_test = L[:len(self.train_lines)], L[len(self.train_lines):] 44 | 45 | self.train_model = None 46 | self.D = D 47 | self.gamma = gamma 48 | self.lr = lr 49 | self.H = None 50 | self.V = None 51 | 52 | def create_train(self, lenW, n_facts): 53 | ONE = theano.shared(np.float32(1)) 54 | ZERO = theano.shared(np.float32(0)) 55 | def phi_x1(x_t, L): 56 | return T.concatenate([L[x_t].reshape((-1,)), zeros((2*lenW,)), zeros((3,))], axis=0) 57 | def phi_x2(x_t, L): 58 | return T.concatenate([zeros((lenW,)), L[x_t].reshape((-1,)), zeros((lenW,)), zeros((3,))], axis=0) 59 | def phi_y(x_t, L): 60 | return T.concatenate([zeros((2*lenW,)), L[x_t].reshape((-1,)), zeros((3,))], axis=0) 61 | def phi_t(x_t, y_t, yp_t, L): 62 | return T.concatenate([zeros(3*lenW,), T.stack(T.switch(T.lt(x_t,y_t), ONE, ZERO), T.switch(T.lt(x_t,yp_t), ONE, ZERO), T.switch(T.lt(y_t,yp_t), ONE, ZERO))], axis=0) 63 | def s_Ot(xs, y_t, yp_t, L): 64 | result, updates = theano.scan( 65 | lambda x_t, t: T.dot(T.dot(T.switch(T.eq(t, 0), phi_x1(x_t, L).reshape((1,-1)), phi_x2(x_t, L).reshape((1,-1))), self.U_Ot.T), 66 | T.dot(self.U_Ot, (phi_y(y_t, L) - phi_y(yp_t, L) + phi_t(x_t, y_t, yp_t, L)))), 67 | sequences=[xs, T.arange(T.shape(xs)[0])]) 68 | return result.sum() 69 | def sR(xs, y_t, L, V): 70 | result, updates = theano.scan( 71 | lambda x_t, t: T.dot(T.dot(T.switch(T.eq(t, 0), phi_x1(x_t, L).reshape((1,-1)), phi_x2(x_t, L).reshape((1,-1))), self.U_R.T), 72 | T.dot(self.U_R, phi_y(y_t, V))), 73 | sequences=[xs, T.arange(T.shape(xs)[0])]) 74 | return result.sum() 75 | 76 | x_t = T.iscalar('x_t') 77 | y_t = T.iscalar('y_t') 78 | yp_t = T.iscalar('yp_t') 79 | xs = T.ivector('xs') 80 | m = [x_t] + [T.iscalar('m_o%d' % i) for i in xrange(n_facts)] 81 | f = [T.iscalar('f%d_t' % i) for i in xrange(n_facts)] 82 | r_t = T.iscalar('r_t') 83 | gamma = T.scalar('gamma') 84 | L = T.fmatrix('L') # list of messages 85 | V = T.fmatrix('V') # vocab 86 | r_args = T.stack(*m) 87 | 88 | cost_arr = [0] * 2 * (len(m)-1) 89 | for i in xrange(len(m)-1): 90 | cost_arr[2*i], _ = theano.scan( 91 | lambda f_bar, t: T.switch(T.or_(T.eq(t, f[i]), T.eq(t, T.shape(L)[0]-1)), 0, T.largest(gamma - s_Ot(T.stack(*m[:i+1]), f[i], t, L), 0)), 92 | sequences=[L, T.arange(T.shape(L)[0])]) 93 | cost_arr[2*i] /= T.shape(L)[0] 94 | cost_arr[2*i+1], _ = theano.scan( 95 | lambda f_bar, t: T.switch(T.or_(T.eq(t, f[i]), T.eq(t, T.shape(L)[0]-1)), 0, T.largest(gamma + s_Ot(T.stack(*m[:i+1]), t, f[i], L), 0)), 96 | sequences=[L, T.arange(T.shape(L)[0])]) 97 | cost_arr[2*i+1] /= T.shape(L)[0] 98 | 99 | cost1, _ = theano.scan( 100 | lambda r_bar, t: T.switch(T.eq(r_t, t), 0, T.largest(gamma - sR(r_args, r_t, L, V) + sR(r_args, t, L, V), 0)), 101 | sequences=[V, T.arange(T.shape(V)[0])]) 102 | cost1 /= T.shape(V)[0] 103 | 104 | cost = cost1.sum() 105 | for c in cost_arr: 106 | cost += c.sum() 107 | 108 | updates = sgd(cost, [self.U_Ot, self.U_R], learning_rate=self.lr) 109 | 110 | self.train_model = theano.function( 111 | inputs=[r_t, gamma, L, V] + m + f, 112 | outputs=[cost], 113 | updates=updates) 114 | 115 | self.sR = theano.function([xs, y_t, L, V], sR(xs, y_t, L, V)) 116 | self.s_Ot = theano.function([xs, y_t, yp_t, L], s_Ot(xs, y_t, yp_t, L)) 117 | 118 | def train(self, n_epochs): 119 | lenW = len(self.vectorizer.vocabulary_) 120 | self.H = {} 121 | for i,v in enumerate(self.vectorizer.vocabulary_): 122 | self.H[v] = i 123 | self.V = self.vectorizer.transform([v for v in self.vectorizer.vocabulary_]).toarray().astype(np.float32) 124 | 125 | W = 3*lenW + 3 126 | self.U_Ot = theano.shared(np.random.uniform(-0.1, 0.1, (self.D, W)).astype(np.float32)) 127 | self.U_R = theano.shared(np.random.uniform(-0.1, 0.1, (self.D, W)).astype(np.float32)) 128 | 129 | prev_err = None 130 | for epoch in range(n_epochs): 131 | total_err = 0 132 | print "*" * 80 133 | print "epoch: ", epoch 134 | n_wrong = 0 135 | 136 | for i,line in enumerate(self.train_lines): 137 | if i > 0 and i % 1000 == 0: 138 | print "i: ", i, " nwrong: ", n_wrong 139 | if line['type'] == 'q': 140 | refs = line['refs'] 141 | f = [ref - 1 for ref in refs] 142 | id = line['id']-1 143 | indices = [idx for idx in range(i-id, i+1)] 144 | memory_list = self.L_train[indices] 145 | # print "REFS: ", self.train_lines[indices][f], "\nMEMORY: ", self.train_lines[indices], '\n', '*' * 80 146 | 147 | if self.train_model is None: 148 | self.create_train(lenW, len(f)) 149 | 150 | m = f 151 | mm = [] 152 | for j in xrange(len(f)): 153 | mm.append(O_t([id]+m[:j], memory_list, self.s_Ot)) 154 | 155 | if mm[0] != f[0]: 156 | n_wrong += 1 157 | 158 | err = self.train_model(self.H[line['answer']], self.gamma, memory_list, self.V, id, *(m + f))[0] 159 | total_err += err 160 | print "i: ", i, " nwrong: ", n_wrong 161 | print "epoch: ", epoch, " err: ", (total_err/len(self.train_lines)) 162 | 163 | # TODO: use validation set 164 | if prev_err is not None and total_err > prev_err: 165 | break 166 | else: 167 | prev_err = total_err 168 | self.test() 169 | 170 | def test(self): 171 | lenW = len(self.vectorizer.vocabulary_) 172 | W = 3*lenW 173 | Y_true = [] 174 | Y_pred = [] 175 | for i,line in enumerate(self.test_lines): 176 | if line['type'] == 'q': 177 | r = line['answer'] 178 | id = line['id']-1 179 | indices = [idx for idx in range(i-id, i+1)] 180 | memory_list = self.L_test[indices] 181 | 182 | m_o1 = O_t([id], memory_list, self.s_Ot) 183 | m_o2 = O_t([id, m_o1], memory_list, self.s_Ot) 184 | 185 | bestVal = None 186 | best = None 187 | for w in self.vectorizer.vocabulary_: 188 | val = self.sR([id, m_o1, m_o2], self.H[w], memory_list, self.V) 189 | if bestVal is None or val > bestVal: 190 | bestVal = val 191 | best = w 192 | Y_true.append(r) 193 | Y_pred.append(best) 194 | print metrics.classification_report(Y_true, Y_pred) 195 | 196 | def get_lines(self, fname): 197 | lines = [] 198 | for i,line in enumerate(open(fname)): 199 | id = int(line[0:line.find(' ')]) 200 | line = line.strip() 201 | line = line[line.find(' ')+1:] 202 | if line.find('?') == -1: 203 | lines.append({'type':'s', 'text': line}) 204 | else: 205 | idx = line.find('?') 206 | tmp = line[idx+1:].split('\t') 207 | lines.append({'id':id, 'type':'q', 'text': line[:idx], 'answer': tmp[1].strip(), 'refs': [int(x) for x in tmp[2:][0].split(' ')]}) 208 | if False and i > 1000: 209 | break 210 | return np.array(lines) 211 | 212 | def str2bool(v): 213 | return v.lower() in ("yes", "true", "t", "1") 214 | 215 | def main(): 216 | parser = argparse.ArgumentParser() 217 | parser.register('type','bool',str2bool) 218 | parser.add_argument('--task', type=int, default=1, help='Task#') 219 | parser.add_argument('--train_file', type=str, default='', help='Train file') 220 | parser.add_argument('--test_file', type=str, default='', help='Test file') 221 | parser.add_argument('--gamma', type=float, default=1, help='Gamma') 222 | parser.add_argument('--lr', type=float, default=0.1, help='Learning rate') 223 | parser.add_argument('--embedding_size', type=int, default=50, help='Embedding size') 224 | parser.add_argument('--n_epochs', type=int, default=10, help='Num epochs') 225 | args = parser.parse_args() 226 | print "args: ", args 227 | 228 | train_file = glob.glob('data/en-10k/qa%d_*train.txt' % args.task)[0] 229 | test_file = glob.glob('data/en-10k/qa%d_*test.txt' % args.task)[0] 230 | if args.train_file != '' and args.test_file != '': 231 | train_file, test_file = args.train_file, args.test_file 232 | 233 | model = Model(train_file, test_file, D=args.embedding_size, gamma=args.gamma, lr=args.lr) 234 | model.train(args.n_epochs) 235 | 236 | if __name__ == '__main__': 237 | main() 238 | -------------------------------------------------------------------------------- /results/.keep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/npow/MemNN/ad4ca96dc2154f964a763818ee64bb377d555af0/results/.keep --------------------------------------------------------------------------------