├── README.md
├── data
└── kanji.cpkl
├── example
├── aaron_sheep_sample.svg
├── cat_vae.png
├── catbus.svg
├── catbus2.svg
├── data_format.svg
├── doraemon.svg
├── elephant.svg
├── elephantpig.svg
├── frog_crab_cat.png
├── full_predictions.svg
├── morph_catchair.svg
├── morph_catchairs.svg
├── multiple_interpolations.png
├── omniglot_sample.svg
├── output.svg
├── owlmorph.svg
├── pig_morph.png
├── short_kanji_sample.svg
├── sketch_rnn.png
├── sketch_rnn.svg
├── sketch_rnn_examples.svg
├── sketch_rnn_schematic.svg
├── training.svg
├── vae_analogy.svg
├── vae_cat.svg
├── vae_cats.svg
├── vae_morph.svg
├── vae_morphs.svg
├── vae_pig.svg
└── vae_pigs.svg
├── get_kanji.sh
├── model.py
├── requirements.linux-x64-cpu.txt
├── requirements.linux-x64-gpu.txt
├── requirements.mac.txt
├── sample.py
├── save
└── kanji
│ ├── checkpoint
│ ├── config.pkl
│ └── model.ckpt-0
├── svg
├── __init__.py
└── path
│ ├── __init__.py
│ ├── parser.py
│ ├── path.py
│ └── tests
│ ├── __init__.py
│ ├── test_doc.py
│ ├── test_generation.py
│ ├── test_parsing.py
│ └── test_paths.py
├── train.py
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 | # Depreciated
2 |
3 | This version of sketch-rnn has been depreciated. Please see an updated version of [sketch-rnn](https://github.com/tensorflow/magenta/blob/master/magenta/models/sketch_rnn/README.md), which is a full generative model for vector drawings.
4 |
5 | # sketch-rnn
6 |
7 | Implementation multi-layer recurrent neural network (RNN, LSTM GRU) used to model and generate sketches stored in .svg vector graphic files. The methodology used is to combine Mixture Density Networks with a RNN, along with modelling dynamic end-of-stroke and end-of-content probabilities learned from a large corpus of similar .svg files, to generate drawings that is simlar to the vector training data.
8 |
9 | See my blog post at [blog.otoro.net](http://blog.otoro.net/2015/12/28/recurrent-net-dreams-up-fake-chinese-characters-in-vector-format-with-tensorflow/) for a detailed description on applying `sketch-rnn` to learn to generate fake Chinese characters in vector format.
10 |
11 | Example Training Sketches (20 randomly chosen out of 11000 [KanjiVG](http://kanjivg.tagaini.net/) dataset):
12 |
13 | 
14 |
15 | Generated Sketches (Temperature = 0.1):
16 |
17 | 
18 |
19 | # Basic Usage
20 |
21 | I tested the implementation on TensorFlow 0.50. I also used the following libraries to help:
22 |
23 | ```
24 | svgwrite
25 | IPython.display.SVG
26 | IPython.display.display
27 | xml.etree.ElementTree
28 | argparse
29 | cPickle
30 | svg.path
31 | ```
32 |
33 | ## Loading in Training Data
34 |
35 | The training data is located inside the `data` subdirectory. In this repo, I've included `kanji.cpkl` which is a preprocessed array of KanjiVG characters.
36 |
37 | To add a new set of training data, for example, from the [TU Berlin Sketch Database](http://cybertron.cg.tu-berlin.de/eitz/projects/classifysketch/), you have to create a subdirectory, say `tuberlin` inside the `data` directory, and in addition create a directory of the same name in the `save` directory. So you end up with `data/tuberlin/` and `save/tuberlin`, where `tuberlin` is defined as a name field for flags in the training and sample programs later on. `save/tuberlin` will contain the check-pointed trained models later on.
38 |
39 | Now, put a large collection of .svg files into `data/tuberlin/`. You can even create subdirectories within `data/tuberlin/` and it will work, as the `SketchLoader` class will scan the entire subdirectory tree.
40 |
41 | Currently, `sketch-rnn` only processes `path` elements inside svg files, and within the `path` elements, it only cares about lines and belzier curves at the moment. I found this sufficient to handle TUBerlin and KanjiVG databases, although it wouldn't be difficult to extent to process the other curve elements, even shape elements in the future.
42 |
43 | You can use `utils.py` to play out some random training data after the svg files have been copied in:
44 |
45 | ```
46 | %run -i utils.py
47 | loader = SketchLoader(data_filename = 'tuberlin')
48 | draw_stroke_color(random.choice(loader.raw_data))
49 | ```
50 |
51 | 
52 |
53 | For this algorithm to work, I recommend the data be similar in size, and similar in style / content. For examples if we have bananas, buildings, elephants, rockets, insects of varying shapes and sizes, it would most likely just produce gibberish.
54 |
55 | ## Training the Model
56 |
57 | After the data is loaded, let's continue with the 'tuberlin' example, you can run `python train.py --dataset_name tuberlin`
58 |
59 | A number of flags can be set for training if you wish to experiment with the parameters. You probably want to change these around, especially the scaling factors to better suit the sizes of your .svg data.
60 |
61 | The default values are in `train.py`
62 |
63 | ```
64 | --rnn_size RNN_SIZE size of RNN hidden state (256)
65 | --num_layers NUM_LAYERS number of layers in the RNN (2)
66 | --model MODEL rnn, gru, or lstm (lstm)
67 | --batch_size BATCH_SIZE minibatch size (100)
68 | --seq_length SEQ_LENGTH RNN sequence length (300)
69 | --num_epochs NUM_EPOCHS number of epochs (500)
70 | --save_every SAVE_EVERY save frequency (250)
71 | --grad_clip GRAD_CLIP clip gradients at this value (5.0)
72 | --learning_rate LEARNING_RATE learning rate (0.005)
73 | --decay_rate DECAY_RATE decay rate after each epoch (adam is used) (0.99)
74 | --num_mixture NUM_MIXTURE number of gaussian mixtures (24)
75 | --data_scale DATA_SCALE factor to scale raw data down by (15.0)
76 | --keep_prob KEEP_PROB dropout keep probability (0.8)
77 | --stroke_importance_factor F gradient boosting of sketch-finish event (200.0)
78 | --dataset_name DATASET_NAME name of directory containing training data (kanji)
79 | ```
80 |
81 | ## Sampling a Sketch
82 |
83 | I've included a pretrained model in `/save` so it should work out of the box. Running `python sample.py --filename output --num_picture 10 --dataset_name kanji` will generate an .svg file containing 10 fake Kanji characters using the pretrained model. Please run `python sample.py --help` to examine extra flags, to see how to change things like number of sketches per row, etc.
84 |
85 | It should be straight forward to examine `sample.py` to be able to generate sketches interactively using an IPython prompt rather than in the command line. Running `%run -i sample.py` in an IPython interactive session would generate sketches shown in the IPython interface as well as generating an .svg output.
86 |
87 | ## More useful links, pointers, datasets
88 |
89 | - Alex Graves' [paper](http://arxiv.org/abs/1308.0850) on text sequence and handwriting generation.
90 |
91 | - Karpathy's [char-rnn](https://github.com/karpathy/char-rnn) tool, motivation for creating sketch-rnn.
92 |
93 | - [KanjiVG](http://kanjivg.tagaini.net/). Fantastic Database of Kanji Stroke Order.
94 |
95 | - Very clean TensorFlow implementation of [char-rnn](https://github.com/sherjilozair/char-rnn-tensorflow), written by [Sherjil Ozair](https://github.com/sherjilozair), where I based the skeleton of this code off of.
96 |
97 | - [svg.path](https://pypi.python.org/pypi/svg.path). I used this well written tool to help convert path data into line data.
98 |
99 | - CASIA Online and Offline Chinese [Handwriting Databases](http://www.nlpr.ia.ac.cn/databases/handwriting/Download.html). Download stroke data for written cursive Simplifed Chinese.
100 |
101 | - How Do Humans Sketch Objects? [TU Berlin Sketch Database](http://cybertron.cg.tu-berlin.de/eitz/projects/classifysketch/). Would be interesting to extend this work and generate random vector art of real life stuff.
102 |
103 | - Doraemon in [SVG format](http://yylam.blogspot.hk/2012/04/doraemon-in-svg-format-doraemonsvg.html).
104 |
105 | - [Potrace](https://en.wikipedia.org/wiki/Potrace). Beautiful looking tool to convert raster bitmapped drawings into SVG for potentially scaling up resolution of drawings. Could potentially apply this to generate large amounts of training data.
106 |
107 | - [Rendering Belzier Curve Codes](http://rosettacode.org/wiki/Bitmap/B%C3%A9zier_curves/Cubic). I used this very useful code to convert Belzier curves into line segments.
108 |
109 |
110 | # License
111 |
112 | MIT
113 |
--------------------------------------------------------------------------------
/data/kanji.cpkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hardmaru/sketch-rnn/862bb94a15d48f0c42af71a9ede1681a0a2d2602/data/kanji.cpkl
--------------------------------------------------------------------------------
/example/cat_vae.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hardmaru/sketch-rnn/862bb94a15d48f0c42af71a9ede1681a0a2d2602/example/cat_vae.png
--------------------------------------------------------------------------------
/example/catbus.svg:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/example/catbus2.svg:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/example/doraemon.svg:
--------------------------------------------------------------------------------
1 |
2 |
71 |
--------------------------------------------------------------------------------
/example/elephant.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
35 |
--------------------------------------------------------------------------------
/example/elephantpig.svg:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/example/frog_crab_cat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hardmaru/sketch-rnn/862bb94a15d48f0c42af71a9ede1681a0a2d2602/example/frog_crab_cat.png
--------------------------------------------------------------------------------
/example/multiple_interpolations.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hardmaru/sketch-rnn/862bb94a15d48f0c42af71a9ede1681a0a2d2602/example/multiple_interpolations.png
--------------------------------------------------------------------------------
/example/omniglot_sample.svg:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/example/pig_morph.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hardmaru/sketch-rnn/862bb94a15d48f0c42af71a9ede1681a0a2d2602/example/pig_morph.png
--------------------------------------------------------------------------------
/example/sketch_rnn.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hardmaru/sketch-rnn/862bb94a15d48f0c42af71a9ede1681a0a2d2602/example/sketch_rnn.png
--------------------------------------------------------------------------------
/get_kanji.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #fetch and unpack the data
4 | cd data
5 | wget https://github.com/KanjiVG/kanjivg/releases/download/r20150615-2/kanjivg-20150615-2-all.zip
6 | unzip kanjivg-20150615-2-all.zip
7 | # move aside one problem file
8 | mkdir -p rejects
9 | mv kanji/05747-Kaisho.svg rejects
10 | # done, now try: python train.py --dataset_name kanji
11 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from tensorflow.models.rnn import rnn_cell
3 |
4 | import numpy as np
5 | import random
6 |
7 | class Model():
8 | def __init__(self, args, infer=False):
9 | if infer:
10 | args.batch_size = 1
11 | args.seq_length = 1
12 | self.args = args
13 |
14 | if args.model == 'rnn':
15 | cell_fn = rnn_cell.BasicRNNCell
16 | elif args.model == 'gru':
17 | cell_fn = rnn_cell.GRUCell
18 | elif args.model == 'lstm':
19 | cell_fn = rnn_cell.BasicLSTMCell
20 | else:
21 | raise Exception("model type not supported: {}".format(args.model))
22 |
23 | cell = cell_fn(args.rnn_size)
24 |
25 | cell = rnn_cell.MultiRNNCell([cell] * args.num_layers)
26 |
27 | if (infer == False and args.keep_prob < 1): # training mode
28 | cell = rnn_cell.DropoutWrapper(cell, output_keep_prob = args.keep_prob)
29 |
30 | self.cell = cell
31 |
32 | self.input_data = tf.placeholder(dtype=tf.float32, shape=[args.batch_size, args.seq_length, 5])
33 | self.target_data = tf.placeholder(dtype=tf.float32, shape=[args.batch_size, args.seq_length, 5])
34 | self.initial_state = cell.zero_state(batch_size=args.batch_size, dtype=tf.float32)
35 |
36 | self.num_mixture = args.num_mixture
37 | NOUT = 3 + self.num_mixture * 6 # [end_of_stroke + end_of_char, continue_with_stroke] + prob + 2*(mu + sig) + corr
38 |
39 | with tf.variable_scope('rnn_mdn'):
40 | output_w = tf.get_variable("output_w", [args.rnn_size, NOUT])
41 | output_b = tf.get_variable("output_b", [NOUT])
42 |
43 | inputs = tf.split(1, args.seq_length, self.input_data)
44 | inputs = [tf.squeeze(input_, [1]) for input_ in inputs]
45 |
46 | self.initial_input = np.zeros((args.batch_size, 5), dtype=np.float32)
47 | self.initial_input[:,4] = 1.0 # initially, the pen is down.
48 | self.initial_input = tf.constant(self.initial_input)
49 |
50 | def tfrepeat(a, repeats):
51 | num_row = a.get_shape()[0].value
52 | num_col = a.get_shape()[1].value
53 | assert(num_col == 1)
54 | result = [a for i in range(repeats)]
55 | result = tf.concat(0, result)
56 | result = tf.reshape(result, [repeats, num_row])
57 | result = tf.transpose(result)
58 | return result
59 |
60 | def custom_rnn_autodecoder(decoder_inputs, initial_input, initial_state, cell, scope=None):
61 | # customized rnn_decoder for the task of dealing with end of character
62 | with tf.variable_scope(scope or "rnn_decoder"):
63 | states = [initial_state]
64 | outputs = []
65 | prev = None
66 |
67 | for i in xrange(len(decoder_inputs)):
68 | inp = decoder_inputs[i]
69 | if i > 0:
70 | tf.get_variable_scope().reuse_variables()
71 | output, new_state = cell(inp, states[-1])
72 |
73 | num_batches = self.args.batch_size # new_state.get_shape()[0].value
74 | num_state = new_state.get_shape()[1].value
75 |
76 | # if the input has an end-of-character signal, have to zero out the state
77 |
78 | #to do: test this code.
79 |
80 | eoc_detection = inp[:,3]
81 | eoc_detection = tf.reshape(eoc_detection, [num_batches, 1])
82 |
83 | eoc_detection_state = tfrepeat(eoc_detection, num_state)
84 |
85 | eoc_detection_state = tf.greater(eoc_detection_state, tf.zeros_like(eoc_detection_state, dtype=tf.float32))
86 |
87 | new_state = tf.select(eoc_detection_state, initial_state, new_state)
88 |
89 | outputs.append(output)
90 | states.append(new_state)
91 | return outputs, states
92 |
93 | outputs, states = custom_rnn_autodecoder(inputs, self.initial_input, self.initial_state, cell, scope='rnn_mdn')
94 | output = tf.reshape(tf.concat(1, outputs), [-1, args.rnn_size])
95 | output = tf.nn.xw_plus_b(output, output_w, output_b)
96 | self.final_state = states[-1]
97 |
98 | # reshape target data so that it is compatible with prediction shape
99 | flat_target_data = tf.reshape(self.target_data,[-1, 5])
100 | [x1_data, x2_data, eos_data, eoc_data, cont_data] = tf.split(1, 5, flat_target_data)
101 | pen_data = tf.concat(1, [eos_data, eoc_data, cont_data])
102 |
103 | # long method:
104 | #flat_target_data = tf.split(1, args.seq_length, self.target_data)
105 | #flat_target_data = [tf.squeeze(flat_target_data_, [1]) for flat_target_data_ in flat_target_data]
106 | #flat_target_data = tf.reshape(tf.concat(1, flat_target_data), [-1, 3])
107 |
108 | def tf_2d_normal(x1, x2, mu1, mu2, s1, s2, rho):
109 | # eq # 24 and 25 of http://arxiv.org/abs/1308.0850
110 | norm1 = tf.sub(x1, mu1)
111 | norm2 = tf.sub(x2, mu2)
112 | s1s2 = tf.mul(s1, s2)
113 | z = tf.square(tf.div(norm1, s1))+tf.square(tf.div(norm2, s2))-2*tf.div(tf.mul(rho, tf.mul(norm1, norm2)), s1s2)
114 | negRho = 1-tf.square(rho)
115 | result = tf.exp(tf.div(-z,2*negRho))
116 | denom = 2*np.pi*tf.mul(s1s2, tf.sqrt(negRho))
117 | result = tf.div(result, denom)
118 | return result
119 |
120 | def get_lossfunc(z_pi, z_mu1, z_mu2, z_sigma1, z_sigma2, z_corr, z_pen, x1_data, x2_data, pen_data):
121 | result0 = tf_2d_normal(x1_data, x2_data, z_mu1, z_mu2, z_sigma1, z_sigma2, z_corr)
122 | # implementing eq # 26 of http://arxiv.org/abs/1308.0850
123 | epsilon = 1e-20
124 | result1 = tf.mul(result0, z_pi)
125 | result1 = tf.reduce_sum(result1, 1, keep_dims=True)
126 | result1 = -tf.log(tf.maximum(result1, 1e-20)) # at the beginning, some errors are exactly zero.
127 | result_shape = tf.reduce_mean(result1)
128 |
129 | result2 = tf.nn.softmax_cross_entropy_with_logits(z_pen, pen_data)
130 | pen_data_weighting = pen_data[:, 2]+np.sqrt(self.args.stroke_importance_factor)*pen_data[:, 0]+self.args.stroke_importance_factor*pen_data[:, 1]
131 | result2 = tf.mul(result2, pen_data_weighting)
132 | result_pen = tf.reduce_mean(result2)
133 |
134 | result = result_shape + result_pen
135 | return result, result_shape, result_pen,
136 |
137 | # below is where we need to do MDN splitting of distribution params
138 | def get_mixture_coef(output):
139 | # returns the tf slices containing mdn dist params
140 | # ie, eq 18 -> 23 of http://arxiv.org/abs/1308.0850
141 | z = output
142 | z_pen = z[:, 0:3] # end of stroke, end of character/content, continue w/ stroke
143 | z_pi, z_mu1, z_mu2, z_sigma1, z_sigma2, z_corr = tf.split(1, 6, z[:, 3:])
144 |
145 | # process output z's into MDN paramters
146 |
147 | # softmax all the pi's:
148 | max_pi = tf.reduce_max(z_pi, 1, keep_dims=True)
149 | z_pi = tf.sub(z_pi, max_pi)
150 | z_pi = tf.exp(z_pi)
151 | normalize_pi = tf.inv(tf.reduce_sum(z_pi, 1, keep_dims=True))
152 | z_pi = tf.mul(normalize_pi, z_pi)
153 |
154 | # exponentiate the sigmas and also make corr between -1 and 1.
155 | z_sigma1 = tf.exp(z_sigma1)
156 | z_sigma2 = tf.exp(z_sigma2)
157 | z_corr = tf.tanh(z_corr)
158 |
159 | return [z_pi, z_mu1, z_mu2, z_sigma1, z_sigma2, z_corr, z_pen]
160 |
161 | [o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, o_corr, o_pen] = get_mixture_coef(output)
162 |
163 | self.pi = o_pi
164 | self.mu1 = o_mu1
165 | self.mu2 = o_mu2
166 | self.sigma1 = o_sigma1
167 | self.sigma2 = o_sigma2
168 | self.corr = o_corr
169 | self.pen = o_pen # state of the pen
170 |
171 | [lossfunc, loss_shape, loss_pen] = get_lossfunc(o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, o_corr, o_pen, x1_data, x2_data, pen_data)
172 | self.cost = lossfunc
173 | self.cost_shape = loss_shape
174 | self.cost_pen = loss_pen
175 |
176 | self.lr = tf.Variable(0.01, trainable=False)
177 | tvars = tf.trainable_variables()
178 | grads, _ = tf.clip_by_global_norm(tf.gradients(self.cost, tvars), args.grad_clip)
179 | optimizer = tf.train.AdamOptimizer(self.lr, epsilon=0.001)
180 | self.train_op = optimizer.apply_gradients(zip(grads, tvars))
181 |
182 |
183 | def sample(self, sess, num=300, temp_mixture=1.0, temp_pen=1.0, stop_if_eoc = False):
184 |
185 | def get_pi_idx(x, pdf):
186 | N = pdf.size
187 | accumulate = 0
188 | for i in range(0, N):
189 | accumulate += pdf[i]
190 | if (accumulate >= x):
191 | return i
192 | print 'error with sampling ensemble'
193 | return -1
194 |
195 | def sample_gaussian_2d(mu1, mu2, s1, s2, rho):
196 | mean = [mu1, mu2]
197 | cov = [[s1*s1, rho*s1*s2], [rho*s1*s2, s2*s2]]
198 | x = np.random.multivariate_normal(mean, cov, 1)
199 | return x[0][0], x[0][1]
200 |
201 | prev_x = np.zeros((1, 1, 5), dtype=np.float32)
202 | #prev_x[0, 0, 2] = 1 # initially, we want to see beginning of new stroke
203 | #prev_x[0, 0, 3] = 1 # initially, we want to see beginning of new character/content
204 | prev_state = sess.run(self.cell.zero_state(self.args.batch_size, tf.float32))
205 |
206 | strokes = np.zeros((num, 5), dtype=np.float32)
207 | mixture_params = []
208 |
209 | for i in xrange(num):
210 |
211 | feed = {self.input_data: prev_x, self.initial_state:prev_state}
212 |
213 | [o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, o_corr, o_pen, next_state] = sess.run([self.pi, self.mu1, self.mu2, self.sigma1, self.sigma2, self.corr, self.pen, self.final_state],feed)
214 |
215 | pi_pdf = o_pi[0]
216 | if i > 1:
217 | pi_pdf = np.log(pi_pdf) / temp_mixture
218 | pi_pdf -= pi_pdf.max()
219 | pi_pdf = np.exp(pi_pdf)
220 | pi_pdf /= pi_pdf.sum()
221 |
222 | idx = get_pi_idx(random.random(), pi_pdf)
223 |
224 | pen_pdf = o_pen[0]
225 | if i > 1:
226 | pi_pdf /= temp_pen # softmax convert to prob
227 | pen_pdf -= pen_pdf.max()
228 | pen_pdf = np.exp(pen_pdf)
229 | pen_pdf /= pen_pdf.sum()
230 |
231 | pen_idx = get_pi_idx(random.random(), pen_pdf)
232 | eos = 0
233 | eoc = 0
234 | cont_state = 0
235 |
236 | if pen_idx == 0:
237 | eos = 1
238 | elif pen_idx == 1:
239 | eoc = 1
240 | else:
241 | cont_state = 1
242 |
243 | next_x1, next_x2 = sample_gaussian_2d(o_mu1[0][idx], o_mu2[0][idx], o_sigma1[0][idx], o_sigma2[0][idx], o_corr[0][idx])
244 |
245 | strokes[i,:] = [next_x1, next_x2, eos, eoc, cont_state]
246 |
247 | params = [pi_pdf, o_mu1[0], o_mu2[0], o_sigma1[0], o_sigma2[0], o_corr[0], pen_pdf]
248 | mixture_params.append(params)
249 |
250 | # early stopping condition
251 | if (stop_if_eoc and eoc == 1):
252 | strokes = strokes[0:i+1, :]
253 | break
254 |
255 | prev_x = np.zeros((1, 1, 5), dtype=np.float32)
256 | prev_x[0][0] = np.array([next_x1, next_x2, eos, eoc, cont_state], dtype=np.float32)
257 | prev_state = next_state
258 |
259 | strokes[:,0:2] *= self.args.data_scale
260 | return strokes, mixture_params
261 |
262 |
263 |
--------------------------------------------------------------------------------
/requirements.linux-x64-cpu.txt:
--------------------------------------------------------------------------------
1 | decorator==4.0.9
2 | ipython==4.1.2
3 | ipython-genutils==0.1.0
4 | numpy==1.10.4
5 | path.py==8.1.2
6 | pexpect==4.0.1
7 | pickleshare==0.6
8 | protobuf==3.0.0b2
9 | ptyprocess==0.5.1
10 | pyparsing==2.1.0
11 | simplegeneric==0.8.1
12 | six==1.10.0
13 | svgwrite==1.1.6
14 | https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.7.1-cp27-none-linux_x86_64.whl
15 | traitlets==4.2.1
16 | wheel==0.29.0
17 | wsgiref==0.1.2
18 |
--------------------------------------------------------------------------------
/requirements.linux-x64-gpu.txt:
--------------------------------------------------------------------------------
1 | decorator==4.0.9
2 | ipython==4.1.2
3 | ipython-genutils==0.1.0
4 | numpy==1.10.4
5 | path.py==8.1.2
6 | pexpect==4.0.1
7 | pickleshare==0.6
8 | protobuf==3.0.0b2
9 | ptyprocess==0.5.1
10 | pyparsing==2.1.0
11 | simplegeneric==0.8.1
12 | six==1.10.0
13 | svgwrite==1.1.6
14 | https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.7.1-cp27-none-linux_x86_64.whl
15 | traitlets==4.2.1
16 | wheel==0.29.0
17 | wsgiref==0.1.2
18 |
--------------------------------------------------------------------------------
/requirements.mac.txt:
--------------------------------------------------------------------------------
1 | decorator==4.0.9
2 | ipython==4.1.2
3 | ipython-genutils==0.1.0
4 | numpy==1.10.4
5 | path.py==8.1.2
6 | pexpect==4.0.1
7 | pickleshare==0.6
8 | protobuf==3.0.0b2
9 | ptyprocess==0.5.1
10 | pyparsing==2.1.0
11 | simplegeneric==0.8.1
12 | six==1.10.0
13 | svgwrite==1.1.6
14 | https://storage.googleapis.com/tensorflow/mac/tensorflow-0.7.1-cp27-none-any.whl
15 | traitlets==4.2.1
16 | wheel==0.29.0
17 | wsgiref==0.1.2
18 |
--------------------------------------------------------------------------------
/sample.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 |
4 | import time
5 | import os
6 | import cPickle
7 | import argparse
8 |
9 | from utils import *
10 | from model import Model
11 | import random
12 |
13 | import svgwrite
14 | from IPython.display import SVG, display
15 |
16 | # main code (not in a main function since I want to run this script in IPython as well).
17 | def in_ipython():
18 | try:
19 | __IPYTHON__
20 | except NameError:
21 | return False
22 | else:
23 | return True
24 |
25 |
26 | parser = argparse.ArgumentParser()
27 | parser.add_argument('--filename', type=str, default='output',
28 | help='filename of .svg file to output, without .svg')
29 | parser.add_argument('--sample_length', type=int, default=600,
30 | help='number of strokes to sample')
31 | parser.add_argument('--picture_size', type=float, default=160,
32 | help='a centered svg will be generated of this size')
33 | parser.add_argument('--scale_factor', type=float, default=1,
34 | help='factor to scale down by for svg output. smaller means bigger output')
35 | parser.add_argument('--num_picture', type=int, default=20,
36 | help='number of pictures to generate')
37 | parser.add_argument('--num_col', type=int, default=5,
38 | help='if num_picture > 1, how many pictures per row?')
39 | parser.add_argument('--dataset_name', type=str, default="kanji",
40 | help='name of directory containing training data')
41 | parser.add_argument('--color_mode', type=int, default=1,
42 | help='set to 0 if you are a black and white sort of person...')
43 | parser.add_argument('--stroke_width', type=float, default=2.0,
44 | help='thickness of pen lines')
45 | parser.add_argument('--temperature', type=float, default=0.1,
46 | help='sampling temperature')
47 | sample_args = parser.parse_args()
48 |
49 | color_mode = True
50 | if sample_args.color_mode == 0:
51 | color_mode = False
52 |
53 |
54 | with open(os.path.join('save', sample_args.dataset_name, 'config.pkl')) as f: # future
55 | saved_args = cPickle.load(f)
56 |
57 | model = Model(saved_args, True)
58 | sess = tf.InteractiveSession()
59 | saver = tf.train.Saver(tf.all_variables())
60 |
61 | ckpt = tf.train.get_checkpoint_state(os.path.join('save', sample_args.dataset_name))
62 | print "loading model: ",ckpt.model_checkpoint_path
63 |
64 | saver.restore(sess, ckpt.model_checkpoint_path)
65 |
66 | def draw_sketch_array(strokes_array, svg_only = False):
67 | draw_stroke_color_array(strokes_array, factor=sample_args.scale_factor, maxcol = sample_args.num_col, svg_filename = sample_args.filename+'.svg', stroke_width = sample_args.stroke_width, block_size = sample_args.picture_size, svg_only = svg_only, color_mode = color_mode)
68 |
69 | def sample_sketches(min_size_ratio = 0.0, max_size_ratio = 0.8, min_num_stroke = 4, max_num_stroke=22, svg_only = True):
70 | N = sample_args.num_picture
71 | frame_size = float(sample_args.picture_size)
72 | max_size = frame_size * max_size_ratio
73 | min_size = frame_size * min_size_ratio
74 | count = 0
75 | sketch_list = []
76 | param_list = []
77 |
78 | temp_mixture = sample_args.temperature
79 | temp_pen = sample_args.temperature
80 |
81 | while count < N:
82 | #print "attempting to generate picture #", count
83 | print '.',
84 | [strokes, params] = model.sample(sess, sample_args.sample_length, temp_mixture, temp_pen, stop_if_eoc = True)
85 | [sx, sy, num_stroke, num_char, _] = strokes.sum(0)
86 | if num_stroke < min_num_stroke or num_char == 0 or num_stroke > max_num_stroke:
87 | #print "num_stroke ", num_stroke, " num_char ", num_char
88 | continue
89 | [sx, sy, sizex, sizey] = calculate_start_point(strokes)
90 | if sizex > max_size or sizey > max_size:
91 | #print "sizex ", sizex, " sizey ", sizey
92 | continue
93 | if sizex < min_size or sizey < min_size:
94 | #print "sizex ", sizex, " sizey ", sizey
95 | continue
96 | # success
97 | print count+1,"/",N
98 | count += 1
99 | sketch_list.append(strokes)
100 | param_list.append(params)
101 | # draw the pics
102 | draw_sketch_array(sketch_list, svg_only = svg_only)
103 | return sketch_list, param_list
104 |
105 | if __name__ == '__main__':
106 | ipython_mode = in_ipython()
107 | if ipython_mode:
108 | print "IPython detected"
109 | else:
110 | print "Console mode"
111 | [strokes, params] = sample_sketches(svg_only = not ipython_mode)
112 |
113 |
114 |
--------------------------------------------------------------------------------
/save/kanji/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "model.ckpt-0"
2 | all_model_checkpoint_paths: "model.ckpt-0"
3 |
--------------------------------------------------------------------------------
/save/kanji/config.pkl:
--------------------------------------------------------------------------------
1 | ccopy_reg
2 | _reconstructor
3 | p1
4 | (cargparse
5 | Namespace
6 | p2
7 | c__builtin__
8 | object
9 | p3
10 | NtRp4
11 | (dp5
12 | S'grad_clip'
13 | p6
14 | F5
15 | sS'rnn_size'
16 | p7
17 | I256
18 | sS'data_scale'
19 | p8
20 | F15
21 | sS'learning_rate'
22 | p9
23 | F0.0050000000000000001
24 | sS'num_layers'
25 | p10
26 | I2
27 | sS'seq_length'
28 | p11
29 | I300
30 | sS'decay_rate'
31 | p12
32 | F0.98999999999999999
33 | sS'num_mixture'
34 | p13
35 | I24
36 | sS'batch_size'
37 | p14
38 | I100
39 | sS'num_epochs'
40 | p15
41 | I500
42 | sS'dataset_name'
43 | p16
44 | S'kanji'
45 | p17
46 | sS'model'
47 | p18
48 | S'lstm'
49 | p19
50 | sS'save_every'
51 | p20
52 | I250
53 | sS'keep_prob'
54 | p21
55 | F0.80000000000000004
56 | sS'stroke_importance_factor'
57 | p22
58 | F200
59 | sb.
--------------------------------------------------------------------------------
/save/kanji/model.ckpt-0:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hardmaru/sketch-rnn/862bb94a15d48f0c42af71a9ede1681a0a2d2602/save/kanji/model.ckpt-0
--------------------------------------------------------------------------------
/svg/__init__.py:
--------------------------------------------------------------------------------
1 | __import__('pkg_resources').declare_namespace(__name__)
2 |
--------------------------------------------------------------------------------
/svg/path/__init__.py:
--------------------------------------------------------------------------------
1 | from .path import Path, Line, Arc, CubicBezier, QuadraticBezier
2 | from .parser import parse_path
3 |
--------------------------------------------------------------------------------
/svg/path/parser.py:
--------------------------------------------------------------------------------
1 | # SVG Path specification parser
2 |
3 | import re
4 | from . import path
5 |
6 | COMMANDS = set('MmZzLlHhVvCcSsQqTtAa')
7 | UPPERCASE = set('MZLHVCSQTA')
8 |
9 | COMMAND_RE = re.compile("([MmZzLlHhVvCcSsQqTtAa])")
10 | FLOAT_RE = re.compile("[-+]?[0-9]*\.?[0-9]+(?:[eE][-+]?[0-9]+)?")
11 |
12 |
13 | def _tokenize_path(pathdef):
14 | for x in COMMAND_RE.split(pathdef):
15 | if x in COMMANDS:
16 | yield x
17 | for token in FLOAT_RE.findall(x):
18 | yield token
19 |
20 |
21 | def parse_path(pathdef, current_pos=0j):
22 | # In the SVG specs, initial movetos are absolute, even if
23 | # specified as 'm'. This is the default behavior here as well.
24 | # But if you pass in a current_pos variable, the initial moveto
25 | # will be relative to that current_pos. This is useful.
26 | elements = list(_tokenize_path(pathdef))
27 | # Reverse for easy use of .pop()
28 | elements.reverse()
29 |
30 | segments = path.Path()
31 | start_pos = None
32 | command = None
33 |
34 | while elements:
35 |
36 | if elements[-1] in COMMANDS:
37 | # New command.
38 | last_command = command # Used by S and T
39 | command = elements.pop()
40 | absolute = command in UPPERCASE
41 | command = command.upper()
42 | else:
43 | # If this element starts with numbers, it is an implicit command
44 | # and we don't change the command. Check that it's allowed:
45 | if command is None:
46 | raise ValueError("Unallowed implicit command in %s, position %s" % (
47 | pathdef, len(pathdef.split()) - len(elements)))
48 |
49 | if command == 'M':
50 | # Moveto command.
51 | x = elements.pop()
52 | y = elements.pop()
53 | pos = float(x) + float(y) * 1j
54 | if absolute:
55 | current_pos = pos
56 | else:
57 | current_pos += pos
58 |
59 | # when M is called, reset start_pos
60 | # This behavior of Z is defined in svg spec:
61 | # http://www.w3.org/TR/SVG/paths.html#PathDataClosePathCommand
62 | start_pos = current_pos
63 |
64 | # Implicit moveto commands are treated as lineto commands.
65 | # So we set command to lineto here, in case there are
66 | # further implicit commands after this moveto.
67 | command = 'L'
68 |
69 | elif command == 'Z':
70 | # Close path
71 | segments.append(path.Line(current_pos, start_pos))
72 | segments.closed = True
73 | current_pos = start_pos
74 | start_pos = None
75 | command = None # You can't have implicit commands after closing.
76 |
77 | elif command == 'L':
78 | x = elements.pop()
79 | y = elements.pop()
80 | pos = float(x) + float(y) * 1j
81 | if not absolute:
82 | pos += current_pos
83 | segments.append(path.Line(current_pos, pos))
84 | current_pos = pos
85 |
86 | elif command == 'H':
87 | x = elements.pop()
88 | pos = float(x) + current_pos.imag * 1j
89 | if not absolute:
90 | pos += current_pos.real
91 | segments.append(path.Line(current_pos, pos))
92 | current_pos = pos
93 |
94 | elif command == 'V':
95 | y = elements.pop()
96 | pos = current_pos.real + float(y) * 1j
97 | if not absolute:
98 | pos += current_pos.imag * 1j
99 | segments.append(path.Line(current_pos, pos))
100 | current_pos = pos
101 |
102 | elif command == 'C':
103 | control1 = float(elements.pop()) + float(elements.pop()) * 1j
104 | control2 = float(elements.pop()) + float(elements.pop()) * 1j
105 | end = float(elements.pop()) + float(elements.pop()) * 1j
106 |
107 | if not absolute:
108 | control1 += current_pos
109 | control2 += current_pos
110 | end += current_pos
111 |
112 | segments.append(path.CubicBezier(current_pos, control1, control2, end))
113 | current_pos = end
114 |
115 | elif command == 'S':
116 | # Smooth curve. First control point is the "reflection" of
117 | # the second control point in the previous path.
118 |
119 | if last_command not in 'CS':
120 | # If there is no previous command or if the previous command
121 | # was not an C, c, S or s, assume the first control point is
122 | # coincident with the current point.
123 | control1 = current_pos
124 | else:
125 | # The first control point is assumed to be the reflection of
126 | # the second control point on the previous command relative
127 | # to the current point.
128 | control1 = current_pos + current_pos - segments[-1].control2
129 |
130 | control2 = float(elements.pop()) + float(elements.pop()) * 1j
131 | end = float(elements.pop()) + float(elements.pop()) * 1j
132 |
133 | if not absolute:
134 | control2 += current_pos
135 | end += current_pos
136 |
137 | segments.append(path.CubicBezier(current_pos, control1, control2, end))
138 | current_pos = end
139 |
140 | elif command == 'Q':
141 | control = float(elements.pop()) + float(elements.pop()) * 1j
142 | end = float(elements.pop()) + float(elements.pop()) * 1j
143 |
144 | if not absolute:
145 | control += current_pos
146 | end += current_pos
147 |
148 | segments.append(path.QuadraticBezier(current_pos, control, end))
149 | current_pos = end
150 |
151 | elif command == 'T':
152 | # Smooth curve. Control point is the "reflection" of
153 | # the second control point in the previous path.
154 |
155 | if last_command not in 'QT':
156 | # If there is no previous command or if the previous command
157 | # was not an Q, q, T or t, assume the first control point is
158 | # coincident with the current point.
159 | control = current_pos
160 | else:
161 | # The control point is assumed to be the reflection of
162 | # the control point on the previous command relative
163 | # to the current point.
164 | control = current_pos + current_pos - segments[-1].control
165 |
166 | end = float(elements.pop()) + float(elements.pop()) * 1j
167 |
168 | if not absolute:
169 | end += current_pos
170 |
171 | segments.append(path.QuadraticBezier(current_pos, control, end))
172 | current_pos = end
173 |
174 | elif command == 'A':
175 | radius = float(elements.pop()) + float(elements.pop()) * 1j
176 | rotation = float(elements.pop())
177 | arc = float(elements.pop())
178 | sweep = float(elements.pop())
179 | end = float(elements.pop()) + float(elements.pop()) * 1j
180 |
181 | if not absolute:
182 | end += current_pos
183 |
184 | segments.append(path.Arc(current_pos, radius, rotation, arc, sweep, end))
185 | current_pos = end
186 |
187 | return segments
188 |
--------------------------------------------------------------------------------
/svg/path/path.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | from math import sqrt, cos, sin, acos, degrees, radians, log
3 | from collections import MutableSequence
4 |
5 |
6 | # This file contains classes for the different types of SVG path segments as
7 | # well as a Path object that contains a sequence of path segments.
8 |
9 | MIN_DEPTH = 5
10 | ERROR = 1e-12
11 |
12 |
13 | def segment_length(curve, start, end, start_point, end_point, error, min_depth, depth):
14 | """Recursively approximates the length by straight lines"""
15 | mid = (start + end) / 2
16 | mid_point = curve.point(mid)
17 | length = abs(end_point - start_point)
18 | first_half = abs(mid_point - start_point)
19 | second_half = abs(end_point - mid_point)
20 |
21 | length2 = first_half + second_half
22 | if (length2 - length > error) or (depth < min_depth):
23 | # Calculate the length of each segment:
24 | depth += 1
25 | return (segment_length(curve, start, mid, start_point, mid_point,
26 | error, min_depth, depth) +
27 | segment_length(curve, mid, end, mid_point, end_point,
28 | error, min_depth, depth))
29 | # This is accurate enough.
30 | return length2
31 |
32 |
33 | class Line(object):
34 |
35 | def __init__(self, start, end):
36 | self.start = start
37 | self.end = end
38 |
39 | def __repr__(self):
40 | return 'Line(start=%s, end=%s)' % (self.start, self.end)
41 |
42 | def __eq__(self, other):
43 | if not isinstance(other, Line):
44 | return NotImplemented
45 | return self.start == other.start and self.end == other.end
46 |
47 | def __ne__(self, other):
48 | if not isinstance(other, Line):
49 | return NotImplemented
50 | return not self == other
51 |
52 | def point(self, pos):
53 | distance = self.end - self.start
54 | return self.start + distance * pos
55 |
56 | def length(self, error=None, min_depth=None):
57 | distance = (self.end - self.start)
58 | return sqrt(distance.real ** 2 + distance.imag ** 2)
59 |
60 |
61 | class CubicBezier(object):
62 | def __init__(self, start, control1, control2, end):
63 | self.start = start
64 | self.control1 = control1
65 | self.control2 = control2
66 | self.end = end
67 |
68 | def __repr__(self):
69 | return 'CubicBezier(start=%s, control1=%s, control2=%s, end=%s)' % (
70 | self.start, self.control1, self.control2, self.end)
71 |
72 | def __eq__(self, other):
73 | if not isinstance(other, CubicBezier):
74 | return NotImplemented
75 | return self.start == other.start and self.end == other.end and \
76 | self.control1 == other.control1 and self.control2 == other.control2
77 |
78 | def __ne__(self, other):
79 | if not isinstance(other, CubicBezier):
80 | return NotImplemented
81 | return not self == other
82 |
83 | def is_smooth_from(self, previous):
84 | """Checks if this segment would be a smooth segment following the previous"""
85 | if isinstance(previous, CubicBezier):
86 | return (self.start == previous.end and
87 | (self.control1 - self.start) == (previous.end - previous.control2))
88 | else:
89 | return self.control1 == self.start
90 |
91 | def point(self, pos):
92 | """Calculate the x,y position at a certain position of the path"""
93 | return ((1 - pos) ** 3 * self.start) + \
94 | (3 * (1 - pos) ** 2 * pos * self.control1) + \
95 | (3 * (1 - pos) * pos ** 2 * self.control2) + \
96 | (pos ** 3 * self.end)
97 |
98 | def length(self, error=ERROR, min_depth=MIN_DEPTH):
99 | """Calculate the length of the path up to a certain position"""
100 | start_point = self.point(0)
101 | end_point = self.point(1)
102 | return segment_length(self, 0, 1, start_point, end_point, error, min_depth, 0)
103 |
104 |
105 | class QuadraticBezier(object):
106 | def __init__(self, start, control, end):
107 | self.start = start
108 | self.end = end
109 | self.control = control
110 |
111 | def __repr__(self):
112 | return 'QuadraticBezier(start=%s, control=%s, end=%s)' % (
113 | self.start, self.control, self.end)
114 |
115 | def __eq__(self, other):
116 | if not isinstance(other, QuadraticBezier):
117 | return NotImplemented
118 | return self.start == other.start and self.end == other.end and \
119 | self.control == other.control
120 |
121 | def __ne__(self, other):
122 | if not isinstance(other, QuadraticBezier):
123 | return NotImplemented
124 | return not self == other
125 |
126 | def is_smooth_from(self, previous):
127 | """Checks if this segment would be a smooth segment following the previous"""
128 | if isinstance(previous, QuadraticBezier):
129 | return (self.start == previous.end and
130 | (self.control - self.start) == (previous.end - previous.control))
131 | else:
132 | return self.control == self.start
133 |
134 | def point(self, pos):
135 | return (1 - pos) ** 2 * self.start + 2 * (1 - pos) * pos * self.control + \
136 | pos ** 2 * self.end
137 |
138 | def length(self, error=None, min_depth=None):
139 | a = self.start - 2*self.control + self.end
140 | b = 2*(self.control - self.start)
141 | a_dot_b = a.real*b.real + a.imag*b.imag
142 |
143 | if abs(a) < 1e-12:
144 | s = abs(b)
145 | elif abs(a_dot_b + abs(a)*abs(b)) < 1e-12:
146 | k = abs(b)/abs(a)
147 | if k >= 2:
148 | s = abs(b) - abs(a)
149 | else:
150 | s = abs(a)*(k**2/2 - k + 1)
151 | else:
152 | # For an explanation of this case, see
153 | # http://www.malczak.info/blog/quadratic-bezier-curve-length/
154 | A = 4 * (a.real ** 2 + a.imag ** 2)
155 | B = 4 * (a.real * b.real + a.imag * b.imag)
156 | C = b.real ** 2 + b.imag ** 2
157 |
158 | Sabc = 2 * sqrt(A + B + C)
159 | A2 = sqrt(A)
160 | A32 = 2 * A * A2
161 | C2 = 2 * sqrt(C)
162 | BA = B / A2
163 |
164 | s = (A32 * Sabc + A2 * B * (Sabc - C2) + (4 * C * A - B ** 2) *
165 | log((2 * A2 + BA + Sabc) / (BA + C2))) / (4 * A32)
166 | return s
167 |
168 | class Arc(object):
169 |
170 | def __init__(self, start, radius, rotation, arc, sweep, end):
171 | """radius is complex, rotation is in degrees,
172 | large and sweep are 1 or 0 (True/False also work)"""
173 |
174 | self.start = start
175 | self.radius = radius
176 | self.rotation = rotation
177 | self.arc = bool(arc)
178 | self.sweep = bool(sweep)
179 | self.end = end
180 |
181 | self._parameterize()
182 |
183 | def __repr__(self):
184 | return 'Arc(start=%s, radius=%s, rotation=%s, arc=%s, sweep=%s, end=%s)' % (
185 | self.start, self.radius, self.rotation, self.arc, self.sweep, self.end)
186 |
187 | def __eq__(self, other):
188 | if not isinstance(other, Arc):
189 | return NotImplemented
190 | return self.start == other.start and self.end == other.end and \
191 | self.radius == other.radius and self.rotation == other.rotation and \
192 | self.arc == other.arc and self.sweep == other.sweep
193 |
194 | def __ne__(self, other):
195 | if not isinstance(other, Arc):
196 | return NotImplemented
197 | return not self == other
198 |
199 | def _parameterize(self):
200 | # Conversion from endpoint to center parameterization
201 | # http://www.w3.org/TR/SVG/implnote.html#ArcImplementationNotes
202 |
203 | cosr = cos(radians(self.rotation))
204 | sinr = sin(radians(self.rotation))
205 | dx = (self.start.real - self.end.real) / 2
206 | dy = (self.start.imag - self.end.imag) / 2
207 | x1prim = cosr * dx + sinr * dy
208 | x1prim_sq = x1prim * x1prim
209 | y1prim = -sinr * dx + cosr * dy
210 | y1prim_sq = y1prim * y1prim
211 |
212 | rx = self.radius.real
213 | rx_sq = rx * rx
214 | ry = self.radius.imag
215 | ry_sq = ry * ry
216 |
217 | # Correct out of range radii
218 | radius_check = (x1prim_sq / rx_sq) + (y1prim_sq / ry_sq)
219 | if radius_check > 1:
220 | rx *= sqrt(radius_check)
221 | ry *= sqrt(radius_check)
222 | rx_sq = rx * rx
223 | ry_sq = ry * ry
224 |
225 | t1 = rx_sq * y1prim_sq
226 | t2 = ry_sq * x1prim_sq
227 | c = sqrt(abs((rx_sq * ry_sq - t1 - t2) / (t1 + t2)))
228 |
229 | if self.arc == self.sweep:
230 | c = -c
231 | cxprim = c * rx * y1prim / ry
232 | cyprim = -c * ry * x1prim / rx
233 |
234 | self.center = complex((cosr * cxprim - sinr * cyprim) +
235 | ((self.start.real + self.end.real) / 2),
236 | (sinr * cxprim + cosr * cyprim) +
237 | ((self.start.imag + self.end.imag) / 2))
238 |
239 | ux = (x1prim - cxprim) / rx
240 | uy = (y1prim - cyprim) / ry
241 | vx = (-x1prim - cxprim) / rx
242 | vy = (-y1prim - cyprim) / ry
243 | n = sqrt(ux * ux + uy * uy)
244 | p = ux
245 | theta = degrees(acos(p / n))
246 | if uy < 0:
247 | theta = -theta
248 | self.theta = theta % 360
249 |
250 | n = sqrt((ux * ux + uy * uy) * (vx * vx + vy * vy))
251 | p = ux * vx + uy * vy
252 | if p == 0:
253 | delta = degrees(acos(0))
254 | else:
255 | delta = degrees(acos(p / n))
256 | if (ux * vy - uy * vx) < 0:
257 | delta = -delta
258 | self.delta = delta % 360
259 | if not self.sweep:
260 | self.delta -= 360
261 |
262 | def point(self, pos):
263 | angle = radians(self.theta + (self.delta * pos))
264 | cosr = cos(radians(self.rotation))
265 | sinr = sin(radians(self.rotation))
266 |
267 | x = (cosr * cos(angle) * self.radius.real - sinr * sin(angle) *
268 | self.radius.imag + self.center.real)
269 | y = (sinr * cos(angle) * self.radius.real + cosr * sin(angle) *
270 | self.radius.imag + self.center.imag)
271 | return complex(x, y)
272 |
273 | def length(self, error=ERROR, min_depth=MIN_DEPTH):
274 | """The length of an elliptical arc segment requires numerical
275 | integration, and in that case it's simpler to just do a geometric
276 | approximation, as for cubic bezier curves.
277 | """
278 | start_point = self.point(0)
279 | end_point = self.point(1)
280 | return segment_length(self, 0, 1, start_point, end_point, error, min_depth, 0)
281 |
282 |
283 | class Path(MutableSequence):
284 | """A Path is a sequence of path segments"""
285 |
286 | # Put it here, so there is a default if unpickled.
287 | _closed = False
288 |
289 | def __init__(self, *segments, **kw):
290 | self._segments = list(segments)
291 | self._length = None
292 | self._lengths = None
293 | if 'closed' in kw:
294 | self.closed = kw['closed']
295 |
296 | def __getitem__(self, index):
297 | return self._segments[index]
298 |
299 | def __setitem__(self, index, value):
300 | self._segments[index] = value
301 | self._length = None
302 |
303 | def __delitem__(self, index):
304 | del self._segments[index]
305 | self._length = None
306 |
307 | def insert(self, index, value):
308 | self._segments.insert(index, value)
309 | self._length = None
310 |
311 | def reverse(self):
312 | # Reversing the order of a path would require reversing each element
313 | # as well. That's not implemented.
314 | raise NotImplementedError
315 |
316 | def __len__(self):
317 | return len(self._segments)
318 |
319 | def __repr__(self):
320 | return 'Path(%s, closed=%s)' % (
321 | ', '.join(repr(x) for x in self._segments), self.closed)
322 |
323 | def __eq__(self, other):
324 | if not isinstance(other, Path):
325 | return NotImplemented
326 | if len(self) != len(other):
327 | return False
328 | for s, o in zip(self._segments, other._segments):
329 | if not s == o:
330 | return False
331 | return True
332 |
333 | def __ne__(self, other):
334 | if not isinstance(other, Path):
335 | return NotImplemented
336 | return not self == other
337 |
338 | def _calc_lengths(self, error=ERROR, min_depth=MIN_DEPTH):
339 | if self._length is not None:
340 | return
341 |
342 | lengths = [each.length(error=error, min_depth=min_depth) for each in self._segments]
343 | self._length = sum(lengths)
344 | self._lengths = [each / self._length for each in lengths]
345 |
346 | def point(self, pos, error=ERROR):
347 |
348 | # Shortcuts
349 | if pos == 0.0:
350 | return self._segments[0].point(pos)
351 | if pos == 1.0:
352 | return self._segments[-1].point(pos)
353 |
354 | self._calc_lengths(error=error)
355 | # Find which segment the point we search for is located on:
356 | segment_start = 0
357 | for index, segment in enumerate(self._segments):
358 | segment_end = segment_start + self._lengths[index]
359 | if segment_end >= pos:
360 | # This is the segment! How far in on the segment is the point?
361 | segment_pos = (pos - segment_start) / (segment_end - segment_start)
362 | break
363 | segment_start = segment_end
364 |
365 | return segment.point(segment_pos)
366 |
367 | def length(self, error=ERROR, min_depth=MIN_DEPTH):
368 | self._calc_lengths(error, min_depth)
369 | return self._length
370 |
371 | def _is_closable(self):
372 | """Returns true if the end is on the start of a segment"""
373 | end = self[-1].end
374 | for segment in self:
375 | if segment.start == end:
376 | return True
377 | return False
378 |
379 | @property
380 | def closed(self):
381 | """Checks that the path is closed"""
382 | return self._closed and self._is_closable()
383 |
384 | @closed.setter
385 | def closed(self, value):
386 | value = bool(value)
387 | if value and not self._is_closable():
388 | raise ValueError("End does not coincide with a segment start.")
389 | self._closed = value
390 |
391 | def d(self):
392 | if self.closed:
393 | segments = self[:-1]
394 | else:
395 | segments = self[:]
396 |
397 | current_pos = None
398 | parts = []
399 | previous_segment = None
400 | end = self[-1].end
401 |
402 | for segment in segments:
403 | start = segment.start
404 | # If the start of this segment does not coincide with the end of
405 | # the last segment or if this segment is actually the close point
406 | # of a closed path, then we should start a new subpath here.
407 | if current_pos != start or (self.closed and start == end):
408 | parts.append('M {0:G},{1:G}'.format(start.real, start.imag))
409 |
410 | if isinstance(segment, Line):
411 | parts.append('L {0:G},{1:G}'.format(
412 | segment.end.real, segment.end.imag)
413 | )
414 | elif isinstance(segment, CubicBezier):
415 | if segment.is_smooth_from(previous_segment):
416 | parts.append('S {0:G},{1:G} {2:G},{3:G}'.format(
417 | segment.control2.real, segment.control2.imag,
418 | segment.end.real, segment.end.imag)
419 | )
420 | else:
421 | parts.append('C {0:G},{1:G} {2:G},{3:G} {4:G},{5:G}'.format(
422 | segment.control1.real, segment.control1.imag,
423 | segment.control2.real, segment.control2.imag,
424 | segment.end.real, segment.end.imag)
425 | )
426 | elif isinstance(segment, QuadraticBezier):
427 | if segment.is_smooth_from(previous_segment):
428 | parts.append('T {0:G},{1:G}'.format(
429 | segment.end.real, segment.end.imag)
430 | )
431 | else:
432 | parts.append('Q {0:G},{1:G} {2:G},{3:G}'.format(
433 | segment.control.real, segment.control.imag,
434 | segment.end.real, segment.end.imag)
435 | )
436 |
437 | elif isinstance(segment, Arc):
438 | parts.append('A {0:G},{1:G} {2:G} {3:d},{4:d} {5:G},{6:G}'.format(
439 | segment.radius.real, segment.radius.imag, segment.rotation,
440 | int(segment.arc), int(segment.sweep),
441 | segment.end.real, segment.end.imag)
442 | )
443 | current_pos = segment.end
444 | previous_segment = segment
445 |
446 | if self.closed:
447 | parts.append('Z')
448 |
449 | return ' '.join(parts)
450 |
--------------------------------------------------------------------------------
/svg/path/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hardmaru/sketch-rnn/862bb94a15d48f0c42af71a9ede1681a0a2d2602/svg/path/tests/__init__.py
--------------------------------------------------------------------------------
/svg/path/tests/test_doc.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import doctest
3 |
4 |
5 | def load_tests(loader, tests, ignore):
6 | tests.addTests(doctest.DocFileSuite('README.rst', package='__main__'))
7 | return tests
8 |
--------------------------------------------------------------------------------
/svg/path/tests/test_generation.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | import unittest
3 | from ..path import CubicBezier, QuadraticBezier, Line, Arc, Path
4 | from ..parser import parse_path
5 |
6 |
7 | class TestGeneration(unittest.TestCase):
8 |
9 | def test_svg_examples(self):
10 | """Examples from the SVG spec"""
11 | paths = [
12 | 'M 100,100 L 300,100 L 200,300 Z',
13 | 'M 0,0 L 50,20 M 100,100 L 300,100 L 200,300 Z',
14 | 'M 100,100 L 200,200',
15 | 'M 100,200 L 200,100 L -100,-200',
16 | 'M 100,200 C 100,100 250,100 250,200 S 400,300 400,200',
17 | 'M 100,200 C 100,100 400,100 400,200',
18 | 'M 100,500 C 25,400 475,400 400,500',
19 | 'M 100,800 C 175,700 325,700 400,800',
20 | 'M 600,200 C 675,100 975,100 900,200',
21 | 'M 600,500 C 600,350 900,650 900,500',
22 | 'M 600,800 C 625,700 725,700 750,800 S 875,900 900,800',
23 | 'M 200,300 Q 400,50 600,300 T 1000,300',
24 | 'M -3.4E+38,3.4E+38 L -3.4E-38,3.4E-38',
25 | 'M 0,0 L 50,20 M 50,20 L 200,100 Z',
26 | 'M 600,350 L 650,325 A 25,25 -30 0,1 700,300 L 750,275',
27 | ]
28 |
29 | for path in paths:
30 | self.assertEqual(parse_path(path).d(), path)
31 |
32 | def test_normalizing(self):
33 | # Relative paths will be made absolute, subpaths merged if they can,
34 | # and syntax will change.
35 | self.assertEqual(parse_path('M0 0L3.4E2-10L100.0,100M100,100l100,-100').d(),
36 | 'M 0,0 L 340,-10 L 100,100 L 200,0')
37 |
--------------------------------------------------------------------------------
/svg/path/tests/test_parsing.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | import unittest
3 | from ..path import CubicBezier, QuadraticBezier, Line, Arc, Path
4 | from ..parser import parse_path
5 |
6 |
7 | class TestParser(unittest.TestCase):
8 |
9 | def test_svg_examples(self):
10 | """Examples from the SVG spec"""
11 | path1 = parse_path('M 100 100 L 300 100 L 200 300 z')
12 | self.assertEqual(path1, Path(Line(100 + 100j, 300 + 100j),
13 | Line(300 + 100j, 200 + 300j),
14 | Line(200 + 300j, 100 + 100j)))
15 | self.assertTrue(path1.closed)
16 |
17 | # for Z command behavior when there is multiple subpaths
18 | path1 = parse_path('M 0 0 L 50 20 M 100 100 L 300 100 L 200 300 z')
19 | self.assertEqual(path1, Path(
20 | Line(0 + 0j, 50 + 20j),
21 | Line(100 + 100j, 300 + 100j),
22 | Line(300 + 100j, 200 + 300j),
23 | Line(200 + 300j, 100 + 100j)))
24 |
25 | path1 = parse_path('M 100 100 L 200 200')
26 | path2 = parse_path('M100 100L200 200')
27 | self.assertEqual(path1, path2)
28 |
29 | path1 = parse_path('M 100 200 L 200 100 L -100 -200')
30 | path2 = parse_path('M 100 200 L 200 100 -100 -200')
31 | self.assertEqual(path1, path2)
32 |
33 | path1 = parse_path("""M100,200 C100,100 250,100 250,200
34 | S400,300 400,200""")
35 | self.assertEqual(path1,
36 | Path(CubicBezier(100 + 200j, 100 + 100j, 250 + 100j, 250 + 200j),
37 | CubicBezier(250 + 200j, 250 + 300j, 400 + 300j, 400 + 200j)))
38 |
39 | path1 = parse_path('M100,200 C100,100 400,100 400,200')
40 | self.assertEqual(path1,
41 | Path(CubicBezier(100 + 200j, 100 + 100j, 400 + 100j, 400 + 200j)))
42 |
43 | path1 = parse_path('M100,500 C25,400 475,400 400,500')
44 | self.assertEqual(path1,
45 | Path(CubicBezier(100 + 500j, 25 + 400j, 475 + 400j, 400 + 500j)))
46 |
47 | path1 = parse_path('M100,800 C175,700 325,700 400,800')
48 | self.assertEqual(path1,
49 | Path(CubicBezier(100 + 800j, 175 + 700j, 325 + 700j, 400 + 800j)))
50 |
51 | path1 = parse_path('M600,200 C675,100 975,100 900,200')
52 | self.assertEqual(path1,
53 | Path(CubicBezier(600 + 200j, 675 + 100j, 975 + 100j, 900 + 200j)))
54 |
55 | path1 = parse_path('M600,500 C600,350 900,650 900,500')
56 | self.assertEqual(path1,
57 | Path(CubicBezier(600 + 500j, 600 + 350j, 900 + 650j, 900 + 500j)))
58 |
59 | path1 = parse_path("""M600,800 C625,700 725,700 750,800
60 | S875,900 900,800""")
61 | self.assertEqual(path1,
62 | Path(CubicBezier(600 + 800j, 625 + 700j, 725 + 700j, 750 + 800j),
63 | CubicBezier(750 + 800j, 775 + 900j, 875 + 900j, 900 + 800j)))
64 |
65 | path1 = parse_path('M200,300 Q400,50 600,300 T1000,300')
66 | self.assertEqual(path1,
67 | Path(QuadraticBezier(200 + 300j, 400 + 50j, 600 + 300j),
68 | QuadraticBezier(600 + 300j, 800 + 550j, 1000 + 300j)))
69 |
70 | path1 = parse_path('M300,200 h-150 a150,150 0 1,0 150,-150 z')
71 | self.assertEqual(path1,
72 | Path(Line(300 + 200j, 150 + 200j),
73 | Arc(150 + 200j, 150 + 150j, 0, 1, 0, 300 + 50j),
74 | Line(300 + 50j, 300 + 200j)))
75 |
76 | path1 = parse_path('M275,175 v-150 a150,150 0 0,0 -150,150 z')
77 | self.assertEqual(path1,
78 | Path(Line(275 + 175j, 275 + 25j),
79 | Arc(275 + 25j, 150 + 150j, 0, 0, 0, 125 + 175j),
80 | Line(125 + 175j, 275 + 175j)))
81 |
82 | path1 = parse_path("""M600,350 l 50,-25
83 | a25,25 -30 0,1 50,-25 l 50,-25
84 | a25,50 -30 0,1 50,-25 l 50,-25
85 | a25,75 -30 0,1 50,-25 l 50,-25
86 | a25,100 -30 0,1 50,-25 l 50,-25""")
87 | self.assertEqual(path1,
88 | Path(Line(600 + 350j, 650 + 325j),
89 | Arc(650 + 325j, 25 + 25j, -30, 0, 1, 700 + 300j),
90 | Line(700 + 300j, 750 + 275j),
91 | Arc(750 + 275j, 25 + 50j, -30, 0, 1, 800 + 250j),
92 | Line(800 + 250j, 850 + 225j),
93 | Arc(850 + 225j, 25 + 75j, -30, 0, 1, 900 + 200j),
94 | Line(900 + 200j, 950 + 175j),
95 | Arc(950 + 175j, 25 + 100j, -30, 0, 1, 1000 + 150j),
96 | Line(1000 + 150j, 1050 + 125j)))
97 |
98 | def test_others(self):
99 | # Other paths that need testing:
100 |
101 | # Relative moveto:
102 | path1 = parse_path('M 0 0 L 50 20 m 50 80 L 300 100 L 200 300 z')
103 | self.assertEqual(path1, Path(
104 | Line(0 + 0j, 50 + 20j),
105 | Line(100 + 100j, 300 + 100j),
106 | Line(300 + 100j, 200 + 300j),
107 | Line(200 + 300j, 100 + 100j)))
108 |
109 | # Initial smooth and relative CubicBezier
110 | path1 = parse_path("""M100,200 s 150,-100 150,0""")
111 | self.assertEqual(path1,
112 | Path(CubicBezier(100 + 200j, 100 + 200j, 250 + 100j, 250 + 200j)))
113 |
114 | # Initial smooth and relative QuadraticBezier
115 | path1 = parse_path("""M100,200 t 150,0""")
116 | self.assertEqual(path1,
117 | Path(QuadraticBezier(100 + 200j, 100 + 200j, 250 + 200j)))
118 |
119 | # Relative QuadraticBezier
120 | path1 = parse_path("""M100,200 q 0,0 150,0""")
121 | self.assertEqual(path1,
122 | Path(QuadraticBezier(100 + 200j, 100 + 200j, 250 + 200j)))
123 |
124 | def test_negative(self):
125 | """You don't need spaces before a minus-sign"""
126 | path1 = parse_path('M100,200c10-5,20-10,30-20')
127 | path2 = parse_path('M 100 200 c 10 -5 20 -10 30 -20')
128 | self.assertEqual(path1, path2)
129 |
130 | def test_numbers(self):
131 | """Exponents and other number format cases"""
132 | # It can be e or E, the plus is optional, and a minimum of +/-3.4e38 must be supported.
133 | path1 = parse_path('M-3.4e38 3.4E+38L-3.4E-38,3.4e-38')
134 | path2 = Path(Line(-3.4e+38 + 3.4e+38j, -3.4e-38 + 3.4e-38j))
135 | self.assertEqual(path1, path2)
136 |
137 | def test_errors(self):
138 | self.assertRaises(ValueError, parse_path, 'M 100 100 L 200 200 Z 100 200')
139 |
--------------------------------------------------------------------------------
/svg/path/tests/test_paths.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | import unittest
3 | from math import sqrt, pi
4 |
5 | from ..path import CubicBezier, QuadraticBezier, Line, Arc, Path
6 |
7 |
8 | # Most of these test points are not calculated serparately, as that would
9 | # take too long and be too error prone. Instead the curves have been verified
10 | # to be correct visually, by drawing them with the turtle module, with code
11 | # like this:
12 | #
13 | # import turtle
14 | # t = turtle.Turtle()
15 | # t.penup()
16 | #
17 | # for arc in (path1, path2):
18 | # p = arc.point(0)
19 | # t.goto(p.real - 500, -p.imag + 300)
20 | # t.dot(3, 'black')
21 | # t.pendown()
22 | # for x in range(1, 101):
23 | # p = arc.point(x * 0.01)
24 | # t.goto(p.real - 500, -p.imag + 300)
25 | # t.penup()
26 | # t.dot(3, 'black')
27 | #
28 | # raw_input()
29 | #
30 | # After the paths have been verified to be correct this way, the testing of
31 | # points along the paths has been added as regression tests, to make sure
32 | # nobody changes the way curves are drawn by mistake. Therefore, do not take
33 | # these points religiously. They might be subtly wrong, unless otherwise
34 | # noted.
35 |
36 | class LineTest(unittest.TestCase):
37 |
38 | def test_lines(self):
39 | # These points are calculated, and not just regression tests.
40 |
41 | line1 = Line(0j, 400 + 0j)
42 | self.assertAlmostEqual(line1.point(0), (0j))
43 | self.assertAlmostEqual(line1.point(0.3), (120 + 0j))
44 | self.assertAlmostEqual(line1.point(0.5), (200 + 0j))
45 | self.assertAlmostEqual(line1.point(0.9), (360 + 0j))
46 | self.assertAlmostEqual(line1.point(1), (400 + 0j))
47 | self.assertAlmostEqual(line1.length(), 400)
48 |
49 | line2 = Line(400 + 0j, 400 + 300j)
50 | self.assertAlmostEqual(line2.point(0), (400 + 0j))
51 | self.assertAlmostEqual(line2.point(0.3), (400 + 90j))
52 | self.assertAlmostEqual(line2.point(0.5), (400 + 150j))
53 | self.assertAlmostEqual(line2.point(0.9), (400 + 270j))
54 | self.assertAlmostEqual(line2.point(1), (400 + 300j))
55 | self.assertAlmostEqual(line2.length(), 300)
56 |
57 | line3 = Line(400 + 300j, 0j)
58 | self.assertAlmostEqual(line3.point(0), (400 + 300j))
59 | self.assertAlmostEqual(line3.point(0.3), (280 + 210j))
60 | self.assertAlmostEqual(line3.point(0.5), (200 + 150j))
61 | self.assertAlmostEqual(line3.point(0.9), (40 + 30j))
62 | self.assertAlmostEqual(line3.point(1), (0j))
63 | self.assertAlmostEqual(line3.length(), 500)
64 |
65 | def test_equality(self):
66 | # This is to test the __eq__ and __ne__ methods, so we can't use
67 | # assertEqual and assertNotEqual
68 | line = Line(0j, 400 + 0j)
69 | self.assertTrue(line == Line(0, 400))
70 | self.assertTrue(line != Line(100, 400))
71 | self.assertFalse(line == str(line))
72 | self.assertTrue(line != str(line))
73 | self.assertFalse(CubicBezier(600 + 500j, 600 + 350j, 900 + 650j, 900 + 500j) ==
74 | line)
75 |
76 |
77 | class CubicBezierTest(unittest.TestCase):
78 | def test_approx_circle(self):
79 | """This is a approximate circle drawn in Inkscape"""
80 |
81 | arc1 = CubicBezier(
82 | complex(0, 0),
83 | complex(0, 109.66797),
84 | complex(-88.90345, 198.57142),
85 | complex(-198.57142, 198.57142)
86 | )
87 |
88 | self.assertAlmostEqual(arc1.point(0), (0j))
89 | self.assertAlmostEqual(arc1.point(0.1), (-2.59896457 + 32.20931647j))
90 | self.assertAlmostEqual(arc1.point(0.2), (-10.12330256 + 62.76392816j))
91 | self.assertAlmostEqual(arc1.point(0.3), (-22.16418039 + 91.25500149j))
92 | self.assertAlmostEqual(arc1.point(0.4), (-38.31276448 + 117.27370288j))
93 | self.assertAlmostEqual(arc1.point(0.5), (-58.16022125 + 140.41119875j))
94 | self.assertAlmostEqual(arc1.point(0.6), (-81.29771712 + 160.25865552j))
95 | self.assertAlmostEqual(arc1.point(0.7), (-107.31641851 + 176.40723961j))
96 | self.assertAlmostEqual(arc1.point(0.8), (-135.80749184 + 188.44811744j))
97 | self.assertAlmostEqual(arc1.point(0.9), (-166.36210353 + 195.97245543j))
98 | self.assertAlmostEqual(arc1.point(1), (-198.57142 + 198.57142j))
99 |
100 | arc2 = CubicBezier(
101 | complex(-198.57142, 198.57142),
102 | complex(-109.66797 - 198.57142, 0 + 198.57142),
103 | complex(-198.57143 - 198.57142, -88.90345 + 198.57142),
104 | complex(-198.57143 - 198.57142, 0),
105 | )
106 |
107 | self.assertAlmostEqual(arc2.point(0), (-198.57142 + 198.57142j))
108 | self.assertAlmostEqual(arc2.point(0.1), (-230.78073675 + 195.97245543j))
109 | self.assertAlmostEqual(arc2.point(0.2), (-261.3353492 + 188.44811744j))
110 | self.assertAlmostEqual(arc2.point(0.3), (-289.82642365 + 176.40723961j))
111 | self.assertAlmostEqual(arc2.point(0.4), (-315.8451264 + 160.25865552j))
112 | self.assertAlmostEqual(arc2.point(0.5), (-338.98262375 + 140.41119875j))
113 | self.assertAlmostEqual(arc2.point(0.6), (-358.830082 + 117.27370288j))
114 | self.assertAlmostEqual(arc2.point(0.7), (-374.97866745 + 91.25500149j))
115 | self.assertAlmostEqual(arc2.point(0.8), (-387.0195464 + 62.76392816j))
116 | self.assertAlmostEqual(arc2.point(0.9), (-394.54388515 + 32.20931647j))
117 | self.assertAlmostEqual(arc2.point(1), (-397.14285 + 0j))
118 |
119 | arc3 = CubicBezier(
120 | complex(-198.57143 - 198.57142, 0),
121 | complex(0 - 198.57143 - 198.57142, -109.66797),
122 | complex(88.90346 - 198.57143 - 198.57142, -198.57143),
123 | complex(-198.57142, -198.57143)
124 | )
125 |
126 | self.assertAlmostEqual(arc3.point(0), (-397.14285 + 0j))
127 | self.assertAlmostEqual(arc3.point(0.1), (-394.54388515 - 32.20931675j))
128 | self.assertAlmostEqual(arc3.point(0.2), (-387.0195464 - 62.7639292j))
129 | self.assertAlmostEqual(arc3.point(0.3), (-374.97866745 - 91.25500365j))
130 | self.assertAlmostEqual(arc3.point(0.4), (-358.830082 - 117.2737064j))
131 | self.assertAlmostEqual(arc3.point(0.5), (-338.98262375 - 140.41120375j))
132 | self.assertAlmostEqual(arc3.point(0.6), (-315.8451264 - 160.258662j))
133 | self.assertAlmostEqual(arc3.point(0.7), (-289.82642365 - 176.40724745j))
134 | self.assertAlmostEqual(arc3.point(0.8), (-261.3353492 - 188.4481264j))
135 | self.assertAlmostEqual(arc3.point(0.9), (-230.78073675 - 195.97246515j))
136 | self.assertAlmostEqual(arc3.point(1), (-198.57142 - 198.57143j))
137 |
138 | arc4 = CubicBezier(
139 | complex(-198.57142, -198.57143),
140 | complex(109.66797 - 198.57142, 0 - 198.57143),
141 | complex(0, 88.90346 - 198.57143),
142 | complex(0, 0),
143 | )
144 |
145 | self.assertAlmostEqual(arc4.point(0), (-198.57142 - 198.57143j))
146 | self.assertAlmostEqual(arc4.point(0.1), (-166.36210353 - 195.97246515j))
147 | self.assertAlmostEqual(arc4.point(0.2), (-135.80749184 - 188.4481264j))
148 | self.assertAlmostEqual(arc4.point(0.3), (-107.31641851 - 176.40724745j))
149 | self.assertAlmostEqual(arc4.point(0.4), (-81.29771712 - 160.258662j))
150 | self.assertAlmostEqual(arc4.point(0.5), (-58.16022125 - 140.41120375j))
151 | self.assertAlmostEqual(arc4.point(0.6), (-38.31276448 - 117.2737064j))
152 | self.assertAlmostEqual(arc4.point(0.7), (-22.16418039 - 91.25500365j))
153 | self.assertAlmostEqual(arc4.point(0.8), (-10.12330256 - 62.7639292j))
154 | self.assertAlmostEqual(arc4.point(0.9), (-2.59896457 - 32.20931675j))
155 | self.assertAlmostEqual(arc4.point(1), (0j))
156 |
157 | def test_svg_examples(self):
158 |
159 | # M100,200 C100,100 250,100 250,200
160 | path1 = CubicBezier(100 + 200j, 100 + 100j, 250 + 100j, 250 + 200j)
161 | self.assertAlmostEqual(path1.point(0), (100 + 200j))
162 | self.assertAlmostEqual(path1.point(0.3), (132.4 + 137j))
163 | self.assertAlmostEqual(path1.point(0.5), (175 + 125j))
164 | self.assertAlmostEqual(path1.point(0.9), (245.8 + 173j))
165 | self.assertAlmostEqual(path1.point(1), (250 + 200j))
166 |
167 | # S400,300 400,200
168 | path2 = CubicBezier(250 + 200j, 250 + 300j, 400 + 300j, 400 + 200j)
169 | self.assertAlmostEqual(path2.point(0), (250 + 200j))
170 | self.assertAlmostEqual(path2.point(0.3), (282.4 + 263j))
171 | self.assertAlmostEqual(path2.point(0.5), (325 + 275j))
172 | self.assertAlmostEqual(path2.point(0.9), (395.8 + 227j))
173 | self.assertAlmostEqual(path2.point(1), (400 + 200j))
174 |
175 | # M100,200 C100,100 400,100 400,200
176 | path3 = CubicBezier(100 + 200j, 100 + 100j, 400 + 100j, 400 + 200j)
177 | self.assertAlmostEqual(path3.point(0), (100 + 200j))
178 | self.assertAlmostEqual(path3.point(0.3), (164.8 + 137j))
179 | self.assertAlmostEqual(path3.point(0.5), (250 + 125j))
180 | self.assertAlmostEqual(path3.point(0.9), (391.6 + 173j))
181 | self.assertAlmostEqual(path3.point(1), (400 + 200j))
182 |
183 | # M100,500 C25,400 475,400 400,500
184 | path4 = CubicBezier(100 + 500j, 25 + 400j, 475 + 400j, 400 + 500j)
185 | self.assertAlmostEqual(path4.point(0), (100 + 500j))
186 | self.assertAlmostEqual(path4.point(0.3), (145.9 + 437j))
187 | self.assertAlmostEqual(path4.point(0.5), (250 + 425j))
188 | self.assertAlmostEqual(path4.point(0.9), (407.8 + 473j))
189 | self.assertAlmostEqual(path4.point(1), (400 + 500j))
190 |
191 | # M100,800 C175,700 325,700 400,800
192 | path5 = CubicBezier(100 + 800j, 175 + 700j, 325 + 700j, 400 + 800j)
193 | self.assertAlmostEqual(path5.point(0), (100 + 800j))
194 | self.assertAlmostEqual(path5.point(0.3), (183.7 + 737j))
195 | self.assertAlmostEqual(path5.point(0.5), (250 + 725j))
196 | self.assertAlmostEqual(path5.point(0.9), (375.4 + 773j))
197 | self.assertAlmostEqual(path5.point(1), (400 + 800j))
198 |
199 | # M600,200 C675,100 975,100 900,200
200 | path6 = CubicBezier(600 + 200j, 675 + 100j, 975 + 100j, 900 + 200j)
201 | self.assertAlmostEqual(path6.point(0), (600 + 200j))
202 | self.assertAlmostEqual(path6.point(0.3), (712.05 + 137j))
203 | self.assertAlmostEqual(path6.point(0.5), (806.25 + 125j))
204 | self.assertAlmostEqual(path6.point(0.9), (911.85 + 173j))
205 | self.assertAlmostEqual(path6.point(1), (900 + 200j))
206 |
207 | # M600,500 C600,350 900,650 900,500
208 | path7 = CubicBezier(600 + 500j, 600 + 350j, 900 + 650j, 900 + 500j)
209 | self.assertAlmostEqual(path7.point(0), (600 + 500j))
210 | self.assertAlmostEqual(path7.point(0.3), (664.8 + 462.2j))
211 | self.assertAlmostEqual(path7.point(0.5), (750 + 500j))
212 | self.assertAlmostEqual(path7.point(0.9), (891.6 + 532.4j))
213 | self.assertAlmostEqual(path7.point(1), (900 + 500j))
214 |
215 | # M600,800 C625,700 725,700 750,800
216 | path8 = CubicBezier(600 + 800j, 625 + 700j, 725 + 700j, 750 + 800j)
217 | self.assertAlmostEqual(path8.point(0), (600 + 800j))
218 | self.assertAlmostEqual(path8.point(0.3), (638.7 + 737j))
219 | self.assertAlmostEqual(path8.point(0.5), (675 + 725j))
220 | self.assertAlmostEqual(path8.point(0.9), (740.4 + 773j))
221 | self.assertAlmostEqual(path8.point(1), (750 + 800j))
222 |
223 | # S875,900 900,800
224 | inversion = (750 + 800j) + (750 + 800j) - (725 + 700j)
225 | path9 = CubicBezier(750 + 800j, inversion, 875 + 900j, 900 + 800j)
226 | self.assertAlmostEqual(path9.point(0), (750 + 800j))
227 | self.assertAlmostEqual(path9.point(0.3), (788.7 + 863j))
228 | self.assertAlmostEqual(path9.point(0.5), (825 + 875j))
229 | self.assertAlmostEqual(path9.point(0.9), (890.4 + 827j))
230 | self.assertAlmostEqual(path9.point(1), (900 + 800j))
231 |
232 | def test_length(self):
233 |
234 | # A straight line:
235 | arc = CubicBezier(
236 | complex(0, 0),
237 | complex(0, 0),
238 | complex(0, 100),
239 | complex(0, 100)
240 | )
241 |
242 | self.assertAlmostEqual(arc.length(), 100)
243 |
244 | # A diagonal line:
245 | arc = CubicBezier(
246 | complex(0, 0),
247 | complex(0, 0),
248 | complex(100, 100),
249 | complex(100, 100)
250 | )
251 |
252 | self.assertAlmostEqual(arc.length(), sqrt(2 * 100 * 100))
253 |
254 | # A quarter circle arc with radius 100:
255 | kappa = 4 * (sqrt(2) - 1) / 3 # http://www.whizkidtech.redprince.net/bezier/circle/
256 |
257 | arc = CubicBezier(
258 | complex(0, 0),
259 | complex(0, kappa * 100),
260 | complex(100 - kappa * 100, 100),
261 | complex(100, 100)
262 | )
263 |
264 | # We can't compare with pi*50 here, because this is just an
265 | # approximation of a circle arc. pi*50 is 157.079632679
266 | # So this is just yet another "warn if this changes" test.
267 | # This value is not verified to be correct.
268 | self.assertAlmostEqual(arc.length(), 157.1016698)
269 |
270 | # A recursive solution has also been suggested, but for CubicBezier
271 | # curves it could get a false solution on curves where the midpoint is on a
272 | # straight line between the start and end. For example, the following
273 | # curve would get solved as a straight line and get the length 300.
274 | # Make sure this is not the case.
275 | arc = CubicBezier(
276 | complex(600, 500),
277 | complex(600, 350),
278 | complex(900, 650),
279 | complex(900, 500)
280 | )
281 | self.assertTrue(arc.length() > 300.0)
282 |
283 | def test_equality(self):
284 | # This is to test the __eq__ and __ne__ methods, so we can't use
285 | # assertEqual and assertNotEqual
286 | segment = CubicBezier(complex(600, 500), complex(600, 350),
287 | complex(900, 650), complex(900, 500))
288 |
289 | self.assertTrue(segment ==
290 | CubicBezier(600 + 500j, 600 + 350j, 900 + 650j, 900 + 500j))
291 | self.assertTrue(segment !=
292 | CubicBezier(600 + 501j, 600 + 350j, 900 + 650j, 900 + 500j))
293 | self.assertTrue(segment != Line(0, 400))
294 |
295 |
296 | class QuadraticBezierTest(unittest.TestCase):
297 |
298 | def test_svg_examples(self):
299 | """These is the path in the SVG specs"""
300 | # M200,300 Q400,50 600,300 T1000,300
301 | path1 = QuadraticBezier(200 + 300j, 400 + 50j, 600 + 300j)
302 | self.assertAlmostEqual(path1.point(0), (200 + 300j))
303 | self.assertAlmostEqual(path1.point(0.3), (320 + 195j))
304 | self.assertAlmostEqual(path1.point(0.5), (400 + 175j))
305 | self.assertAlmostEqual(path1.point(0.9), (560 + 255j))
306 | self.assertAlmostEqual(path1.point(1), (600 + 300j))
307 |
308 | # T1000, 300
309 | inversion = (600 + 300j) + (600 + 300j) - (400 + 50j)
310 | path2 = QuadraticBezier(600 + 300j, inversion, 1000 + 300j)
311 | self.assertAlmostEqual(path2.point(0), (600 + 300j))
312 | self.assertAlmostEqual(path2.point(0.3), (720 + 405j))
313 | self.assertAlmostEqual(path2.point(0.5), (800 + 425j))
314 | self.assertAlmostEqual(path2.point(0.9), (960 + 345j))
315 | self.assertAlmostEqual(path2.point(1), (1000 + 300j))
316 |
317 | def test_length(self):
318 | # expected results calculated with
319 | # svg.path.segment_length(q, 0, 1, q.start, q.end, 1e-14, 20, 0)
320 | q1 = QuadraticBezier(200 + 300j, 400 + 50j, 600 + 300j)
321 | q2 = QuadraticBezier(200 + 300j, 400 + 50j, 500 + 200j)
322 | closedq = QuadraticBezier(6+2j, 5-1j, 6+2j)
323 | linq1 = QuadraticBezier(1, 2, 3)
324 | linq2 = QuadraticBezier(1+3j, 2+5j, -9 - 17j)
325 | nodalq = QuadraticBezier(1, 1, 1)
326 | tests = [(q1, 487.77109389525975),
327 | (q2, 379.90458193489155),
328 | (closedq, 3.1622776601683795),
329 | (linq1, 2),
330 | (linq2, 22.73335777124786),
331 | (nodalq, 0)]
332 | for q, exp_res in tests:
333 | self.assertAlmostEqual(q.length(), exp_res)
334 |
335 | def test_equality(self):
336 | # This is to test the __eq__ and __ne__ methods, so we can't use
337 | # assertEqual and assertNotEqual
338 | segment = QuadraticBezier(200 + 300j, 400 + 50j, 600 + 300j)
339 | self.assertTrue(segment == QuadraticBezier(200 + 300j, 400 + 50j, 600 + 300j))
340 | self.assertTrue(segment != QuadraticBezier(200 + 301j, 400 + 50j, 600 + 300j))
341 | self.assertFalse(segment == Arc(0j, 100 + 50j, 0, 0, 0, 100 + 50j))
342 | self.assertTrue(Arc(0j, 100 + 50j, 0, 0, 0, 100 + 50j) != segment)
343 |
344 |
345 | class ArcTest(unittest.TestCase):
346 |
347 | def test_points(self):
348 | arc1 = Arc(0j, 100 + 50j, 0, 0, 0, 100 + 50j)
349 | self.assertAlmostEqual(arc1.center, 100 + 0j)
350 | self.assertAlmostEqual(arc1.theta, 180.0)
351 | self.assertAlmostEqual(arc1.delta, -90.0)
352 |
353 | self.assertAlmostEqual(arc1.point(0.0), (0j))
354 | self.assertAlmostEqual(arc1.point(0.1), (1.23116594049 + 7.82172325201j))
355 | self.assertAlmostEqual(arc1.point(0.2), (4.89434837048 + 15.4508497187j))
356 | self.assertAlmostEqual(arc1.point(0.3), (10.8993475812 + 22.699524987j))
357 | self.assertAlmostEqual(arc1.point(0.4), (19.0983005625 + 29.3892626146j))
358 | self.assertAlmostEqual(arc1.point(0.5), (29.2893218813 + 35.3553390593j))
359 | self.assertAlmostEqual(arc1.point(0.6), (41.2214747708 + 40.4508497187j))
360 | self.assertAlmostEqual(arc1.point(0.7), (54.6009500260 + 44.5503262094j))
361 | self.assertAlmostEqual(arc1.point(0.8), (69.0983005625 + 47.5528258148j))
362 | self.assertAlmostEqual(arc1.point(0.9), (84.3565534960 + 49.3844170298j))
363 | self.assertAlmostEqual(arc1.point(1.0), (100 + 50j))
364 |
365 | arc2 = Arc(0j, 100 + 50j, 0, 1, 0, 100 + 50j)
366 | self.assertAlmostEqual(arc2.center, 50j)
367 | self.assertAlmostEqual(arc2.theta, 270.0)
368 | self.assertAlmostEqual(arc2.delta, -270.0)
369 |
370 | self.assertAlmostEqual(arc2.point(0.0), (0j))
371 | self.assertAlmostEqual(arc2.point(0.1), (-45.399049974 + 5.44967379058j))
372 | self.assertAlmostEqual(arc2.point(0.2), (-80.9016994375 + 20.6107373854j))
373 | self.assertAlmostEqual(arc2.point(0.3), (-98.7688340595 + 42.178276748j))
374 | self.assertAlmostEqual(arc2.point(0.4), (-95.1056516295 + 65.4508497187j))
375 | self.assertAlmostEqual(arc2.point(0.5), (-70.7106781187 + 85.3553390593j))
376 | self.assertAlmostEqual(arc2.point(0.6), (-30.9016994375 + 97.5528258148j))
377 | self.assertAlmostEqual(arc2.point(0.7), (15.643446504 + 99.3844170298j))
378 | self.assertAlmostEqual(arc2.point(0.8), (58.7785252292 + 90.4508497187j))
379 | self.assertAlmostEqual(arc2.point(0.9), (89.1006524188 + 72.699524987j))
380 | self.assertAlmostEqual(arc2.point(1.0), (100 + 50j))
381 |
382 | arc3 = Arc(0j, 100 + 50j, 0, 0, 1, 100 + 50j)
383 | self.assertAlmostEqual(arc3.center, 50j)
384 | self.assertAlmostEqual(arc3.theta, 270.0)
385 | self.assertAlmostEqual(arc3.delta, 90.0)
386 |
387 | self.assertAlmostEqual(arc3.point(0.0), (0j))
388 | self.assertAlmostEqual(arc3.point(0.1), (15.643446504 + 0.615582970243j))
389 | self.assertAlmostEqual(arc3.point(0.2), (30.9016994375 + 2.44717418524j))
390 | self.assertAlmostEqual(arc3.point(0.3), (45.399049974 + 5.44967379058j))
391 | self.assertAlmostEqual(arc3.point(0.4), (58.7785252292 + 9.54915028125j))
392 | self.assertAlmostEqual(arc3.point(0.5), (70.7106781187 + 14.6446609407j))
393 | self.assertAlmostEqual(arc3.point(0.6), (80.9016994375 + 20.6107373854j))
394 | self.assertAlmostEqual(arc3.point(0.7), (89.1006524188 + 27.300475013j))
395 | self.assertAlmostEqual(arc3.point(0.8), (95.1056516295 + 34.5491502813j))
396 | self.assertAlmostEqual(arc3.point(0.9), (98.7688340595 + 42.178276748j))
397 | self.assertAlmostEqual(arc3.point(1.0), (100 + 50j))
398 |
399 | arc4 = Arc(0j, 100 + 50j, 0, 1, 1, 100 + 50j)
400 | self.assertAlmostEqual(arc4.center, 100 + 0j)
401 | self.assertAlmostEqual(arc4.theta, 180.0)
402 | self.assertAlmostEqual(arc4.delta, 270.0)
403 |
404 | self.assertAlmostEqual(arc4.point(0.0), (0j))
405 | self.assertAlmostEqual(arc4.point(0.1), (10.8993475812 - 22.699524987j))
406 | self.assertAlmostEqual(arc4.point(0.2), (41.2214747708 - 40.4508497187j))
407 | self.assertAlmostEqual(arc4.point(0.3), (84.3565534960 - 49.3844170298j))
408 | self.assertAlmostEqual(arc4.point(0.4), (130.901699437 - 47.5528258148j))
409 | self.assertAlmostEqual(arc4.point(0.5), (170.710678119 - 35.3553390593j))
410 | self.assertAlmostEqual(arc4.point(0.6), (195.105651630 - 15.4508497187j))
411 | self.assertAlmostEqual(arc4.point(0.7), (198.768834060 + 7.82172325201j))
412 | self.assertAlmostEqual(arc4.point(0.8), (180.901699437 + 29.3892626146j))
413 | self.assertAlmostEqual(arc4.point(0.9), (145.399049974 + 44.5503262094j))
414 | self.assertAlmostEqual(arc4.point(1.0), (100 + 50j))
415 |
416 | def test_length(self):
417 | # I'll test the length calculations by making a circle, in two parts.
418 | arc1 = Arc(0j, 100 + 100j, 0, 0, 0, 200 + 0j)
419 | arc2 = Arc(200 + 0j, 100 + 100j, 0, 0, 0, 0j)
420 | self.assertAlmostEqual(arc1.length(), pi * 100)
421 | self.assertAlmostEqual(arc2.length(), pi * 100)
422 |
423 | def test_equality(self):
424 | # This is to test the __eq__ and __ne__ methods, so we can't use
425 | # assertEqual and assertNotEqual
426 | segment = Arc(0j, 100 + 50j, 0, 0, 0, 100 + 50j)
427 | self.assertTrue(segment == Arc(0j, 100 + 50j, 0, 0, 0, 100 + 50j))
428 | self.assertTrue(segment != Arc(0j, 100 + 50j, 0, 1, 0, 100 + 50j))
429 |
430 |
431 | class TestPath(unittest.TestCase):
432 |
433 | def test_circle(self):
434 | arc1 = Arc(0j, 100 + 100j, 0, 0, 0, 200 + 0j)
435 | arc2 = Arc(200 + 0j, 100 + 100j, 0, 0, 0, 0j)
436 | path = Path(arc1, arc2)
437 | self.assertAlmostEqual(path.point(0.0), (0j))
438 | self.assertAlmostEqual(path.point(0.25), (100 + 100j))
439 | self.assertAlmostEqual(path.point(0.5), (200 + 0j))
440 | self.assertAlmostEqual(path.point(0.75), (100 - 100j))
441 | self.assertAlmostEqual(path.point(1.0), (0j))
442 | self.assertAlmostEqual(path.length(), pi * 200)
443 |
444 | def test_svg_specs(self):
445 | """The paths that are in the SVG specs"""
446 |
447 | # Big pie: M300,200 h-150 a150,150 0 1,0 150,-150 z
448 | path = Path(Line(300 + 200j, 150 + 200j),
449 | Arc(150 + 200j, 150 + 150j, 0, 1, 0, 300 + 50j),
450 | Line(300 + 50j, 300 + 200j))
451 | # The points and length for this path are calculated and not regression tests.
452 | self.assertAlmostEqual(path.point(0.0), (300 + 200j))
453 | self.assertAlmostEqual(path.point(0.14897825542), (150 + 200j))
454 | self.assertAlmostEqual(path.point(0.5), (406.066017177 + 306.066017177j))
455 | self.assertAlmostEqual(path.point(1 - 0.14897825542), (300 + 50j))
456 | self.assertAlmostEqual(path.point(1.0), (300 + 200j))
457 | # The errors seem to accumulate. Still 6 decimal places is more than good enough.
458 | self.assertAlmostEqual(path.length(), pi * 225 + 300, places=6)
459 |
460 | # Little pie: M275,175 v-150 a150,150 0 0,0 -150,150 z
461 | path = Path(Line(275 + 175j, 275 + 25j),
462 | Arc(275 + 25j, 150 + 150j, 0, 0, 0, 125 + 175j),
463 | Line(125 + 175j, 275 + 175j))
464 | # The points and length for this path are calculated and not regression tests.
465 | self.assertAlmostEqual(path.point(0.0), (275 + 175j))
466 | self.assertAlmostEqual(path.point(0.2800495767557787), (275 + 25j))
467 | self.assertAlmostEqual(path.point(0.5), (168.93398282201787 + 68.93398282201787j))
468 | self.assertAlmostEqual(path.point(1 - 0.2800495767557787), (125 + 175j))
469 | self.assertAlmostEqual(path.point(1.0), (275 + 175j))
470 | # The errors seem to accumulate. Still 6 decimal places is more than good enough.
471 | self.assertAlmostEqual(path.length(), pi * 75 + 300, places=6)
472 |
473 | # Bumpy path: M600,350 l 50,-25
474 | # a25,25 -30 0,1 50,-25 l 50,-25
475 | # a25,50 -30 0,1 50,-25 l 50,-25
476 | # a25,75 -30 0,1 50,-25 l 50,-25
477 | # a25,100 -30 0,1 50,-25 l 50,-25
478 | path = Path(Line(600 + 350j, 650 + 325j),
479 | Arc(650 + 325j, 25 + 25j, -30, 0, 1, 700 + 300j),
480 | Line(700 + 300j, 750 + 275j),
481 | Arc(750 + 275j, 25 + 50j, -30, 0, 1, 800 + 250j),
482 | Line(800 + 250j, 850 + 225j),
483 | Arc(850 + 225j, 25 + 75j, -30, 0, 1, 900 + 200j),
484 | Line(900 + 200j, 950 + 175j),
485 | Arc(950 + 175j, 25 + 100j, -30, 0, 1, 1000 + 150j),
486 | Line(1000 + 150j, 1050 + 125j),
487 | )
488 | # These are *not* calculated, but just regression tests. Be skeptical.
489 | self.assertAlmostEqual(path.point(0.0), (600 + 350j))
490 | self.assertAlmostEqual(path.point(0.3), (755.31526434 + 217.51578768j))
491 | self.assertAlmostEqual(path.point(0.5), (832.23324151 + 156.33454892j))
492 | self.assertAlmostEqual(path.point(0.9), (974.00559321 + 115.26473532j))
493 | self.assertAlmostEqual(path.point(1.0), (1050 + 125j))
494 | # The errors seem to accumulate. Still 6 decimal places is more than good enough.
495 | self.assertAlmostEqual(path.length(), 860.6756221710)
496 |
497 | def test_repr(self):
498 | path = Path(
499 | Line(start=600 + 350j, end=650 + 325j),
500 | Arc(start=650 + 325j, radius=25 + 25j, rotation=-30, arc=0, sweep=1, end=700 + 300j),
501 | CubicBezier(start=700 + 300j, control1=800 + 400j, control2=750 + 200j, end=600 + 100j),
502 | QuadraticBezier(start=600 + 100j, control=600, end=600 + 300j))
503 | self.assertEqual(eval(repr(path)), path)
504 |
505 | def test_reverse(self):
506 | # Currently you can't reverse paths.
507 | self.assertRaises(NotImplementedError, Path().reverse)
508 |
509 | def test_equality(self):
510 | # This is to test the __eq__ and __ne__ methods, so we can't use
511 | # assertEqual and assertNotEqual
512 | path1 = Path(
513 | Line(start=600 + 350j, end=650 + 325j),
514 | Arc(start=650 + 325j, radius=25 + 25j, rotation=-30, arc=0, sweep=1, end=700 + 300j),
515 | CubicBezier(start=700 + 300j, control1=800 + 400j, control2=750 + 200j, end=600 + 100j),
516 | QuadraticBezier(start=600 + 100j, control=600, end=600 + 300j))
517 | path2 = Path(
518 | Line(start=600 + 350j, end=650 + 325j),
519 | Arc(start=650 + 325j, radius=25 + 25j, rotation=-30, arc=0, sweep=1, end=700 + 300j),
520 | CubicBezier(start=700 + 300j, control1=800 + 400j, control2=750 + 200j, end=600 + 100j),
521 | QuadraticBezier(start=600 + 100j, control=600, end=600 + 300j))
522 |
523 | self.assertTrue(path1 == path2)
524 | # Modify path2:
525 | path2[0].start = 601 + 350j
526 | self.assertTrue(path1 != path2)
527 |
528 | # Modify back:
529 | path2[0].start = 600 + 350j
530 | self.assertFalse(path1 != path2)
531 |
532 | # Get rid of the last segment:
533 | del path2[-1]
534 | self.assertFalse(path1 == path2)
535 |
536 | # It's not equal to a list of it's segments
537 | self.assertTrue(path1 != path1[:])
538 | self.assertFalse(path1 == path1[:])
539 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 |
4 | import argparse
5 | import time
6 | import os
7 | import cPickle
8 |
9 | from utils import SketchLoader
10 | from model import Model
11 |
12 | def main():
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument('--rnn_size', type=int, default=256,
15 | help='size of RNN hidden state')
16 | parser.add_argument('--num_layers', type=int, default=2,
17 | help='number of layers in the RNN')
18 | parser.add_argument('--model', type=str, default='lstm',
19 | help='rnn, gru, or lstm')
20 | parser.add_argument('--batch_size', type=int, default=100,
21 | help='minibatch size')
22 | parser.add_argument('--seq_length', type=int, default=300,
23 | help='RNN sequence length')
24 | parser.add_argument('--num_epochs', type=int, default=500,
25 | help='number of epochs')
26 | parser.add_argument('--save_every', type=int, default=250,
27 | help='save frequency')
28 | parser.add_argument('--grad_clip', type=float, default=5.0,
29 | help='clip gradients at this value')
30 | parser.add_argument('--learning_rate', type=float, default=0.005,
31 | help='learning rate')
32 | parser.add_argument('--decay_rate', type=float, default=0.99,
33 | help='decay rate for rmsprop')
34 | parser.add_argument('--num_mixture', type=int, default=24,
35 | help='number of gaussian mixtures')
36 | parser.add_argument('--data_scale', type=float, default=15.0,
37 | help='factor to scale raw data down by')
38 | parser.add_argument('--keep_prob', type=float, default=0.8,
39 | help='dropout keep probability')
40 | parser.add_argument('--stroke_importance_factor', type=float, default=200.0,
41 | help='relative importance of pen status over mdn coordinate accuracy')
42 | parser.add_argument('--dataset_name', type=str, default="kanji",
43 | help='name of directory containing training data')
44 | args = parser.parse_args()
45 | train(args)
46 |
47 | def train(args):
48 | data_loader = SketchLoader(args.batch_size, args.seq_length, args.data_scale, args.dataset_name)
49 |
50 | dirname = os.path.join('save', args.dataset_name)
51 | if not os.path.exists(dirname):
52 | os.makedirs(dirname)
53 |
54 | with open(os.path.join(dirname, 'config.pkl'), 'w') as f:
55 | cPickle.dump(args, f)
56 |
57 | model = Model(args)
58 |
59 | b_processed = 0
60 |
61 | with tf.Session() as sess:
62 |
63 | tf.initialize_all_variables().run()
64 | saver = tf.train.Saver(tf.all_variables())
65 |
66 | # load previously trained model if appilcable
67 | ckpt = tf.train.get_checkpoint_state(os.path.join('save', args.dataset_name))
68 | if ckpt:
69 | print "loading last model: ",ckpt.model_checkpoint_path
70 | saver.restore(sess, ckpt.model_checkpoint_path)
71 |
72 | def save_model():
73 | checkpoint_path = os.path.join('save', args.dataset_name, 'model.ckpt')
74 | saver.save(sess, checkpoint_path, global_step = b_processed)
75 | print "model saved to {}".format(checkpoint_path)
76 |
77 | for e in xrange(args.num_epochs):
78 | sess.run(tf.assign(model.lr, args.learning_rate * (args.decay_rate ** e)))
79 | data_loader.reset_index_pointer()
80 | state = model.initial_state.eval()
81 | while data_loader.epoch_finished == False:
82 | start = time.time()
83 | input_data, target_data = data_loader.next_batch()
84 | feed = {model.input_data: input_data, model.target_data: target_data, model.initial_state: state}
85 | train_loss, shape_loss, pen_loss, state, _ = sess.run([model.cost, model.cost_shape, model.cost_pen, model.final_state, model.train_op], feed)
86 | end = time.time()
87 | b_processed += 1
88 | print "{}/{} (epoch {} batch {}), cost = {:.2f} ({:.2f}+{:.4f}), time/batch = {:.2f}" \
89 | .format(data_loader.pointer + e * data_loader.num_samples,
90 | args.num_epochs * data_loader.num_samples,
91 | e, b_processed ,train_loss, shape_loss, pen_loss, end - start)
92 | # assert( train_loss != np.NaN or train_loss != np.Inf) # doesn't work.
93 | assert( train_loss < 30000) # if dodgy loss, exit w/ error.
94 | if (b_processed) % args.save_every == 0 and ((b_processed) > 0):
95 | save_model()
96 | save_model()
97 |
98 | if __name__ == '__main__':
99 | main()
100 |
101 |
102 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cPickle
3 | import numpy as np
4 | import xml.etree.ElementTree as ET
5 | import random
6 | import svgwrite
7 | from IPython.display import SVG, display
8 | from svg.path import Path, Line, Arc, CubicBezier, QuadraticBezier, parse_path
9 |
10 | def calculate_start_point(data, factor=1.0, block_size = 200):
11 | # will try to center the sketch to the middle of the block
12 | # determines maxx, minx, maxy, miny
13 | sx = 0
14 | sy = 0
15 | maxx = 0
16 | minx = 0
17 | maxy = 0
18 | miny = 0
19 | for i in xrange(len(data)):
20 | sx += round(float(data[i, 0])*factor, 3)
21 | sy += round(float(data[i, 1])*factor, 3)
22 | maxx = max(maxx, sx)
23 | minx = min(minx, sx)
24 | maxy = max(maxy, sy)
25 | miny = min(miny, sy)
26 |
27 | abs_x = block_size/2-(maxx-minx)/2-minx
28 | abs_y = block_size/2-(maxy-miny)/2-miny
29 |
30 | return abs_x, abs_y, (maxx-minx), (maxy-miny)
31 |
32 | def draw_stroke_color_array(data, factor=1, svg_filename = 'sample.svg', stroke_width = 1, block_size = 200, maxcol = 5, svg_only = False, color_mode = True):
33 |
34 | num_char = len(data)
35 |
36 | if num_char < 1:
37 | return
38 |
39 | max_color_intensity = 225
40 |
41 | numrow = np.ceil(float(num_char)/float(maxcol))
42 | dwg = svgwrite.Drawing(svg_filename, size=(block_size*(min(num_char, maxcol)), block_size*numrow))
43 | dwg.add(dwg.rect(insert=(0, 0), size=(block_size*(min(num_char, maxcol)), block_size*numrow),fill='white'))
44 |
45 | the_color = "rgb("+str(random.randint(0, max_color_intensity))+","+str(int(random.randint(0, max_color_intensity)))+","+str(int(random.randint(0, max_color_intensity)))+")"
46 |
47 | for j in xrange(len(data)):
48 |
49 | lift_pen = 0
50 | #end_of_char = 0
51 | cdata = data[j]
52 | abs_x, abs_y, size_x, size_y = calculate_start_point(cdata, factor, block_size)
53 | abs_x += (j % maxcol) * block_size
54 | abs_y += (j / maxcol) * block_size
55 |
56 | for i in xrange(len(cdata)):
57 |
58 | x = round(float(cdata[i,0])*factor, 3)
59 | y = round(float(cdata[i,1])*factor, 3)
60 |
61 | prev_x = round(abs_x, 3)
62 | prev_y = round(abs_y, 3)
63 |
64 | abs_x += x
65 | abs_y += y
66 |
67 | if (lift_pen == 1):
68 | p = "M "+str(abs_x)+","+str(abs_y)+" "
69 | the_color = "rgb("+str(random.randint(0, max_color_intensity))+","+str(int(random.randint(0, max_color_intensity)))+","+str(int(random.randint(0, max_color_intensity)))+")"
70 | else:
71 | p = "M "+str(prev_x)+","+str(prev_y)+" L "+str(abs_x)+","+str(abs_y)+" "
72 |
73 | lift_pen = max(cdata[i, 2], cdata[i, 3]) # lift pen if both eos or eoc
74 | #end_of_char = cdata[i, 3] # not used for now.
75 |
76 | if color_mode == False:
77 | the_color = "#000"
78 |
79 | dwg.add(dwg.path(p).stroke(the_color,stroke_width).fill(the_color)) #, opacity=round(random.random()*0.5+0.5, 3)
80 |
81 | dwg.save()
82 | if svg_only == False:
83 | display(SVG(dwg.tostring()))
84 |
85 | def draw_stroke_color(data, factor=1, svg_filename = 'sample.svg', stroke_width = 1, block_size = 200, maxcol = 5, svg_only = False, color_mode = True):
86 |
87 | def split_sketch(data):
88 | # split a sketch with many eoc into an array of sketches, each with just one eoc at the end.
89 | # ignores last stub with no eoc.
90 | counter = 0
91 | result = []
92 | for i in xrange(len(data)):
93 | eoc = data[i, 3]
94 | if eoc > 0:
95 | result.append(data[counter:i+1])
96 | counter = i+1
97 | #if (counter < len(data)): # ignore the rest
98 | # result.append(data[counter:])
99 | return result
100 |
101 | data = np.array(data, dtype=np.float32)
102 | data = split_sketch(data)
103 | draw_stroke_color_array(data, factor, svg_filename, stroke_width, block_size, maxcol, svg_only, color_mode)
104 |
105 | class SketchLoader():
106 | def __init__(self, batch_size=50, seq_length=300, scale_factor = 1.0, data_filename = "kanji"):
107 | self.data_dir = "./data"
108 | self.batch_size = batch_size
109 | self.seq_length = seq_length
110 | self.scale_factor = scale_factor # divide data by this factor
111 |
112 | data_file = os.path.join(self.data_dir, data_filename+".cpkl")
113 | raw_data_dir = os.path.join(self.data_dir, data_filename)
114 |
115 | if not (os.path.exists(data_file)) :
116 | print "creating training data cpkl file from raw source"
117 | self.length_data = self.preprocess(raw_data_dir, data_file)
118 |
119 | self.load_preprocessed(data_file)
120 | self.num_samples = len(self.raw_data)
121 | self.index = range(self.num_samples) # this list will be randomized later.
122 | self.reset_index_pointer()
123 |
124 | def preprocess(self, data_dir, data_file):
125 | # create data file from raw xml files from iam handwriting source.
126 | len_data = []
127 | def cubicbezier(x0, y0, x1, y1, x2, y2, x3, y3, n=20):
128 | # from http://rosettacode.org/wiki/Bitmap/B%C3%A9zier_curves/Cubic
129 | pts = []
130 | for i in range(n+1):
131 | t = float(i) / float(n)
132 | a = (1. - t)**3
133 | b = 3. * t * (1. - t)**2
134 | c = 3.0 * t**2 * (1.0 - t)
135 | d = t**3
136 |
137 | x = float(a * x0 + b * x1 + c * x2 + d * x3)
138 | y = float(a * y0 + b * y1 + c * y2 + d * y3)
139 | pts.append( (x, y) )
140 | return pts
141 |
142 | def get_path_strings(svgfile):
143 | tree = ET.parse(svgfile)
144 | p = []
145 | for elem in tree.iter():
146 | if elem.attrib.has_key('d'):
147 | p.append(elem.attrib['d'])
148 | return p
149 |
150 | def build_lines(svgfile, line_length_threshold = 10.0, min_points_per_path = 1, max_points_per_path = 3):
151 | # we don't draw lines less than line_length_threshold
152 | path_strings = get_path_strings(svgfile)
153 |
154 | lines = []
155 |
156 | for path_string in path_strings:
157 | full_path = parse_path(path_string)
158 | for i in range(len(full_path)):
159 | p = full_path[i]
160 | if type(p) != Line and type(p) != CubicBezier:
161 | print "encountered an element that is not just a line or bezier "
162 | print "type: ",type(p)
163 | print p
164 | else:
165 | x_start = p.start.real
166 | y_start = p.start.imag
167 | x_end = p.end.real
168 | y_end = p.end.imag
169 | line_length = np.sqrt((x_end-x_start)*(x_end-x_start)+(y_end-y_start)*(y_end-y_start))
170 | len_data.append(line_length)
171 | points = []
172 | if type(p) == CubicBezier:
173 | x_con1 = p.control1.real
174 | y_con1 = p.control1.imag
175 | x_con2 = p.control2.real
176 | y_con2 = p.control2.imag
177 | n_points = int(line_length / line_length_threshold)+1
178 | n_points = max(n_points, min_points_per_path)
179 | n_points = min(n_points, max_points_per_path)
180 | points = cubicbezier(x_start, y_start, x_con1, y_con1, x_con2, y_con2, x_end, y_end, n_points)
181 | else:
182 | points = [(x_start, y_start), (x_end, y_end)]
183 | if i == 0: # only append the starting point for svg
184 | lines.append([points[0][0], points[0][1], 0, 0]) # put eoc to be zero
185 | for j in range(1, len(points)):
186 | eos = 0
187 | if j == len(points)-1 and i == len(full_path)-1:
188 | eos = 1
189 | lines.append([points[j][0], points[j][1], eos, 0]) # put eoc to be zero
190 | lines = np.array(lines, dtype=np.float32)
191 | # make it relative moves
192 | lines[1:,0:2] -= lines[0:-1,0:2]
193 | lines[-1,3] = 1 # end of character
194 | lines[0] = [0, 0, 0, 0] # start at origin
195 | return lines[1:]
196 |
197 | # build the list of xml files
198 | filelist = []
199 | # Set the directory you want to start from
200 | rootDir = data_dir
201 | for dirName, subdirList, fileList in os.walk(rootDir):
202 | #print('Found directory: %s' % dirName)
203 | for fname in fileList:
204 | #print('\t%s' % fname)
205 | filelist.append(dirName+"/"+fname)
206 |
207 | # build stroke database of every xml file inside iam database
208 | sketch = []
209 | for i in range(len(filelist)):
210 | if (filelist[i][-3:] == 'svg'):
211 | print 'processing '+filelist[i]
212 | sketch.append(build_lines(filelist[i]))
213 |
214 | f = open(data_file,"wb")
215 | cPickle.dump(sketch, f, protocol=2)
216 | f.close()
217 | return len_data
218 |
219 | def load_preprocessed(self, data_file):
220 | f = open(data_file,"rb")
221 | self.raw_data = cPickle.load(f)
222 | # scale the data here, rather than at the data construction (since scaling may change)
223 | for data in self.raw_data:
224 | data[:,0:2] /= self.scale_factor
225 | f.close()
226 |
227 | def next_batch(self):
228 | # returns a set of batches, but the constraint is that the start of each input data batch
229 | # is the start of a new character (although the end of a batch doesn't have to be end of a character)
230 |
231 | def next_seq(n):
232 | result = np.zeros((n, 5), dtype=np.float32) # x, y, [eos, eoc, cont] tokens
233 | #result[0, 2:4] = 1 # set eos and eoc to true for first point
234 | # experimental line below, put a random factor between 70-130% to generate more examples
235 | rand_scale_factor_x = np.random.rand()*0.6+0.7
236 | rand_scale_factor_y = np.random.rand()*0.6+0.7
237 | idx = 0
238 | data = self.current_data()
239 | for i in xrange(n):
240 | result[i, 0:4] = data[idx] # eoc = 0.0
241 | result[i, 4] = 1 # continue on stroke
242 | if (result[i, 2] > 0 or result[i, 3] > 0):
243 | result[i, 4] = 0
244 | idx += 1
245 | if (idx >= len(data)-1): # skip to next sketch example next time and mark eoc
246 | result[i, 4] = 0
247 | result[i, 3] = 1
248 | result[i, 2] = 0 # overrides end of stroke one-hot
249 | idx = 0
250 | self.tick_index_pointer()
251 | data = self.current_data()
252 | assert(result[i, 2:5].sum() == 1)
253 | self.tick_index_pointer() # needed if seq_length is less than last data.
254 | result[:, 0] *= rand_scale_factor_x
255 | result[:, 1] *= rand_scale_factor_y
256 | return result
257 |
258 | skip_length = self.seq_length+1
259 |
260 | batch = []
261 |
262 | for i in xrange(self.batch_size):
263 | seq = next_seq(skip_length)
264 | batch.append(seq)
265 |
266 | batch = np.array(batch, dtype=np.float32)
267 |
268 | return batch[:,0:-1], batch[:, 1:]
269 |
270 | def current_data(self):
271 | return self.raw_data[self.index[self.pointer]]
272 |
273 | def tick_index_pointer(self):
274 | self.pointer += 1
275 | if (self.pointer >= len(self.raw_data)):
276 | self.pointer = 0
277 | self.epoch_finished = True
278 |
279 | def reset_index_pointer(self):
280 | # randomize order for the raw list in the next go.
281 | self.pointer = 0
282 | self.epoch_finished = False
283 | self.index = np.random.permutation(self.index)
284 |
285 |
286 |
--------------------------------------------------------------------------------