├── .gitignore ├── .hidden ├── config.py ├── extra └── prepare_gutenberg.py ├── main.py ├── reader.py ├── rnncell.py ├── rnnlm.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Project-specific stuff 2 | data_*/ 3 | models 4 | gutenberg 5 | 6 | # Editor stuff 7 | *.swp 8 | 9 | # Byte-compiled / optimized / DLL files 10 | __pycache__/ 11 | *.py[cod] 12 | *$py.class 13 | 14 | # C extensions 15 | *.so 16 | 17 | # Distribution / packaging 18 | .Python 19 | env/ 20 | build/ 21 | develop-eggs/ 22 | dist/ 23 | downloads/ 24 | eggs/ 25 | .eggs/ 26 | lib/ 27 | lib64/ 28 | parts/ 29 | sdist/ 30 | var/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | 35 | # PyInstaller 36 | # Usually these files are written by a python script from a template 37 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 38 | *.manifest 39 | *.spec 40 | 41 | # Installer logs 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | # Unit test / coverage reports 46 | htmlcov/ 47 | .tox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *,cover 54 | .hypothesis/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # IPython Notebook 78 | .ipynb_checkpoints 79 | 80 | # pyenv 81 | .python-version 82 | 83 | # celery beat schedule file 84 | celerybeat-schedule 85 | 86 | # dotenv 87 | .env 88 | 89 | # virtualenv 90 | venv/ 91 | ENV/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | 96 | # Rope project settings 97 | .ropeproject 98 | -------------------------------------------------------------------------------- /.hidden: -------------------------------------------------------------------------------- 1 | config.pyc 2 | encdec.pyc 3 | reader.pyc 4 | rnncell.pyc 5 | utils.pyc 6 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import tensorflow as tf 3 | 4 | flags = tf.flags 5 | cfg = flags.FLAGS 6 | 7 | 8 | # command-line config 9 | flags.DEFINE_string ("data_path", "data_ptb", "Data path") 10 | flags.DEFINE_string ("save_file", "models/recent.dat", "Save file") 11 | flags.DEFINE_string ("load_file", "", "File to load model from") 12 | flags.DEFINE_string ("word_vocab_file", "wvocab.pk", "Word vocab pickle file in data path") 13 | flags.DEFINE_string ("char_vocab_file", "cvocab.pk", "Character vocab pickle file in data " 14 | "path") 15 | 16 | flags.DEFINE_bool ("preallocate_gpu", True, "Preallocate all of the GPU memory") 17 | flags.DEFINE_bool ("char_model", False, "Character-level model") 18 | flags.DEFINE_bool ("use_gan", True, "Use adversatial objectives") 19 | flags.DEFINE_integer("batch_size", 50, "Batch size") 20 | flags.DEFINE_integer("word_emb_size", 224, "Word embedding size") 21 | flags.DEFINE_integer("char_emb_size", 96, "Character embedding size") 22 | flags.DEFINE_integer("num_layers", 1, "Number of RNN layers") 23 | flags.DEFINE_integer("word_hidden_size", 768, "RNN hidden state size for word model") 24 | flags.DEFINE_integer("char_hidden_size", 800, "RNN hidden state size for char model") 25 | flags.DEFINE_integer("softmax_samples", 1024, "Number of classes to sample for softmax") 26 | flags.DEFINE_bool ("concat_inputs", True, "Concatenate inputs to states before " 27 | "discriminating") 28 | flags.DEFINE_float ("min_d_acc", 0.75, "Update generator if descriminator is better " 29 | "than this") 30 | flags.DEFINE_float ("max_d_acc", 0.97, "Update descriminator if accuracy less than " 31 | "this") 32 | flags.DEFINE_float ("max_perplexity", -1, "Scheduler maintains perplexity to be under " 33 | "this (-1 to disable)") 34 | flags.DEFINE_integer("sc_list_size", 8, "Number of previous prints to look at in " 35 | "scheduler") 36 | flags.DEFINE_float ("sc_decay", 0.33, "Scheduler importance decay") 37 | flags.DEFINE_bool ("d_rnn", True, "Recurrent discriminator") 38 | flags.DEFINE_bool ("d_energy_based", False, "Energy-based discriminator") 39 | flags.DEFINE_float ("d_word_eb_margin", 992.0, "Margin for energy-based discriminator for word " 40 | "model") 41 | flags.DEFINE_float ("d_char_eb_margin", 1024.0, "Margin for energy-based discriminator for char " 42 | "model") 43 | flags.DEFINE_integer("d_num_layers", 1, "Number of RNN layers for discriminator (if " 44 | "recurrent)") 45 | flags.DEFINE_bool ("d_rnn_bidirect", True, "Recurrent discriminator is bidirectional") 46 | flags.DEFINE_integer("d_conv_window", 5, "Convolution window for convolution on " 47 | "discriminative RNN's states") 48 | flags.DEFINE_integer("word_sent_length", 256, "Maximum length of a sentence for word model") 49 | flags.DEFINE_integer("char_sent_length", 512, "Maximum length of a sentence for char model") 50 | flags.DEFINE_float ("max_grad_norm", 5.0, "Gradient clipping") 51 | flags.DEFINE_bool ("training", True, "Training mode, turn off for testing") 52 | flags.DEFINE_string ("d_optimizer", "adam", "Discriminator optimizer to use (sgd, adam, " 53 | "adagrad, adadelta)") 54 | flags.DEFINE_string ("g_optimizer", "adam", "Generator optimizer to use (sgd, adam, " 55 | "adagrad, adadelta)") 56 | flags.DEFINE_float ("d_learning_rate", 1e-4, "Optimizer initial learning rate for " 57 | "discriminator") 58 | flags.DEFINE_float ("g_learning_rate", 1e-4, "Optimizer initial learning rate for generator") 59 | flags.DEFINE_integer("max_epoch", 10000, "Maximum number of epochs to run for") 60 | flags.DEFINE_integer("max_steps", 9999999, "Maximum number of steps to run for") 61 | flags.DEFINE_integer("gen_samples", 1, "Number of demo samples batches to generate " 62 | "per epoch") 63 | flags.DEFINE_integer("gen_every", 500, "Generate samples every these many training " 64 | "steps (0 to disable, -1 for each epoch)") 65 | flags.DEFINE_integer("print_every", 50, "Print every these many steps") 66 | flags.DEFINE_integer("save_every", -1, "Save every these many steps (0 to disable, " 67 | "-1 for each epoch)") 68 | flags.DEFINE_bool ("save_overwrite", True, "Overwrite the same file each time") 69 | flags.DEFINE_bool ("test_validation", True, "Use the validation set during testing") 70 | flags.DEFINE_integer("validate_every", 1, "Validate every these many epochs " 71 | "(0 to disable)") 72 | 73 | 74 | if cfg.char_model: 75 | cfg.emb_size = cfg.char_emb_size 76 | cfg.hidden_size = cfg.char_hidden_size 77 | cfg.max_sent_length = cfg.char_sent_length 78 | cfg.d_eb_margin = cfg.d_char_eb_margin 79 | cfg.vocab_file = Path(cfg.data_path) / cfg.char_vocab_file 80 | else: 81 | cfg.emb_size = cfg.word_emb_size 82 | cfg.hidden_size = cfg.word_hidden_size 83 | cfg.max_sent_length = cfg.word_sent_length 84 | cfg.d_eb_margin = cfg.d_word_eb_margin 85 | cfg.vocab_file = Path(cfg.data_path) / cfg.word_vocab_file 86 | 87 | 88 | print('Config:') 89 | cfg._parse_flags() 90 | cfg_dict = cfg.__dict__['__flags'] 91 | maxlen = max(len(k) for k in cfg_dict) 92 | for k, v in sorted(cfg_dict.items(), key=lambda x: x[0]): 93 | print(k.ljust(maxlen + 2), v) 94 | print() 95 | -------------------------------------------------------------------------------- /extra/prepare_gutenberg.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | from multiprocessing import Pool 3 | from pathlib import Path 4 | import random 5 | import re 6 | import unicodedata 7 | 8 | import nltk 9 | 10 | input_dir = 'gutenberg' # raw text dir 11 | 12 | data_coverage = 0.966 # decide vocab based on how much of the data should be covered 13 | 14 | MIN_LEN = 4 15 | MAX_LEN = 75 16 | 17 | val_split = 0.0004 # gutenberg is huge 18 | test_split = 0.0006 19 | train_split = 1.0 - val_split - test_split 20 | 21 | 22 | fix_re = re.compile(r"[^a-z0-9]+") 23 | num_re = re.compile(r'[0-9]+') 24 | 25 | 26 | def fix_word(word): 27 | word = word.lower() 28 | word = fix_re.sub('', word) 29 | word = num_re.sub('#', word) 30 | if not any(c.isalpha() for c in word): 31 | word = '' 32 | return word 33 | 34 | 35 | def process(output, vocab, lines): 36 | if not lines: 37 | return 38 | para = unicodedata.normalize('NFKC', ' '.join(lines)) 39 | for sent in nltk.sent_tokenize(para): 40 | words = [fix_word(w) for w in nltk.word_tokenize(sent)] 41 | words = [w for w in words if w] 42 | for word in words: 43 | vocab[word] += 1 44 | if len(words) >= MIN_LEN and len(words) <= MAX_LEN: # ignore very short and long sentences 45 | output.append(words) 46 | 47 | 48 | def create_file(fname, lines, vocab): 49 | with open(fname, 'w') as f: 50 | for line in lines: 51 | words = [] 52 | for w in line: 53 | if w in vocab: 54 | words.append(w) 55 | else: 56 | words.append('') 57 | print(' '.join(words), file=f) 58 | with open('full' + fname, 'w') as f: 59 | for line in lines: 60 | print(' '.join(line), file=f) 61 | 62 | 63 | def summarize(output, vocab): 64 | print() 65 | print('Size of corpus:', vocab.N()) 66 | print('Total vocab size:', vocab.B()) 67 | 68 | N = len(output) 69 | test_N = int(test_split * N) 70 | val_N = int(val_split * N) 71 | train_N = N - test_N - val_N 72 | print('Number of lines:', N) 73 | print(' Train:', train_N) 74 | print(' Val: ', val_N) 75 | print(' Test: ', test_N) 76 | print() 77 | return train_N, val_N, test_N 78 | 79 | 80 | def process_file(fname): 81 | print(fname) 82 | output = [] 83 | vocab = nltk.FreqDist() 84 | with fname.open('r', encoding='latin-1') as f: 85 | paragraph = [] 86 | for l in f: 87 | line = l.strip() 88 | if not line: 89 | process(output, vocab, paragraph) 90 | paragraph = [] 91 | else: 92 | paragraph.append(line) 93 | process(output, vocab, paragraph) 94 | return output, vocab 95 | 96 | 97 | if __name__ == '__main__': 98 | output = [] 99 | vocab = nltk.FreqDist() 100 | print('Reading...') 101 | fnames = sorted(Path(input_dir).glob('*.txt')) 102 | p = Pool(int(.5 + (.9 * multiprocessing.cpu_count()))) 103 | outs = p.map(process_file, fnames) 104 | for o, v in outs: 105 | output.extend(o) 106 | vocab.update(v) 107 | 108 | train_N, val_N, test_N = summarize(output, vocab) 109 | top_words = vocab.most_common() 110 | count = 0 111 | for vocab_size in range(vocab.B()): 112 | count += top_words[vocab_size][1] 113 | if count / vocab.N() >= data_coverage: 114 | top_words = set(w for w, c in vocab.most_common(vocab_size + 1)) 115 | break 116 | print('Final vocab size:', len(top_words)) 117 | 118 | random.shuffle(output) 119 | create_file('train.txt', output[:train_N], top_words) 120 | create_file('test.txt', output[train_N:train_N + test_N], top_words) 121 | create_file('valid.txt', output[train_N + test_N:], top_words) 122 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import sys 3 | import time 4 | 5 | import numpy as np 6 | import tensorflow as tf 7 | 8 | from config import cfg 9 | from reader import Reader, Vocab 10 | from rnnlm import RNNLMModel 11 | import utils 12 | 13 | 14 | def call_session(session, model, batch, train_d=False, train_g=False): 15 | '''Use the session to run the model on the batch data.''' 16 | f_dict = {model.data: batch} 17 | ops = [model.nll, model.mle_cost, model.d_cost, model.g_cost] 18 | # training ops are tf.no_op() for a non-training model 19 | train_ops = [model.mle_train_op] 20 | if train_d: 21 | train_ops.append(model.d_train_op) 22 | if train_g: 23 | train_ops.append(model.g_train_op) 24 | ops.extend(train_ops) 25 | return session.run(ops, f_dict)[:-len(train_ops)] 26 | 27 | 28 | def generate_sentences(session, model, vocab): 29 | '''Generate sentences using the generator.''' 30 | f_dict = {model.data: np.zeros([cfg.batch_size, cfg.max_sent_length], dtype=np.int32)} 31 | utils.display_sentences(session.run(model.generated, f_dict), vocab, cfg.char_model) 32 | 33 | 34 | def save_model(session, saver, perp, cur_iters): 35 | '''Save model file.''' 36 | save_file = cfg.save_file 37 | if not cfg.save_overwrite: 38 | save_file = save_file + '.' + str(cur_iters) 39 | print("Saving model (epoch perplexity: %.3f) ..." % perp) 40 | save_file = saver.save(session, save_file) 41 | print("Saved to", save_file) 42 | 43 | 44 | def run_epoch(epoch, session, model, batch_loader, vocab, saver, steps, max_steps, scheduler, 45 | use_gan, gen_every): 46 | '''Runs the model on the given data for an epoch.''' 47 | start_time = time.time() 48 | nlls = 0.0 49 | mle_costs = 0.0 50 | iters = 0 51 | shortterm_nlls = 0.0 52 | shortterm_mle_costs = 0.0 53 | shortterm_d_costs = 0.0 54 | shortterm_g_costs = 0.0 55 | shortterm_iters = 0 56 | shortterm_steps = 0 57 | g_steps = 0 58 | d_steps = 0 59 | update_d = False 60 | update_g = False 61 | 62 | for step, batch in enumerate(batch_loader): 63 | cur_iters = steps + step 64 | if scheduler is not None: 65 | update_d = use_gan and scheduler.update_d() 66 | update_g = use_gan and scheduler.update_g() 67 | if update_d: 68 | d_steps += 1 69 | if update_g: 70 | g_steps += 1 71 | 72 | nll, mle_cost, d_cost, g_cost = call_session(session, model, batch, train_d=update_d, 73 | train_g=update_g) 74 | if scheduler is not None: 75 | if cfg.d_energy_based: 76 | d_acc = -1.0 77 | else: 78 | d_acc = np.exp(-d_cost) 79 | scheduler.add_d_acc(d_acc) 80 | 81 | if cfg.char_model: 82 | n_words = (np.sum(batch == vocab.vocab_lookup[' ']) // cfg.batch_size) + 1 83 | else: 84 | n_words = cfg.max_sent_length 85 | if scheduler is not None: 86 | scheduler.add_perp(np.exp(nll / n_words)) 87 | 88 | nlls += nll 89 | mle_costs += mle_cost 90 | shortterm_nlls += nll 91 | shortterm_mle_costs += mle_cost 92 | shortterm_d_costs += d_cost 93 | shortterm_g_costs += g_cost 94 | iters += n_words 95 | shortterm_iters += n_words 96 | shortterm_steps += 1 97 | 98 | if step % cfg.print_every == 0: 99 | avg_nll = shortterm_nlls / shortterm_iters 100 | avg_mle_cost = shortterm_mle_costs / shortterm_steps 101 | avg_d_cost = shortterm_d_costs / shortterm_steps 102 | if cfg.d_energy_based: 103 | d_acc = -1.0 104 | else: 105 | d_acc = np.exp(-avg_d_cost) 106 | if g_steps: 107 | avg_g_cost = shortterm_g_costs / g_steps 108 | else: 109 | avg_g_cost = -1.0 110 | print("%d: %d (%d) perplexity: %.3f mle_loss: %.4f mle_cost: %.4f d_cost: %.4f " 111 | "g_cost: %.4f d_acc: %.4f speed: %.0f wps D:%d G:%d" % (epoch + 1, step, 112 | cur_iters, np.exp(avg_nll), avg_nll, avg_mle_cost, avg_d_cost, avg_g_cost, d_acc, 113 | shortterm_iters * cfg.batch_size / (time.time() - start_time), d_steps, g_steps)) 114 | 115 | shortterm_nlls = 0.0 116 | shortterm_mle_costs = 0.0 117 | shortterm_d_costs = 0.0 118 | shortterm_g_costs = 0.0 119 | shortterm_iters = 0 120 | shortterm_steps = 0 121 | g_steps = 0 122 | d_steps = 0 123 | start_time = time.time() 124 | 125 | if gen_every > 0 and (step + 1) % gen_every == 0: 126 | for _ in range(cfg.gen_samples): 127 | generate_sentences(session, model, vocab) 128 | 129 | if saver is not None and cur_iters and cfg.save_every > 0 and \ 130 | cur_iters % cfg.save_every == 0: 131 | save_model(session, saver, np.exp(nlls / iters), cur_iters) 132 | 133 | if max_steps > 0 and cur_iters >= max_steps: 134 | break 135 | 136 | if gen_every < 0: 137 | for _ in range(cfg.gen_samples): 138 | generate_sentences(session, model, vocab) 139 | 140 | perp = np.exp(nlls / iters) 141 | cur_iters = steps + step 142 | if saver is not None and cfg.save_every < 0: 143 | save_model(session, saver, perp, cur_iters) 144 | return perp, cur_iters 145 | 146 | 147 | def main(_): 148 | vocab = Vocab() 149 | vocab.load_from_pickle() 150 | reader = Reader(vocab) 151 | 152 | config_proto = tf.ConfigProto() 153 | if not cfg.preallocate_gpu: 154 | config_proto.gpu_options.allow_growth = True 155 | if not cfg.training and not cfg.save_overwrite: 156 | load_files = [f for f in glob.glob(cfg.load_file + '.*') if not f.endswith('meta')] 157 | load_files = sorted(load_files, key=lambda x: float(x[len(cfg.load_file)+1:])) 158 | else: 159 | load_files = [cfg.load_file] 160 | if not cfg.training: 161 | test_perps = [] 162 | for load_file in load_files: 163 | with tf.Graph().as_default(), tf.Session(config=config_proto) as session: 164 | with tf.variable_scope("Model") as scope: 165 | if cfg.training: 166 | with tf.variable_scope("LR"): 167 | g_lr = tf.get_variable("g_lr", shape=[], initializer=tf.zeros_initializer, 168 | trainable=False) 169 | d_lr = tf.get_variable("d_lr", shape=[], initializer=tf.zeros_initializer, 170 | trainable=False) 171 | g_optimizer = utils.get_optimizer(g_lr, cfg.g_optimizer) 172 | d_optimizer = utils.get_optimizer(d_lr, cfg.d_optimizer) 173 | model = RNNLMModel(vocab, True, cfg.use_gan, g_optimizer=g_optimizer, 174 | d_optimizer=d_optimizer) 175 | scope.reuse_variables() 176 | eval_model = RNNLMModel(vocab, False, cfg.use_gan) 177 | else: 178 | test_model = RNNLMModel(vocab, False, cfg.use_gan) 179 | saver = tf.train.Saver(max_to_keep=None) 180 | steps = 0 181 | try: 182 | # try to restore a saved model file 183 | saver.restore(session, load_file) 184 | print("\nModel restored from", load_file) 185 | with tf.variable_scope("Model", reuse=True): 186 | steps = session.run(tf.get_variable("GlobalMLE/global_step")) 187 | print('Global step', steps) 188 | except ValueError: 189 | if cfg.training: 190 | tf.initialize_all_variables().run() 191 | print("No loadable model file, new model initialized.") 192 | else: 193 | print("You need to provide a valid model file for testing!") 194 | sys.exit(1) 195 | 196 | if cfg.training: 197 | train_perps = [] 198 | valid_perps = [] 199 | session.run(tf.assign(g_lr, cfg.g_learning_rate)) 200 | session.run(tf.assign(d_lr, cfg.d_learning_rate)) 201 | energy_based = cfg.d_rnn and cfg.d_energy_based 202 | scheduler = utils.Scheduler(cfg.min_d_acc, cfg.max_d_acc, cfg.max_perplexity, 203 | cfg.sc_list_size, cfg.sc_decay, eb=energy_based) 204 | for i in range(cfg.max_epoch): 205 | print("\nEpoch: %d" % (i + 1)) 206 | perplexity, steps = run_epoch(i, session, model, reader.training(), vocab, 207 | saver, steps, cfg.max_steps, scheduler, 208 | cfg.use_gan, cfg.gen_every) 209 | print("Epoch: %d Train Perplexity: %.3f" % (i + 1, perplexity)) 210 | train_perps.append(perplexity) 211 | if cfg.validate_every > 0 and (i + 1) % cfg.validate_every == 0: 212 | perplexity, _ = run_epoch(i, session, eval_model, reader.validation(), 213 | vocab, None, 0, -1, None, cfg.use_gan, -1) 214 | print("Epoch: %d Validation Perplexity: %.3f" % (i + 1, perplexity)) 215 | valid_perps.append(perplexity) 216 | else: 217 | valid_perps.append(None) 218 | print('Train:', train_perps) 219 | print('Valid:', valid_perps) 220 | if steps >= cfg.max_steps: 221 | break 222 | else: 223 | if cfg.test_validation: 224 | batch_loader = reader.validation() 225 | else: 226 | batch_loader = reader.testing() 227 | print('\nTesting') 228 | perplexity, _ = run_epoch(0, session, test_model, batch_loader, vocab, None, 0, 229 | cfg.max_steps, None, cfg.use_gan, -1) 230 | print("Test Perplexity: %.3f" % perplexity) 231 | test_perps.append((int(steps), perplexity)) 232 | print('Test:', test_perps) 233 | test_model = None 234 | 235 | 236 | if __name__ == "__main__": 237 | tf.app.run() 238 | -------------------------------------------------------------------------------- /reader.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import pickle 3 | import random 4 | 5 | import numpy as np 6 | import tensorflow as tf 7 | 8 | from config import cfg 9 | import utils 10 | 11 | 12 | class Vocab(object): 13 | 14 | '''Stores the vocab: forward and reverse mappings''' 15 | 16 | def __init__(self): 17 | self.vocab = ['', ''] 18 | self.vocab_lookup = {w: i for i, w in enumerate(self.vocab)} 19 | self.sos_index = self.vocab_lookup.get('') 20 | self.unk_index = self.vocab_lookup.get('') 21 | 22 | def load_by_parsing(self, save=False, verbose=True): 23 | '''Read the vocab from the dataset''' 24 | if verbose: 25 | print('Loading vocabulary by parsing...') 26 | fnames = Path(cfg.data_path).glob('*.txt') 27 | for fname in fnames: 28 | if verbose: 29 | print(fname) 30 | with fname.open('r') as f: 31 | for line in f: 32 | for word in utils.read_words(line, chars=cfg.char_model): 33 | if word not in self.vocab_lookup: 34 | self.vocab_lookup[word] = len(self.vocab) 35 | self.vocab.append(word) 36 | if verbose: 37 | print('Vocabulary loaded, size:', len(self.vocab)) 38 | 39 | def load_from_pickle(self, verbose=True): 40 | '''Read the vocab from a pickled file''' 41 | pkfile = cfg.vocab_file 42 | try: 43 | if verbose: 44 | print('Loading vocabulary from pickle...') 45 | with pkfile.open('rb') as f: 46 | self.vocab, self.vocab_lookup = pickle.load(f) 47 | if verbose: 48 | print('Vocabulary loaded, size:', len(self.vocab)) 49 | except IOError: 50 | if verbose: 51 | print('Error loading from pickle, attempting parsing.') 52 | self.load_by_parsing(save=True, verbose=verbose) 53 | with pkfile.open('wb') as f: 54 | pickle.dump([self.vocab, self.vocab_lookup], f, -1) 55 | if verbose: 56 | print('Saved pickle file.') 57 | 58 | def lookup(self, words): 59 | return [self.vocab_lookup.get(w) for w in words] 60 | 61 | 62 | class Reader(object): 63 | def __init__(self, vocab): 64 | self.vocab = vocab 65 | random.seed(0) # deterministic random 66 | 67 | def read_lines(self, fnames): 68 | '''Read single lines from data''' 69 | for fname in fnames: 70 | with fname.open('r') as f: 71 | for line in f: 72 | yield self.vocab.lookup([w for w in utils.read_words(line, 73 | chars=cfg.char_model)]) 74 | 75 | def _prepare(self, lines): 76 | '''Prepare non-overlapping data''' 77 | seqs = [] 78 | seq = [] 79 | for line in lines: 80 | line.insert(0, self.vocab.sos_index) 81 | for word in line: 82 | seq.append(word) 83 | if len(seq) == cfg.max_sent_length: 84 | seqs.append(seq) 85 | seq = [] 86 | return seqs 87 | 88 | def buffered_read(self, fnames, buffer_size=500): 89 | '''Read and yield a list of non-overlapping sequences''' 90 | buffer_size = max(buffer_size, cfg.max_sent_length) 91 | lines = [] 92 | for line in self.read_lines(fnames): 93 | lines.append(line) 94 | if len(lines) == buffer_size: 95 | random.shuffle(lines) 96 | yield self._prepare(lines) 97 | lines = [] 98 | if lines: 99 | random.shuffle(lines) 100 | yield self._prepare(lines) 101 | 102 | def buffered_read_batches(self, fnames, buffer_size=500): 103 | batches = [] 104 | batch = [] 105 | for lines in self.buffered_read(fnames): 106 | for line in lines: 107 | batch.append(line) 108 | if len(batch) == cfg.batch_size: 109 | batches.append(self.pack(batch)) 110 | if len(batches) == buffer_size: 111 | random.shuffle(batches) 112 | for batch in batches: 113 | yield batch 114 | batches = [] 115 | batch = [] 116 | # ignore current incomplete batch 117 | if batches: 118 | random.shuffle(batches) 119 | for batch in batches: 120 | yield batch 121 | 122 | def pack(self, batch): 123 | '''Pack python-list batches into numpy batches''' 124 | ret_batch = np.zeros([cfg.batch_size, cfg.max_sent_length], dtype=np.int32) 125 | for i, s in enumerate(batch): 126 | ret_batch[i, :len(s)] = s 127 | return ret_batch 128 | 129 | def training(self): 130 | '''Read batches from training data''' 131 | yield from self.buffered_read_batches([Path(cfg.data_path) / 'train.txt']) 132 | 133 | def validation(self): 134 | '''Read batches from validation data''' 135 | yield from self.buffered_read_batches([Path(cfg.data_path) / 'valid.txt']) 136 | 137 | def testing(self): 138 | '''Read batches from testing data''' 139 | yield from self.buffered_read_batches([Path(cfg.data_path) / 'test.txt']) 140 | 141 | 142 | def main(_): 143 | '''Reader tests''' 144 | 145 | vocab = Vocab() 146 | vocab.load_from_pickle() 147 | 148 | reader = Reader(vocab) 149 | c = 0 150 | w = 0 151 | for batch in reader.training(): 152 | n_words = np.sum(batch != 0) 153 | w += n_words 154 | c += len(batch) 155 | #for line in batch: 156 | # print(line) 157 | # for e in line: 158 | # if cfg.char_model: 159 | # print(vocab.vocab[e], end='') 160 | # else: 161 | # print(vocab.vocab[e], end=' ') 162 | # print() 163 | # print() 164 | print('Total lines:', c) 165 | print('Total words:', w) 166 | 167 | 168 | if __name__ == '__main__': 169 | tf.app.run() 170 | -------------------------------------------------------------------------------- /rnncell.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | import utils 4 | 5 | 6 | class GRUCell(tf.nn.rnn_cell.RNNCell): 7 | 8 | """Gated Recurrent Unit cell (cf. http://arxiv.org/abs/1406.1078).""" 9 | 10 | def __init__(self, num_units, pretanh=False, activation=tf.nn.tanh): 11 | self.num_units = num_units 12 | self.pretanh = pretanh 13 | self.activation = activation 14 | 15 | @property 16 | def state_size(self): 17 | if self.pretanh: 18 | return 2 * self.num_units 19 | else: 20 | return self.num_units 21 | 22 | @property 23 | def output_size(self): 24 | return self.num_units 25 | 26 | def __call__(self, inputs, state, scope=None): 27 | """Gated recurrent unit (GRU) with nunits cells.""" 28 | with tf.variable_scope(scope or type(self).__name__): # "GRUCell" 29 | if self.pretanh: 30 | state = state[:, :self.num_units] 31 | with tf.variable_scope("Gates"): # Reset gate and update gate. 32 | # We start with bias of 1.0 to not reset and not update. 33 | r, u = tf.split(1, 2, utils.linear([inputs, state], 2 * self.num_units, True, 1.0)) 34 | r, u = tf.nn.sigmoid(r), tf.nn.sigmoid(u) 35 | with tf.variable_scope("Candidate"): 36 | preact = utils.linear([inputs, r * state], self.num_units, True) 37 | c = self.activation(preact) 38 | new_h = u * state + (1 - u) * c 39 | if self.pretanh: 40 | new_state = tf.concat(1, [new_h, preact]) 41 | else: 42 | new_state = new_h 43 | return new_h, new_state 44 | 45 | 46 | class MultiRNNCell(tf.nn.rnn_cell.RNNCell): 47 | 48 | """RNN cell composed sequentially of multiple simple cells.""" 49 | 50 | def __init__(self, cells, embedding=None, softmax_w=None, softmax_b=None, return_states=False, 51 | outputs_are_states=True, pretanh=False, get_embeddings=False): 52 | """Create a RNN cell composed sequentially of a number of RNNCells. If embedding is not 53 | None, the output of the previous timestep is used for the current time step using the 54 | softmax variables. 55 | """ 56 | if not cells: 57 | raise ValueError("Must specify at least one cell for MultiRNNCell.") 58 | self.cells = cells 59 | if not (embedding is None and softmax_w is None and softmax_b is None) and \ 60 | not (embedding is not None and softmax_w is not None and softmax_b is not None): 61 | raise ValueError('Embedding and softmax variables have to all be None or all not None.') 62 | self.embedding = embedding 63 | self.softmax_w = softmax_w 64 | self.softmax_b = softmax_b 65 | self.return_states = return_states 66 | self.outputs_are_states = outputs_are_states # should be true for GRUs 67 | self.pretanh = pretanh 68 | self.get_embeddings = get_embeddings 69 | if embedding is not None: 70 | self.emb_size = embedding.get_shape()[1] 71 | else: 72 | self.emb_size = 0 73 | 74 | @property 75 | def state_size(self): 76 | sizes = [cell.state_size for cell in self.cells] 77 | if self.emb_size: 78 | sizes.extend([self.emb_size, 1]) 79 | return tuple(sizes) 80 | 81 | @property 82 | def output_size(self): 83 | size = self.cells[-1].output_size 84 | if self.return_states: 85 | if not self.pretanh and self.outputs_are_states: 86 | skip = 1 87 | else: 88 | skip = 0 89 | if self.pretanh: 90 | size += sum(cell.state_size // 2 for cell in self.cells) 91 | else: 92 | size += sum(cell.state_size for cell in self.cells[:-skip]) 93 | if self.get_embeddings: 94 | size += self.embedding.get_shape()[1].value 95 | if self.emb_size: 96 | size += 1 # for the current timestep prediction 97 | return size 98 | 99 | def initial_state(self, initial): 100 | '''Generate the required initial state from $initial.''' 101 | if self.emb_size: 102 | initial.append(tf.zeros([initial[0].get_shape()[0], self.emb_size])) 103 | initial.append(tf.zeros([initial[0].get_shape()[0], 1])) 104 | return tuple(initial) 105 | 106 | def __call__(self, inputs, state, scope=None): 107 | """Run this multi-layer cell on inputs, starting from state.""" 108 | with tf.variable_scope(scope or type(self).__name__): # "MultiRNNCell" 109 | if self.embedding is not None: 110 | cur_inp = tf.select(tf.greater(state[-1][:, 0], 0.5), state[-2], inputs) 111 | else: 112 | cur_inp = inputs 113 | new_states = [] 114 | if self.return_states: 115 | ret_states = [] 116 | for i, cell in enumerate(self.cells): 117 | with tf.variable_scope("Layer%d" % i): 118 | if not tf.nn.nest.is_sequence(state): 119 | raise ValueError("Expected state to be a tuple of length %d, but received: " 120 | "%s" % (len(self.state_size), state)) 121 | cur_state = state[i] 122 | cur_inp, new_state = cell(cur_inp, cur_state) 123 | new_states.append(new_state) 124 | if self.return_states: 125 | if self.pretanh: 126 | size = new_state.get_shape()[1] 127 | ret_states.append(new_state[:, size // 2:]) 128 | else: 129 | ret_states.append(new_state) 130 | if self.embedding is not None: 131 | logits = tf.nn.bias_add(tf.matmul(cur_inp, tf.transpose(self.softmax_w), 132 | name='Softmax_transform'), 133 | self.softmax_b) 134 | logits = tf.nn.log_softmax(logits) 135 | dist = tf.contrib.distributions.Categorical(logits) 136 | prediction = tf.cast(dist.sample(), tf.int64) 137 | with tf.device('/cpu:0'): 138 | embeddings = tf.nn.embedding_lookup(self.embedding, prediction, 139 | name='rnn_embedding_k1') 140 | new_states.append(embeddings) 141 | if self.return_states and self.get_embeddings: 142 | ret_states.insert(0, embeddings) 143 | new_states.append(tf.ones([inputs.get_shape()[0], 1])) # we have valid prev input 144 | if self.return_states: 145 | output = [cur_inp] 146 | if self.embedding is not None: 147 | output.append(tf.cast(tf.expand_dims(prediction, -1), tf.float32)) 148 | if not self.pretanh and self.outputs_are_states: 149 | # skip the last layer states, since they're outputs 150 | ret_states = ret_states[:-1] 151 | return tf.concat(1, output + ret_states), tuple(new_states) 152 | else: 153 | return cur_inp, tuple(new_states) 154 | -------------------------------------------------------------------------------- /rnnlm.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from config import cfg 4 | import rnncell 5 | import utils 6 | 7 | 8 | class RNNLMModel(object): 9 | 10 | '''The adversarial recurrent language model.''' 11 | 12 | def __init__(self, vocab, training, use_gan=True, g_optimizer=None, d_optimizer=None): 13 | self.vocab = vocab 14 | self.training = training 15 | self.g_optimizer = g_optimizer 16 | self.d_optimizer = d_optimizer 17 | 18 | self.embedding = self.word_embedding_matrix() 19 | self.softmax_w, self.softmax_b = self.softmax_variables() 20 | 21 | with tf.variable_scope("GlobalMLE"): 22 | self.global_step = tf.get_variable('global_step', shape=[], 23 | initializer=tf.zeros_initializer, 24 | trainable=False) 25 | # input data 26 | self.data = tf.placeholder(tf.int32, [cfg.batch_size, cfg.max_sent_length], name='data') 27 | 28 | embs = self.word_embeddings(self.data) 29 | output, mle_states, _ = self.generator(embs, True) 30 | _, gan_states, self.generated = self.generator(embs, False, True) 31 | if use_gan: 32 | states = tf.concat(0, [mle_states, gan_states]) 33 | 34 | if cfg.d_energy_based: 35 | d_out = self.discriminator_energy(states) 36 | d_loss, g_loss = self.gan_energy_loss(d_out[:, :-1, :], states) 37 | self.d_cost = tf.reduce_sum(d_loss) / (2 * cfg.batch_size) 38 | self.g_cost = tf.reduce_sum(g_loss) / (2 * cfg.batch_size) 39 | else: 40 | if cfg.d_rnn: 41 | d_out = self.discriminator_rnn(states) 42 | else: 43 | d_out = self.discriminator_finalstate(states) 44 | targets = tf.concat(0, [tf.ones([cfg.batch_size, 1]), 45 | tf.zeros([cfg.batch_size, 1])]) 46 | gan_loss = self.gan_loss(d_out, targets) 47 | self.d_cost = tf.reduce_sum(gan_loss) / (2 * cfg.batch_size) 48 | self.g_cost = -self.d_cost 49 | else: 50 | self.d_cost = tf.zeros([]) 51 | self.g_cost = tf.zeros([]) 52 | 53 | # shift left the input to get the targets 54 | targets = tf.concat(1, [self.data[:, 1:], tf.zeros([cfg.batch_size, 1], tf.int32)]) 55 | self.nll = tf.reduce_sum(self.mle_loss(output, targets)) / cfg.batch_size 56 | self.mle_cost = self.nll 57 | if training: 58 | self.mle_train_op = self.train_mle(self.mle_cost) 59 | else: 60 | self.mle_train_op = tf.no_op() 61 | if training and use_gan: 62 | self.d_train_op = self.train_d(self.d_cost) 63 | self.g_train_op = self.train_g(self.g_cost) 64 | else: 65 | self.d_train_op = tf.no_op() 66 | self.g_train_op = tf.no_op() 67 | 68 | def rnn_cell(self, num_layers, hidden_size, embedding=None, softmax_w=None, softmax_b=None, 69 | return_states=False, pretanh=False, get_embeddings=False): 70 | '''Return a multi-layer RNN cell.''' 71 | return rnncell.MultiRNNCell([rnncell.GRUCell(hidden_size, pretanh=pretanh) 72 | for _ in range(num_layers)], embedding=embedding, 73 | softmax_w=softmax_w, softmax_b=softmax_b, 74 | return_states=return_states, pretanh=pretanh, 75 | get_embeddings=get_embeddings) 76 | 77 | def word_embedding_matrix(self): 78 | '''Define the word embedding matrix.''' 79 | with tf.device('/cpu:0') and tf.variable_scope("Embeddings"): 80 | embedding = tf.get_variable('word_embedding', [len(self.vocab.vocab), 81 | cfg.emb_size], 82 | initializer=tf.random_uniform_initializer(-1.0, 1.0)) 83 | return embedding 84 | 85 | def softmax_variables(self): 86 | '''Define the softmax weight and bias variables.''' 87 | with tf.variable_scope("MLE_Softmax"): 88 | softmax_w = tf.get_variable("W", [len(self.vocab.vocab), cfg.hidden_size], 89 | initializer=tf.contrib.layers.xavier_initializer()) 90 | softmax_b = tf.get_variable("b", [len(self.vocab.vocab)], 91 | initializer=tf.zeros_initializer) 92 | return softmax_w, softmax_b 93 | 94 | def word_embeddings(self, inputs): 95 | '''Look up word embeddings for the input indices.''' 96 | with tf.device('/cpu:0'): 97 | embeds = tf.nn.embedding_lookup(self.embedding, inputs, name='word_embedding_lookup') 98 | return embeds 99 | 100 | def generator(self, inputs, mle_mode, reuse=None): 101 | '''Use the word inputs to predict next words.''' 102 | with tf.variable_scope("Generator", reuse=reuse): 103 | if mle_mode: 104 | cell = self.rnn_cell(cfg.num_layers, cfg.hidden_size, return_states=True, 105 | pretanh=True) 106 | else: 107 | cell = self.rnn_cell(cfg.num_layers, cfg.hidden_size, self.embedding, 108 | self.softmax_w, self.softmax_b, return_states=True, 109 | pretanh=True, get_embeddings=cfg.concat_inputs) 110 | outputs, _ = tf.nn.dynamic_rnn(cell, inputs, swap_memory=True, dtype=tf.float32) 111 | output = outputs[:, :, :cfg.hidden_size] 112 | if mle_mode: 113 | generated = None 114 | skip = 0 115 | else: 116 | generated = tf.squeeze(tf.cast(outputs[:, :, cfg.hidden_size:cfg.hidden_size+1], 117 | tf.int32), [-1]) 118 | skip = 1 119 | if cfg.concat_inputs: 120 | embeddings = outputs[:, :, cfg.hidden_size+1:cfg.hidden_size+1+cfg.emb_size] 121 | embeddings = tf.concat(1, [inputs[:, :1, :], embeddings[:, :-1, :]]) 122 | skip += cfg.emb_size 123 | states = outputs[:, :, cfg.hidden_size+skip:] 124 | if cfg.concat_inputs: 125 | if mle_mode: 126 | states = tf.concat(2, [states, inputs]) 127 | else: 128 | states = tf.concat(2, [states, embeddings]) 129 | return output, states, generated 130 | 131 | def mle_loss(self, outputs, targets): 132 | '''Maximum likelihood estimation loss.''' 133 | # don't enfoce loss on true 's, makes the reported perlexity slightly overestimated 134 | mask = tf.cast(tf.not_equal(targets, self.vocab.unk_index, name='unk_mask'), tf.float32) 135 | output = tf.reshape(tf.concat(1, outputs), [-1, cfg.hidden_size]) 136 | if self.training and cfg.softmax_samples < len(self.vocab.vocab): 137 | targets = tf.reshape(targets, [-1, 1]) 138 | mask = tf.reshape(mask, [-1]) 139 | loss = tf.nn.sampled_softmax_loss(self.softmax_w, self.softmax_b, output, targets, 140 | cfg.softmax_samples, len(self.vocab.vocab)) 141 | loss *= mask 142 | else: 143 | logits = tf.nn.bias_add(tf.matmul(output, tf.transpose(self.softmax_w), 144 | name='softmax_transform_mle'), self.softmax_b) 145 | loss = tf.nn.seq2seq.sequence_loss_by_example([logits], 146 | [tf.reshape(targets, [-1])], 147 | [tf.reshape(mask, [-1])]) 148 | return tf.reshape(loss, [cfg.batch_size, -1]) 149 | 150 | def discriminator_rnn(self, states): 151 | '''Recurrent discriminator that operates on the sequence of states of the sentences.''' 152 | with tf.variable_scope("Discriminator"): 153 | if cfg.d_rnn_bidirect: 154 | hidden_size = cfg.hidden_size 155 | fcell = self.rnn_cell(cfg.d_num_layers, hidden_size, return_states=True) 156 | bcell = self.rnn_cell(cfg.d_num_layers, hidden_size, return_states=True) 157 | seq_lengths = [cfg.max_sent_length] * (2 * cfg.batch_size) 158 | outputs, _ = tf.nn.bidirectional_dynamic_rnn(fcell, bcell, states, 159 | sequence_length=seq_lengths, 160 | swap_memory=True, dtype=tf.float32) 161 | else: 162 | hidden_size = cfg.hidden_size * 2 163 | cell = self.rnn_cell(cfg.d_num_layers, hidden_size, return_states=True) 164 | outputs, _ = tf.nn.dynamic_rnn(cell, states, swap_memory=True, dtype=tf.float32) 165 | outputs = (outputs,) # to match bidirectional RNN's output format 166 | d_states = [] 167 | for out in outputs: 168 | output = out[:, :, :hidden_size] 169 | dir_states = out[:, :, hidden_size:] 170 | # for GRU, we skipped the last layer states because they're the outputs 171 | d_states.append(tf.concat(2, [dir_states, output])) 172 | return self._discriminator_conv(tf.concat(2, d_states)) 173 | 174 | def _discriminator_conv(self, states): 175 | '''Convolve output of bidirectional RNN and predict the discriminator label.''' 176 | with tf.variable_scope("Discriminator"): 177 | W_conv = tf.get_variable('W_conv', [cfg.d_conv_window, 1, states.get_shape()[2], 178 | cfg.hidden_size // cfg.d_conv_window], 179 | initializer=tf.contrib.layers.xavier_initializer_conv2d()) 180 | b_conv = tf.get_variable('b_conv', [cfg.hidden_size // cfg.d_conv_window], 181 | initializer=tf.constant_initializer(0.0)) 182 | states = tf.expand_dims(states, 2) 183 | conv = tf.nn.conv2d(states, W_conv, strides=[1, 1, 1, 1], padding='SAME') 184 | conv_out = tf.reshape(conv, [2 * cfg.batch_size, -1, 185 | cfg.hidden_size // cfg.d_conv_window]) 186 | conv_out = tf.nn.elu(tf.nn.bias_add(conv_out, b_conv)) 187 | conv_out = tf.reshape(conv_out, [2 * cfg.batch_size, -1]) 188 | output = utils.linear(conv_out, 1, True, 0.0, scope='discriminator_output') 189 | return output 190 | 191 | def discriminator_finalstate(self, states): 192 | '''Discriminator that operates on the final states of the sentences.''' 193 | with tf.variable_scope("Discriminator"): 194 | lin1 = tf.nn.elu(utils.linear(states[:, -1, :], cfg.hidden_size, True, 0.0, 195 | scope='discriminator_lin1')) 196 | lin2 = tf.nn.elu(utils.linear(lin1, cfg.hidden_size // 2, True, 0.0, 197 | scope='discriminator_lin2')) 198 | output = utils.linear(lin2, 1, True, 0.0, scope='discriminator_output') 199 | return output 200 | 201 | def discriminator_energy(self, states): 202 | '''An energy-based discriminator that tries to reconstruct the input states.''' 203 | with tf.variable_scope("Discriminator"): 204 | _, state = tf.nn.dynamic_rnn(self.rnn_cell(cfg.d_num_layers, cfg.hidden_size), states, 205 | swap_memory=True, dtype=tf.float32, 206 | scope='discriminator_encoder') 207 | # XXX use BiRNN+convnet for the encoder 208 | # this latent needs a more capacity than to reproduce the hidden states 209 | latent_size = cfg.hidden_size // 2 210 | latent = tf.nn.elu(utils.linear(state, latent_size, True, 211 | scope='discriminator_latent_transform')) 212 | latent = utils.highway(latent, layer_size=2, f=tf.nn.elu) 213 | decoder_input = tf.concat(1, [tf.zeros([2 * cfg.batch_size, 1, 214 | states.get_shape()[2].value]), states]) 215 | decoder_input = tf.concat(2, [decoder_input, 216 | tf.tile(tf.expand_dims(latent, 1), 217 | [1, decoder_input.get_shape()[1].value, 1])]) 218 | hidden_size = cfg.hidden_size 219 | if cfg.concat_inputs: 220 | hidden_size += cfg.emb_size 221 | output, _ = tf.nn.dynamic_rnn(self.rnn_cell(cfg.d_num_layers, hidden_size), 222 | decoder_input, swap_memory=True, dtype=tf.float32, 223 | scope='discriminator_decoder') 224 | output = tf.reshape(output, [-1, hidden_size]) 225 | reconstructed = utils.linear(output, hidden_size, True, 0.0, 226 | scope='discriminator_reconst') 227 | reconstructed = tf.reshape(reconstructed, [2 * cfg.batch_size, -1, hidden_size]) 228 | return reconstructed 229 | 230 | def gan_energy_loss(self, states, targets): 231 | '''Return the GAN energy loss. Put no variables here.''' 232 | losses = tf.reduce_sum(tf.square(states - targets), [1, 2]) / cfg.max_sent_length 233 | d_losses = losses[:cfg.batch_size] + tf.nn.relu(cfg.d_eb_margin - losses[cfg.batch_size:]) 234 | g_losses = losses[cfg.batch_size:] 235 | return d_losses, g_losses 236 | 237 | def gan_loss(self, d_out, targets): 238 | '''Return the discriminator loss according to the label (1 if MLE mode). 239 | Put no variables here.''' 240 | return tf.nn.sigmoid_cross_entropy_with_logits(d_out, targets) 241 | 242 | def _train(self, cost, scope, optimizer, global_step=None): 243 | '''Generic training helper''' 244 | tvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=scope) 245 | grads = tf.gradients(cost, tvars) 246 | if cfg.max_grad_norm > 0: 247 | grads, _ = tf.clip_by_global_norm(grads, cfg.max_grad_norm) 248 | return optimizer.apply_gradients(zip(grads, tvars), global_step=global_step) 249 | 250 | def train_mle(self, cost): 251 | '''Training op for MLE mode.''' 252 | return self._train(cost, '.*/(Embeddings|Generator|MLE_Softmax)', self.g_optimizer, 253 | self.global_step) 254 | 255 | def train_d(self, cost): 256 | '''Training op for GAN mode, discriminator.''' 257 | return self._train(cost, '.*/Discriminator', self.d_optimizer) 258 | 259 | def train_g(self, cost): 260 | '''Training op for GAN mode, generator.''' 261 | # don't update embeddings, just update the generated distributions 262 | return self._train(cost, '.*/Generator', self.g_optimizer) 263 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import re 3 | 4 | import numpy as np 5 | import tensorflow as tf 6 | 7 | 8 | class Scheduler(object): 9 | 10 | '''Scheduler for GANs''' 11 | 12 | def __init__(self, min_d_acc, max_d_acc, max_perp, list_size, decay, eb=False): 13 | assert min_d_acc < max_d_acc 14 | self.min_d_acc = min_d_acc 15 | self.max_d_acc = max_d_acc 16 | self.max_perp = max_perp 17 | self.list_size = list_size 18 | self.eb = eb 19 | self.d_accs = [] 20 | self.perps = [] 21 | coeffs = [1.0] 22 | for _ in range(list_size - 1): 23 | coeffs.append(coeffs[-1] * decay) 24 | self.coeffs = np.array(coeffs) / sum(coeffs) 25 | # TODO make configurable 26 | self.d_coeffs = np.array([0.998 ** i for i in range(list_size)]) 27 | 28 | def add_d_acc(self, d_acc): 29 | '''Observe new descriminator accuracy.''' 30 | self.d_accs.insert(0, d_acc) 31 | if len(self.d_accs) > self.list_size: 32 | self.d_accs.pop() 33 | 34 | def add_perp(self, perp): 35 | '''Observe new perplexity.''' 36 | self.perps.insert(0, perp) 37 | if len(self.perps) > self.list_size: 38 | self.perps.pop() 39 | 40 | def _current_perp(self): 41 | '''Smooth approximation of current perplexity.''' 42 | if not self.perps: 43 | return float('inf') 44 | coeffs = self.coeffs.copy(order='K') 45 | if len(self.perps) < self.list_size: 46 | coeffs = coeffs[:len(self.perps)] 47 | coeffs /= np.sum(coeffs) 48 | return np.sum(np.array(self.perps) * coeffs) 49 | 50 | def _current_d_acc(self): 51 | '''Overestimate the current discriminator accuracy to avoid a powerful discriminator.''' 52 | if not self.d_accs: 53 | return 0.5 54 | else: 55 | return np.max(self.d_accs * self.d_coeffs[:len(self.d_accs)]) 56 | 57 | def update_d(self): 58 | '''Whether or not to update the descriminator.''' 59 | if self.max_perp > 0.0 and self._current_perp() > self.max_perp: 60 | return False 61 | if self.eb or self._current_d_acc() < self.max_d_acc: 62 | return True 63 | else: 64 | return False 65 | 66 | def update_g(self): 67 | '''Whether or not to update the generator.''' 68 | if self.max_perp > 0.0 and self._current_perp() > self.max_perp: 69 | return False 70 | if self.eb or self._current_d_acc() > self.min_d_acc: 71 | return True 72 | else: 73 | return False 74 | 75 | 76 | fix_re = re.compile(r'''[^a-z0-9"'?.,]+''') 77 | num_re = re.compile(r'[0-9]+') 78 | 79 | 80 | def fix_word(word): 81 | word = word.lower() 82 | word = fix_re.sub('', word) 83 | word = num_re.sub('#', word) 84 | return word 85 | 86 | 87 | def display_sentences(output, vocab, char_model): 88 | '''Display sentences from indices.''' 89 | if char_model: 90 | space = '' 91 | nospace = ' ' 92 | else: 93 | space = ' ' 94 | nospace = '' 95 | for i, sent in enumerate(output): 96 | print('Sentence %d:' % i, end=' ') 97 | for word in sent: 98 | if word == vocab.sos_index: 99 | print(nospace + '. ', end='') 100 | else: 101 | print(vocab.vocab[word], end=space) 102 | print() 103 | print() 104 | 105 | 106 | def read_words(line, chars): 107 | if chars: 108 | first = True 109 | for word in line.split(): 110 | if word != '' and not chars: 111 | word = fix_word(word) 112 | if word: 113 | if chars: 114 | if not first: 115 | yield ' ' 116 | else: 117 | first = False 118 | if word == '': 119 | yield word 120 | else: 121 | for c in word: 122 | yield c 123 | else: 124 | yield word 125 | 126 | 127 | def grouper(n, iterable, fillvalue=None): 128 | '''Group elements of iterable in groups of n. For example: 129 | >>> [e for e in grouper(3, [1,2,3,4,5,6,7])] 130 | [(1, 2, 3), (4, 5, 6), (7, None, None)]''' 131 | args = [iter(iterable)] * n 132 | return itertools.zip_longest(*args, fillvalue=fillvalue) 133 | 134 | 135 | def get_optimizer(lr, name): 136 | '''Return an optimizer.''' 137 | if name == 'sgd': 138 | optimizer = tf.train.GradientDescentOptimizer(lr) 139 | elif name == 'adam': 140 | optimizer = tf.train.AdamOptimizer(lr) 141 | elif name == 'adagrad': 142 | optimizer = tf.train.AdagradOptimizer(lr) 143 | elif name == 'adadelta': 144 | optimizer = tf.train.AdadeltaOptimizer(lr) 145 | return optimizer 146 | 147 | 148 | def list_all_variables(trainable=True, rest=False): 149 | trainv = tf.trainable_variables() 150 | if trainable: 151 | print('\nTrainable:') 152 | for v in trainv: 153 | print(v.op.name) 154 | if rest: 155 | print('\nOthers:') 156 | for v in tf.all_variables(): 157 | if v not in trainv: 158 | print(v.op.name) 159 | 160 | 161 | def rowwise_lookup(params, indices): 162 | '''Look up an index from each row of params as per indices.''' 163 | shape = params.get_shape().as_list() 164 | if len(shape) == 2: 165 | hidden_size = 1 166 | else: 167 | hidden_size = shape[-1] 168 | flattened = tf.reshape(params, [-1, hidden_size]) 169 | flattened_indices = indices + (tf.range(shape[0]) * tf.shape(params)[1]) 170 | return tf.gather(flattened, flattened_indices) 171 | 172 | 173 | def linear(args, output_size, bias, bias_start=0.0, scope=None, train=True, initializer=None): 174 | """Linear map: sum_i(args[i] * W[i]), where W[i] is a variable. 175 | Args: 176 | args: a 2D Tensor or a list of 2D, batch x n, Tensors. 177 | output_size: int, second dimension of W[i]. 178 | bias: boolean, whether to add a bias term or not. 179 | bias_start: starting value to initialize the bias; 0 by default. 180 | scope: VariableScope for the created subgraph; defaults to "Linear". 181 | Returns: 182 | A 2D Tensor with shape [batch x output_size] equal to 183 | sum_i(args[i] * W[i]), where W[i]s are newly created matrices. 184 | Raises: 185 | ValueError: if some of the arguments has unspecified or wrong shape. 186 | Based on the code from TensorFlow.""" 187 | if not tf.nn.nest.is_sequence(args): 188 | args = [args] 189 | 190 | # Calculate the total size of arguments on dimension 1. 191 | total_arg_size = 0 192 | shapes = [a.get_shape().as_list() for a in args] 193 | for shape in shapes: 194 | if len(shape) != 2: 195 | raise ValueError("Linear is expecting 2D arguments: %s" % str(shapes)) 196 | if not shape[1]: 197 | raise ValueError("Linear expects shape[1] of arguments: %s" % str(shapes)) 198 | else: 199 | total_arg_size += shape[1] 200 | 201 | dtype = [a.dtype for a in args][0] 202 | 203 | if initializer is None: 204 | initializer = tf.contrib.layers.xavier_initializer() 205 | # Now the computation. 206 | with tf.variable_scope(scope or "Linear"): 207 | matrix = tf.get_variable("Matrix", [total_arg_size, output_size], dtype=dtype, 208 | initializer=initializer, trainable=train) 209 | if len(args) == 1: 210 | res = tf.matmul(args[0], matrix) 211 | else: 212 | res = tf.matmul(tf.concat(1, args), matrix) 213 | if not bias: 214 | return res 215 | bias_term = tf.get_variable("Bias", [output_size], dtype=dtype, 216 | initializer=tf.constant_initializer(bias_start, dtype=dtype), 217 | trainable=train) 218 | return res + bias_term 219 | 220 | 221 | def highway(input_, layer_size=1, bias=-2, f=tf.nn.tanh, scope=None): 222 | """Highway Network (cf. http://arxiv.org/abs/1505.00387). 223 | t = sigmoid(Wy + b) 224 | z = t * g(Wy + b) + (1 - t) * y 225 | where g is nonlinearity, t is transform gate, and (1 - t) is carry gate.""" 226 | if tf.nn.nest.is_sequence(input_): 227 | input_ = tf.concat(1, input_) 228 | shape = input_.get_shape() 229 | if len(shape) != 2: 230 | raise ValueError("Highway is expecting 2D arguments: %s" % str(shape)) 231 | size = shape[1] 232 | with tf.variable_scope(scope or "Highway"): 233 | for idx in range(layer_size): 234 | output = f(linear(input_, size, False, scope='HW_Nonlin_%d' % idx)) 235 | transform_gate = tf.sigmoid(linear(input_, size, False, scope='HW_Gate_%d' % idx) 236 | + bias) 237 | carry_gate = 1.0 - transform_gate 238 | output = transform_gate * output + carry_gate * input_ 239 | input_ = output 240 | 241 | return output 242 | --------------------------------------------------------------------------------