├── LICENSE ├── README.md ├── configuration.py ├── dataloader.py ├── discriminator.py ├── generator.py ├── misc ├── README.md └── seqgan.png ├── rollout.py ├── save ├── eval_file.txt ├── experiment-log.txt ├── generator_sample.txt ├── real_data.txt └── target_params.pkl ├── target_lstm.py ├── train.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 ChenChengKuan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SeqGAN_tensorflow 2 | 3 | This code is used to reproduce the result of synthetic data experiments in "SeqGAN: Sequence Generative Adversarial Nets with Policy Gradient" (Yu et.al). It replaces the original tensor array implementation with higher level tensorflow API for better flexibility. 4 | 5 | ## Introduction 6 | The baisc idea of SeqGAN is to regard sequence generator as an agent in reinforcement learning. To train this agent, it applies REINFORCE (Williams, 1992) algorithm to train the generator and a discriminator is trained to provide the reward. To calculate the reward of partially generated sequence, Monte-Carlo sampling is used to rollout the unfinished sequence to get the estimated reward. 7 | ![seqgan](https://github.com/ChenChengKuan/SeqGAN_tensorflow/blob/master/misc/seqgan.png) 8 | 9 | Some works based on training method used in SeqGAN: 10 | * Recurrent Topic-Transition GAN for Visual Paragraph Generation (Liang et.al, ICCV 2017) 11 | * Towards Diverse and Natural Image Descriptions via a Conditional GAN (Dai et.al, ICCV 2017) 12 | * Show, Adapt and Tell: Adversarial Training of Cross-domain Image Captioner (Chen et.al, ICCV 2017) 13 | * Adversarial Ranking for Language Generation (Lin et.al, NIPS 2017) 14 | * Long Text Generation via Adversarial Training with Leaked Information (Guo et.al, AAAI 2018) 15 | 16 | ## Prerequisites 17 | * Python 2.7 18 | * Tensorflow 1.3 19 | ## Run the code 20 | Simply run `python train.py` will start the training process. It will first pretrain the generator and discriminator then start adversarial training. 21 | 22 | ## Results 23 | The output in experiment.log would be something similar to below, which is close to reported result in [original implementation](https://github.com/LantaoYu/SeqGAN) 24 | ``` 25 | pre-training... 26 | epoch: 0 nll: 10.1971 27 | epoch: 5 nll: 9.4694 28 | epoch: 10 nll: 9.2169 29 | epoch: 15 nll: 9.17986 30 | epoch: 20 nll: 9.16206 31 | epoch: 25 nll: 9.1344 32 | epoch: 30 nll: 9.12127 33 | epoch: 35 nll: 9.0948 34 | epoch: 40 nll: 9.10186 35 | epoch: 45 nll: 9.10108 36 | epoch: 50 nll: 9.0971 37 | epoch: 55 nll: 9.11246 38 | epoch: 60 nll: 9.1182 39 | epoch: 65 nll: 9.10095 40 | epoch: 70 nll: 9.09244 41 | epoch: 75 nll: 9.08816 42 | epoch: 80 nll: 9.10319 43 | epoch: 85 nll: 9.08916 44 | epoch: 90 nll: 9.08348 45 | epoch: 95 nll: 9.09661 46 | epoch: 100 nll: 9.10361 47 | epoch: 105 nll: 9.11718 48 | epoch: 110 nll: 9.10492 49 | epoch: 115 nll: 9.1038 50 | adversarial training... 51 | epoch: 0 nll: 9.09558 52 | epoch: 5 nll: 9.03083 53 | epoch: 10 nll: 8.96725 54 | epoch: 15 nll: 8.91415 55 | epoch: 20 nll: 8.87554 56 | epoch: 25 nll: 8.82305 57 | epoch: 30 nll: 8.76805 58 | epoch: 35 nll: 8.73597 59 | epoch: 40 nll: 8.71933 60 | epoch: 45 nll: 8.71653 61 | epoch: 50 nll: 8.71746 62 | epoch: 55 nll: 8.7036 63 | epoch: 60 nll: 8.68666 64 | epoch: 65 nll: 8.68931 65 | epoch: 70 nll: 8.68588 66 | epoch: 75 nll: 8.69977 67 | epoch: 80 nll: 8.69636 68 | epoch: 85 nll: 8.69916 69 | epoch: 90 nll: 8.6969 70 | epoch: 95 nll: 8.71021 71 | epoch: 100 nll: 8.72561 72 | epoch: 105 nll: 8.71369 73 | epoch: 110 nll: 8.71723 74 | epoch: 115 nll: 8.72388 75 | epoch: 120 nll: 8.71293 76 | epoch: 125 nll: 8.70667 77 | epoch: 130 nll: 8.70341 78 | epoch: 135 nll: 8.69929 79 | epoch: 140 nll: 8.69793 80 | epoch: 145 nll: 8.67705 81 | epoch: 150 nll: 8.65372 82 | ``` 83 | Note: Part of this code (dataloader, discriminator, target LSTM) is based on [original implementation by Lantao Yu](https://github.com/LantaoYu/SeqGAN). Many thanks to his code 84 | -------------------------------------------------------------------------------- /configuration.py: -------------------------------------------------------------------------------- 1 | class generator_config(object): 2 | """Wrapper class for generator hyperparameter""" 3 | 4 | def __init__(self): 5 | self.emb_dim = 32 #dimension of embedding 6 | self.num_emb = 5000 #dimension of output unit 7 | self.hidden_dim = 32 #dimension of hidden unit 8 | self.sequence_length = 20 #maximum input sequence length 9 | self.gen_batch_size = 64 #batch size of generator 10 | self.start_token = 0 #special token for start of sentence 11 | 12 | 13 | class discriminator_config(object): 14 | """Wrapper class for discriminator hyperparameter""" 15 | 16 | def __init__(self): 17 | self.sequence_length = 20 #maximum input sequence length 18 | self.num_classes = 2 #number of class (real and fake) 19 | self.vocab_size = 5000 #vocabulary size, shoud be same as num_emb 20 | self.dis_embedding_dim = 64 #dimension of discriminator embedding space 21 | self.dis_filter_sizes = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20] #convolutional kernel size of discriminator 22 | self.dis_num_filters = [100, 200, 200, 200, 200, 100, 100, 100, 100, 100, 160, 160] #number of filters of each conv. kernel 23 | self.dis_dropout_keep_prob = 0.75 # dropout rate of discriminator 24 | self.dis_l2_reg_lambda = 0.2 #L2 regularization strength 25 | self.dis_batch_size = 64 #Batch size for discriminator 26 | self.dis_learning_rate = 1e-4 #Learning rate of discriminator 27 | 28 | 29 | class training_config(object): 30 | """Wrapper class for parameters for training""" 31 | 32 | def __init__(self): 33 | self.gen_learning_rate = 0.01 #learning rate of generator 34 | self.gen_update_time = 1 #update times of generator in adversarial training 35 | self.dis_update_time_adv = 5 #update times of discriminator in adversarial training 36 | self.dis_update_epoch_adv = 3 #update epoch / times of discriminator 37 | self.dis_update_time_pre = 50 #pretraining times of discriminator 38 | self.dis_update_epoch_pre = 3 #number of epoch / time in pretraining 39 | self.pretrained_epoch_num = 120 #Number of pretraining epoch 40 | self.rollout_num = 16 #Rollout number for reward estimation 41 | self.test_per_epoch = 5 #Test the NLL per epoch 42 | self.batch_size = 64 #Batch size used for training 43 | self.save_pretrained = 120 # Whether to save model in certain epoch (optional) 44 | self.grad_clip = 5.0 #Gradient Clipping 45 | self.seed = 88 #Random seed used for initialization 46 | self.start_token = 0 #special start token 47 | self.total_batch = 200 #total batch used for adversarial training 48 | self.positive_file = "save/real_data.txt" # save path of real data generated by target LSTM 49 | self.negative_file = "save/generator_sample.txt" #save path of fake data generated by generator 50 | self.eval_file = "save/eval_file.txt" #file used for evaluation 51 | self.generated_num = 10000 #Number of samples from generator used for evaluation 52 | -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class Gen_Data_loader(): 5 | def __init__(self, batch_size): 6 | self.batch_size = batch_size 7 | self.token_stream = [] 8 | 9 | def create_batches(self, data_file): 10 | self.token_stream = [] 11 | with open(data_file, 'r') as f: 12 | for line in f: 13 | line = line.strip() 14 | line = line.split() 15 | parse_line = [int(x) for x in line] 16 | if len(parse_line) == 20: 17 | self.token_stream.append(parse_line) 18 | 19 | self.num_batch = int(len(self.token_stream) / self.batch_size) 20 | self.token_stream = self.token_stream[:self.num_batch * self.batch_size] 21 | self.sequence_batch = np.split(np.array(self.token_stream), self.num_batch, 0) 22 | self.pointer = 0 23 | 24 | def next_batch(self): 25 | ret = self.sequence_batch[self.pointer] 26 | self.pointer = (self.pointer + 1) % self.num_batch 27 | return ret 28 | 29 | def reset_pointer(self): 30 | self.pointer = 0 31 | 32 | 33 | class Dis_dataloader(): 34 | def __init__(self, batch_size): 35 | self.batch_size = batch_size 36 | self.sentences = np.array([]) 37 | self.labels = np.array([]) 38 | 39 | def load_train_data(self, positive_file, negative_file): 40 | # Load data 41 | positive_examples = [] 42 | negative_examples = [] 43 | with open(positive_file)as fin: 44 | for line in fin: 45 | line = line.strip() 46 | line = line.split() 47 | parse_line = [int(x) for x in line] 48 | positive_examples.append(parse_line) 49 | with open(negative_file)as fin: 50 | for line in fin: 51 | line = line.strip() 52 | line = line.split() 53 | parse_line = [int(x) for x in line] 54 | if len(parse_line) == 20: 55 | negative_examples.append(parse_line) 56 | self.sentences = np.array(positive_examples + negative_examples) 57 | 58 | # Generate labels 59 | positive_labels = [[0, 1] for _ in positive_examples] 60 | negative_labels = [[1, 0] for _ in negative_examples] 61 | self.labels = np.concatenate([positive_labels, negative_labels], 0) 62 | 63 | # Shuffle the data 64 | shuffle_indices = np.random.permutation(np.arange(len(self.labels))) 65 | self.sentences = self.sentences[shuffle_indices] 66 | self.labels = self.labels[shuffle_indices] 67 | 68 | # Split batches 69 | self.num_batch = int(len(self.labels) / self.batch_size) 70 | self.sentences = self.sentences[:self.num_batch * self.batch_size] 71 | self.labels = self.labels[:self.num_batch * self.batch_size] 72 | self.sentences_batches = np.split(self.sentences, self.num_batch, 0) 73 | self.labels_batches = np.split(self.labels, self.num_batch, 0) 74 | 75 | self.pointer = 0 76 | 77 | 78 | def next_batch(self): 79 | ret = self.sentences_batches[self.pointer], self.labels_batches[self.pointer] 80 | self.pointer = (self.pointer + 1) % self.num_batch 81 | return ret 82 | 83 | def reset_pointer(self): 84 | self.pointer = 0 85 | 86 | -------------------------------------------------------------------------------- /discriminator.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | # An alternative to tf.nn.rnn_cell._linear function, which has been removed in Tensorfow 1.0.1 5 | # The highway layer is borrowed from https://github.com/mkroutikov/tf-lstm-char-cnn 6 | def linear(input_, output_size, scope=None): 7 | ''' 8 | Linear map: output[k] = sum_i(Matrix[k, i] * input_[i] ) + Bias[k] 9 | Args: 10 | input_: a tensor or a list of 2D, batch x n, Tensors. 11 | output_size: int, second dimension of W[i]. 12 | scope: VariableScope for the created subgraph; defaults to "Linear". 13 | Returns: 14 | A 2D Tensor with shape [batch x output_size] equal to 15 | sum_i(input_[i] * W[i]), where W[i]s are newly created matrices. 16 | Raises: 17 | ValueError: if some of the arguments has unspecified or wrong shape. 18 | ''' 19 | 20 | shape = input_.get_shape().as_list() 21 | if len(shape) != 2: 22 | raise ValueError("Linear is expecting 2D arguments: %s" % str(shape)) 23 | if not shape[1]: 24 | raise ValueError("Linear expects shape[1] of arguments: %s" % str(shape)) 25 | input_size = shape[1] 26 | 27 | # Now the computation. 28 | with tf.variable_scope(scope or "SimpleLinear"): 29 | matrix = tf.get_variable("Matrix", [output_size, input_size], dtype=input_.dtype) 30 | bias_term = tf.get_variable("Bias", [output_size], dtype=input_.dtype) 31 | 32 | return tf.matmul(input_, tf.transpose(matrix)) + bias_term 33 | 34 | def highway(input_, size, num_layers=1, bias=-2.0, f=tf.nn.relu, scope='Highway'): 35 | """Highway Network (cf. http://arxiv.org/abs/1505.00387). 36 | t = sigmoid(Wy + b) 37 | z = t * g(Wy + b) + (1 - t) * y 38 | where g is nonlinearity, t is transform gate, and (1 - t) is carry gate. 39 | """ 40 | 41 | with tf.variable_scope(scope): 42 | for idx in range(num_layers): 43 | g = f(linear(input_, size, scope='highway_lin_%d' % idx)) 44 | 45 | t = tf.sigmoid(linear(input_, size, scope='highway_gate_%d' % idx) + bias) 46 | 47 | output = t * g + (1. - t) * input_ 48 | input_ = output 49 | 50 | return output 51 | 52 | class Discriminator(object): 53 | """ 54 | A CNN for text classification. 55 | Uses an embedding layer, followed by a convolutional, max-pooling and softmax layer. 56 | """ 57 | 58 | def __init__(self, config): 59 | # Placeholders for input, output and dropout 60 | self.sequence_length = config.sequence_length 61 | self.num_classes = config.num_classes 62 | self.vocab_size = config.vocab_size 63 | self.filter_sizes = config.dis_filter_sizes 64 | self.num_filters = config.dis_num_filters 65 | self.vocab_size = config.vocab_size 66 | self.dis_learning_rate = config.dis_learning_rate 67 | self.embedding_size = config.dis_embedding_dim 68 | self.l2_reg_lambda = config.dis_l2_reg_lambda 69 | self.input_x = tf.placeholder(tf.int32, [None, self.sequence_length], name="input_x") 70 | self.input_y = tf.placeholder(tf.float32, [None, self.num_classes], name="input_y") 71 | self.dropout_keep_prob = tf.placeholder(tf.float32, name="dropout_keep_prob") 72 | # Keeping track of l2 regularization loss (optional) 73 | self.l2_loss = tf.constant(0.0) 74 | def build_discriminator(self): 75 | with tf.variable_scope('discriminator'): 76 | 77 | # Embedding layer 78 | with tf.device('/cpu:0'), tf.name_scope("embedding"): 79 | self.W = tf.Variable( 80 | tf.random_uniform([self.vocab_size, self.embedding_size], -1.0, 1.0), 81 | name="W") 82 | self.embedded_chars = tf.nn.embedding_lookup(self.W, self.input_x) 83 | self.embedded_chars_expanded = tf.expand_dims(self.embedded_chars, -1) 84 | 85 | # Create a convolution + maxpool layer for each filter size 86 | pooled_outputs = [] 87 | for filter_size, num_filter in zip(self.filter_sizes, self.num_filters): 88 | with tf.name_scope("conv-maxpool-%s" % filter_size): 89 | # Convolution Layer 90 | filter_shape = [filter_size, self.embedding_size, 1, num_filter] 91 | W = tf.Variable(tf.truncated_normal(filter_shape, stddev=0.1), name="W") 92 | b = tf.Variable(tf.constant(0.1, shape=[num_filter]), name="b") 93 | conv = tf.nn.conv2d( 94 | self.embedded_chars_expanded, 95 | W, 96 | strides=[1, 1, 1, 1], 97 | padding="VALID", 98 | name="conv") 99 | # Apply nonlinearity 100 | h = tf.nn.relu(tf.nn.bias_add(conv, b), name="relu") 101 | # Maxpooling over the outputs 102 | pooled = tf.nn.max_pool( 103 | h, 104 | ksize=[1, self.sequence_length - filter_size + 1, 1, 1], 105 | strides=[1, 1, 1, 1], 106 | padding='VALID', 107 | name="pool") 108 | pooled_outputs.append(pooled) 109 | 110 | # Combine all the pooled features 111 | num_filters_total = sum(self.num_filters) 112 | self.h_pool = tf.concat(pooled_outputs, 3) 113 | self.h_pool_flat = tf.reshape(self.h_pool, [-1, num_filters_total]) 114 | 115 | # Add highway 116 | with tf.name_scope("highway"): 117 | self.h_highway = highway(self.h_pool_flat, self.h_pool_flat.get_shape()[1], 1, 0) 118 | 119 | # Add dropout 120 | with tf.name_scope("dropout"): 121 | self.h_drop = tf.nn.dropout(self.h_highway, self.dropout_keep_prob) 122 | 123 | # Final (unnormalized) scores and predictions 124 | with tf.name_scope("output"): 125 | W = tf.Variable(tf.truncated_normal([num_filters_total, self.num_classes], stddev=0.1), name="W") 126 | b = tf.Variable(tf.constant(0.1, shape=[self.num_classes]), name="b") 127 | self.l2_loss += tf.nn.l2_loss(W) 128 | self.l2_loss += tf.nn.l2_loss(b) 129 | self.scores = tf.nn.xw_plus_b(self.h_drop, W, b, name="scores") 130 | self.ypred_for_auc = tf.nn.softmax(self.scores) 131 | self.predictions = tf.argmax(self.scores, 1, name="predictions") 132 | 133 | # CalculateMean cross-entropy loss 134 | with tf.name_scope("loss"): 135 | losses = tf.nn.softmax_cross_entropy_with_logits(logits=self.scores, labels=self.input_y) 136 | self.loss = tf.reduce_mean(losses) + self.l2_reg_lambda * self.l2_loss 137 | 138 | self.params = [param for param in tf.trainable_variables() if 'discriminator' in param.name] 139 | d_optimizer = tf.train.AdamOptimizer(self.dis_learning_rate) 140 | grads_and_vars = d_optimizer.compute_gradients(self.loss, self.params, aggregation_method=2) 141 | self.train_op = d_optimizer.apply_gradients(grads_and_vars) 142 | -------------------------------------------------------------------------------- /generator.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | class Generator(object): 4 | """SeqGAN implementation based on https://arxiv.org/abs/1609.05473 5 | "SeqGAN: Sequence Generative Adversarial Nets with Policy Gradient" 6 | Lantao Yu, Weinan Zhang, Jun Wang, Yong Yu 7 | """ 8 | def __init__(self, config): 9 | """ Basic Set up 10 | 11 | Args: 12 | num_emb: output vocabulary size 13 | batch_size: batch size for generator 14 | emb_dim: LSTM hidden unit dimension 15 | sequence_length: maximum length of input sequence 16 | start_token: special token used to represent start of sentence 17 | initializer: initializer for LSTM kernel and output matrix 18 | """ 19 | self.num_emb = config.num_emb 20 | self.batch_size = config.gen_batch_size 21 | self.emb_dim = config.emb_dim 22 | self.hidden_dim = config.hidden_dim 23 | self.sequence_length = config.sequence_length 24 | self.start_token = tf.constant(config.start_token, dtype=tf.int32, shape=[self.batch_size]) 25 | self.initializer = tf.random_normal_initializer(stddev=0.1) 26 | 27 | def build_input(self, name): 28 | """ Buid input placeholder 29 | 30 | Input: 31 | name: name of network 32 | Output: 33 | self.input_seqs_pre (if name == pretrained) 34 | self.input_seqs_mask (if name == pretrained, optional mask for masking invalid token) 35 | self.input_seqs_adv (if name == 'adversarial') 36 | self.rewards (if name == 'adversarial') 37 | """ 38 | assert name in ['pretrain', 'adversarial', 'sample'] 39 | if name == 'pretrain': 40 | self.input_seqs_pre = tf.placeholder(tf.int32, [None, self.sequence_length], name="input_seqs_pre") 41 | self.input_seqs_mask = tf.placeholder(tf.float32, [None, self.sequence_length], name="input_seqs_mask") 42 | elif name == 'adversarial': 43 | self.input_seqs_adv = tf.placeholder(tf.int32, [None, self.sequence_length], name="input_seqs_adv") 44 | self.rewards = tf.placeholder(tf.float32, [None, self.sequence_length], name="reward") 45 | 46 | def build_pretrain_network(self): 47 | """ Buid pretrained network 48 | 49 | Input: 50 | self.input_seqs_pre 51 | self.input_seqs_mask 52 | Output: 53 | self.pretrained_loss 54 | self.pretrained_loss_sum (optional) 55 | """ 56 | self.build_input(name="pretrain") 57 | self.pretrained_loss = 0.0 58 | with tf.variable_scope("teller"): 59 | with tf.variable_scope("lstm"): 60 | lstm1 = tf.nn.rnn_cell.LSTMCell(self.hidden_dim, state_is_tuple=True) 61 | with tf.device("/cpu:0"), tf.variable_scope("embedding"): 62 | word_emb_W = tf.get_variable("word_emb_W", [self.num_emb, self.emb_dim], "float32", self.initializer) 63 | with tf.variable_scope("output"): 64 | output_W = tf.get_variable("output_W", [self.emb_dim, self.num_emb], "float32", self.initializer) 65 | 66 | with tf.variable_scope("lstm"): 67 | for j in range(self.sequence_length): 68 | with tf.device("/cpu:0"): 69 | if j == 0: 70 | # 71 | lstm1_in = tf.nn.embedding_lookup(word_emb_W, self.start_token) 72 | else: 73 | lstm1_in = tf.nn.embedding_lookup(word_emb_W, self.input_seqs_pre[:, j-1]) 74 | if j == 0: 75 | state = lstm1.zero_state(self.batch_size, tf.float32) 76 | 77 | output, state = lstm1(lstm1_in, state, scope=tf.get_variable_scope()) 78 | 79 | logits = tf.matmul(output, output_W) 80 | # calculate loss 81 | pretrained_loss_t = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=self.input_seqs_pre[:,j], logits=logits) 82 | pretrained_loss_t = tf.reduce_sum(tf.multiply(pretrained_loss_t, self.input_seqs_mask[:,j])) 83 | self.pretrained_loss += pretrained_loss_t 84 | word_predict = tf.to_int32(tf.argmax(logits, 1)) 85 | self.pretrained_loss /= tf.reduce_sum(self.input_seqs_mask) 86 | self.pretrained_loss_sum = tf.summary.scalar("pretrained_loss",self.pretrained_loss) 87 | 88 | def build_adversarial_network(self): 89 | """ Buid adversarial training network 90 | 91 | Input: 92 | self.input_seqs_adv 93 | self.rewards 94 | Output: 95 | self.gen_loss_adv 96 | """ 97 | self.build_input(name="adversarial") 98 | self.softmax_list_reshape = [] 99 | self.softmax_list = [] 100 | with tf.variable_scope("teller"): 101 | tf.get_variable_scope().reuse_variables() 102 | with tf.variable_scope("lstm"): 103 | lstm1 = tf.nn.rnn_cell.LSTMCell(self.hidden_dim, state_is_tuple=True) 104 | with tf.device("/cpu:0"), tf.variable_scope("embedding"): 105 | word_emb_W = tf.get_variable("word_emb_W", [self.num_emb, self.emb_dim], "float32", self.initializer) 106 | with tf.variable_scope("output"): 107 | output_W = tf.get_variable("output_W", [self.emb_dim, self.num_emb], "float32", self.initializer) 108 | with tf.variable_scope("lstm"): 109 | for j in range(self.sequence_length): 110 | tf.get_variable_scope().reuse_variables() 111 | with tf.device("/cpu:0"): 112 | if j == 0: 113 | # 114 | lstm1_in = tf.nn.embedding_lookup(word_emb_W, self.start_token) 115 | else: 116 | lstm1_in = tf.nn.embedding_lookup(word_emb_W, self.input_seqs_adv[:, j-1]) 117 | if j == 0: 118 | state = lstm1.zero_state(self.batch_size, tf.float32) 119 | output, state = lstm1(lstm1_in, state, scope=tf.get_variable_scope()) 120 | 121 | logits = tf.matmul(output, output_W) 122 | softmax = tf.nn.softmax(logits) 123 | self.softmax_list.append(softmax) 124 | self.softmax_list_reshape = tf.transpose(self.softmax_list, perm=[1, 0, 2]) 125 | self.gen_loss_adv = -tf.reduce_sum( 126 | tf.reduce_sum( 127 | tf.one_hot(tf.to_int32(tf.reshape(self.input_seqs_adv, [-1])), self.num_emb, 1.0, 0.0) * tf.log( 128 | tf.clip_by_value(tf.reshape(self.softmax_list_reshape, [-1, self.num_emb]), 1e-20, 1.0) 129 | ), 1) * tf.reshape(self.rewards, [-1])) 130 | def build_sample_network(self): 131 | """ Buid sampling network 132 | 133 | Output: 134 | self.sample_word_list_reshape 135 | """ 136 | self.build_input(name="sample") 137 | self.sample_word_list = [] 138 | with tf.variable_scope("teller"): 139 | tf.get_variable_scope().reuse_variables() 140 | with tf.variable_scope("lstm"): 141 | lstm1 = tf.nn.rnn_cell.LSTMCell(self.hidden_dim, state_is_tuple=True) 142 | with tf.device("/cpu:0"), tf.variable_scope("embedding"): 143 | word_emb_W = tf.get_variable("word_emb_W", [self.num_emb, self.emb_dim], "float32", self.initializer) 144 | with tf.variable_scope("output"): 145 | output_W = tf.get_variable("output_W", [self.emb_dim, self.num_emb], "float32", self.initializer) 146 | 147 | with tf.variable_scope("lstm"): 148 | for j in range(self.sequence_length): 149 | with tf.device("/cpu:0"): 150 | if j == 0: 151 | lstm1_in = tf.nn.embedding_lookup(word_emb_W, self.start_token) 152 | else: 153 | lstm1_in = tf.nn.embedding_lookup(word_emb_W, sample_word) 154 | if j == 0: 155 | state = lstm1.zero_state(self.batch_size, tf.float32) 156 | output, state = lstm1(lstm1_in, state, scope=tf.get_variable_scope()) 157 | logits = tf.matmul(output, output_W) 158 | logprob = tf.log(tf.nn.softmax(logits)) 159 | sample_word = tf.reshape(tf.to_int32(tf.multinomial(logprob, 1)), shape=[self.batch_size]) 160 | self.sample_word_list.append(sample_word) #sequence_length * batch_size 161 | self.sample_word_list_reshape = tf.transpose(tf.squeeze(tf.stack(self.sample_word_list)), perm=[1,0]) #batch_size * sequene_length 162 | def build(self): 163 | """Create all network for pretraining, adversairal training and sampling""" 164 | self.build_pretrain_network() 165 | self.build_adversarial_network() 166 | self.build_sample_network() 167 | def generate(self, sess): 168 | """Helper function for sample generation""" 169 | return sess.run(self.sample_word_list_reshape) 170 | -------------------------------------------------------------------------------- /misc/README.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /misc/seqgan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenChengKuan/SeqGAN_tensorflow/50bc47f9188cdda5b613044de12dbade7b7bf83b/misc/seqgan.png -------------------------------------------------------------------------------- /rollout.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | class rollout(): 3 | """Rollout implementation for generator""" 4 | def __init__(self, config): 5 | #configuraiton setting 6 | self.sequence_length = config.sequence_length 7 | self.hidden_dim = config.hidden_dim 8 | self.num_emb = config.num_emb 9 | self.emb_dim = config.emb_dim 10 | self.batch_size = config.gen_batch_size 11 | self.start_token = config.start_token 12 | self.pred_seq = tf.placeholder(tf.int32, [None, self.sequence_length], name="pred_seq_rollout") 13 | self.sample_rollout_step = [] 14 | 15 | #Rollout graph initialization 16 | with tf.variable_scope("teller"): 17 | tf.get_variable_scope().reuse_variables() 18 | with tf.variable_scope("lstm"): 19 | lstm1 = tf.contrib.rnn.BasicLSTMCell(self.hidden_dim) 20 | with tf.device("/cpu:0"), tf.variable_scope("embedding"): 21 | word_emb_W = tf.get_variable("word_emb_W", [self.num_emb, self.emb_dim], tf.float32) 22 | with tf.variable_scope("output"): 23 | output_W = tf.get_variable("output_W", [self.emb_dim, self.num_emb], tf.float32) 24 | 25 | zero_state = lstm1.zero_state([self.batch_size], tf.float32) 26 | start_token = tf.constant(self.start_token, dtype=tf.int32, shape=[self.batch_size]) 27 | for step in range(1, self.sequence_length): 28 | if step % 5 == 0: 29 | print "Rollout step: {}".format(step) 30 | #Get the token for i < step 31 | sample_rollout_left = tf.reshape(self.pred_seq[:, 0:step], shape=[self.batch_size, step]) 32 | sample_rollout_rihgt = [] 33 | 34 | #Update the hidden state for i < step to prepare sampling token for i >= step 35 | for j in range(step): 36 | if j == 0: 37 | with tf.device("/cpu:0"): 38 | lstm1_in = tf.nn.embedding_lookup(word_emb_W, start_token) 39 | else: 40 | tf.get_variable_scope().reuse_variables() 41 | with tf.device("/cpu:0"): 42 | lstm1_in = tf.nn.embedding_lookup(word_emb_W, self.pred_seq[:, j-1]) 43 | with tf.variable_scope("lstm"): 44 | if j == 0: 45 | output, state = lstm1(lstm1_in, zero_state, scope=tf.get_variable_scope()) 46 | else: 47 | output, state = lstm1(lstm1_in, state, scope=tf.get_variable_scope()) 48 | #Sampling token for i >= step 49 | for j in range(step, self.sequence_length): 50 | with tf.device("/cpu:0"): 51 | if j == step: 52 | lstm1_in = tf.nn.embedding_lookup(word_emb_W, self.pred_seq[:, j-1]) 53 | else: 54 | lstm1_in = tf.nn.embedding_lookup(word_emb_W, tf.stop_gradient(sample_word)) 55 | with tf.variable_scope("lstm"): 56 | output, state = lstm1(lstm1_in, state, scope=tf.get_variable_scope()) 57 | logits = tf.matmul(output, output_W) 58 | log_probs = tf.log(tf.nn.softmax(logits)+1e-8) #add a tolerance to prevent unmeaningful log value 59 | sample_word = tf.to_int32(tf.squeeze(tf.multinomial(log_probs, 1))) 60 | sample_rollout_rihgt.append(sample_word) 61 | sample_rollout_rihgt = tf.transpose(tf.stack(sample_rollout_rihgt)) 62 | sample_rollout = tf.concat([sample_rollout_left, sample_rollout_rihgt], axis=1) 63 | self.sample_rollout_step.append(sample_rollout) 64 | -------------------------------------------------------------------------------- /save/experiment-log.txt: -------------------------------------------------------------------------------- 1 | pre-training... 2 | epoch: 0 nll: 10.1971 3 | epoch: 5 nll: 9.4694 4 | epoch: 10 nll: 9.2169 5 | epoch: 15 nll: 9.17986 6 | epoch: 20 nll: 9.16206 7 | epoch: 25 nll: 9.1344 8 | epoch: 30 nll: 9.12127 9 | epoch: 35 nll: 9.0948 10 | epoch: 40 nll: 9.10186 11 | epoch: 45 nll: 9.10108 12 | epoch: 50 nll: 9.0971 13 | epoch: 55 nll: 9.11246 14 | epoch: 60 nll: 9.1182 15 | epoch: 65 nll: 9.10095 16 | epoch: 70 nll: 9.09244 17 | epoch: 75 nll: 9.08816 18 | epoch: 80 nll: 9.10319 19 | epoch: 85 nll: 9.08916 20 | epoch: 90 nll: 9.08348 21 | epoch: 95 nll: 9.09661 22 | epoch: 100 nll: 9.10361 23 | epoch: 105 nll: 9.11718 24 | epoch: 110 nll: 9.10492 25 | epoch: 115 nll: 9.1038 26 | epoch: 0 nll: 9.09558 27 | epoch: 5 nll: 9.03083 28 | epoch: 10 nll: 8.96725 29 | epoch: 15 nll: 8.91415 30 | epoch: 20 nll: 8.87554 31 | epoch: 25 nll: 8.82305 32 | epoch: 30 nll: 8.76805 33 | epoch: 35 nll: 8.73597 34 | epoch: 40 nll: 8.71933 35 | epoch: 45 nll: 8.71653 36 | epoch: 50 nll: 8.71746 37 | epoch: 55 nll: 8.7036 38 | epoch: 60 nll: 8.68666 39 | epoch: 65 nll: 8.68931 40 | epoch: 70 nll: 8.68588 41 | epoch: 75 nll: 8.69977 42 | epoch: 80 nll: 8.69636 43 | epoch: 85 nll: 8.69916 44 | epoch: 90 nll: 8.6969 45 | epoch: 95 nll: 8.71021 46 | epoch: 100 nll: 8.72561 47 | epoch: 105 nll: 8.71369 48 | epoch: 110 nll: 8.71723 49 | epoch: 115 nll: 8.72388 50 | epoch: 120 nll: 8.71293 51 | epoch: 125 nll: 8.70667 52 | epoch: 130 nll: 8.70341 53 | epoch: 135 nll: 8.69929 54 | epoch: 140 nll: 8.69793 55 | epoch: 145 nll: 8.67705 56 | epoch: 150 nll: 8.65372 57 | epoch: 155 nll: 8.66387 58 | epoch: 160 nll: 8.66754 59 | epoch: 165 nll: 8.64971 60 | epoch: 170 nll: 8.65363 61 | epoch: 175 nll: 8.64989 62 | epoch: 180 nll: 8.65539 63 | epoch: 185 nll: 8.66427 64 | epoch: 190 nll: 8.66863 65 | epoch: 195 nll: 8.67093 66 | epoch: 199 nll: 8.66977 67 | -------------------------------------------------------------------------------- /target_lstm.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.python.ops import tensor_array_ops, control_flow_ops 3 | 4 | 5 | class TARGET_LSTM(object): 6 | def __init__(self, config, params): 7 | self.num_emb = config.num_emb 8 | self.batch_size = config.gen_batch_size 9 | self.emb_dim = config.emb_dim 10 | self.hidden_dim = config.hidden_dim 11 | self.sequence_length = config.sequence_length 12 | self.start_token = tf.constant([config.start_token] * self.batch_size, dtype=tf.int32) 13 | self.g_params = [] 14 | self.temperature = 1.0 15 | self.params = params 16 | 17 | tf.set_random_seed(66) 18 | 19 | with tf.variable_scope('generator'): 20 | self.g_embeddings = tf.Variable(self.params[0]) 21 | self.g_params.append(self.g_embeddings) 22 | self.g_recurrent_unit = self.create_recurrent_unit(self.g_params) # maps h_tm1 to h_t for generator 23 | self.g_output_unit = self.create_output_unit(self.g_params) # maps h_t to o_t (output token logits) 24 | 25 | # placeholder definition 26 | self.x = tf.placeholder(tf.int32, shape=[self.batch_size, self.sequence_length]) # sequence of tokens generated by generator 27 | 28 | # processed for batch 29 | with tf.device("/cpu:0"): 30 | self.processed_x = tf.transpose(tf.nn.embedding_lookup(self.g_embeddings, self.x), perm=[1, 0, 2]) # seq_length x batch_size x emb_dim 31 | 32 | # initial states 33 | self.h0 = tf.zeros([self.batch_size, self.hidden_dim]) 34 | self.h0 = tf.stack([self.h0, self.h0]) 35 | 36 | # generator on initial randomness 37 | gen_o = tensor_array_ops.TensorArray(dtype=tf.float32, size=self.sequence_length, 38 | dynamic_size=False, infer_shape=True) 39 | gen_x = tensor_array_ops.TensorArray(dtype=tf.int32, size=self.sequence_length, 40 | dynamic_size=False, infer_shape=True) 41 | 42 | def _g_recurrence(i, x_t, h_tm1, gen_o, gen_x): 43 | h_t = self.g_recurrent_unit(x_t, h_tm1) # hidden_memory_tuple 44 | o_t = self.g_output_unit(h_t) # batch x vocab , logits not prob 45 | log_prob = tf.log(tf.nn.softmax(o_t)) 46 | next_token = tf.cast(tf.reshape(tf.multinomial(log_prob, 1), [self.batch_size]), tf.int32) 47 | x_tp1 = tf.nn.embedding_lookup(self.g_embeddings, next_token) # batch x emb_dim 48 | gen_o = gen_o.write(i, tf.reduce_sum(tf.multiply(tf.one_hot(next_token, self.num_emb, 1.0, 0.0), 49 | tf.nn.softmax(o_t)), 1)) # [batch_size] , prob 50 | gen_x = gen_x.write(i, next_token) # indices, batch_size 51 | return i + 1, x_tp1, h_t, gen_o, gen_x 52 | 53 | _, _, _, self.gen_o, self.gen_x = control_flow_ops.while_loop( 54 | cond=lambda i, _1, _2, _3, _4: i < self.sequence_length, 55 | body=_g_recurrence, 56 | loop_vars=(tf.constant(0, dtype=tf.int32), 57 | tf.nn.embedding_lookup(self.g_embeddings, self.start_token), self.h0, gen_o, gen_x) 58 | ) 59 | 60 | self.gen_x = self.gen_x.stack() # seq_length x batch_size 61 | self.gen_x = tf.transpose(self.gen_x, perm=[1, 0]) # batch_size x seq_length 62 | 63 | # supervised pretraining for generator 64 | g_predictions = tensor_array_ops.TensorArray( 65 | dtype=tf.float32, size=self.sequence_length, 66 | dynamic_size=False, infer_shape=True) 67 | 68 | ta_emb_x = tensor_array_ops.TensorArray( 69 | dtype=tf.float32, size=self.sequence_length) 70 | ta_emb_x = ta_emb_x.unstack(self.processed_x) 71 | 72 | def _pretrain_recurrence(i, x_t, h_tm1, g_predictions): 73 | h_t = self.g_recurrent_unit(x_t, h_tm1) 74 | o_t = self.g_output_unit(h_t) 75 | g_predictions = g_predictions.write(i, tf.nn.softmax(o_t)) # batch x vocab_size 76 | x_tp1 = ta_emb_x.read(i) 77 | return i + 1, x_tp1, h_t, g_predictions 78 | 79 | _, _, _, self.g_predictions = control_flow_ops.while_loop( 80 | cond=lambda i, _1, _2, _3: i < self.sequence_length, 81 | body=_pretrain_recurrence, 82 | loop_vars=(tf.constant(0, dtype=tf.int32), 83 | tf.nn.embedding_lookup(self.g_embeddings, self.start_token), 84 | self.h0, g_predictions)) 85 | 86 | self.g_predictions = tf.transpose( 87 | self.g_predictions.stack(), perm=[1, 0, 2]) # batch_size x seq_length x vocab_size 88 | 89 | # pretraining loss 90 | self.pretrain_loss = -tf.reduce_sum( 91 | tf.one_hot(tf.to_int32(tf.reshape(self.x, [-1])), self.num_emb, 1.0, 0.0) * tf.log( 92 | tf.reshape(self.g_predictions, [-1, self.num_emb]))) / (self.sequence_length * self.batch_size) 93 | 94 | self.out_loss = tf.reduce_sum( 95 | tf.reshape( 96 | -tf.reduce_sum( 97 | tf.one_hot(tf.to_int32(tf.reshape(self.x, [-1])), self.num_emb, 1.0, 0.0) * tf.log( 98 | tf.reshape(self.g_predictions, [-1, self.num_emb])), 1 99 | ), [-1, self.sequence_length] 100 | ), 1 101 | ) # batch_size 102 | 103 | def generate(self, session): 104 | # h0 = np.random.normal(size=self.hidden_dim) 105 | outputs = session.run(self.gen_x) 106 | return outputs 107 | 108 | def init_matrix(self, shape): 109 | return tf.random_normal(shape, stddev=1.0) 110 | 111 | def create_recurrent_unit(self, params): 112 | # Weights and Bias for input and hidden tensor 113 | self.Wi = tf.Variable(self.params[1]) 114 | self.Ui = tf.Variable(self.params[2]) 115 | self.bi = tf.Variable(self.params[3]) 116 | 117 | self.Wf = tf.Variable(self.params[4]) 118 | self.Uf = tf.Variable(self.params[5]) 119 | self.bf = tf.Variable(self.params[6]) 120 | 121 | self.Wog = tf.Variable(self.params[7]) 122 | self.Uog = tf.Variable(self.params[8]) 123 | self.bog = tf.Variable(self.params[9]) 124 | 125 | self.Wc = tf.Variable(self.params[10]) 126 | self.Uc = tf.Variable(self.params[11]) 127 | self.bc = tf.Variable(self.params[12]) 128 | params.extend([ 129 | self.Wi, self.Ui, self.bi, 130 | self.Wf, self.Uf, self.bf, 131 | self.Wog, self.Uog, self.bog, 132 | self.Wc, self.Uc, self.bc]) 133 | 134 | def unit(x, hidden_memory_tm1): 135 | previous_hidden_state, c_prev = tf.unstack(hidden_memory_tm1) 136 | 137 | # Input Gate 138 | i = tf.sigmoid( 139 | tf.matmul(x, self.Wi) + 140 | tf.matmul(previous_hidden_state, self.Ui) + self.bi 141 | ) 142 | 143 | # Forget Gate 144 | f = tf.sigmoid( 145 | tf.matmul(x, self.Wf) + 146 | tf.matmul(previous_hidden_state, self.Uf) + self.bf 147 | ) 148 | 149 | # Output Gate 150 | o = tf.sigmoid( 151 | tf.matmul(x, self.Wog) + 152 | tf.matmul(previous_hidden_state, self.Uog) + self.bog 153 | ) 154 | 155 | # New Memory Cell 156 | c_ = tf.nn.tanh( 157 | tf.matmul(x, self.Wc) + 158 | tf.matmul(previous_hidden_state, self.Uc) + self.bc 159 | ) 160 | 161 | # Final Memory cell 162 | c = f * c_prev + i * c_ 163 | 164 | # Current Hidden state 165 | current_hidden_state = o * tf.nn.tanh(c) 166 | 167 | return tf.stack([current_hidden_state, c]) 168 | 169 | return unit 170 | 171 | def create_output_unit(self, params): 172 | self.Wo = tf.Variable(self.params[13]) 173 | self.bo = tf.Variable(self.params[14]) 174 | params.extend([self.Wo, self.bo]) 175 | 176 | def unit(hidden_memory_tuple): 177 | hidden_state, c_prev = tf.unstack(hidden_memory_tuple) 178 | # hidden_state : batch x hidden_dim 179 | logits = tf.matmul(hidden_state, self.Wo) + self.bo 180 | # output = tf.nn.softmax(logits) 181 | return logits 182 | 183 | return unit 184 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import random 2 | import os 3 | import cPickle 4 | import numpy as np 5 | import tensorflow as tf 6 | 7 | from configuration import * 8 | from utils import * 9 | from dataloader import Gen_Data_loader, Dis_dataloader 10 | from discriminator import Discriminator 11 | from generator import Generator 12 | from rollout import rollout 13 | from target_lstm import TARGET_LSTM 14 | 15 | #Hardware related setting 16 | config_hardware = tf.ConfigProto() 17 | config_hardware.gpu_options.per_process_gpu_memory_fraction = 0.40 18 | os.environ["CUDA_VISIBLE_DEVICES"]="0" 19 | 20 | def main(unused_argv): 21 | config_train = training_config() 22 | config_gen = generator_config() 23 | config_dis = discriminator_config() 24 | np.random.seed(config_train.seed) 25 | assert config_train.start_token == 0 26 | 27 | #Build dataloader for generaotr, testing and discriminator 28 | gen_data_loader = Gen_Data_loader(config_gen.gen_batch_size) 29 | likelihood_data_loader = Gen_Data_loader(config_gen.gen_batch_size) 30 | dis_data_loader = Dis_dataloader(config_dis.dis_batch_size) 31 | 32 | #Build generator and its rollout 33 | generator = Generator(config=config_gen) 34 | generator.build() 35 | rollout_gen = rollout(config=config_gen) 36 | 37 | #Build target LSTM 38 | target_params = cPickle.load(open('save/target_params.pkl')) 39 | target_lstm = TARGET_LSTM(config=config_gen, params=target_params) # The oracle model 40 | 41 | #Build discriminator 42 | discriminator = Discriminator(config=config_dis) 43 | discriminator.build_discriminator() 44 | 45 | #Build optimizer op for pretraining 46 | pretrained_optimizer = tf.train.AdamOptimizer(config_train.gen_learning_rate) 47 | var_pretrained = [v for v in tf.trainable_variables() if 'teller' in v.name] #Using name 'teller' here to prevent name collision of target LSTM 48 | gradients, variables = zip(*pretrained_optimizer.compute_gradients(generator.pretrained_loss, var_list=var_pretrained)) 49 | gradients, _ = tf.clip_by_global_norm(gradients, config_train.grad_clip) 50 | gen_pre_upate = pretrained_optimizer.apply_gradients(zip(gradients, variables)) 51 | 52 | #Initialize all variables 53 | sess = tf.Session(config=config_hardware) 54 | sess.run(tf.global_variables_initializer()) 55 | 56 | #Initalize data loader of generator 57 | generate_samples(sess, target_lstm, config_train.batch_size, config_train.generated_num, config_train.positive_file) 58 | gen_data_loader.create_batches(config_train.positive_file) 59 | 60 | #Start pretraining 61 | log = open('save/experiment-log.txt', 'w') 62 | print 'Start pre-training generator...' 63 | log.write('pre-training...\n') 64 | for epoch in xrange(config_train.pretrained_epoch_num): 65 | gen_data_loader.reset_pointer() 66 | for it in xrange(gen_data_loader.num_batch): 67 | batch = gen_data_loader.next_batch() 68 | _, g_loss = sess.run([gen_pre_upate, generator.pretrained_loss], feed_dict={generator.input_seqs_pre:batch,\ 69 | generator.input_seqs_mask:np.ones_like(batch)}) 70 | if epoch % config_train.test_per_epoch == 0: 71 | generate_samples(sess, generator, config_train.batch_size, config_train.generated_num, config_train.eval_file) 72 | likelihood_data_loader.create_batches(config_train.eval_file) 73 | test_loss = target_loss(sess, target_lstm, likelihood_data_loader) 74 | print 'pre-train epoch ', epoch, 'test_loss ', test_loss 75 | buffer = 'epoch:\t'+ str(epoch) + '\tnll:\t' + str(test_loss) + '\n' 76 | log.write(buffer) 77 | 78 | print 'Start pre-training discriminator...' 79 | for t in range(config_train.dis_update_time_pre): 80 | print "Times: " + str(t) 81 | generate_samples(sess, generator, config_train.batch_size, config_train.generated_num, config_train.negative_file) 82 | dis_data_loader.load_train_data(config_train.positive_file, config_train.negative_file) 83 | for _ in range(config_train.dis_update_epoch_pre): 84 | dis_data_loader.reset_pointer() 85 | for it in xrange(dis_data_loader.num_batch): 86 | x_batch, y_batch = dis_data_loader.next_batch() 87 | feed = { 88 | discriminator.input_x: x_batch, 89 | discriminator.input_y: y_batch, 90 | discriminator.dropout_keep_prob: config_dis.dis_dropout_keep_prob 91 | } 92 | _ = sess.run(discriminator.train_op, feed) 93 | 94 | #Build optimizer op for adversarial training 95 | train_adv_opt = tf.train.AdamOptimizer(config_train.gen_learning_rate) 96 | gradients, variables = zip(*train_adv_opt.compute_gradients(generator.gen_loss_adv,var_list=var_pretrained)) 97 | gradients, _ = tf.clip_by_global_norm(gradients, config_train.grad_clip) 98 | train_adv_update = train_adv_opt.apply_gradients(zip(gradients, variables)) 99 | 100 | #Initialize global variables of optimizer for adversarial training 101 | uninitialized_var = [e for e in tf.global_variables() if e not in tf.trainable_variables()] 102 | init_vars_uninit_op = tf.variables_initializer(uninitialized_var) 103 | sess.run(init_vars_uninit_op) 104 | 105 | #Start adversarial training 106 | for total_batch in xrange(config_train.total_batch): 107 | for iter_gen in xrange(config_train.gen_update_time): 108 | samples = sess.run(generator.sample_word_list_reshape) 109 | 110 | feed = {"pred_seq_rollout:0":samples} 111 | reward_rollout = [] 112 | #calcuate the reward given in the specific stpe t by roll out 113 | for iter_roll in xrange(config_train.rollout_num): 114 | rollout_list = sess.run(rollout_gen.sample_rollout_step, feed_dict=feed) 115 | rollout_list_stack = np.vstack(rollout_list) #shape: #batch_size * #rollout_step, #sequence length 116 | reward_rollout_seq = sess.run(discriminator.ypred_for_auc, feed_dict={discriminator.input_x:rollout_list_stack, discriminator.dropout_keep_prob:1.0}) 117 | reward_last_tok = sess.run(discriminator.ypred_for_auc, feed_dict={discriminator.input_x:samples, discriminator.dropout_keep_prob:1.0}) 118 | reward_allseq = np.concatenate((reward_rollout_seq, reward_last_tok), axis=0)[:,1] 119 | reward_tmp = [] 120 | for r in xrange(config_gen.gen_batch_size): 121 | reward_tmp.append(reward_allseq[range(r, config_gen.gen_batch_size * config_gen.sequence_length, config_gen.gen_batch_size)]) 122 | reward_rollout.append(np.array(reward_tmp)) 123 | rewards = np.sum(reward_rollout, axis=0)/config_train.rollout_num 124 | _, gen_loss = sess.run([train_adv_update, generator.gen_loss_adv], feed_dict={generator.input_seqs_adv:samples,\ 125 | generator.rewards:rewards}) 126 | if total_batch % config_train.test_per_epoch == 0 or total_batch == config_train.total_batch - 1: 127 | generate_samples(sess, generator, config_train.batch_size, config_train.generated_num, config_train.eval_file) 128 | likelihood_data_loader.create_batches(config_train.eval_file) 129 | test_loss = target_loss(sess, target_lstm, likelihood_data_loader) 130 | buffer = 'epoch:\t' + str(total_batch) + '\tnll:\t' + str(test_loss) + '\n' 131 | print 'total_batch: ', total_batch, 'test_loss: ', test_loss 132 | log.write(buffer) 133 | 134 | for _ in range(config_train.dis_update_time_adv): 135 | generate_samples(sess, generator, config_train.batch_size, config_train.generated_num, config_train.negative_file) 136 | dis_data_loader.load_train_data(config_train.positive_file, config_train.negative_file) 137 | 138 | for _ in range(config_train.dis_update_epoch_adv): 139 | dis_data_loader.reset_pointer() 140 | for it in xrange(dis_data_loader.num_batch): 141 | x_batch, y_batch = dis_data_loader.next_batch() 142 | feed = { 143 | discriminator.input_x: x_batch, 144 | discriminator.input_y: y_batch, 145 | discriminator.dropout_keep_prob: config_dis.dis_dropout_keep_prob 146 | } 147 | _ = sess.run(discriminator.train_op, feed) 148 | log.close() 149 | if __name__ == "__main__": 150 | tf.app.run() 151 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | def generate_samples(sess, trainable_model, batch_size, generated_num, output_file): 3 | # Generate Samples 4 | generated_samples = [] 5 | for _ in range(int(generated_num / batch_size)): 6 | generated_samples.extend(trainable_model.generate(sess)) 7 | 8 | with open(output_file, 'w') as fout: 9 | for poem in generated_samples: 10 | buffer = ' '.join([str(x) for x in poem]) + '\n' 11 | fout.write(buffer) 12 | 13 | 14 | def target_loss(sess, target_lstm, data_loader): 15 | # target_loss means the oracle negative log-likelihood tested with the oracle model "target_lstm" 16 | # For more details, please see the Section 4 in https://arxiv.org/abs/1609.05473 17 | nll = [] 18 | data_loader.reset_pointer() 19 | 20 | for it in xrange(data_loader.num_batch): 21 | batch = data_loader.next_batch() 22 | g_loss = sess.run(target_lstm.pretrain_loss, {target_lstm.x: batch}) 23 | nll.append(g_loss) 24 | 25 | return np.mean(nll) 26 | --------------------------------------------------------------------------------