├── .gitignore ├── README.md ├── data ├── a.txt ├── b.txt ├── users.txt └── vocab.txt ├── datautils.py ├── lstm.py ├── main.py └── seq2seq.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | dataset.py 3 | data/process_a.txt 4 | data/process_b.txt 5 | data/test.txt 6 | lstm_test.py 7 | lstm_bucketing.py 8 | bucket_io.py 9 | misc.py 10 | *.iml 11 | test.py 12 | lstm_bak.py 13 | seq2seq_bak.py 14 | *.pickle 15 | *.pyc 16 | .DS_STORE 17 | bucket_io.pyc 18 | data/data.pickle 19 | lstm_copy.pyc 20 | lstm_test.pyc 21 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # mxnet-seq2seq 2 | 3 | This project implements the sequence to sequence learning with mxnet for open-domain chatbot 4 | 5 | ## Sequence to Sequence learning with LSTM encoder-decoder 6 | 7 | The seq2seq encoder-decoder architecture is introduced by [Sequence to Sequence Learning with Neural Networks](http://arxiv.org/abs/1409.3215) 8 | 9 | This implementation borrows idea from **lstm_bucketing**, I slightly modified it and reconstructed the embedding layer. 10 | 11 | ## How to run 12 | 13 | Firstly, process the data by 14 | ``` 15 | python datautils.py 16 | ``` 17 | then run the model by 18 | ``` 19 | python main.py 20 | ``` 21 | 22 | ## The architecture 23 | 24 | We know that **seq2seq encoder-decoder** architecture includes two RNNs (LSTMs), one for encoding source sequence and another for decoding target sequence. 25 | 26 | For NLP-related tasks, the sequence could be a natural language sentence. As a result, the encoder and decoder should **share the word embedding layer** . 27 | 28 | The bucketing is a grate solution adapting the arbitrariness of sequence length. I padding zero to a fixed length at the encoding sequence and make buckets at the decoding phrase. 29 | 30 | The data is formatted as: 31 | 32 | ``` 33 | 0 0 ... 0 23 12 121 832 || 2 3432 898 7 323 34 | 0 0 ... 0 43 98 233 323 || 7 4423 833 1 232 35 | 0 0 ... 0 32 44 133 555 || 2 4534 545 6 767 36 | --- 37 | 0 0 ... 0 23 12 121 832 || 2 3432 898 7 38 | 0 0 ... 0 23 12 121 832 || 2 3432 898 7 39 | 0 0 ... 0 23 12 121 832 || 2 3432 898 7 40 | --- 41 | 42 | ``` 43 | The input shape for embedding layer is **(batch\_size, seq\_len)**, the input shape for lstm encoder is **(batch\_size, seq\_len, embed\_dim)** . 44 | 45 | 46 | 47 | ## More details coming soon 48 | 49 | For any question, please send me email. 50 | 51 | yoosan.zhou at gmail dot com -------------------------------------------------------------------------------- /data/users.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoosan/mxnet-seq2seq/3dad92bf616782f411332cb0248f73d4e43c53e3/data/users.txt -------------------------------------------------------------------------------- /data/vocab.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoosan/mxnet-seq2seq/3dad92bf616782f411332cb0248f73d4e43c53e3/data/vocab.txt -------------------------------------------------------------------------------- /datautils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | import sys, nltk 4 | import numpy as np 5 | import mxnet as mx 6 | import pickle 7 | 8 | 9 | def Perplexity( label, pred ): 10 | label = label.T.reshape((-1,)) 11 | loss = 0. 12 | for i in range(pred.shape[0]): 13 | loss += -np.log(max(1e-10, pred[i][int(label[i])])) 14 | return np.exp(loss / label.size) 15 | 16 | 17 | def default_read_content( path ): 18 | with open(path) as ins: 19 | content = ins.read() 20 | return content 21 | 22 | 23 | def default_build_vocab( vocab_path ): 24 | vocab = default_read_content(vocab_path) 25 | vocab = vocab.split('\n') 26 | idx = 1 27 | vocab_rsd = {} 28 | for word in vocab: 29 | vocab_rsd[word] = idx 30 | idx += 1 31 | return vocab, vocab_rsd 32 | 33 | 34 | def default_gen_buckets( sentences, batch_size, the_vocab ): 35 | len_dict = {} 36 | max_len = -1 37 | for sentence in sentences: 38 | words = default_text2id(sentence, the_vocab) 39 | if len(words) == 0: 40 | continue 41 | if len(words) > max_len: 42 | max_len = len(words) 43 | if len(words) in len_dict: 44 | len_dict[len(words)] += 1 45 | else: 46 | len_dict[len(words)] = 1 47 | # print(len_dict) 48 | 49 | tl = 0 50 | buckets = [] 51 | for l, n in len_dict.items(): 52 | if n + tl >= batch_size: 53 | buckets.append(l) 54 | tl = 0 55 | else: 56 | tl += n 57 | if tl > 0: 58 | buckets.append(max_len) 59 | return buckets 60 | 61 | 62 | def default_text2id( sentence, the_vocab, max_len, vocab ): 63 | sentence = sentence.lower() 64 | words = nltk.word_tokenize(sentence) 65 | tokens = [] 66 | for w in words: 67 | if len(w) == 0: continue 68 | if len(tokens) >= max_len: break 69 | if not w in vocab: 70 | tokens.append(the_vocab['']) 71 | else: 72 | tokens.append(the_vocab[w]) 73 | return tokens 74 | 75 | 76 | def default_gen_buckets( len_dict, batch_size ): 77 | tl = 0 78 | buckets = [] 79 | for l, n in len_dict.items(): 80 | if n + tl >= batch_size: 81 | buckets.append(l) 82 | tl = 0 83 | else: 84 | tl += n 85 | return buckets 86 | 87 | 88 | class Seq2SeqIter(mx.io.DataIter): 89 | def __init__( self, data_path, source_path, target_path, vocab, vocab_rsd, batch_size, 90 | max_len, data_name='data', label_name='label', split_char='\n', 91 | text2id=None, read_content=None, model_parallel=False, ctx=mx.cpu() ): 92 | super(Seq2SeqIter, self).__init__() 93 | 94 | self.ctx = ctx 95 | self.iter_data = [] 96 | if data_path is not None: 97 | print 'loading data set' 98 | with open(data_path, 'r') as f: 99 | self.iter_data = pickle.load(f) 100 | self.size = len(self.iter_data) 101 | self.vocab = vocab 102 | self.vocab_rsd = vocab_rsd 103 | self.vocab_size = len(vocab) 104 | self.data_name = data_name 105 | self.label_name = label_name 106 | self.model_parallel = model_parallel 107 | self.batch_size = batch_size 108 | self.max_len = max_len 109 | self.source_path = source_path 110 | self.target_path = target_path 111 | self.split_char = split_char 112 | 113 | if text2id is None: 114 | self.text2id = default_text2id 115 | else: 116 | self.text2id = text2id 117 | if read_content is None: 118 | self.read_content = default_read_content 119 | else: 120 | self.read_content = read_content 121 | 122 | if len(self.iter_data) == 0: 123 | self.iter_data = self.make_data_iter_plan() 124 | 125 | def make_data_iter_plan( self ): 126 | print 'processing the raw data ' 127 | source = self.read_content(self.source_path) 128 | source_lines = source.split(self.split_char) 129 | 130 | target = self.read_content(self.target_path) 131 | target_lines = target.split(self.split_char) 132 | 133 | self.size = len(source_lines) 134 | self.suffer_ids = np.random.permutation(self.size) 135 | 136 | self.enc_inputs = [] 137 | self.dec_inputs = [] 138 | self.dec_targets = [] 139 | len_dict = {} 140 | cnt = 0 141 | for i in range(self.size): 142 | source = source_lines[i] 143 | target = target_lines[i] 144 | t1 = source.split('\t\t') 145 | t2 = target.split('\t\t') 146 | if len(t1) < 2 or len(t2) < 2: continue 147 | st, su = t1[0], t1[1] 148 | tt, tu = t2[0], t2[1] 149 | dec_input = [] 150 | dec_target = [] 151 | s_tokens = self.text2id(st, self.vocab_rsd, self.max_len, self.vocab) 152 | t_tokens = self.text2id(tt, self.vocab_rsd, self.max_len, self.vocab) 153 | self.enc_inputs.append(s_tokens) 154 | dec_input.append(self.vocab_rsd['']) 155 | dec_input[1:len(t_tokens) + 1] = t_tokens[:] 156 | self.dec_inputs.append(dec_input) 157 | dec_target[:len(t_tokens)] = t_tokens[:] 158 | dec_target.append(self.vocab_rsd['']) 159 | self.dec_targets.append(dec_target) 160 | if len(dec_input) < 3: continue 161 | if not len(dec_input) in len_dict.keys(): 162 | len_dict[len(dec_input)] = 1 163 | else: 164 | len_dict[len(dec_input)] += 1 165 | cnt += 1 166 | self.buckets = default_gen_buckets(len_dict, self.batch_size) 167 | self.len_dict = len_dict 168 | self.size = cnt 169 | bucket_n_batches = {} 170 | for l, n in self.len_dict.items(): 171 | if l < 3: 172 | continue 173 | bucket_n_batches[l] = n / self.batch_size 174 | # print bucket_n_batches 175 | 176 | data_buffer = {} 177 | for i in range(self.size): 178 | dec_input = self.dec_inputs[i] 179 | if len(dec_input) < 3: 180 | continue 181 | enc_input = self.enc_inputs[i] 182 | dec_target = self.dec_targets[i] 183 | if not len(dec_input) in data_buffer.keys(): 184 | data_buffer[len(dec_input)] = [] 185 | data_buffer[len(dec_input)].append({ 186 | 'enc_input': enc_input, 187 | 'dec_input': dec_input, 188 | 'dec_target': dec_target 189 | }) 190 | else: 191 | data_buffer[len(dec_input)].append({ 192 | 'enc_input': enc_input, 193 | 'dec_input': dec_input, 194 | 'dec_target': dec_target 195 | }) 196 | 197 | iter_data = [] 198 | for l, n in self.len_dict.items(): 199 | for k in range(0, n, self.batch_size): 200 | if k + self.batch_size >= self.size: break 201 | encin_batch = np.zeros((self.batch_size, self.max_len)) 202 | decin_batch = np.zeros((self.batch_size, l)) 203 | dectr_batch = np.zeros((self.batch_size, l)) 204 | if n < self.batch_size: break 205 | for j in range(self.batch_size): 206 | one = data_buffer[l][j] 207 | encin = one['enc_input'] 208 | offset = self.max_len - len(encin) 209 | encin_batch[j][offset:] = encin 210 | decin_batch[j] = one['dec_input'] 211 | dectr_batch[j] = one['dec_target'] 212 | 213 | iter_data.append({ 214 | 'enc_batch_in': encin_batch, 215 | 'dec_batch_in': decin_batch, 216 | 'dec_batch_tr': dectr_batch 217 | }) 218 | with open('./data/data.pickle', 'w') as f: 219 | print 'dumping data ...' 220 | pickle.dump(iter_data, f) 221 | return iter_data 222 | 223 | def __iter__( self ): 224 | for batch in self.iter_data: 225 | yield batch 226 | 227 | 228 | class SimpleBatch(object): 229 | def __init__( self, data_names, data, label_names, label, bucket_key ): 230 | self.data = data 231 | self.label = label 232 | self.data_names = data_names 233 | self.label_names = label_names 234 | self.bucket_key = bucket_key 235 | 236 | self.pad = 0 237 | self.index = None 238 | 239 | @property 240 | def provide_data( self ): 241 | return [(n, x.shape) for n, x in zip(self.data_names, self.data)] 242 | 243 | @property 244 | def provide_label( self ): 245 | return [(n, x.shape) for n, x in zip(self.label_names, self.label)] 246 | 247 | 248 | if __name__ == '__main__': 249 | vocab, vocab_rsd = default_build_vocab('./data/vocab.txt') 250 | data = Seq2SeqIter(data_path=None, source_path='./data/a.txt', target_path='./data/b.txt', 251 | vocab=vocab, vocab_rsd=vocab_rsd, batch_size=10, max_len=25, 252 | data_name='data', label_name='label', split_char='\n', 253 | text2id=None, read_content=None, model_parallel=False) 254 | for iter in data: 255 | print 'enc input size is (%d, %d), and dec size is (%d, %d)' % \ 256 | (iter['enc_batch_in'].shape[0], iter['enc_batch_in'].shape[1], 257 | iter['dec_batch_in'].shape[0], iter['dec_batch_in'].shape[1]) 258 | -------------------------------------------------------------------------------- /lstm.py: -------------------------------------------------------------------------------- 1 | # pylint:skip-file 2 | import sys 3 | 4 | sys.path.insert(0, "../../python") 5 | import mxnet as mx 6 | import numpy as np 7 | from collections import namedtuple 8 | import time 9 | import math 10 | 11 | LSTMState = namedtuple("LSTMState", ["c", "h"]) 12 | LSTMParam = namedtuple("LSTMParam", ["i2h_weight", "i2h_bias", 13 | "h2h_weight", "h2h_bias"]) 14 | LSTMModel = namedtuple("LSTMModel", ["rnn_exec", "symbol", 15 | "init_states", "last_states", 16 | "seq_data", "seq_labels", "seq_outputs", 17 | "param_blocks"]) 18 | 19 | 20 | def lstm( num_hidden, indata, prev_state, param, seqidx, layeridx, dropout=0. ): 21 | """LSTM Cell symbol""" 22 | if dropout > 0.: 23 | indata = mx.sym.Dropout(data=indata, p=dropout) 24 | i2h = mx.sym.FullyConnected(data=indata, 25 | weight=param.i2h_weight, 26 | bias=param.i2h_bias, 27 | num_hidden=num_hidden * 4, 28 | name="t%d_l%d_i2h" % (seqidx, layeridx)) 29 | h2h = mx.sym.FullyConnected(data=prev_state.h, 30 | weight=param.h2h_weight, 31 | bias=param.h2h_bias, 32 | num_hidden=num_hidden * 4, 33 | name="t%d_l%d_h2h" % (seqidx, layeridx)) 34 | gates = i2h + h2h 35 | slice_gates = mx.sym.SliceChannel(gates, num_outputs=4, 36 | name="t%d_l%d_slice" % (seqidx, layeridx)) 37 | in_gate = mx.sym.Activation(slice_gates[0], act_type="sigmoid") 38 | in_transform = mx.sym.Activation(slice_gates[1], act_type="tanh") 39 | forget_gate = mx.sym.Activation(slice_gates[2], act_type="sigmoid") 40 | out_gate = mx.sym.Activation(slice_gates[3], act_type="sigmoid") 41 | next_c = (forget_gate * prev_state.c) + (in_gate * in_transform) 42 | next_h = out_gate * mx.sym.Activation(next_c, act_type="tanh") 43 | return LSTMState(c=next_c, h=next_h) 44 | 45 | 46 | # we define a new unrolling function here because the original 47 | # one in lstm_bak.py concats all the labels at the last layer together, 48 | # making the mini-batch size of the label different from the data. 49 | # I think the existing data-parallelization code need some modification 50 | # to allow this situation to work properly 51 | def dec_lstm_unroll( num_lstm_layer, seq_len, num_hidden, num_label, dropout=0., is_train=True): 52 | cls_weight = mx.sym.Variable("cls_weight") 53 | cls_bias = mx.sym.Variable("cls_bias") 54 | param_cells = [] 55 | last_states = [] 56 | for i in range(num_lstm_layer): 57 | param_cells.append(LSTMParam(i2h_weight=mx.sym.Variable("l%d_i2h_weight" % i), 58 | i2h_bias=mx.sym.Variable("l%d_i2h_bias" % i), 59 | h2h_weight=mx.sym.Variable("l%d_h2h_weight" % i), 60 | h2h_bias=mx.sym.Variable("l%d_h2h_bias" % i))) 61 | state = LSTMState(c=mx.sym.Variable("l%d_init_c" % i), 62 | h=mx.sym.Variable("l%d_init_h" % i)) 63 | last_states.append(state) 64 | assert (len(last_states) == num_lstm_layer) 65 | 66 | # embedding layer 67 | data = mx.sym.Variable('data') 68 | label = mx.sym.Variable('softmax_label') 69 | wordvec = mx.sym.SliceChannel(data=data, num_outputs=seq_len, squeeze_axis=1) 70 | 71 | hidden_all = [] 72 | for seqidx in range(seq_len): 73 | hidden = wordvec[seqidx] 74 | 75 | # stack LSTM 76 | for i in range(num_lstm_layer): 77 | if i == 0: 78 | dp_ratio = 0. 79 | else: 80 | dp_ratio = dropout 81 | next_state = lstm(num_hidden, indata=hidden, 82 | prev_state=last_states[i], 83 | param=param_cells[i], 84 | seqidx=seqidx, layeridx=i, dropout=dp_ratio) 85 | hidden = next_state.h 86 | last_states[i] = next_state 87 | # decoder 88 | if dropout > 0.: 89 | hidden = mx.sym.Dropout(data=hidden, p=dropout) 90 | hidden_all.append(hidden) 91 | 92 | hidden_concat = mx.sym.Concat(*hidden_all, dim=0) 93 | pred = mx.sym.FullyConnected(data=hidden_concat, num_hidden=num_label, 94 | weight=cls_weight, bias=cls_bias, name='pred') 95 | 96 | label = mx.sym.transpose(data=label) 97 | label = mx.sym.Reshape(data=label, target_shape=(0,)) 98 | sm = mx.sym.SoftmaxOutput(data=pred, label=label, name='softmax') 99 | 100 | return sm 101 | 102 | 103 | # same with lstm_unroll, but without softmax layer 104 | def enc_lstm_unroll( num_lstm_layer, seq_len, num_hidden, dropout=0. ): 105 | param_cells = [] 106 | last_states = [] 107 | for i in range(num_lstm_layer): 108 | param_cells.append(LSTMParam(i2h_weight=mx.sym.Variable("l%d_i2h_weight" % i), 109 | i2h_bias=mx.sym.Variable("l%d_i2h_bias" % i), 110 | h2h_weight=mx.sym.Variable("l%d_h2h_weight" % i), 111 | h2h_bias=mx.sym.Variable("l%d_h2h_bias" % i))) 112 | state = LSTMState(c=mx.sym.Variable("l%d_init_c" % i), 113 | h=mx.sym.Variable("l%d_init_h" % i)) 114 | last_states.append(state) 115 | assert (len(last_states) == num_lstm_layer) 116 | 117 | data = mx.sym.Variable('data') 118 | wordvec = mx.sym.SliceChannel(data=data, num_outputs=seq_len, squeeze_axis=1) 119 | 120 | hidden_all = [] 121 | for seqidx in range(seq_len): 122 | hidden = wordvec[seqidx] 123 | 124 | # stack LSTM 125 | for i in range(num_lstm_layer): 126 | if i == 0: 127 | dp_ratio = 0. 128 | else: 129 | dp_ratio = dropout 130 | next_state = lstm(num_hidden, indata=hidden, 131 | prev_state=last_states[i], 132 | param=param_cells[i], 133 | seqidx=seqidx, layeridx=i, dropout=dp_ratio) 134 | hidden = next_state.h 135 | last_states[i] = next_state 136 | # decoder 137 | if dropout > 0.: 138 | hidden = mx.sym.Dropout(data=hidden, p=dropout) 139 | hidden_all.append(hidden) 140 | 141 | hidden_concat = mx.sym.Concat(*hidden_all, dim=0) 142 | 143 | return hidden 144 | 145 | 146 | def perplexity( label, pred ): 147 | label = label.T.reshape((-1,)) 148 | loss = 0. 149 | for i in range(pred.shape[0]): 150 | loss += -np.log(max(1e-10, pred[i][int(label[i])])) 151 | return np.exp(loss / label.size) 152 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import random 3 | import numpy as np 4 | import mxnet as mx 5 | from datautils import Seq2SeqIter, default_build_vocab 6 | from seq2seq import Seq2Seq 7 | 8 | 9 | CTX = mx.cpu() 10 | 11 | def main(**args): 12 | vocab, vocab_rsd = default_build_vocab('./data/vocab.txt') 13 | vocab_size = len(vocab) 14 | print 'vocabulary size is %d' % vocab_size 15 | data = Seq2SeqIter(data_path='./data/data.pickle', source_path='./data/a.txt', 16 | target_path='./data/b.txt', vocab=vocab, 17 | vocab_rsd=vocab_rsd, batch_size=10, max_len=25, 18 | data_name='data', label_name='label', split_char='\n', 19 | text2id=None, read_content=None, model_parallel=False) 20 | print 'training data size is %d' % data.size 21 | model = Seq2Seq(seq_len=25, batch_size=10, num_layers=1, 22 | input_size=vocab_size, embed_size=150, hidden_size=150, 23 | output_size=vocab_size, dropout=0.0, mx_ctx=CTX) 24 | model.train(dataset=data, epoch=5) 25 | 26 | 27 | if __name__ == "__main__": 28 | main() -------------------------------------------------------------------------------- /seq2seq.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import random 3 | import numpy as np 4 | import mxnet as mx 5 | from tqdm import tqdm 6 | from lstm import enc_lstm_unroll, dec_lstm_unroll 7 | from datautils import Seq2SeqIter, default_build_vocab 8 | from datautils import SimpleBatch, Perplexity, default_text2id 9 | 10 | 11 | class Seq2Seq(object): 12 | def __init__( self, seq_len, batch_size, num_layers, 13 | input_size, embed_size, hidden_size, 14 | output_size, dropout, mx_ctx=mx.cpu() ): 15 | self.embed_dict = {} 16 | self.eval_embed_dict = {} 17 | self.seq_len = seq_len 18 | self.batch_size = batch_size 19 | self.num_layers = num_layers 20 | self.input_size = input_size 21 | self.embed_size = embed_size 22 | self.hidden_size = hidden_size 23 | self.dropout = dropout 24 | self.output_size = output_size 25 | self.ctx = mx_ctx 26 | # for training 27 | self.embed = self.build_embed_dict(self.seq_len + 1) 28 | self.encoder = self.build_lstm_encoder() 29 | self.decoder = self.build_lstm_decoder() 30 | self.init_h = mx.nd.zeros((self.batch_size, self.hidden_size), self.ctx) 31 | self.init_c = mx.nd.zeros((self.batch_size, self.hidden_size), self.ctx) 32 | # for evaluation 33 | # self.eval_embed = self.build_embed_dict(self.seq_len+1, is_train=False) 34 | # self.eval_encoder = self.build_lstm_encoder(is_train=False) 35 | # self.eval_decoder = self.build_lstm_decoder(is_train=False) 36 | 37 | def gen_embed_sym( self ): 38 | data = mx.sym.Variable('data') 39 | embed_weight = mx.sym.Variable("embed_weight") 40 | embed_sym = mx.sym.Embedding(data=data, input_dim=self.input_size, 41 | weight=embed_weight, 42 | output_dim=self.embed_size, name='embed') 43 | return embed_sym 44 | 45 | def build_embed_layer( self, default_bucket, is_train=True, bef_args=None ): 46 | 47 | embed_sym = self.gen_embed_sym() 48 | if is_train: 49 | embed = mx.mod.Module(symbol=embed_sym, data_names=('data',), label_names=None, context=self.ctx) 50 | 51 | embed.bind(data_shapes=[('data', (self.batch_size, default_bucket)), ], for_training=is_train) 52 | 53 | embed.init_params(initializer=mx.init.Xavier(factor_type="in", magnitude=2.34), arg_params=bef_args) 54 | embed.init_optimizer( 55 | optimizer='adam', 56 | optimizer_params={ 57 | 'learning_rate': 0.02, 58 | 'wd': 0., 59 | 'beta1': 0.5, 60 | }) 61 | else: 62 | batch = 1 63 | embed = mx.mod.Module(symbol=embed_sym, data_names=('data',), label_names=None, context=self.ctx) 64 | embed.bind(data_shapes=[('data', (batch, default_bucket)), ], for_training=is_train) 65 | embed.init_params(initializer=mx.init.Xavier(factor_type="in", magnitude=2.34), arg_params=bef_args) 66 | return embed 67 | 68 | def build_embed_dict( self, default_bucket, is_train=True ): 69 | sym = self.gen_embed_sym() 70 | batch = self.batch_size if is_train else 1 71 | if len(self.embed_dict.keys()) > 1: 72 | default_embed = self.embed_dict[0] 73 | module = mx.mod.Module(symbol=sym, data_names=('data',), label_names=None, context=self.ctx) 74 | module.bind(data_shapes=[('data', (batch, default_bucket))], label_shapes=None, 75 | for_training=is_train, force_rebind=False, 76 | shared_module=default_embed) 77 | else: 78 | default_embed = self.build_embed_layer(default_bucket, is_train=is_train) 79 | 80 | self.embed_dict[default_bucket] = default_embed 81 | 82 | for i in range(1, self.seq_len + 1): 83 | module = mx.mod.Module(symbol=sym, data_names=('data',), label_names=None, context=self.ctx) 84 | module.bind(data_shapes=[('data', (batch, i))], label_shapes=None, 85 | for_training=default_embed.for_training, 86 | inputs_need_grad=default_embed.inputs_need_grad, 87 | force_rebind=False, shared_module=default_embed) 88 | module.borrow_optimizer(default_embed) 89 | self.embed_dict[i] = module 90 | return self.embed_dict 91 | 92 | def build_lstm_encoder( self, is_train=True, bef_args=None ): 93 | enc_lstm_sym = enc_lstm_unroll(num_lstm_layer=self.num_layers, 94 | seq_len=self.seq_len, num_hidden=self.hidden_size) 95 | if is_train: 96 | encoder = mx.mod.Module(symbol=enc_lstm_sym, data_names=('data', 'l0_init_c', 'l0_init_h'), 97 | label_names=None, context=self.ctx) 98 | 99 | encoder.bind(data_shapes=[('data', (self.batch_size, self.seq_len, self.embed_size)), 100 | ('l0_init_c', (self.batch_size, self.hidden_size)), 101 | ('l0_init_h', (self.batch_size, self.hidden_size))], 102 | inputs_need_grad=True, 103 | for_training=is_train) 104 | 105 | encoder.init_params(initializer=mx.init.Xavier(factor_type="in", magnitude=2.34), arg_params=bef_args) 106 | encoder.init_optimizer( 107 | optimizer='adam', 108 | optimizer_params={ 109 | 'learning_rate': 0.02, 110 | 'wd': 0., 111 | 'beta1': 0.5, 112 | }) 113 | else: 114 | batch = 1 115 | encoder = mx.mod.Module(symbol=enc_lstm_sym, data_names=('data', 'l0_init_c', 'l0_init_h'), 116 | label_names=None, context=self.ctx) 117 | 118 | encoder.bind(data_shapes=[('data', (batch, self.seq_len, self.embed_size)), 119 | ('l0_init_c', (batch, self.hidden_size)), 120 | ('l0_init_h', (batch, self.hidden_size))], 121 | for_training=is_train) 122 | 123 | encoder.init_params(initializer=mx.init.Xavier(factor_type="in", magnitude=2.34), arg_params=bef_args) 124 | 125 | return encoder 126 | 127 | def build_lstm_decoder( self, is_train=True, bef_args=None ): 128 | def gen_dec_sym( seq_len ): 129 | sym = dec_lstm_unroll(1, seq_len, self.hidden_size, self.input_size, 0., is_train=is_train) 130 | data_names = ['data'] + ['l0_init_c', 'l0_init_h'] 131 | label_names = ['softmax_label'] 132 | return (sym, data_names, label_names) 133 | 134 | if is_train: 135 | decoder = mx.mod.BucketingModule(gen_dec_sym, default_bucket_key=self.seq_len + 1, context=self.ctx) 136 | decoder.bind(data_shapes=[('data', (self.batch_size, self.seq_len + 1, self.embed_size)), 137 | ('l0_init_c', (self.batch_size, self.hidden_size)), 138 | ('l0_init_h', (self.batch_size, self.hidden_size))], 139 | label_shapes=[('softmax_label', (self.batch_size, self.seq_len + 1))], 140 | inputs_need_grad=True, 141 | for_training=is_train, ) 142 | 143 | decoder.init_params(initializer=mx.init.Xavier(factor_type="in", magnitude=2.34), arg_params=bef_args) 144 | decoder.init_optimizer( 145 | optimizer='adam', 146 | optimizer_params={ 147 | 'learning_rate': 0.02, 148 | 'wd': 0., 149 | 'beta1': 0.5, 150 | }) 151 | else: 152 | batch = 1 153 | decoder = mx.mod.BucketingModule(gen_dec_sym, default_bucket_key=self.seq_len + 1, context=self.ctx) 154 | 155 | decoder.bind(data_shapes=[('data', (batch, self.seq_len + 1, self.embed_size)), 156 | ('l0_init_c', (batch, self.hidden_size)), 157 | ('l0_init_h', (batch, self.hidden_size))], 158 | label_shapes=['softmax_label'], 159 | for_training=False) 160 | 161 | decoder.init_params(initializer=mx.init.Xavier(factor_type="in", magnitude=2.34), arg_params=bef_args) 162 | 163 | return decoder 164 | 165 | def train_batch( self, enc_input_batch, dec_input_batch, dec_target_batch, is_train=True ): 166 | 167 | self.embed[self.seq_len].forward(mx.io.DataBatch([enc_input_batch], [])) 168 | enc_word_vecs = self.embed[self.seq_len].get_outputs()[0] 169 | 170 | self.encoder.forward(mx.io.DataBatch([enc_word_vecs, self.init_c, self.init_h], [])) 171 | enc_last_h = self.encoder.get_outputs()[0] 172 | 173 | dec_seq_len = dec_input_batch.shape[1] 174 | 175 | self.embed[dec_seq_len].forward(mx.io.DataBatch([dec_input_batch], [])) 176 | dec_word_vecs = self.embed[dec_seq_len].get_outputs()[0] 177 | 178 | self.decoder.forward(SimpleBatch(data_names=['data', 'l0_init_c', 'l0_init_h'], 179 | data=[dec_word_vecs, self.init_c, enc_last_h], 180 | label_names=['softmax_label'], 181 | label=[dec_target_batch], 182 | bucket_key=dec_seq_len)) 183 | output = self.decoder.get_outputs()[0] 184 | ppl = Perplexity(dec_target_batch.asnumpy(), output.asnumpy()) 185 | self.decoder.backward() 186 | dec_word_vecs_grad = self.decoder.get_input_grads()[0] 187 | grad_last_h = self.decoder.get_input_grads()[2] 188 | self.decoder.update() 189 | self.embed_dict[dec_seq_len].backward([dec_word_vecs_grad]) 190 | self.embed_dict[dec_seq_len].update() 191 | self.encoder.backward([grad_last_h]) 192 | enc_word_vecs_grad = self.encoder.get_input_grads()[0] 193 | self.encoder.update() 194 | self.embed_dict[self.seq_len].backward([enc_word_vecs_grad]) 195 | self.embed_dict[self.seq_len].update() 196 | return ppl 197 | 198 | def train( self, dataset, epoch ): 199 | for i in range(epoch): 200 | ppl = 0 201 | for batch in tqdm(dataset): 202 | enc_in = mx.nd.array(batch['enc_batch_in'], self.ctx) 203 | dec_in = mx.nd.array(batch['dec_batch_in'], self.ctx) 204 | dec_tr = mx.nd.array(batch['dec_batch_tr'], self.ctx) 205 | cur_ppl = self.train_batch(enc_input_batch=enc_in, 206 | dec_input_batch=dec_in, 207 | dec_target_batch=dec_tr) 208 | ppl = ppl + cur_ppl 209 | 210 | print 'epoch %d, ppl is %f' % (i, cur_ppl) 211 | 212 | # TODO 213 | def eval( self, sentence, vocab_rsd, vocab ): 214 | ids = default_text2id(sentence, vocab_rsd, 15, vocab) 215 | print ids 216 | --------------------------------------------------------------------------------