├── Unit.py ├── OutputUnit.py ├── README.md ├── LstmUnit.py ├── AttentionUnit.py ├── DataLoader.py ├── MleTrain.py └── SeqUnit.py /Unit.py: -------------------------------------------------------------------------------- 1 | from LstmUnit import LstmUnit 2 | from OutputUnit import OutputUnit 3 | from SeqUnit import SeqUnit 4 | from RlUnit import RlUnit 5 | from AttentionUnit import AttentionWrapper 6 | 7 | -------------------------------------------------------------------------------- /OutputUnit.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import pickle 3 | 4 | class OutputUnit(object): 5 | def __init__(self, input_size, output_size, scope_name): 6 | self.input_size = input_size 7 | self.output_size = output_size 8 | self.scope_name = scope_name 9 | self.params = {} 10 | 11 | with tf.variable_scope(scope_name): 12 | self.W = tf.get_variable('W', [input_size, output_size]) 13 | self.b = tf.get_variable('b', [output_size], initializer=tf.zeros_initializer(), dtype=tf.float32) 14 | 15 | self.params.update({'W': self.W, 'b': self.b}) 16 | 17 | def __call__(self, x, finished = None): 18 | out = tf.nn.xw_plus_b(x, self.W, self.b) 19 | 20 | if finished is not None: 21 | out = tf.where(finished, tf.zeros_like(out), out) 22 | #out = tf.multiply(1 - finished, out) 23 | return out 24 | 25 | def save(self, path): 26 | param_values = {} 27 | for param in self.params: 28 | param_values[param] = self.params[param].eval() 29 | with open(path, 'wb') as f: 30 | pickle.dump(param_values, f, True) 31 | 32 | def load(self, path): 33 | param_values = pickle.load(open(path, 'rb')) 34 | for param in param_values: 35 | self.params[param].load(param_values[param]) 36 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Semantic Relevance Based Text Summarization Model 2 | Code for "Improving Semantic Relevance for Sequence-to-Sequence Learning of Chinese Social Media Text Summarization" 3 | The codes are also used for "A Semantic Relevance Based Neural Network for Text Summarization and Text Simplification" 4 | ## Requirements 5 | * Tensorflow r1.0.1 6 | * Python 3.5 7 | * CUDA 8.0 (For GPU) 8 | * [ROUGE](http://research.microsoft.com/~cyl/download/ROUGE-1.5.5.tgz) 9 | ## Data 10 | The dataset in the paper is Large Scale Chinese Short Text Summarization [(LCSTS)](http://icrc.hitsz.edu.cn/Article/show/139.html). 11 | To preprocess the data, please split the sentences into characters, and transform the characters into numbers (ids). 12 | ## Run 13 | ```bash 14 | python3 MleTrain.py 15 | ``` 16 | ## Cite 17 | If you use this code for your research, please cite the paper this code is 18 | based on: Improving Semantic Relevance for Sequence-to-Sequence Learning of 19 | Chinese Social Media Text Summarization: 20 | 21 | ``` 22 | @inproceedings{MaEA2017, 23 | author = {Shuming Ma and Xu Sun and Jingjing Xu and Houfeng Wang and Wenjie Li and Qi Su}, 24 | title = {Improving Semantic Relevance for Sequence-to-Sequence Learning of Chinese Social Media Text Summarization}, 25 | booktitle = {Proceedings of the 55th Annual Meeting of the Association for Computational 26 | Linguistics, {ACL} 2017, Vancouver, Canada, July 30 - August 4, Volume 27 | 2: Short Papers}, 28 | pages = {635--640}, 29 | year = {2017} 30 | } 31 | ``` 32 | -------------------------------------------------------------------------------- /LstmUnit.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import pickle 3 | 4 | class LstmUnit(object): 5 | def __init__(self, hidden_size, input_size, scope_name): 6 | self.hidden_size = hidden_size 7 | self.input_size = input_size 8 | self.scope_name = scope_name 9 | self.params = {} 10 | 11 | with tf.variable_scope(scope_name): 12 | self.W = tf.get_variable('W', [self.input_size+self.hidden_size, 4*self.hidden_size]) 13 | self.b = tf.get_variable('b', [4*self.hidden_size], initializer=tf.zeros_initializer(), dtype=tf.float32) 14 | 15 | self.params.update({'W':self.W, 'b':self.b}) 16 | 17 | def __call__(self, x, s, finished = None): 18 | h_prev, c_prev = s 19 | 20 | x = tf.concat([x, h_prev], 1) 21 | i,j,f,o = tf.split(tf.nn.xw_plus_b(x, self.W, self.b), 4, 1) 22 | 23 | # Final Memory cell 24 | c = tf.sigmoid(f+1.0) * c_prev + tf.sigmoid(i) * tf.tanh(j) 25 | h = tf.sigmoid(o) * tf.tanh(c) 26 | 27 | out, state = h, (h, c) 28 | if finished is not None: 29 | out = tf.where(finished, tf.zeros_like(h), h) 30 | state = (tf.where(finished, h_prev, h), tf.where(finished, c_prev, c)) 31 | #out = tf.multiply(1 - finished, h) 32 | #state = (tf.multiply(1 - finished, h) + tf.multiply(finished, h_prev), 33 | # tf.multiply(1 - finished, c) + tf.multiply(finished, c_prev)) 34 | 35 | return out, state 36 | 37 | def save(self, path): 38 | param_values = {} 39 | for param in self.params: 40 | param_values[param] = self.params[param].eval() 41 | with open(path, 'wb') as f: 42 | pickle.dump(param_values, f, True) 43 | 44 | def load(self, path): 45 | param_values = pickle.load(open(path, 'rb')) 46 | for param in param_values: 47 | self.params[param].load(param_values[param]) 48 | 49 | -------------------------------------------------------------------------------- /AttentionUnit.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import pickle 3 | 4 | 5 | class AttentionWrapper(object): 6 | def __init__(self, hidden_size, input_size, hs, scope_name): 7 | self.hs = tf.transpose(hs, [1,0,2]) 8 | self.hidden_size = hidden_size 9 | self.input_size = input_size 10 | self.scope_name = scope_name 11 | self.params = {} 12 | 13 | with tf.variable_scope(scope_name): 14 | self.Wh = tf.get_variable('Wh', [input_size, hidden_size]) 15 | self.bh = tf.get_variable('bh', [hidden_size]) 16 | self.Ws = tf.get_variable('Ws', [input_size, hidden_size]) 17 | self.bs = tf.get_variable('bs', [hidden_size]) 18 | self.Wo = tf.get_variable('Wo', [2*input_size, hidden_size]) 19 | self.bo = tf.get_variable('bo', [hidden_size]) 20 | self.params.update({'Wh': self.Wh, 'Ws': self.Ws, 'Wo': self.Wo, 21 | 'bh': self.bh, 'bs': self.bs, 'bo': self.bo}) 22 | 23 | hs2d = tf.reshape(self.hs, [-1, input_size]) 24 | phi_hs2d = tf.tanh(tf.nn.xw_plus_b(hs2d, self.Wh, self.bh)) 25 | self.phi_hs = tf.reshape(phi_hs2d, tf.shape(self.hs)) 26 | 27 | def __call__(self, x, finished = None): 28 | gamma_h = tf.tanh(tf.nn.xw_plus_b(x, self.Ws, self.bs)) 29 | ''' 30 | weights = tf.reduce_sum(self.phi_hs * gamma_h, reduction_indices=2, keep_dims=True) 31 | context = tf.mod(tf.reduce_max(weights * 1000000 + self.hs, reduction_indices=0), 1000000) 32 | context = tf.stop_gradient(context) 33 | ''' 34 | weights = tf.reduce_sum(self.phi_hs * gamma_h, reduction_indices=2, keep_dims=True) 35 | weights = tf.exp(weights - tf.reduce_max(weights, reduction_indices=0, keep_dims=True)) 36 | weights = tf.divide(weights, (1e-6 + tf.reduce_sum(weights, reduction_indices=0, keep_dims=True))) 37 | context = tf.reduce_sum(self.hs * weights, reduction_indices=0) 38 | out = tf.tanh(tf.nn.xw_plus_b(tf.concat([context, x], -1), self.Wo, self.bo)) 39 | 40 | if finished is not None: 41 | out = tf.where(finished, tf.zeros_like(out), out) 42 | return out 43 | 44 | def save(self, path): 45 | param_values = {} 46 | for param in self.params: 47 | param_values[param] = self.params[param].eval() 48 | with open(path, 'wb') as f: 49 | pickle.dump(param_values, f, True) 50 | 51 | def load(self, path): 52 | param_values = pickle.load(open(path, 'rb')) 53 | for param in param_values: 54 | self.params[param].load(param_values[param]) 55 | -------------------------------------------------------------------------------- /DataLoader.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import time 3 | import numpy as np 4 | 5 | class DataLoader(object): 6 | def __init__(self, data_dir, limits): 7 | self.train_data_path = [data_dir + '/data/train.summary.id', data_dir + '/data/train.text.id'] 8 | self.test_data_path = [data_dir + '/data/test.summary.id', data_dir + '/data/test.text.id'] 9 | self.dev_data_path = [data_dir + '/data/dev.summary.id', data_dir + '/data/dev.text.id'] 10 | self.limits = limits 11 | start_time = time.time() 12 | 13 | self.train_set = self.load_data(self.train_data_path) 14 | self.test_set = self.load_data(self.test_data_path) 15 | self.dev_set = self.load_data(self.dev_data_path) 16 | 17 | print ('Reading datasets comsumes %.3f seconds' % (time.time() - start_time)) 18 | 19 | def load_data(self, path): 20 | summary_path, text_path = path 21 | summaries = open(summary_path, 'r').read().strip().split('\n') 22 | texts = open(text_path, 'r').read().strip().split('\n') 23 | if self.limits > 0: 24 | summaries = summaries[:self.limits] 25 | texts = texts[:self.limits] 26 | summaries = [list(map(int,summary.split(' '))) for summary in summaries] 27 | texts = [list(map(int,text.split(' '))) for text in texts] 28 | 29 | return summaries, texts 30 | 31 | def batch_iter(self, data, batch_size, shuffle): 32 | summaries, texts = data 33 | data_size = len(summaries) 34 | num_batches = int(data_size / batch_size) if data_size % batch_size == 0 \ 35 | else int(data_size / batch_size) + 1 36 | 37 | if shuffle: 38 | shuffle_indices = np.random.permutation(np.arange(data_size)) 39 | summaries = np.array(summaries)[shuffle_indices] 40 | texts = np.array(texts)[shuffle_indices] 41 | 42 | for batch_num in range(num_batches): 43 | start_index = batch_num * batch_size 44 | end_index = min((batch_num + 1) * batch_size, data_size) 45 | max_summary_len = max([len(sample) for sample in summaries[start_index:end_index]]) 46 | max_text_len = max([len(sample) for sample in texts[start_index:end_index]]) 47 | batch_data = {'enc_in':[], 'enc_len':[], 'dec_in':[], 'dec_len':[], 'dec_out':[]} 48 | for summary, text in zip(summaries[start_index:end_index], texts[start_index:end_index]): 49 | summary_len = len(summary) 50 | text_len = len(text) 51 | gold = summary + [2] + [0] * (max_summary_len - summary_len) 52 | summary = summary + [0] * (max_summary_len - summary_len) 53 | text = text + [0] * (max_text_len - text_len) 54 | batch_data['enc_in'].append(text) 55 | batch_data['enc_len'].append(text_len) 56 | batch_data['dec_in'].append(summary) 57 | batch_data['dec_len'].append(summary_len) 58 | batch_data['dec_out'].append(gold) 59 | 60 | yield batch_data -------------------------------------------------------------------------------- /MleTrain.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import sys 3 | import os 4 | import tensorflow as tf 5 | import time 6 | from Unit import * 7 | from DataLoader import DataLoader 8 | import numpy as np 9 | from ROUGE.PythonROUGE import PythonROUGE 10 | 11 | tf.app.flags.DEFINE_string("cell", "lstm", "Rnn cell.") 12 | tf.app.flags.DEFINE_integer("hidden_size", 500, "Size of each layer.") 13 | tf.app.flags.DEFINE_integer("emb_size", 400, "Size of embedding.") 14 | tf.app.flags.DEFINE_integer("batch_size", 32, "Batch size of train set.") 15 | tf.app.flags.DEFINE_integer("epoch", 50, "Number of training epoch.") 16 | tf.app.flags.DEFINE_float("dropout", 1.0, "Dropout keep probability.") 17 | tf.app.flags.DEFINE_string("gpu", '0', "GPU id.") 18 | tf.app.flags.DEFINE_string("opt",'Adam','Optimizer.') 19 | tf.app.flags.DEFINE_string("mode",'train','train or test') 20 | tf.app.flags.DEFINE_string("save",'0','save directory') 21 | tf.app.flags.DEFINE_string("load",'0','load directory') 22 | tf.app.flags.DEFINE_string("dir",'lcsts','data set directory') 23 | tf.app.flags.DEFINE_integer("limits",0,'max data set size') 24 | tf.app.flags.DEFINE_boolean("ckpt", False,'load checkpoint or not') 25 | tf.app.flags.DEFINE_boolean("attention", True,'attention mechanism or not') 26 | tf.app.flags.DEFINE_boolean("dev", False,'dev or test') 27 | tf.app.flags.DEFINE_boolean("SRB", True,'use SRB or test') 28 | tf.app.flags.DEFINE_integer("source_vocab", 4003,'vocabulary size') 29 | tf.app.flags.DEFINE_integer("target_vocab", 4003,'vocabulary size') 30 | tf.app.flags.DEFINE_integer("report", 500,'report') 31 | tf.app.flags.DEFINE_string("m","train",'running message') 32 | FLAGS = tf.app.flags.FLAGS 33 | 34 | gold_path = FLAGS.dir + '/evaluation/test_gold_summarys_' 35 | pred_path = FLAGS.dir + '/evaluation/test_pred_summarys_' 36 | 37 | if FLAGS.save == "0": 38 | save_dir = FLAGS.dir + '/' + str(int(time.time() * 1000)) + '/' 39 | os.mkdir(save_dir) 40 | else: 41 | save_dir = FLAGS.save 42 | log_file = save_dir + 'log.txt' 43 | 44 | def train(sess, dataloader, model): 45 | write_log("#######################################################\n") 46 | for flag in FLAGS.__flags: 47 | write_log(flag + " = " + str(FLAGS.__flags[flag]) + "\n") 48 | write_log("#######################################################\n") 49 | trainset = dataloader.train_set 50 | k = 0 51 | for _ in range(FLAGS.epoch): 52 | loss, start_time = 0.0, time.time() 53 | for x in dataloader.batch_iter(trainset, FLAGS.batch_size, True): 54 | loss += model(x, sess) 55 | k += 1 56 | sys.stdout.write('training %.2f ...\r' % (k % FLAGS.report * 100.0 / FLAGS.report)) 57 | sys.stdout.flush() 58 | if (k % FLAGS.report == 0): 59 | cost_time = time.time() - start_time 60 | #print("%d : loss = %.3f, time = %.3f" % (k // FLAGS.report, loss, cost_time), end=' ') 61 | write_log("%d : loss = %.3f, time = %.3f " % (k // FLAGS.report, loss, cost_time)) 62 | loss, start_time = 0.0, time.time() 63 | write_log(evaluate(sess, dataloader, model)) 64 | model.save(save_dir) 65 | 66 | 67 | def test(sess, dataloader, model): 68 | model.load(save_dir) 69 | print(evaluate(sess, dataloader, model, FLAGS.dev), end='') 70 | 71 | 72 | def evaluate(sess, dataloader, model, dev=False): 73 | if dev: 74 | evalset = dataloader.dev_set 75 | else: 76 | evalset = dataloader.test_set 77 | k = 0 78 | for x in dataloader.batch_iter(evalset, FLAGS.batch_size, False): 79 | predictions = model.generate(x, sess) 80 | for summary in np.array(predictions): 81 | with open(pred_path + str(k), 'w') as sw: 82 | summary = list(summary) 83 | if 2 in summary: 84 | summary = summary[:summary.index(2)] if summary[0] != 2 else [2] 85 | sw.write(" ".join([str(x) for x in summary]) + '\n') 86 | k += 1 87 | # print(k) 88 | pred_set = [pred_path + str(i) for i in range(k)] 89 | gold_set = [[gold_path + str(i)] for i in range(k)] 90 | recall, precision, F_measure = PythonROUGE(pred_set, gold_set, ngram_order=2) 91 | result = "F_measure: %s Recall: %s Precision: %s\n" % (str(F_measure), str(recall), str(precision)) 92 | #print(result) 93 | return result 94 | 95 | def write_log(s): 96 | print(s, end='') 97 | with open(log_file, 'a') as f: 98 | f.write(s) 99 | 100 | def main(): 101 | config = tf.ConfigProto(allow_soft_placement=True) 102 | config.gpu_options.allow_growth = True 103 | with tf.Session(config=config) as sess: 104 | dataloader = DataLoader(FLAGS.dir, FLAGS.limits) 105 | model = SeqUnit(batch_size=FLAGS.batch_size, hidden_size=FLAGS.hidden_size, emb_size=FLAGS.emb_size, 106 | source_vocab=FLAGS.source_vocab, target_vocab=FLAGS.target_vocab, scope_name="seq2seq", 107 | name="seq2seq", attention=FLAGS.attention, SRB=FLAGS.SRB) 108 | sess.run(tf.global_variables_initializer()) 109 | if FLAGS.load != '0': 110 | model.load(FLAGS.load) 111 | if FLAGS.mode == 'train': 112 | train(sess, dataloader, model) 113 | else: 114 | test(sess, dataloader, model) 115 | 116 | if __name__=='__main__': 117 | with tf.device('/gpu:' + FLAGS.gpu): 118 | main() -------------------------------------------------------------------------------- /SeqUnit.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from Unit import * 3 | import pickle 4 | from AttentionUnit import AttentionWrapper 5 | 6 | 7 | class SeqUnit(object): 8 | def __init__(self, batch_size, hidden_size, emb_size, source_vocab, target_vocab, scope_name, name, attention, 9 | start_token=2, stop_token=2, max_length=100, SRB=True, weight=0.0001): 10 | self.batch_size = batch_size 11 | self.hidden_size = hidden_size 12 | self.emb_size = emb_size 13 | self.source_vocab = source_vocab 14 | self.target_vocab = target_vocab 15 | self.grad_clip = 5.0 16 | self.start_token = start_token 17 | self.stop_token = stop_token 18 | self.max_length = max_length 19 | self.scope_name = scope_name 20 | self.name = name 21 | self.attention = attention 22 | self.SRB = SRB 23 | self.weight = weight 24 | 25 | self.units = {} 26 | self.params = {} 27 | 28 | self.encoder_input = tf.placeholder(tf.int32, [None, None]) 29 | self.decoder_input = tf.placeholder(tf.int32, [None, None]) 30 | self.encoder_len = tf.placeholder(tf.int32, [None]) 31 | self.decoder_len = tf.placeholder(tf.int32, [None]) 32 | self.decoder_output = tf.placeholder(tf.int32, [None, None]) 33 | 34 | with tf.variable_scope(scope_name): 35 | self.enc_lstm = LstmUnit(self.hidden_size, self.emb_size, 'encoder_lstm') 36 | self.dec_lstm = LstmUnit(self.hidden_size, self.emb_size, 'decoder_lstm') 37 | self.dec_out = OutputUnit(self.hidden_size, self.target_vocab, 'decoder_output') 38 | self.gated_linear = OutputUnit(self.hidden_size+self.emb_size, self.hidden_size, 'gated_linear') 39 | self.gated_output = OutputUnit(self.hidden_size, 1, 'gated_output') 40 | 41 | self.units.update({'encoder_lstm': self.enc_lstm, 'decoder_lstm': self.dec_lstm, 42 | 'decoder_output': self.dec_out, 'gated_linear': self.gated_linear, 43 | 'gated_output': self.gated_output}) 44 | 45 | with tf.device('/cpu:0'): 46 | with tf.variable_scope(scope_name): 47 | self.embedding = tf.get_variable('embedding', [self.source_vocab, self.emb_size]) 48 | self.encoder_embed = tf.nn.embedding_lookup(self.embedding, self.encoder_input) 49 | self.decoder_embed = tf.nn.embedding_lookup(self.embedding, self.decoder_input) 50 | self.params.update({'embedding': self.embedding}) 51 | 52 | if SRB: 53 | en_outputs, en_state = self.gated_attention_encoder(self.encoder_embed, self.encoder_len) 54 | else: 55 | en_outputs, en_state = self.encoder(self.encoder_embed, self.encoder_len) 56 | 57 | if self.attention: 58 | with tf.variable_scope(scope_name): 59 | self.att_layer = AttentionWrapper(self.hidden_size, self.hidden_size, en_outputs, "attention") 60 | self.units.update({'attention': self.att_layer}) 61 | de_outputs, de_state = self.decoder_t(en_state, self.decoder_embed, self.decoder_len) 62 | self.g_tokens = self.decoder_g(en_state) 63 | 64 | _, de_h = de_state 65 | _, en_h = en_state 66 | de_h = de_h - en_h 67 | similarity = tf.reduce_mean(tf.nn.l2_normalize(en_h, dim=-1) * tf.nn.l2_normalize(de_h, dim=-1)) 68 | losses = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=de_outputs, labels=self.decoder_output) 69 | mask = tf.sign(tf.to_float(self.decoder_output)) 70 | losses = mask * losses 71 | if SRB: 72 | self.mean_loss = tf.reduce_mean(losses) - self.weight * similarity 73 | else: 74 | self.mean_loss = tf.reduce_mean(losses) 75 | 76 | tvars = tf.trainable_variables() 77 | grads, _ = tf.clip_by_global_norm(tf.gradients(self.mean_loss, tvars), self.grad_clip) 78 | optimizer = tf.train.AdamOptimizer() 79 | self.train_op = optimizer.apply_gradients(zip(grads, tvars)) 80 | 81 | 82 | 83 | def encoder(self, inputs, inputs_len): 84 | batch_size = tf.shape(self.encoder_input)[0] 85 | max_time = tf.shape(self.encoder_input)[1] 86 | hidden_size = self.hidden_size 87 | 88 | time = tf.constant(0, dtype=tf.int32) 89 | h0 = (tf.zeros([batch_size, hidden_size], dtype=tf.float32), 90 | tf.zeros([batch_size, hidden_size], dtype=tf.float32)) 91 | f0 = tf.zeros([batch_size], dtype=tf.bool) 92 | inputs_ta = tf.TensorArray(dtype=tf.float32, size=max_time) 93 | inputs_ta = inputs_ta.unstack(tf.transpose(inputs, [1,0,2])) 94 | emit_ta = tf.TensorArray(dtype=tf.float32, dynamic_size=True, size=0) 95 | 96 | def loop_fn(t, x_t, s_t, emit_ta, finished): 97 | o_t, s_nt = self.enc_lstm(x_t, s_t, finished) 98 | emit_ta = emit_ta.write(t, o_t) 99 | finished = tf.greater_equal(t+1, inputs_len) 100 | x_nt = tf.cond(tf.reduce_all(finished), lambda: tf.zeros([batch_size, self.emb_size], dtype=tf.float32), 101 | lambda: inputs_ta.read(t+1)) 102 | return t+1, x_nt, s_nt, emit_ta, finished 103 | 104 | _, _, state, emit_ta, _ = tf.while_loop( 105 | cond=lambda _1, _2, _3, _4, finished: tf.logical_not(tf.reduce_all(finished)), 106 | body=loop_fn, 107 | loop_vars=(time, inputs_ta.read(0), h0, emit_ta, f0)) 108 | 109 | outputs = tf.transpose(emit_ta.stack(), [1,0,2]) 110 | return outputs, state 111 | 112 | 113 | def gated_attention_encoder(self, inputs, inputs_len): 114 | batch_size = tf.shape(self.encoder_input)[0] 115 | max_time = tf.shape(self.encoder_input)[1] 116 | hidden_size = self.hidden_size 117 | 118 | time = tf.constant(0, dtype=tf.int32) 119 | h0 = (tf.zeros([batch_size, hidden_size], dtype=tf.float32), 120 | tf.zeros([batch_size, hidden_size], dtype=tf.float32)) 121 | f0 = tf.zeros([batch_size], dtype=tf.bool) 122 | inputs_ta = tf.TensorArray(dtype=tf.float32, size=max_time) 123 | inputs_ta = inputs_ta.unstack(tf.transpose(inputs, [1,0,2])) 124 | emit_ta = tf.TensorArray(dtype=tf.float32, dynamic_size=True, size=0) 125 | 126 | def loop_fn(t, x_t, s_t, emit_ta, finished): 127 | o_t, s_nt = self.enc_lstm(x_t, s_t, finished) 128 | finished = tf.greater_equal(t+1, inputs_len) 129 | x_nt = tf.cond(tf.reduce_all(finished), lambda: tf.zeros([batch_size, self.emb_size], dtype=tf.float32), 130 | lambda: inputs_ta.read(t+1)) 131 | 132 | h = tf.nn.relu(self.gated_linear(tf.concat([o_t, x_nt], axis=-1))) 133 | p = tf.sigmoid(self.gated_output(o_t)) 134 | x_nt = x_nt * p 135 | o_t = o_t * p 136 | emit_ta = emit_ta.write(t, o_t) 137 | 138 | return t+1, x_nt, s_nt, emit_ta, finished 139 | 140 | _, _, state, emit_ta, _ = tf.while_loop( 141 | cond=lambda _1, _2, _3, _4, finished: tf.logical_not(tf.reduce_all(finished)), 142 | body=loop_fn, 143 | loop_vars=(time, inputs_ta.read(0), h0, emit_ta, f0)) 144 | 145 | outputs = tf.transpose(emit_ta.stack(), [1,0,2]) 146 | return outputs, state 147 | 148 | 149 | def decoder_t(self, initial_state, inputs, inputs_len): 150 | batch_size = tf.shape(self.decoder_input)[0] 151 | max_time = tf.shape(self.decoder_input)[1] 152 | 153 | time = tf.constant(0, dtype=tf.int32) 154 | h0 = initial_state 155 | f0 = tf.zeros([batch_size], dtype=tf.bool) 156 | x0 = tf.nn.embedding_lookup(self.embedding, tf.fill([batch_size], self.start_token)) 157 | inputs_ta = tf.TensorArray(dtype=tf.float32, size=max_time) 158 | inputs_ta = inputs_ta.unstack(tf.transpose(inputs, [1,0,2])) 159 | emit_ta = tf.TensorArray(dtype=tf.float32, dynamic_size=True, size=0) 160 | 161 | def loop_fn(t, x_t, s_t, emit_ta, finished): 162 | o_t, s_nt = self.dec_lstm(x_t, s_t, finished) 163 | if self.attention: 164 | o_t = self.att_layer(o_t) 165 | o_t = self.dec_out(o_t, finished) 166 | emit_ta = emit_ta.write(t, o_t) 167 | finished = tf.greater_equal(t, inputs_len) 168 | x_nt = tf.cond(tf.reduce_all(finished), lambda: tf.zeros([batch_size, self.emb_size], dtype=tf.float32), 169 | lambda: inputs_ta.read(t)) 170 | return t+1, x_nt, s_nt, emit_ta, finished 171 | 172 | _, _, state, emit_ta, _ = tf.while_loop( 173 | cond=lambda _1, _2, _3, _4, finished: tf.logical_not(tf.reduce_all(finished)), 174 | body=loop_fn, 175 | loop_vars=(time, x0, h0, emit_ta, f0)) 176 | 177 | outputs = tf.transpose(emit_ta.stack(), [1,0,2]) 178 | return outputs, state 179 | 180 | 181 | def decoder_g(self, initial_state): 182 | batch_size = tf.shape(self.encoder_input)[0] 183 | 184 | time = tf.constant(0, dtype=tf.int32) 185 | h0 = initial_state 186 | f0 = tf.zeros([batch_size], dtype=tf.bool) 187 | x0 = tf.nn.embedding_lookup(self.embedding, tf.fill([batch_size], self.start_token)) 188 | emit_ta = tf.TensorArray(dtype=tf.float32, dynamic_size=True, size=0) 189 | 190 | def loop_fn(t, x_t, s_t, emit_ta, finished): 191 | o_t, s_nt = self.dec_lstm(x_t, s_t, finished) 192 | if self.attention: 193 | o_t = self.att_layer(o_t) 194 | o_t = self.dec_out(o_t, finished) 195 | emit_ta = emit_ta.write(t, o_t) 196 | 197 | next_token = tf.arg_max(o_t, 1) 198 | x_nt = tf.nn.embedding_lookup(self.embedding, next_token) 199 | finished = tf.logical_or(finished, tf.equal(next_token, self.stop_token)) 200 | finished = tf.logical_or(finished, tf.greater_equal(t, self.max_length)) 201 | return t+1, x_nt, s_nt, emit_ta, finished 202 | 203 | _, _, state, emit_ta, _ = tf.while_loop( 204 | cond=lambda _1, _2, _3, _4, finished: tf.logical_not(tf.reduce_all(finished)), 205 | body=loop_fn, 206 | loop_vars=(time, x0, h0, emit_ta, f0)) 207 | 208 | outputs = tf.transpose(emit_ta.stack(), [1,0,2]) 209 | pred_tokens = tf.arg_max(outputs, 2) 210 | return pred_tokens 211 | 212 | 213 | def __call__(self, x, sess): 214 | loss, _ = sess.run([self.mean_loss, self.train_op], 215 | {self.encoder_input: x['enc_in'], self.encoder_len: x['enc_len'], 216 | self.decoder_input: x['dec_in'], self.decoder_len: x['dec_len'], 217 | self.decoder_output: x['dec_out']}) 218 | return loss 219 | 220 | def generate(self, x, sess): 221 | predictions = sess.run(self.g_tokens, 222 | {self.encoder_input: x['enc_in'], self.encoder_len: x['enc_len']}) 223 | return predictions 224 | 225 | 226 | def save(self, path): 227 | for u in self.units: 228 | self.units[u].save(path+u+".pkl") 229 | param_values = {} 230 | for param in self.params: 231 | param_values[param] = self.params[param].eval() 232 | with open(path+self.name+".pkl", 'wb') as f: 233 | pickle.dump(param_values, f, True) 234 | 235 | def load(self, path): 236 | for u in self.units: 237 | self.units[u].load(path+u+".pkl") 238 | param_values = pickle.load(open(path+self.name+".pkl", 'rb')) 239 | for param in param_values: 240 | self.params[param].load(param_values[param]) --------------------------------------------------------------------------------