├── readme.md ├── analyzer.py ├── reader.py ├── test.py ├── train.py └── model.py /readme.md: -------------------------------------------------------------------------------- 1 | # seqGAN 2 | 3 | A project just for practice, using seqGAN for dialogue generation 4 | 5 | Paper: [https://arxiv.org/abs/1701.06547]() 6 | 7 | See [TensorFlow Neural Machine Translation Tutorial](https://github.com/tensorflow/nmt) for more about implementation. 8 | 9 | This project still lacks a lot of components, but it receives attention far more than I expected. So these days I will make some enhancements to this project. 10 | 11 | In the following days, I will: 12 | 13 | * upload some results 14 | * fix some bugs to make it reproducible 15 | -------------------------------------------------------------------------------- /analyzer.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | file_post = open('data/small/weibo_pair_train_Q.post', 'rb') 4 | file_resp = open('data/small/weibo_pair_train_Q.response', 'rb') 5 | 6 | symbol = {} 7 | for line in file_post.readlines(): 8 | line = line.decode('utf-8')[:-1] 9 | for word in re.split(' ', line): 10 | if word in symbol: 11 | symbol[word] += 1 12 | else: 13 | symbol[word] = 1 14 | for line in file_resp.readlines(): 15 | line = line.decode('utf-8')[:-1] 16 | for word in re.split(' ', line): 17 | if word in symbol: 18 | symbol[word] += 1 19 | else: 20 | symbol[word] = 1 21 | 22 | file_result = open('result.txt', 'wb') 23 | 24 | total = sum([x[1] for x in symbol.items()]) 25 | num = 0 26 | word_num = 0 27 | 28 | file_result.write('\n\n\n\n') 29 | 30 | for item in sorted(symbol.items(), key=lambda x:x[1], reverse=True): 31 | print str([item[0].encode('utf-8')]) 32 | file_result.write(item[0].encode('utf-8') + '\n') 33 | num += item[1] 34 | word_num += 1 35 | if num >= total * 0.99: 36 | break 37 | 38 | print word_num 39 | -------------------------------------------------------------------------------- /reader.py: -------------------------------------------------------------------------------- 1 | import re 2 | EOS_ID = 1 3 | UNK_ID = 2 4 | class reader(): 5 | def __init__(self, file_name_post, file_name_resp, file_name_word): 6 | with open(file_name_word, 'rb') as file_word: 7 | self.d = {} 8 | self.symbol = [] 9 | num = 0 10 | for line in file_word.readlines(): 11 | line = line[:-1] 12 | self.symbol.append(line) 13 | self.d[line] = num 14 | num += 1 15 | self.file_name_post = file_name_post 16 | self.file_name_resp = file_name_resp 17 | self.post = open(self.file_name_post, 'rb') 18 | self.resp = open(self.file_name_resp, 'rb') 19 | self.epoch = 0 20 | self.k = 0 21 | 22 | def get_batch(self, batch_size): 23 | result = [] 24 | self.k += batch_size 25 | for _ in range(batch_size): 26 | post = self.post.readline() 27 | resp = self.resp.readline() 28 | if not post: 29 | self.restore() 30 | self.epoch += 1 31 | self.k = 0 32 | print 'epoch: ', self.epoch 33 | return self.get_batch(batch_size) 34 | post = post[:-1] 35 | resp = resp[:-1] 36 | words_post = re.split(' ', post) 37 | words_resp = re.split(' ', resp) 38 | index_post = [self.d[word] if word in self.d else UNK_ID for word in words_post] 39 | index_resp = [self.d[word] if word in self.d else UNK_ID for word in words_resp] 40 | index_resp = index_resp + [EOS_ID] 41 | result.append((index_post, index_resp)) 42 | return result 43 | 44 | def restore(self): 45 | self.post.close() 46 | self.resp.close() 47 | self.post = open(self.file_name_post, 'rb') 48 | self.resp = open(self.file_name_resp, 'rb') 49 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # encoding=utf8 2 | 3 | import tensorflow as tf 4 | 5 | from model import generator_model, discriminator_model 6 | from reader import reader 7 | import re 8 | 9 | gpu_rate = 0.25 10 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=gpu_rate) 11 | 12 | import os 13 | os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" # see issue #152 14 | os.environ["CUDA_VISIBLE_DEVICES"]="2" 15 | 16 | UNK_ID = 2 17 | reader = reader('data/small/weibo_pair_train_Q.post', 18 | 'data/small/weibo_pair_train_Q.response', 'data/words_99%.txt') 19 | 20 | def generate_batch(post): 21 | post = post.decode('utf-8') 22 | words_post = re.split(' ', post) 23 | index_post = [reader.d[word] if word in reader.d else UNK_ID for word in words_post] 24 | return [(index_post, [])] 25 | 26 | g_model = generator_model(vocab_size=len(reader.d), 27 | embedding_size=128, 28 | lstm_size=128, 29 | num_layer=4, 30 | max_length_encoder=40, 31 | max_length_decoder=40, 32 | max_gradient_norm=2, 33 | batch_size_num=20, 34 | learning_rate=0.001, 35 | beam_width=5) 36 | d_model = discriminator_model(vocab_size=len(reader.d), 37 | embedding_size=128, 38 | lstm_size=128, 39 | num_layer=4, 40 | max_post_length=40, 41 | max_resp_length=40, 42 | max_gradient_norm=2, 43 | batch_size_num=20, 44 | learning_rate=0.001) 45 | 46 | saver = tf.train.Saver(tf.global_variables(), keep_checkpoint_every_n_hours=1.0) 47 | sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) 48 | loader = tf.train.import_meta_graph('saved/model.ckpt.meta') 49 | loader.restore(sess, tf.train.latest_checkpoint('saved/')) 50 | print 'load finished' 51 | 52 | from_screen = raw_input('is input from screen: (y)/n') 53 | from_screen = False if from_screen == 'n' else True 54 | 55 | if not from_screen: 56 | file_input = open('data/small/test.post', 'r') 57 | bs_output = open('data/small/test_bs.response', 'w') 58 | sample_output = open('data/small/test_sample.response', 'w') 59 | 60 | while True: 61 | if from_screen: 62 | post = raw_input() 63 | else: 64 | post = file_input.readline() 65 | batch = generate_batch(post) 66 | resp = g_model.generate(sess, batch, 'beam_search') 67 | print resp 68 | resp = resp[0] 69 | 70 | print 'beam search' 71 | result = '' 72 | for sentence in resp: 73 | for index in sentence: 74 | result += reader.symbol[index] if index >= 0 else 'unk' 75 | result += ' ' 76 | result += '\n' 77 | result += '\n' 78 | if from_screen: 79 | print result, 80 | else: 81 | bs_output.write(result) 82 | 83 | resp = g_model.generate(sess, batch, 'sample') 84 | resp = resp[0] 85 | 86 | print 'sample' 87 | result = '' 88 | for word in resp: 89 | result += reader.symbol[word] if word >= 0 else 'unk' 90 | result += '' 91 | result += '\n' 92 | if from_screen: 93 | print result, 94 | else: 95 | sample_output.write(result) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from model import generator_model, discriminator_model 3 | from reader import reader 4 | import sys 5 | reload(sys) 6 | sys.setdefaultencoding("utf-8") 7 | 8 | gpu_rate = 0.5 9 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=gpu_rate) 10 | 11 | import os 12 | os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" 13 | os.environ["CUDA_VISIBLE_DEVICES"]="3" 14 | 15 | from tensorflow.python.client import device_lib 16 | print device_lib.list_local_devices() 17 | 18 | reader = reader('data/small/weibo_pair_train_Q.post', 19 | 'data/small/weibo_pair_train_Q.response', 'data/words_99%.txt') 20 | 21 | print len(reader.d) 22 | 23 | g_model = generator_model(vocab_size=len(reader.d), 24 | embedding_size=128, 25 | lstm_size=128, 26 | num_layer=4, 27 | max_length_encoder=40, 28 | max_length_decoder=40, 29 | max_gradient_norm=2, 30 | batch_size_num=20, 31 | learning_rate=0.001, 32 | beam_width=5) 33 | d_model = discriminator_model(vocab_size=len(reader.d), 34 | embedding_size=128, 35 | lstm_size=128, 36 | num_layer=4, 37 | max_post_length=40, 38 | max_resp_length=40, 39 | max_gradient_norm=2, 40 | batch_size_num=20, 41 | learning_rate=0.001) 42 | 43 | saver = tf.train.Saver(tf.global_variables(), keep_checkpoint_every_n_hours=1.0) 44 | 45 | config = tf.ConfigProto() 46 | config.gpu_options.allow_growth = True 47 | sess = tf.Session(config=config) 48 | try: 49 | loader = tf.train.import_meta_graph('saved/model.ckpt.meta') 50 | loader.restore(sess, tf.train.latest_checkpoint('saved/')) 51 | print 'load finished' 52 | except: 53 | sess.run(tf.global_variables_initializer()) 54 | print 'load failed' 55 | 56 | d_step = 5 57 | g_step = 50 58 | loop_num = 0 59 | 60 | try: 61 | for __ in range(2000): 62 | for _ in range(50): 63 | g_model.pretrain(sess, reader) 64 | if _ % 100 == 0: 65 | batch = reader.get_batch(g_model.batch_size) 66 | result = g_model.generate(sess, batch, 'sample') 67 | for index in range(g_model.batch_size): 68 | post = batch[index][0] 69 | resp = result[index] 70 | def output(l): 71 | print '[', 72 | for word in l: 73 | print reader.symbol[word], 74 | print ']' 75 | output(post) 76 | output(resp) 77 | print '\n', 78 | 79 | for _ in range(5): 80 | d_model.update(sess, g_model, reader) 81 | if __ % 40 == 0: 82 | saver.save(sess, 'saved/model.ckpt') 83 | 84 | while True: 85 | for _ in range(d_step): 86 | g_model.update(sess, d_model, reader) 87 | for _ in range(g_step): 88 | d_model.update(sess, g_model, reader) 89 | loop_num += 1 90 | if loop_num % 50 == 0: 91 | saver.save(sess, 'saved/model.ckpt') 92 | except KeyboardInterrupt: 93 | saver.save(sess, 'saved/model.ckpt') 94 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.contrib.seq2seq 3 | import numpy as np 4 | from tensorflow.python.layers import core as layers_core 5 | 6 | GO_ID = 0 7 | EOS_ID = 1 8 | UNK_ID = 2 9 | PAD_ID = 3 10 | 11 | sample_times = 5 12 | 13 | def cut(resp): 14 | for time in range(len(resp)): 15 | if resp[time] == EOS_ID: 16 | resp = resp[:time+1] 17 | break 18 | return resp 19 | 20 | def out(batch, reader): 21 | for post, resp in batch: 22 | print '(', 23 | for word in post: 24 | print reader.symbol[word], 25 | print '), (', 26 | for word in resp: 27 | print reader.symbol[word], 28 | print ')\n', 29 | 30 | class generator_model(): 31 | def __init__(self, 32 | vocab_size, 33 | embedding_size, 34 | lstm_size, 35 | num_layer, 36 | max_length_encoder, max_length_decoder, 37 | max_gradient_norm, 38 | batch_size_num, 39 | learning_rate, 40 | beam_width, 41 | embed=None): 42 | self.batch_size = batch_size_num 43 | self.max_length_encoder = max_length_encoder 44 | self.max_length_decoder = max_length_decoder 45 | with tf.variable_scope('g_model') as scope: 46 | self.encoder_input = tf.placeholder(tf.int32, [max_length_encoder, None]) 47 | self.decoder_output = tf.placeholder(tf.int32, [max_length_decoder, None]) 48 | self.target_weight = tf.placeholder(tf.float32, [max_length_decoder, None]) # for pretraining or updating 49 | self.reward = tf.placeholder(tf.float32, [max_length_decoder, None]) # for updating 50 | self.start_tokens = tf.placeholder(tf.int32, [None]) # for partial-sampling 51 | self.max_inference_length = tf.placeholder(tf.int32, []) # for inference 52 | 53 | self.encoder_length = tf.placeholder(tf.int32, [None]) 54 | self.decoder_length = tf.placeholder(tf.int32, [None]) 55 | batch_size = tf.shape(self.encoder_length)[0] 56 | #batch_size = 1 57 | decoder_output = self.decoder_output 58 | # if decoder_output have 0 dimention ??? 59 | self.decoder_input = tf.concat([tf.ones([1, batch_size], dtype=tf.int32) * GO_ID, decoder_output[:-1]], axis=0) 60 | if embed == None: 61 | embedding = tf.get_variable('embedding', [vocab_size, embedding_size]) 62 | else: 63 | embedding = tf.get_variable('embedding', [vocab_size, embedding_size], initializer=embed) 64 | encoder_embedded = tf.nn.embedding_lookup(embedding, self.encoder_input) 65 | decoder_embedded = tf.nn.embedding_lookup(embedding, self.decoder_input) 66 | 67 | self.cell_state = tf.placeholder(tf.float32, [2*num_layer, None, lstm_size]) # for partial-sampling 68 | self.attention = tf.placeholder(tf.float32, [None, lstm_size]) 69 | self.time = tf.placeholder(tf.int32) 70 | self.alignments = tf.placeholder(tf.float32, [None, max_length_encoder]) 71 | 72 | def build_attention_state(): 73 | cell_state = tuple([tf.contrib.rnn.LSTMStateTuple(self.cell_state[i], self.cell_state[i+1]) 74 | for i in range(0, 2*num_layer, 2)]) 75 | print cell_state 76 | return tf.contrib.seq2seq.AttentionWrapperState(cell_state, 77 | self.attention, self.time, self.alignments, tuple([])) 78 | partial_decoder_state = build_attention_state() 79 | 80 | print 'shape:', decoder_output.get_shape() 81 | def single_cell(): 82 | return tf.contrib.rnn.BasicLSTMCell(lstm_size) 83 | def multi_cell(): 84 | return tf.contrib.rnn.MultiRNNCell([single_cell() for _ in range(num_layer)]) 85 | with tf.variable_scope('encoder'): 86 | encoder_cell = multi_cell() 87 | encoder_output, encoder_state = tf.nn.dynamic_rnn(encoder_cell, encoder_embedded, 88 | self.encoder_length, time_major=True, dtype=tf.float32) 89 | 90 | with tf.variable_scope('decoder') as decoder_scope: 91 | attention_state = tf.transpose(encoder_output, [1, 0, 2]) 92 | print attention_state, lstm_size, self.encoder_length 93 | attention_mechanism = tf.contrib.seq2seq.LuongAttention(lstm_size, attention_state, 94 | memory_sequence_length=self.encoder_length) 95 | # train or evaluate 96 | decoder_cell_raw = multi_cell() 97 | # attention wrapper 98 | decoder_cell = tf.contrib.seq2seq.AttentionWrapper(decoder_cell_raw, attention_mechanism, 99 | attention_layer_size=lstm_size) 100 | decoder_init_state = decoder_cell.zero_state(batch_size, tf.float32).clone(cell_state=encoder_state) 101 | 102 | helper = tf.contrib.seq2seq.TrainingHelper(decoder_embedded, self.decoder_length, time_major=True) 103 | projection_layer = layers_core.Dense(vocab_size) # use_bias ? 104 | decoder = tf.contrib.seq2seq.BasicDecoder(decoder_cell, helper, decoder_init_state, output_layer=projection_layer) 105 | 106 | output, decoder_state, _ = tf.contrib.seq2seq.dynamic_decode(decoder, output_time_major=True, swap_memory=True, scope=decoder_scope) 107 | logits = output.rnn_output 108 | self.result_train = tf.transpose(output.sample_id) 109 | self.decoder_state = decoder_state 110 | # inference (sample) 111 | helper_sample = tf.contrib.seq2seq.SampleEmbeddingHelper(embedding, 112 | start_tokens=tf.fill([batch_size], GO_ID), end_token=EOS_ID) 113 | decoder_sample = tf.contrib.seq2seq.BasicDecoder(decoder_cell, helper_sample, decoder_init_state, 114 | output_layer=projection_layer) 115 | output, _, _ = tf.contrib.seq2seq.dynamic_decode(decoder_sample, 116 | swap_memory=True, scope=decoder_scope, maximum_iterations=self.max_inference_length) 117 | self.result_sample = output.sample_id 118 | 119 | # inference (partial-sample) 120 | helper_partial = tf.contrib.seq2seq.SampleEmbeddingHelper(embedding, 121 | start_tokens=self.start_tokens, end_token=EOS_ID) 122 | decoder_partial = tf.contrib.seq2seq.BasicDecoder(decoder_cell, helper_partial, partial_decoder_state, 123 | output_layer=projection_layer) 124 | output, _, _ = tf.contrib.seq2seq.dynamic_decode(decoder_partial, 125 | swap_memory=True, scope=decoder_scope, maximum_iterations=self.max_inference_length) 126 | self.result_partial = output.sample_id 127 | 128 | # inference (greedy) 129 | helper_greedy = tf.contrib.seq2seq.GreedyEmbeddingHelper(embedding, 130 | start_tokens=tf.fill([batch_size], GO_ID), end_token=EOS_ID) 131 | decoder_greedy = tf.contrib.seq2seq.BasicDecoder(decoder_cell, helper_greedy, decoder_init_state, 132 | output_layer=projection_layer) 133 | output, _, _ = tf.contrib.seq2seq.dynamic_decode(decoder_greedy, 134 | swap_memory=True, scope=decoder_scope, maximum_iterations=self.max_inference_length) 135 | self.result_greedy = output.sample_id 136 | 137 | # inference (beam search) 138 | #with tf.variable_scope('decoder', reuse=True) as decoder_scope: 139 | attention_state = tf.contrib.seq2seq.tile_batch( 140 | attention_state, multiplier=beam_width) 141 | source_seq_length = tf.contrib.seq2seq.tile_batch( 142 | self.encoder_length, multiplier=beam_width) 143 | encoder_state = tf.contrib.seq2seq.tile_batch( 144 | encoder_state, multiplier=beam_width) 145 | with tf.variable_scope('decoder', reuse=True) as decoder_scope: 146 | attention_mechanism = tf.contrib.seq2seq.LuongAttention(lstm_size, attention_state, 147 | memory_sequence_length=source_seq_length) 148 | 149 | decoder_cell = tf.contrib.seq2seq.AttentionWrapper(decoder_cell_raw, attention_mechanism, 150 | attention_layer_size=lstm_size) 151 | beam_search_init_state = decoder_cell.zero_state(batch_size*beam_width, tf.float32).clone(cell_state=encoder_state) 152 | print 'bs_init_state:', beam_search_init_state 153 | decoder_beam_search = tf.contrib.seq2seq.BeamSearchDecoder( 154 | cell=decoder_cell, 155 | embedding=embedding, 156 | start_tokens=tf.fill([batch_size], GO_ID), end_token=EOS_ID, 157 | initial_state=beam_search_init_state, 158 | beam_width=beam_width, 159 | output_layer=projection_layer, 160 | length_penalty_weight=0.0) 161 | output, _, _ = tf.contrib.seq2seq.dynamic_decode(decoder_beam_search, 162 | swap_memory=True, scope=decoder_scope, maximum_iterations=self.max_inference_length) 163 | self.result_beam_search = tf.transpose(output.predicted_ids, [0, 2, 1]) 164 | 165 | dim = tf.shape(logits)[0] 166 | decoder_output = tf.split(decoder_output, [dim, max_length_decoder-dim])[0] 167 | target_weight = tf.split(self.target_weight, [dim, max_length_decoder-dim])[0] 168 | reward = tf.split(self.reward, [dim, max_length_decoder-dim])[0] 169 | 170 | params = scope.trainable_variables() 171 | print 'shape:', logits.get_shape() 172 | # update for pretraining 173 | cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=decoder_output, logits=logits) # max_len * batch 174 | self.loss_pretrain = tf.reduce_sum(target_weight * cross_entropy) / tf.cast(batch_size, tf.float32) 175 | self.perplexity = tf.exp(tf.reduce_sum(target_weight * cross_entropy) / tf.reduce_sum(target_weight)) 176 | gradient_pretrain = tf.gradients(self.loss_pretrain, params) 177 | gradient_pretrain, _ = tf.clip_by_global_norm(gradient_pretrain, max_gradient_norm) 178 | optimizer = tf.train.AdamOptimizer(learning_rate) 179 | self.opt_pretrain = optimizer.apply_gradients(zip(gradient_pretrain, params)) 180 | 181 | # update for GAN 182 | one_hot = tf.one_hot(decoder_output, vocab_size) 183 | self.prob = tf.reduce_sum(one_hot * tf.nn.softmax(logits), axis=2) 184 | self.loss_generator = tf.reduce_sum(-tf.log(tf.maximum(self.prob, 1e-5)) * reward * target_weight) / tf.cast(batch_size, tf.float32) 185 | gradient_generator = tf.gradients(self.loss_generator, params) 186 | gradient_generator, _ = tf.clip_by_global_norm(gradient_generator, max_gradient_norm) 187 | optimizer = tf.train.AdamOptimizer(learning_rate) 188 | self.opt_update = optimizer.apply_gradients(zip(gradient_generator, params)) 189 | 190 | def all_params(self): 191 | with tf.variable_scope('g_model') as scope: 192 | total = 0 193 | for var in scope.trainable_variables(): 194 | shape = var.get_shape() 195 | k = 1 196 | print shape, 197 | for dim in shape: 198 | k *= dim.value 199 | print k, var.name 200 | total += k 201 | print 'total:', total 202 | 203 | def pretrain(self, sess, reader): 204 | feed_post = [[] for _ in range(self.max_length_encoder)] 205 | feed_resp = [[] for _ in range(self.max_length_decoder)] 206 | feed_weight = [[] for _ in range(self.max_length_decoder)] 207 | feed_post_length = [] 208 | feed_resp_length = [] 209 | 210 | # read training data 211 | batch = reader.get_batch(self.batch_size) 212 | for post, resp in batch: 213 | feed_post_length.append(len(post)) 214 | feed_resp_length.append(len(resp)) 215 | for time in range(self.max_length_encoder): 216 | feed_post[time].append(post[time] if time < len(post) else PAD_ID) 217 | feed_resp[time].append(resp[time] if time < len(resp) else PAD_ID) 218 | if time < len(resp): 219 | feed_weight[time].append(1) 220 | else: 221 | feed_weight[time].append(0) 222 | 223 | feed_dict = {} 224 | feed_dict[self.encoder_input] = feed_post 225 | feed_dict[self.decoder_output] = feed_resp 226 | feed_dict[self.target_weight] = feed_weight 227 | feed_dict[self.encoder_length] = feed_post_length 228 | feed_dict[self.decoder_length] = feed_resp_length 229 | 230 | perplexity, result, loss, state, _ = sess.run([self.perplexity, self.result_train, self.loss_pretrain, self.decoder_state, self.opt_pretrain], feed_dict=feed_dict) 231 | """ 232 | for sentence in result: 233 | for word in sentence: 234 | print reader.symbol[word], 235 | if word == EOS_ID: 236 | break 237 | print '\n', 238 | """ 239 | print 'perplexity:', perplexity, 240 | print 'loss:', loss, 241 | print reader.epoch, str(reader.k)+'/958640', reader.k / 958640.0 242 | 243 | 244 | def update(self, sess, discriminator, reader): 245 | # for each post, sample a response 246 | batch = reader.get_batch(self.batch_size) 247 | resp_generator = self.generate(sess, batch, 'sample') 248 | 249 | max_len = len(resp_generator[0]) 250 | feed_reward = [] # max_len * batch_size 251 | feed_post = [[] for _ in range(self.max_length_encoder)] 252 | feed_post_length = [] 253 | for index in range(self.batch_size): 254 | post = batch[index][0] 255 | feed_post_length.append(len(post)) 256 | for time in range(self.max_length_encoder): 257 | feed_post[time].append(post[time] if time < len(post) else PAD_ID) 258 | for t in range(max_len+1): 259 | # for each partial response, get the final hidden state 260 | feed_resp = [[] for _ in range(self.max_length_decoder)] 261 | feed_resp_length = [] 262 | for index in range(self.batch_size): 263 | resp = resp_generator[index] 264 | resp = cut(resp) 265 | resp = resp[:t] 266 | feed_resp_length.append(len(resp)) 267 | for time in range(self.max_length_decoder): 268 | feed_resp[time].append(resp[time] if time < len(resp) else PAD_ID) 269 | feed_dict = {} 270 | feed_dict[self.encoder_input] = feed_post 271 | feed_dict[self.encoder_length] = feed_post_length 272 | feed_dict[self.decoder_output] = feed_resp 273 | feed_dict[self.decoder_length] = feed_resp_length 274 | 275 | state = sess.run(self.decoder_state, feed_dict=feed_dict) 276 | 277 | # from partial response, randomly sample several full responses 278 | start_tokens = [resp[t-1] for resp in resp_generator] if t >= 1 else [GO_ID] * self.batch_size 279 | mean_reward = [0 for _ in range(self.batch_size)] 280 | for num in range(sample_times): 281 | cell_state = [] 282 | for lstm_tuple in state.cell_state: 283 | cell_state = cell_state + [lstm_tuple.c, lstm_tuple.h] 284 | feed_dict = {} 285 | feed_dict[self.encoder_input] = feed_post 286 | feed_dict[self.encoder_length] = feed_post_length 287 | feed_dict[self.start_tokens] = start_tokens 288 | feed_dict[self.max_inference_length] = self.max_length_decoder - t 289 | feed_dict[self.cell_state] = cell_state 290 | feed_dict[self.attention] = state.attention 291 | feed_dict[self.time] = state.time 292 | feed_dict[self.alignments] = state.alignments 293 | 294 | output = sess.run(self.result_partial, feed_dict=feed_dict) 295 | # feed into disciminator and compute Q 296 | feed_resp = [] 297 | for index in range(self.batch_size): 298 | resp = resp_generator[index] 299 | resp = cut(resp) 300 | length = len(resp) 301 | resp = resp[:t] 302 | final_resp = np.append(resp, output[index]) if length > t else resp 303 | feed_resp.append(final_resp) 304 | feed_batch = [(batch[index][0], feed_resp[index]) for index in range(self.batch_size)] 305 | poss = discriminator.evaluate(sess, feed_batch, reader) 306 | for index in range(self.batch_size): 307 | mean_reward[index] += poss[index] / sample_times 308 | for index in range(self.batch_size): 309 | resp = cut(resp_generator[index])[:t] 310 | print poss[index], 311 | print '[', 312 | for word in resp: 313 | print reader.symbol[word], 314 | print '], [', 315 | for word in output[index]: 316 | print reader.symbol[word], 317 | print ']' 318 | feed_reward.append(mean_reward) 319 | for t in range(max_len): 320 | for index in range(self.batch_size): 321 | feed_reward[t][index] = feed_reward[t+1][index] - feed_reward[t][index] 322 | feed_reward = feed_reward[:max_len] 323 | feed_reward = feed_reward + [[0 for _ in range(self.batch_size)]] * (self.max_length_decoder - max_len) 324 | #print feed_reward 325 | 326 | # update generator 327 | feed_resp = [[] for _ in range(self.max_length_decoder)] 328 | feed_weight = [[] for _ in range(self.max_length_decoder)] 329 | feed_resp_length = [] 330 | for index in range(self.batch_size): 331 | resp = resp_generator[index] 332 | resp = cut(resp) 333 | feed_resp_length.append(len(resp)) 334 | for time in range(self.max_length_decoder): 335 | feed_resp[time].append(resp[time] if time < len(resp) else PAD_ID) 336 | if time < len(resp): 337 | feed_weight[time].append(1) 338 | else: 339 | feed_weight[time].append(0) 340 | print feed_reward 341 | feed_dict = {} 342 | feed_dict[self.encoder_input] = feed_post 343 | feed_dict[self.encoder_length] = feed_post_length 344 | feed_dict[self.decoder_output] = feed_resp 345 | feed_dict[self.decoder_length] = feed_resp_length 346 | feed_dict[self.reward] = feed_reward 347 | feed_dict[self.target_weight] = feed_weight 348 | 349 | loss, prob, _ = sess.run([self.loss_generator, self.prob, self.opt_update], feed_dict=feed_dict) 350 | print 'generator updated, loss =', prob, loss 351 | 352 | # teacher forcing 353 | evaluate_result = discriminator.evaluate(sess, batch, reader) 354 | 355 | feed_dict = {} 356 | feed_dict[self.encoder_input] = feed_post 357 | feed_dict[self.encoder_length] = feed_post_length 358 | feed_resp = [[] for _ in range(self.max_length_decoder)] 359 | feed_weight = [[] for _ in range(self.max_length_decoder)] 360 | feed_reward = [[] for _ in range(self.max_length_decoder)] 361 | feed_resp_length = [] 362 | for index in range(self.batch_size): 363 | resp = batch[index][1] 364 | resp = cut(resp) 365 | feed_resp_length.append(len(resp)) 366 | for time in range(self.max_length_decoder): 367 | feed_reward[time].append(evaluate_result[index] - 0.5) 368 | feed_resp[time].append(resp[time] if time < len(resp) else PAD_ID) 369 | if time < len(resp): 370 | feed_weight[time].append(1) 371 | else: 372 | feed_weight[time].append(0) 373 | feed_dict[self.decoder_output] = feed_resp 374 | feed_dict[self.decoder_length] = feed_resp_length 375 | feed_dict[self.reward] = feed_reward 376 | feed_dict[self.target_weight] = feed_weight 377 | 378 | loss, prob, _ = sess.run([self.loss_generator, self.prob, self.opt_update], feed_dict=feed_dict) 379 | print 'generator updated, loss =', prob, loss 380 | 381 | def generate(self, sess, batch, mode): 382 | feed_post = [[] for _ in range(self.max_length_encoder)] 383 | feed_weight = [[] for _ in range(self.max_length_decoder)] 384 | feed_post_length = [] 385 | for post, resp in batch: 386 | feed_post_length.append(len(post)) 387 | for time in range(self.max_length_encoder): 388 | feed_post[time].append(post[time] if time < len(post) else PAD_ID) 389 | feed_dict = {} 390 | feed_dict[self.encoder_input] = feed_post 391 | feed_dict[self.encoder_length] = feed_post_length 392 | feed_dict[self.max_inference_length] = self.max_length_decoder-1 393 | result = None 394 | if mode == 'sample': 395 | result = sess.run(self.result_sample, feed_dict=feed_dict) 396 | elif mode == 'greedy': 397 | result = sess.run(self.result_greedy, feed_dict=feed_dict) 398 | elif mode == 'beam_search': 399 | result = sess.run(self.result_beam_search, feed_dict=feed_dict) 400 | return result 401 | 402 | class discriminator_model(): 403 | def __init__(self, vocab_size, 404 | embedding_size, 405 | lstm_size, 406 | num_layer, 407 | max_post_length, max_resp_length, 408 | max_gradient_norm, 409 | batch_size_num, 410 | learning_rate, 411 | embed=None): 412 | self.batch_size = batch_size_num 413 | self.max_post_length = max_post_length 414 | self.max_resp_length = max_resp_length 415 | #with tf.variable_scope('g_model', reuse=True): 416 | #embedding = tf.get_variable('embedding') 417 | if embed == None: 418 | embedding = tf.get_variable('embedding', [vocab_size, embedding_size]) 419 | else: 420 | embedding = tf.get_variable('embedding', [vocab_size, embedding_size], initializer=embed) 421 | with tf.variable_scope('d_model') as scope: 422 | self.post_input = tf.placeholder(tf.int32, [max_post_length, None]) 423 | self.resp_input = tf.placeholder(tf.int32, [max_resp_length, None]) 424 | self.post_length = tf.placeholder(tf.int32, [None]) 425 | self.resp_length = tf.placeholder(tf.int32, [None]) 426 | self.labels = tf.placeholder(tf.int64, [None]) 427 | 428 | batch_size = tf.shape(self.post_length)[0] 429 | post_embedded = tf.nn.embedding_lookup(embedding, self.post_input) 430 | resp_embedded = tf.nn.embedding_lookup(embedding, self.resp_input) 431 | def single_cell(): 432 | return tf.contrib.rnn.BasicLSTMCell(lstm_size) 433 | def multi_cell(): 434 | return tf.contrib.rnn.MultiRNNCell([single_cell() for _ in range(num_layer)]) 435 | with tf.variable_scope('encoder'): 436 | cell = multi_cell() 437 | post_output, post_state = tf.nn.dynamic_rnn(cell, post_embedded, 438 | self.post_length, time_major=True, dtype=tf.float32) 439 | with tf.variable_scope('encoder', reuse=True): 440 | resp_output, resp_state = tf.nn.dynamic_rnn(cell, resp_embedded, 441 | self.resp_length, time_major=True, dtype=tf.float32) 442 | 443 | def concat(lstm_tuple): 444 | return tf.concat([tf.concat([pair.c, pair.h], axis=1) for pair in lstm_tuple], axis=1) 445 | post_state_concat = concat(post_state) 446 | resp_state_concat = concat(resp_state) 447 | 448 | """ 449 | cell_sentence = tf.contrib.rnn.BasicLSTMCell(lstm_size) 450 | init_state = cell_sentence.zero_state(batch_size, tf.float32) 451 | out1, state_mid = cell_sentence(post_state_concat, init_state, scope=scope) 452 | out2, state_final = cell_sentence(resp_state_concat, state_mid, scope=scope) 453 | 454 | state_final_concat = tf.concat([state_final.c, state_final.h], axis=1) 455 | """ 456 | state_final_concat = tf.concat([post_state_concat, resp_state_concat], axis=1) 457 | logits = tf.layers.dense(state_final_concat, 2) 458 | print logits, self.labels 459 | self.loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=self.labels, logits=logits)) 460 | self.poss = tf.nn.softmax(logits)[:, 1] 461 | 462 | result = tf.argmax(logits, axis=1) 463 | self.acc = tf.reduce_mean(tf.cast(tf.equal(result, self.labels), tf.float32)) 464 | 465 | params = scope.trainable_variables() 466 | gradients = tf.gradients(self.loss, params) 467 | gradients, _ = tf.clip_by_global_norm(gradients, max_gradient_norm) 468 | optimizer = tf.train.AdamOptimizer(learning_rate) 469 | self.opt_train = optimizer.apply_gradients(zip(gradients, params)) 470 | def all_params(self): 471 | with tf.variable_scope('d_model') as scope: 472 | total = 0 473 | for var in scope.trainable_variables(): 474 | shape = var.get_shape() 475 | k = 1 476 | print shape, 477 | for dim in shape: 478 | k *= dim.value 479 | print k, var.name 480 | total += k 481 | print 'total:', total 482 | 483 | def update(self, sess, generator, reader): 484 | batch = reader.get_batch(self.batch_size) 485 | 486 | resp_generator = generator.generate(sess, batch, 'sample') 487 | feed_post = [[] for _ in range(self.max_post_length)] 488 | feed_resp = [[] for _ in range(self.max_resp_length)] 489 | feed_post_length = [] 490 | feed_resp_length = [] 491 | feed_labels = [] 492 | print 'positive:' 493 | out(batch, reader) 494 | print 'negative:' 495 | out([(batch[index][0], resp_generator[index]) for index in range(self.batch_size)], reader) 496 | for post, resp in batch: 497 | feed_post_length.append(len(post)) 498 | feed_resp_length.append(len(resp)) 499 | feed_labels.append(1) 500 | for time in range(self.max_post_length): 501 | feed_post[time].append(post[time] if time < len(post) else PAD_ID) 502 | feed_resp[time].append(resp[time] if time < len(resp) else PAD_ID) 503 | #out([(batch[index][0], resp_generator[index]) for index in range(self.batch_size)], reader) 504 | for index in range(self.batch_size): 505 | post = batch[index][0] 506 | resp = resp_generator[index] 507 | resp = cut(resp) 508 | feed_post_length.append(len(post)) 509 | feed_resp_length.append(len(resp)) 510 | feed_labels.append(0) 511 | for time in range(self.max_post_length): 512 | feed_post[time].append(post[time] if time < len(post) else PAD_ID) 513 | feed_resp[time].append(resp[time] if time < len(resp) else PAD_ID) 514 | 515 | feed_dict = {} 516 | feed_dict[self.post_input] = feed_post 517 | feed_dict[self.resp_input] = feed_resp 518 | feed_dict[self.post_length] = feed_post_length 519 | feed_dict[self.resp_length] = feed_resp_length 520 | feed_dict[self.labels] = feed_labels 521 | 522 | poss, loss, acc, _ = sess.run([self.poss, self.loss, self.acc, self.opt_train], feed_dict=feed_dict) 523 | print 'discriminator:', loss, acc, poss 524 | 525 | def evaluate(self, sess, batch, reader): 526 | feed_post = [[] for _ in range(self.max_post_length)] 527 | feed_resp = [[] for _ in range(self.max_resp_length)] 528 | feed_post_length = [] 529 | feed_resp_length = [] 530 | #out(batch, reader) 531 | for post, resp in batch: 532 | feed_post_length.append(len(post)) 533 | resp = cut(resp) 534 | feed_resp_length.append(len(resp)) 535 | for time in range(self.max_post_length): 536 | feed_post[time].append(post[time] if time < len(post) else PAD_ID) 537 | feed_resp[time].append(resp[time] if time < len(resp) else PAD_ID) 538 | 539 | feed_dict = {} 540 | feed_dict[self.post_input] = feed_post 541 | feed_dict[self.resp_input] = feed_resp 542 | feed_dict[self.post_length] = feed_post_length 543 | feed_dict[self.resp_length] = feed_resp_length 544 | 545 | poss = sess.run(self.poss, feed_dict=feed_dict) 546 | return poss 547 | 548 | if __name__ == '__main__': 549 | g = generator_model(1000, 128, 101, 4, 98, 99, 5, 20, 0.001, 5) 550 | #g.all_params() 551 | d = discriminator_model(1000, 100, 101, 4, 40, 40, 2, 20, 0.001) 552 | #d.all_params() 553 | --------------------------------------------------------------------------------