├── PreProcess.py
├── README.md
├── classification.py
├── gen_w2v.py
└── model.py
/PreProcess.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | import pickle
3 | from collections import defaultdict
4 | import logging
5 | #import theano
6 | import gensim
7 | import numpy as np
8 | from random import shuffle
9 | from gensim.models.word2vec import Word2Vec
10 | import codecs
11 | logger = logging.getLogger('relevance_logger')
12 |
13 |
14 | def build_multiturn_data(trainfile, max_len = 100,isshuffle=False):
15 | revs = []
16 | vocab = defaultdict(float)
17 | total = 1
18 | with codecs.open(trainfile,'r','utf-8') as f:
19 | for line in f:
20 | line = line.replace("_","")
21 | parts = line.strip().split("\t")
22 |
23 | lable = parts[0]
24 | message = ""
25 | words = set()
26 | for i in range(1,len(parts)-1,1):
27 | message += "_t_"
28 | message += parts[i]
29 | words.update(set(parts[i].split()))
30 |
31 | response = parts[-1]
32 |
33 | data = {"y" : lable, "m":message,"r": response}
34 | revs.append(data)
35 | total += 1
36 | if total % 10000 == 0:
37 | print(total)
38 | #words = set(message.split())
39 | words.update(set(response.split()))
40 |
41 | for word in words:
42 | vocab[word] += 1
43 | logger.info("processed dataset with %d question-answer pairs " %(len(revs)))
44 | logger.info("vocab size: %d" %(len(vocab)))
45 | if isshuffle == True:
46 | shuffle(revs)
47 | return revs, vocab, max_len
48 |
49 |
50 | def build_data(trainfile, max_len = 20,isshuffle=False):
51 | revs = []
52 | vocab = defaultdict(float)
53 | total = 1
54 | with codecs.open(trainfile,'r','utf-8') as f:
55 | for line in f:
56 | line = line.replace("_","")
57 | parts = line.strip().split("\t")
58 |
59 | topic = parts[0]
60 | topic_r = parts[1]
61 | lable = parts[2]
62 | message = parts[-2]
63 | response = parts[-1]
64 |
65 | data = {"y" : lable, "m":message,"r": response,"t":topic,"t2":topic_r}
66 | revs.append(data)
67 | total += 1
68 |
69 | words = set(message.split())
70 | words.update(set(response.split()))
71 | for word in words:
72 | vocab[word] += 1
73 | logger.info("processed dataset with %d question-answer pairs " %(len(revs)))
74 | logger.info("vocab size: %d" %(len(vocab)))
75 | if isshuffle == True:
76 | shuffle(revs)
77 | return revs, vocab, max_len
78 |
79 | class WordVecs(object):
80 | def __init__(self, fname, vocab, binary, gensim):
81 | if gensim:
82 | word_vecs = self.load_gensim(fname,vocab)
83 | self.k = len(list(word_vecs.values())[0])
84 | self.W, self.word_idx_map = self.get_W(word_vecs, k=self.k)
85 |
86 | def get_W(self, word_vecs, k=300):
87 | """
88 | Get word matrix. W[i] is the vector for word indexed by i
89 | """
90 | vocab_size = len(word_vecs)
91 | word_idx_map = dict()
92 | W = np.zeros(shape=(vocab_size+1, k))
93 | W[0] = np.zeros(k)
94 | i = 1
95 | for word in word_vecs:
96 | W[i] = word_vecs[word]
97 | word_idx_map[word] = i
98 | i += 1
99 | return W, word_idx_map
100 |
101 | def load_gensim(self, fname, vocab):
102 | model = Word2Vec.load(fname)
103 | weights = [[0.] * model.vector_size]
104 | word_vecs = {}
105 | total_inside_new_embed = 0
106 | miss= 0
107 | for pair in vocab:
108 | word = gensim.utils.to_unicode(pair)
109 | if word in model:
110 | total_inside_new_embed += 1
111 | word_vecs[pair] = np.array([w for w in model[word]])
112 | #weights.append([w for w in model[word]])
113 | else:
114 | miss = miss + 1
115 | word_vecs[pair] = np.array([0.] * model.vector_size)
116 | #weights.append([0.] * model.vector_size)
117 | print('transfer', total_inside_new_embed, 'words from the embedding file, total', len(vocab), 'candidate')
118 | print('miss word2vec', miss)
119 | return word_vecs
120 |
121 | def ParseSingleTurn():
122 | logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO)
123 | revs, vocab, max_len = build_data(r"\\msra-sandvm-001\v-wuyu\Data\ubuntu_data\ubuntu_data\train.topic",isshuffle=True)
124 | word2vec = WordVecs(r"\\msra-sandvm-001\v-wuyu\Models\W2V\Ubuntu\word2vec.model", vocab, True, True)
125 | pickle.dump([revs, word2vec, max_len,createtopicvec()], open("ubuntu_data.test",'wb'))
126 | logger.info("dataset created!")
127 |
128 | def ParseMultiTurn():
129 | logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO)
130 | revs, vocab, max_len = build_multiturn_data(r"/root/桌面/DoubanConversaionCorpus/train.txt",isshuffle=False)
131 | word2vec = WordVecs(r"/root/桌面/DoubanConversaionCorpus/train_vec.model", vocab, True, True)
132 | pickle.dump([revs, word2vec, max_len], open("train_processed",'wb'))
133 | logger.info("dataset created!")
134 |
135 | if __name__=="__main__":
136 | ParseMultiTurn()
137 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # MXNET-SMN
2 |
3 | Sequential Matching Network(ACL 2017) reimplemented by MXNET1.0-GPU//but you can always use cpu
4 |
5 | 用mxnet写的SMN,这篇文章发表在ACL 2017,使用MXNET1.0实现
6 |
7 |
8 |
9 | # What's SMN
10 | **SMN是啥?**
11 |
12 | SoTA model of multi-turn retrieval-based conversation systems.
13 |
14 | You can see the paper[SMN](http://www.aclweb.org/anthology/P/P17/P17-1046.pdf)
15 |
16 | Don't forget reference it if you use.
17 |
18 | Wu, Yu, et al. "Sequential Matching Network: A New Archtechture for Multi-turn Response Selection in Retrieval-based Chatbots." ACL. 2017.
19 |
20 |
21 | SMN是目前最屌的多轮检索对话模型.
22 |
23 | 如果你是个大佬,请去读一下这篇论文[SMN](http://www.aclweb.org/anthology/P/P17/P17-1046.pdf),这是我这一年中读过的为数不多的好文章.
24 |
25 | 它坦诚了SMN在某些方面的失败,对于baseline也作了完整归类,总结,所以ACL的质量还是非常高的。
26 |
27 |
28 |
29 | # How to use?
30 | **你能不能告诉我怎么用?**
31 |
32 |
33 | 1.get data from msra[douban corpus](https://1drv.ms/u/s!AtcxwlQuQjw1jF0bjeaKHEUNwitA)
34 |
35 | 2.get pre-trained word2vec.model use bash *python3.5 gen_w2v.py train.txt train_vec.model train_vec.vec*
36 |
37 | 3.get processed data use Process.py(also only support py3)
38 |
39 | 4.run the model
40 |
41 |
42 | 1.下载数据 [douban corpus](https://1drv.ms/u/s!AtcxwlQuQjw1jF0bjeaKHEUNwitA)
43 |
44 | 2.预训练Embedding矩阵*python3.5 gen_w2v.py train.txt train_vec.model train_vec.vec*
45 |
46 | 3.预处理数据使用Process.py
47 |
48 | 4.跑模型
49 |
50 | 额外赠礼:每个py文件的作用
51 |
52 | |代码 |作用 |
53 | |:------------------------------------:|------------------------------------|
54 | |gen_w2v.py|用来生成预训练的词向量|
55 | |Process.py|用来打包数据|
56 | |model.py|用来训练模型|
57 |
58 |
59 |
60 |
61 | # Params of Model
62 | **模型的参数**
63 |
64 | batch_size = 1000(with 1 titan xp)
65 |
66 | embedding_size = 200
67 |
68 | gru_layer = 1
69 |
70 | max_turn = 10
71 |
72 | max_len = 50
73 |
74 | lr = 0.001
75 |
76 | **警告:如果尝试修改参数,将会是一件非常痛苦的事情,因为我的代码高耦合**
77 |
78 |
79 |
80 | # Why you use MXNET
81 | **为什么是MXNET?**
82 |
83 | **Fast!**
84 |
85 | **快**! 不仅是开发快,运行快,训练也快。
86 |
87 | > *当你使用MXNET的时候你会有一种闪电侠在中城奔跑的错觉。*
88 | > *-杜存宵*
89 |
90 |
91 |
92 | # Other versions?
93 | **我只会用该死的TensorFlow,怎么办?**
94 |
95 | [There is Theano version.](https://github.com/MarkWuNLP/MultiTurnResponseSelection)
96 |
97 | 你可以选择学习theano或者mxnet,也可以自己实现一个。
98 |
99 | *或者自杀*。
100 |
101 |
102 |
103 | # Your code is not pythonic!
104 | **你的代码就像shit一样**
105 |
106 | Sorry to hear that.
107 |
108 | 你可以贡献你pythonic的代码,但是不好意思,我只用了三天的边角零头来开发这个模型。
109 |
110 |
111 |
112 | # Can I use your code to do a chatbot
113 | **我能用这个东西做个聊天机器人吗**
114 |
115 | If you can use lucene then you can.
116 |
117 | 需要为每个query使用lucene检索出一个候选列表,之后进行排序
118 |
119 |
120 |
121 | # How to tipping
122 | **你真帅,我要为你生猴子**
123 |
124 | You can sent your money to this [website](https://love.alipay.com/donate/index.htm)
125 |
126 | 你可以在支付宝E公益进行打赏。
127 |
128 |
129 |
130 | ## 你的屁话真多
131 |
132 | 不好意思,不好意思。
133 |
134 | ## TODO:
135 |
136 | 1.save model
137 |
138 | 2.add dropout
139 |
140 | 3.add evaluations
141 |
142 | 4.add inference
143 |
144 |
--------------------------------------------------------------------------------
/classification.py:
--------------------------------------------------------------------------------
1 | from fr_loss import ArcLoss
2 | from mxnet import init
3 | import mxnet as mx
4 | from mxnet.gluon import nn
5 | import numpy as np
6 | from mxnet import nd
7 | from mxnet.gluon import rnn
8 | import pickle
9 | from mxnet import gluon
10 | from mxnet import autograd
11 | """
12 | classification orgin classificaiton loss, 1v1
13 | """
14 | train_size = 31427
15 | batch_size = 100
16 | ctx = mx.gpu(0)
17 | padding_len = 200
18 |
19 |
20 | train = open('train_class.txt','r')
21 |
22 |
23 | def get_single_data_train(raw_data,batch_size):
24 | sent_all,class_id = np.zeros((train_size,padding_len)),[]
25 | t = 0
26 | for line in raw_data:
27 | line = line.strip()
28 | class_label = line.split('\t')[0]
29 | sent = line.split('\t')[1]
30 |
31 | def get_sent(sent):
32 | sent = sent.split()
33 | sent = sent[0:padding_len]
34 | sent = [int(k) for k in sent]
35 | return [0]*(padding_len-len(sent)) + sent
36 | sent_all[t] = get_sent(sent)
37 | class_id.append(class_label)
38 | t=t+1
39 | if(t%10000==0):
40 | print(t)
41 | sent_all = nd.array(sent_all, ctx = mx.cpu())
42 | class_id = nd.array(class_id, ctx = ctx)
43 | print("get data")
44 | train_dataset = gluon.data.ArrayDataset(sent_all,class_id)
45 | train_data_iter = gluon.data.DataLoader(train_dataset, batch_size,last_batch='discard', shuffle=True)
46 | return train_data_iter
47 |
48 |
49 | # for all
50 | def get_single_data_test_all(raw_data,val_size,test_bs):
51 | sent_all,class_id = np.zeros((val_size,padding_len)),[]
52 | t = 0
53 |
54 | for line in raw_data:
55 | line = line.strip()
56 | label = line.split('\t')[0]
57 | sent = line.split('\t')[1]
58 | sent = sent.split()
59 | sent = sent[0:padding_len]
60 | sent = [int(k) for k in sent]
61 |
62 | sent_all[t] = [0]*(padding_len-len(sent)) + sent
63 |
64 | class_id.append(int(label)) # when eval, logits-logits =0 when the logits!=0 , so keep the id>0
65 | t=t+1
66 | if(t%10000==0):
67 | print(t)
68 | sent_all = nd.array(sent_all, ctx = ctx)
69 | class_id = nd.array(class_id, ctx = ctx)
70 | print("get data")
71 | test_dataset = gluon.data.ArrayDataset(sent_all,class_id)
72 | test_data_iter = gluon.data.DataLoader(test_dataset, test_bs, shuffle=False)
73 | return test_data_iter
74 |
75 |
76 |
77 | class SMN_Last(nn.Block):
78 | def __init__(self,**kwargs):
79 | super(SMN_Last,self).__init__(**kwargs)
80 | with self.name_scope():
81 |
82 | self.emb = nn.Embedding(19224+5,100)
83 | self.gru_1 = rnn.GRU(layout='NTC',hidden_size=100,bidirectional='True')
84 | self.pool = nn.GlobalMaxPool1D()
85 | self.W = self.params.get('param_test',shape=(12877,padding_len))
86 |
87 |
88 | def forward(self,question,train,label):
89 | if(train):
90 | anc = question[:,0:padding_len]
91 | def compute_ques(question):
92 | mask = question.clip(0,1)
93 | mask = mask.reshape(batch_size,-1,1)
94 | question = self.emb(question)
95 | question = self.gru_1(question)
96 | question = mask*question
97 | question = nd.transpose(question,(0,2,1))
98 | question = self.pool(question)
99 | # question = self.pool2(question)
100 | question = question.reshape(batch_size,-1)
101 | question = nd.L2Normalization(question,mode='instance')
102 | id_center = nd.L2Normalization(self.W.data(),mode='instance')
103 | res = nd.dot(question,id_center.T)
104 | return res
105 | anc = compute_ques(anc)
106 |
107 |
108 | #res = nd.dot(question, self.W.data().T)
109 | return anc
110 | else:
111 |
112 | q1 = question[:,0:padding_len]
113 | def compute_ques(question):
114 | mask = question.clip(0,1)
115 | mask = mask.reshape(-1,padding_len,1)
116 | question = self.emb(question)
117 | question = self.gru_1(question)
118 | question = mask*question
119 | question = nd.transpose(question,(0,2,1))
120 | question = self.pool(question)
121 | # question = self.pool2(question)
122 | question = question.reshape(-1,200)
123 | question = nd.L2Normalization(question,mode='instance')
124 | return question
125 | q1 = compute_ques(q1)
126 | return q1
127 |
128 |
129 |
130 | #Train Model
131 | SMN = SMN_Last()
132 | SMN.initialize(ctx=ctx)
133 |
134 |
135 |
136 | train_iter = get_single_data_train(train,batch_size)
137 | max_epoch = 3000
138 |
139 | Sloss = ArcLoss(12877,0.5,64,False)
140 | trainer = gluon.Trainer(SMN.collect_params(), 'adam', {'learning_rate': 0.001})
141 |
142 | val_post = open('mpost.txt','r')
143 | val_resp = open('mresp.txt','r')
144 | top_k = 1
145 | val_post_size = 1800
146 | random_size = 500 # test sample number
147 | val_resp_size = val_post_size*random_size
148 | val_size = val_post_size
149 |
150 |
151 | def test(SMN):
152 |
153 |
154 | for post,post_label in val_post_iter:
155 | post_encoding = SMN(post,False,"place_holder")
156 | post_encoding = post_encoding.reshape((val_post_size,1,-1))
157 | # val_post_size *1* 100
158 | xcount = 0
159 | all_count = 0
160 | for resp,label in val_resp_iter:
161 | res = SMN(resp,False,False) # every raw is the predict for the line
162 | res = res.reshape((3,random_size,-1))
163 | res = nd.transpose(res,(0,2,1))
164 | res = nd.batch_dot(post_encoding[xcount*3:(xcount+1)*3,].copyto(mx.cpu()),res.copyto(mx.cpu()))# yunsuanjieguo
165 | res = res.reshape(3,-1)
166 | index = nd.topk(res, ret_typ='indices',k=top_k).reshape(-1,).asnumpy().tolist() # val_size*k,1
167 | index = [label.asnumpy().tolist()[int(indi)] for indi in index]
168 | xcount = xcount + 1
169 | zero_matrix = np.array(index)+1
170 | all_count = all_count + np.sum(zero_matrix==0)
171 | print(xcount)
172 | print("count: " + str(all_count))
173 | print(" percent: " + str(all_count/(val_size*top_k)))
174 |
175 | val_post_iter = get_single_data_test_all(val_post,val_post_size,val_post_size)
176 | val_resp_iter = get_single_data_test_all(val_resp,val_resp_size,val_resp_size/600)
177 |
178 | for epoch in range(max_epoch):
179 | train_loss = 0.
180 | count = 0
181 | for question,label in train_iter:
182 | question = question.copyto(ctx)
183 | label = label.copyto(ctx)
184 | with autograd.record():
185 | ques = SMN(question,True,label)
186 | loss = Sloss(ques,label)
187 | count = count + 1
188 | loss.backward()
189 | trainer.step(batch_size)
190 | if(True):
191 | print("loss of epoch "+str(epoch) +" batch "+str(count)+": ")
192 | print(nd.mean(loss).asscalar())
193 | if(epoch%4==3):
194 | test(SMN)
195 |
196 | # if(epoch%30==0):
197 | # test()
198 |
199 | # acc = mx.metric.Accuracy()#(top_k=1000)
200 | # acc10 = mx.metric.TopKAccuracy(top_k=5000)
201 | # acc100 = mx.metric.TopKAccuracy(top_k=100)
202 | # acc1000 = mx.metric.TopKAccuracy(top_k=1000)
203 |
204 | # pos_mask = nd.array(np.arange(10000),ctx=ctx)
205 | # neg_mask = np.flipud(np.arange(10000))
206 | # neg_mask = nd.array(neg_mask,ctx=ctx)
207 |
208 | # print(acc.get())
209 | # print(acc10.get())
210 |
--------------------------------------------------------------------------------
/gen_w2v.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | import logging
5 | import os
6 | import sys
7 | import multiprocessing
8 |
9 | from gensim.models import Word2Vec
10 | from gensim.models.word2vec import LineSentence
11 |
12 | if __name__ == '__main__':
13 | program = os.path.basename(sys.argv[0])
14 | logger = logging.getLogger(program)
15 |
16 | logging.basicConfig(format='%(asctime)s: %(levelname)s: %(message)s')
17 | logging.root.setLevel(level=logging.INFO)
18 | logger.info("running %s" % ' '.join(sys.argv))
19 |
20 | # check and process input arguments
21 | if len(sys.argv) < 4:
22 | print(globals()['__doc__'] % locals())
23 | sys.exit(1)
24 | inp, outp1, outp2 = sys.argv[1:4]
25 |
26 | model = Word2Vec(LineSentence(inp), size=200, window=5, min_count=5,
27 | workers=multiprocessing.cpu_count())
28 |
29 | # trim unneeded model memory = use(much) less RAM
30 | # model.init_sims(replace=True)
31 | model.save(outp1)
32 | model.wv.save_word2vec_format(outp2, binary=False)
33 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | from mxnet import init
2 | import mxnet as mx
3 | from mxnet.gluon import nn
4 | from gensim.models.word2vec import Word2Vec
5 | from PreProcess import WordVecs
6 | import numpy as np
7 | from mxnet import nd
8 | from mxnet.gluon import rnn
9 | import pickle
10 | from mxnet import gluon
11 | from mxnet import autograd
12 | from sklearn.metrics import recall_score as r
13 | batch_size = 1000
14 | ctx = mx.gpu(0)
15 | #136334x200 shape for word embedding
16 | #Load data
17 | max_word_per_utterence = 50
18 | dataset = r"train_processed"
19 | test = r"test_processed"
20 | x = pickle.load(open(dataset,"rb"),encoding='utf-8')
21 | test_x = pickle.load(open(test,"rb"),encoding='utf-8')
22 | revs, wordvecs, max_l = x[0], x[1], x[2]
23 |
24 | max_turn = 10
25 |
26 | def test(Model):
27 | test_iter = get_data(test_x[0],wordvecs.word_idx_map,batch_size,max_l=50,val_size = 1000)
28 | y_pred, y_true = [],[]
29 | for test_data, ground_truth in test_iter:
30 | t_data,t_label = test_data,ground_truth
31 | test_y_hat = SMN(t_data)[:,1]
32 | test_label = nd.array(t_label,ctx=ctx)
33 | test_y_hat = test_y_hat.reshape((-1,10))
34 | test_label = test_label.reshape((-1,10))
35 | test_y_hat = nd.topk(test_y_hat, ret_typ='indices', k=1)
36 | test_y_hat = test_y_hat.reshape(-1,)
37 | test_y_hat = nd.one_hot(test_y_hat,10)
38 | test_y_hat = test_y_hat.reshape(-1,).asnumpy().tolist()
39 | test_label = test_label.reshape(-1,).asnumpy().tolist()
40 | y_pred = y_pred+test_y_hat
41 | y_true = y_true+test_label
42 | y_pred = nd.array(y_pred)
43 | y_true = nd.array(y_true)
44 | #macro_recall = r(y_true, y_pred, average='macro')
45 | #micro_recall = r(y_true, y_pred, average='micro')
46 | print(str(nd.mean(y_pred*y_true).asscalar()))
47 | def get_idx_from_sent_msg(sents, word_idx_map, max_l=50,mask = False):
48 | """
49 | Transforms sentence into a list of indices. Pad with zeroes.
50 | """
51 | turns = []
52 | for sent in sents.split('_t_'):
53 | x = [0] * max_l
54 | words = sent.split()
55 | length = len(words)
56 | for i, word in enumerate(words):
57 | if max_l - length + i < 0: continue
58 | if word in word_idx_map:
59 | x[max_l - length + i] = word_idx_map[word]
60 |
61 | turns.append(x)
62 |
63 | final = [0.] * (max_l * max_turn)
64 | for i in range(max_turn):
65 | if max_turn - i <= len(turns):
66 | for j in range(max_l ):
67 | final[i*(max_l) + j] = turns[-(max_turn-i)][j]
68 | return final
69 |
70 | def get_idx_from_sent(sent, word_idx_map, max_l=50,mask = False):
71 | """
72 | Transforms sentence into a list of indices. Pad with zeroes.
73 | """
74 | x = [0] * max_l
75 | x_mask = [0.] * max_l
76 | words = sent.split()
77 | length = len(words)
78 | for i, word in enumerate(words):
79 | if max_l - length + i < 0: continue
80 | if word in word_idx_map:
81 | x[max_l - length + i] = word_idx_map[word]
82 |
83 | return x
84 |
85 | def get_data(raw_data,word_idx_map,batch_size,max_l,val_size=0):
86 | X,y,X_val,y_val = np.zeros((999000+1000,550)),[],[],[]
87 | i = 0
88 | t = 0
89 | for rev in raw_data:
90 |
91 |
92 | sent = get_idx_from_sent_msg(rev["m"], word_idx_map, max_l, False)
93 | sent += get_idx_from_sent(rev["r"], word_idx_map, max_l, False)
94 | if(False):
95 | if(i%1000==0):
96 | print(i)
97 | i = i + 1
98 | X_val.append(sent)
99 | y_val.append(int(rev["y"]))
100 | else:
101 | X[t] = sent
102 | t = t + 1
103 | if(t%10000==0):
104 | print(t)
105 | y.append(int(rev["y"]))
106 | X = X[0:t]
107 | y = y[0:t]
108 | X = nd.array(X,ctx =ctx)
109 | y = nd.array(y,ctx = ctx)
110 | print("get data")
111 | # X_val = nd.array(X_val,ctx=ctx)
112 | # y_val = nd.array(y_val,ctx=ctx)
113 | # mx.ndarray.reshape(y_val,shape=(batch_size,1))
114 | train_dataset = gluon.data.ArrayDataset(X, y)
115 | train_data_iter = gluon.data.DataLoader(train_dataset, batch_size, shuffle=False)
116 | # val_dataset = gluon.data.ArrayDataset(X_val,y_val)
117 | # val_data_iter = gluon.data.DataLoader(val_dataset,batch_size,shuffle=True)
118 | return train_data_iter#,val_data_iter
119 |
120 | class SMN_Last(nn.Block):
121 | def __init__(self,**kwargs):
122 | super(SMN_Last,self).__init__(**kwargs)
123 | with self.name_scope():
124 | self.Embed = nn.Embedding(136334,200)
125 | self.conv = nn.Conv2D(channels=8, kernel_size=3, activation='relu')
126 | self.pooling = nn.MaxPool2D(pool_size=3, strides=3)
127 | self.mlp_1 = nn.Dense(units=50,activation='tanh',flatten=True)
128 | self.gru_1 = rnn.GRU(hidden_size=50,layout='NTC')
129 | self.gru_2 = rnn.GRU(layout='NTC',hidden_size=50)
130 | self.mlp_2 = nn.Dense(units=2,flatten=False)
131 | self.W = self.params.get('param_test',shape=(50,50))
132 | def forward(self,x):
133 | u_0,u_1,u_2,u_3,u_4,u_5,u_6,u_7,u_8,u_9,r=x[:,0:50],x[:,50:100],x[:,100:150],x[:,150:200],x[:,200:250],x[:,250:300],x[:,300:350],x[:,350:400],x[:,400:450],x[:,450:500],x[:,500:550]
134 |
135 | u_0 = self.Embed(u_0)
136 | u_1 = self.Embed(u_1)
137 | u_2 = self.Embed(u_2)
138 | u_3 = self.Embed(u_3)
139 | u_4 = self.Embed(u_4)
140 | u_5 = self.Embed(u_5)
141 | u_6 = self.Embed(u_6)
142 | u_7 = self.Embed(u_7)
143 | u_8 = self.Embed(u_8)
144 | u_9 = self.Embed(u_9)
145 | r = self.Embed(r)
146 | h_0 = nd.zeros((1,batch_size,50),ctx = ctx)
147 |
148 | gru_u_0,_ = self.gru_1(u_0, h_0)
149 | gru_u_1,_ = self.gru_1(u_1, h_0)
150 | gru_u_2,_ = self.gru_1(u_2, h_0)
151 | gru_u_3,_ = self.gru_1(u_3, h_0)
152 | gru_u_4,_ = self.gru_1(u_4, h_0)
153 | gru_u_5,_ = self.gru_1(u_5, h_0)
154 | gru_u_6,_ = self.gru_1(u_6, h_0)
155 | gru_u_7,_ = self.gru_1(u_7, h_0)
156 | gru_u_8,_ = self.gru_1(u_8, h_0)
157 | gru_u_9,_ = self.gru_1(u_9, h_0)
158 | gru_r,_ = self.gru_1(r, h_0)
159 |
160 | r_t = mx.nd.transpose(r,axes=(0,2,1))
161 | gru_r_t = mx.nd.transpose(gru_r,axes=(0, 2, 1))
162 |
163 | M01 = nd.batch_dot(u_0,r_t)
164 | M02 = nd.batch_dot(nd.dot(gru_u_0,self.W.data()),gru_r_t)
165 | M11 = nd.batch_dot(u_1,r_t)
166 | M12 = nd.batch_dot(nd.dot(gru_u_1,self.W.data()),gru_r_t)
167 | M21 = nd.batch_dot(u_2,r_t)
168 | M22 = nd.batch_dot(nd.dot(gru_u_2,self.W.data()),gru_r_t)
169 | M31 = nd.batch_dot(u_3,r_t)
170 | M32 = nd.batch_dot(nd.dot(gru_u_3,self.W.data()),gru_r_t)
171 | M41 = nd.batch_dot(u_4,r_t)
172 | M42 = nd.batch_dot(nd.dot(gru_u_4,self.W.data()),gru_r_t)
173 | M51 = nd.batch_dot(u_5,r_t)
174 | M52 = nd.batch_dot(nd.dot(gru_u_5,self.W.data()),gru_r_t)
175 | M61 = nd.batch_dot(u_6,r_t)
176 | M62 = nd.batch_dot(nd.dot(gru_u_6,self.W.data()),gru_r_t)
177 | M71 = nd.batch_dot(u_7,r_t)
178 | M72 = nd.batch_dot(nd.dot(gru_u_7,self.W.data()),gru_r_t)
179 | M81 = nd.batch_dot(u_8,r_t)
180 | M82 = nd.batch_dot(nd.dot(gru_u_8,self.W.data()),gru_r_t)
181 | M91 = nd.batch_dot(u_9,r_t)
182 | M92 = nd.batch_dot(nd.dot(gru_u_9,self.W.data()),gru_r_t)
183 |
184 | #input to conv layer
185 | M0 = nd.stack(M01, M02, axis=1)
186 | M1 = nd.stack(M11, M12, axis=1)
187 | M2 = nd.stack(M21, M22, axis=1)
188 | M3 = nd.stack(M31, M32, axis=1)
189 | M4 = nd.stack(M41, M42, axis=1)
190 | M5 = nd.stack(M51, M52, axis=1)
191 | M6 = nd.stack(M61, M62, axis=1)
192 | M7 = nd.stack(M71, M72, axis=1)
193 | M8 = nd.stack(M81, M82, axis=1)
194 | M9 = nd.stack(M91, M92, axis=1)
195 |
196 | #output of conv
197 |
198 | conv_out_0 = self.mlp_1(self.pooling(self.conv(M0)))
199 | conv_out_1 = self.mlp_1(self.pooling(self.conv(M1)))
200 | conv_out_2 = self.mlp_1(self.pooling(self.conv(M2)))
201 | conv_out_3 = self.mlp_1(self.pooling(self.conv(M3)))
202 | conv_out_4 = self.mlp_1(self.pooling(self.conv(M4)))
203 | conv_out_5 = self.mlp_1(self.pooling(self.conv(M5)))
204 | conv_out_6 = self.mlp_1(self.pooling(self.conv(M6)))
205 | conv_out_7 = self.mlp_1(self.pooling(self.conv(M7)))
206 | conv_out_8 = self.mlp_1(self.pooling(self.conv(M8)))
207 | conv_out_9 = self.mlp_1(self.pooling(self.conv(M9)))
208 | #TODO:figure out why the conv output is tuple?
209 | #concat as input to gru_2
210 | Concat_conv_out = nd.stack(conv_out_0,conv_out_1,conv_out_2,conv_out_3,
211 | conv_out_4,conv_out_5,conv_out_6,conv_out_7,
212 | conv_out_8,conv_out_9,axis=1)
213 | #output of gru_2
214 | h_1 = nd.zeros((1,batch_size,50),ctx=ctx)
215 | _,gru_out = self.gru_2(Concat_conv_out,h_1)
216 | #output of mlp(yhat)
217 | y_hat = self.mlp_2(gru_out[0])
218 | return y_hat[0]
219 |
220 | #Train Model
221 | SMN = SMN_Last()
222 | SMN.initialize(ctx=mx.gpu())
223 | word2vec = (nd.array(wordvecs.W)).copyto(mx.gpu(0))
224 | #print(word2vec)
225 | class MyInit(init.Initializer):
226 | def __init__(self):
227 | super(MyInit,self).__init__()
228 | self._verbose = True
229 | def _init_weight(self, _,arr):
230 | word2vec
231 | #params = SMN.collect_params()
232 | #print(params)
233 | #params['smn_last0_embedding0_weight'].initialize(force_reinit = True,ctx = mx.cpu())
234 | #params['smn_last0_embedding0_weight'].initialize(MyInit(),force_reinit = True,ctx = ctx)
235 | #SMN.Embed.weight.set_data(word2vec)
236 |
237 | train_iter = get_data(revs,wordvecs.word_idx_map,batch_size,max_l=50,val_size = 1000)
238 |
239 | max_epoch = 50
240 |
241 | softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss()
242 | trainer = gluon.Trainer(SMN.collect_params(), 'adam', {'learning_rate': 0.001})
243 |
244 |
245 |
246 | for epoch in range(max_epoch):
247 | train_loss = 0.
248 | train_acc = 0.
249 | for data, label in train_iter:
250 | with autograd.record():
251 | output = SMN(data)
252 | loss = softmax_cross_entropy(output, nd.array(label,ctx=ctx))
253 | loss.backward()
254 | trainer.step(batch_size)
255 | print("loss:")
256 | print(nd.mean(loss).asscalar())
257 | print("recall:")
258 | test(SMN)
259 | train_loss += nd.mean(loss).asscalar()
260 | train_acc = 0
261 | test_acc = 0
262 | print("Epoch %d. Loss: %f, Train acc %f, Test acc %f" % (
263 | epoch, train_loss/len(train_iter), train_acc/len(train_iter), test_acc))
264 |
--------------------------------------------------------------------------------