├── .gitignore ├── .travis.yml ├── LICENSE.md ├── README.md ├── data └── tinyshakespeare │ └── input.txt ├── logs └── .gitignore ├── model.py ├── sample.py ├── save └── .gitignore ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *,cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # dotenv 80 | .env 81 | 82 | # virtualenv 83 | .venv/ 84 | venv/ 85 | ENV/ 86 | 87 | # Spyder project settings 88 | .spyderproject 89 | 90 | # Rope project settings 91 | .ropeproject 92 | 93 | # char-rnn-tensorflow data files 94 | data.npy 95 | vocab.pkl 96 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | dist: trusty 2 | 3 | language: python 4 | 5 | python: 6 | - 2.7 7 | - 3.6 8 | - 3.5 9 | - 3.4 10 | 11 | # Use container-based infrastructure 12 | sudo: false 13 | 14 | install: 15 | - pip install -U pip 16 | - pip install pyflakes 17 | - pip install coverage 18 | - pip install tensorflow 19 | 20 | # Make a smaller input file. Output won't look great, but runs much quicker. 21 | - mkdir data/teenytinyshakespeare 22 | - head -100 data/tinyshakespeare/input.txt > data/teenytinyshakespeare/input.txt 23 | 24 | script: 25 | - pyflakes . 26 | - coverage erase 27 | - export NUM_EPOCHS=15 28 | - export SAVE_DIR_FIRST=$(mktemp -d) 29 | - export SAVE_DIR_SECOND=$(mktemp -d) 30 | - export LOG_DIR=$(mktemp -d) 31 | - export SAMPLE_FILE=$(mktemp) 32 | - coverage run --append --include=./* train.py --data_dir data/teenytinyshakespeare --save_dir $SAVE_DIR_FIRST --log_dir $LOG_DIR --num_epochs $NUM_EPOCHS;test -s $SAVE_DIR_FIRST/model.ckpt-$(( $NUM_EPOCHS - 1)).index 33 | - coverage run --append --include=./* train.py --data_dir data/teenytinyshakespeare --init_from $SAVE_DIR_FIRST --save_dir $SAVE_DIR_SECOND --log_dir $LOG_DIR --num_epochs $NUM_EPOCHS;test -s $SAVE_DIR_SECOND/model.ckpt-$(( $NUM_EPOCHS - 1)).index 34 | - coverage run --append --include=./* sample.py --save_dir $SAVE_DIR_SECOND | tee $SAMPLE_FILE;test -s $SAMPLE_FILE 35 | 36 | after_script: 37 | - pip install pycodestyle 38 | - pycodestyle --statistics --count . 39 | - pip freeze 40 | 41 | after_success: 42 | - coverage report 43 | - pip install coveralls 44 | - coveralls 45 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2015 Sherjil Ozair 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | char-rnn-tensorflow 2 | === 3 | 4 | [![Join the chat at https://gitter.im/char-rnn-tensorflow/Lobby](https://badges.gitter.im/char-rnn-tensorflow/Lobby.svg)](https://gitter.im/char-rnn-tensorflow/Lobby?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) 5 | [![Coverage Status](https://coveralls.io/repos/github/sherjilozair/char-rnn-tensorflow/badge.svg)](https://coveralls.io/github/sherjilozair/char-rnn-tensorflow) 6 | [![Build Status](https://travis-ci.org/sherjilozair/char-rnn-tensorflow.svg?branch=master)](https://travis-ci.org/sherjilozair/char-rnn-tensorflow) 7 | 8 | Multi-layer Recurrent Neural Networks (LSTM, RNN) for character-level language models in Python using Tensorflow. 9 | 10 | Inspired from Andrej Karpathy's [char-rnn](https://github.com/karpathy/char-rnn). 11 | 12 | ## Requirements 13 | - [Tensorflow 1.0](http://www.tensorflow.org) 14 | 15 | ## Basic Usage 16 | To train with default parameters on the tinyshakespeare corpus, run `python train.py`. To access all the parameters use `python train.py --help`. 17 | 18 | To sample from a checkpointed model, `python sample.py`. 19 | Sampling while the learning is still in progress (to check last checkpoint) works only in CPU or using another GPU. 20 | To force CPU mode, use `export CUDA_VISIBLE_DEVICES=""` and `unset CUDA_VISIBLE_DEVICES` afterward 21 | (resp. `set CUDA_VISIBLE_DEVICES=""` and `set CUDA_VISIBLE_DEVICES=` on Windows). 22 | 23 | To continue training after interruption or to run on more epochs, `python train.py --init_from=save` 24 | 25 | ## Datasets 26 | You can use any plain text file as input. For example you could download [The complete Sherlock Holmes](https://sherlock-holm.es/ascii/) as such: 27 | 28 | ```bash 29 | cd data 30 | mkdir sherlock 31 | cd sherlock 32 | wget https://sherlock-holm.es/stories/plain-text/cnus.txt 33 | mv cnus.txt input.txt 34 | ``` 35 | 36 | Then start train from the top level directory using `python train.py --data_dir=./data/sherlock/` 37 | 38 | A quick tip to concatenate many small disparate `.txt` files into one large training file: `ls *.txt | xargs -L 1 cat >> input.txt`. 39 | 40 | ## Tuning 41 | 42 | Tuning your models is kind of a "dark art" at this point. In general: 43 | 44 | 1. Start with as much clean input.txt as possible e.g. 50MiB 45 | 2. Start by establishing a baseline using the default settings. 46 | 3. Use tensorboard to compare all of your runs visually to aid in experimenting. 47 | 4. Tweak --rnn_size up somewhat from 128 if you have a lot of input data. 48 | 5. Tweak --num_layers from 2 to 3 but no higher unless you have experience. 49 | 6. Tweak --seq_length up from 50 based on the length of a valid input string 50 | (e.g. names are <= 12 characters, sentences may be up to 64 characters, etc). 51 | An lstm cell will "remember" for durations longer than this sequence, but the effect falls off for longer character distances. 52 | 7. Finally once you've done all that, only then would I suggest adding some dropout. 53 | Start with --output_keep_prob 0.8 and maybe end up with both --input_keep_prob 0.8 --output_keep_prob 0.5 only after exhausting all the above values. 54 | 55 | ## Tensorboard 56 | To visualize training progress, model graphs, and internal state histograms: fire up Tensorboard and point it at your `log_dir`. E.g.: 57 | ```bash 58 | $ tensorboard --logdir=./logs/ 59 | ``` 60 | 61 | Then open a browser to [http://localhost:6006](http://localhost:6006) or the correct IP/Port specified. 62 | 63 | 64 | ## Roadmap 65 | - [ ] Add explanatory comments 66 | - [ ] Expose more command-line arguments 67 | - [ ] Compare accuracy and performance with char-rnn 68 | - [ ] More Tensorboard instrumentation 69 | 70 | ## Contributing 71 | Please feel free to: 72 | * Leave feedback in the issues 73 | * Open a Pull Request 74 | * Join the [gittr chat](https://gitter.im/char-rnn-tensorflow/Lobby) 75 | * Share your success stories and data sets! 76 | -------------------------------------------------------------------------------- /logs/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore 5 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.contrib import rnn 3 | from tensorflow.contrib import legacy_seq2seq 4 | 5 | import numpy as np 6 | 7 | 8 | class Model(): 9 | def __init__(self, args, training=True): 10 | self.args = args 11 | if not training: 12 | args.batch_size = 1 13 | args.seq_length = 1 14 | 15 | # choose different rnn cell 16 | if args.model == 'rnn': 17 | cell_fn = rnn.RNNCell 18 | elif args.model == 'gru': 19 | cell_fn = rnn.GRUCell 20 | elif args.model == 'lstm': 21 | cell_fn = rnn.LSTMCell 22 | elif args.model == 'nas': 23 | cell_fn = rnn.NASCell 24 | else: 25 | raise Exception("model type not supported: {}".format(args.model)) 26 | 27 | # warp multi layered rnn cell into one cell with dropout 28 | cells = [] 29 | for _ in range(args.num_layers): 30 | cell = cell_fn(args.rnn_size) 31 | if training and (args.output_keep_prob < 1.0 or args.input_keep_prob < 1.0): 32 | cell = rnn.DropoutWrapper(cell, 33 | input_keep_prob=args.input_keep_prob, 34 | output_keep_prob=args.output_keep_prob) 35 | cells.append(cell) 36 | self.cell = cell = rnn.MultiRNNCell(cells, state_is_tuple=True) 37 | 38 | # input/target data (int32 since input is char-level) 39 | self.input_data = tf.placeholder( 40 | tf.int32, [args.batch_size, args.seq_length]) 41 | self.targets = tf.placeholder( 42 | tf.int32, [args.batch_size, args.seq_length]) 43 | self.initial_state = cell.zero_state(args.batch_size, tf.float32) 44 | 45 | # softmax output layer, use softmax to classify 46 | with tf.variable_scope('rnnlm'): 47 | softmax_w = tf.get_variable("softmax_w", 48 | [args.rnn_size, args.vocab_size]) 49 | softmax_b = tf.get_variable("softmax_b", [args.vocab_size]) 50 | 51 | # transform input to embedding 52 | embedding = tf.get_variable("embedding", [args.vocab_size, args.rnn_size]) 53 | inputs = tf.nn.embedding_lookup(embedding, self.input_data) 54 | 55 | # dropout beta testing: double check which one should affect next line 56 | if training and args.output_keep_prob: 57 | inputs = tf.nn.dropout(inputs, args.output_keep_prob) 58 | 59 | # unstack the input to fits in rnn model 60 | inputs = tf.split(inputs, args.seq_length, 1) 61 | inputs = [tf.squeeze(input_, [1]) for input_ in inputs] 62 | 63 | # loop function for rnn_decoder, which take the previous i-th cell's output and generate the (i+1)-th cell's input 64 | def loop(prev, _): 65 | prev = tf.matmul(prev, softmax_w) + softmax_b 66 | prev_symbol = tf.stop_gradient(tf.argmax(prev, 1)) 67 | return tf.nn.embedding_lookup(embedding, prev_symbol) 68 | 69 | # rnn_decoder to generate the ouputs and final state. When we are not training the model, we use the loop function. 70 | outputs, last_state = legacy_seq2seq.rnn_decoder(inputs, self.initial_state, cell, loop_function=loop if not training else None, scope='rnnlm') 71 | output = tf.reshape(tf.concat(outputs, 1), [-1, args.rnn_size]) 72 | 73 | # output layer 74 | self.logits = tf.matmul(output, softmax_w) + softmax_b 75 | self.probs = tf.nn.softmax(self.logits) 76 | 77 | # loss is calculate by the log loss and taking the average. 78 | loss = legacy_seq2seq.sequence_loss_by_example( 79 | [self.logits], 80 | [tf.reshape(self.targets, [-1])], 81 | [tf.ones([args.batch_size * args.seq_length])]) 82 | with tf.name_scope('cost'): 83 | self.cost = tf.reduce_sum(loss) / args.batch_size / args.seq_length 84 | self.final_state = last_state 85 | self.lr = tf.Variable(0.0, trainable=False) 86 | tvars = tf.trainable_variables() 87 | 88 | # calculate gradients 89 | grads, _ = tf.clip_by_global_norm(tf.gradients(self.cost, tvars), 90 | args.grad_clip) 91 | with tf.name_scope('optimizer'): 92 | optimizer = tf.train.AdamOptimizer(self.lr) 93 | 94 | # apply gradient change to the all the trainable variable. 95 | self.train_op = optimizer.apply_gradients(zip(grads, tvars)) 96 | 97 | # instrument tensorboard 98 | tf.summary.histogram('logits', self.logits) 99 | tf.summary.histogram('loss', loss) 100 | tf.summary.scalar('train_loss', self.cost) 101 | 102 | def sample(self, sess, chars, vocab, num=200, prime='The ', sampling_type=1): 103 | state = sess.run(self.cell.zero_state(1, tf.float32)) 104 | for char in prime[:-1]: 105 | x = np.zeros((1, 1)) 106 | x[0, 0] = vocab[char] 107 | feed = {self.input_data: x, self.initial_state: state} 108 | [state] = sess.run([self.final_state], feed) 109 | 110 | def weighted_pick(weights): 111 | t = np.cumsum(weights) 112 | s = np.sum(weights) 113 | return(int(np.searchsorted(t, np.random.rand(1)*s))) 114 | 115 | ret = prime 116 | char = prime[-1] 117 | for _ in range(num): 118 | x = np.zeros((1, 1)) 119 | x[0, 0] = vocab[char] 120 | feed = {self.input_data: x, self.initial_state: state} 121 | [probs, state] = sess.run([self.probs, self.final_state], feed) 122 | p = probs[0] 123 | 124 | if sampling_type == 0: 125 | sample = np.argmax(p) 126 | elif sampling_type == 2: 127 | if char == ' ': 128 | sample = weighted_pick(p) 129 | else: 130 | sample = np.argmax(p) 131 | else: # sampling_type == 1 default: 132 | sample = weighted_pick(p) 133 | 134 | pred = chars[sample] 135 | ret += pred 136 | char = pred 137 | return ret 138 | -------------------------------------------------------------------------------- /sample.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from __future__ import print_function 4 | 5 | import argparse 6 | import os 7 | from six.moves import cPickle 8 | 9 | 10 | from six import text_type 11 | 12 | 13 | parser = argparse.ArgumentParser( 14 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 15 | parser.add_argument('--save_dir', type=str, default='save', 16 | help='model directory to store checkpointed models') 17 | parser.add_argument('-n', type=int, default=500, 18 | help='number of characters to sample') 19 | parser.add_argument('--prime', type=text_type, default=u'', 20 | help='prime text') 21 | parser.add_argument('--sample', type=int, default=1, 22 | help='0 to use max at each timestep, 1 to sample at ' 23 | 'each timestep, 2 to sample on spaces') 24 | 25 | args = parser.parse_args() 26 | 27 | import tensorflow as tf 28 | from model import Model 29 | 30 | def sample(args): 31 | with open(os.path.join(args.save_dir, 'config.pkl'), 'rb') as f: 32 | saved_args = cPickle.load(f) 33 | with open(os.path.join(args.save_dir, 'chars_vocab.pkl'), 'rb') as f: 34 | chars, vocab = cPickle.load(f) 35 | #Use most frequent char if no prime is given 36 | if args.prime == '': 37 | args.prime = chars[0] 38 | model = Model(saved_args, training=False) 39 | with tf.Session() as sess: 40 | tf.global_variables_initializer().run() 41 | saver = tf.train.Saver(tf.global_variables()) 42 | ckpt = tf.train.get_checkpoint_state(args.save_dir) 43 | if ckpt and ckpt.model_checkpoint_path: 44 | saver.restore(sess, ckpt.model_checkpoint_path) 45 | data = model.sample(sess, chars, vocab, args.n, args.prime, 46 | args.sample).encode('utf-8') 47 | print(data.decode("utf-8")) 48 | 49 | if __name__ == '__main__': 50 | sample(args) 51 | -------------------------------------------------------------------------------- /save/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore 5 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from __future__ import print_function 4 | 5 | import argparse 6 | import time 7 | import os 8 | from six.moves import cPickle 9 | 10 | 11 | parser = argparse.ArgumentParser( 12 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 13 | # Data and model checkpoints directories 14 | parser.add_argument('--data_dir', type=str, default='data/tinyshakespeare', 15 | help='data directory containing input.txt with training examples') 16 | parser.add_argument('--save_dir', type=str, default='save', 17 | help='directory to store checkpointed models') 18 | parser.add_argument('--log_dir', type=str, default='logs', 19 | help='directory to store tensorboard logs') 20 | parser.add_argument('--save_every', type=int, default=1000, 21 | help='Save frequency. Number of passes between checkpoints of the model.') 22 | parser.add_argument('--init_from', type=str, default=None, 23 | help="""continue training from saved model at this path (usually "save"). 24 | Path must contain files saved by previous training process: 25 | 'config.pkl' : configuration; 26 | 'chars_vocab.pkl' : vocabulary definitions; 27 | 'checkpoint' : paths to model file(s) (created by tf). 28 | Note: this file contains absolute paths, be careful when moving files around; 29 | 'model.ckpt-*' : file(s) with model definition (created by tf) 30 | Model params must be the same between multiple runs (model, rnn_size, num_layers and seq_length). 31 | """) 32 | # Model params 33 | parser.add_argument('--model', type=str, default='lstm', 34 | help='lstm, rnn, gru, or nas') 35 | parser.add_argument('--rnn_size', type=int, default=128, 36 | help='size of RNN hidden state') 37 | parser.add_argument('--num_layers', type=int, default=2, 38 | help='number of layers in the RNN') 39 | # Optimization 40 | parser.add_argument('--seq_length', type=int, default=50, 41 | help='RNN sequence length. Number of timesteps to unroll for.') 42 | parser.add_argument('--batch_size', type=int, default=50, 43 | help="""minibatch size. Number of sequences propagated through the network in parallel. 44 | Pick batch-sizes to fully leverage the GPU (e.g. until the memory is filled up) 45 | commonly in the range 10-500.""") 46 | parser.add_argument('--num_epochs', type=int, default=50, 47 | help='number of epochs. Number of full passes through the training examples.') 48 | parser.add_argument('--grad_clip', type=float, default=5., 49 | help='clip gradients at this value') 50 | parser.add_argument('--learning_rate', type=float, default=0.002, 51 | help='learning rate') 52 | parser.add_argument('--decay_rate', type=float, default=0.97, 53 | help='decay rate for rmsprop') 54 | parser.add_argument('--output_keep_prob', type=float, default=1.0, 55 | help='probability of keeping weights in the hidden layer') 56 | parser.add_argument('--input_keep_prob', type=float, default=1.0, 57 | help='probability of keeping weights in the input layer') 58 | args = parser.parse_args() 59 | 60 | import tensorflow as tf 61 | from utils import TextLoader 62 | from model import Model 63 | 64 | def train(args): 65 | data_loader = TextLoader(args.data_dir, args.batch_size, args.seq_length) 66 | args.vocab_size = data_loader.vocab_size 67 | 68 | # check compatibility if training is continued from previously saved model 69 | if args.init_from is not None: 70 | # check if all necessary files exist 71 | assert os.path.isdir(args.init_from)," %s must be a a path" % args.init_from 72 | assert os.path.isfile(os.path.join(args.init_from,"config.pkl")),"config.pkl file does not exist in path %s"%args.init_from 73 | assert os.path.isfile(os.path.join(args.init_from,"chars_vocab.pkl")),"chars_vocab.pkl.pkl file does not exist in path %s" % args.init_from 74 | ckpt = tf.train.latest_checkpoint(args.init_from) 75 | assert ckpt, "No checkpoint found" 76 | 77 | # open old config and check if models are compatible 78 | with open(os.path.join(args.init_from, 'config.pkl'), 'rb') as f: 79 | saved_model_args = cPickle.load(f) 80 | need_be_same = ["model", "rnn_size", "num_layers", "seq_length"] 81 | for checkme in need_be_same: 82 | assert vars(saved_model_args)[checkme]==vars(args)[checkme],"Command line argument and saved model disagree on '%s' "%checkme 83 | 84 | # open saved vocab/dict and check if vocabs/dicts are compatible 85 | with open(os.path.join(args.init_from, 'chars_vocab.pkl'), 'rb') as f: 86 | saved_chars, saved_vocab = cPickle.load(f) 87 | assert saved_chars==data_loader.chars, "Data and loaded model disagree on character set!" 88 | assert saved_vocab==data_loader.vocab, "Data and loaded model disagree on dictionary mappings!" 89 | 90 | if not os.path.isdir(args.save_dir): 91 | os.makedirs(args.save_dir) 92 | with open(os.path.join(args.save_dir, 'config.pkl'), 'wb') as f: 93 | cPickle.dump(args, f) 94 | with open(os.path.join(args.save_dir, 'chars_vocab.pkl'), 'wb') as f: 95 | cPickle.dump((data_loader.chars, data_loader.vocab), f) 96 | 97 | model = Model(args) 98 | 99 | with tf.Session() as sess: 100 | # instrument for tensorboard 101 | summaries = tf.summary.merge_all() 102 | writer = tf.summary.FileWriter( 103 | os.path.join(args.log_dir, time.strftime("%Y-%m-%d-%H-%M-%S"))) 104 | writer.add_graph(sess.graph) 105 | 106 | sess.run(tf.global_variables_initializer()) 107 | saver = tf.train.Saver(tf.global_variables()) 108 | # restore model 109 | if args.init_from is not None: 110 | saver.restore(sess, ckpt) 111 | for e in range(args.num_epochs): 112 | sess.run(tf.assign(model.lr, 113 | args.learning_rate * (args.decay_rate ** e))) 114 | data_loader.reset_batch_pointer() 115 | state = sess.run(model.initial_state) 116 | for b in range(data_loader.num_batches): 117 | start = time.time() 118 | x, y = data_loader.next_batch() 119 | feed = {model.input_data: x, model.targets: y} 120 | for i, (c, h) in enumerate(model.initial_state): 121 | feed[c] = state[i].c 122 | feed[h] = state[i].h 123 | 124 | # instrument for tensorboard 125 | summ, train_loss, state, _ = sess.run([summaries, model.cost, model.final_state, model.train_op], feed) 126 | writer.add_summary(summ, e * data_loader.num_batches + b) 127 | 128 | end = time.time() 129 | print("{}/{} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" 130 | .format(e * data_loader.num_batches + b, 131 | args.num_epochs * data_loader.num_batches, 132 | e, train_loss, end - start)) 133 | if (e * data_loader.num_batches + b) % args.save_every == 0\ 134 | or (e == args.num_epochs-1 and 135 | b == data_loader.num_batches-1): 136 | # save for the last result 137 | checkpoint_path = os.path.join(args.save_dir, 'model.ckpt') 138 | saver.save(sess, checkpoint_path, 139 | global_step=e * data_loader.num_batches + b) 140 | print("model saved to {}".format(checkpoint_path)) 141 | 142 | 143 | if __name__ == '__main__': 144 | train(args) 145 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import codecs 2 | import os 3 | import collections 4 | from six.moves import cPickle 5 | import numpy as np 6 | 7 | 8 | class TextLoader(): 9 | def __init__(self, data_dir, batch_size, seq_length, encoding='utf-8'): 10 | self.data_dir = data_dir 11 | self.batch_size = batch_size 12 | self.seq_length = seq_length 13 | self.encoding = encoding 14 | 15 | input_file = os.path.join(data_dir, "input.txt") 16 | vocab_file = os.path.join(data_dir, "vocab.pkl") 17 | tensor_file = os.path.join(data_dir, "data.npy") 18 | 19 | if not (os.path.exists(vocab_file) and os.path.exists(tensor_file)): 20 | print("reading text file") 21 | self.preprocess(input_file, vocab_file, tensor_file) 22 | else: 23 | print("loading preprocessed files") 24 | self.load_preprocessed(vocab_file, tensor_file) 25 | self.create_batches() 26 | self.reset_batch_pointer() 27 | 28 | # preprocess data for the first time. 29 | def preprocess(self, input_file, vocab_file, tensor_file): 30 | with codecs.open(input_file, "r", encoding=self.encoding) as f: 31 | data = f.read() 32 | counter = collections.Counter(data) 33 | count_pairs = sorted(counter.items(), key=lambda x: -x[1]) 34 | self.chars, _ = zip(*count_pairs) 35 | self.vocab_size = len(self.chars) 36 | self.vocab = dict(zip(self.chars, range(len(self.chars)))) 37 | with open(vocab_file, 'wb') as f: 38 | cPickle.dump(self.chars, f) 39 | self.tensor = np.array(list(map(self.vocab.get, data))) 40 | np.save(tensor_file, self.tensor) 41 | 42 | 43 | # load the preprocessed the data if the data has been processed before. 44 | def load_preprocessed(self, vocab_file, tensor_file): 45 | with open(vocab_file, 'rb') as f: 46 | self.chars = cPickle.load(f) 47 | self.vocab_size = len(self.chars) 48 | self.vocab = dict(zip(self.chars, range(len(self.chars)))) 49 | self.tensor = np.load(tensor_file) 50 | self.num_batches = int(self.tensor.size / (self.batch_size * 51 | self.seq_length)) 52 | # seperate the whole data into different batches. 53 | def create_batches(self): 54 | self.num_batches = int(self.tensor.size / (self.batch_size * 55 | self.seq_length)) 56 | 57 | # When the data (tensor) is too small, 58 | # let's give them a better error message 59 | if self.num_batches == 0: 60 | assert False, "Not enough data. Make seq_length and batch_size small." 61 | 62 | # reshape the original data into the length self.num_batches * self.batch_size * self.seq_length for convenience. 63 | self.tensor = self.tensor[:self.num_batches * self.batch_size * self.seq_length] 64 | xdata = self.tensor 65 | ydata = np.copy(self.tensor) 66 | 67 | #ydata is the xdata with one position shift. 68 | ydata[:-1] = xdata[1:] 69 | ydata[-1] = xdata[0] 70 | self.x_batches = np.split(xdata.reshape(self.batch_size, -1), 71 | self.num_batches, 1) 72 | self.y_batches = np.split(ydata.reshape(self.batch_size, -1), 73 | self.num_batches, 1) 74 | 75 | def next_batch(self): 76 | x, y = self.x_batches[self.pointer], self.y_batches[self.pointer] 77 | self.pointer += 1 78 | return x, y 79 | 80 | def reset_batch_pointer(self): 81 | self.pointer = 0 82 | --------------------------------------------------------------------------------