├── .gitignore
├── README.md
├── docker
├── Dockerfile
└── build_image.sh
├── model.py
├── run_docker.sh
├── sample.py
├── sample_frozen.py
├── save
├── checkpoint
├── config.pkl
├── model.ckpt-10000
├── model.ckpt-10000.meta
├── model.ckpt-10500
├── model.ckpt-10500.meta
├── model.ckpt-11000
├── model.ckpt-11000.meta
├── model.ckpt-9000
├── model.ckpt-9000.meta
├── model.ckpt-9500
└── model.ckpt-9500.meta
├── svg
├── example.svg
├── example1.color.svg
├── example1.eos_pdf.svg
├── example1.multi_color.svg
├── example1.normal.svg
├── example1.pdf.svg
└── many_examples.svg
├── train.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 |
27 | # PyInstaller
28 | # Usually these files are written by a python script from a template
29 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
30 | *.manifest
31 | *.spec
32 |
33 | # Installer logs
34 | pip-log.txt
35 | pip-delete-this-directory.txt
36 |
37 | # Unit test / coverage reports
38 | htmlcov/
39 | .tox/
40 | .coverage
41 | .coverage.*
42 | .cache
43 | nosetests.xml
44 | coverage.xml
45 | *,cover
46 | .hypothesis/
47 |
48 | # Translations
49 | *.mo
50 | *.pot
51 |
52 | # Django stuff:
53 | *.log
54 | local_settings.py
55 |
56 | # Flask stuff:
57 | instance/
58 | .webassets-cache
59 |
60 | # Scrapy stuff:
61 | .scrapy
62 |
63 | # Sphinx documentation
64 | docs/_build/
65 |
66 | # PyBuilder
67 | target/
68 |
69 | # IPython Notebook
70 | .ipynb_checkpoints
71 |
72 | # pyenv
73 | .python-version
74 |
75 | # celery beat schedule file
76 | celerybeat-schedule
77 |
78 | # dotenv
79 | .env
80 |
81 | # virtualenv
82 | venv/
83 | ENV/
84 |
85 | # Spyder project settings
86 | .spyderproject
87 |
88 | # Rope project settings
89 | .ropeproject
90 |
91 | # SVG
92 | *.svg
93 |
94 | # training data folder
95 | data/
96 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | ## Generative Handwriting Demo using TensorFlow
3 |
4 | 
5 |
6 | 
7 |
8 | An attempt to implement the random handwriting generation portion of Alex Graves' [paper](http://arxiv.org/abs/1308.0850).
9 |
10 | See my blog post at [blog.otoro.net](http://blog.otoro.net/2015/12/12/handwriting-generation-demo-in-tensorflow) for more information.
11 |
12 | ### How to use
13 |
14 | I tested the implementation on TensorFlow r0.11 and Pyton 3. I also used the following libraries to help:
15 |
16 | ```
17 | svgwrite
18 | IPython.display.SVG
19 | IPython.display.display
20 | xml.etree.ElementTree
21 | argparse
22 | pickle
23 | ```
24 |
25 | ### Training
26 |
27 | You will need permission from [these wonderful people](http://www.iam.unibe.ch/fki/databases/iam-on-line-handwriting-database) people to get the IAM On-Line Handwriting data. Unzip `lineStrokes-all.tar.gz` into the data subdirectory, so that you end up with `data/lineStrokes/a01`, `data/lineStrokes/a02`, etc. Afterwards, running `python train.py` will start the training process.
28 |
29 | A number of flags can be set for training if you wish to experiment with the parameters. The default values are in `train.py`
30 |
31 | ```
32 | --rnn_size RNN_SIZE size of RNN hidden state
33 | --num_layers NUM_LAYERS number of layers in the RNN
34 | --model MODEL rnn, gru, or lstm
35 | --batch_size BATCH_SIZE minibatch size
36 | --seq_length SEQ_LENGTH RNN sequence length
37 | --num_epochs NUM_EPOCHS number of epochs
38 | --save_every SAVE_EVERY save frequency
39 | --grad_clip GRAD_CLIP clip gradients at this value
40 | --learning_rate LEARNING_RATE learning rate
41 | --decay_rate DECAY_RATE decay rate for rmsprop
42 | --num_mixture NUM_MIXTURE number of gaussian mixtures
43 | --data_scale DATA_SCALE factor to scale raw data down by
44 | --keep_prob KEEP_PROB dropout keep probability
45 | ```
46 |
47 | ### Generating a Handwriting Sample
48 |
49 | I've included a pretrained model in `/save` so it should work out of the box. Running `python sample.py --filename example_name --sample_length 1000` will generate 4 .svg files for each example, with 1000 points.
50 |
51 | ### IPython interactive session.
52 |
53 | If you wish to experiment with this code interactively, just run `%run -i sample.py` in an IPython console, and then the following code is an example on how to generate samples and show them inside IPython.
54 |
55 | ```
56 | [strokes, params] = model.sample(sess, 800)
57 | draw_strokes(strokes, factor=8, svg_filename = 'sample.normal.svg')
58 | draw_strokes_random_color(strokes, factor=8, svg_filename = 'sample.color.svg')
59 | draw_strokes_random_color(strokes, factor=8, per_stroke_mode = False, svg_filename = 'sample.multi_color.svg')
60 | draw_strokes_eos_weighted(strokes, params, factor=8, svg_filename = 'sample.eos.svg')
61 | draw_strokes_pdf(strokes, params, factor=8, svg_filename = 'sample.pdf.svg')
62 |
63 | ```
64 |
65 | 
66 | 
67 | 
68 | 
69 | 
70 |
71 | Have fun-
72 |
73 | ## License
74 |
75 | MIT
76 |
77 |
78 |
--------------------------------------------------------------------------------
/docker/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM tensorflow/tensorflow:1.5.0-py3
2 |
3 | RUN pip install svgwrite
4 |
5 |
--------------------------------------------------------------------------------
/docker/build_image.sh:
--------------------------------------------------------------------------------
1 | docker build --tag write-rnn-tensorflow:1.5.0-py3 .
2 |
3 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | import numpy as np
4 | import tensorflow as tf
5 |
6 |
7 | class Model():
8 | def __init__(self, args, infer=False):
9 | self.args = args
10 | if infer:
11 | args.batch_size = 1
12 | args.seq_length = 1
13 |
14 | if args.model == 'rnn':
15 | cell_fn = tf.contrib.rnn.BasicRNNCell
16 | elif args.model == 'gru':
17 | cell_fn = tf.contrib.rnn.GRUCell
18 | elif args.model == 'lstm':
19 | cell_fn = tf.contrib.rnn.BasicLSTMCell
20 | else:
21 | raise Exception("model type not supported: {}".format(args.model))
22 |
23 | def get_cell():
24 | return cell_fn(args.rnn_size, state_is_tuple=False)
25 |
26 | cell = tf.contrib.rnn.MultiRNNCell(
27 | [get_cell() for _ in range(args.num_layers)])
28 |
29 | if (infer == False and args.keep_prob < 1): # training mode
30 | cell = tf.contrib.rnn.DropoutWrapper(
31 | cell, output_keep_prob=args.keep_prob)
32 |
33 | self.cell = cell
34 |
35 | self.input_data = tf.placeholder(
36 | dtype=tf.float32, shape=[
37 | None, args.seq_length, 3], name='data_in')
38 | self.target_data = tf.placeholder(
39 | dtype=tf.float32, shape=[
40 | None, args.seq_length, 3], name='targets')
41 | zero_state = cell.zero_state(
42 | batch_size=args.batch_size, dtype=tf.float32)
43 | self.state_in = tf.identity(zero_state, name='state_in')
44 |
45 | self.num_mixture = args.num_mixture
46 | # end_of_stroke + prob + 2*(mu + sig) + corr
47 | NOUT = 1 + self.num_mixture * 6
48 |
49 | with tf.variable_scope('rnnlm'):
50 | output_w = tf.get_variable("output_w", [args.rnn_size, NOUT])
51 | output_b = tf.get_variable("output_b", [NOUT])
52 |
53 | # inputs = tf.split(axis=1, num_or_size_splits=args.seq_length, value=self.input_data)
54 | # inputs = [tf.squeeze(input_, [1]) for input_ in inputs]
55 | inputs = tf.unstack(self.input_data, axis=1)
56 |
57 | # outputs, state_out = tf.contrib.legacy_seq2seq.rnn_decoder(inputs, self.state_in, cell, loop_function=None, scope='rnnlm')
58 | outputs, state_out = tf.contrib.legacy_seq2seq.rnn_decoder(
59 | inputs, zero_state, cell, loop_function=None, scope='rnnlm')
60 |
61 | output = tf.reshape(
62 | tf.concat(axis=1, values=outputs), [-1, args.rnn_size])
63 | output = tf.nn.xw_plus_b(output, output_w, output_b)
64 | self.state_out = tf.identity(state_out, name='state_out')
65 |
66 | # reshape target data so that it is compatible with prediction shape
67 | flat_target_data = tf.reshape(self.target_data, [-1, 3])
68 | [x1_data, x2_data, eos_data] = tf.split(
69 | axis=1, num_or_size_splits=3, value=flat_target_data)
70 |
71 | # long method:
72 | #flat_target_data = tf.split(1, args.seq_length, self.target_data)
73 | #flat_target_data = [tf.squeeze(flat_target_data_, [1]) for flat_target_data_ in flat_target_data]
74 | #flat_target_data = tf.reshape(tf.concat(1, flat_target_data), [-1, 3])
75 |
76 | def tf_2d_normal(x1, x2, mu1, mu2, s1, s2, rho):
77 | # eq # 24 and 25 of http://arxiv.org/abs/1308.0850
78 | norm1 = tf.subtract(x1, mu1)
79 | norm2 = tf.subtract(x2, mu2)
80 | s1s2 = tf.multiply(s1, s2)
81 | z = tf.square(tf.div(norm1, s1)) + tf.square(tf.div(norm2, s2)) - \
82 | 2 * tf.div(tf.multiply(rho, tf.multiply(norm1, norm2)), s1s2)
83 | negRho = 1 - tf.square(rho)
84 | result = tf.exp(tf.div(-z, 2 * negRho))
85 | denom = 2 * np.pi * tf.multiply(s1s2, tf.sqrt(negRho))
86 | result = tf.div(result, denom)
87 | return result
88 |
89 | def get_lossfunc(
90 | z_pi,
91 | z_mu1,
92 | z_mu2,
93 | z_sigma1,
94 | z_sigma2,
95 | z_corr,
96 | z_eos,
97 | x1_data,
98 | x2_data,
99 | eos_data):
100 | result0 = tf_2d_normal(
101 | x1_data,
102 | x2_data,
103 | z_mu1,
104 | z_mu2,
105 | z_sigma1,
106 | z_sigma2,
107 | z_corr)
108 | # implementing eq # 26 of http://arxiv.org/abs/1308.0850
109 | epsilon = 1e-20
110 | result1 = tf.multiply(result0, z_pi)
111 | result1 = tf.reduce_sum(result1, 1, keep_dims=True)
112 | # at the beginning, some errors are exactly zero.
113 | result1 = -tf.log(tf.maximum(result1, 1e-20))
114 |
115 | result2 = tf.multiply(z_eos, eos_data) + \
116 | tf.multiply(1 - z_eos, 1 - eos_data)
117 | result2 = -tf.log(result2)
118 |
119 | result = result1 + result2
120 | return tf.reduce_sum(result)
121 |
122 | # below is where we need to do MDN splitting of distribution params
123 | def get_mixture_coef(output):
124 | # returns the tf slices containing mdn dist params
125 | # ie, eq 18 -> 23 of http://arxiv.org/abs/1308.0850
126 | z = output
127 | z_eos = z[:, 0:1]
128 | z_pi, z_mu1, z_mu2, z_sigma1, z_sigma2, z_corr = tf.split(
129 | axis=1, num_or_size_splits=6, value=z[:, 1:])
130 |
131 | # process output z's into MDN paramters
132 |
133 | # end of stroke signal
134 | z_eos = tf.sigmoid(z_eos) # should be negated, but doesn't matter.
135 |
136 | # softmax all the pi's:
137 | max_pi = tf.reduce_max(z_pi, 1, keep_dims=True)
138 | z_pi = tf.subtract(z_pi, max_pi)
139 | z_pi = tf.exp(z_pi)
140 | normalize_pi = tf.reciprocal(
141 | tf.reduce_sum(z_pi, 1, keep_dims=True))
142 | z_pi = tf.multiply(normalize_pi, z_pi)
143 |
144 | # exponentiate the sigmas and also make corr between -1 and 1.
145 | z_sigma1 = tf.exp(z_sigma1)
146 | z_sigma2 = tf.exp(z_sigma2)
147 | z_corr = tf.tanh(z_corr)
148 |
149 | return [z_pi, z_mu1, z_mu2, z_sigma1, z_sigma2, z_corr, z_eos]
150 |
151 | [o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2,
152 | o_corr, o_eos] = get_mixture_coef(output)
153 |
154 | # I could put all of these in a single tensor for reading out, but this
155 | # is more human readable
156 | data_out_pi = tf.identity(o_pi, "data_out_pi")
157 | data_out_mu1 = tf.identity(o_mu1, "data_out_mu1")
158 | data_out_mu2 = tf.identity(o_mu2, "data_out_mu2")
159 | data_out_sigma1 = tf.identity(o_sigma1, "data_out_sigma1")
160 | data_out_sigma2 = tf.identity(o_sigma2, "data_out_sigma2")
161 | data_out_corr = tf.identity(o_corr, "data_out_corr")
162 | data_out_eos = tf.identity(o_eos, "data_out_eos")
163 |
164 | # sticking them all (except eos) in one op anyway, makes it easier for freezing the graph later
165 | # IMPORTANT, this needs to stack the named ops above (data_out_XXX), not the prev ops (o_XXX)
166 | # otherwise when I freeze the graph up to this point, the named versions will be cut
167 | # eos is diff size to others, so excluding that
168 | data_out_mdn = tf.identity([data_out_pi,
169 | data_out_mu1,
170 | data_out_mu2,
171 | data_out_sigma1,
172 | data_out_sigma2,
173 | data_out_corr],
174 | name="data_out_mdn")
175 |
176 | self.pi = o_pi
177 | self.mu1 = o_mu1
178 | self.mu2 = o_mu2
179 | self.sigma1 = o_sigma1
180 | self.sigma2 = o_sigma2
181 | self.corr = o_corr
182 | self.eos = o_eos
183 |
184 | lossfunc = get_lossfunc(
185 | o_pi,
186 | o_mu1,
187 | o_mu2,
188 | o_sigma1,
189 | o_sigma2,
190 | o_corr,
191 | o_eos,
192 | x1_data,
193 | x2_data,
194 | eos_data)
195 | self.cost = lossfunc / (args.batch_size * args.seq_length)
196 |
197 | self.train_loss_summary = tf.summary.scalar('train_loss', self.cost)
198 | self.valid_loss_summary = tf.summary.scalar(
199 | 'validation_loss', self.cost)
200 |
201 | self.lr = tf.Variable(0.0, trainable=False)
202 | tvars = tf.trainable_variables()
203 | grads, _ = tf.clip_by_global_norm(
204 | tf.gradients(self.cost, tvars), args.grad_clip)
205 | optimizer = tf.train.AdamOptimizer(self.lr)
206 | self.train_op = optimizer.apply_gradients(zip(grads, tvars))
207 |
208 | def sample(self, sess, num=1200):
209 |
210 | def get_pi_idx(x, pdf):
211 | N = pdf.size
212 | accumulate = 0
213 | for i in range(0, N):
214 | accumulate += pdf[i]
215 | if (accumulate >= x):
216 | return i
217 | print('error with sampling ensemble')
218 | return -1
219 |
220 | def sample_gaussian_2d(mu1, mu2, s1, s2, rho):
221 | mean = [mu1, mu2]
222 | cov = [[s1 * s1, rho * s1 * s2], [rho * s1 * s2, s2 * s2]]
223 | x = np.random.multivariate_normal(mean, cov, 1)
224 | return x[0][0], x[0][1]
225 |
226 | prev_x = np.zeros((1, 1, 3), dtype=np.float32)
227 | prev_x[0, 0, 2] = 1 # initially, we want to see beginning of new stroke
228 | prev_state = sess.run(self.cell.zero_state(1, tf.float32))
229 |
230 | strokes = np.zeros((num, 3), dtype=np.float32)
231 | mixture_params = []
232 |
233 | for i in range(num):
234 |
235 | feed = {self.input_data: prev_x, self.state_in: prev_state}
236 |
237 | [o_pi,
238 | o_mu1,
239 | o_mu2,
240 | o_sigma1,
241 | o_sigma2,
242 | o_corr,
243 | o_eos,
244 | next_state] = sess.run([self.pi,
245 | self.mu1,
246 | self.mu2,
247 | self.sigma1,
248 | self.sigma2,
249 | self.corr,
250 | self.eos,
251 | self.state_out],
252 | feed)
253 |
254 | idx = get_pi_idx(random.random(), o_pi[0])
255 |
256 | eos = 1 if random.random() < o_eos[0][0] else 0
257 |
258 | next_x1, next_x2 = sample_gaussian_2d(
259 | o_mu1[0][idx], o_mu2[0][idx], o_sigma1[0][idx], o_sigma2[0][idx], o_corr[0][idx])
260 |
261 | strokes[i, :] = [next_x1, next_x2, eos]
262 |
263 | params = [
264 | o_pi[0],
265 | o_mu1[0],
266 | o_mu2[0],
267 | o_sigma1[0],
268 | o_sigma2[0],
269 | o_corr[0],
270 | o_eos[0]]
271 | mixture_params.append(params)
272 |
273 | prev_x = np.zeros((1, 1, 3), dtype=np.float32)
274 | prev_x[0][0] = np.array([next_x1, next_x2, eos], dtype=np.float32)
275 | prev_state = next_state
276 |
277 | strokes[:, 0:2] *= self.args.data_scale
278 | return strokes, mixture_params
279 |
--------------------------------------------------------------------------------
/run_docker.sh:
--------------------------------------------------------------------------------
1 | docker run \
2 | -it \
3 | -v $(pwd):/workspace \
4 | -w /workspace \
5 | write-rnn-tensorflow:1.5.0-py3 \
6 | python train.py
7 |
8 |
--------------------------------------------------------------------------------
/sample.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import tensorflow as tf
4 |
5 | from model import Model
6 | from utils import *
7 |
8 | # main code (not in a main function since I want to run this script in
9 | # IPython as well).
10 |
11 | parser = argparse.ArgumentParser()
12 | parser.add_argument('--filename', type=str, default='sample',
13 | help='filename of .svg file to output, without .svg')
14 | parser.add_argument('--sample_length', type=int, default=800,
15 | help='number of strokes to sample')
16 | parser.add_argument(
17 | '--scale_factor',
18 | type=int,
19 | default=10,
20 | help='factor to scale down by for svg output. smaller means bigger output')
21 | parser.add_argument('--model_dir', type=str, default='save',
22 | help='directory to save model to')
23 | parser.add_argument(
24 | '--freeze_graph',
25 | dest='freeze_graph',
26 | action='store_true',
27 | help='if true, freeze (replace variables with consts), prune (for inference) and save graph')
28 |
29 | sample_args = parser.parse_args()
30 |
31 | with open(os.path.join(sample_args.model_dir, 'config.pkl'), 'rb') as f:
32 | saved_args = pickle.load(f)
33 |
34 | model = Model(saved_args, True)
35 | sess = tf.InteractiveSession()
36 | #saver = tf.train.Saver(tf.all_variables())
37 | saver = tf.train.Saver()
38 |
39 | ckpt = tf.train.get_checkpoint_state(sample_args.model_dir)
40 | print("loading model: ", ckpt.model_checkpoint_path)
41 |
42 | saver.restore(sess, ckpt.model_checkpoint_path)
43 |
44 |
45 | def sample_stroke():
46 | [strokes, params] = model.sample(sess, sample_args.sample_length)
47 | draw_strokes(
48 | strokes,
49 | factor=sample_args.scale_factor,
50 | svg_filename=sample_args.filename +
51 | '.normal.svg')
52 | draw_strokes_random_color(
53 | strokes,
54 | factor=sample_args.scale_factor,
55 | svg_filename=sample_args.filename +
56 | '.color.svg')
57 | draw_strokes_random_color(
58 | strokes,
59 | factor=sample_args.scale_factor,
60 | per_stroke_mode=False,
61 | svg_filename=sample_args.filename +
62 | '.multi_color.svg')
63 | draw_strokes_eos_weighted(
64 | strokes,
65 | params,
66 | factor=sample_args.scale_factor,
67 | svg_filename=sample_args.filename +
68 | '.eos_pdf.svg')
69 | draw_strokes_pdf(
70 | strokes,
71 | params,
72 | factor=sample_args.scale_factor,
73 | svg_filename=sample_args.filename +
74 | '.pdf.svg')
75 | return [strokes, params]
76 |
77 |
78 | def freeze_and_save_graph(sess, folder, out_nodes, as_text=False):
79 | # save graph definition
80 | graph_raw = sess.graph_def
81 | graph_frz = tf.graph_util.convert_variables_to_constants(
82 | sess, graph_raw, out_nodes)
83 | ext = '.txt' if as_text else '.pb'
84 | #tf.train.write_graph(graph_raw, folder, 'graph_raw'+ext, as_text=as_text)
85 | tf.train.write_graph(graph_frz, folder, 'graph_frz' + ext, as_text=as_text)
86 |
87 |
88 | if(sample_args.freeze_graph):
89 | freeze_and_save_graph(
90 | sess, sample_args.model_dir, [
91 | 'data_out_mdn', 'data_out_eos', 'state_out'], False)
92 |
93 | [strokes, params] = sample_stroke()
94 |
--------------------------------------------------------------------------------
/sample_frozen.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Created on Thu Feb 23 20:25:16 2017
4 |
5 | @author: memo
6 |
7 | demonstrates inference with frozen graph def
8 | same as sample.py, but:
9 | - instead of loading model + checkpoint, loads frozen graph
10 | - instead of calling model.sample() function, uses own sample() function with named ops
11 | """
12 |
13 | import argparse
14 |
15 | import tensorflow as tf
16 |
17 | from utils import *
18 |
19 | # main code (not in a main function since I want to run this script in
20 | # IPython as well).
21 |
22 | parser = argparse.ArgumentParser()
23 | parser.add_argument('--filename', type=str, default='sample',
24 | help='filename of .svg file to output, without .svg')
25 | parser.add_argument('--sample_length', type=int, default=800,
26 | help='number of strokes to sample')
27 | parser.add_argument(
28 | '--scale_factor',
29 | type=int,
30 | default=10,
31 | help='factor to scale down by for svg output. smaller means bigger output')
32 | parser.add_argument('--model_dir', type=str, default='save',
33 | help='directory to save model to')
34 | sample_args = parser.parse_args()
35 |
36 | sess = tf.InteractiveSession()
37 |
38 | # load frozen graph
39 | from tensorflow.python.platform import gfile
40 | with gfile.FastGFile(os.path.join(sample_args.model_dir, 'graph_frz.pb'), 'rb') as f:
41 | graph_def = tf.GraphDef()
42 | graph_def.ParseFromString(f.read())
43 | sess.graph.as_default()
44 | tf.import_graph_def(graph_def, name='')
45 |
46 |
47 | def sample_stroke():
48 | # don't call model.sample(), instead call sample() function defined below
49 | [strokes, params] = sample(sess, sample_args.sample_length)
50 | draw_strokes(
51 | strokes,
52 | factor=sample_args.scale_factor,
53 | svg_filename=sample_args.filename +
54 | '.normal.svg')
55 | draw_strokes_random_color(
56 | strokes,
57 | factor=sample_args.scale_factor,
58 | svg_filename=sample_args.filename +
59 | '.color.svg')
60 | draw_strokes_random_color(
61 | strokes,
62 | factor=sample_args.scale_factor,
63 | per_stroke_mode=False,
64 | svg_filename=sample_args.filename +
65 | '.multi_color.svg')
66 | draw_strokes_eos_weighted(
67 | strokes,
68 | params,
69 | factor=sample_args.scale_factor,
70 | svg_filename=sample_args.filename +
71 | '.eos_pdf.svg')
72 | draw_strokes_pdf(
73 | strokes,
74 | params,
75 | factor=sample_args.scale_factor,
76 | svg_filename=sample_args.filename +
77 | '.pdf.svg')
78 | return [strokes, params]
79 |
80 |
81 | # copied straight from model.sample, but replaced all referenes to 'self'
82 | # with named ops
83 | def sample(sess, num=1200):
84 | data_in = 'data_in:0'
85 | data_out_pi = 'data_out_pi:0'
86 | data_out_mu1 = 'data_out_mu1:0'
87 | data_out_mu2 = 'data_out_mu2:0'
88 | data_out_sigma1 = 'data_out_sigma1:0'
89 | data_out_sigma2 = 'data_out_sigma2:0'
90 | data_out_corr = 'data_out_corr:0'
91 | data_out_eos = 'data_out_eos:0'
92 | state_in = 'state_in:0'
93 | state_out = 'state_out:0'
94 |
95 | def get_pi_idx(x, pdf):
96 | N = pdf.size
97 | accumulate = 0
98 | for i in range(0, N):
99 | accumulate += pdf[i]
100 | if (accumulate >= x):
101 | return i
102 | print('error with sampling ensemble')
103 | return -1
104 |
105 | def sample_gaussian_2d(mu1, mu2, s1, s2, rho):
106 | mean = [mu1, mu2]
107 | cov = [[s1 * s1, rho * s1 * s2], [rho * s1 * s2, s2 * s2]]
108 | x = np.random.multivariate_normal(mean, cov, 1)
109 | return x[0][0], x[0][1]
110 |
111 | prev_x = np.zeros((1, 1, 3), dtype=np.float32)
112 | prev_x[0, 0, 2] = 1 # initially, we want to see beginning of new stroke
113 | prev_state = sess.run(state_in)
114 |
115 | strokes = np.zeros((num, 3), dtype=np.float32)
116 | mixture_params = []
117 |
118 | for i in range(num):
119 |
120 | feed = {data_in: prev_x, state_in: prev_state}
121 |
122 | [o_pi,
123 | o_mu1,
124 | o_mu2,
125 | o_sigma1,
126 | o_sigma2,
127 | o_corr,
128 | o_eos,
129 | next_state] = sess.run([data_out_pi,
130 | data_out_mu1,
131 | data_out_mu2,
132 | data_out_sigma1,
133 | data_out_sigma2,
134 | data_out_corr,
135 | data_out_eos,
136 | state_out],
137 | feed)
138 |
139 | idx = get_pi_idx(random.random(), o_pi[0])
140 |
141 | eos = 1 if random.random() < o_eos[0][0] else 0
142 |
143 | next_x1, next_x2 = sample_gaussian_2d(
144 | o_mu1[0][idx], o_mu2[0][idx], o_sigma1[0][idx], o_sigma2[0][idx], o_corr[0][idx])
145 |
146 | strokes[i, :] = [next_x1, next_x2, eos]
147 |
148 | params = [
149 | o_pi[0],
150 | o_mu1[0],
151 | o_mu2[0],
152 | o_sigma1[0],
153 | o_sigma2[0],
154 | o_corr[0],
155 | o_eos[0]]
156 | mixture_params.append(params)
157 |
158 | prev_x = np.zeros((1, 1, 3), dtype=np.float32)
159 | prev_x[0][0] = np.array([next_x1, next_x2, eos], dtype=np.float32)
160 | prev_state = next_state
161 |
162 | # self.args.data_scale # TODO: fix mega hack hardcoding the scale
163 | strokes[:, 0:2] *= 20
164 | return strokes, mixture_params
165 |
166 |
167 | # check output
168 | [strokes, params] = sample_stroke()
169 |
--------------------------------------------------------------------------------
/save/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "model.ckpt-11000"
2 | all_model_checkpoint_paths: "model.ckpt-9000"
3 | all_model_checkpoint_paths: "model.ckpt-9500"
4 | all_model_checkpoint_paths: "model.ckpt-10000"
5 | all_model_checkpoint_paths: "model.ckpt-10500"
6 | all_model_checkpoint_paths: "model.ckpt-11000"
7 |
--------------------------------------------------------------------------------
/save/config.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hardmaru/write-rnn-tensorflow/c9423a1adf1eaa65297e16f7a1a555ffb777d8c5/save/config.pkl
--------------------------------------------------------------------------------
/save/model.ckpt-10000:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hardmaru/write-rnn-tensorflow/c9423a1adf1eaa65297e16f7a1a555ffb777d8c5/save/model.ckpt-10000
--------------------------------------------------------------------------------
/save/model.ckpt-10000.meta:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hardmaru/write-rnn-tensorflow/c9423a1adf1eaa65297e16f7a1a555ffb777d8c5/save/model.ckpt-10000.meta
--------------------------------------------------------------------------------
/save/model.ckpt-10500:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hardmaru/write-rnn-tensorflow/c9423a1adf1eaa65297e16f7a1a555ffb777d8c5/save/model.ckpt-10500
--------------------------------------------------------------------------------
/save/model.ckpt-10500.meta:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hardmaru/write-rnn-tensorflow/c9423a1adf1eaa65297e16f7a1a555ffb777d8c5/save/model.ckpt-10500.meta
--------------------------------------------------------------------------------
/save/model.ckpt-11000:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hardmaru/write-rnn-tensorflow/c9423a1adf1eaa65297e16f7a1a555ffb777d8c5/save/model.ckpt-11000
--------------------------------------------------------------------------------
/save/model.ckpt-11000.meta:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hardmaru/write-rnn-tensorflow/c9423a1adf1eaa65297e16f7a1a555ffb777d8c5/save/model.ckpt-11000.meta
--------------------------------------------------------------------------------
/save/model.ckpt-9000:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hardmaru/write-rnn-tensorflow/c9423a1adf1eaa65297e16f7a1a555ffb777d8c5/save/model.ckpt-9000
--------------------------------------------------------------------------------
/save/model.ckpt-9000.meta:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hardmaru/write-rnn-tensorflow/c9423a1adf1eaa65297e16f7a1a555ffb777d8c5/save/model.ckpt-9000.meta
--------------------------------------------------------------------------------
/save/model.ckpt-9500:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hardmaru/write-rnn-tensorflow/c9423a1adf1eaa65297e16f7a1a555ffb777d8c5/save/model.ckpt-9500
--------------------------------------------------------------------------------
/save/model.ckpt-9500.meta:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hardmaru/write-rnn-tensorflow/c9423a1adf1eaa65297e16f7a1a555ffb777d8c5/save/model.ckpt-9500.meta
--------------------------------------------------------------------------------
/svg/example1.multi_color.svg:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/svg/example1.normal.svg:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import pickle
4 | import time
5 |
6 | import tensorflow as tf
7 |
8 | from model import Model
9 | from utils import DataLoader
10 |
11 |
12 | def main():
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument('--rnn_size', type=int, default=256,
15 | help='size of RNN hidden state')
16 | parser.add_argument('--num_layers', type=int, default=2,
17 | help='number of layers in the RNN')
18 | parser.add_argument('--model', type=str, default='lstm',
19 | help='rnn, gru, or lstm')
20 | parser.add_argument('--batch_size', type=int, default=50,
21 | help='minibatch size')
22 | parser.add_argument('--seq_length', type=int, default=300,
23 | help='RNN sequence length')
24 | parser.add_argument('--num_epochs', type=int, default=30,
25 | help='number of epochs')
26 | parser.add_argument('--save_every', type=int, default=500,
27 | help='save frequency')
28 | parser.add_argument('--model_dir', type=str, default='save',
29 | help='directory to save model to')
30 | parser.add_argument('--grad_clip', type=float, default=10.,
31 | help='clip gradients at this value')
32 | parser.add_argument('--learning_rate', type=float, default=0.005,
33 | help='learning rate')
34 | parser.add_argument('--decay_rate', type=float, default=0.95,
35 | help='decay rate for rmsprop')
36 | parser.add_argument('--num_mixture', type=int, default=20,
37 | help='number of gaussian mixtures')
38 | parser.add_argument('--data_scale', type=float, default=20,
39 | help='factor to scale raw data down by')
40 | parser.add_argument('--keep_prob', type=float, default=0.8,
41 | help='dropout keep probability')
42 | args = parser.parse_args()
43 | train(args)
44 |
45 |
46 | def train(args):
47 | data_loader = DataLoader(args.batch_size, args.seq_length, args.data_scale)
48 |
49 | if args.model_dir != '' and not os.path.exists(args.model_dir):
50 | os.makedirs(args.model_dir)
51 |
52 | with open(os.path.join(args.model_dir, 'config.pkl'), 'wb') as f:
53 | pickle.dump(args, f)
54 |
55 | model = Model(args)
56 |
57 | with tf.Session() as sess:
58 | summary_writer = tf.summary.FileWriter(
59 | os.path.join(args.model_dir, 'log'), sess.graph)
60 |
61 | tf.global_variables_initializer().run()
62 | saver = tf.train.Saver(tf.global_variables())
63 | for e in range(args.num_epochs):
64 | sess.run(tf.assign(model.lr,
65 | args.learning_rate * (args.decay_rate ** e)))
66 | data_loader.reset_batch_pointer()
67 | v_x, v_y = data_loader.validation_data()
68 | valid_feed = {
69 | model.input_data: v_x,
70 | model.target_data: v_y,
71 | model.state_in: model.state_in.eval()}
72 | state = model.state_in.eval()
73 | for b in range(data_loader.num_batches):
74 | ith_train_step = e * data_loader.num_batches + b
75 | start = time.time()
76 | x, y = data_loader.next_batch()
77 | feed = {
78 | model.input_data: x,
79 | model.target_data: y,
80 | model.state_in: state}
81 | train_loss_summary, train_loss, state, _ = sess.run(
82 | [model.train_loss_summary, model.cost, model.state_out, model.train_op], feed)
83 | summary_writer.add_summary(train_loss_summary, ith_train_step)
84 |
85 | valid_loss_summary, valid_loss, = sess.run(
86 | [model.valid_loss_summary, model.cost], valid_feed)
87 | summary_writer.add_summary(valid_loss_summary, ith_train_step)
88 |
89 | end = time.time()
90 | print(
91 | "{}/{} (epoch {}), train_loss = {:.3f}, valid_loss = {:.3f}, time/batch = {:.3f}"
92 | .format(
93 | ith_train_step,
94 | args.num_epochs * data_loader.num_batches,
95 | e,
96 | train_loss,
97 | valid_loss,
98 | end - start))
99 | if (ith_train_step %
100 | args.save_every == 0) and (ith_train_step > 0):
101 | checkpoint_path = os.path.join(
102 | args.model_dir, 'model.ckpt')
103 | saver.save(
104 | sess,
105 | checkpoint_path,
106 | global_step=ith_train_step)
107 | print("model saved to {}".format(checkpoint_path))
108 |
109 |
110 | if __name__ == '__main__':
111 | main()
112 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | import random
4 | import xml.etree.ElementTree as ET
5 |
6 | import numpy as np
7 | import svgwrite
8 | from IPython.display import SVG, display
9 |
10 |
11 | def get_bounds(data, factor):
12 | min_x = 0
13 | max_x = 0
14 | min_y = 0
15 | max_y = 0
16 |
17 | abs_x = 0
18 | abs_y = 0
19 | for i in range(len(data)):
20 | x = float(data[i, 0]) / factor
21 | y = float(data[i, 1]) / factor
22 | abs_x += x
23 | abs_y += y
24 | min_x = min(min_x, abs_x)
25 | min_y = min(min_y, abs_y)
26 | max_x = max(max_x, abs_x)
27 | max_y = max(max_y, abs_y)
28 |
29 | return (min_x, max_x, min_y, max_y)
30 |
31 | # old version, where each path is entire stroke (smaller svg size, but
32 | # have to keep same color)
33 |
34 |
35 | def draw_strokes(data, factor=10, svg_filename='sample.svg'):
36 | min_x, max_x, min_y, max_y = get_bounds(data, factor)
37 | dims = (50 + max_x - min_x, 50 + max_y - min_y)
38 |
39 | dwg = svgwrite.Drawing(svg_filename, size=dims)
40 | dwg.add(dwg.rect(insert=(0, 0), size=dims, fill='white'))
41 |
42 | lift_pen = 1
43 |
44 | abs_x = 25 - min_x
45 | abs_y = 25 - min_y
46 | p = "M%s,%s " % (abs_x, abs_y)
47 |
48 | command = "m"
49 |
50 | for i in range(len(data)):
51 | if (lift_pen == 1):
52 | command = "m"
53 | elif (command != "l"):
54 | command = "l"
55 | else:
56 | command = ""
57 | x = float(data[i, 0]) / factor
58 | y = float(data[i, 1]) / factor
59 | lift_pen = data[i, 2]
60 | p += command + str(x) + "," + str(y) + " "
61 |
62 | the_color = "black"
63 | stroke_width = 1
64 |
65 | dwg.add(dwg.path(p).stroke(the_color, stroke_width).fill("none"))
66 |
67 | dwg.save()
68 | display(SVG(dwg.tostring()))
69 |
70 |
71 | def draw_strokes_eos_weighted(
72 | stroke,
73 | param,
74 | factor=10,
75 | svg_filename='sample_eos.svg'):
76 | c_data_eos = np.zeros((len(stroke), 3))
77 | for i in range(len(param)):
78 | # make color gray scale, darker = more likely to eos
79 | c_data_eos[i, :] = (1 - param[i][6][0]) * 225
80 | draw_strokes_custom_color(
81 | stroke,
82 | factor=factor,
83 | svg_filename=svg_filename,
84 | color_data=c_data_eos,
85 | stroke_width=3)
86 |
87 |
88 | def draw_strokes_random_color(
89 | stroke,
90 | factor=10,
91 | svg_filename='sample_random_color.svg',
92 | per_stroke_mode=True):
93 | c_data = np.array(np.random.rand(len(stroke), 3) * 240, dtype=np.uint8)
94 | if per_stroke_mode:
95 | switch_color = False
96 | for i in range(len(stroke)):
97 | if switch_color == False and i > 0:
98 | c_data[i] = c_data[i - 1]
99 | if stroke[i, 2] < 1: # same strike
100 | switch_color = False
101 | else:
102 | switch_color = True
103 | draw_strokes_custom_color(
104 | stroke,
105 | factor=factor,
106 | svg_filename=svg_filename,
107 | color_data=c_data,
108 | stroke_width=2)
109 |
110 |
111 | def draw_strokes_custom_color(
112 | data,
113 | factor=10,
114 | svg_filename='test.svg',
115 | color_data=None,
116 | stroke_width=1):
117 | min_x, max_x, min_y, max_y = get_bounds(data, factor)
118 | dims = (50 + max_x - min_x, 50 + max_y - min_y)
119 |
120 | dwg = svgwrite.Drawing(svg_filename, size=dims)
121 | dwg.add(dwg.rect(insert=(0, 0), size=dims, fill='white'))
122 |
123 | lift_pen = 1
124 | abs_x = 25 - min_x
125 | abs_y = 25 - min_y
126 |
127 | for i in range(len(data)):
128 |
129 | x = float(data[i, 0]) / factor
130 | y = float(data[i, 1]) / factor
131 |
132 | prev_x = abs_x
133 | prev_y = abs_y
134 |
135 | abs_x += x
136 | abs_y += y
137 |
138 | if (lift_pen == 1):
139 | p = "M " + str(abs_x) + "," + str(abs_y) + " "
140 | else:
141 | p = "M +" + str(prev_x) + "," + str(prev_y) + \
142 | " L " + str(abs_x) + "," + str(abs_y) + " "
143 |
144 | lift_pen = data[i, 2]
145 |
146 | the_color = "black"
147 |
148 | if (color_data is not None):
149 | the_color = "rgb(" + str(int(color_data[i, 0])) + "," + str(
150 | int(color_data[i, 1])) + "," + str(int(color_data[i, 2])) + ")"
151 |
152 | dwg.add(dwg.path(p).stroke(the_color, stroke_width).fill(the_color))
153 | dwg.save()
154 | display(SVG(dwg.tostring()))
155 |
156 |
157 | def draw_strokes_pdf(data, param, factor=10, svg_filename='sample_pdf.svg'):
158 | min_x, max_x, min_y, max_y = get_bounds(data, factor)
159 | dims = (50 + max_x - min_x, 50 + max_y - min_y)
160 |
161 | dwg = svgwrite.Drawing(svg_filename, size=dims)
162 | dwg.add(dwg.rect(insert=(0, 0), size=dims, fill='white'))
163 |
164 | abs_x = 25 - min_x
165 | abs_y = 25 - min_y
166 |
167 | num_mixture = len(param[0][0])
168 |
169 | for i in range(len(data)):
170 |
171 | x = float(data[i, 0]) / factor
172 | y = float(data[i, 1]) / factor
173 |
174 | for k in range(num_mixture):
175 | pi = param[i][0][k]
176 | if pi > 0.01: # optimisation, ignore pi's less than 1% chance
177 | mu1 = param[i][1][k]
178 | mu2 = param[i][2][k]
179 | s1 = param[i][3][k]
180 | s2 = param[i][4][k]
181 | sigma = np.sqrt(s1 * s2)
182 | dwg.add(dwg.circle(center=(abs_x + mu1 * factor,
183 | abs_y + mu2 * factor),
184 | r=int(sigma * factor)).fill('red',
185 | opacity=pi / (sigma * sigma * factor)))
186 |
187 | prev_x = abs_x
188 | prev_y = abs_y
189 |
190 | abs_x += x
191 | abs_y += y
192 |
193 | dwg.save()
194 | display(SVG(dwg.tostring()))
195 |
196 |
197 | class DataLoader():
198 | def __init__(
199 | self,
200 | batch_size=50,
201 | seq_length=300,
202 | scale_factor=10,
203 | limit=500):
204 | self.data_dir = "./data"
205 | self.batch_size = batch_size
206 | self.seq_length = seq_length
207 | self.scale_factor = scale_factor # divide data by this factor
208 | self.limit = limit # removes large noisy gaps in the data
209 |
210 | data_file = os.path.join(self.data_dir, "strokes_training_data.cpkl")
211 | raw_data_dir = self.data_dir + "/lineStrokes"
212 |
213 | if not (os.path.exists(data_file)):
214 | print("creating training data pkl file from raw source")
215 | self.preprocess(raw_data_dir, data_file)
216 |
217 | self.load_preprocessed(data_file)
218 | self.reset_batch_pointer()
219 |
220 | def preprocess(self, data_dir, data_file):
221 | # create data file from raw xml files from iam handwriting source.
222 |
223 | # build the list of xml files
224 | filelist = []
225 | # Set the directory you want to start from
226 | rootDir = data_dir
227 | for dirName, subdirList, fileList in os.walk(rootDir):
228 | #print('Found directory: %s' % dirName)
229 | for fname in fileList:
230 | #print('\t%s' % fname)
231 | filelist.append(dirName + "/" + fname)
232 |
233 | # function to read each individual xml file
234 | def getStrokes(filename):
235 | tree = ET.parse(filename)
236 | root = tree.getroot()
237 |
238 | result = []
239 |
240 | x_offset = 1e20
241 | y_offset = 1e20
242 | y_height = 0
243 | for i in range(1, 4):
244 | x_offset = min(x_offset, float(root[0][i].attrib['x']))
245 | y_offset = min(y_offset, float(root[0][i].attrib['y']))
246 | y_height = max(y_height, float(root[0][i].attrib['y']))
247 | y_height -= y_offset
248 | x_offset -= 100
249 | y_offset -= 100
250 |
251 | for stroke in root[1].findall('Stroke'):
252 | points = []
253 | for point in stroke.findall('Point'):
254 | points.append(
255 | [float(point.attrib['x']) - x_offset, float(point.attrib['y']) - y_offset])
256 | result.append(points)
257 |
258 | return result
259 |
260 | # converts a list of arrays into a 2d numpy int16 array
261 | def convert_stroke_to_array(stroke):
262 |
263 | n_point = 0
264 | for i in range(len(stroke)):
265 | n_point += len(stroke[i])
266 | stroke_data = np.zeros((n_point, 3), dtype=np.int16)
267 |
268 | prev_x = 0
269 | prev_y = 0
270 | counter = 0
271 |
272 | for j in range(len(stroke)):
273 | for k in range(len(stroke[j])):
274 | stroke_data[counter, 0] = int(stroke[j][k][0]) - prev_x
275 | stroke_data[counter, 1] = int(stroke[j][k][1]) - prev_y
276 | prev_x = int(stroke[j][k][0])
277 | prev_y = int(stroke[j][k][1])
278 | stroke_data[counter, 2] = 0
279 | if (k == (len(stroke[j]) - 1)): # end of stroke
280 | stroke_data[counter, 2] = 1
281 | counter += 1
282 | return stroke_data
283 |
284 | # build stroke database of every xml file inside iam database
285 | strokes = []
286 | for i in range(len(filelist)):
287 | if (filelist[i][-3:] == 'xml'):
288 | print('processing ' + filelist[i])
289 | strokes.append(
290 | convert_stroke_to_array(
291 | getStrokes(
292 | filelist[i])))
293 |
294 | f = open(data_file, "wb")
295 | pickle.dump(strokes, f, protocol=2)
296 | f.close()
297 |
298 | def load_preprocessed(self, data_file):
299 | f = open(data_file, "rb")
300 | self.raw_data = pickle.load(f)
301 | f.close()
302 |
303 | # goes thru the list, and only keeps the text entries that have more
304 | # than seq_length points
305 | self.data = []
306 | self.valid_data = []
307 | counter = 0
308 |
309 | # every 1 in 20 (5%) will be used for validation data
310 | cur_data_counter = 0
311 | for data in self.raw_data:
312 | if len(data) > (self.seq_length + 2):
313 | # removes large gaps from the data
314 | data = np.minimum(data, self.limit)
315 | data = np.maximum(data, -self.limit)
316 | data = np.array(data, dtype=np.float32)
317 | data[:, 0:2] /= self.scale_factor
318 | cur_data_counter = cur_data_counter + 1
319 | if cur_data_counter % 20 == 0:
320 | self.valid_data.append(data)
321 | else:
322 | self.data.append(data)
323 | # number of equiv batches this datapoint is worth
324 | counter += int(len(data) / ((self.seq_length + 2)))
325 |
326 | print("train data: {}, valid data: {}".format(
327 | len(self.data), len(self.valid_data)))
328 | # minus 1, since we want the ydata to be a shifted version of x data
329 | self.num_batches = int(counter / self.batch_size)
330 |
331 | def validation_data(self):
332 | # returns validation data
333 | x_batch = []
334 | y_batch = []
335 | for i in range(self.batch_size):
336 | data = self.valid_data[i % len(self.valid_data)]
337 | idx = 0
338 | x_batch.append(np.copy(data[idx:idx + self.seq_length]))
339 | y_batch.append(np.copy(data[idx + 1:idx + self.seq_length + 1]))
340 | return x_batch, y_batch
341 |
342 | def next_batch(self):
343 | # returns a randomised, seq_length sized portion of the training data
344 | x_batch = []
345 | y_batch = []
346 | for i in range(self.batch_size):
347 | data = self.data[self.pointer]
348 | # number of equiv batches this datapoint is worth
349 | n_batch = int(len(data) / ((self.seq_length + 2)))
350 | idx = random.randint(0, len(data) - self.seq_length - 2)
351 | x_batch.append(np.copy(data[idx:idx + self.seq_length]))
352 | y_batch.append(np.copy(data[idx + 1:idx + self.seq_length + 1]))
353 | # adjust sampling probability.
354 | if random.random() < (1.0 / float(n_batch)):
355 | # if this is a long datapoint, sample this data more with
356 | # higher probability
357 | self.tick_batch_pointer()
358 | return x_batch, y_batch
359 |
360 | def tick_batch_pointer(self):
361 | self.pointer += 1
362 | if (self.pointer >= len(self.data)):
363 | self.pointer = 0
364 |
365 | def reset_batch_pointer(self):
366 | self.pointer = 0
367 |
--------------------------------------------------------------------------------