├── save ├── experiment-log.txt └── target_params_py3.pkl ├── figures ├── lc.png └── seqgan.png ├── .gitignore ├── model_settings.py ├── README.md ├── dataloader.py ├── target_lstm.py ├── sequence_gan.py ├── generator.py ├── discriminator.py └── rollout.py /save/experiment-log.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /figures/lc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/desire2020/RankGAN/HEAD/figures/lc.png -------------------------------------------------------------------------------- /figures/seqgan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/desire2020/RankGAN/HEAD/figures/seqgan.png -------------------------------------------------------------------------------- /save/target_params_py3.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/desire2020/RankGAN/HEAD/save/target_params_py3.pkl -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .idea 3 | save/eval_file.txt 4 | save/generator_sample.txt 5 | save/real_data.txt 6 | -------------------------------------------------------------------------------- /model_settings.py: -------------------------------------------------------------------------------- 1 | use_real_data = False 2 | COCO_vocab_size = 4980 3 | COCO_seq_len = 32 4 | NEWS_vocab_size = 5742 5 | NEWS_seq_len = 48 6 | seq_len = 20 7 | real_data_vocab_size = NEWS_vocab_size 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RankGAN 2 | 3 | ## Requirements: 4 | * **Tensorflow r1.6.0** 5 | * Python 3.x 6 | * CUDA 9.0 (For GPU) 7 | 8 | ## Introduction 9 | Apply Generative Adversarial Nets to generating sequences of discrete tokens with optimization via replacing the discriminator with a ranker. 10 | 11 | The previous research paper [SeqGAN: Sequence Generative Adversarial Nets with Policy Gradient](http://arxiv.org/abs/1609.05473) has been accepted at the Thirty-First AAAI Conference on Artificial Intelligence (AAAI-17). 12 | 13 | The research paper [Adversarial Ranking for Language Generation](https://papers.nips.cc/paper/6908-adversarial-ranking-for-language-generation.pdf) has been accepted at 31st Conference on Neural Information Processing Systems (NIPS 2017). 14 | 15 | We reproduce example codes to repeat the synthetic data experiments with oracle evaluation mechanisms. 16 | To run the experiment with default parameters: 17 | ``` 18 | $ python sequence_gan.py 19 | ``` 20 | You can change the all the parameters in `sequence_gan.py`. 21 | 22 | The experiment has two stages. In the first stage, use the positive data provided by the oracle model and Maximum Likelihood Estimation to perform supervise learning. In the second stage, use adversarial training to improve the generator. 23 | 24 | 25 | Note: this code is based on the [previous work by ofirnachum](https://github.com/ofirnachum/sequence_gan) and [SeqGAN](https://github.com/LantaoYu/SeqGAN) . 26 | -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import model_settings 4 | 5 | class Gen_Data_loader(): 6 | def __init__(self, batch_size): 7 | self.batch_size = batch_size 8 | self.token_stream = [] 9 | 10 | def create_batches(self, data_file): 11 | global pos_size 12 | self.token_stream = [] 13 | with open(data_file, 'r') as f: 14 | for line in f: 15 | line = line.strip() 16 | line = line.split() 17 | parse_line = [int(x) for x in line] 18 | if len(parse_line) == model_settings.seq_len: 19 | self.token_stream.append(parse_line) 20 | 21 | self.num_batch = int(len(self.token_stream) / self.batch_size) 22 | self.data_size = len(self.token_stream) 23 | pos_size = self.data_size 24 | self.token_stream = self.token_stream[:self.num_batch * self.batch_size] 25 | self.sequence_batch = np.split(np.array(self.token_stream), self.num_batch, 0) 26 | self.pointer = 0 27 | 28 | def next_batch(self): 29 | ret = self.sequence_batch[self.pointer] 30 | self.pointer = (self.pointer + 1) % self.num_batch 31 | return ret 32 | 33 | def reset_pointer(self): 34 | self.pointer = 0 35 | 36 | 37 | class Dis_dataloader(): 38 | def __init__(self, batch_size, ref_size=None): 39 | self.batch_size = batch_size 40 | if ref_size != None: 41 | self.ref_size = ref_size 42 | else: 43 | self.ref_size = 16 44 | self.sentences = np.array([]) 45 | self.labels = np.array([]) 46 | 47 | def load_train_data(self, positive_file, negative_file): 48 | # Load data 49 | global pos_size 50 | positive_examples = [] 51 | negative_examples = [] 52 | with open(positive_file)as fin: 53 | for line in fin: 54 | if (random.random() * pos_size) < 10000: 55 | line = line.strip() 56 | line = line.split() 57 | parse_line = [int(x) for x in line] 58 | positive_examples.append(parse_line) 59 | with open(negative_file)as fin: 60 | for line in fin: 61 | line = line.strip() 62 | line = line.split() 63 | parse_line = [int(x) for x in line] 64 | if len(parse_line) == model_settings.seq_len: 65 | negative_examples.append(parse_line) 66 | self.sentences = np.array(positive_examples + negative_examples) 67 | self.positive_examples = positive_examples 68 | # Generate labels 69 | positive_labels = [[0, 1] for _ in positive_examples] 70 | negative_labels = [[1, 0] for _ in negative_examples] 71 | self.labels = np.concatenate([positive_labels, negative_labels], 0) 72 | 73 | # Shuffle the data 74 | shuffle_indices = np.random.permutation(np.arange(len(self.labels))) 75 | self.sentences = self.sentences[shuffle_indices] 76 | self.labels = self.labels[shuffle_indices] 77 | 78 | # Split batches 79 | self.num_batch = int(len(self.labels) / self.batch_size) 80 | self.sentences = self.sentences[:self.num_batch * self.batch_size] 81 | self.labels = self.labels[:self.num_batch * self.batch_size] 82 | self.sentences_batches = np.split(self.sentences, self.num_batch, 0) 83 | self.labels_batches = np.split(self.labels, self.num_batch, 0) 84 | 85 | self.pointer = 0 86 | 87 | def get_reference(self): 88 | ref_samples = [] 89 | for _ in range(self.ref_size): 90 | ref_samples.append(self.positive_examples[random.randint(0, len(self.positive_examples) - 1)]) 91 | return np.array(ref_samples) 92 | 93 | def next_batch(self): 94 | ret = self.sentences_batches[self.pointer], self.labels_batches[self.pointer], self.get_reference() 95 | self.pointer = (self.pointer + 1) % self.num_batch 96 | return ret 97 | 98 | def reset_pointer(self): 99 | self.pointer = 0 100 | 101 | -------------------------------------------------------------------------------- /target_lstm.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.python.ops import tensor_array_ops, control_flow_ops 3 | 4 | 5 | class TARGET_LSTM(object): 6 | def __init__(self, num_emb, batch_size, emb_dim, hidden_dim, sequence_length, start_token, params): 7 | self.num_emb = num_emb 8 | self.batch_size = batch_size 9 | self.emb_dim = emb_dim 10 | self.hidden_dim = hidden_dim 11 | self.sequence_length = sequence_length 12 | self.start_token = tf.constant([start_token] * self.batch_size, dtype=tf.int32) 13 | self.g_params = [] 14 | self.temperature = 1.0 15 | self.params = params 16 | 17 | tf.set_random_seed(66) 18 | 19 | with tf.variable_scope('generator'): 20 | self.g_embeddings = tf.Variable(self.params[0]) 21 | self.g_params.append(self.g_embeddings) 22 | self.g_recurrent_unit = self.create_recurrent_unit(self.g_params) # maps h_tm1 to h_t for generator 23 | self.g_output_unit = self.create_output_unit(self.g_params) # maps h_t to o_t (output token logits) 24 | 25 | # placeholder definition 26 | self.x = tf.placeholder(tf.int32, shape=[self.batch_size, self.sequence_length]) # sequence of tokens generated by generator 27 | 28 | # processed for batch 29 | with tf.device("/cpu:0"): 30 | self.processed_x = tf.transpose(tf.nn.embedding_lookup(self.g_embeddings, self.x), perm=[1, 0, 2]) # seq_length x batch_size x emb_dim 31 | 32 | # initial states 33 | self.h0 = tf.zeros([self.batch_size, self.hidden_dim]) 34 | self.h0 = tf.stack([self.h0, self.h0]) 35 | 36 | # generator on initial randomness 37 | gen_o = tensor_array_ops.TensorArray(dtype=tf.float32, size=self.sequence_length, 38 | dynamic_size=False, infer_shape=True) 39 | gen_x = tensor_array_ops.TensorArray(dtype=tf.int32, size=self.sequence_length, 40 | dynamic_size=False, infer_shape=True) 41 | 42 | def _g_recurrence(i, x_t, h_tm1, gen_o, gen_x): 43 | h_t = self.g_recurrent_unit(x_t, h_tm1) # hidden_memory_tuple 44 | o_t = self.g_output_unit(h_t) # batch x vocab , logits not prob 45 | log_prob = tf.log(tf.nn.softmax(o_t)) 46 | next_token = tf.cast(tf.reshape(tf.multinomial(log_prob, 1), [self.batch_size]), tf.int32) 47 | x_tp1 = tf.nn.embedding_lookup(self.g_embeddings, next_token) # batch x emb_dim 48 | gen_o = gen_o.write(i, tf.reduce_sum(tf.multiply(tf.one_hot(next_token, self.num_emb, 1.0, 0.0), 49 | tf.nn.softmax(o_t)), 1)) # [batch_size] , prob 50 | gen_x = gen_x.write(i, next_token) # indices, batch_size 51 | return i + 1, x_tp1, h_t, gen_o, gen_x 52 | 53 | _, _, _, self.gen_o, self.gen_x = control_flow_ops.while_loop( 54 | cond=lambda i, _1, _2, _3, _4: i < self.sequence_length, 55 | body=_g_recurrence, 56 | loop_vars=(tf.constant(0, dtype=tf.int32), 57 | tf.nn.embedding_lookup(self.g_embeddings, self.start_token), self.h0, gen_o, gen_x) 58 | ) 59 | 60 | self.gen_x = self.gen_x.stack() # seq_length x batch_size 61 | self.gen_x = tf.transpose(self.gen_x, perm=[1, 0]) # batch_size x seq_length 62 | 63 | # supervised pretraining for generator 64 | g_predictions = tensor_array_ops.TensorArray( 65 | dtype=tf.float32, size=self.sequence_length, 66 | dynamic_size=False, infer_shape=True) 67 | 68 | ta_emb_x = tensor_array_ops.TensorArray( 69 | dtype=tf.float32, size=self.sequence_length) 70 | ta_emb_x = ta_emb_x.unstack(self.processed_x) 71 | 72 | def _pretrain_recurrence(i, x_t, h_tm1, g_predictions): 73 | h_t = self.g_recurrent_unit(x_t, h_tm1) 74 | o_t = self.g_output_unit(h_t) 75 | g_predictions = g_predictions.write(i, tf.nn.softmax(o_t)) # batch x vocab_size 76 | x_tp1 = ta_emb_x.read(i) 77 | return i + 1, x_tp1, h_t, g_predictions 78 | 79 | _, _, _, self.g_predictions = control_flow_ops.while_loop( 80 | cond=lambda i, _1, _2, _3: i < self.sequence_length, 81 | body=_pretrain_recurrence, 82 | loop_vars=(tf.constant(0, dtype=tf.int32), 83 | tf.nn.embedding_lookup(self.g_embeddings, self.start_token), 84 | self.h0, g_predictions)) 85 | 86 | self.g_predictions = tf.transpose( 87 | self.g_predictions.stack(), perm=[1, 0, 2]) # batch_size x seq_length x vocab_size 88 | 89 | # pretraining loss 90 | self.pretrain_loss = -tf.reduce_sum( 91 | tf.one_hot(tf.to_int32(tf.reshape(self.x, [-1])), self.num_emb, 1.0, 0.0) * tf.log( 92 | tf.reshape(self.g_predictions, [-1, self.num_emb]))) / (self.sequence_length * self.batch_size) 93 | 94 | self.out_loss = tf.reduce_sum( 95 | tf.reshape( 96 | -tf.reduce_sum( 97 | tf.one_hot(tf.to_int32(tf.reshape(self.x, [-1])), self.num_emb, 1.0, 0.0) * tf.log( 98 | tf.reshape(self.g_predictions, [-1, self.num_emb])), 1 99 | ), [-1, self.sequence_length] 100 | ), 1 101 | ) # batch_size 102 | 103 | def generate(self, session): 104 | # h0 = np.random.normal(size=self.hidden_dim) 105 | outputs = session.run(self.gen_x) 106 | return outputs 107 | 108 | def init_matrix(self, shape): 109 | return tf.random_normal(shape, stddev=1.0) 110 | 111 | def create_recurrent_unit(self, params): 112 | # Weights and Bias for input and hidden tensor 113 | self.Wi = tf.Variable(self.params[1]) 114 | self.Ui = tf.Variable(self.params[2]) 115 | self.bi = tf.Variable(self.params[3]) 116 | 117 | self.Wf = tf.Variable(self.params[4]) 118 | self.Uf = tf.Variable(self.params[5]) 119 | self.bf = tf.Variable(self.params[6]) 120 | 121 | self.Wog = tf.Variable(self.params[7]) 122 | self.Uog = tf.Variable(self.params[8]) 123 | self.bog = tf.Variable(self.params[9]) 124 | 125 | self.Wc = tf.Variable(self.params[10]) 126 | self.Uc = tf.Variable(self.params[11]) 127 | self.bc = tf.Variable(self.params[12]) 128 | params.extend([ 129 | self.Wi, self.Ui, self.bi, 130 | self.Wf, self.Uf, self.bf, 131 | self.Wog, self.Uog, self.bog, 132 | self.Wc, self.Uc, self.bc]) 133 | 134 | def unit(x, hidden_memory_tm1): 135 | previous_hidden_state, c_prev = tf.unstack(hidden_memory_tm1) 136 | 137 | # Input Gate 138 | i = tf.sigmoid( 139 | tf.matmul(x, self.Wi) + 140 | tf.matmul(previous_hidden_state, self.Ui) + self.bi 141 | ) 142 | 143 | # Forget Gate 144 | f = tf.sigmoid( 145 | tf.matmul(x, self.Wf) + 146 | tf.matmul(previous_hidden_state, self.Uf) + self.bf 147 | ) 148 | 149 | # Output Gate 150 | o = tf.sigmoid( 151 | tf.matmul(x, self.Wog) + 152 | tf.matmul(previous_hidden_state, self.Uog) + self.bog 153 | ) 154 | 155 | # New Memory Cell 156 | c_ = tf.nn.tanh( 157 | tf.matmul(x, self.Wc) + 158 | tf.matmul(previous_hidden_state, self.Uc) + self.bc 159 | ) 160 | 161 | # Final Memory cell 162 | c = f * c_prev + i * c_ 163 | 164 | # Current Hidden state 165 | current_hidden_state = o * tf.nn.tanh(c) 166 | 167 | return tf.stack([current_hidden_state, c]) 168 | 169 | return unit 170 | 171 | def create_output_unit(self, params): 172 | self.Wo = tf.Variable(self.params[13]) 173 | self.bo = tf.Variable(self.params[14]) 174 | params.extend([self.Wo, self.bo]) 175 | 176 | def unit(hidden_memory_tuple): 177 | hidden_state, c_prev = tf.unstack(hidden_memory_tuple) 178 | # hidden_state : batch x hidden_dim 179 | logits = tf.matmul(hidden_state, self.Wo) + self.bo 180 | # output = tf.nn.softmax(logits) 181 | return logits 182 | 183 | return unit -------------------------------------------------------------------------------- /sequence_gan.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import numpy as np 3 | import tensorflow as tf 4 | import random 5 | from dataloader import Gen_Data_loader, Dis_dataloader 6 | from generator import Generator 7 | from discriminator import Discriminator 8 | from rollout import ROLLOUT 9 | from target_lstm import TARGET_LSTM 10 | import pickle 11 | import model_settings 12 | 13 | ######################################################################################### 14 | # Generator Hyper-parameters 15 | ###################################################################################### 16 | EMB_DIM = 32 # embedding dimension 17 | HIDDEN_DIM = 32 # hidden state dimension of lstm cell 18 | SEQ_LENGTH = model_settings.seq_len # sequence length 19 | START_TOKEN = 0 20 | PRE_EPOCH_NUM = 120 # supervise (maximum likelihood estimation) epochs 21 | SEED = 88 22 | BATCH_SIZE = 64 23 | 24 | ######################################################################################### 25 | # Discriminator Hyper-parameters 26 | ######################################################################################### 27 | dis_embedding_dim = 64 28 | dis_filter_sizes = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20, int((20 + model_settings.seq_len) / 2), model_settings.seq_len] 29 | dis_num_filters = [100, 200, 200, 200, 200, 100, 100, 100, 100, 100, 160, 160, 160, 160] 30 | dis_dropout_keep_prob = 0.75 31 | dis_l2_reg_lambda = 0.2 32 | dis_batch_size = 64 33 | 34 | ######################################################################################### 35 | # Basic Training Parameters 36 | ######################################################################################### 37 | TOTAL_BATCH = 10000 38 | positive_file = 'save/real_data.txt' 39 | negative_file = 'save/generator_sample.txt' 40 | eval_file = 'save/eval_file.txt' 41 | generated_num = 10000 42 | 43 | 44 | def generate_samples(sess, trainable_model, batch_size, generated_num, output_file): 45 | # Generate Samples 46 | generated_samples = [] 47 | for _ in range(int(generated_num / batch_size)): 48 | generated_samples.extend(trainable_model.generate(sess)) 49 | 50 | with open(output_file, 'w') as fout: 51 | for poem in generated_samples: 52 | buffer = ' '.join([str(x) for x in poem]) + '\n' 53 | fout.write(buffer) 54 | 55 | 56 | def target_loss(sess, target_lstm, data_loader): 57 | # target_loss means the oracle negative log-likelihood tested with the oracle model "target_lstm" 58 | # For more details, please see the Section 4 in https://arxiv.org/abs/1609.05473 59 | nll = [] 60 | data_loader.reset_pointer() 61 | 62 | for it in range(data_loader.num_batch): 63 | batch = data_loader.next_batch() 64 | g_loss = sess.run(target_lstm.pretrain_loss, {target_lstm.x: batch}) 65 | nll.append(g_loss) 66 | 67 | return np.mean(nll) 68 | 69 | 70 | def pre_train_epoch(sess, trainable_model, data_loader): 71 | # Pre-train the generator using MLE for one epoch 72 | supervised_g_losses = [] 73 | data_loader.reset_pointer() 74 | 75 | for it in range(data_loader.num_batch): 76 | batch = data_loader.next_batch() 77 | if random.random() < (float(10000) / float(data_loader.data_size)): 78 | _, g_loss = trainable_model.pretrain_step(sess, batch) 79 | supervised_g_losses.append(g_loss) 80 | return np.mean(supervised_g_losses) 81 | 82 | 83 | def main(): 84 | random.seed(SEED) 85 | np.random.seed(SEED) 86 | assert START_TOKEN == 0 87 | 88 | gen_data_loader = Gen_Data_loader(BATCH_SIZE) 89 | likelihood_data_loader = Gen_Data_loader(BATCH_SIZE) # For testing 90 | if not model_settings.use_real_data: 91 | vocab_size = 5000 92 | else: 93 | vocab_size = model_settings.real_data_vocab_size 94 | dis_data_loader = Dis_dataloader(BATCH_SIZE, 4) 95 | 96 | generator = Generator(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM, SEQ_LENGTH, START_TOKEN) 97 | target_params = pickle.load(open('save/target_params_py3.pkl', 'rb')) 98 | target_lstm = TARGET_LSTM(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM, SEQ_LENGTH, START_TOKEN, params=target_params) # The oracle model 99 | 100 | discriminator = Discriminator(sequence_length=model_settings.seq_len, num_classes=2, vocab_size=vocab_size, emd_dim=dis_embedding_dim, 101 | filter_sizes=dis_filter_sizes, num_filters=dis_num_filters, l2_reg_lambda=dis_l2_reg_lambda, batch_size=dis_batch_size, reference_size=4) 102 | 103 | config = tf.ConfigProto() 104 | config.gpu_options.allow_growth = True 105 | sess = tf.Session(config=config) 106 | sess.run(tf.global_variables_initializer()) 107 | 108 | # First, use the oracle model to provide the positive examples, which are sampled from the oracle data distribution 109 | if not model_settings.use_real_data: 110 | generate_samples(sess, target_lstm, BATCH_SIZE, generated_num, positive_file) 111 | gen_data_loader.create_batches(positive_file) 112 | 113 | log = open('save/experiment-log.txt', 'w') 114 | # pre-train generator 115 | print('Start pre-training...') 116 | log.write('pre-training...\n') 117 | for epoch in range(PRE_EPOCH_NUM): 118 | loss = pre_train_epoch(sess, generator, gen_data_loader) 119 | print('Pre-training generator epoch #%d, loss=%f' % (epoch, loss)) 120 | if not model_settings.use_real_data: 121 | if epoch % 5 == 0: 122 | generate_samples(sess, generator, BATCH_SIZE, generated_num, eval_file) 123 | likelihood_data_loader.create_batches(eval_file) 124 | test_loss = target_loss(sess, target_lstm, likelihood_data_loader) 125 | print('pre-train epoch ', epoch, 'test_loss ', test_loss) 126 | buffer = 'epoch:\t'+ str(epoch) + '\tnll:\t' + str(test_loss) + '\n' 127 | log.write(buffer) 128 | print('Start pre-training discriminator...') 129 | # Train 3 epoch on the generated data and do this for 50 times 130 | for idx in range(50): 131 | generate_samples(sess, generator, BATCH_SIZE, generated_num, negative_file) 132 | dis_data_loader.load_train_data(positive_file, negative_file) 133 | for _ in range(3): 134 | dis_data_loader.reset_pointer() 135 | for it in range(dis_data_loader.num_batch): 136 | x_batch, y_batch, ref_batch = dis_data_loader.next_batch() 137 | feed = { 138 | discriminator.input_x: x_batch, 139 | discriminator.input_y: y_batch, 140 | discriminator.input_ref: ref_batch, 141 | discriminator.dropout_keep_prob: dis_dropout_keep_prob 142 | } 143 | _, loss, pos_vec, neg_vec = sess.run( 144 | [discriminator.train_op, discriminator.loss, discriminator.pos_vec, discriminator.neg_vec], feed) 145 | # print 'pos_vec:', np.sum(pos_vec), 'neg_vec:', np.sum(neg_vec) 146 | print('Pre-training discriminator epoch #%d, loss=%f' % (idx, loss)) 147 | 148 | rollout = ROLLOUT(generator, 0.8) 149 | 150 | print('#########################################################################') 151 | print('Start Adversarial Training...') 152 | log.write('adversarial training...\n') 153 | for total_batch in range(TOTAL_BATCH): 154 | # Train the generator for one step 155 | for it in range(1): 156 | samples = generator.generate(sess) 157 | generate_samples(sess, generator, BATCH_SIZE, generated_num, negative_file) 158 | dis_data_loader.load_train_data(positive_file, negative_file) 159 | rewards = rollout.get_reward(sess, samples, 16, discriminator, dis_data_loader) 160 | feed = {generator.x: samples, generator.rewards: rewards} 161 | _, loss = sess.run([generator.g_updates, generator.g_loss], feed_dict=feed) 162 | print('Training generator epoch #%d, loss=%f' % (total_batch, loss)) 163 | 164 | # Test 165 | if not model_settings.use_real_data: 166 | if total_batch % 5 == 0 or total_batch == TOTAL_BATCH - 1: 167 | generate_samples(sess, generator, BATCH_SIZE, generated_num, eval_file) 168 | likelihood_data_loader.create_batches(eval_file) 169 | test_loss = target_loss(sess, target_lstm, likelihood_data_loader) 170 | buffer = 'epoch:\t' + str(total_batch) + '\tnll:\t' + str(test_loss) + '\n' 171 | print('total_batch: ', total_batch, 'test_loss: ', test_loss) 172 | log.write(buffer) 173 | 174 | # Update roll-out parameters 175 | rollout.update_params() 176 | 177 | # Train the discriminator 178 | for idx in range(5): 179 | generate_samples(sess, generator, BATCH_SIZE, generated_num, negative_file) 180 | dis_data_loader.load_train_data(positive_file, negative_file) 181 | 182 | for _ in range(3): 183 | dis_data_loader.reset_pointer() 184 | for it in range(dis_data_loader.num_batch): 185 | x_batch, y_batch, ref_batch = dis_data_loader.next_batch() 186 | feed = { 187 | discriminator.input_x: x_batch, 188 | discriminator.input_y: y_batch, 189 | discriminator.input_ref: ref_batch, 190 | discriminator.dropout_keep_prob: dis_dropout_keep_prob 191 | } 192 | _, loss, pos_vec, neg_vec = sess.run([discriminator.train_op, discriminator.loss, discriminator.pos_vec, discriminator.neg_vec], feed) 193 | # print 'pos_vec:', np.sum(pos_vec), 'neg_vec:', np.sum(neg_vec) 194 | print('Training discriminator epoch #%d-%d, loss=%f' % (total_batch, idx, loss)) 195 | log.close() 196 | 197 | 198 | if __name__ == '__main__': 199 | main() 200 | -------------------------------------------------------------------------------- /generator.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.python.ops import tensor_array_ops, control_flow_ops 3 | 4 | 5 | class Generator(object): 6 | def __init__(self, num_emb, batch_size, emb_dim, hidden_dim, 7 | sequence_length, start_token, 8 | learning_rate=0.01, reward_gamma=0.95): 9 | self.num_emb = num_emb 10 | self.batch_size = batch_size 11 | self.emb_dim = emb_dim 12 | self.hidden_dim = hidden_dim 13 | self.sequence_length = sequence_length 14 | self.start_token = tf.constant([start_token] * self.batch_size, dtype=tf.int32) 15 | self.learning_rate = tf.Variable(float(learning_rate), trainable=False) 16 | self.reward_gamma = reward_gamma 17 | self.g_params = [] 18 | self.d_params = [] 19 | self.temperature = 1.0 20 | self.grad_clip = 5.0 21 | 22 | self.expected_reward = tf.Variable(tf.zeros([self.sequence_length])) 23 | 24 | tf.set_random_seed(666) 25 | 26 | with tf.variable_scope('generator'): 27 | self.g_embeddings = tf.Variable(self.init_matrix([self.num_emb, self.emb_dim])) 28 | self.g_params.append(self.g_embeddings) 29 | self.g_recurrent_unit = self.create_recurrent_unit(self.g_params) # maps h_tm1 to h_t for generator 30 | self.g_output_unit = self.create_output_unit(self.g_params) # maps h_t to o_t (output token logits) 31 | 32 | # placeholder definition 33 | self.x = tf.placeholder(tf.int32, shape=[self.batch_size, self.sequence_length]) # sequence of tokens generated by generator 34 | self.rewards = tf.placeholder(tf.float32, shape=[self.batch_size, self.sequence_length]) # get from rollout policy and discriminator 35 | 36 | # processed for batch 37 | with tf.device("/cpu:0"): 38 | self.processed_x = tf.transpose(tf.nn.embedding_lookup(self.g_embeddings, self.x), perm=[1, 0, 2]) # seq_length x batch_size x emb_dim 39 | 40 | # Initial states 41 | self.h0 = tf.zeros([self.batch_size, self.hidden_dim]) 42 | self.h0 = tf.stack([self.h0, self.h0]) 43 | 44 | gen_o = tensor_array_ops.TensorArray(dtype=tf.float32, size=self.sequence_length, 45 | dynamic_size=False, infer_shape=True) 46 | gen_x = tensor_array_ops.TensorArray(dtype=tf.int32, size=self.sequence_length, 47 | dynamic_size=False, infer_shape=True) 48 | 49 | def _g_recurrence(i, x_t, h_tm1, gen_o, gen_x): 50 | h_t = self.g_recurrent_unit(x_t, h_tm1) # hidden_memory_tuple 51 | o_t = self.g_output_unit(h_t) # batch x vocab , logits not prob 52 | log_prob = tf.log(tf.nn.softmax(o_t)) 53 | next_token = tf.cast(tf.reshape(tf.multinomial(log_prob, 1), [self.batch_size]), tf.int32) 54 | x_tp1 = tf.nn.embedding_lookup(self.g_embeddings, next_token) # batch x emb_dim 55 | gen_o = gen_o.write(i, tf.reduce_sum(tf.multiply(tf.one_hot(next_token, self.num_emb, 1.0, 0.0), 56 | tf.nn.softmax(o_t)), 1)) # [batch_size] , prob 57 | gen_x = gen_x.write(i, next_token) # indices, batch_size 58 | return i + 1, x_tp1, h_t, gen_o, gen_x 59 | 60 | _, _, _, self.gen_o, self.gen_x = control_flow_ops.while_loop( 61 | cond=lambda i, _1, _2, _3, _4: i < self.sequence_length, 62 | body=_g_recurrence, 63 | loop_vars=(tf.constant(0, dtype=tf.int32), 64 | tf.nn.embedding_lookup(self.g_embeddings, self.start_token), self.h0, gen_o, gen_x)) 65 | 66 | self.gen_x = self.gen_x.stack() # seq_length x batch_size 67 | self.gen_x = tf.transpose(self.gen_x, perm=[1, 0]) # batch_size x seq_length 68 | 69 | # supervised pretraining for generator 70 | g_predictions = tensor_array_ops.TensorArray( 71 | dtype=tf.float32, size=self.sequence_length, 72 | dynamic_size=False, infer_shape=True) 73 | 74 | ta_emb_x = tensor_array_ops.TensorArray( 75 | dtype=tf.float32, size=self.sequence_length) 76 | ta_emb_x = ta_emb_x.unstack(self.processed_x) 77 | 78 | def _pretrain_recurrence(i, x_t, h_tm1, g_predictions): 79 | h_t = self.g_recurrent_unit(x_t, h_tm1) 80 | o_t = self.g_output_unit(h_t) 81 | g_predictions = g_predictions.write(i, tf.nn.softmax(o_t)) # batch x vocab_size 82 | x_tp1 = ta_emb_x.read(i) 83 | return i + 1, x_tp1, h_t, g_predictions 84 | 85 | _, _, _, self.g_predictions = control_flow_ops.while_loop( 86 | cond=lambda i, _1, _2, _3: i < self.sequence_length, 87 | body=_pretrain_recurrence, 88 | loop_vars=(tf.constant(0, dtype=tf.int32), 89 | tf.nn.embedding_lookup(self.g_embeddings, self.start_token), 90 | self.h0, g_predictions)) 91 | 92 | self.g_predictions = tf.transpose(self.g_predictions.stack(), perm=[1, 0, 2]) # batch_size x seq_length x vocab_size 93 | 94 | # pretraining loss 95 | self.pretrain_loss = -tf.reduce_sum( 96 | tf.one_hot(tf.to_int32(tf.reshape(self.x, [-1])), self.num_emb, 1.0, 0.0) * tf.log( 97 | tf.clip_by_value(tf.reshape(self.g_predictions, [-1, self.num_emb]), 1e-20, 1.0) 98 | ) 99 | ) / (self.sequence_length * self.batch_size) 100 | 101 | # training updates 102 | pretrain_opt = self.g_optimizer(self.learning_rate) 103 | 104 | self.pretrain_grad, _ = tf.clip_by_global_norm(tf.gradients(self.pretrain_loss, self.g_params), self.grad_clip) 105 | self.pretrain_updates = pretrain_opt.apply_gradients(zip(self.pretrain_grad, self.g_params)) 106 | 107 | ####################################################################################################### 108 | # Unsupervised Training 109 | ####################################################################################################### 110 | self.g_loss = -tf.reduce_sum( 111 | tf.reduce_sum( 112 | tf.one_hot(tf.to_int32(tf.reshape(self.x, [-1])), self.num_emb, 1.0, 0.0) * tf.log( 113 | tf.clip_by_value(tf.reshape(self.g_predictions, [-1, self.num_emb]), 1e-20, 1.0) 114 | ), 1) * tf.reshape(self.rewards, [-1]) 115 | ) 116 | 117 | g_opt = self.g_optimizer(self.learning_rate) 118 | 119 | self.g_grad, _ = tf.clip_by_global_norm(tf.gradients(self.g_loss, self.g_params), self.grad_clip) 120 | self.g_updates = g_opt.apply_gradients(zip(self.g_grad, self.g_params)) 121 | 122 | def generate(self, sess): 123 | outputs = sess.run(self.gen_x) 124 | return outputs 125 | 126 | def pretrain_step(self, sess, x): 127 | outputs = sess.run([self.pretrain_updates, self.pretrain_loss], feed_dict={self.x: x}) 128 | return outputs 129 | 130 | def init_matrix(self, shape): 131 | return tf.random_normal(shape, stddev=0.1) 132 | 133 | def init_vector(self, shape): 134 | return tf.zeros(shape) 135 | 136 | def create_recurrent_unit(self, params): 137 | # Weights and Bias for input and hidden tensor 138 | self.Wi = tf.Variable(self.init_matrix([self.emb_dim, self.hidden_dim])) 139 | self.Ui = tf.Variable(self.init_matrix([self.hidden_dim, self.hidden_dim])) 140 | self.bi = tf.Variable(self.init_matrix([self.hidden_dim])) 141 | 142 | self.Wf = tf.Variable(self.init_matrix([self.emb_dim, self.hidden_dim])) 143 | self.Uf = tf.Variable(self.init_matrix([self.hidden_dim, self.hidden_dim])) 144 | self.bf = tf.Variable(self.init_matrix([self.hidden_dim])) 145 | 146 | self.Wog = tf.Variable(self.init_matrix([self.emb_dim, self.hidden_dim])) 147 | self.Uog = tf.Variable(self.init_matrix([self.hidden_dim, self.hidden_dim])) 148 | self.bog = tf.Variable(self.init_matrix([self.hidden_dim])) 149 | 150 | self.Wc = tf.Variable(self.init_matrix([self.emb_dim, self.hidden_dim])) 151 | self.Uc = tf.Variable(self.init_matrix([self.hidden_dim, self.hidden_dim])) 152 | self.bc = tf.Variable(self.init_matrix([self.hidden_dim])) 153 | params.extend([ 154 | self.Wi, self.Ui, self.bi, 155 | self.Wf, self.Uf, self.bf, 156 | self.Wog, self.Uog, self.bog, 157 | self.Wc, self.Uc, self.bc]) 158 | 159 | def unit(x, hidden_memory_tm1): 160 | previous_hidden_state, c_prev = tf.unstack(hidden_memory_tm1) 161 | 162 | # Input Gate 163 | i = tf.sigmoid( 164 | tf.matmul(x, self.Wi) + 165 | tf.matmul(previous_hidden_state, self.Ui) + self.bi 166 | ) 167 | 168 | # Forget Gate 169 | f = tf.sigmoid( 170 | tf.matmul(x, self.Wf) + 171 | tf.matmul(previous_hidden_state, self.Uf) + self.bf 172 | ) 173 | 174 | # Output Gate 175 | o = tf.sigmoid( 176 | tf.matmul(x, self.Wog) + 177 | tf.matmul(previous_hidden_state, self.Uog) + self.bog 178 | ) 179 | 180 | # New Memory Cell 181 | c_ = tf.nn.tanh( 182 | tf.matmul(x, self.Wc) + 183 | tf.matmul(previous_hidden_state, self.Uc) + self.bc 184 | ) 185 | 186 | # Final Memory cell 187 | c = f * c_prev + i * c_ 188 | 189 | # Current Hidden state 190 | current_hidden_state = o * tf.nn.tanh(c) 191 | 192 | return tf.stack([current_hidden_state, c]) 193 | 194 | return unit 195 | 196 | def create_output_unit(self, params): 197 | self.Wo = tf.Variable(self.init_matrix([self.hidden_dim, self.num_emb])) 198 | self.bo = tf.Variable(self.init_matrix([self.num_emb])) 199 | params.extend([self.Wo, self.bo]) 200 | 201 | def unit(hidden_memory_tuple): 202 | hidden_state, c_prev = tf.unstack(hidden_memory_tuple) 203 | # hidden_state : batch x hidden_dim 204 | logits = tf.matmul(hidden_state, self.Wo) + self.bo 205 | # output = tf.nn.softmax(logits) 206 | return logits 207 | 208 | return unit 209 | 210 | def g_optimizer(self, *args, **kwargs): 211 | return tf.train.AdamOptimizer(*args, **kwargs) 212 | -------------------------------------------------------------------------------- /discriminator.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from tensorflow.python.ops import tensor_array_ops, control_flow_ops 4 | 5 | 6 | # An alternative to tf.nn.rnn_cell._linear function, which has been removed in Tensorfow 1.0.1 7 | # The highway layer is borrowed from https://github.com/mkroutikov/tf-lstm-char-cnn 8 | def linear(input_, output_size, scope=None): 9 | ''' 10 | Linear map: output[k] = sum_i(Matrix[k, i] * input_[i] ) + Bias[k] 11 | Args: 12 | input_: a tensor or a list of 2D, batch x n, Tensors. 13 | output_size: int, second dimension of W[i]. 14 | scope: VariableScope for the created subgraph; defaults to "Linear". 15 | Returns: 16 | A 2D Tensor with shape [batch x output_size] equal to 17 | sum_i(input_[i] * W[i]), where W[i]s are newly created matrices. 18 | Raises: 19 | ValueError: if some of the arguments has unspecified or wrong shape. 20 | ''' 21 | 22 | shape = input_.get_shape().as_list() 23 | if len(shape) != 2: 24 | raise ValueError("Linear is expecting 2D arguments: %s" % str(shape)) 25 | if not shape[1]: 26 | raise ValueError("Linear expects shape[1] of arguments: %s" % str(shape)) 27 | input_size = shape[1] 28 | 29 | # Now the computation. 30 | with tf.variable_scope(scope or "SimpleLinear", reuse=tf.AUTO_REUSE): 31 | matrix = tf.get_variable("Matrix", [output_size, input_size], dtype=input_.dtype) 32 | bias_term = tf.get_variable("Bias", [output_size], dtype=input_.dtype) 33 | 34 | return tf.matmul(input_, tf.transpose(matrix)) + bias_term 35 | 36 | 37 | def highway(input_, size, num_layers=1, bias=-2.0, f=tf.nn.relu, scope='Highway'): 38 | """Highway Network (cf. http://arxiv.org/abs/1505.00387). 39 | t = sigmoid(Wy + b) 40 | z = t * g(Wy + b) + (1 - t) * y 41 | where g is nonlinearity, t is transform gate, and (1 - t) is carry gate. 42 | """ 43 | 44 | with tf.variable_scope(scope, reuse=tf.AUTO_REUSE): 45 | for idx in range(num_layers): 46 | g = f(linear(input_, size, scope='highway_lin_%d' % idx)) 47 | 48 | t = tf.sigmoid(linear(input_, size, scope='highway_gate_%d' % idx) + bias) 49 | 50 | output = t * g + (1. - t) * input_ 51 | input_ = output 52 | 53 | return output 54 | 55 | 56 | def cosine_distance(y_s, y_u, gamma=1.0): 57 | return gamma * tf.reduce_sum(y_s * y_u) / (tf.norm(y_s) * tf.norm(y_u)) 58 | 59 | 60 | def get_rank_score(emb_test, embs_ref): 61 | p = embs_ref.shape 62 | ref_size = p.as_list()[0] 63 | 64 | def _loop_body(i, ret_v, emb_test, embs_ref): 65 | return i + 1, ret_v + cosine_distance(emb_test, tf.nn.embedding_lookup(embs_ref, i)), emb_test, embs_ref 66 | 67 | _, ret, _, _ = control_flow_ops.while_loop( 68 | cond=lambda i, _1, _2, _3: i < ref_size, 69 | body=_loop_body, 70 | loop_vars=(tf.constant(0, dtype=tf.int32), tf.constant(0.0, dtype=tf.float32), emb_test, embs_ref) 71 | ) 72 | return ret / ref_size 73 | 74 | 75 | class Discriminator(object): 76 | """ 77 | A CNN for text classification. 78 | Uses an embedding layer, followed by a convolutional, max-pooling and softmax layer. 79 | """ 80 | def __init__( 81 | self, sequence_length, num_classes, vocab_size, 82 | emd_dim, filter_sizes, num_filters, l2_reg_lambda=0.0, batch_size=32, reference_size=16, dropout_keep_prob = .75): 83 | # Placeholders for input, output and dropout 84 | self.input_x = tf.placeholder(tf.int32, [batch_size, sequence_length], name="input_x") 85 | self.input_ref = tf.placeholder(tf.int32, [reference_size, sequence_length], name="input_ref") 86 | self.input_y = tf.placeholder(tf.float32, [batch_size, num_classes], name="input_y") 87 | self.dropout_keep_prob = tf.placeholder(tf.float32, name="dropout_keep_dim") 88 | 89 | # Keeping track of l2 regularization loss (optional) 90 | l2_loss = tf.constant(0.0) 91 | 92 | with tf.variable_scope('discriminator'): 93 | 94 | # Embedding layer 95 | with tf.device('/cpu:0'), tf.name_scope("embedding"): 96 | self.W = tf.Variable( 97 | tf.random_uniform([vocab_size, emd_dim], -1.0, 1.0), 98 | name="W") 99 | self.embedded_chars = tf.nn.embedding_lookup(self.W, self.input_x) 100 | self.embedded_chars_expanded = tf.expand_dims(self.embedded_chars, -1) 101 | self.embedded_chars_ref = tf.nn.embedding_lookup(self.W, self.input_ref) 102 | self.embedded_chars_expanded_ref = tf.expand_dims(self.embedded_chars_ref, -1) 103 | 104 | # Create a convolution + maxpool layer for each filter size 105 | pooled_outputs = [] 106 | pooled_outputs_ref = [] 107 | for filter_size, num_filter in zip(filter_sizes, num_filters): 108 | with tf.name_scope("conv-maxpool-%s" % filter_size): 109 | # Convolution Layer 110 | filter_shape = [filter_size, emd_dim, 1, num_filter] 111 | W = tf.Variable(tf.truncated_normal(filter_shape, stddev=0.1), name="W") 112 | b = tf.Variable(tf.constant(0.1, shape=[num_filter]), name="b") 113 | conv = tf.nn.conv2d( 114 | self.embedded_chars_expanded, 115 | W, 116 | strides=[1, 1, 1, 1], 117 | padding="VALID", 118 | name="conv") 119 | conv_ref = tf.nn.conv2d( 120 | self.embedded_chars_expanded_ref, 121 | W, 122 | strides=[1, 1, 1, 1], 123 | padding="VALID", 124 | name="conv_ref" 125 | ) 126 | # Apply nonlinearity 127 | h = tf.nn.relu(tf.nn.bias_add(conv, b), name="relu") 128 | h_ref = tf.nn.relu(tf.nn.bias_add(conv_ref, b, name="relu_ref")) 129 | # Maxpooling over the outputs 130 | pooled = tf.nn.max_pool( 131 | h, 132 | ksize=[1, sequence_length - filter_size + 1, 1, 1], 133 | strides=[1, 1, 1, 1], 134 | padding='VALID', 135 | name="pool") 136 | pooled_ref = tf.nn.max_pool( 137 | h_ref, 138 | ksize=[1, sequence_length - filter_size + 1, 1, 1], 139 | strides=[1, 1, 1, 1], 140 | padding='VALID', 141 | name="pool_ref") 142 | pooled_outputs.append(pooled) 143 | pooled_outputs_ref.append(pooled_ref) 144 | 145 | # Combine all the pooled features 146 | num_filters_total = sum(num_filters) 147 | self.h_pool = tf.concat(pooled_outputs, 3) 148 | self.h_pool_flat = tf.reshape(self.h_pool, [-1, num_filters_total]) 149 | 150 | self.h_pool_ref = tf.concat(pooled_outputs_ref, 3) 151 | self.h_pool_flat_ref = tf.reshape(self.h_pool_ref, [-1, num_filters_total]) 152 | 153 | # Add highway 154 | with tf.name_scope("highway"): 155 | self.h_highway = highway(self.h_pool_flat, self.h_pool_flat.get_shape()[1], 1, 0, scope="highway") 156 | self.h_highway_ref = highway(self.h_pool_flat_ref, self.h_pool_flat_ref.get_shape()[1], 1, 0, 157 | scope="highway") 158 | 159 | # Add dropout 160 | with tf.name_scope("dropout"): 161 | self.h_drop = tf.nn.dropout(self.h_highway, self.dropout_keep_prob) 162 | self.h_drop_ref = tf.nn.dropout(self.h_highway_ref, self.dropout_keep_prob) 163 | 164 | # Final (unnormalized) scores and predictions 165 | with tf.name_scope("output"): 166 | """ 167 | scores = tf.TensorArray(dtype=tf.float32, size=batch_size, dynamic_size=False, infer_shape=True) 168 | def rank_recurrence(i, scores): 169 | rank_score = get_rank_score(tf.nn.embedding_lookup(self.h_drop, i), self.h_drop_ref) 170 | scores = scores.write(i, rank_score) 171 | return i + 1, scores 172 | _, self.scores = control_flow_ops.while_loop( 173 | cond=lambda i, _1: i < batch_size, 174 | body=rank_recurrence, 175 | loop_vars=(tf.constant(0, dtype=tf.int32), scores) 176 | ) 177 | """ 178 | score = [] 179 | """ 180 | for i in range(batch_size): 181 | value = tf.constant(0.0, dtype=tf.float32) 182 | for j in range(reference_size): 183 | value += cosine_distance(tf.nn.embedding_lookup(self.h_drop, i), 184 | tf.nn.embedding_lookup(self.h_drop_ref, j)) 185 | score.append(value) 186 | self.scores = tf.stack(score) 187 | self.scores = tf.reshape(self.scores, [-1]) 188 | """ 189 | self.reference = tf.reduce_mean(tf.nn.l2_normalize(self.h_drop_ref, axis=-1), axis=0, keep_dims=True) 190 | self.feature = tf.nn.l2_normalize(self.h_drop, axis=-1) 191 | self.scores = tf.reshape(self.feature @ tf.transpose(self.reference, perm=[1, 0]), [-1]) 192 | self.ypred_for_auc = tf.reshape(tf.nn.softmax(self.scores), [-1]) 193 | self.log_score = tf.log(self.ypred_for_auc) 194 | 195 | # CalculateMean cross-entropy loss 196 | with tf.name_scope("loss"): 197 | self.neg_vec = tf.nn.embedding_lookup(tf.transpose(self.input_y), 1) 198 | self.pos_vec = tf.nn.embedding_lookup(tf.transpose(self.input_y), 0) 199 | losses_minus = self.log_score * self.neg_vec 200 | losses_posit = self.log_score * self.pos_vec 201 | self.loss = (- tf.reduce_sum(losses_minus) / tf.maximum(tf.reduce_sum(self.neg_vec), 1e-5) + tf.reduce_sum( 202 | losses_posit) / tf.maximum(tf.reduce_sum(self.pos_vec), 1e-5)) / reference_size 203 | 204 | self.params = [param for param in tf.trainable_variables() if 'discriminator' in param.name] 205 | d_optimizer = tf.train.AdamOptimizer(1e-4) 206 | grads_and_vars = d_optimizer.compute_gradients(self.loss, self.params, aggregation_method=2) 207 | self.train_op = d_optimizer.apply_gradients(grads_and_vars) 208 | 209 | -------------------------------------------------------------------------------- /rollout.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.python.ops import tensor_array_ops, control_flow_ops 3 | import numpy as np 4 | import model_settings 5 | 6 | 7 | class ROLLOUT(object): 8 | def __init__(self, lstm, update_rate): 9 | self.lstm = lstm 10 | self.update_rate = update_rate 11 | 12 | self.num_emb = self.lstm.num_emb 13 | self.batch_size = self.lstm.batch_size 14 | self.emb_dim = self.lstm.emb_dim 15 | self.hidden_dim = self.lstm.hidden_dim 16 | self.sequence_length = self.lstm.sequence_length 17 | self.start_token = tf.identity(self.lstm.start_token) 18 | self.learning_rate = self.lstm.learning_rate 19 | 20 | self.g_embeddings = tf.identity(self.lstm.g_embeddings) 21 | self.g_recurrent_unit = self.create_recurrent_unit() # maps h_tm1 to h_t for generator 22 | self.g_output_unit = self.create_output_unit() # maps h_t to o_t (output token logits) 23 | 24 | ##################################################################################################### 25 | # placeholder definition 26 | self.x = tf.placeholder(tf.int32, shape=[self.batch_size, self.sequence_length]) # sequence of tokens generated by generator 27 | self.given_num = tf.placeholder(tf.int32) 28 | 29 | # processed for batch 30 | with tf.device("/cpu:0"): 31 | self.processed_x = tf.transpose(tf.nn.embedding_lookup(self.g_embeddings, self.x), perm=[1, 0, 2]) # seq_length x batch_size x emb_dim 32 | 33 | ta_emb_x = tensor_array_ops.TensorArray( 34 | dtype=tf.float32, size=self.sequence_length) 35 | ta_emb_x = ta_emb_x.unstack(self.processed_x) 36 | 37 | ta_x = tensor_array_ops.TensorArray(dtype=tf.int32, size=self.sequence_length) 38 | ta_x = ta_x.unstack(tf.transpose(self.x, perm=[1, 0])) 39 | ##################################################################################################### 40 | 41 | self.h0 = tf.zeros([self.batch_size, self.hidden_dim]) 42 | self.h0 = tf.stack([self.h0, self.h0]) 43 | 44 | gen_x = tensor_array_ops.TensorArray(dtype=tf.int32, size=self.sequence_length, 45 | dynamic_size=False, infer_shape=True) 46 | 47 | # When current index i < given_num, use the provided tokens as the input at each time step 48 | def _g_recurrence_1(i, x_t, h_tm1, given_num, gen_x): 49 | h_t = self.g_recurrent_unit(x_t, h_tm1) # hidden_memory_tuple 50 | x_tp1 = ta_emb_x.read(i) 51 | gen_x = gen_x.write(i, ta_x.read(i)) 52 | return i + 1, x_tp1, h_t, given_num, gen_x 53 | 54 | # When current index i >= given_num, start roll-out, use the output as time step t as the input at time step t+1 55 | def _g_recurrence_2(i, x_t, h_tm1, given_num, gen_x): 56 | h_t = self.g_recurrent_unit(x_t, h_tm1) # hidden_memory_tuple 57 | o_t = self.g_output_unit(h_t) # batch x vocab , logits not prob 58 | log_prob = tf.log(tf.nn.softmax(o_t)) 59 | next_token = tf.cast(tf.reshape(tf.multinomial(log_prob, 1), [self.batch_size]), tf.int32) 60 | x_tp1 = tf.nn.embedding_lookup(self.g_embeddings, next_token) # batch x emb_dim 61 | gen_x = gen_x.write(i, next_token) # indices, batch_size 62 | return i + 1, x_tp1, h_t, given_num, gen_x 63 | 64 | i, x_t, h_tm1, given_num, self.gen_x = control_flow_ops.while_loop( 65 | cond=lambda i, _1, _2, given_num, _4: i < given_num, 66 | body=_g_recurrence_1, 67 | loop_vars=(tf.constant(0, dtype=tf.int32), 68 | tf.nn.embedding_lookup(self.g_embeddings, self.start_token), self.h0, self.given_num, gen_x)) 69 | 70 | _, _, _, _, self.gen_x = control_flow_ops.while_loop( 71 | cond=lambda i, _1, _2, _3, _4: i < self.sequence_length, 72 | body=_g_recurrence_2, 73 | loop_vars=(i, x_t, h_tm1, given_num, self.gen_x)) 74 | 75 | self.gen_x = self.gen_x.stack() # seq_length x batch_size 76 | self.gen_x = tf.transpose(self.gen_x, perm=[1, 0]) # batch_size x seq_length 77 | 78 | def get_reward(self, sess, input_x, rollout_num, discriminator, dis_data_loader): 79 | rewards = [] 80 | for i in range(rollout_num): 81 | for given_num in range(1, model_settings.seq_len): 82 | feed = {self.x: input_x, self.given_num: given_num} 83 | samples = sess.run(self.gen_x, feed) 84 | feed = {discriminator.input_x: samples, discriminator.dropout_keep_prob: 1.0, discriminator.input_ref: dis_data_loader.get_reference()} 85 | scores = sess.run(discriminator.ypred_for_auc, feed) 86 | ypred = np.array([item for item in scores]) 87 | if i == 0: 88 | rewards.append(ypred) 89 | else: 90 | rewards[given_num - 1] += ypred 91 | 92 | # the last token reward 93 | feed = {discriminator.input_x: input_x, discriminator.dropout_keep_prob: 1.0, discriminator.input_ref: dis_data_loader.get_reference()} 94 | scores = sess.run(discriminator.ypred_for_auc, feed) 95 | ypred = np.array([item for item in scores]) 96 | if i == 0: 97 | rewards.append(ypred) 98 | else: 99 | rewards[model_settings.seq_len - 1] += ypred 100 | 101 | rewards = np.transpose(np.array(rewards)) / (1.0 * rollout_num) # batch_size x seq_length 102 | return rewards 103 | 104 | def create_recurrent_unit(self): 105 | # Weights and Bias for input and hidden tensor 106 | self.Wi = tf.identity(self.lstm.Wi) 107 | self.Ui = tf.identity(self.lstm.Ui) 108 | self.bi = tf.identity(self.lstm.bi) 109 | 110 | self.Wf = tf.identity(self.lstm.Wf) 111 | self.Uf = tf.identity(self.lstm.Uf) 112 | self.bf = tf.identity(self.lstm.bf) 113 | 114 | self.Wog = tf.identity(self.lstm.Wog) 115 | self.Uog = tf.identity(self.lstm.Uog) 116 | self.bog = tf.identity(self.lstm.bog) 117 | 118 | self.Wc = tf.identity(self.lstm.Wc) 119 | self.Uc = tf.identity(self.lstm.Uc) 120 | self.bc = tf.identity(self.lstm.bc) 121 | 122 | def unit(x, hidden_memory_tm1): 123 | previous_hidden_state, c_prev = tf.unstack(hidden_memory_tm1) 124 | 125 | # Input Gate 126 | i = tf.sigmoid( 127 | tf.matmul(x, self.Wi) + 128 | tf.matmul(previous_hidden_state, self.Ui) + self.bi 129 | ) 130 | 131 | # Forget Gate 132 | f = tf.sigmoid( 133 | tf.matmul(x, self.Wf) + 134 | tf.matmul(previous_hidden_state, self.Uf) + self.bf 135 | ) 136 | 137 | # Output Gate 138 | o = tf.sigmoid( 139 | tf.matmul(x, self.Wog) + 140 | tf.matmul(previous_hidden_state, self.Uog) + self.bog 141 | ) 142 | 143 | # New Memory Cell 144 | c_ = tf.nn.tanh( 145 | tf.matmul(x, self.Wc) + 146 | tf.matmul(previous_hidden_state, self.Uc) + self.bc 147 | ) 148 | 149 | # Final Memory cell 150 | c = f * c_prev + i * c_ 151 | 152 | # Current Hidden state 153 | current_hidden_state = o * tf.nn.tanh(c) 154 | 155 | return tf.stack([current_hidden_state, c]) 156 | 157 | return unit 158 | 159 | def update_recurrent_unit(self): 160 | # Weights and Bias for input and hidden tensor 161 | self.Wi = self.update_rate * self.Wi + (1 - self.update_rate) * tf.identity(self.lstm.Wi) 162 | self.Ui = self.update_rate * self.Ui + (1 - self.update_rate) * tf.identity(self.lstm.Ui) 163 | self.bi = self.update_rate * self.bi + (1 - self.update_rate) * tf.identity(self.lstm.bi) 164 | 165 | self.Wf = self.update_rate * self.Wf + (1 - self.update_rate) * tf.identity(self.lstm.Wf) 166 | self.Uf = self.update_rate * self.Uf + (1 - self.update_rate) * tf.identity(self.lstm.Uf) 167 | self.bf = self.update_rate * self.bf + (1 - self.update_rate) * tf.identity(self.lstm.bf) 168 | 169 | self.Wog = self.update_rate * self.Wog + (1 - self.update_rate) * tf.identity(self.lstm.Wog) 170 | self.Uog = self.update_rate * self.Uog + (1 - self.update_rate) * tf.identity(self.lstm.Uog) 171 | self.bog = self.update_rate * self.bog + (1 - self.update_rate) * tf.identity(self.lstm.bog) 172 | 173 | self.Wc = self.update_rate * self.Wc + (1 - self.update_rate) * tf.identity(self.lstm.Wc) 174 | self.Uc = self.update_rate * self.Uc + (1 - self.update_rate) * tf.identity(self.lstm.Uc) 175 | self.bc = self.update_rate * self.bc + (1 - self.update_rate) * tf.identity(self.lstm.bc) 176 | 177 | def unit(x, hidden_memory_tm1): 178 | previous_hidden_state, c_prev = tf.unstack(hidden_memory_tm1) 179 | 180 | # Input Gate 181 | i = tf.sigmoid( 182 | tf.matmul(x, self.Wi) + 183 | tf.matmul(previous_hidden_state, self.Ui) + self.bi 184 | ) 185 | 186 | # Forget Gate 187 | f = tf.sigmoid( 188 | tf.matmul(x, self.Wf) + 189 | tf.matmul(previous_hidden_state, self.Uf) + self.bf 190 | ) 191 | 192 | # Output Gate 193 | o = tf.sigmoid( 194 | tf.matmul(x, self.Wog) + 195 | tf.matmul(previous_hidden_state, self.Uog) + self.bog 196 | ) 197 | 198 | # New Memory Cell 199 | c_ = tf.nn.tanh( 200 | tf.matmul(x, self.Wc) + 201 | tf.matmul(previous_hidden_state, self.Uc) + self.bc 202 | ) 203 | 204 | # Final Memory cell 205 | c = f * c_prev + i * c_ 206 | 207 | # Current Hidden state 208 | current_hidden_state = o * tf.nn.tanh(c) 209 | 210 | return tf.stack([current_hidden_state, c]) 211 | 212 | return unit 213 | 214 | def create_output_unit(self): 215 | self.Wo = tf.identity(self.lstm.Wo) 216 | self.bo = tf.identity(self.lstm.bo) 217 | 218 | def unit(hidden_memory_tuple): 219 | hidden_state, c_prev = tf.unstack(hidden_memory_tuple) 220 | # hidden_state : batch x hidden_dim 221 | logits = tf.matmul(hidden_state, self.Wo) + self.bo 222 | # output = tf.nn.softmax(logits) 223 | return logits 224 | 225 | return unit 226 | 227 | def update_output_unit(self): 228 | self.Wo = self.update_rate * self.Wo + (1 - self.update_rate) * tf.identity(self.lstm.Wo) 229 | self.bo = self.update_rate * self.bo + (1 - self.update_rate) * tf.identity(self.lstm.bo) 230 | 231 | def unit(hidden_memory_tuple): 232 | hidden_state, c_prev = tf.unstack(hidden_memory_tuple) 233 | # hidden_state : batch x hidden_dim 234 | logits = tf.matmul(hidden_state, self.Wo) + self.bo 235 | # output = tf.nn.softmax(logits) 236 | return logits 237 | 238 | return unit 239 | 240 | def update_params(self): 241 | self.g_embeddings = tf.identity(self.lstm.g_embeddings) 242 | self.g_recurrent_unit = self.update_recurrent_unit() 243 | self.g_output_unit = self.update_output_unit() 244 | --------------------------------------------------------------------------------