├── .gitignore ├── README.md ├── docker ├── Dockerfile └── build_image.sh ├── model.py ├── run_docker.sh ├── sample.py ├── sample_frozen.py ├── save ├── checkpoint ├── config.pkl ├── model.ckpt-10000 ├── model.ckpt-10000.meta ├── model.ckpt-10500 ├── model.ckpt-10500.meta ├── model.ckpt-11000 ├── model.ckpt-11000.meta ├── model.ckpt-9000 ├── model.ckpt-9000.meta ├── model.ckpt-9500 └── model.ckpt-9500.meta ├── svg ├── example.svg ├── example1.color.svg ├── example1.eos_pdf.svg ├── example1.multi_color.svg ├── example1.normal.svg ├── example1.pdf.svg └── many_examples.svg ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # IPython Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule 77 | 78 | # dotenv 79 | .env 80 | 81 | # virtualenv 82 | venv/ 83 | ENV/ 84 | 85 | # Spyder project settings 86 | .spyderproject 87 | 88 | # Rope project settings 89 | .ropeproject 90 | 91 | # SVG 92 | *.svg 93 | 94 | # training data folder 95 | data/ 96 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | ## Generative Handwriting Demo using TensorFlow 3 | 4 | ![example](https://cdn.rawgit.com/hardmaru/write-rnn-tensorflow/master/svg/example.svg) 5 | 6 | ![example](https://cdn.rawgit.com/hardmaru/write-rnn-tensorflow/master/svg/many_examples.svg) 7 | 8 | An attempt to implement the random handwriting generation portion of Alex Graves' [paper](http://arxiv.org/abs/1308.0850). 9 | 10 | See my blog post at [blog.otoro.net](http://blog.otoro.net/2015/12/12/handwriting-generation-demo-in-tensorflow) for more information. 11 | 12 | ### How to use 13 | 14 | I tested the implementation on TensorFlow r0.11 and Pyton 3. I also used the following libraries to help: 15 | 16 | ``` 17 | svgwrite 18 | IPython.display.SVG 19 | IPython.display.display 20 | xml.etree.ElementTree 21 | argparse 22 | pickle 23 | ``` 24 | 25 | ### Training 26 | 27 | You will need permission from [these wonderful people](http://www.iam.unibe.ch/fki/databases/iam-on-line-handwriting-database) people to get the IAM On-Line Handwriting data. Unzip `lineStrokes-all.tar.gz` into the data subdirectory, so that you end up with `data/lineStrokes/a01`, `data/lineStrokes/a02`, etc. Afterwards, running `python train.py` will start the training process. 28 | 29 | A number of flags can be set for training if you wish to experiment with the parameters. The default values are in `train.py` 30 | 31 | ``` 32 | --rnn_size RNN_SIZE size of RNN hidden state 33 | --num_layers NUM_LAYERS number of layers in the RNN 34 | --model MODEL rnn, gru, or lstm 35 | --batch_size BATCH_SIZE minibatch size 36 | --seq_length SEQ_LENGTH RNN sequence length 37 | --num_epochs NUM_EPOCHS number of epochs 38 | --save_every SAVE_EVERY save frequency 39 | --grad_clip GRAD_CLIP clip gradients at this value 40 | --learning_rate LEARNING_RATE learning rate 41 | --decay_rate DECAY_RATE decay rate for rmsprop 42 | --num_mixture NUM_MIXTURE number of gaussian mixtures 43 | --data_scale DATA_SCALE factor to scale raw data down by 44 | --keep_prob KEEP_PROB dropout keep probability 45 | ``` 46 | 47 | ### Generating a Handwriting Sample 48 | 49 | I've included a pretrained model in `/save` so it should work out of the box. Running `python sample.py --filename example_name --sample_length 1000` will generate 4 .svg files for each example, with 1000 points. 50 | 51 | ### IPython interactive session. 52 | 53 | If you wish to experiment with this code interactively, just run `%run -i sample.py` in an IPython console, and then the following code is an example on how to generate samples and show them inside IPython. 54 | 55 | ``` 56 | [strokes, params] = model.sample(sess, 800) 57 | draw_strokes(strokes, factor=8, svg_filename = 'sample.normal.svg') 58 | draw_strokes_random_color(strokes, factor=8, svg_filename = 'sample.color.svg') 59 | draw_strokes_random_color(strokes, factor=8, per_stroke_mode = False, svg_filename = 'sample.multi_color.svg') 60 | draw_strokes_eos_weighted(strokes, params, factor=8, svg_filename = 'sample.eos.svg') 61 | draw_strokes_pdf(strokes, params, factor=8, svg_filename = 'sample.pdf.svg') 62 | 63 | ``` 64 | 65 | ![example1a](https://cdn.rawgit.com/hardmaru/write-rnn-tensorflow/master/svg/example1.normal.svg) 66 | ![example1b](https://cdn.rawgit.com/hardmaru/write-rnn-tensorflow/master/svg/example1.color.svg) 67 | ![example1c](https://cdn.rawgit.com/hardmaru/write-rnn-tensorflow/master/svg/example1.multi_color.svg) 68 | ![example1d](https://cdn.rawgit.com/hardmaru/write-rnn-tensorflow/master/svg/example1.eos_pdf.svg) 69 | ![example1e](https://cdn.rawgit.com/hardmaru/write-rnn-tensorflow/master/svg/example1.pdf.svg) 70 | 71 | Have fun- 72 | 73 | ## License 74 | 75 | MIT 76 | 77 | 78 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM tensorflow/tensorflow:1.5.0-py3 2 | 3 | RUN pip install svgwrite 4 | 5 | -------------------------------------------------------------------------------- /docker/build_image.sh: -------------------------------------------------------------------------------- 1 | docker build --tag write-rnn-tensorflow:1.5.0-py3 . 2 | 3 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | 7 | class Model(): 8 | def __init__(self, args, infer=False): 9 | self.args = args 10 | if infer: 11 | args.batch_size = 1 12 | args.seq_length = 1 13 | 14 | if args.model == 'rnn': 15 | cell_fn = tf.contrib.rnn.BasicRNNCell 16 | elif args.model == 'gru': 17 | cell_fn = tf.contrib.rnn.GRUCell 18 | elif args.model == 'lstm': 19 | cell_fn = tf.contrib.rnn.BasicLSTMCell 20 | else: 21 | raise Exception("model type not supported: {}".format(args.model)) 22 | 23 | def get_cell(): 24 | return cell_fn(args.rnn_size, state_is_tuple=False) 25 | 26 | cell = tf.contrib.rnn.MultiRNNCell( 27 | [get_cell() for _ in range(args.num_layers)]) 28 | 29 | if (infer == False and args.keep_prob < 1): # training mode 30 | cell = tf.contrib.rnn.DropoutWrapper( 31 | cell, output_keep_prob=args.keep_prob) 32 | 33 | self.cell = cell 34 | 35 | self.input_data = tf.placeholder( 36 | dtype=tf.float32, shape=[ 37 | None, args.seq_length, 3], name='data_in') 38 | self.target_data = tf.placeholder( 39 | dtype=tf.float32, shape=[ 40 | None, args.seq_length, 3], name='targets') 41 | zero_state = cell.zero_state( 42 | batch_size=args.batch_size, dtype=tf.float32) 43 | self.state_in = tf.identity(zero_state, name='state_in') 44 | 45 | self.num_mixture = args.num_mixture 46 | # end_of_stroke + prob + 2*(mu + sig) + corr 47 | NOUT = 1 + self.num_mixture * 6 48 | 49 | with tf.variable_scope('rnnlm'): 50 | output_w = tf.get_variable("output_w", [args.rnn_size, NOUT]) 51 | output_b = tf.get_variable("output_b", [NOUT]) 52 | 53 | # inputs = tf.split(axis=1, num_or_size_splits=args.seq_length, value=self.input_data) 54 | # inputs = [tf.squeeze(input_, [1]) for input_ in inputs] 55 | inputs = tf.unstack(self.input_data, axis=1) 56 | 57 | # outputs, state_out = tf.contrib.legacy_seq2seq.rnn_decoder(inputs, self.state_in, cell, loop_function=None, scope='rnnlm') 58 | outputs, state_out = tf.contrib.legacy_seq2seq.rnn_decoder( 59 | inputs, zero_state, cell, loop_function=None, scope='rnnlm') 60 | 61 | output = tf.reshape( 62 | tf.concat(axis=1, values=outputs), [-1, args.rnn_size]) 63 | output = tf.nn.xw_plus_b(output, output_w, output_b) 64 | self.state_out = tf.identity(state_out, name='state_out') 65 | 66 | # reshape target data so that it is compatible with prediction shape 67 | flat_target_data = tf.reshape(self.target_data, [-1, 3]) 68 | [x1_data, x2_data, eos_data] = tf.split( 69 | axis=1, num_or_size_splits=3, value=flat_target_data) 70 | 71 | # long method: 72 | #flat_target_data = tf.split(1, args.seq_length, self.target_data) 73 | #flat_target_data = [tf.squeeze(flat_target_data_, [1]) for flat_target_data_ in flat_target_data] 74 | #flat_target_data = tf.reshape(tf.concat(1, flat_target_data), [-1, 3]) 75 | 76 | def tf_2d_normal(x1, x2, mu1, mu2, s1, s2, rho): 77 | # eq # 24 and 25 of http://arxiv.org/abs/1308.0850 78 | norm1 = tf.subtract(x1, mu1) 79 | norm2 = tf.subtract(x2, mu2) 80 | s1s2 = tf.multiply(s1, s2) 81 | z = tf.square(tf.div(norm1, s1)) + tf.square(tf.div(norm2, s2)) - \ 82 | 2 * tf.div(tf.multiply(rho, tf.multiply(norm1, norm2)), s1s2) 83 | negRho = 1 - tf.square(rho) 84 | result = tf.exp(tf.div(-z, 2 * negRho)) 85 | denom = 2 * np.pi * tf.multiply(s1s2, tf.sqrt(negRho)) 86 | result = tf.div(result, denom) 87 | return result 88 | 89 | def get_lossfunc( 90 | z_pi, 91 | z_mu1, 92 | z_mu2, 93 | z_sigma1, 94 | z_sigma2, 95 | z_corr, 96 | z_eos, 97 | x1_data, 98 | x2_data, 99 | eos_data): 100 | result0 = tf_2d_normal( 101 | x1_data, 102 | x2_data, 103 | z_mu1, 104 | z_mu2, 105 | z_sigma1, 106 | z_sigma2, 107 | z_corr) 108 | # implementing eq # 26 of http://arxiv.org/abs/1308.0850 109 | epsilon = 1e-20 110 | result1 = tf.multiply(result0, z_pi) 111 | result1 = tf.reduce_sum(result1, 1, keep_dims=True) 112 | # at the beginning, some errors are exactly zero. 113 | result1 = -tf.log(tf.maximum(result1, 1e-20)) 114 | 115 | result2 = tf.multiply(z_eos, eos_data) + \ 116 | tf.multiply(1 - z_eos, 1 - eos_data) 117 | result2 = -tf.log(result2) 118 | 119 | result = result1 + result2 120 | return tf.reduce_sum(result) 121 | 122 | # below is where we need to do MDN splitting of distribution params 123 | def get_mixture_coef(output): 124 | # returns the tf slices containing mdn dist params 125 | # ie, eq 18 -> 23 of http://arxiv.org/abs/1308.0850 126 | z = output 127 | z_eos = z[:, 0:1] 128 | z_pi, z_mu1, z_mu2, z_sigma1, z_sigma2, z_corr = tf.split( 129 | axis=1, num_or_size_splits=6, value=z[:, 1:]) 130 | 131 | # process output z's into MDN paramters 132 | 133 | # end of stroke signal 134 | z_eos = tf.sigmoid(z_eos) # should be negated, but doesn't matter. 135 | 136 | # softmax all the pi's: 137 | max_pi = tf.reduce_max(z_pi, 1, keep_dims=True) 138 | z_pi = tf.subtract(z_pi, max_pi) 139 | z_pi = tf.exp(z_pi) 140 | normalize_pi = tf.reciprocal( 141 | tf.reduce_sum(z_pi, 1, keep_dims=True)) 142 | z_pi = tf.multiply(normalize_pi, z_pi) 143 | 144 | # exponentiate the sigmas and also make corr between -1 and 1. 145 | z_sigma1 = tf.exp(z_sigma1) 146 | z_sigma2 = tf.exp(z_sigma2) 147 | z_corr = tf.tanh(z_corr) 148 | 149 | return [z_pi, z_mu1, z_mu2, z_sigma1, z_sigma2, z_corr, z_eos] 150 | 151 | [o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, 152 | o_corr, o_eos] = get_mixture_coef(output) 153 | 154 | # I could put all of these in a single tensor for reading out, but this 155 | # is more human readable 156 | data_out_pi = tf.identity(o_pi, "data_out_pi") 157 | data_out_mu1 = tf.identity(o_mu1, "data_out_mu1") 158 | data_out_mu2 = tf.identity(o_mu2, "data_out_mu2") 159 | data_out_sigma1 = tf.identity(o_sigma1, "data_out_sigma1") 160 | data_out_sigma2 = tf.identity(o_sigma2, "data_out_sigma2") 161 | data_out_corr = tf.identity(o_corr, "data_out_corr") 162 | data_out_eos = tf.identity(o_eos, "data_out_eos") 163 | 164 | # sticking them all (except eos) in one op anyway, makes it easier for freezing the graph later 165 | # IMPORTANT, this needs to stack the named ops above (data_out_XXX), not the prev ops (o_XXX) 166 | # otherwise when I freeze the graph up to this point, the named versions will be cut 167 | # eos is diff size to others, so excluding that 168 | data_out_mdn = tf.identity([data_out_pi, 169 | data_out_mu1, 170 | data_out_mu2, 171 | data_out_sigma1, 172 | data_out_sigma2, 173 | data_out_corr], 174 | name="data_out_mdn") 175 | 176 | self.pi = o_pi 177 | self.mu1 = o_mu1 178 | self.mu2 = o_mu2 179 | self.sigma1 = o_sigma1 180 | self.sigma2 = o_sigma2 181 | self.corr = o_corr 182 | self.eos = o_eos 183 | 184 | lossfunc = get_lossfunc( 185 | o_pi, 186 | o_mu1, 187 | o_mu2, 188 | o_sigma1, 189 | o_sigma2, 190 | o_corr, 191 | o_eos, 192 | x1_data, 193 | x2_data, 194 | eos_data) 195 | self.cost = lossfunc / (args.batch_size * args.seq_length) 196 | 197 | self.train_loss_summary = tf.summary.scalar('train_loss', self.cost) 198 | self.valid_loss_summary = tf.summary.scalar( 199 | 'validation_loss', self.cost) 200 | 201 | self.lr = tf.Variable(0.0, trainable=False) 202 | tvars = tf.trainable_variables() 203 | grads, _ = tf.clip_by_global_norm( 204 | tf.gradients(self.cost, tvars), args.grad_clip) 205 | optimizer = tf.train.AdamOptimizer(self.lr) 206 | self.train_op = optimizer.apply_gradients(zip(grads, tvars)) 207 | 208 | def sample(self, sess, num=1200): 209 | 210 | def get_pi_idx(x, pdf): 211 | N = pdf.size 212 | accumulate = 0 213 | for i in range(0, N): 214 | accumulate += pdf[i] 215 | if (accumulate >= x): 216 | return i 217 | print('error with sampling ensemble') 218 | return -1 219 | 220 | def sample_gaussian_2d(mu1, mu2, s1, s2, rho): 221 | mean = [mu1, mu2] 222 | cov = [[s1 * s1, rho * s1 * s2], [rho * s1 * s2, s2 * s2]] 223 | x = np.random.multivariate_normal(mean, cov, 1) 224 | return x[0][0], x[0][1] 225 | 226 | prev_x = np.zeros((1, 1, 3), dtype=np.float32) 227 | prev_x[0, 0, 2] = 1 # initially, we want to see beginning of new stroke 228 | prev_state = sess.run(self.cell.zero_state(1, tf.float32)) 229 | 230 | strokes = np.zeros((num, 3), dtype=np.float32) 231 | mixture_params = [] 232 | 233 | for i in range(num): 234 | 235 | feed = {self.input_data: prev_x, self.state_in: prev_state} 236 | 237 | [o_pi, 238 | o_mu1, 239 | o_mu2, 240 | o_sigma1, 241 | o_sigma2, 242 | o_corr, 243 | o_eos, 244 | next_state] = sess.run([self.pi, 245 | self.mu1, 246 | self.mu2, 247 | self.sigma1, 248 | self.sigma2, 249 | self.corr, 250 | self.eos, 251 | self.state_out], 252 | feed) 253 | 254 | idx = get_pi_idx(random.random(), o_pi[0]) 255 | 256 | eos = 1 if random.random() < o_eos[0][0] else 0 257 | 258 | next_x1, next_x2 = sample_gaussian_2d( 259 | o_mu1[0][idx], o_mu2[0][idx], o_sigma1[0][idx], o_sigma2[0][idx], o_corr[0][idx]) 260 | 261 | strokes[i, :] = [next_x1, next_x2, eos] 262 | 263 | params = [ 264 | o_pi[0], 265 | o_mu1[0], 266 | o_mu2[0], 267 | o_sigma1[0], 268 | o_sigma2[0], 269 | o_corr[0], 270 | o_eos[0]] 271 | mixture_params.append(params) 272 | 273 | prev_x = np.zeros((1, 1, 3), dtype=np.float32) 274 | prev_x[0][0] = np.array([next_x1, next_x2, eos], dtype=np.float32) 275 | prev_state = next_state 276 | 277 | strokes[:, 0:2] *= self.args.data_scale 278 | return strokes, mixture_params 279 | -------------------------------------------------------------------------------- /run_docker.sh: -------------------------------------------------------------------------------- 1 | docker run \ 2 | -it \ 3 | -v $(pwd):/workspace \ 4 | -w /workspace \ 5 | write-rnn-tensorflow:1.5.0-py3 \ 6 | python train.py 7 | 8 | -------------------------------------------------------------------------------- /sample.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import tensorflow as tf 4 | 5 | from model import Model 6 | from utils import * 7 | 8 | # main code (not in a main function since I want to run this script in 9 | # IPython as well). 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--filename', type=str, default='sample', 13 | help='filename of .svg file to output, without .svg') 14 | parser.add_argument('--sample_length', type=int, default=800, 15 | help='number of strokes to sample') 16 | parser.add_argument( 17 | '--scale_factor', 18 | type=int, 19 | default=10, 20 | help='factor to scale down by for svg output. smaller means bigger output') 21 | parser.add_argument('--model_dir', type=str, default='save', 22 | help='directory to save model to') 23 | parser.add_argument( 24 | '--freeze_graph', 25 | dest='freeze_graph', 26 | action='store_true', 27 | help='if true, freeze (replace variables with consts), prune (for inference) and save graph') 28 | 29 | sample_args = parser.parse_args() 30 | 31 | with open(os.path.join(sample_args.model_dir, 'config.pkl'), 'rb') as f: 32 | saved_args = pickle.load(f) 33 | 34 | model = Model(saved_args, True) 35 | sess = tf.InteractiveSession() 36 | #saver = tf.train.Saver(tf.all_variables()) 37 | saver = tf.train.Saver() 38 | 39 | ckpt = tf.train.get_checkpoint_state(sample_args.model_dir) 40 | print("loading model: ", ckpt.model_checkpoint_path) 41 | 42 | saver.restore(sess, ckpt.model_checkpoint_path) 43 | 44 | 45 | def sample_stroke(): 46 | [strokes, params] = model.sample(sess, sample_args.sample_length) 47 | draw_strokes( 48 | strokes, 49 | factor=sample_args.scale_factor, 50 | svg_filename=sample_args.filename + 51 | '.normal.svg') 52 | draw_strokes_random_color( 53 | strokes, 54 | factor=sample_args.scale_factor, 55 | svg_filename=sample_args.filename + 56 | '.color.svg') 57 | draw_strokes_random_color( 58 | strokes, 59 | factor=sample_args.scale_factor, 60 | per_stroke_mode=False, 61 | svg_filename=sample_args.filename + 62 | '.multi_color.svg') 63 | draw_strokes_eos_weighted( 64 | strokes, 65 | params, 66 | factor=sample_args.scale_factor, 67 | svg_filename=sample_args.filename + 68 | '.eos_pdf.svg') 69 | draw_strokes_pdf( 70 | strokes, 71 | params, 72 | factor=sample_args.scale_factor, 73 | svg_filename=sample_args.filename + 74 | '.pdf.svg') 75 | return [strokes, params] 76 | 77 | 78 | def freeze_and_save_graph(sess, folder, out_nodes, as_text=False): 79 | # save graph definition 80 | graph_raw = sess.graph_def 81 | graph_frz = tf.graph_util.convert_variables_to_constants( 82 | sess, graph_raw, out_nodes) 83 | ext = '.txt' if as_text else '.pb' 84 | #tf.train.write_graph(graph_raw, folder, 'graph_raw'+ext, as_text=as_text) 85 | tf.train.write_graph(graph_frz, folder, 'graph_frz' + ext, as_text=as_text) 86 | 87 | 88 | if(sample_args.freeze_graph): 89 | freeze_and_save_graph( 90 | sess, sample_args.model_dir, [ 91 | 'data_out_mdn', 'data_out_eos', 'state_out'], False) 92 | 93 | [strokes, params] = sample_stroke() 94 | -------------------------------------------------------------------------------- /sample_frozen.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Thu Feb 23 20:25:16 2017 4 | 5 | @author: memo 6 | 7 | demonstrates inference with frozen graph def 8 | same as sample.py, but: 9 | - instead of loading model + checkpoint, loads frozen graph 10 | - instead of calling model.sample() function, uses own sample() function with named ops 11 | """ 12 | 13 | import argparse 14 | 15 | import tensorflow as tf 16 | 17 | from utils import * 18 | 19 | # main code (not in a main function since I want to run this script in 20 | # IPython as well). 21 | 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument('--filename', type=str, default='sample', 24 | help='filename of .svg file to output, without .svg') 25 | parser.add_argument('--sample_length', type=int, default=800, 26 | help='number of strokes to sample') 27 | parser.add_argument( 28 | '--scale_factor', 29 | type=int, 30 | default=10, 31 | help='factor to scale down by for svg output. smaller means bigger output') 32 | parser.add_argument('--model_dir', type=str, default='save', 33 | help='directory to save model to') 34 | sample_args = parser.parse_args() 35 | 36 | sess = tf.InteractiveSession() 37 | 38 | # load frozen graph 39 | from tensorflow.python.platform import gfile 40 | with gfile.FastGFile(os.path.join(sample_args.model_dir, 'graph_frz.pb'), 'rb') as f: 41 | graph_def = tf.GraphDef() 42 | graph_def.ParseFromString(f.read()) 43 | sess.graph.as_default() 44 | tf.import_graph_def(graph_def, name='') 45 | 46 | 47 | def sample_stroke(): 48 | # don't call model.sample(), instead call sample() function defined below 49 | [strokes, params] = sample(sess, sample_args.sample_length) 50 | draw_strokes( 51 | strokes, 52 | factor=sample_args.scale_factor, 53 | svg_filename=sample_args.filename + 54 | '.normal.svg') 55 | draw_strokes_random_color( 56 | strokes, 57 | factor=sample_args.scale_factor, 58 | svg_filename=sample_args.filename + 59 | '.color.svg') 60 | draw_strokes_random_color( 61 | strokes, 62 | factor=sample_args.scale_factor, 63 | per_stroke_mode=False, 64 | svg_filename=sample_args.filename + 65 | '.multi_color.svg') 66 | draw_strokes_eos_weighted( 67 | strokes, 68 | params, 69 | factor=sample_args.scale_factor, 70 | svg_filename=sample_args.filename + 71 | '.eos_pdf.svg') 72 | draw_strokes_pdf( 73 | strokes, 74 | params, 75 | factor=sample_args.scale_factor, 76 | svg_filename=sample_args.filename + 77 | '.pdf.svg') 78 | return [strokes, params] 79 | 80 | 81 | # copied straight from model.sample, but replaced all referenes to 'self' 82 | # with named ops 83 | def sample(sess, num=1200): 84 | data_in = 'data_in:0' 85 | data_out_pi = 'data_out_pi:0' 86 | data_out_mu1 = 'data_out_mu1:0' 87 | data_out_mu2 = 'data_out_mu2:0' 88 | data_out_sigma1 = 'data_out_sigma1:0' 89 | data_out_sigma2 = 'data_out_sigma2:0' 90 | data_out_corr = 'data_out_corr:0' 91 | data_out_eos = 'data_out_eos:0' 92 | state_in = 'state_in:0' 93 | state_out = 'state_out:0' 94 | 95 | def get_pi_idx(x, pdf): 96 | N = pdf.size 97 | accumulate = 0 98 | for i in range(0, N): 99 | accumulate += pdf[i] 100 | if (accumulate >= x): 101 | return i 102 | print('error with sampling ensemble') 103 | return -1 104 | 105 | def sample_gaussian_2d(mu1, mu2, s1, s2, rho): 106 | mean = [mu1, mu2] 107 | cov = [[s1 * s1, rho * s1 * s2], [rho * s1 * s2, s2 * s2]] 108 | x = np.random.multivariate_normal(mean, cov, 1) 109 | return x[0][0], x[0][1] 110 | 111 | prev_x = np.zeros((1, 1, 3), dtype=np.float32) 112 | prev_x[0, 0, 2] = 1 # initially, we want to see beginning of new stroke 113 | prev_state = sess.run(state_in) 114 | 115 | strokes = np.zeros((num, 3), dtype=np.float32) 116 | mixture_params = [] 117 | 118 | for i in range(num): 119 | 120 | feed = {data_in: prev_x, state_in: prev_state} 121 | 122 | [o_pi, 123 | o_mu1, 124 | o_mu2, 125 | o_sigma1, 126 | o_sigma2, 127 | o_corr, 128 | o_eos, 129 | next_state] = sess.run([data_out_pi, 130 | data_out_mu1, 131 | data_out_mu2, 132 | data_out_sigma1, 133 | data_out_sigma2, 134 | data_out_corr, 135 | data_out_eos, 136 | state_out], 137 | feed) 138 | 139 | idx = get_pi_idx(random.random(), o_pi[0]) 140 | 141 | eos = 1 if random.random() < o_eos[0][0] else 0 142 | 143 | next_x1, next_x2 = sample_gaussian_2d( 144 | o_mu1[0][idx], o_mu2[0][idx], o_sigma1[0][idx], o_sigma2[0][idx], o_corr[0][idx]) 145 | 146 | strokes[i, :] = [next_x1, next_x2, eos] 147 | 148 | params = [ 149 | o_pi[0], 150 | o_mu1[0], 151 | o_mu2[0], 152 | o_sigma1[0], 153 | o_sigma2[0], 154 | o_corr[0], 155 | o_eos[0]] 156 | mixture_params.append(params) 157 | 158 | prev_x = np.zeros((1, 1, 3), dtype=np.float32) 159 | prev_x[0][0] = np.array([next_x1, next_x2, eos], dtype=np.float32) 160 | prev_state = next_state 161 | 162 | # self.args.data_scale # TODO: fix mega hack hardcoding the scale 163 | strokes[:, 0:2] *= 20 164 | return strokes, mixture_params 165 | 166 | 167 | # check output 168 | [strokes, params] = sample_stroke() 169 | -------------------------------------------------------------------------------- /save/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "model.ckpt-11000" 2 | all_model_checkpoint_paths: "model.ckpt-9000" 3 | all_model_checkpoint_paths: "model.ckpt-9500" 4 | all_model_checkpoint_paths: "model.ckpt-10000" 5 | all_model_checkpoint_paths: "model.ckpt-10500" 6 | all_model_checkpoint_paths: "model.ckpt-11000" 7 | -------------------------------------------------------------------------------- /save/config.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hardmaru/write-rnn-tensorflow/c9423a1adf1eaa65297e16f7a1a555ffb777d8c5/save/config.pkl -------------------------------------------------------------------------------- /save/model.ckpt-10000: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hardmaru/write-rnn-tensorflow/c9423a1adf1eaa65297e16f7a1a555ffb777d8c5/save/model.ckpt-10000 -------------------------------------------------------------------------------- /save/model.ckpt-10000.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hardmaru/write-rnn-tensorflow/c9423a1adf1eaa65297e16f7a1a555ffb777d8c5/save/model.ckpt-10000.meta -------------------------------------------------------------------------------- /save/model.ckpt-10500: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hardmaru/write-rnn-tensorflow/c9423a1adf1eaa65297e16f7a1a555ffb777d8c5/save/model.ckpt-10500 -------------------------------------------------------------------------------- /save/model.ckpt-10500.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hardmaru/write-rnn-tensorflow/c9423a1adf1eaa65297e16f7a1a555ffb777d8c5/save/model.ckpt-10500.meta -------------------------------------------------------------------------------- /save/model.ckpt-11000: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hardmaru/write-rnn-tensorflow/c9423a1adf1eaa65297e16f7a1a555ffb777d8c5/save/model.ckpt-11000 -------------------------------------------------------------------------------- /save/model.ckpt-11000.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hardmaru/write-rnn-tensorflow/c9423a1adf1eaa65297e16f7a1a555ffb777d8c5/save/model.ckpt-11000.meta -------------------------------------------------------------------------------- /save/model.ckpt-9000: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hardmaru/write-rnn-tensorflow/c9423a1adf1eaa65297e16f7a1a555ffb777d8c5/save/model.ckpt-9000 -------------------------------------------------------------------------------- /save/model.ckpt-9000.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hardmaru/write-rnn-tensorflow/c9423a1adf1eaa65297e16f7a1a555ffb777d8c5/save/model.ckpt-9000.meta -------------------------------------------------------------------------------- /save/model.ckpt-9500: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hardmaru/write-rnn-tensorflow/c9423a1adf1eaa65297e16f7a1a555ffb777d8c5/save/model.ckpt-9500 -------------------------------------------------------------------------------- /save/model.ckpt-9500.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hardmaru/write-rnn-tensorflow/c9423a1adf1eaa65297e16f7a1a555ffb777d8c5/save/model.ckpt-9500.meta -------------------------------------------------------------------------------- /svg/example1.multi_color.svg: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /svg/example1.normal.svg: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import pickle 4 | import time 5 | 6 | import tensorflow as tf 7 | 8 | from model import Model 9 | from utils import DataLoader 10 | 11 | 12 | def main(): 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--rnn_size', type=int, default=256, 15 | help='size of RNN hidden state') 16 | parser.add_argument('--num_layers', type=int, default=2, 17 | help='number of layers in the RNN') 18 | parser.add_argument('--model', type=str, default='lstm', 19 | help='rnn, gru, or lstm') 20 | parser.add_argument('--batch_size', type=int, default=50, 21 | help='minibatch size') 22 | parser.add_argument('--seq_length', type=int, default=300, 23 | help='RNN sequence length') 24 | parser.add_argument('--num_epochs', type=int, default=30, 25 | help='number of epochs') 26 | parser.add_argument('--save_every', type=int, default=500, 27 | help='save frequency') 28 | parser.add_argument('--model_dir', type=str, default='save', 29 | help='directory to save model to') 30 | parser.add_argument('--grad_clip', type=float, default=10., 31 | help='clip gradients at this value') 32 | parser.add_argument('--learning_rate', type=float, default=0.005, 33 | help='learning rate') 34 | parser.add_argument('--decay_rate', type=float, default=0.95, 35 | help='decay rate for rmsprop') 36 | parser.add_argument('--num_mixture', type=int, default=20, 37 | help='number of gaussian mixtures') 38 | parser.add_argument('--data_scale', type=float, default=20, 39 | help='factor to scale raw data down by') 40 | parser.add_argument('--keep_prob', type=float, default=0.8, 41 | help='dropout keep probability') 42 | args = parser.parse_args() 43 | train(args) 44 | 45 | 46 | def train(args): 47 | data_loader = DataLoader(args.batch_size, args.seq_length, args.data_scale) 48 | 49 | if args.model_dir != '' and not os.path.exists(args.model_dir): 50 | os.makedirs(args.model_dir) 51 | 52 | with open(os.path.join(args.model_dir, 'config.pkl'), 'wb') as f: 53 | pickle.dump(args, f) 54 | 55 | model = Model(args) 56 | 57 | with tf.Session() as sess: 58 | summary_writer = tf.summary.FileWriter( 59 | os.path.join(args.model_dir, 'log'), sess.graph) 60 | 61 | tf.global_variables_initializer().run() 62 | saver = tf.train.Saver(tf.global_variables()) 63 | for e in range(args.num_epochs): 64 | sess.run(tf.assign(model.lr, 65 | args.learning_rate * (args.decay_rate ** e))) 66 | data_loader.reset_batch_pointer() 67 | v_x, v_y = data_loader.validation_data() 68 | valid_feed = { 69 | model.input_data: v_x, 70 | model.target_data: v_y, 71 | model.state_in: model.state_in.eval()} 72 | state = model.state_in.eval() 73 | for b in range(data_loader.num_batches): 74 | ith_train_step = e * data_loader.num_batches + b 75 | start = time.time() 76 | x, y = data_loader.next_batch() 77 | feed = { 78 | model.input_data: x, 79 | model.target_data: y, 80 | model.state_in: state} 81 | train_loss_summary, train_loss, state, _ = sess.run( 82 | [model.train_loss_summary, model.cost, model.state_out, model.train_op], feed) 83 | summary_writer.add_summary(train_loss_summary, ith_train_step) 84 | 85 | valid_loss_summary, valid_loss, = sess.run( 86 | [model.valid_loss_summary, model.cost], valid_feed) 87 | summary_writer.add_summary(valid_loss_summary, ith_train_step) 88 | 89 | end = time.time() 90 | print( 91 | "{}/{} (epoch {}), train_loss = {:.3f}, valid_loss = {:.3f}, time/batch = {:.3f}" 92 | .format( 93 | ith_train_step, 94 | args.num_epochs * data_loader.num_batches, 95 | e, 96 | train_loss, 97 | valid_loss, 98 | end - start)) 99 | if (ith_train_step % 100 | args.save_every == 0) and (ith_train_step > 0): 101 | checkpoint_path = os.path.join( 102 | args.model_dir, 'model.ckpt') 103 | saver.save( 104 | sess, 105 | checkpoint_path, 106 | global_step=ith_train_step) 107 | print("model saved to {}".format(checkpoint_path)) 108 | 109 | 110 | if __name__ == '__main__': 111 | main() 112 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import random 4 | import xml.etree.ElementTree as ET 5 | 6 | import numpy as np 7 | import svgwrite 8 | from IPython.display import SVG, display 9 | 10 | 11 | def get_bounds(data, factor): 12 | min_x = 0 13 | max_x = 0 14 | min_y = 0 15 | max_y = 0 16 | 17 | abs_x = 0 18 | abs_y = 0 19 | for i in range(len(data)): 20 | x = float(data[i, 0]) / factor 21 | y = float(data[i, 1]) / factor 22 | abs_x += x 23 | abs_y += y 24 | min_x = min(min_x, abs_x) 25 | min_y = min(min_y, abs_y) 26 | max_x = max(max_x, abs_x) 27 | max_y = max(max_y, abs_y) 28 | 29 | return (min_x, max_x, min_y, max_y) 30 | 31 | # old version, where each path is entire stroke (smaller svg size, but 32 | # have to keep same color) 33 | 34 | 35 | def draw_strokes(data, factor=10, svg_filename='sample.svg'): 36 | min_x, max_x, min_y, max_y = get_bounds(data, factor) 37 | dims = (50 + max_x - min_x, 50 + max_y - min_y) 38 | 39 | dwg = svgwrite.Drawing(svg_filename, size=dims) 40 | dwg.add(dwg.rect(insert=(0, 0), size=dims, fill='white')) 41 | 42 | lift_pen = 1 43 | 44 | abs_x = 25 - min_x 45 | abs_y = 25 - min_y 46 | p = "M%s,%s " % (abs_x, abs_y) 47 | 48 | command = "m" 49 | 50 | for i in range(len(data)): 51 | if (lift_pen == 1): 52 | command = "m" 53 | elif (command != "l"): 54 | command = "l" 55 | else: 56 | command = "" 57 | x = float(data[i, 0]) / factor 58 | y = float(data[i, 1]) / factor 59 | lift_pen = data[i, 2] 60 | p += command + str(x) + "," + str(y) + " " 61 | 62 | the_color = "black" 63 | stroke_width = 1 64 | 65 | dwg.add(dwg.path(p).stroke(the_color, stroke_width).fill("none")) 66 | 67 | dwg.save() 68 | display(SVG(dwg.tostring())) 69 | 70 | 71 | def draw_strokes_eos_weighted( 72 | stroke, 73 | param, 74 | factor=10, 75 | svg_filename='sample_eos.svg'): 76 | c_data_eos = np.zeros((len(stroke), 3)) 77 | for i in range(len(param)): 78 | # make color gray scale, darker = more likely to eos 79 | c_data_eos[i, :] = (1 - param[i][6][0]) * 225 80 | draw_strokes_custom_color( 81 | stroke, 82 | factor=factor, 83 | svg_filename=svg_filename, 84 | color_data=c_data_eos, 85 | stroke_width=3) 86 | 87 | 88 | def draw_strokes_random_color( 89 | stroke, 90 | factor=10, 91 | svg_filename='sample_random_color.svg', 92 | per_stroke_mode=True): 93 | c_data = np.array(np.random.rand(len(stroke), 3) * 240, dtype=np.uint8) 94 | if per_stroke_mode: 95 | switch_color = False 96 | for i in range(len(stroke)): 97 | if switch_color == False and i > 0: 98 | c_data[i] = c_data[i - 1] 99 | if stroke[i, 2] < 1: # same strike 100 | switch_color = False 101 | else: 102 | switch_color = True 103 | draw_strokes_custom_color( 104 | stroke, 105 | factor=factor, 106 | svg_filename=svg_filename, 107 | color_data=c_data, 108 | stroke_width=2) 109 | 110 | 111 | def draw_strokes_custom_color( 112 | data, 113 | factor=10, 114 | svg_filename='test.svg', 115 | color_data=None, 116 | stroke_width=1): 117 | min_x, max_x, min_y, max_y = get_bounds(data, factor) 118 | dims = (50 + max_x - min_x, 50 + max_y - min_y) 119 | 120 | dwg = svgwrite.Drawing(svg_filename, size=dims) 121 | dwg.add(dwg.rect(insert=(0, 0), size=dims, fill='white')) 122 | 123 | lift_pen = 1 124 | abs_x = 25 - min_x 125 | abs_y = 25 - min_y 126 | 127 | for i in range(len(data)): 128 | 129 | x = float(data[i, 0]) / factor 130 | y = float(data[i, 1]) / factor 131 | 132 | prev_x = abs_x 133 | prev_y = abs_y 134 | 135 | abs_x += x 136 | abs_y += y 137 | 138 | if (lift_pen == 1): 139 | p = "M " + str(abs_x) + "," + str(abs_y) + " " 140 | else: 141 | p = "M +" + str(prev_x) + "," + str(prev_y) + \ 142 | " L " + str(abs_x) + "," + str(abs_y) + " " 143 | 144 | lift_pen = data[i, 2] 145 | 146 | the_color = "black" 147 | 148 | if (color_data is not None): 149 | the_color = "rgb(" + str(int(color_data[i, 0])) + "," + str( 150 | int(color_data[i, 1])) + "," + str(int(color_data[i, 2])) + ")" 151 | 152 | dwg.add(dwg.path(p).stroke(the_color, stroke_width).fill(the_color)) 153 | dwg.save() 154 | display(SVG(dwg.tostring())) 155 | 156 | 157 | def draw_strokes_pdf(data, param, factor=10, svg_filename='sample_pdf.svg'): 158 | min_x, max_x, min_y, max_y = get_bounds(data, factor) 159 | dims = (50 + max_x - min_x, 50 + max_y - min_y) 160 | 161 | dwg = svgwrite.Drawing(svg_filename, size=dims) 162 | dwg.add(dwg.rect(insert=(0, 0), size=dims, fill='white')) 163 | 164 | abs_x = 25 - min_x 165 | abs_y = 25 - min_y 166 | 167 | num_mixture = len(param[0][0]) 168 | 169 | for i in range(len(data)): 170 | 171 | x = float(data[i, 0]) / factor 172 | y = float(data[i, 1]) / factor 173 | 174 | for k in range(num_mixture): 175 | pi = param[i][0][k] 176 | if pi > 0.01: # optimisation, ignore pi's less than 1% chance 177 | mu1 = param[i][1][k] 178 | mu2 = param[i][2][k] 179 | s1 = param[i][3][k] 180 | s2 = param[i][4][k] 181 | sigma = np.sqrt(s1 * s2) 182 | dwg.add(dwg.circle(center=(abs_x + mu1 * factor, 183 | abs_y + mu2 * factor), 184 | r=int(sigma * factor)).fill('red', 185 | opacity=pi / (sigma * sigma * factor))) 186 | 187 | prev_x = abs_x 188 | prev_y = abs_y 189 | 190 | abs_x += x 191 | abs_y += y 192 | 193 | dwg.save() 194 | display(SVG(dwg.tostring())) 195 | 196 | 197 | class DataLoader(): 198 | def __init__( 199 | self, 200 | batch_size=50, 201 | seq_length=300, 202 | scale_factor=10, 203 | limit=500): 204 | self.data_dir = "./data" 205 | self.batch_size = batch_size 206 | self.seq_length = seq_length 207 | self.scale_factor = scale_factor # divide data by this factor 208 | self.limit = limit # removes large noisy gaps in the data 209 | 210 | data_file = os.path.join(self.data_dir, "strokes_training_data.cpkl") 211 | raw_data_dir = self.data_dir + "/lineStrokes" 212 | 213 | if not (os.path.exists(data_file)): 214 | print("creating training data pkl file from raw source") 215 | self.preprocess(raw_data_dir, data_file) 216 | 217 | self.load_preprocessed(data_file) 218 | self.reset_batch_pointer() 219 | 220 | def preprocess(self, data_dir, data_file): 221 | # create data file from raw xml files from iam handwriting source. 222 | 223 | # build the list of xml files 224 | filelist = [] 225 | # Set the directory you want to start from 226 | rootDir = data_dir 227 | for dirName, subdirList, fileList in os.walk(rootDir): 228 | #print('Found directory: %s' % dirName) 229 | for fname in fileList: 230 | #print('\t%s' % fname) 231 | filelist.append(dirName + "/" + fname) 232 | 233 | # function to read each individual xml file 234 | def getStrokes(filename): 235 | tree = ET.parse(filename) 236 | root = tree.getroot() 237 | 238 | result = [] 239 | 240 | x_offset = 1e20 241 | y_offset = 1e20 242 | y_height = 0 243 | for i in range(1, 4): 244 | x_offset = min(x_offset, float(root[0][i].attrib['x'])) 245 | y_offset = min(y_offset, float(root[0][i].attrib['y'])) 246 | y_height = max(y_height, float(root[0][i].attrib['y'])) 247 | y_height -= y_offset 248 | x_offset -= 100 249 | y_offset -= 100 250 | 251 | for stroke in root[1].findall('Stroke'): 252 | points = [] 253 | for point in stroke.findall('Point'): 254 | points.append( 255 | [float(point.attrib['x']) - x_offset, float(point.attrib['y']) - y_offset]) 256 | result.append(points) 257 | 258 | return result 259 | 260 | # converts a list of arrays into a 2d numpy int16 array 261 | def convert_stroke_to_array(stroke): 262 | 263 | n_point = 0 264 | for i in range(len(stroke)): 265 | n_point += len(stroke[i]) 266 | stroke_data = np.zeros((n_point, 3), dtype=np.int16) 267 | 268 | prev_x = 0 269 | prev_y = 0 270 | counter = 0 271 | 272 | for j in range(len(stroke)): 273 | for k in range(len(stroke[j])): 274 | stroke_data[counter, 0] = int(stroke[j][k][0]) - prev_x 275 | stroke_data[counter, 1] = int(stroke[j][k][1]) - prev_y 276 | prev_x = int(stroke[j][k][0]) 277 | prev_y = int(stroke[j][k][1]) 278 | stroke_data[counter, 2] = 0 279 | if (k == (len(stroke[j]) - 1)): # end of stroke 280 | stroke_data[counter, 2] = 1 281 | counter += 1 282 | return stroke_data 283 | 284 | # build stroke database of every xml file inside iam database 285 | strokes = [] 286 | for i in range(len(filelist)): 287 | if (filelist[i][-3:] == 'xml'): 288 | print('processing ' + filelist[i]) 289 | strokes.append( 290 | convert_stroke_to_array( 291 | getStrokes( 292 | filelist[i]))) 293 | 294 | f = open(data_file, "wb") 295 | pickle.dump(strokes, f, protocol=2) 296 | f.close() 297 | 298 | def load_preprocessed(self, data_file): 299 | f = open(data_file, "rb") 300 | self.raw_data = pickle.load(f) 301 | f.close() 302 | 303 | # goes thru the list, and only keeps the text entries that have more 304 | # than seq_length points 305 | self.data = [] 306 | self.valid_data = [] 307 | counter = 0 308 | 309 | # every 1 in 20 (5%) will be used for validation data 310 | cur_data_counter = 0 311 | for data in self.raw_data: 312 | if len(data) > (self.seq_length + 2): 313 | # removes large gaps from the data 314 | data = np.minimum(data, self.limit) 315 | data = np.maximum(data, -self.limit) 316 | data = np.array(data, dtype=np.float32) 317 | data[:, 0:2] /= self.scale_factor 318 | cur_data_counter = cur_data_counter + 1 319 | if cur_data_counter % 20 == 0: 320 | self.valid_data.append(data) 321 | else: 322 | self.data.append(data) 323 | # number of equiv batches this datapoint is worth 324 | counter += int(len(data) / ((self.seq_length + 2))) 325 | 326 | print("train data: {}, valid data: {}".format( 327 | len(self.data), len(self.valid_data))) 328 | # minus 1, since we want the ydata to be a shifted version of x data 329 | self.num_batches = int(counter / self.batch_size) 330 | 331 | def validation_data(self): 332 | # returns validation data 333 | x_batch = [] 334 | y_batch = [] 335 | for i in range(self.batch_size): 336 | data = self.valid_data[i % len(self.valid_data)] 337 | idx = 0 338 | x_batch.append(np.copy(data[idx:idx + self.seq_length])) 339 | y_batch.append(np.copy(data[idx + 1:idx + self.seq_length + 1])) 340 | return x_batch, y_batch 341 | 342 | def next_batch(self): 343 | # returns a randomised, seq_length sized portion of the training data 344 | x_batch = [] 345 | y_batch = [] 346 | for i in range(self.batch_size): 347 | data = self.data[self.pointer] 348 | # number of equiv batches this datapoint is worth 349 | n_batch = int(len(data) / ((self.seq_length + 2))) 350 | idx = random.randint(0, len(data) - self.seq_length - 2) 351 | x_batch.append(np.copy(data[idx:idx + self.seq_length])) 352 | y_batch.append(np.copy(data[idx + 1:idx + self.seq_length + 1])) 353 | # adjust sampling probability. 354 | if random.random() < (1.0 / float(n_batch)): 355 | # if this is a long datapoint, sample this data more with 356 | # higher probability 357 | self.tick_batch_pointer() 358 | return x_batch, y_batch 359 | 360 | def tick_batch_pointer(self): 361 | self.pointer += 1 362 | if (self.pointer >= len(self.data)): 363 | self.pointer = 0 364 | 365 | def reset_batch_pointer(self): 366 | self.pointer = 0 367 | --------------------------------------------------------------------------------