├── svg ├── path │ ├── tests │ │ ├── __init__.py │ │ ├── test_doc.py │ │ ├── test_generation.py │ │ ├── test_parsing.py │ │ └── test_paths.py │ ├── path.pyc │ ├── parser.pyc │ ├── __init__.pyc │ ├── __init__.py │ ├── parser.py │ └── path.py ├── __init__.py └── __init__.pyc ├── model.pyc ├── utils.pyc ├── magenta_rnn.pyc ├── svg_utils.pyc ├── tf_data_work.pyc ├── images ├── example.png ├── highlight.png └── architecture.png ├── README.md ├── prepare_data.py ├── sketchrnn_cnn_dual_test.py ├── sketchrnn_cnn_dual_train.py ├── tf_data_work.py ├── svg_utils.py ├── magenta_rnn.py └── model.py /svg/path/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /model.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seindlut/deep_p2s/HEAD/model.pyc -------------------------------------------------------------------------------- /utils.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seindlut/deep_p2s/HEAD/utils.pyc -------------------------------------------------------------------------------- /svg/__init__.py: -------------------------------------------------------------------------------- 1 | __import__('pkg_resources').declare_namespace(__name__) 2 | -------------------------------------------------------------------------------- /magenta_rnn.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seindlut/deep_p2s/HEAD/magenta_rnn.pyc -------------------------------------------------------------------------------- /svg_utils.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seindlut/deep_p2s/HEAD/svg_utils.pyc -------------------------------------------------------------------------------- /svg/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seindlut/deep_p2s/HEAD/svg/__init__.pyc -------------------------------------------------------------------------------- /svg/path/path.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seindlut/deep_p2s/HEAD/svg/path/path.pyc -------------------------------------------------------------------------------- /tf_data_work.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seindlut/deep_p2s/HEAD/tf_data_work.pyc -------------------------------------------------------------------------------- /images/example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seindlut/deep_p2s/HEAD/images/example.png -------------------------------------------------------------------------------- /images/highlight.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seindlut/deep_p2s/HEAD/images/highlight.png -------------------------------------------------------------------------------- /svg/path/parser.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seindlut/deep_p2s/HEAD/svg/path/parser.pyc -------------------------------------------------------------------------------- /svg/path/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seindlut/deep_p2s/HEAD/svg/path/__init__.pyc -------------------------------------------------------------------------------- /images/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seindlut/deep_p2s/HEAD/images/architecture.png -------------------------------------------------------------------------------- /svg/path/__init__.py: -------------------------------------------------------------------------------- 1 | from .path import Path, Line, Arc, CubicBezier, QuadraticBezier 2 | from .parser import parse_path 3 | -------------------------------------------------------------------------------- /svg/path/tests/test_doc.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import doctest 3 | 4 | 5 | def load_tests(loader, tests, ignore): 6 | tests.addTests(doctest.DocFileSuite('README.rst', package='__main__')) 7 | return tests 8 | -------------------------------------------------------------------------------- /svg/path/tests/test_generation.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import unittest 3 | from ..path import CubicBezier, QuadraticBezier, Line, Arc, Path 4 | from ..parser import parse_path 5 | 6 | 7 | class TestGeneration(unittest.TestCase): 8 | 9 | def test_svg_examples(self): 10 | """Examples from the SVG spec""" 11 | paths = [ 12 | 'M 100,100 L 300,100 L 200,300 Z', 13 | 'M 0,0 L 50,20 M 100,100 L 300,100 L 200,300 Z', 14 | 'M 100,100 L 200,200', 15 | 'M 100,200 L 200,100 L -100,-200', 16 | 'M 100,200 C 100,100 250,100 250,200 S 400,300 400,200', 17 | 'M 100,200 C 100,100 400,100 400,200', 18 | 'M 100,500 C 25,400 475,400 400,500', 19 | 'M 100,800 C 175,700 325,700 400,800', 20 | 'M 600,200 C 675,100 975,100 900,200', 21 | 'M 600,500 C 600,350 900,650 900,500', 22 | 'M 600,800 C 625,700 725,700 750,800 S 875,900 900,800', 23 | 'M 200,300 Q 400,50 600,300 T 1000,300', 24 | 'M -3.4E+38,3.4E+38 L -3.4E-38,3.4E-38', 25 | 'M 0,0 L 50,20 M 50,20 L 200,100 Z', 26 | 'M 600,350 L 650,325 A 25,25 -30 0,1 700,300 L 750,275', 27 | ] 28 | 29 | for path in paths: 30 | self.assertEqual(parse_path(path).d(), path) 31 | 32 | def test_normalizing(self): 33 | # Relative paths will be made absolute, subpaths merged if they can, 34 | # and syntax will change. 35 | self.assertEqual(parse_path('M0 0L3.4E2-10L100.0,100M100,100l100,-100').d(), 36 | 'M 0,0 L 340,-10 L 100,100 L 200,0') 37 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep Photo-to-Sketch Synthesis Model 2 | 3 | Before jumping in our code implementation based on [Tensorflow](https://github.com/tensorflow/tensorflow), please refer to our paper [Learning to Sketch with Shortcut Cycle Consistency](https://arxiv.org/abs/1805.00247) for the basic idea. 4 | 5 | 6 | # Overview 7 | 8 | In this paper, we present a novel approach for translating an object photo to a sketch, mimicking the human sketching process. Teaching a machine to generate a sketch from a photo just like humans do is not easy. This requires not only developing an abstract concept of a visual object instance, but also knowing what, where and when to sketch the next line stroke. Figure \ref{fig:highlight} shows that the developed photo-to-sketch synthesizer takes a photo as input and mimics the human sketching process by sequentially drawing one stroke at a time. The resulting synthesized sketches provide an abstract and semantically meaningful depiction of the given object, just like human sketches do. 9 | 10 |

11 | 12 | 13 | 14 | *Examples of our model mimicking to sketch stroke by stroke.* 15 | 16 | # Model Structure 17 | 18 | We aim to learn a mapping function between the photo domain *X* and sketch domain *Y*, where we denote the empirical data distribution as *x ~ pdata(x)* and *y ~ pdata(y)* and represent each vector sketch segment as (*sxi*, *syi*), a two-dimensional offset vector. Our model includes four mapping functions, learned using four subnets namely a photo encoder, a sketch encoder, a photo decoder, a sketch decoder. The illustration of our model architecture is as shown as below. 19 | 20 |

21 | 22 | # Training a Model 23 | 24 | Our deep photo-to-sketch (p2s) synthesis model is trained on the dataset of ShoeV2 and ChairV2. 25 | 26 | Usage: 27 | 28 | ```bash 29 | python sketchrnn_cnn_dual_train.py --dataset shoesv2 30 | ``` 31 | 32 | As mentioned in the paper, before you train a photo-to-sketch (p2s) synthesis model, you need pretrain your model on the [QuickDraw](https://github.com/googlecreativelab/quickdraw-dataset) Data from the corresponding categories. 33 | 34 | We have tested this model on TensorFlow 1.4 for Python 2.7. 35 | 36 | # Result 37 | 38 | Example: 39 | 40 |

41 | 42 | # Datasets 43 | 44 | The datasets for our photo-to-sketch synthesis task are *ShoeV2* and *ChairV2* datasets, which can be dowloaded from the homepage of our group [SketchX](http://sketchx.eecs.qmul.ac.uk/downloads/). 45 | 46 | The pretraining dataset can be download from [QuickDraw](https://github.com/googlecreativelab/quickdraw-dataset). 47 | 48 | The original data can be converted to hdf5 format using *prepare_data.py*. Or you can download it from the [GoogleDrive](https://drive.google.com/open?id=1029l8QZ9EWzQEb9GDylVM3EP4_kurW1E). 49 | 50 | # Citation 51 | 52 | If you find this project useful for academic purposes, please cite it as: 53 | 54 | ``` 55 | @Inproceedings{song2018learning, 56 | title = {Learning to Sketch with Shortcut Cycle Consistency}, 57 | author = {Song, Jifei and Pang, Kaiyue and Song, Yi-Zhe and Xiang, Tao and Hospedales, Timothy M}, 58 | booktitle = {CVPR}, 59 | year = {2018} 60 | } 61 | ``` 62 | -------------------------------------------------------------------------------- /svg/path/tests/test_parsing.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import unittest 3 | from ..path import CubicBezier, QuadraticBezier, Line, Arc, Path 4 | from ..parser import parse_path 5 | 6 | 7 | class TestParser(unittest.TestCase): 8 | 9 | def test_svg_examples(self): 10 | """Examples from the SVG spec""" 11 | path1 = parse_path('M 100 100 L 300 100 L 200 300 z') 12 | self.assertEqual(path1, Path(Line(100 + 100j, 300 + 100j), 13 | Line(300 + 100j, 200 + 300j), 14 | Line(200 + 300j, 100 + 100j))) 15 | self.assertTrue(path1.closed) 16 | 17 | # for Z command behavior when there is multiple subpaths 18 | path1 = parse_path('M 0 0 L 50 20 M 100 100 L 300 100 L 200 300 z') 19 | self.assertEqual(path1, Path( 20 | Line(0 + 0j, 50 + 20j), 21 | Line(100 + 100j, 300 + 100j), 22 | Line(300 + 100j, 200 + 300j), 23 | Line(200 + 300j, 100 + 100j))) 24 | 25 | path1 = parse_path('M 100 100 L 200 200') 26 | path2 = parse_path('M100 100L200 200') 27 | self.assertEqual(path1, path2) 28 | 29 | path1 = parse_path('M 100 200 L 200 100 L -100 -200') 30 | path2 = parse_path('M 100 200 L 200 100 -100 -200') 31 | self.assertEqual(path1, path2) 32 | 33 | path1 = parse_path("""M100,200 C100,100 250,100 250,200 34 | S400,300 400,200""") 35 | self.assertEqual(path1, 36 | Path(CubicBezier(100 + 200j, 100 + 100j, 250 + 100j, 250 + 200j), 37 | CubicBezier(250 + 200j, 250 + 300j, 400 + 300j, 400 + 200j))) 38 | 39 | path1 = parse_path('M100,200 C100,100 400,100 400,200') 40 | self.assertEqual(path1, 41 | Path(CubicBezier(100 + 200j, 100 + 100j, 400 + 100j, 400 + 200j))) 42 | 43 | path1 = parse_path('M100,500 C25,400 475,400 400,500') 44 | self.assertEqual(path1, 45 | Path(CubicBezier(100 + 500j, 25 + 400j, 475 + 400j, 400 + 500j))) 46 | 47 | path1 = parse_path('M100,800 C175,700 325,700 400,800') 48 | self.assertEqual(path1, 49 | Path(CubicBezier(100 + 800j, 175 + 700j, 325 + 700j, 400 + 800j))) 50 | 51 | path1 = parse_path('M600,200 C675,100 975,100 900,200') 52 | self.assertEqual(path1, 53 | Path(CubicBezier(600 + 200j, 675 + 100j, 975 + 100j, 900 + 200j))) 54 | 55 | path1 = parse_path('M600,500 C600,350 900,650 900,500') 56 | self.assertEqual(path1, 57 | Path(CubicBezier(600 + 500j, 600 + 350j, 900 + 650j, 900 + 500j))) 58 | 59 | path1 = parse_path("""M600,800 C625,700 725,700 750,800 60 | S875,900 900,800""") 61 | self.assertEqual(path1, 62 | Path(CubicBezier(600 + 800j, 625 + 700j, 725 + 700j, 750 + 800j), 63 | CubicBezier(750 + 800j, 775 + 900j, 875 + 900j, 900 + 800j))) 64 | 65 | path1 = parse_path('M200,300 Q400,50 600,300 T1000,300') 66 | self.assertEqual(path1, 67 | Path(QuadraticBezier(200 + 300j, 400 + 50j, 600 + 300j), 68 | QuadraticBezier(600 + 300j, 800 + 550j, 1000 + 300j))) 69 | 70 | path1 = parse_path('M300,200 h-150 a150,150 0 1,0 150,-150 z') 71 | self.assertEqual(path1, 72 | Path(Line(300 + 200j, 150 + 200j), 73 | Arc(150 + 200j, 150 + 150j, 0, 1, 0, 300 + 50j), 74 | Line(300 + 50j, 300 + 200j))) 75 | 76 | path1 = parse_path('M275,175 v-150 a150,150 0 0,0 -150,150 z') 77 | self.assertEqual(path1, 78 | Path(Line(275 + 175j, 275 + 25j), 79 | Arc(275 + 25j, 150 + 150j, 0, 0, 0, 125 + 175j), 80 | Line(125 + 175j, 275 + 175j))) 81 | 82 | path1 = parse_path("""M600,350 l 50,-25 83 | a25,25 -30 0,1 50,-25 l 50,-25 84 | a25,50 -30 0,1 50,-25 l 50,-25 85 | a25,75 -30 0,1 50,-25 l 50,-25 86 | a25,100 -30 0,1 50,-25 l 50,-25""") 87 | self.assertEqual(path1, 88 | Path(Line(600 + 350j, 650 + 325j), 89 | Arc(650 + 325j, 25 + 25j, -30, 0, 1, 700 + 300j), 90 | Line(700 + 300j, 750 + 275j), 91 | Arc(750 + 275j, 25 + 50j, -30, 0, 1, 800 + 250j), 92 | Line(800 + 250j, 850 + 225j), 93 | Arc(850 + 225j, 25 + 75j, -30, 0, 1, 900 + 200j), 94 | Line(900 + 200j, 950 + 175j), 95 | Arc(950 + 175j, 25 + 100j, -30, 0, 1, 1000 + 150j), 96 | Line(1000 + 150j, 1050 + 125j))) 97 | 98 | def test_others(self): 99 | # Other paths that need testing: 100 | 101 | # Relative moveto: 102 | path1 = parse_path('M 0 0 L 50 20 m 50 80 L 300 100 L 200 300 z') 103 | self.assertEqual(path1, Path( 104 | Line(0 + 0j, 50 + 20j), 105 | Line(100 + 100j, 300 + 100j), 106 | Line(300 + 100j, 200 + 300j), 107 | Line(200 + 300j, 100 + 100j))) 108 | 109 | # Initial smooth and relative CubicBezier 110 | path1 = parse_path("""M100,200 s 150,-100 150,0""") 111 | self.assertEqual(path1, 112 | Path(CubicBezier(100 + 200j, 100 + 200j, 250 + 100j, 250 + 200j))) 113 | 114 | # Initial smooth and relative QuadraticBezier 115 | path1 = parse_path("""M100,200 t 150,0""") 116 | self.assertEqual(path1, 117 | Path(QuadraticBezier(100 + 200j, 100 + 200j, 250 + 200j))) 118 | 119 | # Relative QuadraticBezier 120 | path1 = parse_path("""M100,200 q 0,0 150,0""") 121 | self.assertEqual(path1, 122 | Path(QuadraticBezier(100 + 200j, 100 + 200j, 250 + 200j))) 123 | 124 | def test_negative(self): 125 | """You don't need spaces before a minus-sign""" 126 | path1 = parse_path('M100,200c10-5,20-10,30-20') 127 | path2 = parse_path('M 100 200 c 10 -5 20 -10 30 -20') 128 | self.assertEqual(path1, path2) 129 | 130 | def test_numbers(self): 131 | """Exponents and other number format cases""" 132 | # It can be e or E, the plus is optional, and a minimum of +/-3.4e38 must be supported. 133 | path1 = parse_path('M-3.4e38 3.4E+38L-3.4E-38,3.4e-38') 134 | path2 = Path(Line(-3.4e+38 + 3.4e+38j, -3.4e-38 + 3.4e-38j)) 135 | self.assertEqual(path1, path2) 136 | 137 | def test_errors(self): 138 | self.assertRaises(ValueError, parse_path, 'M 100 100 L 200 200 Z 100 200') 139 | -------------------------------------------------------------------------------- /svg/path/parser.py: -------------------------------------------------------------------------------- 1 | # SVG Path specification parser 2 | 3 | import re 4 | from . import path 5 | 6 | COMMANDS = set('MmZzLlHhVvCcSsQqTtAa') 7 | UPPERCASE = set('MZLHVCSQTA') 8 | 9 | COMMAND_RE = re.compile("([MmZzLlHhVvCcSsQqTtAa])") 10 | FLOAT_RE = re.compile("[-+]?[0-9]*\.?[0-9]+(?:[eE][-+]?[0-9]+)?") 11 | 12 | 13 | def _tokenize_path(pathdef): 14 | for x in COMMAND_RE.split(pathdef): 15 | if x in COMMANDS: 16 | yield x 17 | for token in FLOAT_RE.findall(x): 18 | yield token 19 | 20 | 21 | def parse_path(pathdef, current_pos=0j): 22 | # In the SVG specs, initial movetos are absolute, even if 23 | # specified as 'm'. This is the default behavior here as well. 24 | # But if you pass in a current_pos variable, the initial moveto 25 | # will be relative to that current_pos. This is useful. 26 | elements = list(_tokenize_path(pathdef)) 27 | # Reverse for easy use of .pop() 28 | elements.reverse() 29 | 30 | segments = path.Path() 31 | start_pos = None 32 | command = None 33 | 34 | while elements: 35 | 36 | if elements[-1] in COMMANDS: 37 | # New command. 38 | last_command = command # Used by S and T 39 | command = elements.pop() 40 | absolute = command in UPPERCASE 41 | command = command.upper() 42 | else: 43 | # If this element starts with numbers, it is an implicit command 44 | # and we don't change the command. Check that it's allowed: 45 | if command is None: 46 | raise ValueError("Unallowed implicit command in %s, position %s" % ( 47 | pathdef, len(pathdef.split()) - len(elements))) 48 | 49 | if command == 'M': 50 | # Moveto command. 51 | x = elements.pop() 52 | y = elements.pop() 53 | pos = float(x) + float(y) * 1j 54 | if absolute: 55 | current_pos = pos 56 | else: 57 | current_pos += pos 58 | 59 | # when M is called, reset start_pos 60 | # This behavior of Z is defined in svg spec: 61 | # http://www.w3.org/TR/SVG/paths.html#PathDataClosePathCommand 62 | start_pos = current_pos 63 | 64 | # Implicit moveto commands are treated as lineto commands. 65 | # So we set command to lineto here, in case there are 66 | # further implicit commands after this moveto. 67 | command = 'L' 68 | 69 | elif command == 'Z': 70 | # Close path 71 | segments.append(path.Line(current_pos, start_pos)) 72 | segments.closed = True 73 | current_pos = start_pos 74 | start_pos = None 75 | command = None # You can't have implicit commands after closing. 76 | 77 | elif command == 'L': 78 | x = elements.pop() 79 | y = elements.pop() 80 | pos = float(x) + float(y) * 1j 81 | if not absolute: 82 | pos += current_pos 83 | segments.append(path.Line(current_pos, pos)) 84 | current_pos = pos 85 | 86 | elif command == 'H': 87 | x = elements.pop() 88 | pos = float(x) + current_pos.imag * 1j 89 | if not absolute: 90 | pos += current_pos.real 91 | segments.append(path.Line(current_pos, pos)) 92 | current_pos = pos 93 | 94 | elif command == 'V': 95 | y = elements.pop() 96 | pos = current_pos.real + float(y) * 1j 97 | if not absolute: 98 | pos += current_pos.imag * 1j 99 | segments.append(path.Line(current_pos, pos)) 100 | current_pos = pos 101 | 102 | elif command == 'C': 103 | control1 = float(elements.pop()) + float(elements.pop()) * 1j 104 | control2 = float(elements.pop()) + float(elements.pop()) * 1j 105 | end = float(elements.pop()) + float(elements.pop()) * 1j 106 | 107 | if not absolute: 108 | control1 += current_pos 109 | control2 += current_pos 110 | end += current_pos 111 | 112 | segments.append(path.CubicBezier(current_pos, control1, control2, end)) 113 | current_pos = end 114 | 115 | elif command == 'S': 116 | # Smooth curve. First control point is the "reflection" of 117 | # the second control point in the previous path. 118 | 119 | if last_command not in 'CS': 120 | # If there is no previous command or if the previous command 121 | # was not an C, c, S or s, assume the first control point is 122 | # coincident with the current point. 123 | control1 = current_pos 124 | else: 125 | # The first control point is assumed to be the reflection of 126 | # the second control point on the previous command relative 127 | # to the current point. 128 | control1 = current_pos + current_pos - segments[-1].control2 129 | 130 | control2 = float(elements.pop()) + float(elements.pop()) * 1j 131 | end = float(elements.pop()) + float(elements.pop()) * 1j 132 | 133 | if not absolute: 134 | control2 += current_pos 135 | end += current_pos 136 | 137 | segments.append(path.CubicBezier(current_pos, control1, control2, end)) 138 | current_pos = end 139 | 140 | elif command == 'Q': 141 | control = float(elements.pop()) + float(elements.pop()) * 1j 142 | end = float(elements.pop()) + float(elements.pop()) * 1j 143 | 144 | if not absolute: 145 | control += current_pos 146 | end += current_pos 147 | 148 | segments.append(path.QuadraticBezier(current_pos, control, end)) 149 | current_pos = end 150 | 151 | elif command == 'T': 152 | # Smooth curve. Control point is the "reflection" of 153 | # the second control point in the previous path. 154 | 155 | if last_command not in 'QT': 156 | # If there is no previous command or if the previous command 157 | # was not an Q, q, T or t, assume the first control point is 158 | # coincident with the current point. 159 | control = current_pos 160 | else: 161 | # The control point is assumed to be the reflection of 162 | # the control point on the previous command relative 163 | # to the current point. 164 | control = current_pos + current_pos - segments[-1].control 165 | 166 | end = float(elements.pop()) + float(elements.pop()) * 1j 167 | 168 | if not absolute: 169 | end += current_pos 170 | 171 | segments.append(path.QuadraticBezier(current_pos, control, end)) 172 | current_pos = end 173 | 174 | elif command == 'A': 175 | radius = float(elements.pop()) + float(elements.pop()) * 1j 176 | rotation = float(elements.pop()) 177 | arc = float(elements.pop()) 178 | sweep = float(elements.pop()) 179 | end = float(elements.pop()) + float(elements.pop()) * 1j 180 | 181 | if not absolute: 182 | end += current_pos 183 | 184 | segments.append(path.Arc(current_pos, radius, rotation, arc, sweep, end)) 185 | current_pos = end 186 | 187 | return segments 188 | -------------------------------------------------------------------------------- /prepare_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import scipy.io as sio 4 | import glob 5 | import cv2 6 | import h5py 7 | import json 8 | import sys 9 | import shutil 10 | import svg_utils 11 | from svgpathtools import svg2paths, wsvg 12 | from svgpathtools import svg2paths2 13 | 14 | 15 | def save_hdf5(fname, d): 16 | hf = h5py.File(fname, 'w') 17 | for key in d.keys(): 18 | value = d[key] 19 | if type(value) is list: 20 | value = np.array(value) 21 | dtype = value.dtype.name 22 | if 'string' in dtype: 23 | dtype = value.dtype.str.split('|')[1] 24 | value = [v.encode("ascii", "ignore") for v in value] 25 | hf.create_dataset(key, (len(value),1), dtype, value) 26 | else: 27 | hf.create_dataset(key, 28 | dtype=value.dtype.name, 29 | data=value) 30 | hf.close() 31 | return fname 32 | 33 | 34 | def load_hdf5(fname): 35 | hf = h5py.File(fname, 'r') 36 | d = {key: np.array(hf.get(key)) for key in hf.keys()} 37 | hf.close() 38 | return d 39 | 40 | 41 | def read_info(dataset_folder): 42 | subset_info = { 43 | 'image_nums': [], 44 | 'class_names': [], 45 | 'num_classes': 0 46 | } 47 | data_info = { 48 | 'id': [], 49 | 'class_name': [], 50 | 'class_id': [], 51 | 'image_name': [], 52 | 'image_id': [], 53 | 'instance_id': [], 54 | 'image_data': [] 55 | } 56 | class_name = dataset_folder.split('/')[-3] 57 | data_type = dataset_folder.split('/')[-2] 58 | # class_id_dict = {'shoes': 0, 'chairs': 1} 59 | class_id_dict = {'shoes': 0, 'chairs': 0} 60 | id_in_list = 0 61 | image_id_offset = 0 62 | class_id = 0 63 | 64 | image_files = os.walk(dataset_folder).next()[2] 65 | # sort image files 66 | image_files.sort() 67 | image_base_names = [] 68 | unique_image_base_names = [] 69 | instance_ids = [] 70 | print "read info for %s in %s" % (class_name, data_type) 71 | for image_file in image_files: 72 | if '_' not in image_file: 73 | raise Exception('Sketch file name wrong') 74 | image_base_name = '_'.join(image_file.split('_')[:-1]) 75 | instance_id = image_file.split('_')[-1] 76 | instance_id = instance_id.split('.')[0] 77 | instance_ids.append(int(instance_id) - 1) 78 | 79 | image_base_names.append(image_base_name) 80 | if image_base_name not in unique_image_base_names: 81 | unique_image_base_names.append(image_base_name) 82 | 83 | # this is to avoid the ranking problem that "a56-3002_1.png < a_1.png" but "a.png < a56-3002.png" on chair dataset 84 | image_files = np.array(image_files)[np.argsort(image_base_names)].tolist() 85 | image_base_names = np.array(image_base_names)[np.argsort(image_base_names)].tolist() 86 | instance_ids = np.array(instance_ids)[np.argsort(image_base_names)].tolist() 87 | unique_image_base_names = np.sort(unique_image_base_names).tolist() 88 | # image_base_names.sort() 89 | 90 | for idx in range(len(image_files)): 91 | image_file = image_files[idx] 92 | data_info['id'].append(id_in_list) 93 | data_info['class_name'].append(class_name) 94 | data_info['class_id'].append(class_id_dict[class_name]) 95 | data_info['image_name'].append(image_file) 96 | data_info['image_id'].append(unique_image_base_names.index(image_base_names[idx]) + image_id_offset) 97 | data_info['instance_id'].append(instance_ids[idx]) 98 | id_in_list += 1 99 | image_id_offset += len(unique_image_base_names) 100 | class_id += 1 101 | 102 | print "\n Data list reading complete" 103 | 104 | num_images = len(data_info['image_name']) 105 | print "save svg data" 106 | data_info['image_data'] = [] 107 | data_info['data_offset'] = np.zeros((num_images, 2)) 108 | start_idx = 0 109 | for idx in range(num_images): 110 | sys.stdout.write('\x1b[2K\r>> Process svg data, [%d/%d]' % (idx, num_images)) 111 | sys.stdout.flush() 112 | lines = svg_utils.build_lines(os.path.join(dataset_folder, data_info['image_name'][idx])) 113 | data_info['image_data'].extend(lines) 114 | end_idx = start_idx + len(lines) 115 | data_info['data_offset'][idx, ::] = [start_idx, end_idx] 116 | start_idx = end_idx 117 | data_info['data_offset'] = data_info['data_offset'].astype(int) 118 | 119 | if save_png: 120 | if simplify_flag: 121 | png_data_dir = dataset_folder.split('sim')[0] + 'png' 122 | else: 123 | png_data_dir = dataset_folder + '_png' 124 | print "save rgb data" 125 | data_info['png_data'] = np.zeros((num_images, 256, 256), dtype=np.uint8) 126 | for idx in range(num_images): 127 | sys.stdout.write('\x1b[2K\r>> Process png data, [%d/%d]' % (idx, num_images)) 128 | sys.stdout.flush() 129 | im = cv2.imread(os.path.join(png_data_dir, data_info['image_name'][idx]).split('.svg')[0] + '.png') 130 | im = cv2.resize(im, (256, 256)) 131 | im = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY) 132 | # im = cv2.imread(os.path.join(dataset_folder, test_img.split('.jpg')[0] + '.jpg')) 133 | data_info['png_data'][idx, ::] = im 134 | 135 | return subset_info, data_info 136 | 137 | 138 | def prepare_dbs(image_folder, data_type = 'h5'): 139 | 140 | if simplify_flag: 141 | simplify_str = '_sim' 142 | else: 143 | simplify_str = '' 144 | 145 | sketchy_info_train, data_info_train = read_info(os.path.join(image_folder, 'svg_train%s' % simplify_str)) 146 | sketchy_info_test, data_info_test = read_info(os.path.join(image_folder, 'svg_test%s' % simplify_str)) 147 | 148 | if save_png: 149 | simplify_str += '_png' 150 | 151 | save_hdf5(os.path.join(image_folder, 'train_svg%s.%s' % (simplify_str, data_type)), data_info_train) 152 | save_hdf5(os.path.join(image_folder, 'test_svg%s.%s' % (simplify_str, data_type)), data_info_test) 153 | 154 | 155 | def generate_db_list(image_folder): 156 | train_file_list_txt_origin = os.path.join(image_folder, 'train.txt') 157 | train_file_list_txt = os.path.join(image_folder, 'train_svg.txt') 158 | test_file_list_txt_origin = os.path.join(image_folder, 'test.txt') 159 | test_file_list_txt = os.path.join(image_folder, 'test_svg.txt') 160 | if not os.path.exists(train_file_list_txt) or not os.path.exists(test_file_list_txt): 161 | with open(train_file_list_txt_origin, 'r') as f: 162 | train_file_lists_origin = f.read().splitlines() 163 | with open(test_file_list_txt_origin, 'r') as f: 164 | test_file_lists_origin = f.read().splitlines() 165 | train_file_lists = [item.split('png')[0] + 'svg' for item in train_file_lists_origin] 166 | test_file_lists = [item.split('png')[0] + 'svg' for item in test_file_lists_origin] 167 | with open(train_file_list_txt, 'w') as f: 168 | f.writelines("\n".join(train_file_lists)) 169 | with open(test_file_list_txt, 'w') as f: 170 | f.writelines("\n".join(test_file_lists)) 171 | 172 | 173 | def split_db(image_folder, train_list_txt = 'train_svg.txt', test_list_txt = 'test_svg.txt'): 174 | if train_list_txt: 175 | with open(os.path.join(image_folder, train_list_txt)) as f: 176 | train_list = f.read().splitlines() 177 | copy_db_files(image_folder, 'svg_all', 'svg_train', train_list) 178 | if test_list_txt: 179 | with open(os.path.join(image_folder, test_list_txt)) as f: 180 | test_list = f.read().splitlines() 181 | copy_db_files(image_folder, 'svg_all', 'svg_test', test_list) 182 | # copy_db_files(image_folder, 'all', 'test', test_list) 183 | 184 | 185 | def copy_db_files(root_folder, src_folder, dst_folder, file_list): 186 | src_path = os.path.join(root_folder, src_folder) 187 | dst_path = os.path.join(root_folder, dst_folder) 188 | print "Copy files from %s/%s to %s/%s" % (root_folder, src_folder, root_folder, dst_folder) 189 | src_base_names = [] 190 | dst_base_names = [] 191 | if not os.path.exists(dst_path): 192 | os.mkdir(dst_path) 193 | sub_dirs = os.walk(src_path).next()[1] 194 | for sub_dir in sub_dirs: 195 | os.mkdir(os.path.join(dst_path, sub_dir.replace(' ', '_').replace('-', '_').replace('(', '').replace(')', ''))) 196 | for file in file_list: 197 | sys.stdout.write('\x1b[2K\r>> Copying subfolder %s ==> %s: %d/%d' % (src_folder, dst_folder, file_list.index(file)+1, len(file_list))) 198 | sys.stdout.flush() 199 | src_file = os.path.join(src_path, file) 200 | dst_file = os.path.join(dst_path, file.replace(' ', '_').replace('-', '_').replace('(', '').replace(')', '')) 201 | base_name = file.split('_')[0] 202 | if base_name not in src_base_names: 203 | src_base_names.append(base_name) 204 | try: 205 | shutil.copy2(src_file, dst_file) 206 | if base_name not in dst_base_names: 207 | dst_base_names.append(base_name) 208 | except: 209 | print "File not exist: ", src_file 210 | print "\n Copy finished" 211 | print "Warning, none of files with below basename is copied" 212 | print [filename for filename in src_base_names if filename not in dst_base_names] 213 | 214 | 215 | 216 | if __name__ == "__main__": 217 | 218 | datasets = ['shoes', 'chairs'] 219 | simplify_flag = True 220 | 221 | save_png = True 222 | 223 | for dataset in datasets: 224 | data_dir = 'data/%s/svg' % dataset 225 | 226 | generate_db_list(data_dir) 227 | 228 | split_db(data_dir) 229 | 230 | prepare_dbs(data_dir) 231 | -------------------------------------------------------------------------------- /sketchrnn_cnn_dual_test.py: -------------------------------------------------------------------------------- 1 | """Model training.""" 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import json 8 | import os 9 | import sys 10 | import time 11 | 12 | # internal imports 13 | 14 | import cv2 15 | import numpy as np 16 | import tensorflow as tf 17 | 18 | from model import sample, sample_recons, get_init_fn 19 | import model as sketch_rnn_model 20 | import utils 21 | 22 | FLAGS = tf.app.flags.FLAGS 23 | 24 | tf.app.flags.DEFINE_string('root_dir', './data', 'The root directory for the data') 25 | # tf.app.flags.DEFINE_string('data_dir', data_dir, 'The directory to find the dataset') 26 | # tf.app.flags.DEFINE_string('dataset', 'quickdraw', 'The dataset for classification') 27 | # tf.app.flags.DEFINE_string('dataset', 'shoes', 'The dataset for classification') 28 | tf.app.flags.DEFINE_string('dataset', 'shoesv2', 'The dataset for classification') 29 | # tf.app.flags.DEFINE_string('dataset', 'quickdraw_shoe', 'The dataset for classification') 30 | tf.app.flags.DEFINE_boolean('simplify_flag', True, 'use simplified dataset') 31 | tf.app.flags.DEFINE_boolean('use_vae', True, 'use vae or ae only') 32 | tf.app.flags.DEFINE_boolean('concat_z', True, 'concatenate z with x') 33 | tf.app.flags.DEFINE_string('log_root', './models/runs', 'Directory to store model checkpoints, tensorboard.') 34 | tf.app.flags.DEFINE_float('lr', 0.0001, "Learning rate.") 35 | tf.app.flags.DEFINE_float('decay_rate', 0.9999, "Learning rate decay for certain minibatches.") 36 | tf.app.flags.DEFINE_boolean('lr_decay', False, "Learning rate decay.") 37 | tf.app.flags.DEFINE_boolean('nkl', False, "if True, keep vae architecture but remove kl loss.") 38 | tf.app.flags.DEFINE_float('kl_weight_start', 0.01, "KL start weight when annealing.") 39 | tf.app.flags.DEFINE_float('kl_decay_rate', 0.99995, "KL annealing decay rate per minibatch") 40 | tf.app.flags.DEFINE_boolean('kl_weight_decay', False, "KL weight decay.") 41 | tf.app.flags.DEFINE_boolean('kl_tolerance', 0.2, "Level of KL loss at which to stop optimizing for KL.") 42 | tf.app.flags.DEFINE_float('l2_weight_start', 0.1, "start weight for l2.") 43 | tf.app.flags.DEFINE_float('l2_decay_rate', 0.99995, "l2 decay rate per minibatch") 44 | tf.app.flags.DEFINE_boolean('l2_weight_decay', False, "l2 weight decay.") 45 | tf.app.flags.DEFINE_boolean('l2_decay_step', 5000, "l2 weight decay after how many steps.") 46 | tf.app.flags.DEFINE_integer('max_seq_len', 250, "max length of sequential data.") 47 | tf.app.flags.DEFINE_float('seq_lw', 1.0, "Loss weight for sequence reconstruction.") 48 | tf.app.flags.DEFINE_float('pix_lw', 1.0, "Loss weight for pixel reconstruction.") 49 | tf.app.flags.DEFINE_float('tri_weight', 1.0, "Triplet loss weight.") 50 | tf.app.flags.DEFINE_boolean('tune_cnn', True, 'finetune the cnn part or not, this is trying to ') 51 | tf.app.flags.DEFINE_string('vae_type', 'p2s', 'variational autoencoder type: s2s, sketch2sketch, p2s, photo2sketch, ' 52 | 'ps2s/sp2s, photo2sketch & sketch2sketch') 53 | tf.app.flags.DEFINE_string('enc_type', 'cnn', 'type of encoder') 54 | tf.app.flags.DEFINE_string('rcons_type', 'mdn', 'type of reconstruction loss') 55 | tf.app.flags.DEFINE_boolean('rd_dim', 512, 'embedding dim after mlp or other subnet') 56 | # tf.app.flags.DEFINE_boolean('reduce_dim', False, 'add fc layer before the embedding loss') 57 | tf.app.flags.DEFINE_boolean('image_size', 256, 'image size for cnn') 58 | tf.app.flags.DEFINE_boolean('crop_size', 224, 'crop size for cnn') 59 | tf.app.flags.DEFINE_boolean('chn_size', 1, 'number of channel for cnn') 60 | tf.app.flags.DEFINE_string('basenet', 'gen_cnn', 'basenet for cnn encoder') 61 | tf.app.flags.DEFINE_string('feat_type', 'inceptionv3', 'feature size for the extracted photo feature') 62 | tf.app.flags.DEFINE_integer('feat_size', 2048, 'feature size for the extracted photo feature') 63 | tf.app.flags.DEFINE_float('margin', 0.1, 'Margin for contrastive/triplet loss') 64 | tf.app.flags.DEFINE_boolean('load_pretrain', False, 'Load pretrain model for CBB') 65 | tf.app.flags.DEFINE_boolean('resume_training', True, 'Set to true to load previous checkpoint') 66 | tf.app.flags.DEFINE_string('load_dir', '', 'Directory to load the pretrained model') 67 | tf.app.flags.DEFINE_string('img_dir', '', 'Directory to save the images') 68 | tf.app.flags.DEFINE_string('add_str', '', 'add str to the image save directory and checkpoints') 69 | 70 | # hyperparameters 71 | tf.app.flags.DEFINE_integer('batch_size', 100, 'Number of images to process in a batch.') 72 | tf.app.flags.DEFINE_boolean('is_train', False, 'In the training stage or not') 73 | tf.app.flags.DEFINE_float('drop_kp', 1.0, 'Dropout keep rate') 74 | 75 | # data augmentation 76 | tf.app.flags.DEFINE_boolean('flip_aug', False, 'Whether to flip the sketch and photo or not') 77 | tf.app.flags.DEFINE_boolean('dist_aug', False, 'Whether to distort the images') 78 | tf.app.flags.DEFINE_boolean('hp_filter', False, 'Whether to add high pass filter') 79 | 80 | tf.app.flags.DEFINE_integer("print_every", 100, "print training loss after this many steps (default: 20)") 81 | tf.app.flags.DEFINE_integer("save_every", 1000, "Evaluate model on dev set after this many steps (default: 100)") 82 | tf.app.flags.DEFINE_boolean('debug_test', False, 'Set to true to load previous checkpoint') 83 | tf.app.flags.DEFINE_string("saved_flags", None, "Save all flags for printing") 84 | tf.app.flags.DEFINE_string('hparams', '', 85 | 'Pass in comma-separated key=value pairs such as \'save_every=40,decay_rate=0.99\'' 86 | '(no whitespace) to be read into the HParams object defined in model.py') 87 | 88 | # save settings for sampling in testing stage 89 | tf.app.flags.DEFINE_boolean('sample_sketch', True, 'Set to true to save ground truth sketch') 90 | tf.app.flags.DEFINE_boolean('save_gt_sketch', True, 'Set to true to save ground truth sketch') 91 | tf.app.flags.DEFINE_boolean('save_photo', False, 'Set to true to save ground truth photo') 92 | tf.app.flags.DEFINE_boolean('cond_sketch', False, 'Set to true to generate sketch conditioned on sketch') 93 | tf.app.flags.DEFINE_boolean('inter_z', False, 'Interpolate latent vector of batch z') 94 | tf.app.flags.DEFINE_boolean('recon_sketch', False, 'Set to true to reconstruct sketch') 95 | tf.app.flags.DEFINE_boolean('recon_photo', False, 'Set to true to reconstruct photo') 96 | 97 | # hyperparameters succeded from sketch-rnn 98 | tf.app.flags.DEFINE_float('random_scale_factor', 0.15, 'Random scaling data augmention proportion.') 99 | tf.app.flags.DEFINE_float('augment_stroke_prob', 0.10, 'Point dropping augmentation proportion.') 100 | tf.app.flags.DEFINE_string('rnn_model', 'lstm', 'lstm, layer_norm or hyper.') 101 | tf.app.flags.DEFINE_boolean('use_recurrent_dropout', True, 'Dropout with memory loss.') 102 | tf.app.flags.DEFINE_float('recurrent_dropout_prob', 1.0, 'Probability of recurrent dropout keep.') 103 | tf.app.flags.DEFINE_boolean('rnn_input_dropout', False, 'RNN input dropout.') 104 | tf.app.flags.DEFINE_boolean('rnn_output_dropout', False, 'RNN output droput.') 105 | tf.app.flags.DEFINE_integer('enc_rnn_size', 256, 'Size of RNN when used as encoder') 106 | tf.app.flags.DEFINE_integer('dec_rnn_size', 512, 'Size of RNN when used as decoder') 107 | tf.app.flags.DEFINE_integer('z_size', 128, 'Size of latent vector z') 108 | tf.app.flags.DEFINE_integer('num_mixture', 20, 'Size of latent vector z') 109 | 110 | 111 | def reset_graph(): 112 | """Closes the current default session and resets the graph.""" 113 | sess = tf.get_default_session() 114 | if sess: 115 | sess.close() 116 | tf.reset_default_graph() 117 | 118 | 119 | def load_model(model_dir): 120 | """Loads model for inference mode, used in jupyter notebook.""" 121 | model_params = sketch_rnn_model.get_default_hparams() 122 | with tf.gfile.Open(os.path.join(model_dir, 'model_config.json'), 'r') as f: 123 | model_params.parse_json(f.read()) 124 | 125 | model_params.batch_size = 1 # only sample one at a time 126 | eval_model_params = sketch_rnn_model.copy_hparams(model_params) 127 | eval_model_params.use_input_dropout = 0 128 | eval_model_params.use_recurrent_dropout = 0 129 | eval_model_params.use_output_dropout = 0 130 | eval_model_params.is_training = 0 131 | sample_model_params = sketch_rnn_model.copy_hparams(eval_model_params) 132 | sample_model_params.max_seq_len = 1 # sample one point at a time 133 | return [model_params, eval_model_params, sample_model_params] 134 | 135 | 136 | def sampling_model_eval(sess, model, gen_model, data_set, seq_len): 137 | """Returns the average weighted cost, reconstruction cost and KL cost.""" 138 | sketch_size, photo_size = data_set.sketch_size, data_set.image_size 139 | 140 | folders_to_create = ['gen_test', 'gen_test_png', 'gt_test', 'gt_test_png', 'gt_test_photo', 'gt_test_sketch_image', 141 | 'gen_test_s', 'gen_test_s_png', 'gen_test_inter', 'gen_test_inter_png', 'gen_test_inter_sep', 142 | 'gen_test_inter_sep_png', 'gen_photo', 'gen_test_inter_with_photo', 'recon_test', 143 | 'recon_test_png', 'recon_photo'] 144 | for folder_to_create in folders_to_create: 145 | folder_path = os.path.join(FLAGS.img_dir, '%s/%s' % (data_set.dataset, folder_to_create)) 146 | if not os.path.exists(folder_path): 147 | os.mkdir(folder_path) 148 | 149 | for image_index in range(photo_size): 150 | 151 | sys.stdout.write('\x1b[2K\r>> Sampling test set, [%d/%d]' % (image_index + 1, photo_size)) 152 | sys.stdout.flush() 153 | 154 | image_feat, rnn_enc_seq_len = data_set.get_input_image(image_index) 155 | sample_strokes, m = sample(sess, model, image_feat, seq_len=seq_len, rnn_enc_seq_len=rnn_enc_seq_len) 156 | strokes = utils.to_normal_strokes(sample_strokes) 157 | svg_gen_sketch = os.path.join(FLAGS.img_dir, '%s/gen_test/gen_sketch%d.svg' % (data_set.dataset, image_index)) 158 | png_gen_sketch = os.path.join(FLAGS.img_dir, '%s/gen_test_png/gen_sketch%d.png' % (data_set.dataset, image_index)) 159 | utils.sv_svg_png_from_strokes(strokes, svg_filename=svg_gen_sketch, png_filename=png_gen_sketch) 160 | 161 | print("\nSampling finished") 162 | 163 | 164 | def load_checkpoint(sess, checkpoint_path): 165 | 166 | ckpt = tf.train.get_checkpoint_state(checkpoint_path) 167 | if ckpt is None: 168 | raise Exception('Pretrained model not found at %s' % checkpoint_path) 169 | print('Loading model %s.' % ckpt.model_checkpoint_path) 170 | init_op = get_init_fn(checkpoint_path, [''], ckpt.model_checkpoint_path) 171 | init_op(sess) 172 | 173 | 174 | def save_model(sess, model_save_path, global_step): 175 | saver = tf.train.Saver(tf.global_variables()) 176 | checkpoint_path = os.path.join(model_save_path, 'vector') 177 | print('saving model %s.' % checkpoint_path) 178 | print('global_step %i.' % global_step) 179 | saver.save(sess, checkpoint_path, global_step=global_step) 180 | 181 | 182 | def sample_test(sess, sample_model, gen_model, test_set, max_seq_len): 183 | 184 | # set image dir 185 | # FLAGS.img_dir = FLAGS.log_root.split('runs')[0] + 'sv_imgs/%s/' % FLAGS.dataset 186 | FLAGS.img_dir = FLAGS.log_root.split('runs')[0] + 'sv_imgs/' 187 | 188 | FLAGS.img_dir += 'dual/' + FLAGS.basenet 189 | 190 | sampling_model_eval(sess, sample_model, gen_model, test_set, max_seq_len, sample_model.hps.rcons_type) 191 | 192 | 193 | def tester(model_params): 194 | """Test model.""" 195 | np.set_printoptions(precision=8, edgeitems=6, linewidth=200, suppress=True) 196 | 197 | print('Hyperparams:') 198 | for key, val in model_params.values().iteritems(): 199 | print('%s = %s' % (key, str(val))) 200 | print('Loading data files.') 201 | test_set, sample_model_params, gen_model_params = utils.load_dataset(FLAGS.root_dir, FLAGS.dataset, model_params, inference_mode=True) 202 | 203 | reset_graph() 204 | sample_model = sketch_rnn_model.Model(sample_model_params) 205 | gen_model = sketch_rnn_model.Model(gen_model_params, reuse=True) 206 | 207 | sess = tf.Session() 208 | sess.run(tf.global_variables_initializer()) 209 | 210 | if FLAGS.dataset in ['shoesv2f_sup', 'shoesv2f_train']: 211 | dataset = 'shoesv2' 212 | else: 213 | dataset = FLAGS.dataset 214 | 215 | if FLAGS.resume_training: 216 | if FLAGS.load_dir == '': 217 | FLAGS.load_dir = FLAGS.log_root.split('runs')[0] + 'model_to_test/%s/' % dataset 218 | # set dir to load the model for testing 219 | FLAGS.load_dir = os.path.join(FLAGS.load_dir, FLAGS.basenet) 220 | load_checkpoint(sess, FLAGS.load_dir) 221 | 222 | # Write config file to json file. 223 | tf.gfile.MakeDirs(FLAGS.log_root) 224 | with tf.gfile.Open( 225 | os.path.join(FLAGS.log_root, 'model_config.json'), 'w') as f: 226 | json.dump(model_params.values(), f, indent=True) 227 | 228 | sample_test(sess, sample_model, gen_model, test_set, model_params.max_seq_len) 229 | 230 | 231 | def main(unused_argv): 232 | """Load model params, save config file and start trainer.""" 233 | model_params = tf.contrib.training.HParams() 234 | # merge FLAGS to hps 235 | for attr, value in sorted(FLAGS.__flags.items()): 236 | model_params.add_hparam(attr, value) 237 | 238 | tester(model_params) 239 | 240 | 241 | if __name__ == '__main__': 242 | tf.app.run(main) -------------------------------------------------------------------------------- /sketchrnn_cnn_dual_train.py: -------------------------------------------------------------------------------- 1 | """Model training.""" 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import json 8 | import os 9 | import sys 10 | import time 11 | 12 | # internal imports 13 | 14 | import numpy as np 15 | import tensorflow as tf 16 | 17 | from model import sample, get_init_fn 18 | import model as sketch_rnn_model 19 | import utils 20 | import cv2 21 | 22 | 23 | FLAGS = tf.app.flags.FLAGS 24 | 25 | tf.app.flags.DEFINE_string('root_dir', './data', 'The root directory for the data') 26 | tf.app.flags.DEFINE_string('dataset', 'shoesv2', 'The dataset for classification') 27 | tf.app.flags.DEFINE_string('log_root', './models/runs', 'Directory to store model checkpoints, tensorboard.') 28 | tf.app.flags.DEFINE_float('lr', 0.0001, "Learning rate.") 29 | tf.app.flags.DEFINE_float('decay_rate', 0.9, "Learning rate decay for certain minibatches.") 30 | tf.app.flags.DEFINE_float('decay_step', 5000, "Learning rate decay after how many training steps.") 31 | tf.app.flags.DEFINE_boolean('lr_decay', False, "Learning rate decay.") 32 | tf.app.flags.DEFINE_float('grad_clip', 1.0, 'Gradient clipping') 33 | tf.app.flags.DEFINE_float('kl_weight_start', 0.01, "KL start weight when annealing.") 34 | tf.app.flags.DEFINE_float('kl_decay_rate', 0.99995, "KL annealing decay rate per minibatch") 35 | tf.app.flags.DEFINE_boolean('kl_weight_decay', False, "KL weight decay.") 36 | tf.app.flags.DEFINE_boolean('kl_decay_step', 5000, "KL weight decay after how many steps.") 37 | tf.app.flags.DEFINE_boolean('kl_tolerance', 0.2, "Level of KL loss at which to stop optimizing for KL.") 38 | tf.app.flags.DEFINE_float('l2_weight_start', 0.1, "start weight for l2.") 39 | tf.app.flags.DEFINE_float('l2_decay_rate', 0.99995, "l2 decay rate per minibatch") 40 | tf.app.flags.DEFINE_boolean('l2_weight_decay', False, "l2 weight decay.") 41 | tf.app.flags.DEFINE_boolean('l2_decay_step', 5000, "l2 weight decay after how many steps.") 42 | tf.app.flags.DEFINE_integer('max_seq_len', 250, "max length of sequential data.") 43 | tf.app.flags.DEFINE_float('seq_lw', 1.0, "Loss weight for sequence reconstruction.") 44 | tf.app.flags.DEFINE_float('pix_lw', 1.0, "Loss weight for pixel reconstruction.") 45 | tf.app.flags.DEFINE_boolean('tune_cnn', True, 'finetune the cnn part or not, this is trying to ') 46 | tf.app.flags.DEFINE_string('vae_type', 'p2s', 'variational autoencoder type: s2s, sketch2sketch, p2s, photo2sketch, ' 47 | 'ps2s/sp2s, photo2sketch & sketch2sketch') 48 | tf.app.flags.DEFINE_string('enc_type', 'cnn', 'type of encoder') 49 | tf.app.flags.DEFINE_boolean('image_size', 256, 'image size for cnn') 50 | tf.app.flags.DEFINE_boolean('crop_size', 224, 'crop size for cnn') 51 | tf.app.flags.DEFINE_boolean('chn_size', 1, 'number of channel for cnn') 52 | tf.app.flags.DEFINE_string('basenet', 'gen_cnn', 'basenet for cnn encoder') 53 | tf.app.flags.DEFINE_float('margin', 0.1, 'Margin for contrastive/triplet loss') 54 | tf.app.flags.DEFINE_boolean('load_pretrain', True, 'Load pretrain model for the p2s model') 55 | tf.app.flags.DEFINE_boolean('resume_training', False, 'Set to true to load previous checkpoint') 56 | tf.app.flags.DEFINE_string('load_dir', '', 'Directory to load the pretrained model') 57 | tf.app.flags.DEFINE_string('img_dir', '', 'Directory to save the images') 58 | 59 | # hyperparameters 60 | tf.app.flags.DEFINE_integer('batch_size', 100, 'Number of images to process in a batch.') 61 | tf.app.flags.DEFINE_boolean('is_train', True, 'In the training stage or not') 62 | tf.app.flags.DEFINE_float('drop_kp', 0.8, 'Dropout keep rate') 63 | 64 | # data augmentation 65 | tf.app.flags.DEFINE_boolean('flip_aug', False, 'Whether to flip the sketch and photo or not') 66 | tf.app.flags.DEFINE_boolean('dist_aug', False, 'Whether to distort the images') 67 | tf.app.flags.DEFINE_boolean('hp_filter', False, 'Whether to add high pass filter') 68 | 69 | # print flags 70 | tf.app.flags.DEFINE_integer("max_steps", 100000, "Max training steps") 71 | tf.app.flags.DEFINE_integer("print_every", 100, "print training loss after this many steps (default: 20)") 72 | tf.app.flags.DEFINE_integer("save_every", 1000, "Evaluate model on dev set after this many steps (default: 100)") 73 | tf.app.flags.DEFINE_boolean('debug_test', False, 'Set to true to load previous checkpoint') 74 | tf.app.flags.DEFINE_boolean('tee_log', True, 'Create log file to save the print info') 75 | tf.app.flags.DEFINE_boolean('inter_z', False, 'Interpolate latent vector of batch z') 76 | tf.app.flags.DEFINE_string("saved_flags", None, "Save all flags for printing") 77 | 78 | # hyperparameters succeded from sketch-rnn 79 | tf.app.flags.DEFINE_float('random_scale_factor', 0.15, 'Random scaling data augmention proportion.') 80 | tf.app.flags.DEFINE_float('augment_stroke_prob', 0.10, 'Point dropping augmentation proportion.') 81 | tf.app.flags.DEFINE_string('rnn_model', 'lstm', 'lstm, layer_norm or hyper.') 82 | tf.app.flags.DEFINE_boolean('use_recurrent_dropout', True, 'Dropout with memory loss.') 83 | tf.app.flags.DEFINE_float('recurrent_dropout_prob', 0.90, 'Probability of recurrent dropout keep.') 84 | tf.app.flags.DEFINE_boolean('rnn_input_dropout', False, 'RNN input dropout.') 85 | tf.app.flags.DEFINE_boolean('rnn_output_dropout', False, 'RNN output droput.') 86 | tf.app.flags.DEFINE_integer('enc_rnn_size', 256, 'Size of RNN when used as encoder') 87 | tf.app.flags.DEFINE_integer('dec_rnn_size', 512, 'Size of RNN when used as decoder') 88 | tf.app.flags.DEFINE_integer('z_size', 128, 'Size of latent vector z') 89 | tf.app.flags.DEFINE_integer('num_mixture', 20, 'Size of latent vector z') 90 | 91 | 92 | def reset_graph(): 93 | """Closes the current default session and resets the graph.""" 94 | sess = tf.get_default_session() 95 | if sess: 96 | sess.close() 97 | tf.reset_default_graph() 98 | 99 | 100 | def load_model(model_dir): 101 | """Loads model for inference mode, used in jupyter notebook.""" 102 | model_params = sketch_rnn_model.get_default_hparams() 103 | with tf.gfile.Open(os.path.join(model_dir, 'model_config.json'), 'r') as f: 104 | model_params.parse_json(f.read()) 105 | 106 | model_params.batch_size = 1 # only sample one at a time 107 | eval_model_params = sketch_rnn_model.copy_hparams(model_params) 108 | eval_model_params.use_input_dropout = 0 109 | eval_model_params.use_recurrent_dropout = 0 110 | eval_model_params.use_output_dropout = 0 111 | eval_model_params.is_training = 0 112 | sample_model_params = sketch_rnn_model.copy_hparams(eval_model_params) 113 | sample_model_params.max_seq_len = 1 # sample one point at a time 114 | return [model_params, eval_model_params, sample_model_params] 115 | 116 | 117 | def sampling_model(sess, model, gen_model, data_set, step, seq_len, subset_str=''): 118 | """Returns the average weighted cost, reconstruction cost and KL cost.""" 119 | sketch_size, photo_size = data_set.sketch_size, data_set.image_size 120 | 121 | image_index = np.random.randint(0, photo_size) 122 | sketch_index = data_set.get_corr_sketch_id(image_index) 123 | gt_strokes = data_set.sketch_strokes[sketch_index] 124 | 125 | image_feat, rnn_enc_seq_len = data_set.get_input_image(image_index) 126 | sample_strokes, m = sample(sess, model, image_feat, seq_len=seq_len, rnn_enc_seq_len=rnn_enc_seq_len) 127 | strokes = utils.to_normal_strokes(sample_strokes) 128 | svg_gen_sketch = os.path.join(FLAGS.img_dir, '%s/%s/gensketch_for_photo%d_step%d.svg' % (data_set.dataset, subset_str, image_index, step)) 129 | utils.draw_strokes(strokes, svg_filename=svg_gen_sketch) 130 | svg_gt_sketch = os.path.join(FLAGS.img_dir, '%s/%s/gt_sketch%d_for_photo%d.svg' % (data_set.dataset, subset_str, sketch_index, image_index)) 131 | utils.draw_strokes(gt_strokes, svg_filename=svg_gt_sketch) 132 | input_sketch = data_set.pad_single_sketch(image_index) 133 | feed = {gen_model.input_sketch: input_sketch, gen_model.input_photo: image_feat, gen_model.sequence_lengths: [seq_len]} 134 | gen_photo = sess.run(gen_model.gen_photo, feed) 135 | gen_photo_file = os.path.join(FLAGS.img_dir, '%s/%s/gen_photo%d_step%d.png' % (data_set.dataset, subset_str, image_index, step)) 136 | cv2.imwrite(gen_photo_file, cv2.cvtColor(gen_photo[0, ::].astype(np.uint8), cv2.COLOR_RGB2BGR)) 137 | gt_photo = os.path.join(FLAGS.img_dir, '%s/%s/gt_photo%d.png' % (data_set.dataset, subset_str, image_index)) 138 | if len(image_feat[0].shape) == 2: 139 | cv2.imwrite(gt_photo, image_feat[0]) 140 | else: 141 | cv2.imwrite(gt_photo, cv2.cvtColor(image_feat[0].astype(np.uint8), cv2.COLOR_RGB2BGR)) 142 | 143 | 144 | def load_pretrain(sess, vae_type, enc_type, dataset, basenet, log_root): 145 | if vae_type in ['ps2s', 'sp2s'] or dataset in ['shoesv2', 'chairsv2']: 146 | if 'shoe' in dataset: 147 | sv_str = 'shoe' 148 | elif 'chair' in dataset: 149 | sv_str = 'chair' 150 | pretrain_dir = log_root.split('runs')[0] + 'pretrained_model/%s/' % sv_str 151 | ckpt = tf.train.get_checkpoint_state(pretrain_dir) 152 | if ckpt is not None: 153 | pretrained_model = ckpt.model_checkpoint_path 154 | print('Loading model %s.' % pretrained_model) 155 | checkpoint_exclude_scopes = [] 156 | init_fn = get_init_fn(pretrained_model, checkpoint_exclude_scopes) 157 | init_fn(sess) 158 | else: 159 | print('Warning: pretrained model not found at %s' % pretrain_dir) 160 | 161 | 162 | def resume_train(sess, load_dir, dataset, enc_type, basenet, feat_type, log_root): 163 | if not load_dir: 164 | if 'shoe' in dataset: 165 | sv_str = 'shoe' 166 | elif 'chair' in dataset: 167 | sv_str = 'chair' 168 | load_dir = log_root.split('runs')[0] + 'save_models/%s/' % sv_str 169 | # set dir to load the model for resume training 170 | load_dir = load_dir + basenet 171 | 172 | load_checkpoint(sess, load_dir) 173 | 174 | 175 | def load_checkpoint(sess, checkpoint_path): 176 | 177 | ckpt = tf.train.get_checkpoint_state(checkpoint_path) 178 | if ckpt is None: 179 | raise Exception('Pretrained model not found at %s' % checkpoint_path) 180 | print('Loading model %s.' % ckpt.model_checkpoint_path) 181 | saver = tf.train.Saver(tf.global_variables()) 182 | saver.restore(sess, ckpt.model_checkpoint_path) 183 | 184 | 185 | def save_model(sess, saver, model_save_path, global_step): 186 | checkpoint_path = os.path.join(model_save_path, 'p2s') 187 | print('saving model %s at global_step %i.' % (checkpoint_path, global_step)) 188 | saver.save(sess, checkpoint_path, global_step=global_step) 189 | 190 | 191 | def train(sess, model, train_set): 192 | """Train a sketch-rnn model.""" 193 | 194 | # print log 195 | if FLAGS.tee_log: 196 | utils.config_and_print_log(FLAGS.log_root, FLAGS) 197 | 198 | # set image dir 199 | FLAGS.img_dir = FLAGS.log_root.split('runs')[0] + 'sv_imgs/' 200 | 201 | # main train loop 202 | hps = model.hps 203 | start = time.time() 204 | 205 | # create saver 206 | saver = tf.train.Saver(tf.global_variables(), max_to_keep=5) 207 | 208 | curr_step = sess.run(model.global_step) 209 | 210 | for step_id in range(curr_step, FLAGS.max_steps + FLAGS.save_every): 211 | 212 | step = sess.run(model.global_step) 213 | 214 | if hps.vae_type in ['p2s', 's2s']: 215 | s, a, p = train_set.random_batch() 216 | feed = { 217 | model.input_sketch: a, 218 | model.input_photo: p, 219 | model.sequence_lengths: s, 220 | } 221 | else: 222 | s, a, p, n = train_set.random_batch() 223 | feed = { 224 | model.input_sketch: a, 225 | model.input_photo: p, 226 | model.input_sketch_photo: n, 227 | model.sequence_lengths: s, 228 | } 229 | 230 | (train_cost, r_cost, kl_cost, p2s_r, p2s_kl, s2p_r, s2p_kl, p2p_r, p2p_kl, s2s_r, s2s_kl, _, train_step, _) = \ 231 | sess.run([model.cost, model.r_cost, model.kl_cost, model.p2s_r, model.p2s_kl, model.s2p_r, model.s2p_kl, 232 | model.p2p_r, model.p2p_kl, model.s2s_r, model.s2s_kl, model.final_state, model.global_step, 233 | model.train_op], feed) 234 | 235 | if step % hps.print_every == 0 and step > 0: 236 | end = time.time() 237 | time_taken = end - start 238 | 239 | output_format = ('step: %d, ALL (cost: %.4f, recon: %.4f, kl: %.4f), P2S (recons: %.4f, kl: %.4f), ' 240 | 'S2P (recons: %.4f, kl: %.4f), P2P (recons: %.4f, kl: %.4f), S2S (recons: %.4f, kl: %.4f), ' 241 | 'train_time_taken: %.4f') 242 | output_values = (step, train_cost, r_cost, kl_cost, p2s_r, p2s_kl, 243 | s2p_r, s2p_kl, p2p_r, p2p_kl, s2s_r, s2s_kl, time_taken) 244 | output_log = output_format % output_values 245 | 246 | print(output_log) 247 | 248 | start = time.time() 249 | 250 | if step % hps.save_every == 0 and step > 0: 251 | 252 | save_model(sess, saver, FLAGS.log_root, step) 253 | 254 | print("Finished training stage with %d steps" % FLAGS.max_steps) 255 | 256 | 257 | def trainer(model_params): 258 | """Train a sketch-rnn model.""" 259 | np.set_printoptions(precision=8, edgeitems=6, linewidth=200, suppress=True) 260 | 261 | print('Loading data files.') 262 | train_set, model_params = utils.load_dataset(FLAGS.root_dir, FLAGS.dataset, model_params) 263 | 264 | reset_graph() 265 | model = sketch_rnn_model.Model(model_params) 266 | 267 | sess = tf.Session() 268 | sess.run(tf.global_variables_initializer()) 269 | 270 | if FLAGS.load_pretrain: 271 | load_pretrain(sess, FLAGS.vae_type, FLAGS.enc_type, FLAGS.dataset, FLAGS.basenet, FLAGS.log_root) 272 | 273 | if FLAGS.resume_training: 274 | resume_train(sess, FLAGS.load_dir, FLAGS.dataset, FLAGS.enc_type, FLAGS.basenet, FLAGS.feat_type, FLAGS.log_root) 275 | 276 | train(sess, model, train_set) 277 | 278 | 279 | def main(unused_argv): 280 | """Load model params, save config file and start trainer.""" 281 | model_params = tf.contrib.training.HParams() 282 | # merge FLAGS to hps 283 | for attr, value in sorted(FLAGS.__flags.items()): 284 | model_params.add_hparam(attr, value) 285 | 286 | trainer(model_params) 287 | 288 | 289 | if __name__ == '__main__': 290 | tf.app.run(main) 291 | -------------------------------------------------------------------------------- /tf_data_work.py: -------------------------------------------------------------------------------- 1 | # modeified from image processing in inceptionv3 2 | # revised at 01/10/2017, By Jeffrey 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import random 9 | import tensorflow as tf 10 | from tensorflow.python.ops import control_flow_ops 11 | import pdb 12 | # from model import image_processing_test 13 | 14 | FLAGS = tf.app.flags.FLAGS 15 | 16 | 17 | def _crop(image, offset_height, offset_width, crop_height, crop_width): 18 | """Crops the given image using the provided offsets and sizes. 19 | Note that the method doesn't assume we know the input image size but it does 20 | assume we know the input image rank. 21 | Args: 22 | image: an image of shape [height, width, channels]. 23 | offset_height: a scalar tensor indicating the height offset. 24 | offset_width: a scalar tensor indicating the width offset. 25 | crop_height: the height of the cropped image. 26 | crop_width: the width of the cropped image. 27 | Returns: 28 | the cropped (and resized) image. 29 | Raises: 30 | InvalidArgumentError: if the rank is not 3 or if the image dimensions are 31 | less than the crop size. 32 | """ 33 | original_shape = tf.shape(image) 34 | 35 | rank_assertion = tf.Assert( 36 | tf.equal(tf.rank(image), 3), 37 | ['Rank of image must be equal to 3.']) 38 | cropped_shape = control_flow_ops.with_dependencies( 39 | [rank_assertion], 40 | tf.pack([crop_height, crop_width, original_shape[2]])) 41 | 42 | size_assertion = tf.Assert( 43 | tf.logical_and( 44 | tf.greater_equal(original_shape[0], crop_height), 45 | tf.greater_equal(original_shape[1], crop_width)), 46 | ['Crop size greater than the image size.']) 47 | 48 | offsets = tf.to_int32(tf.pack([offset_height, offset_width, 0])) 49 | 50 | # Use tf.slice instead of crop_to_bounding box as it accepts tensors to 51 | # define the crop size. 52 | image = control_flow_ops.with_dependencies( 53 | [size_assertion], 54 | tf.slice(image, offsets, cropped_shape)) 55 | return tf.reshape(image, cropped_shape) 56 | 57 | 58 | def _random_crop(image_list, crop_height, crop_width): 59 | """Crops the given list of images. 60 | The function applies the same crop to each image in the list. This can be 61 | effectively applied when there are multiple image inputs of the same 62 | dimension such as: 63 | image, depths, normals = _random_crop([image, depths, normals], 120, 150) 64 | Args: 65 | image_list: a list of image tensors of the same dimension but possibly 66 | varying channel. 67 | crop_height: the new height. 68 | crop_width: the new width. 69 | Returns: 70 | the image_list with cropped images. 71 | Raises: 72 | ValueError: if there are multiple image inputs provided with different size 73 | or the images are smaller than the crop dimensions. 74 | """ 75 | if not image_list: 76 | raise ValueError('Empty image_list.') 77 | 78 | # Compute the rank assertions. 79 | rank_assertions = [] 80 | for i in range(len(image_list)): 81 | image_rank = tf.rank(image_list[i]) 82 | rank_assert = tf.Assert( 83 | tf.equal(image_rank, 3), 84 | ['Wrong rank for tensor %s [expected] [actual]', 85 | image_list[i].name, 3, image_rank]) 86 | rank_assertions.append(rank_assert) 87 | 88 | image_shape = control_flow_ops.with_dependencies( 89 | [rank_assertions[0]], 90 | tf.shape(image_list[0])) 91 | image_height = image_shape[0] 92 | image_width = image_shape[1] 93 | crop_size_assert = tf.Assert( 94 | tf.logical_and( 95 | tf.greater_equal(image_height, crop_height), 96 | tf.greater_equal(image_width, crop_width)), 97 | ['Crop size greater than the image size.']) 98 | 99 | asserts = [rank_assertions[0], crop_size_assert] 100 | 101 | for i in range(1, len(image_list)): 102 | image = image_list[i] 103 | asserts.append(rank_assertions[i]) 104 | shape = control_flow_ops.with_dependencies([rank_assertions[i]], 105 | tf.shape(image)) 106 | height = shape[0] 107 | width = shape[1] 108 | 109 | height_assert = tf.Assert( 110 | tf.equal(height, image_height), 111 | ['Wrong height for tensor %s [expected][actual]', 112 | image.name, height, image_height]) 113 | width_assert = tf.Assert( 114 | tf.equal(width, image_width), 115 | ['Wrong width for tensor %s [expected][actual]', 116 | image.name, width, image_width]) 117 | asserts.extend([height_assert, width_assert]) 118 | 119 | # Create a random bounding box. 120 | # 121 | # Use tf.random_uniform and not numpy.random.rand as doing the former would 122 | # generate random numbers at graph eval time, unlike the latter which 123 | # generates random numbers at graph definition time. 124 | max_offset_height = control_flow_ops.with_dependencies( 125 | asserts, tf.reshape(image_height - crop_height + 1, [])) 126 | max_offset_width = control_flow_ops.with_dependencies( 127 | asserts, tf.reshape(image_width - crop_width + 1, [])) 128 | offset_height = tf.random_uniform( 129 | [], maxval=max_offset_height, dtype=tf.int32) 130 | offset_width = tf.random_uniform( 131 | [], maxval=max_offset_width, dtype=tf.int32) 132 | 133 | return [_crop(image, offset_height, offset_width, 134 | crop_height, crop_width) for image in image_list] 135 | 136 | 137 | def _central_crop(image_list, crop_height, crop_width): 138 | """Performs central crops of the given image list. 139 | Args: 140 | image_list: a list of image tensors of the same dimension but possibly 141 | varying channel. 142 | crop_height: the height of the image following the crop. 143 | crop_width: the width of the image following the crop. 144 | Returns: 145 | the list of cropped images. 146 | """ 147 | outputs = [] 148 | for image in image_list: 149 | image_height = tf.shape(image)[0] 150 | image_width = tf.shape(image)[1] 151 | 152 | offset_height = (image_height - crop_height) / 2 153 | offset_width = (image_width - crop_width) / 2 154 | 155 | outputs.append(_crop(image, offset_height, offset_width, 156 | crop_height, crop_width)) 157 | return outputs 158 | 159 | 160 | def decode_png(image_buffer, scope=None): 161 | """Decode a JPEG string into one 3-D float image Tensor. 162 | 163 | Args: 164 | image_buffer: scalar string Tensor. 165 | scope: Optional scope for op_scope. 166 | Returns: 167 | 3-D float Tensor with values ranging from [0, 1). 168 | """ 169 | # with tf.name_scope([image_buffer], scope, 'decode_png'): 170 | # Decode the string as an RGB JPEG. 171 | image = tf.image.decode_png(image_buffer, channels=3) 172 | 173 | # After this point, all image pixels reside in [0,1) 174 | # until the very end, when they're rescaled to (-1, 1). The various 175 | # adjust_* ops all require this range for dtype float. 176 | image = tf.image.convert_image_dtype(image, dtype=tf.float32) 177 | return image 178 | 179 | 180 | def distort_color(image, thread_id=0, scope=None): 181 | """Distort the color of the image. 182 | 183 | Each color distortion is non-commutative and thus ordering of the color ops 184 | matters. Ideally we would randomly permute the ordering of the color ops. 185 | Rather then adding that level of complication, we select a distinct ordering 186 | of color ops for each preprocessing thread. 187 | 188 | Args: 189 | image: Tensor containing single image. 190 | thread_id: preprocessing thread ID. 191 | scope: Optional scope for op_scope. 192 | Returns: 193 | color-distorted image 194 | """ 195 | with tf.op_scope([image], scope, 'distort_color'): 196 | color_ordering = thread_id % 2 197 | 198 | if color_ordering == 0: 199 | image = tf.image.random_brightness(image, max_delta=32. / 255.) 200 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5) 201 | image = tf.image.random_hue(image, max_delta=0.2) 202 | image = tf.image.random_contrast(image, lower=0.5, upper=1.5) 203 | elif color_ordering == 1: 204 | image = tf.image.random_brightness(image, max_delta=32. / 255.) 205 | image = tf.image.random_contrast(image, lower=0.5, upper=1.5) 206 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5) 207 | image = tf.image.random_hue(image, max_delta=0.2) 208 | 209 | # The random_* ops do not necessarily clamp. 210 | image = tf.clip_by_value(image, 0.0, 1.0) 211 | return image 212 | 213 | 214 | def distort_image(image, height, width, chn_size=3, bbox=None, flipr=False, scope=None): 215 | """Distort one image for training a network. 216 | 217 | Distorting images provides a useful technique for augmenting the data 218 | set during training in order to make the network invariant to aspects 219 | of the image that do not effect the label. 220 | 221 | Args: 222 | image: 3-D float Tensor of image 223 | height: integer 224 | width: integer 225 | bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords] 226 | where each coordinate is [0, 1) and the coordinates are arranged 227 | as [ymin, xmin, ymax, xmax]. 228 | thread_id: integer indicating the preprocessing thread. 229 | scope: Optional scope for op_scope. 230 | Returns: 231 | 3-D float Tensor of distorted image used for training. 232 | """ 233 | if bbox is None: 234 | bbox = tf.constant([0.0, 0.0, 1.0, 1.0], 235 | dtype=tf.float32, 236 | shape=[1, 1, 4]) 237 | with tf.op_scope([image, height, width, bbox], scope, 'distort_image'): 238 | # Each bounding box has shape [1, num_boxes, box coords] and 239 | # the coordinates are ordered [ymin, xmin, ymax, xmax]. 240 | # A large fraction of image datasets contain a human-annotated bounding 241 | # box delineating the region of the image containing the object of interest. 242 | # We choose to create a new bounding box for the object which is a randomly 243 | # distorted version of the human-annotated bounding box that obeys an allowed 244 | # range of aspect ratios, sizes and overlap with the human-annotated 245 | # bounding box. If no box is supplied, then we assume the bounding box is 246 | # the entire image. 247 | sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box( 248 | tf.shape(image), 249 | bounding_boxes=[[[0.0, 0.0, 1.0, 1.0]]], 250 | min_object_covered=0.1, 251 | aspect_ratio_range=[0.9, 1.1], 252 | area_range=[0.8, 1.0], 253 | max_attempts=100, 254 | use_image_if_no_bounding_boxes=True) 255 | bbox_begin, bbox_size, distort_bbox = sample_distorted_bounding_box 256 | # if not thread_id: 257 | # image_with_distorted_box = tf.image.draw_bounding_boxes( 258 | # tf.expand_dims(image, 0), distort_bbox) 259 | # tf.image_summary('images_with_distorted_bounding_box', 260 | # image_with_distorted_box) 261 | 262 | # Crop the image to the specified bounding box. 263 | distorted_image = tf.slice(image, bbox_begin, bbox_size) 264 | 265 | # This resizing operation may distort the images because the aspect 266 | # ratio is not respected. We select a resize method in a round robin 267 | # fashion based on the thread number. 268 | # Note that ResizeMethod contains 4 enumerated resizing methods. 269 | distorted_image = tf.image.resize_images(distorted_image, [height, width]) 270 | # Restore the shape since the dynamic slice based upon the bbox_size loses 271 | # the third dimension. 272 | distorted_image.set_shape([height, width, chn_size]) 273 | 274 | # Randomly flip the image horizontally. 275 | if flipr: 276 | distorted_image = tf.image.random_flip_left_right(distorted_image) 277 | 278 | # # Randomly distort the colors. 279 | # distorted_image = distort_color(distorted_image) 280 | 281 | return distorted_image 282 | 283 | 284 | def processing_image(image_buffer, thread_id=0, data_augmentation_flag=True): 285 | 286 | image = decode_png(image_buffer) 287 | 288 | if data_augmentation_flag: 289 | image = distort_image(image, FLAGS.output_height, FLAGS.output_width, thread_id) 290 | # else: 291 | # image = eval_image(image, height, width) 292 | 293 | image.set_shape([FLAGS.output_height, FLAGS.output_width, 3]) 294 | image = tf.image.resize_images(image, [FLAGS.resize_height, FLAGS.resize_width]) 295 | image = tf.sub(image, 0.5) 296 | image = tf.mul(image, 2.0) 297 | # image = _central_crop([image], FLAGS.crop_height, FLAGS.crop_width)[0] 298 | image.set_shape([FLAGS.resize_height, FLAGS.resize_width, 3]) 299 | image = tf.to_float(image) 300 | 301 | return image 302 | 303 | 304 | def call_distort_image(image): 305 | crop_size = FLAGS.crop_size 306 | dist_chn_size = FLAGS.dist_chn_size 307 | return distort_image(image, crop_size, crop_size, dist_chn_size) 308 | 309 | 310 | def data_augmentation(raw_images): 311 | # processed_images = tf.map_fn(lambda inputs: call_distort_image(*inputs), elems=raw_images, dtype=tf.float32) 312 | processed_images = tf.map_fn(lambda inputs: call_distort_image(inputs), raw_images) 313 | return processed_images 314 | 315 | 316 | def tf_high_pass_filter(tf_images): 317 | # implementation of high pass filter according to the "Sketch-pix2seq" 318 | filter_w = tf.constant([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]], tf.float32) 319 | filter_w = tf.expand_dims(tf.expand_dims(filter_w, -1), -1, name='hp_w') 320 | filtered_images = tf.nn.conv2d(tf_images, filter_w, strides=[1, 1, 1, 1], padding='SAME') 321 | return filtered_images 322 | 323 | 324 | def tf_image_processing(tf_images, basenet, crop_size, distort=False, hp_filter=False): 325 | if len(tf_images.shape) == 3: 326 | tf_images = tf.expand_dims(tf_images, -1) 327 | if basenet == 'sketchanet': 328 | mean_value = 250.42 329 | tf_images = tf.subtract(tf_images, mean_value) 330 | if distort: 331 | print("Distorting photos") 332 | FLAGS.crop_size = crop_size 333 | FLAGS.dist_chn_size = 1 334 | tf_images = data_augmentation(tf_images) 335 | else: 336 | tf_images = tf.image.resize_images(tf_images, (crop_size, crop_size)) 337 | elif basenet in ['inceptionv1', 'inceptionv3', 'gen_cnn']: 338 | tf_images = tf.divide(tf_images, 255.0) 339 | tf_images = tf.subtract(tf_images, 0.5) 340 | tf_images = tf.multiply(tf_images, 2.0) 341 | if int(tf_images.shape[-1]) != 3: 342 | tf_images = tf.concat([tf_images, tf_images, tf_images], axis=-1) 343 | if distort: 344 | print("Distorting photos") 345 | FLAGS.crop_size = crop_size 346 | FLAGS.dist_chn_size = 3 347 | tf_images = data_augmentation(tf_images) 348 | # Display the training images in the visualizer. 349 | # tf.image_summary('input_images', input_images) 350 | else: 351 | tf_images = tf.image.resize_images(tf_images, (crop_size, crop_size)) 352 | 353 | if hp_filter: 354 | tf_images = tf_high_pass_filter(tf_images) 355 | 356 | return tf_images 357 | -------------------------------------------------------------------------------- /svg_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cPickle 3 | import numpy as np 4 | import xml.etree.ElementTree as ET 5 | import random 6 | import svgwrite 7 | from IPython.display import SVG, display 8 | from svg.path import Path, Line, Arc, CubicBezier, QuadraticBezier, parse_path 9 | 10 | 11 | def calculate_start_point(data, factor=1.0, block_size=200): 12 | # will try to center the sketch to the middle of the block 13 | # determines maxx, minx, maxy, miny 14 | sx = 0 15 | sy = 0 16 | maxx = 0 17 | minx = 0 18 | maxy = 0 19 | miny = 0 20 | for i in xrange(len(data)): 21 | sx += round(float(data[i, 0]) * factor, 3) 22 | sy += round(float(data[i, 1]) * factor, 3) 23 | maxx = max(maxx, sx) 24 | minx = min(minx, sx) 25 | maxy = max(maxy, sy) 26 | miny = min(miny, sy) 27 | 28 | abs_x = block_size / 2 - (maxx - minx) / 2 - minx 29 | abs_y = block_size / 2 - (maxy - miny) / 2 - miny 30 | 31 | return abs_x, abs_y, (maxx - minx), (maxy - miny) 32 | 33 | 34 | def draw_stroke_color_array(data, factor=1, svg_filename='sample.svg', stroke_width=1, block_size=200, maxcol=5, 35 | svg_only=False, color_mode=True): 36 | num_char = len(data) 37 | 38 | if num_char < 1: 39 | return 40 | 41 | max_color_intensity = 225 42 | 43 | numrow = np.ceil(float(num_char) / float(maxcol)) 44 | dwg = svgwrite.Drawing(svg_filename, size=(block_size * (min(num_char, maxcol)), block_size * numrow)) 45 | dwg.add(dwg.rect(insert=(0, 0), size=(block_size * (min(num_char, maxcol)), block_size * numrow), fill='white')) 46 | 47 | the_color = "rgb(" + str(random.randint(0, max_color_intensity)) + "," + str( 48 | int(random.randint(0, max_color_intensity))) + "," + str(int(random.randint(0, max_color_intensity))) + ")" 49 | 50 | for j in xrange(len(data)): 51 | 52 | lift_pen = 0 53 | # end_of_char = 0 54 | cdata = data[j] 55 | abs_x, abs_y, size_x, size_y = calculate_start_point(cdata, factor, block_size) 56 | abs_x += (j % maxcol) * block_size 57 | abs_y += (j / maxcol) * block_size 58 | 59 | for i in xrange(len(cdata)): 60 | 61 | x = round(float(cdata[i, 0]) * factor, 3) 62 | y = round(float(cdata[i, 1]) * factor, 3) 63 | 64 | prev_x = round(abs_x, 3) 65 | prev_y = round(abs_y, 3) 66 | 67 | abs_x += x 68 | abs_y += y 69 | 70 | if (lift_pen == 1): 71 | p = "M " + str(abs_x) + "," + str(abs_y) + " " 72 | the_color = "rgb(" + str(random.randint(0, max_color_intensity)) + "," + str( 73 | int(random.randint(0, max_color_intensity))) + "," + str( 74 | int(random.randint(0, max_color_intensity))) + ")" 75 | else: 76 | p = "M " + str(prev_x) + "," + str(prev_y) + " L " + str(abs_x) + "," + str(abs_y) + " " 77 | 78 | lift_pen = max(cdata[i, 2], cdata[i, 3]) # lift pen if both eos or eoc 79 | # end_of_char = cdata[i, 3] # not used for now. 80 | 81 | if color_mode == False: 82 | the_color = "#000" 83 | 84 | dwg.add(dwg.path(p).stroke(the_color, stroke_width).fill( 85 | the_color)) # , opacity=round(random.random()*0.5+0.5, 3) 86 | 87 | dwg.save() 88 | if svg_only == False: 89 | display(SVG(dwg.tostring())) 90 | 91 | 92 | def draw_stroke_color(data, factor=1, svg_filename='sample.svg', stroke_width=1, block_size=200, maxcol=5, 93 | svg_only=False, color_mode=True): 94 | def split_sketch(data): 95 | # split a sketch with many eoc into an array of sketches, each with just one eoc at the end. 96 | # ignores last stub with no eoc. 97 | counter = 0 98 | result = [] 99 | for i in xrange(len(data)): 100 | eoc = data[i, 3] 101 | if eoc > 0: 102 | result.append(data[counter:i + 1]) 103 | counter = i + 1 104 | # if (counter < len(data)): # ignore the rest 105 | # result.append(data[counter:]) 106 | return result 107 | 108 | data = np.array(data, dtype=np.float32) 109 | data = split_sketch(data) 110 | draw_stroke_color_array(data, factor, svg_filename, stroke_width, block_size, maxcol, svg_only, color_mode) 111 | 112 | 113 | def cubicbezier(x0, y0, x1, y1, x2, y2, x3, y3, n=20): 114 | # from http://rosettacode.org/wiki/Bitmap/B%C3%A9zier_curves/Cubic 115 | pts = [] 116 | for i in range(n + 1): 117 | t = float(i) / float(n) 118 | a = (1. - t) ** 3 119 | b = 3. * t * (1. - t) ** 2 120 | c = 3.0 * t ** 2 * (1.0 - t) 121 | d = t ** 3 122 | 123 | x = float(a * x0 + b * x1 + c * x2 + d * x3) 124 | y = float(a * y0 + b * y1 + c * y2 + d * y3) 125 | pts.append((x, y)) 126 | return pts 127 | 128 | 129 | def get_path_strings(svgfile): 130 | tree = ET.parse(svgfile) 131 | p = [] 132 | for elem in tree.iter(): 133 | if elem.attrib.has_key('d'): 134 | p.append(elem.attrib['d']) 135 | return p 136 | 137 | 138 | def build_lines(svgfile, line_length_threshold=10.0, min_points_per_path=1, max_points_per_path=3): 139 | # we don't draw lines less than line_length_threshold 140 | path_strings = get_path_strings(svgfile) 141 | 142 | lines = [] 143 | 144 | for path_string in path_strings: 145 | try: 146 | full_path = parse_path(path_string) 147 | except: 148 | import pdb 149 | pdb.set_trace() 150 | print "e" 151 | for i in range(len(full_path)): 152 | p = full_path[i] 153 | if type(p) != Line and type(p) != CubicBezier: 154 | print "encountered an element that is not just a line or bezier " 155 | print "type: ", type(p) 156 | print p 157 | else: 158 | x_start = p.start.real 159 | y_start = p.start.imag 160 | x_end = p.end.real 161 | y_end = p.end.imag 162 | line_length = np.sqrt( 163 | (x_end - x_start) * (x_end - x_start) + (y_end - y_start) * (y_end - y_start)) 164 | # len_data.append(line_length) 165 | points = [] 166 | if type(p) == CubicBezier: 167 | x_con1 = p.control1.real 168 | y_con1 = p.control1.imag 169 | x_con2 = p.control2.real 170 | y_con2 = p.control2.imag 171 | n_points = int(line_length / line_length_threshold) + 1 172 | n_points = max(n_points, min_points_per_path) 173 | n_points = min(n_points, max_points_per_path) 174 | points = cubicbezier(x_start, y_start, x_con1, y_con1, x_con2, y_con2, x_end, y_end, 175 | n_points) 176 | else: 177 | points = [(x_start, y_start), (x_end, y_end)] 178 | if i == 0: # only append the starting point for svg 179 | lines.append([points[0][0], points[0][1], 0, 0]) # put eoc to be zero 180 | for j in range(1, len(points)): 181 | eos = 0 182 | if j == len(points) - 1 and i == len(full_path) - 1: 183 | eos = 1 184 | lines.append([points[j][0], points[j][1], eos, 0]) # put eoc to be zero 185 | lines = np.array(lines, dtype=np.float32) 186 | # make it relative moves 187 | lines[1:, 0:2] -= lines[0:-1, 0:2] 188 | lines[-1, 3] = 1 # end of character 189 | lines[0] = [0, 0, 0, 0] # start at origin 190 | return lines[1:] 191 | 192 | 193 | class SketchLoader(): 194 | def __init__(self, batch_size=50, seq_length=300, scale_factor=1.0, data_filename="kanji"): 195 | import pdb 196 | pdb.set_trace() 197 | self.data_dir = "./data" 198 | self.batch_size = batch_size 199 | self.seq_length = seq_length 200 | self.scale_factor = scale_factor # divide data by this factor 201 | 202 | data_file = os.path.join(self.data_dir, data_filename + ".cpkl") 203 | raw_data_dir = os.path.join(self.data_dir, data_filename) 204 | 205 | if not (os.path.exists(data_file)): 206 | raise Exception('File not exist') 207 | # self.length_data = self.preprocess(raw_data_dir, data_file) 208 | 209 | self.load_preprocessed(data_file) 210 | self.num_samples = len(self.raw_data) 211 | self.index = range(self.num_samples) # this list will be randomized later. 212 | self.reset_index_pointer() 213 | 214 | def preprocess(self, data_dir, data_file): 215 | # create data file from raw xml files from iam handwriting source. 216 | len_data = [] 217 | 218 | def cubicbezier(x0, y0, x1, y1, x2, y2, x3, y3, n=20): 219 | # from http://rosettacode.org/wiki/Bitmap/B%C3%A9zier_curves/Cubic 220 | pts = [] 221 | for i in range(n + 1): 222 | t = float(i) / float(n) 223 | a = (1. - t) ** 3 224 | b = 3. * t * (1. - t) ** 2 225 | c = 3.0 * t ** 2 * (1.0 - t) 226 | d = t ** 3 227 | 228 | x = float(a * x0 + b * x1 + c * x2 + d * x3) 229 | y = float(a * y0 + b * y1 + c * y2 + d * y3) 230 | pts.append((x, y)) 231 | return pts 232 | 233 | def get_path_strings(svgfile): 234 | tree = ET.parse(svgfile) 235 | p = [] 236 | for elem in tree.iter(): 237 | if elem.attrib.has_key('d'): 238 | p.append(elem.attrib['d']) 239 | return p 240 | 241 | def build_lines(svgfile, line_length_threshold=10.0, min_points_per_path=1, max_points_per_path=3): 242 | # we don't draw lines less than line_length_threshold 243 | path_strings = get_path_strings(svgfile) 244 | 245 | lines = [] 246 | 247 | for path_string in path_strings: 248 | full_path = parse_path(path_string) 249 | for i in range(len(full_path)): 250 | p = full_path[i] 251 | if type(p) != Line and type(p) != CubicBezier: 252 | print "encountered an element that is not just a line or bezier " 253 | print "type: ", type(p) 254 | print p 255 | else: 256 | x_start = p.start.real 257 | y_start = p.start.imag 258 | x_end = p.end.real 259 | y_end = p.end.imag 260 | line_length = np.sqrt( 261 | (x_end - x_start) * (x_end - x_start) + (y_end - y_start) * (y_end - y_start)) 262 | len_data.append(line_length) 263 | points = [] 264 | if type(p) == CubicBezier: 265 | x_con1 = p.control1.real 266 | y_con1 = p.control1.imag 267 | x_con2 = p.control2.real 268 | y_con2 = p.control2.imag 269 | n_points = int(line_length / line_length_threshold) + 1 270 | n_points = max(n_points, min_points_per_path) 271 | n_points = min(n_points, max_points_per_path) 272 | points = cubicbezier(x_start, y_start, x_con1, y_con1, x_con2, y_con2, x_end, y_end, 273 | n_points) 274 | else: 275 | points = [(x_start, y_start), (x_end, y_end)] 276 | if i == 0: # only append the starting point for svg 277 | lines.append([points[0][0], points[0][1], 0, 0]) # put eoc to be zero 278 | for j in range(1, len(points)): 279 | eos = 0 280 | if j == len(points) - 1 and i == len(full_path) - 1: 281 | eos = 1 282 | lines.append([points[j][0], points[j][1], eos, 0]) # put eoc to be zero 283 | lines = np.array(lines, dtype=np.float32) 284 | # make it relative moves 285 | lines[1:, 0:2] -= lines[0:-1, 0:2] 286 | lines[-1, 3] = 1 # end of character 287 | lines[0] = [0, 0, 0, 0] # start at origin 288 | return lines[1:] 289 | 290 | # build the list of xml files 291 | filelist = [] 292 | # Set the directory you want to start from 293 | rootDir = data_dir 294 | for dirName, subdirList, fileList in os.walk(rootDir): 295 | # print('Found directory: %s' % dirName) 296 | for fname in fileList: 297 | # print('\t%s' % fname) 298 | filelist.append(dirName + "/" + fname) 299 | 300 | # build stroke database of every xml file inside iam database 301 | sketch = [] 302 | for i in range(len(filelist)): 303 | if (filelist[i][-3:] == 'svg'): 304 | print 'processing ' + filelist[i] 305 | sketch.append(build_lines(filelist[i])) 306 | 307 | f = open(data_file, "wb") 308 | cPickle.dump(sketch, f, protocol=2) 309 | f.close() 310 | import pdb 311 | pdb.set_trace() 312 | return len_data 313 | 314 | def load_preprocessed(self, data_file): 315 | f = open(data_file, "rb") 316 | self.raw_data = cPickle.load(f) 317 | # scale the data here, rather than at the data construction (since scaling may change) 318 | for data in self.raw_data: 319 | data[:, 0:2] /= self.scale_factor 320 | f.close() 321 | 322 | def next_batch(self): 323 | # returns a set of batches, but the constraint is that the start of each input data batch 324 | # is the start of a new character (although the end of a batch doesn't have to be end of a character) 325 | 326 | def next_seq(n): 327 | result = np.zeros((n, 5), dtype=np.float32) # x, y, [eos, eoc, cont] tokens 328 | # result[0, 2:4] = 1 # set eos and eoc to true for first point 329 | # experimental line below, put a random factor between 70-130% to generate more examples 330 | rand_scale_factor_x = np.random.rand() * 0.6 + 0.7 331 | rand_scale_factor_y = np.random.rand() * 0.6 + 0.7 332 | idx = 0 333 | data = self.current_data() 334 | for i in xrange(n): 335 | result[i, 0:4] = data[idx] # eoc = 0.0 336 | result[i, 4] = 1 # continue on stroke 337 | if (result[i, 2] > 0 or result[i, 3] > 0): 338 | result[i, 4] = 0 339 | idx += 1 340 | if (idx >= len(data) - 1): # skip to next sketch example next time and mark eoc 341 | result[i, 4] = 0 342 | result[i, 3] = 1 343 | result[i, 2] = 0 # overrides end of stroke one-hot 344 | idx = 0 345 | self.tick_index_pointer() 346 | data = self.current_data() 347 | assert (result[i, 2:5].sum() == 1) 348 | self.tick_index_pointer() # needed if seq_length is less than last data. 349 | result[:, 0] *= rand_scale_factor_x 350 | result[:, 1] *= rand_scale_factor_y 351 | return result 352 | 353 | skip_length = self.seq_length + 1 354 | 355 | batch = [] 356 | 357 | for i in xrange(self.batch_size): 358 | seq = next_seq(skip_length) 359 | batch.append(seq) 360 | 361 | batch = np.array(batch, dtype=np.float32) 362 | 363 | return batch[:, 0:-1], batch[:, 1:] 364 | 365 | def current_data(self): 366 | return self.raw_data[self.index[self.pointer]] 367 | 368 | def tick_index_pointer(self): 369 | self.pointer += 1 370 | if (self.pointer >= len(self.raw_data)): 371 | self.pointer = 0 372 | self.epoch_finished = True 373 | 374 | def reset_index_pointer(self): 375 | # randomize order for the raw list in the next go. 376 | self.pointer = 0 377 | self.epoch_finished = False 378 | self.index = np.random.permutation(self.index) 379 | 380 | 381 | -------------------------------------------------------------------------------- /svg/path/path.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from math import sqrt, cos, sin, acos, degrees, radians, log 3 | from collections import MutableSequence 4 | 5 | 6 | # This file contains classes for the different types of SVG path segments as 7 | # well as a Path object that contains a sequence of path segments. 8 | 9 | MIN_DEPTH = 5 10 | ERROR = 1e-12 11 | 12 | 13 | def segment_length(curve, start, end, start_point, end_point, error, min_depth, depth): 14 | """Recursively approximates the length by straight lines""" 15 | mid = (start + end) / 2 16 | mid_point = curve.point(mid) 17 | length = abs(end_point - start_point) 18 | first_half = abs(mid_point - start_point) 19 | second_half = abs(end_point - mid_point) 20 | 21 | length2 = first_half + second_half 22 | if (length2 - length > error) or (depth < min_depth): 23 | # Calculate the length of each segment: 24 | depth += 1 25 | return (segment_length(curve, start, mid, start_point, mid_point, 26 | error, min_depth, depth) + 27 | segment_length(curve, mid, end, mid_point, end_point, 28 | error, min_depth, depth)) 29 | # This is accurate enough. 30 | return length2 31 | 32 | 33 | class Line(object): 34 | 35 | def __init__(self, start, end): 36 | self.start = start 37 | self.end = end 38 | 39 | def __repr__(self): 40 | return 'Line(start=%s, end=%s)' % (self.start, self.end) 41 | 42 | def __eq__(self, other): 43 | if not isinstance(other, Line): 44 | return NotImplemented 45 | return self.start == other.start and self.end == other.end 46 | 47 | def __ne__(self, other): 48 | if not isinstance(other, Line): 49 | return NotImplemented 50 | return not self == other 51 | 52 | def point(self, pos): 53 | distance = self.end - self.start 54 | return self.start + distance * pos 55 | 56 | def length(self, error=None, min_depth=None): 57 | distance = (self.end - self.start) 58 | return sqrt(distance.real ** 2 + distance.imag ** 2) 59 | 60 | 61 | class CubicBezier(object): 62 | def __init__(self, start, control1, control2, end): 63 | self.start = start 64 | self.control1 = control1 65 | self.control2 = control2 66 | self.end = end 67 | 68 | def __repr__(self): 69 | return 'CubicBezier(start=%s, control1=%s, control2=%s, end=%s)' % ( 70 | self.start, self.control1, self.control2, self.end) 71 | 72 | def __eq__(self, other): 73 | if not isinstance(other, CubicBezier): 74 | return NotImplemented 75 | return self.start == other.start and self.end == other.end and \ 76 | self.control1 == other.control1 and self.control2 == other.control2 77 | 78 | def __ne__(self, other): 79 | if not isinstance(other, CubicBezier): 80 | return NotImplemented 81 | return not self == other 82 | 83 | def is_smooth_from(self, previous): 84 | """Checks if this segment would be a smooth segment following the previous""" 85 | if isinstance(previous, CubicBezier): 86 | return (self.start == previous.end and 87 | (self.control1 - self.start) == (previous.end - previous.control2)) 88 | else: 89 | return self.control1 == self.start 90 | 91 | def point(self, pos): 92 | """Calculate the x,y position at a certain position of the path""" 93 | return ((1 - pos) ** 3 * self.start) + \ 94 | (3 * (1 - pos) ** 2 * pos * self.control1) + \ 95 | (3 * (1 - pos) * pos ** 2 * self.control2) + \ 96 | (pos ** 3 * self.end) 97 | 98 | def length(self, error=ERROR, min_depth=MIN_DEPTH): 99 | """Calculate the length of the path up to a certain position""" 100 | start_point = self.point(0) 101 | end_point = self.point(1) 102 | return segment_length(self, 0, 1, start_point, end_point, error, min_depth, 0) 103 | 104 | 105 | class QuadraticBezier(object): 106 | def __init__(self, start, control, end): 107 | self.start = start 108 | self.end = end 109 | self.control = control 110 | 111 | def __repr__(self): 112 | return 'QuadraticBezier(start=%s, control=%s, end=%s)' % ( 113 | self.start, self.control, self.end) 114 | 115 | def __eq__(self, other): 116 | if not isinstance(other, QuadraticBezier): 117 | return NotImplemented 118 | return self.start == other.start and self.end == other.end and \ 119 | self.control == other.control 120 | 121 | def __ne__(self, other): 122 | if not isinstance(other, QuadraticBezier): 123 | return NotImplemented 124 | return not self == other 125 | 126 | def is_smooth_from(self, previous): 127 | """Checks if this segment would be a smooth segment following the previous""" 128 | if isinstance(previous, QuadraticBezier): 129 | return (self.start == previous.end and 130 | (self.control - self.start) == (previous.end - previous.control)) 131 | else: 132 | return self.control == self.start 133 | 134 | def point(self, pos): 135 | return (1 - pos) ** 2 * self.start + 2 * (1 - pos) * pos * self.control + \ 136 | pos ** 2 * self.end 137 | 138 | def length(self, error=None, min_depth=None): 139 | a = self.start - 2*self.control + self.end 140 | b = 2*(self.control - self.start) 141 | a_dot_b = a.real*b.real + a.imag*b.imag 142 | 143 | if abs(a) < 1e-12: 144 | s = abs(b) 145 | elif abs(a_dot_b + abs(a)*abs(b)) < 1e-12: 146 | k = abs(b)/abs(a) 147 | if k >= 2: 148 | s = abs(b) - abs(a) 149 | else: 150 | s = abs(a)*(k**2/2 - k + 1) 151 | else: 152 | # For an explanation of this case, see 153 | # http://www.malczak.info/blog/quadratic-bezier-curve-length/ 154 | A = 4 * (a.real ** 2 + a.imag ** 2) 155 | B = 4 * (a.real * b.real + a.imag * b.imag) 156 | C = b.real ** 2 + b.imag ** 2 157 | 158 | Sabc = 2 * sqrt(A + B + C) 159 | A2 = sqrt(A) 160 | A32 = 2 * A * A2 161 | C2 = 2 * sqrt(C) 162 | BA = B / A2 163 | 164 | s = (A32 * Sabc + A2 * B * (Sabc - C2) + (4 * C * A - B ** 2) * 165 | log((2 * A2 + BA + Sabc) / (BA + C2))) / (4 * A32) 166 | return s 167 | 168 | class Arc(object): 169 | 170 | def __init__(self, start, radius, rotation, arc, sweep, end): 171 | """radius is complex, rotation is in degrees, 172 | large and sweep are 1 or 0 (True/False also work)""" 173 | 174 | self.start = start 175 | self.radius = radius 176 | self.rotation = rotation 177 | self.arc = bool(arc) 178 | self.sweep = bool(sweep) 179 | self.end = end 180 | 181 | self._parameterize() 182 | 183 | def __repr__(self): 184 | return 'Arc(start=%s, radius=%s, rotation=%s, arc=%s, sweep=%s, end=%s)' % ( 185 | self.start, self.radius, self.rotation, self.arc, self.sweep, self.end) 186 | 187 | def __eq__(self, other): 188 | if not isinstance(other, Arc): 189 | return NotImplemented 190 | return self.start == other.start and self.end == other.end and \ 191 | self.radius == other.radius and self.rotation == other.rotation and \ 192 | self.arc == other.arc and self.sweep == other.sweep 193 | 194 | def __ne__(self, other): 195 | if not isinstance(other, Arc): 196 | return NotImplemented 197 | return not self == other 198 | 199 | def _parameterize(self): 200 | # Conversion from endpoint to center parameterization 201 | # http://www.w3.org/TR/SVG/implnote.html#ArcImplementationNotes 202 | 203 | cosr = cos(radians(self.rotation)) 204 | sinr = sin(radians(self.rotation)) 205 | dx = (self.start.real - self.end.real) / 2 206 | dy = (self.start.imag - self.end.imag) / 2 207 | x1prim = cosr * dx + sinr * dy 208 | x1prim_sq = x1prim * x1prim 209 | y1prim = -sinr * dx + cosr * dy 210 | y1prim_sq = y1prim * y1prim 211 | 212 | rx = self.radius.real 213 | rx_sq = rx * rx 214 | ry = self.radius.imag 215 | ry_sq = ry * ry 216 | 217 | # Correct out of range radii 218 | radius_check = (x1prim_sq / rx_sq) + (y1prim_sq / ry_sq) 219 | if radius_check > 1: 220 | rx *= sqrt(radius_check) 221 | ry *= sqrt(radius_check) 222 | rx_sq = rx * rx 223 | ry_sq = ry * ry 224 | 225 | t1 = rx_sq * y1prim_sq 226 | t2 = ry_sq * x1prim_sq 227 | c = sqrt(abs((rx_sq * ry_sq - t1 - t2) / (t1 + t2))) 228 | 229 | if self.arc == self.sweep: 230 | c = -c 231 | cxprim = c * rx * y1prim / ry 232 | cyprim = -c * ry * x1prim / rx 233 | 234 | self.center = complex((cosr * cxprim - sinr * cyprim) + 235 | ((self.start.real + self.end.real) / 2), 236 | (sinr * cxprim + cosr * cyprim) + 237 | ((self.start.imag + self.end.imag) / 2)) 238 | 239 | ux = (x1prim - cxprim) / rx 240 | uy = (y1prim - cyprim) / ry 241 | vx = (-x1prim - cxprim) / rx 242 | vy = (-y1prim - cyprim) / ry 243 | n = sqrt(ux * ux + uy * uy) 244 | p = ux 245 | theta = degrees(acos(p / n)) 246 | if uy < 0: 247 | theta = -theta 248 | self.theta = theta % 360 249 | 250 | n = sqrt((ux * ux + uy * uy) * (vx * vx + vy * vy)) 251 | p = ux * vx + uy * vy 252 | if p == 0: 253 | delta = degrees(acos(0)) 254 | else: 255 | delta = degrees(acos(p / n)) 256 | if (ux * vy - uy * vx) < 0: 257 | delta = -delta 258 | self.delta = delta % 360 259 | if not self.sweep: 260 | self.delta -= 360 261 | 262 | def point(self, pos): 263 | angle = radians(self.theta + (self.delta * pos)) 264 | cosr = cos(radians(self.rotation)) 265 | sinr = sin(radians(self.rotation)) 266 | 267 | x = (cosr * cos(angle) * self.radius.real - sinr * sin(angle) * 268 | self.radius.imag + self.center.real) 269 | y = (sinr * cos(angle) * self.radius.real + cosr * sin(angle) * 270 | self.radius.imag + self.center.imag) 271 | return complex(x, y) 272 | 273 | def length(self, error=ERROR, min_depth=MIN_DEPTH): 274 | """The length of an elliptical arc segment requires numerical 275 | integration, and in that case it's simpler to just do a geometric 276 | approximation, as for cubic bezier curves. 277 | """ 278 | start_point = self.point(0) 279 | end_point = self.point(1) 280 | return segment_length(self, 0, 1, start_point, end_point, error, min_depth, 0) 281 | 282 | 283 | class Path(MutableSequence): 284 | """A Path is a sequence of path segments""" 285 | 286 | # Put it here, so there is a default if unpickled. 287 | _closed = False 288 | 289 | def __init__(self, *segments, **kw): 290 | self._segments = list(segments) 291 | self._length = None 292 | self._lengths = None 293 | if 'closed' in kw: 294 | self.closed = kw['closed'] 295 | 296 | def __getitem__(self, index): 297 | return self._segments[index] 298 | 299 | def __setitem__(self, index, value): 300 | self._segments[index] = value 301 | self._length = None 302 | 303 | def __delitem__(self, index): 304 | del self._segments[index] 305 | self._length = None 306 | 307 | def insert(self, index, value): 308 | self._segments.insert(index, value) 309 | self._length = None 310 | 311 | def reverse(self): 312 | # Reversing the order of a path would require reversing each element 313 | # as well. That's not implemented. 314 | raise NotImplementedError 315 | 316 | def __len__(self): 317 | return len(self._segments) 318 | 319 | def __repr__(self): 320 | return 'Path(%s, closed=%s)' % ( 321 | ', '.join(repr(x) for x in self._segments), self.closed) 322 | 323 | def __eq__(self, other): 324 | if not isinstance(other, Path): 325 | return NotImplemented 326 | if len(self) != len(other): 327 | return False 328 | for s, o in zip(self._segments, other._segments): 329 | if not s == o: 330 | return False 331 | return True 332 | 333 | def __ne__(self, other): 334 | if not isinstance(other, Path): 335 | return NotImplemented 336 | return not self == other 337 | 338 | def _calc_lengths(self, error=ERROR, min_depth=MIN_DEPTH): 339 | if self._length is not None: 340 | return 341 | 342 | lengths = [each.length(error=error, min_depth=min_depth) for each in self._segments] 343 | self._length = sum(lengths) 344 | self._lengths = [each / self._length for each in lengths] 345 | 346 | def point(self, pos, error=ERROR): 347 | 348 | # Shortcuts 349 | if pos == 0.0: 350 | return self._segments[0].point(pos) 351 | if pos == 1.0: 352 | return self._segments[-1].point(pos) 353 | 354 | self._calc_lengths(error=error) 355 | # Find which segment the point we search for is located on: 356 | segment_start = 0 357 | for index, segment in enumerate(self._segments): 358 | segment_end = segment_start + self._lengths[index] 359 | if segment_end >= pos: 360 | # This is the segment! How far in on the segment is the point? 361 | segment_pos = (pos - segment_start) / (segment_end - segment_start) 362 | break 363 | segment_start = segment_end 364 | 365 | return segment.point(segment_pos) 366 | 367 | def length(self, error=ERROR, min_depth=MIN_DEPTH): 368 | self._calc_lengths(error, min_depth) 369 | return self._length 370 | 371 | def _is_closable(self): 372 | """Returns true if the end is on the start of a segment""" 373 | end = self[-1].end 374 | for segment in self: 375 | if segment.start == end: 376 | return True 377 | return False 378 | 379 | @property 380 | def closed(self): 381 | """Checks that the path is closed""" 382 | return self._closed and self._is_closable() 383 | 384 | @closed.setter 385 | def closed(self, value): 386 | value = bool(value) 387 | if value and not self._is_closable(): 388 | raise ValueError("End does not coincide with a segment start.") 389 | self._closed = value 390 | 391 | def d(self): 392 | if self.closed: 393 | segments = self[:-1] 394 | else: 395 | segments = self[:] 396 | 397 | current_pos = None 398 | parts = [] 399 | previous_segment = None 400 | end = self[-1].end 401 | 402 | for segment in segments: 403 | start = segment.start 404 | # If the start of this segment does not coincide with the end of 405 | # the last segment or if this segment is actually the close point 406 | # of a closed path, then we should start a new subpath here. 407 | if current_pos != start or (self.closed and start == end): 408 | parts.append('M {0:G},{1:G}'.format(start.real, start.imag)) 409 | 410 | if isinstance(segment, Line): 411 | parts.append('L {0:G},{1:G}'.format( 412 | segment.end.real, segment.end.imag) 413 | ) 414 | elif isinstance(segment, CubicBezier): 415 | if segment.is_smooth_from(previous_segment): 416 | parts.append('S {0:G},{1:G} {2:G},{3:G}'.format( 417 | segment.control2.real, segment.control2.imag, 418 | segment.end.real, segment.end.imag) 419 | ) 420 | else: 421 | parts.append('C {0:G},{1:G} {2:G},{3:G} {4:G},{5:G}'.format( 422 | segment.control1.real, segment.control1.imag, 423 | segment.control2.real, segment.control2.imag, 424 | segment.end.real, segment.end.imag) 425 | ) 426 | elif isinstance(segment, QuadraticBezier): 427 | if segment.is_smooth_from(previous_segment): 428 | parts.append('T {0:G},{1:G}'.format( 429 | segment.end.real, segment.end.imag) 430 | ) 431 | else: 432 | parts.append('Q {0:G},{1:G} {2:G},{3:G}'.format( 433 | segment.control.real, segment.control.imag, 434 | segment.end.real, segment.end.imag) 435 | ) 436 | 437 | elif isinstance(segment, Arc): 438 | parts.append('A {0:G},{1:G} {2:G} {3:d},{4:d} {5:G},{6:G}'.format( 439 | segment.radius.real, segment.radius.imag, segment.rotation, 440 | int(segment.arc), int(segment.sweep), 441 | segment.end.real, segment.end.imag) 442 | ) 443 | current_pos = segment.end 444 | previous_segment = segment 445 | 446 | if self.closed: 447 | parts.append('Z') 448 | 449 | return ' '.join(parts) 450 | -------------------------------------------------------------------------------- /magenta_rnn.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """SketchRNN RNN definition.""" 15 | 16 | # internal imports 17 | import numpy as np 18 | import tensorflow as tf 19 | 20 | 21 | def orthogonal(shape): 22 | """Orthogonal initilaizer.""" 23 | flat_shape = (shape[0], np.prod(shape[1:])) 24 | a = np.random.normal(0.0, 1.0, flat_shape) 25 | u, _, v = np.linalg.svd(a, full_matrices=False) 26 | q = u if u.shape == flat_shape else v 27 | return q.reshape(shape) 28 | 29 | 30 | def orthogonal_initializer(scale=1.0): 31 | """Orthogonal initializer.""" 32 | 33 | def _initializer(shape, dtype=tf.float32, 34 | partition_info=None): # pylint: disable=unused-argument 35 | return tf.constant(orthogonal(shape) * scale, dtype) 36 | 37 | return _initializer 38 | 39 | 40 | def lstm_ortho_initializer(scale=1.0): 41 | """LSTM orthogonal initializer.""" 42 | 43 | def _initializer(shape, dtype=tf.float32, 44 | partition_info=None): # pylint: disable=unused-argument 45 | size_x = shape[0] 46 | size_h = shape[1] / 4 # assumes lstm. 47 | t = np.zeros(shape) 48 | t[:, :size_h] = orthogonal([size_x, size_h]) * scale 49 | t[:, size_h:size_h * 2] = orthogonal([size_x, size_h]) * scale 50 | t[:, size_h * 2:size_h * 3] = orthogonal([size_x, size_h]) * scale 51 | t[:, size_h * 3:] = orthogonal([size_x, size_h]) * scale 52 | return tf.constant(t, dtype) 53 | 54 | return _initializer 55 | 56 | 57 | class LSTMCell(tf.contrib.rnn.RNNCell): 58 | """Vanilla LSTM cell. 59 | 60 | Uses ortho initializer, and also recurrent dropout without memory loss 61 | (https://arxiv.org/abs/1603.05118) 62 | """ 63 | 64 | def __init__(self, 65 | num_units, 66 | forget_bias=1.0, 67 | use_recurrent_dropout=False, 68 | dropout_keep_prob=0.9): 69 | self.num_units = num_units 70 | self.forget_bias = forget_bias 71 | self.use_recurrent_dropout = use_recurrent_dropout 72 | self.dropout_keep_prob = dropout_keep_prob 73 | 74 | @property 75 | def state_size(self): 76 | return 2 * self.num_units 77 | 78 | @property 79 | def output_size(self): 80 | return self.num_units 81 | 82 | def get_output(self, state): 83 | unused_c, h = tf.split(state, 2, 1) 84 | return h 85 | 86 | def __call__(self, x, state, scope=None): 87 | with tf.variable_scope(scope or type(self).__name__): 88 | c, h = tf.split(state, 2, 1) 89 | 90 | x_size = x.get_shape().as_list()[1] 91 | 92 | w_init = None # uniform 93 | 94 | h_init = lstm_ortho_initializer(1.0) 95 | 96 | # Keep W_xh and W_hh separate here as well to use different init methods. 97 | w_xh = tf.get_variable( 98 | 'W_xh', [x_size, 4 * self.num_units], initializer=w_init) 99 | w_hh = tf.get_variable( 100 | 'W_hh', [self.num_units, 4 * self.num_units], initializer=h_init) 101 | bias = tf.get_variable( 102 | 'bias', [4 * self.num_units], 103 | initializer=tf.constant_initializer(0.0)) 104 | 105 | concat = tf.concat([x, h], 1) 106 | w_full = tf.concat([w_xh, w_hh], 0) 107 | hidden = tf.matmul(concat, w_full) + bias 108 | 109 | i, j, f, o = tf.split(hidden, 4, 1) 110 | 111 | if self.use_recurrent_dropout: 112 | g = tf.nn.dropout(tf.tanh(j), self.dropout_keep_prob) 113 | else: 114 | g = tf.tanh(j) 115 | 116 | new_c = c * tf.sigmoid(f + self.forget_bias) + tf.sigmoid(i) * g 117 | new_h = tf.tanh(new_c) * tf.sigmoid(o) 118 | 119 | return new_h, tf.concat([new_c, new_h], 1) # fuk tuples. 120 | 121 | 122 | def layer_norm_all(h, 123 | batch_size, 124 | base, 125 | num_units, 126 | scope='layer_norm', 127 | reuse=False, 128 | gamma_start=1.0, 129 | epsilon=1e-3, 130 | use_bias=True): 131 | """Layer Norm (faster version, but not using defun).""" 132 | # Performs layer norm on multiple base at once (ie, i, g, j, o for lstm) 133 | # Reshapes h in to perform layer norm in parallel 134 | h_reshape = tf.reshape(h, [batch_size, base, num_units]) 135 | mean = tf.reduce_mean(h_reshape, [2], keep_dims=True) 136 | var = tf.reduce_mean(tf.square(h_reshape - mean), [2], keep_dims=True) 137 | epsilon = tf.constant(epsilon) 138 | rstd = tf.rsqrt(var + epsilon) 139 | h_reshape = (h_reshape - mean) * rstd 140 | # reshape back to original 141 | h = tf.reshape(h_reshape, [batch_size, base * num_units]) 142 | with tf.variable_scope(scope): 143 | if reuse: 144 | tf.get_variable_scope().reuse_variables() 145 | gamma = tf.get_variable( 146 | 'ln_gamma', [4 * num_units], 147 | initializer=tf.constant_initializer(gamma_start)) 148 | if use_bias: 149 | beta = tf.get_variable( 150 | 'ln_beta', [4 * num_units], initializer=tf.constant_initializer(0.0)) 151 | if use_bias: 152 | return gamma * h + beta 153 | return gamma * h 154 | 155 | 156 | def layer_norm(x, 157 | num_units, 158 | scope='layer_norm', 159 | reuse=False, 160 | gamma_start=1.0, 161 | epsilon=1e-3, 162 | use_bias=True): 163 | """Calculate layer norm.""" 164 | axes = [1] 165 | mean = tf.reduce_mean(x, axes, keep_dims=True) 166 | x_shifted = x - mean 167 | var = tf.reduce_mean(tf.square(x_shifted), axes, keep_dims=True) 168 | inv_std = tf.rsqrt(var + epsilon) 169 | with tf.variable_scope(scope): 170 | if reuse is True: 171 | tf.get_variable_scope().reuse_variables() 172 | gamma = tf.get_variable( 173 | 'ln_gamma', [num_units], 174 | initializer=tf.constant_initializer(gamma_start)) 175 | if use_bias: 176 | beta = tf.get_variable( 177 | 'ln_beta', [num_units], initializer=tf.constant_initializer(0.0)) 178 | output = gamma * (x_shifted) * inv_std 179 | if use_bias: 180 | output += beta 181 | return output 182 | 183 | 184 | def raw_layer_norm(x, epsilon=1e-3): 185 | axes = [1] 186 | mean = tf.reduce_mean(x, axes, keep_dims=True) 187 | std = tf.sqrt( 188 | tf.reduce_mean(tf.square(x - mean), axes, keep_dims=True) + epsilon) 189 | output = (x - mean) / (std) 190 | return output 191 | 192 | 193 | def super_linear(x, 194 | output_size, 195 | scope=None, 196 | reuse=False, 197 | init_w='ortho', 198 | weight_start=0.0, 199 | use_bias=True, 200 | bias_start=0.0, 201 | input_size=None): 202 | """Performs linear operation. Uses ortho init defined earlier.""" 203 | shape = x.get_shape().as_list() 204 | with tf.variable_scope(scope or 'linear'): 205 | if reuse is True: 206 | tf.get_variable_scope().reuse_variables() 207 | 208 | w_init = None # uniform 209 | if input_size is None: 210 | x_size = shape[1] 211 | else: 212 | x_size = input_size 213 | if init_w == 'zeros': 214 | w_init = tf.constant_initializer(0.0) 215 | elif init_w == 'constant': 216 | w_init = tf.constant_initializer(weight_start) 217 | elif init_w == 'gaussian': 218 | w_init = tf.random_normal_initializer(stddev=weight_start) 219 | elif init_w == 'ortho': 220 | w_init = lstm_ortho_initializer(1.0) 221 | 222 | w = tf.get_variable( 223 | 'super_linear_w', [x_size, output_size], tf.float32, initializer=w_init) 224 | if use_bias: 225 | b = tf.get_variable( 226 | 'super_linear_b', [output_size], 227 | tf.float32, 228 | initializer=tf.constant_initializer(bias_start)) 229 | return tf.matmul(x, w) + b 230 | return tf.matmul(x, w) 231 | 232 | 233 | class LayerNormLSTMCell(tf.contrib.rnn.RNNCell): 234 | """Layer-Norm, with Ortho Init. and Recurrent Dropout without Memory Loss. 235 | 236 | https://arxiv.org/abs/1607.06450 - Layer Norm 237 | https://arxiv.org/abs/1603.05118 - Recurrent Dropout without Memory Loss 238 | """ 239 | 240 | def __init__(self, 241 | num_units, 242 | forget_bias=1.0, 243 | use_recurrent_dropout=False, 244 | dropout_keep_prob=0.90): 245 | """Initialize the Layer Norm LSTM cell. 246 | 247 | Args: 248 | num_units: int, The number of units in the LSTM cell. 249 | forget_bias: float, The bias added to forget gates (default 1.0). 250 | use_recurrent_dropout: Whether to use Recurrent Dropout (default False) 251 | dropout_keep_prob: float, dropout keep probability (default 0.90) 252 | """ 253 | self.num_units = num_units 254 | self.forget_bias = forget_bias 255 | self.use_recurrent_dropout = use_recurrent_dropout 256 | self.dropout_keep_prob = dropout_keep_prob 257 | 258 | @property 259 | def input_size(self): 260 | return self.num_units 261 | 262 | @property 263 | def output_size(self): 264 | return self.num_units 265 | 266 | @property 267 | def state_size(self): 268 | return 2 * self.num_units 269 | 270 | def get_output(self, state): 271 | h, unused_c = tf.split(state, 2, 1) 272 | return h 273 | 274 | def __call__(self, x, state, timestep=0, scope=None): 275 | with tf.variable_scope(scope or type(self).__name__): 276 | h, c = tf.split(state, 2, 1) 277 | 278 | h_size = self.num_units 279 | x_size = x.get_shape().as_list()[1] 280 | batch_size = x.get_shape().as_list()[0] 281 | 282 | w_init = None # uniform 283 | 284 | h_init = lstm_ortho_initializer(1.0) 285 | 286 | w_xh = tf.get_variable( 287 | 'W_xh', [x_size, 4 * self.num_units], initializer=w_init) 288 | w_hh = tf.get_variable( 289 | 'W_hh', [self.num_units, 4 * self.num_units], initializer=h_init) 290 | 291 | concat = tf.concat([x, h], 1) # concat for speed. 292 | w_full = tf.concat([w_xh, w_hh], 0) 293 | concat = tf.matmul(concat, w_full) # + bias # live life without garbage. 294 | 295 | # i = input_gate, j = new_input, f = forget_gate, o = output_gate 296 | concat = layer_norm_all(concat, batch_size, 4, h_size, 'ln_all') 297 | i, j, f, o = tf.split(concat, 4, 1) 298 | 299 | if self.use_recurrent_dropout: 300 | g = tf.nn.dropout(tf.tanh(j), self.dropout_keep_prob) 301 | else: 302 | g = tf.tanh(j) 303 | 304 | new_c = c * tf.sigmoid(f + self.forget_bias) + tf.sigmoid(i) * g 305 | new_h = tf.tanh(layer_norm(new_c, h_size, 'ln_c')) * tf.sigmoid(o) 306 | 307 | return new_h, tf.concat([new_h, new_c], 1) 308 | 309 | 310 | class HyperLSTMCell(tf.contrib.rnn.RNNCell): 311 | """HyperLSTM with Ortho Init, Layer Norm, Recurrent Dropout, no Memory Loss. 312 | 313 | https://arxiv.org/abs/1609.09106 314 | http://blog.otoro.net/2016/09/28/hyper-networks/ 315 | """ 316 | 317 | def __init__(self, 318 | num_units, 319 | forget_bias=1.0, 320 | use_recurrent_dropout=False, 321 | dropout_keep_prob=0.90, 322 | use_layer_norm=True, 323 | hyper_num_units=256, 324 | hyper_embedding_size=32, 325 | hyper_use_recurrent_dropout=False): 326 | """Initialize the Layer Norm HyperLSTM cell. 327 | 328 | Args: 329 | num_units: int, The number of units in the LSTM cell. 330 | forget_bias: float, The bias added to forget gates (default 1.0). 331 | use_recurrent_dropout: Whether to use Recurrent Dropout (default False) 332 | dropout_keep_prob: float, dropout keep probability (default 0.90) 333 | use_layer_norm: boolean. (default True) 334 | Controls whether we use LayerNorm layers in main LSTM & HyperLSTM cell. 335 | hyper_num_units: int, number of units in HyperLSTM cell. 336 | (default is 128, recommend experimenting with 256 for larger tasks) 337 | hyper_embedding_size: int, size of signals emitted from HyperLSTM cell. 338 | (default is 16, recommend trying larger values for large datasets) 339 | hyper_use_recurrent_dropout: boolean. (default False) 340 | Controls whether HyperLSTM cell also uses recurrent dropout. 341 | Recommend turning this on only if hyper_num_units becomes large (>= 512) 342 | """ 343 | self.num_units = num_units 344 | self.forget_bias = forget_bias 345 | self.use_recurrent_dropout = use_recurrent_dropout 346 | self.dropout_keep_prob = dropout_keep_prob 347 | self.use_layer_norm = use_layer_norm 348 | self.hyper_num_units = hyper_num_units 349 | self.hyper_embedding_size = hyper_embedding_size 350 | self.hyper_use_recurrent_dropout = hyper_use_recurrent_dropout 351 | 352 | self.total_num_units = self.num_units + self.hyper_num_units 353 | 354 | if self.use_layer_norm: 355 | cell_fn = LayerNormLSTMCell 356 | else: 357 | cell_fn = LSTMCell 358 | self.hyper_cell = cell_fn( 359 | hyper_num_units, 360 | use_recurrent_dropout=hyper_use_recurrent_dropout, 361 | dropout_keep_prob=dropout_keep_prob) 362 | 363 | @property 364 | def input_size(self): 365 | return self._input_size 366 | 367 | @property 368 | def output_size(self): 369 | return self.num_units 370 | 371 | @property 372 | def state_size(self): 373 | return 2 * self.total_num_units 374 | 375 | def get_output(self, state): 376 | total_h, unused_total_c = tf.split(state, 2, 1) 377 | h = total_h[:, 0:self.num_units] 378 | return h 379 | 380 | def hyper_norm(self, layer, scope='hyper', use_bias=True): 381 | num_units = self.num_units 382 | embedding_size = self.hyper_embedding_size 383 | # recurrent batch norm init trick (https://arxiv.org/abs/1603.09025). 384 | init_gamma = 0.10 # cooijmans' da man. 385 | with tf.variable_scope(scope): 386 | zw = super_linear( 387 | self.hyper_output, 388 | embedding_size, 389 | init_w='constant', 390 | weight_start=0.00, 391 | use_bias=True, 392 | bias_start=1.0, 393 | scope='zw') 394 | alpha = super_linear( 395 | zw, 396 | num_units, 397 | init_w='constant', 398 | weight_start=init_gamma / embedding_size, 399 | use_bias=False, 400 | scope='alpha') 401 | result = tf.multiply(alpha, layer) 402 | if use_bias: 403 | zb = super_linear( 404 | self.hyper_output, 405 | embedding_size, 406 | init_w='gaussian', 407 | weight_start=0.01, 408 | use_bias=False, 409 | bias_start=0.0, 410 | scope='zb') 411 | beta = super_linear( 412 | zb, 413 | num_units, 414 | init_w='constant', 415 | weight_start=0.00, 416 | use_bias=False, 417 | scope='beta') 418 | result += beta 419 | return result 420 | 421 | def __call__(self, x, state, timestep=0, scope=None): 422 | with tf.variable_scope(scope or type(self).__name__): 423 | total_h, total_c = tf.split(state, 2, 1) 424 | h = total_h[:, 0:self.num_units] 425 | c = total_c[:, 0:self.num_units] 426 | self.hyper_state = tf.concat( 427 | [total_h[:, self.num_units:], total_c[:, self.num_units:]], 1) 428 | 429 | batch_size = x.get_shape().as_list()[0] 430 | x_size = x.get_shape().as_list()[1] 431 | self._input_size = x_size 432 | 433 | w_init = None # uniform 434 | 435 | h_init = lstm_ortho_initializer(1.0) 436 | 437 | w_xh = tf.get_variable( 438 | 'W_xh', [x_size, 4 * self.num_units], initializer=w_init) 439 | w_hh = tf.get_variable( 440 | 'W_hh', [self.num_units, 4 * self.num_units], initializer=h_init) 441 | bias = tf.get_variable( 442 | 'bias', [4 * self.num_units], 443 | initializer=tf.constant_initializer(0.0)) 444 | 445 | # concatenate the input and hidden states for hyperlstm input 446 | hyper_input = tf.concat([x, h], 1) 447 | hyper_output, hyper_new_state = self.hyper_cell(hyper_input, 448 | self.hyper_state) 449 | self.hyper_output = hyper_output 450 | self.hyper_state = hyper_new_state 451 | 452 | xh = tf.matmul(x, w_xh) 453 | hh = tf.matmul(h, w_hh) 454 | 455 | # split Wxh contributions 456 | ix, jx, fx, ox = tf.split(xh, 4, 1) 457 | ix = self.hyper_norm(ix, 'hyper_ix', use_bias=False) 458 | jx = self.hyper_norm(jx, 'hyper_jx', use_bias=False) 459 | fx = self.hyper_norm(fx, 'hyper_fx', use_bias=False) 460 | ox = self.hyper_norm(ox, 'hyper_ox', use_bias=False) 461 | 462 | # split Whh contributions 463 | ih, jh, fh, oh = tf.split(hh, 4, 1) 464 | ih = self.hyper_norm(ih, 'hyper_ih', use_bias=True) 465 | jh = self.hyper_norm(jh, 'hyper_jh', use_bias=True) 466 | fh = self.hyper_norm(fh, 'hyper_fh', use_bias=True) 467 | oh = self.hyper_norm(oh, 'hyper_oh', use_bias=True) 468 | 469 | # split bias 470 | ib, jb, fb, ob = tf.split(bias, 4, 0) # bias is to be broadcasted. 471 | 472 | # i = input_gate, j = new_input, f = forget_gate, o = output_gate 473 | i = ix + ih + ib 474 | j = jx + jh + jb 475 | f = fx + fh + fb 476 | o = ox + oh + ob 477 | 478 | if self.use_layer_norm: 479 | concat = tf.concat([i, j, f, o], 1) 480 | concat = layer_norm_all(concat, batch_size, 4, self.num_units, 'ln_all') 481 | i, j, f, o = tf.split(concat, 4, 1) 482 | 483 | if self.use_recurrent_dropout: 484 | g = tf.nn.dropout(tf.tanh(j), self.dropout_keep_prob) 485 | else: 486 | g = tf.tanh(j) 487 | 488 | new_c = c * tf.sigmoid(f + self.forget_bias) + tf.sigmoid(i) * g 489 | new_h = tf.tanh(layer_norm(new_c, self.num_units, 'ln_c')) * tf.sigmoid(o) 490 | 491 | hyper_h, hyper_c = tf.split(hyper_new_state, 2, 1) 492 | new_total_h = tf.concat([new_h, hyper_h], 1) 493 | new_total_c = tf.concat([new_c, hyper_c], 1) 494 | new_total_state = tf.concat([new_total_h, new_total_c], 1) 495 | return new_h, new_total_state 496 | -------------------------------------------------------------------------------- /svg/path/tests/test_paths.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import unittest 3 | from math import sqrt, pi 4 | 5 | from ..path import CubicBezier, QuadraticBezier, Line, Arc, Path 6 | 7 | 8 | # Most of these test points are not calculated serparately, as that would 9 | # take too long and be too error prone. Instead the curves have been verified 10 | # to be correct visually, by drawing them with the turtle module, with code 11 | # like this: 12 | # 13 | # import turtle 14 | # t = turtle.Turtle() 15 | # t.penup() 16 | # 17 | # for arc in (path1, path2): 18 | # p = arc.point(0) 19 | # t.goto(p.real - 500, -p.imag + 300) 20 | # t.dot(3, 'black') 21 | # t.pendown() 22 | # for x in range(1, 101): 23 | # p = arc.point(x * 0.01) 24 | # t.goto(p.real - 500, -p.imag + 300) 25 | # t.penup() 26 | # t.dot(3, 'black') 27 | # 28 | # raw_input() 29 | # 30 | # After the paths have been verified to be correct this way, the testing of 31 | # points along the paths has been added as regression tests, to make sure 32 | # nobody changes the way curves are drawn by mistake. Therefore, do not take 33 | # these points religiously. They might be subtly wrong, unless otherwise 34 | # noted. 35 | 36 | class LineTest(unittest.TestCase): 37 | 38 | def test_lines(self): 39 | # These points are calculated, and not just regression tests. 40 | 41 | line1 = Line(0j, 400 + 0j) 42 | self.assertAlmostEqual(line1.point(0), (0j)) 43 | self.assertAlmostEqual(line1.point(0.3), (120 + 0j)) 44 | self.assertAlmostEqual(line1.point(0.5), (200 + 0j)) 45 | self.assertAlmostEqual(line1.point(0.9), (360 + 0j)) 46 | self.assertAlmostEqual(line1.point(1), (400 + 0j)) 47 | self.assertAlmostEqual(line1.length(), 400) 48 | 49 | line2 = Line(400 + 0j, 400 + 300j) 50 | self.assertAlmostEqual(line2.point(0), (400 + 0j)) 51 | self.assertAlmostEqual(line2.point(0.3), (400 + 90j)) 52 | self.assertAlmostEqual(line2.point(0.5), (400 + 150j)) 53 | self.assertAlmostEqual(line2.point(0.9), (400 + 270j)) 54 | self.assertAlmostEqual(line2.point(1), (400 + 300j)) 55 | self.assertAlmostEqual(line2.length(), 300) 56 | 57 | line3 = Line(400 + 300j, 0j) 58 | self.assertAlmostEqual(line3.point(0), (400 + 300j)) 59 | self.assertAlmostEqual(line3.point(0.3), (280 + 210j)) 60 | self.assertAlmostEqual(line3.point(0.5), (200 + 150j)) 61 | self.assertAlmostEqual(line3.point(0.9), (40 + 30j)) 62 | self.assertAlmostEqual(line3.point(1), (0j)) 63 | self.assertAlmostEqual(line3.length(), 500) 64 | 65 | def test_equality(self): 66 | # This is to test the __eq__ and __ne__ methods, so we can't use 67 | # assertEqual and assertNotEqual 68 | line = Line(0j, 400 + 0j) 69 | self.assertTrue(line == Line(0, 400)) 70 | self.assertTrue(line != Line(100, 400)) 71 | self.assertFalse(line == str(line)) 72 | self.assertTrue(line != str(line)) 73 | self.assertFalse(CubicBezier(600 + 500j, 600 + 350j, 900 + 650j, 900 + 500j) == 74 | line) 75 | 76 | 77 | class CubicBezierTest(unittest.TestCase): 78 | def test_approx_circle(self): 79 | """This is a approximate circle drawn in Inkscape""" 80 | 81 | arc1 = CubicBezier( 82 | complex(0, 0), 83 | complex(0, 109.66797), 84 | complex(-88.90345, 198.57142), 85 | complex(-198.57142, 198.57142) 86 | ) 87 | 88 | self.assertAlmostEqual(arc1.point(0), (0j)) 89 | self.assertAlmostEqual(arc1.point(0.1), (-2.59896457 + 32.20931647j)) 90 | self.assertAlmostEqual(arc1.point(0.2), (-10.12330256 + 62.76392816j)) 91 | self.assertAlmostEqual(arc1.point(0.3), (-22.16418039 + 91.25500149j)) 92 | self.assertAlmostEqual(arc1.point(0.4), (-38.31276448 + 117.27370288j)) 93 | self.assertAlmostEqual(arc1.point(0.5), (-58.16022125 + 140.41119875j)) 94 | self.assertAlmostEqual(arc1.point(0.6), (-81.29771712 + 160.25865552j)) 95 | self.assertAlmostEqual(arc1.point(0.7), (-107.31641851 + 176.40723961j)) 96 | self.assertAlmostEqual(arc1.point(0.8), (-135.80749184 + 188.44811744j)) 97 | self.assertAlmostEqual(arc1.point(0.9), (-166.36210353 + 195.97245543j)) 98 | self.assertAlmostEqual(arc1.point(1), (-198.57142 + 198.57142j)) 99 | 100 | arc2 = CubicBezier( 101 | complex(-198.57142, 198.57142), 102 | complex(-109.66797 - 198.57142, 0 + 198.57142), 103 | complex(-198.57143 - 198.57142, -88.90345 + 198.57142), 104 | complex(-198.57143 - 198.57142, 0), 105 | ) 106 | 107 | self.assertAlmostEqual(arc2.point(0), (-198.57142 + 198.57142j)) 108 | self.assertAlmostEqual(arc2.point(0.1), (-230.78073675 + 195.97245543j)) 109 | self.assertAlmostEqual(arc2.point(0.2), (-261.3353492 + 188.44811744j)) 110 | self.assertAlmostEqual(arc2.point(0.3), (-289.82642365 + 176.40723961j)) 111 | self.assertAlmostEqual(arc2.point(0.4), (-315.8451264 + 160.25865552j)) 112 | self.assertAlmostEqual(arc2.point(0.5), (-338.98262375 + 140.41119875j)) 113 | self.assertAlmostEqual(arc2.point(0.6), (-358.830082 + 117.27370288j)) 114 | self.assertAlmostEqual(arc2.point(0.7), (-374.97866745 + 91.25500149j)) 115 | self.assertAlmostEqual(arc2.point(0.8), (-387.0195464 + 62.76392816j)) 116 | self.assertAlmostEqual(arc2.point(0.9), (-394.54388515 + 32.20931647j)) 117 | self.assertAlmostEqual(arc2.point(1), (-397.14285 + 0j)) 118 | 119 | arc3 = CubicBezier( 120 | complex(-198.57143 - 198.57142, 0), 121 | complex(0 - 198.57143 - 198.57142, -109.66797), 122 | complex(88.90346 - 198.57143 - 198.57142, -198.57143), 123 | complex(-198.57142, -198.57143) 124 | ) 125 | 126 | self.assertAlmostEqual(arc3.point(0), (-397.14285 + 0j)) 127 | self.assertAlmostEqual(arc3.point(0.1), (-394.54388515 - 32.20931675j)) 128 | self.assertAlmostEqual(arc3.point(0.2), (-387.0195464 - 62.7639292j)) 129 | self.assertAlmostEqual(arc3.point(0.3), (-374.97866745 - 91.25500365j)) 130 | self.assertAlmostEqual(arc3.point(0.4), (-358.830082 - 117.2737064j)) 131 | self.assertAlmostEqual(arc3.point(0.5), (-338.98262375 - 140.41120375j)) 132 | self.assertAlmostEqual(arc3.point(0.6), (-315.8451264 - 160.258662j)) 133 | self.assertAlmostEqual(arc3.point(0.7), (-289.82642365 - 176.40724745j)) 134 | self.assertAlmostEqual(arc3.point(0.8), (-261.3353492 - 188.4481264j)) 135 | self.assertAlmostEqual(arc3.point(0.9), (-230.78073675 - 195.97246515j)) 136 | self.assertAlmostEqual(arc3.point(1), (-198.57142 - 198.57143j)) 137 | 138 | arc4 = CubicBezier( 139 | complex(-198.57142, -198.57143), 140 | complex(109.66797 - 198.57142, 0 - 198.57143), 141 | complex(0, 88.90346 - 198.57143), 142 | complex(0, 0), 143 | ) 144 | 145 | self.assertAlmostEqual(arc4.point(0), (-198.57142 - 198.57143j)) 146 | self.assertAlmostEqual(arc4.point(0.1), (-166.36210353 - 195.97246515j)) 147 | self.assertAlmostEqual(arc4.point(0.2), (-135.80749184 - 188.4481264j)) 148 | self.assertAlmostEqual(arc4.point(0.3), (-107.31641851 - 176.40724745j)) 149 | self.assertAlmostEqual(arc4.point(0.4), (-81.29771712 - 160.258662j)) 150 | self.assertAlmostEqual(arc4.point(0.5), (-58.16022125 - 140.41120375j)) 151 | self.assertAlmostEqual(arc4.point(0.6), (-38.31276448 - 117.2737064j)) 152 | self.assertAlmostEqual(arc4.point(0.7), (-22.16418039 - 91.25500365j)) 153 | self.assertAlmostEqual(arc4.point(0.8), (-10.12330256 - 62.7639292j)) 154 | self.assertAlmostEqual(arc4.point(0.9), (-2.59896457 - 32.20931675j)) 155 | self.assertAlmostEqual(arc4.point(1), (0j)) 156 | 157 | def test_svg_examples(self): 158 | 159 | # M100,200 C100,100 250,100 250,200 160 | path1 = CubicBezier(100 + 200j, 100 + 100j, 250 + 100j, 250 + 200j) 161 | self.assertAlmostEqual(path1.point(0), (100 + 200j)) 162 | self.assertAlmostEqual(path1.point(0.3), (132.4 + 137j)) 163 | self.assertAlmostEqual(path1.point(0.5), (175 + 125j)) 164 | self.assertAlmostEqual(path1.point(0.9), (245.8 + 173j)) 165 | self.assertAlmostEqual(path1.point(1), (250 + 200j)) 166 | 167 | # S400,300 400,200 168 | path2 = CubicBezier(250 + 200j, 250 + 300j, 400 + 300j, 400 + 200j) 169 | self.assertAlmostEqual(path2.point(0), (250 + 200j)) 170 | self.assertAlmostEqual(path2.point(0.3), (282.4 + 263j)) 171 | self.assertAlmostEqual(path2.point(0.5), (325 + 275j)) 172 | self.assertAlmostEqual(path2.point(0.9), (395.8 + 227j)) 173 | self.assertAlmostEqual(path2.point(1), (400 + 200j)) 174 | 175 | # M100,200 C100,100 400,100 400,200 176 | path3 = CubicBezier(100 + 200j, 100 + 100j, 400 + 100j, 400 + 200j) 177 | self.assertAlmostEqual(path3.point(0), (100 + 200j)) 178 | self.assertAlmostEqual(path3.point(0.3), (164.8 + 137j)) 179 | self.assertAlmostEqual(path3.point(0.5), (250 + 125j)) 180 | self.assertAlmostEqual(path3.point(0.9), (391.6 + 173j)) 181 | self.assertAlmostEqual(path3.point(1), (400 + 200j)) 182 | 183 | # M100,500 C25,400 475,400 400,500 184 | path4 = CubicBezier(100 + 500j, 25 + 400j, 475 + 400j, 400 + 500j) 185 | self.assertAlmostEqual(path4.point(0), (100 + 500j)) 186 | self.assertAlmostEqual(path4.point(0.3), (145.9 + 437j)) 187 | self.assertAlmostEqual(path4.point(0.5), (250 + 425j)) 188 | self.assertAlmostEqual(path4.point(0.9), (407.8 + 473j)) 189 | self.assertAlmostEqual(path4.point(1), (400 + 500j)) 190 | 191 | # M100,800 C175,700 325,700 400,800 192 | path5 = CubicBezier(100 + 800j, 175 + 700j, 325 + 700j, 400 + 800j) 193 | self.assertAlmostEqual(path5.point(0), (100 + 800j)) 194 | self.assertAlmostEqual(path5.point(0.3), (183.7 + 737j)) 195 | self.assertAlmostEqual(path5.point(0.5), (250 + 725j)) 196 | self.assertAlmostEqual(path5.point(0.9), (375.4 + 773j)) 197 | self.assertAlmostEqual(path5.point(1), (400 + 800j)) 198 | 199 | # M600,200 C675,100 975,100 900,200 200 | path6 = CubicBezier(600 + 200j, 675 + 100j, 975 + 100j, 900 + 200j) 201 | self.assertAlmostEqual(path6.point(0), (600 + 200j)) 202 | self.assertAlmostEqual(path6.point(0.3), (712.05 + 137j)) 203 | self.assertAlmostEqual(path6.point(0.5), (806.25 + 125j)) 204 | self.assertAlmostEqual(path6.point(0.9), (911.85 + 173j)) 205 | self.assertAlmostEqual(path6.point(1), (900 + 200j)) 206 | 207 | # M600,500 C600,350 900,650 900,500 208 | path7 = CubicBezier(600 + 500j, 600 + 350j, 900 + 650j, 900 + 500j) 209 | self.assertAlmostEqual(path7.point(0), (600 + 500j)) 210 | self.assertAlmostEqual(path7.point(0.3), (664.8 + 462.2j)) 211 | self.assertAlmostEqual(path7.point(0.5), (750 + 500j)) 212 | self.assertAlmostEqual(path7.point(0.9), (891.6 + 532.4j)) 213 | self.assertAlmostEqual(path7.point(1), (900 + 500j)) 214 | 215 | # M600,800 C625,700 725,700 750,800 216 | path8 = CubicBezier(600 + 800j, 625 + 700j, 725 + 700j, 750 + 800j) 217 | self.assertAlmostEqual(path8.point(0), (600 + 800j)) 218 | self.assertAlmostEqual(path8.point(0.3), (638.7 + 737j)) 219 | self.assertAlmostEqual(path8.point(0.5), (675 + 725j)) 220 | self.assertAlmostEqual(path8.point(0.9), (740.4 + 773j)) 221 | self.assertAlmostEqual(path8.point(1), (750 + 800j)) 222 | 223 | # S875,900 900,800 224 | inversion = (750 + 800j) + (750 + 800j) - (725 + 700j) 225 | path9 = CubicBezier(750 + 800j, inversion, 875 + 900j, 900 + 800j) 226 | self.assertAlmostEqual(path9.point(0), (750 + 800j)) 227 | self.assertAlmostEqual(path9.point(0.3), (788.7 + 863j)) 228 | self.assertAlmostEqual(path9.point(0.5), (825 + 875j)) 229 | self.assertAlmostEqual(path9.point(0.9), (890.4 + 827j)) 230 | self.assertAlmostEqual(path9.point(1), (900 + 800j)) 231 | 232 | def test_length(self): 233 | 234 | # A straight line: 235 | arc = CubicBezier( 236 | complex(0, 0), 237 | complex(0, 0), 238 | complex(0, 100), 239 | complex(0, 100) 240 | ) 241 | 242 | self.assertAlmostEqual(arc.length(), 100) 243 | 244 | # A diagonal line: 245 | arc = CubicBezier( 246 | complex(0, 0), 247 | complex(0, 0), 248 | complex(100, 100), 249 | complex(100, 100) 250 | ) 251 | 252 | self.assertAlmostEqual(arc.length(), sqrt(2 * 100 * 100)) 253 | 254 | # A quarter circle arc with radius 100: 255 | kappa = 4 * (sqrt(2) - 1) / 3 # http://www.whizkidtech.redprince.net/bezier/circle/ 256 | 257 | arc = CubicBezier( 258 | complex(0, 0), 259 | complex(0, kappa * 100), 260 | complex(100 - kappa * 100, 100), 261 | complex(100, 100) 262 | ) 263 | 264 | # We can't compare with pi*50 here, because this is just an 265 | # approximation of a circle arc. pi*50 is 157.079632679 266 | # So this is just yet another "warn if this changes" test. 267 | # This value is not verified to be correct. 268 | self.assertAlmostEqual(arc.length(), 157.1016698) 269 | 270 | # A recursive solution has also been suggested, but for CubicBezier 271 | # curves it could get a false solution on curves where the midpoint is on a 272 | # straight line between the start and end. For example, the following 273 | # curve would get solved as a straight line and get the length 300. 274 | # Make sure this is not the case. 275 | arc = CubicBezier( 276 | complex(600, 500), 277 | complex(600, 350), 278 | complex(900, 650), 279 | complex(900, 500) 280 | ) 281 | self.assertTrue(arc.length() > 300.0) 282 | 283 | def test_equality(self): 284 | # This is to test the __eq__ and __ne__ methods, so we can't use 285 | # assertEqual and assertNotEqual 286 | segment = CubicBezier(complex(600, 500), complex(600, 350), 287 | complex(900, 650), complex(900, 500)) 288 | 289 | self.assertTrue(segment == 290 | CubicBezier(600 + 500j, 600 + 350j, 900 + 650j, 900 + 500j)) 291 | self.assertTrue(segment != 292 | CubicBezier(600 + 501j, 600 + 350j, 900 + 650j, 900 + 500j)) 293 | self.assertTrue(segment != Line(0, 400)) 294 | 295 | 296 | class QuadraticBezierTest(unittest.TestCase): 297 | 298 | def test_svg_examples(self): 299 | """These is the path in the SVG specs""" 300 | # M200,300 Q400,50 600,300 T1000,300 301 | path1 = QuadraticBezier(200 + 300j, 400 + 50j, 600 + 300j) 302 | self.assertAlmostEqual(path1.point(0), (200 + 300j)) 303 | self.assertAlmostEqual(path1.point(0.3), (320 + 195j)) 304 | self.assertAlmostEqual(path1.point(0.5), (400 + 175j)) 305 | self.assertAlmostEqual(path1.point(0.9), (560 + 255j)) 306 | self.assertAlmostEqual(path1.point(1), (600 + 300j)) 307 | 308 | # T1000, 300 309 | inversion = (600 + 300j) + (600 + 300j) - (400 + 50j) 310 | path2 = QuadraticBezier(600 + 300j, inversion, 1000 + 300j) 311 | self.assertAlmostEqual(path2.point(0), (600 + 300j)) 312 | self.assertAlmostEqual(path2.point(0.3), (720 + 405j)) 313 | self.assertAlmostEqual(path2.point(0.5), (800 + 425j)) 314 | self.assertAlmostEqual(path2.point(0.9), (960 + 345j)) 315 | self.assertAlmostEqual(path2.point(1), (1000 + 300j)) 316 | 317 | def test_length(self): 318 | # expected results calculated with 319 | # svg.path.segment_length(q, 0, 1, q.start, q.end, 1e-14, 20, 0) 320 | q1 = QuadraticBezier(200 + 300j, 400 + 50j, 600 + 300j) 321 | q2 = QuadraticBezier(200 + 300j, 400 + 50j, 500 + 200j) 322 | closedq = QuadraticBezier(6+2j, 5-1j, 6+2j) 323 | linq1 = QuadraticBezier(1, 2, 3) 324 | linq2 = QuadraticBezier(1+3j, 2+5j, -9 - 17j) 325 | nodalq = QuadraticBezier(1, 1, 1) 326 | tests = [(q1, 487.77109389525975), 327 | (q2, 379.90458193489155), 328 | (closedq, 3.1622776601683795), 329 | (linq1, 2), 330 | (linq2, 22.73335777124786), 331 | (nodalq, 0)] 332 | for q, exp_res in tests: 333 | self.assertAlmostEqual(q.length(), exp_res) 334 | 335 | def test_equality(self): 336 | # This is to test the __eq__ and __ne__ methods, so we can't use 337 | # assertEqual and assertNotEqual 338 | segment = QuadraticBezier(200 + 300j, 400 + 50j, 600 + 300j) 339 | self.assertTrue(segment == QuadraticBezier(200 + 300j, 400 + 50j, 600 + 300j)) 340 | self.assertTrue(segment != QuadraticBezier(200 + 301j, 400 + 50j, 600 + 300j)) 341 | self.assertFalse(segment == Arc(0j, 100 + 50j, 0, 0, 0, 100 + 50j)) 342 | self.assertTrue(Arc(0j, 100 + 50j, 0, 0, 0, 100 + 50j) != segment) 343 | 344 | 345 | class ArcTest(unittest.TestCase): 346 | 347 | def test_points(self): 348 | arc1 = Arc(0j, 100 + 50j, 0, 0, 0, 100 + 50j) 349 | self.assertAlmostEqual(arc1.center, 100 + 0j) 350 | self.assertAlmostEqual(arc1.theta, 180.0) 351 | self.assertAlmostEqual(arc1.delta, -90.0) 352 | 353 | self.assertAlmostEqual(arc1.point(0.0), (0j)) 354 | self.assertAlmostEqual(arc1.point(0.1), (1.23116594049 + 7.82172325201j)) 355 | self.assertAlmostEqual(arc1.point(0.2), (4.89434837048 + 15.4508497187j)) 356 | self.assertAlmostEqual(arc1.point(0.3), (10.8993475812 + 22.699524987j)) 357 | self.assertAlmostEqual(arc1.point(0.4), (19.0983005625 + 29.3892626146j)) 358 | self.assertAlmostEqual(arc1.point(0.5), (29.2893218813 + 35.3553390593j)) 359 | self.assertAlmostEqual(arc1.point(0.6), (41.2214747708 + 40.4508497187j)) 360 | self.assertAlmostEqual(arc1.point(0.7), (54.6009500260 + 44.5503262094j)) 361 | self.assertAlmostEqual(arc1.point(0.8), (69.0983005625 + 47.5528258148j)) 362 | self.assertAlmostEqual(arc1.point(0.9), (84.3565534960 + 49.3844170298j)) 363 | self.assertAlmostEqual(arc1.point(1.0), (100 + 50j)) 364 | 365 | arc2 = Arc(0j, 100 + 50j, 0, 1, 0, 100 + 50j) 366 | self.assertAlmostEqual(arc2.center, 50j) 367 | self.assertAlmostEqual(arc2.theta, 270.0) 368 | self.assertAlmostEqual(arc2.delta, -270.0) 369 | 370 | self.assertAlmostEqual(arc2.point(0.0), (0j)) 371 | self.assertAlmostEqual(arc2.point(0.1), (-45.399049974 + 5.44967379058j)) 372 | self.assertAlmostEqual(arc2.point(0.2), (-80.9016994375 + 20.6107373854j)) 373 | self.assertAlmostEqual(arc2.point(0.3), (-98.7688340595 + 42.178276748j)) 374 | self.assertAlmostEqual(arc2.point(0.4), (-95.1056516295 + 65.4508497187j)) 375 | self.assertAlmostEqual(arc2.point(0.5), (-70.7106781187 + 85.3553390593j)) 376 | self.assertAlmostEqual(arc2.point(0.6), (-30.9016994375 + 97.5528258148j)) 377 | self.assertAlmostEqual(arc2.point(0.7), (15.643446504 + 99.3844170298j)) 378 | self.assertAlmostEqual(arc2.point(0.8), (58.7785252292 + 90.4508497187j)) 379 | self.assertAlmostEqual(arc2.point(0.9), (89.1006524188 + 72.699524987j)) 380 | self.assertAlmostEqual(arc2.point(1.0), (100 + 50j)) 381 | 382 | arc3 = Arc(0j, 100 + 50j, 0, 0, 1, 100 + 50j) 383 | self.assertAlmostEqual(arc3.center, 50j) 384 | self.assertAlmostEqual(arc3.theta, 270.0) 385 | self.assertAlmostEqual(arc3.delta, 90.0) 386 | 387 | self.assertAlmostEqual(arc3.point(0.0), (0j)) 388 | self.assertAlmostEqual(arc3.point(0.1), (15.643446504 + 0.615582970243j)) 389 | self.assertAlmostEqual(arc3.point(0.2), (30.9016994375 + 2.44717418524j)) 390 | self.assertAlmostEqual(arc3.point(0.3), (45.399049974 + 5.44967379058j)) 391 | self.assertAlmostEqual(arc3.point(0.4), (58.7785252292 + 9.54915028125j)) 392 | self.assertAlmostEqual(arc3.point(0.5), (70.7106781187 + 14.6446609407j)) 393 | self.assertAlmostEqual(arc3.point(0.6), (80.9016994375 + 20.6107373854j)) 394 | self.assertAlmostEqual(arc3.point(0.7), (89.1006524188 + 27.300475013j)) 395 | self.assertAlmostEqual(arc3.point(0.8), (95.1056516295 + 34.5491502813j)) 396 | self.assertAlmostEqual(arc3.point(0.9), (98.7688340595 + 42.178276748j)) 397 | self.assertAlmostEqual(arc3.point(1.0), (100 + 50j)) 398 | 399 | arc4 = Arc(0j, 100 + 50j, 0, 1, 1, 100 + 50j) 400 | self.assertAlmostEqual(arc4.center, 100 + 0j) 401 | self.assertAlmostEqual(arc4.theta, 180.0) 402 | self.assertAlmostEqual(arc4.delta, 270.0) 403 | 404 | self.assertAlmostEqual(arc4.point(0.0), (0j)) 405 | self.assertAlmostEqual(arc4.point(0.1), (10.8993475812 - 22.699524987j)) 406 | self.assertAlmostEqual(arc4.point(0.2), (41.2214747708 - 40.4508497187j)) 407 | self.assertAlmostEqual(arc4.point(0.3), (84.3565534960 - 49.3844170298j)) 408 | self.assertAlmostEqual(arc4.point(0.4), (130.901699437 - 47.5528258148j)) 409 | self.assertAlmostEqual(arc4.point(0.5), (170.710678119 - 35.3553390593j)) 410 | self.assertAlmostEqual(arc4.point(0.6), (195.105651630 - 15.4508497187j)) 411 | self.assertAlmostEqual(arc4.point(0.7), (198.768834060 + 7.82172325201j)) 412 | self.assertAlmostEqual(arc4.point(0.8), (180.901699437 + 29.3892626146j)) 413 | self.assertAlmostEqual(arc4.point(0.9), (145.399049974 + 44.5503262094j)) 414 | self.assertAlmostEqual(arc4.point(1.0), (100 + 50j)) 415 | 416 | def test_length(self): 417 | # I'll test the length calculations by making a circle, in two parts. 418 | arc1 = Arc(0j, 100 + 100j, 0, 0, 0, 200 + 0j) 419 | arc2 = Arc(200 + 0j, 100 + 100j, 0, 0, 0, 0j) 420 | self.assertAlmostEqual(arc1.length(), pi * 100) 421 | self.assertAlmostEqual(arc2.length(), pi * 100) 422 | 423 | def test_equality(self): 424 | # This is to test the __eq__ and __ne__ methods, so we can't use 425 | # assertEqual and assertNotEqual 426 | segment = Arc(0j, 100 + 50j, 0, 0, 0, 100 + 50j) 427 | self.assertTrue(segment == Arc(0j, 100 + 50j, 0, 0, 0, 100 + 50j)) 428 | self.assertTrue(segment != Arc(0j, 100 + 50j, 0, 1, 0, 100 + 50j)) 429 | 430 | 431 | class TestPath(unittest.TestCase): 432 | 433 | def test_circle(self): 434 | arc1 = Arc(0j, 100 + 100j, 0, 0, 0, 200 + 0j) 435 | arc2 = Arc(200 + 0j, 100 + 100j, 0, 0, 0, 0j) 436 | path = Path(arc1, arc2) 437 | self.assertAlmostEqual(path.point(0.0), (0j)) 438 | self.assertAlmostEqual(path.point(0.25), (100 + 100j)) 439 | self.assertAlmostEqual(path.point(0.5), (200 + 0j)) 440 | self.assertAlmostEqual(path.point(0.75), (100 - 100j)) 441 | self.assertAlmostEqual(path.point(1.0), (0j)) 442 | self.assertAlmostEqual(path.length(), pi * 200) 443 | 444 | def test_svg_specs(self): 445 | """The paths that are in the SVG specs""" 446 | 447 | # Big pie: M300,200 h-150 a150,150 0 1,0 150,-150 z 448 | path = Path(Line(300 + 200j, 150 + 200j), 449 | Arc(150 + 200j, 150 + 150j, 0, 1, 0, 300 + 50j), 450 | Line(300 + 50j, 300 + 200j)) 451 | # The points and length for this path are calculated and not regression tests. 452 | self.assertAlmostEqual(path.point(0.0), (300 + 200j)) 453 | self.assertAlmostEqual(path.point(0.14897825542), (150 + 200j)) 454 | self.assertAlmostEqual(path.point(0.5), (406.066017177 + 306.066017177j)) 455 | self.assertAlmostEqual(path.point(1 - 0.14897825542), (300 + 50j)) 456 | self.assertAlmostEqual(path.point(1.0), (300 + 200j)) 457 | # The errors seem to accumulate. Still 6 decimal places is more than good enough. 458 | self.assertAlmostEqual(path.length(), pi * 225 + 300, places=6) 459 | 460 | # Little pie: M275,175 v-150 a150,150 0 0,0 -150,150 z 461 | path = Path(Line(275 + 175j, 275 + 25j), 462 | Arc(275 + 25j, 150 + 150j, 0, 0, 0, 125 + 175j), 463 | Line(125 + 175j, 275 + 175j)) 464 | # The points and length for this path are calculated and not regression tests. 465 | self.assertAlmostEqual(path.point(0.0), (275 + 175j)) 466 | self.assertAlmostEqual(path.point(0.2800495767557787), (275 + 25j)) 467 | self.assertAlmostEqual(path.point(0.5), (168.93398282201787 + 68.93398282201787j)) 468 | self.assertAlmostEqual(path.point(1 - 0.2800495767557787), (125 + 175j)) 469 | self.assertAlmostEqual(path.point(1.0), (275 + 175j)) 470 | # The errors seem to accumulate. Still 6 decimal places is more than good enough. 471 | self.assertAlmostEqual(path.length(), pi * 75 + 300, places=6) 472 | 473 | # Bumpy path: M600,350 l 50,-25 474 | # a25,25 -30 0,1 50,-25 l 50,-25 475 | # a25,50 -30 0,1 50,-25 l 50,-25 476 | # a25,75 -30 0,1 50,-25 l 50,-25 477 | # a25,100 -30 0,1 50,-25 l 50,-25 478 | path = Path(Line(600 + 350j, 650 + 325j), 479 | Arc(650 + 325j, 25 + 25j, -30, 0, 1, 700 + 300j), 480 | Line(700 + 300j, 750 + 275j), 481 | Arc(750 + 275j, 25 + 50j, -30, 0, 1, 800 + 250j), 482 | Line(800 + 250j, 850 + 225j), 483 | Arc(850 + 225j, 25 + 75j, -30, 0, 1, 900 + 200j), 484 | Line(900 + 200j, 950 + 175j), 485 | Arc(950 + 175j, 25 + 100j, -30, 0, 1, 1000 + 150j), 486 | Line(1000 + 150j, 1050 + 125j), 487 | ) 488 | # These are *not* calculated, but just regression tests. Be skeptical. 489 | self.assertAlmostEqual(path.point(0.0), (600 + 350j)) 490 | self.assertAlmostEqual(path.point(0.3), (755.31526434 + 217.51578768j)) 491 | self.assertAlmostEqual(path.point(0.5), (832.23324151 + 156.33454892j)) 492 | self.assertAlmostEqual(path.point(0.9), (974.00559321 + 115.26473532j)) 493 | self.assertAlmostEqual(path.point(1.0), (1050 + 125j)) 494 | # The errors seem to accumulate. Still 6 decimal places is more than good enough. 495 | self.assertAlmostEqual(path.length(), 860.6756221710) 496 | 497 | def test_repr(self): 498 | path = Path( 499 | Line(start=600 + 350j, end=650 + 325j), 500 | Arc(start=650 + 325j, radius=25 + 25j, rotation=-30, arc=0, sweep=1, end=700 + 300j), 501 | CubicBezier(start=700 + 300j, control1=800 + 400j, control2=750 + 200j, end=600 + 100j), 502 | QuadraticBezier(start=600 + 100j, control=600, end=600 + 300j)) 503 | self.assertEqual(eval(repr(path)), path) 504 | 505 | def test_reverse(self): 506 | # Currently you can't reverse paths. 507 | self.assertRaises(NotImplementedError, Path().reverse) 508 | 509 | def test_equality(self): 510 | # This is to test the __eq__ and __ne__ methods, so we can't use 511 | # assertEqual and assertNotEqual 512 | path1 = Path( 513 | Line(start=600 + 350j, end=650 + 325j), 514 | Arc(start=650 + 325j, radius=25 + 25j, rotation=-30, arc=0, sweep=1, end=700 + 300j), 515 | CubicBezier(start=700 + 300j, control1=800 + 400j, control2=750 + 200j, end=600 + 100j), 516 | QuadraticBezier(start=600 + 100j, control=600, end=600 + 300j)) 517 | path2 = Path( 518 | Line(start=600 + 350j, end=650 + 325j), 519 | Arc(start=650 + 325j, radius=25 + 25j, rotation=-30, arc=0, sweep=1, end=700 + 300j), 520 | CubicBezier(start=700 + 300j, control1=800 + 400j, control2=750 + 200j, end=600 + 100j), 521 | QuadraticBezier(start=600 + 100j, control=600, end=600 + 300j)) 522 | 523 | self.assertTrue(path1 == path2) 524 | # Modify path2: 525 | path2[0].start = 601 + 350j 526 | self.assertTrue(path1 != path2) 527 | 528 | # Modify back: 529 | path2[0].start = 600 + 350j 530 | self.assertFalse(path1 != path2) 531 | 532 | # Get rid of the last segment: 533 | del path2[-1] 534 | self.assertFalse(path1 == path2) 535 | 536 | # It's not equal to a list of it's segments 537 | self.assertTrue(path1 != path1[:]) 538 | self.assertFalse(path1 == path1[:]) 539 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Sketch-RNN Model.""" 15 | 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import os 21 | import random 22 | 23 | # internal imports 24 | 25 | import numpy as np 26 | import tensorflow as tf 27 | import tensorflow.contrib.slim as slim 28 | try: 29 | from magenta.models.sketch_rnn import rnn 30 | except: 31 | import magenta_rnn as rnn 32 | 33 | from build_subnet import * 34 | from tf_data_work import * 35 | 36 | 37 | def copy_hparams(hparams): 38 | """Return a copy of an HParams instance.""" 39 | return tf.contrib.training.HParams(**hparams.values()) 40 | 41 | 42 | class Model(object): 43 | """Define a SketchRNN model.""" 44 | 45 | def __init__(self, hps, reuse=False): 46 | """Initializer for the SketchRNN model. 47 | 48 | Args: 49 | hps: a HParams object containing model hyperparameters 50 | gpu_mode: a boolean that when True, uses GPU mode. 51 | reuse: a boolean that when true, attemps to reuse variables. 52 | """ 53 | self.hps = hps 54 | with tf.variable_scope('vector_rnn', reuse=reuse): 55 | self.build_model(hps) 56 | 57 | def encoder(self, input_batch, sequence_lengths, reuse): 58 | if self.hps.enc_type == 'rnn': # vae mode: 59 | image_embeddings = self.rnn_encoder(input_batch, sequence_lengths) 60 | elif self.hps.enc_type == 'cnn': 61 | image_embeddings = self.cnn_encoder(input_batch, reuse) 62 | elif self.hps.enc_type == 'feat': 63 | image_embeddings = input_batch 64 | else: 65 | raise Exception('Please choose a valid encoder type') 66 | return image_embeddings 67 | 68 | def rnn_encoder(self, batch, sequence_lengths): 69 | 70 | if self.hps.rnn_model == 'lstm': 71 | enc_cell_fn = rnn.LSTMCell 72 | elif self.hps.rnn_model == 'layer_norm': 73 | enc_cell_fn = rnn.LayerNormLSTMCell 74 | elif self.hps.rnn_model == 'hyper': 75 | enc_cell_fn = rnn.HyperLSTMCell 76 | else: 77 | assert False, 'please choose a respectable cell' 78 | 79 | if self.hps.rnn_model == 'hyper': 80 | self.enc_cell_fw = enc_cell_fn( 81 | self.hps.enc_rnn_size, 82 | use_recurrent_dropout=self.hps.use_recurrent_dropout, 83 | dropout_keep_prob=self.hps.recurrent_dropout_prob) 84 | self.enc_cell_bw = enc_cell_fn( 85 | self.hps.enc_rnn_size, 86 | use_recurrent_dropout=self.hps.use_recurrent_dropout, 87 | dropout_keep_prob=self.hps.recurrent_dropout_prob) 88 | else: 89 | self.enc_cell_fw = enc_cell_fn( 90 | self.hps.enc_rnn_size, 91 | use_recurrent_dropout=self.hps.use_recurrent_dropout, 92 | dropout_keep_prob=self.hps.recurrent_dropout_prob) 93 | self.enc_cell_bw = enc_cell_fn( 94 | self.hps.enc_rnn_size, 95 | use_recurrent_dropout=self.hps.use_recurrent_dropout, 96 | dropout_keep_prob=self.hps.recurrent_dropout_prob) 97 | 98 | """Define the bi-directional encoder module of sketch-rnn.""" 99 | unused_outputs, last_states = tf.nn.bidirectional_dynamic_rnn( 100 | self.enc_cell_fw, 101 | self.enc_cell_bw, 102 | batch, 103 | sequence_length=sequence_lengths, 104 | time_major=False, 105 | swap_memory=True, 106 | dtype=tf.float32, 107 | scope='ENC_RNN') 108 | 109 | last_state_fw, last_state_bw = last_states 110 | last_h_fw = self.enc_cell_fw.get_output(last_state_fw) 111 | last_h_bw = self.enc_cell_bw.get_output(last_state_bw) 112 | last_h = tf.concat([last_h_fw, last_h_bw], 1) 113 | return last_h 114 | 115 | def cnn_encoder(self, batch_input, reuse): 116 | if self.hps.is_train: 117 | is_train = True 118 | dropout_keep_prob = self.hps.drop_kp 119 | else: 120 | is_train = False 121 | dropout_keep_prob = 1.0 122 | tf_batch_input = tf_image_processing(batch_input, self.hps.basenet, self.hps.crop_size, self.hps.dist_aug, self.hps.hp_filter) 123 | self.tf_images = tf_batch_input 124 | with tf.variable_scope(tf.get_variable_scope(), reuse=reuse): 125 | if self.hps.basenet == 'sketchanet': 126 | feature = sketch_a_net_slim(tf_batch_input) 127 | 128 | elif self.hps.basenet == 'gen_cnn': 129 | # feature = generative_cnn_encoder(tf_batch_input, is_train, dropout_keep_prob, reuse=reuse) 130 | feature = generative_cnn_encoder(tf_batch_input, True, dropout_keep_prob, reuse=reuse) 131 | 132 | elif FLAGS.basenet == 'alexnet': 133 | feature, end_points = tf_alexnet_single(tf_batch_input, dropout_keep_prob) 134 | 135 | elif FLAGS.basenet == 'vgg': 136 | _, feature = build_single_vggnet(tf_batch_input, is_train, dropout_keep_prob) 137 | 138 | elif FLAGS.basenet == 'resnet': 139 | print('Warning, resnet scope is not set') 140 | _, feature = build_single_resnet(tf_batch_input, is_train, name_scope='resnet_v1_50') 141 | 142 | elif FLAGS.basenet == 'inceptionv1': 143 | # _, feature = build_single_inceptionv1(tf_batch_input, is_train, dropout_keep_prob) 144 | _, feature = build_single_inceptionv1(tf_batch_input, True, dropout_keep_prob) 145 | # _, feature = build_single_inceptionv1(tf_batch_input, False, dropout_keep_prob) 146 | 147 | elif FLAGS.basenet == 'inceptionv3': 148 | # _, feature = build_single_inceptionv3(batch_input, is_train, dropout_keep_prob, reduce_dim=False) 149 | _, feature = build_single_inceptionv3(tf_batch_input, True, dropout_keep_prob, reduce_dim=False) 150 | # _, feature = build_single_inceptionv3(tf_batch_input, False, dropout_keep_prob, reduce_dim=False) 151 | 152 | else: 153 | raise Exception('basenet error') 154 | return feature 155 | 156 | def decoder(self, actual_input_x, initial_state, reuse): 157 | 158 | # decoder module of sketch-rnn is below 159 | with tf.variable_scope("RNN", reuse=reuse) as rnn_scope: 160 | output, last_state = tf.nn.dynamic_rnn( 161 | self.cell, 162 | actual_input_x, 163 | initial_state=initial_state, 164 | time_major=False, 165 | swap_memory=True, 166 | dtype=tf.float32, 167 | scope=rnn_scope) 168 | return output, last_state 169 | 170 | def cnn_decoder(self, z_input, reuse): 171 | 172 | if self.hps.is_train: 173 | is_train = True 174 | dropout_keep_prob = self.hps.drop_kp 175 | else: 176 | is_train = False 177 | dropout_keep_prob = 1.0 178 | 179 | output = generative_cnn_decoder(z_input, is_train, dropout_keep_prob, reuse) 180 | 181 | return output 182 | 183 | def get_mu_sig(self, image_embedding): 184 | enc_size = int(image_embedding.shape[-1]) 185 | mu = rnn.super_linear( 186 | image_embedding, 187 | self.hps.z_size, 188 | input_size=enc_size, 189 | scope='ENC_RNN_mu', 190 | init_w='gaussian', 191 | weight_start=0.001) 192 | presig = rnn.super_linear( 193 | image_embedding, 194 | self.hps.z_size, 195 | input_size=enc_size, 196 | scope='ENC_RNN_sigma', 197 | init_w='gaussian', 198 | weight_start=0.001) 199 | return mu, presig 200 | 201 | def build_kl_for_vae(self, image_embedding, scope_name, with_state=True, reuse=False): 202 | with tf.variable_scope(scope_name, reuse=reuse): 203 | if with_state: 204 | return self.get_init_state(image_embedding) 205 | else: 206 | return self.get_kl_cost(image_embedding) 207 | 208 | def get_init_state(self, image_embedding): 209 | self.mean, self.presig = self.get_mu_sig(image_embedding) 210 | self.sigma = tf.exp(self.presig / 2.0) # sigma > 0. div 2.0 -> sqrt. 211 | eps = tf.random_normal( 212 | (self.hps.batch_size, self.hps.z_size), 0.0, 1.0, dtype=tf.float32) 213 | # batch_z = self.mean + tf.multiply(self.sigma, eps) 214 | if self.hps.is_train: 215 | batch_z = self.mean + tf.multiply(self.sigma, eps) 216 | else: 217 | batch_z = self.mean 218 | if self.hps.inter_z: 219 | batch_z = self.mean + tf.multiply(self.sigma, self.sample_gussian) 220 | # KL cost 221 | kl_cost = -0.5 * tf.reduce_mean( 222 | (1 + self.presig - tf.square(self.mean) - tf.exp(self.presig))) 223 | kl_cost = tf.maximum(kl_cost, self.hps.kl_tolerance) 224 | 225 | # get initial state based on batch_z 226 | initial_state = tf.nn.tanh( 227 | rnn.super_linear( 228 | batch_z, 229 | self.cell.state_size, 230 | init_w='gaussian', 231 | weight_start=0.001, 232 | input_size=self.hps.z_size)) 233 | pre_tile_y = tf.reshape(batch_z, [self.hps.batch_size, 1, self.hps.z_size]) 234 | overlay_x = tf.tile(pre_tile_y, [1, self.hps.max_seq_len, 1]) 235 | actual_input_x = tf.concat([self.input_x, overlay_x], 2) 236 | 237 | return initial_state, actual_input_x, batch_z, kl_cost 238 | 239 | def get_kl_cost(self, image_embedding): 240 | self.mean, self.presig = self.get_mu_sig(image_embedding) 241 | self.sigma = tf.exp(self.presig / 2.0) # sigma > 0. div 2.0 -> sqrt. 242 | eps = tf.random_normal( 243 | (self.hps.batch_size, self.hps.z_size), 0.0, 1.0, dtype=tf.float32) 244 | batch_z = self.mean + tf.multiply(self.sigma, eps) 245 | # KL cost 246 | kl_cost = -0.5 * tf.reduce_mean( 247 | (1 + self.presig - tf.square(self.mean) - tf.exp(self.presig))) 248 | kl_cost = tf.maximum(kl_cost, self.hps.kl_tolerance) 249 | 250 | return batch_z, kl_cost 251 | 252 | def config_model(self, hps): 253 | """Define model architecture.""" 254 | if hps.is_train: 255 | self.global_step = tf.Variable(0, name='global_step', trainable=False) 256 | # self.global_step = tf.get_variable('global_step', trainable=False) 257 | 258 | if hps.rnn_model == 'lstm': 259 | cell_fn = rnn.LSTMCell 260 | elif hps.rnn_model == 'layer_norm': 261 | cell_fn = rnn.LayerNormLSTMCell 262 | elif hps.rnn_model == 'hyper': 263 | cell_fn = rnn.HyperLSTMCell 264 | else: 265 | assert False, 'please choose a respectable cell' 266 | 267 | self.hps.crop_size, self.hps.chn_size = get_input_size() 268 | 269 | use_recurrent_dropout = self.hps.use_recurrent_dropout 270 | rnn_input_dropout = self.hps.rnn_input_dropout 271 | rnn_output_dropout = self.hps.rnn_output_dropout 272 | 273 | if hps.rnn_model == 'hyper': 274 | cell = cell_fn( 275 | hps.dec_rnn_size, 276 | use_recurrent_dropout=use_recurrent_dropout, 277 | dropout_keep_prob=self.hps.recurrent_dropout_prob) 278 | else: 279 | cell = cell_fn( 280 | hps.dec_rnn_size, 281 | use_recurrent_dropout=use_recurrent_dropout, 282 | dropout_keep_prob=self.hps.recurrent_dropout_prob) 283 | 284 | # dropout: 285 | if rnn_input_dropout: 286 | cell = tf.contrib.rnn.DropoutWrapper(cell, input_keep_prob=self.hps.input_dropout_prob) 287 | if rnn_output_dropout: 288 | cell = tf.contrib.rnn.DropoutWrapper(cell, output_keep_prob=self.hps.output_dropout_prob) 289 | self.cell = cell 290 | 291 | batch_size = self.hps.batch_size 292 | image_size = self.hps.image_size 293 | 294 | self.sequence_lengths = tf.placeholder( 295 | dtype=tf.int32, shape=[batch_size], name='seq_len') 296 | self.input_sketch = tf.placeholder( 297 | dtype=tf.float32, 298 | shape=[batch_size, self.hps.max_seq_len + 1, 5], name='input_sketch') 299 | self.target_sketch = tf.placeholder( 300 | dtype=tf.float32, 301 | shape=[batch_size, self.hps.max_seq_len + 1, 5], name='target_sketch') 302 | 303 | if self.hps.chn_size == 1: 304 | image_shape = [batch_size, image_size, image_size] 305 | else: 306 | image_shape = [batch_size, image_size, image_size, self.hps.chn_size] 307 | 308 | sketch_shape = [batch_size, image_size, image_size] 309 | 310 | if self.hps.vae_type == 's2s': 311 | self.input_photo = tf.placeholder(dtype=tf.float32, shape=sketch_shape, name='input_photo') 312 | else: 313 | self.input_photo = tf.placeholder(dtype=tf.float32, shape=image_shape, name='input_photo') 314 | if self.hps.vae_type in ['ps2s', 'sp2s']: 315 | self.input_sketch_photo = tf.placeholder(dtype=tf.float32, shape=sketch_shape, name='input_sketch_photo') 316 | 317 | self.input_label = tf.placeholder(dtype=tf.int32, shape=batch_size, name='label') 318 | 319 | if self.hps.inter_z: 320 | self.sample_gussian = tf.placeholder(dtype=tf.float32, shape=batch_size, name='sample_gussian') 321 | 322 | # The target/expected vectors of strokes 323 | self.output_x = self.input_sketch[:, 1:self.hps.max_seq_len + 1, :] 324 | # vectors of strokes to be fed to decoder (same as above, but lagged behind 325 | # one step to include initial dummy value of (0, 0, 1, 0, 0)) 326 | self.input_x = self.input_sketch[:, :self.hps.max_seq_len, :] 327 | 328 | def build_pix_encoder(self, input_image, reuse=False): 329 | 330 | with tf.variable_scope(tf.get_variable_scope(), reuse=reuse): 331 | image_embedding = self.cnn_encoder(input_image, reuse) 332 | 333 | return image_embedding 334 | 335 | def build_seq_encoder(self, input_strokes, reuse=False): 336 | 337 | with tf.variable_scope(tf.get_variable_scope(), reuse=reuse): 338 | strokes_embedding = self.rnn_encoder(input_strokes, self.sequence_lengths) 339 | 340 | return strokes_embedding 341 | 342 | # ###################### 343 | 344 | def build_seq_decoder(self, feat_embedding, kl_name_scope, reuse = False): 345 | 346 | initial_state, actual_input_x, batch_z, kl_cost = self.build_kl_for_vae(feat_embedding, kl_name_scope, reuse=False) 347 | with tf.variable_scope(tf.get_variable_scope(), reuse=reuse): 348 | output, last_state = self.decoder(actual_input_x, initial_state, reuse) 349 | 350 | return output, initial_state, last_state, actual_input_x, batch_z, kl_cost 351 | 352 | def build_pix_decoder(self, feat_embedding, kl_name_scope, reuse = False): 353 | 354 | batch_z, kl_cost = self.build_kl_for_vae(feat_embedding, kl_name_scope, with_state=False, reuse=False) 355 | with tf.variable_scope(tf.get_variable_scope(), reuse=reuse): 356 | output = self.cnn_decoder(batch_z, reuse) 357 | 358 | return output, batch_z, kl_cost 359 | 360 | def build_pix2seq_embedding(self, input_image, encode_pix=True, reuse = False): 361 | 362 | if encode_pix: 363 | self.pix_embedding = self.build_pix_encoder(input_image) 364 | 365 | return self.build_seq_decoder(self.pix_embedding, 'p2s', reuse=reuse) 366 | 367 | def build_seq2pix_embedding(self, input_strokes, encode_seq=True, reuse = False): 368 | 369 | if encode_seq: 370 | self.seq_embedding = self.build_seq_encoder(input_strokes) 371 | 372 | return self.build_pix_decoder(self.seq_embedding, 's2p', reuse=reuse) 373 | 374 | def build_pix2pix_embedding(self, input_image, encode_pix=True, reuse = False): 375 | 376 | if encode_pix: 377 | self.pix_embedding = self.build_pix_encoder(input_image) 378 | 379 | return self.build_pix_decoder(self.pix_embedding, 'p2p', reuse=reuse) 380 | 381 | def build_seq2seq_embedding(self, input_strokes, encode_seq=True, reuse=False): 382 | 383 | if encode_seq: 384 | self.seq_embedding = self.build_seq_encoder(input_strokes) 385 | 386 | return self.build_seq_decoder(self.seq_embedding, 's2s', reuse=reuse) 387 | 388 | def build_seq_loss(self, output, initial_state, final_state, batch_z, kl_cost, vae_type, reuse=False): 389 | # code for output, pi miu 390 | pi, mu1, mu2, sigma1, sigma2, corr, pen_logits, pen, y1_data, y2_data, r_cost, r_score, gen_strokes = self.build_strokes_rcons(output, reuse=reuse) 391 | # self.pen: pen state probabilities (result of applying softmax to self.pen_logits) 392 | 393 | r_cost = self.hps.seq_lw * r_cost 394 | 395 | cost_dict = {'rcons': r_cost, 'kl': kl_cost} 396 | 397 | end_points = {'init_s': initial_state, 'fin_s': final_state, 'pi': pi, 'mu1': mu1, 'mu2': mu2, 'sigma1': sigma1, 398 | 'sigma2': sigma2, 'corr': corr, 'pen': pen, 'batch_z': batch_z} 399 | 400 | cost_dict_vae_type = {'%s_%s' % (vae_type, key): cost_dict[key] for key in cost_dict.keys()} 401 | end_points_vae_type = {'%s_%s' % (vae_type, key): end_points[key] for key in end_points.keys()} 402 | 403 | return gen_strokes, cost_dict_vae_type, end_points_vae_type 404 | 405 | def build_pix_loss(self, gen_photo, batch_z, kl_cost, vae_type, reuse=False): 406 | r_cost = self.build_photo_rcons(self.target_photo, gen_photo) 407 | # self.pen: pen state probabilities (result of applying softmax to self.pen_logits) 408 | 409 | r_cost = self.hps.pix_lw * r_cost 410 | 411 | cost_dict = {'rcons': r_cost, 'kl': kl_cost} 412 | 413 | gen_photo_rgb = tf.cast((gen_photo + 1) * 127.5, tf.int16) 414 | 415 | end_points = {'gen_photo': gen_photo, 'gen_photo_rgb': gen_photo_rgb, 'batch_z': batch_z} 416 | 417 | cost_dict_vae_type = {'%s_%s' % (vae_type, key): cost_dict[key] for key in cost_dict.keys()} 418 | end_points_vae_type = {'%s_%s' % (vae_type, key): end_points[key] for key in end_points.keys()} 419 | 420 | return gen_photo_rgb, cost_dict_vae_type, end_points_vae_type 421 | 422 | def build_strokes_rcons(self, output, reuse=False): 423 | 424 | # target data 425 | x1_data, x2_data = self.x1_data, self.x2_data 426 | eos_data, eoc_data, cont_data = self.eos_data, self.eoc_data, self.cont_data 427 | 428 | # TODO(deck): Better understand this comment. 429 | # Number of outputs is 3 (one logit per pen state) plus 6 per mixture 430 | # component: mean_x, stdev_x, mean_y, stdev_y, correlation_xy, and the 431 | # mixture weight/probability (Pi_k) 432 | n_out = (3 + self.hps.num_mixture * 6) 433 | 434 | with tf.variable_scope(tf.get_variable_scope(), reuse=reuse): 435 | 436 | with tf.variable_scope('RNN'): 437 | output_w = tf.get_variable('output_w', [self.hps.dec_rnn_size, n_out]) 438 | output_b = tf.get_variable('output_b', [n_out]) 439 | 440 | output_reshape = tf.reshape(output, [-1, self.hps.dec_rnn_size]) 441 | 442 | output_mdn = tf.nn.xw_plus_b(output_reshape, output_w, output_b) 443 | 444 | o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, o_corr, o_pen, o_pen_logits = get_mixture_coef(output_mdn) 445 | 446 | o_id = tf.stack([tf.range(0, tf.shape(o_mu1)[0]), tf.cast(tf.argmax(o_pi, 1), tf.int32)], axis=1) 447 | y1_data = tf.gather_nd(o_mu1, o_id) 448 | y2_data = tf.gather_nd(o_mu2, o_id) 449 | 450 | y3_data = tf.cast(tf.greater(tf.argmax(o_pen_logits, 1), 0), tf.float32) 451 | gen_strokes = tf.stack([y1_data, y2_data, y3_data], 1) 452 | gen_strokes = tf.reshape(gen_strokes, [self.hps.batch_size, -1, 3]) 453 | start_points_np = np.zeros((self.hps.batch_size, 1, 3)) 454 | start_points_tf = tf.constant(start_points_np, dtype=tf.float32) 455 | gen_strokes = tf.concat([start_points_tf, gen_strokes], 1) 456 | 457 | pen_data = tf.concat([eos_data, eoc_data, cont_data], 1) 458 | 459 | rcons_loss = get_rcons_loss_pen_state(o_pen_logits, pen_data, self.hps.is_train) 460 | 461 | rcons_loss += get_rcons_loss_mdn(o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, o_corr, x1_data, x2_data, pen_data) 462 | 463 | r_cost = tf.reduce_mean(rcons_loss) 464 | r_score = -tf.reduce_sum(tf.reshape(rcons_loss, [self.hps.batch_size, -1]), axis=1) 465 | return o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, o_corr, o_pen_logits, o_pen, y1_data, y2_data, r_cost, r_score, gen_strokes 466 | 467 | def build_photo_rcons(self, real_images, output_images): 468 | 469 | # target data 470 | image_shape = real_images.get_shape() 471 | output_images_reshaped = tf.reshape(output_images, image_shape) 472 | pixel_losses = tf.reduce_mean(tf.square(real_images - output_images_reshaped)) 473 | return pixel_losses 474 | 475 | def build_l2_loss(self): 476 | 477 | self.rnn_l2 = tf.reduce_mean(tf.square(self.end_points['p2s_batch_z'] - self.end_points['s2s_batch_z'])) 478 | self.cnn_l2 = tf.reduce_mean(tf.square(self.end_points['s2p_batch_z'] - self.end_points['p2p_batch_z'])) 479 | 480 | self.l2_cost = self.rnn_l2 + self.cnn_l2 481 | 482 | def build_seq_discriminator(self, x, y, l, reuse): 483 | 484 | # set label for orig and gen 485 | label_r = tf.ones([self.hps.batch_size, 1], tf.int32) 486 | label_f = tf.zeros([self.hps.batch_size, 1], tf.int32) 487 | 488 | # build domain classifier 489 | cell_type, n_hidden, num_layers = self.hps.dis_model, self.hps.dis_num_hidden, self.hps.dis_num_layers 490 | in_dp, out_dp, batch_size = self.hps.dis_input_dropout, self.hps.dis_output_dropout, self.hps.batch_size 491 | pred_r, logits_r = rnn_discriminator(x, l, cell_type, n_hidden, num_layers, in_dp, out_dp, batch_size, reuse=reuse) 492 | pred_f, logits_f = rnn_discriminator(y, l, cell_type, n_hidden, num_layers, in_dp, out_dp, batch_size, reuse=True) 493 | 494 | # if self.hps.w_gan: 495 | # dis_loss, gen_loss = wgan_gp_loss(logits_r, logits_f, None, use_gradients=False) 496 | # dis_acc = tf.constant(-1.0) 497 | # else: 498 | dis_loss, gen_loss, dis_acc = get_adv_loss(logits_r, logits_f, label_r, label_f) 499 | 500 | dis_loss *= self.hps.rnn_dis_lw 501 | gen_loss *= self.hps.rnn_gen_lw 502 | 503 | return dis_loss, gen_loss, dis_acc 504 | 505 | def build_pix_discriminator(self, x, y, reuse): 506 | # set label for orig and gen 507 | label_r = tf.ones([self.hps.batch_size, 1], tf.int32) 508 | label_f = tf.zeros([self.hps.batch_size, 1], tf.int32) 509 | 510 | # build domain classifier 511 | batch_size = self.hps.batch_size 512 | pred_r, logits_r = cnn_discriminator(x, batch_size, reuse=reuse) 513 | pred_f, logits_f = cnn_discriminator(y, batch_size, reuse=True) 514 | 515 | if self.hps.gp_gan: 516 | alpha = tf.random_uniform(shape=[self.hps.batch_size, 1, 1, 1], minval=0., maxval=1.) 517 | differences = y - x 518 | interpolates = x + (alpha*differences) 519 | gradients = tf.gradients(cnn_discriminator(interpolates, batch_size, reuse=True)[1], [interpolates])[0] 520 | dis_loss, gen_loss, dis_acc = get_adv_gp_loss(logits_r, logits_f, label_r, label_f, gradients) 521 | else: 522 | dis_loss, gen_loss, dis_acc = get_adv_loss(logits_r, logits_f, label_r, label_f) 523 | 524 | dis_loss *= self.hps.cnn_dis_lw 525 | gen_loss *= self.hps.cnn_gen_lw 526 | 527 | return dis_loss, gen_loss, dis_acc 528 | 529 | def build_wgan_seq_discriminator(self, x, y, l, reuse): 530 | 531 | print("Build wgan seq discriminator") 532 | 533 | # build domain classifier 534 | logits_r = wgan_gp_rnn_discriminator(x, reuse=reuse) 535 | logits_f = wgan_gp_rnn_discriminator(y, reuse=True) 536 | 537 | alpha = tf.random_uniform(shape=[self.hps.batch_size, 1, 1], minval=0., maxval=1.) 538 | differences = y - x 539 | interpolates = x + (alpha*differences) 540 | gradients = tf.gradients(wgan_gp_rnn_discriminator(interpolates, reuse=True), [interpolates])[0] 541 | dis_loss, gen_loss = wgan_gp_loss(logits_f, logits_r, gradients) 542 | dis_acc = tf.constant(-1.0) 543 | 544 | dis_loss *= self.hps.rnn_dis_lw 545 | gen_loss *= self.hps.rnn_gen_lw 546 | 547 | return dis_loss, gen_loss, dis_acc 548 | 549 | def build_wgan_pix_discriminator(self, x, y, reuse): 550 | 551 | print("Build wgan pix discriminator") 552 | 553 | # build domain classifier 554 | logits_r = wgan_gp_cnn_discriminator(x, reuse=reuse) 555 | logits_f = wgan_gp_cnn_discriminator(y, reuse=True) 556 | 557 | alpha = tf.random_uniform(shape=[self.hps.batch_size, 1, 1, 1], minval=0., maxval=1.) 558 | differences = y - x 559 | interpolates = x + (alpha*differences) 560 | gradients = tf.gradients(wgan_gp_cnn_discriminator(interpolates, reuse=True), [interpolates])[0] 561 | dis_loss, gen_loss = wgan_gp_loss(logits_f, logits_r, gradients) 562 | dis_acc = tf.constant(-1.0) 563 | 564 | dis_loss *= self.hps.cnn_dis_lw 565 | gen_loss *= self.hps.cnn_gen_lw 566 | 567 | return dis_loss, gen_loss, dis_acc 568 | 569 | def get_train_vars(self): 570 | 571 | self.t_vars = tf.trainable_variables() 572 | self.d_vars = [var for var in self.t_vars if 'DIS' in var.name] 573 | self.g_vars = [var for var in self.t_vars if 'DIS' not in var.name] 574 | 575 | def get_train_op(self): 576 | 577 | self.apply_decay() 578 | 579 | # get total loss 580 | self.get_total_loss() 581 | 582 | # get train vars 583 | self.get_train_vars() 584 | 585 | optimizer = tf.train.AdamOptimizer(self.lr) 586 | gvs = optimizer.compute_gradients(self.cost) 587 | 588 | capped_gvs = clip_gradients(gvs, self.hps.grad_clip) 589 | self.train_op = optimizer.apply_gradients( 590 | capped_gvs, global_step=self.global_step, name='train_step') 591 | 592 | def apply_decay(self): 593 | if self.hps.lr_decay: 594 | self.lr = tf.train.exponential_decay(self.hps.lr, self.global_step, self.hps.decay_step, self.hps.decay_rate, staircase=True) 595 | else: 596 | self.lr = self.hps.lr 597 | 598 | # self.kl_weight = tf.Variable(self.hps.kl_weight_start, trainable=False) 599 | if self.hps.kl_weight_decay: 600 | self.kl_weight = tf.train.exponential_decay(self.hps.kl_weight_start, self.global_step, self.hps.kl_decay_step, self.hps.kl_decay_rate, staircase=True) 601 | else: 602 | self.kl_weight = self.hps.kl_weight_start 603 | 604 | if self.hps.l2_weight_decay: 605 | self.l2_weight = tf.train.exponential_decay(self.hps.l2_weight_start, self.global_step, self.hps.l2_decay_step, self.hps.l2_decay_rate, staircase=True) 606 | else: 607 | self.l2_weight = self.hps.l2_weight_start 608 | 609 | def get_total_loss(self): 610 | self.p2s_kl, self.s2p_kl = self.cost_dict['p2s_kl'], self.cost_dict['s2p_kl'] 611 | self.p2p_kl, self.s2s_kl = self.cost_dict['p2p_kl'], self.cost_dict['s2s_kl'] 612 | self.kl_cost = self.p2s_kl + self.s2p_kl + self.p2p_kl + self.s2s_kl 613 | self.cost = self.kl_cost * self.kl_weight 614 | 615 | # get reconstruction loss 616 | self.p2s_r, self.s2p_r = self.cost_dict['p2s_rcons'], self.cost_dict['s2p_rcons'] 617 | self.p2p_r, self.s2s_r = self.cost_dict['p2p_rcons'], self.cost_dict['s2s_rcons'] 618 | self.r_cost = self.p2s_r + self.s2p_r + self.p2p_r + self.s2s_r 619 | self.cost += self.r_cost 620 | 621 | def get_target_strokes(self): 622 | target = tf.reshape(self.output_x, [-1, 5]) 623 | # reshape target data so that it is compatible with prediction shape 624 | [self.x1_data, self.x2_data, self.eos_data, self.eoc_data, self.cont_data] = tf.split(target, 5, 1) 625 | start_points_np = np.zeros((self.hps.batch_size, 1, 3)) 626 | start_points_tf = tf.constant(start_points_np, dtype=tf.float32) 627 | self.target_strokes = tf.concat([self.x1_data, self.x2_data, 1 - self.eos_data], 1) 628 | self.target_strokes = tf.reshape(self.target_strokes, [self.hps.batch_size, -1, 3]) 629 | self.target_strokes = tf.concat([start_points_tf, self.target_strokes], 1) 630 | 631 | def get_target_photo(self): 632 | self.target_photo = \ 633 | tf_image_processing(self.input_photo, self.hps.basenet, self.hps.crop_size, self.hps.dist_aug, self.hps.hp_filter) 634 | 635 | def build_model(self, hps): 636 | self.config_model(hps) 637 | 638 | # get target data 639 | self.get_target_strokes() 640 | self.get_target_photo() 641 | 642 | # build photo to stroke-level synthesis part 643 | self.gen_strokes, cost_dict_p2s, end_points_p2s = self.build_pix2seq_branch(self.input_photo) 644 | 645 | # build stroke-level to photo synthesis part 646 | self.gen_photo, cost_dict_s2p, end_points_s2p = self.build_seq2pix_branch(self.input_sketch) 647 | 648 | # build photo to photo reconstruction part 649 | self.recon_photo, cost_dict_p2p, end_points_p2p = self.build_pix2pix_branch(self.input_photo, encode_pix=False, reuse=True) 650 | 651 | # build sketch to sketch reconstruction part 652 | self.recon_sketch, cost_dict_s2s, end_points_s2s = self.build_seq2seq_branch(self.input_sketch, encode_seq=False, reuse=True) 653 | 654 | self.cost_dict = dict(cost_dict_p2s.items() + cost_dict_s2p.items() + cost_dict_p2p.items() + cost_dict_s2s.items()) 655 | self.end_points = dict(end_points_p2s.items() + end_points_s2p.items() + end_points_p2p.items() + end_points_s2s.items()) 656 | 657 | self.initial_state, self.final_state = self.end_points['p2s_init_s'], self.end_points['p2s_fin_s'] 658 | self.pi, self.corr = self.end_points['p2s_pi'], self.end_points['p2s_corr'] 659 | self.mu1, self.mu2 = self.end_points['p2s_mu1'], self.end_points['p2s_mu2'] 660 | self.sigma1, self.sigma2 = self.end_points['p2s_sigma1'], self.end_points['p2s_sigma2'] 661 | self.pen = self.end_points['p2s_pen'] 662 | self.batch_z = self.end_points['p2s_batch_z'] 663 | 664 | self.recon_initial_state, self.recon_final_state = self.end_points['s2s_init_s'], self.end_points['s2s_fin_s'] 665 | self.recon_pi, self.recon_corr = self.end_points['s2s_pi'], self.end_points['s2s_corr'] 666 | self.recon_mu1, self.recon_mu2 = self.end_points['s2s_mu1'], self.end_points['s2s_mu2'] 667 | self.recon_sigma1, self.recon_sigma2 = self.end_points['s2s_sigma1'], self.end_points['s2s_sigma2'] 668 | self.recon_pen = self.end_points['s2s_pen'] 669 | self.recon_batch_z = self.end_points['s2s_batch_z'] 670 | 671 | if self.hps.is_train: 672 | # self.get_train_op_with_bn() # dosen't work 673 | self.get_train_op() 674 | 675 | def build_pix2seq_branch(self, input_photo, encode_pix=True, reuse=False): 676 | # pixel to sequence 677 | output, initial_state, final_state, actual_input_x, batch_z, kl_cost = \ 678 | self.build_pix2seq_embedding(input_photo, encode_pix=encode_pix, reuse=reuse) 679 | 680 | return self.build_seq_loss(output, initial_state, final_state, batch_z, kl_cost, 'p2s', reuse=reuse) 681 | 682 | def build_seq2pix_branch(self, input_strokes, encode_seq=True, reuse=False): 683 | # sequence to pixel 684 | gen_photo, batch_z, kl_cost = self.build_seq2pix_embedding(input_strokes, encode_seq=encode_seq, reuse=reuse) 685 | 686 | return self.build_pix_loss(gen_photo, batch_z, kl_cost, 's2p', reuse=reuse) 687 | 688 | def build_pix2pix_branch(self, input_photo, encode_pix=False, reuse=False): 689 | # pixel to pixel 690 | gen_photo, batch_z, kl_cost = self.build_pix2pix_embedding(input_photo, encode_pix=encode_pix, reuse=reuse) 691 | 692 | return self.build_pix_loss(gen_photo, batch_z, kl_cost, 'p2p', reuse=reuse) 693 | 694 | def build_seq2seq_branch(self, input_strokes, encode_seq=False, reuse=False): 695 | output, initial_state, final_state, actual_input_x, batch_z, kl_cost = \ 696 | self.build_seq2seq_embedding(input_strokes, encode_seq=encode_seq, reuse=reuse) 697 | 698 | return self.build_seq_loss(output, initial_state, final_state, batch_z, kl_cost, 's2s', reuse=reuse) 699 | 700 | 701 | def get_pi_idx(pdf): 702 | """Samples from a pdf.""" 703 | return np.argmax(pdf) 704 | 705 | 706 | def sample(sess, model, input_image, sketch=None, seq_len=250, temperature=0.5, with_sketch=False, rnn_enc_seq_len = None, cond_sketch=False, inter_z=False, inter_z_sample=0): 707 | """Samples a sequence from a pre-trained model.""" 708 | 709 | prev_x = np.zeros((1, 1, 5), dtype=np.float32) 710 | prev_x[0, 0, 2] = 1 # initially, we want to see beginning of new stroke 711 | # print("enter the function of sample") 712 | 713 | if cond_sketch: 714 | if int(model.input_photo.get_shape()[-1]) == 3: 715 | input_image = input_image[:,:,:,np.newaxis] 716 | input_image = np.concatenate([input_image, input_image, input_image], -1) 717 | 718 | if rnn_enc_seq_len is None: 719 | if inter_z: 720 | prev_state = sess.run(model.initial_state, feed_dict={model.input_photo: input_image, model.sample_gussian: inter_z_sample}) 721 | else: 722 | prev_state = sess.run(model.initial_state, feed_dict={model.input_photo: input_image}) 723 | # image_embedding = sess.run(model.image_embedding, feed_dict={model.input_photo: input_image}) 724 | # batch_z = sess.run(model.batch_z, feed_dict={model.input_photo: input_image}) 725 | else: 726 | if inter_z: 727 | prev_state = sess.run(model.initial_state, feed_dict={model.input_photo: input_image, model.sequence_lengths: rnn_enc_seq_len, model.sample_gussian: inter_z_sample}) 728 | else: 729 | prev_state = sess.run(model.initial_state, feed_dict={model.input_photo: input_image, model.sequence_lengths: rnn_enc_seq_len}) 730 | # image_embedding = sess.run(model.image_embedding, feed_dict={model.input_photo: input_image, model.sequence_lengths: rnn_enc_seq_len}) 731 | # batch_z = sess.run(model.batch_z, feed_dict={model.input_photo: input_image, model.sequence_lengths: rnn_enc_seq_len}) 732 | 733 | strokes = np.zeros((seq_len, 5), dtype=np.float32) 734 | mixture_params = [] 735 | 736 | for i in range(seq_len): 737 | if inter_z: 738 | feed = { 739 | model.input_x: prev_x, 740 | model.sequence_lengths: [1], 741 | model.initial_state: prev_state, 742 | model.input_photo: input_image, 743 | model.sample_gussian: inter_z_sample 744 | } 745 | else: 746 | feed = { 747 | model.input_x: prev_x, 748 | model.sequence_lengths: [1], 749 | model.initial_state: prev_state, 750 | model.input_photo: input_image 751 | } 752 | 753 | params = sess.run([ 754 | model.pi, model.mu1, model.mu2, model.sigma1, model.sigma2, model.corr, 755 | model.pen, model.final_state 756 | ], feed) 757 | 758 | [o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, o_corr, o_pen, next_state] = params 759 | 760 | idx = get_pi_idx(o_pi[0]) 761 | 762 | idx_eos = get_pi_idx(o_pen[0]) 763 | eos = [0, 0, 0] 764 | eos[idx_eos] = 1 765 | 766 | next_x1, next_x2 = o_mu1[0][idx], o_mu2[0][idx] 767 | 768 | strokes[i, :] = [next_x1, next_x2, eos[0], eos[1], eos[2]] 769 | 770 | params = [o_pi[0], o_mu1[0], o_mu2[0], o_sigma1[0], o_sigma2[0], o_corr[0], o_pen[0]] 771 | 772 | mixture_params.append(params) 773 | 774 | prev_x = np.zeros((1, 1, 5), dtype=np.float32) 775 | if with_sketch: 776 | prev_x[0][0] = sketch[0][i+1] 777 | else: 778 | prev_x[0][0] = np.array( 779 | [next_x1, next_x2, eos[0], eos[1], eos[2]], dtype=np.float32) 780 | prev_state = next_state 781 | 782 | return strokes, mixture_params 783 | 784 | 785 | def sample_recons(sess, model, gen_model, input_sketch, sketch=None, seq_len=250, temperature=0.5, with_sketch=False, cond_sketch=False, inter_z=False, inter_z_sample=0): 786 | """Samples a sequence from a pre-trained model.""" 787 | 788 | prev_x = np.zeros((1, 1, 5), dtype=np.float32) 789 | prev_x[0, 0, 2] = 1 # initially, we want to see beginning of new stroke 790 | # print("enter the function of sample") 791 | 792 | if inter_z: 793 | feed_dict = { 794 | gen_model.input_sketch: input_sketch, 795 | gen_model.sequence_lengths: [seq_len], 796 | gen_model.sample_gussian: inter_z_sample 797 | } 798 | prev_state, batch_z = sess.run([gen_model.recon_initial_state, gen_model.recon_batch_z], feed_dict=feed_dict) 799 | else: 800 | feed_dict = { 801 | gen_model.input_sketch: input_sketch, 802 | gen_model.sequence_lengths: [seq_len] 803 | } 804 | prev_state, batch_z = sess.run([gen_model.recon_initial_state, gen_model.recon_batch_z], feed_dict=feed_dict) 805 | 806 | strokes = np.zeros((seq_len, 5), dtype=np.float32) 807 | mixture_params = [] 808 | 809 | for i in range(seq_len): 810 | if not model.hps.concat_z: 811 | feed = { 812 | model.input_x: prev_x, 813 | model.sequence_lengths: [1], 814 | model.recon_initial_state: prev_state 815 | } 816 | elif inter_z: 817 | feed = { 818 | model.input_x: prev_x, 819 | model.sequence_lengths: [1], 820 | model.recon_initial_state: prev_state, 821 | model.sample_gussian: inter_z_sample, 822 | model.recon_batch_z: batch_z 823 | } 824 | else: 825 | feed = { 826 | model.input_x: prev_x, 827 | model.sequence_lengths: [1], 828 | model.recon_initial_state: prev_state, 829 | model.recon_batch_z: batch_z 830 | } 831 | 832 | params = sess.run([ 833 | model.recon_pi, model.recon_mu1, model.recon_mu2, model.recon_sigma1, model.recon_sigma2, model.recon_corr, 834 | model.recon_pen, model.recon_final_state 835 | ], feed) 836 | 837 | [o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, o_corr, o_pen, next_state] = params 838 | 839 | idx = get_pi_idx(o_pi[0]) 840 | 841 | idx_eos = get_pi_idx(o_pen[0]) 842 | eos = [0, 0, 0] 843 | eos[idx_eos] = 1 844 | 845 | next_x1, next_x2 = o_mu1[0][idx], o_mu2[0][idx] 846 | 847 | strokes[i, :] = [next_x1, next_x2, eos[0], eos[1], eos[2]] 848 | 849 | params = [o_pi[0], o_mu1[0], o_mu2[0], o_sigma1[0], o_sigma2[0], o_corr[0], o_pen[0]] 850 | 851 | mixture_params.append(params) 852 | 853 | prev_x = np.zeros((1, 1, 5), dtype=np.float32) 854 | if with_sketch: 855 | prev_x[0][0] = sketch[0][i+1] 856 | else: 857 | prev_x[0][0] = np.array( 858 | [next_x1, next_x2, eos[0], eos[1], eos[2]], dtype=np.float32) 859 | prev_state = next_state 860 | 861 | return strokes, mixture_params 862 | 863 | 864 | def get_init_fn(pretrain_model, checkpoint_exclude_scopes): 865 | """Returns a function run by the chief worker to warm-start the training.""" 866 | print("load pretrained model from %s" % pretrain_model) 867 | exclusions = [scope.strip() for scope in checkpoint_exclude_scopes] 868 | 869 | variables_to_restore = [] 870 | # for var in slim.get_model_variables(): 871 | for var in tf.trainable_variables(): 872 | excluded = False 873 | for exclusion in exclusions: 874 | if var.op.name.startswith(exclusion): 875 | excluded = True 876 | break 877 | if not excluded: 878 | print(var.name) 879 | variables_to_restore.append(var) 880 | 881 | return slim.assign_from_checkpoint_fn(pretrain_model, variables_to_restore) 882 | 883 | 884 | def get_input_size(): 885 | if FLAGS.basenet == 'alexnet': 886 | crop_size = 227 887 | channel_size = 3 888 | elif FLAGS.basenet in ['resnet', 'inceptionv3']: 889 | crop_size = 299 890 | channel_size = 3 891 | elif FLAGS.basenet in ['sketchynet', 'inceptionv1', 'resnet', 'vgg', 'mobilenet', 'gen_cnn']: 892 | crop_size = 224 893 | channel_size = 3 894 | else: 895 | crop_size = 225 896 | channel_size = 1 897 | return crop_size, channel_size 898 | --------------------------------------------------------------------------------