├── 1.png ├── Model.py ├── README.md ├── TIGS_Inferece.ipynb ├── TIGS_train.ipynb ├── Util ├── DiverseDecode.py ├── GSutil.py ├── __pycache__ │ ├── bleu.cpython-35.pyc │ ├── myAttWrapper.cpython-35.pyc │ ├── myResidualCell.cpython-35.pyc │ ├── myUtil.cpython-35.pyc │ └── my_helper.cpython-35.pyc ├── bleu.py ├── myAttLM.py ├── myAttLM_Diverse.py ├── myAttWrapper.py ├── myAttmoLM.py ├── myAttoLM.py ├── myLM.py ├── myResidualCell.py ├── myUtil.py ├── my_helper.py ├── my_seq2seq.py └── my_seq2seq_Diverse.py ├── bleu.py ├── overview_2.pdf └── results └── _URNN-f_res.pkl /1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dayihengliu/Text-Infilling-Gradient-Search/af89b634fa9d74222d29ed5ef0b91da46533dd62/1.png -------------------------------------------------------------------------------- /Model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time 3 | import copy 4 | import itertools 5 | import random 6 | import pickle as cPickle 7 | import matplotlib.pyplot as plt 8 | 9 | import tensorflow as tf 10 | from tensorflow.python.layers import core as core_layers 11 | 12 | 13 | import os 14 | from Util.myAttWrapper import SelfAttWrapper 15 | from Util import myResidualCell 16 | from Util.bleu import BLEU 17 | from Util.myUtil import * 18 | from Util import my_helper 19 | 20 | 21 | class LM: 22 | def __init__(self, dp, rnn_size, n_layers, decoder_embedding_dim, max_infer_length, is_jieba, 23 | sess, qid_list, close_loss_rate=0.0, l2_reg_lambda=0.015, l1_reg_lambda=0.0, att_type='B', lr=0.001, grad_clip=5.0, beam_width=5, force_teaching_ratio=1.0, beam_penalty=1.0, 24 | residual=False, output_keep_prob=0.5, input_keep_prob=0.9, cell_type='lstm', reverse=False, is_save=True, 25 | decay_scheme='luong234'): 26 | 27 | self.rnn_size = rnn_size 28 | self.n_layers = n_layers 29 | self.is_jieba = is_jieba 30 | self.grad_clip = grad_clip 31 | self.dp = dp 32 | self.qid_list = qid_list 33 | self.l2_reg_lambda = l2_reg_lambda 34 | self.l1_reg_lambda = l1_reg_lambda 35 | self.decoder_embedding_dim = decoder_embedding_dim 36 | self.beam_width = beam_width 37 | self.beam_penalty = beam_penalty 38 | self.max_infer_length = max_infer_length 39 | self.residual = residual 40 | self.decay_scheme = decay_scheme 41 | if self.residual: 42 | assert decoder_embedding_dim == rnn_size 43 | self.reverse = reverse 44 | self.cell_type = cell_type 45 | self.force_teaching_ratio = force_teaching_ratio 46 | self._output_keep_prob = output_keep_prob 47 | self._input_keep_prob = input_keep_prob 48 | self.is_save = is_save 49 | self.sess = sess 50 | self.att_type = att_type 51 | self.lr=lr 52 | self.close_loss_rate = close_loss_rate 53 | self.build_graph() 54 | self.sess.run(tf.global_variables_initializer()) 55 | self.opt_var = [x for x in tf.global_variables() if 'Ftrl' in x.name or 'Momentum' in x.name] 56 | #print(len(tf.global_variables()), len([x for x in tf.global_variables() if x != self.extra_embedding])) 57 | self.saver = tf.train.Saver([x for x in tf.trainable_variables() if x not in self.extra_embedding_list], max_to_keep = 15) 58 | self.summary_placeholders, self.update_ops, self.summary_op = self.setup_summary() 59 | 60 | # end constructor 61 | 62 | def build_graph(self): 63 | self.register_symbols() 64 | self.add_input_layer() 65 | with tf.variable_scope('decode'): 66 | self.add_decoder_for_training() 67 | with tf.variable_scope('decode', reuse=True): 68 | self.add_decoder_for_prefix_inference() 69 | with tf.variable_scope('decode', reuse=True): 70 | self.add_decoder_for_sample() 71 | with tf.variable_scope('decode', reuse=True): 72 | self.add_decoder_for_prefix_sample() 73 | self.build_embop() 74 | self.build_l2_distance() 75 | self.build_projection() 76 | self.add_assign() 77 | self.build_nearst() 78 | self.add_backward_path() 79 | # end method 80 | 81 | def add_assign(self): 82 | self.assgin_placeholder_list = [] 83 | for i in range(len(self.qid_list)): 84 | self.assgin_placeholder_list.append(tf.placeholder(tf.float32, [1, self.decoder_embedding_dim], name='assgin_placeholder_%d' % i)) 85 | self.assign_op_list = [] 86 | for i in range(len(self.qid_list)): 87 | self.assign_op_list.append(self.extra_embedding_list[i].assign(self.assgin_placeholder_list[i])) 88 | 89 | def add_input_layer(self): 90 | self.X = tf.placeholder(tf.int32, [None, None], name="X") 91 | self.Y = tf.placeholder(tf.int32, [None, None], name="Y") 92 | self.X_seq_len = tf.placeholder(tf.int32, [None], name="X_seq_len") 93 | self.Y_seq_len = tf.placeholder(tf.int32, [None], name="Y_seq_len") 94 | self.input_keep_prob = tf.placeholder(tf.float32,name="input_keep_prob") 95 | self.output_keep_prob = tf.placeholder(tf.float32,name="output_keep_prob") 96 | self.batch_size = tf.shape(self.X)[0] 97 | self.init_memory = tf.zeros([self.batch_size, 1, self.rnn_size]) 98 | self.init_attention = tf.zeros([self.batch_size, self.rnn_size]) 99 | self.global_step = tf.Variable(0, name="global_step", trainable=False) 100 | # end method 101 | 102 | def single_cell(self, reuse=False): 103 | if self.cell_type == 'lstm': 104 | cell = tf.contrib.rnn.LayerNormBasicLSTMCell(self.rnn_size, reuse=reuse) 105 | else: 106 | cell = tf.contrib.rnn.GRUBlockCell(self.rnn_size) 107 | cell = tf.contrib.rnn.DropoutWrapper(cell, self.output_keep_prob, self.input_keep_prob) 108 | if self.residual: 109 | cell = myResidualCell.ResidualWrapper(cell) 110 | return cell 111 | 112 | def processed_decoder_input(self): 113 | main = tf.strided_slice(self.X, [0, 0], [self.batch_size, -1], [1, 1]) # remove last char 114 | decoder_input = tf.concat([tf.fill([self.batch_size, 1], self._x_go), main], 1) 115 | return decoder_input 116 | 117 | def add_decoder_for_training(self): 118 | self.decoder_cell = tf.nn.rnn_cell.MultiRNNCell([self.single_cell() for _ in range(1 * self.n_layers)]) 119 | self.decoder_cell = SelfAttWrapper(self.decoder_cell, self.init_attention, self.init_memory, att_layer = core_layers.Dense(self.rnn_size, name='att_dense'), att_type=self.att_type) 120 | self.decoder_embedding = tf.get_variable('word_embedding', [len(self.dp.X_w2id), self.decoder_embedding_dim], 121 | tf.float32, tf.random_uniform_initializer(-1.0, 1.0)) 122 | #print(decoder_embedding) 123 | #print(decoder_embedding[:1], decoder_embedding[2:]) 124 | self.extra_embedding_list = [] 125 | for i in self.qid_list: 126 | self.extra_embedding_list.append(tf.get_variable('extra_embedding_%d' % i, [1, self.decoder_embedding_dim], 127 | tf.float32, tf.random_uniform_initializer(-1.0, 1.0))) 128 | #print(self.extra_embedding_list) 129 | self.extra_embeddings = tf.concat(self.extra_embedding_list, axis=0) 130 | for i,extra_embedding in enumerate(self.extra_embedding_list): 131 | self.decoder_embedding = tf.concat([self.decoder_embedding[:self.qid_list[i]], extra_embedding, self.decoder_embedding[self.qid_list[i]+1:]], axis=0) 132 | #print(self.decoder_embedding) 133 | training_helper = tf.contrib.seq2seq.ScheduledEmbeddingTrainingHelper( 134 | inputs = tf.nn.embedding_lookup(self.decoder_embedding, self.processed_decoder_input()), 135 | sequence_length = self.X_seq_len, 136 | embedding = self.decoder_embedding, 137 | sampling_probability = 1 - self.force_teaching_ratio, 138 | time_major = False) 139 | training_decoder = tf.contrib.seq2seq.BasicDecoder( 140 | cell = self.decoder_cell, 141 | helper = training_helper, 142 | initial_state = self.decoder_cell.zero_state(self.batch_size, tf.float32), #.clone(cell_state=self.encoder_state), 143 | output_layer = core_layers.Dense(len(self.dp.X_w2id), name='output_dense')) 144 | training_decoder_output, training_final_state, _ = tf.contrib.seq2seq.dynamic_decode( 145 | decoder = training_decoder, 146 | impute_finished = True, 147 | maximum_iterations = tf.reduce_max(self.X_seq_len)) 148 | self.training_logits = training_decoder_output.rnn_output 149 | self.init_prefix_state = training_final_state 150 | self.output_prob = tf.nn.softmax(self.training_logits, -1) 151 | 152 | def add_decoder_for_sample(self): 153 | self.decoder_cell = tf.nn.rnn_cell.MultiRNNCell([self.single_cell() for _ in range(1 * self.n_layers)]) 154 | self.decoder_cell = SelfAttWrapper(self.decoder_cell, self.init_attention, self.init_memory, att_layer = core_layers.Dense(self.rnn_size, name='att_dense', _reuse=True), att_type=self.att_type) 155 | word_embedding = tf.get_variable('word_embedding') 156 | sample_helper = tf.contrib.seq2seq.SampleEmbeddingHelper( 157 | embedding= word_embedding, 158 | start_tokens = tf.tile(tf.constant([self._x_go], dtype=tf.int32), [self.batch_size]), 159 | end_token = self._x_eos) 160 | sample_decoder = tf.contrib.seq2seq.BasicDecoder( 161 | cell = self.decoder_cell, 162 | helper = sample_helper, 163 | initial_state = self.decoder_cell.zero_state(self.batch_size, tf.float32),#.clone(cell_state=self.encoder_state), 164 | output_layer = core_layers.Dense(len(self.dp.X_w2id),name='output_dense', _reuse=True)) 165 | sample_decoder_output, self.sample_final_state, _ = tf.contrib.seq2seq.dynamic_decode( 166 | decoder = sample_decoder, 167 | impute_finished = False, 168 | maximum_iterations = self.max_infer_length) 169 | self.sample_output = sample_decoder_output.sample_id 170 | 171 | def add_decoder_for_prefix_sample(self): 172 | self.decoder_cell = tf.nn.rnn_cell.MultiRNNCell([self.single_cell() for _ in range(1 * self.n_layers)]) 173 | self.decoder_cell = SelfAttWrapper(self.decoder_cell, self.init_attention, self.init_memory, att_layer = core_layers.Dense(self.rnn_size, name='att_dense', _reuse=True), att_type=self.att_type) 174 | word_embedding = tf.get_variable('word_embedding') 175 | prefix_sample_helper = my_helper.MyHelper( 176 | inputs = self.processed_decoder_input(), 177 | sequence_length = self.X_seq_len, 178 | embedding= word_embedding, 179 | end_token = self._x_eos) 180 | sample_prefix_decoder = tf.contrib.seq2seq.BasicDecoder( 181 | cell = self.decoder_cell, 182 | helper = prefix_sample_helper, 183 | initial_state = self.decoder_cell.zero_state(self.batch_size, tf.float32),#.clone(cell_state=self.encoder_state), 184 | output_layer = core_layers.Dense(len(self.dp.X_w2id), name='output_dense', _reuse=True)) 185 | sample_decoder_prefix_output, self.sample_prefix_final_state, _ = tf.contrib.seq2seq.dynamic_decode( 186 | decoder = sample_prefix_decoder, 187 | impute_finished = False, 188 | maximum_iterations = self.max_infer_length) 189 | self.sample_prefix_output = sample_decoder_prefix_output.sample_id 190 | 191 | def add_decoder_for_prefix_inference(self): 192 | self.decoder_cell = tf.nn.rnn_cell.MultiRNNCell([self.single_cell() for _ in range(1 * self.n_layers)]) 193 | self.init_attention_tiled = tf.contrib.seq2seq.tile_batch(self.init_attention, self.beam_width) 194 | self.init_memory_tiled = tf.contrib.seq2seq.tile_batch(self.init_memory, self.beam_width) 195 | 196 | self.decoder_cell = SelfAttWrapper(self.decoder_cell, self.init_attention_tiled, self.init_memory_tiled, att_layer = core_layers.Dense(self.rnn_size, name='att_dense', _reuse=True),att_type=self.att_type) 197 | self.beam_init_state = tf.contrib.seq2seq.tile_batch(self.init_prefix_state, self.beam_width) 198 | my_decoder = tf.contrib.seq2seq.BeamSearchDecoder( 199 | cell = self.decoder_cell, 200 | embedding = self.decoder_embedding, 201 | start_tokens = tf.tile(tf.constant([self._x_go], dtype=tf.int32), [self.batch_size]), 202 | end_token = self._x_eos, 203 | initial_state = self.beam_init_state, 204 | beam_width = self.beam_width, 205 | output_layer = core_layers.Dense(len(self.dp.X_w2id), name='output_dense', _reuse=True), 206 | length_penalty_weight = self.beam_penalty) 207 | 208 | self.prefix_go = tf.placeholder(tf.int32, [None]) 209 | prefix_go_beam = tf.tile(tf.expand_dims(self.prefix_go, 1), [1, self.beam_width]) 210 | prefix_emb = tf.nn.embedding_lookup(self.decoder_embedding, prefix_go_beam) 211 | my_decoder._start_inputs = prefix_emb 212 | predicting_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode( 213 | decoder = my_decoder, 214 | impute_finished = False, 215 | maximum_iterations = self.max_infer_length) 216 | self.prefix_infer_outputs = predicting_decoder_output.predicted_ids 217 | self.score = predicting_decoder_output.beam_search_decoder_output.scores 218 | 219 | 220 | 221 | def add_backward_path(self): 222 | masks = tf.sequence_mask(self.X_seq_len, tf.reduce_max(self.X_seq_len), dtype=tf.float32) 223 | self.loss = tf.contrib.seq2seq.sequence_loss(logits = self.training_logits, 224 | targets = self.X, 225 | weights = masks) 226 | 227 | l1_regularizer = tf.contrib.layers.l1_regularizer(self.l1_reg_lambda) 228 | l2_regularizer = tf.contrib.layers.l2_regularizer(self.l2_reg_lambda) 229 | self.l1_loss = tf.contrib.layers.apply_regularization(l1_regularizer, self.extra_embedding_list) 230 | self.l2_loss = tf.contrib.layers.apply_regularization(l2_regularizer, self.extra_embedding_list) 231 | #print(tf.norm(self.nearest_emb_placeholder-self.extra_embeddings, ord='euclidean')) 232 | self.close_loss = tf.reduce_sum(tf.norm(self.nearest_emb_placeholder-self.extra_embeddings, ord='euclidean')) * self.close_loss_rate 233 | if self.close_loss_rate > 0.0: 234 | self.update_loss = self.close_loss + self.l2_loss + self.l1_loss + tf.contrib.seq2seq.sequence_loss(logits = self.training_logits, 235 | targets = self.Y, 236 | weights = masks) 237 | else: 238 | self.update_loss = self.l2_loss + self.l1_loss + tf.contrib.seq2seq.sequence_loss(logits = self.training_logits, 239 | targets = self.Y, 240 | weights = masks) 241 | self.update_batch_loss = self.l2_loss + self.l1_loss + tf.contrib.seq2seq.sequence_loss(logits = self.training_logits, 242 | targets = self.Y, 243 | weights = masks, 244 | average_across_batch=False) 245 | 246 | self.batch_loss = tf.contrib.seq2seq.sequence_loss(logits = self.training_logits, 247 | targets = self.X, 248 | weights = masks, 249 | average_across_batch=False) 250 | self.time_batch_loss = tf.contrib.seq2seq.sequence_loss(logits = self.training_logits, 251 | targets = self.X, 252 | weights = masks, 253 | average_across_batch=False, 254 | average_across_timesteps=False) 255 | self.time_loss = tf.contrib.seq2seq.sequence_loss(logits = self.training_logits, 256 | targets = self.X, 257 | weights = masks, 258 | average_across_timesteps=False) 259 | params = tf.trainable_variables() 260 | gradients = tf.gradients(self.loss, params) 261 | 262 | update_params = self.extra_embedding_list 263 | self.Dgrad = tf.gradients(self.update_loss, update_params) 264 | self.Dgrad_list = [tf.gradients(self.update_loss, [update_params[k]]) for k in range(len(update_params))] 265 | clipped_gradients, _ = tf.clip_by_global_norm(gradients, self.grad_clip) 266 | self.lbfgs_op = tf.contrib.opt.ScipyOptimizerInterface(self.update_loss, var_list=update_params, method='L-BFGS-B', options={'maxiter': 100,'disp': 0}) 267 | #update_clipped_gradients, _ = tf.clip_by_global_norm(self.Dgrad, 5.0) 268 | #print(self.lbfgs_op) 269 | 270 | self.learning_rate = tf.constant(self.lr) 271 | self.learning_rate = self.get_learning_rate_decay(self.decay_scheme) # decay 272 | self.train_op = tf.train.AdamOptimizer(self.learning_rate).apply_gradients(zip(clipped_gradients, params), global_step=self.global_step) 273 | self.update_op = dict() 274 | lr = self.lr 275 | self.update_op['Adam'] = tf.train.AdamOptimizer(lr).apply_gradients(zip(self.Dgrad, update_params)) 276 | self.update_op['Adadelta'] = tf.train.AdadeltaOptimizer(lr).apply_gradients(zip(self.Dgrad, update_params)) 277 | self.update_op['Adagrad'] = tf.train.AdagradOptimizer(lr).apply_gradients(zip(self.Dgrad, update_params)) 278 | self.update_op['GradientDescent'] = tf.train.GradientDescentOptimizer(lr).apply_gradients(zip(self.Dgrad, update_params)) 279 | self.update_op['Momentum'] = tf.train.MomentumOptimizer(lr, 0.9).apply_gradients(zip(self.Dgrad, update_params)) 280 | self.update_op['Nesterov'] = tf.train.MomentumOptimizer(lr, 0.9, use_nesterov=True).apply_gradients(zip(self.Dgrad, update_params)) 281 | self.update_op['Ftrl'] = tf.train.FtrlOptimizer(lr).apply_gradients(zip(self.Dgrad, update_params)) 282 | self.update_op['ProximalAdagrad'] = tf.train.ProximalAdagradOptimizer(lr).apply_gradients(zip(self.Dgrad, update_params)) 283 | self.update_op['ProximalGradientDescent'] = tf.train.ProximalGradientDescentOptimizer(lr).apply_gradients(zip(self.Dgrad, update_params)) 284 | self.update_op['RMSProp'] = tf.train.RMSPropOptimizer(lr).apply_gradients(zip(self.Dgrad, update_params)) 285 | for k,extra_emb in enumerate(self.extra_embedding_list): 286 | self.update_op['Ftrl_%d' % k] = tf.train.FtrlOptimizer(lr).apply_gradients(zip(self.Dgrad_list[k], [update_params[k]])) 287 | self.update_op['Nesterov_%d' % k] = tf.train.MomentumOptimizer(lr, 0.9, use_nesterov=True).apply_gradients(zip(self.Dgrad_list[k], [update_params[k]])) 288 | self.update_op['GradientDescent_%d' % k] = tf.train.GradientDescentOptimizer(lr).apply_gradients(zip(self.Dgrad_list[k], [update_params[k]])) 289 | self.update_op['Momentum_%d' % k] = tf.train.MomentumOptimizer(lr,0.9).apply_gradients(zip(self.Dgrad_list[k], [update_params[k]])) 290 | self.update_op['Adam_%d' % k] = tf.train.AdamOptimizer(lr).apply_gradients(zip(self.Dgrad_list[k], [update_params[k]])) 291 | 292 | 293 | def build_nearst(self): 294 | self.nearest_emb_placeholder = tf.placeholder(tf.float32, [1, self.decoder_embedding_dim], name='nearest_emb') 295 | 296 | def build_embop(self): 297 | nemb = tf.nn.l2_normalize(self.decoder_embedding, 1) 298 | self.nearby_word = tf.placeholder(dtype=tf.int32) # word id 299 | nearby_emb = tf.gather(nemb, self.nearby_word) 300 | 301 | nearby_dist = tf.matmul(nearby_emb, nemb, transpose_b=True) 302 | #print('cos', nearby_emb,nearby_dist) 303 | self.nearby_dist = nearby_dist 304 | self.nearby_val, self.nearby_idx = tf.nn.top_k(nearby_dist, len(self.dp.X_w2id)) 305 | 306 | def build_l2_distance(self): 307 | nemb = self.decoder_embedding 308 | nearby_emb = tf.gather(nemb, self.nearby_word) 309 | euclidean = tf.sqrt(tf.reduce_sum(tf.square(nearby_emb-nemb), 1)) 310 | self.eu_nearby_val, self.eu_nearby_idx = tf.nn.top_k(-euclidean, len(self.dp.X_w2id)) 311 | 312 | def build_projection(self): 313 | nemb = self.decoder_embedding 314 | nearby_emb = tf.gather(nemb, self.nearby_word) 315 | z_x = (nemb - nearby_emb) 316 | self.g = tf.placeholder(tf.float32, [1, self.decoder_embedding_dim], name='g') 317 | g_norm = tf.nn.l2_normalize(self.g) 318 | #print(z_x, g_norm) 319 | self.proj = tf.matmul(g_norm, z_x, transpose_b=True) 320 | #print(self.proj) 321 | self.nearby_pro_val, self.nearby_pro_idx = tf.nn.top_k(self.proj, len(self.dp.X_w2id)) 322 | 323 | def find_nearnest(self, idx, topk=10): 324 | nearby_val, nearby_idx = self.sess.run([self.nearby_val, self.nearby_idx], {self.nearby_word:idx}) 325 | return nearby_val[:topk], nearby_idx[:topk] 326 | 327 | def register_symbols(self): 328 | self._x_go = self.dp.X_w2id[''] 329 | self._x_eos = self.dp.X_w2id[''] 330 | self._x_pad = self.dp.X_w2id[''] 331 | self._x_unk = self.dp.X_w2id[''] 332 | 333 | 334 | def infer(self, input_word, batch_size=1, is_show=True): 335 | #return ["pass"] 336 | if self.is_jieba: 337 | input_index = list(jieba.cut(input_word)) 338 | else: 339 | input_index = input_word.split(' ') 340 | xx = [char for char in input_index] 341 | if self.reverse: 342 | xx = xx[::-1] 343 | length = [len(xx),] * batch_size 344 | input_indices = [[self.dp.X_w2id.get(char, self._x_unk) for char in xx]] * batch_size 345 | prefix_go = [] 346 | for ipt in input_indices: 347 | prefix_go.append(ipt[-1]) 348 | out_indices, scores = self.sess.run([self.prefix_infer_outputs, self.score], { 349 | self.X: input_indices, self.X_seq_len: length, self.prefix_go: prefix_go, self.input_keep_prob:1, 350 | self.output_keep_prob:1}) 351 | outputs = [] 352 | for idx in range(out_indices.shape[-1]): 353 | eos_id = self.dp.X_w2id[''] 354 | ot = out_indices[0,:,idx] 355 | if eos_id in ot: 356 | ot = ot.tolist() 357 | ot = ot[:ot.index(eos_id)] 358 | if self.reverse: 359 | ot = ot[::-1] 360 | if self.reverse: 361 | output_str = ' '.join([self.dp.X_id2w.get(i, '<-1>') for i in ot]) + ' '+ input_word 362 | else: 363 | output_str = input_word+' ' + ' '.join([self.dp.X_id2w.get(i, '<-1>') for i in ot]) 364 | outputs.append(output_str) 365 | return outputs 366 | 367 | def infer_with_scores(self, input_word, batch_size=1, is_show=True): 368 | #return ["pass"] 369 | if self.is_jieba: 370 | input_index = list(jieba.cut(input_word)) 371 | else: 372 | input_index = input_word.split(' ') 373 | xx = [char for char in input_index] 374 | if self.reverse: 375 | xx = xx[::-1] 376 | length = [len(xx),] * batch_size 377 | input_indices = [[self.dp.X_w2id.get(char, self._x_unk) for char in xx]] * batch_size 378 | prefix_go = [] 379 | for ipt in input_indices: 380 | prefix_go.append(ipt[-1]) 381 | out_indices, scores = self.sess.run([self.prefix_infer_outputs, self.score], { 382 | self.X: input_indices, self.X_seq_len: length, self.prefix_go: prefix_go, self.input_keep_prob:1, 383 | self.output_keep_prob:1}) 384 | outputs = [] 385 | for idx in range(out_indices.shape[-1]): 386 | eos_id = self.dp.X_w2id[''] 387 | ot = out_indices[0,:,idx] 388 | if eos_id in ot: 389 | ot = ot.tolist() 390 | ot = ot[:ot.index(eos_id)] 391 | if self.reverse: 392 | ot = ot[::-1] 393 | if self.reverse: 394 | output_str = ' '.join([self.dp.X_id2w.get(i, '<-1>') for i in ot]) + ' '+ input_word 395 | else: 396 | output_str = input_word+' ' + ' '.join([self.dp.X_id2w.get(i, '<-1>') for i in ot]) 397 | outputs.append(output_str) 398 | return outputs, scores 399 | 400 | def generate(self, batch_size=1, is_show=True): 401 | fake_x = [[1] for _ in range(batch_size)] 402 | out_indices = self.sess.run(self.sample_output, {self.X: fake_x, self.input_keep_prob:1, self.output_keep_prob:1}) 403 | #print(out_indices.shape) 404 | outputs = [] 405 | for ot in out_indices: 406 | eos_id = self.dp.X_w2id[''] 407 | if eos_id in ot: 408 | ot = ot.tolist() 409 | ot = ot[:ot.index(eos_id)] 410 | if self.reverse: 411 | ot = ot[::-1] 412 | if self.reverse: 413 | output_str = ' '.join([self.dp.X_id2w.get(i, '<-1>') for i in ot]) 414 | else: 415 | output_str = ' '.join([self.dp.X_id2w.get(i, '<-1>') for i in ot]) 416 | outputs.append(output_str) 417 | return out_indices, outputs 418 | 419 | def rollout(self, input_word, batch_size=1, is_show=True): 420 | if self.is_jieba: 421 | input_index = list(jieba.cut(input_word)) 422 | else: 423 | input_index = input_word 424 | xx = [char for char in input_index] 425 | if self.reverse: 426 | xx = xx[::-1] 427 | length = [len(xx)+1] * batch_size 428 | input_indices = [[self.dp.X_w2id.get(char, self._x_unk) for char in xx]] * batch_size 429 | input_indices = [x+[self.dp.X_w2id[''],] for x in input_indices] 430 | #print(input_indices) 431 | out_indices = self.sess.run(self.sample_prefix_output, { 432 | self.X: input_indices, self.X_seq_len: length, self.input_keep_prob:1, 433 | self.output_keep_prob:1}) 434 | outputs = [] 435 | for ot in out_indices: 436 | eos_id = self.dp.X_w2id[''] 437 | if eos_id in ot: 438 | ot = ot.tolist() 439 | ot = ot[:ot.index(eos_id)] 440 | if self.reverse: 441 | ot = ot[::-1] 442 | 443 | if self.reverse: 444 | output_str = ' '.join([self.dp.X_id2w.get(i, '<-1>') for i in ot]) 445 | else: 446 | output_str = ' '.join([self.dp.X_id2w.get(i, '<-1>') for i in ot]) 447 | outputs.append(output_str) 448 | return outputs 449 | 450 | def restore(self, path): 451 | self.saver.restore(self.sess, path) 452 | #print('restore %s success' % path) 453 | 454 | def get_learning_rate_decay(self, decay_scheme='luong234'): 455 | num_train_steps = self.dp.num_steps 456 | if decay_scheme == "luong10": 457 | start_decay_step = int(num_train_steps / 2) 458 | remain_steps = num_train_steps - start_decay_step 459 | decay_steps = int(remain_steps / 10) # decay 10 times 460 | decay_factor = 0.5 461 | else: 462 | start_decay_step = int(num_train_steps * 2 / 3) 463 | remain_steps = num_train_steps - start_decay_step 464 | decay_steps = int(remain_steps / 4) # decay 4 times 465 | decay_factor = 0.5 466 | return tf.cond( 467 | self.global_step < start_decay_step, 468 | lambda: self.learning_rate, 469 | lambda: tf.train.exponential_decay( 470 | self.learning_rate, 471 | (self.global_step - start_decay_step), 472 | decay_steps, decay_factor, staircase=True), 473 | name="learning_rate_decay_cond") 474 | 475 | def _opt_init(self): 476 | self.sess.run(tf.variables_initializer(self.opt_var)) 477 | 478 | def setup_summary(self): 479 | train_loss = tf.Variable(0.) 480 | tf.summary.scalar('Train_loss', train_loss) 481 | 482 | test_loss = tf.Variable(0.) 483 | tf.summary.scalar('Test_loss', test_loss) 484 | 485 | bleu_score = tf.Variable(0.) 486 | tf.summary.scalar('BLEU_score', bleu_score) 487 | 488 | tf.summary.scalar('lr_rate', self.learning_rate) 489 | 490 | summary_vars = [train_loss, test_loss, bleu_score] 491 | summary_placeholders = [tf.placeholder(tf.float32) for _ in range(len(summary_vars))] 492 | update_ops = [summary_vars[i].assign(summary_placeholders[i]) for i in range(len(summary_vars))] 493 | summary_op = tf.summary.merge_all() 494 | return summary_placeholders, update_ops, summary_op -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # TIGS: An Inference Algorithm for Text Infilling with Gradient Search 3 | 4 | This repo contains the code and data of the following paper: 5 | >**TIGS: An Inference Algorithm for Text Infilling with Gradient Search**, *Dayiheng Liu, Jie Fu, Pengfei Liu, Jiancheng Lv*, Association for Computational Linguistics. **ACL** 2019 [[arXiv]](https://arxiv.org/abs/1905.10752) 6 | 7 | ## Overview 8 |

9 | 10 | Given a well-trained sequential generative model, generating missing symbols conditioned on the context is challenging for existing greedy approximate inference algorithms. We propose a dramatically different inference approach called Text Infilling with Gradient Search (**TIGS**), in which we search for infilled words based on gradient information to fill in the blanks. To the best of our knowledge, this could be the first inference algorithm that does not require any modification or training of the model and can be broadly used in any sequence generative model to solve the fillin-the-blank tasks. 11 | 12 | 13 | ## Dependencies 14 | 15 | - Jupyter notebook 4.4.0 16 | - Python 3.6 17 | - Tensorflow 1.6.0+ 18 | 19 | ## Quick Start 20 | - Training: Run `TIGS_train.ipynb` 21 | - Inference: Run `TIGS_inference.ipynb` 22 | 23 | ## Trained Model 24 | Download the trained models at the link https://drive.google.com/open?id=1IABzc6ovkR6Uprnl3isSAWf6ax2fLHgH 25 | - The APRC trained model can be found in `Model/APRC` 26 | - The Poem trained model can be found in `Model/Poem` 27 | - The Daily trained model can be found in `Model/Daily` 28 | 29 | ## Dataset 30 | Download the datasets at the link https://drive.google.com/open?id=1GKyBtU0pPysB10wdsqMxYDoQ5CRQIXI8 31 | - The APRC dataset can be found in `Data/APRC` 32 | - The Poem dataset can be found in `Data/Poem` 33 | - The Daily dataset can be found in `Data/Daily` 34 | -------------------------------------------------------------------------------- /TIGS_Inferece.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Import" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": { 14 | "ExecuteTime": { 15 | "end_time": "2019-05-26T01:38:20.588780Z", 16 | "start_time": "2019-05-26T01:38:20.579169Z" 17 | }, 18 | "code_folding": [], 19 | "run_control": { 20 | "marked": false 21 | } 22 | }, 23 | "outputs": [], 24 | "source": [ 25 | "import os\n", 26 | "import numpy as np\n", 27 | "import time\n", 28 | "\n", 29 | "from Util import my_helper\n", 30 | "import copy\n", 31 | "import itertools\n", 32 | "import random\n", 33 | "import pickle as cPickle\n", 34 | "import matplotlib.pyplot as plt\n", 35 | "\n", 36 | "import tensorflow as tf\n", 37 | "from tensorflow.python.layers import core as core_layers\n", 38 | "\n", 39 | "\n", 40 | "from Model import LM\n", 41 | "from Util.myAttWrapper import SelfAttWrapper\n", 42 | "from Util import myResidualCell\n", 43 | "from Util.bleu import BLEU\n", 44 | "from Util.myUtil import *\n", 45 | "\n", 46 | "\n", 47 | "tf.logging.set_verbosity(tf.logging.INFO)\n", 48 | "sess_conf = tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True))\n", 49 | "\n", 50 | "\n" 51 | ] 52 | }, 53 | { 54 | "cell_type": "markdown", 55 | "metadata": {}, 56 | "source": [ 57 | "# Util" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": null, 63 | "metadata": { 64 | "ExecuteTime": { 65 | "end_time": "2019-05-26T01:38:21.877344Z", 66 | "start_time": "2019-05-26T01:38:21.834864Z" 67 | }, 68 | "code_folding": [] 69 | }, 70 | "outputs": [], 71 | "source": [ 72 | "def _construct_blank(num, length, style='random'):\n", 73 | " data = X_indices[:num]\n", 74 | " blank_data = []\n", 75 | " for d in data:\n", 76 | " if style == 'random':\n", 77 | " pos_list = np.sort(random.sample(range(len(d)-2), length)).tolist()\n", 78 | " blank_data.append((d, pos_list))\n", 79 | " elif style == 'middle':\n", 80 | " l = int(length * (len(d) - 1))\n", 81 | " pos_list = [int((len(d)-1-l) / 2.0) + i for i in range(l)]\n", 82 | " blank_data.append((d, pos_list))\n", 83 | " return blank_data\n", 84 | "\n", 85 | "def show_blank(idx, pos):\n", 86 | " t = [len(id2w)+5 if i in pos else a for i,a in enumerate(idx)]\n", 87 | " s = ' '.join([id2w.get(tt,'_') for tt in t])\n", 88 | " return s\n", 89 | "\n", 90 | "def idx2str(idx):\n", 91 | " return \" \".join(id2w.get(idxx, '_') for idxx in idx)\n", 92 | "\n", 93 | "def str2idx(idx):\n", 94 | " idx = idx.strip()\n", 95 | " return [w2id[idxx] for idxx in idx.split(' ')]\n", 96 | "\n", 97 | "def cal_candidate_list(inputs, pos):\n", 98 | " idx = copy.deepcopy(inputs)\n", 99 | " idx[pos] = model.qid_list[0]\n", 100 | " candidate_list = []\n", 101 | " for t in range(len(model.dp.X_w2id)):\n", 102 | " temp = copy.deepcopy(idx)\n", 103 | " temp = [k if k!=model.qid_list[0] else t for k in temp]\n", 104 | " candidate_list.append(temp)\n", 105 | " return candidate_list\n", 106 | "\n", 107 | "\n", 108 | "def replace_list(idx, pos_list, target):\n", 109 | " t = [idxx for idxx in idx]\n", 110 | " if target:\n", 111 | " for i,p in enumerate(pos_list):\n", 112 | " t[p] = target[i]\n", 113 | " else:\n", 114 | " for i,p in enumerate(pos_list):\n", 115 | " t[p] = -1\n", 116 | " return t\n", 117 | "\n", 118 | "def cal_optimal(idx, pos, max_it=-1):\n", 119 | " X_batch = cal_candidate_list(idx, pos)\n", 120 | " X_batch_len = [len(x) for x in X_batch]\n", 121 | " \n", 122 | " if max_it > 0:\n", 123 | " batch_loss = []\n", 124 | " t = 0\n", 125 | " while t+max_it'])\n", 163 | " else:\n", 164 | " o_idx.append(str2idx(infer)[i])\n", 165 | " else:\n", 166 | " o_idx.append(c_idx[i])\n", 167 | " init_word = [id2w[o_idx[i]] for i in pos] \n", 168 | " return o_idx, idx2str(o_idx), init_word\n", 169 | "\n", 170 | "def _init_data(name):\n", 171 | " w2id, id2w = cPickle.load(open('Data/%s/w2id_id2w.pkl' % name,'rb'))\n", 172 | " X_indices = cPickle.load(open('Data/%s/index.pkl' % name,'rb'))\n", 173 | " return X_indices, w2id, id2w\n", 174 | "\n", 175 | "def _init_model(name, lr=10.0, l1_reg_lambda=0.00, l2_reg_lambda=0.00, close_loss_rate=0.00):\n", 176 | " qid_list = cPickle.load(open('Data/%s/qid_list.pkl'%name,'rb'))\n", 177 | " qid_list = [w2id[w] for w in qid_list]\n", 178 | " rnn_size = dict()\n", 179 | " rnn_size['Poem'] = 512\n", 180 | " rnn_size['Daily'] = 512 \n", 181 | " rnn_size['APRC'] = 1024\n", 182 | " \n", 183 | " num_layer = dict()\n", 184 | " num_layer['Poem'] = 2\n", 185 | " num_layer['Daily'] = 1\n", 186 | " num_layer['APRC'] = 1\n", 187 | " \n", 188 | " max_infer_length = dict()\n", 189 | " max_infer_length['Poem'] = 33\n", 190 | " max_infer_length['Daily'] = 50\n", 191 | " max_infer_length['APRC'] = 36\n", 192 | " \n", 193 | " model_iter = dict()\n", 194 | " model_iter['Poem'] = 30\n", 195 | " model_iter['Daily'] = 30 \n", 196 | " model_iter['APRC'] = 20\n", 197 | " \n", 198 | " assert name in ['Poem','Daily', 'APRC']\n", 199 | "\n", 200 | " BATCH_SIZE = 256\n", 201 | " NUM_EPOCH = 30\n", 202 | " train_dir ='Model/%s' % name\n", 203 | " dp = LM_DP(X_indices, w2id, BATCH_SIZE, n_epoch=NUM_EPOCH)\n", 204 | " g = tf.Graph() \n", 205 | " sess = tf.Session(graph=g, config=sess_conf) \n", 206 | " with sess.as_default():\n", 207 | " with sess.graph.as_default():\n", 208 | " model = LM(\n", 209 | " dp = dp,\n", 210 | " rnn_size = rnn_size[name],\n", 211 | " n_layers = num_layer[name],\n", 212 | " decoder_embedding_dim = rnn_size[name],\n", 213 | " cell_type='lstm',\n", 214 | " close_loss_rate = close_loss_rate,\n", 215 | " max_infer_length = max_infer_length[name],\n", 216 | " att_type='B',\n", 217 | " qid_list = qid_list,\n", 218 | " lr = lr,\n", 219 | " l1_reg_lambda = l1_reg_lambda,\n", 220 | " l2_reg_lambda = l2_reg_lambda,\n", 221 | " is_save = False,\n", 222 | " residual = True,\n", 223 | " is_jieba = False,\n", 224 | " sess=sess\n", 225 | " )\n", 226 | "\n", 227 | "\n", 228 | " util = LM_util(dp=dp, model=model)\n", 229 | " model.restore('Model/%s/model-%d'% (name,model_iter[name])) # restore pre-train model\n", 230 | " return model\n", 231 | "\n", 232 | "\n", 233 | "\n", 234 | "def _reload(name):\n", 235 | " rnn_size = dict()\n", 236 | " rnn_size['Poem'] = 512\n", 237 | " rnn_size['Daily'] = 512 \n", 238 | " rnn_size['APRC'] = 1024\n", 239 | " \n", 240 | " num_layer = dict()\n", 241 | " num_layer['Poem'] = 2\n", 242 | " num_layer['Daily'] = 1\n", 243 | " num_layer['APRC'] = 1\n", 244 | " \n", 245 | " max_infer_length = dict()\n", 246 | " max_infer_length['Poem'] = 33\n", 247 | " max_infer_length['Daily'] = 50\n", 248 | " max_infer_length['APRC'] = 36\n", 249 | " \n", 250 | " model_iter = dict()\n", 251 | " model_iter['Poem'] = 30\n", 252 | " model_iter['Daily'] = 30 \n", 253 | " model_iter['APRC'] = 20\n", 254 | " \n", 255 | " assert name in ['Poem','Daily', 'APRC']\n", 256 | "\n", 257 | " model.restore('Model/%s/model-%d'% (name,model_iter[name]))" 258 | ] 259 | }, 260 | { 261 | "cell_type": "markdown", 262 | "metadata": {}, 263 | "source": [ 264 | "# TIGS" 265 | ] 266 | }, 267 | { 268 | "cell_type": "code", 269 | "execution_count": null, 270 | "metadata": { 271 | "ExecuteTime": { 272 | "end_time": "2019-05-26T01:38:22.745667Z", 273 | "start_time": "2019-05-26T01:38:22.718086Z" 274 | } 275 | }, 276 | "outputs": [], 277 | "source": [ 278 | "def cal_optimizer_gibbs(inputs, pos_list, init_word, name='Nesterov', epoch=10, is_show=True, top_k=100, upper_size=2, distance='l2'):\n", 279 | " total_tic = time.time()\n", 280 | " init_time = 0.0\n", 281 | " update_time = 0.0\n", 282 | " assign_time = 0.0\n", 283 | " search_time = 0.0\n", 284 | " cal_pre_time = 0.0\n", 285 | " cal_next_time = 0.0\n", 286 | " \n", 287 | " tic = time.time()\n", 288 | " upper_cnt = 0\n", 289 | " \n", 290 | " # prepare\n", 291 | " K = len(pos_list)\n", 292 | " assert K <= len(model.qid_list)\n", 293 | " idx = copy.deepcopy(inputs)\n", 294 | " for i,pos in enumerate(pos_list):\n", 295 | " idx[pos] = model.qid_list[i]\n", 296 | " if init_word:\n", 297 | " pre_p_list = [w2id[t] for t in init_word]\n", 298 | " else:\n", 299 | " pre_p_list = model.qid_list[:K]\n", 300 | " next_p_list = []\n", 301 | " pre_sentence = replace_list(idx, pos_list, pre_p_list)\n", 302 | " epoch_sentence = replace_list(idx, pos_list, pre_p_list)\n", 303 | " word_emb = model.sess.run(model.decoder_embedding)\n", 304 | " \n", 305 | " # init specific embedding \n", 306 | " if init_word:\n", 307 | " feed_dict = dict()\n", 308 | " for j in range(K):\n", 309 | " feed_dict[model.assgin_placeholder_list[j]] = word_emb[[w2id[init_word[j]]]]\n", 310 | " model.sess.run(model.assign_op_list[:K],feed_dict)\n", 311 | " init_time += time.time()-tic\n", 312 | "\n", 313 | " # search\n", 314 | " for i in range(epoch):\n", 315 | " if i > 1 and epoch_sentence == replace_list(idx, pos_list, pre_p_list):\n", 316 | " upper_cnt += 1\n", 317 | " if upper_cnt >= upper_size:\n", 318 | " if is_show:\n", 319 | " print('total_epoch %d'% (i+1))\n", 320 | " break\n", 321 | " else:\n", 322 | " upper_cnt = 0\n", 323 | " epoch_sentence = replace_list(idx, pos_list, pre_p_list)\n", 324 | " if is_show:\n", 325 | " print('epoch %d :' %(i+1), idx2str(epoch_sentence))\n", 326 | " ep_tic = time.time() \n", 327 | " for k in range(K):\n", 328 | " # O-step\n", 329 | " pre_sentence = replace_list(idx, pos_list, pre_p_list)\n", 330 | " tic = time.time()\n", 331 | " if distance == 'cos':\n", 332 | " v, o = model.sess.run([model.nearby_val, model.nearby_idx], {model.nearby_word:[model.qid_list[k]]})\n", 333 | " nearset = o[0][1]\n", 334 | " else:\n", 335 | " v, o = model.sess.run([model.eu_nearby_val, model.eu_nearby_idx], {model.nearby_word:[model.qid_list[k]]})\n", 336 | " nearset = o[1]\n", 337 | " loss, _ = model.sess.run([model.update_loss, \n", 338 | " model.update_op[name+'_%d' % k]], \n", 339 | " {model.X: [idx], \n", 340 | " model.X_seq_len: [len(idx)], \n", 341 | " model.Y:[pre_sentence],\n", 342 | " model.output_keep_prob:1,\n", 343 | " model.input_keep_prob:1,\n", 344 | " model.nearest_emb_placeholder:word_emb[[nearset]]})\n", 345 | " update_time += time.time() - tic\n", 346 | "\n", 347 | "\n", 348 | " # P-step\n", 349 | " if i % 1 == 0:\n", 350 | " # candidate\n", 351 | " tic = time.time()\n", 352 | " if distance == 'cos':\n", 353 | " v, o = model.sess.run([model.nearby_val, model.nearby_idx], {model.nearby_word:[model.qid_list[k]]})\n", 354 | " candi_pos = o[0][1:top_k+1].tolist() + [pre_p_list[k]]\n", 355 | " else:\n", 356 | " v, o = model.sess.run([model.eu_nearby_val, model.eu_nearby_idx], {model.nearby_word:[model.qid_list[k]]})\n", 357 | " candi_pos = o[1:top_k+1].tolist() + [pre_p_list[k]]\n", 358 | " \n", 359 | " candi_list = [[pre_p_list[j] if j!=k else t for j in range(len(pre_p_list))] for t in candi_pos]\n", 360 | " next_sentences = [replace_list(idx, pos_list, candi) for candi in candi_list]\n", 361 | " search_time += time.time() - tic\n", 362 | " \n", 363 | " # cal loss\n", 364 | " tic = time.time()\n", 365 | " next_loss_list = model.sess.run(model.batch_loss, {model.X: next_sentences, \n", 366 | " model.X_seq_len: [len(idx) for j in range(len(next_sentences))], \n", 367 | " model.output_keep_prob:1,\n", 368 | " model.input_keep_prob:1})\n", 369 | " argmin_idx = np.argmin(next_loss_list)\n", 370 | " next_p_pos = candi_pos[argmin_idx]\n", 371 | " cal_next_time += time.time()-tic\n", 372 | " # update\n", 373 | " tic = time.time()\n", 374 | " if next_p_pos != pre_p_list[k]:\n", 375 | " model.sess.run(model.assign_op_list[k],{model.assgin_placeholder_list[k]:word_emb[[next_p_pos]]})\n", 376 | " pre_p_list[k] = next_p_pos\n", 377 | " \n", 378 | " assign_time += time.time() - tic\n", 379 | " if is_show:\n", 380 | " print('epoch %d_%d :' % (i+1, k),idx2str(replace_list(idx, pos_list, pre_p_list)))\n", 381 | " \n", 382 | " tic = time.time() \n", 383 | " pre_sentence = replace_list(idx, pos_list, pre_p_list)\n", 384 | " loss = model.sess.run(model.loss, {model.X: [pre_sentence], \n", 385 | " model.X_seq_len: [len(idx)], \n", 386 | " model.output_keep_prob:1,\n", 387 | " model.input_keep_prob:1})\n", 388 | " cal_pre_time += time.time() - tic\n", 389 | " total_time = time.time() - total_tic\n", 390 | " \n", 391 | " return pre_p_list, idx2str(pre_p_list),loss" 392 | ] 393 | }, 394 | { 395 | "cell_type": "code", 396 | "execution_count": null, 397 | "metadata": { 398 | "ExecuteTime": { 399 | "end_time": "2019-05-26T01:38:23.739860Z", 400 | "start_time": "2019-05-26T01:38:23.434859Z" 401 | } 402 | }, 403 | "outputs": [], 404 | "source": [ 405 | "# initialize blank with left-to-right greedy beam search\n", 406 | "f_init = cPickle.load(open('results/_URNN-f_res.pkl','rb')) " 407 | ] 408 | }, 409 | { 410 | "cell_type": "markdown", 411 | "metadata": {}, 412 | "source": [ 413 | "# Inference" 414 | ] 415 | }, 416 | { 417 | "cell_type": "code", 418 | "execution_count": null, 419 | "metadata": { 420 | "ExecuteTime": { 421 | "end_time": "2019-05-26T01:42:54.580893Z", 422 | "start_time": "2019-05-26T01:42:01.554969Z" 423 | } 424 | }, 425 | "outputs": [], 426 | "source": [ 427 | "task_name = 'APRC' # 'Daily', 'Poem'\n", 428 | "assert task_name in ['Poem','Daily', 'APRC']\n", 429 | "X_indices, w2id, id2w = _init_data(task_name)\n", 430 | "model = _init_model(task_name, lr=10.0)\n", 431 | "\n", 432 | " " 433 | ] 434 | }, 435 | { 436 | "cell_type": "code", 437 | "execution_count": null, 438 | "metadata": { 439 | "ExecuteTime": { 440 | "end_time": "2019-05-26T01:44:42.226194Z", 441 | "start_time": "2019-05-26T01:44:12.389464Z" 442 | } 443 | }, 444 | "outputs": [], 445 | "source": [ 446 | "import random\n", 447 | "is_init = True\n", 448 | "for length_ratio in [0.25, 0.5, 0.75]:\n", 449 | " for style in ['random', 'middle']:\n", 450 | " blank_data = cPickle.load(open('Data/%s/%d_%s.pkl'%(task_name, int(length_ratio*100), style),'rb'))\n", 451 | " i = random.sample(range(5000), 1)[0]\n", 452 | " idx, pos_list = blank_data[i]\n", 453 | " prefix = '%s_%d_%s' % (task_name, int(length_ratio*100), style)\n", 454 | " model._opt_init()\n", 455 | " if is_init:\n", 456 | " init_word = [id2w[f_init[prefix+'_URNN-f'][i][p]] for p in pos_list]\n", 457 | "\n", 458 | " else:\n", 459 | " init_word = random.sample(w2id.keys(), len(pos_list))\n", 460 | " sid, sw, loss = cal_optimizer_gibbs(idx, pos_list, init_word = init_word, is_show=False)\n", 461 | " print('Template:', show_blank(idx, pos_list))\n", 462 | " print('GroundTruth:', idx2str(idx))\n", 463 | " print('TIGS:', idx2str(replace_list(idx, pos_list, sid)))\n", 464 | " print('')" 465 | ] 466 | }, 467 | { 468 | "cell_type": "code", 469 | "execution_count": null, 470 | "metadata": {}, 471 | "outputs": [], 472 | "source": [] 473 | } 474 | ], 475 | "metadata": { 476 | "hide_input": false, 477 | "kernelspec": { 478 | "display_name": "Python 3", 479 | "language": "python", 480 | "name": "python3" 481 | }, 482 | "language_info": { 483 | "codemirror_mode": { 484 | "name": "ipython", 485 | "version": 3 486 | }, 487 | "file_extension": ".py", 488 | "mimetype": "text/x-python", 489 | "name": "python", 490 | "nbconvert_exporter": "python", 491 | "pygments_lexer": "ipython3", 492 | "version": "3.5.2" 493 | }, 494 | "toc": { 495 | "base_numbering": 1, 496 | "nav_menu": {}, 497 | "number_sections": true, 498 | "sideBar": true, 499 | "skip_h1_title": false, 500 | "title_cell": "Table of Contents", 501 | "title_sidebar": "Contents", 502 | "toc_cell": false, 503 | "toc_position": { 504 | "height": "calc(100% - 180px)", 505 | "left": "10px", 506 | "top": "150px", 507 | "width": "165px" 508 | }, 509 | "toc_section_display": true, 510 | "toc_window_display": true 511 | }, 512 | "varInspector": { 513 | "cols": { 514 | "lenName": 16, 515 | "lenType": 16, 516 | "lenVar": 40 517 | }, 518 | "kernels_config": { 519 | "python": { 520 | "delete_cmd_postfix": "", 521 | "delete_cmd_prefix": "del ", 522 | "library": "var_list.py", 523 | "varRefreshCmd": "print(var_dic_list())" 524 | }, 525 | "r": { 526 | "delete_cmd_postfix": ") ", 527 | "delete_cmd_prefix": "rm(", 528 | "library": "var_list.r", 529 | "varRefreshCmd": "cat(var_dic_list()) " 530 | } 531 | }, 532 | "types_to_exclude": [ 533 | "module", 534 | "function", 535 | "builtin_function_or_method", 536 | "instance", 537 | "_Feature" 538 | ], 539 | "window_display": false 540 | } 541 | }, 542 | "nbformat": 4, 543 | "nbformat_minor": 2 544 | } 545 | -------------------------------------------------------------------------------- /TIGS_train.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "ExecuteTime": { 8 | "end_time": "2019-05-26T01:58:12.294665Z", 9 | "start_time": "2019-05-26T01:58:10.019832Z" 10 | }, 11 | "code_folding": [], 12 | "run_control": { 13 | "marked": false 14 | } 15 | }, 16 | "outputs": [], 17 | "source": [ 18 | "import os\n", 19 | "import numpy as np\n", 20 | "import time\n", 21 | "\n", 22 | "from Util import my_helper\n", 23 | "import copy\n", 24 | "import itertools\n", 25 | "import random\n", 26 | "import pickle as cPickle\n", 27 | "import matplotlib.pyplot as plt\n", 28 | "\n", 29 | "import tensorflow as tf\n", 30 | "from tensorflow.python.layers import core as core_layers\n", 31 | "\n", 32 | "\n", 33 | "from Model import LM\n", 34 | "from Util.myAttWrapper import SelfAttWrapper\n", 35 | "from Util import myResidualCell\n", 36 | "from Util.bleu import BLEU\n", 37 | "from Util.myUtil import *\n", 38 | "\n", 39 | "\n", 40 | "tf.logging.set_verbosity(tf.logging.INFO)\n", 41 | "sess_conf = tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True))\n", 42 | "\n", 43 | "\n" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": null, 49 | "metadata": { 50 | "ExecuteTime": { 51 | "end_time": "2019-05-26T01:58:12.309216Z", 52 | "start_time": "2019-05-26T01:58:12.297026Z" 53 | } 54 | }, 55 | "outputs": [], 56 | "source": [ 57 | "def idx2str(idx):\n", 58 | " return \" \".join(id2w[idxx] for idxx in idx)\n", 59 | "\n", 60 | "def str2idx(idx):\n", 61 | " return [w2id[idxx] for idxx in idx]\n", 62 | "\n", 63 | "def _init_data(name):\n", 64 | " w2id, id2w = cPickle.load(open('Data/%s/w2id_id2w.pkl' % name,'rb'))\n", 65 | " X_indices = cPickle.load(open('Data/%s/index.pkl' % name,'rb'))\n", 66 | " return X_indices, w2id, id2w\n", 67 | "\n", 68 | "def _train_model(name, lr=10.0, l1_reg_lambda=0.00, l2_reg_lambda=0.00, close_loss_rate=0.00):\n", 69 | " qid_list = cPickle.load(open('Data/%s/qid_list.pkl'%name,'rb'))\n", 70 | " qid_list = [w2id[w] for w in qid_list]\n", 71 | " rnn_size = dict()\n", 72 | " rnn_size['Poem'] = 512\n", 73 | " rnn_size['Daily'] = 512 \n", 74 | " rnn_size['APRC'] = 1024\n", 75 | " \n", 76 | " num_layer = dict()\n", 77 | " num_layer['Poem'] = 2\n", 78 | " num_layer['Daily'] = 1\n", 79 | " num_layer['APRC'] = 1\n", 80 | " \n", 81 | " max_infer_length = dict()\n", 82 | " max_infer_length['Poem'] = 33\n", 83 | " max_infer_length['Daily'] = 50\n", 84 | " max_infer_length['APRC'] = 36\n", 85 | " \n", 86 | " model_iter = dict()\n", 87 | " model_iter['Poem'] = 30\n", 88 | " model_iter['Daily'] = 30 \n", 89 | " model_iter['APRC'] = 20\n", 90 | " \n", 91 | " assert name in ['Poem','Daily', 'APRC']\n", 92 | "\n", 93 | " BATCH_SIZE = 256\n", 94 | " NUM_EPOCH = 30\n", 95 | " train_dir ='Model/%s' % name\n", 96 | " dp = LM_DP(X_indices, w2id, BATCH_SIZE, n_epoch=NUM_EPOCH)\n", 97 | " g = tf.Graph() \n", 98 | " sess = tf.Session(graph=g, config=sess_conf) \n", 99 | " with sess.as_default():\n", 100 | " with sess.graph.as_default():\n", 101 | " model = LM(\n", 102 | " dp = dp,\n", 103 | " rnn_size = rnn_size[name],\n", 104 | " n_layers = num_layer[name],\n", 105 | " decoder_embedding_dim = rnn_size[name],\n", 106 | " cell_type='lstm',\n", 107 | " close_loss_rate = close_loss_rate,\n", 108 | " max_infer_length = max_infer_length[name],\n", 109 | " att_type='B',\n", 110 | " qid_list = qid_list,\n", 111 | " lr = lr,\n", 112 | " l1_reg_lambda = l1_reg_lambda,\n", 113 | " l2_reg_lambda = l2_reg_lambda,\n", 114 | " is_save = True,\n", 115 | " residual = True,\n", 116 | " is_jieba = False,\n", 117 | " sess=sess\n", 118 | " )\n", 119 | "\n", 120 | "\n", 121 | " util = LM_util(dp=dp, model=model)\n", 122 | " util.fit(train_dir=train_dir, is_bleu=False)\n", 123 | " return model" 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": null, 129 | "metadata": { 130 | "ExecuteTime": { 131 | "start_time": "2019-05-26T01:58:09.594Z" 132 | } 133 | }, 134 | "outputs": [], 135 | "source": [ 136 | "task_name = 'APRC' # 'Daily', 'Poem'\n", 137 | "assert task_name in ['Poem','Daily', 'APRC']\n", 138 | "X_indices, w2id, id2w = _init_data(task_name)\n", 139 | "model = _train_model(task_name, lr=10.0)\n", 140 | "\n", 141 | " " 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": null, 147 | "metadata": {}, 148 | "outputs": [], 149 | "source": [] 150 | } 151 | ], 152 | "metadata": { 153 | "hide_input": false, 154 | "kernelspec": { 155 | "display_name": "Python 3", 156 | "language": "python", 157 | "name": "python3" 158 | }, 159 | "language_info": { 160 | "codemirror_mode": { 161 | "name": "ipython", 162 | "version": 3 163 | }, 164 | "file_extension": ".py", 165 | "mimetype": "text/x-python", 166 | "name": "python", 167 | "nbconvert_exporter": "python", 168 | "pygments_lexer": "ipython3", 169 | "version": "3.5.2" 170 | }, 171 | "toc": { 172 | "base_numbering": 1, 173 | "nav_menu": {}, 174 | "number_sections": true, 175 | "sideBar": true, 176 | "skip_h1_title": false, 177 | "title_cell": "Table of Contents", 178 | "title_sidebar": "Contents", 179 | "toc_cell": false, 180 | "toc_position": {}, 181 | "toc_section_display": true, 182 | "toc_window_display": false 183 | }, 184 | "varInspector": { 185 | "cols": { 186 | "lenName": 16, 187 | "lenType": 16, 188 | "lenVar": 40 189 | }, 190 | "kernels_config": { 191 | "python": { 192 | "delete_cmd_postfix": "", 193 | "delete_cmd_prefix": "del ", 194 | "library": "var_list.py", 195 | "varRefreshCmd": "print(var_dic_list())" 196 | }, 197 | "r": { 198 | "delete_cmd_postfix": ") ", 199 | "delete_cmd_prefix": "rm(", 200 | "library": "var_list.r", 201 | "varRefreshCmd": "cat(var_dic_list()) " 202 | } 203 | }, 204 | "types_to_exclude": [ 205 | "module", 206 | "function", 207 | "builtin_function_or_method", 208 | "instance", 209 | "_Feature" 210 | ], 211 | "window_display": false 212 | } 213 | }, 214 | "nbformat": 4, 215 | "nbformat_minor": 2 216 | } 217 | -------------------------------------------------------------------------------- /Util/GSutil.py: -------------------------------------------------------------------------------- 1 | def _construct_blank(X_indices, num, length, style='random'): 2 | data = X_indices[:num] 3 | blank_data = [] 4 | for d in data: 5 | if style == 'random': 6 | pos_list = np.sort(random.sample(range(len(d)-2), length)).tolist() 7 | blank_data.append((d, pos_list)) 8 | elif style == 'middle': 9 | l = int(length * (len(d) - 1)) 10 | pos_list = [int((len(d)-1-l) / 2.0) + i for i in range(l)] 11 | blank_data.append((d, pos_list)) 12 | return blank_data 13 | 14 | def show_blank(idx, pos): 15 | t = [len(id2w)+5 if i in pos else a for i,a in enumerate(idx)] 16 | s = ' '.join([id2w.get(tt,'_') for tt in t]) 17 | return s 18 | 19 | def idx2str(idx, id2w): 20 | return " ".join(id2w.get(idxx, '_') for idxx in idx) 21 | 22 | def str2idx(idx, w2id): 23 | idx = idx.strip() 24 | return [w2id[idxx] for idxx in idx.split(' ')] 25 | 26 | def cal_candidate_list(model, inputs, pos): 27 | idx = copy.deepcopy(inputs) 28 | idx[pos] = model.qid_list[0] 29 | candidate_list = [] 30 | for t in range(len(model.dp.X_w2id)): 31 | temp = copy.deepcopy(idx) 32 | temp = [k if k!=model.qid_list[0] else t for k in temp] 33 | candidate_list.append(temp) 34 | return candidate_list 35 | 36 | 37 | def replace_list(idx, pos_list, target): 38 | t = [idxx for idxx in idx] 39 | if target: 40 | for i,p in enumerate(pos_list): 41 | t[p] = target[i] 42 | else: 43 | for i,p in enumerate(pos_list): 44 | t[p] = -1 45 | return t 46 | 47 | def cal_optimal(model, idx, pos, max_it=-1): 48 | X_batch = cal_candidate_list(idx, pos) 49 | X_batch_len = [len(x) for x in X_batch] 50 | 51 | if max_it > 0: 52 | batch_loss = [] 53 | t = 0 54 | while t+max_it']) 95 | else: 96 | o_idx.append(str2idx(infer)[i]) 97 | else: 98 | o_idx.append(c_idx[i]) 99 | init_word = [id2w[o_idx[i]] for i in pos] 100 | return o_idx, idx2str(o_idx), init_word 101 | 102 | def _init_data(name): 103 | w2id, id2w = cPickle.load(open('Data/%s/w2id_id2w.pkl' % name,'rb')) 104 | X_indices = cPickle.load(open('Data/%s/index.pkl' % name,'rb')) 105 | return X_indices, w2id, id2w 106 | 107 | def _init_model(name, lr=5.0, l1_reg_lambda=0.01, l2_reg_lambda=0.01, qid_list=[]): 108 | rnn_size = dict() 109 | rnn_size['SM'] = 1024 110 | rnn_size['Poem'] = 512 111 | rnn_size['Daily'] = 512 112 | rnn_size['APRC'] = 1024 113 | 114 | num_layer = dict() 115 | num_layer['SM'] = 2 116 | num_layer['Poem'] = 2 117 | num_layer['Daily'] = 1 118 | num_layer['APRC'] = 1 119 | 120 | max_infer_length = dict() 121 | max_infer_length['SM'] = 35 122 | max_infer_length['Poem'] = 33 123 | max_infer_length['Daily'] = 50 124 | max_infer_length['APRC'] = 36 125 | 126 | model_iter = dict() 127 | model_iter['SM'] = 30 128 | model_iter['Poem'] = 30 129 | model_iter['Daily'] = 30 130 | model_iter['APRC'] = 20 131 | 132 | assert name in ['SM','Poem','Daily', 'APRC'] 133 | 134 | BATCH_SIZE = 256 135 | NUM_EPOCH = 30 136 | train_dir ='Model/%s' % name 137 | dp = LM_DP(X_indices, w2id, BATCH_SIZE, n_epoch=NUM_EPOCH) 138 | g = tf.Graph() 139 | sess = tf.Session(graph=g, config=sess_conf) 140 | with sess.as_default(): 141 | with sess.graph.as_default(): 142 | model = LM( 143 | dp = dp, 144 | rnn_size = rnn_size[name], 145 | n_layers = num_layer[name], 146 | decoder_embedding_dim = rnn_size[name], 147 | cell_type='lstm', 148 | max_infer_length = max_infer_length[name], 149 | att_type='B', 150 | qid_list = qid_list, 151 | lr = lr, 152 | l1_reg_lambda = l1_reg_lambda, 153 | l2_reg_lambda = l2_reg_lambda, 154 | is_save = False, 155 | residual = True, 156 | is_jieba = False, 157 | sess=sess 158 | ) 159 | #print(tf.global_variables()) 160 | #print([var for var in tf.global_variables() if 'Nesterov' in var.name]) 161 | 162 | util = LM_util(dp=dp, model=model) 163 | model.restore('Model/%s/model-%d'% (name,model_iter[name])) 164 | return model#, tf.global_variables() 165 | 166 | 167 | 168 | def _reload(model, name): 169 | rnn_size = dict() 170 | rnn_size['SM'] = 1024 171 | rnn_size['Poem'] = 512 172 | rnn_size['Daily'] = 512 173 | rnn_size['APRC'] = 1024 174 | 175 | num_layer = dict() 176 | num_layer['SM'] = 2 177 | num_layer['Poem'] = 2 178 | num_layer['Daily'] = 1 179 | num_layer['APRC'] = 1 180 | 181 | max_infer_length = dict() 182 | max_infer_length['SM'] = 35 183 | max_infer_length['Poem'] = 33 184 | max_infer_length['Daily'] = 50 185 | max_infer_length['APRC'] = 36 186 | 187 | model_iter = dict() 188 | model_iter['SM'] = 30 189 | model_iter['Poem'] = 30 190 | model_iter['Daily'] = 30 191 | model_iter['APRC'] = 20 192 | 193 | assert name in ['SM','Poem','Daily', 'APRC'] 194 | 195 | model.restore('Model/%s/model-%d'% (name,model_iter[name])) -------------------------------------------------------------------------------- /Util/__pycache__/bleu.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dayihengliu/Text-Infilling-Gradient-Search/af89b634fa9d74222d29ed5ef0b91da46533dd62/Util/__pycache__/bleu.cpython-35.pyc -------------------------------------------------------------------------------- /Util/__pycache__/myAttWrapper.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dayihengliu/Text-Infilling-Gradient-Search/af89b634fa9d74222d29ed5ef0b91da46533dd62/Util/__pycache__/myAttWrapper.cpython-35.pyc -------------------------------------------------------------------------------- /Util/__pycache__/myResidualCell.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dayihengliu/Text-Infilling-Gradient-Search/af89b634fa9d74222d29ed5ef0b91da46533dd62/Util/__pycache__/myResidualCell.cpython-35.pyc -------------------------------------------------------------------------------- /Util/__pycache__/myUtil.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dayihengliu/Text-Infilling-Gradient-Search/af89b634fa9d74222d29ed5ef0b91da46533dd62/Util/__pycache__/myUtil.cpython-35.pyc -------------------------------------------------------------------------------- /Util/__pycache__/my_helper.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dayihengliu/Text-Infilling-Gradient-Search/af89b634fa9d74222d29ed5ef0b91da46533dd62/Util/__pycache__/my_helper.cpython-35.pyc -------------------------------------------------------------------------------- /Util/bleu.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import codecs 3 | import os 4 | import math 5 | import operator 6 | import json 7 | 8 | 9 | 10 | def fetch_data(cand, ref): 11 | """ Store each reference and candidate sentences as a list """ 12 | references = [] 13 | if '.txt' in ref: 14 | reference_file = codecs.open(ref, 'r', 'utf-8') 15 | references.append(reference_file.readlines()) 16 | else: 17 | for root, dirs, files in os.walk(ref): 18 | for f in files: 19 | reference_file = codecs.open(os.path.join(root, f), 'r', 'utf-8') 20 | references.append(reference_file.readlines()) 21 | candidate_file = codecs.open(cand, 'r', 'utf-8') 22 | candidate = candidate_file.readlines() 23 | return candidate, references 24 | 25 | 26 | def count_ngram(candidate, references, n): 27 | clipped_count = 0 28 | count = 0 29 | r = 0 30 | c = 0 31 | for si in range(len(candidate)): 32 | # Calculate precision for each sentence 33 | ref_counts = [] 34 | ref_lengths = [] 35 | # Build dictionary of ngram counts 36 | for reference in references: 37 | ref_sentence = reference[si] 38 | ngram_d = {} 39 | words = ref_sentence.strip().split() 40 | ref_lengths.append(len(words)) 41 | limits = len(words) - n + 1 42 | # loop through the sentance consider the ngram length 43 | for i in range(limits): 44 | ngram = ' '.join(words[i:i+n]).lower() 45 | if ngram in ngram_d.keys(): 46 | ngram_d[ngram] += 1 47 | else: 48 | ngram_d[ngram] = 1 49 | ref_counts.append(ngram_d) 50 | # candidate 51 | cand_sentence = candidate[si] 52 | cand_dict = {} 53 | words = cand_sentence.strip().split() 54 | limits = len(words) - n + 1 55 | for i in range(0, limits): 56 | ngram = ' '.join(words[i:i + n]).lower() 57 | if ngram in cand_dict: 58 | cand_dict[ngram] += 1 59 | else: 60 | cand_dict[ngram] = 1 61 | clipped_count += clip_count(cand_dict, ref_counts) 62 | count += limits 63 | r += best_length_match(ref_lengths, len(words)) 64 | c += len(words) 65 | if clipped_count == 0: 66 | pr = 0 67 | else: 68 | pr = float(clipped_count) / count 69 | bp = brevity_penalty(c, r) 70 | return pr, bp 71 | 72 | 73 | def clip_count(cand_d, ref_ds): 74 | """Count the clip count for each ngram considering all references""" 75 | count = 0 76 | for m in cand_d.keys(): 77 | m_w = cand_d[m] 78 | m_max = 0 79 | for ref in ref_ds: 80 | if m in ref: 81 | m_max = max(m_max, ref[m]) 82 | m_w = min(m_w, m_max) 83 | count += m_w 84 | return count 85 | 86 | 87 | def best_length_match(ref_l, cand_l): 88 | """Find the closest length of reference to that of candidate""" 89 | least_diff = abs(cand_l-ref_l[0]) 90 | best = ref_l[0] 91 | for ref in ref_l: 92 | if abs(cand_l-ref) < least_diff: 93 | least_diff = abs(cand_l-ref) 94 | best = ref 95 | return best 96 | 97 | 98 | def brevity_penalty(c, r): 99 | if c > r: 100 | bp = 1 101 | else: 102 | bp = math.exp(1-(float(r)/c)) 103 | return bp 104 | 105 | 106 | def geometric_mean(precisions): 107 | return (reduce(operator.mul, precisions)) ** (1.0 / len(precisions)) 108 | 109 | 110 | def BLEU(candidate, references, gram=4): 111 | precisions = [] 112 | for i in range(gram): 113 | pr, bp = count_ngram(candidate, references, i+1) 114 | #print pr, bp 115 | precisions.append(pr) 116 | #print geometric_mean(precisions), bp 117 | bleu = geometric_mean(precisions) * bp 118 | return bleu 119 | -------------------------------------------------------------------------------- /Util/myAttLM.py: -------------------------------------------------------------------------------- 1 | import os 2 | #os.environ["CUDA_VISIBLE_DEVICES"] = "1" 3 | 4 | from tensorflow.python.layers import core as core_layers 5 | from myAttWrapper import SelfAttWrapper 6 | import tensorflow as tf 7 | import my_helper 8 | import numpy as np 9 | import time 10 | import myResidualCell 11 | #import DiverseDecode 12 | #import jieba 13 | from bleu import BLEU 14 | import random 15 | import pickle as cPickle 16 | import matplotlib.pyplot as plt 17 | 18 | 19 | 20 | class LM: 21 | def __init__(self, dp, rnn_size, n_layers, decoder_embedding_dim, max_infer_length, is_jieba, 22 | sess, att_type='B', lr=0.001, grad_clip=5.0, beam_width=5, force_teaching_ratio=1.0, beam_penalty=1.0, 23 | residual=False, output_keep_prob=0.5, input_keep_prob=0.9, cell_type='lstm', reverse=False, is_save=True, 24 | decay_scheme='luong234'): 25 | 26 | self.rnn_size = rnn_size 27 | self.n_layers = n_layers 28 | self.is_jieba = is_jieba 29 | self.grad_clip = grad_clip 30 | self.dp = dp 31 | self.decoder_embedding_dim = decoder_embedding_dim 32 | self.beam_width = beam_width 33 | self.beam_penalty = beam_penalty 34 | self.max_infer_length = max_infer_length 35 | self.residual = residual 36 | self.decay_scheme = decay_scheme 37 | if self.residual: 38 | assert decoder_embedding_dim == rnn_size 39 | self.reverse = reverse 40 | self.cell_type = cell_type 41 | self.force_teaching_ratio = force_teaching_ratio 42 | self._output_keep_prob = output_keep_prob 43 | self._input_keep_prob = input_keep_prob 44 | self.is_save = is_save 45 | self.sess = sess 46 | self.att_type = att_type 47 | self.lr=lr 48 | self.build_graph() 49 | self.sess.run(tf.global_variables_initializer()) 50 | self.saver = tf.train.Saver(tf.trainable_variables(), max_to_keep = 15) 51 | 52 | self.summary_placeholders, self.update_ops, self.summary_op = self.setup_summary() 53 | 54 | # end constructor 55 | 56 | def build_graph(self): 57 | self.register_symbols() 58 | self.add_input_layer() 59 | with tf.variable_scope('decode'): 60 | self.add_decoder_for_training() 61 | 62 | with tf.variable_scope('decode', reuse=True): 63 | self.add_decoder_for_prefix_inference() 64 | 65 | with tf.variable_scope('decode', reuse=True): 66 | self.add_decoder_for_sample() 67 | 68 | with tf.variable_scope('decode', reuse=True): 69 | self.add_decoder_for_prefix_sample() 70 | self.add_backward_path() 71 | # end method 72 | 73 | def add_input_layer(self): 74 | self.X = tf.placeholder(tf.int32, [None, None], name="X") 75 | self.X_seq_len = tf.placeholder(tf.int32, [None], name="X_seq_len") 76 | self.input_keep_prob = tf.placeholder(tf.float32,name="input_keep_prob") 77 | self.output_keep_prob = tf.placeholder(tf.float32,name="output_keep_prob") 78 | self.batch_size = tf.shape(self.X)[0] 79 | self.init_memory = tf.zeros([self.batch_size, 1, self.rnn_size]) 80 | self.init_attention = tf.zeros([self.batch_size, self.rnn_size]) 81 | self.global_step = tf.Variable(0, name="global_step", trainable=False) 82 | # end method 83 | 84 | def single_cell(self, reuse=False): 85 | if self.cell_type == 'lstm': 86 | cell = tf.contrib.rnn.LayerNormBasicLSTMCell(self.rnn_size, reuse=reuse) 87 | else: 88 | cell = tf.contrib.rnn.GRUBlockCell(self.rnn_size) 89 | cell = tf.contrib.rnn.DropoutWrapper(cell, self.output_keep_prob, self.input_keep_prob) 90 | if self.residual: 91 | cell = myResidualCell.ResidualWrapper(cell) 92 | return cell 93 | 94 | def processed_decoder_input(self): 95 | main = tf.strided_slice(self.X, [0, 0], [self.batch_size, -1], [1, 1]) # remove last char 96 | decoder_input = tf.concat([tf.fill([self.batch_size, 1], self._x_go), main], 1) 97 | return decoder_input 98 | 99 | def add_decoder_for_training(self): 100 | self.decoder_cell = tf.nn.rnn_cell.MultiRNNCell([self.single_cell() for _ in range(1 * self.n_layers)]) 101 | self.decoder_cell = SelfAttWrapper(self.decoder_cell, self.init_attention, self.init_memory, att_layer = core_layers.Dense(self.rnn_size, name='att_dense'), att_type=self.att_type) 102 | decoder_embedding = tf.get_variable('word_embedding', [len(self.dp.X_w2id), self.decoder_embedding_dim], 103 | tf.float32, tf.random_uniform_initializer(-1.0, 1.0)) 104 | training_helper = tf.contrib.seq2seq.ScheduledEmbeddingTrainingHelper( 105 | inputs = tf.nn.embedding_lookup(decoder_embedding, self.processed_decoder_input()), 106 | sequence_length = self.X_seq_len, 107 | embedding = decoder_embedding, 108 | sampling_probability = 1 - self.force_teaching_ratio, 109 | time_major = False) 110 | training_decoder = tf.contrib.seq2seq.BasicDecoder( 111 | cell = self.decoder_cell, 112 | helper = training_helper, 113 | initial_state = self.decoder_cell.zero_state(self.batch_size, tf.float32), #.clone(cell_state=self.encoder_state), 114 | output_layer = core_layers.Dense(len(self.dp.X_w2id), name='output_dense')) 115 | training_decoder_output, training_final_state, _ = tf.contrib.seq2seq.dynamic_decode( 116 | decoder = training_decoder, 117 | impute_finished = True, 118 | maximum_iterations = tf.reduce_max(self.X_seq_len)) 119 | self.training_logits = training_decoder_output.rnn_output 120 | self.init_prefix_state = training_final_state 121 | 122 | 123 | def add_decoder_for_prefix_inference(self): 124 | self.decoder_cell = tf.nn.rnn_cell.MultiRNNCell([self.single_cell() for _ in range(1 * self.n_layers)]) 125 | self.init_attention_tiled = tf.contrib.seq2seq.tile_batch(self.init_attention, self.beam_width) 126 | self.init_memory_tiled = tf.contrib.seq2seq.tile_batch(self.init_memory, self.beam_width) 127 | 128 | self.decoder_cell = SelfAttWrapper(self.decoder_cell, self.init_attention_tiled, self.init_memory_tiled, att_layer = core_layers.Dense(self.rnn_size, name='att_dense', _reuse=True),att_type=self.att_type) 129 | self.beam_init_state = tf.contrib.seq2seq.tile_batch(self.init_prefix_state, self.beam_width) 130 | my_decoder = tf.contrib.seq2seq.BeamSearchDecoder( 131 | cell = self.decoder_cell, 132 | embedding = tf.get_variable('word_embedding'), 133 | start_tokens = tf.tile(tf.constant([self._x_go], dtype=tf.int32), [self.batch_size]), 134 | end_token = self._x_eos, 135 | initial_state = self.beam_init_state, 136 | beam_width = self.beam_width, 137 | output_layer = core_layers.Dense(len(self.dp.X_w2id), name='output_dense', _reuse=True), 138 | length_penalty_weight = self.beam_penalty) 139 | 140 | self.prefix_go = tf.placeholder(tf.int32, [None]) 141 | prefix_go_beam = tf.tile(tf.expand_dims(self.prefix_go, 1), [1, self.beam_width]) 142 | prefix_emb = tf.nn.embedding_lookup(tf.get_variable('word_embedding'), prefix_go_beam) 143 | my_decoder._start_inputs = prefix_emb 144 | predicting_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode( 145 | decoder = my_decoder, 146 | impute_finished = False, 147 | maximum_iterations = self.max_infer_length) 148 | self.prefix_infer_outputs = predicting_decoder_output.predicted_ids 149 | self.score = predicting_decoder_output.beam_search_decoder_output.scores 150 | 151 | def add_decoder_for_sample(self): 152 | self.decoder_cell = tf.nn.rnn_cell.MultiRNNCell([self.single_cell() for _ in range(1 * self.n_layers)]) 153 | self.decoder_cell = SelfAttWrapper(self.decoder_cell, self.init_attention, self.init_memory, att_layer = core_layers.Dense(self.rnn_size, name='att_dense', _reuse=True), att_type=self.att_type) 154 | word_embedding = tf.get_variable('word_embedding') 155 | sample_helper = tf.contrib.seq2seq.SampleEmbeddingHelper( 156 | embedding= word_embedding, 157 | start_tokens = tf.tile(tf.constant([self._x_go], dtype=tf.int32), [self.batch_size]), 158 | end_token = self._x_eos) 159 | sample_decoder = tf.contrib.seq2seq.BasicDecoder( 160 | cell = self.decoder_cell, 161 | helper = sample_helper, 162 | initial_state = self.decoder_cell.zero_state(self.batch_size, tf.float32),#.clone(cell_state=self.encoder_state), 163 | output_layer = core_layers.Dense(len(self.dp.X_w2id),name='output_dense', _reuse=True)) 164 | sample_decoder_output, self.sample_final_state, _ = tf.contrib.seq2seq.dynamic_decode( 165 | decoder = sample_decoder, 166 | impute_finished = False, 167 | maximum_iterations = self.max_infer_length) 168 | self.sample_output = sample_decoder_output.sample_id 169 | 170 | def add_decoder_for_prefix_sample(self): 171 | self.decoder_cell = tf.nn.rnn_cell.MultiRNNCell([self.single_cell() for _ in range(1 * self.n_layers)]) 172 | self.decoder_cell = SelfAttWrapper(self.decoder_cell, self.init_attention, self.init_memory, att_layer = core_layers.Dense(self.rnn_size, name='att_dense', _reuse=True), att_type=self.att_type) 173 | word_embedding = tf.get_variable('word_embedding') 174 | prefix_sample_helper = my_helper.MyHelper( 175 | inputs = self.processed_decoder_input(), 176 | sequence_length = self.X_seq_len, 177 | embedding= word_embedding, 178 | end_token = self._x_eos) 179 | sample_prefix_decoder = tf.contrib.seq2seq.BasicDecoder( 180 | cell = self.decoder_cell, 181 | helper = prefix_sample_helper, 182 | initial_state = self.decoder_cell.zero_state(self.batch_size, tf.float32),#.clone(cell_state=self.encoder_state), 183 | output_layer = core_layers.Dense(len(self.dp.X_w2id), name='output_dense', _reuse=True)) 184 | sample_decoder_prefix_output, self.sample_prefix_final_state, _ = tf.contrib.seq2seq.dynamic_decode( 185 | decoder = sample_prefix_decoder, 186 | impute_finished = False, 187 | maximum_iterations = self.max_infer_length) 188 | self.sample_prefix_output = sample_decoder_prefix_output.sample_id 189 | 190 | def add_backward_path(self): 191 | masks = tf.sequence_mask(self.X_seq_len, tf.reduce_max(self.X_seq_len), dtype=tf.float32) 192 | self.loss = tf.contrib.seq2seq.sequence_loss(logits = self.training_logits, 193 | targets = self.X, 194 | weights = masks) 195 | self.batch_loss = tf.contrib.seq2seq.sequence_loss(logits = self.training_logits, 196 | targets = self.X, 197 | weights = masks, 198 | average_across_batch=False) 199 | params = tf.trainable_variables() 200 | gradients = tf.gradients(self.loss, params) 201 | clipped_gradients, _ = tf.clip_by_global_norm(gradients, self.grad_clip) 202 | self.learning_rate = tf.constant(self.lr) 203 | self.learning_rate = self.get_learning_rate_decay(self.decay_scheme) # decay 204 | self.train_op = tf.train.AdamOptimizer(self.learning_rate).apply_gradients(zip(clipped_gradients, params), global_step=self.global_step) 205 | 206 | def register_symbols(self): 207 | self._x_go = self.dp.X_w2id[''] 208 | self._x_eos = self.dp.X_w2id[''] 209 | self._x_pad = self.dp.X_w2id[''] 210 | self._x_unk = self.dp.X_w2id[''] 211 | 212 | def infer(self, input_word, batch_size=1, is_show=True): 213 | #return ["pass"] 214 | if self.is_jieba: 215 | input_index = list(jieba.cut(input_word)) 216 | else: 217 | input_index = input_word 218 | xx = [char for char in input_index] 219 | if self.reverse: 220 | xx = xx[::-1] 221 | #print(xx) 222 | length = [len(xx),] * batch_size 223 | input_indices = [[self.dp.X_w2id.get(char, self._x_unk) for char in xx]] * batch_size 224 | prefix_go = [] 225 | for ipt in input_indices: 226 | prefix_go.append(ipt[-1]) 227 | out_indices, scores = self.sess.run([self.prefix_infer_outputs, self.score], { 228 | self.X: input_indices, self.X_seq_len: length, self.prefix_go: prefix_go, self.input_keep_prob:1, 229 | self.output_keep_prob:1}) 230 | outputs = [] 231 | for idx in range(out_indices.shape[-1]): 232 | eos_id = self.dp.X_w2id[''] 233 | ot = out_indices[0,:,idx] 234 | if eos_id in ot: 235 | ot = ot.tolist() 236 | ot = ot[:ot.index(eos_id)] 237 | #print(ot) 238 | if self.reverse: 239 | ot = ot[::-1] 240 | if self.reverse: 241 | output_str = ''.join([self.dp.X_id2w.get(i, '<-1>') for i in ot]) + input_word 242 | else: 243 | output_str = input_word+''.join([self.dp.X_id2w.get(i, '<-1>') for i in ot]) 244 | outputs.append(output_str) 245 | return outputs 246 | 247 | 248 | def batch_infer(self, input_words, is_show=True): 249 | #return ["pass"] 250 | #xx = [char for char in input_index] 251 | #if self.reverse: 252 | # xx = xx[::-1] 253 | length = [len(xx) for xx in input_words] 254 | input_indices = [[self.dp.X_w2id.get(char, self._x_unk) for char in s] for s in input_words] 255 | prefix_go = [] 256 | #print(length) 257 | for ipt in input_indices: 258 | prefix_go.append(ipt[-1]) 259 | #print(prefix_go) 260 | out_indices, scores = self.sess.run([self.prefix_infer_outputs, self.score], { 261 | self.X: input_indices, self.X_seq_len: length, self.prefix_go: prefix_go, self.input_keep_prob:1, 262 | self.output_keep_prob:1}) 263 | outputs = [] 264 | for b in range(len(input_indices)): 265 | eos_id = self.dp.X_w2id[''] 266 | ot = out_indices[b,:,0] 267 | if eos_id in ot: 268 | ot = ot.tolist() 269 | ot = ot[:ot.index(eos_id)] 270 | #if self.reverse: 271 | # ot = ot[::-1] 272 | #if self.reverse: 273 | # output_str = ''.join([self.dp.X_id2w.get(i, '<-1>') for i in ot]) + input_words[b] 274 | #else: 275 | output_str = input_words[b] +''.join([self.dp.X_id2w.get(i, '<-1>') for i in ot]) 276 | outputs.append(output_str) 277 | return outputs 278 | 279 | def generate(self, batch_size=1, is_show=True): 280 | fake_x = [[1] for _ in range(batch_size)] 281 | out_indices = self.sess.run(self.sample_output, {self.X: fake_x, self.input_keep_prob:1, self.output_keep_prob:1}) 282 | #print(out_indices.shape) 283 | outputs = [] 284 | for ot in out_indices: 285 | eos_id = self.dp.X_w2id[''] 286 | if eos_id in ot: 287 | ot = ot.tolist() 288 | ot = ot[:ot.index(eos_id)] 289 | if self.reverse: 290 | ot = ot[::-1] 291 | if self.reverse: 292 | output_str = ' '.join([self.dp.X_id2w.get(i, '<-1>') for i in ot]) 293 | else: 294 | output_str = ' '.join([self.dp.X_id2w.get(i, '<-1>') for i in ot]) 295 | outputs.append(output_str) 296 | return out_indices, outputs 297 | 298 | def rollout_batch(self, input_indices): 299 | length = [len(ind)+1 for ind in input_indices] 300 | input_indices = [x.tolist()+[self.dp.X_w2id[''],] for x in input_indices] 301 | #print(input_indices) 302 | 303 | ## show 304 | 305 | #for _ in input_indices: 306 | # print(" ".join([self.dp.X_id2w.get(i, '<-1>') for i in _])) 307 | 308 | out_indices = self.sess.run(self.sample_prefix_output, { 309 | self.X: input_indices, self.X_seq_len: length, self.input_keep_prob:1, 310 | self.output_keep_prob:1}) 311 | outputs = [] 312 | for ot in out_indices: 313 | eos_id = self.dp.X_w2id[''] 314 | if eos_id in ot: 315 | ot = ot.tolist() 316 | ot = ot[:ot.index(eos_id)] 317 | if self.reverse: 318 | ot = ot[::-1] 319 | 320 | if self.reverse: 321 | output_str = ' '.join([self.dp.X_id2w.get(i, '<-1>') for i in ot]) 322 | else: 323 | output_str = ' '.join([self.dp.X_id2w.get(i, '<-1>') for i in ot]) 324 | outputs.append(output_str) 325 | return out_indices, outputs 326 | 327 | def rollout(self, input_word, batch_size=1, is_show=True): 328 | if self.is_jieba: 329 | input_index = list(jieba.cut(input_word)) 330 | else: 331 | input_index = input_word 332 | xx = [char for char in input_index] 333 | if self.reverse: 334 | xx = xx[::-1] 335 | length = [len(xx)+1] * batch_size 336 | input_indices = [[self.dp.X_w2id.get(char, self._x_unk) for char in xx]] * batch_size 337 | input_indices = [x+[self.dp.X_w2id[''],] for x in input_indices] 338 | #print(input_indices) 339 | out_indices = self.sess.run(self.sample_prefix_output, { 340 | self.X: input_indices, self.X_seq_len: length, self.input_keep_prob:1, 341 | self.output_keep_prob:1}) 342 | outputs = [] 343 | for ot in out_indices: 344 | eos_id = self.dp.X_w2id[''] 345 | if eos_id in ot: 346 | ot = ot.tolist() 347 | ot = ot[:ot.index(eos_id)] 348 | if self.reverse: 349 | ot = ot[::-1] 350 | 351 | if self.reverse: 352 | output_str = ''.join([self.dp.X_id2w.get(i, '<-1>') for i in ot]) 353 | else: 354 | output_str = ''.join([self.dp.X_id2w.get(i, '<-1>') for i in ot]) 355 | outputs.append(output_str) 356 | return outputs 357 | 358 | def restore(self, path): 359 | self.saver.restore(self.sess, path) 360 | print('restore %s success' % path) 361 | 362 | def get_learning_rate_decay(self, decay_scheme='luong234'): 363 | num_train_steps = self.dp.num_steps 364 | if decay_scheme == "luong10": 365 | start_decay_step = int(num_train_steps / 2) 366 | remain_steps = num_train_steps - start_decay_step 367 | decay_steps = int(remain_steps / 10) # decay 10 times 368 | decay_factor = 0.5 369 | else: 370 | start_decay_step = int(num_train_steps * 2 / 3) 371 | remain_steps = num_train_steps - start_decay_step 372 | decay_steps = int(remain_steps / 4) # decay 4 times 373 | decay_factor = 0.5 374 | return tf.cond( 375 | self.global_step < start_decay_step, 376 | lambda: self.learning_rate, 377 | lambda: tf.train.exponential_decay( 378 | self.learning_rate, 379 | (self.global_step - start_decay_step), 380 | decay_steps, decay_factor, staircase=True), 381 | name="learning_rate_decay_cond") 382 | 383 | def setup_summary(self): 384 | train_loss = tf.Variable(0.) 385 | tf.summary.scalar('Train_loss', train_loss) 386 | 387 | test_loss = tf.Variable(0.) 388 | tf.summary.scalar('Test_loss', test_loss) 389 | 390 | bleu_score = tf.Variable(0.) 391 | tf.summary.scalar('BLEU_score', bleu_score) 392 | 393 | tf.summary.scalar('lr_rate', self.learning_rate) 394 | 395 | summary_vars = [train_loss, test_loss, bleu_score] 396 | summary_placeholders = [tf.placeholder(tf.float32) for _ in range(len(summary_vars))] 397 | update_ops = [summary_vars[i].assign(summary_placeholders[i]) for i in range(len(summary_vars))] 398 | summary_op = tf.summary.merge_all() 399 | return summary_placeholders, update_ops, summary_op -------------------------------------------------------------------------------- /Util/myAttLM_Diverse.py: -------------------------------------------------------------------------------- 1 | import os 2 | #os.environ["CUDA_VISIBLE_DEVICES"] = "1" 3 | 4 | from tensorflow.python.layers import core as core_layers 5 | from myAttWrapper import SelfAttWrapper 6 | import tensorflow as tf 7 | import my_helper 8 | import numpy as np 9 | import time 10 | import myResidualCell 11 | import DiverseDecode 12 | #import jieba 13 | from bleu import BLEU 14 | import random 15 | import pickle as cPickle 16 | import matplotlib.pyplot as plt 17 | 18 | 19 | 20 | class LM: 21 | def __init__(self, dp, rnn_size, n_layers, decoder_embedding_dim, max_infer_length, is_jieba, gamma, 22 | sess, att_type='B', lr=0.001, grad_clip=5.0, beam_width=5, force_teaching_ratio=1.0, beam_penalty=1.0, 23 | residual=False, output_keep_prob=0.5, input_keep_prob=0.9, cell_type='lstm', reverse=False, is_save=True, 24 | decay_scheme='luong234'): 25 | 26 | self.rnn_size = rnn_size 27 | self.n_layers = n_layers 28 | self.is_jieba = is_jieba 29 | self.grad_clip = grad_clip 30 | self.dp = dp 31 | self.decoder_embedding_dim = decoder_embedding_dim 32 | self.beam_width = beam_width 33 | self.beam_penalty = beam_penalty 34 | self.max_infer_length = max_infer_length 35 | self.residual = residual 36 | self.decay_scheme = decay_scheme 37 | if self.residual: 38 | assert decoder_embedding_dim == rnn_size 39 | self.reverse = reverse 40 | self.cell_type = cell_type 41 | self.force_teaching_ratio = force_teaching_ratio 42 | self._output_keep_prob = output_keep_prob 43 | self._input_keep_prob = input_keep_prob 44 | self.is_save = is_save 45 | self.sess = sess 46 | self.gamma = gamma 47 | self.att_type = att_type 48 | self.lr=lr 49 | self.build_graph() 50 | self.sess.run(tf.global_variables_initializer()) 51 | self.saver = tf.train.Saver(tf.global_variables(), max_to_keep = 15) 52 | self.summary_placeholders, self.update_ops, self.summary_op = self.setup_summary() 53 | 54 | # end constructor 55 | 56 | def build_graph(self): 57 | self.register_symbols() 58 | self.add_input_layer() 59 | with tf.variable_scope('decode'): 60 | self.add_decoder_for_training() 61 | 62 | with tf.variable_scope('decode', reuse=True): 63 | self.add_decoder_for_prefix_inference() 64 | 65 | with tf.variable_scope('decode', reuse=True): 66 | self.add_decoder_for_sample() 67 | 68 | with tf.variable_scope('decode', reuse=True): 69 | self.add_decoder_for_prefix_sample() 70 | self.add_backward_path() 71 | # end method 72 | 73 | def add_input_layer(self): 74 | self.X = tf.placeholder(tf.int32, [1, None], name="X") 75 | self.X_seq_len = tf.placeholder(tf.int32, [1], name="X_seq_len") 76 | self.input_keep_prob = tf.placeholder(tf.float32,name="input_keep_prob") 77 | self.output_keep_prob = tf.placeholder(tf.float32,name="output_keep_prob") 78 | self.batch_size = 1 79 | self.init_memory = tf.zeros([self.batch_size, 1, self.rnn_size]) 80 | self.init_attention = tf.zeros([self.batch_size, self.rnn_size]) 81 | self.global_step = tf.Variable(0, name="global_step", trainable=False) 82 | # end method 83 | 84 | def single_cell(self, reuse=False): 85 | if self.cell_type == 'lstm': 86 | cell = tf.contrib.rnn.LayerNormBasicLSTMCell(self.rnn_size, reuse=reuse) 87 | else: 88 | cell = tf.contrib.rnn.GRUBlockCell(self.rnn_size) 89 | cell = tf.contrib.rnn.DropoutWrapper(cell, self.output_keep_prob, self.input_keep_prob) 90 | if self.residual: 91 | cell = myResidualCell.ResidualWrapper(cell) 92 | return cell 93 | 94 | def processed_decoder_input(self): 95 | main = tf.strided_slice(self.X, [0, 0], [self.batch_size, -1], [1, 1]) # remove last char 96 | decoder_input = tf.concat([tf.fill([self.batch_size, 1], self._x_go), main], 1) 97 | return decoder_input 98 | 99 | def add_decoder_for_training(self): 100 | self.decoder_cell = tf.nn.rnn_cell.MultiRNNCell([self.single_cell() for _ in range(1 * self.n_layers)]) 101 | self.decoder_cell = SelfAttWrapper(self.decoder_cell, self.init_attention, self.init_memory, att_layer = core_layers.Dense(self.rnn_size, name='att_dense'), att_type=self.att_type) 102 | decoder_embedding = tf.get_variable('word_embedding', [len(self.dp.X_w2id), self.decoder_embedding_dim], 103 | tf.float32, tf.random_uniform_initializer(-1.0, 1.0)) 104 | training_helper = tf.contrib.seq2seq.ScheduledEmbeddingTrainingHelper( 105 | inputs = tf.nn.embedding_lookup(decoder_embedding, self.processed_decoder_input()), 106 | sequence_length = self.X_seq_len, 107 | embedding = decoder_embedding, 108 | sampling_probability = 1 - self.force_teaching_ratio, 109 | time_major = False) 110 | training_decoder = tf.contrib.seq2seq.BasicDecoder( 111 | cell = self.decoder_cell, 112 | helper = training_helper, 113 | initial_state = self.decoder_cell.zero_state(self.batch_size, tf.float32), #.clone(cell_state=self.encoder_state), 114 | output_layer = core_layers.Dense(len(self.dp.X_w2id), name='output_dense')) 115 | training_decoder_output, training_final_state, _ = tf.contrib.seq2seq.dynamic_decode( 116 | decoder = training_decoder, 117 | impute_finished = True, 118 | maximum_iterations = tf.reduce_max(self.X_seq_len)) 119 | self.training_logits = training_decoder_output.rnn_output 120 | self.init_prefix_state = training_final_state 121 | 122 | 123 | def add_decoder_for_prefix_inference(self): 124 | self.decoder_cell = tf.nn.rnn_cell.MultiRNNCell([self.single_cell() for _ in range(1 * self.n_layers)]) 125 | self.init_attention_tiled = tf.contrib.seq2seq.tile_batch(self.init_attention, self.beam_width) 126 | self.init_memory_tiled = tf.contrib.seq2seq.tile_batch(self.init_memory, self.beam_width) 127 | 128 | self.decoder_cell = SelfAttWrapper(self.decoder_cell, self.init_attention_tiled, self.init_memory_tiled, att_layer = core_layers.Dense(self.rnn_size, name='att_dense', _reuse=True),att_type=self.att_type) 129 | self.beam_init_state = tf.contrib.seq2seq.tile_batch(self.init_prefix_state, self.beam_width) 130 | my_decoder = DiverseDecode.BeamSearchDecoder( 131 | cell = self.decoder_cell, 132 | embedding = tf.get_variable('word_embedding'), 133 | start_tokens = tf.tile(tf.constant([self._x_go], dtype=tf.int32), [self.batch_size]), 134 | end_token = self._x_eos, 135 | gamma = self.gamma, 136 | initial_state = self.beam_init_state, 137 | beam_width = self.beam_width, 138 | vocab_size = len(self.dp.X_w2id), 139 | output_layer = core_layers.Dense(len(self.dp.X_w2id), name='output_dense', _reuse=True), 140 | length_penalty_weight = self.beam_penalty) 141 | 142 | self.prefix_go = tf.placeholder(tf.int32, [None]) 143 | prefix_go_beam = tf.tile(tf.expand_dims(self.prefix_go, 1), [1, self.beam_width]) 144 | prefix_emb = tf.nn.embedding_lookup(tf.get_variable('word_embedding'), prefix_go_beam) 145 | my_decoder._start_inputs = prefix_emb 146 | predicting_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode( 147 | decoder = my_decoder, 148 | impute_finished = False, 149 | maximum_iterations = self.max_infer_length) 150 | self.prefix_infer_outputs = predicting_decoder_output.predicted_ids 151 | self.score = predicting_decoder_output.beam_search_decoder_output.scores 152 | 153 | def add_decoder_for_sample(self): 154 | self.decoder_cell = tf.nn.rnn_cell.MultiRNNCell([self.single_cell() for _ in range(1 * self.n_layers)]) 155 | self.decoder_cell = SelfAttWrapper(self.decoder_cell, self.init_attention, self.init_memory, att_layer = core_layers.Dense(self.rnn_size, name='att_dense', _reuse=True), att_type=self.att_type) 156 | word_embedding = tf.get_variable('word_embedding') 157 | sample_helper = tf.contrib.seq2seq.SampleEmbeddingHelper( 158 | embedding= word_embedding, 159 | start_tokens = tf.tile(tf.constant([self._x_go], dtype=tf.int32), [self.batch_size]), 160 | end_token = self._x_eos) 161 | sample_decoder = tf.contrib.seq2seq.BasicDecoder( 162 | cell = self.decoder_cell, 163 | helper = sample_helper, 164 | initial_state = self.decoder_cell.zero_state(self.batch_size, tf.float32),#.clone(cell_state=self.encoder_state), 165 | output_layer = core_layers.Dense(len(self.dp.X_w2id),name='output_dense', _reuse=True)) 166 | sample_decoder_output, self.sample_final_state, _ = tf.contrib.seq2seq.dynamic_decode( 167 | decoder = sample_decoder, 168 | impute_finished = False, 169 | maximum_iterations = self.max_infer_length) 170 | self.sample_output = sample_decoder_output.sample_id 171 | 172 | def add_decoder_for_prefix_sample(self): 173 | self.decoder_cell = tf.nn.rnn_cell.MultiRNNCell([self.single_cell() for _ in range(1 * self.n_layers)]) 174 | self.decoder_cell = SelfAttWrapper(self.decoder_cell, self.init_attention, self.init_memory, att_layer = core_layers.Dense(self.rnn_size, name='att_dense', _reuse=True), att_type=self.att_type) 175 | word_embedding = tf.get_variable('word_embedding') 176 | prefix_sample_helper = my_helper.MyHelper( 177 | inputs = self.processed_decoder_input(), 178 | sequence_length = self.X_seq_len, 179 | embedding= word_embedding, 180 | end_token = self._x_eos) 181 | sample_prefix_decoder = tf.contrib.seq2seq.BasicDecoder( 182 | cell = self.decoder_cell, 183 | helper = prefix_sample_helper, 184 | initial_state = self.decoder_cell.zero_state(self.batch_size, tf.float32),#.clone(cell_state=self.encoder_state), 185 | output_layer = core_layers.Dense(len(self.dp.X_w2id), name='output_dense', _reuse=True)) 186 | sample_decoder_prefix_output, self.sample_prefix_final_state, _ = tf.contrib.seq2seq.dynamic_decode( 187 | decoder = sample_prefix_decoder, 188 | impute_finished = False, 189 | maximum_iterations = self.max_infer_length) 190 | self.sample_prefix_output = sample_decoder_prefix_output.sample_id 191 | 192 | def add_backward_path(self): 193 | masks = tf.sequence_mask(self.X_seq_len, tf.reduce_max(self.X_seq_len), dtype=tf.float32) 194 | self.loss = tf.contrib.seq2seq.sequence_loss(logits = self.training_logits, 195 | targets = self.X, 196 | weights = masks) 197 | self.batch_loss = tf.contrib.seq2seq.sequence_loss(logits = self.training_logits, 198 | targets = self.X, 199 | weights = masks, 200 | average_across_batch=False) 201 | params = tf.trainable_variables() 202 | gradients = tf.gradients(self.loss, params) 203 | clipped_gradients, _ = tf.clip_by_global_norm(gradients, self.grad_clip) 204 | self.learning_rate = tf.constant(self.lr) 205 | self.learning_rate = self.get_learning_rate_decay(self.decay_scheme) # decay 206 | self.train_op = tf.train.AdamOptimizer(self.learning_rate).apply_gradients(zip(clipped_gradients, params), global_step=self.global_step) 207 | 208 | def register_symbols(self): 209 | self._x_go = self.dp.X_w2id[''] 210 | self._x_eos = self.dp.X_w2id[''] 211 | self._x_pad = self.dp.X_w2id[''] 212 | self._x_unk = self.dp.X_w2id[''] 213 | 214 | def infer(self, input_word, batch_size=1, is_show=True): 215 | #return ["pass"] 216 | if self.is_jieba: 217 | input_index = list(jieba.cut(input_word)) 218 | else: 219 | input_index = input_word 220 | xx = [char for char in input_index] 221 | if self.reverse: 222 | xx = xx[::-1] 223 | #print(xx) 224 | length = [len(xx),] * batch_size 225 | input_indices = [[self.dp.X_w2id.get(char, self._x_unk) for char in xx]] * batch_size 226 | prefix_go = [] 227 | for ipt in input_indices: 228 | prefix_go.append(ipt[-1]) 229 | out_indices, scores = self.sess.run([self.prefix_infer_outputs, self.score], { 230 | self.X: input_indices, self.X_seq_len: length, self.prefix_go: prefix_go, self.input_keep_prob:1, 231 | self.output_keep_prob:1}) 232 | outputs = [] 233 | for idx in range(out_indices.shape[-1]): 234 | eos_id = self.dp.X_w2id[''] 235 | ot = out_indices[0,:,idx] 236 | if eos_id in ot: 237 | ot = ot.tolist() 238 | ot = ot[:ot.index(eos_id)] 239 | #print(ot) 240 | if self.reverse: 241 | ot = ot[::-1] 242 | if self.reverse: 243 | output_str = ''.join([self.dp.X_id2w.get(i, '<-1>') for i in ot]) + input_word 244 | else: 245 | output_str = input_word+''.join([self.dp.X_id2w.get(i, '<-1>') for i in ot]) 246 | outputs.append(output_str) 247 | return outputs 248 | 249 | 250 | def batch_infer(self, input_words, is_show=True): 251 | #return ["pass"] 252 | #xx = [char for char in input_index] 253 | #if self.reverse: 254 | # xx = xx[::-1] 255 | length = [len(xx) for xx in input_words] 256 | input_indices = [[self.dp.X_w2id.get(char, self._x_unk) for char in s] for s in input_words] 257 | prefix_go = [] 258 | #print(length) 259 | for ipt in input_indices: 260 | prefix_go.append(ipt[-1]) 261 | #print(prefix_go) 262 | out_indices, scores = self.sess.run([self.prefix_infer_outputs, self.score], { 263 | self.X: input_indices, self.X_seq_len: length, self.prefix_go: prefix_go, self.input_keep_prob:1, 264 | self.output_keep_prob:1}) 265 | outputs = [] 266 | for b in range(len(input_indices)): 267 | eos_id = self.dp.X_w2id[''] 268 | ot = out_indices[b,:,0] 269 | if eos_id in ot: 270 | ot = ot.tolist() 271 | ot = ot[:ot.index(eos_id)] 272 | #if self.reverse: 273 | # ot = ot[::-1] 274 | #if self.reverse: 275 | # output_str = ''.join([self.dp.X_id2w.get(i, '<-1>') for i in ot]) + input_words[b] 276 | #else: 277 | output_str = input_words[b] +''.join([self.dp.X_id2w.get(i, '<-1>') for i in ot]) 278 | outputs.append(output_str) 279 | return outputs 280 | 281 | def generate(self, batch_size=1, is_show=True): 282 | fake_x = [[1] for _ in range(batch_size)] 283 | out_indices = self.sess.run(self.sample_output, {self.X: fake_x, self.input_keep_prob:1, self.output_keep_prob:1}) 284 | #print(out_indices.shape) 285 | outputs = [] 286 | for ot in out_indices: 287 | eos_id = self.dp.X_w2id[''] 288 | if eos_id in ot: 289 | ot = ot.tolist() 290 | ot = ot[:ot.index(eos_id)] 291 | if self.reverse: 292 | ot = ot[::-1] 293 | if self.reverse: 294 | output_str = ' '.join([self.dp.X_id2w.get(i, '<-1>') for i in ot]) 295 | else: 296 | output_str = ' '.join([self.dp.X_id2w.get(i, '<-1>') for i in ot]) 297 | outputs.append(output_str) 298 | return out_indices, outputs 299 | 300 | def rollout_batch(self, input_indices): 301 | length = [len(ind)+1 for ind in input_indices] 302 | input_indices = [x.tolist()+[self.dp.X_w2id[''],] for x in input_indices] 303 | #print(input_indices) 304 | 305 | ## show 306 | 307 | #for _ in input_indices: 308 | # print(" ".join([self.dp.X_id2w.get(i, '<-1>') for i in _])) 309 | 310 | out_indices = self.sess.run(self.sample_prefix_output, { 311 | self.X: input_indices, self.X_seq_len: length, self.input_keep_prob:1, 312 | self.output_keep_prob:1}) 313 | outputs = [] 314 | for ot in out_indices: 315 | eos_id = self.dp.X_w2id[''] 316 | if eos_id in ot: 317 | ot = ot.tolist() 318 | ot = ot[:ot.index(eos_id)] 319 | if self.reverse: 320 | ot = ot[::-1] 321 | 322 | if self.reverse: 323 | output_str = ' '.join([self.dp.X_id2w.get(i, '<-1>') for i in ot]) 324 | else: 325 | output_str = ' '.join([self.dp.X_id2w.get(i, '<-1>') for i in ot]) 326 | outputs.append(output_str) 327 | return out_indices, outputs 328 | 329 | def rollout(self, input_word, batch_size=1, is_show=True): 330 | if self.is_jieba: 331 | input_index = list(jieba.cut(input_word)) 332 | else: 333 | input_index = input_word 334 | xx = [char for char in input_index] 335 | if self.reverse: 336 | xx = xx[::-1] 337 | length = [len(xx)+1] * batch_size 338 | input_indices = [[self.dp.X_w2id.get(char, self._x_unk) for char in xx]] * batch_size 339 | input_indices = [x+[self.dp.X_w2id[''],] for x in input_indices] 340 | #print(input_indices) 341 | out_indices = self.sess.run(self.sample_prefix_output, { 342 | self.X: input_indices, self.X_seq_len: length, self.input_keep_prob:1, 343 | self.output_keep_prob:1}) 344 | outputs = [] 345 | for ot in out_indices: 346 | eos_id = self.dp.X_w2id[''] 347 | if eos_id in ot: 348 | ot = ot.tolist() 349 | ot = ot[:ot.index(eos_id)] 350 | if self.reverse: 351 | ot = ot[::-1] 352 | 353 | if self.reverse: 354 | output_str = ''.join([self.dp.X_id2w.get(i, '<-1>') for i in ot]) 355 | else: 356 | output_str = ''.join([self.dp.X_id2w.get(i, '<-1>') for i in ot]) 357 | outputs.append(output_str) 358 | return outputs 359 | 360 | def restore(self, path): 361 | self.saver.restore(self.sess, path) 362 | print('restore %s success' % path) 363 | 364 | def get_learning_rate_decay(self, decay_scheme='luong234'): 365 | num_train_steps = self.dp.num_steps 366 | if decay_scheme == "luong10": 367 | start_decay_step = int(num_train_steps / 2) 368 | remain_steps = num_train_steps - start_decay_step 369 | decay_steps = int(remain_steps / 10) # decay 10 times 370 | decay_factor = 0.5 371 | else: 372 | start_decay_step = int(num_train_steps * 2 / 3) 373 | remain_steps = num_train_steps - start_decay_step 374 | decay_steps = int(remain_steps / 4) # decay 4 times 375 | decay_factor = 0.5 376 | return tf.cond( 377 | self.global_step < start_decay_step, 378 | lambda: self.learning_rate, 379 | lambda: tf.train.exponential_decay( 380 | self.learning_rate, 381 | (self.global_step - start_decay_step), 382 | decay_steps, decay_factor, staircase=True), 383 | name="learning_rate_decay_cond") 384 | 385 | def setup_summary(self): 386 | train_loss = tf.Variable(0.) 387 | tf.summary.scalar('Train_loss', train_loss) 388 | 389 | test_loss = tf.Variable(0.) 390 | tf.summary.scalar('Test_loss', test_loss) 391 | 392 | bleu_score = tf.Variable(0.) 393 | tf.summary.scalar('BLEU_score', bleu_score) 394 | 395 | tf.summary.scalar('lr_rate', self.learning_rate) 396 | 397 | summary_vars = [train_loss, test_loss, bleu_score] 398 | summary_placeholders = [tf.placeholder(tf.float32) for _ in range(len(summary_vars))] 399 | update_ops = [summary_vars[i].assign(summary_placeholders[i]) for i in range(len(summary_vars))] 400 | summary_op = tf.summary.merge_all() 401 | return summary_placeholders, update_ops, summary_op -------------------------------------------------------------------------------- /Util/myAttWrapper.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import collections 7 | import functools 8 | import math 9 | import numpy as np 10 | from tensorflow.python.ops.rnn_cell_impl import * 11 | from tensorflow.contrib.seq2seq.python.ops import decoder 12 | from tensorflow.contrib.seq2seq.python.ops import helper as helper_py 13 | from tensorflow.python.framework import ops 14 | from tensorflow.python.framework import tensor_shape 15 | from tensorflow.python.layers import base as layers_base 16 | from tensorflow.python.ops import rnn_cell_impl 17 | from tensorflow.python.util import nest 18 | from tensorflow.contrib.framework.python.framework import tensor_util 19 | from tensorflow.python.framework import dtypes 20 | from tensorflow.python.layers import core as layers_core 21 | from tensorflow.python.ops import array_ops 22 | from tensorflow.python.ops import check_ops 23 | from tensorflow.python.ops import clip_ops 24 | from tensorflow.python.ops import functional_ops 25 | from tensorflow.python.ops import init_ops 26 | from tensorflow.python.ops import math_ops 27 | from tensorflow.python.ops import nn_ops 28 | from tensorflow.python.ops import random_ops 29 | from tensorflow.python.ops import tensor_array_ops 30 | from tensorflow.python.ops import variable_scope 31 | 32 | def _luong_score(query, keys, scale=True): 33 | """Implements Luong-style (multiplicative) scoring function. 34 | This attention has two forms. The first is standard Luong attention, 35 | as described in: 36 | Minh-Thang Luong, Hieu Pham, Christopher D. Manning. 37 | "Effective Approaches to Attention-based Neural Machine Translation." 38 | EMNLP 2015. https://arxiv.org/abs/1508.04025 39 | The second is the scaled form inspired partly by the normalized form of 40 | Bahdanau attention. 41 | To enable the second form, call this function with `scale=True`. 42 | Args: 43 | query: Tensor, shape `[batch_size, num_units]` to compare to keys. 44 | keys: Processed memory, shape `[batch_size, max_time, num_units]`. 45 | scale: Whether to apply a scale to the score function. 46 | Returns: 47 | A `[batch_size, max_time]` tensor of unnormalized score values. 48 | Raises: 49 | ValueError: If `key` and `query` depths do not match. 50 | """ 51 | depth = query.get_shape()[-1] 52 | key_units = keys.get_shape()[-1] 53 | if depth != key_units: 54 | raise ValueError( 55 | "Incompatible or unknown inner dimensions between query and keys. " 56 | "Query (%s) has units: %s. Keys (%s) have units: %s. " 57 | "Perhaps you need to set num_units to the keys' dimension (%s)?" 58 | % (query, depth, keys, key_units, key_units)) 59 | dtype = query.dtype 60 | 61 | # Reshape from [batch_size, depth] to [batch_size, 1, depth] 62 | # for matmul. 63 | query = array_ops.expand_dims(query, 1) 64 | 65 | # Inner product along the query units dimension. 66 | # matmul shapes: query is [batch_size, 1, depth] and 67 | # keys is [batch_size, max_time, depth]. 68 | # the inner product is asked to **transpose keys' inner shape** to get a 69 | # batched matmul on: 70 | # [batch_size, 1, depth] . [batch_size, depth, max_time] 71 | # resulting in an output shape of: 72 | # [batch_time, 1, max_time]. 73 | # we then squeeze out the center singleton dimension. 74 | score = math_ops.matmul(query, keys, transpose_b=True) 75 | score = array_ops.squeeze(score, [1]) 76 | 77 | if scale: 78 | # Scalar used in weight scaling 79 | g = variable_scope.get_variable( 80 | "attention_g", dtype=dtype, initializer=1.) 81 | score = g * score 82 | return score 83 | 84 | def _bahdanau_score(processed_query, keys, normalize=True): 85 | """Implements Bahdanau-style (additive) scoring function. 86 | This attention has two forms. The first is Bhandanau attention, 87 | as described in: 88 | Dzmitry Bahdanau, Kyunghyun Cho, Yoshua Bengio. 89 | "Neural Machine Translation by Jointly Learning to Align and Translate." 90 | ICLR 2015. https://arxiv.org/abs/1409.0473 91 | The second is the normalized form. This form is inspired by the 92 | weight normalization article: 93 | Tim Salimans, Diederik P. Kingma. 94 | "Weight Normalization: A Simple Reparameterization to Accelerate 95 | Training of Deep Neural Networks." 96 | https://arxiv.org/abs/1602.07868 97 | To enable the second form, set `normalize=True`. 98 | Args: 99 | processed_query: Tensor, shape `[batch_size, num_units]` to compare to keys. 100 | keys: Processed memory, shape `[batch_size, max_time, num_units]`. 101 | normalize: Whether to normalize the score function. 102 | Returns: 103 | A `[batch_size, max_time]` tensor of unnormalized score values. 104 | """ 105 | dtype = processed_query.dtype 106 | # Get the number of hidden units from the trailing dimension of keys 107 | num_units = keys.shape[2].value or array_ops.shape(keys)[2] 108 | # Reshape from [batch_size, ...] to [batch_size, 1, ...] for broadcasting. 109 | processed_query = array_ops.expand_dims(processed_query, 1) 110 | v = variable_scope.get_variable( 111 | "attention_v", [num_units], dtype=dtype) 112 | if normalize: 113 | # Scalar used in weight normalization 114 | g = variable_scope.get_variable( 115 | "attention_g", dtype=dtype, 116 | initializer=math.sqrt((1. / num_units))) 117 | # Bias added prior to the nonlinearity 118 | b = variable_scope.get_variable( 119 | "attention_b", [num_units], dtype=dtype, 120 | initializer=init_ops.zeros_initializer()) 121 | # normed_v = g * v / ||v|| 122 | normed_v = g * v * math_ops.rsqrt( 123 | math_ops.reduce_sum(math_ops.square(v))) 124 | return math_ops.reduce_sum( 125 | normed_v * math_ops.tanh(keys + processed_query + b), [2]) 126 | else: 127 | return math_ops.reduce_sum(v * math_ops.tanh(keys + processed_query), [2]) 128 | 129 | 130 | class SelfAttWrapper(RNNCell): 131 | """RNNCell wrapper that ensures cell inputs are added to the outputs.""" 132 | 133 | def __init__(self, cell, initial_attention, initial_memory, att_layer, att_type='B'): 134 | """Constructs a `ResidualWrapper` for `cell`. 135 | Args: 136 | cell: An instance of `RNNCell`. 137 | residual_fn: (Optional) The function to map raw cell inputs and raw cell 138 | outputs to the actual cell outputs of the residual network. 139 | Defaults to calling nest.map_structure on (lambda i, o: i + o), inputs 140 | and outputs. 141 | """ 142 | self._cell = cell 143 | self._memory_list = [initial_memory,] 144 | self._attention_list = [initial_attention,] 145 | assert(att_type=='B' or att_type=='L') 146 | if att_type == 'B': 147 | self._att_func = _bahdanau_score 148 | else: 149 | self._att_func = _luong_score 150 | self._att_layer = att_layer 151 | 152 | @property 153 | def state_size(self): 154 | return self._cell.state_size 155 | 156 | @property 157 | def output_size(self): 158 | return self._cell.output_size 159 | 160 | def zero_state(self, batch_size, dtype): 161 | with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): 162 | return self._cell.zero_state(batch_size, dtype) 163 | 164 | def __call__(self, inputs, state, scope=None): 165 | inputs = array_ops.concat([inputs, self._attention_list[-1]], 1) 166 | cell_outputs, new_state = self._cell(inputs, state) 167 | if self._att_layer is not None: 168 | query = self._att_layer(cell_outputs) 169 | else: 170 | query = cell_outputs 171 | memory = self._memory_list[-1] 172 | expand_query = array_ops.expand_dims(query, 1) 173 | memory = array_ops.concat([memory, expand_query], 1) # [batch, time+1, depth] 174 | # [batch_size, 1, depth] * [batch_size, depth, time+1] 175 | #expanded_alignments = math_ops.matmul(expand_query, memory, transpose_b=True) #[batch_size, 1, time+1] 176 | alignments = self._att_func(query, memory) 177 | expanded_alignments = array_ops.expand_dims(alignments, 1) 178 | expanded_attention = math_ops.matmul(expanded_alignments, memory) 179 | attention = array_ops.squeeze(expanded_attention, [1]) # [batch_size, depth] 180 | self._attention_list.append(attention) 181 | self._memory_list.append(memory) 182 | 183 | return (cell_outputs, new_state) 184 | 185 | class SelfAttOtWrapper(RNNCell): 186 | """RNNCell wrapper that ensures cell inputs are added to the outputs.""" 187 | 188 | def __init__(self, cell, initial_memory, att_layer, out_layer, att_type='B'): 189 | """Constructs a `ResidualWrapper` for `cell`. 190 | Args: 191 | cell: An instance of `RNNCell`. 192 | residual_fn: (Optional) The function to map raw cell inputs and raw cell 193 | outputs to the actual cell outputs of the residual network. 194 | Defaults to calling nest.map_structure on (lambda i, o: i + o), inputs 195 | and outputs. 196 | """ 197 | self._cell = cell 198 | self._memory_list = [initial_memory,] 199 | self._out_layer = out_layer 200 | #self._attention_list = [initial_attention,] 201 | assert(att_type=='B' or att_type=='L') 202 | if att_type == 'B': 203 | self._att_func = _bahdanau_score 204 | else: 205 | self._att_func = _luong_score 206 | self._att_layer = att_layer 207 | 208 | @property 209 | def state_size(self): 210 | return self._cell.state_size 211 | 212 | @property 213 | def output_size(self): 214 | return self._cell.output_size 215 | 216 | def zero_state(self, batch_size, dtype): 217 | with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): 218 | return self._cell.zero_state(batch_size, dtype) 219 | 220 | def __call__(self, inputs, state, scope=None): 221 | #inputs = array_ops.concat([inputs, self._attention_list[-1]], 1) 222 | cell_outputs, new_state = self._cell(inputs, state) 223 | if self._att_layer is not None: 224 | #print('cell_outputs', cell_outputs) 225 | query = self._att_layer(cell_outputs) 226 | #print('query', query) 227 | else: 228 | query = cell_outputs 229 | memory = self._memory_list[-1] 230 | expand_query = array_ops.expand_dims(query, 1) 231 | # [batch_size, 1, depth] * [batch_size, depth, time+1] 232 | #expanded_alignments = math_ops.matmul(expand_query, memory, transpose_b=True) #[batch_size, 1, time+1] 233 | alignments = self._att_func(query, memory) 234 | expanded_alignments = array_ops.expand_dims(alignments, 1) 235 | expanded_attention = math_ops.matmul(expanded_alignments, memory) 236 | attention = array_ops.squeeze(expanded_attention, [1]) # [batch_size, depth] 237 | #print('attention', attention, 'concat', array_ops.concat([cell_outputs, attention], 1)) 238 | if self._out_layer is not None: 239 | new_outputs = self._out_layer(array_ops.concat([cell_outputs, attention], 1)) 240 | #print('new_outputs', new_outputs) 241 | else: 242 | new_outputs = cell_outputs 243 | #self._attention_list.append(attention) 244 | new_memory = array_ops.concat([memory, expand_query], 1) # [batch, time+1, depth] 245 | self._memory_list.append(new_memory) 246 | 247 | return (new_outputs, new_state) 248 | 249 | 250 | class SelfAttMulOtWrapper(RNNCell): 251 | """RNNCell wrapper that ensures cell inputs are added to the outputs.""" 252 | 253 | def __init__(self, cell, initial_memory, att_layer, out_layer, att_type='B'): 254 | """Constructs a `ResidualWrapper` for `cell`. 255 | Args: 256 | cell: An instance of `RNNCell`. 257 | residual_fn: (Optional) The function to map raw cell inputs and raw cell 258 | outputs to the actual cell outputs of the residual network. 259 | Defaults to calling nest.map_structure on (lambda i, o: i + o), inputs 260 | and outputs. 261 | """ 262 | self._cell = cell 263 | self._memory_list = [initial_memory,] 264 | self._out_layer = out_layer 265 | #self._attention_list = [initial_attention,] 266 | assert(att_type=='B' or att_type=='L') 267 | if att_type == 'B': 268 | self._att_func = _bahdanau_score 269 | else: 270 | self._att_func = _luong_score 271 | self._att_layer = att_layer 272 | 273 | @property 274 | def state_size(self): 275 | return self._cell.state_size 276 | 277 | @property 278 | def output_size(self): 279 | return self._cell.output_size 280 | 281 | def zero_state(self, batch_size, dtype): 282 | with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): 283 | return self._cell.zero_state(batch_size, dtype) 284 | 285 | def __call__(self, inputs, state, scope=None): 286 | #inputs = array_ops.concat([inputs, self._attention_list[-1]], 1) 287 | cell_outputs, new_state = self._cell(inputs, state) 288 | query = cell_outputs 289 | memory = self._memory_list[-1] 290 | expand_query = array_ops.expand_dims(query, 1) 291 | alignments = self._att_func(query, memory) 292 | expanded_alignments = array_ops.expand_dims(alignments, 1) 293 | expanded_attention = math_ops.matmul(expanded_alignments, memory) 294 | attention = array_ops.squeeze(expanded_attention, [1]) # [batch_size, depth] 295 | #print('attention', attention, 'concat', array_ops.concat([cell_outputs, attention], 1)) 296 | if self._out_layer is not None: 297 | new_outputs = self._out_layer(array_ops.concat([cell_outputs, attention], 1)) 298 | #print('new_outputs', new_outputs) 299 | else: 300 | new_outputs = cell_outputs 301 | #self._attention_list.append(attention) 302 | expand_current_memory = array_ops.expand_dims(array_ops.concat([query, inputs], 1),1) 303 | #print('expand_current_memory',expand_current_memory) 304 | if self._att_layer is not None: 305 | #print('cell_outputs', cell_outputs) 306 | expand_current_memory = self._att_layer(expand_current_memory) 307 | #print('expand_current_memory_att',expand_current_memory) 308 | #print('query', query) 309 | new_memory = array_ops.concat([memory, expand_current_memory], 1) # [batch, time+1, depth] 310 | self._memory_list.append(new_memory) 311 | 312 | return (new_outputs, new_state) -------------------------------------------------------------------------------- /Util/myAttmoLM.py: -------------------------------------------------------------------------------- 1 | import os 2 | #os.environ["CUDA_VISIBLE_DEVICES"] = "0" 3 | 4 | from tensorflow.python.layers import core as core_layers 5 | from myAttWrapper import SelfAttMulOtWrapper 6 | import tensorflow as tf 7 | import numpy as np 8 | import time 9 | import myResidualCell 10 | #import jieba 11 | from bleu import BLEU 12 | import random 13 | import pickle as cPickle 14 | import matplotlib.pyplot as plt 15 | 16 | 17 | 18 | class LM: 19 | def __init__(self, dp, rnn_size, n_layers, decoder_embedding_dim, max_infer_length, is_jieba, 20 | sess, att_type='B', lr=0.001, grad_clip=5.0, beam_width=5, force_teaching_ratio=1.0, beam_penalty=1.0, 21 | residual=False, output_keep_prob=0.5, input_keep_prob=0.9, cell_type='lstm', reverse=False, is_save=True, 22 | decay_scheme='luong234'): 23 | 24 | self.rnn_size = rnn_size 25 | self.n_layers = n_layers 26 | self.is_jieba = is_jieba 27 | self.grad_clip = grad_clip 28 | self.dp = dp 29 | self.decoder_embedding_dim = decoder_embedding_dim 30 | self.beam_width = beam_width 31 | self.beam_penalty = beam_penalty 32 | self.max_infer_length = max_infer_length 33 | self.residual = residual 34 | self.decay_scheme = decay_scheme 35 | if self.residual: 36 | assert decoder_embedding_dim == rnn_size 37 | self.reverse = reverse 38 | self.cell_type = cell_type 39 | self.force_teaching_ratio = force_teaching_ratio 40 | self._output_keep_prob = output_keep_prob 41 | self._input_keep_prob = input_keep_prob 42 | self.is_save = is_save 43 | self.sess = sess 44 | self.att_type = att_type 45 | self.lr=lr 46 | self.build_graph() 47 | self.sess.run(tf.global_variables_initializer()) 48 | self.saver = tf.train.Saver(tf.global_variables(), max_to_keep = 15) 49 | self.summary_placeholders, self.update_ops, self.summary_op = self.setup_summary() 50 | 51 | # end constructor 52 | 53 | def build_graph(self): 54 | self.register_symbols() 55 | self.add_input_layer() 56 | with tf.variable_scope('decode'): 57 | self.add_decoder_for_training() 58 | with tf.variable_scope('decode', reuse=True): 59 | self.add_decoder_for_prefix_inference() 60 | self.add_backward_path() 61 | # end method 62 | 63 | def add_input_layer(self): 64 | self.X = tf.placeholder(tf.int32, [None, None], name="X") 65 | self.X_seq_len = tf.placeholder(tf.int32, [None], name="X_seq_len") 66 | self.input_keep_prob = tf.placeholder(tf.float32,name="input_keep_prob") 67 | self.output_keep_prob = tf.placeholder(tf.float32,name="output_keep_prob") 68 | self.batch_size = tf.shape(self.X)[0] 69 | self.init_memory = tf.zeros([self.batch_size, 1, self.rnn_size]) 70 | #self.init_attention = tf.zeros([self.batch_size, self.rnn_size]) 71 | self.global_step = tf.Variable(0, name="global_step", trainable=False) 72 | # end method 73 | 74 | def single_cell(self, reuse=False): 75 | if self.cell_type == 'lstm': 76 | cell = tf.contrib.rnn.LayerNormBasicLSTMCell(self.rnn_size, reuse=reuse) 77 | else: 78 | cell = tf.contrib.rnn.GRUBlockCell(self.rnn_size) 79 | cell = tf.contrib.rnn.DropoutWrapper(cell, self.output_keep_prob, self.input_keep_prob) 80 | if self.residual: 81 | cell = myResidualCell.ResidualWrapper(cell) 82 | return cell 83 | 84 | def processed_decoder_input(self): 85 | main = tf.strided_slice(self.X, [0, 0], [self.batch_size, -1], [1, 1]) # remove last char 86 | decoder_input = tf.concat([tf.fill([self.batch_size, 1], self._x_go), main], 1) 87 | return decoder_input 88 | 89 | def add_decoder_for_training(self): 90 | self.decoder_cell = tf.nn.rnn_cell.MultiRNNCell([self.single_cell() for _ in range(1 * self.n_layers)]) 91 | self.decoder_cell = SelfAttMulOtWrapper(self.decoder_cell, self.init_memory, 92 | att_layer = core_layers.Dense(self.rnn_size, name='att_dense'), 93 | out_layer = core_layers.Dense(self.rnn_size, name='out_dense'), 94 | att_type=self.att_type) 95 | decoder_embedding = tf.get_variable('word_embedding', [len(self.dp.X_w2id), self.decoder_embedding_dim], 96 | tf.float32, tf.random_uniform_initializer(-1.0, 1.0)) 97 | training_helper = tf.contrib.seq2seq.ScheduledEmbeddingTrainingHelper( 98 | inputs = tf.nn.embedding_lookup(decoder_embedding, self.processed_decoder_input()), 99 | sequence_length = self.X_seq_len, 100 | embedding = decoder_embedding, 101 | sampling_probability = 1 - self.force_teaching_ratio, 102 | time_major = False) 103 | training_decoder = tf.contrib.seq2seq.BasicDecoder( 104 | cell = self.decoder_cell, 105 | helper = training_helper, 106 | initial_state = self.decoder_cell.zero_state(self.batch_size, tf.float32), #.clone(cell_state=self.encoder_state), 107 | output_layer = core_layers.Dense(len(self.dp.X_w2id), name='output_dense')) 108 | training_decoder_output, training_final_state, _ = tf.contrib.seq2seq.dynamic_decode( 109 | decoder = training_decoder, 110 | impute_finished = True, 111 | maximum_iterations = tf.reduce_max(self.X_seq_len)) 112 | self.training_logits = training_decoder_output.rnn_output 113 | self.init_prefix_state = training_final_state 114 | 115 | 116 | def add_decoder_for_prefix_inference(self): 117 | self.decoder_cell = tf.nn.rnn_cell.MultiRNNCell([self.single_cell() for _ in range(1 * self.n_layers)]) 118 | #self.init_attention_tiled = tf.contrib.seq2seq.tile_batch(self.init_attention, self.beam_width) 119 | self.init_memory_tiled = tf.contrib.seq2seq.tile_batch(self.init_memory, self.beam_width) 120 | 121 | self.decoder_cell = SelfAttMulOtWrapper(self.decoder_cell, 122 | self.init_memory_tiled, 123 | att_layer = core_layers.Dense(self.rnn_size, name='att_dense', _reuse=True), 124 | out_layer = core_layers.Dense(self.rnn_size, name='out_dense', _reuse=True), 125 | att_type=self.att_type) 126 | self.beam_init_state = tf.contrib.seq2seq.tile_batch(self.init_prefix_state, self.beam_width) 127 | my_decoder = tf.contrib.seq2seq.BeamSearchDecoder( 128 | cell = self.decoder_cell, 129 | embedding = tf.get_variable('word_embedding'), 130 | start_tokens = tf.tile(tf.constant([self._x_go], dtype=tf.int32), [self.batch_size]), 131 | end_token = self._x_eos, 132 | initial_state = self.beam_init_state, 133 | beam_width = self.beam_width, 134 | output_layer = core_layers.Dense(len(self.dp.X_w2id), name='output_dense', _reuse=True), 135 | length_penalty_weight = self.beam_penalty) 136 | 137 | self.prefix_go = tf.placeholder(tf.int32, [None]) 138 | prefix_go_beam = tf.tile(tf.expand_dims(self.prefix_go, 1), [1, self.beam_width]) 139 | prefix_emb = tf.nn.embedding_lookup(tf.get_variable('word_embedding'), prefix_go_beam) 140 | my_decoder._start_inputs = prefix_emb 141 | predicting_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode( 142 | decoder = my_decoder, 143 | impute_finished = False, 144 | maximum_iterations = self.max_infer_length) 145 | self.prefix_infer_outputs = predicting_decoder_output.predicted_ids 146 | self.score = predicting_decoder_output.beam_search_decoder_output.scores 147 | 148 | def add_backward_path(self): 149 | masks = tf.sequence_mask(self.X_seq_len, tf.reduce_max(self.X_seq_len), dtype=tf.float32) 150 | self.loss = tf.contrib.seq2seq.sequence_loss(logits = self.training_logits, 151 | targets = self.X, 152 | weights = masks) 153 | self.batch_loss = tf.contrib.seq2seq.sequence_loss(logits = self.training_logits, 154 | targets = self.X, 155 | weights = masks, 156 | average_across_batch=False) 157 | params = tf.trainable_variables() 158 | gradients = tf.gradients(self.loss, params) 159 | clipped_gradients, _ = tf.clip_by_global_norm(gradients, self.grad_clip) 160 | self.learning_rate = tf.constant(self.lr) 161 | self.learning_rate = self.get_learning_rate_decay(self.decay_scheme) # decay 162 | self.train_op = tf.train.AdamOptimizer(self.learning_rate).apply_gradients(zip(clipped_gradients, params), global_step=self.global_step) 163 | 164 | def register_symbols(self): 165 | self._x_go = self.dp.X_w2id[''] 166 | self._x_eos = self.dp.X_w2id[''] 167 | self._x_pad = self.dp.X_w2id[''] 168 | self._x_unk = self.dp.X_w2id[''] 169 | 170 | def infer(self, input_word, batch_size=1, is_show=True): 171 | #return ["pass"] 172 | if self.is_jieba: 173 | input_index = list(jieba.cut(input_word)) 174 | else: 175 | input_index = input_word 176 | xx = [char for char in input_index] 177 | if self.reverse: 178 | xx = xx[::-1] 179 | length = [len(xx),] * batch_size 180 | input_indices = [[self.dp.X_w2id.get(char, self._x_unk) for char in xx]] * batch_size 181 | prefix_go = [] 182 | for ipt in input_indices: 183 | prefix_go.append(ipt[-1]) 184 | out_indices, scores = self.sess.run([self.prefix_infer_outputs, self.score], { 185 | self.X: input_indices, self.X_seq_len: length, self.prefix_go: prefix_go, self.input_keep_prob:1, 186 | self.output_keep_prob:1}) 187 | outputs = [] 188 | for idx in range(out_indices.shape[-1]): 189 | eos_id = self.dp.X_w2id[''] 190 | ot = out_indices[0,:,idx] 191 | if eos_id in ot: 192 | ot = ot.tolist() 193 | ot = ot[:ot.index(eos_id)] 194 | if self.reverse: 195 | ot = ot[::-1] 196 | if self.reverse: 197 | output_str = ''.join([self.dp.X_id2w.get(i, '<-1>') for i in ot]) + input_word 198 | else: 199 | output_str = input_word+''.join([self.dp.X_id2w.get(i, '<-1>') for i in ot]) 200 | outputs.append(output_str) 201 | return outputs 202 | 203 | def batch_infer(self, input_words, is_show=True): 204 | #return ["pass"] 205 | #xx = [char for char in input_index] 206 | #if self.reverse: 207 | # xx = xx[::-1] 208 | length = [len(xx) for xx in input_words] 209 | input_indices = [[self.dp.X_w2id.get(char, self._x_unk) for char in s] for s in input_words] 210 | prefix_go = [] 211 | #print(length) 212 | for ipt in input_indices: 213 | prefix_go.append(ipt[-1]) 214 | #print(prefix_go) 215 | out_indices, scores = self.sess.run([self.prefix_infer_outputs, self.score], { 216 | self.X: input_indices, self.X_seq_len: length, self.prefix_go: prefix_go, self.input_keep_prob:1, 217 | self.output_keep_prob:1}) 218 | outputs = [] 219 | for b in range(len(input_indices)): 220 | eos_id = self.dp.X_w2id[''] 221 | ot = out_indices[b,:,0] 222 | if eos_id in ot: 223 | ot = ot.tolist() 224 | ot = ot[:ot.index(eos_id)] 225 | #if self.reverse: 226 | # ot = ot[::-1] 227 | #if self.reverse: 228 | # output_str = ''.join([self.dp.X_id2w.get(i, '<-1>') for i in ot]) + input_words[b] 229 | #else: 230 | output_str = input_words[b] +''.join([self.dp.X_id2w.get(i, '<-1>') for i in ot]) 231 | outputs.append(output_str) 232 | return outputs 233 | 234 | def restore(self, path): 235 | self.saver.restore(self.sess, path) 236 | print('restore %s success' % path) 237 | 238 | def get_learning_rate_decay(self, decay_scheme='luong234'): 239 | num_train_steps = self.dp.num_steps 240 | if decay_scheme == "luong10": 241 | start_decay_step = int(num_train_steps / 2) 242 | remain_steps = num_train_steps - start_decay_step 243 | decay_steps = int(remain_steps / 10) # decay 10 times 244 | decay_factor = 0.5 245 | else: 246 | start_decay_step = int(num_train_steps * 2 / 3) 247 | remain_steps = num_train_steps - start_decay_step 248 | decay_steps = int(remain_steps / 4) # decay 4 times 249 | decay_factor = 0.5 250 | return tf.cond( 251 | self.global_step < start_decay_step, 252 | lambda: self.learning_rate, 253 | lambda: tf.train.exponential_decay( 254 | self.learning_rate, 255 | (self.global_step - start_decay_step), 256 | decay_steps, decay_factor, staircase=True), 257 | name="learning_rate_decay_cond") 258 | 259 | def setup_summary(self): 260 | train_loss = tf.Variable(0.) 261 | tf.summary.scalar('Train_loss', train_loss) 262 | 263 | test_loss = tf.Variable(0.) 264 | tf.summary.scalar('Test_loss', test_loss) 265 | 266 | bleu_score = tf.Variable(0.) 267 | tf.summary.scalar('BLEU_score', bleu_score) 268 | 269 | tf.summary.scalar('lr_rate', self.learning_rate) 270 | 271 | summary_vars = [train_loss, test_loss, bleu_score] 272 | summary_placeholders = [tf.placeholder(tf.float32) for _ in range(len(summary_vars))] 273 | update_ops = [summary_vars[i].assign(summary_placeholders[i]) for i in range(len(summary_vars))] 274 | summary_op = tf.summary.merge_all() 275 | return summary_placeholders, update_ops, summary_op -------------------------------------------------------------------------------- /Util/myAttoLM.py: -------------------------------------------------------------------------------- 1 | import os 2 | #os.environ["CUDA_VISIBLE_DEVICES"] = "1" 3 | 4 | from tensorflow.python.layers import core as core_layers 5 | from myAttWrapper import SelfAttOtWrapper 6 | import tensorflow as tf 7 | import numpy as np 8 | import time 9 | import myResidualCell 10 | #import jieba 11 | from bleu import BLEU 12 | import random 13 | import pickle as cPickle 14 | import matplotlib.pyplot as plt 15 | 16 | 17 | class LM: 18 | def __init__(self, dp, rnn_size, n_layers, decoder_embedding_dim, max_infer_length, is_jieba, 19 | sess, att_type='B', lr=0.001, grad_clip=5.0, beam_width=5, force_teaching_ratio=1.0, beam_penalty=1.0, 20 | residual=False, output_keep_prob=0.5, input_keep_prob=0.9, cell_type='lstm', reverse=False, is_save=True, 21 | decay_scheme='luong234'): 22 | 23 | self.rnn_size = rnn_size 24 | self.n_layers = n_layers 25 | self.is_jieba = is_jieba 26 | self.grad_clip = grad_clip 27 | self.dp = dp 28 | self.decoder_embedding_dim = decoder_embedding_dim 29 | self.beam_width = beam_width 30 | self.beam_penalty = beam_penalty 31 | self.max_infer_length = max_infer_length 32 | self.residual = residual 33 | self.decay_scheme = decay_scheme 34 | if self.residual: 35 | assert decoder_embedding_dim == rnn_size 36 | self.reverse = reverse 37 | self.cell_type = cell_type 38 | self.force_teaching_ratio = force_teaching_ratio 39 | self._output_keep_prob = output_keep_prob 40 | self._input_keep_prob = input_keep_prob 41 | self.is_save = is_save 42 | self.sess = sess 43 | self.att_type = att_type 44 | self.lr=lr 45 | self.build_graph() 46 | self.sess.run(tf.global_variables_initializer()) 47 | self.saver = tf.train.Saver(tf.global_variables(), max_to_keep = 15) 48 | self.summary_placeholders, self.update_ops, self.summary_op = self.setup_summary() 49 | 50 | # end constructor 51 | 52 | def build_graph(self): 53 | self.register_symbols() 54 | self.add_input_layer() 55 | with tf.variable_scope('decode'): 56 | self.add_decoder_for_training() 57 | with tf.variable_scope('decode', reuse=True): 58 | self.add_decoder_for_prefix_inference() 59 | self.add_backward_path() 60 | # end method 61 | 62 | def add_input_layer(self): 63 | self.X = tf.placeholder(tf.int32, [None, None], name="X") 64 | self.X_seq_len = tf.placeholder(tf.int32, [None], name="X_seq_len") 65 | self.input_keep_prob = tf.placeholder(tf.float32,name="input_keep_prob") 66 | self.output_keep_prob = tf.placeholder(tf.float32,name="output_keep_prob") 67 | self.batch_size = tf.shape(self.X)[0] 68 | self.init_memory = tf.zeros([self.batch_size, 1, self.rnn_size]) 69 | #self.init_attention = tf.zeros([self.batch_size, self.rnn_size]) 70 | self.global_step = tf.Variable(0, name="global_step", trainable=False) 71 | # end method 72 | 73 | def single_cell(self, reuse=False): 74 | if self.cell_type == 'lstm': 75 | cell = tf.contrib.rnn.LayerNormBasicLSTMCell(self.rnn_size, reuse=reuse) 76 | else: 77 | cell = tf.contrib.rnn.GRUBlockCell(self.rnn_size) 78 | cell = tf.contrib.rnn.DropoutWrapper(cell, self.output_keep_prob, self.input_keep_prob) 79 | if self.residual: 80 | cell = myResidualCell.ResidualWrapper(cell) 81 | return cell 82 | 83 | def processed_decoder_input(self): 84 | main = tf.strided_slice(self.X, [0, 0], [self.batch_size, -1], [1, 1]) # remove last char 85 | decoder_input = tf.concat([tf.fill([self.batch_size, 1], self._x_go), main], 1) 86 | return decoder_input 87 | 88 | def add_decoder_for_training(self): 89 | self.decoder_cell = tf.nn.rnn_cell.MultiRNNCell([self.single_cell() for _ in range(1 * self.n_layers)]) 90 | self.decoder_cell = SelfAttOtWrapper(self.decoder_cell, self.init_memory, 91 | att_layer = core_layers.Dense(self.rnn_size, name='att_dense'), 92 | out_layer = core_layers.Dense(self.rnn_size, name='out_dense'), 93 | att_type=self.att_type) 94 | decoder_embedding = tf.get_variable('word_embedding', [len(self.dp.X_w2id), self.decoder_embedding_dim], 95 | tf.float32, tf.random_uniform_initializer(-1.0, 1.0)) 96 | training_helper = tf.contrib.seq2seq.ScheduledEmbeddingTrainingHelper( 97 | inputs = tf.nn.embedding_lookup(decoder_embedding, self.processed_decoder_input()), 98 | sequence_length = self.X_seq_len, 99 | embedding = decoder_embedding, 100 | sampling_probability = 1 - self.force_teaching_ratio, 101 | time_major = False) 102 | training_decoder = tf.contrib.seq2seq.BasicDecoder( 103 | cell = self.decoder_cell, 104 | helper = training_helper, 105 | initial_state = self.decoder_cell.zero_state(self.batch_size, tf.float32), #.clone(cell_state=self.encoder_state), 106 | output_layer = core_layers.Dense(len(self.dp.X_w2id), name='output_dense')) 107 | training_decoder_output, training_final_state, _ = tf.contrib.seq2seq.dynamic_decode( 108 | decoder = training_decoder, 109 | impute_finished = True, 110 | maximum_iterations = tf.reduce_max(self.X_seq_len)) 111 | self.training_logits = training_decoder_output.rnn_output 112 | self.init_prefix_state = training_final_state 113 | 114 | 115 | def add_decoder_for_prefix_inference(self): 116 | self.decoder_cell = tf.nn.rnn_cell.MultiRNNCell([self.single_cell() for _ in range(1 * self.n_layers)]) 117 | #self.init_attention_tiled = tf.contrib.seq2seq.tile_batch(self.init_attention, self.beam_width) 118 | self.init_memory_tiled = tf.contrib.seq2seq.tile_batch(self.init_memory, self.beam_width) 119 | 120 | self.decoder_cell = SelfAttOtWrapper(self.decoder_cell, 121 | self.init_memory_tiled, 122 | att_layer = core_layers.Dense(self.rnn_size, name='att_dense', _reuse=True), 123 | out_layer = core_layers.Dense(self.rnn_size, name='out_dense', _reuse=True), 124 | att_type=self.att_type) 125 | self.beam_init_state = tf.contrib.seq2seq.tile_batch(self.init_prefix_state, self.beam_width) 126 | my_decoder = tf.contrib.seq2seq.BeamSearchDecoder( 127 | cell = self.decoder_cell, 128 | embedding = tf.get_variable('word_embedding'), 129 | start_tokens = tf.tile(tf.constant([self._x_go], dtype=tf.int32), [self.batch_size]), 130 | end_token = self._x_eos, 131 | initial_state = self.beam_init_state, 132 | beam_width = self.beam_width, 133 | output_layer = core_layers.Dense(len(self.dp.X_w2id), name='output_dense', _reuse=True), 134 | length_penalty_weight = self.beam_penalty) 135 | 136 | self.prefix_go = tf.placeholder(tf.int32, [None]) 137 | prefix_go_beam = tf.tile(tf.expand_dims(self.prefix_go, 1), [1, self.beam_width]) 138 | prefix_emb = tf.nn.embedding_lookup(tf.get_variable('word_embedding'), prefix_go_beam) 139 | my_decoder._start_inputs = prefix_emb 140 | predicting_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode( 141 | decoder = my_decoder, 142 | impute_finished = False, 143 | maximum_iterations = self.max_infer_length) 144 | self.prefix_infer_outputs = predicting_decoder_output.predicted_ids 145 | self.score = predicting_decoder_output.beam_search_decoder_output.scores 146 | 147 | def add_backward_path(self): 148 | masks = tf.sequence_mask(self.X_seq_len, tf.reduce_max(self.X_seq_len), dtype=tf.float32) 149 | self.loss = tf.contrib.seq2seq.sequence_loss(logits = self.training_logits, 150 | targets = self.X, 151 | weights = masks) 152 | self.batch_loss = tf.contrib.seq2seq.sequence_loss(logits = self.training_logits, 153 | targets = self.X, 154 | weights = masks, 155 | average_across_batch=False) 156 | params = tf.trainable_variables() 157 | gradients = tf.gradients(self.loss, params) 158 | clipped_gradients, _ = tf.clip_by_global_norm(gradients, self.grad_clip) 159 | self.learning_rate = tf.constant(self.lr) 160 | self.learning_rate = self.get_learning_rate_decay(self.decay_scheme) # decay 161 | self.train_op = tf.train.AdamOptimizer(self.learning_rate).apply_gradients(zip(clipped_gradients, params), global_step=self.global_step) 162 | 163 | def register_symbols(self): 164 | self._x_go = self.dp.X_w2id[''] 165 | self._x_eos = self.dp.X_w2id[''] 166 | self._x_pad = self.dp.X_w2id[''] 167 | self._x_unk = self.dp.X_w2id[''] 168 | 169 | def infer(self, input_word, batch_size=1, is_show=True): 170 | #return ["pass"] 171 | if self.is_jieba: 172 | input_index = list(jieba.cut(input_word)) 173 | else: 174 | input_index = input_word 175 | xx = [char for char in input_index] 176 | if self.reverse: 177 | xx = xx[::-1] 178 | length = [len(xx),] * batch_size 179 | input_indices = [[self.dp.X_w2id.get(char, self._x_unk) for char in xx]] * batch_size 180 | prefix_go = [] 181 | for ipt in input_indices: 182 | prefix_go.append(ipt[-1]) 183 | out_indices, scores = self.sess.run([self.prefix_infer_outputs, self.score], { 184 | self.X: input_indices, self.X_seq_len: length, self.prefix_go: prefix_go, self.input_keep_prob:1, 185 | self.output_keep_prob:1}) 186 | outputs = [] 187 | for idx in range(out_indices.shape[-1]): 188 | eos_id = self.dp.X_w2id[''] 189 | ot = out_indices[0,:,idx] 190 | if eos_id in ot: 191 | ot = ot.tolist() 192 | ot = ot[:ot.index(eos_id)] 193 | if self.reverse: 194 | ot = ot[::-1] 195 | if self.reverse: 196 | output_str = ''.join([self.dp.X_id2w.get(i, '<-1>') for i in ot]) + input_word 197 | else: 198 | output_str = input_word+''.join([self.dp.X_id2w.get(i, '<-1>') for i in ot]) 199 | outputs.append(output_str) 200 | return outputs 201 | 202 | def batch_infer(self, input_words, is_show=True): 203 | #return ["pass"] 204 | #xx = [char for char in input_index] 205 | #if self.reverse: 206 | # xx = xx[::-1] 207 | length = [len(xx) for xx in input_words] 208 | input_indices = [[self.dp.X_w2id.get(char, self._x_unk) for char in s] for s in input_words] 209 | prefix_go = [] 210 | #print(length) 211 | for ipt in input_indices: 212 | prefix_go.append(ipt[-1]) 213 | #print(prefix_go) 214 | out_indices, scores = self.sess.run([self.prefix_infer_outputs, self.score], { 215 | self.X: input_indices, self.X_seq_len: length, self.prefix_go: prefix_go, self.input_keep_prob:1, 216 | self.output_keep_prob:1}) 217 | outputs = [] 218 | for b in range(len(input_indices)): 219 | eos_id = self.dp.X_w2id[''] 220 | ot = out_indices[b,:,0] 221 | if eos_id in ot: 222 | ot = ot.tolist() 223 | ot = ot[:ot.index(eos_id)] 224 | #if self.reverse: 225 | # ot = ot[::-1] 226 | #if self.reverse: 227 | # output_str = ''.join([self.dp.X_id2w.get(i, '<-1>') for i in ot]) + input_words[b] 228 | #else: 229 | output_str = input_words[b] +''.join([self.dp.X_id2w.get(i, '<-1>') for i in ot]) 230 | outputs.append(output_str) 231 | return outputs 232 | 233 | def restore(self, path): 234 | self.saver.restore(self.sess, path) 235 | print('restore %s success' % path) 236 | 237 | def get_learning_rate_decay(self, decay_scheme='luong234'): 238 | num_train_steps = self.dp.num_steps 239 | if decay_scheme == "luong10": 240 | start_decay_step = int(num_train_steps / 2) 241 | remain_steps = num_train_steps - start_decay_step 242 | decay_steps = int(remain_steps / 10) # decay 10 times 243 | decay_factor = 0.5 244 | else: 245 | start_decay_step = int(num_train_steps * 2 / 3) 246 | remain_steps = num_train_steps - start_decay_step 247 | decay_steps = int(remain_steps / 4) # decay 4 times 248 | decay_factor = 0.5 249 | return tf.cond( 250 | self.global_step < start_decay_step, 251 | lambda: self.learning_rate, 252 | lambda: tf.train.exponential_decay( 253 | self.learning_rate, 254 | (self.global_step - start_decay_step), 255 | decay_steps, decay_factor, staircase=True), 256 | name="learning_rate_decay_cond") 257 | 258 | def setup_summary(self): 259 | train_loss = tf.Variable(0.) 260 | tf.summary.scalar('Train_loss', train_loss) 261 | 262 | test_loss = tf.Variable(0.) 263 | tf.summary.scalar('Test_loss', test_loss) 264 | 265 | bleu_score = tf.Variable(0.) 266 | tf.summary.scalar('BLEU_score', bleu_score) 267 | 268 | tf.summary.scalar('lr_rate', self.learning_rate) 269 | 270 | summary_vars = [train_loss, test_loss, bleu_score] 271 | summary_placeholders = [tf.placeholder(tf.float32) for _ in range(len(summary_vars))] 272 | update_ops = [summary_vars[i].assign(summary_placeholders[i]) for i in range(len(summary_vars))] 273 | summary_op = tf.summary.merge_all() 274 | return summary_placeholders, update_ops, summary_op -------------------------------------------------------------------------------- /Util/myLM.py: -------------------------------------------------------------------------------- 1 | import os 2 | #os.environ["CUDA_VISIBLE_DEVICES"] = "1" 3 | 4 | from tensorflow.python.layers import core as core_layers 5 | import tensorflow as tf 6 | import numpy as np 7 | import time 8 | import myResidualCell 9 | #import jieba 10 | from bleu import BLEU 11 | import random 12 | import pickle as cPickle 13 | import matplotlib.pyplot as plt 14 | 15 | 16 | class LM: 17 | def __init__(self, dp, rnn_size, n_layers, decoder_embedding_dim, max_infer_length, is_jieba, 18 | sess, lr=0.001, grad_clip=5.0, beam_width=5, force_teaching_ratio=1.0, beam_penalty=1.0, 19 | residual=False, output_keep_prob=0.5, input_keep_prob=0.9, cell_type='lstm', reverse=False, is_save=True, 20 | decay_scheme='luong234'): 21 | 22 | self.rnn_size = rnn_size 23 | self.n_layers = n_layers 24 | self.is_jieba = is_jieba 25 | self.grad_clip = grad_clip 26 | self.dp = dp 27 | self.decoder_embedding_dim = decoder_embedding_dim 28 | self.beam_width = beam_width 29 | self.beam_penalty = beam_penalty 30 | self.max_infer_length = max_infer_length 31 | self.residual = residual 32 | self.decay_scheme = decay_scheme 33 | if self.residual: 34 | assert decoder_embedding_dim == rnn_size 35 | self.reverse = reverse 36 | self.cell_type = cell_type 37 | self.force_teaching_ratio = force_teaching_ratio 38 | self._output_keep_prob = output_keep_prob 39 | self._input_keep_prob = input_keep_prob 40 | self.is_save = is_save 41 | self.sess = sess 42 | self.lr=lr 43 | self.build_graph() 44 | self.sess.run(tf.global_variables_initializer()) 45 | self.saver = tf.train.Saver(tf.global_variables(), max_to_keep = 15) 46 | self.summary_placeholders, self.update_ops, self.summary_op = self.setup_summary() 47 | 48 | # end constructor 49 | 50 | def build_graph(self): 51 | self.register_symbols() 52 | self.add_input_layer() 53 | with tf.variable_scope('decode'): 54 | self.add_decoder_for_training() 55 | with tf.variable_scope('decode', reuse=True): 56 | self.add_decoder_for_prefix_inference() 57 | self.add_backward_path() 58 | # end method 59 | 60 | def add_input_layer(self): 61 | self.X = tf.placeholder(tf.int32, [None, None], name="X") 62 | self.X_seq_len = tf.placeholder(tf.int32, [None], name="X_seq_len") 63 | self.input_keep_prob = tf.placeholder(tf.float32,name="input_keep_prob") 64 | self.output_keep_prob = tf.placeholder(tf.float32,name="output_keep_prob") 65 | self.batch_size = tf.shape(self.X)[0] 66 | self.global_step = tf.Variable(0, name="global_step", trainable=False) 67 | # end method 68 | 69 | def single_cell(self, reuse=False): 70 | if self.cell_type == 'lstm': 71 | cell = tf.contrib.rnn.LayerNormBasicLSTMCell(self.rnn_size, reuse=reuse) 72 | else: 73 | cell = tf.contrib.rnn.GRUBlockCell(self.rnn_size) 74 | cell = tf.contrib.rnn.DropoutWrapper(cell, self.output_keep_prob, self.input_keep_prob) 75 | if self.residual: 76 | cell = myResidualCell.ResidualWrapper(cell) 77 | return cell 78 | 79 | def processed_decoder_input(self): 80 | main = tf.strided_slice(self.X, [0, 0], [self.batch_size, -1], [1, 1]) # remove last char 81 | decoder_input = tf.concat([tf.fill([self.batch_size, 1], self._x_go), main], 1) 82 | return decoder_input 83 | 84 | def add_decoder_for_training(self): 85 | self.decoder_cell = tf.nn.rnn_cell.MultiRNNCell([self.single_cell() for _ in range(1 * self.n_layers)]) 86 | decoder_embedding = tf.get_variable('word_embedding', [len(self.dp.X_w2id), self.decoder_embedding_dim], 87 | tf.float32, tf.random_uniform_initializer(-1.0, 1.0)) 88 | training_helper = tf.contrib.seq2seq.ScheduledEmbeddingTrainingHelper( 89 | inputs = tf.nn.embedding_lookup(decoder_embedding, self.processed_decoder_input()), 90 | sequence_length = self.X_seq_len, 91 | embedding = decoder_embedding, 92 | sampling_probability = 1 - self.force_teaching_ratio, 93 | time_major = False) 94 | training_decoder = tf.contrib.seq2seq.BasicDecoder( 95 | cell = self.decoder_cell, 96 | helper = training_helper, 97 | initial_state = self.decoder_cell.zero_state(self.batch_size, tf.float32), #.clone(cell_state=self.encoder_state), 98 | output_layer = core_layers.Dense(len(self.dp.X_w2id))) 99 | training_decoder_output, training_final_state, _ = tf.contrib.seq2seq.dynamic_decode( 100 | decoder = training_decoder, 101 | impute_finished = True, 102 | maximum_iterations = tf.reduce_max(self.X_seq_len)) 103 | self.training_logits = training_decoder_output.rnn_output 104 | self.init_prefix_state = training_final_state 105 | 106 | 107 | def add_decoder_for_prefix_inference(self): 108 | self.beam_init_state = tf.contrib.seq2seq.tile_batch(self.init_prefix_state, self.beam_width) 109 | my_decoder = tf.contrib.seq2seq.BeamSearchDecoder( 110 | cell = self.decoder_cell, 111 | embedding = tf.get_variable('word_embedding'), 112 | start_tokens = tf.tile(tf.constant([self._x_go], dtype=tf.int32), [self.batch_size]), 113 | end_token = self._x_eos, 114 | initial_state = self.beam_init_state, 115 | beam_width = self.beam_width, 116 | output_layer = core_layers.Dense(len(self.dp.X_w2id), _reuse=True), 117 | length_penalty_weight = self.beam_penalty) 118 | 119 | self.prefix_go = tf.placeholder(tf.int32, [None]) 120 | prefix_go_beam = tf.tile(tf.expand_dims(self.prefix_go, 1), [1, self.beam_width]) 121 | prefix_emb = tf.nn.embedding_lookup(tf.get_variable('word_embedding'), prefix_go_beam) 122 | my_decoder._start_inputs = prefix_emb 123 | predicting_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode( 124 | decoder = my_decoder, 125 | impute_finished = False, 126 | maximum_iterations = self.max_infer_length) 127 | self.prefix_infer_outputs = predicting_decoder_output.predicted_ids 128 | self.score = predicting_decoder_output.beam_search_decoder_output.scores 129 | 130 | def add_backward_path(self): 131 | masks = tf.sequence_mask(self.X_seq_len, tf.reduce_max(self.X_seq_len), dtype=tf.float32) 132 | self.loss = tf.contrib.seq2seq.sequence_loss(logits = self.training_logits, 133 | targets = self.X, 134 | weights = masks) 135 | self.batch_loss = tf.contrib.seq2seq.sequence_loss(logits = self.training_logits, 136 | targets = self.X, 137 | weights = masks, 138 | average_across_batch=False) 139 | params = tf.trainable_variables() 140 | gradients = tf.gradients(self.loss, params) 141 | clipped_gradients, _ = tf.clip_by_global_norm(gradients, self.grad_clip) 142 | self.learning_rate = tf.constant(self.lr) 143 | self.learning_rate = self.get_learning_rate_decay(self.decay_scheme) # decay 144 | self.train_op = tf.train.AdamOptimizer(self.learning_rate).apply_gradients(zip(clipped_gradients, params), global_step=self.global_step) 145 | 146 | def register_symbols(self): 147 | self._x_go = self.dp.X_w2id[''] 148 | self._x_eos = self.dp.X_w2id[''] 149 | self._x_pad = self.dp.X_w2id[''] 150 | self._x_unk = self.dp.X_w2id[''] 151 | 152 | def infer(self, input_word, batch_size=1, is_show=True): 153 | if self.is_jieba: 154 | input_index = list(jieba.cut(input_word)) 155 | else: 156 | input_index = input_word 157 | xx = [char for char in input_index] 158 | if self.reverse: 159 | xx = xx[::-1] 160 | length = [len(xx),] * batch_size 161 | input_indices = [[self.dp.X_w2id.get(char, self._x_unk) for char in xx]] * batch_size 162 | prefix_go = [] 163 | for ipt in input_indices: 164 | prefix_go.append(ipt[-1]) 165 | out_indices, scores = self.sess.run([self.prefix_infer_outputs, self.score], { 166 | self.X: input_indices, self.X_seq_len: length, self.prefix_go: prefix_go, self.input_keep_prob:1, 167 | self.output_keep_prob:1}) 168 | outputs = [] 169 | for idx in range(out_indices.shape[-1]): 170 | eos_id = self.dp.X_w2id[''] 171 | ot = out_indices[0,:,idx] 172 | if eos_id in ot: 173 | ot = ot.tolist() 174 | ot = ot[:ot.index(eos_id)] 175 | if self.reverse: 176 | ot = ot[::-1] 177 | if self.reverse: 178 | output_str = ''.join([self.dp.X_id2w.get(i, '<-1>') for i in ot]) + input_word 179 | else: 180 | output_str = input_word+''.join([self.dp.X_id2w.get(i, '<-1>') for i in ot]) 181 | outputs.append(output_str) 182 | return outputs 183 | 184 | def batch_infer(self, input_words, is_show=True): 185 | #return ["pass"] 186 | #xx = [char for char in input_index] 187 | #if self.reverse: 188 | # xx = xx[::-1] 189 | length = [len(xx) for xx in input_words] 190 | input_indices = [[self.dp.X_w2id.get(char, self._x_unk) for char in s] for s in input_words] 191 | prefix_go = [] 192 | #print(length) 193 | for ipt in input_indices: 194 | prefix_go.append(ipt[-1]) 195 | #print(prefix_go) 196 | out_indices, scores = self.sess.run([self.prefix_infer_outputs, self.score], { 197 | self.X: input_indices, self.X_seq_len: length, self.prefix_go: prefix_go, self.input_keep_prob:1, 198 | self.output_keep_prob:1}) 199 | outputs = [] 200 | for b in range(len(input_indices)): 201 | eos_id = self.dp.X_w2id[''] 202 | ot = out_indices[b,:,0] 203 | if eos_id in ot: 204 | ot = ot.tolist() 205 | ot = ot[:ot.index(eos_id)] 206 | #if self.reverse: 207 | # ot = ot[::-1] 208 | #if self.reverse: 209 | # output_str = ''.join([self.dp.X_id2w.get(i, '<-1>') for i in ot]) + input_words[b] 210 | #else: 211 | output_str = input_words[b] +''.join([self.dp.X_id2w.get(i, '<-1>') for i in ot]) 212 | outputs.append(output_str) 213 | return outputs 214 | 215 | def restore(self, path): 216 | self.saver.restore(self.sess, path) 217 | print('restore %s success' % path) 218 | 219 | def get_learning_rate_decay(self, decay_scheme='luong234'): 220 | num_train_steps = self.dp.num_steps 221 | if decay_scheme == "luong10": 222 | start_decay_step = int(num_train_steps / 2) 223 | remain_steps = num_train_steps - start_decay_step 224 | decay_steps = int(remain_steps / 10) # decay 10 times 225 | decay_factor = 0.5 226 | else: 227 | start_decay_step = int(num_train_steps * 2 / 3) 228 | remain_steps = num_train_steps - start_decay_step 229 | decay_steps = int(remain_steps / 4) # decay 4 times 230 | decay_factor = 0.5 231 | return tf.cond( 232 | self.global_step < start_decay_step, 233 | lambda: self.learning_rate, 234 | lambda: tf.train.exponential_decay( 235 | self.learning_rate, 236 | (self.global_step - start_decay_step), 237 | decay_steps, decay_factor, staircase=True), 238 | name="learning_rate_decay_cond") 239 | 240 | def setup_summary(self): 241 | train_loss = tf.Variable(0.) 242 | tf.summary.scalar('Train_loss', train_loss) 243 | 244 | test_loss = tf.Variable(0.) 245 | tf.summary.scalar('Test_loss', test_loss) 246 | 247 | bleu_score = tf.Variable(0.) 248 | tf.summary.scalar('BLEU_score', bleu_score) 249 | 250 | tf.summary.scalar('lr_rate', self.learning_rate) 251 | 252 | summary_vars = [train_loss, test_loss, bleu_score] 253 | summary_placeholders = [tf.placeholder(tf.float32) for _ in range(len(summary_vars))] 254 | update_ops = [summary_vars[i].assign(summary_placeholders[i]) for i in range(len(summary_vars))] 255 | summary_op = tf.summary.merge_all() 256 | return summary_placeholders, update_ops, summary_op -------------------------------------------------------------------------------- /Util/myResidualCell.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Module implementing RNN Cells. 16 | This module provides a number of basic commonly used RNN cells, such as LSTM 17 | (Long Short Term Memory) or GRU (Gated Recurrent Unit), and a number of 18 | operators that allow adding dropouts, projections, or embeddings for inputs. 19 | Constructing multi-layer cells is supported by the class `MultiRNNCell`, or by 20 | calling the `rnn` ops several times. 21 | """ 22 | import tensorflow as tf 23 | 24 | from tensorflow.python.framework import ops 25 | from tensorflow.python.framework import tensor_shape 26 | from tensorflow.python.framework import tensor_util 27 | from tensorflow.python.layers import base as base_layer 28 | from tensorflow.python.ops import array_ops 29 | from tensorflow.python.ops import clip_ops 30 | from tensorflow.python.ops import init_ops 31 | from tensorflow.python.ops import math_ops 32 | from tensorflow.python.ops import nn_ops 33 | from tensorflow.python.ops import partitioned_variables 34 | from tensorflow.python.ops import random_ops 35 | from tensorflow.python.ops import tensor_array_ops 36 | from tensorflow.python.ops import variable_scope as vs 37 | from tensorflow.python.ops import variables as tf_variables 38 | from tensorflow.python.platform import tf_logging as logging 39 | from tensorflow.python.util import nest 40 | 41 | 42 | def gnmt_residual_fn(inputs, outputs): 43 | """Residual function that handles different inputs and outputs inner dims. 44 | Args: 45 | inputs: cell inputs, this is actual inputs concatenated with the attention 46 | vector. 47 | outputs: cell outputs 48 | Returns: 49 | outputs + actual inputs 50 | """ 51 | def split_input(inp, out): 52 | out_dim = out.get_shape().as_list()[-1] 53 | inp_dim = inp.get_shape().as_list()[-1] 54 | return tf.split(inp, [out_dim, inp_dim - out_dim], axis=1) 55 | actual_inputs, _ = nest.map_structure(split_input, inputs, outputs) 56 | 57 | def assert_shape_match(inp, out): 58 | inp.get_shape().assert_is_compatible_with(out.get_shape()) 59 | nest.assert_same_structure(actual_inputs, outputs) 60 | nest.map_structure(assert_shape_match, actual_inputs, outputs) 61 | return nest.map_structure(lambda inp, out: inp + out, actual_inputs, outputs) 62 | 63 | class ResidualWrapper(tf.contrib.rnn.RNNCell): 64 | """RNNCell wrapper that ensures cell inputs are added to the outputs.""" 65 | 66 | def __init__(self, cell): 67 | """Constructs a `ResidualWrapper` for `cell`. 68 | Args: 69 | cell: An instance of `RNNCell`. 70 | residual_fn: (Optional) The function to map raw cell inputs and raw cell 71 | outputs to the actual cell outputs of the residual network. 72 | Defaults to calling nest.map_structure on (lambda i, o: i + o), inputs 73 | and outputs. 74 | """ 75 | self._cell = cell 76 | self._residual_fn = gnmt_residual_fn 77 | 78 | @property 79 | def state_size(self): 80 | return self._cell.state_size 81 | 82 | @property 83 | def output_size(self): 84 | return self._cell.output_size 85 | 86 | def zero_state(self, batch_size, dtype): 87 | with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): 88 | return self._cell.zero_state(batch_size, dtype) 89 | 90 | def __call__(self, inputs, state, scope=None): 91 | """Run the cell and then apply the residual_fn on its inputs to its outputs. 92 | Args: 93 | inputs: cell inputs. 94 | state: cell state. 95 | scope: optional cell scope. 96 | Returns: 97 | Tuple of cell outputs and new state. 98 | Raises: 99 | TypeError: If cell inputs and outputs have different structure (type). 100 | ValueError: If cell inputs and outputs have different structure (value). 101 | """ 102 | outputs, new_state = self._cell(inputs, state, scope=scope) 103 | # Ensure shapes match 104 | def assert_shape_match(inp, out): 105 | inp.get_shape().assert_is_compatible_with(out.get_shape()) 106 | def default_residual_fn(inputs, outputs): 107 | nest.assert_same_structure(inputs, outputs) 108 | nest.map_structure(assert_shape_match, inputs, outputs) 109 | return nest.map_structure(lambda inp, out: inp + out, inputs, outputs) 110 | res_outputs = (self._residual_fn or default_residual_fn)(inputs, outputs) 111 | return (res_outputs, new_state) 112 | -------------------------------------------------------------------------------- /Util/myUtil.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import random 4 | import time 5 | import os 6 | import pickle as cPickle 7 | 8 | class LM_DP: 9 | def __init__(self, X_indices, X_w2id, BATCH_SIZE, n_epoch): 10 | num_test = int(len(X_indices) * 0.1) 11 | self.n_epoch = n_epoch 12 | self.X_train = np.array(X_indices[num_test:]) 13 | self.X_test = np.array(X_indices[:num_test]) 14 | self.num_batch = int(len(self.X_train) / BATCH_SIZE) 15 | self.num_steps = self.num_batch * self.n_epoch 16 | self.batch_size = BATCH_SIZE 17 | self.X_w2id = X_w2id 18 | self.X_id2w = dict(zip(X_w2id.values(), X_w2id.keys())) 19 | self._x_pad = self.X_w2id[''] 20 | print('Train_data: %d | Test_data: %d | Batch_size: %d | Num_batch: %d | X_vocab_size: %d ' % (len(self.X_train), len(self.X_test), BATCH_SIZE, self.num_batch, len(self.X_w2id))) 21 | 22 | def next_batch(self, X): 23 | r = np.random.permutation(len(X)) 24 | X = X[r] 25 | for i in range(0, len(X) - len(X) % self.batch_size, self.batch_size): 26 | X_batch = X[i : i + self.batch_size] 27 | padded_X_batch, X_batch_lens = self.pad_sentence_batch(X_batch, self._x_pad) 28 | yield (np.array(padded_X_batch), 29 | X_batch_lens) 30 | 31 | def sample_test_batch(self): 32 | padded_X_batch, X_batch_lens = self.pad_sentence_batch(self.X_test[: self.batch_size], self._x_pad) 33 | return np.array(padded_X_batch), X_batch_lens 34 | 35 | def pad_sentence_batch(self, sentence_batch, pad_int): 36 | padded_seqs = [] 37 | seq_lens = [] 38 | max_sentence_len = max([len(sentence) for sentence in sentence_batch]) 39 | for sentence in sentence_batch: 40 | padded_seqs.append(sentence + [pad_int] * (max_sentence_len - len(sentence))) 41 | seq_lens.append(len(sentence)) 42 | return padded_seqs, seq_lens 43 | 44 | 45 | class LM_util: 46 | def __init__(self, dp, model, display_freq=3): 47 | self.display_freq = display_freq 48 | self.dp = dp 49 | self.model = model 50 | 51 | def train(self, epoch): 52 | avg_loss = 0.0 53 | tic = time.time() 54 | X_test_batch, X_test_batch_lens = self.dp.sample_test_batch() 55 | for local_step, (X_train_batch, X_train_batch_lens) in enumerate( 56 | self.dp.next_batch(self.dp.X_train)): 57 | self.model.step, _, loss = self.model.sess.run([self.model.global_step, self.model.train_op, self.model.loss], 58 | {self.model.X: X_train_batch, 59 | self.model.X_seq_len: X_train_batch_lens, 60 | self.model.output_keep_prob:self.model._output_keep_prob, 61 | self.model.input_keep_prob:self.model._input_keep_prob}) 62 | avg_loss += loss 63 | """ 64 | stats = [loss] 65 | for i in xrange(len(stats)): 66 | self.model.sess.run(self.model.update_ops[i], feed_dict={ 67 | self.model.summary_placeholders[i]: float(stats[i]) 68 | }) 69 | summary_str = self.model.sess.run([self.model.summary_op]) 70 | self.summary_writer.add_summary(summary_str, self.model.step + 1) 71 | """ 72 | if local_step % (self.dp.num_batch / self.display_freq) == 0: 73 | val_loss = self.model.sess.run(self.model.loss, {self.model.X: X_test_batch, 74 | self.model.X_seq_len: X_test_batch_lens, 75 | self.model.output_keep_prob:1, 76 | self.model.input_keep_prob:1}) 77 | print("Epoch %d/%d | Batch %d/%d | Train_loss: %.3f | Test_loss: %.3f | Time_cost:%.3f" % (epoch, self.n_epoch, local_step, self.dp.num_batch, avg_loss / (local_step + 1), val_loss, time.time()-tic)) 78 | self.cal() 79 | tic = time.time() 80 | return avg_loss / self.dp.num_batch 81 | 82 | def test(self): 83 | avg_loss = 0.0 84 | local_step = 0 85 | for local_step, (X_test_batch, X_test_batch_lens) in enumerate( 86 | self.dp.next_batch(self.dp.X_test)): 87 | val_loss = self.model.sess.run(self.model.loss, {self.model.X: X_test_batch, 88 | self.model.X_seq_len: X_test_batch_lens, 89 | self.model.output_keep_prob:1, 90 | self.model.input_keep_prob:1}) 91 | avg_loss += val_loss 92 | return avg_loss / (local_step + 1) 93 | 94 | def fit(self, train_dir, is_bleu, init_epoch=0): 95 | self.n_epoch = self.dp.n_epoch 96 | test_loss_list = [] 97 | train_loss_list = [] 98 | time_cost_list = [] 99 | bleu_list = [] 100 | #timestamp = str(int(time.time())) 101 | #out_dir = os.path.abspath(os.path.join(train_dir, "runs", timestamp)) 102 | out_dir = train_dir 103 | if not os.path.exists(out_dir): 104 | os.makedirs(out_dir) 105 | print("Writing to %s" % out_dir) 106 | checkpoint_prefix = os.path.join(out_dir, "model") 107 | self.summary_writer = tf.summary.FileWriter(os.path.join(out_dir, 'Summary'), self.model.sess.graph) 108 | for epoch in range(init_epoch, init_epoch+self.n_epoch+1): 109 | tic = time.time() 110 | train_loss = self.train(epoch) 111 | train_loss_list.append(train_loss) 112 | test_loss = self.test() 113 | test_loss_list.append(test_loss) 114 | toc = time.time() 115 | time_cost_list.append((toc - tic)) 116 | if is_bleu: 117 | bleu = self.test_bleu() 118 | bleu_list.append(bleu) 119 | print("Epoch %d/%d | Train_loss: %.3f | Test_loss: %.3f | Bleu: %.3f" % (epoch, self.n_epoch, train_loss, test_loss, bleu)) 120 | else: 121 | bleu = 0.0 122 | print("Epoch %d/%d | Train_loss: %.3f | Test_loss: %.3f" % (epoch, self.n_epoch, train_loss, test_loss)) 123 | 124 | print('============================================') 125 | stats = [train_loss, test_loss, bleu] 126 | for i in range(len(stats)): 127 | self.model.sess.run(self.model.update_ops[i], feed_dict={ 128 | self.model.summary_placeholders[i]: float(stats[i]) 129 | }) 130 | summary_str = self.model.sess.run(self.model.summary_op) 131 | self.summary_writer.add_summary(summary_str, epoch) 132 | if self.model.is_save: 133 | cPickle.dump((train_loss_list, test_loss_list, time_cost_list, bleu_list), open(os.path.join(out_dir,"res.pkl"),'wb')) 134 | path = self.model.saver.save(self.model.sess, checkpoint_prefix, global_step=epoch) 135 | print("Saved model checkpoint to %s" % path) 136 | 137 | def show(self, sent, id2w): 138 | if self.model.reverse: 139 | return " ".join([id2w.get(idx, u'&') for idx in sent])[::-1] 140 | else: 141 | return " ".join([id2w.get(idx, u'&') for idx in sent]) 142 | 143 | def cal(self, n_example=5): 144 | train_n_example = int(n_example / 2) 145 | test_n_example = n_example - train_n_example 146 | train_examples = random.sample(list(self.dp.X_train), train_n_example) 147 | test_examples = random.sample(list(self.dp.X_test), test_n_example) 148 | for _ in range(train_n_example): 149 | example = self.show(train_examples[_][:-1], self.dp.X_id2w) 150 | if len(example) < 3: 151 | continue 152 | length = random.randint(1, len(example)-2) 153 | if self.model.reverse: 154 | o = self.model.infer(example[-length:])[0] 155 | print('Train_Input: %s | Output: %s | GroundTruth: %s' % (example[-length:], o, example)) 156 | else: 157 | o = self.model.infer(example[:length])[0] 158 | print('Train_Input: %s | Output: %s | GroundTruth: %s' % (example[:length], o, example)) 159 | 160 | for _ in range(test_n_example): 161 | example = self.show(test_examples[_][:-1], self.dp.X_id2w) 162 | if len(example) < 3: 163 | continue 164 | length = random.randint(1, len(example)-2) 165 | if self.model.reverse: 166 | o = self.model.infer(example[-length:])[0] 167 | print('Train_Input: %s | Output: %s | GroundTruth: %s' % (example[-length:], o, example)) 168 | else: 169 | o = self.model.infer(example[:length])[0] 170 | print('Train_Input: %s | Output: %s | GroundTruth: %s' % (example[:length], o, example)) 171 | print("") 172 | """ 173 | def test_bleu(self, N=300, gram=4): 174 | all_score = [] 175 | for i in range(N): 176 | input_indices = self.show(self.dp.X_test[i][:-1], self.dp.X_id2w) 177 | o = self.model.infer(input_indices)[0] 178 | refer4bleu = [[' '.join([self.dp.X_id2w.get(w, u'&') for w in self.dp.X_test[i]])]] 179 | candi = [' '.join(w for w in o)] 180 | score = BLEU(candi, refer4bleu, gram=gram) 181 | all_score.append(score) 182 | return np.mean(all_score) 183 | """ 184 | def show_res(self, path): 185 | res = cPickle.load(open(path)) 186 | plt.figure(1) 187 | plt.title('The results') 188 | l1, = plt.plot(res[0], 'g') 189 | l2, = plt.plot(res[1], 'r') 190 | l3, = plt.plot(res[3], 'b') 191 | plt.legend(handles = [l1, l2, l3], labels = ["Train_loss","Test_loss","BLEU"], loc = 'best') 192 | plt.show() 193 | 194 | def test_all(self, path, epoch_range, is_bleu=True): 195 | val_loss_list = [] 196 | bleu_list = [] 197 | for i in range(epoch_range[0], epoch_range[-1]): 198 | self.model.restore(path + str(i)) 199 | val_loss = self.test() 200 | val_loss_list.append(val_loss) 201 | if is_bleu: 202 | bleu_score = self.test_bleu() 203 | bleu_list.append(bleu_score) 204 | plt.figure(1) 205 | plt.title('The results') 206 | l1, = plt.plot(val_loss_list,'r') 207 | l2, = plt.plot(bleu_list,'b') 208 | plt.legend(handles = [l1, l2], labels = ["Test_loss","BLEU"], loc = 'best') 209 | plt.show() 210 | 211 | -------------------------------------------------------------------------------- /Util/my_helper.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """A library of helpers for use with SamplingDecoders. 16 | """ 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import abc 23 | 24 | import six 25 | 26 | from tensorflow.contrib.seq2seq.python.ops import decoder 27 | from tensorflow.python.framework import dtypes 28 | from tensorflow.python.framework import ops 29 | from tensorflow.python.layers import base as layers_base 30 | from tensorflow.python.ops import array_ops 31 | from tensorflow.python.ops import control_flow_ops 32 | from tensorflow.python.ops import embedding_ops 33 | from tensorflow.python.ops import gen_array_ops 34 | from tensorflow.python.ops import math_ops 35 | from tensorflow.python.ops import tensor_array_ops 36 | from tensorflow.python.ops.distributions import bernoulli 37 | from tensorflow.python.ops.distributions import categorical 38 | from tensorflow.python.util import nest 39 | from tensorflow.python.framework import tensor_shape 40 | from tensorflow.contrib.seq2seq.python.ops import helper 41 | __all__ = [ 42 | "myHelper", 43 | ] 44 | 45 | _transpose_batch_time = decoder._transpose_batch_time # pylint: disable=protected-access 46 | 47 | 48 | def _unstack_ta(inp): 49 | return tensor_array_ops.TensorArray( 50 | dtype=inp.dtype, size=array_ops.shape(inp)[0], 51 | element_shape=inp.get_shape()[1:]).unstack(inp) 52 | 53 | 54 | class MyHelper(helper.Helper): 55 | """A helper for first generate given prefix, then generate the reminder by sampling. 56 | """ 57 | def __init__(self, inputs, sequence_length, end_token, embedding, seed=None, time_major=False, name=None,sample_ids_shape=None, sample_ids_dtype=None): 58 | 59 | with ops.name_scope(name, "MyHelper", [inputs, sequence_length]): 60 | inputs = ops.convert_to_tensor(inputs, name="inputs") 61 | self._inputs = inputs 62 | if not time_major: 63 | inputs = nest.map_structure(_transpose_batch_time, inputs) 64 | 65 | self._input_tas = nest.map_structure(_unstack_ta, inputs) 66 | self._sequence_length = ops.convert_to_tensor( 67 | sequence_length, name="sequence_length") 68 | 69 | if self._sequence_length.get_shape().ndims != 1: 70 | raise ValueError( 71 | "Expected sequence_length to be a vector, but received shape: %s" % 72 | self._sequence_length.get_shape()) 73 | 74 | if callable(embedding): 75 | self._embedding_fn = embedding 76 | else: 77 | self._embedding_fn = ( 78 | lambda ids: embedding_ops.embedding_lookup(embedding, ids)) 79 | 80 | self._seed = seed 81 | 82 | self._end_token = ops.convert_to_tensor( 83 | end_token, dtype=dtypes.int32, name="end_token") 84 | if self._end_token.get_shape().ndims != 0: 85 | raise ValueError("end_token must be a scalar") 86 | 87 | 88 | #self._zero_inputs = nest.map_structure( 89 | # lambda inp: array_ops.zeros_like(inp[0, :]), inputs) 90 | # !!!! 91 | self._batch_size = array_ops.size(sequence_length) 92 | self._sample_ids_shape = tensor_shape.TensorShape(sample_ids_shape or []) 93 | self._sample_ids_dtype = sample_ids_dtype or dtypes.int32 94 | 95 | 96 | @property 97 | def inputs(self): 98 | return self._inputs 99 | 100 | @property 101 | def sequence_length(self): 102 | return self._sequence_length 103 | 104 | @property 105 | def batch_size(self): 106 | return self._batch_size 107 | 108 | @property 109 | def sample_ids_shape(self): 110 | return tensor_shape.TensorShape([]) 111 | 112 | @property 113 | def sample_ids_dtype(self): 114 | return dtypes.int32 115 | """ 116 | def initialize(self, name=None): 117 | with ops.name_scope(name, "TrainingHelperInitialize"): 118 | finished = array_ops.tile([False], [self._batch_size]) 119 | all_finished = math_ops.reduce_all(finished) 120 | next_inputs = self._embedding_fn(self._input_tas.read(0)) 121 | return (finished, next_inputs) 122 | """ 123 | def initialize(self, name=None): 124 | with ops.name_scope(name, "MyHelperInitialize"): 125 | finished = math_ops.equal(0, self._sequence_length) 126 | all_finished = math_ops.reduce_all(finished) 127 | next_inputs = self._embedding_fn(self._input_tas.read(0)) 128 | return (finished, next_inputs) 129 | 130 | def sample(self, time, outputs, name=None, **unused_kwargs): 131 | if not isinstance(outputs, ops.Tensor): 132 | raise TypeError("Expected outputs to be a single Tensor, got: %s" % 133 | type(outputs)) 134 | prefixed = (time+1 >= self._sequence_length) 135 | all_prefixed = math_ops.reduce_all(prefixed) 136 | 137 | sample_id_sampler = categorical.Categorical(logits=outputs) 138 | sample_ids = control_flow_ops.cond( 139 | all_prefixed, lambda: sample_id_sampler.sample(seed=self._seed), 140 | lambda: self._input_tas.read(time+1)) #nest.map_structure(lambda inp: inp.read(time+1), self._input_tas)) #self._input_tas.read(time+1)) 141 | 142 | return sample_ids 143 | 144 | def next_inputs(self, time, outputs, state, sample_ids, name=None): 145 | with ops.name_scope(name, "MyHelperNextInputs", 146 | [time, outputs, state]): 147 | next_time = time + 1 148 | finished = math_ops.equal(sample_ids, self._end_token) 149 | prefixed = (next_time >= self._sequence_length) 150 | all_prefixed = math_ops.reduce_all(prefixed) 151 | all_finished = math_ops.reduce_all(finished) 152 | 153 | next_inputs = self._embedding_fn(sample_ids) 154 | """ 155 | next_inputs = control_flow_ops.cond( 156 | all_prefixed, lambda: self._embedding_fn(sample_ids), 157 | lambda: nest.map_structure(read_from_ta, self._input_tas)) 158 | """ 159 | """ 160 | next_inputs = control_flow_ops.cond( 161 | all_prefixed, lambda: self._embedding_fn(sample_ids), 162 | lambda: self._embedding_fn(self._input_tas.read(next_time))) 163 | # lambda: self._input_tas.read(next_time)) 164 | """ 165 | return (finished, next_inputs, state) 166 | -------------------------------------------------------------------------------- /Util/my_seq2seq.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tensorflow.python.layers import core as core_layers 3 | import tensorflow as tf 4 | import numpy as np 5 | import time 6 | import myResidualCell 7 | import jieba 8 | from bleu import BLEU 9 | import random 10 | import pickle as cPickle 11 | import matplotlib.pyplot as plt 12 | 13 | 14 | 15 | 16 | class Seq2Seq: 17 | def __init__(self, dp, rnn_size, n_layers, encoder_embedding_dim, decoder_embedding_dim, max_infer_length, 18 | sess, lr=0.001, grad_clip=5.0, beam_width=5, force_teaching_ratio=1.0, beam_penalty=1.0, 19 | residual=False, output_keep_prob=0.5, input_keep_prob=0.9, cell_type='lstm', reverse=False, 20 | decay_scheme='luong234'): 21 | 22 | self.rnn_size = rnn_size 23 | self.n_layers = n_layers 24 | self.grad_clip = grad_clip 25 | self.dp = dp 26 | self.encoder_embedding_dim = encoder_embedding_dim 27 | self.decoder_embedding_dim = decoder_embedding_dim 28 | self.beam_width = beam_width 29 | self.beam_penalty = beam_penalty 30 | self.max_infer_length = max_infer_length 31 | self.residual = residual 32 | self.decay_scheme = decay_scheme 33 | if self.residual: 34 | assert encoder_embedding_dim == rnn_size 35 | assert decoder_embedding_dim == rnn_size 36 | self.reverse = reverse 37 | self.cell_type = cell_type 38 | self.force_teaching_ratio = force_teaching_ratio 39 | self._output_keep_prob = output_keep_prob 40 | self._input_keep_prob = input_keep_prob 41 | self.sess = sess 42 | self.lr=lr 43 | self.build_graph() 44 | self.sess.run(tf.global_variables_initializer()) 45 | self.saver = tf.train.Saver(tf.global_variables(), max_to_keep = 5) 46 | self.summary_placeholders, self.update_ops, self.summary_op = self.setup_summary() 47 | 48 | # end constructor 49 | 50 | def build_graph(self): 51 | self.register_symbols() 52 | self.add_input_layer() 53 | self.add_encoder_layer() 54 | with tf.variable_scope('decode'): 55 | self.add_decoder_for_training() 56 | with tf.variable_scope('decode', reuse=True): 57 | self.add_decoder_for_inference() 58 | with tf.variable_scope('decode', reuse=True): 59 | self.add_decoder_for_prefix_inference() 60 | self.add_backward_path() 61 | # end method 62 | 63 | def _item_or_tuple(self, seq): 64 | """Returns `seq` as tuple or the singular element. 65 | Which is returned is determined by how the AttentionMechanism(s) were passed 66 | to the constructor. 67 | Args: 68 | seq: A non-empty sequence of items or generator. 69 | Returns: 70 | Either the values in the sequence as a tuple if AttentionMechanism(s) 71 | were passed to the constructor as a sequence or the singular element. 72 | """ 73 | t = tuple(seq) 74 | if self._is_multi: 75 | return t 76 | else: 77 | return t[0] 78 | 79 | def add_input_layer(self): 80 | self.X = tf.placeholder(tf.int32, [None, None], name="X") 81 | self.Y = tf.placeholder(tf.int32, [None, None], name="Y") 82 | self.X_seq_len = tf.placeholder(tf.int32, [None], name="X_seq_len") 83 | self.Y_seq_len = tf.placeholder(tf.int32, [None], name="Y_seq_len") 84 | self.input_keep_prob = tf.placeholder(tf.float32,name="input_keep_prob") 85 | self.output_keep_prob = tf.placeholder(tf.float32,name="output_keep_prob") 86 | self.batch_size = tf.shape(self.X)[0] 87 | self.global_step = tf.Variable(0, name="global_step", trainable=False) 88 | # end method 89 | 90 | def single_cell(self, reuse=False): 91 | if self.cell_type == 'lstm': 92 | cell = tf.contrib.rnn.LayerNormBasicLSTMCell(self.rnn_size, reuse=reuse) 93 | else: 94 | cell = tf.contrib.rnn.GRUBlockCell(self.rnn_size) 95 | cell = tf.contrib.rnn.DropoutWrapper(cell, self.output_keep_prob, self.input_keep_prob) 96 | if self.residual: 97 | cell = myResidualCell.ResidualWrapper(cell) 98 | return cell 99 | 100 | def add_encoder_layer(self): 101 | encoder_embedding = tf.get_variable('encoder_embedding', [len(self.dp.X_w2id), self.encoder_embedding_dim], 102 | tf.float32, tf.random_uniform_initializer(-1.0, 1.0)) 103 | 104 | self.encoder_inputs = tf.nn.embedding_lookup(encoder_embedding, self.X) 105 | bi_encoder_output, bi_encoder_state = tf.nn.bidirectional_dynamic_rnn( 106 | cell_fw = tf.contrib.rnn.MultiRNNCell([self.single_cell() for _ in range(self.n_layers)]), 107 | cell_bw = tf.contrib.rnn.MultiRNNCell([self.single_cell() for _ in range(self.n_layers)]), 108 | inputs = self.encoder_inputs, 109 | sequence_length = self.X_seq_len, 110 | dtype = tf.float32, 111 | scope = 'bidirectional_rnn') 112 | self.encoder_out = tf.concat(bi_encoder_output, 2) 113 | encoder_state = [] 114 | for layer_id in range(self.n_layers): 115 | encoder_state.append(bi_encoder_state[0][layer_id]) # forward 116 | encoder_state.append(bi_encoder_state[1][layer_id]) # backward 117 | self.encoder_state = tuple(encoder_state) 118 | """ 119 | def add_encoder_layer(self): 120 | encoder_embedding = tf.get_variable('encoder_embedding', [len(self.dp.X_w2id), self.encoder_embedding_dim], 121 | tf.float32, tf.random_uniform_initializer(-1.0, 1.0)) 122 | self.encoder_out = tf.nn.embedding_lookup(encoder_embedding, self.X) 123 | for n in range(self.n_layers): 124 | (out_fw, out_bw), (state_fw, state_bw) = tf.nn.bidirectional_dynamic_rnn( 125 | cell_fw = self.single_cell(), cell_bw = self.single_cell(), 126 | inputs = self.encoder_out, 127 | sequence_length = self.X_seq_len, 128 | dtype = tf.float32, 129 | scope = 'bidirectional_rnn_'+str(n)) 130 | self.encoder_out = tf.concat((out_fw, out_bw), 2) 131 | self.encoder_state = () 132 | for n in range(self.n_layers): # replicate top-most state 133 | self.encoder_state += (state_fw, state_bw) 134 | """ 135 | def processed_decoder_input(self): 136 | main = tf.strided_slice(self.Y, [0, 0], [self.batch_size, -1], [1, 1]) # remove last char 137 | decoder_input = tf.concat([tf.fill([self.batch_size, 1], self._y_go), main], 1) 138 | return decoder_input 139 | 140 | def add_attention_for_training(self): 141 | if self.cell_type == 'lstm': 142 | attention_mechanism = tf.contrib.seq2seq.BahdanauAttention( 143 | num_units = self.rnn_size, 144 | memory = self.encoder_out, 145 | memory_sequence_length = self.X_seq_len, 146 | normalize=True) 147 | else: 148 | attention_mechanism = tf.contrib.seq2seq.LuongAttention( 149 | num_units = self.rnn_size, 150 | memory = self.encoder_out, 151 | memory_sequence_length = self.X_seq_len, 152 | scale=True) 153 | 154 | self.decoder_cell = tf.contrib.seq2seq.AttentionWrapper( 155 | cell = tf.nn.rnn_cell.MultiRNNCell([self.single_cell() for _ in range(2 * self.n_layers)]), 156 | attention_mechanism = attention_mechanism, 157 | attention_layer_size = self.rnn_size) 158 | 159 | def add_decoder_for_training(self): 160 | self.add_attention_for_training() 161 | decoder_embedding = tf.get_variable('decoder_embedding', [len(self.dp.Y_w2id), self.decoder_embedding_dim], 162 | tf.float32, tf.random_uniform_initializer(-1.0, 1.0)) 163 | training_helper = tf.contrib.seq2seq.ScheduledEmbeddingTrainingHelper( 164 | inputs = tf.nn.embedding_lookup(decoder_embedding, self.processed_decoder_input()), 165 | sequence_length = self.Y_seq_len, 166 | embedding = decoder_embedding, 167 | sampling_probability = 1 - self.force_teaching_ratio, 168 | time_major = False) 169 | training_decoder = tf.contrib.seq2seq.BasicDecoder( 170 | cell = self.decoder_cell, 171 | helper = training_helper, 172 | initial_state = self.decoder_cell.zero_state(self.batch_size, tf.float32).clone(cell_state=self.encoder_state), 173 | output_layer = core_layers.Dense(len(self.dp.Y_w2id))) 174 | training_decoder_output, training_final_state, _ = tf.contrib.seq2seq.dynamic_decode( 175 | decoder = training_decoder, 176 | impute_finished = True, 177 | maximum_iterations = tf.reduce_max(self.Y_seq_len)) 178 | self.training_logits = training_decoder_output.rnn_output 179 | self.init_prefix_state = training_final_state 180 | 181 | def add_attention_for_inference(self): 182 | self.encoder_out_tiled = tf.contrib.seq2seq.tile_batch(self.encoder_out, self.beam_width) 183 | self.encoder_state_tiled = tf.contrib.seq2seq.tile_batch(self.encoder_state, self.beam_width) 184 | self.X_seq_len_tiled = tf.contrib.seq2seq.tile_batch(self.X_seq_len, self.beam_width) 185 | if self.cell_type == 'lstm': 186 | attention_mechanism = tf.contrib.seq2seq.BahdanauAttention( 187 | num_units = self.rnn_size, 188 | memory = self.encoder_out_tiled, 189 | memory_sequence_length = self.X_seq_len_tiled, 190 | normalize=True) 191 | else: 192 | attention_mechanism = tf.contrib.seq2seq.LuongAttention( 193 | num_units = self.rnn_size, 194 | memory = self.encoder_out_tiled, 195 | memory_sequence_length = self.X_seq_len_tiled, 196 | scale=True) 197 | self.decoder_cell = tf.contrib.seq2seq.AttentionWrapper( 198 | cell = tf.nn.rnn_cell.MultiRNNCell([self.single_cell(reuse=True) for _ in range(2 * self.n_layers)]), 199 | attention_mechanism = attention_mechanism, 200 | attention_layer_size = self.rnn_size) 201 | self.attention_mechanism = attention_mechanism 202 | 203 | def add_decoder_for_inference(self): 204 | self.add_attention_for_inference() 205 | predicting_decoder = tf.contrib.seq2seq.BeamSearchDecoder( 206 | cell = self.decoder_cell, 207 | embedding = tf.get_variable('decoder_embedding'), 208 | start_tokens = tf.tile(tf.constant([self._y_go], dtype=tf.int32), [self.batch_size]), 209 | end_token = self._y_eos, 210 | initial_state = self.decoder_cell.zero_state(self.batch_size * self.beam_width, tf.float32).clone( 211 | cell_state = self.encoder_state_tiled), 212 | beam_width = self.beam_width, 213 | output_layer = core_layers.Dense(len(self.dp.Y_w2id), _reuse=True), 214 | length_penalty_weight = self.beam_penalty) 215 | predicting_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode( 216 | decoder = predicting_decoder, 217 | impute_finished = False, 218 | maximum_iterations = self.max_infer_length) 219 | self.predicting_ids = predicting_decoder_output.predicted_ids 220 | self.score = predicting_decoder_output.beam_search_decoder_output.scores 221 | 222 | def add_decoder_for_prefix_inference(self): 223 | self.add_attention_for_inference() 224 | prefix_cell_state = tf.contrib.seq2seq.tile_batch(self.init_prefix_state.cell_state, self.beam_width) 225 | prefix_attention = tf.contrib.seq2seq.tile_batch(self.init_prefix_state.attention, self.beam_width) 226 | prefix_time = self.init_prefix_state.time 227 | prefix_alignments = self.init_prefix_state.alignments 228 | prefix_alignment_history = self.init_prefix_state.alignment_history 229 | 230 | init_state = tf.contrib.seq2seq.AttentionWrapperState(cell_state=prefix_cell_state, 231 | attention=prefix_attention, time=prefix_time, 232 | attention_state=self.init_prefix_state.attention_state, 233 | alignments=prefix_alignments, 234 | alignment_history=prefix_alignment_history) 235 | predicting_decoder = tf.contrib.seq2seq.BeamSearchDecoder( 236 | cell = self.decoder_cell, 237 | embedding = tf.get_variable('decoder_embedding'), 238 | start_tokens = tf.tile(tf.constant([self._y_go], dtype=tf.int32), [self.batch_size]), 239 | end_token = self._y_eos, 240 | initial_state = init_state, 241 | beam_width = self.beam_width, 242 | output_layer = core_layers.Dense(len(self.dp.Y_w2id), _reuse=True), 243 | length_penalty_weight = self.beam_penalty) 244 | self.prefix_go = tf.placeholder(tf.int32, [None]) 245 | prefix_go_beam = tf.tile(tf.expand_dims(self.prefix_go, 1), [1, self.beam_width]) 246 | prefix_emb = tf.nn.embedding_lookup(tf.get_variable('decoder_embedding'), prefix_go_beam) 247 | predicting_decoder._start_inputs = prefix_emb 248 | predicting_prefix_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode( 249 | decoder = predicting_decoder, 250 | impute_finished = False, 251 | maximum_iterations = self.max_infer_length) 252 | self.predicting_prefix_ids = predicting_prefix_decoder_output.predicted_ids 253 | self.prefix_score = predicting_prefix_decoder_output.beam_search_decoder_output.scores 254 | 255 | def add_backward_path(self): 256 | masks = tf.sequence_mask(self.Y_seq_len, tf.reduce_max(self.Y_seq_len), dtype=tf.float32) 257 | self.loss = tf.contrib.seq2seq.sequence_loss(logits = self.training_logits, 258 | targets = self.Y, 259 | weights = masks) 260 | self.batch_loss = tf.contrib.seq2seq.sequence_loss(logits = self.training_logits, 261 | targets = self.Y, 262 | weights = masks, 263 | average_across_batch=False) 264 | params = tf.trainable_variables() 265 | gradients = tf.gradients(self.loss, params) 266 | clipped_gradients, _ = tf.clip_by_global_norm(gradients, self.grad_clip) 267 | self.learning_rate = tf.constant(self.lr) 268 | self.learning_rate = self.get_learning_rate_decay(self.decay_scheme) # decay 269 | self.train_op = tf.train.AdamOptimizer(self.learning_rate).apply_gradients(zip(clipped_gradients, params), global_step=self.global_step) 270 | 271 | def register_symbols(self): 272 | self._x_go = self.dp.X_w2id[''] 273 | self._x_eos = self.dp.X_w2id[''] 274 | self._x_pad = self.dp.X_w2id[''] 275 | self._x_unk = self.dp.X_w2id[''] 276 | 277 | self._y_go = self.dp.Y_w2id[''] 278 | self._y_eos = self.dp.Y_w2id[''] 279 | self._y_pad = self.dp.Y_w2id[''] 280 | self._y_unk = self.dp.Y_w2id[''] 281 | 282 | def infer(self, input_word): 283 | if self.reverse: 284 | input_word = input_word[::-1] 285 | input_indices = [self.dp.X_w2id.get(char, self._x_unk) for char in input_word] 286 | out_indices = self.sess.run(self.predicting_ids, { 287 | self.X: [input_indices], self.X_seq_len: [len(input_indices)], self.output_keep_prob:1, self.input_keep_prob:1}) 288 | outputs = [] 289 | for idx in range(out_indices.shape[-1]): 290 | eos_id = self.dp.Y_w2id[''] 291 | ot = out_indices[0,:,idx] 292 | if eos_id in ot: 293 | ot = ot.tolist() 294 | ot = ot[:ot.index(eos_id)] 295 | if self.reverse: 296 | ot = ot[::-1] 297 | output_str = ''.join([self.dp.Y_id2w.get(i, u'&') for i in ot]) 298 | outputs.append(output_str) 299 | return outputs 300 | 301 | def prefix_infer(self, input_word, prefix): 302 | input_indices_X = [self.dp.X_w2id.get(char, self._x_unk) for char in input_word] 303 | input_indices_Y = [self.dp.Y_w2id.get(char, self._y_unk) for char in prefix] 304 | prefix_go = [] 305 | prefix_go.append(input_indices_Y[-1]) 306 | out_indices, scores = self.sess.run([self.predicting_prefix_ids, self.prefix_score], { 307 | self.X: [input_indices_X], self.X_seq_len: [len(input_indices_X)], self.Y:[input_indices_Y], self.Y_seq_len:[len(input_indices_Y)], 308 | self.prefix_go: prefix_go, self.input_keep_prob:1, self.output_keep_prob:1}) 309 | 310 | outputs = [] 311 | for idx in range(out_indices.shape[-1]): 312 | eos_id = self.dp.Y_w2id[''] 313 | ot = out_indices[0,:,idx] 314 | if eos_id in ot: 315 | ot = ot.tolist() 316 | ot = ot[:ot.index(eos_id)] 317 | if self.reverse: 318 | ot = ot[::-1] 319 | if self.reverse: 320 | output_str = ''.join([self.dp.Y_id2w.get(i, u'&') for i in ot]) + prefix 321 | else: 322 | output_str = prefix + ''.join([self.dp.Y_id2w.get(i, u'&') for i in ot]) 323 | outputs.append(output_str) 324 | return outputs 325 | 326 | 327 | def restore(self, path): 328 | self.saver.restore(self.sess, path) 329 | print('restore %s success' % path) 330 | 331 | def get_learning_rate_decay(self, decay_scheme='luong234'): 332 | num_train_steps = self.dp.num_steps 333 | if decay_scheme == "luong10": 334 | start_decay_step = int(num_train_steps / 2) 335 | remain_steps = num_train_steps - start_decay_step 336 | decay_steps = int(remain_steps / 10) # decay 10 times 337 | decay_factor = 0.5 338 | else: 339 | start_decay_step = int(num_train_steps * 2 / 3) 340 | remain_steps = num_train_steps - start_decay_step 341 | decay_steps = int(remain_steps / 4) # decay 4 times 342 | decay_factor = 0.5 343 | return tf.cond( 344 | self.global_step < start_decay_step, 345 | lambda: self.learning_rate, 346 | lambda: tf.train.exponential_decay( 347 | self.learning_rate, 348 | (self.global_step - start_decay_step), 349 | decay_steps, decay_factor, staircase=True), 350 | name="learning_rate_decay_cond") 351 | 352 | def setup_summary(self): 353 | train_loss = tf.Variable(0.) 354 | tf.summary.scalar('Train_loss', train_loss) 355 | 356 | test_loss = tf.Variable(0.) 357 | tf.summary.scalar('Test_loss', test_loss) 358 | 359 | bleu_score = tf.Variable(0.) 360 | tf.summary.scalar('BLEU_score', bleu_score) 361 | 362 | tf.summary.scalar('lr_rate', self.learning_rate) 363 | 364 | summary_vars = [train_loss, test_loss, bleu_score] 365 | summary_placeholders = [tf.placeholder(tf.float32) for _ in range(len(summary_vars))] 366 | update_ops = [summary_vars[i].assign(summary_placeholders[i]) for i in range(len(summary_vars))] 367 | summary_op = tf.summary.merge_all() 368 | return summary_placeholders, update_ops, summary_op 369 | 370 | class Seq2Seq_DP: 371 | def __init__(self, X_indices, Y_indices, X_w2id, Y_w2id, BATCH_SIZE, n_epoch): 372 | assert len(X_indices) == len(Y_indices) 373 | num_test = int(len(X_indices) * 0.2) 374 | self.n_epoch = n_epoch 375 | self.X_train = np.array(X_indices[num_test:]) 376 | self.Y_train = np.array(Y_indices[num_test:]) 377 | self.X_test = np.array(X_indices[:num_test]) 378 | self.Y_test = np.array(Y_indices[:num_test]) 379 | self.num_batch = int(len(self.X_train) / BATCH_SIZE) 380 | self.num_steps = self.num_batch * self.n_epoch 381 | self.batch_size = BATCH_SIZE 382 | self.X_w2id = X_w2id 383 | self.X_id2w = dict(zip(X_w2id.values(), X_w2id.keys())) 384 | self.Y_w2id = Y_w2id 385 | self.Y_id2w = dict(zip(Y_w2id.values(), Y_w2id.keys())) 386 | self._x_pad = self.X_w2id[''] 387 | self._y_pad = self.Y_w2id[''] 388 | print('Train_data: %d | Test_data: %d | Batch_size: %d | Num_batch: %d | X_vocab_size: %d | Y_vocab_size: %d' % (len(self.X_train), len(self.X_test), BATCH_SIZE, self.num_batch, len(self.X_w2id), len(self.Y_w2id))) 389 | 390 | def next_batch(self, X, Y): 391 | r = np.random.permutation(len(X)) 392 | X = X[r] 393 | Y = Y[r] 394 | for i in range(0, len(X) - len(X) % self.batch_size, self.batch_size): 395 | X_batch = X[i : i + self.batch_size] 396 | Y_batch = Y[i : i + self.batch_size] 397 | padded_X_batch, X_batch_lens = self.pad_sentence_batch(X_batch, self._x_pad) 398 | padded_Y_batch, Y_batch_lens = self.pad_sentence_batch(Y_batch, self._y_pad) 399 | yield (np.array(padded_X_batch), 400 | np.array(padded_Y_batch), 401 | X_batch_lens, 402 | Y_batch_lens) 403 | 404 | def sample_test_batch(self): 405 | padded_X_batch, X_batch_lens = self.pad_sentence_batch(self.X_test[: self.batch_size], self._x_pad) 406 | padded_Y_batch, Y_batch_lens = self.pad_sentence_batch(self.Y_test[: self.batch_size], self._y_pad) 407 | return np.array(padded_X_batch), np.array(padded_Y_batch), X_batch_lens, Y_batch_lens 408 | 409 | def pad_sentence_batch(self, sentence_batch, pad_int): 410 | padded_seqs = [] 411 | seq_lens = [] 412 | max_sentence_len = max([len(sentence) for sentence in sentence_batch]) 413 | for sentence in sentence_batch: 414 | padded_seqs.append(sentence + [pad_int] * (max_sentence_len - len(sentence))) 415 | seq_lens.append(len(sentence)) 416 | return padded_seqs, seq_lens 417 | 418 | class Seq2Seq_util: 419 | def __init__(self, dp, model, display_freq=3): 420 | self.display_freq = display_freq 421 | self.dp = dp 422 | self.model = model 423 | 424 | def train(self, epoch): 425 | avg_loss = 0.0 426 | tic = time.time() 427 | X_test_batch, Y_test_batch, X_test_batch_lens, Y_test_batch_lens = self.dp.sample_test_batch() 428 | for local_step, (X_train_batch, Y_train_batch, X_train_batch_lens, Y_train_batch_lens) in enumerate( 429 | self.dp.next_batch(self.dp.X_train, self.dp.Y_train)): 430 | self.model.step, _, loss = self.model.sess.run([self.model.global_step, self.model.train_op, self.model.loss], 431 | {self.model.X: X_train_batch, 432 | self.model.Y: Y_train_batch, 433 | self.model.X_seq_len: X_train_batch_lens, 434 | self.model.Y_seq_len: Y_train_batch_lens, 435 | self.model.output_keep_prob:self.model._output_keep_prob, 436 | self.model.input_keep_prob:self.model._input_keep_prob}) 437 | avg_loss += loss 438 | """ 439 | stats = [loss] 440 | for i in xrange(len(stats)): 441 | self.model.sess.run(self.model.update_ops[i], feed_dict={ 442 | self.model.summary_placeholders[i]: float(stats[i]) 443 | }) 444 | summary_str = self.model.sess.run([self.model.summary_op]) 445 | self.summary_writer.add_summary(summary_str, self.model.step + 1) 446 | """ 447 | if local_step % (self.dp.num_batch / self.display_freq) == 0: 448 | val_loss = self.model.sess.run(self.model.loss, {self.model.X: X_test_batch, 449 | self.model.Y: Y_test_batch, 450 | self.model.X_seq_len: X_test_batch_lens, 451 | self.model.Y_seq_len: Y_test_batch_lens, 452 | self.model.output_keep_prob:1, 453 | self.model.input_keep_prob:1}) 454 | print("Epoch %d/%d | Batch %d/%d | Train_loss: %.3f | Test_loss: %.3f | Time_cost:%.3f" % (epoch, self.n_epoch, local_step, self.dp.num_batch, avg_loss / (local_step + 1), val_loss, time.time()-tic)) 455 | self.cal() 456 | tic = time.time() 457 | return avg_loss / self.dp.num_batch 458 | 459 | def test(self): 460 | avg_loss = 0.0 461 | for local_step, (X_test_batch, Y_test_batch, X_test_batch_lens, Y_test_batch_lens) in enumerate( 462 | self.dp.next_batch(self.dp.X_test, self.dp.Y_test)): 463 | val_loss = self.model.sess.run(self.model.loss, {self.model.X: X_test_batch, 464 | self.model.Y: Y_test_batch, 465 | self.model.X_seq_len: X_test_batch_lens, 466 | self.model.Y_seq_len: Y_test_batch_lens, 467 | self.model.output_keep_prob:1, 468 | self.model.input_keep_prob:1}) 469 | avg_loss += val_loss 470 | return avg_loss / (local_step + 1) 471 | 472 | def fit(self, train_dir, is_bleu): 473 | self.n_epoch = self.dp.n_epoch 474 | test_loss_list = [] 475 | train_loss_list = [] 476 | time_cost_list = [] 477 | bleu_list = [] 478 | timestamp = str(int(time.time())) 479 | out_dir = os.path.abspath(os.path.join(train_dir, "runs", timestamp)) 480 | if not os.path.exists(out_dir): 481 | os.makedirs(out_dir) 482 | print("Writing to %s" % out_dir) 483 | checkpoint_prefix = os.path.join(out_dir, "model") 484 | self.summary_writer = tf.summary.FileWriter(os.path.join(out_dir, 'Summary'), self.model.sess.graph) 485 | for epoch in range(1, self.n_epoch+1): 486 | tic = time.time() 487 | train_loss = self.train(epoch) 488 | train_loss_list.append(train_loss) 489 | test_loss = self.test() 490 | test_loss_list.append(test_loss) 491 | toc = time.time() 492 | time_cost_list.append((toc - tic)) 493 | if is_bleu: 494 | bleu = self.test_bleu() 495 | bleu_list.append(bleu) 496 | print("Epoch %d/%d | Train_loss: %.3f | Test_loss: %.3f | Bleu: %.3f" % (epoch, self.n_epoch, train_loss, test_loss, bleu)) 497 | else: 498 | bleu = 0.0 499 | print("Epoch %d/%d | Train_loss: %.3f | Test_loss: %.3f" % (epoch, self.n_epoch, train_loss, test_loss)) 500 | 501 | stats = [train_loss, test_loss, bleu] 502 | for i in range(len(stats)): 503 | self.model.sess.run(self.model.update_ops[i], feed_dict={ 504 | self.model.summary_placeholders[i]: float(stats[i]) 505 | }) 506 | summary_str = self.model.sess.run(self.model.summary_op) 507 | self.summary_writer.add_summary(summary_str, epoch) 508 | cPickle.dump((train_loss_list, test_loss_list, time_cost_list, bleu_list), open(os.path.join(out_dir,"res.pkl"),'wb')) 509 | path = self.model.saver.save(self.model.sess, checkpoint_prefix, global_step=epoch) 510 | print("Saved model checkpoint to %s" % path) 511 | 512 | def show(self, sent, id2w): 513 | return "".join([id2w.get(idx, u'&') for idx in sent]) 514 | 515 | def cal(self, n_example=5): 516 | train_n_example = int(n_example / 2) 517 | test_n_example = n_example - train_n_example 518 | for _ in range(test_n_example): 519 | example = self.show(self.dp.X_test[_], self.dp.X_id2w) 520 | y = self.show(self.dp.Y_test[_], self.dp.Y_id2w) 521 | o = self.model.infer(example)[0] 522 | print('Input: %s | Output: %s | GroundTruth: %s' % (example, o, y)) 523 | for _ in range(train_n_example): 524 | example = self.show(self.dp.X_train[_], self.dp.X_id2w) 525 | y = self.show(self.dp.Y_train[_], self.dp.Y_id2w) 526 | o = self.model.infer(example)[0] 527 | print('Input: %s | Output: %s | GroundTruth: %s' % (example, o, y)) 528 | print("") 529 | 530 | def test_bleu(self, N=300, gram=4): 531 | all_score = [] 532 | for i in range(N): 533 | input_indices = self.show(self.dp.X_test[i], self.dp.X_id2w) 534 | o = self.model.infer(input_indices)[0] 535 | refer4bleu = [[' '.join([self.dp.Y_id2w.get(w, u'&') for w in self.dp.Y_test[i]])]] 536 | candi = [' '.join(w for w in o)] 537 | score = BLEU(candi, refer4bleu, gram=gram) 538 | all_score.append(score) 539 | return np.mean(all_score) 540 | 541 | def show_res(self, path): 542 | res = cPickle.load(open(path)) 543 | plt.figure(1) 544 | plt.title('The results') 545 | l1, = plt.plot(res[0], 'g') 546 | l2, = plt.plot(res[1], 'r') 547 | l3, = plt.plot(res[3], 'b') 548 | plt.legend(handles = [l1, l2, l3], labels = ["Train_loss","Test_loss","BLEU"], loc = 'best') 549 | plt.show() 550 | 551 | def test_all(self, path, epoch_range, is_bleu=True): 552 | val_loss_list = [] 553 | bleu_list = [] 554 | for i in range(epoch_range[0], epoch_range[-1]): 555 | self.model.restore(path + str(i)) 556 | val_loss = self.test() 557 | val_loss_list.append(val_loss) 558 | if is_bleu: 559 | bleu_score = self.test_bleu() 560 | bleu_list.append(bleu_score) 561 | plt.figure(1) 562 | plt.title('The results') 563 | l1, = plt.plot(val_loss_list,'r') 564 | l2, = plt.plot(bleu_list,'b') 565 | plt.legend(handles = [l1, l2], labels = ["Test_loss","BLEU"], loc = 'best') 566 | plt.show() 567 | 568 | 569 | -------------------------------------------------------------------------------- /bleu.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import codecs 3 | import os 4 | import math 5 | import operator 6 | import json 7 | from functools import reduce 8 | 9 | 10 | def fetch_data(cand, ref): 11 | """ Store each reference and candidate sentences as a list """ 12 | references = [] 13 | if '.txt' in ref: 14 | reference_file = codecs.open(ref, 'r', 'utf-8') 15 | references.append(reference_file.readlines()) 16 | else: 17 | for root, dirs, files in os.walk(ref): 18 | for f in files: 19 | reference_file = codecs.open(os.path.join(root, f), 'r', 'utf-8') 20 | references.append(reference_file.readlines()) 21 | candidate_file = codecs.open(cand, 'r', 'utf-8') 22 | candidate = candidate_file.readlines() 23 | return candidate, references 24 | 25 | def count_ngram(candidate, references, n): 26 | clipped_count = 0 27 | count = 0 28 | r = 0 29 | c = 0 30 | for si in range(len(candidate)): 31 | # Calculate precision for each sentence 32 | ref_counts = [] 33 | ref_lengths = [] 34 | # Build dictionary of ngram counts 35 | for reference in references: 36 | ref_sentence = reference[si] 37 | ngram_d = {} 38 | words = ref_sentence.strip().split() 39 | ref_lengths.append(len(words)) 40 | limits = len(words) - n + 1 41 | # loop through the sentance consider the ngram length 42 | for i in range(limits): 43 | ngram = ' '.join(words[i:i+n]).lower() 44 | if ngram in ngram_d.keys(): 45 | ngram_d[ngram] += 1 46 | else: 47 | ngram_d[ngram] = 1 48 | ref_counts.append(ngram_d) 49 | # candidate 50 | cand_sentence = candidate[si] 51 | cand_dict = {} 52 | words = cand_sentence.strip().split() 53 | limits = len(words) - n + 1 54 | for i in range(0, limits): 55 | ngram = ' '.join(words[i:i + n]).lower() 56 | if ngram in cand_dict: 57 | cand_dict[ngram] += 1 58 | else: 59 | cand_dict[ngram] = 1 60 | #print('cand_dict',cand_dict) 61 | clipped_count += clip_count(cand_dict, ref_counts) 62 | count += limits 63 | #print('clipped_count',clipped_count) 64 | #print('count',count) 65 | #print('len(words)', len(words)) 66 | r += best_length_match(ref_lengths, len(words)) 67 | #print('best_match',r) 68 | 69 | c += len(words) 70 | #print('c', c) 71 | if clipped_count == 0: 72 | pr = 0 73 | #print('pr',pr) 74 | else: 75 | pr = float(clipped_count) / count 76 | #print('pr',pr) 77 | bp = brevity_penalty(c, r) 78 | #print('bp, c, r', bp, c, r) 79 | return pr, bp 80 | 81 | def _count_ngram(candidate, ref_counts, ref_lengths, n): 82 | clipped_count = 0 83 | count = 0 84 | r = 0 85 | c = 0 86 | si = 0 87 | cand_sentence = candidate[si] 88 | cand_dict = {} 89 | words = cand_sentence.strip().split() 90 | limits = len(words) - n + 1 91 | for i in range(0, limits): 92 | ngram = ' '.join(words[i:i + n]).lower() 93 | if ngram in cand_dict: 94 | cand_dict[ngram] += 1 95 | else: 96 | cand_dict[ngram] = 1 97 | #print('cand_dict',cand_dict) 98 | clipped_count += clip_count(cand_dict, ref_counts) 99 | count += limits 100 | #print('clipped_count',clipped_count) 101 | #print('count',count) 102 | #print('len(words)', len(words)) 103 | r += best_length_match(ref_lengths, len(words)) 104 | #print('best_match',r) 105 | 106 | c += len(words) 107 | #print('c', c) 108 | if clipped_count == 0: 109 | pr = 0 110 | #print('pr',pr) 111 | else: 112 | pr = float(clipped_count) / count 113 | #print('pr',pr) 114 | bp = brevity_penalty(c, r) 115 | #print('bp, c, r', bp, c, r) 116 | return pr, bp 117 | 118 | 119 | def clip_count(cand_d, ref_ds): 120 | """Count the clip count for each ngram considering all references""" 121 | count = 0 122 | for m in cand_d.keys(): 123 | m_w = cand_d[m] 124 | m_max = 0 125 | for ref in ref_ds: 126 | if m in ref: 127 | m_max = max(m_max, ref[m]) 128 | m_w = min(m_w, m_max) 129 | count += m_w 130 | return count 131 | 132 | 133 | def best_length_match(ref_l, cand_l): 134 | """Find the closest length of reference to that of candidate""" 135 | least_diff = abs(cand_l-ref_l[0]) 136 | best = ref_l[0] 137 | for ref in ref_l: 138 | if abs(cand_l-ref) < least_diff: 139 | least_diff = abs(cand_l-ref) 140 | best = ref 141 | return best 142 | 143 | 144 | def brevity_penalty(c, r): 145 | if c == 0: 146 | return 0.0 147 | if c > r: 148 | bp = 1 149 | else: 150 | bp = math.exp(1-(float(r)/c)) 151 | return bp 152 | 153 | 154 | def geometric_mean(precisions): 155 | return (reduce(operator.mul, precisions)) ** (1.0 / len(precisions)) 156 | 157 | def get_reference_count(references, n): 158 | ref_counts = [] 159 | ref_lengths = [] 160 | # Build dictionary of ngram counts 161 | for reference in references: 162 | ref_sentence = reference[0] 163 | ngram_d = {} 164 | words = ref_sentence.strip().split() 165 | ref_lengths.append(len(words)) 166 | limits = len(words) - n + 1 167 | # loop through the sentance consider the ngram length 168 | for i in range(limits): 169 | ngram = ' '.join(words[i:i+n]).lower() 170 | if ngram in ngram_d.keys(): 171 | ngram_d[ngram] += 1 172 | else: 173 | ngram_d[ngram] = 1 174 | ref_counts.append(ngram_d) 175 | 176 | return ref_counts, ref_lengths 177 | 178 | 179 | def BLEU(candidate, references, gram=4): 180 | precisions = [] 181 | for i in range(gram): 182 | pr, bp = count_ngram(candidate, references, i+1) 183 | #print pr, bp 184 | precisions.append(pr) 185 | #print geometric_mean(precisions), bp 186 | bleu = geometric_mean(precisions) * bp 187 | return bleu 188 | 189 | def _BLEU(candidate, ref_counts_n, ref_lengths_n, gram): 190 | assert len(ref_counts_n) == gram 191 | assert len(ref_lengths_n) == gram 192 | precisions = [] 193 | bleu_gram = [] 194 | for i in range(gram): 195 | pr, bp = _count_ngram(candidate, ref_counts_n[i], ref_lengths_n[i], i+1) 196 | precisions.append(pr) 197 | bleu = geometric_mean(precisions) * bp 198 | bleu_gram.append(bleu) 199 | 200 | return bleu_gram -------------------------------------------------------------------------------- /overview_2.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dayihengliu/Text-Infilling-Gradient-Search/af89b634fa9d74222d29ed5ef0b91da46533dd62/overview_2.pdf -------------------------------------------------------------------------------- /results/_URNN-f_res.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dayihengliu/Text-Infilling-Gradient-Search/af89b634fa9d74222d29ed5ef0b91da46533dd62/results/_URNN-f_res.pkl --------------------------------------------------------------------------------