├── README.md ├── data └── kanji.cpkl ├── example ├── aaron_sheep_sample.svg ├── cat_vae.png ├── catbus.svg ├── catbus2.svg ├── data_format.svg ├── doraemon.svg ├── elephant.svg ├── elephantpig.svg ├── frog_crab_cat.png ├── full_predictions.svg ├── morph_catchair.svg ├── morph_catchairs.svg ├── multiple_interpolations.png ├── omniglot_sample.svg ├── output.svg ├── owlmorph.svg ├── pig_morph.png ├── short_kanji_sample.svg ├── sketch_rnn.png ├── sketch_rnn.svg ├── sketch_rnn_examples.svg ├── sketch_rnn_schematic.svg ├── training.svg ├── vae_analogy.svg ├── vae_cat.svg ├── vae_cats.svg ├── vae_morph.svg ├── vae_morphs.svg ├── vae_pig.svg └── vae_pigs.svg ├── get_kanji.sh ├── model.py ├── requirements.linux-x64-cpu.txt ├── requirements.linux-x64-gpu.txt ├── requirements.mac.txt ├── sample.py ├── save └── kanji │ ├── checkpoint │ ├── config.pkl │ └── model.ckpt-0 ├── svg ├── __init__.py └── path │ ├── __init__.py │ ├── parser.py │ ├── path.py │ └── tests │ ├── __init__.py │ ├── test_doc.py │ ├── test_generation.py │ ├── test_parsing.py │ └── test_paths.py ├── train.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Depreciated 2 | 3 | This version of sketch-rnn has been depreciated. Please see an updated version of [sketch-rnn](https://github.com/tensorflow/magenta/blob/master/magenta/models/sketch_rnn/README.md), which is a full generative model for vector drawings. 4 | 5 | # sketch-rnn 6 | 7 | Implementation multi-layer recurrent neural network (RNN, LSTM GRU) used to model and generate sketches stored in .svg vector graphic files. The methodology used is to combine Mixture Density Networks with a RNN, along with modelling dynamic end-of-stroke and end-of-content probabilities learned from a large corpus of similar .svg files, to generate drawings that is simlar to the vector training data. 8 | 9 | See my blog post at [blog.otoro.net](http://blog.otoro.net/2015/12/28/recurrent-net-dreams-up-fake-chinese-characters-in-vector-format-with-tensorflow/) for a detailed description on applying `sketch-rnn` to learn to generate fake Chinese characters in vector format. 10 | 11 | Example Training Sketches (20 randomly chosen out of 11000 [KanjiVG](http://kanjivg.tagaini.net/) dataset): 12 | 13 | ![Example Training Sketches](https://cdn.rawgit.com/hardmaru/sketch-rnn/master/example/training.svg) 14 | 15 | Generated Sketches (Temperature = 0.1): 16 | 17 | ![Generated Sketches](https://cdn.rawgit.com/hardmaru/sketch-rnn/master/example/output.svg) 18 | 19 | # Basic Usage 20 | 21 | I tested the implementation on TensorFlow 0.50. I also used the following libraries to help: 22 | 23 | ``` 24 | svgwrite 25 | IPython.display.SVG 26 | IPython.display.display 27 | xml.etree.ElementTree 28 | argparse 29 | cPickle 30 | svg.path 31 | ``` 32 | 33 | ## Loading in Training Data 34 | 35 | The training data is located inside the `data` subdirectory. In this repo, I've included `kanji.cpkl` which is a preprocessed array of KanjiVG characters. 36 | 37 | To add a new set of training data, for example, from the [TU Berlin Sketch Database](http://cybertron.cg.tu-berlin.de/eitz/projects/classifysketch/), you have to create a subdirectory, say `tuberlin` inside the `data` directory, and in addition create a directory of the same name in the `save` directory. So you end up with `data/tuberlin/` and `save/tuberlin`, where `tuberlin` is defined as a name field for flags in the training and sample programs later on. `save/tuberlin` will contain the check-pointed trained models later on. 38 | 39 | Now, put a large collection of .svg files into `data/tuberlin/`. You can even create subdirectories within `data/tuberlin/` and it will work, as the `SketchLoader` class will scan the entire subdirectory tree. 40 | 41 | Currently, `sketch-rnn` only processes `path` elements inside svg files, and within the `path` elements, it only cares about lines and belzier curves at the moment. I found this sufficient to handle TUBerlin and KanjiVG databases, although it wouldn't be difficult to extent to process the other curve elements, even shape elements in the future. 42 | 43 | You can use `utils.py` to play out some random training data after the svg files have been copied in: 44 | 45 | ``` 46 | %run -i utils.py 47 | loader = SketchLoader(data_filename = 'tuberlin') 48 | draw_stroke_color(random.choice(loader.raw_data)) 49 | ``` 50 | 51 | ![Example Elephant from TU Berlin database](https://cdn.rawgit.com/hardmaru/sketch-rnn/master/example/elephant.svg) 52 | 53 | For this algorithm to work, I recommend the data be similar in size, and similar in style / content. For examples if we have bananas, buildings, elephants, rockets, insects of varying shapes and sizes, it would most likely just produce gibberish. 54 | 55 | ## Training the Model 56 | 57 | After the data is loaded, let's continue with the 'tuberlin' example, you can run `python train.py --dataset_name tuberlin` 58 | 59 | A number of flags can be set for training if you wish to experiment with the parameters. You probably want to change these around, especially the scaling factors to better suit the sizes of your .svg data. 60 | 61 | The default values are in `train.py` 62 | 63 | ``` 64 | --rnn_size RNN_SIZE size of RNN hidden state (256) 65 | --num_layers NUM_LAYERS number of layers in the RNN (2) 66 | --model MODEL rnn, gru, or lstm (lstm) 67 | --batch_size BATCH_SIZE minibatch size (100) 68 | --seq_length SEQ_LENGTH RNN sequence length (300) 69 | --num_epochs NUM_EPOCHS number of epochs (500) 70 | --save_every SAVE_EVERY save frequency (250) 71 | --grad_clip GRAD_CLIP clip gradients at this value (5.0) 72 | --learning_rate LEARNING_RATE learning rate (0.005) 73 | --decay_rate DECAY_RATE decay rate after each epoch (adam is used) (0.99) 74 | --num_mixture NUM_MIXTURE number of gaussian mixtures (24) 75 | --data_scale DATA_SCALE factor to scale raw data down by (15.0) 76 | --keep_prob KEEP_PROB dropout keep probability (0.8) 77 | --stroke_importance_factor F gradient boosting of sketch-finish event (200.0) 78 | --dataset_name DATASET_NAME name of directory containing training data (kanji) 79 | ``` 80 | 81 | ## Sampling a Sketch 82 | 83 | I've included a pretrained model in `/save` so it should work out of the box. Running `python sample.py --filename output --num_picture 10 --dataset_name kanji` will generate an .svg file containing 10 fake Kanji characters using the pretrained model. Please run `python sample.py --help` to examine extra flags, to see how to change things like number of sketches per row, etc. 84 | 85 | It should be straight forward to examine `sample.py` to be able to generate sketches interactively using an IPython prompt rather than in the command line. Running `%run -i sample.py` in an IPython interactive session would generate sketches shown in the IPython interface as well as generating an .svg output. 86 | 87 | ## More useful links, pointers, datasets 88 | 89 | - Alex Graves' [paper](http://arxiv.org/abs/1308.0850) on text sequence and handwriting generation. 90 | 91 | - Karpathy's [char-rnn](https://github.com/karpathy/char-rnn) tool, motivation for creating sketch-rnn. 92 | 93 | - [KanjiVG](http://kanjivg.tagaini.net/). Fantastic Database of Kanji Stroke Order. 94 | 95 | - Very clean TensorFlow implementation of [char-rnn](https://github.com/sherjilozair/char-rnn-tensorflow), written by [Sherjil Ozair](https://github.com/sherjilozair), where I based the skeleton of this code off of. 96 | 97 | - [svg.path](https://pypi.python.org/pypi/svg.path). I used this well written tool to help convert path data into line data. 98 | 99 | - CASIA Online and Offline Chinese [Handwriting Databases](http://www.nlpr.ia.ac.cn/databases/handwriting/Download.html). Download stroke data for written cursive Simplifed Chinese. 100 | 101 | - How Do Humans Sketch Objects? [TU Berlin Sketch Database](http://cybertron.cg.tu-berlin.de/eitz/projects/classifysketch/). Would be interesting to extend this work and generate random vector art of real life stuff. 102 | 103 | - Doraemon in [SVG format](http://yylam.blogspot.hk/2012/04/doraemon-in-svg-format-doraemonsvg.html). 104 | 105 | - [Potrace](https://en.wikipedia.org/wiki/Potrace). Beautiful looking tool to convert raster bitmapped drawings into SVG for potentially scaling up resolution of drawings. Could potentially apply this to generate large amounts of training data. 106 | 107 | - [Rendering Belzier Curve Codes](http://rosettacode.org/wiki/Bitmap/B%C3%A9zier_curves/Cubic). I used this very useful code to convert Belzier curves into line segments. 108 | 109 | 110 | # License 111 | 112 | MIT 113 | -------------------------------------------------------------------------------- /data/kanji.cpkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hardmaru/sketch-rnn/862bb94a15d48f0c42af71a9ede1681a0a2d2602/data/kanji.cpkl -------------------------------------------------------------------------------- /example/cat_vae.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hardmaru/sketch-rnn/862bb94a15d48f0c42af71a9ede1681a0a2d2602/example/cat_vae.png -------------------------------------------------------------------------------- /example/catbus.svg: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /example/catbus2.svg: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /example/doraemon.svg: -------------------------------------------------------------------------------- 1 | 2 | 12 | 13 | 14 | 15 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 41 | 42 | 45 | 46 | 48 | 49 | 51 | 52 | 54 | 55 | 56 | 57 | 59 | 60 | 62 | 63 | 65 | 66 | 68 | 69 | 70 | 71 | -------------------------------------------------------------------------------- /example/elephant.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | -------------------------------------------------------------------------------- /example/elephantpig.svg: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /example/frog_crab_cat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hardmaru/sketch-rnn/862bb94a15d48f0c42af71a9ede1681a0a2d2602/example/frog_crab_cat.png -------------------------------------------------------------------------------- /example/multiple_interpolations.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hardmaru/sketch-rnn/862bb94a15d48f0c42af71a9ede1681a0a2d2602/example/multiple_interpolations.png -------------------------------------------------------------------------------- /example/omniglot_sample.svg: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /example/pig_morph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hardmaru/sketch-rnn/862bb94a15d48f0c42af71a9ede1681a0a2d2602/example/pig_morph.png -------------------------------------------------------------------------------- /example/sketch_rnn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hardmaru/sketch-rnn/862bb94a15d48f0c42af71a9ede1681a0a2d2602/example/sketch_rnn.png -------------------------------------------------------------------------------- /get_kanji.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #fetch and unpack the data 4 | cd data 5 | wget https://github.com/KanjiVG/kanjivg/releases/download/r20150615-2/kanjivg-20150615-2-all.zip 6 | unzip kanjivg-20150615-2-all.zip 7 | # move aside one problem file 8 | mkdir -p rejects 9 | mv kanji/05747-Kaisho.svg rejects 10 | # done, now try: python train.py --dataset_name kanji 11 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.models.rnn import rnn_cell 3 | 4 | import numpy as np 5 | import random 6 | 7 | class Model(): 8 | def __init__(self, args, infer=False): 9 | if infer: 10 | args.batch_size = 1 11 | args.seq_length = 1 12 | self.args = args 13 | 14 | if args.model == 'rnn': 15 | cell_fn = rnn_cell.BasicRNNCell 16 | elif args.model == 'gru': 17 | cell_fn = rnn_cell.GRUCell 18 | elif args.model == 'lstm': 19 | cell_fn = rnn_cell.BasicLSTMCell 20 | else: 21 | raise Exception("model type not supported: {}".format(args.model)) 22 | 23 | cell = cell_fn(args.rnn_size) 24 | 25 | cell = rnn_cell.MultiRNNCell([cell] * args.num_layers) 26 | 27 | if (infer == False and args.keep_prob < 1): # training mode 28 | cell = rnn_cell.DropoutWrapper(cell, output_keep_prob = args.keep_prob) 29 | 30 | self.cell = cell 31 | 32 | self.input_data = tf.placeholder(dtype=tf.float32, shape=[args.batch_size, args.seq_length, 5]) 33 | self.target_data = tf.placeholder(dtype=tf.float32, shape=[args.batch_size, args.seq_length, 5]) 34 | self.initial_state = cell.zero_state(batch_size=args.batch_size, dtype=tf.float32) 35 | 36 | self.num_mixture = args.num_mixture 37 | NOUT = 3 + self.num_mixture * 6 # [end_of_stroke + end_of_char, continue_with_stroke] + prob + 2*(mu + sig) + corr 38 | 39 | with tf.variable_scope('rnn_mdn'): 40 | output_w = tf.get_variable("output_w", [args.rnn_size, NOUT]) 41 | output_b = tf.get_variable("output_b", [NOUT]) 42 | 43 | inputs = tf.split(1, args.seq_length, self.input_data) 44 | inputs = [tf.squeeze(input_, [1]) for input_ in inputs] 45 | 46 | self.initial_input = np.zeros((args.batch_size, 5), dtype=np.float32) 47 | self.initial_input[:,4] = 1.0 # initially, the pen is down. 48 | self.initial_input = tf.constant(self.initial_input) 49 | 50 | def tfrepeat(a, repeats): 51 | num_row = a.get_shape()[0].value 52 | num_col = a.get_shape()[1].value 53 | assert(num_col == 1) 54 | result = [a for i in range(repeats)] 55 | result = tf.concat(0, result) 56 | result = tf.reshape(result, [repeats, num_row]) 57 | result = tf.transpose(result) 58 | return result 59 | 60 | def custom_rnn_autodecoder(decoder_inputs, initial_input, initial_state, cell, scope=None): 61 | # customized rnn_decoder for the task of dealing with end of character 62 | with tf.variable_scope(scope or "rnn_decoder"): 63 | states = [initial_state] 64 | outputs = [] 65 | prev = None 66 | 67 | for i in xrange(len(decoder_inputs)): 68 | inp = decoder_inputs[i] 69 | if i > 0: 70 | tf.get_variable_scope().reuse_variables() 71 | output, new_state = cell(inp, states[-1]) 72 | 73 | num_batches = self.args.batch_size # new_state.get_shape()[0].value 74 | num_state = new_state.get_shape()[1].value 75 | 76 | # if the input has an end-of-character signal, have to zero out the state 77 | 78 | #to do: test this code. 79 | 80 | eoc_detection = inp[:,3] 81 | eoc_detection = tf.reshape(eoc_detection, [num_batches, 1]) 82 | 83 | eoc_detection_state = tfrepeat(eoc_detection, num_state) 84 | 85 | eoc_detection_state = tf.greater(eoc_detection_state, tf.zeros_like(eoc_detection_state, dtype=tf.float32)) 86 | 87 | new_state = tf.select(eoc_detection_state, initial_state, new_state) 88 | 89 | outputs.append(output) 90 | states.append(new_state) 91 | return outputs, states 92 | 93 | outputs, states = custom_rnn_autodecoder(inputs, self.initial_input, self.initial_state, cell, scope='rnn_mdn') 94 | output = tf.reshape(tf.concat(1, outputs), [-1, args.rnn_size]) 95 | output = tf.nn.xw_plus_b(output, output_w, output_b) 96 | self.final_state = states[-1] 97 | 98 | # reshape target data so that it is compatible with prediction shape 99 | flat_target_data = tf.reshape(self.target_data,[-1, 5]) 100 | [x1_data, x2_data, eos_data, eoc_data, cont_data] = tf.split(1, 5, flat_target_data) 101 | pen_data = tf.concat(1, [eos_data, eoc_data, cont_data]) 102 | 103 | # long method: 104 | #flat_target_data = tf.split(1, args.seq_length, self.target_data) 105 | #flat_target_data = [tf.squeeze(flat_target_data_, [1]) for flat_target_data_ in flat_target_data] 106 | #flat_target_data = tf.reshape(tf.concat(1, flat_target_data), [-1, 3]) 107 | 108 | def tf_2d_normal(x1, x2, mu1, mu2, s1, s2, rho): 109 | # eq # 24 and 25 of http://arxiv.org/abs/1308.0850 110 | norm1 = tf.sub(x1, mu1) 111 | norm2 = tf.sub(x2, mu2) 112 | s1s2 = tf.mul(s1, s2) 113 | z = tf.square(tf.div(norm1, s1))+tf.square(tf.div(norm2, s2))-2*tf.div(tf.mul(rho, tf.mul(norm1, norm2)), s1s2) 114 | negRho = 1-tf.square(rho) 115 | result = tf.exp(tf.div(-z,2*negRho)) 116 | denom = 2*np.pi*tf.mul(s1s2, tf.sqrt(negRho)) 117 | result = tf.div(result, denom) 118 | return result 119 | 120 | def get_lossfunc(z_pi, z_mu1, z_mu2, z_sigma1, z_sigma2, z_corr, z_pen, x1_data, x2_data, pen_data): 121 | result0 = tf_2d_normal(x1_data, x2_data, z_mu1, z_mu2, z_sigma1, z_sigma2, z_corr) 122 | # implementing eq # 26 of http://arxiv.org/abs/1308.0850 123 | epsilon = 1e-20 124 | result1 = tf.mul(result0, z_pi) 125 | result1 = tf.reduce_sum(result1, 1, keep_dims=True) 126 | result1 = -tf.log(tf.maximum(result1, 1e-20)) # at the beginning, some errors are exactly zero. 127 | result_shape = tf.reduce_mean(result1) 128 | 129 | result2 = tf.nn.softmax_cross_entropy_with_logits(z_pen, pen_data) 130 | pen_data_weighting = pen_data[:, 2]+np.sqrt(self.args.stroke_importance_factor)*pen_data[:, 0]+self.args.stroke_importance_factor*pen_data[:, 1] 131 | result2 = tf.mul(result2, pen_data_weighting) 132 | result_pen = tf.reduce_mean(result2) 133 | 134 | result = result_shape + result_pen 135 | return result, result_shape, result_pen, 136 | 137 | # below is where we need to do MDN splitting of distribution params 138 | def get_mixture_coef(output): 139 | # returns the tf slices containing mdn dist params 140 | # ie, eq 18 -> 23 of http://arxiv.org/abs/1308.0850 141 | z = output 142 | z_pen = z[:, 0:3] # end of stroke, end of character/content, continue w/ stroke 143 | z_pi, z_mu1, z_mu2, z_sigma1, z_sigma2, z_corr = tf.split(1, 6, z[:, 3:]) 144 | 145 | # process output z's into MDN paramters 146 | 147 | # softmax all the pi's: 148 | max_pi = tf.reduce_max(z_pi, 1, keep_dims=True) 149 | z_pi = tf.sub(z_pi, max_pi) 150 | z_pi = tf.exp(z_pi) 151 | normalize_pi = tf.inv(tf.reduce_sum(z_pi, 1, keep_dims=True)) 152 | z_pi = tf.mul(normalize_pi, z_pi) 153 | 154 | # exponentiate the sigmas and also make corr between -1 and 1. 155 | z_sigma1 = tf.exp(z_sigma1) 156 | z_sigma2 = tf.exp(z_sigma2) 157 | z_corr = tf.tanh(z_corr) 158 | 159 | return [z_pi, z_mu1, z_mu2, z_sigma1, z_sigma2, z_corr, z_pen] 160 | 161 | [o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, o_corr, o_pen] = get_mixture_coef(output) 162 | 163 | self.pi = o_pi 164 | self.mu1 = o_mu1 165 | self.mu2 = o_mu2 166 | self.sigma1 = o_sigma1 167 | self.sigma2 = o_sigma2 168 | self.corr = o_corr 169 | self.pen = o_pen # state of the pen 170 | 171 | [lossfunc, loss_shape, loss_pen] = get_lossfunc(o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, o_corr, o_pen, x1_data, x2_data, pen_data) 172 | self.cost = lossfunc 173 | self.cost_shape = loss_shape 174 | self.cost_pen = loss_pen 175 | 176 | self.lr = tf.Variable(0.01, trainable=False) 177 | tvars = tf.trainable_variables() 178 | grads, _ = tf.clip_by_global_norm(tf.gradients(self.cost, tvars), args.grad_clip) 179 | optimizer = tf.train.AdamOptimizer(self.lr, epsilon=0.001) 180 | self.train_op = optimizer.apply_gradients(zip(grads, tvars)) 181 | 182 | 183 | def sample(self, sess, num=300, temp_mixture=1.0, temp_pen=1.0, stop_if_eoc = False): 184 | 185 | def get_pi_idx(x, pdf): 186 | N = pdf.size 187 | accumulate = 0 188 | for i in range(0, N): 189 | accumulate += pdf[i] 190 | if (accumulate >= x): 191 | return i 192 | print 'error with sampling ensemble' 193 | return -1 194 | 195 | def sample_gaussian_2d(mu1, mu2, s1, s2, rho): 196 | mean = [mu1, mu2] 197 | cov = [[s1*s1, rho*s1*s2], [rho*s1*s2, s2*s2]] 198 | x = np.random.multivariate_normal(mean, cov, 1) 199 | return x[0][0], x[0][1] 200 | 201 | prev_x = np.zeros((1, 1, 5), dtype=np.float32) 202 | #prev_x[0, 0, 2] = 1 # initially, we want to see beginning of new stroke 203 | #prev_x[0, 0, 3] = 1 # initially, we want to see beginning of new character/content 204 | prev_state = sess.run(self.cell.zero_state(self.args.batch_size, tf.float32)) 205 | 206 | strokes = np.zeros((num, 5), dtype=np.float32) 207 | mixture_params = [] 208 | 209 | for i in xrange(num): 210 | 211 | feed = {self.input_data: prev_x, self.initial_state:prev_state} 212 | 213 | [o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, o_corr, o_pen, next_state] = sess.run([self.pi, self.mu1, self.mu2, self.sigma1, self.sigma2, self.corr, self.pen, self.final_state],feed) 214 | 215 | pi_pdf = o_pi[0] 216 | if i > 1: 217 | pi_pdf = np.log(pi_pdf) / temp_mixture 218 | pi_pdf -= pi_pdf.max() 219 | pi_pdf = np.exp(pi_pdf) 220 | pi_pdf /= pi_pdf.sum() 221 | 222 | idx = get_pi_idx(random.random(), pi_pdf) 223 | 224 | pen_pdf = o_pen[0] 225 | if i > 1: 226 | pi_pdf /= temp_pen # softmax convert to prob 227 | pen_pdf -= pen_pdf.max() 228 | pen_pdf = np.exp(pen_pdf) 229 | pen_pdf /= pen_pdf.sum() 230 | 231 | pen_idx = get_pi_idx(random.random(), pen_pdf) 232 | eos = 0 233 | eoc = 0 234 | cont_state = 0 235 | 236 | if pen_idx == 0: 237 | eos = 1 238 | elif pen_idx == 1: 239 | eoc = 1 240 | else: 241 | cont_state = 1 242 | 243 | next_x1, next_x2 = sample_gaussian_2d(o_mu1[0][idx], o_mu2[0][idx], o_sigma1[0][idx], o_sigma2[0][idx], o_corr[0][idx]) 244 | 245 | strokes[i,:] = [next_x1, next_x2, eos, eoc, cont_state] 246 | 247 | params = [pi_pdf, o_mu1[0], o_mu2[0], o_sigma1[0], o_sigma2[0], o_corr[0], pen_pdf] 248 | mixture_params.append(params) 249 | 250 | # early stopping condition 251 | if (stop_if_eoc and eoc == 1): 252 | strokes = strokes[0:i+1, :] 253 | break 254 | 255 | prev_x = np.zeros((1, 1, 5), dtype=np.float32) 256 | prev_x[0][0] = np.array([next_x1, next_x2, eos, eoc, cont_state], dtype=np.float32) 257 | prev_state = next_state 258 | 259 | strokes[:,0:2] *= self.args.data_scale 260 | return strokes, mixture_params 261 | 262 | 263 | -------------------------------------------------------------------------------- /requirements.linux-x64-cpu.txt: -------------------------------------------------------------------------------- 1 | decorator==4.0.9 2 | ipython==4.1.2 3 | ipython-genutils==0.1.0 4 | numpy==1.10.4 5 | path.py==8.1.2 6 | pexpect==4.0.1 7 | pickleshare==0.6 8 | protobuf==3.0.0b2 9 | ptyprocess==0.5.1 10 | pyparsing==2.1.0 11 | simplegeneric==0.8.1 12 | six==1.10.0 13 | svgwrite==1.1.6 14 | https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.7.1-cp27-none-linux_x86_64.whl 15 | traitlets==4.2.1 16 | wheel==0.29.0 17 | wsgiref==0.1.2 18 | -------------------------------------------------------------------------------- /requirements.linux-x64-gpu.txt: -------------------------------------------------------------------------------- 1 | decorator==4.0.9 2 | ipython==4.1.2 3 | ipython-genutils==0.1.0 4 | numpy==1.10.4 5 | path.py==8.1.2 6 | pexpect==4.0.1 7 | pickleshare==0.6 8 | protobuf==3.0.0b2 9 | ptyprocess==0.5.1 10 | pyparsing==2.1.0 11 | simplegeneric==0.8.1 12 | six==1.10.0 13 | svgwrite==1.1.6 14 | https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.7.1-cp27-none-linux_x86_64.whl 15 | traitlets==4.2.1 16 | wheel==0.29.0 17 | wsgiref==0.1.2 18 | -------------------------------------------------------------------------------- /requirements.mac.txt: -------------------------------------------------------------------------------- 1 | decorator==4.0.9 2 | ipython==4.1.2 3 | ipython-genutils==0.1.0 4 | numpy==1.10.4 5 | path.py==8.1.2 6 | pexpect==4.0.1 7 | pickleshare==0.6 8 | protobuf==3.0.0b2 9 | ptyprocess==0.5.1 10 | pyparsing==2.1.0 11 | simplegeneric==0.8.1 12 | six==1.10.0 13 | svgwrite==1.1.6 14 | https://storage.googleapis.com/tensorflow/mac/tensorflow-0.7.1-cp27-none-any.whl 15 | traitlets==4.2.1 16 | wheel==0.29.0 17 | wsgiref==0.1.2 18 | -------------------------------------------------------------------------------- /sample.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | import time 5 | import os 6 | import cPickle 7 | import argparse 8 | 9 | from utils import * 10 | from model import Model 11 | import random 12 | 13 | import svgwrite 14 | from IPython.display import SVG, display 15 | 16 | # main code (not in a main function since I want to run this script in IPython as well). 17 | def in_ipython(): 18 | try: 19 | __IPYTHON__ 20 | except NameError: 21 | return False 22 | else: 23 | return True 24 | 25 | 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument('--filename', type=str, default='output', 28 | help='filename of .svg file to output, without .svg') 29 | parser.add_argument('--sample_length', type=int, default=600, 30 | help='number of strokes to sample') 31 | parser.add_argument('--picture_size', type=float, default=160, 32 | help='a centered svg will be generated of this size') 33 | parser.add_argument('--scale_factor', type=float, default=1, 34 | help='factor to scale down by for svg output. smaller means bigger output') 35 | parser.add_argument('--num_picture', type=int, default=20, 36 | help='number of pictures to generate') 37 | parser.add_argument('--num_col', type=int, default=5, 38 | help='if num_picture > 1, how many pictures per row?') 39 | parser.add_argument('--dataset_name', type=str, default="kanji", 40 | help='name of directory containing training data') 41 | parser.add_argument('--color_mode', type=int, default=1, 42 | help='set to 0 if you are a black and white sort of person...') 43 | parser.add_argument('--stroke_width', type=float, default=2.0, 44 | help='thickness of pen lines') 45 | parser.add_argument('--temperature', type=float, default=0.1, 46 | help='sampling temperature') 47 | sample_args = parser.parse_args() 48 | 49 | color_mode = True 50 | if sample_args.color_mode == 0: 51 | color_mode = False 52 | 53 | 54 | with open(os.path.join('save', sample_args.dataset_name, 'config.pkl')) as f: # future 55 | saved_args = cPickle.load(f) 56 | 57 | model = Model(saved_args, True) 58 | sess = tf.InteractiveSession() 59 | saver = tf.train.Saver(tf.all_variables()) 60 | 61 | ckpt = tf.train.get_checkpoint_state(os.path.join('save', sample_args.dataset_name)) 62 | print "loading model: ",ckpt.model_checkpoint_path 63 | 64 | saver.restore(sess, ckpt.model_checkpoint_path) 65 | 66 | def draw_sketch_array(strokes_array, svg_only = False): 67 | draw_stroke_color_array(strokes_array, factor=sample_args.scale_factor, maxcol = sample_args.num_col, svg_filename = sample_args.filename+'.svg', stroke_width = sample_args.stroke_width, block_size = sample_args.picture_size, svg_only = svg_only, color_mode = color_mode) 68 | 69 | def sample_sketches(min_size_ratio = 0.0, max_size_ratio = 0.8, min_num_stroke = 4, max_num_stroke=22, svg_only = True): 70 | N = sample_args.num_picture 71 | frame_size = float(sample_args.picture_size) 72 | max_size = frame_size * max_size_ratio 73 | min_size = frame_size * min_size_ratio 74 | count = 0 75 | sketch_list = [] 76 | param_list = [] 77 | 78 | temp_mixture = sample_args.temperature 79 | temp_pen = sample_args.temperature 80 | 81 | while count < N: 82 | #print "attempting to generate picture #", count 83 | print '.', 84 | [strokes, params] = model.sample(sess, sample_args.sample_length, temp_mixture, temp_pen, stop_if_eoc = True) 85 | [sx, sy, num_stroke, num_char, _] = strokes.sum(0) 86 | if num_stroke < min_num_stroke or num_char == 0 or num_stroke > max_num_stroke: 87 | #print "num_stroke ", num_stroke, " num_char ", num_char 88 | continue 89 | [sx, sy, sizex, sizey] = calculate_start_point(strokes) 90 | if sizex > max_size or sizey > max_size: 91 | #print "sizex ", sizex, " sizey ", sizey 92 | continue 93 | if sizex < min_size or sizey < min_size: 94 | #print "sizex ", sizex, " sizey ", sizey 95 | continue 96 | # success 97 | print count+1,"/",N 98 | count += 1 99 | sketch_list.append(strokes) 100 | param_list.append(params) 101 | # draw the pics 102 | draw_sketch_array(sketch_list, svg_only = svg_only) 103 | return sketch_list, param_list 104 | 105 | if __name__ == '__main__': 106 | ipython_mode = in_ipython() 107 | if ipython_mode: 108 | print "IPython detected" 109 | else: 110 | print "Console mode" 111 | [strokes, params] = sample_sketches(svg_only = not ipython_mode) 112 | 113 | 114 | -------------------------------------------------------------------------------- /save/kanji/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "model.ckpt-0" 2 | all_model_checkpoint_paths: "model.ckpt-0" 3 | -------------------------------------------------------------------------------- /save/kanji/config.pkl: -------------------------------------------------------------------------------- 1 | ccopy_reg 2 | _reconstructor 3 | p1 4 | (cargparse 5 | Namespace 6 | p2 7 | c__builtin__ 8 | object 9 | p3 10 | NtRp4 11 | (dp5 12 | S'grad_clip' 13 | p6 14 | F5 15 | sS'rnn_size' 16 | p7 17 | I256 18 | sS'data_scale' 19 | p8 20 | F15 21 | sS'learning_rate' 22 | p9 23 | F0.0050000000000000001 24 | sS'num_layers' 25 | p10 26 | I2 27 | sS'seq_length' 28 | p11 29 | I300 30 | sS'decay_rate' 31 | p12 32 | F0.98999999999999999 33 | sS'num_mixture' 34 | p13 35 | I24 36 | sS'batch_size' 37 | p14 38 | I100 39 | sS'num_epochs' 40 | p15 41 | I500 42 | sS'dataset_name' 43 | p16 44 | S'kanji' 45 | p17 46 | sS'model' 47 | p18 48 | S'lstm' 49 | p19 50 | sS'save_every' 51 | p20 52 | I250 53 | sS'keep_prob' 54 | p21 55 | F0.80000000000000004 56 | sS'stroke_importance_factor' 57 | p22 58 | F200 59 | sb. -------------------------------------------------------------------------------- /save/kanji/model.ckpt-0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hardmaru/sketch-rnn/862bb94a15d48f0c42af71a9ede1681a0a2d2602/save/kanji/model.ckpt-0 -------------------------------------------------------------------------------- /svg/__init__.py: -------------------------------------------------------------------------------- 1 | __import__('pkg_resources').declare_namespace(__name__) 2 | -------------------------------------------------------------------------------- /svg/path/__init__.py: -------------------------------------------------------------------------------- 1 | from .path import Path, Line, Arc, CubicBezier, QuadraticBezier 2 | from .parser import parse_path 3 | -------------------------------------------------------------------------------- /svg/path/parser.py: -------------------------------------------------------------------------------- 1 | # SVG Path specification parser 2 | 3 | import re 4 | from . import path 5 | 6 | COMMANDS = set('MmZzLlHhVvCcSsQqTtAa') 7 | UPPERCASE = set('MZLHVCSQTA') 8 | 9 | COMMAND_RE = re.compile("([MmZzLlHhVvCcSsQqTtAa])") 10 | FLOAT_RE = re.compile("[-+]?[0-9]*\.?[0-9]+(?:[eE][-+]?[0-9]+)?") 11 | 12 | 13 | def _tokenize_path(pathdef): 14 | for x in COMMAND_RE.split(pathdef): 15 | if x in COMMANDS: 16 | yield x 17 | for token in FLOAT_RE.findall(x): 18 | yield token 19 | 20 | 21 | def parse_path(pathdef, current_pos=0j): 22 | # In the SVG specs, initial movetos are absolute, even if 23 | # specified as 'm'. This is the default behavior here as well. 24 | # But if you pass in a current_pos variable, the initial moveto 25 | # will be relative to that current_pos. This is useful. 26 | elements = list(_tokenize_path(pathdef)) 27 | # Reverse for easy use of .pop() 28 | elements.reverse() 29 | 30 | segments = path.Path() 31 | start_pos = None 32 | command = None 33 | 34 | while elements: 35 | 36 | if elements[-1] in COMMANDS: 37 | # New command. 38 | last_command = command # Used by S and T 39 | command = elements.pop() 40 | absolute = command in UPPERCASE 41 | command = command.upper() 42 | else: 43 | # If this element starts with numbers, it is an implicit command 44 | # and we don't change the command. Check that it's allowed: 45 | if command is None: 46 | raise ValueError("Unallowed implicit command in %s, position %s" % ( 47 | pathdef, len(pathdef.split()) - len(elements))) 48 | 49 | if command == 'M': 50 | # Moveto command. 51 | x = elements.pop() 52 | y = elements.pop() 53 | pos = float(x) + float(y) * 1j 54 | if absolute: 55 | current_pos = pos 56 | else: 57 | current_pos += pos 58 | 59 | # when M is called, reset start_pos 60 | # This behavior of Z is defined in svg spec: 61 | # http://www.w3.org/TR/SVG/paths.html#PathDataClosePathCommand 62 | start_pos = current_pos 63 | 64 | # Implicit moveto commands are treated as lineto commands. 65 | # So we set command to lineto here, in case there are 66 | # further implicit commands after this moveto. 67 | command = 'L' 68 | 69 | elif command == 'Z': 70 | # Close path 71 | segments.append(path.Line(current_pos, start_pos)) 72 | segments.closed = True 73 | current_pos = start_pos 74 | start_pos = None 75 | command = None # You can't have implicit commands after closing. 76 | 77 | elif command == 'L': 78 | x = elements.pop() 79 | y = elements.pop() 80 | pos = float(x) + float(y) * 1j 81 | if not absolute: 82 | pos += current_pos 83 | segments.append(path.Line(current_pos, pos)) 84 | current_pos = pos 85 | 86 | elif command == 'H': 87 | x = elements.pop() 88 | pos = float(x) + current_pos.imag * 1j 89 | if not absolute: 90 | pos += current_pos.real 91 | segments.append(path.Line(current_pos, pos)) 92 | current_pos = pos 93 | 94 | elif command == 'V': 95 | y = elements.pop() 96 | pos = current_pos.real + float(y) * 1j 97 | if not absolute: 98 | pos += current_pos.imag * 1j 99 | segments.append(path.Line(current_pos, pos)) 100 | current_pos = pos 101 | 102 | elif command == 'C': 103 | control1 = float(elements.pop()) + float(elements.pop()) * 1j 104 | control2 = float(elements.pop()) + float(elements.pop()) * 1j 105 | end = float(elements.pop()) + float(elements.pop()) * 1j 106 | 107 | if not absolute: 108 | control1 += current_pos 109 | control2 += current_pos 110 | end += current_pos 111 | 112 | segments.append(path.CubicBezier(current_pos, control1, control2, end)) 113 | current_pos = end 114 | 115 | elif command == 'S': 116 | # Smooth curve. First control point is the "reflection" of 117 | # the second control point in the previous path. 118 | 119 | if last_command not in 'CS': 120 | # If there is no previous command or if the previous command 121 | # was not an C, c, S or s, assume the first control point is 122 | # coincident with the current point. 123 | control1 = current_pos 124 | else: 125 | # The first control point is assumed to be the reflection of 126 | # the second control point on the previous command relative 127 | # to the current point. 128 | control1 = current_pos + current_pos - segments[-1].control2 129 | 130 | control2 = float(elements.pop()) + float(elements.pop()) * 1j 131 | end = float(elements.pop()) + float(elements.pop()) * 1j 132 | 133 | if not absolute: 134 | control2 += current_pos 135 | end += current_pos 136 | 137 | segments.append(path.CubicBezier(current_pos, control1, control2, end)) 138 | current_pos = end 139 | 140 | elif command == 'Q': 141 | control = float(elements.pop()) + float(elements.pop()) * 1j 142 | end = float(elements.pop()) + float(elements.pop()) * 1j 143 | 144 | if not absolute: 145 | control += current_pos 146 | end += current_pos 147 | 148 | segments.append(path.QuadraticBezier(current_pos, control, end)) 149 | current_pos = end 150 | 151 | elif command == 'T': 152 | # Smooth curve. Control point is the "reflection" of 153 | # the second control point in the previous path. 154 | 155 | if last_command not in 'QT': 156 | # If there is no previous command or if the previous command 157 | # was not an Q, q, T or t, assume the first control point is 158 | # coincident with the current point. 159 | control = current_pos 160 | else: 161 | # The control point is assumed to be the reflection of 162 | # the control point on the previous command relative 163 | # to the current point. 164 | control = current_pos + current_pos - segments[-1].control 165 | 166 | end = float(elements.pop()) + float(elements.pop()) * 1j 167 | 168 | if not absolute: 169 | end += current_pos 170 | 171 | segments.append(path.QuadraticBezier(current_pos, control, end)) 172 | current_pos = end 173 | 174 | elif command == 'A': 175 | radius = float(elements.pop()) + float(elements.pop()) * 1j 176 | rotation = float(elements.pop()) 177 | arc = float(elements.pop()) 178 | sweep = float(elements.pop()) 179 | end = float(elements.pop()) + float(elements.pop()) * 1j 180 | 181 | if not absolute: 182 | end += current_pos 183 | 184 | segments.append(path.Arc(current_pos, radius, rotation, arc, sweep, end)) 185 | current_pos = end 186 | 187 | return segments 188 | -------------------------------------------------------------------------------- /svg/path/path.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from math import sqrt, cos, sin, acos, degrees, radians, log 3 | from collections import MutableSequence 4 | 5 | 6 | # This file contains classes for the different types of SVG path segments as 7 | # well as a Path object that contains a sequence of path segments. 8 | 9 | MIN_DEPTH = 5 10 | ERROR = 1e-12 11 | 12 | 13 | def segment_length(curve, start, end, start_point, end_point, error, min_depth, depth): 14 | """Recursively approximates the length by straight lines""" 15 | mid = (start + end) / 2 16 | mid_point = curve.point(mid) 17 | length = abs(end_point - start_point) 18 | first_half = abs(mid_point - start_point) 19 | second_half = abs(end_point - mid_point) 20 | 21 | length2 = first_half + second_half 22 | if (length2 - length > error) or (depth < min_depth): 23 | # Calculate the length of each segment: 24 | depth += 1 25 | return (segment_length(curve, start, mid, start_point, mid_point, 26 | error, min_depth, depth) + 27 | segment_length(curve, mid, end, mid_point, end_point, 28 | error, min_depth, depth)) 29 | # This is accurate enough. 30 | return length2 31 | 32 | 33 | class Line(object): 34 | 35 | def __init__(self, start, end): 36 | self.start = start 37 | self.end = end 38 | 39 | def __repr__(self): 40 | return 'Line(start=%s, end=%s)' % (self.start, self.end) 41 | 42 | def __eq__(self, other): 43 | if not isinstance(other, Line): 44 | return NotImplemented 45 | return self.start == other.start and self.end == other.end 46 | 47 | def __ne__(self, other): 48 | if not isinstance(other, Line): 49 | return NotImplemented 50 | return not self == other 51 | 52 | def point(self, pos): 53 | distance = self.end - self.start 54 | return self.start + distance * pos 55 | 56 | def length(self, error=None, min_depth=None): 57 | distance = (self.end - self.start) 58 | return sqrt(distance.real ** 2 + distance.imag ** 2) 59 | 60 | 61 | class CubicBezier(object): 62 | def __init__(self, start, control1, control2, end): 63 | self.start = start 64 | self.control1 = control1 65 | self.control2 = control2 66 | self.end = end 67 | 68 | def __repr__(self): 69 | return 'CubicBezier(start=%s, control1=%s, control2=%s, end=%s)' % ( 70 | self.start, self.control1, self.control2, self.end) 71 | 72 | def __eq__(self, other): 73 | if not isinstance(other, CubicBezier): 74 | return NotImplemented 75 | return self.start == other.start and self.end == other.end and \ 76 | self.control1 == other.control1 and self.control2 == other.control2 77 | 78 | def __ne__(self, other): 79 | if not isinstance(other, CubicBezier): 80 | return NotImplemented 81 | return not self == other 82 | 83 | def is_smooth_from(self, previous): 84 | """Checks if this segment would be a smooth segment following the previous""" 85 | if isinstance(previous, CubicBezier): 86 | return (self.start == previous.end and 87 | (self.control1 - self.start) == (previous.end - previous.control2)) 88 | else: 89 | return self.control1 == self.start 90 | 91 | def point(self, pos): 92 | """Calculate the x,y position at a certain position of the path""" 93 | return ((1 - pos) ** 3 * self.start) + \ 94 | (3 * (1 - pos) ** 2 * pos * self.control1) + \ 95 | (3 * (1 - pos) * pos ** 2 * self.control2) + \ 96 | (pos ** 3 * self.end) 97 | 98 | def length(self, error=ERROR, min_depth=MIN_DEPTH): 99 | """Calculate the length of the path up to a certain position""" 100 | start_point = self.point(0) 101 | end_point = self.point(1) 102 | return segment_length(self, 0, 1, start_point, end_point, error, min_depth, 0) 103 | 104 | 105 | class QuadraticBezier(object): 106 | def __init__(self, start, control, end): 107 | self.start = start 108 | self.end = end 109 | self.control = control 110 | 111 | def __repr__(self): 112 | return 'QuadraticBezier(start=%s, control=%s, end=%s)' % ( 113 | self.start, self.control, self.end) 114 | 115 | def __eq__(self, other): 116 | if not isinstance(other, QuadraticBezier): 117 | return NotImplemented 118 | return self.start == other.start and self.end == other.end and \ 119 | self.control == other.control 120 | 121 | def __ne__(self, other): 122 | if not isinstance(other, QuadraticBezier): 123 | return NotImplemented 124 | return not self == other 125 | 126 | def is_smooth_from(self, previous): 127 | """Checks if this segment would be a smooth segment following the previous""" 128 | if isinstance(previous, QuadraticBezier): 129 | return (self.start == previous.end and 130 | (self.control - self.start) == (previous.end - previous.control)) 131 | else: 132 | return self.control == self.start 133 | 134 | def point(self, pos): 135 | return (1 - pos) ** 2 * self.start + 2 * (1 - pos) * pos * self.control + \ 136 | pos ** 2 * self.end 137 | 138 | def length(self, error=None, min_depth=None): 139 | a = self.start - 2*self.control + self.end 140 | b = 2*(self.control - self.start) 141 | a_dot_b = a.real*b.real + a.imag*b.imag 142 | 143 | if abs(a) < 1e-12: 144 | s = abs(b) 145 | elif abs(a_dot_b + abs(a)*abs(b)) < 1e-12: 146 | k = abs(b)/abs(a) 147 | if k >= 2: 148 | s = abs(b) - abs(a) 149 | else: 150 | s = abs(a)*(k**2/2 - k + 1) 151 | else: 152 | # For an explanation of this case, see 153 | # http://www.malczak.info/blog/quadratic-bezier-curve-length/ 154 | A = 4 * (a.real ** 2 + a.imag ** 2) 155 | B = 4 * (a.real * b.real + a.imag * b.imag) 156 | C = b.real ** 2 + b.imag ** 2 157 | 158 | Sabc = 2 * sqrt(A + B + C) 159 | A2 = sqrt(A) 160 | A32 = 2 * A * A2 161 | C2 = 2 * sqrt(C) 162 | BA = B / A2 163 | 164 | s = (A32 * Sabc + A2 * B * (Sabc - C2) + (4 * C * A - B ** 2) * 165 | log((2 * A2 + BA + Sabc) / (BA + C2))) / (4 * A32) 166 | return s 167 | 168 | class Arc(object): 169 | 170 | def __init__(self, start, radius, rotation, arc, sweep, end): 171 | """radius is complex, rotation is in degrees, 172 | large and sweep are 1 or 0 (True/False also work)""" 173 | 174 | self.start = start 175 | self.radius = radius 176 | self.rotation = rotation 177 | self.arc = bool(arc) 178 | self.sweep = bool(sweep) 179 | self.end = end 180 | 181 | self._parameterize() 182 | 183 | def __repr__(self): 184 | return 'Arc(start=%s, radius=%s, rotation=%s, arc=%s, sweep=%s, end=%s)' % ( 185 | self.start, self.radius, self.rotation, self.arc, self.sweep, self.end) 186 | 187 | def __eq__(self, other): 188 | if not isinstance(other, Arc): 189 | return NotImplemented 190 | return self.start == other.start and self.end == other.end and \ 191 | self.radius == other.radius and self.rotation == other.rotation and \ 192 | self.arc == other.arc and self.sweep == other.sweep 193 | 194 | def __ne__(self, other): 195 | if not isinstance(other, Arc): 196 | return NotImplemented 197 | return not self == other 198 | 199 | def _parameterize(self): 200 | # Conversion from endpoint to center parameterization 201 | # http://www.w3.org/TR/SVG/implnote.html#ArcImplementationNotes 202 | 203 | cosr = cos(radians(self.rotation)) 204 | sinr = sin(radians(self.rotation)) 205 | dx = (self.start.real - self.end.real) / 2 206 | dy = (self.start.imag - self.end.imag) / 2 207 | x1prim = cosr * dx + sinr * dy 208 | x1prim_sq = x1prim * x1prim 209 | y1prim = -sinr * dx + cosr * dy 210 | y1prim_sq = y1prim * y1prim 211 | 212 | rx = self.radius.real 213 | rx_sq = rx * rx 214 | ry = self.radius.imag 215 | ry_sq = ry * ry 216 | 217 | # Correct out of range radii 218 | radius_check = (x1prim_sq / rx_sq) + (y1prim_sq / ry_sq) 219 | if radius_check > 1: 220 | rx *= sqrt(radius_check) 221 | ry *= sqrt(radius_check) 222 | rx_sq = rx * rx 223 | ry_sq = ry * ry 224 | 225 | t1 = rx_sq * y1prim_sq 226 | t2 = ry_sq * x1prim_sq 227 | c = sqrt(abs((rx_sq * ry_sq - t1 - t2) / (t1 + t2))) 228 | 229 | if self.arc == self.sweep: 230 | c = -c 231 | cxprim = c * rx * y1prim / ry 232 | cyprim = -c * ry * x1prim / rx 233 | 234 | self.center = complex((cosr * cxprim - sinr * cyprim) + 235 | ((self.start.real + self.end.real) / 2), 236 | (sinr * cxprim + cosr * cyprim) + 237 | ((self.start.imag + self.end.imag) / 2)) 238 | 239 | ux = (x1prim - cxprim) / rx 240 | uy = (y1prim - cyprim) / ry 241 | vx = (-x1prim - cxprim) / rx 242 | vy = (-y1prim - cyprim) / ry 243 | n = sqrt(ux * ux + uy * uy) 244 | p = ux 245 | theta = degrees(acos(p / n)) 246 | if uy < 0: 247 | theta = -theta 248 | self.theta = theta % 360 249 | 250 | n = sqrt((ux * ux + uy * uy) * (vx * vx + vy * vy)) 251 | p = ux * vx + uy * vy 252 | if p == 0: 253 | delta = degrees(acos(0)) 254 | else: 255 | delta = degrees(acos(p / n)) 256 | if (ux * vy - uy * vx) < 0: 257 | delta = -delta 258 | self.delta = delta % 360 259 | if not self.sweep: 260 | self.delta -= 360 261 | 262 | def point(self, pos): 263 | angle = radians(self.theta + (self.delta * pos)) 264 | cosr = cos(radians(self.rotation)) 265 | sinr = sin(radians(self.rotation)) 266 | 267 | x = (cosr * cos(angle) * self.radius.real - sinr * sin(angle) * 268 | self.radius.imag + self.center.real) 269 | y = (sinr * cos(angle) * self.radius.real + cosr * sin(angle) * 270 | self.radius.imag + self.center.imag) 271 | return complex(x, y) 272 | 273 | def length(self, error=ERROR, min_depth=MIN_DEPTH): 274 | """The length of an elliptical arc segment requires numerical 275 | integration, and in that case it's simpler to just do a geometric 276 | approximation, as for cubic bezier curves. 277 | """ 278 | start_point = self.point(0) 279 | end_point = self.point(1) 280 | return segment_length(self, 0, 1, start_point, end_point, error, min_depth, 0) 281 | 282 | 283 | class Path(MutableSequence): 284 | """A Path is a sequence of path segments""" 285 | 286 | # Put it here, so there is a default if unpickled. 287 | _closed = False 288 | 289 | def __init__(self, *segments, **kw): 290 | self._segments = list(segments) 291 | self._length = None 292 | self._lengths = None 293 | if 'closed' in kw: 294 | self.closed = kw['closed'] 295 | 296 | def __getitem__(self, index): 297 | return self._segments[index] 298 | 299 | def __setitem__(self, index, value): 300 | self._segments[index] = value 301 | self._length = None 302 | 303 | def __delitem__(self, index): 304 | del self._segments[index] 305 | self._length = None 306 | 307 | def insert(self, index, value): 308 | self._segments.insert(index, value) 309 | self._length = None 310 | 311 | def reverse(self): 312 | # Reversing the order of a path would require reversing each element 313 | # as well. That's not implemented. 314 | raise NotImplementedError 315 | 316 | def __len__(self): 317 | return len(self._segments) 318 | 319 | def __repr__(self): 320 | return 'Path(%s, closed=%s)' % ( 321 | ', '.join(repr(x) for x in self._segments), self.closed) 322 | 323 | def __eq__(self, other): 324 | if not isinstance(other, Path): 325 | return NotImplemented 326 | if len(self) != len(other): 327 | return False 328 | for s, o in zip(self._segments, other._segments): 329 | if not s == o: 330 | return False 331 | return True 332 | 333 | def __ne__(self, other): 334 | if not isinstance(other, Path): 335 | return NotImplemented 336 | return not self == other 337 | 338 | def _calc_lengths(self, error=ERROR, min_depth=MIN_DEPTH): 339 | if self._length is not None: 340 | return 341 | 342 | lengths = [each.length(error=error, min_depth=min_depth) for each in self._segments] 343 | self._length = sum(lengths) 344 | self._lengths = [each / self._length for each in lengths] 345 | 346 | def point(self, pos, error=ERROR): 347 | 348 | # Shortcuts 349 | if pos == 0.0: 350 | return self._segments[0].point(pos) 351 | if pos == 1.0: 352 | return self._segments[-1].point(pos) 353 | 354 | self._calc_lengths(error=error) 355 | # Find which segment the point we search for is located on: 356 | segment_start = 0 357 | for index, segment in enumerate(self._segments): 358 | segment_end = segment_start + self._lengths[index] 359 | if segment_end >= pos: 360 | # This is the segment! How far in on the segment is the point? 361 | segment_pos = (pos - segment_start) / (segment_end - segment_start) 362 | break 363 | segment_start = segment_end 364 | 365 | return segment.point(segment_pos) 366 | 367 | def length(self, error=ERROR, min_depth=MIN_DEPTH): 368 | self._calc_lengths(error, min_depth) 369 | return self._length 370 | 371 | def _is_closable(self): 372 | """Returns true if the end is on the start of a segment""" 373 | end = self[-1].end 374 | for segment in self: 375 | if segment.start == end: 376 | return True 377 | return False 378 | 379 | @property 380 | def closed(self): 381 | """Checks that the path is closed""" 382 | return self._closed and self._is_closable() 383 | 384 | @closed.setter 385 | def closed(self, value): 386 | value = bool(value) 387 | if value and not self._is_closable(): 388 | raise ValueError("End does not coincide with a segment start.") 389 | self._closed = value 390 | 391 | def d(self): 392 | if self.closed: 393 | segments = self[:-1] 394 | else: 395 | segments = self[:] 396 | 397 | current_pos = None 398 | parts = [] 399 | previous_segment = None 400 | end = self[-1].end 401 | 402 | for segment in segments: 403 | start = segment.start 404 | # If the start of this segment does not coincide with the end of 405 | # the last segment or if this segment is actually the close point 406 | # of a closed path, then we should start a new subpath here. 407 | if current_pos != start or (self.closed and start == end): 408 | parts.append('M {0:G},{1:G}'.format(start.real, start.imag)) 409 | 410 | if isinstance(segment, Line): 411 | parts.append('L {0:G},{1:G}'.format( 412 | segment.end.real, segment.end.imag) 413 | ) 414 | elif isinstance(segment, CubicBezier): 415 | if segment.is_smooth_from(previous_segment): 416 | parts.append('S {0:G},{1:G} {2:G},{3:G}'.format( 417 | segment.control2.real, segment.control2.imag, 418 | segment.end.real, segment.end.imag) 419 | ) 420 | else: 421 | parts.append('C {0:G},{1:G} {2:G},{3:G} {4:G},{5:G}'.format( 422 | segment.control1.real, segment.control1.imag, 423 | segment.control2.real, segment.control2.imag, 424 | segment.end.real, segment.end.imag) 425 | ) 426 | elif isinstance(segment, QuadraticBezier): 427 | if segment.is_smooth_from(previous_segment): 428 | parts.append('T {0:G},{1:G}'.format( 429 | segment.end.real, segment.end.imag) 430 | ) 431 | else: 432 | parts.append('Q {0:G},{1:G} {2:G},{3:G}'.format( 433 | segment.control.real, segment.control.imag, 434 | segment.end.real, segment.end.imag) 435 | ) 436 | 437 | elif isinstance(segment, Arc): 438 | parts.append('A {0:G},{1:G} {2:G} {3:d},{4:d} {5:G},{6:G}'.format( 439 | segment.radius.real, segment.radius.imag, segment.rotation, 440 | int(segment.arc), int(segment.sweep), 441 | segment.end.real, segment.end.imag) 442 | ) 443 | current_pos = segment.end 444 | previous_segment = segment 445 | 446 | if self.closed: 447 | parts.append('Z') 448 | 449 | return ' '.join(parts) 450 | -------------------------------------------------------------------------------- /svg/path/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hardmaru/sketch-rnn/862bb94a15d48f0c42af71a9ede1681a0a2d2602/svg/path/tests/__init__.py -------------------------------------------------------------------------------- /svg/path/tests/test_doc.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import doctest 3 | 4 | 5 | def load_tests(loader, tests, ignore): 6 | tests.addTests(doctest.DocFileSuite('README.rst', package='__main__')) 7 | return tests 8 | -------------------------------------------------------------------------------- /svg/path/tests/test_generation.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import unittest 3 | from ..path import CubicBezier, QuadraticBezier, Line, Arc, Path 4 | from ..parser import parse_path 5 | 6 | 7 | class TestGeneration(unittest.TestCase): 8 | 9 | def test_svg_examples(self): 10 | """Examples from the SVG spec""" 11 | paths = [ 12 | 'M 100,100 L 300,100 L 200,300 Z', 13 | 'M 0,0 L 50,20 M 100,100 L 300,100 L 200,300 Z', 14 | 'M 100,100 L 200,200', 15 | 'M 100,200 L 200,100 L -100,-200', 16 | 'M 100,200 C 100,100 250,100 250,200 S 400,300 400,200', 17 | 'M 100,200 C 100,100 400,100 400,200', 18 | 'M 100,500 C 25,400 475,400 400,500', 19 | 'M 100,800 C 175,700 325,700 400,800', 20 | 'M 600,200 C 675,100 975,100 900,200', 21 | 'M 600,500 C 600,350 900,650 900,500', 22 | 'M 600,800 C 625,700 725,700 750,800 S 875,900 900,800', 23 | 'M 200,300 Q 400,50 600,300 T 1000,300', 24 | 'M -3.4E+38,3.4E+38 L -3.4E-38,3.4E-38', 25 | 'M 0,0 L 50,20 M 50,20 L 200,100 Z', 26 | 'M 600,350 L 650,325 A 25,25 -30 0,1 700,300 L 750,275', 27 | ] 28 | 29 | for path in paths: 30 | self.assertEqual(parse_path(path).d(), path) 31 | 32 | def test_normalizing(self): 33 | # Relative paths will be made absolute, subpaths merged if they can, 34 | # and syntax will change. 35 | self.assertEqual(parse_path('M0 0L3.4E2-10L100.0,100M100,100l100,-100').d(), 36 | 'M 0,0 L 340,-10 L 100,100 L 200,0') 37 | -------------------------------------------------------------------------------- /svg/path/tests/test_parsing.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import unittest 3 | from ..path import CubicBezier, QuadraticBezier, Line, Arc, Path 4 | from ..parser import parse_path 5 | 6 | 7 | class TestParser(unittest.TestCase): 8 | 9 | def test_svg_examples(self): 10 | """Examples from the SVG spec""" 11 | path1 = parse_path('M 100 100 L 300 100 L 200 300 z') 12 | self.assertEqual(path1, Path(Line(100 + 100j, 300 + 100j), 13 | Line(300 + 100j, 200 + 300j), 14 | Line(200 + 300j, 100 + 100j))) 15 | self.assertTrue(path1.closed) 16 | 17 | # for Z command behavior when there is multiple subpaths 18 | path1 = parse_path('M 0 0 L 50 20 M 100 100 L 300 100 L 200 300 z') 19 | self.assertEqual(path1, Path( 20 | Line(0 + 0j, 50 + 20j), 21 | Line(100 + 100j, 300 + 100j), 22 | Line(300 + 100j, 200 + 300j), 23 | Line(200 + 300j, 100 + 100j))) 24 | 25 | path1 = parse_path('M 100 100 L 200 200') 26 | path2 = parse_path('M100 100L200 200') 27 | self.assertEqual(path1, path2) 28 | 29 | path1 = parse_path('M 100 200 L 200 100 L -100 -200') 30 | path2 = parse_path('M 100 200 L 200 100 -100 -200') 31 | self.assertEqual(path1, path2) 32 | 33 | path1 = parse_path("""M100,200 C100,100 250,100 250,200 34 | S400,300 400,200""") 35 | self.assertEqual(path1, 36 | Path(CubicBezier(100 + 200j, 100 + 100j, 250 + 100j, 250 + 200j), 37 | CubicBezier(250 + 200j, 250 + 300j, 400 + 300j, 400 + 200j))) 38 | 39 | path1 = parse_path('M100,200 C100,100 400,100 400,200') 40 | self.assertEqual(path1, 41 | Path(CubicBezier(100 + 200j, 100 + 100j, 400 + 100j, 400 + 200j))) 42 | 43 | path1 = parse_path('M100,500 C25,400 475,400 400,500') 44 | self.assertEqual(path1, 45 | Path(CubicBezier(100 + 500j, 25 + 400j, 475 + 400j, 400 + 500j))) 46 | 47 | path1 = parse_path('M100,800 C175,700 325,700 400,800') 48 | self.assertEqual(path1, 49 | Path(CubicBezier(100 + 800j, 175 + 700j, 325 + 700j, 400 + 800j))) 50 | 51 | path1 = parse_path('M600,200 C675,100 975,100 900,200') 52 | self.assertEqual(path1, 53 | Path(CubicBezier(600 + 200j, 675 + 100j, 975 + 100j, 900 + 200j))) 54 | 55 | path1 = parse_path('M600,500 C600,350 900,650 900,500') 56 | self.assertEqual(path1, 57 | Path(CubicBezier(600 + 500j, 600 + 350j, 900 + 650j, 900 + 500j))) 58 | 59 | path1 = parse_path("""M600,800 C625,700 725,700 750,800 60 | S875,900 900,800""") 61 | self.assertEqual(path1, 62 | Path(CubicBezier(600 + 800j, 625 + 700j, 725 + 700j, 750 + 800j), 63 | CubicBezier(750 + 800j, 775 + 900j, 875 + 900j, 900 + 800j))) 64 | 65 | path1 = parse_path('M200,300 Q400,50 600,300 T1000,300') 66 | self.assertEqual(path1, 67 | Path(QuadraticBezier(200 + 300j, 400 + 50j, 600 + 300j), 68 | QuadraticBezier(600 + 300j, 800 + 550j, 1000 + 300j))) 69 | 70 | path1 = parse_path('M300,200 h-150 a150,150 0 1,0 150,-150 z') 71 | self.assertEqual(path1, 72 | Path(Line(300 + 200j, 150 + 200j), 73 | Arc(150 + 200j, 150 + 150j, 0, 1, 0, 300 + 50j), 74 | Line(300 + 50j, 300 + 200j))) 75 | 76 | path1 = parse_path('M275,175 v-150 a150,150 0 0,0 -150,150 z') 77 | self.assertEqual(path1, 78 | Path(Line(275 + 175j, 275 + 25j), 79 | Arc(275 + 25j, 150 + 150j, 0, 0, 0, 125 + 175j), 80 | Line(125 + 175j, 275 + 175j))) 81 | 82 | path1 = parse_path("""M600,350 l 50,-25 83 | a25,25 -30 0,1 50,-25 l 50,-25 84 | a25,50 -30 0,1 50,-25 l 50,-25 85 | a25,75 -30 0,1 50,-25 l 50,-25 86 | a25,100 -30 0,1 50,-25 l 50,-25""") 87 | self.assertEqual(path1, 88 | Path(Line(600 + 350j, 650 + 325j), 89 | Arc(650 + 325j, 25 + 25j, -30, 0, 1, 700 + 300j), 90 | Line(700 + 300j, 750 + 275j), 91 | Arc(750 + 275j, 25 + 50j, -30, 0, 1, 800 + 250j), 92 | Line(800 + 250j, 850 + 225j), 93 | Arc(850 + 225j, 25 + 75j, -30, 0, 1, 900 + 200j), 94 | Line(900 + 200j, 950 + 175j), 95 | Arc(950 + 175j, 25 + 100j, -30, 0, 1, 1000 + 150j), 96 | Line(1000 + 150j, 1050 + 125j))) 97 | 98 | def test_others(self): 99 | # Other paths that need testing: 100 | 101 | # Relative moveto: 102 | path1 = parse_path('M 0 0 L 50 20 m 50 80 L 300 100 L 200 300 z') 103 | self.assertEqual(path1, Path( 104 | Line(0 + 0j, 50 + 20j), 105 | Line(100 + 100j, 300 + 100j), 106 | Line(300 + 100j, 200 + 300j), 107 | Line(200 + 300j, 100 + 100j))) 108 | 109 | # Initial smooth and relative CubicBezier 110 | path1 = parse_path("""M100,200 s 150,-100 150,0""") 111 | self.assertEqual(path1, 112 | Path(CubicBezier(100 + 200j, 100 + 200j, 250 + 100j, 250 + 200j))) 113 | 114 | # Initial smooth and relative QuadraticBezier 115 | path1 = parse_path("""M100,200 t 150,0""") 116 | self.assertEqual(path1, 117 | Path(QuadraticBezier(100 + 200j, 100 + 200j, 250 + 200j))) 118 | 119 | # Relative QuadraticBezier 120 | path1 = parse_path("""M100,200 q 0,0 150,0""") 121 | self.assertEqual(path1, 122 | Path(QuadraticBezier(100 + 200j, 100 + 200j, 250 + 200j))) 123 | 124 | def test_negative(self): 125 | """You don't need spaces before a minus-sign""" 126 | path1 = parse_path('M100,200c10-5,20-10,30-20') 127 | path2 = parse_path('M 100 200 c 10 -5 20 -10 30 -20') 128 | self.assertEqual(path1, path2) 129 | 130 | def test_numbers(self): 131 | """Exponents and other number format cases""" 132 | # It can be e or E, the plus is optional, and a minimum of +/-3.4e38 must be supported. 133 | path1 = parse_path('M-3.4e38 3.4E+38L-3.4E-38,3.4e-38') 134 | path2 = Path(Line(-3.4e+38 + 3.4e+38j, -3.4e-38 + 3.4e-38j)) 135 | self.assertEqual(path1, path2) 136 | 137 | def test_errors(self): 138 | self.assertRaises(ValueError, parse_path, 'M 100 100 L 200 200 Z 100 200') 139 | -------------------------------------------------------------------------------- /svg/path/tests/test_paths.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import unittest 3 | from math import sqrt, pi 4 | 5 | from ..path import CubicBezier, QuadraticBezier, Line, Arc, Path 6 | 7 | 8 | # Most of these test points are not calculated serparately, as that would 9 | # take too long and be too error prone. Instead the curves have been verified 10 | # to be correct visually, by drawing them with the turtle module, with code 11 | # like this: 12 | # 13 | # import turtle 14 | # t = turtle.Turtle() 15 | # t.penup() 16 | # 17 | # for arc in (path1, path2): 18 | # p = arc.point(0) 19 | # t.goto(p.real - 500, -p.imag + 300) 20 | # t.dot(3, 'black') 21 | # t.pendown() 22 | # for x in range(1, 101): 23 | # p = arc.point(x * 0.01) 24 | # t.goto(p.real - 500, -p.imag + 300) 25 | # t.penup() 26 | # t.dot(3, 'black') 27 | # 28 | # raw_input() 29 | # 30 | # After the paths have been verified to be correct this way, the testing of 31 | # points along the paths has been added as regression tests, to make sure 32 | # nobody changes the way curves are drawn by mistake. Therefore, do not take 33 | # these points religiously. They might be subtly wrong, unless otherwise 34 | # noted. 35 | 36 | class LineTest(unittest.TestCase): 37 | 38 | def test_lines(self): 39 | # These points are calculated, and not just regression tests. 40 | 41 | line1 = Line(0j, 400 + 0j) 42 | self.assertAlmostEqual(line1.point(0), (0j)) 43 | self.assertAlmostEqual(line1.point(0.3), (120 + 0j)) 44 | self.assertAlmostEqual(line1.point(0.5), (200 + 0j)) 45 | self.assertAlmostEqual(line1.point(0.9), (360 + 0j)) 46 | self.assertAlmostEqual(line1.point(1), (400 + 0j)) 47 | self.assertAlmostEqual(line1.length(), 400) 48 | 49 | line2 = Line(400 + 0j, 400 + 300j) 50 | self.assertAlmostEqual(line2.point(0), (400 + 0j)) 51 | self.assertAlmostEqual(line2.point(0.3), (400 + 90j)) 52 | self.assertAlmostEqual(line2.point(0.5), (400 + 150j)) 53 | self.assertAlmostEqual(line2.point(0.9), (400 + 270j)) 54 | self.assertAlmostEqual(line2.point(1), (400 + 300j)) 55 | self.assertAlmostEqual(line2.length(), 300) 56 | 57 | line3 = Line(400 + 300j, 0j) 58 | self.assertAlmostEqual(line3.point(0), (400 + 300j)) 59 | self.assertAlmostEqual(line3.point(0.3), (280 + 210j)) 60 | self.assertAlmostEqual(line3.point(0.5), (200 + 150j)) 61 | self.assertAlmostEqual(line3.point(0.9), (40 + 30j)) 62 | self.assertAlmostEqual(line3.point(1), (0j)) 63 | self.assertAlmostEqual(line3.length(), 500) 64 | 65 | def test_equality(self): 66 | # This is to test the __eq__ and __ne__ methods, so we can't use 67 | # assertEqual and assertNotEqual 68 | line = Line(0j, 400 + 0j) 69 | self.assertTrue(line == Line(0, 400)) 70 | self.assertTrue(line != Line(100, 400)) 71 | self.assertFalse(line == str(line)) 72 | self.assertTrue(line != str(line)) 73 | self.assertFalse(CubicBezier(600 + 500j, 600 + 350j, 900 + 650j, 900 + 500j) == 74 | line) 75 | 76 | 77 | class CubicBezierTest(unittest.TestCase): 78 | def test_approx_circle(self): 79 | """This is a approximate circle drawn in Inkscape""" 80 | 81 | arc1 = CubicBezier( 82 | complex(0, 0), 83 | complex(0, 109.66797), 84 | complex(-88.90345, 198.57142), 85 | complex(-198.57142, 198.57142) 86 | ) 87 | 88 | self.assertAlmostEqual(arc1.point(0), (0j)) 89 | self.assertAlmostEqual(arc1.point(0.1), (-2.59896457 + 32.20931647j)) 90 | self.assertAlmostEqual(arc1.point(0.2), (-10.12330256 + 62.76392816j)) 91 | self.assertAlmostEqual(arc1.point(0.3), (-22.16418039 + 91.25500149j)) 92 | self.assertAlmostEqual(arc1.point(0.4), (-38.31276448 + 117.27370288j)) 93 | self.assertAlmostEqual(arc1.point(0.5), (-58.16022125 + 140.41119875j)) 94 | self.assertAlmostEqual(arc1.point(0.6), (-81.29771712 + 160.25865552j)) 95 | self.assertAlmostEqual(arc1.point(0.7), (-107.31641851 + 176.40723961j)) 96 | self.assertAlmostEqual(arc1.point(0.8), (-135.80749184 + 188.44811744j)) 97 | self.assertAlmostEqual(arc1.point(0.9), (-166.36210353 + 195.97245543j)) 98 | self.assertAlmostEqual(arc1.point(1), (-198.57142 + 198.57142j)) 99 | 100 | arc2 = CubicBezier( 101 | complex(-198.57142, 198.57142), 102 | complex(-109.66797 - 198.57142, 0 + 198.57142), 103 | complex(-198.57143 - 198.57142, -88.90345 + 198.57142), 104 | complex(-198.57143 - 198.57142, 0), 105 | ) 106 | 107 | self.assertAlmostEqual(arc2.point(0), (-198.57142 + 198.57142j)) 108 | self.assertAlmostEqual(arc2.point(0.1), (-230.78073675 + 195.97245543j)) 109 | self.assertAlmostEqual(arc2.point(0.2), (-261.3353492 + 188.44811744j)) 110 | self.assertAlmostEqual(arc2.point(0.3), (-289.82642365 + 176.40723961j)) 111 | self.assertAlmostEqual(arc2.point(0.4), (-315.8451264 + 160.25865552j)) 112 | self.assertAlmostEqual(arc2.point(0.5), (-338.98262375 + 140.41119875j)) 113 | self.assertAlmostEqual(arc2.point(0.6), (-358.830082 + 117.27370288j)) 114 | self.assertAlmostEqual(arc2.point(0.7), (-374.97866745 + 91.25500149j)) 115 | self.assertAlmostEqual(arc2.point(0.8), (-387.0195464 + 62.76392816j)) 116 | self.assertAlmostEqual(arc2.point(0.9), (-394.54388515 + 32.20931647j)) 117 | self.assertAlmostEqual(arc2.point(1), (-397.14285 + 0j)) 118 | 119 | arc3 = CubicBezier( 120 | complex(-198.57143 - 198.57142, 0), 121 | complex(0 - 198.57143 - 198.57142, -109.66797), 122 | complex(88.90346 - 198.57143 - 198.57142, -198.57143), 123 | complex(-198.57142, -198.57143) 124 | ) 125 | 126 | self.assertAlmostEqual(arc3.point(0), (-397.14285 + 0j)) 127 | self.assertAlmostEqual(arc3.point(0.1), (-394.54388515 - 32.20931675j)) 128 | self.assertAlmostEqual(arc3.point(0.2), (-387.0195464 - 62.7639292j)) 129 | self.assertAlmostEqual(arc3.point(0.3), (-374.97866745 - 91.25500365j)) 130 | self.assertAlmostEqual(arc3.point(0.4), (-358.830082 - 117.2737064j)) 131 | self.assertAlmostEqual(arc3.point(0.5), (-338.98262375 - 140.41120375j)) 132 | self.assertAlmostEqual(arc3.point(0.6), (-315.8451264 - 160.258662j)) 133 | self.assertAlmostEqual(arc3.point(0.7), (-289.82642365 - 176.40724745j)) 134 | self.assertAlmostEqual(arc3.point(0.8), (-261.3353492 - 188.4481264j)) 135 | self.assertAlmostEqual(arc3.point(0.9), (-230.78073675 - 195.97246515j)) 136 | self.assertAlmostEqual(arc3.point(1), (-198.57142 - 198.57143j)) 137 | 138 | arc4 = CubicBezier( 139 | complex(-198.57142, -198.57143), 140 | complex(109.66797 - 198.57142, 0 - 198.57143), 141 | complex(0, 88.90346 - 198.57143), 142 | complex(0, 0), 143 | ) 144 | 145 | self.assertAlmostEqual(arc4.point(0), (-198.57142 - 198.57143j)) 146 | self.assertAlmostEqual(arc4.point(0.1), (-166.36210353 - 195.97246515j)) 147 | self.assertAlmostEqual(arc4.point(0.2), (-135.80749184 - 188.4481264j)) 148 | self.assertAlmostEqual(arc4.point(0.3), (-107.31641851 - 176.40724745j)) 149 | self.assertAlmostEqual(arc4.point(0.4), (-81.29771712 - 160.258662j)) 150 | self.assertAlmostEqual(arc4.point(0.5), (-58.16022125 - 140.41120375j)) 151 | self.assertAlmostEqual(arc4.point(0.6), (-38.31276448 - 117.2737064j)) 152 | self.assertAlmostEqual(arc4.point(0.7), (-22.16418039 - 91.25500365j)) 153 | self.assertAlmostEqual(arc4.point(0.8), (-10.12330256 - 62.7639292j)) 154 | self.assertAlmostEqual(arc4.point(0.9), (-2.59896457 - 32.20931675j)) 155 | self.assertAlmostEqual(arc4.point(1), (0j)) 156 | 157 | def test_svg_examples(self): 158 | 159 | # M100,200 C100,100 250,100 250,200 160 | path1 = CubicBezier(100 + 200j, 100 + 100j, 250 + 100j, 250 + 200j) 161 | self.assertAlmostEqual(path1.point(0), (100 + 200j)) 162 | self.assertAlmostEqual(path1.point(0.3), (132.4 + 137j)) 163 | self.assertAlmostEqual(path1.point(0.5), (175 + 125j)) 164 | self.assertAlmostEqual(path1.point(0.9), (245.8 + 173j)) 165 | self.assertAlmostEqual(path1.point(1), (250 + 200j)) 166 | 167 | # S400,300 400,200 168 | path2 = CubicBezier(250 + 200j, 250 + 300j, 400 + 300j, 400 + 200j) 169 | self.assertAlmostEqual(path2.point(0), (250 + 200j)) 170 | self.assertAlmostEqual(path2.point(0.3), (282.4 + 263j)) 171 | self.assertAlmostEqual(path2.point(0.5), (325 + 275j)) 172 | self.assertAlmostEqual(path2.point(0.9), (395.8 + 227j)) 173 | self.assertAlmostEqual(path2.point(1), (400 + 200j)) 174 | 175 | # M100,200 C100,100 400,100 400,200 176 | path3 = CubicBezier(100 + 200j, 100 + 100j, 400 + 100j, 400 + 200j) 177 | self.assertAlmostEqual(path3.point(0), (100 + 200j)) 178 | self.assertAlmostEqual(path3.point(0.3), (164.8 + 137j)) 179 | self.assertAlmostEqual(path3.point(0.5), (250 + 125j)) 180 | self.assertAlmostEqual(path3.point(0.9), (391.6 + 173j)) 181 | self.assertAlmostEqual(path3.point(1), (400 + 200j)) 182 | 183 | # M100,500 C25,400 475,400 400,500 184 | path4 = CubicBezier(100 + 500j, 25 + 400j, 475 + 400j, 400 + 500j) 185 | self.assertAlmostEqual(path4.point(0), (100 + 500j)) 186 | self.assertAlmostEqual(path4.point(0.3), (145.9 + 437j)) 187 | self.assertAlmostEqual(path4.point(0.5), (250 + 425j)) 188 | self.assertAlmostEqual(path4.point(0.9), (407.8 + 473j)) 189 | self.assertAlmostEqual(path4.point(1), (400 + 500j)) 190 | 191 | # M100,800 C175,700 325,700 400,800 192 | path5 = CubicBezier(100 + 800j, 175 + 700j, 325 + 700j, 400 + 800j) 193 | self.assertAlmostEqual(path5.point(0), (100 + 800j)) 194 | self.assertAlmostEqual(path5.point(0.3), (183.7 + 737j)) 195 | self.assertAlmostEqual(path5.point(0.5), (250 + 725j)) 196 | self.assertAlmostEqual(path5.point(0.9), (375.4 + 773j)) 197 | self.assertAlmostEqual(path5.point(1), (400 + 800j)) 198 | 199 | # M600,200 C675,100 975,100 900,200 200 | path6 = CubicBezier(600 + 200j, 675 + 100j, 975 + 100j, 900 + 200j) 201 | self.assertAlmostEqual(path6.point(0), (600 + 200j)) 202 | self.assertAlmostEqual(path6.point(0.3), (712.05 + 137j)) 203 | self.assertAlmostEqual(path6.point(0.5), (806.25 + 125j)) 204 | self.assertAlmostEqual(path6.point(0.9), (911.85 + 173j)) 205 | self.assertAlmostEqual(path6.point(1), (900 + 200j)) 206 | 207 | # M600,500 C600,350 900,650 900,500 208 | path7 = CubicBezier(600 + 500j, 600 + 350j, 900 + 650j, 900 + 500j) 209 | self.assertAlmostEqual(path7.point(0), (600 + 500j)) 210 | self.assertAlmostEqual(path7.point(0.3), (664.8 + 462.2j)) 211 | self.assertAlmostEqual(path7.point(0.5), (750 + 500j)) 212 | self.assertAlmostEqual(path7.point(0.9), (891.6 + 532.4j)) 213 | self.assertAlmostEqual(path7.point(1), (900 + 500j)) 214 | 215 | # M600,800 C625,700 725,700 750,800 216 | path8 = CubicBezier(600 + 800j, 625 + 700j, 725 + 700j, 750 + 800j) 217 | self.assertAlmostEqual(path8.point(0), (600 + 800j)) 218 | self.assertAlmostEqual(path8.point(0.3), (638.7 + 737j)) 219 | self.assertAlmostEqual(path8.point(0.5), (675 + 725j)) 220 | self.assertAlmostEqual(path8.point(0.9), (740.4 + 773j)) 221 | self.assertAlmostEqual(path8.point(1), (750 + 800j)) 222 | 223 | # S875,900 900,800 224 | inversion = (750 + 800j) + (750 + 800j) - (725 + 700j) 225 | path9 = CubicBezier(750 + 800j, inversion, 875 + 900j, 900 + 800j) 226 | self.assertAlmostEqual(path9.point(0), (750 + 800j)) 227 | self.assertAlmostEqual(path9.point(0.3), (788.7 + 863j)) 228 | self.assertAlmostEqual(path9.point(0.5), (825 + 875j)) 229 | self.assertAlmostEqual(path9.point(0.9), (890.4 + 827j)) 230 | self.assertAlmostEqual(path9.point(1), (900 + 800j)) 231 | 232 | def test_length(self): 233 | 234 | # A straight line: 235 | arc = CubicBezier( 236 | complex(0, 0), 237 | complex(0, 0), 238 | complex(0, 100), 239 | complex(0, 100) 240 | ) 241 | 242 | self.assertAlmostEqual(arc.length(), 100) 243 | 244 | # A diagonal line: 245 | arc = CubicBezier( 246 | complex(0, 0), 247 | complex(0, 0), 248 | complex(100, 100), 249 | complex(100, 100) 250 | ) 251 | 252 | self.assertAlmostEqual(arc.length(), sqrt(2 * 100 * 100)) 253 | 254 | # A quarter circle arc with radius 100: 255 | kappa = 4 * (sqrt(2) - 1) / 3 # http://www.whizkidtech.redprince.net/bezier/circle/ 256 | 257 | arc = CubicBezier( 258 | complex(0, 0), 259 | complex(0, kappa * 100), 260 | complex(100 - kappa * 100, 100), 261 | complex(100, 100) 262 | ) 263 | 264 | # We can't compare with pi*50 here, because this is just an 265 | # approximation of a circle arc. pi*50 is 157.079632679 266 | # So this is just yet another "warn if this changes" test. 267 | # This value is not verified to be correct. 268 | self.assertAlmostEqual(arc.length(), 157.1016698) 269 | 270 | # A recursive solution has also been suggested, but for CubicBezier 271 | # curves it could get a false solution on curves where the midpoint is on a 272 | # straight line between the start and end. For example, the following 273 | # curve would get solved as a straight line and get the length 300. 274 | # Make sure this is not the case. 275 | arc = CubicBezier( 276 | complex(600, 500), 277 | complex(600, 350), 278 | complex(900, 650), 279 | complex(900, 500) 280 | ) 281 | self.assertTrue(arc.length() > 300.0) 282 | 283 | def test_equality(self): 284 | # This is to test the __eq__ and __ne__ methods, so we can't use 285 | # assertEqual and assertNotEqual 286 | segment = CubicBezier(complex(600, 500), complex(600, 350), 287 | complex(900, 650), complex(900, 500)) 288 | 289 | self.assertTrue(segment == 290 | CubicBezier(600 + 500j, 600 + 350j, 900 + 650j, 900 + 500j)) 291 | self.assertTrue(segment != 292 | CubicBezier(600 + 501j, 600 + 350j, 900 + 650j, 900 + 500j)) 293 | self.assertTrue(segment != Line(0, 400)) 294 | 295 | 296 | class QuadraticBezierTest(unittest.TestCase): 297 | 298 | def test_svg_examples(self): 299 | """These is the path in the SVG specs""" 300 | # M200,300 Q400,50 600,300 T1000,300 301 | path1 = QuadraticBezier(200 + 300j, 400 + 50j, 600 + 300j) 302 | self.assertAlmostEqual(path1.point(0), (200 + 300j)) 303 | self.assertAlmostEqual(path1.point(0.3), (320 + 195j)) 304 | self.assertAlmostEqual(path1.point(0.5), (400 + 175j)) 305 | self.assertAlmostEqual(path1.point(0.9), (560 + 255j)) 306 | self.assertAlmostEqual(path1.point(1), (600 + 300j)) 307 | 308 | # T1000, 300 309 | inversion = (600 + 300j) + (600 + 300j) - (400 + 50j) 310 | path2 = QuadraticBezier(600 + 300j, inversion, 1000 + 300j) 311 | self.assertAlmostEqual(path2.point(0), (600 + 300j)) 312 | self.assertAlmostEqual(path2.point(0.3), (720 + 405j)) 313 | self.assertAlmostEqual(path2.point(0.5), (800 + 425j)) 314 | self.assertAlmostEqual(path2.point(0.9), (960 + 345j)) 315 | self.assertAlmostEqual(path2.point(1), (1000 + 300j)) 316 | 317 | def test_length(self): 318 | # expected results calculated with 319 | # svg.path.segment_length(q, 0, 1, q.start, q.end, 1e-14, 20, 0) 320 | q1 = QuadraticBezier(200 + 300j, 400 + 50j, 600 + 300j) 321 | q2 = QuadraticBezier(200 + 300j, 400 + 50j, 500 + 200j) 322 | closedq = QuadraticBezier(6+2j, 5-1j, 6+2j) 323 | linq1 = QuadraticBezier(1, 2, 3) 324 | linq2 = QuadraticBezier(1+3j, 2+5j, -9 - 17j) 325 | nodalq = QuadraticBezier(1, 1, 1) 326 | tests = [(q1, 487.77109389525975), 327 | (q2, 379.90458193489155), 328 | (closedq, 3.1622776601683795), 329 | (linq1, 2), 330 | (linq2, 22.73335777124786), 331 | (nodalq, 0)] 332 | for q, exp_res in tests: 333 | self.assertAlmostEqual(q.length(), exp_res) 334 | 335 | def test_equality(self): 336 | # This is to test the __eq__ and __ne__ methods, so we can't use 337 | # assertEqual and assertNotEqual 338 | segment = QuadraticBezier(200 + 300j, 400 + 50j, 600 + 300j) 339 | self.assertTrue(segment == QuadraticBezier(200 + 300j, 400 + 50j, 600 + 300j)) 340 | self.assertTrue(segment != QuadraticBezier(200 + 301j, 400 + 50j, 600 + 300j)) 341 | self.assertFalse(segment == Arc(0j, 100 + 50j, 0, 0, 0, 100 + 50j)) 342 | self.assertTrue(Arc(0j, 100 + 50j, 0, 0, 0, 100 + 50j) != segment) 343 | 344 | 345 | class ArcTest(unittest.TestCase): 346 | 347 | def test_points(self): 348 | arc1 = Arc(0j, 100 + 50j, 0, 0, 0, 100 + 50j) 349 | self.assertAlmostEqual(arc1.center, 100 + 0j) 350 | self.assertAlmostEqual(arc1.theta, 180.0) 351 | self.assertAlmostEqual(arc1.delta, -90.0) 352 | 353 | self.assertAlmostEqual(arc1.point(0.0), (0j)) 354 | self.assertAlmostEqual(arc1.point(0.1), (1.23116594049 + 7.82172325201j)) 355 | self.assertAlmostEqual(arc1.point(0.2), (4.89434837048 + 15.4508497187j)) 356 | self.assertAlmostEqual(arc1.point(0.3), (10.8993475812 + 22.699524987j)) 357 | self.assertAlmostEqual(arc1.point(0.4), (19.0983005625 + 29.3892626146j)) 358 | self.assertAlmostEqual(arc1.point(0.5), (29.2893218813 + 35.3553390593j)) 359 | self.assertAlmostEqual(arc1.point(0.6), (41.2214747708 + 40.4508497187j)) 360 | self.assertAlmostEqual(arc1.point(0.7), (54.6009500260 + 44.5503262094j)) 361 | self.assertAlmostEqual(arc1.point(0.8), (69.0983005625 + 47.5528258148j)) 362 | self.assertAlmostEqual(arc1.point(0.9), (84.3565534960 + 49.3844170298j)) 363 | self.assertAlmostEqual(arc1.point(1.0), (100 + 50j)) 364 | 365 | arc2 = Arc(0j, 100 + 50j, 0, 1, 0, 100 + 50j) 366 | self.assertAlmostEqual(arc2.center, 50j) 367 | self.assertAlmostEqual(arc2.theta, 270.0) 368 | self.assertAlmostEqual(arc2.delta, -270.0) 369 | 370 | self.assertAlmostEqual(arc2.point(0.0), (0j)) 371 | self.assertAlmostEqual(arc2.point(0.1), (-45.399049974 + 5.44967379058j)) 372 | self.assertAlmostEqual(arc2.point(0.2), (-80.9016994375 + 20.6107373854j)) 373 | self.assertAlmostEqual(arc2.point(0.3), (-98.7688340595 + 42.178276748j)) 374 | self.assertAlmostEqual(arc2.point(0.4), (-95.1056516295 + 65.4508497187j)) 375 | self.assertAlmostEqual(arc2.point(0.5), (-70.7106781187 + 85.3553390593j)) 376 | self.assertAlmostEqual(arc2.point(0.6), (-30.9016994375 + 97.5528258148j)) 377 | self.assertAlmostEqual(arc2.point(0.7), (15.643446504 + 99.3844170298j)) 378 | self.assertAlmostEqual(arc2.point(0.8), (58.7785252292 + 90.4508497187j)) 379 | self.assertAlmostEqual(arc2.point(0.9), (89.1006524188 + 72.699524987j)) 380 | self.assertAlmostEqual(arc2.point(1.0), (100 + 50j)) 381 | 382 | arc3 = Arc(0j, 100 + 50j, 0, 0, 1, 100 + 50j) 383 | self.assertAlmostEqual(arc3.center, 50j) 384 | self.assertAlmostEqual(arc3.theta, 270.0) 385 | self.assertAlmostEqual(arc3.delta, 90.0) 386 | 387 | self.assertAlmostEqual(arc3.point(0.0), (0j)) 388 | self.assertAlmostEqual(arc3.point(0.1), (15.643446504 + 0.615582970243j)) 389 | self.assertAlmostEqual(arc3.point(0.2), (30.9016994375 + 2.44717418524j)) 390 | self.assertAlmostEqual(arc3.point(0.3), (45.399049974 + 5.44967379058j)) 391 | self.assertAlmostEqual(arc3.point(0.4), (58.7785252292 + 9.54915028125j)) 392 | self.assertAlmostEqual(arc3.point(0.5), (70.7106781187 + 14.6446609407j)) 393 | self.assertAlmostEqual(arc3.point(0.6), (80.9016994375 + 20.6107373854j)) 394 | self.assertAlmostEqual(arc3.point(0.7), (89.1006524188 + 27.300475013j)) 395 | self.assertAlmostEqual(arc3.point(0.8), (95.1056516295 + 34.5491502813j)) 396 | self.assertAlmostEqual(arc3.point(0.9), (98.7688340595 + 42.178276748j)) 397 | self.assertAlmostEqual(arc3.point(1.0), (100 + 50j)) 398 | 399 | arc4 = Arc(0j, 100 + 50j, 0, 1, 1, 100 + 50j) 400 | self.assertAlmostEqual(arc4.center, 100 + 0j) 401 | self.assertAlmostEqual(arc4.theta, 180.0) 402 | self.assertAlmostEqual(arc4.delta, 270.0) 403 | 404 | self.assertAlmostEqual(arc4.point(0.0), (0j)) 405 | self.assertAlmostEqual(arc4.point(0.1), (10.8993475812 - 22.699524987j)) 406 | self.assertAlmostEqual(arc4.point(0.2), (41.2214747708 - 40.4508497187j)) 407 | self.assertAlmostEqual(arc4.point(0.3), (84.3565534960 - 49.3844170298j)) 408 | self.assertAlmostEqual(arc4.point(0.4), (130.901699437 - 47.5528258148j)) 409 | self.assertAlmostEqual(arc4.point(0.5), (170.710678119 - 35.3553390593j)) 410 | self.assertAlmostEqual(arc4.point(0.6), (195.105651630 - 15.4508497187j)) 411 | self.assertAlmostEqual(arc4.point(0.7), (198.768834060 + 7.82172325201j)) 412 | self.assertAlmostEqual(arc4.point(0.8), (180.901699437 + 29.3892626146j)) 413 | self.assertAlmostEqual(arc4.point(0.9), (145.399049974 + 44.5503262094j)) 414 | self.assertAlmostEqual(arc4.point(1.0), (100 + 50j)) 415 | 416 | def test_length(self): 417 | # I'll test the length calculations by making a circle, in two parts. 418 | arc1 = Arc(0j, 100 + 100j, 0, 0, 0, 200 + 0j) 419 | arc2 = Arc(200 + 0j, 100 + 100j, 0, 0, 0, 0j) 420 | self.assertAlmostEqual(arc1.length(), pi * 100) 421 | self.assertAlmostEqual(arc2.length(), pi * 100) 422 | 423 | def test_equality(self): 424 | # This is to test the __eq__ and __ne__ methods, so we can't use 425 | # assertEqual and assertNotEqual 426 | segment = Arc(0j, 100 + 50j, 0, 0, 0, 100 + 50j) 427 | self.assertTrue(segment == Arc(0j, 100 + 50j, 0, 0, 0, 100 + 50j)) 428 | self.assertTrue(segment != Arc(0j, 100 + 50j, 0, 1, 0, 100 + 50j)) 429 | 430 | 431 | class TestPath(unittest.TestCase): 432 | 433 | def test_circle(self): 434 | arc1 = Arc(0j, 100 + 100j, 0, 0, 0, 200 + 0j) 435 | arc2 = Arc(200 + 0j, 100 + 100j, 0, 0, 0, 0j) 436 | path = Path(arc1, arc2) 437 | self.assertAlmostEqual(path.point(0.0), (0j)) 438 | self.assertAlmostEqual(path.point(0.25), (100 + 100j)) 439 | self.assertAlmostEqual(path.point(0.5), (200 + 0j)) 440 | self.assertAlmostEqual(path.point(0.75), (100 - 100j)) 441 | self.assertAlmostEqual(path.point(1.0), (0j)) 442 | self.assertAlmostEqual(path.length(), pi * 200) 443 | 444 | def test_svg_specs(self): 445 | """The paths that are in the SVG specs""" 446 | 447 | # Big pie: M300,200 h-150 a150,150 0 1,0 150,-150 z 448 | path = Path(Line(300 + 200j, 150 + 200j), 449 | Arc(150 + 200j, 150 + 150j, 0, 1, 0, 300 + 50j), 450 | Line(300 + 50j, 300 + 200j)) 451 | # The points and length for this path are calculated and not regression tests. 452 | self.assertAlmostEqual(path.point(0.0), (300 + 200j)) 453 | self.assertAlmostEqual(path.point(0.14897825542), (150 + 200j)) 454 | self.assertAlmostEqual(path.point(0.5), (406.066017177 + 306.066017177j)) 455 | self.assertAlmostEqual(path.point(1 - 0.14897825542), (300 + 50j)) 456 | self.assertAlmostEqual(path.point(1.0), (300 + 200j)) 457 | # The errors seem to accumulate. Still 6 decimal places is more than good enough. 458 | self.assertAlmostEqual(path.length(), pi * 225 + 300, places=6) 459 | 460 | # Little pie: M275,175 v-150 a150,150 0 0,0 -150,150 z 461 | path = Path(Line(275 + 175j, 275 + 25j), 462 | Arc(275 + 25j, 150 + 150j, 0, 0, 0, 125 + 175j), 463 | Line(125 + 175j, 275 + 175j)) 464 | # The points and length for this path are calculated and not regression tests. 465 | self.assertAlmostEqual(path.point(0.0), (275 + 175j)) 466 | self.assertAlmostEqual(path.point(0.2800495767557787), (275 + 25j)) 467 | self.assertAlmostEqual(path.point(0.5), (168.93398282201787 + 68.93398282201787j)) 468 | self.assertAlmostEqual(path.point(1 - 0.2800495767557787), (125 + 175j)) 469 | self.assertAlmostEqual(path.point(1.0), (275 + 175j)) 470 | # The errors seem to accumulate. Still 6 decimal places is more than good enough. 471 | self.assertAlmostEqual(path.length(), pi * 75 + 300, places=6) 472 | 473 | # Bumpy path: M600,350 l 50,-25 474 | # a25,25 -30 0,1 50,-25 l 50,-25 475 | # a25,50 -30 0,1 50,-25 l 50,-25 476 | # a25,75 -30 0,1 50,-25 l 50,-25 477 | # a25,100 -30 0,1 50,-25 l 50,-25 478 | path = Path(Line(600 + 350j, 650 + 325j), 479 | Arc(650 + 325j, 25 + 25j, -30, 0, 1, 700 + 300j), 480 | Line(700 + 300j, 750 + 275j), 481 | Arc(750 + 275j, 25 + 50j, -30, 0, 1, 800 + 250j), 482 | Line(800 + 250j, 850 + 225j), 483 | Arc(850 + 225j, 25 + 75j, -30, 0, 1, 900 + 200j), 484 | Line(900 + 200j, 950 + 175j), 485 | Arc(950 + 175j, 25 + 100j, -30, 0, 1, 1000 + 150j), 486 | Line(1000 + 150j, 1050 + 125j), 487 | ) 488 | # These are *not* calculated, but just regression tests. Be skeptical. 489 | self.assertAlmostEqual(path.point(0.0), (600 + 350j)) 490 | self.assertAlmostEqual(path.point(0.3), (755.31526434 + 217.51578768j)) 491 | self.assertAlmostEqual(path.point(0.5), (832.23324151 + 156.33454892j)) 492 | self.assertAlmostEqual(path.point(0.9), (974.00559321 + 115.26473532j)) 493 | self.assertAlmostEqual(path.point(1.0), (1050 + 125j)) 494 | # The errors seem to accumulate. Still 6 decimal places is more than good enough. 495 | self.assertAlmostEqual(path.length(), 860.6756221710) 496 | 497 | def test_repr(self): 498 | path = Path( 499 | Line(start=600 + 350j, end=650 + 325j), 500 | Arc(start=650 + 325j, radius=25 + 25j, rotation=-30, arc=0, sweep=1, end=700 + 300j), 501 | CubicBezier(start=700 + 300j, control1=800 + 400j, control2=750 + 200j, end=600 + 100j), 502 | QuadraticBezier(start=600 + 100j, control=600, end=600 + 300j)) 503 | self.assertEqual(eval(repr(path)), path) 504 | 505 | def test_reverse(self): 506 | # Currently you can't reverse paths. 507 | self.assertRaises(NotImplementedError, Path().reverse) 508 | 509 | def test_equality(self): 510 | # This is to test the __eq__ and __ne__ methods, so we can't use 511 | # assertEqual and assertNotEqual 512 | path1 = Path( 513 | Line(start=600 + 350j, end=650 + 325j), 514 | Arc(start=650 + 325j, radius=25 + 25j, rotation=-30, arc=0, sweep=1, end=700 + 300j), 515 | CubicBezier(start=700 + 300j, control1=800 + 400j, control2=750 + 200j, end=600 + 100j), 516 | QuadraticBezier(start=600 + 100j, control=600, end=600 + 300j)) 517 | path2 = Path( 518 | Line(start=600 + 350j, end=650 + 325j), 519 | Arc(start=650 + 325j, radius=25 + 25j, rotation=-30, arc=0, sweep=1, end=700 + 300j), 520 | CubicBezier(start=700 + 300j, control1=800 + 400j, control2=750 + 200j, end=600 + 100j), 521 | QuadraticBezier(start=600 + 100j, control=600, end=600 + 300j)) 522 | 523 | self.assertTrue(path1 == path2) 524 | # Modify path2: 525 | path2[0].start = 601 + 350j 526 | self.assertTrue(path1 != path2) 527 | 528 | # Modify back: 529 | path2[0].start = 600 + 350j 530 | self.assertFalse(path1 != path2) 531 | 532 | # Get rid of the last segment: 533 | del path2[-1] 534 | self.assertFalse(path1 == path2) 535 | 536 | # It's not equal to a list of it's segments 537 | self.assertTrue(path1 != path1[:]) 538 | self.assertFalse(path1 == path1[:]) 539 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | import argparse 5 | import time 6 | import os 7 | import cPickle 8 | 9 | from utils import SketchLoader 10 | from model import Model 11 | 12 | def main(): 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--rnn_size', type=int, default=256, 15 | help='size of RNN hidden state') 16 | parser.add_argument('--num_layers', type=int, default=2, 17 | help='number of layers in the RNN') 18 | parser.add_argument('--model', type=str, default='lstm', 19 | help='rnn, gru, or lstm') 20 | parser.add_argument('--batch_size', type=int, default=100, 21 | help='minibatch size') 22 | parser.add_argument('--seq_length', type=int, default=300, 23 | help='RNN sequence length') 24 | parser.add_argument('--num_epochs', type=int, default=500, 25 | help='number of epochs') 26 | parser.add_argument('--save_every', type=int, default=250, 27 | help='save frequency') 28 | parser.add_argument('--grad_clip', type=float, default=5.0, 29 | help='clip gradients at this value') 30 | parser.add_argument('--learning_rate', type=float, default=0.005, 31 | help='learning rate') 32 | parser.add_argument('--decay_rate', type=float, default=0.99, 33 | help='decay rate for rmsprop') 34 | parser.add_argument('--num_mixture', type=int, default=24, 35 | help='number of gaussian mixtures') 36 | parser.add_argument('--data_scale', type=float, default=15.0, 37 | help='factor to scale raw data down by') 38 | parser.add_argument('--keep_prob', type=float, default=0.8, 39 | help='dropout keep probability') 40 | parser.add_argument('--stroke_importance_factor', type=float, default=200.0, 41 | help='relative importance of pen status over mdn coordinate accuracy') 42 | parser.add_argument('--dataset_name', type=str, default="kanji", 43 | help='name of directory containing training data') 44 | args = parser.parse_args() 45 | train(args) 46 | 47 | def train(args): 48 | data_loader = SketchLoader(args.batch_size, args.seq_length, args.data_scale, args.dataset_name) 49 | 50 | dirname = os.path.join('save', args.dataset_name) 51 | if not os.path.exists(dirname): 52 | os.makedirs(dirname) 53 | 54 | with open(os.path.join(dirname, 'config.pkl'), 'w') as f: 55 | cPickle.dump(args, f) 56 | 57 | model = Model(args) 58 | 59 | b_processed = 0 60 | 61 | with tf.Session() as sess: 62 | 63 | tf.initialize_all_variables().run() 64 | saver = tf.train.Saver(tf.all_variables()) 65 | 66 | # load previously trained model if appilcable 67 | ckpt = tf.train.get_checkpoint_state(os.path.join('save', args.dataset_name)) 68 | if ckpt: 69 | print "loading last model: ",ckpt.model_checkpoint_path 70 | saver.restore(sess, ckpt.model_checkpoint_path) 71 | 72 | def save_model(): 73 | checkpoint_path = os.path.join('save', args.dataset_name, 'model.ckpt') 74 | saver.save(sess, checkpoint_path, global_step = b_processed) 75 | print "model saved to {}".format(checkpoint_path) 76 | 77 | for e in xrange(args.num_epochs): 78 | sess.run(tf.assign(model.lr, args.learning_rate * (args.decay_rate ** e))) 79 | data_loader.reset_index_pointer() 80 | state = model.initial_state.eval() 81 | while data_loader.epoch_finished == False: 82 | start = time.time() 83 | input_data, target_data = data_loader.next_batch() 84 | feed = {model.input_data: input_data, model.target_data: target_data, model.initial_state: state} 85 | train_loss, shape_loss, pen_loss, state, _ = sess.run([model.cost, model.cost_shape, model.cost_pen, model.final_state, model.train_op], feed) 86 | end = time.time() 87 | b_processed += 1 88 | print "{}/{} (epoch {} batch {}), cost = {:.2f} ({:.2f}+{:.4f}), time/batch = {:.2f}" \ 89 | .format(data_loader.pointer + e * data_loader.num_samples, 90 | args.num_epochs * data_loader.num_samples, 91 | e, b_processed ,train_loss, shape_loss, pen_loss, end - start) 92 | # assert( train_loss != np.NaN or train_loss != np.Inf) # doesn't work. 93 | assert( train_loss < 30000) # if dodgy loss, exit w/ error. 94 | if (b_processed) % args.save_every == 0 and ((b_processed) > 0): 95 | save_model() 96 | save_model() 97 | 98 | if __name__ == '__main__': 99 | main() 100 | 101 | 102 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cPickle 3 | import numpy as np 4 | import xml.etree.ElementTree as ET 5 | import random 6 | import svgwrite 7 | from IPython.display import SVG, display 8 | from svg.path import Path, Line, Arc, CubicBezier, QuadraticBezier, parse_path 9 | 10 | def calculate_start_point(data, factor=1.0, block_size = 200): 11 | # will try to center the sketch to the middle of the block 12 | # determines maxx, minx, maxy, miny 13 | sx = 0 14 | sy = 0 15 | maxx = 0 16 | minx = 0 17 | maxy = 0 18 | miny = 0 19 | for i in xrange(len(data)): 20 | sx += round(float(data[i, 0])*factor, 3) 21 | sy += round(float(data[i, 1])*factor, 3) 22 | maxx = max(maxx, sx) 23 | minx = min(minx, sx) 24 | maxy = max(maxy, sy) 25 | miny = min(miny, sy) 26 | 27 | abs_x = block_size/2-(maxx-minx)/2-minx 28 | abs_y = block_size/2-(maxy-miny)/2-miny 29 | 30 | return abs_x, abs_y, (maxx-minx), (maxy-miny) 31 | 32 | def draw_stroke_color_array(data, factor=1, svg_filename = 'sample.svg', stroke_width = 1, block_size = 200, maxcol = 5, svg_only = False, color_mode = True): 33 | 34 | num_char = len(data) 35 | 36 | if num_char < 1: 37 | return 38 | 39 | max_color_intensity = 225 40 | 41 | numrow = np.ceil(float(num_char)/float(maxcol)) 42 | dwg = svgwrite.Drawing(svg_filename, size=(block_size*(min(num_char, maxcol)), block_size*numrow)) 43 | dwg.add(dwg.rect(insert=(0, 0), size=(block_size*(min(num_char, maxcol)), block_size*numrow),fill='white')) 44 | 45 | the_color = "rgb("+str(random.randint(0, max_color_intensity))+","+str(int(random.randint(0, max_color_intensity)))+","+str(int(random.randint(0, max_color_intensity)))+")" 46 | 47 | for j in xrange(len(data)): 48 | 49 | lift_pen = 0 50 | #end_of_char = 0 51 | cdata = data[j] 52 | abs_x, abs_y, size_x, size_y = calculate_start_point(cdata, factor, block_size) 53 | abs_x += (j % maxcol) * block_size 54 | abs_y += (j / maxcol) * block_size 55 | 56 | for i in xrange(len(cdata)): 57 | 58 | x = round(float(cdata[i,0])*factor, 3) 59 | y = round(float(cdata[i,1])*factor, 3) 60 | 61 | prev_x = round(abs_x, 3) 62 | prev_y = round(abs_y, 3) 63 | 64 | abs_x += x 65 | abs_y += y 66 | 67 | if (lift_pen == 1): 68 | p = "M "+str(abs_x)+","+str(abs_y)+" " 69 | the_color = "rgb("+str(random.randint(0, max_color_intensity))+","+str(int(random.randint(0, max_color_intensity)))+","+str(int(random.randint(0, max_color_intensity)))+")" 70 | else: 71 | p = "M "+str(prev_x)+","+str(prev_y)+" L "+str(abs_x)+","+str(abs_y)+" " 72 | 73 | lift_pen = max(cdata[i, 2], cdata[i, 3]) # lift pen if both eos or eoc 74 | #end_of_char = cdata[i, 3] # not used for now. 75 | 76 | if color_mode == False: 77 | the_color = "#000" 78 | 79 | dwg.add(dwg.path(p).stroke(the_color,stroke_width).fill(the_color)) #, opacity=round(random.random()*0.5+0.5, 3) 80 | 81 | dwg.save() 82 | if svg_only == False: 83 | display(SVG(dwg.tostring())) 84 | 85 | def draw_stroke_color(data, factor=1, svg_filename = 'sample.svg', stroke_width = 1, block_size = 200, maxcol = 5, svg_only = False, color_mode = True): 86 | 87 | def split_sketch(data): 88 | # split a sketch with many eoc into an array of sketches, each with just one eoc at the end. 89 | # ignores last stub with no eoc. 90 | counter = 0 91 | result = [] 92 | for i in xrange(len(data)): 93 | eoc = data[i, 3] 94 | if eoc > 0: 95 | result.append(data[counter:i+1]) 96 | counter = i+1 97 | #if (counter < len(data)): # ignore the rest 98 | # result.append(data[counter:]) 99 | return result 100 | 101 | data = np.array(data, dtype=np.float32) 102 | data = split_sketch(data) 103 | draw_stroke_color_array(data, factor, svg_filename, stroke_width, block_size, maxcol, svg_only, color_mode) 104 | 105 | class SketchLoader(): 106 | def __init__(self, batch_size=50, seq_length=300, scale_factor = 1.0, data_filename = "kanji"): 107 | self.data_dir = "./data" 108 | self.batch_size = batch_size 109 | self.seq_length = seq_length 110 | self.scale_factor = scale_factor # divide data by this factor 111 | 112 | data_file = os.path.join(self.data_dir, data_filename+".cpkl") 113 | raw_data_dir = os.path.join(self.data_dir, data_filename) 114 | 115 | if not (os.path.exists(data_file)) : 116 | print "creating training data cpkl file from raw source" 117 | self.length_data = self.preprocess(raw_data_dir, data_file) 118 | 119 | self.load_preprocessed(data_file) 120 | self.num_samples = len(self.raw_data) 121 | self.index = range(self.num_samples) # this list will be randomized later. 122 | self.reset_index_pointer() 123 | 124 | def preprocess(self, data_dir, data_file): 125 | # create data file from raw xml files from iam handwriting source. 126 | len_data = [] 127 | def cubicbezier(x0, y0, x1, y1, x2, y2, x3, y3, n=20): 128 | # from http://rosettacode.org/wiki/Bitmap/B%C3%A9zier_curves/Cubic 129 | pts = [] 130 | for i in range(n+1): 131 | t = float(i) / float(n) 132 | a = (1. - t)**3 133 | b = 3. * t * (1. - t)**2 134 | c = 3.0 * t**2 * (1.0 - t) 135 | d = t**3 136 | 137 | x = float(a * x0 + b * x1 + c * x2 + d * x3) 138 | y = float(a * y0 + b * y1 + c * y2 + d * y3) 139 | pts.append( (x, y) ) 140 | return pts 141 | 142 | def get_path_strings(svgfile): 143 | tree = ET.parse(svgfile) 144 | p = [] 145 | for elem in tree.iter(): 146 | if elem.attrib.has_key('d'): 147 | p.append(elem.attrib['d']) 148 | return p 149 | 150 | def build_lines(svgfile, line_length_threshold = 10.0, min_points_per_path = 1, max_points_per_path = 3): 151 | # we don't draw lines less than line_length_threshold 152 | path_strings = get_path_strings(svgfile) 153 | 154 | lines = [] 155 | 156 | for path_string in path_strings: 157 | full_path = parse_path(path_string) 158 | for i in range(len(full_path)): 159 | p = full_path[i] 160 | if type(p) != Line and type(p) != CubicBezier: 161 | print "encountered an element that is not just a line or bezier " 162 | print "type: ",type(p) 163 | print p 164 | else: 165 | x_start = p.start.real 166 | y_start = p.start.imag 167 | x_end = p.end.real 168 | y_end = p.end.imag 169 | line_length = np.sqrt((x_end-x_start)*(x_end-x_start)+(y_end-y_start)*(y_end-y_start)) 170 | len_data.append(line_length) 171 | points = [] 172 | if type(p) == CubicBezier: 173 | x_con1 = p.control1.real 174 | y_con1 = p.control1.imag 175 | x_con2 = p.control2.real 176 | y_con2 = p.control2.imag 177 | n_points = int(line_length / line_length_threshold)+1 178 | n_points = max(n_points, min_points_per_path) 179 | n_points = min(n_points, max_points_per_path) 180 | points = cubicbezier(x_start, y_start, x_con1, y_con1, x_con2, y_con2, x_end, y_end, n_points) 181 | else: 182 | points = [(x_start, y_start), (x_end, y_end)] 183 | if i == 0: # only append the starting point for svg 184 | lines.append([points[0][0], points[0][1], 0, 0]) # put eoc to be zero 185 | for j in range(1, len(points)): 186 | eos = 0 187 | if j == len(points)-1 and i == len(full_path)-1: 188 | eos = 1 189 | lines.append([points[j][0], points[j][1], eos, 0]) # put eoc to be zero 190 | lines = np.array(lines, dtype=np.float32) 191 | # make it relative moves 192 | lines[1:,0:2] -= lines[0:-1,0:2] 193 | lines[-1,3] = 1 # end of character 194 | lines[0] = [0, 0, 0, 0] # start at origin 195 | return lines[1:] 196 | 197 | # build the list of xml files 198 | filelist = [] 199 | # Set the directory you want to start from 200 | rootDir = data_dir 201 | for dirName, subdirList, fileList in os.walk(rootDir): 202 | #print('Found directory: %s' % dirName) 203 | for fname in fileList: 204 | #print('\t%s' % fname) 205 | filelist.append(dirName+"/"+fname) 206 | 207 | # build stroke database of every xml file inside iam database 208 | sketch = [] 209 | for i in range(len(filelist)): 210 | if (filelist[i][-3:] == 'svg'): 211 | print 'processing '+filelist[i] 212 | sketch.append(build_lines(filelist[i])) 213 | 214 | f = open(data_file,"wb") 215 | cPickle.dump(sketch, f, protocol=2) 216 | f.close() 217 | return len_data 218 | 219 | def load_preprocessed(self, data_file): 220 | f = open(data_file,"rb") 221 | self.raw_data = cPickle.load(f) 222 | # scale the data here, rather than at the data construction (since scaling may change) 223 | for data in self.raw_data: 224 | data[:,0:2] /= self.scale_factor 225 | f.close() 226 | 227 | def next_batch(self): 228 | # returns a set of batches, but the constraint is that the start of each input data batch 229 | # is the start of a new character (although the end of a batch doesn't have to be end of a character) 230 | 231 | def next_seq(n): 232 | result = np.zeros((n, 5), dtype=np.float32) # x, y, [eos, eoc, cont] tokens 233 | #result[0, 2:4] = 1 # set eos and eoc to true for first point 234 | # experimental line below, put a random factor between 70-130% to generate more examples 235 | rand_scale_factor_x = np.random.rand()*0.6+0.7 236 | rand_scale_factor_y = np.random.rand()*0.6+0.7 237 | idx = 0 238 | data = self.current_data() 239 | for i in xrange(n): 240 | result[i, 0:4] = data[idx] # eoc = 0.0 241 | result[i, 4] = 1 # continue on stroke 242 | if (result[i, 2] > 0 or result[i, 3] > 0): 243 | result[i, 4] = 0 244 | idx += 1 245 | if (idx >= len(data)-1): # skip to next sketch example next time and mark eoc 246 | result[i, 4] = 0 247 | result[i, 3] = 1 248 | result[i, 2] = 0 # overrides end of stroke one-hot 249 | idx = 0 250 | self.tick_index_pointer() 251 | data = self.current_data() 252 | assert(result[i, 2:5].sum() == 1) 253 | self.tick_index_pointer() # needed if seq_length is less than last data. 254 | result[:, 0] *= rand_scale_factor_x 255 | result[:, 1] *= rand_scale_factor_y 256 | return result 257 | 258 | skip_length = self.seq_length+1 259 | 260 | batch = [] 261 | 262 | for i in xrange(self.batch_size): 263 | seq = next_seq(skip_length) 264 | batch.append(seq) 265 | 266 | batch = np.array(batch, dtype=np.float32) 267 | 268 | return batch[:,0:-1], batch[:, 1:] 269 | 270 | def current_data(self): 271 | return self.raw_data[self.index[self.pointer]] 272 | 273 | def tick_index_pointer(self): 274 | self.pointer += 1 275 | if (self.pointer >= len(self.raw_data)): 276 | self.pointer = 0 277 | self.epoch_finished = True 278 | 279 | def reset_index_pointer(self): 280 | # randomize order for the raw list in the next go. 281 | self.pointer = 0 282 | self.epoch_finished = False 283 | self.index = np.random.permutation(self.index) 284 | 285 | 286 | --------------------------------------------------------------------------------