├── scripts ├── sample.sh ├── shakespeare.sh └── test.sh ├── README.md ├── sample.py ├── utils.py ├── config.py ├── model.py └── train.py /scripts/sample.sh: -------------------------------------------------------------------------------- 1 | python ../sample.py --init_dir=./output --start_text="The meaning of life is" --length=1000 -------------------------------------------------------------------------------- /scripts/shakespeare.sh: -------------------------------------------------------------------------------- 1 | python ../train.py --data_file=../data/tiny_shakespeare.txt --dropout=0.5 --verbose=1 --debug -------------------------------------------------------------------------------- /scripts/test.sh: -------------------------------------------------------------------------------- 1 | python ../train.py --data_file=../data/tiny_shakespeare.txt --num_epochs=10 --dropout=0.5 --verbose=1 --test --debug 2 | python ../train.py --data_file=../data/tiny_shakespeare.txt --num_epochs=10 --dropout=0.5 --verbose=1 --test --debug --init_dir=./output 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Char RNN Language Model based on Tensorflow 2 | Multi-layer Recurrent Neural Networks (LSTM, RNN) for character-level language models in Python using Tensorflow. 3 | 4 | If you are not familar with RNN, you can first read my post . 5 | 6 | Inspired from Andrej Karpathy's [char-rnn](https://github.com/karpathy/char-rnn). 7 | 8 | See his article [The Unreasonable Effectiveness of Recurrent Neural Network](http://karpathy.github.io/2015/05/21/rnn-effectiveness/) to learn more about this model. 9 | 10 | ## How to use? 11 | In the directory `scripts`, you can first run `sh shakespeare.sh` to train the model, and the run `sh sample.sh` to sample text based on the model. 12 | -------------------------------------------------------------------------------- /sample.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os, sys 3 | import logging 4 | 5 | import numpy as np 6 | import tensorflow as tf 7 | from model import CharRNNLM 8 | from config import config_sample 9 | from utils import VocabularyLoader 10 | 11 | def main(): 12 | args = config_sample() 13 | 14 | logging.basicConfig(stream=sys.stdout, 15 | format='%(asctime)s %(levelname)s:%(message)s', 16 | level=logging.INFO, datefmt='%I:%M:%S') 17 | 18 | # Prepare parameters. 19 | with open(os.path.join(args.init_dir, 'result.json'), 'r') as f: 20 | result = json.load(f) 21 | params = result['params'] 22 | best_model = result['best_model'] 23 | best_valid_ppl = result['best_valid_ppl'] 24 | if 'encoding' in result: 25 | args.encoding = result['encoding'] 26 | else: 27 | args.encoding = 'utf-8' 28 | args.vocab_file = os.path.join(args.init_dir, 'vocab.json') 29 | vocab_loader = VocabularyLoader() 30 | vocab_loader.load_vocab(args.vocab_file, args.encoding) 31 | 32 | logging.info('best_model: %s\n', best_model) 33 | 34 | # Create graphs 35 | graph = tf.Graph() 36 | with graph.as_default(): 37 | with tf.name_scope('evaluation'): 38 | model = CharRNNLM(is_training=False, infer=True, **params) 39 | saver = tf.train.Saver(name='model_saver') 40 | 41 | if args.seed >= 0: 42 | np.random.seed(args.seed) 43 | with tf.Session(graph=graph) as session: 44 | saver.restore(session, best_model) 45 | sample = model.sample_seq(session, args.length, args.start_text, vocab_loader, 46 | max_prob=args.max_prob) 47 | print('Sampled text is:\n\n%s' % sample) 48 | 49 | if __name__ == '__main__': 50 | main() -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import codecs 2 | import json 3 | import numpy as np 4 | 5 | class VocabularyLoader(object): 6 | def load_vocab(self, vocab_file, encoding): 7 | with codecs.open(vocab_file, 'r', encoding=encoding) as f: 8 | self.vocab_index_dict = json.load(f) 9 | self.index_vocab_dict = {} 10 | self.vocab_size = 0 11 | for char, index in self.vocab_index_dict.iteritems(): 12 | self.index_vocab_dict[index] = char 13 | self.vocab_size += 1 14 | 15 | def create_vocab(self, text): 16 | unique_chars = list(set(text)) 17 | self.vocab_size = len(unique_chars) 18 | self.vocab_index_dict = {} 19 | self.index_vocab_dict = {} 20 | for i, char in enumerate(unique_chars): 21 | self.vocab_index_dict[char] = i 22 | self.index_vocab_dict[i] = char 23 | 24 | def save_vocab(self, vocab_file, encoding): 25 | with codecs.open(vocab_file, 'w', encoding=encoding) as f: 26 | json.dump(self.vocab_index_dict, f, indent=2, sort_keys=True) 27 | 28 | 29 | class BatchGenerator(object): 30 | def __init__(self, vocab_index_dict, text, batch_size, seq_length): 31 | self.batch_size = batch_size 32 | self.seq_length = seq_length 33 | self.tensor = np.array(list(map(vocab_index_dict.get, text))) 34 | self.create_batches() 35 | self.reset_batch_pointer() 36 | 37 | def reset_batch_pointer(self): 38 | self.pointer = 0 39 | 40 | def create_batches(self): 41 | self.num_batches = int(self.tensor.size / (self.batch_size * self.seq_length)) 42 | 43 | # When the data (tesor) is too small, let's give them a better error message 44 | if self.num_batches==0: 45 | assert False, "Not enough data. Make seq_length and batch_size small." 46 | 47 | self.tensor = self.tensor[:self.num_batches * self.batch_size * self.seq_length] 48 | xdata = self.tensor 49 | ydata = np.copy(self.tensor) 50 | ydata[:-1] = xdata[1:] 51 | ydata[-1] = xdata[0] 52 | self.x_batches = np.split(xdata.reshape(self.batch_size, -1), self.num_batches, 1) 53 | self.y_batches = np.split(ydata.reshape(self.batch_size, -1), self.num_batches, 1) 54 | 55 | def next_batch(self): 56 | x, y = self.x_batches[self.pointer], self.y_batches[self.pointer] 57 | self.pointer += 1 58 | return x, y 59 | 60 | # util functions 61 | def batche2string(batch, index_vocab_dict): 62 | return ''.join(list(map(index_vocab_dict.get, batch))) -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | 4 | def config_train(): 5 | parser = argparse.ArgumentParser() 6 | 7 | # Data and vocabulary file 8 | parser.add_argument('--data_file', type=str, 9 | default='data/tiny_shakespeare.txt', 10 | help='data file') 11 | 12 | parser.add_argument('--encoding', type=str, 13 | default='utf-8', 14 | help='the encoding of the data file.') 15 | 16 | # Parameters for saving models. 17 | parser.add_argument('--output_dir', type=str, default='output', 18 | help=('directory to store final and' 19 | ' intermediate results and models.')) 20 | 21 | # Parameters to configure the neural network. 22 | parser.add_argument('--hidden_size', type=int, default=128, 23 | help='size of RNN hidden state vector') 24 | parser.add_argument('--embedding_size', type=int, default=0, 25 | help='size of character embeddings, 0 for one-hot') 26 | parser.add_argument('--num_layers', type=int, default=2, 27 | help='number of layers in the RNN') 28 | parser.add_argument('--num_unrollings', type=int, default=10, 29 | help='number of unrolling steps.') 30 | parser.add_argument('--model', type=str, default='lstm', 31 | help='which model to use (rnn, lstm or gru).') 32 | 33 | # Parameters to control the training. 34 | parser.add_argument('--num_epochs', type=int, default=50, 35 | help='number of epochs') 36 | parser.add_argument('--batch_size', type=int, default=20, 37 | help='minibatch size') 38 | parser.add_argument('--train_frac', type=float, default=0.9, 39 | help='fraction of data used for training.') 40 | parser.add_argument('--valid_frac', type=float, default=0.05, 41 | help='fraction of data used for validation.') 42 | # test_frac is computed as (1 - train_frac - valid_frac). 43 | parser.add_argument('--dropout', type=float, default=0.0, 44 | help='dropout rate, default to 0 (no dropout).') 45 | 46 | parser.add_argument('--input_dropout', type=float, default=0.0, 47 | help=('dropout rate on input layer, default to 0 (no dropout),' 48 | 'and no dropout if using one-hot representation.')) 49 | 50 | # Parameters for gradient descent. 51 | parser.add_argument('--max_grad_norm', type=float, default=5., 52 | help='clip global grad norm') 53 | parser.add_argument('--learning_rate', type=float, default=5e-3, 54 | help='initial learning rate') 55 | 56 | # Parameters for logging. 57 | parser.add_argument('--progress_freq', type=int, default=100, 58 | help=('frequency for progress report in training and evalution.')) 59 | parser.add_argument('--verbose', type=int, default=0, 60 | help=('whether to show progress report in training and evalution.')) 61 | 62 | # Parameters to feed in the initial model and current best model. 63 | parser.add_argument('--init_model', type=str, 64 | default='', help=('initial model')) 65 | parser.add_argument('--best_model', type=str, 66 | default='', help=('current best model')) 67 | parser.add_argument('--best_valid_ppl', type=float, 68 | default=np.Inf, help=('current valid perplexity')) 69 | 70 | # Parameters for using saved best models. 71 | parser.add_argument('--init_dir', type=str, default='', 72 | help='continue from the outputs in the given directory') 73 | 74 | # Parameters for debugging. 75 | parser.add_argument('--debug', dest='debug', action='store_true', 76 | help='show debug information') 77 | parser.set_defaults(debug=False) 78 | 79 | # Parameters for unittesting the implementation. 80 | parser.add_argument('--test', dest='test', action='store_true', 81 | help=('use the first 1000 character to as data to test the implementation')) 82 | parser.set_defaults(test=False) 83 | 84 | args = parser.parse_args() 85 | 86 | return args 87 | 88 | 89 | def config_sample(): 90 | parser = argparse.ArgumentParser() 91 | 92 | # Parameters for using saved best models. 93 | parser.add_argument('--init_dir', type=str, default='', 94 | help='continue from the outputs in the given directory') 95 | 96 | # Parameters for sampling. 97 | parser.add_argument('--max_prob', dest='max_prob', action='store_true', 98 | help='always pick the most probable next character in sampling') 99 | parser.set_defaults(max_prob=False) 100 | 101 | parser.add_argument('--start_text', type=str, 102 | default='The meaning of life is ', 103 | help='the text to start with') 104 | 105 | parser.add_argument('--length', type=int, 106 | default=100, 107 | help='length of sampled sequence') 108 | 109 | parser.add_argument('--seed', type=int, 110 | default=-1, 111 | help=('seed for sampling to replicate results, ' 112 | 'an integer between 0 and 4294967295.')) 113 | 114 | args = parser.parse_args() 115 | 116 | return args -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import time 3 | import numpy as np 4 | import tensorflow as tf 5 | from tensorflow.models.rnn import rnn 6 | 7 | logging.getLogger('tensorflow').setLevel(logging.WARNING) 8 | 9 | class CharRNNLM(object): 10 | def __init__(self, is_training, batch_size, num_unrollings, vocab_size, 11 | hidden_size, max_grad_norm, embedding_size, num_layers, 12 | learning_rate, model, dropout=0.0, input_dropout=0.0, infer=False): 13 | self.batch_size = batch_size 14 | self.num_unrollings = num_unrollings 15 | if infer: 16 | self.batch_size = 1 17 | self.num_unrollings = 1 18 | self.hidden_size = hidden_size 19 | self.vocab_size = vocab_size 20 | self.max_grad_norm = max_grad_norm 21 | self.num_layers = num_layers 22 | self.embedding_size = embedding_size 23 | self.model = model 24 | self.dropout = dropout 25 | self.input_dropout = input_dropout 26 | if embedding_size <= 0: 27 | self.input_size = vocab_size 28 | self.input_dropout = 0.0 29 | else: 30 | self.input_size = embedding_size 31 | 32 | self.input_data = tf.placeholder(tf.int64, [self.batch_size, self.num_unrollings], name='inputs') 33 | self.targets = tf.placeholder(tf.int64, [self.batch_size, self.num_unrollings], name='targets') 34 | 35 | if self.model == 'rnn': 36 | cell_fn = tf.nn.rnn_cell.BasicRNNCell 37 | elif self.model == 'lstm': 38 | cell_fn = tf.nn.rnn_cell.BasicLSTMCell 39 | elif self.model == 'gru': 40 | cell_fn = tf.nn.rnn_cell.GRUCell 41 | 42 | params = {'input_size': self.input_size} 43 | if self.model == 'lstm': 44 | params['forget_bias'] = 0.0 45 | cell = cell_fn(self.hidden_size, **params) 46 | 47 | cells = [cell] 48 | params['input_size'] = self.hidden_size 49 | for i in range(self.num_layers-1): 50 | higher_layer_cell = cell_fn(self.hidden_size, **params) 51 | cells.append(higher_layer_cell) 52 | 53 | if is_training and self.dropout > 0: 54 | cells = [tf.nn.rnn_cell.DropoutWrapper(cell, output_keep_prob=1.0-self.dropout) for cell in cells] 55 | 56 | multi_cell = tf.nn.rnn_cell.MultiRNNCell(cells) 57 | 58 | with tf.name_scope('initial_state'): 59 | self.zero_state = multi_cell.zero_state(self.batch_size, tf.float32) 60 | self.initial_state = tf.placeholder(tf.float32, [self.batch_size, multi_cell.state_size], 'initial_state') 61 | 62 | with tf.name_scope('embedding_layer'): 63 | if embedding_size > 0: 64 | self.embedding = tf.get_variable('embedding', [self.vocab_size, self.embedding_size]) 65 | else: 66 | self.embedding = tf.constant(np.eye(self.vocab_size), dtype=tf.float32) 67 | inputs = tf.nn.embedding_lookup(self.embedding, self.input_data) 68 | if is_training and self.input_dropout > 0: 69 | inputs = tf.nn.dropout(inputs, 1-self.input_dropout) 70 | 71 | with tf.name_scope('slice_inputs'): 72 | # num_unrollings * (batch_size, embedding_size), the format of rnn inputs. 73 | sliced_inputs = [tf.squeeze(input_, [1]) for input_ in tf.split(1, self.num_unrollings, inputs)] 74 | 75 | outputs, final_state = rnn.rnn(multi_cell, sliced_inputs, initial_state=self.initial_state) 76 | self.final_state = final_state 77 | 78 | with tf.name_scope('flatten_outputs'): 79 | flat_outputs = tf.reshape(tf.concat(1, outputs), [-1, hidden_size]) 80 | 81 | with tf.name_scope('flatten_targets'): 82 | flat_targets = tf.reshape(tf.concat(1, self.targets), [-1]) 83 | 84 | with tf.variable_scope('softmax') as sm_vs: 85 | softmax_w = tf.get_variable('softmax_w', [hidden_size, vocab_size]) 86 | softmax_b = tf.get_variable('softmax_b', [vocab_size]) 87 | self.logits = tf.matmul(flat_outputs, softmax_w) + softmax_b 88 | self.probs = tf.nn.softmax(self.logits) 89 | 90 | with tf.name_scope('loss'): 91 | loss = tf.nn.sparse_softmax_cross_entropy_with_logits(self.logits, flat_targets) 92 | self.mean_loss = tf.reduce_mean(loss) 93 | 94 | with tf.name_scope('loss_montor'): 95 | count = tf.Variable(1.0, name='count') 96 | sum_mean_loss = tf.Variable(1.0, name='sum_mean_loss') 97 | 98 | self.reset_loss_monitor = tf.group(sum_mean_loss.assign(0.0), 99 | count.assign(0.0), name='reset_loss_monitor') 100 | self.update_loss_monitor = tf.group(sum_mean_loss.assign(sum_mean_loss+self.mean_loss), 101 | count.assign(count+1), name='update_loss_monitor') 102 | 103 | with tf.control_dependencies([self.update_loss_monitor]): 104 | self.average_loss = sum_mean_loss / count 105 | self.ppl = tf.exp(self.average_loss) 106 | 107 | average_loss_summary = tf.scalar_summary('average loss', self.average_loss) 108 | ppl_summary = tf.scalar_summary('perplexity', self.ppl) 109 | 110 | self.summaries = tf.merge_summary([average_loss_summary, ppl_summary], name='loss_monitor') 111 | 112 | self.global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0.0)) 113 | 114 | # self.learning_rate = tf.constant(learning_rate) 115 | self.learning_rate = tf.placeholder(tf.float32, [], name='learning_rate') 116 | 117 | if is_training: 118 | tvars = tf.trainable_variables() 119 | grads, _ = tf.clip_by_global_norm(tf.gradients(self.mean_loss, tvars), self.max_grad_norm) 120 | optimizer = tf.train.AdamOptimizer(self.learning_rate) 121 | self.train_op = optimizer.apply_gradients(zip(grads, tvars), global_step=self.global_step) 122 | 123 | 124 | def run_epoch(self, session, batch_generator, is_training, learning_rate, verbose=0, freq=10): 125 | epoch_size = batch_generator.num_batches 126 | 127 | if verbose > 0: 128 | logging.info('epoch_size: %d', epoch_size) 129 | logging.info('data_size: %d', batch_generator.seq_length) 130 | logging.info('num_unrollings: %d', self.num_unrollings) 131 | logging.info('batch_size: %d', self.batch_size) 132 | 133 | if is_training: 134 | extra_op = self.train_op 135 | else: 136 | extra_op = tf.no_op() 137 | 138 | state = self.zero_state.eval() 139 | self.reset_loss_monitor.run() 140 | batch_generator.reset_batch_pointer() 141 | start_time = time.time() 142 | for step in range(epoch_size): 143 | x, y = batch_generator.next_batch() 144 | 145 | ops = [self.average_loss, self.ppl, self.final_state, extra_op, 146 | self.summaries, self.global_step] 147 | 148 | feed_dict = {self.input_data: x, self.targets: y, self.initial_state: state, 149 | self.learning_rate: learning_rate} 150 | 151 | results = session.run(ops, feed_dict) 152 | average_loss, ppl, final_state, _, summary_str, global_step = results 153 | 154 | if (verbose > 0) and ((step+1) % freq == 0): 155 | logging.info('%.1f%%, step:%d, perplexity: %.3f, speed: %.0f words', 156 | (step + 1) * 1.0 / epoch_size * 100, step, ppl, 157 | (step + 1) * self.batch_size * self.num_unrollings / (time.time() - start_time)) 158 | logging.info("Perplexity: %.3f, speed: %.0f words per sec", 159 | ppl, (step + 1) * self.batch_size * self.num_unrollings / (time.time() - start_time)) 160 | 161 | return ppl, summary_str, global_step 162 | 163 | def sample_seq(self, session, length, start_text, vocab_loader, max_prob=True): 164 | state = self.zero_state.eval() 165 | 166 | # use start_text to warm up the RNN. 167 | if start_text is not None and len(start_text) > 0: 168 | seq = list(start_text) 169 | for char in start_text[:-1]: 170 | x = np.array([[vocab_loader.vocab_index_dict[char]]]) 171 | state = session.run(self.final_state, {self.input_data: x, self.initial_state: state}) 172 | x = np.array([[vocab_loader.vocab_index_dict[start_text[-1]]]]) 173 | else: 174 | x = np.array([[np.random.randint(0, vocab_loader.vocab_size)]]) 175 | seq = [] 176 | 177 | for i in range(length): 178 | state, logits = session.run([self.final_state, self.logits], 179 | {self.input_data: x, self.initial_state: state}) 180 | unnormalized_probs = np.exp(logits[0] - np.max(logits[0])) 181 | probs = unnormalized_probs / np.sum(unnormalized_probs) 182 | 183 | if max_prob: 184 | sample = np.argmax(probs) 185 | else: 186 | sample = np.random.choice(vocab_loader.vocab_size, 1, p=probs)[0] 187 | 188 | seq.append(vocab_loader.index_vocab_dict[sample]) 189 | x = np.array([[sample]]) 190 | 191 | return ''.join(seq) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import codecs 2 | import json 3 | import logging 4 | import os 5 | import shutil 6 | import sys 7 | 8 | import numpy as np 9 | import tensorflow as tf 10 | from model import CharRNNLM 11 | from config import config_train 12 | from utils import VocabularyLoader, BatchGenerator, batche2string 13 | 14 | TF_VERSION = int(tf.__version__.split('.')[1]) 15 | 16 | def main(): 17 | args = config_train() 18 | 19 | # Specifying location to store model, best model and tensorboard log. 20 | args.save_model = os.path.join(args.output_dir, 'save_model/model') 21 | args.save_best_model = os.path.join(args.output_dir, 'best_model/model') 22 | args.tb_log_dir = os.path.join(args.output_dir, 'tensorboard_log/') 23 | args.vocab_file = '' 24 | 25 | # Create necessary directories. 26 | if len(args.init_dir) != 0: 27 | args.output_dir = args.init_dir 28 | else: 29 | if os.path.exists(args.output_dir): 30 | shutil.rmtree(args.output_dir) 31 | for paths in [args.save_model, args.save_best_model, args.tb_log_dir]: 32 | os.makedirs(os.path.dirname(paths)) 33 | 34 | logging.basicConfig(stream=sys.stdout, 35 | format='%(asctime)s %(levelname)s:%(message)s', 36 | level=logging.INFO, datefmt='%I:%M:%S') 37 | 38 | print('=' * 60) 39 | print('All final and intermediate outputs will be stored in %s/' % args.output_dir) 40 | print('=' * 60 + '\n') 41 | 42 | if args.debug: 43 | logging.info('args are:\n%s', args) 44 | 45 | if len(args.init_dir) != 0: 46 | with open(os.path.join(args.init_dir, 'result.json'), 'r') as f: 47 | result = json.load(f) 48 | params = result['params'] 49 | args.init_model = result['latest_model'] 50 | best_model = result['best_model'] 51 | best_valid_ppl = result['best_valid_ppl'] 52 | if 'encoding' in result: 53 | args.encoding = result['encoding'] 54 | else: 55 | args.encoding = 'utf-8' 56 | args.vocab_file = os.path.join(args.init_dir, 'vocab.json') 57 | else: 58 | params = {'batch_size': args.batch_size, 59 | 'num_unrollings': args.num_unrollings, 60 | 'hidden_size': args.hidden_size, 61 | 'max_grad_norm': args.max_grad_norm, 62 | 'embedding_size': args.embedding_size, 63 | 'num_layers': args.num_layers, 64 | 'learning_rate': args.learning_rate, 65 | 'model': args.model, 66 | 'dropout': args.dropout, 67 | 'input_dropout': args.input_dropout} 68 | best_model = '' 69 | logging.info('Parameters are:\n%s\n', json.dumps(params, sort_keys=True, indent=4)) 70 | 71 | # Read and split data. 72 | logging.info('Reading data from: %s', args.data_file) 73 | with codecs.open(args.data_file, 'r', encoding=args.encoding) as f: 74 | text = f.read() 75 | 76 | if args.test: 77 | text = text[:50000] 78 | logging.info('Number of characters: %s', len(text)) 79 | 80 | if args.debug: 81 | logging.info('First %d characters: %s', 10, text[:10]) 82 | 83 | logging.info('Creating train, valid, test split') 84 | train_size = int(args.train_frac * len(text)) 85 | valid_size = int(args.valid_frac * len(text)) 86 | test_size = len(text) - train_size - valid_size 87 | train_text = text[:train_size] 88 | valid_text = text[train_size:train_size + valid_size] 89 | test_text = text[train_size + valid_size:] 90 | 91 | vocab_loader = VocabularyLoader() 92 | if len(args.vocab_file) != 0: 93 | vocab_loader.load_vocab(args.vocab_file, args.encoding) 94 | else: 95 | logging.info('Creating vocabulary') 96 | vocab_loader.create_vocab(text) 97 | vocab_file = os.path.join(args.output_dir, 'vocab.json') 98 | vocab_loader.save_vocab(vocab_file, args.encoding) 99 | logging.info('Vocabulary is saved in %s', vocab_file) 100 | args.vocab_file = vocab_file 101 | 102 | params['vocab_size'] = vocab_loader.vocab_size 103 | logging.info('Vocab size: %d', vocab_loader.vocab_size) 104 | 105 | # Create batch generators. 106 | batch_size = params['batch_size'] 107 | num_unrollings = params['num_unrollings'] 108 | 109 | train_batches = BatchGenerator(vocab_loader.vocab_index_dict, train_text, batch_size, num_unrollings) 110 | valid_batches = BatchGenerator(vocab_loader.vocab_index_dict, valid_text, batch_size, num_unrollings) 111 | test_batches = BatchGenerator(vocab_loader.vocab_index_dict, test_text, batch_size, num_unrollings) 112 | 113 | if args.debug: 114 | logging.info('Test batch generators') 115 | x, y = train_batches.next_batch() 116 | logging.info((str(x[0]), str(batche2string(x[0], vocab_loader.index_vocab_dict)))) 117 | logging.info((str(y[0]), str(batche2string(y[0], vocab_loader.index_vocab_dict)))) 118 | 119 | # Create graphs 120 | logging.info('Creating graph') 121 | graph = tf.Graph() 122 | with graph.as_default(): 123 | with tf.name_scope('training'): 124 | train_model = CharRNNLM(is_training=True, infer=False, **params) 125 | tf.get_variable_scope().reuse_variables() 126 | with tf.name_scope('validation'): 127 | valid_model = CharRNNLM(is_training=False, infer=False, **params) 128 | with tf.name_scope('evaluation'): 129 | test_model = CharRNNLM(is_training=False, infer=False, **params) 130 | saver = tf.train.Saver(name='model_saver') 131 | best_model_saver = tf.train.Saver(name='best_model_saver') 132 | 133 | logging.info('Start training\n') 134 | 135 | result = {} 136 | result['params'] = params 137 | result['vocab_file'] = args.vocab_file 138 | result['encoding'] = args.encoding 139 | 140 | try: 141 | with tf.Session(graph=graph) as session: 142 | # Version 8 changed the api of summary writer to use 143 | # graph instead of graph_def. 144 | if TF_VERSION >= 8: 145 | graph_info = session.graph 146 | else: 147 | graph_info = session.graph_def 148 | 149 | train_writer = tf.train.SummaryWriter(args.tb_log_dir + 'train/', graph_info) 150 | valid_writer = tf.train.SummaryWriter(args.tb_log_dir + 'valid/', graph_info) 151 | 152 | # load a saved model or start from random initialization. 153 | if len(args.init_model) != 0: 154 | saver.restore(session, args.init_model) 155 | else: 156 | tf.initialize_all_variables().run() 157 | 158 | learning_rate = args.learning_rate 159 | for epoch in range(args.num_epochs): 160 | logging.info('=' * 19 + ' Epoch %d ' + '=' * 19 + '\n', epoch) 161 | logging.info('Training on training set') 162 | # training step 163 | ppl, train_summary_str, global_step = train_model.run_epoch(session, train_batches, is_training=True, 164 | learning_rate=learning_rate, verbose=args.verbose, freq=args.progress_freq) 165 | # record the summary 166 | train_writer.add_summary(train_summary_str, global_step) 167 | train_writer.flush() 168 | # save model 169 | saved_path = saver.save(session, args.save_model, 170 | global_step=train_model.global_step) 171 | 172 | logging.info('Latest model saved in %s\n', saved_path) 173 | logging.info('Evaluate on validation set') 174 | 175 | valid_ppl, valid_summary_str, _ = valid_model.run_epoch(session, valid_batches, is_training=False, 176 | learning_rate=learning_rate, verbose=args.verbose, freq=args.progress_freq) 177 | 178 | # save and update best model 179 | if (len(best_model) == 0) or (valid_ppl < best_valid_ppl): 180 | best_model = best_model_saver.save(session, args.save_best_model, 181 | global_step=train_model.global_step) 182 | best_valid_ppl = valid_ppl 183 | else: 184 | learning_rate /= 2.0 185 | logging.info('Decay the learning rate: ' + str(learning_rate)) 186 | 187 | valid_writer.add_summary(valid_summary_str, global_step) 188 | valid_writer.flush() 189 | 190 | logging.info('Best model is saved in %s', best_model) 191 | logging.info('Best validation ppl is %f\n', best_valid_ppl) 192 | 193 | result['latest_model'] = saved_path 194 | result['best_model'] = best_model 195 | # Convert to float because numpy.float is not json serializable. 196 | result['best_valid_ppl'] = float(best_valid_ppl) 197 | 198 | result_path = os.path.join(args.output_dir, 'result.json') 199 | if os.path.exists(result_path): 200 | os.remove(result_path) 201 | with open(result_path, 'w') as f: 202 | json.dump(result, f, indent=2, sort_keys=True) 203 | 204 | logging.info('Latest model is saved in %s', saved_path) 205 | logging.info('Best model is saved in %s', best_model) 206 | logging.info('Best validation ppl is %f\n', best_valid_ppl) 207 | 208 | logging.info('Evaluate the best model on test set') 209 | saver.restore(session, best_model) 210 | test_ppl, _, _ = test_model.run_epoch(session, test_batches, is_training=False, 211 | learning_rate=learning_rate, verbose=args.verbose, freq=args.progress_freq) 212 | result['test_ppl'] = float(test_ppl) 213 | finally: 214 | result_path = os.path.join(args.output_dir, 'result.json') 215 | if os.path.exists(result_path): 216 | os.remove(result_path) 217 | with open(result_path, 'w') as f: 218 | json.dump(result, f, indent=2, sort_keys=True) 219 | 220 | 221 | if __name__ == '__main__': 222 | main() --------------------------------------------------------------------------------