├── Model.py ├── config.py ├── data_loader.py ├── debug.py ├── gen_dict.py ├── gen_reader_data.py ├── gen_sim_data.py ├── plot.py ├── readme.md ├── setup_processed_data.sh ├── train_kv_mm.py ├── train_lstm.py ├── train_mlp.py └── utils.py /Model.py: -------------------------------------------------------------------------------- 1 | #/usr/bin/python 2 | from collections import namedtuple 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.optim as optim 7 | from torchvision import datasets, transforms 8 | from torch.autograd import Variable 9 | from torch.legacy.nn import CosineEmbeddingCriterion 10 | from torch.legacy.nn import Sum 11 | import pdb 12 | 13 | class MLP(nn.Module): 14 | def __init__(self, config): 15 | super(MLP, self).__init__() 16 | self.embed = nn.Embedding(config.n_embed, config.d_embed) 17 | self.cosine = nn.CosineEmbeddingLoss() 18 | #self.mean = Sum(0,True) 19 | 20 | def forward(self, x1, x2, y): 21 | #pdb.set_trace() 22 | v1 = Variable(x1) 23 | v2 = Variable(x2) 24 | y = Variable(y) 25 | v1 = self.embed(v1) 26 | v1 = v1.mean(1).squeeze(1) 27 | v2 = self.embed(v2) 28 | v2 = v2.mean(1).squeeze(1) 29 | #pdb.set_trace() 30 | loss = self.cosine(v1,v2,y) 31 | return loss 32 | 33 | def save(self, filename): 34 | tmp = [x for x in self.parameters()] 35 | with open(filename, "w") as f: 36 | torch.save(tmp[0], f) 37 | 38 | def load(self, filename): 39 | embed_t = None 40 | with open(filename) as f: 41 | embed_t = torch.load(f) 42 | self.embed.weight = embed_t 43 | 44 | 45 | class RnnReader(nn.Module): 46 | def __init__(self, config): 47 | super(RnnReader, self).__init__() 48 | self.config = config 49 | self.embed = nn.Embedding(config.n_embed, config.d_embed) 50 | self.rnn_doc = nn.GRU(config.d_embed, config.rnn_fea_size, batch_first=True, bidirectional=True) 51 | self.rnn_qus = nn.GRU(config.d_embed, config.rnn_fea_size, batch_first=True, bidirectional=True) 52 | self.h0_doc = torch.rand(2,1, self.config.rnn_fea_size) 53 | self.h0_q = Variable(torch.rand(2, 1, self.config.rnn_fea_size)) 54 | 55 | def forward(self, qu, w, e_p): 56 | qu = Variable(qu) 57 | w = Variable(w) 58 | embed_q = self.embed(qu) 59 | embed_w = self.embed(w) 60 | s_ = embed_w.size() 61 | b_size = s_[0] 62 | 63 | #pdb.set_trace() 64 | h0_doc = Variable(torch.cat([self.h0_doc for _ in range(b_size)], 1)) 65 | out_qus, h_qus = self.rnn_qus(embed_q, self.h0_q) 66 | out_doc, h_doc = self.rnn_doc(embed_w, h0_doc) 67 | 68 | q_state = torch.cat([out_qus[0,-1,:self.config.rnn_fea_size], out_qus[0,0,self.config.rnn_fea_size:]],0) 69 | 70 | # token attention 71 | doc_tit_ent_dot = [] 72 | doc_tit_ent = [] 73 | doc_states = [] 74 | for i,k in enumerate(e_p): 75 | # memory 76 | t_e_v = self.cat(out_doc[i,1], out_doc[i,k]) 77 | # dot product 78 | title = torch.dot(out_doc[i,1], q_state) 79 | entity = torch.dot(out_doc[i,k], q_state) 80 | token_att = torch.cat([title, entity],0).unsqueeze(0) 81 | s_m = F.softmax(token_att) 82 | att_v = torch.mm(s_m, t_e_v) 83 | doc_tit_ent.append(att_v) 84 | # concate start and end 85 | state_ = torch.cat([out_doc[i,-1,:self.config.rnn_fea_size], out_doc[i,0,self.config.rnn_fea_size:]],0) 86 | doc_states.append(state_.unsqueeze(0)) 87 | #pdb.set_trace() 88 | t_e_vecs = torch.cat(doc_tit_ent,0) 89 | 90 | # sentence attention 91 | doc_states_v = torch.cat(doc_states, 0) 92 | doc_dot = torch.mm(doc_states_v, q_state.unsqueeze(1)) 93 | doc_sm = F.softmax(doc_dot) 94 | t_doc_feat = torch.add(doc_states_v, t_e_vecs) 95 | doc_feat = torch.mm(doc_sm.view(1,-1), t_doc_feat) 96 | 97 | score = torch.mm(self.embed.weight, doc_feat.view(-1,1)).view(1,-1) 98 | score_n = F.log_softmax(score) 99 | 100 | return score_n 101 | 102 | def predict(self, q, w, e_p): 103 | score = self.forward(q, w, e_p) 104 | _, index = torch.max(score.squeeze(0), 0) 105 | return index.data.numpy()[0] 106 | 107 | # concat 1-D tensor 108 | def cat(self, t1, t2): 109 | return torch.cat([t1.unsqueeze(0), t2.unsqueeze(0)],0) 110 | 111 | def load_embed(self, filename): 112 | embed_t = None 113 | with open(filename) as f: 114 | embed_t = torch.load(f) 115 | self.embed.weight = embed_t 116 | 117 | class LocalReader(nn.Module): 118 | def __init__(self, config): 119 | super(LocalReader, self).__init__() 120 | self.config = config 121 | self.embed = nn.Embedding(config.n_embed, config.d_embed) 122 | self.rnn = nn.GRU(config.d_embed, config.rnn_fea_size, batch_first=True, bidirectional=True) 123 | #self.rnn_doc = nn.GRU(config.d_embed, config.rnn_fea_size, batch_first=True, bidirectional=True) 124 | #self.rnn_qus = nn.GRU(config.d_embed, config.rnn_fea_size, batch_first=True, bidirectional=True) 125 | self.h0_doc = torch.rand(2,1, self.config.rnn_fea_size) 126 | self.h0_q = Variable(torch.rand(2, 1, self.config.rnn_fea_size)) 127 | 128 | def forward(self, qu, w, e_p): 129 | qu = Variable(qu) 130 | w = Variable(w) 131 | embed_q = self.embed(qu) 132 | embed_w = self.embed(w) 133 | s_ = embed_w.size() 134 | b_size = s_[0] 135 | 136 | #pdb.set_trace() 137 | h0_doc = Variable(torch.cat([self.h0_doc for _ in range(b_size)], 1)) 138 | out_qus, h_qus = self.rnn_qus(embed_q, self.h0_q) 139 | out_doc, h_doc = self.rnn_doc(embed_w, h0_doc) 140 | 141 | q_state = torch.cat([out_qus[0,-1,:self.config.rnn_fea_size], out_qus[0,0,self.config.rnn_fea_size:]],0) 142 | #q_state = out_qus[:,-1,:] 143 | 144 | # token attention 145 | title_states = [] 146 | entity_states = [] 147 | candidate_states = [] 148 | doc_states = [] 149 | for i,bu in enumerate(e_p): 150 | k, t_i, e_i = bu 151 | # memory 152 | #title_states.append(out_doc[i,1].unsqueeze(0)) 153 | #entity_states.append(out_doc[i,k].unsqueeze(0)) 154 | #state_ = torch.cat([out_doc[i,-1,:self.config.rnn_fea_size], out_doc[i,0,self.config.rnn_fea_size:]],0) 155 | #doc_states.append(state_.unsqueeze(0)) 156 | 157 | if t_i < len(candidate_states): 158 | torch.add(candidate_states[t_i], out_doc[i,1].unsqueeze(0)) 159 | else: 160 | candidate_states.append(out_doc[i,1].unsqueeze(0)) 161 | if e_i < len(candidate_states): 162 | torch.add(candidate_states[e_i], out_doc[i,k].unsqueeze(0)) 163 | else: 164 | candidate_states.append(out_doc[i,k].unsqueeze(0)) 165 | #doc_states_v = torch.cat(doc_states, 0) 166 | #title_states_v = torch.cat(title_states, 0) 167 | #entity_states_v = torch.cat(entity_states, 0) 168 | cand_states_v = torch.cat(candidate_states, 0) 169 | 170 | # add 171 | #title_states_v = torch.add(doc_states_v, title_states_v) 172 | #entity_states_v = torch.add(doc_states_v, entity_states_v) 173 | # final feature 174 | #f_fea_v = torch.cat([title_states_v, entity_states_v], 0) 175 | #f_fea_v = torch.mm(q_state.unsqueeze(0), torch.transpose(f_fea_v,0,1)) 176 | f_fea_v = torch.mm(q_state.unsqueeze(0), torch.transpose(cand_states_v,0,1)) 177 | 178 | score_n = F.log_softmax(f_fea_v) 179 | return score_n 180 | 181 | def predict(self, q, w, e_p): 182 | score = self.forward(q, w, e_p) 183 | _, index = torch.max(score.squeeze(0), 0) 184 | return index.data.numpy()[0] 185 | 186 | # concat 1-D tensor 187 | def cat(self, t1, t2): 188 | return torch.cat([t1.unsqueeze(0), t2.unsqueeze(0)],0) 189 | 190 | def load_embed(self, filename): 191 | embed_t = None 192 | with open(filename) as f: 193 | embed_t = torch.load(f) 194 | self.embed.weight = embed_t 195 | 196 | class MemoryReader(nn.Module): 197 | def __init__(self, config): 198 | super(MemoryReader, self).__init__() 199 | self.config = config 200 | self.embed_A = nn.Embedding(config.n_embed, config.d_embed) 201 | self.embed_B = nn.Embedding(config.n_embed, config.d_embed) 202 | self.embed_C = nn.Embedding(config.n_embed, config.d_embed) 203 | self.H = nn.Linear(config.d_embed, config.d_embed) 204 | 205 | def forward(self, qu, w, cand): 206 | qu = Variable(qu) 207 | w = Variable(w) 208 | cand = Variable(cand) 209 | embed_q = self.embed_B(qu) 210 | embed_w1 = self.embed_A(w) 211 | embed_w2 = self.embed_C(w) 212 | embed_c = self.embed_C(cand) 213 | 214 | #pdb.set_trace() 215 | q_state = torch.sum(embed_q, 1).squeeze(1) 216 | w1_state = torch.sum(embed_w1, 1).squeeze(1) 217 | w2_state = torch.sum(embed_w2, 1).squeeze(1) 218 | 219 | for _ in range(self.config.hop): 220 | sent_dot = torch.mm(q_state, torch.transpose(w1_state, 0, 1)) 221 | sent_att = F.softmax(sent_dot) 222 | 223 | a_dot = torch.mm(sent_att, w2_state) 224 | a_dot = self.H(a_dot) 225 | q_state = torch.add(a_dot, q_state) 226 | 227 | f_feat = torch.mm(q_state, torch.transpose(embed_c, 0, 1)) 228 | score = F.log_softmax(f_feat) 229 | return score 230 | 231 | def predict(self, q, w, e_p): 232 | score = self.forward(q, w, e_p) 233 | _, index = torch.max(score.squeeze(0), 0) 234 | return index.data.numpy()[0] 235 | 236 | def load_embed(self, filename): 237 | embed_t = None 238 | with open(filename) as f: 239 | embed_t = torch.load(f) 240 | self.embed_A.weight = embed_t 241 | self.embed_B.weight = embed_t 242 | self.embed_C.weight = embed_t 243 | 244 | class RLReader(nn.Module): 245 | def __init__(self, config): 246 | super(RLReader, self).__init__() 247 | self.config = config 248 | self.embed_A = nn.Embedding(config.n_embed, config.d_embed) 249 | self.embed_B = nn.Embedding(config.n_embed, config.d_embed) 250 | self.embed_C = nn.Embedding(config.n_embed, config.d_embed) 251 | self.H = nn.Linear(config.d_embed, config.d_embed) 252 | self.rnn_doc = nn.GRU(config.d_embed, config.rnn_fea_size, batch_first=True, bidirectional=False) 253 | self.rnn_qus = nn.GRU(config.d_embed, config.rnn_fea_size, batch_first=True, bidirectional=False) 254 | self.h0_doc = Variable(torch.rand(2, 1, self.config.rnn_fea_size)) 255 | self.h0_q = Variable(torch.rand(2, 1, self.config.rnn_fea_size)) 256 | 257 | 258 | def forward(self, qu, w, cand): 259 | qu = Variable(qu) 260 | w = Variable(w) 261 | cand = Variable(cand) 262 | embed_q = self.embed_B(qu) 263 | embed_w1 = self.embed_A(w) 264 | embed_c = self.embed_C(cand) 265 | 266 | #pdb.set_trace() 267 | q_state = torch.sum(embed_q, 1).squeeze(1) 268 | w1_state = torch.sum(embed_w1, 1).squeeze(1) 269 | 270 | sent_dot = torch.mm(q_state, torch.transpose(w1_state, 0, 1)) 271 | sent_att = F.softmax(sent_dot) 272 | 273 | q_rnn_state = self.rnn_qus(embed_q, self.h0_q)[-1].squeeze(0) 274 | #pdb.set_trace() 275 | 276 | action = sent_att.multinomial() 277 | 278 | sent = embed_w1[action.data[0]] 279 | sent_state = self.rnn_doc(sent, self.h0_doc)[-1].squeeze(0) 280 | q_state = torch.add(q_state, sent_state) 281 | 282 | f_feat = torch.mm(q_state, torch.transpose(embed_c, 0, 1)) 283 | reward_prob = F.log_softmax(f_feat).squeeze(0) 284 | 285 | return action, reward_prob 286 | 287 | def predict(self, q, w, e_p): 288 | _, score = self.forward(q, w, e_p) 289 | _, index = torch.max(score.squeeze(0), 0) 290 | return index.data.numpy()[0] 291 | 292 | def load_embed(self, filename): 293 | embed_t = None 294 | with open(filename) as f: 295 | embed_t = torch.load(f) 296 | self.embed_A.weight = embed_t 297 | self.embed_B.weight = embed_t 298 | self.embed_C.weight = embed_t 299 | 300 | 301 | class PairReader(nn.Module): 302 | def __init__(self, config): 303 | super(PairReader, self).__init__() 304 | self.config = config 305 | self.embed = nn.Embedding(config.n_embed, config.d_embed) 306 | self.rnn = nn.LSTM(config.d_embed, config.rnn_fea_size, batch_first=True, bidirectional=False, dropout=0.1) 307 | self.h0 = Variable(torch.FloatTensor(1, 1, self.config.rnn_fea_size).zero_()) 308 | self.c0 = Variable(torch.FloatTensor(1, 1, self.config.rnn_fea_size).zero_()) 309 | 310 | def forward(self, qu, w, cand): 311 | qu = Variable(qu) 312 | cand = Variable(cand) 313 | embed_q = self.embed(qu) 314 | embed_cand = self.embed(cand) 315 | 316 | out, (self.h0, self.c0) = self.rnn(embed_q, (self.h0, self.c0)) 317 | self.h0.detach_() 318 | self.c0.detach_() 319 | q_state = out[:,-1,:] 320 | 321 | f_fea_v = torch.mm(q_state, torch.transpose(embed_cand,0,1)) 322 | 323 | score_n = F.log_softmax(f_fea_v) 324 | return score_n 325 | 326 | def predict(self, q, w, e_p): 327 | score = self.forward(q, w, e_p) 328 | _, index = torch.max(score.squeeze(0), 0) 329 | return index.data.numpy()[0] 330 | 331 | # concat 1-D tensor 332 | def cat(self, t1, t2): 333 | return torch.cat([t1.unsqueeze(0), t2.unsqueeze(0)],0) 334 | 335 | def load_embed(self, filename): 336 | embed_t = None 337 | with open(filename) as f: 338 | embed_t = torch.load(f) 339 | self.embed.weight = embed_t 340 | 341 | class KVMemoryReader(nn.Module): 342 | def __init__(self, config): 343 | super(KVMemoryReader, self).__init__() 344 | self.config = config 345 | self.embed_A = nn.Embedding(config.n_embed, config.d_embed) 346 | self.embed_B = nn.Embedding(config.n_embed, config.d_embed) 347 | self.embed_C = nn.Embedding(config.n_embed, config.d_embed) 348 | self.H = nn.Linear(config.d_embed, config.d_embed) 349 | 350 | def forward(self, qu, key, value, cand): 351 | qu = Variable(qu) 352 | key = Variable(key) 353 | value = Variable(value) 354 | cand = Variable(cand) 355 | embed_q = self.embed_B(qu) 356 | embed_w1 = self.embed_A(key) 357 | embed_w2 = self.embed_C(value) 358 | embed_c = self.embed_C(cand) 359 | 360 | #pdb.set_trace() 361 | q_state = torch.sum(embed_q, 1).squeeze(1) 362 | w1_state = torch.sum(embed_w1, 1).squeeze(1) 363 | w2_state = embed_w2 364 | 365 | for _ in range(self.config.hop): 366 | sent_dot = torch.mm(q_state, torch.transpose(w1_state, 0, 1)) 367 | sent_att = F.softmax(sent_dot) 368 | 369 | a_dot = torch.mm(sent_att, w2_state) 370 | a_dot = self.H(a_dot) 371 | q_state = torch.add(a_dot, q_state) 372 | 373 | f_feat = torch.mm(q_state, torch.transpose(embed_c, 0, 1)) 374 | score = F.log_softmax(f_feat) 375 | return score 376 | 377 | def predict(self, q, key, value, cand): 378 | score = self.forward(q, key, value, cand) 379 | _, index = torch.max(score.squeeze(0), 0) 380 | return index.data.numpy()[0] 381 | 382 | def load_embed(self, filename): 383 | embed_t = None 384 | with open(filename) as f: 385 | embed_t = torch.load(f) 386 | self.embed_A.weight = embed_t 387 | self.embed_B.weight = embed_t 388 | self.embed_C.weight = embed_t 389 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | class Config(): 4 | def __init__(self): 5 | self.n_embed = 186841 6 | self.d_embed = 300 7 | self.sent_len = 20 8 | self.win_len = 7 + 3 9 | 10 | self.ir_size = 1000 11 | self.filter_size = 100 12 | 13 | self.batch_size = 128 14 | self.epoch = 100 15 | self.lr = 0.005 16 | self.l2 = 0.00001 17 | self.valid_every = self.batch_size * 100 18 | 19 | self.margin = 0.2 20 | 21 | self.rnn_fea_size = self.d_embed 22 | 23 | #model file 24 | self.pre_embed_file = "./model/{}/embedding.pre".format(self.d_embed) 25 | self.reader_model = "./model/reader_{}.torch".format(self.d_embed) 26 | 27 | self.title_dict = "./pkl/dict/title.dict" 28 | self.entity_dict = "./pkl/dict/entity.dict" 29 | 30 | #data dir 31 | self.data_dir = "./data" 32 | 33 | #memory network 34 | self.hop = 2 35 | 36 | # RL 37 | self.K = 1 38 | 39 | 40 | config = Config() 41 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | from __future__ import division 3 | from config import config 4 | import torchfile 5 | from utils import * 6 | import pdb 7 | 8 | data_dir = config.data_dir 9 | train_file = data_dir + "/torch/train_1.txt" 10 | dev_file = data_dir + "/torch/dev_1.txt" 11 | test_file = data_dir + "/torch/test_1.txt" 12 | wiki_file = data_dir + "/torch/wiki-w=0-d=3-i-m.txt" 13 | 14 | train_x = torchfile.load(train_file + ".vecarray.x") 15 | train_y = torchfile.load(train_file + ".vecarray.y") 16 | dev_x = torchfile.load(dev_file + ".vecarray.x") 17 | dev_y = torchfile.load(dev_file + ".vecarray.y") 18 | test_x = torchfile.load(test_file + ".vecarray.x") 19 | test_y = torchfile.load(test_file + ".vecarray.y") 20 | wiki_x = torchfile.load(wiki_file + ".hash.facts1_va") 21 | wiki_y = torchfile.load(wiki_file + ".hash.facts2_va") 22 | wiki_ind_ = torchfile.load(wiki_file + ".hash.facts_ind") 23 | wiki_hash_ = torchfile.load(wiki_file + ".hash.facts_hash_va") 24 | 25 | #pdb.set_trace() 26 | 27 | def get(x, i): 28 | start = int(x["idx"][i]) - 1 29 | length = int(x["len"][i]) 30 | ret = x["data"][start : start + length] 31 | return [ int(x) -1 for x in ret] 32 | 33 | def extract(x,y): 34 | q = [] 35 | a = [] 36 | l = int(x["cnt"][0]) 37 | for i in range(l): 38 | q_ = get(x,i) 39 | a_ = get(y,i) 40 | q.append(q_) 41 | a.append(a_) 42 | return q, a 43 | 44 | train_q, train_a = extract(train_x, train_y) 45 | dev_q, dev_a = extract(dev_x, dev_y) 46 | test_q, test_a = extract(test_x, test_y) 47 | wiki_q, wiki_a = extract(wiki_x, wiki_y) 48 | # -1 represents none ! 49 | wiki_hash, _ = extract(wiki_hash_, wiki_hash_) 50 | #wiki_ind = extract(wiki_ind_, wiki_ind_) 51 | 52 | assert len(train_q) == 96185 53 | assert len(dev_q) == 10000 54 | -------------------------------------------------------------------------------- /debug.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | import operator 3 | import cPickle 4 | import pdb 5 | 6 | dict_file = "../WikiMovies/data/torch/dict.txt" 7 | pkl_file = "./pkl/dic.pkl" 8 | word2id = {} 9 | id2word = [] 10 | 11 | def gen_dict(): 12 | with open(dict_file) as f: 13 | for line in f: 14 | line = line.strip() 15 | ind = line.rfind("\t") 16 | w = line[:ind] 17 | word2id[w.lower()] = len(word2id) 18 | sort_l = sorted(word2id.items(), key=operator.itemgetter(1)) 19 | id2word = [x[0] for x in sort_l] 20 | print "Voc len,", len(word2id) 21 | with open(pkl_file, 'w') as f: 22 | cPickle.dump((word2id, id2word), f) 23 | 24 | def load_dict(): 25 | with open(pkl_file) as f: 26 | word2id, id2word = cPickle.load(f) 27 | print "Loading dict:", len(id2word) 28 | return word2id, id2word 29 | 30 | def toSent(sent): 31 | return " ".join([id2word[i] for i in sent]) 32 | 33 | #gen_dict() 34 | word2id, id2word = load_dict() 35 | #pdb.set_trace() -------------------------------------------------------------------------------- /gen_dict.py: -------------------------------------------------------------------------------- 1 | #!/usr/local/bin/python 2 | from __future__ import division 3 | import numpy as np 4 | import torchfile 5 | from config import config 6 | from utils import * 7 | from data_loader import * 8 | import pdb 9 | 10 | # generating title and entity_dic 11 | def make_dict(x,y): 12 | title_dict = {} 13 | entity_dict = {} 14 | cnt = 0 15 | for sent, a in zip(x,y): 16 | if sent[0] == 6: 17 | #movie 18 | sent = sent[1:] 19 | title_dict[tuple(sent)] = a[0] 20 | else: 21 | #window 22 | assert sent[0] == 7 23 | sent = sent[2:] 24 | entity_dict[tuple(sent)] = a[0] 25 | if a[0] not in sent: 26 | cnt += 1 27 | print "Pos bug", cnt / len(y) 28 | for sent, a in zip(x,y): 29 | if sent[0] == 6: 30 | #movie 31 | continue 32 | else: 33 | #window 34 | assert sent[0] == 7 35 | sent_ = sent[2:] 36 | k_ = tuple(sent_) 37 | if k_ not in title_dict: 38 | title_dict[k_] = sent[1] 39 | return title_dict, entity_dict 40 | 41 | d1, d2 = make_dict(wiki_q, wiki_a) 42 | #pdb.set_trace() 43 | dump_to_file(d1, config.title_dict) 44 | dump_to_file(d2, config.entity_dict) 45 | -------------------------------------------------------------------------------- /gen_reader_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | from __future__ import division 3 | from config import config 4 | from utils import * 5 | from data_loader import * 6 | import pdb 7 | from Model import MLP 8 | 9 | # dump path 10 | train_file_path = "./pkl/reader/300/train_pair.pkl" 11 | dev_file_path = "./pkl/reader/300/dev_pair.pkl" 12 | test_file_path = "./pkl/reader/300/test_pair.pkl" 13 | 14 | model = MLP(config) 15 | model.load(config.pre_embed_file) 16 | 17 | title_dict = load_from_file(config.title_dict) 18 | entity_dict= load_from_file(config.entity_dict) 19 | 20 | def predict_sim(x,y): 21 | x = torch.LongTensor([x]) 22 | y = torch.LongTensor([y]) 23 | sim = torch.LongTensor([1]) 24 | #pdb.set_trace() 25 | loss = model.forward(x,y,sim) 26 | return loss.data.numpy()[0] 27 | 28 | def extract_ans_pair(golds, wiki_ans): 29 | s1 = set(golds) 30 | s2 = set(wiki_ans) 31 | s_ = s1.intersection(s2) 32 | return list(s_)[0] 33 | 34 | # the answer of test is list 35 | def make_pairs(x,y,test=False): 36 | eos = 1 # index 37 | unk = 2 38 | ret_pair = [] 39 | for q, a in zip(x,y): 40 | cand_facts = [] 41 | for token in q: 42 | if token >= len(wiki_hash): continue 43 | cand_inds = wiki_hash[token] 44 | if len(cand_inds) == 0: continue 45 | 46 | # for every cand in pre-select candidates 47 | for cand in cand_inds: 48 | if cand == -1: continue # none 49 | s_sim_pair = None # store the result 50 | 51 | # reform the sentence and extract the entity position 52 | sent = wiki_q[cand] 53 | ent_pos = 0 54 | if sent[0] == 6: 55 | #movie 56 | continue 57 | else: 58 | #window 59 | assert sent[0] == 7 60 | raw_sent = sent[2:] 61 | try: 62 | sent = [6, title_dict[tuple(raw_sent)],6] + sent[2:] 63 | except: 64 | sent = [6, entity_dict[tuple(raw_sent)], 6] + sent[2:] 65 | try: 66 | ent_pos = sent.index(entity_dict[tuple(raw_sent)]) 67 | except: 68 | # bug of preprocesing 69 | e_tmp_p = (len(sent) -3) // 2 + 3 70 | sent = sent[:e_tmp_p] + [entity_dict[tuple(raw_sent)]] + sent[e_tmp_p:] 71 | ent_pos = e_tmp_p 72 | #pdb.set_trace() 73 | #print "Pos bug" 74 | 75 | # sentence too long or too short 76 | if len(sent) < config.win_len: 77 | sent = sent + [unk for _ in range(config.win_len - len(sent))] 78 | else: 79 | # change based on the end_pos 80 | if ent_pos <= 6: 81 | sent = sent[:config.win_len] 82 | else: 83 | pre_len = len(sent) 84 | sent = sent[:3] + sent[pre_len-config.win_len+3:] 85 | ent_pos = ent_pos + config.win_len - pre_len 86 | assert len(sent) == config.win_len and ent_pos < config.win_len 87 | 88 | # add the pair to the list 89 | s_sim_pair = (predict_sim(q,sent), ent_pos) 90 | cand_facts.append((s_sim_pair, sent, ent_pos)) 91 | 92 | # filter out the top K facts 93 | cand_facts.sort(key=lambda x:x[0]) 94 | filter_size = min(len(cand_facts), config.filter_size) 95 | cand_qs = [x[1] for x in cand_facts[:filter_size]] 96 | cand_ent_pos = [x[2] for x in cand_facts[:filter_size]] 97 | if test: 98 | ret_pair.append((q, cand_qs, cand_ent_pos, a)) 99 | else: 100 | ret_pair.append((q, cand_qs, cand_ent_pos, a[0])) 101 | 102 | # convert to tensor for batch 103 | ret_pair.sort(key=lambda x:len(x[0])) 104 | ret_q = [] 105 | ret_w = [] 106 | ret_a = [] 107 | ret_ent_p = [] 108 | for q_, w_q_, w_e_p_, a_ in ret_pair: 109 | ret_q.append(torch.LongTensor(q_)) 110 | ret_w.append(torch.LongTensor(w_q_)) 111 | ret_ent_p.append(torch.LongTensor(w_e_p_)) 112 | if not test: 113 | ret_a.append(torch.LongTensor([a_])) # single value must have [] 114 | else: 115 | ret_a.append(a_) # test no need to be tensor 116 | return (ret_q, ret_w, ret_ent_p, ret_a) 117 | 118 | def evaluate(data): 119 | cnt1 = 0 120 | cnt2 = 0 121 | cnt3 = 0 122 | q, w, e_p, a = data 123 | for q_, w_, a_ in zip(q, w, a): 124 | #pdb.set_trace() 125 | w_np = w_.numpy() 126 | cnt3 += len(w_np) 127 | if len(w_np) == 0: continue 128 | 129 | 130 | # top 1 131 | best_w = w_np[0] 132 | for tmp_a in a_: 133 | sig2 = tmp_a in best_w 134 | if sig2: 135 | cnt2 += 1 136 | break 137 | 138 | # coverage 139 | for tmp_a in a_: 140 | sig1 = [tmp_a in tmp_w for tmp_w in w_np] 141 | if max(sig1): 142 | cnt1 += 1 143 | break 144 | return cnt1 / len(a), cnt2 / len(a), cnt3 / len(a) 145 | 146 | if __name__ == "__main__": 147 | train_pairs = make_pairs(train_q, train_a) 148 | dump_to_file(train_pairs, train_file_path) 149 | dev_pairs = make_pairs(dev_q, dev_a) 150 | dump_to_file(dev_pairs, dev_file_path) 151 | test_pairs = make_pairs(test_q, test_a, test=True) 152 | dump_to_file(test_pairs, test_file_path) 153 | 154 | #test_pairs = load_from_file(test_file_path) 155 | print evaluate(test_pairs) 156 | -------------------------------------------------------------------------------- /gen_sim_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/local/bin/python 2 | from __future__ import division 3 | import numpy as np 4 | from data_loader import * 5 | from config import config 6 | from utils import * 7 | import pdb 8 | 9 | # parameters 10 | train_file_path = "./pkl/sim/train_pair.pkl" 11 | dev_file_path = "./pkl/sim/dev_pair.pkl" 12 | test_file_path = "./pkl/sim/test_pair.pkl" 13 | 14 | title_dict = load_from_file(config.title_dict) 15 | entity_dict= load_from_file(config.entity_dict) 16 | 17 | # evaluate coverage 18 | def evaluate(x, y, notitle=False): 19 | cnt = 0 20 | for q, a in zip(x,y): 21 | for token in q: 22 | flag = False 23 | if token >= len(wiki_hash): continue 24 | cand_inds = wiki_hash[token] 25 | if len(cand_inds) == 0: continue 26 | for cand in cand_inds: 27 | if cand == -1: continue 28 | sig = False 29 | sent = wiki_q[cand] 30 | if sent[0] == 6: 31 | #movie 32 | if not notitle: 33 | sig = max([i_ in a for i_ in wiki_a[cand]]) 34 | else: 35 | #window 36 | assert sent[0] == 7 37 | sig = -1 38 | if title_dict: 39 | # change 1: to original 40 | k_ = tuple(sent[2:]) 41 | try: 42 | sent[1] = title_dict[tuple(sent[2:])] 43 | except: 44 | pass 45 | sig = max([i_ in a for i_ in sent]) 46 | else: 47 | sig = max([i_ in a for i_ in wiki_a[cand]]) 48 | if sig == True: 49 | cnt += 1 50 | flag = True 51 | break 52 | if flag: break 53 | return cnt / len(x) 54 | 55 | ''' 56 | 1. raw sent with sim 57 | 2. form a special sent with title 58 | ''' 59 | def make_pairs(x,y): 60 | pairs = {} 61 | for q, a in zip(x,y): 62 | cnt = 0 63 | for token in q: 64 | if token >= len(wiki_hash): continue 65 | cand_inds = wiki_hash[token] 66 | if len(cand_inds) == 0: continue 67 | for cand in cand_inds: 68 | if cand == -1: continue # none 69 | sent = wiki_q[cand] 70 | sim = -1 71 | if sent[0] == 6: 72 | #movie 73 | #sent = sent[1:] 74 | #sig = max([i_ in a for i_ in wiki_a[cand]]) 75 | #if sig: sim = 1 #TODO: fix the loss function 76 | continue 77 | else: 78 | #window 79 | assert sent[0] == 7 80 | sent = sent[2:] 81 | sig = max([i_ in a for i_ in wiki_a[cand]]) 82 | if sig: sim = 1 83 | pair = (tuple(q), tuple(sent)) 84 | if pair not in pairs: 85 | pairs[pair] = sim 86 | cnt += 1 87 | elif pairs[pair] != sim: 88 | # it's related to the notitle 89 | pairs[pair] = max(sim, pairs[pair]) 90 | if cnt > config.ir_size: print "Too many windows:{}".format(cnt) 91 | ret_pair = [] 92 | for k in pairs: 93 | #pdb.set_trace() 94 | q_,a_ = k 95 | try: 96 | a_ = [6, title_dict[a_], 6] + list(a_) 97 | except: 98 | a_ = [6, entity_dict[a_], 6] + list(a_) 99 | ret_pair.append((q_,a_,pairs[k])) 100 | # for batch 101 | ret_pair.sort(key=lambda x:len(x[0])) 102 | ret_q = [] 103 | ret_a = [] 104 | ret_sim = [] 105 | for q_, a_, s_ in ret_pair: 106 | ret_q.append(q_) 107 | # TODO: why extra length 108 | ret_a.append(a_[:config.win_len]) 109 | ret_sim.append(s_) 110 | return (ret_q, ret_a, ret_sim) 111 | 112 | if __name__ == "__main__": 113 | print "Evaluating coverage" 114 | print evaluate(test_q, test_a, notitle = False) 115 | print evaluate(test_q, test_a, notitle = True) 116 | 117 | train_pairs = make_pairs(train_q, train_a) 118 | dump_to_file(train_pairs, train_file_path) 119 | dev_pairs = make_pairs(dev_q, dev_a) 120 | dump_to_file(dev_pairs, dev_file_path) 121 | #test_pairs = make_pairs(test_q, test_a) 122 | #dump_to_file(test_pairs, dev_file_path) 123 | -------------------------------------------------------------------------------- /plot.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | from __future__ import division 3 | import matplotlib.pyplot as plt 4 | 5 | #TBD for doc retrieve 6 | 7 | N = 9952 8 | nums = [5, 10, 20, 30, 40, 50, 60, 70, 80] 9 | cover = [6763, 7847, 8765, 9237, 9367, 9405, 9429, 9445, 9447] 10 | 11 | ratio = [x / N for x in cover] 12 | 13 | 14 | plt.figure(1) 15 | 16 | plt.plot(nums, ratio) 17 | plt.title("Answer hit(%) in top-N ranked sentences") 18 | plt.axis([0,100, 0.6, 1]) 19 | plt.xlabel("Number of retrieved sentences") 20 | plt.ylabel("% of correct answers coverred") 21 | plt.show() 22 | 23 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Implementation of Memory Network in pyTorch 2 | 3 | This repo including implementations of End-to-End Memory Network and Key-Value Memory Network 4 | 5 | Dataset used here is WikiMovie. Note that the preprocessing scripts are borrowed from [original repo](https://github.com/facebook/MemNN/tree/master/KVmemnn) 6 | 7 | # Code (mainly in pytorch) 8 | * config.py: some global configurations for data dir and training detail 9 | * gen\_dict.py: generate dict 10 | * data\_loader.py: load preprocessed data 11 | * model.py: model for training similarities and readers 12 | * train\_\*.py: training script for experimental model 13 | * plot.py: plot the graph for report 14 | 15 | # Run 16 | * run setup\_processed\_data.sh to get the processed data 17 | * run gen\_dict.py to generate the dictionary file 18 | * run gen\_sim\_data.py to generate the data for training similarity function 19 | * run the train\_mlp.py to train the similarity function 20 | * run gen\_reader\_data.py to generate the data for reader 21 | * run train\_lstm.py, train\_memory\_network.py, train\_kv\_mm.py to train the corresponding model 22 | -------------------------------------------------------------------------------- /setup_processed_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2004-present Facebook. All Rights Reserved. 3 | 4 | wget "https://s3.amazonaws.com/fair-data/memnn/kvmemnn/data.tar.gz" \ 5 | && tar -xzvf data.tar.gz && rm data.tar.gz 6 | -------------------------------------------------------------------------------- /train_kv_mm.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | from __future__ import division 3 | from Model import KVMemoryReader 4 | from config import config 5 | import operator 6 | from utils import * 7 | import pdb 8 | from debug import * 9 | import torch 10 | import torch.autograd as autograd 11 | from torch.autograd import Variable 12 | import torch.nn as nn 13 | from torch.legacy.nn import CosineEmbeddingCriterion 14 | import torch.nn.functional as F 15 | import torch.optim as optim 16 | import numpy as np 17 | 18 | train_q, train_w, train_e_p, train_a = load_from_file("./pkl/reader/{}/train_pair.pkl".format(config.d_embed)) 19 | dev_q, dev_w, dev_e_p, dev_a = load_from_file("./pkl/reader/{}/dev_pair.pkl".format(config.n_embed)) 20 | #train_q, train_w, train_e_p, train_a = load_from_file("./pkl/toy/reader/train_pair.pkl") 21 | #dev_q, dev_w, dev_e_p, dev_a = load_from_file("./pkl/toy/reader/dev_pair.pkl") 22 | 23 | def modify(q, wiki, pos, ans): 24 | tL = torch.LongTensor 25 | ret_q = [] 26 | ret_key = [] 27 | ret_value = [] 28 | ret_cand = [] 29 | ret_a = [] 30 | for qu,w,p,a_ in zip(q,wiki,pos,ans): 31 | # encoding the candidate 32 | can_dict = {} 33 | qu = qu.numpy() 34 | w = w.numpy() 35 | p = p.numpy() 36 | a_ = a_.numpy() 37 | 38 | # generate local candidate 39 | len_w = len(w) 40 | cand_ind = [] 41 | for i in range(len_w): 42 | if w[i][1] not in can_dict: 43 | can_dict[w[i][1]] = len(can_dict) 44 | if w[i][p[i]] not in can_dict: 45 | can_dict[w[i][p[i]]] = len(can_dict) 46 | if a_[0] not in can_dict: 47 | continue 48 | else: 49 | sort_l = sorted(can_dict.items(), key=operator.itemgetter(1)) 50 | cand_l = [x[0] for x in sort_l] 51 | 52 | # split into key value format 53 | #pdb.set_trace() 54 | key_m, val_m = transKV(w,p) 55 | ret_q.append(tL(qu)) 56 | ret_key.append(tL(key_m)) 57 | ret_value.append(tL(val_m)) 58 | ret_cand.append(tL(cand_l)) 59 | ret_a.append(tL([can_dict[a_[0]]])) 60 | print len(ret_q) / len(q) 61 | return ret_q, ret_key,ret_value,ret_cand, ret_a 62 | 63 | def transKV(sents, pos): 64 | unk = 2 65 | ret_k = [] 66 | ret_v = [] 67 | for sent, p in zip(sents,pos): 68 | k_ = sent[3:].tolist() + [unk] 69 | v_ = sent[1] 70 | #pdb.set_trace() 71 | ret_k.append(k_) 72 | ret_v.append(v_) 73 | #print toSent(k_),toSent([v_]) 74 | 75 | k_ = [sent[1]] + sent[3:].tolist() 76 | v_ = sent[p] 77 | ret_k.append(k_) 78 | ret_v.append(v_) 79 | #print toSent(k_),toSent([v_]) 80 | return np.array(ret_k), np.array(ret_v) 81 | 82 | def train(epoch): 83 | for e_ in range(epoch): 84 | if (e_ + 1) % 10 == 0: 85 | adjust_learning_rate(optimizer, e_) 86 | cnt = 0 87 | loss = Variable(torch.Tensor([0])) 88 | for i_q, i_k, i_v, i_cand, i_a in zip(train_q, train_key,train_value, train_cand, train_a): 89 | cnt += 1 90 | i_q = i_q.unsqueeze(0) # add dimension 91 | probs = model.forward(i_q, i_k, i_v,i_cand) 92 | i_a = Variable(i_a) 93 | curr_loss = loss_function(probs, i_a) 94 | loss = torch.add(loss, torch.div(curr_loss, config.batch_size)) 95 | 96 | # naive batch implemetation, the lr is divided by batch size 97 | if cnt % config.batch_size == 0: 98 | print "Training loss", loss.data.sum() 99 | loss.backward() 100 | optimizer.step() 101 | loss = Variable(torch.Tensor([0])) 102 | model.zero_grad() 103 | if cnt % config.valid_every == 0: 104 | print "Accuracy:",eval() 105 | 106 | def adjust_learning_rate(optimizer, epoch): 107 | lr = config.lr / (2 ** (epoch // 10)) 108 | print "Adjust lr to ", lr 109 | for param_group in optimizer.param_groups: 110 | param_group['lr'] = lr 111 | 112 | def eval(): 113 | cnt = 0 114 | for i_q, i_k, i_v, i_cand, i_a in zip(dev_q, dev_key, dev_value, dev_cand, dev_a): 115 | i_q = i_q.unsqueeze(0) # add dimension 116 | try: 117 | ind = model.predict(i_q, i_k, i_v, i_cand) 118 | except: 119 | continue 120 | if ind == i_a[0]: 121 | cnt += 1 122 | return cnt / len(dev_q) 123 | 124 | model = KVMemoryReader(config) 125 | model.load_embed(config.pre_embed_file) 126 | # here lr is divide by batch size since loss is accumulated 127 | optimizer = optim.SGD(model.parameters(), lr=config.lr) 128 | print "Training setting: lr {0}, batch size {1}".format(config.lr, config.batch_size) 129 | 130 | loss_function = nn.NLLLoss() 131 | 132 | print "{} batch expected".format(len(train_q) * config.epoch / config.batch_size) 133 | train_q, train_key, train_value, train_cand,train_a = modify(train_q, train_w, train_e_p, train_a) 134 | dev_q, dev_key, dev_value, dev_cand, dev_a = modify(dev_q, dev_w, dev_e_p, dev_a) 135 | train(config.epoch) 136 | dump_to_file(model, config.reader_model) 137 | -------------------------------------------------------------------------------- /train_lstm.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | from __future__ import division 3 | from Model import PairReader 4 | from config import config 5 | import operator 6 | from utils import * 7 | import pdb 8 | import torch 9 | import torch.autograd as autograd 10 | from torch.autograd import Variable 11 | import torch.nn as nn 12 | from torch.legacy.nn import CosineEmbeddingCriterion 13 | import torch.nn.functional as F 14 | import torch.optim as optim 15 | 16 | train_q, train_w, train_e_p, train_a = load_from_file("./pkl/reader/100_raw/train_pair.pkl") 17 | dev_q, dev_w, dev_e_p, dev_a = load_from_file("./pkl/reader/100_raw/dev_pair.pkl") 18 | #train_q, train_w, train_e_p, train_a = load_from_file("./pkl/toy/reader/train_pair.pkl") 19 | #dev_q, dev_w, dev_e_p, dev_a = load_from_file("./pkl/toy/reader/dev_pair.pkl") 20 | 21 | def modify(q, wiki, pos, ans): 22 | tL = torch.LongTensor 23 | ret_q = [] 24 | ret_wiki = [] 25 | ret_p = [] 26 | ret_a = [] 27 | for qu,w,p,a_ in zip(q,wiki,pos,ans): 28 | # encoding the candidate 29 | can_dict = {} 30 | qu = qu.numpy() 31 | w = w.numpy() 32 | p = p.numpy() 33 | a_ = a_.numpy() 34 | 35 | len_w = len(w) 36 | cand_ind = [] 37 | 38 | for i in range(len_w): 39 | if w[i][1] not in can_dict: 40 | can_dict[w[i][1]] = len(can_dict) 41 | if w[i][p[i]] not in can_dict: 42 | can_dict[w[i][p[i]]] = len(can_dict) 43 | if a_[0] not in can_dict: 44 | continue 45 | else: 46 | sort_l = sorted(can_dict.items(), key=operator.itemgetter(1)) 47 | cand_l = [x[0] for x in sort_l] 48 | 49 | ret_q.append(tL(qu)) 50 | ret_wiki.append(tL(w)) 51 | ret_p.append(tL(cand_l)) 52 | ret_a.append(tL([can_dict[a_[0]]])) 53 | print len(ret_q) / len(q) 54 | return ret_q, ret_wiki, ret_p, ret_a 55 | 56 | 57 | def train(epoch): 58 | for e_ in range(epoch): 59 | if (e_ + 1) % 10 == 0: 60 | adjust_learning_rate(optimizer, e_) 61 | cnt = 0 62 | loss = Variable(torch.Tensor([0])) 63 | for i_q, i_w, i_e_p, i_a in zip(train_q, train_w, train_e_p, train_a): 64 | cnt += 1 65 | i_q = i_q.unsqueeze(0) # add dimension 66 | probs = model.forward(i_q, i_w, i_e_p) 67 | i_a = Variable(i_a) 68 | curr_loss = loss_function(probs, i_a) 69 | loss = torch.add(loss, torch.div(curr_loss, config.batch_size)) 70 | 71 | # naive batch implemetation, the lr is divided by batch size 72 | if cnt % config.batch_size == 0: 73 | print "Training loss", loss.data.sum() 74 | loss.backward() 75 | optimizer.step() 76 | loss = Variable(torch.Tensor([0])) 77 | model.zero_grad() 78 | if cnt % config.valid_every == 0: 79 | print "Accuracy:",eval() 80 | 81 | def adjust_learning_rate(optimizer, epoch): 82 | lr = config.lr / (2 ** (epoch // 10)) 83 | print "Adjust lr to ", lr 84 | for param_group in optimizer.param_groups: 85 | param_group['lr'] = lr 86 | 87 | def eval(): 88 | cnt = 0 89 | for i_q, i_w,i_e_p, i_a in zip(dev_q, dev_w, dev_e_p, dev_a): 90 | i_q = i_q.unsqueeze(0) # add dimension 91 | try: 92 | ind = model.predict(i_q, i_w, i_e_p) 93 | except: 94 | continue 95 | if ind == i_a[0]: 96 | cnt += 1 97 | return cnt / len(dev_q) 98 | 99 | model = PairReader(config) 100 | model.load_embed(config.pre_embed_file) 101 | # here lr is divide by batch size since loss is accumulated 102 | optimizer = optim.SGD(model.parameters(), lr=config.lr) 103 | print "Training setting: lr {0}, batch size {1}".format(config.lr, config.batch_size) 104 | 105 | loss_function = nn.NLLLoss() 106 | 107 | print "{} batch expected".format(len(train_q) * config.epoch / config.batch_size) 108 | train_q, train_w, train_e_p, train_a = modify(train_q, train_w, train_e_p, train_a) 109 | dev_q, dev_w, dev_e_p, dev_a = modify(dev_q, dev_w, dev_e_p, dev_a) 110 | train(config.epoch) 111 | dump_to_file(model, config.reader_model) 112 | -------------------------------------------------------------------------------- /train_mlp.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | from __future__ import division 3 | from Model import MLP 4 | from config import config 5 | from utils import * 6 | import pdb 7 | import torch 8 | import torch.autograd as autograd 9 | import torch.nn as nn 10 | from torch.legacy.nn import CosineEmbeddingCriterion 11 | import torch.nn.functional as F 12 | import torch.optim as optim 13 | 14 | train_q, train_a, train_sim = load_from_file("./pkl/sim/train_pair.pkl") 15 | dev_q, dev_a, dev_sim = load_from_file("./pkl/sim/dev_pair.pkl") 16 | 17 | train_q_iter = batch_sort_iter(train_q, config.batch_size, config.epoch, padding = True) 18 | train_a_iter = batch_sort_iter(train_a, config.batch_size, config.epoch, padding = True, sort=False) 19 | train_sim_iter = batch_sort_iter(train_sim, config.batch_size, config.epoch, padding = False) 20 | 21 | # large dev evaluation will result in memory issue 22 | N = 1000 23 | dev_q_t = to_tensor(dev_q[:N], padding = True) 24 | dev_a_t = to_tensor(dev_a[:N], padding = True, sort=False) 25 | dev_sim_t = torch.LongTensor(dev_sim[:N]) 26 | 27 | model = MLP(config) 28 | optimizer = optim.SGD(model.parameters(), lr=config.lr) 29 | #optimizer = nn.StochasticGradient(mlp, criterion) 30 | #criterion = CosineEmbeddingCriterion(config.margin) 31 | 32 | #pdb.set_trace() 33 | 34 | def train(): 35 | cnt = 0 36 | for i_q, i_a, i_s in zip(train_q_iter, train_a_iter, train_sim_iter): 37 | #pdb.set_trace() 38 | model.zero_grad() 39 | loss = model.forward(i_q, i_a, i_s) 40 | loss.backward() 41 | optimizer.step() 42 | print "Training loss", loss.data.sum() 43 | cnt += 1 44 | if cnt % config.valid_every == 0: 45 | loss = model.forward(dev_q_t, dev_a_t, dev_sim_t) 46 | print "Validation loss", loss.data.sum() 47 | train() 48 | model.save(config.pre_embed_file) 49 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cPickle 3 | import torch 4 | from config import config 5 | import os 6 | 7 | def batch_iter(data, batch_size, num_epochs, shuffle=True): 8 | #Generates a batch iterator for a dataset. 9 | data = torch.Tensor(data) 10 | data_size = len(data) 11 | num_batches_per_epoch = int(len(data)/batch_size) + 1 12 | print "{} batches expected".format(num_batches_per_epoch * num_epochs) 13 | for epoch in range(num_epochs): 14 | # Shuffle the data at each epoch 15 | if shuffle: 16 | shuffle_indices = np.random.permutation(np.arange(data_size)) 17 | shuffled_data = data[torch.LongTensor(shuffle_indices)] 18 | else: 19 | shuffled_data = data 20 | for batch_num in range(num_batches_per_epoch): 21 | start_index = batch_num * batch_size 22 | end_index = min((batch_num + 1) * batch_size, data_size) 23 | yield shuffled_data[start_index:end_index] 24 | 25 | def batch_sort_iter(data, batch_size, num_epochs, padding=False, sort=True): 26 | #variable length but sorted 27 | data_size = len(data) 28 | num_batches_per_epoch = int(len(data)/batch_size) + 1 29 | print "{} batches expected".format(num_batches_per_epoch * num_epochs) 30 | for epoch in range(num_epochs): 31 | # Shuffle the data at each epoch 32 | for batch_num in range(num_batches_per_epoch): 33 | start_index = batch_num * batch_size 34 | end_index = min((batch_num + 1) * batch_size, data_size) 35 | yield to_tensor(data[start_index:end_index], padding = padding, sort = sort) 36 | 37 | # padding and to tensor 38 | def to_tensor(data, padding = False, sort=True): 39 | if padding: 40 | if not sort: return padding_list(data, config.win_len) #TODO: Fix 41 | else: return padding_list(data, len(data[-1])) 42 | else: return torch.LongTensor(data) 43 | 44 | # 2 for UNK 45 | # return LongTensor 46 | def padding_list(l, length, pad=2): 47 | ret_l = [] 48 | for item in l: 49 | if len(item) < length: 50 | pad_item = [pad for _ in range(length - len(item))] 51 | item = list(item) + list(pad_item) 52 | assert len(item) == length 53 | ret_l.append(item) 54 | return torch.LongTensor(ret_l) 55 | 56 | def dump_to_file(obj, filename): 57 | path = os.path.dirname(filename) 58 | if not os.path.exists(path): 59 | print "Warning: file path not exisits, writed in the cur dir" 60 | filename = "./" + os.path.basename(filename) 61 | with open(filename, "w") as f: 62 | print "Dumping to:", filename 63 | cPickle.dump(obj, f) 64 | 65 | def load_from_file(filename): 66 | print "Loading: ", filename 67 | with open(filename) as f: 68 | return cPickle.load(f) 69 | 70 | --------------------------------------------------------------------------------