├── PreProcess.py ├── README.md ├── classification.py ├── gen_w2v.py └── model.py /PreProcess.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | import pickle 3 | from collections import defaultdict 4 | import logging 5 | #import theano 6 | import gensim 7 | import numpy as np 8 | from random import shuffle 9 | from gensim.models.word2vec import Word2Vec 10 | import codecs 11 | logger = logging.getLogger('relevance_logger') 12 | 13 | 14 | def build_multiturn_data(trainfile, max_len = 100,isshuffle=False): 15 | revs = [] 16 | vocab = defaultdict(float) 17 | total = 1 18 | with codecs.open(trainfile,'r','utf-8') as f: 19 | for line in f: 20 | line = line.replace("_","") 21 | parts = line.strip().split("\t") 22 | 23 | lable = parts[0] 24 | message = "" 25 | words = set() 26 | for i in range(1,len(parts)-1,1): 27 | message += "_t_" 28 | message += parts[i] 29 | words.update(set(parts[i].split())) 30 | 31 | response = parts[-1] 32 | 33 | data = {"y" : lable, "m":message,"r": response} 34 | revs.append(data) 35 | total += 1 36 | if total % 10000 == 0: 37 | print(total) 38 | #words = set(message.split()) 39 | words.update(set(response.split())) 40 | 41 | for word in words: 42 | vocab[word] += 1 43 | logger.info("processed dataset with %d question-answer pairs " %(len(revs))) 44 | logger.info("vocab size: %d" %(len(vocab))) 45 | if isshuffle == True: 46 | shuffle(revs) 47 | return revs, vocab, max_len 48 | 49 | 50 | def build_data(trainfile, max_len = 20,isshuffle=False): 51 | revs = [] 52 | vocab = defaultdict(float) 53 | total = 1 54 | with codecs.open(trainfile,'r','utf-8') as f: 55 | for line in f: 56 | line = line.replace("_","") 57 | parts = line.strip().split("\t") 58 | 59 | topic = parts[0] 60 | topic_r = parts[1] 61 | lable = parts[2] 62 | message = parts[-2] 63 | response = parts[-1] 64 | 65 | data = {"y" : lable, "m":message,"r": response,"t":topic,"t2":topic_r} 66 | revs.append(data) 67 | total += 1 68 | 69 | words = set(message.split()) 70 | words.update(set(response.split())) 71 | for word in words: 72 | vocab[word] += 1 73 | logger.info("processed dataset with %d question-answer pairs " %(len(revs))) 74 | logger.info("vocab size: %d" %(len(vocab))) 75 | if isshuffle == True: 76 | shuffle(revs) 77 | return revs, vocab, max_len 78 | 79 | class WordVecs(object): 80 | def __init__(self, fname, vocab, binary, gensim): 81 | if gensim: 82 | word_vecs = self.load_gensim(fname,vocab) 83 | self.k = len(list(word_vecs.values())[0]) 84 | self.W, self.word_idx_map = self.get_W(word_vecs, k=self.k) 85 | 86 | def get_W(self, word_vecs, k=300): 87 | """ 88 | Get word matrix. W[i] is the vector for word indexed by i 89 | """ 90 | vocab_size = len(word_vecs) 91 | word_idx_map = dict() 92 | W = np.zeros(shape=(vocab_size+1, k)) 93 | W[0] = np.zeros(k) 94 | i = 1 95 | for word in word_vecs: 96 | W[i] = word_vecs[word] 97 | word_idx_map[word] = i 98 | i += 1 99 | return W, word_idx_map 100 | 101 | def load_gensim(self, fname, vocab): 102 | model = Word2Vec.load(fname) 103 | weights = [[0.] * model.vector_size] 104 | word_vecs = {} 105 | total_inside_new_embed = 0 106 | miss= 0 107 | for pair in vocab: 108 | word = gensim.utils.to_unicode(pair) 109 | if word in model: 110 | total_inside_new_embed += 1 111 | word_vecs[pair] = np.array([w for w in model[word]]) 112 | #weights.append([w for w in model[word]]) 113 | else: 114 | miss = miss + 1 115 | word_vecs[pair] = np.array([0.] * model.vector_size) 116 | #weights.append([0.] * model.vector_size) 117 | print('transfer', total_inside_new_embed, 'words from the embedding file, total', len(vocab), 'candidate') 118 | print('miss word2vec', miss) 119 | return word_vecs 120 | 121 | def ParseSingleTurn(): 122 | logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO) 123 | revs, vocab, max_len = build_data(r"\\msra-sandvm-001\v-wuyu\Data\ubuntu_data\ubuntu_data\train.topic",isshuffle=True) 124 | word2vec = WordVecs(r"\\msra-sandvm-001\v-wuyu\Models\W2V\Ubuntu\word2vec.model", vocab, True, True) 125 | pickle.dump([revs, word2vec, max_len,createtopicvec()], open("ubuntu_data.test",'wb')) 126 | logger.info("dataset created!") 127 | 128 | def ParseMultiTurn(): 129 | logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO) 130 | revs, vocab, max_len = build_multiturn_data(r"/root/桌面/DoubanConversaionCorpus/train.txt",isshuffle=False) 131 | word2vec = WordVecs(r"/root/桌面/DoubanConversaionCorpus/train_vec.model", vocab, True, True) 132 | pickle.dump([revs, word2vec, max_len], open("train_processed",'wb')) 133 | logger.info("dataset created!") 134 | 135 | if __name__=="__main__": 136 | ParseMultiTurn() 137 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MXNET-SMN 2 | 3 | Sequential Matching Network(ACL 2017) reimplemented by MXNET1.0-GPU//but you can always use cpu 4 | 5 | 用mxnet写的SMN,这篇文章发表在ACL 2017,使用MXNET1.0实现 6 | 7 |
8 | 9 | # What's SMN 10 | **SMN是啥?** 11 | 12 | SoTA model of multi-turn retrieval-based conversation systems. 13 | 14 | You can see the paper[SMN](http://www.aclweb.org/anthology/P/P17/P17-1046.pdf) 15 | 16 | Don't forget reference it if you use. 17 | 18 | Wu, Yu, et al. "Sequential Matching Network: A New Archtechture for Multi-turn Response Selection in Retrieval-based Chatbots." ACL. 2017. 19 | 20 | 21 | SMN是目前最屌的多轮检索对话模型. 22 | 23 | 如果你是个大佬,请去读一下这篇论文[SMN](http://www.aclweb.org/anthology/P/P17/P17-1046.pdf),这是我这一年中读过的为数不多的好文章. 24 | 25 | 它坦诚了SMN在某些方面的失败,对于baseline也作了完整归类,总结,所以ACL的质量还是非常高的。 26 | 27 |
28 | 29 | # How to use? 30 | **你能不能告诉我怎么用?** 31 | 32 | 33 | 1.get data from msra[douban corpus](https://1drv.ms/u/s!AtcxwlQuQjw1jF0bjeaKHEUNwitA) 34 | 35 | 2.get pre-trained word2vec.model use bash *python3.5 gen_w2v.py train.txt train_vec.model train_vec.vec* 36 | 37 | 3.get processed data use Process.py(also only support py3) 38 | 39 | 4.run the model 40 | 41 | 42 | 1.下载数据 [douban corpus](https://1drv.ms/u/s!AtcxwlQuQjw1jF0bjeaKHEUNwitA) 43 | 44 | 2.预训练Embedding矩阵*python3.5 gen_w2v.py train.txt train_vec.model train_vec.vec* 45 | 46 | 3.预处理数据使用Process.py 47 | 48 | 4.跑模型 49 | 50 | 额外赠礼:每个py文件的作用 51 | 52 | |代码 |作用 | 53 | |:------------------------------------:|------------------------------------| 54 | |gen_w2v.py|用来生成预训练的词向量| 55 | |Process.py|用来打包数据| 56 | |model.py|用来训练模型| 57 | 58 | 59 |
60 | 61 | # Params of Model 62 | **模型的参数** 63 | 64 | batch_size = 1000(with 1 titan xp) 65 | 66 | embedding_size = 200 67 | 68 | gru_layer = 1 69 | 70 | max_turn = 10 71 | 72 | max_len = 50 73 | 74 | lr = 0.001 75 | 76 | **警告:如果尝试修改参数,将会是一件非常痛苦的事情,因为我的代码高耦合** 77 | 78 |
79 | 80 | # Why you use MXNET 81 | **为什么是MXNET?** 82 | 83 | **Fast!** 84 | 85 | **快**! 不仅是开发快,运行快,训练也快。 86 | 87 | > *当你使用MXNET的时候你会有一种闪电侠在中城奔跑的错觉。* 88 | > *-杜存宵* 89 | 90 |
91 | 92 | # Other versions? 93 | **我只会用该死的TensorFlow,怎么办?** 94 | 95 | [There is Theano version.](https://github.com/MarkWuNLP/MultiTurnResponseSelection) 96 | 97 | 你可以选择学习theano或者mxnet,也可以自己实现一个。 98 | 99 | *或者自杀*。 100 | 101 |
102 | 103 | # Your code is not pythonic! 104 | **你的代码就像shit一样** 105 | 106 | Sorry to hear that. 107 | 108 | 你可以贡献你pythonic的代码,但是不好意思,我只用了三天的边角零头来开发这个模型。 109 | 110 |
111 | 112 | # Can I use your code to do a chatbot 113 | **我能用这个东西做个聊天机器人吗** 114 | 115 | If you can use lucene then you can. 116 | 117 | 需要为每个query使用lucene检索出一个候选列表,之后进行排序 118 | 119 |
120 | 121 | # How to tipping 122 | **你真帅,我要为你生猴子** 123 | 124 | You can sent your money to this [website](https://love.alipay.com/donate/index.htm) 125 | 126 | 你可以在支付宝E公益进行打赏。 127 | 128 |
129 | 130 | ## 你的屁话真多 131 | 132 | 不好意思,不好意思。 133 | 134 | ## TODO: 135 | 136 | 1.save model 137 | 138 | 2.add dropout 139 | 140 | 3.add evaluations 141 | 142 | 4.add inference 143 | 144 | -------------------------------------------------------------------------------- /classification.py: -------------------------------------------------------------------------------- 1 | from fr_loss import ArcLoss 2 | from mxnet import init 3 | import mxnet as mx 4 | from mxnet.gluon import nn 5 | import numpy as np 6 | from mxnet import nd 7 | from mxnet.gluon import rnn 8 | import pickle 9 | from mxnet import gluon 10 | from mxnet import autograd 11 | """ 12 | classification orgin classificaiton loss, 1v1 13 | """ 14 | train_size = 31427 15 | batch_size = 100 16 | ctx = mx.gpu(0) 17 | padding_len = 200 18 | 19 | 20 | train = open('train_class.txt','r') 21 | 22 | 23 | def get_single_data_train(raw_data,batch_size): 24 | sent_all,class_id = np.zeros((train_size,padding_len)),[] 25 | t = 0 26 | for line in raw_data: 27 | line = line.strip() 28 | class_label = line.split('\t')[0] 29 | sent = line.split('\t')[1] 30 | 31 | def get_sent(sent): 32 | sent = sent.split() 33 | sent = sent[0:padding_len] 34 | sent = [int(k) for k in sent] 35 | return [0]*(padding_len-len(sent)) + sent 36 | sent_all[t] = get_sent(sent) 37 | class_id.append(class_label) 38 | t=t+1 39 | if(t%10000==0): 40 | print(t) 41 | sent_all = nd.array(sent_all, ctx = mx.cpu()) 42 | class_id = nd.array(class_id, ctx = ctx) 43 | print("get data") 44 | train_dataset = gluon.data.ArrayDataset(sent_all,class_id) 45 | train_data_iter = gluon.data.DataLoader(train_dataset, batch_size,last_batch='discard', shuffle=True) 46 | return train_data_iter 47 | 48 | 49 | # for all 50 | def get_single_data_test_all(raw_data,val_size,test_bs): 51 | sent_all,class_id = np.zeros((val_size,padding_len)),[] 52 | t = 0 53 | 54 | for line in raw_data: 55 | line = line.strip() 56 | label = line.split('\t')[0] 57 | sent = line.split('\t')[1] 58 | sent = sent.split() 59 | sent = sent[0:padding_len] 60 | sent = [int(k) for k in sent] 61 | 62 | sent_all[t] = [0]*(padding_len-len(sent)) + sent 63 | 64 | class_id.append(int(label)) # when eval, logits-logits =0 when the logits!=0 , so keep the id>0 65 | t=t+1 66 | if(t%10000==0): 67 | print(t) 68 | sent_all = nd.array(sent_all, ctx = ctx) 69 | class_id = nd.array(class_id, ctx = ctx) 70 | print("get data") 71 | test_dataset = gluon.data.ArrayDataset(sent_all,class_id) 72 | test_data_iter = gluon.data.DataLoader(test_dataset, test_bs, shuffle=False) 73 | return test_data_iter 74 | 75 | 76 | 77 | class SMN_Last(nn.Block): 78 | def __init__(self,**kwargs): 79 | super(SMN_Last,self).__init__(**kwargs) 80 | with self.name_scope(): 81 | 82 | self.emb = nn.Embedding(19224+5,100) 83 | self.gru_1 = rnn.GRU(layout='NTC',hidden_size=100,bidirectional='True') 84 | self.pool = nn.GlobalMaxPool1D() 85 | self.W = self.params.get('param_test',shape=(12877,padding_len)) 86 | 87 | 88 | def forward(self,question,train,label): 89 | if(train): 90 | anc = question[:,0:padding_len] 91 | def compute_ques(question): 92 | mask = question.clip(0,1) 93 | mask = mask.reshape(batch_size,-1,1) 94 | question = self.emb(question) 95 | question = self.gru_1(question) 96 | question = mask*question 97 | question = nd.transpose(question,(0,2,1)) 98 | question = self.pool(question) 99 | # question = self.pool2(question) 100 | question = question.reshape(batch_size,-1) 101 | question = nd.L2Normalization(question,mode='instance') 102 | id_center = nd.L2Normalization(self.W.data(),mode='instance') 103 | res = nd.dot(question,id_center.T) 104 | return res 105 | anc = compute_ques(anc) 106 | 107 | 108 | #res = nd.dot(question, self.W.data().T) 109 | return anc 110 | else: 111 | 112 | q1 = question[:,0:padding_len] 113 | def compute_ques(question): 114 | mask = question.clip(0,1) 115 | mask = mask.reshape(-1,padding_len,1) 116 | question = self.emb(question) 117 | question = self.gru_1(question) 118 | question = mask*question 119 | question = nd.transpose(question,(0,2,1)) 120 | question = self.pool(question) 121 | # question = self.pool2(question) 122 | question = question.reshape(-1,200) 123 | question = nd.L2Normalization(question,mode='instance') 124 | return question 125 | q1 = compute_ques(q1) 126 | return q1 127 | 128 | 129 | 130 | #Train Model 131 | SMN = SMN_Last() 132 | SMN.initialize(ctx=ctx) 133 | 134 | 135 | 136 | train_iter = get_single_data_train(train,batch_size) 137 | max_epoch = 3000 138 | 139 | Sloss = ArcLoss(12877,0.5,64,False) 140 | trainer = gluon.Trainer(SMN.collect_params(), 'adam', {'learning_rate': 0.001}) 141 | 142 | val_post = open('mpost.txt','r') 143 | val_resp = open('mresp.txt','r') 144 | top_k = 1 145 | val_post_size = 1800 146 | random_size = 500 # test sample number 147 | val_resp_size = val_post_size*random_size 148 | val_size = val_post_size 149 | 150 | 151 | def test(SMN): 152 | 153 | 154 | for post,post_label in val_post_iter: 155 | post_encoding = SMN(post,False,"place_holder") 156 | post_encoding = post_encoding.reshape((val_post_size,1,-1)) 157 | # val_post_size *1* 100 158 | xcount = 0 159 | all_count = 0 160 | for resp,label in val_resp_iter: 161 | res = SMN(resp,False,False) # every raw is the predict for the line 162 | res = res.reshape((3,random_size,-1)) 163 | res = nd.transpose(res,(0,2,1)) 164 | res = nd.batch_dot(post_encoding[xcount*3:(xcount+1)*3,].copyto(mx.cpu()),res.copyto(mx.cpu()))# yunsuanjieguo 165 | res = res.reshape(3,-1) 166 | index = nd.topk(res, ret_typ='indices',k=top_k).reshape(-1,).asnumpy().tolist() # val_size*k,1 167 | index = [label.asnumpy().tolist()[int(indi)] for indi in index] 168 | xcount = xcount + 1 169 | zero_matrix = np.array(index)+1 170 | all_count = all_count + np.sum(zero_matrix==0) 171 | print(xcount) 172 | print("count: " + str(all_count)) 173 | print(" percent: " + str(all_count/(val_size*top_k))) 174 | 175 | val_post_iter = get_single_data_test_all(val_post,val_post_size,val_post_size) 176 | val_resp_iter = get_single_data_test_all(val_resp,val_resp_size,val_resp_size/600) 177 | 178 | for epoch in range(max_epoch): 179 | train_loss = 0. 180 | count = 0 181 | for question,label in train_iter: 182 | question = question.copyto(ctx) 183 | label = label.copyto(ctx) 184 | with autograd.record(): 185 | ques = SMN(question,True,label) 186 | loss = Sloss(ques,label) 187 | count = count + 1 188 | loss.backward() 189 | trainer.step(batch_size) 190 | if(True): 191 | print("loss of epoch "+str(epoch) +" batch "+str(count)+": ") 192 | print(nd.mean(loss).asscalar()) 193 | if(epoch%4==3): 194 | test(SMN) 195 | 196 | # if(epoch%30==0): 197 | # test() 198 | 199 | # acc = mx.metric.Accuracy()#(top_k=1000) 200 | # acc10 = mx.metric.TopKAccuracy(top_k=5000) 201 | # acc100 = mx.metric.TopKAccuracy(top_k=100) 202 | # acc1000 = mx.metric.TopKAccuracy(top_k=1000) 203 | 204 | # pos_mask = nd.array(np.arange(10000),ctx=ctx) 205 | # neg_mask = np.flipud(np.arange(10000)) 206 | # neg_mask = nd.array(neg_mask,ctx=ctx) 207 | 208 | # print(acc.get()) 209 | # print(acc10.get()) 210 | -------------------------------------------------------------------------------- /gen_w2v.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import logging 5 | import os 6 | import sys 7 | import multiprocessing 8 | 9 | from gensim.models import Word2Vec 10 | from gensim.models.word2vec import LineSentence 11 | 12 | if __name__ == '__main__': 13 | program = os.path.basename(sys.argv[0]) 14 | logger = logging.getLogger(program) 15 | 16 | logging.basicConfig(format='%(asctime)s: %(levelname)s: %(message)s') 17 | logging.root.setLevel(level=logging.INFO) 18 | logger.info("running %s" % ' '.join(sys.argv)) 19 | 20 | # check and process input arguments 21 | if len(sys.argv) < 4: 22 | print(globals()['__doc__'] % locals()) 23 | sys.exit(1) 24 | inp, outp1, outp2 = sys.argv[1:4] 25 | 26 | model = Word2Vec(LineSentence(inp), size=200, window=5, min_count=5, 27 | workers=multiprocessing.cpu_count()) 28 | 29 | # trim unneeded model memory = use(much) less RAM 30 | # model.init_sims(replace=True) 31 | model.save(outp1) 32 | model.wv.save_word2vec_format(outp2, binary=False) 33 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from mxnet import init 2 | import mxnet as mx 3 | from mxnet.gluon import nn 4 | from gensim.models.word2vec import Word2Vec 5 | from PreProcess import WordVecs 6 | import numpy as np 7 | from mxnet import nd 8 | from mxnet.gluon import rnn 9 | import pickle 10 | from mxnet import gluon 11 | from mxnet import autograd 12 | from sklearn.metrics import recall_score as r 13 | batch_size = 1000 14 | ctx = mx.gpu(0) 15 | #136334x200 shape for word embedding 16 | #Load data 17 | max_word_per_utterence = 50 18 | dataset = r"train_processed" 19 | test = r"test_processed" 20 | x = pickle.load(open(dataset,"rb"),encoding='utf-8') 21 | test_x = pickle.load(open(test,"rb"),encoding='utf-8') 22 | revs, wordvecs, max_l = x[0], x[1], x[2] 23 | 24 | max_turn = 10 25 | 26 | def test(Model): 27 | test_iter = get_data(test_x[0],wordvecs.word_idx_map,batch_size,max_l=50,val_size = 1000) 28 | y_pred, y_true = [],[] 29 | for test_data, ground_truth in test_iter: 30 | t_data,t_label = test_data,ground_truth 31 | test_y_hat = SMN(t_data)[:,1] 32 | test_label = nd.array(t_label,ctx=ctx) 33 | test_y_hat = test_y_hat.reshape((-1,10)) 34 | test_label = test_label.reshape((-1,10)) 35 | test_y_hat = nd.topk(test_y_hat, ret_typ='indices', k=1) 36 | test_y_hat = test_y_hat.reshape(-1,) 37 | test_y_hat = nd.one_hot(test_y_hat,10) 38 | test_y_hat = test_y_hat.reshape(-1,).asnumpy().tolist() 39 | test_label = test_label.reshape(-1,).asnumpy().tolist() 40 | y_pred = y_pred+test_y_hat 41 | y_true = y_true+test_label 42 | y_pred = nd.array(y_pred) 43 | y_true = nd.array(y_true) 44 | #macro_recall = r(y_true, y_pred, average='macro') 45 | #micro_recall = r(y_true, y_pred, average='micro') 46 | print(str(nd.mean(y_pred*y_true).asscalar())) 47 | def get_idx_from_sent_msg(sents, word_idx_map, max_l=50,mask = False): 48 | """ 49 | Transforms sentence into a list of indices. Pad with zeroes. 50 | """ 51 | turns = [] 52 | for sent in sents.split('_t_'): 53 | x = [0] * max_l 54 | words = sent.split() 55 | length = len(words) 56 | for i, word in enumerate(words): 57 | if max_l - length + i < 0: continue 58 | if word in word_idx_map: 59 | x[max_l - length + i] = word_idx_map[word] 60 | 61 | turns.append(x) 62 | 63 | final = [0.] * (max_l * max_turn) 64 | for i in range(max_turn): 65 | if max_turn - i <= len(turns): 66 | for j in range(max_l ): 67 | final[i*(max_l) + j] = turns[-(max_turn-i)][j] 68 | return final 69 | 70 | def get_idx_from_sent(sent, word_idx_map, max_l=50,mask = False): 71 | """ 72 | Transforms sentence into a list of indices. Pad with zeroes. 73 | """ 74 | x = [0] * max_l 75 | x_mask = [0.] * max_l 76 | words = sent.split() 77 | length = len(words) 78 | for i, word in enumerate(words): 79 | if max_l - length + i < 0: continue 80 | if word in word_idx_map: 81 | x[max_l - length + i] = word_idx_map[word] 82 | 83 | return x 84 | 85 | def get_data(raw_data,word_idx_map,batch_size,max_l,val_size=0): 86 | X,y,X_val,y_val = np.zeros((999000+1000,550)),[],[],[] 87 | i = 0 88 | t = 0 89 | for rev in raw_data: 90 | 91 | 92 | sent = get_idx_from_sent_msg(rev["m"], word_idx_map, max_l, False) 93 | sent += get_idx_from_sent(rev["r"], word_idx_map, max_l, False) 94 | if(False): 95 | if(i%1000==0): 96 | print(i) 97 | i = i + 1 98 | X_val.append(sent) 99 | y_val.append(int(rev["y"])) 100 | else: 101 | X[t] = sent 102 | t = t + 1 103 | if(t%10000==0): 104 | print(t) 105 | y.append(int(rev["y"])) 106 | X = X[0:t] 107 | y = y[0:t] 108 | X = nd.array(X,ctx =ctx) 109 | y = nd.array(y,ctx = ctx) 110 | print("get data") 111 | # X_val = nd.array(X_val,ctx=ctx) 112 | # y_val = nd.array(y_val,ctx=ctx) 113 | # mx.ndarray.reshape(y_val,shape=(batch_size,1)) 114 | train_dataset = gluon.data.ArrayDataset(X, y) 115 | train_data_iter = gluon.data.DataLoader(train_dataset, batch_size, shuffle=False) 116 | # val_dataset = gluon.data.ArrayDataset(X_val,y_val) 117 | # val_data_iter = gluon.data.DataLoader(val_dataset,batch_size,shuffle=True) 118 | return train_data_iter#,val_data_iter 119 | 120 | class SMN_Last(nn.Block): 121 | def __init__(self,**kwargs): 122 | super(SMN_Last,self).__init__(**kwargs) 123 | with self.name_scope(): 124 | self.Embed = nn.Embedding(136334,200) 125 | self.conv = nn.Conv2D(channels=8, kernel_size=3, activation='relu') 126 | self.pooling = nn.MaxPool2D(pool_size=3, strides=3) 127 | self.mlp_1 = nn.Dense(units=50,activation='tanh',flatten=True) 128 | self.gru_1 = rnn.GRU(hidden_size=50,layout='NTC') 129 | self.gru_2 = rnn.GRU(layout='NTC',hidden_size=50) 130 | self.mlp_2 = nn.Dense(units=2,flatten=False) 131 | self.W = self.params.get('param_test',shape=(50,50)) 132 | def forward(self,x): 133 | u_0,u_1,u_2,u_3,u_4,u_5,u_6,u_7,u_8,u_9,r=x[:,0:50],x[:,50:100],x[:,100:150],x[:,150:200],x[:,200:250],x[:,250:300],x[:,300:350],x[:,350:400],x[:,400:450],x[:,450:500],x[:,500:550] 134 | 135 | u_0 = self.Embed(u_0) 136 | u_1 = self.Embed(u_1) 137 | u_2 = self.Embed(u_2) 138 | u_3 = self.Embed(u_3) 139 | u_4 = self.Embed(u_4) 140 | u_5 = self.Embed(u_5) 141 | u_6 = self.Embed(u_6) 142 | u_7 = self.Embed(u_7) 143 | u_8 = self.Embed(u_8) 144 | u_9 = self.Embed(u_9) 145 | r = self.Embed(r) 146 | h_0 = nd.zeros((1,batch_size,50),ctx = ctx) 147 | 148 | gru_u_0,_ = self.gru_1(u_0, h_0) 149 | gru_u_1,_ = self.gru_1(u_1, h_0) 150 | gru_u_2,_ = self.gru_1(u_2, h_0) 151 | gru_u_3,_ = self.gru_1(u_3, h_0) 152 | gru_u_4,_ = self.gru_1(u_4, h_0) 153 | gru_u_5,_ = self.gru_1(u_5, h_0) 154 | gru_u_6,_ = self.gru_1(u_6, h_0) 155 | gru_u_7,_ = self.gru_1(u_7, h_0) 156 | gru_u_8,_ = self.gru_1(u_8, h_0) 157 | gru_u_9,_ = self.gru_1(u_9, h_0) 158 | gru_r,_ = self.gru_1(r, h_0) 159 | 160 | r_t = mx.nd.transpose(r,axes=(0,2,1)) 161 | gru_r_t = mx.nd.transpose(gru_r,axes=(0, 2, 1)) 162 | 163 | M01 = nd.batch_dot(u_0,r_t) 164 | M02 = nd.batch_dot(nd.dot(gru_u_0,self.W.data()),gru_r_t) 165 | M11 = nd.batch_dot(u_1,r_t) 166 | M12 = nd.batch_dot(nd.dot(gru_u_1,self.W.data()),gru_r_t) 167 | M21 = nd.batch_dot(u_2,r_t) 168 | M22 = nd.batch_dot(nd.dot(gru_u_2,self.W.data()),gru_r_t) 169 | M31 = nd.batch_dot(u_3,r_t) 170 | M32 = nd.batch_dot(nd.dot(gru_u_3,self.W.data()),gru_r_t) 171 | M41 = nd.batch_dot(u_4,r_t) 172 | M42 = nd.batch_dot(nd.dot(gru_u_4,self.W.data()),gru_r_t) 173 | M51 = nd.batch_dot(u_5,r_t) 174 | M52 = nd.batch_dot(nd.dot(gru_u_5,self.W.data()),gru_r_t) 175 | M61 = nd.batch_dot(u_6,r_t) 176 | M62 = nd.batch_dot(nd.dot(gru_u_6,self.W.data()),gru_r_t) 177 | M71 = nd.batch_dot(u_7,r_t) 178 | M72 = nd.batch_dot(nd.dot(gru_u_7,self.W.data()),gru_r_t) 179 | M81 = nd.batch_dot(u_8,r_t) 180 | M82 = nd.batch_dot(nd.dot(gru_u_8,self.W.data()),gru_r_t) 181 | M91 = nd.batch_dot(u_9,r_t) 182 | M92 = nd.batch_dot(nd.dot(gru_u_9,self.W.data()),gru_r_t) 183 | 184 | #input to conv layer 185 | M0 = nd.stack(M01, M02, axis=1) 186 | M1 = nd.stack(M11, M12, axis=1) 187 | M2 = nd.stack(M21, M22, axis=1) 188 | M3 = nd.stack(M31, M32, axis=1) 189 | M4 = nd.stack(M41, M42, axis=1) 190 | M5 = nd.stack(M51, M52, axis=1) 191 | M6 = nd.stack(M61, M62, axis=1) 192 | M7 = nd.stack(M71, M72, axis=1) 193 | M8 = nd.stack(M81, M82, axis=1) 194 | M9 = nd.stack(M91, M92, axis=1) 195 | 196 | #output of conv 197 | 198 | conv_out_0 = self.mlp_1(self.pooling(self.conv(M0))) 199 | conv_out_1 = self.mlp_1(self.pooling(self.conv(M1))) 200 | conv_out_2 = self.mlp_1(self.pooling(self.conv(M2))) 201 | conv_out_3 = self.mlp_1(self.pooling(self.conv(M3))) 202 | conv_out_4 = self.mlp_1(self.pooling(self.conv(M4))) 203 | conv_out_5 = self.mlp_1(self.pooling(self.conv(M5))) 204 | conv_out_6 = self.mlp_1(self.pooling(self.conv(M6))) 205 | conv_out_7 = self.mlp_1(self.pooling(self.conv(M7))) 206 | conv_out_8 = self.mlp_1(self.pooling(self.conv(M8))) 207 | conv_out_9 = self.mlp_1(self.pooling(self.conv(M9))) 208 | #TODO:figure out why the conv output is tuple? 209 | #concat as input to gru_2 210 | Concat_conv_out = nd.stack(conv_out_0,conv_out_1,conv_out_2,conv_out_3, 211 | conv_out_4,conv_out_5,conv_out_6,conv_out_7, 212 | conv_out_8,conv_out_9,axis=1) 213 | #output of gru_2 214 | h_1 = nd.zeros((1,batch_size,50),ctx=ctx) 215 | _,gru_out = self.gru_2(Concat_conv_out,h_1) 216 | #output of mlp(yhat) 217 | y_hat = self.mlp_2(gru_out[0]) 218 | return y_hat[0] 219 | 220 | #Train Model 221 | SMN = SMN_Last() 222 | SMN.initialize(ctx=mx.gpu()) 223 | word2vec = (nd.array(wordvecs.W)).copyto(mx.gpu(0)) 224 | #print(word2vec) 225 | class MyInit(init.Initializer): 226 | def __init__(self): 227 | super(MyInit,self).__init__() 228 | self._verbose = True 229 | def _init_weight(self, _,arr): 230 | word2vec 231 | #params = SMN.collect_params() 232 | #print(params) 233 | #params['smn_last0_embedding0_weight'].initialize(force_reinit = True,ctx = mx.cpu()) 234 | #params['smn_last0_embedding0_weight'].initialize(MyInit(),force_reinit = True,ctx = ctx) 235 | #SMN.Embed.weight.set_data(word2vec) 236 | 237 | train_iter = get_data(revs,wordvecs.word_idx_map,batch_size,max_l=50,val_size = 1000) 238 | 239 | max_epoch = 50 240 | 241 | softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss() 242 | trainer = gluon.Trainer(SMN.collect_params(), 'adam', {'learning_rate': 0.001}) 243 | 244 | 245 | 246 | for epoch in range(max_epoch): 247 | train_loss = 0. 248 | train_acc = 0. 249 | for data, label in train_iter: 250 | with autograd.record(): 251 | output = SMN(data) 252 | loss = softmax_cross_entropy(output, nd.array(label,ctx=ctx)) 253 | loss.backward() 254 | trainer.step(batch_size) 255 | print("loss:") 256 | print(nd.mean(loss).asscalar()) 257 | print("recall:") 258 | test(SMN) 259 | train_loss += nd.mean(loss).asscalar() 260 | train_acc = 0 261 | test_acc = 0 262 | print("Epoch %d. Loss: %f, Train acc %f, Test acc %f" % ( 263 | epoch, train_loss/len(train_iter), train_acc/len(train_iter), test_acc)) 264 | --------------------------------------------------------------------------------