├── .gitignore ├── README.md ├── deepx ├── batcher.py ├── charrnn.py ├── experiments.py ├── gan.py ├── iterate.py ├── load_generative_parameters.py ├── plots │ └── plotting.py ├── refactor_experiments.py ├── rename_weights.py ├── train_discriminator.py ├── train_generator.py └── utils.py └── tensorflow ├── README.md ├── batcher.py ├── batcher_gan.py ├── discriminator.py ├── gan.py ├── generator.py ├── predict.py ├── sample.py ├── simple_vocab.pkl ├── train_gan_new.py └── train_models.py /.gitignore: -------------------------------------------------------------------------------- 1 | deepx/*.pyc 2 | deepx/data/* 3 | deepx/dataset/* 4 | deepx/models/* 5 | deepx/.theano/* 6 | deepx/loss/* 7 | deepx/log.txt 8 | tensorflow/loss 9 | tensorflow/logs 10 | tensorflow/data 11 | tensorflow/models 12 | tensorflow/models_discriminator 13 | tensorflow/models_generator 14 | tensorflow/models_GAN 15 | tensorflow/old 16 | tensorflow/vocab 17 | tensorflow/*.pyc -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Character-Level Generative Adversarial Network 2 | Generalizing the [Generative Adversarial Networks](http://arxiv.org/abs/1406.2661) introduced by I. Goodfellow et al. into a natural language formulation. This research was initially powered by [DeepX](https://github.com/sharadmv/deepx) but more recent work is using [Tensorflow](http://www.tensorflow.org). -------------------------------------------------------------------------------- /deepx/batcher.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class WindowedBatcher(object): 4 | 5 | def __init__(self, sequences, encodings, target, batch_size=100, sequence_length=50): 6 | self.sequences = sequences 7 | 8 | self.pre_vector_sizes = [c.seq[0].shape[0] for c in self.sequences] + [target.seq[0].shape[0]] 9 | self.pre_vector_size = sum(self.pre_vector_sizes) 10 | self.target_size = target.seq[0].shape[0] 11 | 12 | self.encodings = encodings 13 | self.vocab_sizes = [c.index for c in self.encodings] + [self.target_size] 14 | self.vocab_size = sum(self.vocab_sizes) 15 | self.batch_index = 0 16 | self.batches = [] 17 | self.batch_size = batch_size 18 | self.sequence_length = sequence_length 19 | self.length = len(self.sequences[0]) 20 | 21 | self.batch_index = 0 22 | self.X = np.zeros((self.length, self.pre_vector_size), dtype=np.int32) 23 | self.X = np.hstack([c.seq for c in self.sequences] + [target.seq]) 24 | 25 | N, D = self.X.shape 26 | assert N > self.batch_size * self.sequence_length, "File has to be at least %u characters" % (self.batch_size * self.sequence_length) 27 | 28 | self.X = self.X[:N - N % (self.batch_size * self.sequence_length)] 29 | self.N, self.D = self.X.shape 30 | self.X = self.X.reshape((self.N / self.sequence_length, self.sequence_length, self.D)) 31 | 32 | self.N, self.S, self.D = self.X.shape 33 | 34 | self.num_sequences = self.N / self.sequence_length 35 | self.num_batches = self.N / self.batch_size 36 | self.batch_cache = {} 37 | 38 | def next_batch(self): 39 | idx = (self.batch_index * self.batch_size) 40 | if self.batch_index >= self.num_batches: 41 | self.batch_index = 0 42 | idx = 0 43 | 44 | if self.batch_index in self.batch_cache: 45 | batch = self.batch_cache[self.batch_index] 46 | self.batch_index += 1 47 | return batch 48 | 49 | X = self.X[idx:idx + self.batch_size] 50 | y = np.zeros((X.shape[0], self.sequence_length, self.vocab_size)) 51 | for i in xrange(self.batch_size): 52 | for c in xrange(self.sequence_length): 53 | seq_splits = np.split(X[i, c], np.cumsum(self.pre_vector_sizes)) 54 | vec = np.concatenate([e.convert_representation(split) for 55 | e, split in zip(self.encodings, seq_splits)] + [X[i, c, -self.target_size:]]) 56 | y[i, c] = vec 57 | 58 | X = y[:, :, :-self.target_size] 59 | y = y[:, :, -self.target_size:] 60 | 61 | X = np.swapaxes(X, 0, 1) 62 | y = np.swapaxes(y, 0, 1) 63 | # self.batch_cache[self.batch_index] = X, y 64 | self.batch_index += 1 65 | return X, y 66 | -------------------------------------------------------------------------------- /deepx/charrnn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import theano 3 | import sys 4 | import csv 5 | theano.config.on_unused_input = 'ignore' 6 | import theano.tensor as T 7 | import logging 8 | from theanify import theanify, Theanifiable 9 | logging.basicConfig(level=logging.DEBUG) 10 | from argparse import ArgumentParser 11 | import random 12 | from dataset import * 13 | 14 | from deepx.nn import * 15 | from deepx.rnn import * 16 | from deepx.loss import * 17 | from deepx.optimize import * 18 | 19 | 20 | class WindowedBatcher(object): 21 | 22 | def __init__(self, sequences, encodings, target, batch_size=100, sequence_length=50): 23 | self.sequences = sequences 24 | 25 | self.pre_vector_sizes = [c.seq[0].shape[0] for c in self.sequences] + [target.seq[0].shape[0]] 26 | self.pre_vector_size = sum(self.pre_vector_sizes) 27 | self.target_size = target.seq[0].shape[0] 28 | 29 | self.encodings = encodings 30 | self.vocab_sizes = [c.index for c in self.encodings] + [self.target_size] 31 | self.vocab_size = sum(self.vocab_sizes) 32 | self.batch_index = 0 33 | self.batches = [] 34 | self.batch_size = batch_size 35 | self.sequence_length = sequence_length 36 | self.length = len(self.sequences[0]) 37 | 38 | self.batch_index = 0 39 | self.X = np.zeros((self.length, self.pre_vector_size), dtype=np.int32) 40 | self.X = np.hstack([c.seq for c in self.sequences] + [target.seq]) 41 | 42 | N, D = self.X.shape 43 | assert N > self.batch_size * self.sequence_length, "File has to be at least %u characters" % (self.batch_size * self.sequence_length) 44 | 45 | self.X = self.X[:N - N % (self.batch_size * self.sequence_length)] 46 | self.N, self.D = self.X.shape 47 | self.X = self.X.reshape((self.N / self.sequence_length, self.sequence_length, self.D)) 48 | 49 | self.N, self.S, self.D = self.X.shape 50 | 51 | self.num_sequences = self.N / self.sequence_length 52 | self.num_batches = self.N / self.batch_size 53 | self.batch_cache = {} 54 | 55 | def next_batch(self): 56 | idx = (self.batch_index * self.batch_size) 57 | if self.batch_index >= self.num_batches: 58 | self.batch_index = 0 59 | idx = 0 60 | 61 | if self.batch_index in self.batch_cache: 62 | batch = self.batch_cache[self.batch_index] 63 | self.batch_index += 1 64 | return batch 65 | 66 | X = self.X[idx:idx + self.batch_size] 67 | y = np.zeros((X.shape[0], self.sequence_length, self.vocab_size)) 68 | for i in xrange(self.batch_size): 69 | for c in xrange(self.sequence_length): 70 | seq_splits = np.split(X[i, c], np.cumsum(self.pre_vector_sizes)) 71 | vec = np.concatenate([e.convert_representation(split) for 72 | e, split in zip(self.encodings, seq_splits)] + [X[i, c, -self.target_size:]]) 73 | y[i, c] = vec 74 | 75 | X = y[:, :, :-self.target_size] 76 | y = y[:, :, -self.target_size:] 77 | 78 | 79 | X = np.swapaxes(X, 0, 1) 80 | y = np.swapaxes(y, 0, 1) 81 | # self.batch_cache[self.batch_index] = X, y 82 | self.batch_index += 1 83 | return X, y 84 | 85 | def parse_args(): 86 | argparser = ArgumentParser() 87 | 88 | argparser.add_argument("real_file") 89 | argparser.add_argument("fake_file") 90 | 91 | return argparser.parse_args() 92 | 93 | def generate(length, temperature): 94 | results = charrnn.generate( 95 | np.eye(len(encoding))[encoding.encode("i")], 96 | length, 97 | temperature).argmax(axis=1) 98 | return NumberSequence(results).decode(encoding) 99 | 100 | 101 | if __name__ == "__main__": 102 | args = parse_args() 103 | logging.debug("Reading file...") 104 | with open(args.real_file, 'r') as fp: 105 | real_reviews = [r[3:] for r in fp.read().strip().split('\n')] 106 | with open(args.fake_file, 'r') as fp: 107 | fake_reviews = [r[3:] for r in fp.read().strip().split('\n')] 108 | 109 | # Load and shuffle reviews 110 | real_targets, fake_targets = [], [] 111 | for _ in xrange(len(real_reviews)): 112 | real_targets.append([0, 1]) 113 | for _ in xrange(len(fake_reviews)): 114 | fake_targets.append([1, 0]) 115 | 116 | all_reviews = zip(real_reviews, real_targets) + zip(fake_reviews, fake_targets) 117 | 118 | random.seed(1) 119 | random.shuffle(all_reviews) 120 | 121 | # Partition training set due to memory constraints on AWS 122 | reviews, targets = zip(*all_reviews[:100000]) 123 | 124 | logging.debug("Converting to one-hot...") 125 | review_sequences = [CharacterSequence.from_string(review) for review in reviews] 126 | 127 | text_encoding = OneHotEncoding(include_start_token=True, include_stop_token=True) 128 | text_encoding.build_encoding(review_sequences) 129 | 130 | num_sequences = [c.encode(text_encoding) for c in review_sequences] 131 | final_seq = NumberSequence(np.concatenate([c.seq.astype(np.int32) for c in num_sequences])) 132 | target_sequences = [NumberSequence([target]).replicate(len(r)) for target, r in zip(targets, num_sequences)] 133 | final_target = NumberSequence(np.concatenate([c.seq.astype(np.int32) for c in target_sequences])) 134 | 135 | # Construct the batcher 136 | batcher = WindowedBatcher([final_seq], [text_encoding], final_target, sequence_length=200, batch_size=100) 137 | 138 | # Define the model 139 | logging.debug("Compiling model") 140 | discriminator = Sequence(Vector(len(text_encoding))) >> MultilayerLSTM(1024, num_layers=2) >> Softmax(2) 141 | 142 | # Training 143 | rmsprop = RMSProp(discriminator, ConvexSequentialLoss(CrossEntropy(), 0.5), clip_gradients=5) 144 | 145 | step_size = 5 146 | iterations = 1000 147 | 148 | # Training loss 149 | train_loss = [] 150 | for _ in xrange(iterations): 151 | X, y = batcher.next_batch() 152 | train_loss.append(rmsprop.train(X,Y,step_size)) 153 | 154 | # Training accuracy 155 | 156 | 157 | # Testing accuracy 158 | 159 | 160 | # with open('loss/loss_0.25_50.csv', 'wb') as loss_file: 161 | # wr = csv.writer(loss_file) 162 | # wr.writerow(loss) 163 | 164 | 165 | # Training Accuracy 166 | # test_reviews = [np.vstack([text_encoding.convert_representation(a) for a in r.encode(text_encoding).seq]) [:, np.newaxis] for r in review_sequences[:100]] 167 | # test_labels = np.array(targets)[:100].argmax(axis=1) 168 | # discriminator.compile_method('accuracy') # 169 | # discriminator.compile_method('predict') 170 | 171 | # errors = 0 172 | # for review, label in zip(test_reviews, test_labels): 173 | # error = discriminator.accuracy(review[:, np.newaxis], np.zeros((1, 2, 1024)), [label]) 174 | # if error == 0: 175 | # print "Correct!" 176 | # errors += error 177 | 178 | # print "Error rate: %f" % (float(errors) / 100) 179 | 180 | -------------------------------------------------------------------------------- /deepx/experiments.py: -------------------------------------------------------------------------------- 1 | import theano 2 | theano.config.on_unused_input = 'ignore' 3 | import numpy as np 4 | import cPickle as pickle 5 | import theano 6 | import sys 7 | import csv 8 | import logging 9 | import random 10 | import Tkinter 11 | from dataset import * 12 | from dataset.sequence import * 13 | from deepx.nn import * 14 | from deepx.rnn import * 15 | from deepx.loss import * 16 | from deepx.optimize import * 17 | from deepx import backend as T 18 | from argparse import ArgumentParser 19 | from utils import * 20 | import string 21 | 22 | logger = logging.getLogger() 23 | logger.setLevel(logging.DEBUG) 24 | 25 | 26 | def text_to_num(text): 27 | '''Convert text to number representation''' 28 | char_seq = CharacterSequence.from_string(text) 29 | num_seq = char_seq.encode(text_encoding_D) 30 | num_seq_np = num_seq.seq.astype(np.int32) 31 | X = np.eye(len(text_encoding_D))[num_seq_np] 32 | return X 33 | 34 | def predict(model, encoding, text, preprocess=True): 35 | '''Return prediction array at each time-step of input text''' 36 | if preprocess: 37 | text = text.replace('','') 38 | text = text.replace('','') 39 | char_seq = CharacterSequence.from_string(text) 40 | num_seq = char_seq.encode(encoding) 41 | num_seq_np = num_seq.seq.astype(np.int32) 42 | X = np.eye(len(encoding))[num_seq_np] 43 | return model.predict(X) 44 | 45 | 46 | ############### 47 | # Experiment 1 48 | ############### 49 | def noise_test(num_reviews, data_dir = 'data/fake_beer_reviews.txt', fractional_noise = 0.2, distribution='uniform'): 50 | '''Test performance of the discriminator with noise added to one-hot vectors''' 51 | 52 | reviews = load_reviews(data_dir) 53 | last_review = np.random.randint(num_reviews, len(reviews)) 54 | reviews = reviews[last_review - num_reviews : last_review] 55 | reviews = [r.replace('','').replace('','').replace('<','').replace('>','') for r in reviews] 56 | 57 | for i, review in enumerate(reviews): 58 | print 'Review #%i'%(i) 59 | print review, '\n' 60 | num_seq = text_to_num(review) 61 | shape = num_seq.shape 62 | print ' Unperturbed_0: ', discriminator_0.predict(num_seq)[-1] 63 | 64 | if distribution is 'constant': 65 | noise = fractional_noise * np.ones(shape) 66 | blurred = num_seq + noise 67 | elif distribution is 'uniform': 68 | noise = np.random.uniform(0.0, fractional_noise, shape) 69 | blurred = num_seq + noise 70 | elif distribution is 'dirichlet': 71 | blurred = [np.random.dirichlet(num_seq[j,0,:] + fractional_noise) for j in xrange(len(num_seq))] 72 | blurred = np.asarray(blurred) 73 | blurred = blurred.reshape(shape) 74 | print ' Perturbed_0: ', discriminator_0.predict(blurred)[-1], '\n' 75 | 76 | print ' Unperturbed_1: ', discriminator_1.predict(num_seq)[-1] 77 | print ' Perturbed_1: ', discriminator_1.predict(blurred)[-1], '\n' 78 | 79 | ############### 80 | # Experiment 2 81 | ############### 82 | 83 | class DiscriminatorEvaluation(object): 84 | def __init__(self, models): 85 | self.models = models 86 | 87 | 88 | def load_sequences(self, num_sequences=100): 89 | sequences = {} 90 | sequences['real'] = load_reviews('data/real_beer_reviews.txt')[:num_sequences] 91 | sequences['fake'] = load_reviews('data/fake_beer_reviews.txt')[:num_sequences] 92 | if num_sequences <= 79: #79 and above not in encoding 93 | sequences['repeat'] = [char*200 for char in string.printable[:num_sequences]] 94 | else: 95 | sequences['repeat'] = [char*200 for char in string.printable[:79]] 96 | sequences['random_char'] = load_reviews('data/curriculum/random_reviews_1.discriminator_0')[:num_sequences] 97 | 98 | 99 | def manual_evaluation(self, sequence_dict, model_list): 100 | '''Review predictions on specific sequences''' 101 | pass 102 | 103 | def batch_evaluation(self, model): 104 | '''Return predictions on specific sequences''' 105 | pass 106 | 107 | 108 | 109 | 110 | 111 | def discriminator_evaluation(models, encoding, num_sequences=5): 112 | # Sequence 0: Real reviews 113 | real_reviews = load_reviews('data/real_beer_reviews.txt')[:num_sequences] 114 | 115 | # Sequence 1: Fake reviews (Original) 116 | fake_reviews = load_reviews('data/fake_beer_reviews.txt')[:num_sequences] 117 | 118 | # Sequence 2: Repeating characters 119 | repeating_chars = [char*200 for char in string.printable[:79]] #79 and above not in encoding 120 | 121 | # Sequence 3: Repeating words 122 | repeating_words = [] 123 | 124 | # Sequence 4: Random characters 125 | random_chars = load_reviews('data/curriculum/random_reviews_1.0.txt')[:num_sequences] 126 | 127 | # Sequence 5: Random words 128 | random_words = [] 129 | 130 | print 'Real' 131 | for review in real_reviews: 132 | print review 133 | for name, model in models.items(): 134 | print('{:<25}: {:<10.3} {:< 10.3}'.format(name, *predict(model, encoding, review)[-1][0])) 135 | print '\n' 136 | 137 | print 'Fake' 138 | for review in fake_reviews: 139 | print review 140 | 141 | for name, model in models.items(): 142 | print('{:<25}: {:<10.3} {:< 10.3}'.format(name, *predict(model, encoding, review)[-1][0])) 143 | print '\n' 144 | 145 | print 'Repeating' 146 | for review in repeating_chars: 147 | print review 148 | 149 | for name, model in models.items(): 150 | print('{:<25}: {:<10.3} {:< 10.3}'.format(name, *predict(model, encoding, review)[-1][0])) 151 | print '\n' 152 | 153 | print 'Random' 154 | for review in random_chars: 155 | print review 156 | 157 | for name, model in models.items(): 158 | print('{:<25}: {:<10.3} {:< 10.3}'.format(name, *predict(model, encoding, review)[-1][0])) 159 | print '\n' 160 | 161 | 162 | def discriminator_histograms(models, encoding, num_sequences = 10): 163 | real_reviews = load_reviews('data/real_beer_reviews.txt')[:num_sequences] 164 | random_chars = load_reviews('data/curriculum/random_reviews_1.0.txt')[:num_sequences] 165 | 166 | results = {} 167 | 168 | for name in models.keys(): 169 | print name 170 | results[name] = [] 171 | 172 | print results 173 | 174 | for review in random_chars: 175 | print review 176 | for name, model in models.iteritems(): 177 | 178 | print predict(model, encoding, review)[-1][0][0] 179 | results[name].append(predict(model, encoding, review)[-1][0][0]) 180 | 181 | return results 182 | 183 | 184 | 185 | 186 | 187 | if __name__ == '__main__': 188 | logging.debug('Loading encoding...') 189 | with open('data/charnet-encoding.pkl', 'rb') as fp: 190 | text_encoding_D = pickle.load(fp) 191 | text_encoding_D.include_stop_token = False 192 | text_encoding_D.include_start_token = False 193 | 194 | discriminator_0 = Sequence(Vector(len(text_encoding_D))) >> (Repeat(LSTM(1024), 2) >> Softmax(2)) 195 | discriminator_1 = Sequence(Vector(len(text_encoding_D))) >> (Repeat(LSTM(1024), 2) >> Softmax(2)) 196 | discriminator_2 = Sequence(Vector(len(text_encoding_D))) >> Repeat(LSTM(1024) >> Dropout(0.5), 2) >> Softmax(2) 197 | discriminator_3 = Sequence(Vector(len(text_encoding_D))) >> (Repeat(LSTM(1024), 2) >> Softmax(2)) 198 | discriminator_4 = Sequence(Vector(len(text_encoding_D))) >> Repeat(LSTM(1024) >> Dropout(0.5), 2) >> Softmax(2) 199 | discriminator_5 = Sequence(Vector(len(text_encoding_D))) >> Repeat(LSTM(1024) >> Dropout(0.5), 2) >> Softmax(2) 200 | 201 | logging.debug('Loading discriminators...') 202 | with open('models/discriminative/discriminative-model-0.0.0.pkl', 'rb') as fp: 203 | state = pickle.load(fp) 204 | state = (state[0][0], (state[0][1], state[1])) 205 | discriminator_0.set_state(state) 206 | 207 | with open('models/discriminative/discriminative-model-0.3.1.pkl', 'rb') as fp: 208 | discriminator_1.set_state(pickle.load(fp)) 209 | 210 | with open('models/discriminative/discriminative-dropout-model-0.0.2.pkl', 'rb') as fp: 211 | discriminator_2.set_state(pickle.load(fp)) 212 | 213 | with open('models/discriminative/discriminative-adversarial-model-0.0.0.pkl', 'rb') as fp: 214 | state = pickle.load(fp) 215 | state = (state[0][0], (state[0][1], state[1])) 216 | discriminator_3.set_state(state) 217 | 218 | with open('models/discriminative/discriminative-adversarial-dropout-model-0.0.0.pkl', 'rb') as fp: 219 | discriminator_4.set_state(pickle.load(fp)) 220 | 221 | with open('models/discriminative/discriminative-adversarial-dropout-model-0.1.0.pkl', 'rb') as fp: 222 | discriminator_5.set_state(pickle.load(fp)) 223 | 224 | models = { 225 | 'original': discriminator_0, 226 | 'mix': discriminator_1, 227 | 'dropout': discriminator_2, 228 | 'adversarial': discriminator_3, 229 | 'adversarial_dropout': discriminator_4, 230 | 'adversarial_dropout_mix': discriminator_5} 231 | 232 | # discriminator_evaluation(models, text_encoding_D, 5) 233 | 234 | -------------------------------------------------------------------------------- /deepx/gan.py: -------------------------------------------------------------------------------- 1 | import theano 2 | theano.config.on_unused_input = 'ignore' 3 | import numpy as np 4 | import cPickle as pickle 5 | import theano 6 | import sys 7 | import csv 8 | import logging 9 | import random 10 | import Tkinter 11 | from dataset import * 12 | from dataset.sequence import * 13 | from batcher import * 14 | from deepx.nn import * 15 | from deepx.rnn import * 16 | from deepx.loss import * 17 | from deepx.optimize import * 18 | from deepx import backend as T 19 | from argparse import ArgumentParser 20 | from utils import * 21 | 22 | logger = logging.getLogger() 23 | logger.setLevel(logging.DEBUG) 24 | 25 | 26 | def parse_args(): 27 | argparser = ArgumentParser() 28 | argparser.add_argument('--sequence_length', default=200) 29 | argparser.add_argument('--batch_size', default=100, type=int) 30 | argparser.add_argument('--dropout_rate', default=0.5, type=float) 31 | argparser.add_argument('--save_model_every', default=100, type=int) 32 | argparser.add_argument('--log', default='loss/gan/gan_adversarial_log_current.txt') 33 | return argparser.parse_args() 34 | 35 | def generate_sample(num_reviews): 36 | '''Generate a sample from the current version of the generator''' 37 | pred_seq = generator.predict(np.tile(np.eye(100)[0], (num_reviews, 1))) 38 | return pred_seq 39 | 40 | def generate_fake_reviews(num_reviews): 41 | '''Generate fake reviews using the current generator''' 42 | pred_seq = generate_sample(num_reviews).argmax(axis=2).T 43 | num_seq = [NumberSequence(pred_seq[i]).decode(text_encoding_D) for i in xrange(num_reviews)] 44 | return_str = [''.join(n.seq) for n in num_seq] 45 | return return_str 46 | 47 | def predict(text, preprocess=True): 48 | '''Return prediction array at each time-step of input text''' 49 | if preprocess: 50 | text = text.replace('','') 51 | text = text.replace('','') 52 | char_seq = CharacterSequence.from_string(text) 53 | num_seq = char_seq.encode(text_encoding_D) 54 | num_seq_np = num_seq.seq.astype(np.int32) 55 | X = np.eye(len(text_encoding_D))[num_seq_np] 56 | return discriminator.predict(X) 57 | 58 | def classification_accuracy(reviews, labels): 59 | '''Classification accuracy based on prediction at final time-step''' 60 | correct = 0.0 61 | reviews = [r.replace('', '') for r in reviews] 62 | reviews = [r.replace('', '') for r in reviews] 63 | 64 | for review, label in zip(reviews, labels): 65 | pred = predict(review)[-1][0] 66 | print pred, label, pred.argmax() == label.argmax() 67 | if pred.argmax() == label.argmax(): 68 | correct += 1 69 | return correct/len(reviews) 70 | 71 | 72 | 73 | 74 | if __name__ == "__main__": 75 | args = parse_args() 76 | 77 | logging.debug('Retrieving text encoding...') 78 | with open('data/charnet-encoding.pkl', 'rb') as fp: 79 | text_encoding_G = pickle.load(fp) 80 | 81 | with open('data/charnet-encoding.pkl', 'rb') as fp: 82 | text_encoding_D = pickle.load(fp) 83 | text_encoding_D.include_stop_token = False 84 | text_encoding_D.include_start_token = False 85 | 86 | logging.debug('Declaring models...') 87 | 88 | 89 | # Classic Models 90 | if args.dropout_rate == 0.0: 91 | discriminator = Sequence(Vector(len(text_encoding_D))) >> (Repeat(LSTM(1024), 2) >> Softmax(2)) 92 | generator = Generate(Vector(len(text_encoding_G)) >> Repeat(LSTM(1024), 2) >> Softmax(len(text_encoding_G)), args.sequence_length) 93 | gennet = Sequence(Vector(len(text_encoding_G))) >> Repeat(LSTM(1024), 2) >> Softmax(len(text_encoding_G)) 94 | generator = generator.tie(gennet) 95 | 96 | assert gennet.get_parameters() == generator.get_parameters() 97 | 98 | logging.debug('Declaring GAN...') 99 | gan = gennet >> discriminator.right # Classic 100 | 101 | logging.debug('Compiling GAN...') 102 | adam_G = Adam(gan.left >> Freeze(gan.right) >> CrossEntropy(), 500) # refactor 103 | 104 | logging.debug('Compiling discriminator...') 105 | adam_D = Adam(discriminator >> CrossEntropy(), 500) # refactor 106 | 107 | with open('models/generative/generative-model-2.1.pkl', 'rb') as fp: 108 | generator.set_state(pickle.load(fp)) 109 | 110 | # # with open('models/discriminative/discriminative-model-0.2.1.pkl', 'rb') as fp: 111 | # state = pickle.load(fp) 112 | # state = (state[0][0], (state[0][1], state[1])) 113 | # discriminator.set_state(state) 114 | 115 | with open('models/discriminative/discriminative-model-0.3.1.pkl', 'rb') as fp: 116 | discriminator.set_state(pickle.load(fp)) 117 | 118 | # Dropout Models 119 | elif args.dropout_rate > 0.0 and args.dropout_rate <= 1.0: 120 | rate = args.dropout_rate 121 | 122 | discriminator = Sequence(Vector(len(text_encoding_D))) >> Repeat(LSTM(1024) >> Dropout(rate), 2) >> Softmax(2) 123 | generator = Generate(Vector(len(text_encoding_G)) >> Repeat(LSTM(1024) >> Dropout(rate), 2) >> Softmax(len(text_encoding_G)), args.sequence_length) 124 | gennet = Sequence(Vector(len(text_encoding_G))) >> Repeat(LSTM(1024) >> Dropout(rate), 2) >> Softmax(len(text_encoding_G)) 125 | generator = generator.tie(gennet) 126 | 127 | assert gennet.get_parameters() == generator.get_parameters() 128 | 129 | logging.debug('Declaring GAN...') 130 | gan = gennet >> discriminator.left.right >> discriminator.right # Dropout Hack 131 | 132 | logging.debug('Compiling GAN...') 133 | # adam_G = Adam(CrossEntropy(gan.left >> Freeze(gan.right)), 500) # master (4/9/16) 134 | adam_G = Adam(gan.left >> Freeze(gan.right) >> CrossEntropy(), 500) # refactor 135 | 136 | logging.debug('Compiling discriminator...') 137 | # adam_D = Adam(CrossEntropy(discriminator), 500) # master (4/9/16) 138 | # adam_D = Adam(discriminator >> CrossEntropy(), 500) #refactor 139 | loss_D = AdversarialLoss(discriminator >> CrossEntropy(), discriminator.get_inputs()[0]) 140 | adam_D = Adam(loss_D, 500) 141 | 142 | # Experiment 1 143 | # with open('models/generative/generative-dropout-model-0.0.6.pkl', 'rb') as fp: 144 | # generator.set_state(pickle.load(fp)) 145 | 146 | # with open('models/discriminative/discriminative-adversarial-dropout-model-0.1.0.pkl', 'rb') as fp: 147 | # discriminator.set_state(pickle.load(fp)) 148 | 149 | else: 150 | raise ValueError('Dropout rate must be greater or equal to 0.0 and less than or equal to 1.0') 151 | 152 | 153 | ########### 154 | # Stage II 155 | ########### 156 | def train_generator(max_iterations, step_size, stop_train_loss=1.1): 157 | '''Train the generative model (G) via a GAN framework''' 158 | 159 | avg_loss = [] 160 | with open(args.log, 'a+') as fp: 161 | for i in xrange(max_iterations): 162 | batch = generate_sample(args.batch_size) 163 | starts = np.tile(np.eye(len(text_encoding_D))[0], (1, batch.shape[1], 1)) 164 | batch = np.concatenate([starts, batch])[:-1] 165 | y = np.tile([0, 1], (args.sequence_length, args.batch_size, 1)) 166 | loss = adam_G.train(batch, y, step_size) 167 | # adam_G.loss.model.reset_states() # master (4/9/16) 168 | 169 | if i == 0: 170 | avg_loss.append(loss) 171 | avg_loss.append(loss * 0.05 + avg_loss[-1] * 0.95) 172 | 173 | print >> fp, "Generator Loss[%u]: %f (%f)" % (i, loss, avg_loss[-1]) 174 | print "Generator Loss[%u]: %f (%f)" % (i, loss, avg_loss[-1]) 175 | fp.flush() 176 | 177 | if loss <= stop_train_loss: 178 | return 179 | 180 | 181 | def train_discriminator(max_iterations, step_size, real_reviews, stop_train_loss=0.50): 182 | '''Train the discriminator (D) on real and fake reviews''' 183 | random.seed(1) 184 | 185 | num_reviews = len(real_reviews) 186 | fake_reviews = generate_sample(num_reviews) 187 | 188 | # Load and shuffle reviews 189 | logging.debug("Converting to one-hot...") 190 | batches, targets = [], [] 191 | for i in xrange(len(real_reviews)): 192 | batches.append(np.eye(len(text_encoding_D))[None, CharacterSequence.from_string(real_reviews[i][:args.sequence_length]).encode(text_encoding_D).seq.ravel()]) 193 | # assert batches[-1].shape == (1, args.sequence_length, len(text_encoding_D)), batches[-1].shape 194 | targets.append(np.tile([0, 1], (1, args.sequence_length, 1))) 195 | for i in xrange(len(real_reviews)): 196 | batches.append(fake_reviews[None, :, i]) 197 | # assert batches[-1].shape == (1, args.sequence_length, len(text_encoding_D)), batches[-1].shape 198 | targets.append(np.tile([1, 0], (1, args.sequence_length, 1))) 199 | batches = np.concatenate(batches).swapaxes(0, 1) 200 | targets = np.concatenate(targets).swapaxes(0, 1) 201 | # assert batches.shape == (args.sequence_length, num_reviews * 2, len(text_encoding_D)), batches.shape 202 | # assert targets.shape == (args.sequence_length, num_reviews * 2, 2), targets.shape 203 | 204 | avg_loss = [] 205 | with open(args.log, 'a+') as fp: 206 | for i in xrange(max_iterations): 207 | idx = np.random.permutation(xrange(batches.shape[1]))[:args.batch_size] 208 | X, y = batches[:,idx], targets[:, idx] 209 | loss = adam_D.train(X, y, step_size) 210 | 211 | if i == 0: 212 | avg_loss.append(loss) 213 | avg_loss.append(loss * 0.05 + avg_loss[-1] * 0.95) 214 | 215 | print >> fp, 'Discriminator Loss [%u]: %f (%f)' % (i, loss, avg_loss[-1]) 216 | print 'Discriminator Loss [%u]: %f (%f)' % (i, loss, avg_loss[-1]) 217 | fp.flush() 218 | 219 | if loss <= stop_train_loss: 220 | return 221 | 222 | 223 | def monitor_gan(real_reviews_test, num_reviews = 10): 224 | '''Monitoring function for GAN training. return_str 225 | 226 | 1. real: Avg. log-likelihood attributed to real_reviews 227 | 2. fake: Avg. log-likelihood attributed to fake_reviews 228 | ''' 229 | logging.debug('Monitor performance...') 230 | 231 | last_review = np.random.randint(num_reviews, len(real_reviews_test)) 232 | real_reviews = real_reviews_test[last_review - num_reviews : last_review] 233 | 234 | real_labels = np.asarray([[0,1] for _ in xrange(len(real_reviews))]) 235 | fake_reviews = generate_fake_reviews(num_reviews) 236 | fake_labels = np.asarray([[1,0] for _ in xrange(len(fake_reviews))]) 237 | 238 | real = classification_accuracy(real_reviews, real_labels) 239 | fake = classification_accuracy(fake_reviews, fake_labels) 240 | 241 | return real, fake 242 | 243 | def alternating_gan(num_epoch, dis_iter=1, gen_iter=1, dis_lr=0.0001, gen_lr=0.0001, num_reviews = 1000, seq_length=args.sequence_length, monitor=False): 244 | '''Alternating GAN procedure for jointly training the generator (G) 245 | and the discriminator (D)''' 246 | 247 | logging.debug('Loading real reviews...') 248 | # real_reviews_all = load_reviews('data/real_beer_reviews.txt', seq_length) 249 | real_reviews_all = load_reviews('data/real_beer_reviews_test2.txt', seq_length) 250 | 251 | # real_reviews_train = real_reviews_all[:100000] 252 | # real_reviews_test = real_reviews_all[100000:] 253 | 254 | with open(args.log, 'w') as fp: 255 | print >> fp, 'Alternating GAN for ',num_epoch,' epochs.' 256 | 257 | for i in xrange(num_epoch): 258 | if monitor: 259 | r, f = monitor_gan(real_reviews_test) 260 | print 'Percent correct for real: %f and for fake: %f' % (r, f) 261 | 262 | logging.debug('Training discriminator...') 263 | # last_review = np.random.randint(num_reviews, len(real_reviews_train)) 264 | # real_reviews = real_reviews_train[last_review : last_review + num_reviews] 265 | real_reviews = real_reviews_all 266 | train_discriminator(dis_iter, dis_lr, real_reviews) 267 | 268 | logging.debug('Training generator...') 269 | train_generator(gen_iter, gen_lr) 270 | 271 | logging.debug('Generating new fake reviews...') 272 | fake_reviews = generate_fake_reviews(num_reviews) 273 | 274 | with open('data/gan/gan_adversarial_reviews_current.txt', 'a+') as f: 275 | print >> f, fake_reviews[0] 276 | for review in fake_reviews[:10]: 277 | print review 278 | 279 | # logging.debug('Saving models...') 280 | # with open('models/gan/gan-model-epoch'+str(i)+'.pkl', 'wb') as f: 281 | # pickle.dump(gan.get_state(), f) 282 | 283 | # if i % args.save_model_every == 0: 284 | # with open('models/generative/generative-gan-model-current.pkl', 'wb') as f: 285 | # pickle.dump(generator.get_state(), f) 286 | 287 | # with open('models/discriminative/discriminative-gan-model-current.pkl', 'wb') as f: 288 | # pickle.dump(discriminator.get_state(), f) 289 | -------------------------------------------------------------------------------- /deepx/iterate.py: -------------------------------------------------------------------------------- 1 | #print NumberSequence(generator.predict(np.eye(100)[None,0]).argmx(axis=2).ravel()).decode(text_encoding) 2 | 3 | 4 | def iterate(iterations, step_size, gan): 5 | with open(args.log, 'w') as fp: 6 | for _ in xrange(iterations): 7 | batch = np.tile(text_encoding.convert_representation([text_encoding.encode('')]), (args.batch_size, 1)) 8 | y = np.tile([0, 1], (args.sequence_length, args.batch_size, 1)) 9 | loss = rmsprop.train(batch, y, step_size) 10 | print >> fp, "Loss[%u]: %f" % (_, loss) 11 | print "Loss[%u]: %f" % (_, loss) 12 | fp.flush() 13 | train_loss.append(loss) 14 | 15 | # with open('models/current-gan-model.pkl', 'wb') as fp: 16 | # pickle.dump(gan.get_state(), fp) 17 | -------------------------------------------------------------------------------- /deepx/load_generative_parameters.py: -------------------------------------------------------------------------------- 1 | from deepx.nn import * 2 | from deepx.rnn import * 3 | import cPickle as pickle 4 | 5 | def convert_params(params): 6 | new_params = {} 7 | for param, value in params.items(): 8 | new_params["%s-0" % param] = value.tolist() 9 | return new_params 10 | 11 | if __name__ == "__main__": 12 | with open('data/charnet-top_2-1024-2.pkl', 'rb') as fp: 13 | generative_params = pickle.load(fp) 14 | lstm1 = convert_params(generative_params['lstm']['input_layer']['parameters']) 15 | lstm2 = convert_params(generative_params['lstm']['layers'][0]['parameters']) 16 | softmax = generative_params['output']['parameters'] 17 | 18 | new_state = (({}, (lstm1, lstm2)), softmax) 19 | with open('data/generative-model-original.pkl', 'wb') as fp: 20 | pickle.dump(new_state, fp) 21 | 22 | -------------------------------------------------------------------------------- /deepx/plots/plotting.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import re 4 | from argparse import ArgumentParser 5 | 6 | 7 | def parse_args(): 8 | argparser = ArgumentParser() 9 | argparser.add_argument('log_file', default='../loss/gan_log_current.txt', 10 | help='The training loss log file for the GAN training') 11 | argparser.add_argument('predict_file', default='../data/sequences/predictions.txt', 12 | help='Predictions') 13 | return argparser.parse_args() 14 | 15 | 16 | def plot_gan(args): 17 | '''Plot adversarial training''' 18 | with open(args.log_file, 'r') as f: 19 | next(f) 20 | train_loss = [] 21 | labels = [] 22 | prev_label = None 23 | # cur_label = 1 24 | 25 | for i, line in enumerate(f): 26 | train_loss.append(float(re.search('\((.*?)\)', line).group(1))) 27 | label = re.search('^\w+', line).group() 28 | 29 | 30 | if i == 0 or label != prev_label: 31 | labels.append(label) 32 | 33 | # if label == 'Generator': 34 | # if cur_label % 5 == 0: 35 | # labels.append(cur_label) 36 | # else: 37 | # labels.append('') 38 | # cur_label += 1 39 | 40 | else: 41 | labels.append('') 42 | prev_label = label 43 | 44 | # fig = plt.figure() 45 | # ax = fig.add_subplot(111) 46 | # x_min, x_max = ax.get_xlim() 47 | # ticks_scaled = [(tick - x_min)/(x_max - x_min) for tick in ax.get_xticks()] 48 | # ax.xaxis.set_major_locator(eval(locator)) 49 | # plt.plot((x1, x2), (0.0, 1.0), 'k-') 50 | plt.plot(train_loss) 51 | plt.xticks(np.arange(len(labels)), labels, rotation=75) 52 | plt.title('GAN Adversarial Training Loss') 53 | plt.show() 54 | 55 | 56 | 57 | def discriminator_prediction(args): 58 | '''Plot the discriminator prediction over characters''' 59 | prob, labels = [], [] 60 | 61 | with open(args.predict_file, 'rb') as f: 62 | labels = [r[0] for r in f.read().strip().split('\n')] 63 | with open(args.predict_file, 'rb') as f: 64 | prob = [float(r[3:]) for r in f.read().strip().split('\n')] 65 | 66 | plt.plot(prob) 67 | plt.title('Probability of Real as Function of Text') 68 | plt.xticks(np.arange(len(labels)), labels) 69 | plt.show() 70 | 71 | 72 | 73 | if __name__=='__main__': 74 | args = parse_args() 75 | 76 | # plot_gan(args) 77 | # discriminator_prediction(args) -------------------------------------------------------------------------------- /deepx/refactor_experiments.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import cPickle as pickle 4 | import theano 5 | import sys 6 | import csv 7 | import logging 8 | import random 9 | from dataset import * 10 | 11 | from deepx.nn import * 12 | from deepx.rnn import * 13 | from deepx.loss import * 14 | from deepx.optimize import * 15 | 16 | from batcher import * 17 | from argparse import ArgumentParser 18 | theano.config.on_unused_input = 'ignore' 19 | 20 | logging.basicConfig(level=logging.DEBUG) 21 | 22 | def parse_args(): 23 | argparser = ArgumentParser() 24 | argparser.add_argument("real_file") 25 | argparser.add_argument("fake_file") 26 | argparser.add_argument("--log", default="loss/discriminative/discriminative-adversarial-dropout-loss-0.1.0.txt") 27 | return argparser.parse_args() 28 | 29 | 30 | if __name__ == "__main__": 31 | args = parse_args() 32 | 33 | logging.debug("Reading file...") 34 | with open(args.real_file, 'r') as fp: 35 | real_reviews = [r[3:] for r in fp.read().strip().split('\n')] 36 | real_reviews = [r.replace('\x05', '') for r in real_reviews] 37 | real_reviews = [r.replace('', '') for r in real_reviews] 38 | with open(args.fake_file, 'r') as fp: 39 | fake_reviews = [r[3:] for r in fp.read().strip().split('\n')] 40 | fake_reviews = [r.replace('\x05', '') for r in fake_reviews] 41 | fake_reviews = [r.replace('', '') for r in fake_reviews] 42 | 43 | # Load and shuffle reviews 44 | real_targets, fake_targets = [], [] 45 | for _ in xrange(len(real_reviews)): 46 | real_targets.append([0, 1]) 47 | for _ in xrange(len(fake_reviews)): 48 | fake_targets.append([1, 0]) 49 | 50 | all_reviews = zip(real_reviews, real_targets) + zip(fake_reviews, fake_targets) 51 | 52 | random.seed(1) 53 | random.shuffle(all_reviews) 54 | 55 | reviews, targets = zip(*all_reviews[:150000]) 56 | 57 | logging.debug('Retrieving text encoding...') 58 | with open('data/charnet-encoding.pkl', 'rb') as fp: 59 | text_encoding = pickle.load(fp) 60 | text_encoding.include_stop_token = False 61 | text_encoding.include_start_token = False 62 | 63 | logging.debug("Converting to one-hot...") 64 | review_sequences = [CharacterSequence.from_string(review.replace('', '\x00').replace('', '\x01').replace('>', '').replace('<', '').replace('"','')) for review in reviews] 65 | 66 | num_sequences = [c.encode(text_encoding) for c in review_sequences] 67 | target_sequences = [NumberSequence([target]).replicate(len(r)) for target, r in zip(targets, num_sequences)] 68 | final_seq = NumberSequence(np.concatenate([c.seq.astype(np.int32) for c in num_sequences])) 69 | final_target = NumberSequence(np.concatenate([c.seq.astype(np.int32) for c in target_sequences])) 70 | 71 | # Construct the batcher 72 | batcher = WindowedBatcher([final_seq], [text_encoding], final_target, sequence_length=200, batch_size=100) 73 | 74 | logging.debug("Compiling discriminator...") 75 | 76 | # # Classical 77 | # discriminator = Sequence(Vector(len(text_encoding), batch_size=100)) >> Repeat(LSTM(1024, stateful=True), 2) >> Softmax(2) 78 | 79 | # with open('models/discriminative/discriminative-model-0.0.0.pkl', 'rb') as fp: 80 | # discriminator.set_state(pickle.load(fp)) 81 | 82 | # Dropout 83 | discriminator = Sequence(Vector(len(text_encoding))) >> Repeat(LSTM(1024) >> Dropout(0.5), 2) >> Softmax(2) 84 | with open('models/discriminative/discriminative-dropout-model-0.0.2.pkl', 'rb') as fp: 85 | discriminator.set_state(pickle.load(fp)) 86 | 87 | # Optimization procedure 88 | loss_function = AdversarialLoss(discriminator >> CrossEntropy(), discriminator.get_inputs()[0]) 89 | adam = Adam(loss_function, clip_gradients=500) 90 | 91 | train_loss = [] 92 | def train_discriminator(iterations, step_size): 93 | with open(args.log, 'a+') as fp: 94 | for _ in xrange(iterations): 95 | X, y = batcher.next_batch() 96 | loss = adam.train(X,y,step_size) 97 | print >> fp, "Loss[%u]: %f" % (_, loss) 98 | print "Loss[%u]: %f" % (_, loss) 99 | fp.flush() 100 | train_loss.append(loss) 101 | with open('models/discriminative/discriminative-adversarial-dropout-model-0.1.0.pkl', 'wb') as fp: 102 | pickle.dump(discriminator.get_state(), fp) -------------------------------------------------------------------------------- /deepx/rename_weights.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import cPickle as pickle 3 | 4 | def convert_params(params): 5 | if isinstance(params, tuple): 6 | return (convert_params(params[0]), convert_params(params[1])) 7 | new_params = {} 8 | for param, value in params.items(): 9 | if param[-1].isdigit(): 10 | new_params[param[:-2]] = value 11 | else: 12 | new_params["%s" % param] = value 13 | return new_params 14 | 15 | if __name__ == "__main__": 16 | argparser = ArgumentParser() 17 | argparser.add_argument('weights') 18 | argparser.add_argument('out') 19 | 20 | args = argparser.parse_args() 21 | 22 | with open(args.weights) as fp: 23 | weights = pickle.load(fp) 24 | with open(args.out, 'w') as fp: 25 | pickle.dump(convert_params(weights), fp) -------------------------------------------------------------------------------- /deepx/train_discriminator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cPickle as pickle 3 | import theano 4 | import sys 5 | import csv 6 | import logging 7 | import random 8 | from dataset import * 9 | from batcher import * 10 | from deepx.nn import * 11 | from deepx.rnn import * 12 | from deepx.loss import * 13 | from deepx.optimize import * 14 | from argparse import ArgumentParser 15 | theano.config.on_unused_input = 'ignore' 16 | 17 | logging.basicConfig(level=logging.DEBUG) 18 | 19 | 20 | def parse_args(): 21 | argparser = ArgumentParser() 22 | argparser.add_argument("real_file") 23 | argparser.add_argument("fake_file") 24 | argparser.add_argument("--log", default="loss/discriminative/discriminative-adversarial-loss-0.0.0.txt") 25 | return argparser.parse_args() 26 | 27 | 28 | def generate(length, temperature): 29 | results = charrnn.generate( 30 | np.eye(len(encoding))[encoding.encode("i")],length,temperature).argmax(axis=1) 31 | return NumberSequence(results).decode(encoding) 32 | 33 | 34 | # def create_data_batcher(reviews, targets, encoding, sequence_length=200, batch_size=100): 35 | # '''Create a batcher for a set of reviews and targets given a text encoding''' 36 | # logging.debug('Converting to one-hot...') 37 | # review_seq = [CharacterSequence.from_string(review.replace('', '\x00').replace('', '\x01')) for review in reviews] 38 | 39 | # num_seq = [c.encode(encoding) for c in review_seq] 40 | # target_seq = [NumberSequence([target]).replicate(len(r)) for target, r in zip(targets, num_seq)] 41 | 42 | # final_seq = NumberSequence(np.concatenate([c.seq.astype(np.int32) for c in num_seq])) 43 | # final_target = NumberSequence(np.concatenate([c.seq.astype(np.int32) for c in target_seq])) 44 | 45 | # batcher = WindowedBatcher([final_seq], [encoding], final_target, sequence_length, batch_size) 46 | # return batcher 47 | 48 | 49 | if __name__ == "__main__": 50 | args = parse_args() 51 | 52 | logging.debug("Reading file...") 53 | with open(args.real_file, 'r') as fp: 54 | real_reviews = [r[3:] for r in fp.read().strip().split('\n')] 55 | real_reviews = [r.replace('\x05', '') for r in real_reviews] 56 | real_reviews = [r.replace('', '') for r in real_reviews] 57 | with open(args.fake_file, 'r') as fp: 58 | fake_reviews = [r[3:] for r in fp.read().strip().split('\n')] 59 | fake_reviews = [r.replace('\x05', '') for r in fake_reviews] 60 | fake_reviews = [r.replace('', '') for r in fake_reviews] 61 | 62 | # Load and shuffle reviews 63 | real_targets, fake_targets = [], [] 64 | for _ in xrange(len(real_reviews)): 65 | real_targets.append([0, 1]) 66 | for _ in xrange(len(fake_reviews)): 67 | fake_targets.append([1, 0]) 68 | 69 | all_reviews = zip(real_reviews, real_targets) + zip(fake_reviews, fake_targets) 70 | 71 | random.seed(1) 72 | random.shuffle(all_reviews) 73 | 74 | reviews, targets = zip(*all_reviews[:150000]) 75 | # test_reviews, test_targets = zip(*all_reviews[500000:]) 76 | 77 | logging.debug('Retrieving text encoding...') 78 | with open('data/charnet-encoding.pkl', 'rb') as fp: 79 | text_encoding = pickle.load(fp) 80 | text_encoding.include_stop_token = False 81 | text_encoding.include_start_token = False 82 | 83 | logging.debug("Converting to one-hot...") 84 | review_sequences = [CharacterSequence.from_string(review.replace('', '\x00').replace('', '\x01').replace('>', '').replace('<', '').replace('"','')) for review in reviews] 85 | 86 | num_sequences = [c.encode(text_encoding) for c in review_sequences] 87 | target_sequences = [NumberSequence([target]).replicate(len(r)) for target, r in zip(targets, num_sequences)] 88 | final_seq = NumberSequence(np.concatenate([c.seq.astype(np.int32) for c in num_sequences])) 89 | final_target = NumberSequence(np.concatenate([c.seq.astype(np.int32) for c in target_sequences])) 90 | 91 | # Construct the batcher 92 | batcher = WindowedBatcher([final_seq], [text_encoding], final_target, sequence_length=200, batch_size=100) 93 | 94 | logging.debug("Compiling discriminator...") 95 | 96 | ################################# 97 | # Classic Training Discriminator 98 | ################################# 99 | # discriminator = Sequence(Vector(len(text_encoding), batch_size=100)) >> Repeat(LSTM(1024, stateful=True), 2) >> Softmax(2) 100 | 101 | # with open('models/discriminative/discriminative-model-0.0.renamed.pkl', 'rb') as fp: 102 | # # with open('models/discriminative/discriminative-model-1.0.pkl', 'rb') as fp: 103 | # discriminator.set_state(pickle.load(fp)) 104 | 105 | #################################### 106 | # Dropout Training of Discriminator 107 | #################################### 108 | dropout_lstm = LSTM(1024, stateful=True) >> Dropout(0.5) 109 | discriminator = Sequence(Vector(len(text_encoding), batch_size=100)) >> Repeat(dropout_lstm, 2) >> Softmax(2) 110 | 111 | # Load dropout generator weights into discriminator 112 | generator = Sequence(Vector(len(text_encoding), batch_size=100)) >> Repeat(dropout_lstm, 2) >> Softmax(len(text_encoding)) 113 | with open('models/generative/generative-dropout-model-0.0.5.pkl') as f: 114 | generator.set_state(pickle.load(f)) 115 | 116 | discriminator.left.set_state(generator.left.get_state()) 117 | 118 | # Optimization procedure 119 | loss_function = AdversarialLoss(CrossEntropy(discriminator)) 120 | adam = Adam(loss_function, clip_gradients=500) 121 | 122 | # Training loss 123 | train_loss = [] 124 | def train_discriminator(iterations, step_size): 125 | with open(args.log, 'a+') as fp: 126 | for _ in xrange(iterations): 127 | X, y = batcher.next_batch() 128 | loss = adam.train(X,y,step_size) 129 | print >> fp, "Loss[%u]: %f" % (_, loss) 130 | print "Loss[%u]: %f" % (_, loss) 131 | fp.flush() 132 | train_loss.append(loss) 133 | with open('models/discriminative/discriminative-adversarial-model-0.0.0.pkl', 'wb') as fp: 134 | pickle.dump(discriminator.get_state(), fp) 135 | 136 | 137 | # Train accuracy 138 | def calculate_accuracy(num_batches=100): 139 | '''Calculate the training accuracy over number of batches. Calculates 140 | accuracy based on prediction at final time-step''' 141 | errors = 0 142 | total = 0 143 | 144 | for _ in xrange(num_batches): 145 | X, y = batcher.next_batch() 146 | pred = discriminator.predict(X) 147 | 148 | # Retrieve last label 149 | last_y = y[-1, :, :] 150 | last_p = pred[-1, :, :] 151 | 152 | errors += np.count_nonzero(last_y.argmax(axis=1) - last_p.argmax(axis=1)) 153 | total += batcher.batch_size 154 | 155 | return 1.0 - float(errors)/float(total) 156 | 157 | 158 | 159 | 160 | 161 | -------------------------------------------------------------------------------- /deepx/train_generator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cPickle as pickle 3 | import theano 4 | import sys 5 | import csv 6 | import logging 7 | import random 8 | from dataset import * 9 | from deepx.nn import * 10 | from deepx.rnn import * 11 | from deepx.loss import * 12 | from deepx.optimize import * 13 | from argparse import ArgumentParser 14 | theano.config.on_unused_input = 'ignore' 15 | 16 | logger = logging.getLogger() 17 | logger.setLevel(logging.DEBUG) 18 | 19 | 20 | def parse_args(): 21 | argparser = ArgumentParser() 22 | argparser.add_argument("reviews") 23 | argparser.add_argument("--loss_log", default="loss/generative/generative_loss_current.txt") 24 | return argparser.parse_args() 25 | 26 | 27 | class WindowedBatcher(object): 28 | 29 | def __init__(self, sequences, encodings, batch_size=10, sequence_length=200): 30 | self.sequences = sequences 31 | 32 | self.pre_vector_sizes = [c.seq[0].shape[0] for c in self.sequences] 33 | self.pre_vector_size = sum(self.pre_vector_sizes) 34 | 35 | self.encodings = encodings 36 | self.vocab_sizes = [c.index for c in self.encodings] 37 | self.vocab_size = sum(self.vocab_sizes) 38 | self.batch_index = 0 39 | self.batches = [] 40 | self.batch_size = batch_size 41 | self.sequence_length = sequence_length + 1 42 | self.length = len(self.sequences[0]) 43 | 44 | self.batch_index = 0 45 | self.X = np.zeros((self.length, self.pre_vector_size)) 46 | self.X = np.hstack([c.seq for c in self.sequences]) 47 | 48 | N, D = self.X.shape 49 | assert N > self.batch_size * self.sequence_length, "File has to be at least %u characters" % (self.batch_size * self.sequence_length) 50 | 51 | self.X = self.X[:N - N % (self.batch_size * self.sequence_length)] 52 | self.N, self.D = self.X.shape 53 | self.X = self.X.reshape((self.N / self.sequence_length, self.sequence_length, self.D)) 54 | 55 | self.N, self.S, self.D = self.X.shape 56 | 57 | self.num_sequences = self.N / self.sequence_length 58 | self.num_batches = self.N / self.batch_size 59 | self.batch_cache = {} 60 | 61 | def next_batch(self): 62 | idx = (self.batch_index * self.batch_size) 63 | if self.batch_index >= self.num_batches: 64 | self.batch_index = 0 65 | idx = 0 66 | 67 | if self.batch_index in self.batch_cache: 68 | batch = self.batch_cache[self.batch_index] 69 | self.batch_index += 1 70 | return batch 71 | 72 | X = self.X[idx:idx + self.batch_size] 73 | y = np.zeros((X.shape[0], self.sequence_length, self.vocab_size)) 74 | for i in xrange(self.batch_size): 75 | for c in xrange(self.sequence_length): 76 | seq_splits = np.split(X[i, c], np.cumsum(self.pre_vector_sizes)) 77 | vec = np.concatenate([e.convert_representation(split) for 78 | e, split in zip(self.encodings, seq_splits)]) 79 | y[i, c] = vec 80 | 81 | X = y[:, :-1, :] 82 | y = y[:, 1:, :self.vocab_sizes[0]] 83 | 84 | X = np.swapaxes(X, 0, 1) 85 | y = np.swapaxes(y, 0, 1) 86 | self.batch_index += 1 87 | return X, y 88 | 89 | 90 | def generate_number_samples(num_reviews): 91 | '''Generate a batch of samples from the current version of the generator''' 92 | pred_seq = generator_sample.predict(np.tile(np.eye(100)[0], (num_reviews, 1))) 93 | return pred_seq 94 | 95 | 96 | def generate_text_samples(num_reviews): 97 | '''Generate fake reviews using the current generator''' 98 | pred_seq = generate_number_samples(num_reviews).argmax(axis=2).T 99 | num_seq = [NumberSequence(pred_seq[i]).decode(text_encoding) for i in xrange(num_reviews)] 100 | return_str = [''.join(n.seq) for n in num_seq] 101 | return return_str 102 | 103 | 104 | def generate_training_set(gan_versions=10, reviews_per_gan=3000, train_iter=100, step_size=10): 105 | '''Generate a reviews classically Note: Reviews may contain non-unicode characters''' 106 | with open('data/fake_beer_reviews_2.1_30000.txt', 'wb') as f: 107 | for i in xrange(gan_versions): 108 | logging.debug('Generating reviews...') 109 | reviews = generate_text_samples(reviews_per_gan) 110 | 111 | logging.debug('Appending reviews to file...') 112 | for review in reviews: 113 | print >> f, review 114 | 115 | logging.debug('Training generator...') 116 | train_generator(train_iter, step_size) 117 | 118 | 119 | if __name__ == '__main__': 120 | args = parse_args() 121 | 122 | logging.debug('Reading file...') 123 | with open(args.reviews, 'r') as f: 124 | reviews = [r[3:] for r in f.read().strip().split('\n')] 125 | reviews = [r.replace('\x05', '') for r in reviews] 126 | reviews = [r.replace('', '') for r in reviews] 127 | 128 | logging.debug('Retrieving text encoding...') 129 | with open('data/charnet-encoding.pkl', 'rb') as fp: 130 | text_encoding = pickle.load(fp) 131 | 132 | # Create reviews and targets 133 | logging.debug('Converting to one-hot...') 134 | review_sequences = [CharacterSequence.from_string(r) for r in reviews] 135 | num_sequences = [c.encode(text_encoding) for c in review_sequences] 136 | final_sequences = NumberSequence(np.concatenate([c.seq.astype(np.int32) for c in num_sequences])) 137 | 138 | # Batcher and generator 139 | batcher = WindowedBatcher([final_sequences], [text_encoding], sequence_length=200, batch_size=100) 140 | 141 | 142 | ############################# 143 | # Classic Training Generator 144 | ############################# 145 | generator = Sequence(Vector(len(text_encoding), batch_size=100)) >> Repeat(LSTM(1024, stateful=True), 2) >> Softmax(len(text_encoding)) 146 | generator_sample = Generate(Vector(len(text_encoding)) >> Repeat(LSTM(1024), 2) >> Softmax(len(text_encoding)), 500) 147 | 148 | # Tie the weights 149 | generator_sample = generator_sample.tie(generator) 150 | 151 | logging.debug('Loading prior model...') 152 | with open('models/generative/generative-model-2.0.1.pkl', 'rb') as fp: 153 | generator.set_state(pickle.load(fp)) 154 | 155 | 156 | ################################## 157 | # Dropout Training of Generator 158 | ################################## 159 | # dropout_lstm = LSTM(1024, stateful=True) >> Dropout(0.5) 160 | # generator = Sequence(Vector(len(text_encoding), batch_size=100)) >> Repeat(dropout_lstm, 2) >> Softmax(len(text_encoding)) 161 | # generator_sample = Generate(Vector(len(text_encoding)) >> Repeat(LSTM(1024) >> Dropout(0.5), 2) >> Softmax(len(text_encoding)), 500) 162 | 163 | # # Tie the weights 164 | # generator_sample = generator_sample.tie(generator) 165 | 166 | # logging.debug('Loading prior model...') 167 | # with open('models/generative/generative-dropout-model-0.0.5.pkl', 'rb') as fp: 168 | # generator.set_state(pickle.load(fp)) 169 | 170 | 171 | logging.debug('Compiling graph...') 172 | # loss_function = CrossEntropy(generator) 173 | loss_function = AdversarialLoss(CrossEntropy(generator)) 174 | adam = Adam(loss_function, clip_gradients=500) 175 | 176 | def train_generator(iterations, step_size): 177 | with open(args.loss_log, 'a+') as f: 178 | for _ in xrange(iterations): 179 | X, y = batcher.next_batch() 180 | # grads = rmsprop.gradient(X, y) 181 | # if grads: 182 | # for g in grads: 183 | # print np.linalg.norm(np.asarray(g)) 184 | loss = adam.train(X, y, step_size) 185 | 186 | print >> f, 'Loss[%u]: %f' % (_, loss) 187 | print 'Loss[%u]: %f' % (_, loss) 188 | f.flush() 189 | 190 | with open('models/generative/generative-adversarial-model-0.0.0.pkl', 'wb') as g: 191 | pickle.dump(generator.get_state(), g) 192 | 193 | -------------------------------------------------------------------------------- /deepx/utils.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | def load_reviews(file_dir, min_sequence_length=200): 5 | '''Loads list of reviews from file_dir''' 6 | with open(file_dir, 'rb') as f: 7 | reviews = [r[3:] for r in f.read().strip().split('\n')] 8 | reviews = [r.replace('\x05', '') for r in reviews] 9 | reviews = [r.replace('', '') for r in reviews] 10 | reviews = [r.replace('<', '') for r in reviews] 11 | reviews = [r.replace('>', '') for r in reviews] 12 | reviews = [r for r in reviews if len(r) >= min_sequence_length] 13 | return reviews 14 | 15 | 16 | def write_predictions_to_file(text, file_dir='data/sequences/predictions.txt'): 17 | '''Write predictions of real probabiltiy of sequence to file''' 18 | text = text.replace('','').replace('','') 19 | 20 | prob = predict(text)[:, 0, 1].tolist() 21 | 22 | with open(file_dir, 'w') as f: 23 | for i in xrange(len(text)): 24 | print >> f, '%s, %f' % (text[i], prob[i]) 25 | 26 | -------------------------------------------------------------------------------- /tensorflow/README.md: -------------------------------------------------------------------------------- 1 | # Generative Adversarial Networks (GAN) 2 | GAN training framework for natural language in base-TensorFlow. The models are initially adopted from [char-rnn-tensorflow](https://github.com/sherjilozair/char-rnn-tensorflow). 3 | 4 | # Requirements 5 | TensorFlow master branch required (using new distributions not within latest 0.9 release) 6 | - [Tensorflow](http://www.tensorflow.org) 7 | 8 | -------------------------------------------------------------------------------- /tensorflow/batcher.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import collections 4 | import cPickle 5 | import logging 6 | 7 | 8 | logger = logging.getLogger() 9 | logger.setLevel(logging.DEBUG) 10 | 11 | 12 | class Batcher(object): 13 | def __init__(self, data_dir, batch_size, seq_length): 14 | self.batch_size = batch_size 15 | self.seq_length = seq_length 16 | 17 | input_file = os.path.join(data_dir, 'real_beer_reviews.txt') 18 | vocab_file = os.path.join(data_dir, 'real_beer_vocab.pkl') 19 | tensor_file = os.path.join(data_dir, 'real_beer_data.npy') 20 | 21 | if not (os.path.exists(vocab_file) and os.path.exists(tensor_file)): 22 | self.preprocess(input_file, vocab_file, tensor_file) 23 | else: 24 | self.load_preprocessed(vocab_file, tensor_file) 25 | 26 | self.create_batches() 27 | self.reset_batch_pointer() 28 | 29 | def preprocess(self, input_file, vocab_file, tensor_file): 30 | logging.debug('Reading text file...') 31 | with open(input_file, 'r') as f: 32 | data = f.read() 33 | counter = collections.Counter(data) 34 | count_pairs = sorted(counter.items(), key=lambda x: -x[1]) 35 | self.chars, _ = list(zip(*count_pairs)) 36 | self.vocab_size = len(self.chars) 37 | self.vocab = dict(zip(self.chars, range(len(self.chars)))) 38 | with open(vocab_file, 'w') as f: 39 | cPickle.dump(self.chars, f) 40 | self.tensor = np.array(map(self.vocab.get, data)) 41 | np.save(tensor_file, self.tensor) 42 | 43 | def load_preprocessed(self, vocab_file, tensor_file): 44 | logging.debug('Loading preprocessed files...') 45 | with open(vocab_file, 'r') as f: 46 | self.chars = cPickle.load(f) 47 | self.vocab_size = len(self.chars) 48 | self.vocab = dict(zip(self.chars, range(len(self.chars)))) 49 | self.tensor = np.load(tensor_file) 50 | 51 | def create_batches(self): 52 | logging.debug('Creating batches...') 53 | self.num_batches = self.tensor.size / (self.batch_size * self.seq_length) 54 | self.tensor = self.tensor[:self.num_batches * self.batch_size * self.seq_length] 55 | x_data = self.tensor 56 | y_data = np.copy(self.tensor) 57 | y_data[:-1] = x_data[1:] # Labels are simply the next char 58 | y_data[-1] = x_data[0] 59 | self.x_batches = np.split(x_data.reshape(self.batch_size, -1), self.num_batches, 1) 60 | self.y_batches = np.split(y_data.reshape(self.batch_size, -1), self.num_batches, 1) 61 | 62 | def next_batch(self): 63 | x, y = self.x_batches[self.pointer], self.y_batches[self.pointer] 64 | self.pointer += 1 65 | return x, y 66 | 67 | def reset_batch_pointer(self): 68 | self.pointer = 0 69 | 70 | 71 | class DiscriminatorBatcher(object): 72 | def __init__(self, data_dir, batch_size, seq_length): 73 | self.batch_size = batch_size 74 | self.seq_length = seq_length 75 | 76 | real_file = os.path.join(data_dir, 'real_beer_reviews.txt') 77 | fake_file = os.path.join(data_dir, 'fake_beer_reviews.txt') 78 | real_tensor = os.path.join(data_dir, 'real_beer_data_v0.1.npy') 79 | fake_tensor = os.path.join(data_dir, 'fake_beer_data_v0.1.npy') 80 | vocab_file = os.path.join(data_dir, 'combined_vocab.pkl') 81 | 82 | if not (os.path.exists(vocab_file) and os.path.exists(real_tensor) and os.path.exists(fake_tensor)): 83 | self.preprocess(real_file, fake_file, vocab_file, real_tensor, fake_tensor) 84 | else: 85 | self.load_preprocessed(vocab_file, real_tensor, fake_tensor) 86 | 87 | self.create_batches() 88 | self.reset_batch_pointer() 89 | 90 | def preprocess(self, real_file, fake_file, vocab_file, tensor_file_real, tensor_file_fake): 91 | logging.debug('Preprocessing...') 92 | with open(real_file, 'r') as f: 93 | data_real = f.read() 94 | with open(fake_file, 'r') as f: 95 | data_fake = f.read() 96 | data = data_real + data_fake 97 | counter = collections.Counter(data) 98 | count_pairs = sorted(counter.items(), key=lambda x: -x[1]) 99 | self.chars, _ = list(zip(*count_pairs)) 100 | self.vocab_size = len(self.chars) 101 | self.vocab = dict(zip(self.chars, range(len(self.chars)))) 102 | with open(vocab_file, 'w') as f: 103 | cPickle.dump(self.chars, f) 104 | 105 | def build_tensor(tensor_file, data_str): 106 | tensor = np.array(map(self.vocab.get, data_str)) 107 | np.save(tensor_file, tensor) 108 | return tensor 109 | 110 | self.tensor_real = build_tensor(tensor_file_real, data_real) 111 | self.tensor_fake = build_tensor(tensor_file_fake, data_fake) 112 | np.save(tensor_file_real, self.tensor_real) 113 | np.save(tensor_file_fake, self.tensor_fake) 114 | 115 | def load_preprocessed(self, vocab_file, tensor_file_real, tensor_file_fake): 116 | logging.debug('Loading preprocessed files...') 117 | with open(vocab_file, 'r') as f: 118 | self.chars = cPickle.load(f) 119 | self.vocab_size = len(self.chars) 120 | self.vocab = dict(zip(self.chars, range(len(self.chars)))) 121 | self.tensor_real = np.load(tensor_file_real) 122 | self.tensor_fake = np.load(tensor_file_fake) 123 | 124 | def create_batches(self): 125 | logging.debug('Creating batches...') 126 | 127 | # Real batches 128 | num_batches = self.tensor_real.size / (self.batch_size / 2 * self.seq_length) 129 | self.tensor_real = self.tensor_real[:num_batches * self.batch_size / 2 * self.seq_length] 130 | x_data_real = self.tensor_real 131 | y_data_real = np.ones((len(x_data_real), 1)) 132 | x_batches_real = np.split(x_data_real.reshape(self.batch_size / 2, -1), num_batches, 1) 133 | y_batches_real = np.split(y_data_real.reshape(self.batch_size / 2, -1), num_batches, 1) 134 | batches_real = [np.hstack([x, y]) for x, y in zip(x_batches_real, y_batches_real)] 135 | 136 | # Fake batches 137 | num_batches = self.tensor_fake.size / (self.batch_size / 2 * self.seq_length) 138 | self.tensor_fake = self.tensor_fake[:num_batches * self.batch_size / 2 * self.seq_length] 139 | x_data_fake = self.tensor_fake 140 | y_data_fake = np.zeros((len(x_data_fake), 1)) 141 | x_batches_fake = np.split(x_data_fake.reshape(self.batch_size / 2, -1), num_batches, 1) 142 | y_batches_fake = np.split(y_data_fake.reshape(self.batch_size / 2, -1), num_batches, 1) 143 | batches_fake = [np.hstack([x, y]) for x, y in zip(x_batches_fake, y_batches_fake)] 144 | 145 | # Combine batches 146 | batches = [np.vstack((real, fake)) for real, fake in zip(batches_real, batches_fake)] 147 | for arr in batches: 148 | np.random.shuffle(arr) 149 | self.x_batches = [arr[:, :self.seq_length] for arr in batches] 150 | self.y_batches = [arr[:, self.seq_length:] for arr in batches] 151 | self.num_batches = len(batches) 152 | 153 | def next_batch(self): 154 | x, y = self.x_batches[self.pointer], self.y_batches[self.pointer] 155 | self.pointer += 1 156 | return x, y 157 | 158 | def reset_batch_pointer(self): 159 | self.pointer = 0 160 | 161 | 162 | class GANBatcher(object): 163 | def __init__(self, data_dir, batch_size, seq_length): 164 | self.batch_size = batch_size 165 | self.seq_length = seq_length 166 | 167 | input_file = os.path.join(data_dir, 'simple_reviews.txt') 168 | vocab_file = os.path.join(data_dir, 'simple_vocab.pkl') 169 | tensor_file = os.path.join(data_dir, 'simple_data.npy') 170 | 171 | if not (os.path.exists(vocab_file) and os.path.exists(tensor_file)): 172 | self.preprocess(input_file, vocab_file, tensor_file) 173 | else: 174 | self.load_preprocessed(vocab_file, tensor_file) 175 | 176 | self.create_batches() 177 | self.reset_batch_pointer() 178 | 179 | def preprocess(self, input_file, vocab_file, tensor_file): 180 | logging.debug('Reading text file...') 181 | with open(input_file, 'r') as f: 182 | data = f.read() 183 | counter = collections.Counter(data) 184 | count_pairs = sorted(counter.items(), key=lambda x: -x[1]) 185 | self.chars, _ = list(zip(*count_pairs)) 186 | self.vocab_size = len(self.chars) 187 | self.vocab = dict(zip(self.chars, range(len(self.chars)))) 188 | with open(vocab_file, 'w') as f: 189 | cPickle.dump(self.chars, f) 190 | self.tensor = np.array(map(self.vocab.get, data)) 191 | np.save(tensor_file, self.tensor) 192 | 193 | def load_preprocessed(self, vocab_file, tensor_file): 194 | logging.debug('Loading preprocessed files...') 195 | with open(vocab_file, 'r') as f: 196 | self.chars = cPickle.load(f) 197 | self.vocab_size = len(self.chars) 198 | self.vocab = dict(zip(self.chars, range(len(self.chars)))) 199 | self.tensor = np.load(tensor_file) 200 | 201 | def create_batches(self): 202 | logging.debug('Creating batches...') 203 | self.num_batches = self.tensor.size / (self.batch_size * self.seq_length) 204 | self.tensor = self.tensor[:self.num_batches * self.batch_size * self.seq_length] 205 | x_data = self.tensor 206 | y_data = np.copy(self.tensor) 207 | y_data[:-1] = x_data[1:] # Labels are simply the next char 208 | y_data[-1] = x_data[0] 209 | self.x_batches = np.split(x_data.reshape(self.batch_size, -1), self.num_batches, 1) 210 | self.y_batches = np.split(y_data.reshape(self.batch_size, -1), self.num_batches, 1) 211 | 212 | def next_batch(self): 213 | x, y = self.x_batches[self.pointer], self.y_batches[self.pointer] 214 | self.pointer += 1 215 | return x, y 216 | 217 | def reset_batch_pointer(self): 218 | self.pointer = 0 -------------------------------------------------------------------------------- /tensorflow/batcher_gan.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import collections 4 | import cPickle 5 | import logging 6 | 7 | 8 | logger = logging.getLogger() 9 | logger.setLevel(logging.DEBUG) 10 | 11 | 12 | class DiscriminatorBatcher(object): 13 | def __init__(self, real_file, fake_file, data_dir, vocab_file, batch_size, seq_length): 14 | self.batch_size = batch_size 15 | self.seq_length = seq_length 16 | 17 | real_file = os.path.join(data_dir, real_file) 18 | fake_file = os.path.join(data_dir, fake_file) 19 | real_tensor = os.path.join(data_dir, 'real_data.npy') 20 | fake_tensor = os.path.join(data_dir, 'fake_data.npy') 21 | vocab_fiel = os.path.join(data_dir, vocab_file) 22 | 23 | # if not (os.path.exists(vocab_file) and os.path.exists(real_tensor) and os.path.exists(fake_tensor)): 24 | # self.preprocess(real_file, fake_file, vocab_file, real_tensor, fake_tensor) 25 | # else: 26 | # self.load_preprocessed(vocab_file, real_tensor, fake_tensor) 27 | 28 | self.preprocess(real_file, fake_file, vocab_file, real_tensor, fake_tensor) 29 | 30 | self.create_batches() 31 | self.reset_batch_pointer() 32 | 33 | def preprocess(self, real_file, fake_file, vocab_file, tensor_file_real, tensor_file_fake): 34 | logging.debug('Preprocessing...') 35 | with open(real_file, 'r') as f: 36 | data_real = f.read() 37 | with open(fake_file, 'r') as f: 38 | data_fake = f.read() 39 | data = data_real + data_fake 40 | counter = collections.Counter(data) 41 | count_pairs = sorted(counter.items(), key=lambda x: -x[1]) 42 | self.chars, _ = list(zip(*count_pairs)) 43 | self.vocab_size = len(self.chars) 44 | self.vocab = dict(zip(self.chars, range(len(self.chars)))) 45 | with open(vocab_file, 'w') as f: 46 | cPickle.dump(self.chars, f) 47 | 48 | def build_tensor(tensor_file, data_str): 49 | tensor = np.array(map(self.vocab.get, data_str)) 50 | np.save(tensor_file, tensor) 51 | return tensor 52 | 53 | self.tensor_real = build_tensor(tensor_file_real, data_real) 54 | self.tensor_fake = build_tensor(tensor_file_fake, data_fake) 55 | np.save(tensor_file_real, self.tensor_real) 56 | np.save(tensor_file_fake, self.tensor_fake) 57 | 58 | # def load_preprocessed(self, vocab_file, tensor_file_real, tensor_file_fake): 59 | # logging.debug('Loading preprocessed files...') 60 | # with open(vocab_file, 'r') as f: 61 | # self.chars = cPickle.load(f) 62 | # self.vocab_size = len(self.chars) 63 | # self.vocab = dict(zip(self.chars, range(len(self.chars)))) 64 | # self.tensor_real = np.load(tensor_file_real) 65 | # self.tensor_fake = np.load(tensor_file_fake) 66 | 67 | def create_batches(self): 68 | logging.debug('Creating batches...') 69 | 70 | # Real batches 71 | num_batches = self.tensor_real.size / (self.batch_size / 2 * self.seq_length) 72 | self.tensor_real = self.tensor_real[:num_batches * self.batch_size / 2 * self.seq_length] 73 | x_data_real = self.tensor_real 74 | y_data_real = np.ones((len(x_data_real), 1)) 75 | x_batches_real = np.split(x_data_real.reshape(self.batch_size / 2, -1), num_batches, 1) 76 | y_batches_real = np.split(y_data_real.reshape(self.batch_size / 2, -1), num_batches, 1) 77 | batches_real = [np.hstack([x, y]) for x, y in zip(x_batches_real, y_batches_real)] 78 | 79 | # Fake batches 80 | num_batches = self.tensor_fake.size / (self.batch_size / 2 * self.seq_length) 81 | self.tensor_fake = self.tensor_fake[:num_batches * self.batch_size / 2 * self.seq_length] 82 | x_data_fake = self.tensor_fake 83 | y_data_fake = np.zeros((len(x_data_fake), 1)) 84 | x_batches_fake = np.split(x_data_fake.reshape(self.batch_size / 2, -1), num_batches, 1) 85 | y_batches_fake = np.split(y_data_fake.reshape(self.batch_size / 2, -1), num_batches, 1) 86 | batches_fake = [np.hstack([x, y]) for x, y in zip(x_batches_fake, y_batches_fake)] 87 | 88 | # Combine batches 89 | batches = [np.vstack((real, fake)) for real, fake in zip(batches_real, batches_fake)] 90 | for arr in batches: 91 | np.random.shuffle(arr) 92 | self.x_batches = [arr[:, :self.seq_length] for arr in batches] 93 | self.y_batches = [arr[:, self.seq_length:] for arr in batches] 94 | self.num_batches = len(batches) 95 | 96 | def next_batch(self): 97 | x, y = self.x_batches[self.pointer], self.y_batches[self.pointer] 98 | self.pointer += 1 99 | return x, y 100 | 101 | def reset_batch_pointer(self): 102 | self.pointer = 0 103 | 104 | 105 | class GANBatcher(object): 106 | def __init__(self, input_file, vocab_file, data_dir, batch_size, seq_length): 107 | self.batch_size = batch_size 108 | self.seq_length = seq_length 109 | 110 | input_file = os.path.join(data_dir, input_file) 111 | vocab_file = os.path.join(data_dir, vocab_file) 112 | tensor_file = os.path.join(data_dir, 'simple_data.npy') 113 | 114 | if not (os.path.exists(vocab_file) and os.path.exists(tensor_file)): 115 | self.preprocess(input_file, vocab_file, tensor_file) 116 | else: 117 | self.load_preprocessed(vocab_file, tensor_file) 118 | 119 | self.create_batches() 120 | self.reset_batch_pointer() 121 | 122 | def preprocess(self, input_file, vocab_file, tensor_file): 123 | logging.debug('Reading text file...') 124 | with open(input_file, 'r') as f: 125 | data = f.read() 126 | counter = collections.Counter(data) 127 | count_pairs = sorted(counter.items(), key=lambda x: -x[1]) 128 | self.chars, _ = list(zip(*count_pairs)) 129 | self.vocab_size = len(self.chars) 130 | self.vocab = dict(zip(self.chars, range(len(self.chars)))) 131 | with open(vocab_file, 'w') as f: 132 | cPickle.dump(self.chars, f) 133 | self.tensor = np.array(map(self.vocab.get, data)) 134 | np.save(tensor_file, self.tensor) 135 | 136 | def load_preprocessed(self, vocab_file, tensor_file): 137 | logging.debug('Loading preprocessed files...') 138 | with open(vocab_file, 'r') as f: 139 | self.chars = cPickle.load(f) 140 | self.vocab_size = len(self.chars) 141 | self.vocab = dict(zip(self.chars, range(len(self.chars)))) 142 | self.tensor = np.load(tensor_file) 143 | 144 | def create_batches(self): 145 | logging.debug('Creating batches...') 146 | self.num_batches = self.tensor.size / (self.batch_size * self.seq_length) 147 | self.tensor = self.tensor[:self.num_batches * self.batch_size * self.seq_length] 148 | x_data = self.tensor 149 | y_data = np.copy(self.tensor) 150 | y_data[:-1] = x_data[1:] # Labels are simply the next char 151 | y_data[-1] = x_data[0] 152 | self.x_batches = np.split(x_data.reshape(self.batch_size, -1), self.num_batches, 1) 153 | self.y_batches = np.split(y_data.reshape(self.batch_size, -1), self.num_batches, 1) 154 | 155 | def next_batch(self): 156 | x, y = self.x_batches[self.pointer], self.y_batches[self.pointer] 157 | self.pointer += 1 158 | return x, y 159 | 160 | def reset_batch_pointer(self): 161 | self.pointer = 0 -------------------------------------------------------------------------------- /tensorflow/discriminator.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from tensorflow.python.ops.nn import rnn_cell 4 | from tensorflow.python.ops.nn import rnn 5 | from tensorflow.python.ops.nn import seq2seq 6 | from tensorflow.python.ops import array_ops 7 | from tensorflow.python.ops import nn_ops 8 | from tensorflow.python.ops import math_ops 9 | from tensorflow.python.framework import ops 10 | 11 | class Discriminator(object): 12 | def __init__(self, args, is_training=True): 13 | self.args = args 14 | 15 | if not is_training: 16 | args.batch_size = 1 17 | args.seq_length = 1 18 | 19 | if args.model == 'rnn': 20 | self.cell = rnn_cell.BasicRNNCell(args.rnn_size) 21 | elif args.model == 'gru': 22 | self.cell = rnn_cell.GRUCell(args.rnn_size) 23 | elif args.model == 'lstm': 24 | self.cell = rnn_cell.BasicLSTMCell(args.rnn_size) 25 | else: 26 | raise Exception('model type not supported: {}'.format(args.model)) 27 | 28 | self.cell = rnn_cell.MultiRNNCell([self.cell] * args.num_layers) 29 | 30 | self.input_data = tf.placeholder(tf.int32, [args.batch_size, args.seq_length]) 31 | self.targets = tf.placeholder(tf.int32, [args.batch_size, args.seq_length]) # Target replication 32 | self.initial_state = self.cell.zero_state(args.batch_size, tf.float32) 33 | 34 | with tf.variable_scope('rnn'): 35 | softmax_w = tf.get_variable('softmax_w', [args.rnn_size, 2]) 36 | softmax_b = tf.get_variable('softmax_b', [2]) 37 | 38 | with tf.device('/cpu:0'): 39 | embedding = tf.get_variable('embedding', [args.vocab_size, args.rnn_size]) 40 | inputs = tf.split(1, args.seq_length, tf.nn.embedding_lookup(embedding, self.input_data)) 41 | inputs = [tf.squeeze(i, [1]) for i in inputs] 42 | 43 | outputs, last_state = seq2seq.rnn_decoder(inputs, self.initial_state, 44 | self.cell, loop_function=None) 45 | 46 | output_tf = tf.reshape(tf.concat(1, outputs), [-1, args.rnn_size]) 47 | self.logits = tf.nn.xw_plus_b(output_tf, softmax_w, softmax_b) 48 | self.probs = tf.nn.softmax(self.logits) 49 | 50 | loss = seq2seq.sequence_loss_by_example( 51 | [self.logits], 52 | [tf.reshape(self.targets, [-1])], 53 | [tf.ones([args.batch_size * args.seq_length])]) 54 | 55 | self.cost = tf.reduce_sum(loss) / args.batch_size / args.seq_length 56 | 57 | self.final_state = last_state 58 | self.lr = tf.Variable(0.0, trainable = False) 59 | tvars = tf.trainable_variables() 60 | grads, _ = tf.clip_by_global_norm(tf.gradients(self.cost, tvars, aggregation_method=2), args.grad_clip) 61 | optimizer = tf.train.AdamOptimizer(self.lr) 62 | self.train_op = optimizer.apply_gradients(zip(grads, tvars)) 63 | 64 | def predict(self, sess, chars, vocab): 65 | '''Predict a seqeuence of chars using the current model''' 66 | state = self.cell.zero_state(self.args.batch_size, tf.float32).eval() 67 | probabilities, num_seq = [], [] 68 | 69 | for char in chars: 70 | num_seq.append(vocab[char]) 71 | 72 | for num in num_seq: 73 | x = np.array([[num]]) 74 | feed = {self.input_data: x, self.initial_state: state} 75 | [probs, state] = sess.run([self.probs, self.final_state], feed) 76 | probabilities.append(probs) 77 | 78 | return probabilities 79 | -------------------------------------------------------------------------------- /tensorflow/gan.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from tensorflow.python.ops.nn import rnn_cell 4 | from tensorflow.python.ops.nn import rnn 5 | from tensorflow.python.ops.nn import seq2seq 6 | from tensorflow.python.ops import array_ops 7 | from tensorflow.python.ops import nn_ops 8 | from tensorflow.python.ops import math_ops 9 | from tensorflow.python.framework import ops 10 | from tensorflow.contrib.distributions import Categorical 11 | 12 | def variable_summaries(var, name): 13 | '''Attach a lot of summaries to a Tensor.''' 14 | mean = tf.reduce_mean(var) 15 | tf.scalar_summary('mean/' + name, mean) 16 | with tf.name_scope('stddev'): 17 | stddev = tf.sqrt(tf.reduce_sum(tf.square(var - mean))) 18 | tf.scalar_summary('sttdev/' + name, stddev) 19 | tf.scalar_summary('max/' + name, tf.reduce_max(var)) 20 | tf.scalar_summary('min/' + name, tf.reduce_min(var)) 21 | tf.histogram_summary(name, var) 22 | 23 | class GAN(object): 24 | def __init__(self, args, is_training=True): 25 | 26 | if not is_training: 27 | seq_length = 1 28 | else: 29 | seq_length = args.seq_length 30 | 31 | if args.model == 'rnn': 32 | cell_gen = rnn_cell.BasicRNNCell(args.rnn_size) 33 | cell_dis = rnn_cell.BasicRNNCell(args.rnn_size) 34 | elif args.model == 'gru': 35 | cell_gen = rnn_cell.GRUCell(args.rnn_size) 36 | cell_dis = rnn_cell.GRUCell(args.rnn_size) 37 | elif args.model == 'lstm': 38 | cell_gen = rnn_cell.BasicLSTMCell(args.rnn_size) 39 | cell_dis = rnn_cell.BasicLSTMCell(args.rnn_size) 40 | else: 41 | raise Exception('model type not supported: {}'.format(args.model)) 42 | 43 | # Pass the generated sequences and targets (1) 44 | with tf.name_scope('input'): 45 | with tf.name_scope('data'): 46 | self.input_data = tf.placeholder(tf.int32, [args.batch_size, seq_length]) 47 | with tf.name_scope('targets'): 48 | self.targets = tf.placeholder(tf.int32, [args.batch_size, seq_length]) 49 | 50 | ############ 51 | # Generator 52 | ############ 53 | with tf.variable_scope('generator'): 54 | self.cell_gen = rnn_cell.MultiRNNCell([cell_gen] * args.num_layers) 55 | self.initial_state_gen = self.cell_gen.zero_state(args.batch_size, tf.float32) 56 | 57 | with tf.variable_scope('rnn'): 58 | softmax_w = tf.get_variable('softmax_w', [args.rnn_size, args.vocab_size]) 59 | softmax_b = tf.get_variable('softmax_b', [args.vocab_size]) 60 | 61 | with tf.device('/cpu:0'): 62 | embedding = tf.get_variable('embedding', [args.vocab_size, args.rnn_size]) 63 | inputs_gen = tf.split(1, seq_length, tf.nn.embedding_lookup( 64 | embedding, self.input_data)) 65 | inputs_gen = [tf.squeeze(i, [1]) for i in inputs_gen] 66 | 67 | outputs_gen, last_state_gen = seq2seq.rnn_decoder(inputs_gen, self.initial_state_gen, 68 | self.cell_gen, loop_function=None) 69 | 70 | self.logits_sequence = [] 71 | for output_gen in outputs_gen: 72 | logits_gen = tf.nn.xw_plus_b(output_gen, softmax_w, softmax_b) 73 | self.logits_sequence.append(logits_gen) 74 | 75 | self.final_state_gen = last_state_gen 76 | 77 | ################ 78 | # Discriminator 79 | ################ 80 | with tf.variable_scope('discriminator'): 81 | self.cell_dis = rnn_cell.MultiRNNCell([cell_dis] * args.num_layers) 82 | self.initial_state_dis = self.cell_dis.zero_state(args.batch_size, tf.float32) 83 | 84 | with tf.variable_scope('rnn'): 85 | softmax_w = tf.get_variable('softmax_w', [args.rnn_size, 2]) 86 | softmax_b = tf.get_variable('softmax_b', [2]) 87 | 88 | inputs_dis = [] 89 | embedding = tf.get_variable('embedding', [args.vocab_size, args.rnn_size]) 90 | for logit in self.logits_sequence: 91 | inputs_dis.append(tf.matmul(logit, embedding)) 92 | # inputs_dis.append(tf.matmul(tf.nn.softmax(logit), embedding)) 93 | 94 | outputs_dis, last_state_dis = seq2seq.rnn_decoder(inputs_dis, 95 | self.initial_state_dis, self.cell_dis, loop_function=None) 96 | 97 | probs, logits = [], [] 98 | for output_dis in outputs_dis: 99 | logit = tf.nn.xw_plus_b(output_dis, softmax_w, softmax_b) 100 | prob = tf.nn.softmax(logit) 101 | logits.append(logit) 102 | probs.append(prob) 103 | 104 | with tf.name_scope('summary'): 105 | probs = tf.pack(probs) 106 | probs_real = tf.slice(probs, [0,0,1], [args.seq_length, args.batch_size, 1]) 107 | variable_summaries(probs_real, 'probability of real') 108 | 109 | self.final_state_dis = last_state_dis 110 | 111 | ######### 112 | # Train 113 | ######### 114 | with tf.name_scope('train'): 115 | gen_loss = seq2seq.sequence_loss_by_example( 116 | logits, 117 | tf.unpack(tf.transpose(self.targets)), 118 | tf.unpack(tf.transpose(tf.ones_like(self.targets, dtype=tf.float32)))) 119 | 120 | self.gen_cost = tf.reduce_sum(gen_loss) / args.batch_size 121 | tf.scalar_summary('training loss', self.gen_cost) 122 | self.lr_gen = tf.Variable(0.0, trainable = False) 123 | self.tvars = tf.trainable_variables() 124 | gen_vars = [v for v in self.tvars if not v.name.startswith("discriminator/")] 125 | 126 | if is_training: 127 | gen_grads = tf.gradients(self.gen_cost, gen_vars) 128 | self.all_grads = tf.gradients(self.gen_cost, self.tvars) 129 | gen_grads_clipped, _ = tf.clip_by_global_norm(gen_grads, args.grad_clip) 130 | gen_optimizer = tf.train.AdamOptimizer(self.lr_gen) 131 | self.gen_train_op = gen_optimizer.apply_gradients( 132 | zip(gen_grads_clipped, gen_vars)) 133 | 134 | with tf.name_scope('summary'): 135 | with tf.name_scope('weight_summary'): 136 | for v in self.tvars: 137 | variable_summaries(v, v.op.name) 138 | if is_training: 139 | with tf.name_scope('grad_summary'): 140 | for var, grad in zip(self.tvars, self.all_grads): 141 | variable_summaries(grad, 'grad/' + var.op.name) 142 | 143 | self.merged = tf.merge_all_summaries() 144 | 145 | 146 | def generate_samples(self, sess, args, chars, vocab, seq_length = 200, 147 | initial = ' ', 148 | datafile = 'data/gan/fake_reviews.txt'): 149 | ''' Generate a batch of reviews''' 150 | state = self.cell_gen.zero_state(args.batch_size, tf.float32).eval() 151 | 152 | sequence_matrix = [] 153 | for i in xrange(args.batch_size): 154 | sequence_matrix.append([]) 155 | char_arr = args.batch_size * [initial] 156 | for n in xrange(seq_length): 157 | x = np.zeros((args.batch_size, 1)) 158 | for i, char in enumerate(char_arr): 159 | x[i,0] = vocab[char] 160 | feed = {self.input_data: x, self.initial_state_gen: state} 161 | sample_op = Categorical(self.logits_sequence[0]) 162 | [sample_indexes, state] = sess.run( 163 | [sample_op.sample(n = 1), self.final_state_gen], feed) 164 | char_arr = [chars[i] for i in tf.squeeze(sample_indexes)] 165 | for i, char in enumerate(char_arr): 166 | sequence_matrix[i].append(char) 167 | 168 | with open(datafile, 'a+') as f: 169 | for line in sequence_matrix: 170 | print ''.join(line) 171 | print>>f, ''.join(line) 172 | 173 | return sequence_matrix -------------------------------------------------------------------------------- /tensorflow/generator.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from tensorflow.python.ops.nn import rnn_cell 4 | from tensorflow.python.ops.nn import rnn 5 | from tensorflow.python.ops.nn import seq2seq 6 | 7 | 8 | class Generator(object): 9 | def __init__(self, args, is_training=True, batch=True): 10 | self.args = args 11 | 12 | if not is_training: 13 | args.seq_length = 1 14 | 15 | if not batch: 16 | args.batch_size = 1 17 | 18 | if args.model == 'rnn': 19 | cell = rnn_cell.BasicRNNCell(args.rnn_size) 20 | elif args.model == 'gru': 21 | cell = rnn_cell.GRUCell(args.rnn_size) 22 | elif args.model == 'lstm': 23 | cell = rnn_cell.BasicLSTMCell(args.rnn_size) 24 | else: 25 | raise Exception("model type not supported: {}".format(args.model)) 26 | 27 | self.cell = rnn_cell.MultiRNNCell([cell] * args.num_layers) 28 | 29 | self.input_data = tf.placeholder(tf.int32, [args.batch_size, args.seq_length]) 30 | self.targets = tf.placeholder(tf.int32, [args.batch_size, args.seq_length]) 31 | self.initial_state = self.cell.zero_state(args.batch_size, tf.float32) 32 | 33 | with tf.variable_scope('rnn'): 34 | softmax_w = tf.get_variable('softmax_w', [args.rnn_size, args.vocab_size]) 35 | softmax_b = tf.get_variable('softmax_b', [args.vocab_size]) 36 | 37 | with tf.device('/cpu:0'): 38 | embedding = tf.get_variable('embedding', [args.vocab_size, args.rnn_size]) 39 | inputs = tf.split(1, args.seq_length, tf.nn.embedding_lookup(embedding, self.input_data)) 40 | inputs = [tf.squeeze(i, [1]) for i in inputs] 41 | 42 | def loop(prev, _): 43 | prev = tf.nn.xw_plus_b(prev, softmax_w, softmax_b) 44 | prev_symbol = tf.stop_gradient(tf.argmax(prev, 1)) 45 | return tf.nn.embedding_lookup(embedding, prev_symbol) 46 | 47 | outputs, last_state = seq2seq.rnn_decoder(inputs, self.initial_state, 48 | self.cell, loop_function=None if is_training else loop, scope='rnn') 49 | output = tf.reshape(tf.concat(1, outputs), [-1, args.rnn_size]) 50 | self.logits = tf.nn.xw_plus_b(output, softmax_w, softmax_b) 51 | self.probs = tf.nn.softmax(self.logits) 52 | loss = seq2seq.sequence_loss_by_example([self.logits], 53 | [tf.reshape(self.targets, [-1])], 54 | [tf.ones([args.batch_size * args.seq_length])], 55 | args.vocab_size) 56 | self.cost = tf.reduce_sum(loss) / args.batch_size / args.seq_length 57 | self.final_state = last_state 58 | self.lr = tf.Variable(0.0, trainable = False) 59 | tvars = tf.trainable_variables() 60 | grads, _ = tf.clip_by_global_norm(tf.gradients(self.cost, tvars, aggregation_method=2), args.grad_clip) 61 | optimizer = tf.train.AdamOptimizer(self.lr) 62 | self.train_op = optimizer.apply_gradients(zip(grads, tvars)) 63 | 64 | def generate(self, sess, chars, vocab, seq_length = 200, initial=' '): 65 | state = self.cell.zero_state(1, tf.float32).eval() 66 | for char in initial[:-1]: 67 | x = np.zeros((1,1)) 68 | x[0,0] = vocab[char] 69 | feed = {self.input_data: x, self.initial_state: state} 70 | [state] = sess.run([self.final_state], feed) 71 | 72 | sequence = initial 73 | char = initial[-1] 74 | for n in xrange(seq_length): 75 | x = np.zeros((1,1)) 76 | x[0,0] = vocab[char] 77 | feed = {self.input_data: x, self.initial_state: state} 78 | [probs, state] = sess.run([self.probs, self.final_state], feed) 79 | p = probs[0] 80 | sample = int(np.random.choice(len(p), p=p)) 81 | pred = chars[sample] 82 | sequence += pred 83 | char = pred 84 | return sequence 85 | 86 | def generate_probabilities(self, sess, chars, vocab, seq_length = 200, initial = ' '): 87 | state = self.cell.zero_state(1, tf.float32).eval() 88 | for char in initial[:-1]: 89 | x = np.zeros((1,1)) 90 | x[0,0] = vocab[char] 91 | feed = {self.input_data: x, self.initial_state: state} 92 | [state] = sess.run([self.final_state], feed) 93 | 94 | probability_sequence = [] 95 | char = initial[-1] 96 | for n in xrange(seq_length): 97 | x = np.zeros((1,1)) 98 | x[0,0] = vocab[char] 99 | feed = {self.input_data: x, self.initial_state: state} 100 | [probs, state] = sess.run([self.probs, self.final_state], feed) 101 | p = probs[0] 102 | probability_sequence.append(p) 103 | sample = int(np.random.choice(len(p), p=p)) 104 | pred = chars[sample] 105 | char = pred 106 | return probability_sequence 107 | 108 | def sample_logits(self, logits, temperature=1.0): 109 | ''' This function is like sample_with_temperature except it can handle 110 | batch input a of [batch_size x logits] this function takes logits 111 | input, and produces a specific number from the array. 112 | 113 | args: 114 | logits: 2d array [batch_size x logits] 115 | temperature: hyperparameter to control variance of sampling 116 | 117 | returns: 118 | sequence_matrix: 1d array [batch_size] of selected numbers 119 | from distribution 120 | ''' 121 | exponent_raised = tf.exp(tf.div(logits, temperature)) 122 | matrix_X = tf.div(exponent_raised, tf.reduce_sum(exponent_raised, 123 | reduction_indices = 1, keep_dims = True)) 124 | matrix_U = tf.random_uniform(logits.get_shape(), minval = 0, maxval = 1) 125 | # You want dimension = 1 because you are argmaxing across rows. 126 | final_number = tf.argmax(tf.sub(matrix_X, matrix_U), dimension = 1) 127 | return final_number 128 | 129 | def sample_probs(self, probs): 130 | ''' This function takes probs input, [batch_size x probs] and 131 | produces a specific number from the array. 132 | 133 | args: 134 | probs: 2d array [batch_size x probs] 135 | temperature: hyperparameter to control variance of sampling 136 | 137 | returns: 138 | sequence_matrix: 1d array [batch_size] of selected numbers 139 | from distribution 140 | ''' 141 | matrix_U = tf.random_uniform(probs.get_shape(), minval = 0, maxval = 1) 142 | # You want dimension = 1 because you are argmaxing across rows. 143 | final_number = tf.argmax(tf.sub(probs, matrix_U), dimension = 1) 144 | return final_number 145 | 146 | 147 | def generate_batch(self, sess, args, chars, vocab, seq_length = 200, initial = ' '): 148 | ''' Generate a batch of reviews entirely within TensorFlow''' 149 | 150 | state = self.cell.zero_state(args.batch_size, tf.float32).eval() 151 | 152 | sequence_matrix = [] 153 | for i in xrange(args.batch_size): 154 | sequence_matrix.append([]) 155 | char_arr = args.batch_size * [initial] 156 | 157 | # TF implementation (doesn't quite work): 158 | # probs_tf = tf.placeholder(tf.float32, [args.batch_size, args.vocab_size]) 159 | # sample_op = self.sample_probs(probs_tf) 160 | 161 | for n in xrange(seq_length): 162 | x = np.zeros((args.batch_size, 1)) 163 | for i, char in enumerate(char_arr): 164 | x[i,0] = vocab[char] 165 | feed = {self.input_data: x, self.initial_state: state} 166 | [probs, state] = sess.run([self.probs, self.final_state], feed) 167 | # TF implementation (doesn't quite work): 168 | # probs_feed = {probs_tf: probs} 169 | # [sample_indexes] = sess.run([sample_op], probs_feed) 170 | 171 | # Numpy implementation: 172 | sample_indexes = [int(np.random.choice(len(p), p=p)) for p in probs] 173 | print len(sample_indexes) 174 | char_arr = [chars[i] for i in sample_indexes] 175 | for i, char in enumerate(char_arr): 176 | sequence_matrix[i].append(char) 177 | return sequence_matrix -------------------------------------------------------------------------------- /tensorflow/predict.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from argparse import ArgumentParser 4 | import time 5 | import os 6 | import cPickle 7 | from batcher import DiscriminatorBatcher 8 | from discriminator import Discriminator 9 | 10 | def parse_args(): 11 | parser = ArgumentParser() 12 | parser.add_argument('text', 13 | help='string of text to predict') 14 | parser.add_argument('--save_dir', type=str, default='models_discriminator', 15 | help='model directory to store checkpointed models') 16 | parser.add_argument('--data_dir', type=str, default='data', 17 | help='data directory containing reviews') 18 | return parser.parse_args() 19 | 20 | def predict(args): 21 | with open(os.path.join(args.save_dir, 'config.pkl')) as f: 22 | saved_args = cPickle.load(f) 23 | with open(os.path.join(args.save_dir, 'combined_vocab.pkl')) as f: 24 | _, vocab = cPickle.load(f) 25 | model = Discriminator(saved_args, is_training = False) 26 | with tf.Session() as sess: 27 | tf.initialize_all_variables().run() 28 | saver = tf.train.Saver(tf.all_variables()) 29 | ckpt = tf.train.get_checkpoint_state(args.save_dir) 30 | if ckpt and ckpt.model_checkpoint_path: 31 | saver.restore(sess, ckpt.model_checkpoint_path) 32 | return model.predict(sess, args.text, vocab) 33 | 34 | if __name__=='__main__': 35 | args = parse_args() 36 | with tf.device('/gpu:3'): 37 | probs = predict(args) 38 | 39 | for char, prob in zip(args.text, probs): 40 | print char, prob -------------------------------------------------------------------------------- /tensorflow/sample.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from argparse import ArgumentParser 4 | import time 5 | import os 6 | import cPickle 7 | from batcher import Batcher 8 | from generator import Generator 9 | from gan import GAN 10 | 11 | def parse_args(): 12 | parser = ArgumentParser() 13 | parser.add_argument('--save_dir_gen', type=str, default='models_generator', 14 | help='model directory to store checkpointed models') 15 | parser.add_argument('--save_dir_GAN', type=str, default='models_GAN', 16 | help='model directory to store checkpointed models') 17 | parser.add_argument('--data_dir_gen', type=str, default='data', 18 | help='data directory containing reviews') 19 | parser.add_argument('--data_dir_GAN', type=str, default='data', 20 | help='data directory containing reviews') 21 | parser.add_argument('-n', type=int, default=500, 22 | help='number of characters to sample') 23 | parser.add_argument('--prime', type=str, default=' ', 24 | help='prime text') 25 | return parser.parse_args() 26 | 27 | 28 | def sample_generator(args, num_samples = 10): 29 | with open(os.path.join(args.save_dir_GAN, 'config.pkl')) as f: 30 | saved_args = cPickle.load(f) 31 | with open(os.path.join(args.save_dir_GAN, 'real_beer_vocab.pkl')) as f: 32 | chars, vocab = cPickle.load(f) 33 | generator = Generator(saved_args, is_training = False, batch = True) 34 | with tf.Session() as sess: 35 | tf.initialize_all_variables().run() 36 | saver = tf.train.Saver(tf.all_variables()) 37 | ckpt = tf.train.get_checkpoint_state(args.save_dir) 38 | if ckpt and ckpt.model_checkpoint_path: 39 | saver.restore(sess, ckpt.model_checkpoint_path) 40 | # for i in range(num_samples): 41 | # print 'Review',i,':', generator.generate(sess, chars, vocab, args.n, args.prime), '\n' 42 | 43 | return generator.generate_batch(sess, saved_args, chars, vocab) 44 | 45 | def sample_GAN(args, num_samples = 10): 46 | with open(os.path.join(args.save_dir_GAN, 'config.pkl')) as f: 47 | saved_args = cPickle.load(f) 48 | with open(os.path.join(args.save_dir_GAN, 'simple_vocab.pkl')) as f: 49 | chars, vocab = cPickle.load(f) 50 | gan = GAN(saved_args, is_training = False) 51 | with tf.Session() as sess: 52 | tf.initialize_all_variables().run() 53 | saver = tf.train.Saver(tf.all_variables()) 54 | ckpt = tf.train.get_checkpoint_state(args.save_dir_GAN) 55 | if ckpt and ckpt.model_checkpoint_path: 56 | saver.restore(sess, ckpt.model_checkpoint_path) 57 | return gan.generate_samples(sess, saved_args, chars, vocab, args.n) 58 | 59 | if __name__ == '__main__': 60 | args = parse_args() 61 | with tf.device('/gpu:3'): 62 | # sample_GAN(args) 63 | reviews = sample_GAN(args) 64 | for review in reviews: 65 | print ''.join(review) 66 | 67 | 68 | -------------------------------------------------------------------------------- /tensorflow/simple_vocab.pkl: -------------------------------------------------------------------------------- 1 | (S'e' 2 | S' ' 3 | S'r' 4 | S'b' 5 | S'\n' 6 | tp1 7 | . -------------------------------------------------------------------------------- /tensorflow/train_gan_new.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import logging 4 | from tensorflow.models.rnn import * 5 | from argparse import ArgumentParser 6 | from batcher_gan import DiscriminatorBatcher, GANBatcher 7 | from gan import GAN 8 | from discriminator import Discriminator 9 | from generator import Generator 10 | import time 11 | import os 12 | import cPickle 13 | 14 | logger = logging.getLogger() 15 | logger.setLevel(logging.ERROR) 16 | 17 | 18 | def parse_args(): 19 | parser = ArgumentParser() 20 | parser.add_argument('--real_input_file', type=str, default='real_reviews.txt', 21 | help='real reviews') 22 | parser.add_argument('--fake_input_file', type=str, default='fake_reviews.txt', 23 | help='fake reviews') 24 | parser.add_argument('--data_dir', type=str, default='data/gan', 25 | help='data directory containing reviews') 26 | parser.add_argument('--log_dir', type=str, default='logs', 27 | help='log directory for TensorBoard') 28 | parser.add_argument('--vocab_file', type=str, default='simple_vocab.pkl', 29 | help='data directory containing reviews') 30 | parser.add_argument('--save_dir_GAN', type=str, default='models_GAN', 31 | help='directory to store checkpointed GAN models') 32 | parser.add_argument('--save_dir_dis', type=str, default='models_GAN/discriminator', 33 | help='directory to store checkpointed discriminator models') 34 | parser.add_argument('--rnn_size', type=int, default=128, 35 | help='size of RNN hidden state') 36 | parser.add_argument('--num_layers', type=int, default=1, 37 | help='number of layers in the RNN') 38 | parser.add_argument('--model', type=str, default='lstm', 39 | help='rnn, gru, or lstm') 40 | parser.add_argument('--batch_size', type=int, default=10, 41 | help='minibatch size') 42 | parser.add_argument('--seq_length', type=int, default=30, 43 | help='RNN sequence length') 44 | parser.add_argument('-n', type=int, default=10, 45 | help='number of characters to sample') 46 | parser.add_argument('--prime', type=str, default=' ', 47 | help='prime text') 48 | parser.add_argument('--num_epochs_GAN', type=int, default=25, 49 | help='number of epochs of GAN') 50 | parser.add_argument('--num_epochs_gen', type=int, default=1, 51 | help='number of epochs to train generator') 52 | parser.add_argument('--num_epochs_dis', type=int, default=1, 53 | help='number of epochs to train discriminator') 54 | parser.add_argument('--save_every', type=int, default=500, 55 | help='save frequency') 56 | parser.add_argument('--grad_clip', type=float, default=5., 57 | help='clip gradients at this value') 58 | parser.add_argument('--learning_rate_gen', type=float, default=0.0001, 59 | help='learning rate') 60 | parser.add_argument('--learning_rate_dis', type=float, default=0.0002, 61 | help='learning rate for discriminator') 62 | parser.add_argument('--decay_rate', type=float, default=0.97, 63 | help='decay rate for rmsprop') 64 | parser.add_argument('--keep_prob', type=float, default=0.5, 65 | help='keep probability for dropout') 66 | parser.add_argument('--vocab_size', type=float, default=5, 67 | help='size of the vocabulary (characters)') 68 | return parser.parse_args() 69 | 70 | 71 | def train_generator(gan, args, sess, train_writer, weights_load = 'random'): 72 | '''Train Generator via GAN''' 73 | logging.debug('Training generator...') 74 | 75 | batcher = GANBatcher(args.fake_input_file, args.vocab_file, 76 | args.data_dir, args.batch_size, 77 | args.seq_length) 78 | 79 | logging.debug('Vocabulary...') 80 | # TODO: Why do this each time? Unnecessary 81 | with open(os.path.join(args.save_dir_GAN, 'config.pkl'), 'w') as f: 82 | cPickle.dump(args, f) 83 | with open(os.path.join(args.save_dir_GAN, 'simple_vocab.pkl'), 'w') as f: 84 | cPickle.dump((batcher.chars, batcher.vocab), f) 85 | 86 | # Save all GAN variables to gan_saver 87 | gan_vars = [v for v in tf.all_variables() if not 88 | (v.name.startswith('classic/') or v.name.startswith('sampler/') )] 89 | gan_saver = tf.train.Saver(gan_vars) 90 | 91 | # Retrieve trainable Discriminator variables from dis_saver 92 | # TODO: Trainable vs. All Variables? 93 | dis_vars = [v for v in tf.trainable_variables() if v.name.startswith('discriminator/')] 94 | dis_saver = tf.train.Saver(dis_vars) 95 | 96 | if weights_load is 'random': 97 | logging.debug('Random GAN parameters') 98 | elif weights_load is 'gan': 99 | logging.debug('Initial load of GAN parameters...') 100 | ckpt = tf.train.get_checkpoint_state(args.save_dir_GAN) 101 | if ckpt and ckpt.model_checkpoint_path: 102 | gan_saver.restore(sess, ckpt.model_checkpoint_path) 103 | elif weights_load is 'discriminator': 104 | logging.debug('Update GAN parameters from Discriminator...') 105 | ckpt = tf.train.get_checkpoint_state(args.save_dir_dis) 106 | if ckpt and ckpt.model_checkpoint_path: 107 | dis_saver.restore(sess, ckpt.model_checkpoint_path) 108 | else: 109 | raise Exception('Invalid weight initialization for GAN') 110 | 111 | for epoch in xrange(args.num_epochs_gen): 112 | new_lr = args.learning_rate_gen * (args.decay_rate ** epoch) 113 | sess.run(tf.assign(gan.lr_gen, new_lr)) 114 | batcher.reset_batch_pointer() 115 | state_gen = gan.initial_state_gen.eval() 116 | state_dis = gan.initial_state_dis.eval() 117 | 118 | for batch in xrange(25): 119 | # for batch in xrange(batcher.num_batches): 120 | start = time.time() 121 | x, _ = batcher.next_batch() 122 | y = np.ones(x.shape) 123 | feed = {gan.input_data: x, 124 | gan.targets: y, 125 | gan.initial_state_gen: state_gen, 126 | gan.initial_state_dis: state_dis} 127 | # feed = {gan.input_data: x, 128 | # gan.targets: y} 129 | gen_train_loss, gen_summary, state_gen, state_dis, _ = sess.run([ 130 | gan.gen_cost, 131 | gan.merged, 132 | gan.final_state_gen, 133 | gan.final_state_dis, 134 | gan.gen_train_op], feed) 135 | 136 | train_writer.add_summary(gen_summary, batch) 137 | end = time.time() 138 | 139 | print '{}/{} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}' \ 140 | .format(epoch * batcher.num_batches + batch, 141 | args.num_epochs_gen * batcher.num_batches, 142 | epoch, gen_train_loss, end - start) 143 | 144 | if (epoch * batcher.num_batches + batch) % args.save_every == 0: 145 | checkpoint_path = os.path.join(args.save_dir_GAN, 'model.ckpt') 146 | gan_saver.save(sess, checkpoint_path, global_step = epoch * batcher.num_batches + batch) 147 | print 'GAN model saved to {}'.format(checkpoint_path) 148 | 149 | 150 | 151 | def train_discriminator(discriminator, args, sess): 152 | '''Train the discriminator via classical approach''' 153 | logging.debug('Training discriminator...') 154 | 155 | batcher = DiscriminatorBatcher(args.real_input_file, 156 | args.fake_input_file, 157 | args.data_dir, args.vocab_file, 158 | args.batch_size, args.seq_length) 159 | 160 | logging.debug('Vocabulary...') 161 | with open(os.path.join(args.save_dir_GAN, 'simple_vocab.pkl'), 'w') as f: 162 | cPickle.dump((batcher.chars, batcher.vocab), f) 163 | 164 | logging.debug('Loading GAN parameters to Discriminator...') 165 | dis_vars = [v for v in tf.trainable_variables() if v.name.startswith('classic/')] 166 | dis_dict = {} 167 | for v in dis_vars: 168 | # Key: op.name in GAN Checkpoint file 169 | # Value: Local generator Variable 170 | dis_dict[v.op.name.replace('classic/','discriminator/')] = v 171 | dis_saver = tf.train.Saver(dis_dict) 172 | 173 | ckpt = tf.train.get_checkpoint_state(args.save_dir_GAN) 174 | if ckpt and ckpt.model_checkpoint_path: 175 | dis_saver.restore(sess, ckpt.model_checkpoint_path) 176 | 177 | for epoch in xrange(args.num_epochs_dis): 178 | # Anneal learning rate 179 | new_lr = args.learning_rate_dis * (args.decay_rate ** epoch) 180 | sess.run(tf.assign(discriminator.lr, new_lr)) 181 | batcher.reset_batch_pointer() 182 | state = discriminator.initial_state.eval() 183 | 184 | for batch in xrange(10): 185 | # for batch in xrange(batcher.num_batches): 186 | start = time.time() 187 | x, y = batcher.next_batch() 188 | 189 | feed = {discriminator.input_data: x, 190 | discriminator.targets: y, 191 | discriminator.initial_state: state} 192 | train_loss, state, _ = sess.run([discriminator.cost, 193 | discriminator.final_state, 194 | discriminator.train_op], 195 | feed) 196 | end = time.time() 197 | 198 | print '{}/{} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}' \ 199 | .format(epoch * batcher.num_batches + batch, 200 | args.num_epochs_dis * batcher.num_batches, 201 | epoch, train_loss, end - start) 202 | 203 | if (epoch * batcher.num_batches + batch) % args.save_every == 0: 204 | checkpoint_path = os.path.join(args.save_dir_dis, 'model.ckpt') 205 | dis_saver.save(sess, checkpoint_path, global_step = epoch * batcher.num_batches + batch) 206 | print 'Discriminator model saved to {}'.format(checkpoint_path) 207 | 208 | 209 | def generate_samples(generator, args, sess, num_samples=500): 210 | '''Generate samples from the current version of the GAN''' 211 | samples = [] 212 | 213 | with open(os.path.join(args.save_dir_GAN, 'config.pkl')) as f: 214 | saved_args = cPickle.load(f) 215 | with open(os.path.join(args.save_dir_GAN, args.vocab_file)) as f: 216 | chars, vocab = cPickle.load(f) 217 | 218 | logging.debug('Loading GAN parameters to Generator...') 219 | gen_vars = [v for v in tf.all_variables() if v.name.startswith('sampler/')] 220 | gen_dict = {} 221 | for v in gen_vars: 222 | # Key: op.name in GAN Checkpoint file 223 | # Value: Local generator Variable 224 | gen_dict[v.op.name.replace('sampler/','')] = v 225 | gen_saver = tf.train.Saver(gen_dict) 226 | ckpt = tf.train.get_checkpoint_state(args.save_dir_GAN) 227 | if ckpt and ckpt.model_checkpoint_path: 228 | gen_saver.restore(sess, ckpt.model_checkpoint_path) 229 | 230 | for _ in xrange(num_samples / args.batch_size): 231 | samples.append(generator.generate_samples(sess, saved_args, chars, vocab, args.n)) 232 | return samples 233 | 234 | 235 | def reset_reviews(data_dir, file_name): 236 | open(os.path.join(data_dir, file_name), 'w').close() 237 | 238 | 239 | def adversarial_training(gan, discriminator, generator, train_writer, args, sess): 240 | '''Adversarial Training''' 241 | train_generator(gan, args, sess, train_writer, weights_load = 'random') 242 | generate_samples(generator, args, sess, 200) 243 | 244 | for epoch in xrange(args.num_epochs_GAN): 245 | train_discriminator(discriminator, args, sess) 246 | train_generator(gan, args, sess, train_writer, weights_load = 'discriminator') 247 | reset_reviews(args.data_dir, args.fake_input_file) 248 | generate_samples(generator, args, sess, 200) 249 | 250 | 251 | if __name__=='__main__': 252 | args = parse_args() 253 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.05) 254 | 255 | with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=True, gpu_options=gpu_options)) as sess: 256 | 257 | logging.debug('Creating models...') 258 | gan = GAN(args, is_training = True) 259 | with tf.variable_scope('classic'): 260 | discriminator = Discriminator(args, is_training = True) 261 | with tf.variable_scope('sampler'): 262 | generator = GAN(args, is_training = False) 263 | 264 | logging.debug('TensorBoard...') 265 | train_writer = tf.train.SummaryWriter(args.log_dir, sess.graph) 266 | 267 | logging.debug('Initializing variables in graph...') 268 | tf.initialize_all_variables().run() 269 | 270 | adversarial_training(gan, discriminator, generator, train_writer, args, sess) 271 | # train_generator(gan, args, sess, train_writer, weights_load = 'random') 272 | # generate_samples(generator, args, sess, 50) 273 | # train_discriminator(discriminator, args, sess) 274 | -------------------------------------------------------------------------------- /tensorflow/train_models.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import logging 4 | from tensorflow.models.rnn import * 5 | from argparse import ArgumentParser 6 | from batcher import Batcher, DiscriminatorBatcher 7 | from generator import Generator 8 | from discriminator import Discriminator 9 | import time 10 | import os 11 | import cPickle 12 | 13 | logger = logging.getLogger() 14 | logger.setLevel(logging.INFO) 15 | 16 | 17 | def parse_args(): 18 | parser = ArgumentParser() 19 | parser.add_argument('--data_dir', type=str, default='data', 20 | help='data directory containing reviews') 21 | parser.add_argument('--save_dir_gen', type=str, default='models_generator', 22 | help='directory to store checkpointed generator models') 23 | parser.add_argument('--save_dir_dis', type=str, default='models_discriminator', 24 | help='directory to store checkpointed discriminator models') 25 | parser.add_argument('--rnn_size', type=int, default=2048, 26 | help='size of RNN hidden state') 27 | parser.add_argument('--num_layers', type=int, default=2, 28 | help='number of layers in the RNN') 29 | parser.add_argument('--model', type=str, default='lstm', 30 | help='rnn, gru, or lstm') 31 | parser.add_argument('--batch_size', type=int, default=100, 32 | help='minibatch size') 33 | parser.add_argument('--seq_length', type=int, default=200, 34 | help='RNN sequence length') 35 | parser.add_argument('-n', type=int, default=500, 36 | help='number of characters to sample') 37 | parser.add_argument('--prime', type=str, default=' ', 38 | help='prime text') 39 | parser.add_argument('--num_epochs', type=int, default=5, 40 | help='number of epochs') 41 | parser.add_argument('--num_epochs_dis', type=int, default=5, 42 | help='number of epochs to train discriminator') 43 | parser.add_argument('--save_every', type=int, default=50, 44 | help='save frequency') 45 | parser.add_argument('--grad_clip', type=float, default=5., 46 | help='clip gradients at this value') 47 | parser.add_argument('--learning_rate', type=float, default=0.002, 48 | help='learning rate') 49 | parser.add_argument('--learning_rate_dis', type=float, default=0.0002, 50 | help='learning rate for discriminator') 51 | parser.add_argument('--decay_rate', type=float, default=0.97, 52 | help='decay rate for rmsprop') 53 | parser.add_argument('--keep_prob', type=float, default=0.5, 54 | help='keep probability for dropout') 55 | parser.add_argument('--vocab_size', type=float, default=100, 56 | help='size of the vocabulary (characters)') 57 | return parser.parse_args() 58 | 59 | 60 | def train_generator(args, load_recent=True): 61 | '''Train the generator via classical approach''' 62 | logging.debug('Batcher...') 63 | batcher = Batcher(args.data_dir, args.batch_size, args.seq_length) 64 | 65 | logging.debug('Vocabulary...') 66 | with open(os.path.join(args.save_dir_gen, 'config.pkl'), 'w') as f: 67 | cPickle.dump(args, f) 68 | with open(os.path.join(args.save_dir_gen, 'real_beer_vocab.pkl'), 'w') as f: 69 | cPickle.dump((batcher.chars, batcher.vocab), f) 70 | 71 | logging.debug('Creating generator...') 72 | generator = Generator(args, is_training = True) 73 | 74 | with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=True)) as sess: 75 | tf.initialize_all_variables().run() 76 | saver = tf.train.Saver(tf.all_variables()) 77 | 78 | if load_recent: 79 | ckpt = tf.train.get_checkpoint_state(args.save_dir_gen) 80 | if ckpt and ckpt.model_checkpoint_path: 81 | saver.restore(sess, ckpt.model_checkpoint_path) 82 | 83 | for epoch in xrange(args.num_epochs): 84 | # Anneal learning rate 85 | new_lr = args.learning_rate * (args.decay_rate ** epoch) 86 | sess.run(tf.assign(generator.lr, new_lr)) 87 | batcher.reset_batch_pointer() 88 | state = generator.initial_state.eval() 89 | 90 | for batch in xrange(batcher.num_batches): 91 | start = time.time() 92 | x, y = batcher.next_batch() 93 | feed = {generator.input_data: x, generator.targets: y, generator.initial_state: state} 94 | # train_loss, state, _ = sess.run([generator.cost, generator.final_state, generator.train_op], feed) 95 | train_loss, _ = sess.run([generator.cost, generator.train_op], feed) 96 | end = time.time() 97 | 98 | print '{}/{} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}' \ 99 | .format(epoch * batcher.num_batches + batch, 100 | args.num_epochs * batcher.num_batches, 101 | epoch, train_loss, end - start) 102 | 103 | if (epoch * batcher.num_batches + batch) % args.save_every == 0: 104 | checkpoint_path = os.path.join(args.save_dir_gen, 'model.ckpt') 105 | saver.save(sess, checkpoint_path, global_step = epoch * batcher.num_batches + batch) 106 | print 'Generator model saved to {}'.format(checkpoint_path) 107 | 108 | 109 | def train_discriminator(args, load_recent_weights = 'Generator'): 110 | '''Train the discriminator via classical approach''' 111 | logging.debug('Batcher...') 112 | batcher = DiscriminatorBatcher(args.data_dir, args.batch_size, args.seq_length) 113 | 114 | logging.debug('Vocabulary...') 115 | with open(os.path.join(args.save_dir_dis, 'config.pkl'), 'w') as f: 116 | cPickle.dump(args, f) 117 | with open(os.path.join(args.save_dir_dis, 'real_beer_vocab.pkl'), 'w') as f: 118 | cPickle.dump((batcher.chars, batcher.vocab), f) 119 | 120 | logging.debug('Creating discriminator...') 121 | discriminator = Discriminator(args, is_training = True) 122 | 123 | with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=True)) as sess: 124 | tf.initialize_all_variables().run() 125 | saver = tf.train.Saver(tf.all_variables()) 126 | 127 | if load_recent_weights is 'Discriminator': 128 | logging.debug('Loading recent Discriminator weights...') 129 | ckpt = tf.train.get_checkpoint_state(args.save_dir_dis) 130 | if ckpt and ckpt.model_checkpoint_path: 131 | saver.restore(sess, ckpt.model_checkpoint_path) 132 | 133 | elif load_recent_weights is 'Generator': 134 | logging.debug('Loading recent Generator weights...') 135 | # Only retrieve non-softmax weights from generator model 136 | vars_all = tf.all_variables() 137 | vars_gen = [var for var in vars_all if 'softmax' not in var.op.name] 138 | saver_gen = tf.train.Saver(vars_gen) 139 | ckpt = tf.train.get_checkpoint_state(args.save_dir_gen) 140 | if ckpt and ckpt.model_checkpoint_path: 141 | saver_gen.restore(sess, ckpt.model_checkpoint_path) 142 | 143 | else: 144 | raise Exception('Must initialize weights from either Generator or Discriminator') 145 | 146 | for epoch in xrange(args.num_epochs_dis): 147 | # Anneal learning rate 148 | new_lr = args.learning_rate_dis * (args.decay_rate ** epoch) 149 | sess.run(tf.assign(discriminator.lr, new_lr)) 150 | batcher.reset_batch_pointer() 151 | state = discriminator.initial_state.eval() 152 | 153 | for batch in xrange(batcher.num_batches): 154 | start = time.time() 155 | x, y = batcher.next_batch() 156 | 157 | feed = {discriminator.input_data: x, 158 | discriminator.targets: y, 159 | discriminator.initial_state: state} 160 | train_loss, state, _ = sess.run([discriminator.cost, 161 | discriminator.final_state, 162 | discriminator.train_op], 163 | feed) 164 | end = time.time() 165 | 166 | print '{}/{} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}' \ 167 | .format(epoch * batcher.num_batches + batch, 168 | args.num_epochs * batcher.num_batches, 169 | epoch, train_loss, end - start) 170 | 171 | if (epoch * batcher.num_batches + batch) % args.save_every == 0: 172 | checkpoint_path = os.path.join(args.save_dir_dis, 'discriminator.ckpt') 173 | saver.save(sess, checkpoint_path, global_step = epoch * batcher.num_batches + batch) 174 | print 'Discriminator model saved to {}'.format(checkpoint_path) 175 | 176 | 177 | if __name__=='__main__': 178 | args = parse_args() 179 | with tf.device('/gpu:3'): 180 | # train_generator(args, load_recent=True) 181 | train_discriminator(args) 182 | 183 | # with tf.device('/gpu:3'): 184 | # train_generator(args, load_recent=True) 185 | 186 | # generate_sample(args) 187 | 188 | --------------------------------------------------------------------------------