├── svg
├── path
│ ├── tests
│ │ ├── __init__.py
│ │ ├── test_doc.py
│ │ ├── test_generation.py
│ │ ├── test_parsing.py
│ │ └── test_paths.py
│ ├── path.pyc
│ ├── parser.pyc
│ ├── __init__.pyc
│ ├── __init__.py
│ ├── parser.py
│ └── path.py
├── __init__.py
└── __init__.pyc
├── model.pyc
├── utils.pyc
├── magenta_rnn.pyc
├── svg_utils.pyc
├── tf_data_work.pyc
├── images
├── example.png
├── highlight.png
└── architecture.png
├── README.md
├── prepare_data.py
├── sketchrnn_cnn_dual_test.py
├── sketchrnn_cnn_dual_train.py
├── tf_data_work.py
├── svg_utils.py
├── magenta_rnn.py
└── model.py
/svg/path/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/model.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seindlut/deep_p2s/HEAD/model.pyc
--------------------------------------------------------------------------------
/utils.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seindlut/deep_p2s/HEAD/utils.pyc
--------------------------------------------------------------------------------
/svg/__init__.py:
--------------------------------------------------------------------------------
1 | __import__('pkg_resources').declare_namespace(__name__)
2 |
--------------------------------------------------------------------------------
/magenta_rnn.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seindlut/deep_p2s/HEAD/magenta_rnn.pyc
--------------------------------------------------------------------------------
/svg_utils.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seindlut/deep_p2s/HEAD/svg_utils.pyc
--------------------------------------------------------------------------------
/svg/__init__.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seindlut/deep_p2s/HEAD/svg/__init__.pyc
--------------------------------------------------------------------------------
/svg/path/path.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seindlut/deep_p2s/HEAD/svg/path/path.pyc
--------------------------------------------------------------------------------
/tf_data_work.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seindlut/deep_p2s/HEAD/tf_data_work.pyc
--------------------------------------------------------------------------------
/images/example.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seindlut/deep_p2s/HEAD/images/example.png
--------------------------------------------------------------------------------
/images/highlight.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seindlut/deep_p2s/HEAD/images/highlight.png
--------------------------------------------------------------------------------
/svg/path/parser.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seindlut/deep_p2s/HEAD/svg/path/parser.pyc
--------------------------------------------------------------------------------
/svg/path/__init__.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seindlut/deep_p2s/HEAD/svg/path/__init__.pyc
--------------------------------------------------------------------------------
/images/architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seindlut/deep_p2s/HEAD/images/architecture.png
--------------------------------------------------------------------------------
/svg/path/__init__.py:
--------------------------------------------------------------------------------
1 | from .path import Path, Line, Arc, CubicBezier, QuadraticBezier
2 | from .parser import parse_path
3 |
--------------------------------------------------------------------------------
/svg/path/tests/test_doc.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import doctest
3 |
4 |
5 | def load_tests(loader, tests, ignore):
6 | tests.addTests(doctest.DocFileSuite('README.rst', package='__main__'))
7 | return tests
8 |
--------------------------------------------------------------------------------
/svg/path/tests/test_generation.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | import unittest
3 | from ..path import CubicBezier, QuadraticBezier, Line, Arc, Path
4 | from ..parser import parse_path
5 |
6 |
7 | class TestGeneration(unittest.TestCase):
8 |
9 | def test_svg_examples(self):
10 | """Examples from the SVG spec"""
11 | paths = [
12 | 'M 100,100 L 300,100 L 200,300 Z',
13 | 'M 0,0 L 50,20 M 100,100 L 300,100 L 200,300 Z',
14 | 'M 100,100 L 200,200',
15 | 'M 100,200 L 200,100 L -100,-200',
16 | 'M 100,200 C 100,100 250,100 250,200 S 400,300 400,200',
17 | 'M 100,200 C 100,100 400,100 400,200',
18 | 'M 100,500 C 25,400 475,400 400,500',
19 | 'M 100,800 C 175,700 325,700 400,800',
20 | 'M 600,200 C 675,100 975,100 900,200',
21 | 'M 600,500 C 600,350 900,650 900,500',
22 | 'M 600,800 C 625,700 725,700 750,800 S 875,900 900,800',
23 | 'M 200,300 Q 400,50 600,300 T 1000,300',
24 | 'M -3.4E+38,3.4E+38 L -3.4E-38,3.4E-38',
25 | 'M 0,0 L 50,20 M 50,20 L 200,100 Z',
26 | 'M 600,350 L 650,325 A 25,25 -30 0,1 700,300 L 750,275',
27 | ]
28 |
29 | for path in paths:
30 | self.assertEqual(parse_path(path).d(), path)
31 |
32 | def test_normalizing(self):
33 | # Relative paths will be made absolute, subpaths merged if they can,
34 | # and syntax will change.
35 | self.assertEqual(parse_path('M0 0L3.4E2-10L100.0,100M100,100l100,-100').d(),
36 | 'M 0,0 L 340,-10 L 100,100 L 200,0')
37 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Deep Photo-to-Sketch Synthesis Model
2 |
3 | Before jumping in our code implementation based on [Tensorflow](https://github.com/tensorflow/tensorflow), please refer to our paper [Learning to Sketch with Shortcut Cycle Consistency](https://arxiv.org/abs/1805.00247) for the basic idea.
4 |
5 |
6 | # Overview
7 |
8 | In this paper, we present a novel approach for translating an object photo to a sketch, mimicking the human sketching process. Teaching a machine to generate a sketch from a photo just like humans do is not easy. This requires not only developing an abstract concept of a visual object instance, but also knowing what, where and when to sketch the next line stroke. Figure \ref{fig:highlight} shows that the developed photo-to-sketch synthesizer takes a photo as input and mimics the human sketching process by sequentially drawing one stroke at a time. The resulting synthesized sketches provide an abstract and semantically meaningful depiction of the given object, just like human sketches do.
9 |
10 |

11 |
12 |
13 |
14 | *Examples of our model mimicking to sketch stroke by stroke.*
15 |
16 | # Model Structure
17 |
18 | We aim to learn a mapping function between the photo domain *X* and sketch domain *Y*, where we denote the empirical data distribution as *x ~ pdata(x)* and *y ~ pdata(y)* and represent each vector sketch segment as (*sxi*, *syi*), a two-dimensional offset vector. Our model includes four mapping functions, learned using four subnets namely a photo encoder, a sketch encoder, a photo decoder, a sketch decoder. The illustration of our model architecture is as shown as below.
19 |
20 | 
21 |
22 | # Training a Model
23 |
24 | Our deep photo-to-sketch (p2s) synthesis model is trained on the dataset of ShoeV2 and ChairV2.
25 |
26 | Usage:
27 |
28 | ```bash
29 | python sketchrnn_cnn_dual_train.py --dataset shoesv2
30 | ```
31 |
32 | As mentioned in the paper, before you train a photo-to-sketch (p2s) synthesis model, you need pretrain your model on the [QuickDraw](https://github.com/googlecreativelab/quickdraw-dataset) Data from the corresponding categories.
33 |
34 | We have tested this model on TensorFlow 1.4 for Python 2.7.
35 |
36 | # Result
37 |
38 | Example:
39 |
40 | 
41 |
42 | # Datasets
43 |
44 | The datasets for our photo-to-sketch synthesis task are *ShoeV2* and *ChairV2* datasets, which can be dowloaded from the homepage of our group [SketchX](http://sketchx.eecs.qmul.ac.uk/downloads/).
45 |
46 | The pretraining dataset can be download from [QuickDraw](https://github.com/googlecreativelab/quickdraw-dataset).
47 |
48 | The original data can be converted to hdf5 format using *prepare_data.py*. Or you can download it from the [GoogleDrive](https://drive.google.com/open?id=1029l8QZ9EWzQEb9GDylVM3EP4_kurW1E).
49 |
50 | # Citation
51 |
52 | If you find this project useful for academic purposes, please cite it as:
53 |
54 | ```
55 | @Inproceedings{song2018learning,
56 | title = {Learning to Sketch with Shortcut Cycle Consistency},
57 | author = {Song, Jifei and Pang, Kaiyue and Song, Yi-Zhe and Xiang, Tao and Hospedales, Timothy M},
58 | booktitle = {CVPR},
59 | year = {2018}
60 | }
61 | ```
62 |
--------------------------------------------------------------------------------
/svg/path/tests/test_parsing.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | import unittest
3 | from ..path import CubicBezier, QuadraticBezier, Line, Arc, Path
4 | from ..parser import parse_path
5 |
6 |
7 | class TestParser(unittest.TestCase):
8 |
9 | def test_svg_examples(self):
10 | """Examples from the SVG spec"""
11 | path1 = parse_path('M 100 100 L 300 100 L 200 300 z')
12 | self.assertEqual(path1, Path(Line(100 + 100j, 300 + 100j),
13 | Line(300 + 100j, 200 + 300j),
14 | Line(200 + 300j, 100 + 100j)))
15 | self.assertTrue(path1.closed)
16 |
17 | # for Z command behavior when there is multiple subpaths
18 | path1 = parse_path('M 0 0 L 50 20 M 100 100 L 300 100 L 200 300 z')
19 | self.assertEqual(path1, Path(
20 | Line(0 + 0j, 50 + 20j),
21 | Line(100 + 100j, 300 + 100j),
22 | Line(300 + 100j, 200 + 300j),
23 | Line(200 + 300j, 100 + 100j)))
24 |
25 | path1 = parse_path('M 100 100 L 200 200')
26 | path2 = parse_path('M100 100L200 200')
27 | self.assertEqual(path1, path2)
28 |
29 | path1 = parse_path('M 100 200 L 200 100 L -100 -200')
30 | path2 = parse_path('M 100 200 L 200 100 -100 -200')
31 | self.assertEqual(path1, path2)
32 |
33 | path1 = parse_path("""M100,200 C100,100 250,100 250,200
34 | S400,300 400,200""")
35 | self.assertEqual(path1,
36 | Path(CubicBezier(100 + 200j, 100 + 100j, 250 + 100j, 250 + 200j),
37 | CubicBezier(250 + 200j, 250 + 300j, 400 + 300j, 400 + 200j)))
38 |
39 | path1 = parse_path('M100,200 C100,100 400,100 400,200')
40 | self.assertEqual(path1,
41 | Path(CubicBezier(100 + 200j, 100 + 100j, 400 + 100j, 400 + 200j)))
42 |
43 | path1 = parse_path('M100,500 C25,400 475,400 400,500')
44 | self.assertEqual(path1,
45 | Path(CubicBezier(100 + 500j, 25 + 400j, 475 + 400j, 400 + 500j)))
46 |
47 | path1 = parse_path('M100,800 C175,700 325,700 400,800')
48 | self.assertEqual(path1,
49 | Path(CubicBezier(100 + 800j, 175 + 700j, 325 + 700j, 400 + 800j)))
50 |
51 | path1 = parse_path('M600,200 C675,100 975,100 900,200')
52 | self.assertEqual(path1,
53 | Path(CubicBezier(600 + 200j, 675 + 100j, 975 + 100j, 900 + 200j)))
54 |
55 | path1 = parse_path('M600,500 C600,350 900,650 900,500')
56 | self.assertEqual(path1,
57 | Path(CubicBezier(600 + 500j, 600 + 350j, 900 + 650j, 900 + 500j)))
58 |
59 | path1 = parse_path("""M600,800 C625,700 725,700 750,800
60 | S875,900 900,800""")
61 | self.assertEqual(path1,
62 | Path(CubicBezier(600 + 800j, 625 + 700j, 725 + 700j, 750 + 800j),
63 | CubicBezier(750 + 800j, 775 + 900j, 875 + 900j, 900 + 800j)))
64 |
65 | path1 = parse_path('M200,300 Q400,50 600,300 T1000,300')
66 | self.assertEqual(path1,
67 | Path(QuadraticBezier(200 + 300j, 400 + 50j, 600 + 300j),
68 | QuadraticBezier(600 + 300j, 800 + 550j, 1000 + 300j)))
69 |
70 | path1 = parse_path('M300,200 h-150 a150,150 0 1,0 150,-150 z')
71 | self.assertEqual(path1,
72 | Path(Line(300 + 200j, 150 + 200j),
73 | Arc(150 + 200j, 150 + 150j, 0, 1, 0, 300 + 50j),
74 | Line(300 + 50j, 300 + 200j)))
75 |
76 | path1 = parse_path('M275,175 v-150 a150,150 0 0,0 -150,150 z')
77 | self.assertEqual(path1,
78 | Path(Line(275 + 175j, 275 + 25j),
79 | Arc(275 + 25j, 150 + 150j, 0, 0, 0, 125 + 175j),
80 | Line(125 + 175j, 275 + 175j)))
81 |
82 | path1 = parse_path("""M600,350 l 50,-25
83 | a25,25 -30 0,1 50,-25 l 50,-25
84 | a25,50 -30 0,1 50,-25 l 50,-25
85 | a25,75 -30 0,1 50,-25 l 50,-25
86 | a25,100 -30 0,1 50,-25 l 50,-25""")
87 | self.assertEqual(path1,
88 | Path(Line(600 + 350j, 650 + 325j),
89 | Arc(650 + 325j, 25 + 25j, -30, 0, 1, 700 + 300j),
90 | Line(700 + 300j, 750 + 275j),
91 | Arc(750 + 275j, 25 + 50j, -30, 0, 1, 800 + 250j),
92 | Line(800 + 250j, 850 + 225j),
93 | Arc(850 + 225j, 25 + 75j, -30, 0, 1, 900 + 200j),
94 | Line(900 + 200j, 950 + 175j),
95 | Arc(950 + 175j, 25 + 100j, -30, 0, 1, 1000 + 150j),
96 | Line(1000 + 150j, 1050 + 125j)))
97 |
98 | def test_others(self):
99 | # Other paths that need testing:
100 |
101 | # Relative moveto:
102 | path1 = parse_path('M 0 0 L 50 20 m 50 80 L 300 100 L 200 300 z')
103 | self.assertEqual(path1, Path(
104 | Line(0 + 0j, 50 + 20j),
105 | Line(100 + 100j, 300 + 100j),
106 | Line(300 + 100j, 200 + 300j),
107 | Line(200 + 300j, 100 + 100j)))
108 |
109 | # Initial smooth and relative CubicBezier
110 | path1 = parse_path("""M100,200 s 150,-100 150,0""")
111 | self.assertEqual(path1,
112 | Path(CubicBezier(100 + 200j, 100 + 200j, 250 + 100j, 250 + 200j)))
113 |
114 | # Initial smooth and relative QuadraticBezier
115 | path1 = parse_path("""M100,200 t 150,0""")
116 | self.assertEqual(path1,
117 | Path(QuadraticBezier(100 + 200j, 100 + 200j, 250 + 200j)))
118 |
119 | # Relative QuadraticBezier
120 | path1 = parse_path("""M100,200 q 0,0 150,0""")
121 | self.assertEqual(path1,
122 | Path(QuadraticBezier(100 + 200j, 100 + 200j, 250 + 200j)))
123 |
124 | def test_negative(self):
125 | """You don't need spaces before a minus-sign"""
126 | path1 = parse_path('M100,200c10-5,20-10,30-20')
127 | path2 = parse_path('M 100 200 c 10 -5 20 -10 30 -20')
128 | self.assertEqual(path1, path2)
129 |
130 | def test_numbers(self):
131 | """Exponents and other number format cases"""
132 | # It can be e or E, the plus is optional, and a minimum of +/-3.4e38 must be supported.
133 | path1 = parse_path('M-3.4e38 3.4E+38L-3.4E-38,3.4e-38')
134 | path2 = Path(Line(-3.4e+38 + 3.4e+38j, -3.4e-38 + 3.4e-38j))
135 | self.assertEqual(path1, path2)
136 |
137 | def test_errors(self):
138 | self.assertRaises(ValueError, parse_path, 'M 100 100 L 200 200 Z 100 200')
139 |
--------------------------------------------------------------------------------
/svg/path/parser.py:
--------------------------------------------------------------------------------
1 | # SVG Path specification parser
2 |
3 | import re
4 | from . import path
5 |
6 | COMMANDS = set('MmZzLlHhVvCcSsQqTtAa')
7 | UPPERCASE = set('MZLHVCSQTA')
8 |
9 | COMMAND_RE = re.compile("([MmZzLlHhVvCcSsQqTtAa])")
10 | FLOAT_RE = re.compile("[-+]?[0-9]*\.?[0-9]+(?:[eE][-+]?[0-9]+)?")
11 |
12 |
13 | def _tokenize_path(pathdef):
14 | for x in COMMAND_RE.split(pathdef):
15 | if x in COMMANDS:
16 | yield x
17 | for token in FLOAT_RE.findall(x):
18 | yield token
19 |
20 |
21 | def parse_path(pathdef, current_pos=0j):
22 | # In the SVG specs, initial movetos are absolute, even if
23 | # specified as 'm'. This is the default behavior here as well.
24 | # But if you pass in a current_pos variable, the initial moveto
25 | # will be relative to that current_pos. This is useful.
26 | elements = list(_tokenize_path(pathdef))
27 | # Reverse for easy use of .pop()
28 | elements.reverse()
29 |
30 | segments = path.Path()
31 | start_pos = None
32 | command = None
33 |
34 | while elements:
35 |
36 | if elements[-1] in COMMANDS:
37 | # New command.
38 | last_command = command # Used by S and T
39 | command = elements.pop()
40 | absolute = command in UPPERCASE
41 | command = command.upper()
42 | else:
43 | # If this element starts with numbers, it is an implicit command
44 | # and we don't change the command. Check that it's allowed:
45 | if command is None:
46 | raise ValueError("Unallowed implicit command in %s, position %s" % (
47 | pathdef, len(pathdef.split()) - len(elements)))
48 |
49 | if command == 'M':
50 | # Moveto command.
51 | x = elements.pop()
52 | y = elements.pop()
53 | pos = float(x) + float(y) * 1j
54 | if absolute:
55 | current_pos = pos
56 | else:
57 | current_pos += pos
58 |
59 | # when M is called, reset start_pos
60 | # This behavior of Z is defined in svg spec:
61 | # http://www.w3.org/TR/SVG/paths.html#PathDataClosePathCommand
62 | start_pos = current_pos
63 |
64 | # Implicit moveto commands are treated as lineto commands.
65 | # So we set command to lineto here, in case there are
66 | # further implicit commands after this moveto.
67 | command = 'L'
68 |
69 | elif command == 'Z':
70 | # Close path
71 | segments.append(path.Line(current_pos, start_pos))
72 | segments.closed = True
73 | current_pos = start_pos
74 | start_pos = None
75 | command = None # You can't have implicit commands after closing.
76 |
77 | elif command == 'L':
78 | x = elements.pop()
79 | y = elements.pop()
80 | pos = float(x) + float(y) * 1j
81 | if not absolute:
82 | pos += current_pos
83 | segments.append(path.Line(current_pos, pos))
84 | current_pos = pos
85 |
86 | elif command == 'H':
87 | x = elements.pop()
88 | pos = float(x) + current_pos.imag * 1j
89 | if not absolute:
90 | pos += current_pos.real
91 | segments.append(path.Line(current_pos, pos))
92 | current_pos = pos
93 |
94 | elif command == 'V':
95 | y = elements.pop()
96 | pos = current_pos.real + float(y) * 1j
97 | if not absolute:
98 | pos += current_pos.imag * 1j
99 | segments.append(path.Line(current_pos, pos))
100 | current_pos = pos
101 |
102 | elif command == 'C':
103 | control1 = float(elements.pop()) + float(elements.pop()) * 1j
104 | control2 = float(elements.pop()) + float(elements.pop()) * 1j
105 | end = float(elements.pop()) + float(elements.pop()) * 1j
106 |
107 | if not absolute:
108 | control1 += current_pos
109 | control2 += current_pos
110 | end += current_pos
111 |
112 | segments.append(path.CubicBezier(current_pos, control1, control2, end))
113 | current_pos = end
114 |
115 | elif command == 'S':
116 | # Smooth curve. First control point is the "reflection" of
117 | # the second control point in the previous path.
118 |
119 | if last_command not in 'CS':
120 | # If there is no previous command or if the previous command
121 | # was not an C, c, S or s, assume the first control point is
122 | # coincident with the current point.
123 | control1 = current_pos
124 | else:
125 | # The first control point is assumed to be the reflection of
126 | # the second control point on the previous command relative
127 | # to the current point.
128 | control1 = current_pos + current_pos - segments[-1].control2
129 |
130 | control2 = float(elements.pop()) + float(elements.pop()) * 1j
131 | end = float(elements.pop()) + float(elements.pop()) * 1j
132 |
133 | if not absolute:
134 | control2 += current_pos
135 | end += current_pos
136 |
137 | segments.append(path.CubicBezier(current_pos, control1, control2, end))
138 | current_pos = end
139 |
140 | elif command == 'Q':
141 | control = float(elements.pop()) + float(elements.pop()) * 1j
142 | end = float(elements.pop()) + float(elements.pop()) * 1j
143 |
144 | if not absolute:
145 | control += current_pos
146 | end += current_pos
147 |
148 | segments.append(path.QuadraticBezier(current_pos, control, end))
149 | current_pos = end
150 |
151 | elif command == 'T':
152 | # Smooth curve. Control point is the "reflection" of
153 | # the second control point in the previous path.
154 |
155 | if last_command not in 'QT':
156 | # If there is no previous command or if the previous command
157 | # was not an Q, q, T or t, assume the first control point is
158 | # coincident with the current point.
159 | control = current_pos
160 | else:
161 | # The control point is assumed to be the reflection of
162 | # the control point on the previous command relative
163 | # to the current point.
164 | control = current_pos + current_pos - segments[-1].control
165 |
166 | end = float(elements.pop()) + float(elements.pop()) * 1j
167 |
168 | if not absolute:
169 | end += current_pos
170 |
171 | segments.append(path.QuadraticBezier(current_pos, control, end))
172 | current_pos = end
173 |
174 | elif command == 'A':
175 | radius = float(elements.pop()) + float(elements.pop()) * 1j
176 | rotation = float(elements.pop())
177 | arc = float(elements.pop())
178 | sweep = float(elements.pop())
179 | end = float(elements.pop()) + float(elements.pop()) * 1j
180 |
181 | if not absolute:
182 | end += current_pos
183 |
184 | segments.append(path.Arc(current_pos, radius, rotation, arc, sweep, end))
185 | current_pos = end
186 |
187 | return segments
188 |
--------------------------------------------------------------------------------
/prepare_data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import scipy.io as sio
4 | import glob
5 | import cv2
6 | import h5py
7 | import json
8 | import sys
9 | import shutil
10 | import svg_utils
11 | from svgpathtools import svg2paths, wsvg
12 | from svgpathtools import svg2paths2
13 |
14 |
15 | def save_hdf5(fname, d):
16 | hf = h5py.File(fname, 'w')
17 | for key in d.keys():
18 | value = d[key]
19 | if type(value) is list:
20 | value = np.array(value)
21 | dtype = value.dtype.name
22 | if 'string' in dtype:
23 | dtype = value.dtype.str.split('|')[1]
24 | value = [v.encode("ascii", "ignore") for v in value]
25 | hf.create_dataset(key, (len(value),1), dtype, value)
26 | else:
27 | hf.create_dataset(key,
28 | dtype=value.dtype.name,
29 | data=value)
30 | hf.close()
31 | return fname
32 |
33 |
34 | def load_hdf5(fname):
35 | hf = h5py.File(fname, 'r')
36 | d = {key: np.array(hf.get(key)) for key in hf.keys()}
37 | hf.close()
38 | return d
39 |
40 |
41 | def read_info(dataset_folder):
42 | subset_info = {
43 | 'image_nums': [],
44 | 'class_names': [],
45 | 'num_classes': 0
46 | }
47 | data_info = {
48 | 'id': [],
49 | 'class_name': [],
50 | 'class_id': [],
51 | 'image_name': [],
52 | 'image_id': [],
53 | 'instance_id': [],
54 | 'image_data': []
55 | }
56 | class_name = dataset_folder.split('/')[-3]
57 | data_type = dataset_folder.split('/')[-2]
58 | # class_id_dict = {'shoes': 0, 'chairs': 1}
59 | class_id_dict = {'shoes': 0, 'chairs': 0}
60 | id_in_list = 0
61 | image_id_offset = 0
62 | class_id = 0
63 |
64 | image_files = os.walk(dataset_folder).next()[2]
65 | # sort image files
66 | image_files.sort()
67 | image_base_names = []
68 | unique_image_base_names = []
69 | instance_ids = []
70 | print "read info for %s in %s" % (class_name, data_type)
71 | for image_file in image_files:
72 | if '_' not in image_file:
73 | raise Exception('Sketch file name wrong')
74 | image_base_name = '_'.join(image_file.split('_')[:-1])
75 | instance_id = image_file.split('_')[-1]
76 | instance_id = instance_id.split('.')[0]
77 | instance_ids.append(int(instance_id) - 1)
78 |
79 | image_base_names.append(image_base_name)
80 | if image_base_name not in unique_image_base_names:
81 | unique_image_base_names.append(image_base_name)
82 |
83 | # this is to avoid the ranking problem that "a56-3002_1.png < a_1.png" but "a.png < a56-3002.png" on chair dataset
84 | image_files = np.array(image_files)[np.argsort(image_base_names)].tolist()
85 | image_base_names = np.array(image_base_names)[np.argsort(image_base_names)].tolist()
86 | instance_ids = np.array(instance_ids)[np.argsort(image_base_names)].tolist()
87 | unique_image_base_names = np.sort(unique_image_base_names).tolist()
88 | # image_base_names.sort()
89 |
90 | for idx in range(len(image_files)):
91 | image_file = image_files[idx]
92 | data_info['id'].append(id_in_list)
93 | data_info['class_name'].append(class_name)
94 | data_info['class_id'].append(class_id_dict[class_name])
95 | data_info['image_name'].append(image_file)
96 | data_info['image_id'].append(unique_image_base_names.index(image_base_names[idx]) + image_id_offset)
97 | data_info['instance_id'].append(instance_ids[idx])
98 | id_in_list += 1
99 | image_id_offset += len(unique_image_base_names)
100 | class_id += 1
101 |
102 | print "\n Data list reading complete"
103 |
104 | num_images = len(data_info['image_name'])
105 | print "save svg data"
106 | data_info['image_data'] = []
107 | data_info['data_offset'] = np.zeros((num_images, 2))
108 | start_idx = 0
109 | for idx in range(num_images):
110 | sys.stdout.write('\x1b[2K\r>> Process svg data, [%d/%d]' % (idx, num_images))
111 | sys.stdout.flush()
112 | lines = svg_utils.build_lines(os.path.join(dataset_folder, data_info['image_name'][idx]))
113 | data_info['image_data'].extend(lines)
114 | end_idx = start_idx + len(lines)
115 | data_info['data_offset'][idx, ::] = [start_idx, end_idx]
116 | start_idx = end_idx
117 | data_info['data_offset'] = data_info['data_offset'].astype(int)
118 |
119 | if save_png:
120 | if simplify_flag:
121 | png_data_dir = dataset_folder.split('sim')[0] + 'png'
122 | else:
123 | png_data_dir = dataset_folder + '_png'
124 | print "save rgb data"
125 | data_info['png_data'] = np.zeros((num_images, 256, 256), dtype=np.uint8)
126 | for idx in range(num_images):
127 | sys.stdout.write('\x1b[2K\r>> Process png data, [%d/%d]' % (idx, num_images))
128 | sys.stdout.flush()
129 | im = cv2.imread(os.path.join(png_data_dir, data_info['image_name'][idx]).split('.svg')[0] + '.png')
130 | im = cv2.resize(im, (256, 256))
131 | im = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)
132 | # im = cv2.imread(os.path.join(dataset_folder, test_img.split('.jpg')[0] + '.jpg'))
133 | data_info['png_data'][idx, ::] = im
134 |
135 | return subset_info, data_info
136 |
137 |
138 | def prepare_dbs(image_folder, data_type = 'h5'):
139 |
140 | if simplify_flag:
141 | simplify_str = '_sim'
142 | else:
143 | simplify_str = ''
144 |
145 | sketchy_info_train, data_info_train = read_info(os.path.join(image_folder, 'svg_train%s' % simplify_str))
146 | sketchy_info_test, data_info_test = read_info(os.path.join(image_folder, 'svg_test%s' % simplify_str))
147 |
148 | if save_png:
149 | simplify_str += '_png'
150 |
151 | save_hdf5(os.path.join(image_folder, 'train_svg%s.%s' % (simplify_str, data_type)), data_info_train)
152 | save_hdf5(os.path.join(image_folder, 'test_svg%s.%s' % (simplify_str, data_type)), data_info_test)
153 |
154 |
155 | def generate_db_list(image_folder):
156 | train_file_list_txt_origin = os.path.join(image_folder, 'train.txt')
157 | train_file_list_txt = os.path.join(image_folder, 'train_svg.txt')
158 | test_file_list_txt_origin = os.path.join(image_folder, 'test.txt')
159 | test_file_list_txt = os.path.join(image_folder, 'test_svg.txt')
160 | if not os.path.exists(train_file_list_txt) or not os.path.exists(test_file_list_txt):
161 | with open(train_file_list_txt_origin, 'r') as f:
162 | train_file_lists_origin = f.read().splitlines()
163 | with open(test_file_list_txt_origin, 'r') as f:
164 | test_file_lists_origin = f.read().splitlines()
165 | train_file_lists = [item.split('png')[0] + 'svg' for item in train_file_lists_origin]
166 | test_file_lists = [item.split('png')[0] + 'svg' for item in test_file_lists_origin]
167 | with open(train_file_list_txt, 'w') as f:
168 | f.writelines("\n".join(train_file_lists))
169 | with open(test_file_list_txt, 'w') as f:
170 | f.writelines("\n".join(test_file_lists))
171 |
172 |
173 | def split_db(image_folder, train_list_txt = 'train_svg.txt', test_list_txt = 'test_svg.txt'):
174 | if train_list_txt:
175 | with open(os.path.join(image_folder, train_list_txt)) as f:
176 | train_list = f.read().splitlines()
177 | copy_db_files(image_folder, 'svg_all', 'svg_train', train_list)
178 | if test_list_txt:
179 | with open(os.path.join(image_folder, test_list_txt)) as f:
180 | test_list = f.read().splitlines()
181 | copy_db_files(image_folder, 'svg_all', 'svg_test', test_list)
182 | # copy_db_files(image_folder, 'all', 'test', test_list)
183 |
184 |
185 | def copy_db_files(root_folder, src_folder, dst_folder, file_list):
186 | src_path = os.path.join(root_folder, src_folder)
187 | dst_path = os.path.join(root_folder, dst_folder)
188 | print "Copy files from %s/%s to %s/%s" % (root_folder, src_folder, root_folder, dst_folder)
189 | src_base_names = []
190 | dst_base_names = []
191 | if not os.path.exists(dst_path):
192 | os.mkdir(dst_path)
193 | sub_dirs = os.walk(src_path).next()[1]
194 | for sub_dir in sub_dirs:
195 | os.mkdir(os.path.join(dst_path, sub_dir.replace(' ', '_').replace('-', '_').replace('(', '').replace(')', '')))
196 | for file in file_list:
197 | sys.stdout.write('\x1b[2K\r>> Copying subfolder %s ==> %s: %d/%d' % (src_folder, dst_folder, file_list.index(file)+1, len(file_list)))
198 | sys.stdout.flush()
199 | src_file = os.path.join(src_path, file)
200 | dst_file = os.path.join(dst_path, file.replace(' ', '_').replace('-', '_').replace('(', '').replace(')', ''))
201 | base_name = file.split('_')[0]
202 | if base_name not in src_base_names:
203 | src_base_names.append(base_name)
204 | try:
205 | shutil.copy2(src_file, dst_file)
206 | if base_name not in dst_base_names:
207 | dst_base_names.append(base_name)
208 | except:
209 | print "File not exist: ", src_file
210 | print "\n Copy finished"
211 | print "Warning, none of files with below basename is copied"
212 | print [filename for filename in src_base_names if filename not in dst_base_names]
213 |
214 |
215 |
216 | if __name__ == "__main__":
217 |
218 | datasets = ['shoes', 'chairs']
219 | simplify_flag = True
220 |
221 | save_png = True
222 |
223 | for dataset in datasets:
224 | data_dir = 'data/%s/svg' % dataset
225 |
226 | generate_db_list(data_dir)
227 |
228 | split_db(data_dir)
229 |
230 | prepare_dbs(data_dir)
231 |
--------------------------------------------------------------------------------
/sketchrnn_cnn_dual_test.py:
--------------------------------------------------------------------------------
1 | """Model training."""
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 | import json
8 | import os
9 | import sys
10 | import time
11 |
12 | # internal imports
13 |
14 | import cv2
15 | import numpy as np
16 | import tensorflow as tf
17 |
18 | from model import sample, sample_recons, get_init_fn
19 | import model as sketch_rnn_model
20 | import utils
21 |
22 | FLAGS = tf.app.flags.FLAGS
23 |
24 | tf.app.flags.DEFINE_string('root_dir', './data', 'The root directory for the data')
25 | # tf.app.flags.DEFINE_string('data_dir', data_dir, 'The directory to find the dataset')
26 | # tf.app.flags.DEFINE_string('dataset', 'quickdraw', 'The dataset for classification')
27 | # tf.app.flags.DEFINE_string('dataset', 'shoes', 'The dataset for classification')
28 | tf.app.flags.DEFINE_string('dataset', 'shoesv2', 'The dataset for classification')
29 | # tf.app.flags.DEFINE_string('dataset', 'quickdraw_shoe', 'The dataset for classification')
30 | tf.app.flags.DEFINE_boolean('simplify_flag', True, 'use simplified dataset')
31 | tf.app.flags.DEFINE_boolean('use_vae', True, 'use vae or ae only')
32 | tf.app.flags.DEFINE_boolean('concat_z', True, 'concatenate z with x')
33 | tf.app.flags.DEFINE_string('log_root', './models/runs', 'Directory to store model checkpoints, tensorboard.')
34 | tf.app.flags.DEFINE_float('lr', 0.0001, "Learning rate.")
35 | tf.app.flags.DEFINE_float('decay_rate', 0.9999, "Learning rate decay for certain minibatches.")
36 | tf.app.flags.DEFINE_boolean('lr_decay', False, "Learning rate decay.")
37 | tf.app.flags.DEFINE_boolean('nkl', False, "if True, keep vae architecture but remove kl loss.")
38 | tf.app.flags.DEFINE_float('kl_weight_start', 0.01, "KL start weight when annealing.")
39 | tf.app.flags.DEFINE_float('kl_decay_rate', 0.99995, "KL annealing decay rate per minibatch")
40 | tf.app.flags.DEFINE_boolean('kl_weight_decay', False, "KL weight decay.")
41 | tf.app.flags.DEFINE_boolean('kl_tolerance', 0.2, "Level of KL loss at which to stop optimizing for KL.")
42 | tf.app.flags.DEFINE_float('l2_weight_start', 0.1, "start weight for l2.")
43 | tf.app.flags.DEFINE_float('l2_decay_rate', 0.99995, "l2 decay rate per minibatch")
44 | tf.app.flags.DEFINE_boolean('l2_weight_decay', False, "l2 weight decay.")
45 | tf.app.flags.DEFINE_boolean('l2_decay_step', 5000, "l2 weight decay after how many steps.")
46 | tf.app.flags.DEFINE_integer('max_seq_len', 250, "max length of sequential data.")
47 | tf.app.flags.DEFINE_float('seq_lw', 1.0, "Loss weight for sequence reconstruction.")
48 | tf.app.flags.DEFINE_float('pix_lw', 1.0, "Loss weight for pixel reconstruction.")
49 | tf.app.flags.DEFINE_float('tri_weight', 1.0, "Triplet loss weight.")
50 | tf.app.flags.DEFINE_boolean('tune_cnn', True, 'finetune the cnn part or not, this is trying to ')
51 | tf.app.flags.DEFINE_string('vae_type', 'p2s', 'variational autoencoder type: s2s, sketch2sketch, p2s, photo2sketch, '
52 | 'ps2s/sp2s, photo2sketch & sketch2sketch')
53 | tf.app.flags.DEFINE_string('enc_type', 'cnn', 'type of encoder')
54 | tf.app.flags.DEFINE_string('rcons_type', 'mdn', 'type of reconstruction loss')
55 | tf.app.flags.DEFINE_boolean('rd_dim', 512, 'embedding dim after mlp or other subnet')
56 | # tf.app.flags.DEFINE_boolean('reduce_dim', False, 'add fc layer before the embedding loss')
57 | tf.app.flags.DEFINE_boolean('image_size', 256, 'image size for cnn')
58 | tf.app.flags.DEFINE_boolean('crop_size', 224, 'crop size for cnn')
59 | tf.app.flags.DEFINE_boolean('chn_size', 1, 'number of channel for cnn')
60 | tf.app.flags.DEFINE_string('basenet', 'gen_cnn', 'basenet for cnn encoder')
61 | tf.app.flags.DEFINE_string('feat_type', 'inceptionv3', 'feature size for the extracted photo feature')
62 | tf.app.flags.DEFINE_integer('feat_size', 2048, 'feature size for the extracted photo feature')
63 | tf.app.flags.DEFINE_float('margin', 0.1, 'Margin for contrastive/triplet loss')
64 | tf.app.flags.DEFINE_boolean('load_pretrain', False, 'Load pretrain model for CBB')
65 | tf.app.flags.DEFINE_boolean('resume_training', True, 'Set to true to load previous checkpoint')
66 | tf.app.flags.DEFINE_string('load_dir', '', 'Directory to load the pretrained model')
67 | tf.app.flags.DEFINE_string('img_dir', '', 'Directory to save the images')
68 | tf.app.flags.DEFINE_string('add_str', '', 'add str to the image save directory and checkpoints')
69 |
70 | # hyperparameters
71 | tf.app.flags.DEFINE_integer('batch_size', 100, 'Number of images to process in a batch.')
72 | tf.app.flags.DEFINE_boolean('is_train', False, 'In the training stage or not')
73 | tf.app.flags.DEFINE_float('drop_kp', 1.0, 'Dropout keep rate')
74 |
75 | # data augmentation
76 | tf.app.flags.DEFINE_boolean('flip_aug', False, 'Whether to flip the sketch and photo or not')
77 | tf.app.flags.DEFINE_boolean('dist_aug', False, 'Whether to distort the images')
78 | tf.app.flags.DEFINE_boolean('hp_filter', False, 'Whether to add high pass filter')
79 |
80 | tf.app.flags.DEFINE_integer("print_every", 100, "print training loss after this many steps (default: 20)")
81 | tf.app.flags.DEFINE_integer("save_every", 1000, "Evaluate model on dev set after this many steps (default: 100)")
82 | tf.app.flags.DEFINE_boolean('debug_test', False, 'Set to true to load previous checkpoint')
83 | tf.app.flags.DEFINE_string("saved_flags", None, "Save all flags for printing")
84 | tf.app.flags.DEFINE_string('hparams', '',
85 | 'Pass in comma-separated key=value pairs such as \'save_every=40,decay_rate=0.99\''
86 | '(no whitespace) to be read into the HParams object defined in model.py')
87 |
88 | # save settings for sampling in testing stage
89 | tf.app.flags.DEFINE_boolean('sample_sketch', True, 'Set to true to save ground truth sketch')
90 | tf.app.flags.DEFINE_boolean('save_gt_sketch', True, 'Set to true to save ground truth sketch')
91 | tf.app.flags.DEFINE_boolean('save_photo', False, 'Set to true to save ground truth photo')
92 | tf.app.flags.DEFINE_boolean('cond_sketch', False, 'Set to true to generate sketch conditioned on sketch')
93 | tf.app.flags.DEFINE_boolean('inter_z', False, 'Interpolate latent vector of batch z')
94 | tf.app.flags.DEFINE_boolean('recon_sketch', False, 'Set to true to reconstruct sketch')
95 | tf.app.flags.DEFINE_boolean('recon_photo', False, 'Set to true to reconstruct photo')
96 |
97 | # hyperparameters succeded from sketch-rnn
98 | tf.app.flags.DEFINE_float('random_scale_factor', 0.15, 'Random scaling data augmention proportion.')
99 | tf.app.flags.DEFINE_float('augment_stroke_prob', 0.10, 'Point dropping augmentation proportion.')
100 | tf.app.flags.DEFINE_string('rnn_model', 'lstm', 'lstm, layer_norm or hyper.')
101 | tf.app.flags.DEFINE_boolean('use_recurrent_dropout', True, 'Dropout with memory loss.')
102 | tf.app.flags.DEFINE_float('recurrent_dropout_prob', 1.0, 'Probability of recurrent dropout keep.')
103 | tf.app.flags.DEFINE_boolean('rnn_input_dropout', False, 'RNN input dropout.')
104 | tf.app.flags.DEFINE_boolean('rnn_output_dropout', False, 'RNN output droput.')
105 | tf.app.flags.DEFINE_integer('enc_rnn_size', 256, 'Size of RNN when used as encoder')
106 | tf.app.flags.DEFINE_integer('dec_rnn_size', 512, 'Size of RNN when used as decoder')
107 | tf.app.flags.DEFINE_integer('z_size', 128, 'Size of latent vector z')
108 | tf.app.flags.DEFINE_integer('num_mixture', 20, 'Size of latent vector z')
109 |
110 |
111 | def reset_graph():
112 | """Closes the current default session and resets the graph."""
113 | sess = tf.get_default_session()
114 | if sess:
115 | sess.close()
116 | tf.reset_default_graph()
117 |
118 |
119 | def load_model(model_dir):
120 | """Loads model for inference mode, used in jupyter notebook."""
121 | model_params = sketch_rnn_model.get_default_hparams()
122 | with tf.gfile.Open(os.path.join(model_dir, 'model_config.json'), 'r') as f:
123 | model_params.parse_json(f.read())
124 |
125 | model_params.batch_size = 1 # only sample one at a time
126 | eval_model_params = sketch_rnn_model.copy_hparams(model_params)
127 | eval_model_params.use_input_dropout = 0
128 | eval_model_params.use_recurrent_dropout = 0
129 | eval_model_params.use_output_dropout = 0
130 | eval_model_params.is_training = 0
131 | sample_model_params = sketch_rnn_model.copy_hparams(eval_model_params)
132 | sample_model_params.max_seq_len = 1 # sample one point at a time
133 | return [model_params, eval_model_params, sample_model_params]
134 |
135 |
136 | def sampling_model_eval(sess, model, gen_model, data_set, seq_len):
137 | """Returns the average weighted cost, reconstruction cost and KL cost."""
138 | sketch_size, photo_size = data_set.sketch_size, data_set.image_size
139 |
140 | folders_to_create = ['gen_test', 'gen_test_png', 'gt_test', 'gt_test_png', 'gt_test_photo', 'gt_test_sketch_image',
141 | 'gen_test_s', 'gen_test_s_png', 'gen_test_inter', 'gen_test_inter_png', 'gen_test_inter_sep',
142 | 'gen_test_inter_sep_png', 'gen_photo', 'gen_test_inter_with_photo', 'recon_test',
143 | 'recon_test_png', 'recon_photo']
144 | for folder_to_create in folders_to_create:
145 | folder_path = os.path.join(FLAGS.img_dir, '%s/%s' % (data_set.dataset, folder_to_create))
146 | if not os.path.exists(folder_path):
147 | os.mkdir(folder_path)
148 |
149 | for image_index in range(photo_size):
150 |
151 | sys.stdout.write('\x1b[2K\r>> Sampling test set, [%d/%d]' % (image_index + 1, photo_size))
152 | sys.stdout.flush()
153 |
154 | image_feat, rnn_enc_seq_len = data_set.get_input_image(image_index)
155 | sample_strokes, m = sample(sess, model, image_feat, seq_len=seq_len, rnn_enc_seq_len=rnn_enc_seq_len)
156 | strokes = utils.to_normal_strokes(sample_strokes)
157 | svg_gen_sketch = os.path.join(FLAGS.img_dir, '%s/gen_test/gen_sketch%d.svg' % (data_set.dataset, image_index))
158 | png_gen_sketch = os.path.join(FLAGS.img_dir, '%s/gen_test_png/gen_sketch%d.png' % (data_set.dataset, image_index))
159 | utils.sv_svg_png_from_strokes(strokes, svg_filename=svg_gen_sketch, png_filename=png_gen_sketch)
160 |
161 | print("\nSampling finished")
162 |
163 |
164 | def load_checkpoint(sess, checkpoint_path):
165 |
166 | ckpt = tf.train.get_checkpoint_state(checkpoint_path)
167 | if ckpt is None:
168 | raise Exception('Pretrained model not found at %s' % checkpoint_path)
169 | print('Loading model %s.' % ckpt.model_checkpoint_path)
170 | init_op = get_init_fn(checkpoint_path, [''], ckpt.model_checkpoint_path)
171 | init_op(sess)
172 |
173 |
174 | def save_model(sess, model_save_path, global_step):
175 | saver = tf.train.Saver(tf.global_variables())
176 | checkpoint_path = os.path.join(model_save_path, 'vector')
177 | print('saving model %s.' % checkpoint_path)
178 | print('global_step %i.' % global_step)
179 | saver.save(sess, checkpoint_path, global_step=global_step)
180 |
181 |
182 | def sample_test(sess, sample_model, gen_model, test_set, max_seq_len):
183 |
184 | # set image dir
185 | # FLAGS.img_dir = FLAGS.log_root.split('runs')[0] + 'sv_imgs/%s/' % FLAGS.dataset
186 | FLAGS.img_dir = FLAGS.log_root.split('runs')[0] + 'sv_imgs/'
187 |
188 | FLAGS.img_dir += 'dual/' + FLAGS.basenet
189 |
190 | sampling_model_eval(sess, sample_model, gen_model, test_set, max_seq_len, sample_model.hps.rcons_type)
191 |
192 |
193 | def tester(model_params):
194 | """Test model."""
195 | np.set_printoptions(precision=8, edgeitems=6, linewidth=200, suppress=True)
196 |
197 | print('Hyperparams:')
198 | for key, val in model_params.values().iteritems():
199 | print('%s = %s' % (key, str(val)))
200 | print('Loading data files.')
201 | test_set, sample_model_params, gen_model_params = utils.load_dataset(FLAGS.root_dir, FLAGS.dataset, model_params, inference_mode=True)
202 |
203 | reset_graph()
204 | sample_model = sketch_rnn_model.Model(sample_model_params)
205 | gen_model = sketch_rnn_model.Model(gen_model_params, reuse=True)
206 |
207 | sess = tf.Session()
208 | sess.run(tf.global_variables_initializer())
209 |
210 | if FLAGS.dataset in ['shoesv2f_sup', 'shoesv2f_train']:
211 | dataset = 'shoesv2'
212 | else:
213 | dataset = FLAGS.dataset
214 |
215 | if FLAGS.resume_training:
216 | if FLAGS.load_dir == '':
217 | FLAGS.load_dir = FLAGS.log_root.split('runs')[0] + 'model_to_test/%s/' % dataset
218 | # set dir to load the model for testing
219 | FLAGS.load_dir = os.path.join(FLAGS.load_dir, FLAGS.basenet)
220 | load_checkpoint(sess, FLAGS.load_dir)
221 |
222 | # Write config file to json file.
223 | tf.gfile.MakeDirs(FLAGS.log_root)
224 | with tf.gfile.Open(
225 | os.path.join(FLAGS.log_root, 'model_config.json'), 'w') as f:
226 | json.dump(model_params.values(), f, indent=True)
227 |
228 | sample_test(sess, sample_model, gen_model, test_set, model_params.max_seq_len)
229 |
230 |
231 | def main(unused_argv):
232 | """Load model params, save config file and start trainer."""
233 | model_params = tf.contrib.training.HParams()
234 | # merge FLAGS to hps
235 | for attr, value in sorted(FLAGS.__flags.items()):
236 | model_params.add_hparam(attr, value)
237 |
238 | tester(model_params)
239 |
240 |
241 | if __name__ == '__main__':
242 | tf.app.run(main)
--------------------------------------------------------------------------------
/sketchrnn_cnn_dual_train.py:
--------------------------------------------------------------------------------
1 | """Model training."""
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 | import json
8 | import os
9 | import sys
10 | import time
11 |
12 | # internal imports
13 |
14 | import numpy as np
15 | import tensorflow as tf
16 |
17 | from model import sample, get_init_fn
18 | import model as sketch_rnn_model
19 | import utils
20 | import cv2
21 |
22 |
23 | FLAGS = tf.app.flags.FLAGS
24 |
25 | tf.app.flags.DEFINE_string('root_dir', './data', 'The root directory for the data')
26 | tf.app.flags.DEFINE_string('dataset', 'shoesv2', 'The dataset for classification')
27 | tf.app.flags.DEFINE_string('log_root', './models/runs', 'Directory to store model checkpoints, tensorboard.')
28 | tf.app.flags.DEFINE_float('lr', 0.0001, "Learning rate.")
29 | tf.app.flags.DEFINE_float('decay_rate', 0.9, "Learning rate decay for certain minibatches.")
30 | tf.app.flags.DEFINE_float('decay_step', 5000, "Learning rate decay after how many training steps.")
31 | tf.app.flags.DEFINE_boolean('lr_decay', False, "Learning rate decay.")
32 | tf.app.flags.DEFINE_float('grad_clip', 1.0, 'Gradient clipping')
33 | tf.app.flags.DEFINE_float('kl_weight_start', 0.01, "KL start weight when annealing.")
34 | tf.app.flags.DEFINE_float('kl_decay_rate', 0.99995, "KL annealing decay rate per minibatch")
35 | tf.app.flags.DEFINE_boolean('kl_weight_decay', False, "KL weight decay.")
36 | tf.app.flags.DEFINE_boolean('kl_decay_step', 5000, "KL weight decay after how many steps.")
37 | tf.app.flags.DEFINE_boolean('kl_tolerance', 0.2, "Level of KL loss at which to stop optimizing for KL.")
38 | tf.app.flags.DEFINE_float('l2_weight_start', 0.1, "start weight for l2.")
39 | tf.app.flags.DEFINE_float('l2_decay_rate', 0.99995, "l2 decay rate per minibatch")
40 | tf.app.flags.DEFINE_boolean('l2_weight_decay', False, "l2 weight decay.")
41 | tf.app.flags.DEFINE_boolean('l2_decay_step', 5000, "l2 weight decay after how many steps.")
42 | tf.app.flags.DEFINE_integer('max_seq_len', 250, "max length of sequential data.")
43 | tf.app.flags.DEFINE_float('seq_lw', 1.0, "Loss weight for sequence reconstruction.")
44 | tf.app.flags.DEFINE_float('pix_lw', 1.0, "Loss weight for pixel reconstruction.")
45 | tf.app.flags.DEFINE_boolean('tune_cnn', True, 'finetune the cnn part or not, this is trying to ')
46 | tf.app.flags.DEFINE_string('vae_type', 'p2s', 'variational autoencoder type: s2s, sketch2sketch, p2s, photo2sketch, '
47 | 'ps2s/sp2s, photo2sketch & sketch2sketch')
48 | tf.app.flags.DEFINE_string('enc_type', 'cnn', 'type of encoder')
49 | tf.app.flags.DEFINE_boolean('image_size', 256, 'image size for cnn')
50 | tf.app.flags.DEFINE_boolean('crop_size', 224, 'crop size for cnn')
51 | tf.app.flags.DEFINE_boolean('chn_size', 1, 'number of channel for cnn')
52 | tf.app.flags.DEFINE_string('basenet', 'gen_cnn', 'basenet for cnn encoder')
53 | tf.app.flags.DEFINE_float('margin', 0.1, 'Margin for contrastive/triplet loss')
54 | tf.app.flags.DEFINE_boolean('load_pretrain', True, 'Load pretrain model for the p2s model')
55 | tf.app.flags.DEFINE_boolean('resume_training', False, 'Set to true to load previous checkpoint')
56 | tf.app.flags.DEFINE_string('load_dir', '', 'Directory to load the pretrained model')
57 | tf.app.flags.DEFINE_string('img_dir', '', 'Directory to save the images')
58 |
59 | # hyperparameters
60 | tf.app.flags.DEFINE_integer('batch_size', 100, 'Number of images to process in a batch.')
61 | tf.app.flags.DEFINE_boolean('is_train', True, 'In the training stage or not')
62 | tf.app.flags.DEFINE_float('drop_kp', 0.8, 'Dropout keep rate')
63 |
64 | # data augmentation
65 | tf.app.flags.DEFINE_boolean('flip_aug', False, 'Whether to flip the sketch and photo or not')
66 | tf.app.flags.DEFINE_boolean('dist_aug', False, 'Whether to distort the images')
67 | tf.app.flags.DEFINE_boolean('hp_filter', False, 'Whether to add high pass filter')
68 |
69 | # print flags
70 | tf.app.flags.DEFINE_integer("max_steps", 100000, "Max training steps")
71 | tf.app.flags.DEFINE_integer("print_every", 100, "print training loss after this many steps (default: 20)")
72 | tf.app.flags.DEFINE_integer("save_every", 1000, "Evaluate model on dev set after this many steps (default: 100)")
73 | tf.app.flags.DEFINE_boolean('debug_test', False, 'Set to true to load previous checkpoint')
74 | tf.app.flags.DEFINE_boolean('tee_log', True, 'Create log file to save the print info')
75 | tf.app.flags.DEFINE_boolean('inter_z', False, 'Interpolate latent vector of batch z')
76 | tf.app.flags.DEFINE_string("saved_flags", None, "Save all flags for printing")
77 |
78 | # hyperparameters succeded from sketch-rnn
79 | tf.app.flags.DEFINE_float('random_scale_factor', 0.15, 'Random scaling data augmention proportion.')
80 | tf.app.flags.DEFINE_float('augment_stroke_prob', 0.10, 'Point dropping augmentation proportion.')
81 | tf.app.flags.DEFINE_string('rnn_model', 'lstm', 'lstm, layer_norm or hyper.')
82 | tf.app.flags.DEFINE_boolean('use_recurrent_dropout', True, 'Dropout with memory loss.')
83 | tf.app.flags.DEFINE_float('recurrent_dropout_prob', 0.90, 'Probability of recurrent dropout keep.')
84 | tf.app.flags.DEFINE_boolean('rnn_input_dropout', False, 'RNN input dropout.')
85 | tf.app.flags.DEFINE_boolean('rnn_output_dropout', False, 'RNN output droput.')
86 | tf.app.flags.DEFINE_integer('enc_rnn_size', 256, 'Size of RNN when used as encoder')
87 | tf.app.flags.DEFINE_integer('dec_rnn_size', 512, 'Size of RNN when used as decoder')
88 | tf.app.flags.DEFINE_integer('z_size', 128, 'Size of latent vector z')
89 | tf.app.flags.DEFINE_integer('num_mixture', 20, 'Size of latent vector z')
90 |
91 |
92 | def reset_graph():
93 | """Closes the current default session and resets the graph."""
94 | sess = tf.get_default_session()
95 | if sess:
96 | sess.close()
97 | tf.reset_default_graph()
98 |
99 |
100 | def load_model(model_dir):
101 | """Loads model for inference mode, used in jupyter notebook."""
102 | model_params = sketch_rnn_model.get_default_hparams()
103 | with tf.gfile.Open(os.path.join(model_dir, 'model_config.json'), 'r') as f:
104 | model_params.parse_json(f.read())
105 |
106 | model_params.batch_size = 1 # only sample one at a time
107 | eval_model_params = sketch_rnn_model.copy_hparams(model_params)
108 | eval_model_params.use_input_dropout = 0
109 | eval_model_params.use_recurrent_dropout = 0
110 | eval_model_params.use_output_dropout = 0
111 | eval_model_params.is_training = 0
112 | sample_model_params = sketch_rnn_model.copy_hparams(eval_model_params)
113 | sample_model_params.max_seq_len = 1 # sample one point at a time
114 | return [model_params, eval_model_params, sample_model_params]
115 |
116 |
117 | def sampling_model(sess, model, gen_model, data_set, step, seq_len, subset_str=''):
118 | """Returns the average weighted cost, reconstruction cost and KL cost."""
119 | sketch_size, photo_size = data_set.sketch_size, data_set.image_size
120 |
121 | image_index = np.random.randint(0, photo_size)
122 | sketch_index = data_set.get_corr_sketch_id(image_index)
123 | gt_strokes = data_set.sketch_strokes[sketch_index]
124 |
125 | image_feat, rnn_enc_seq_len = data_set.get_input_image(image_index)
126 | sample_strokes, m = sample(sess, model, image_feat, seq_len=seq_len, rnn_enc_seq_len=rnn_enc_seq_len)
127 | strokes = utils.to_normal_strokes(sample_strokes)
128 | svg_gen_sketch = os.path.join(FLAGS.img_dir, '%s/%s/gensketch_for_photo%d_step%d.svg' % (data_set.dataset, subset_str, image_index, step))
129 | utils.draw_strokes(strokes, svg_filename=svg_gen_sketch)
130 | svg_gt_sketch = os.path.join(FLAGS.img_dir, '%s/%s/gt_sketch%d_for_photo%d.svg' % (data_set.dataset, subset_str, sketch_index, image_index))
131 | utils.draw_strokes(gt_strokes, svg_filename=svg_gt_sketch)
132 | input_sketch = data_set.pad_single_sketch(image_index)
133 | feed = {gen_model.input_sketch: input_sketch, gen_model.input_photo: image_feat, gen_model.sequence_lengths: [seq_len]}
134 | gen_photo = sess.run(gen_model.gen_photo, feed)
135 | gen_photo_file = os.path.join(FLAGS.img_dir, '%s/%s/gen_photo%d_step%d.png' % (data_set.dataset, subset_str, image_index, step))
136 | cv2.imwrite(gen_photo_file, cv2.cvtColor(gen_photo[0, ::].astype(np.uint8), cv2.COLOR_RGB2BGR))
137 | gt_photo = os.path.join(FLAGS.img_dir, '%s/%s/gt_photo%d.png' % (data_set.dataset, subset_str, image_index))
138 | if len(image_feat[0].shape) == 2:
139 | cv2.imwrite(gt_photo, image_feat[0])
140 | else:
141 | cv2.imwrite(gt_photo, cv2.cvtColor(image_feat[0].astype(np.uint8), cv2.COLOR_RGB2BGR))
142 |
143 |
144 | def load_pretrain(sess, vae_type, enc_type, dataset, basenet, log_root):
145 | if vae_type in ['ps2s', 'sp2s'] or dataset in ['shoesv2', 'chairsv2']:
146 | if 'shoe' in dataset:
147 | sv_str = 'shoe'
148 | elif 'chair' in dataset:
149 | sv_str = 'chair'
150 | pretrain_dir = log_root.split('runs')[0] + 'pretrained_model/%s/' % sv_str
151 | ckpt = tf.train.get_checkpoint_state(pretrain_dir)
152 | if ckpt is not None:
153 | pretrained_model = ckpt.model_checkpoint_path
154 | print('Loading model %s.' % pretrained_model)
155 | checkpoint_exclude_scopes = []
156 | init_fn = get_init_fn(pretrained_model, checkpoint_exclude_scopes)
157 | init_fn(sess)
158 | else:
159 | print('Warning: pretrained model not found at %s' % pretrain_dir)
160 |
161 |
162 | def resume_train(sess, load_dir, dataset, enc_type, basenet, feat_type, log_root):
163 | if not load_dir:
164 | if 'shoe' in dataset:
165 | sv_str = 'shoe'
166 | elif 'chair' in dataset:
167 | sv_str = 'chair'
168 | load_dir = log_root.split('runs')[0] + 'save_models/%s/' % sv_str
169 | # set dir to load the model for resume training
170 | load_dir = load_dir + basenet
171 |
172 | load_checkpoint(sess, load_dir)
173 |
174 |
175 | def load_checkpoint(sess, checkpoint_path):
176 |
177 | ckpt = tf.train.get_checkpoint_state(checkpoint_path)
178 | if ckpt is None:
179 | raise Exception('Pretrained model not found at %s' % checkpoint_path)
180 | print('Loading model %s.' % ckpt.model_checkpoint_path)
181 | saver = tf.train.Saver(tf.global_variables())
182 | saver.restore(sess, ckpt.model_checkpoint_path)
183 |
184 |
185 | def save_model(sess, saver, model_save_path, global_step):
186 | checkpoint_path = os.path.join(model_save_path, 'p2s')
187 | print('saving model %s at global_step %i.' % (checkpoint_path, global_step))
188 | saver.save(sess, checkpoint_path, global_step=global_step)
189 |
190 |
191 | def train(sess, model, train_set):
192 | """Train a sketch-rnn model."""
193 |
194 | # print log
195 | if FLAGS.tee_log:
196 | utils.config_and_print_log(FLAGS.log_root, FLAGS)
197 |
198 | # set image dir
199 | FLAGS.img_dir = FLAGS.log_root.split('runs')[0] + 'sv_imgs/'
200 |
201 | # main train loop
202 | hps = model.hps
203 | start = time.time()
204 |
205 | # create saver
206 | saver = tf.train.Saver(tf.global_variables(), max_to_keep=5)
207 |
208 | curr_step = sess.run(model.global_step)
209 |
210 | for step_id in range(curr_step, FLAGS.max_steps + FLAGS.save_every):
211 |
212 | step = sess.run(model.global_step)
213 |
214 | if hps.vae_type in ['p2s', 's2s']:
215 | s, a, p = train_set.random_batch()
216 | feed = {
217 | model.input_sketch: a,
218 | model.input_photo: p,
219 | model.sequence_lengths: s,
220 | }
221 | else:
222 | s, a, p, n = train_set.random_batch()
223 | feed = {
224 | model.input_sketch: a,
225 | model.input_photo: p,
226 | model.input_sketch_photo: n,
227 | model.sequence_lengths: s,
228 | }
229 |
230 | (train_cost, r_cost, kl_cost, p2s_r, p2s_kl, s2p_r, s2p_kl, p2p_r, p2p_kl, s2s_r, s2s_kl, _, train_step, _) = \
231 | sess.run([model.cost, model.r_cost, model.kl_cost, model.p2s_r, model.p2s_kl, model.s2p_r, model.s2p_kl,
232 | model.p2p_r, model.p2p_kl, model.s2s_r, model.s2s_kl, model.final_state, model.global_step,
233 | model.train_op], feed)
234 |
235 | if step % hps.print_every == 0 and step > 0:
236 | end = time.time()
237 | time_taken = end - start
238 |
239 | output_format = ('step: %d, ALL (cost: %.4f, recon: %.4f, kl: %.4f), P2S (recons: %.4f, kl: %.4f), '
240 | 'S2P (recons: %.4f, kl: %.4f), P2P (recons: %.4f, kl: %.4f), S2S (recons: %.4f, kl: %.4f), '
241 | 'train_time_taken: %.4f')
242 | output_values = (step, train_cost, r_cost, kl_cost, p2s_r, p2s_kl,
243 | s2p_r, s2p_kl, p2p_r, p2p_kl, s2s_r, s2s_kl, time_taken)
244 | output_log = output_format % output_values
245 |
246 | print(output_log)
247 |
248 | start = time.time()
249 |
250 | if step % hps.save_every == 0 and step > 0:
251 |
252 | save_model(sess, saver, FLAGS.log_root, step)
253 |
254 | print("Finished training stage with %d steps" % FLAGS.max_steps)
255 |
256 |
257 | def trainer(model_params):
258 | """Train a sketch-rnn model."""
259 | np.set_printoptions(precision=8, edgeitems=6, linewidth=200, suppress=True)
260 |
261 | print('Loading data files.')
262 | train_set, model_params = utils.load_dataset(FLAGS.root_dir, FLAGS.dataset, model_params)
263 |
264 | reset_graph()
265 | model = sketch_rnn_model.Model(model_params)
266 |
267 | sess = tf.Session()
268 | sess.run(tf.global_variables_initializer())
269 |
270 | if FLAGS.load_pretrain:
271 | load_pretrain(sess, FLAGS.vae_type, FLAGS.enc_type, FLAGS.dataset, FLAGS.basenet, FLAGS.log_root)
272 |
273 | if FLAGS.resume_training:
274 | resume_train(sess, FLAGS.load_dir, FLAGS.dataset, FLAGS.enc_type, FLAGS.basenet, FLAGS.feat_type, FLAGS.log_root)
275 |
276 | train(sess, model, train_set)
277 |
278 |
279 | def main(unused_argv):
280 | """Load model params, save config file and start trainer."""
281 | model_params = tf.contrib.training.HParams()
282 | # merge FLAGS to hps
283 | for attr, value in sorted(FLAGS.__flags.items()):
284 | model_params.add_hparam(attr, value)
285 |
286 | trainer(model_params)
287 |
288 |
289 | if __name__ == '__main__':
290 | tf.app.run(main)
291 |
--------------------------------------------------------------------------------
/tf_data_work.py:
--------------------------------------------------------------------------------
1 | # modeified from image processing in inceptionv3
2 | # revised at 01/10/2017, By Jeffrey
3 |
4 | from __future__ import absolute_import
5 | from __future__ import division
6 | from __future__ import print_function
7 |
8 | import random
9 | import tensorflow as tf
10 | from tensorflow.python.ops import control_flow_ops
11 | import pdb
12 | # from model import image_processing_test
13 |
14 | FLAGS = tf.app.flags.FLAGS
15 |
16 |
17 | def _crop(image, offset_height, offset_width, crop_height, crop_width):
18 | """Crops the given image using the provided offsets and sizes.
19 | Note that the method doesn't assume we know the input image size but it does
20 | assume we know the input image rank.
21 | Args:
22 | image: an image of shape [height, width, channels].
23 | offset_height: a scalar tensor indicating the height offset.
24 | offset_width: a scalar tensor indicating the width offset.
25 | crop_height: the height of the cropped image.
26 | crop_width: the width of the cropped image.
27 | Returns:
28 | the cropped (and resized) image.
29 | Raises:
30 | InvalidArgumentError: if the rank is not 3 or if the image dimensions are
31 | less than the crop size.
32 | """
33 | original_shape = tf.shape(image)
34 |
35 | rank_assertion = tf.Assert(
36 | tf.equal(tf.rank(image), 3),
37 | ['Rank of image must be equal to 3.'])
38 | cropped_shape = control_flow_ops.with_dependencies(
39 | [rank_assertion],
40 | tf.pack([crop_height, crop_width, original_shape[2]]))
41 |
42 | size_assertion = tf.Assert(
43 | tf.logical_and(
44 | tf.greater_equal(original_shape[0], crop_height),
45 | tf.greater_equal(original_shape[1], crop_width)),
46 | ['Crop size greater than the image size.'])
47 |
48 | offsets = tf.to_int32(tf.pack([offset_height, offset_width, 0]))
49 |
50 | # Use tf.slice instead of crop_to_bounding box as it accepts tensors to
51 | # define the crop size.
52 | image = control_flow_ops.with_dependencies(
53 | [size_assertion],
54 | tf.slice(image, offsets, cropped_shape))
55 | return tf.reshape(image, cropped_shape)
56 |
57 |
58 | def _random_crop(image_list, crop_height, crop_width):
59 | """Crops the given list of images.
60 | The function applies the same crop to each image in the list. This can be
61 | effectively applied when there are multiple image inputs of the same
62 | dimension such as:
63 | image, depths, normals = _random_crop([image, depths, normals], 120, 150)
64 | Args:
65 | image_list: a list of image tensors of the same dimension but possibly
66 | varying channel.
67 | crop_height: the new height.
68 | crop_width: the new width.
69 | Returns:
70 | the image_list with cropped images.
71 | Raises:
72 | ValueError: if there are multiple image inputs provided with different size
73 | or the images are smaller than the crop dimensions.
74 | """
75 | if not image_list:
76 | raise ValueError('Empty image_list.')
77 |
78 | # Compute the rank assertions.
79 | rank_assertions = []
80 | for i in range(len(image_list)):
81 | image_rank = tf.rank(image_list[i])
82 | rank_assert = tf.Assert(
83 | tf.equal(image_rank, 3),
84 | ['Wrong rank for tensor %s [expected] [actual]',
85 | image_list[i].name, 3, image_rank])
86 | rank_assertions.append(rank_assert)
87 |
88 | image_shape = control_flow_ops.with_dependencies(
89 | [rank_assertions[0]],
90 | tf.shape(image_list[0]))
91 | image_height = image_shape[0]
92 | image_width = image_shape[1]
93 | crop_size_assert = tf.Assert(
94 | tf.logical_and(
95 | tf.greater_equal(image_height, crop_height),
96 | tf.greater_equal(image_width, crop_width)),
97 | ['Crop size greater than the image size.'])
98 |
99 | asserts = [rank_assertions[0], crop_size_assert]
100 |
101 | for i in range(1, len(image_list)):
102 | image = image_list[i]
103 | asserts.append(rank_assertions[i])
104 | shape = control_flow_ops.with_dependencies([rank_assertions[i]],
105 | tf.shape(image))
106 | height = shape[0]
107 | width = shape[1]
108 |
109 | height_assert = tf.Assert(
110 | tf.equal(height, image_height),
111 | ['Wrong height for tensor %s [expected][actual]',
112 | image.name, height, image_height])
113 | width_assert = tf.Assert(
114 | tf.equal(width, image_width),
115 | ['Wrong width for tensor %s [expected][actual]',
116 | image.name, width, image_width])
117 | asserts.extend([height_assert, width_assert])
118 |
119 | # Create a random bounding box.
120 | #
121 | # Use tf.random_uniform and not numpy.random.rand as doing the former would
122 | # generate random numbers at graph eval time, unlike the latter which
123 | # generates random numbers at graph definition time.
124 | max_offset_height = control_flow_ops.with_dependencies(
125 | asserts, tf.reshape(image_height - crop_height + 1, []))
126 | max_offset_width = control_flow_ops.with_dependencies(
127 | asserts, tf.reshape(image_width - crop_width + 1, []))
128 | offset_height = tf.random_uniform(
129 | [], maxval=max_offset_height, dtype=tf.int32)
130 | offset_width = tf.random_uniform(
131 | [], maxval=max_offset_width, dtype=tf.int32)
132 |
133 | return [_crop(image, offset_height, offset_width,
134 | crop_height, crop_width) for image in image_list]
135 |
136 |
137 | def _central_crop(image_list, crop_height, crop_width):
138 | """Performs central crops of the given image list.
139 | Args:
140 | image_list: a list of image tensors of the same dimension but possibly
141 | varying channel.
142 | crop_height: the height of the image following the crop.
143 | crop_width: the width of the image following the crop.
144 | Returns:
145 | the list of cropped images.
146 | """
147 | outputs = []
148 | for image in image_list:
149 | image_height = tf.shape(image)[0]
150 | image_width = tf.shape(image)[1]
151 |
152 | offset_height = (image_height - crop_height) / 2
153 | offset_width = (image_width - crop_width) / 2
154 |
155 | outputs.append(_crop(image, offset_height, offset_width,
156 | crop_height, crop_width))
157 | return outputs
158 |
159 |
160 | def decode_png(image_buffer, scope=None):
161 | """Decode a JPEG string into one 3-D float image Tensor.
162 |
163 | Args:
164 | image_buffer: scalar string Tensor.
165 | scope: Optional scope for op_scope.
166 | Returns:
167 | 3-D float Tensor with values ranging from [0, 1).
168 | """
169 | # with tf.name_scope([image_buffer], scope, 'decode_png'):
170 | # Decode the string as an RGB JPEG.
171 | image = tf.image.decode_png(image_buffer, channels=3)
172 |
173 | # After this point, all image pixels reside in [0,1)
174 | # until the very end, when they're rescaled to (-1, 1). The various
175 | # adjust_* ops all require this range for dtype float.
176 | image = tf.image.convert_image_dtype(image, dtype=tf.float32)
177 | return image
178 |
179 |
180 | def distort_color(image, thread_id=0, scope=None):
181 | """Distort the color of the image.
182 |
183 | Each color distortion is non-commutative and thus ordering of the color ops
184 | matters. Ideally we would randomly permute the ordering of the color ops.
185 | Rather then adding that level of complication, we select a distinct ordering
186 | of color ops for each preprocessing thread.
187 |
188 | Args:
189 | image: Tensor containing single image.
190 | thread_id: preprocessing thread ID.
191 | scope: Optional scope for op_scope.
192 | Returns:
193 | color-distorted image
194 | """
195 | with tf.op_scope([image], scope, 'distort_color'):
196 | color_ordering = thread_id % 2
197 |
198 | if color_ordering == 0:
199 | image = tf.image.random_brightness(image, max_delta=32. / 255.)
200 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
201 | image = tf.image.random_hue(image, max_delta=0.2)
202 | image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
203 | elif color_ordering == 1:
204 | image = tf.image.random_brightness(image, max_delta=32. / 255.)
205 | image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
206 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
207 | image = tf.image.random_hue(image, max_delta=0.2)
208 |
209 | # The random_* ops do not necessarily clamp.
210 | image = tf.clip_by_value(image, 0.0, 1.0)
211 | return image
212 |
213 |
214 | def distort_image(image, height, width, chn_size=3, bbox=None, flipr=False, scope=None):
215 | """Distort one image for training a network.
216 |
217 | Distorting images provides a useful technique for augmenting the data
218 | set during training in order to make the network invariant to aspects
219 | of the image that do not effect the label.
220 |
221 | Args:
222 | image: 3-D float Tensor of image
223 | height: integer
224 | width: integer
225 | bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
226 | where each coordinate is [0, 1) and the coordinates are arranged
227 | as [ymin, xmin, ymax, xmax].
228 | thread_id: integer indicating the preprocessing thread.
229 | scope: Optional scope for op_scope.
230 | Returns:
231 | 3-D float Tensor of distorted image used for training.
232 | """
233 | if bbox is None:
234 | bbox = tf.constant([0.0, 0.0, 1.0, 1.0],
235 | dtype=tf.float32,
236 | shape=[1, 1, 4])
237 | with tf.op_scope([image, height, width, bbox], scope, 'distort_image'):
238 | # Each bounding box has shape [1, num_boxes, box coords] and
239 | # the coordinates are ordered [ymin, xmin, ymax, xmax].
240 | # A large fraction of image datasets contain a human-annotated bounding
241 | # box delineating the region of the image containing the object of interest.
242 | # We choose to create a new bounding box for the object which is a randomly
243 | # distorted version of the human-annotated bounding box that obeys an allowed
244 | # range of aspect ratios, sizes and overlap with the human-annotated
245 | # bounding box. If no box is supplied, then we assume the bounding box is
246 | # the entire image.
247 | sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box(
248 | tf.shape(image),
249 | bounding_boxes=[[[0.0, 0.0, 1.0, 1.0]]],
250 | min_object_covered=0.1,
251 | aspect_ratio_range=[0.9, 1.1],
252 | area_range=[0.8, 1.0],
253 | max_attempts=100,
254 | use_image_if_no_bounding_boxes=True)
255 | bbox_begin, bbox_size, distort_bbox = sample_distorted_bounding_box
256 | # if not thread_id:
257 | # image_with_distorted_box = tf.image.draw_bounding_boxes(
258 | # tf.expand_dims(image, 0), distort_bbox)
259 | # tf.image_summary('images_with_distorted_bounding_box',
260 | # image_with_distorted_box)
261 |
262 | # Crop the image to the specified bounding box.
263 | distorted_image = tf.slice(image, bbox_begin, bbox_size)
264 |
265 | # This resizing operation may distort the images because the aspect
266 | # ratio is not respected. We select a resize method in a round robin
267 | # fashion based on the thread number.
268 | # Note that ResizeMethod contains 4 enumerated resizing methods.
269 | distorted_image = tf.image.resize_images(distorted_image, [height, width])
270 | # Restore the shape since the dynamic slice based upon the bbox_size loses
271 | # the third dimension.
272 | distorted_image.set_shape([height, width, chn_size])
273 |
274 | # Randomly flip the image horizontally.
275 | if flipr:
276 | distorted_image = tf.image.random_flip_left_right(distorted_image)
277 |
278 | # # Randomly distort the colors.
279 | # distorted_image = distort_color(distorted_image)
280 |
281 | return distorted_image
282 |
283 |
284 | def processing_image(image_buffer, thread_id=0, data_augmentation_flag=True):
285 |
286 | image = decode_png(image_buffer)
287 |
288 | if data_augmentation_flag:
289 | image = distort_image(image, FLAGS.output_height, FLAGS.output_width, thread_id)
290 | # else:
291 | # image = eval_image(image, height, width)
292 |
293 | image.set_shape([FLAGS.output_height, FLAGS.output_width, 3])
294 | image = tf.image.resize_images(image, [FLAGS.resize_height, FLAGS.resize_width])
295 | image = tf.sub(image, 0.5)
296 | image = tf.mul(image, 2.0)
297 | # image = _central_crop([image], FLAGS.crop_height, FLAGS.crop_width)[0]
298 | image.set_shape([FLAGS.resize_height, FLAGS.resize_width, 3])
299 | image = tf.to_float(image)
300 |
301 | return image
302 |
303 |
304 | def call_distort_image(image):
305 | crop_size = FLAGS.crop_size
306 | dist_chn_size = FLAGS.dist_chn_size
307 | return distort_image(image, crop_size, crop_size, dist_chn_size)
308 |
309 |
310 | def data_augmentation(raw_images):
311 | # processed_images = tf.map_fn(lambda inputs: call_distort_image(*inputs), elems=raw_images, dtype=tf.float32)
312 | processed_images = tf.map_fn(lambda inputs: call_distort_image(inputs), raw_images)
313 | return processed_images
314 |
315 |
316 | def tf_high_pass_filter(tf_images):
317 | # implementation of high pass filter according to the "Sketch-pix2seq"
318 | filter_w = tf.constant([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]], tf.float32)
319 | filter_w = tf.expand_dims(tf.expand_dims(filter_w, -1), -1, name='hp_w')
320 | filtered_images = tf.nn.conv2d(tf_images, filter_w, strides=[1, 1, 1, 1], padding='SAME')
321 | return filtered_images
322 |
323 |
324 | def tf_image_processing(tf_images, basenet, crop_size, distort=False, hp_filter=False):
325 | if len(tf_images.shape) == 3:
326 | tf_images = tf.expand_dims(tf_images, -1)
327 | if basenet == 'sketchanet':
328 | mean_value = 250.42
329 | tf_images = tf.subtract(tf_images, mean_value)
330 | if distort:
331 | print("Distorting photos")
332 | FLAGS.crop_size = crop_size
333 | FLAGS.dist_chn_size = 1
334 | tf_images = data_augmentation(tf_images)
335 | else:
336 | tf_images = tf.image.resize_images(tf_images, (crop_size, crop_size))
337 | elif basenet in ['inceptionv1', 'inceptionv3', 'gen_cnn']:
338 | tf_images = tf.divide(tf_images, 255.0)
339 | tf_images = tf.subtract(tf_images, 0.5)
340 | tf_images = tf.multiply(tf_images, 2.0)
341 | if int(tf_images.shape[-1]) != 3:
342 | tf_images = tf.concat([tf_images, tf_images, tf_images], axis=-1)
343 | if distort:
344 | print("Distorting photos")
345 | FLAGS.crop_size = crop_size
346 | FLAGS.dist_chn_size = 3
347 | tf_images = data_augmentation(tf_images)
348 | # Display the training images in the visualizer.
349 | # tf.image_summary('input_images', input_images)
350 | else:
351 | tf_images = tf.image.resize_images(tf_images, (crop_size, crop_size))
352 |
353 | if hp_filter:
354 | tf_images = tf_high_pass_filter(tf_images)
355 |
356 | return tf_images
357 |
--------------------------------------------------------------------------------
/svg_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cPickle
3 | import numpy as np
4 | import xml.etree.ElementTree as ET
5 | import random
6 | import svgwrite
7 | from IPython.display import SVG, display
8 | from svg.path import Path, Line, Arc, CubicBezier, QuadraticBezier, parse_path
9 |
10 |
11 | def calculate_start_point(data, factor=1.0, block_size=200):
12 | # will try to center the sketch to the middle of the block
13 | # determines maxx, minx, maxy, miny
14 | sx = 0
15 | sy = 0
16 | maxx = 0
17 | minx = 0
18 | maxy = 0
19 | miny = 0
20 | for i in xrange(len(data)):
21 | sx += round(float(data[i, 0]) * factor, 3)
22 | sy += round(float(data[i, 1]) * factor, 3)
23 | maxx = max(maxx, sx)
24 | minx = min(minx, sx)
25 | maxy = max(maxy, sy)
26 | miny = min(miny, sy)
27 |
28 | abs_x = block_size / 2 - (maxx - minx) / 2 - minx
29 | abs_y = block_size / 2 - (maxy - miny) / 2 - miny
30 |
31 | return abs_x, abs_y, (maxx - minx), (maxy - miny)
32 |
33 |
34 | def draw_stroke_color_array(data, factor=1, svg_filename='sample.svg', stroke_width=1, block_size=200, maxcol=5,
35 | svg_only=False, color_mode=True):
36 | num_char = len(data)
37 |
38 | if num_char < 1:
39 | return
40 |
41 | max_color_intensity = 225
42 |
43 | numrow = np.ceil(float(num_char) / float(maxcol))
44 | dwg = svgwrite.Drawing(svg_filename, size=(block_size * (min(num_char, maxcol)), block_size * numrow))
45 | dwg.add(dwg.rect(insert=(0, 0), size=(block_size * (min(num_char, maxcol)), block_size * numrow), fill='white'))
46 |
47 | the_color = "rgb(" + str(random.randint(0, max_color_intensity)) + "," + str(
48 | int(random.randint(0, max_color_intensity))) + "," + str(int(random.randint(0, max_color_intensity))) + ")"
49 |
50 | for j in xrange(len(data)):
51 |
52 | lift_pen = 0
53 | # end_of_char = 0
54 | cdata = data[j]
55 | abs_x, abs_y, size_x, size_y = calculate_start_point(cdata, factor, block_size)
56 | abs_x += (j % maxcol) * block_size
57 | abs_y += (j / maxcol) * block_size
58 |
59 | for i in xrange(len(cdata)):
60 |
61 | x = round(float(cdata[i, 0]) * factor, 3)
62 | y = round(float(cdata[i, 1]) * factor, 3)
63 |
64 | prev_x = round(abs_x, 3)
65 | prev_y = round(abs_y, 3)
66 |
67 | abs_x += x
68 | abs_y += y
69 |
70 | if (lift_pen == 1):
71 | p = "M " + str(abs_x) + "," + str(abs_y) + " "
72 | the_color = "rgb(" + str(random.randint(0, max_color_intensity)) + "," + str(
73 | int(random.randint(0, max_color_intensity))) + "," + str(
74 | int(random.randint(0, max_color_intensity))) + ")"
75 | else:
76 | p = "M " + str(prev_x) + "," + str(prev_y) + " L " + str(abs_x) + "," + str(abs_y) + " "
77 |
78 | lift_pen = max(cdata[i, 2], cdata[i, 3]) # lift pen if both eos or eoc
79 | # end_of_char = cdata[i, 3] # not used for now.
80 |
81 | if color_mode == False:
82 | the_color = "#000"
83 |
84 | dwg.add(dwg.path(p).stroke(the_color, stroke_width).fill(
85 | the_color)) # , opacity=round(random.random()*0.5+0.5, 3)
86 |
87 | dwg.save()
88 | if svg_only == False:
89 | display(SVG(dwg.tostring()))
90 |
91 |
92 | def draw_stroke_color(data, factor=1, svg_filename='sample.svg', stroke_width=1, block_size=200, maxcol=5,
93 | svg_only=False, color_mode=True):
94 | def split_sketch(data):
95 | # split a sketch with many eoc into an array of sketches, each with just one eoc at the end.
96 | # ignores last stub with no eoc.
97 | counter = 0
98 | result = []
99 | for i in xrange(len(data)):
100 | eoc = data[i, 3]
101 | if eoc > 0:
102 | result.append(data[counter:i + 1])
103 | counter = i + 1
104 | # if (counter < len(data)): # ignore the rest
105 | # result.append(data[counter:])
106 | return result
107 |
108 | data = np.array(data, dtype=np.float32)
109 | data = split_sketch(data)
110 | draw_stroke_color_array(data, factor, svg_filename, stroke_width, block_size, maxcol, svg_only, color_mode)
111 |
112 |
113 | def cubicbezier(x0, y0, x1, y1, x2, y2, x3, y3, n=20):
114 | # from http://rosettacode.org/wiki/Bitmap/B%C3%A9zier_curves/Cubic
115 | pts = []
116 | for i in range(n + 1):
117 | t = float(i) / float(n)
118 | a = (1. - t) ** 3
119 | b = 3. * t * (1. - t) ** 2
120 | c = 3.0 * t ** 2 * (1.0 - t)
121 | d = t ** 3
122 |
123 | x = float(a * x0 + b * x1 + c * x2 + d * x3)
124 | y = float(a * y0 + b * y1 + c * y2 + d * y3)
125 | pts.append((x, y))
126 | return pts
127 |
128 |
129 | def get_path_strings(svgfile):
130 | tree = ET.parse(svgfile)
131 | p = []
132 | for elem in tree.iter():
133 | if elem.attrib.has_key('d'):
134 | p.append(elem.attrib['d'])
135 | return p
136 |
137 |
138 | def build_lines(svgfile, line_length_threshold=10.0, min_points_per_path=1, max_points_per_path=3):
139 | # we don't draw lines less than line_length_threshold
140 | path_strings = get_path_strings(svgfile)
141 |
142 | lines = []
143 |
144 | for path_string in path_strings:
145 | try:
146 | full_path = parse_path(path_string)
147 | except:
148 | import pdb
149 | pdb.set_trace()
150 | print "e"
151 | for i in range(len(full_path)):
152 | p = full_path[i]
153 | if type(p) != Line and type(p) != CubicBezier:
154 | print "encountered an element that is not just a line or bezier "
155 | print "type: ", type(p)
156 | print p
157 | else:
158 | x_start = p.start.real
159 | y_start = p.start.imag
160 | x_end = p.end.real
161 | y_end = p.end.imag
162 | line_length = np.sqrt(
163 | (x_end - x_start) * (x_end - x_start) + (y_end - y_start) * (y_end - y_start))
164 | # len_data.append(line_length)
165 | points = []
166 | if type(p) == CubicBezier:
167 | x_con1 = p.control1.real
168 | y_con1 = p.control1.imag
169 | x_con2 = p.control2.real
170 | y_con2 = p.control2.imag
171 | n_points = int(line_length / line_length_threshold) + 1
172 | n_points = max(n_points, min_points_per_path)
173 | n_points = min(n_points, max_points_per_path)
174 | points = cubicbezier(x_start, y_start, x_con1, y_con1, x_con2, y_con2, x_end, y_end,
175 | n_points)
176 | else:
177 | points = [(x_start, y_start), (x_end, y_end)]
178 | if i == 0: # only append the starting point for svg
179 | lines.append([points[0][0], points[0][1], 0, 0]) # put eoc to be zero
180 | for j in range(1, len(points)):
181 | eos = 0
182 | if j == len(points) - 1 and i == len(full_path) - 1:
183 | eos = 1
184 | lines.append([points[j][0], points[j][1], eos, 0]) # put eoc to be zero
185 | lines = np.array(lines, dtype=np.float32)
186 | # make it relative moves
187 | lines[1:, 0:2] -= lines[0:-1, 0:2]
188 | lines[-1, 3] = 1 # end of character
189 | lines[0] = [0, 0, 0, 0] # start at origin
190 | return lines[1:]
191 |
192 |
193 | class SketchLoader():
194 | def __init__(self, batch_size=50, seq_length=300, scale_factor=1.0, data_filename="kanji"):
195 | import pdb
196 | pdb.set_trace()
197 | self.data_dir = "./data"
198 | self.batch_size = batch_size
199 | self.seq_length = seq_length
200 | self.scale_factor = scale_factor # divide data by this factor
201 |
202 | data_file = os.path.join(self.data_dir, data_filename + ".cpkl")
203 | raw_data_dir = os.path.join(self.data_dir, data_filename)
204 |
205 | if not (os.path.exists(data_file)):
206 | raise Exception('File not exist')
207 | # self.length_data = self.preprocess(raw_data_dir, data_file)
208 |
209 | self.load_preprocessed(data_file)
210 | self.num_samples = len(self.raw_data)
211 | self.index = range(self.num_samples) # this list will be randomized later.
212 | self.reset_index_pointer()
213 |
214 | def preprocess(self, data_dir, data_file):
215 | # create data file from raw xml files from iam handwriting source.
216 | len_data = []
217 |
218 | def cubicbezier(x0, y0, x1, y1, x2, y2, x3, y3, n=20):
219 | # from http://rosettacode.org/wiki/Bitmap/B%C3%A9zier_curves/Cubic
220 | pts = []
221 | for i in range(n + 1):
222 | t = float(i) / float(n)
223 | a = (1. - t) ** 3
224 | b = 3. * t * (1. - t) ** 2
225 | c = 3.0 * t ** 2 * (1.0 - t)
226 | d = t ** 3
227 |
228 | x = float(a * x0 + b * x1 + c * x2 + d * x3)
229 | y = float(a * y0 + b * y1 + c * y2 + d * y3)
230 | pts.append((x, y))
231 | return pts
232 |
233 | def get_path_strings(svgfile):
234 | tree = ET.parse(svgfile)
235 | p = []
236 | for elem in tree.iter():
237 | if elem.attrib.has_key('d'):
238 | p.append(elem.attrib['d'])
239 | return p
240 |
241 | def build_lines(svgfile, line_length_threshold=10.0, min_points_per_path=1, max_points_per_path=3):
242 | # we don't draw lines less than line_length_threshold
243 | path_strings = get_path_strings(svgfile)
244 |
245 | lines = []
246 |
247 | for path_string in path_strings:
248 | full_path = parse_path(path_string)
249 | for i in range(len(full_path)):
250 | p = full_path[i]
251 | if type(p) != Line and type(p) != CubicBezier:
252 | print "encountered an element that is not just a line or bezier "
253 | print "type: ", type(p)
254 | print p
255 | else:
256 | x_start = p.start.real
257 | y_start = p.start.imag
258 | x_end = p.end.real
259 | y_end = p.end.imag
260 | line_length = np.sqrt(
261 | (x_end - x_start) * (x_end - x_start) + (y_end - y_start) * (y_end - y_start))
262 | len_data.append(line_length)
263 | points = []
264 | if type(p) == CubicBezier:
265 | x_con1 = p.control1.real
266 | y_con1 = p.control1.imag
267 | x_con2 = p.control2.real
268 | y_con2 = p.control2.imag
269 | n_points = int(line_length / line_length_threshold) + 1
270 | n_points = max(n_points, min_points_per_path)
271 | n_points = min(n_points, max_points_per_path)
272 | points = cubicbezier(x_start, y_start, x_con1, y_con1, x_con2, y_con2, x_end, y_end,
273 | n_points)
274 | else:
275 | points = [(x_start, y_start), (x_end, y_end)]
276 | if i == 0: # only append the starting point for svg
277 | lines.append([points[0][0], points[0][1], 0, 0]) # put eoc to be zero
278 | for j in range(1, len(points)):
279 | eos = 0
280 | if j == len(points) - 1 and i == len(full_path) - 1:
281 | eos = 1
282 | lines.append([points[j][0], points[j][1], eos, 0]) # put eoc to be zero
283 | lines = np.array(lines, dtype=np.float32)
284 | # make it relative moves
285 | lines[1:, 0:2] -= lines[0:-1, 0:2]
286 | lines[-1, 3] = 1 # end of character
287 | lines[0] = [0, 0, 0, 0] # start at origin
288 | return lines[1:]
289 |
290 | # build the list of xml files
291 | filelist = []
292 | # Set the directory you want to start from
293 | rootDir = data_dir
294 | for dirName, subdirList, fileList in os.walk(rootDir):
295 | # print('Found directory: %s' % dirName)
296 | for fname in fileList:
297 | # print('\t%s' % fname)
298 | filelist.append(dirName + "/" + fname)
299 |
300 | # build stroke database of every xml file inside iam database
301 | sketch = []
302 | for i in range(len(filelist)):
303 | if (filelist[i][-3:] == 'svg'):
304 | print 'processing ' + filelist[i]
305 | sketch.append(build_lines(filelist[i]))
306 |
307 | f = open(data_file, "wb")
308 | cPickle.dump(sketch, f, protocol=2)
309 | f.close()
310 | import pdb
311 | pdb.set_trace()
312 | return len_data
313 |
314 | def load_preprocessed(self, data_file):
315 | f = open(data_file, "rb")
316 | self.raw_data = cPickle.load(f)
317 | # scale the data here, rather than at the data construction (since scaling may change)
318 | for data in self.raw_data:
319 | data[:, 0:2] /= self.scale_factor
320 | f.close()
321 |
322 | def next_batch(self):
323 | # returns a set of batches, but the constraint is that the start of each input data batch
324 | # is the start of a new character (although the end of a batch doesn't have to be end of a character)
325 |
326 | def next_seq(n):
327 | result = np.zeros((n, 5), dtype=np.float32) # x, y, [eos, eoc, cont] tokens
328 | # result[0, 2:4] = 1 # set eos and eoc to true for first point
329 | # experimental line below, put a random factor between 70-130% to generate more examples
330 | rand_scale_factor_x = np.random.rand() * 0.6 + 0.7
331 | rand_scale_factor_y = np.random.rand() * 0.6 + 0.7
332 | idx = 0
333 | data = self.current_data()
334 | for i in xrange(n):
335 | result[i, 0:4] = data[idx] # eoc = 0.0
336 | result[i, 4] = 1 # continue on stroke
337 | if (result[i, 2] > 0 or result[i, 3] > 0):
338 | result[i, 4] = 0
339 | idx += 1
340 | if (idx >= len(data) - 1): # skip to next sketch example next time and mark eoc
341 | result[i, 4] = 0
342 | result[i, 3] = 1
343 | result[i, 2] = 0 # overrides end of stroke one-hot
344 | idx = 0
345 | self.tick_index_pointer()
346 | data = self.current_data()
347 | assert (result[i, 2:5].sum() == 1)
348 | self.tick_index_pointer() # needed if seq_length is less than last data.
349 | result[:, 0] *= rand_scale_factor_x
350 | result[:, 1] *= rand_scale_factor_y
351 | return result
352 |
353 | skip_length = self.seq_length + 1
354 |
355 | batch = []
356 |
357 | for i in xrange(self.batch_size):
358 | seq = next_seq(skip_length)
359 | batch.append(seq)
360 |
361 | batch = np.array(batch, dtype=np.float32)
362 |
363 | return batch[:, 0:-1], batch[:, 1:]
364 |
365 | def current_data(self):
366 | return self.raw_data[self.index[self.pointer]]
367 |
368 | def tick_index_pointer(self):
369 | self.pointer += 1
370 | if (self.pointer >= len(self.raw_data)):
371 | self.pointer = 0
372 | self.epoch_finished = True
373 |
374 | def reset_index_pointer(self):
375 | # randomize order for the raw list in the next go.
376 | self.pointer = 0
377 | self.epoch_finished = False
378 | self.index = np.random.permutation(self.index)
379 |
380 |
381 |
--------------------------------------------------------------------------------
/svg/path/path.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | from math import sqrt, cos, sin, acos, degrees, radians, log
3 | from collections import MutableSequence
4 |
5 |
6 | # This file contains classes for the different types of SVG path segments as
7 | # well as a Path object that contains a sequence of path segments.
8 |
9 | MIN_DEPTH = 5
10 | ERROR = 1e-12
11 |
12 |
13 | def segment_length(curve, start, end, start_point, end_point, error, min_depth, depth):
14 | """Recursively approximates the length by straight lines"""
15 | mid = (start + end) / 2
16 | mid_point = curve.point(mid)
17 | length = abs(end_point - start_point)
18 | first_half = abs(mid_point - start_point)
19 | second_half = abs(end_point - mid_point)
20 |
21 | length2 = first_half + second_half
22 | if (length2 - length > error) or (depth < min_depth):
23 | # Calculate the length of each segment:
24 | depth += 1
25 | return (segment_length(curve, start, mid, start_point, mid_point,
26 | error, min_depth, depth) +
27 | segment_length(curve, mid, end, mid_point, end_point,
28 | error, min_depth, depth))
29 | # This is accurate enough.
30 | return length2
31 |
32 |
33 | class Line(object):
34 |
35 | def __init__(self, start, end):
36 | self.start = start
37 | self.end = end
38 |
39 | def __repr__(self):
40 | return 'Line(start=%s, end=%s)' % (self.start, self.end)
41 |
42 | def __eq__(self, other):
43 | if not isinstance(other, Line):
44 | return NotImplemented
45 | return self.start == other.start and self.end == other.end
46 |
47 | def __ne__(self, other):
48 | if not isinstance(other, Line):
49 | return NotImplemented
50 | return not self == other
51 |
52 | def point(self, pos):
53 | distance = self.end - self.start
54 | return self.start + distance * pos
55 |
56 | def length(self, error=None, min_depth=None):
57 | distance = (self.end - self.start)
58 | return sqrt(distance.real ** 2 + distance.imag ** 2)
59 |
60 |
61 | class CubicBezier(object):
62 | def __init__(self, start, control1, control2, end):
63 | self.start = start
64 | self.control1 = control1
65 | self.control2 = control2
66 | self.end = end
67 |
68 | def __repr__(self):
69 | return 'CubicBezier(start=%s, control1=%s, control2=%s, end=%s)' % (
70 | self.start, self.control1, self.control2, self.end)
71 |
72 | def __eq__(self, other):
73 | if not isinstance(other, CubicBezier):
74 | return NotImplemented
75 | return self.start == other.start and self.end == other.end and \
76 | self.control1 == other.control1 and self.control2 == other.control2
77 |
78 | def __ne__(self, other):
79 | if not isinstance(other, CubicBezier):
80 | return NotImplemented
81 | return not self == other
82 |
83 | def is_smooth_from(self, previous):
84 | """Checks if this segment would be a smooth segment following the previous"""
85 | if isinstance(previous, CubicBezier):
86 | return (self.start == previous.end and
87 | (self.control1 - self.start) == (previous.end - previous.control2))
88 | else:
89 | return self.control1 == self.start
90 |
91 | def point(self, pos):
92 | """Calculate the x,y position at a certain position of the path"""
93 | return ((1 - pos) ** 3 * self.start) + \
94 | (3 * (1 - pos) ** 2 * pos * self.control1) + \
95 | (3 * (1 - pos) * pos ** 2 * self.control2) + \
96 | (pos ** 3 * self.end)
97 |
98 | def length(self, error=ERROR, min_depth=MIN_DEPTH):
99 | """Calculate the length of the path up to a certain position"""
100 | start_point = self.point(0)
101 | end_point = self.point(1)
102 | return segment_length(self, 0, 1, start_point, end_point, error, min_depth, 0)
103 |
104 |
105 | class QuadraticBezier(object):
106 | def __init__(self, start, control, end):
107 | self.start = start
108 | self.end = end
109 | self.control = control
110 |
111 | def __repr__(self):
112 | return 'QuadraticBezier(start=%s, control=%s, end=%s)' % (
113 | self.start, self.control, self.end)
114 |
115 | def __eq__(self, other):
116 | if not isinstance(other, QuadraticBezier):
117 | return NotImplemented
118 | return self.start == other.start and self.end == other.end and \
119 | self.control == other.control
120 |
121 | def __ne__(self, other):
122 | if not isinstance(other, QuadraticBezier):
123 | return NotImplemented
124 | return not self == other
125 |
126 | def is_smooth_from(self, previous):
127 | """Checks if this segment would be a smooth segment following the previous"""
128 | if isinstance(previous, QuadraticBezier):
129 | return (self.start == previous.end and
130 | (self.control - self.start) == (previous.end - previous.control))
131 | else:
132 | return self.control == self.start
133 |
134 | def point(self, pos):
135 | return (1 - pos) ** 2 * self.start + 2 * (1 - pos) * pos * self.control + \
136 | pos ** 2 * self.end
137 |
138 | def length(self, error=None, min_depth=None):
139 | a = self.start - 2*self.control + self.end
140 | b = 2*(self.control - self.start)
141 | a_dot_b = a.real*b.real + a.imag*b.imag
142 |
143 | if abs(a) < 1e-12:
144 | s = abs(b)
145 | elif abs(a_dot_b + abs(a)*abs(b)) < 1e-12:
146 | k = abs(b)/abs(a)
147 | if k >= 2:
148 | s = abs(b) - abs(a)
149 | else:
150 | s = abs(a)*(k**2/2 - k + 1)
151 | else:
152 | # For an explanation of this case, see
153 | # http://www.malczak.info/blog/quadratic-bezier-curve-length/
154 | A = 4 * (a.real ** 2 + a.imag ** 2)
155 | B = 4 * (a.real * b.real + a.imag * b.imag)
156 | C = b.real ** 2 + b.imag ** 2
157 |
158 | Sabc = 2 * sqrt(A + B + C)
159 | A2 = sqrt(A)
160 | A32 = 2 * A * A2
161 | C2 = 2 * sqrt(C)
162 | BA = B / A2
163 |
164 | s = (A32 * Sabc + A2 * B * (Sabc - C2) + (4 * C * A - B ** 2) *
165 | log((2 * A2 + BA + Sabc) / (BA + C2))) / (4 * A32)
166 | return s
167 |
168 | class Arc(object):
169 |
170 | def __init__(self, start, radius, rotation, arc, sweep, end):
171 | """radius is complex, rotation is in degrees,
172 | large and sweep are 1 or 0 (True/False also work)"""
173 |
174 | self.start = start
175 | self.radius = radius
176 | self.rotation = rotation
177 | self.arc = bool(arc)
178 | self.sweep = bool(sweep)
179 | self.end = end
180 |
181 | self._parameterize()
182 |
183 | def __repr__(self):
184 | return 'Arc(start=%s, radius=%s, rotation=%s, arc=%s, sweep=%s, end=%s)' % (
185 | self.start, self.radius, self.rotation, self.arc, self.sweep, self.end)
186 |
187 | def __eq__(self, other):
188 | if not isinstance(other, Arc):
189 | return NotImplemented
190 | return self.start == other.start and self.end == other.end and \
191 | self.radius == other.radius and self.rotation == other.rotation and \
192 | self.arc == other.arc and self.sweep == other.sweep
193 |
194 | def __ne__(self, other):
195 | if not isinstance(other, Arc):
196 | return NotImplemented
197 | return not self == other
198 |
199 | def _parameterize(self):
200 | # Conversion from endpoint to center parameterization
201 | # http://www.w3.org/TR/SVG/implnote.html#ArcImplementationNotes
202 |
203 | cosr = cos(radians(self.rotation))
204 | sinr = sin(radians(self.rotation))
205 | dx = (self.start.real - self.end.real) / 2
206 | dy = (self.start.imag - self.end.imag) / 2
207 | x1prim = cosr * dx + sinr * dy
208 | x1prim_sq = x1prim * x1prim
209 | y1prim = -sinr * dx + cosr * dy
210 | y1prim_sq = y1prim * y1prim
211 |
212 | rx = self.radius.real
213 | rx_sq = rx * rx
214 | ry = self.radius.imag
215 | ry_sq = ry * ry
216 |
217 | # Correct out of range radii
218 | radius_check = (x1prim_sq / rx_sq) + (y1prim_sq / ry_sq)
219 | if radius_check > 1:
220 | rx *= sqrt(radius_check)
221 | ry *= sqrt(radius_check)
222 | rx_sq = rx * rx
223 | ry_sq = ry * ry
224 |
225 | t1 = rx_sq * y1prim_sq
226 | t2 = ry_sq * x1prim_sq
227 | c = sqrt(abs((rx_sq * ry_sq - t1 - t2) / (t1 + t2)))
228 |
229 | if self.arc == self.sweep:
230 | c = -c
231 | cxprim = c * rx * y1prim / ry
232 | cyprim = -c * ry * x1prim / rx
233 |
234 | self.center = complex((cosr * cxprim - sinr * cyprim) +
235 | ((self.start.real + self.end.real) / 2),
236 | (sinr * cxprim + cosr * cyprim) +
237 | ((self.start.imag + self.end.imag) / 2))
238 |
239 | ux = (x1prim - cxprim) / rx
240 | uy = (y1prim - cyprim) / ry
241 | vx = (-x1prim - cxprim) / rx
242 | vy = (-y1prim - cyprim) / ry
243 | n = sqrt(ux * ux + uy * uy)
244 | p = ux
245 | theta = degrees(acos(p / n))
246 | if uy < 0:
247 | theta = -theta
248 | self.theta = theta % 360
249 |
250 | n = sqrt((ux * ux + uy * uy) * (vx * vx + vy * vy))
251 | p = ux * vx + uy * vy
252 | if p == 0:
253 | delta = degrees(acos(0))
254 | else:
255 | delta = degrees(acos(p / n))
256 | if (ux * vy - uy * vx) < 0:
257 | delta = -delta
258 | self.delta = delta % 360
259 | if not self.sweep:
260 | self.delta -= 360
261 |
262 | def point(self, pos):
263 | angle = radians(self.theta + (self.delta * pos))
264 | cosr = cos(radians(self.rotation))
265 | sinr = sin(radians(self.rotation))
266 |
267 | x = (cosr * cos(angle) * self.radius.real - sinr * sin(angle) *
268 | self.radius.imag + self.center.real)
269 | y = (sinr * cos(angle) * self.radius.real + cosr * sin(angle) *
270 | self.radius.imag + self.center.imag)
271 | return complex(x, y)
272 |
273 | def length(self, error=ERROR, min_depth=MIN_DEPTH):
274 | """The length of an elliptical arc segment requires numerical
275 | integration, and in that case it's simpler to just do a geometric
276 | approximation, as for cubic bezier curves.
277 | """
278 | start_point = self.point(0)
279 | end_point = self.point(1)
280 | return segment_length(self, 0, 1, start_point, end_point, error, min_depth, 0)
281 |
282 |
283 | class Path(MutableSequence):
284 | """A Path is a sequence of path segments"""
285 |
286 | # Put it here, so there is a default if unpickled.
287 | _closed = False
288 |
289 | def __init__(self, *segments, **kw):
290 | self._segments = list(segments)
291 | self._length = None
292 | self._lengths = None
293 | if 'closed' in kw:
294 | self.closed = kw['closed']
295 |
296 | def __getitem__(self, index):
297 | return self._segments[index]
298 |
299 | def __setitem__(self, index, value):
300 | self._segments[index] = value
301 | self._length = None
302 |
303 | def __delitem__(self, index):
304 | del self._segments[index]
305 | self._length = None
306 |
307 | def insert(self, index, value):
308 | self._segments.insert(index, value)
309 | self._length = None
310 |
311 | def reverse(self):
312 | # Reversing the order of a path would require reversing each element
313 | # as well. That's not implemented.
314 | raise NotImplementedError
315 |
316 | def __len__(self):
317 | return len(self._segments)
318 |
319 | def __repr__(self):
320 | return 'Path(%s, closed=%s)' % (
321 | ', '.join(repr(x) for x in self._segments), self.closed)
322 |
323 | def __eq__(self, other):
324 | if not isinstance(other, Path):
325 | return NotImplemented
326 | if len(self) != len(other):
327 | return False
328 | for s, o in zip(self._segments, other._segments):
329 | if not s == o:
330 | return False
331 | return True
332 |
333 | def __ne__(self, other):
334 | if not isinstance(other, Path):
335 | return NotImplemented
336 | return not self == other
337 |
338 | def _calc_lengths(self, error=ERROR, min_depth=MIN_DEPTH):
339 | if self._length is not None:
340 | return
341 |
342 | lengths = [each.length(error=error, min_depth=min_depth) for each in self._segments]
343 | self._length = sum(lengths)
344 | self._lengths = [each / self._length for each in lengths]
345 |
346 | def point(self, pos, error=ERROR):
347 |
348 | # Shortcuts
349 | if pos == 0.0:
350 | return self._segments[0].point(pos)
351 | if pos == 1.0:
352 | return self._segments[-1].point(pos)
353 |
354 | self._calc_lengths(error=error)
355 | # Find which segment the point we search for is located on:
356 | segment_start = 0
357 | for index, segment in enumerate(self._segments):
358 | segment_end = segment_start + self._lengths[index]
359 | if segment_end >= pos:
360 | # This is the segment! How far in on the segment is the point?
361 | segment_pos = (pos - segment_start) / (segment_end - segment_start)
362 | break
363 | segment_start = segment_end
364 |
365 | return segment.point(segment_pos)
366 |
367 | def length(self, error=ERROR, min_depth=MIN_DEPTH):
368 | self._calc_lengths(error, min_depth)
369 | return self._length
370 |
371 | def _is_closable(self):
372 | """Returns true if the end is on the start of a segment"""
373 | end = self[-1].end
374 | for segment in self:
375 | if segment.start == end:
376 | return True
377 | return False
378 |
379 | @property
380 | def closed(self):
381 | """Checks that the path is closed"""
382 | return self._closed and self._is_closable()
383 |
384 | @closed.setter
385 | def closed(self, value):
386 | value = bool(value)
387 | if value and not self._is_closable():
388 | raise ValueError("End does not coincide with a segment start.")
389 | self._closed = value
390 |
391 | def d(self):
392 | if self.closed:
393 | segments = self[:-1]
394 | else:
395 | segments = self[:]
396 |
397 | current_pos = None
398 | parts = []
399 | previous_segment = None
400 | end = self[-1].end
401 |
402 | for segment in segments:
403 | start = segment.start
404 | # If the start of this segment does not coincide with the end of
405 | # the last segment or if this segment is actually the close point
406 | # of a closed path, then we should start a new subpath here.
407 | if current_pos != start or (self.closed and start == end):
408 | parts.append('M {0:G},{1:G}'.format(start.real, start.imag))
409 |
410 | if isinstance(segment, Line):
411 | parts.append('L {0:G},{1:G}'.format(
412 | segment.end.real, segment.end.imag)
413 | )
414 | elif isinstance(segment, CubicBezier):
415 | if segment.is_smooth_from(previous_segment):
416 | parts.append('S {0:G},{1:G} {2:G},{3:G}'.format(
417 | segment.control2.real, segment.control2.imag,
418 | segment.end.real, segment.end.imag)
419 | )
420 | else:
421 | parts.append('C {0:G},{1:G} {2:G},{3:G} {4:G},{5:G}'.format(
422 | segment.control1.real, segment.control1.imag,
423 | segment.control2.real, segment.control2.imag,
424 | segment.end.real, segment.end.imag)
425 | )
426 | elif isinstance(segment, QuadraticBezier):
427 | if segment.is_smooth_from(previous_segment):
428 | parts.append('T {0:G},{1:G}'.format(
429 | segment.end.real, segment.end.imag)
430 | )
431 | else:
432 | parts.append('Q {0:G},{1:G} {2:G},{3:G}'.format(
433 | segment.control.real, segment.control.imag,
434 | segment.end.real, segment.end.imag)
435 | )
436 |
437 | elif isinstance(segment, Arc):
438 | parts.append('A {0:G},{1:G} {2:G} {3:d},{4:d} {5:G},{6:G}'.format(
439 | segment.radius.real, segment.radius.imag, segment.rotation,
440 | int(segment.arc), int(segment.sweep),
441 | segment.end.real, segment.end.imag)
442 | )
443 | current_pos = segment.end
444 | previous_segment = segment
445 |
446 | if self.closed:
447 | parts.append('Z')
448 |
449 | return ' '.join(parts)
450 |
--------------------------------------------------------------------------------
/magenta_rnn.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """SketchRNN RNN definition."""
15 |
16 | # internal imports
17 | import numpy as np
18 | import tensorflow as tf
19 |
20 |
21 | def orthogonal(shape):
22 | """Orthogonal initilaizer."""
23 | flat_shape = (shape[0], np.prod(shape[1:]))
24 | a = np.random.normal(0.0, 1.0, flat_shape)
25 | u, _, v = np.linalg.svd(a, full_matrices=False)
26 | q = u if u.shape == flat_shape else v
27 | return q.reshape(shape)
28 |
29 |
30 | def orthogonal_initializer(scale=1.0):
31 | """Orthogonal initializer."""
32 |
33 | def _initializer(shape, dtype=tf.float32,
34 | partition_info=None): # pylint: disable=unused-argument
35 | return tf.constant(orthogonal(shape) * scale, dtype)
36 |
37 | return _initializer
38 |
39 |
40 | def lstm_ortho_initializer(scale=1.0):
41 | """LSTM orthogonal initializer."""
42 |
43 | def _initializer(shape, dtype=tf.float32,
44 | partition_info=None): # pylint: disable=unused-argument
45 | size_x = shape[0]
46 | size_h = shape[1] / 4 # assumes lstm.
47 | t = np.zeros(shape)
48 | t[:, :size_h] = orthogonal([size_x, size_h]) * scale
49 | t[:, size_h:size_h * 2] = orthogonal([size_x, size_h]) * scale
50 | t[:, size_h * 2:size_h * 3] = orthogonal([size_x, size_h]) * scale
51 | t[:, size_h * 3:] = orthogonal([size_x, size_h]) * scale
52 | return tf.constant(t, dtype)
53 |
54 | return _initializer
55 |
56 |
57 | class LSTMCell(tf.contrib.rnn.RNNCell):
58 | """Vanilla LSTM cell.
59 |
60 | Uses ortho initializer, and also recurrent dropout without memory loss
61 | (https://arxiv.org/abs/1603.05118)
62 | """
63 |
64 | def __init__(self,
65 | num_units,
66 | forget_bias=1.0,
67 | use_recurrent_dropout=False,
68 | dropout_keep_prob=0.9):
69 | self.num_units = num_units
70 | self.forget_bias = forget_bias
71 | self.use_recurrent_dropout = use_recurrent_dropout
72 | self.dropout_keep_prob = dropout_keep_prob
73 |
74 | @property
75 | def state_size(self):
76 | return 2 * self.num_units
77 |
78 | @property
79 | def output_size(self):
80 | return self.num_units
81 |
82 | def get_output(self, state):
83 | unused_c, h = tf.split(state, 2, 1)
84 | return h
85 |
86 | def __call__(self, x, state, scope=None):
87 | with tf.variable_scope(scope or type(self).__name__):
88 | c, h = tf.split(state, 2, 1)
89 |
90 | x_size = x.get_shape().as_list()[1]
91 |
92 | w_init = None # uniform
93 |
94 | h_init = lstm_ortho_initializer(1.0)
95 |
96 | # Keep W_xh and W_hh separate here as well to use different init methods.
97 | w_xh = tf.get_variable(
98 | 'W_xh', [x_size, 4 * self.num_units], initializer=w_init)
99 | w_hh = tf.get_variable(
100 | 'W_hh', [self.num_units, 4 * self.num_units], initializer=h_init)
101 | bias = tf.get_variable(
102 | 'bias', [4 * self.num_units],
103 | initializer=tf.constant_initializer(0.0))
104 |
105 | concat = tf.concat([x, h], 1)
106 | w_full = tf.concat([w_xh, w_hh], 0)
107 | hidden = tf.matmul(concat, w_full) + bias
108 |
109 | i, j, f, o = tf.split(hidden, 4, 1)
110 |
111 | if self.use_recurrent_dropout:
112 | g = tf.nn.dropout(tf.tanh(j), self.dropout_keep_prob)
113 | else:
114 | g = tf.tanh(j)
115 |
116 | new_c = c * tf.sigmoid(f + self.forget_bias) + tf.sigmoid(i) * g
117 | new_h = tf.tanh(new_c) * tf.sigmoid(o)
118 |
119 | return new_h, tf.concat([new_c, new_h], 1) # fuk tuples.
120 |
121 |
122 | def layer_norm_all(h,
123 | batch_size,
124 | base,
125 | num_units,
126 | scope='layer_norm',
127 | reuse=False,
128 | gamma_start=1.0,
129 | epsilon=1e-3,
130 | use_bias=True):
131 | """Layer Norm (faster version, but not using defun)."""
132 | # Performs layer norm on multiple base at once (ie, i, g, j, o for lstm)
133 | # Reshapes h in to perform layer norm in parallel
134 | h_reshape = tf.reshape(h, [batch_size, base, num_units])
135 | mean = tf.reduce_mean(h_reshape, [2], keep_dims=True)
136 | var = tf.reduce_mean(tf.square(h_reshape - mean), [2], keep_dims=True)
137 | epsilon = tf.constant(epsilon)
138 | rstd = tf.rsqrt(var + epsilon)
139 | h_reshape = (h_reshape - mean) * rstd
140 | # reshape back to original
141 | h = tf.reshape(h_reshape, [batch_size, base * num_units])
142 | with tf.variable_scope(scope):
143 | if reuse:
144 | tf.get_variable_scope().reuse_variables()
145 | gamma = tf.get_variable(
146 | 'ln_gamma', [4 * num_units],
147 | initializer=tf.constant_initializer(gamma_start))
148 | if use_bias:
149 | beta = tf.get_variable(
150 | 'ln_beta', [4 * num_units], initializer=tf.constant_initializer(0.0))
151 | if use_bias:
152 | return gamma * h + beta
153 | return gamma * h
154 |
155 |
156 | def layer_norm(x,
157 | num_units,
158 | scope='layer_norm',
159 | reuse=False,
160 | gamma_start=1.0,
161 | epsilon=1e-3,
162 | use_bias=True):
163 | """Calculate layer norm."""
164 | axes = [1]
165 | mean = tf.reduce_mean(x, axes, keep_dims=True)
166 | x_shifted = x - mean
167 | var = tf.reduce_mean(tf.square(x_shifted), axes, keep_dims=True)
168 | inv_std = tf.rsqrt(var + epsilon)
169 | with tf.variable_scope(scope):
170 | if reuse is True:
171 | tf.get_variable_scope().reuse_variables()
172 | gamma = tf.get_variable(
173 | 'ln_gamma', [num_units],
174 | initializer=tf.constant_initializer(gamma_start))
175 | if use_bias:
176 | beta = tf.get_variable(
177 | 'ln_beta', [num_units], initializer=tf.constant_initializer(0.0))
178 | output = gamma * (x_shifted) * inv_std
179 | if use_bias:
180 | output += beta
181 | return output
182 |
183 |
184 | def raw_layer_norm(x, epsilon=1e-3):
185 | axes = [1]
186 | mean = tf.reduce_mean(x, axes, keep_dims=True)
187 | std = tf.sqrt(
188 | tf.reduce_mean(tf.square(x - mean), axes, keep_dims=True) + epsilon)
189 | output = (x - mean) / (std)
190 | return output
191 |
192 |
193 | def super_linear(x,
194 | output_size,
195 | scope=None,
196 | reuse=False,
197 | init_w='ortho',
198 | weight_start=0.0,
199 | use_bias=True,
200 | bias_start=0.0,
201 | input_size=None):
202 | """Performs linear operation. Uses ortho init defined earlier."""
203 | shape = x.get_shape().as_list()
204 | with tf.variable_scope(scope or 'linear'):
205 | if reuse is True:
206 | tf.get_variable_scope().reuse_variables()
207 |
208 | w_init = None # uniform
209 | if input_size is None:
210 | x_size = shape[1]
211 | else:
212 | x_size = input_size
213 | if init_w == 'zeros':
214 | w_init = tf.constant_initializer(0.0)
215 | elif init_w == 'constant':
216 | w_init = tf.constant_initializer(weight_start)
217 | elif init_w == 'gaussian':
218 | w_init = tf.random_normal_initializer(stddev=weight_start)
219 | elif init_w == 'ortho':
220 | w_init = lstm_ortho_initializer(1.0)
221 |
222 | w = tf.get_variable(
223 | 'super_linear_w', [x_size, output_size], tf.float32, initializer=w_init)
224 | if use_bias:
225 | b = tf.get_variable(
226 | 'super_linear_b', [output_size],
227 | tf.float32,
228 | initializer=tf.constant_initializer(bias_start))
229 | return tf.matmul(x, w) + b
230 | return tf.matmul(x, w)
231 |
232 |
233 | class LayerNormLSTMCell(tf.contrib.rnn.RNNCell):
234 | """Layer-Norm, with Ortho Init. and Recurrent Dropout without Memory Loss.
235 |
236 | https://arxiv.org/abs/1607.06450 - Layer Norm
237 | https://arxiv.org/abs/1603.05118 - Recurrent Dropout without Memory Loss
238 | """
239 |
240 | def __init__(self,
241 | num_units,
242 | forget_bias=1.0,
243 | use_recurrent_dropout=False,
244 | dropout_keep_prob=0.90):
245 | """Initialize the Layer Norm LSTM cell.
246 |
247 | Args:
248 | num_units: int, The number of units in the LSTM cell.
249 | forget_bias: float, The bias added to forget gates (default 1.0).
250 | use_recurrent_dropout: Whether to use Recurrent Dropout (default False)
251 | dropout_keep_prob: float, dropout keep probability (default 0.90)
252 | """
253 | self.num_units = num_units
254 | self.forget_bias = forget_bias
255 | self.use_recurrent_dropout = use_recurrent_dropout
256 | self.dropout_keep_prob = dropout_keep_prob
257 |
258 | @property
259 | def input_size(self):
260 | return self.num_units
261 |
262 | @property
263 | def output_size(self):
264 | return self.num_units
265 |
266 | @property
267 | def state_size(self):
268 | return 2 * self.num_units
269 |
270 | def get_output(self, state):
271 | h, unused_c = tf.split(state, 2, 1)
272 | return h
273 |
274 | def __call__(self, x, state, timestep=0, scope=None):
275 | with tf.variable_scope(scope or type(self).__name__):
276 | h, c = tf.split(state, 2, 1)
277 |
278 | h_size = self.num_units
279 | x_size = x.get_shape().as_list()[1]
280 | batch_size = x.get_shape().as_list()[0]
281 |
282 | w_init = None # uniform
283 |
284 | h_init = lstm_ortho_initializer(1.0)
285 |
286 | w_xh = tf.get_variable(
287 | 'W_xh', [x_size, 4 * self.num_units], initializer=w_init)
288 | w_hh = tf.get_variable(
289 | 'W_hh', [self.num_units, 4 * self.num_units], initializer=h_init)
290 |
291 | concat = tf.concat([x, h], 1) # concat for speed.
292 | w_full = tf.concat([w_xh, w_hh], 0)
293 | concat = tf.matmul(concat, w_full) # + bias # live life without garbage.
294 |
295 | # i = input_gate, j = new_input, f = forget_gate, o = output_gate
296 | concat = layer_norm_all(concat, batch_size, 4, h_size, 'ln_all')
297 | i, j, f, o = tf.split(concat, 4, 1)
298 |
299 | if self.use_recurrent_dropout:
300 | g = tf.nn.dropout(tf.tanh(j), self.dropout_keep_prob)
301 | else:
302 | g = tf.tanh(j)
303 |
304 | new_c = c * tf.sigmoid(f + self.forget_bias) + tf.sigmoid(i) * g
305 | new_h = tf.tanh(layer_norm(new_c, h_size, 'ln_c')) * tf.sigmoid(o)
306 |
307 | return new_h, tf.concat([new_h, new_c], 1)
308 |
309 |
310 | class HyperLSTMCell(tf.contrib.rnn.RNNCell):
311 | """HyperLSTM with Ortho Init, Layer Norm, Recurrent Dropout, no Memory Loss.
312 |
313 | https://arxiv.org/abs/1609.09106
314 | http://blog.otoro.net/2016/09/28/hyper-networks/
315 | """
316 |
317 | def __init__(self,
318 | num_units,
319 | forget_bias=1.0,
320 | use_recurrent_dropout=False,
321 | dropout_keep_prob=0.90,
322 | use_layer_norm=True,
323 | hyper_num_units=256,
324 | hyper_embedding_size=32,
325 | hyper_use_recurrent_dropout=False):
326 | """Initialize the Layer Norm HyperLSTM cell.
327 |
328 | Args:
329 | num_units: int, The number of units in the LSTM cell.
330 | forget_bias: float, The bias added to forget gates (default 1.0).
331 | use_recurrent_dropout: Whether to use Recurrent Dropout (default False)
332 | dropout_keep_prob: float, dropout keep probability (default 0.90)
333 | use_layer_norm: boolean. (default True)
334 | Controls whether we use LayerNorm layers in main LSTM & HyperLSTM cell.
335 | hyper_num_units: int, number of units in HyperLSTM cell.
336 | (default is 128, recommend experimenting with 256 for larger tasks)
337 | hyper_embedding_size: int, size of signals emitted from HyperLSTM cell.
338 | (default is 16, recommend trying larger values for large datasets)
339 | hyper_use_recurrent_dropout: boolean. (default False)
340 | Controls whether HyperLSTM cell also uses recurrent dropout.
341 | Recommend turning this on only if hyper_num_units becomes large (>= 512)
342 | """
343 | self.num_units = num_units
344 | self.forget_bias = forget_bias
345 | self.use_recurrent_dropout = use_recurrent_dropout
346 | self.dropout_keep_prob = dropout_keep_prob
347 | self.use_layer_norm = use_layer_norm
348 | self.hyper_num_units = hyper_num_units
349 | self.hyper_embedding_size = hyper_embedding_size
350 | self.hyper_use_recurrent_dropout = hyper_use_recurrent_dropout
351 |
352 | self.total_num_units = self.num_units + self.hyper_num_units
353 |
354 | if self.use_layer_norm:
355 | cell_fn = LayerNormLSTMCell
356 | else:
357 | cell_fn = LSTMCell
358 | self.hyper_cell = cell_fn(
359 | hyper_num_units,
360 | use_recurrent_dropout=hyper_use_recurrent_dropout,
361 | dropout_keep_prob=dropout_keep_prob)
362 |
363 | @property
364 | def input_size(self):
365 | return self._input_size
366 |
367 | @property
368 | def output_size(self):
369 | return self.num_units
370 |
371 | @property
372 | def state_size(self):
373 | return 2 * self.total_num_units
374 |
375 | def get_output(self, state):
376 | total_h, unused_total_c = tf.split(state, 2, 1)
377 | h = total_h[:, 0:self.num_units]
378 | return h
379 |
380 | def hyper_norm(self, layer, scope='hyper', use_bias=True):
381 | num_units = self.num_units
382 | embedding_size = self.hyper_embedding_size
383 | # recurrent batch norm init trick (https://arxiv.org/abs/1603.09025).
384 | init_gamma = 0.10 # cooijmans' da man.
385 | with tf.variable_scope(scope):
386 | zw = super_linear(
387 | self.hyper_output,
388 | embedding_size,
389 | init_w='constant',
390 | weight_start=0.00,
391 | use_bias=True,
392 | bias_start=1.0,
393 | scope='zw')
394 | alpha = super_linear(
395 | zw,
396 | num_units,
397 | init_w='constant',
398 | weight_start=init_gamma / embedding_size,
399 | use_bias=False,
400 | scope='alpha')
401 | result = tf.multiply(alpha, layer)
402 | if use_bias:
403 | zb = super_linear(
404 | self.hyper_output,
405 | embedding_size,
406 | init_w='gaussian',
407 | weight_start=0.01,
408 | use_bias=False,
409 | bias_start=0.0,
410 | scope='zb')
411 | beta = super_linear(
412 | zb,
413 | num_units,
414 | init_w='constant',
415 | weight_start=0.00,
416 | use_bias=False,
417 | scope='beta')
418 | result += beta
419 | return result
420 |
421 | def __call__(self, x, state, timestep=0, scope=None):
422 | with tf.variable_scope(scope or type(self).__name__):
423 | total_h, total_c = tf.split(state, 2, 1)
424 | h = total_h[:, 0:self.num_units]
425 | c = total_c[:, 0:self.num_units]
426 | self.hyper_state = tf.concat(
427 | [total_h[:, self.num_units:], total_c[:, self.num_units:]], 1)
428 |
429 | batch_size = x.get_shape().as_list()[0]
430 | x_size = x.get_shape().as_list()[1]
431 | self._input_size = x_size
432 |
433 | w_init = None # uniform
434 |
435 | h_init = lstm_ortho_initializer(1.0)
436 |
437 | w_xh = tf.get_variable(
438 | 'W_xh', [x_size, 4 * self.num_units], initializer=w_init)
439 | w_hh = tf.get_variable(
440 | 'W_hh', [self.num_units, 4 * self.num_units], initializer=h_init)
441 | bias = tf.get_variable(
442 | 'bias', [4 * self.num_units],
443 | initializer=tf.constant_initializer(0.0))
444 |
445 | # concatenate the input and hidden states for hyperlstm input
446 | hyper_input = tf.concat([x, h], 1)
447 | hyper_output, hyper_new_state = self.hyper_cell(hyper_input,
448 | self.hyper_state)
449 | self.hyper_output = hyper_output
450 | self.hyper_state = hyper_new_state
451 |
452 | xh = tf.matmul(x, w_xh)
453 | hh = tf.matmul(h, w_hh)
454 |
455 | # split Wxh contributions
456 | ix, jx, fx, ox = tf.split(xh, 4, 1)
457 | ix = self.hyper_norm(ix, 'hyper_ix', use_bias=False)
458 | jx = self.hyper_norm(jx, 'hyper_jx', use_bias=False)
459 | fx = self.hyper_norm(fx, 'hyper_fx', use_bias=False)
460 | ox = self.hyper_norm(ox, 'hyper_ox', use_bias=False)
461 |
462 | # split Whh contributions
463 | ih, jh, fh, oh = tf.split(hh, 4, 1)
464 | ih = self.hyper_norm(ih, 'hyper_ih', use_bias=True)
465 | jh = self.hyper_norm(jh, 'hyper_jh', use_bias=True)
466 | fh = self.hyper_norm(fh, 'hyper_fh', use_bias=True)
467 | oh = self.hyper_norm(oh, 'hyper_oh', use_bias=True)
468 |
469 | # split bias
470 | ib, jb, fb, ob = tf.split(bias, 4, 0) # bias is to be broadcasted.
471 |
472 | # i = input_gate, j = new_input, f = forget_gate, o = output_gate
473 | i = ix + ih + ib
474 | j = jx + jh + jb
475 | f = fx + fh + fb
476 | o = ox + oh + ob
477 |
478 | if self.use_layer_norm:
479 | concat = tf.concat([i, j, f, o], 1)
480 | concat = layer_norm_all(concat, batch_size, 4, self.num_units, 'ln_all')
481 | i, j, f, o = tf.split(concat, 4, 1)
482 |
483 | if self.use_recurrent_dropout:
484 | g = tf.nn.dropout(tf.tanh(j), self.dropout_keep_prob)
485 | else:
486 | g = tf.tanh(j)
487 |
488 | new_c = c * tf.sigmoid(f + self.forget_bias) + tf.sigmoid(i) * g
489 | new_h = tf.tanh(layer_norm(new_c, self.num_units, 'ln_c')) * tf.sigmoid(o)
490 |
491 | hyper_h, hyper_c = tf.split(hyper_new_state, 2, 1)
492 | new_total_h = tf.concat([new_h, hyper_h], 1)
493 | new_total_c = tf.concat([new_c, hyper_c], 1)
494 | new_total_state = tf.concat([new_total_h, new_total_c], 1)
495 | return new_h, new_total_state
496 |
--------------------------------------------------------------------------------
/svg/path/tests/test_paths.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | import unittest
3 | from math import sqrt, pi
4 |
5 | from ..path import CubicBezier, QuadraticBezier, Line, Arc, Path
6 |
7 |
8 | # Most of these test points are not calculated serparately, as that would
9 | # take too long and be too error prone. Instead the curves have been verified
10 | # to be correct visually, by drawing them with the turtle module, with code
11 | # like this:
12 | #
13 | # import turtle
14 | # t = turtle.Turtle()
15 | # t.penup()
16 | #
17 | # for arc in (path1, path2):
18 | # p = arc.point(0)
19 | # t.goto(p.real - 500, -p.imag + 300)
20 | # t.dot(3, 'black')
21 | # t.pendown()
22 | # for x in range(1, 101):
23 | # p = arc.point(x * 0.01)
24 | # t.goto(p.real - 500, -p.imag + 300)
25 | # t.penup()
26 | # t.dot(3, 'black')
27 | #
28 | # raw_input()
29 | #
30 | # After the paths have been verified to be correct this way, the testing of
31 | # points along the paths has been added as regression tests, to make sure
32 | # nobody changes the way curves are drawn by mistake. Therefore, do not take
33 | # these points religiously. They might be subtly wrong, unless otherwise
34 | # noted.
35 |
36 | class LineTest(unittest.TestCase):
37 |
38 | def test_lines(self):
39 | # These points are calculated, and not just regression tests.
40 |
41 | line1 = Line(0j, 400 + 0j)
42 | self.assertAlmostEqual(line1.point(0), (0j))
43 | self.assertAlmostEqual(line1.point(0.3), (120 + 0j))
44 | self.assertAlmostEqual(line1.point(0.5), (200 + 0j))
45 | self.assertAlmostEqual(line1.point(0.9), (360 + 0j))
46 | self.assertAlmostEqual(line1.point(1), (400 + 0j))
47 | self.assertAlmostEqual(line1.length(), 400)
48 |
49 | line2 = Line(400 + 0j, 400 + 300j)
50 | self.assertAlmostEqual(line2.point(0), (400 + 0j))
51 | self.assertAlmostEqual(line2.point(0.3), (400 + 90j))
52 | self.assertAlmostEqual(line2.point(0.5), (400 + 150j))
53 | self.assertAlmostEqual(line2.point(0.9), (400 + 270j))
54 | self.assertAlmostEqual(line2.point(1), (400 + 300j))
55 | self.assertAlmostEqual(line2.length(), 300)
56 |
57 | line3 = Line(400 + 300j, 0j)
58 | self.assertAlmostEqual(line3.point(0), (400 + 300j))
59 | self.assertAlmostEqual(line3.point(0.3), (280 + 210j))
60 | self.assertAlmostEqual(line3.point(0.5), (200 + 150j))
61 | self.assertAlmostEqual(line3.point(0.9), (40 + 30j))
62 | self.assertAlmostEqual(line3.point(1), (0j))
63 | self.assertAlmostEqual(line3.length(), 500)
64 |
65 | def test_equality(self):
66 | # This is to test the __eq__ and __ne__ methods, so we can't use
67 | # assertEqual and assertNotEqual
68 | line = Line(0j, 400 + 0j)
69 | self.assertTrue(line == Line(0, 400))
70 | self.assertTrue(line != Line(100, 400))
71 | self.assertFalse(line == str(line))
72 | self.assertTrue(line != str(line))
73 | self.assertFalse(CubicBezier(600 + 500j, 600 + 350j, 900 + 650j, 900 + 500j) ==
74 | line)
75 |
76 |
77 | class CubicBezierTest(unittest.TestCase):
78 | def test_approx_circle(self):
79 | """This is a approximate circle drawn in Inkscape"""
80 |
81 | arc1 = CubicBezier(
82 | complex(0, 0),
83 | complex(0, 109.66797),
84 | complex(-88.90345, 198.57142),
85 | complex(-198.57142, 198.57142)
86 | )
87 |
88 | self.assertAlmostEqual(arc1.point(0), (0j))
89 | self.assertAlmostEqual(arc1.point(0.1), (-2.59896457 + 32.20931647j))
90 | self.assertAlmostEqual(arc1.point(0.2), (-10.12330256 + 62.76392816j))
91 | self.assertAlmostEqual(arc1.point(0.3), (-22.16418039 + 91.25500149j))
92 | self.assertAlmostEqual(arc1.point(0.4), (-38.31276448 + 117.27370288j))
93 | self.assertAlmostEqual(arc1.point(0.5), (-58.16022125 + 140.41119875j))
94 | self.assertAlmostEqual(arc1.point(0.6), (-81.29771712 + 160.25865552j))
95 | self.assertAlmostEqual(arc1.point(0.7), (-107.31641851 + 176.40723961j))
96 | self.assertAlmostEqual(arc1.point(0.8), (-135.80749184 + 188.44811744j))
97 | self.assertAlmostEqual(arc1.point(0.9), (-166.36210353 + 195.97245543j))
98 | self.assertAlmostEqual(arc1.point(1), (-198.57142 + 198.57142j))
99 |
100 | arc2 = CubicBezier(
101 | complex(-198.57142, 198.57142),
102 | complex(-109.66797 - 198.57142, 0 + 198.57142),
103 | complex(-198.57143 - 198.57142, -88.90345 + 198.57142),
104 | complex(-198.57143 - 198.57142, 0),
105 | )
106 |
107 | self.assertAlmostEqual(arc2.point(0), (-198.57142 + 198.57142j))
108 | self.assertAlmostEqual(arc2.point(0.1), (-230.78073675 + 195.97245543j))
109 | self.assertAlmostEqual(arc2.point(0.2), (-261.3353492 + 188.44811744j))
110 | self.assertAlmostEqual(arc2.point(0.3), (-289.82642365 + 176.40723961j))
111 | self.assertAlmostEqual(arc2.point(0.4), (-315.8451264 + 160.25865552j))
112 | self.assertAlmostEqual(arc2.point(0.5), (-338.98262375 + 140.41119875j))
113 | self.assertAlmostEqual(arc2.point(0.6), (-358.830082 + 117.27370288j))
114 | self.assertAlmostEqual(arc2.point(0.7), (-374.97866745 + 91.25500149j))
115 | self.assertAlmostEqual(arc2.point(0.8), (-387.0195464 + 62.76392816j))
116 | self.assertAlmostEqual(arc2.point(0.9), (-394.54388515 + 32.20931647j))
117 | self.assertAlmostEqual(arc2.point(1), (-397.14285 + 0j))
118 |
119 | arc3 = CubicBezier(
120 | complex(-198.57143 - 198.57142, 0),
121 | complex(0 - 198.57143 - 198.57142, -109.66797),
122 | complex(88.90346 - 198.57143 - 198.57142, -198.57143),
123 | complex(-198.57142, -198.57143)
124 | )
125 |
126 | self.assertAlmostEqual(arc3.point(0), (-397.14285 + 0j))
127 | self.assertAlmostEqual(arc3.point(0.1), (-394.54388515 - 32.20931675j))
128 | self.assertAlmostEqual(arc3.point(0.2), (-387.0195464 - 62.7639292j))
129 | self.assertAlmostEqual(arc3.point(0.3), (-374.97866745 - 91.25500365j))
130 | self.assertAlmostEqual(arc3.point(0.4), (-358.830082 - 117.2737064j))
131 | self.assertAlmostEqual(arc3.point(0.5), (-338.98262375 - 140.41120375j))
132 | self.assertAlmostEqual(arc3.point(0.6), (-315.8451264 - 160.258662j))
133 | self.assertAlmostEqual(arc3.point(0.7), (-289.82642365 - 176.40724745j))
134 | self.assertAlmostEqual(arc3.point(0.8), (-261.3353492 - 188.4481264j))
135 | self.assertAlmostEqual(arc3.point(0.9), (-230.78073675 - 195.97246515j))
136 | self.assertAlmostEqual(arc3.point(1), (-198.57142 - 198.57143j))
137 |
138 | arc4 = CubicBezier(
139 | complex(-198.57142, -198.57143),
140 | complex(109.66797 - 198.57142, 0 - 198.57143),
141 | complex(0, 88.90346 - 198.57143),
142 | complex(0, 0),
143 | )
144 |
145 | self.assertAlmostEqual(arc4.point(0), (-198.57142 - 198.57143j))
146 | self.assertAlmostEqual(arc4.point(0.1), (-166.36210353 - 195.97246515j))
147 | self.assertAlmostEqual(arc4.point(0.2), (-135.80749184 - 188.4481264j))
148 | self.assertAlmostEqual(arc4.point(0.3), (-107.31641851 - 176.40724745j))
149 | self.assertAlmostEqual(arc4.point(0.4), (-81.29771712 - 160.258662j))
150 | self.assertAlmostEqual(arc4.point(0.5), (-58.16022125 - 140.41120375j))
151 | self.assertAlmostEqual(arc4.point(0.6), (-38.31276448 - 117.2737064j))
152 | self.assertAlmostEqual(arc4.point(0.7), (-22.16418039 - 91.25500365j))
153 | self.assertAlmostEqual(arc4.point(0.8), (-10.12330256 - 62.7639292j))
154 | self.assertAlmostEqual(arc4.point(0.9), (-2.59896457 - 32.20931675j))
155 | self.assertAlmostEqual(arc4.point(1), (0j))
156 |
157 | def test_svg_examples(self):
158 |
159 | # M100,200 C100,100 250,100 250,200
160 | path1 = CubicBezier(100 + 200j, 100 + 100j, 250 + 100j, 250 + 200j)
161 | self.assertAlmostEqual(path1.point(0), (100 + 200j))
162 | self.assertAlmostEqual(path1.point(0.3), (132.4 + 137j))
163 | self.assertAlmostEqual(path1.point(0.5), (175 + 125j))
164 | self.assertAlmostEqual(path1.point(0.9), (245.8 + 173j))
165 | self.assertAlmostEqual(path1.point(1), (250 + 200j))
166 |
167 | # S400,300 400,200
168 | path2 = CubicBezier(250 + 200j, 250 + 300j, 400 + 300j, 400 + 200j)
169 | self.assertAlmostEqual(path2.point(0), (250 + 200j))
170 | self.assertAlmostEqual(path2.point(0.3), (282.4 + 263j))
171 | self.assertAlmostEqual(path2.point(0.5), (325 + 275j))
172 | self.assertAlmostEqual(path2.point(0.9), (395.8 + 227j))
173 | self.assertAlmostEqual(path2.point(1), (400 + 200j))
174 |
175 | # M100,200 C100,100 400,100 400,200
176 | path3 = CubicBezier(100 + 200j, 100 + 100j, 400 + 100j, 400 + 200j)
177 | self.assertAlmostEqual(path3.point(0), (100 + 200j))
178 | self.assertAlmostEqual(path3.point(0.3), (164.8 + 137j))
179 | self.assertAlmostEqual(path3.point(0.5), (250 + 125j))
180 | self.assertAlmostEqual(path3.point(0.9), (391.6 + 173j))
181 | self.assertAlmostEqual(path3.point(1), (400 + 200j))
182 |
183 | # M100,500 C25,400 475,400 400,500
184 | path4 = CubicBezier(100 + 500j, 25 + 400j, 475 + 400j, 400 + 500j)
185 | self.assertAlmostEqual(path4.point(0), (100 + 500j))
186 | self.assertAlmostEqual(path4.point(0.3), (145.9 + 437j))
187 | self.assertAlmostEqual(path4.point(0.5), (250 + 425j))
188 | self.assertAlmostEqual(path4.point(0.9), (407.8 + 473j))
189 | self.assertAlmostEqual(path4.point(1), (400 + 500j))
190 |
191 | # M100,800 C175,700 325,700 400,800
192 | path5 = CubicBezier(100 + 800j, 175 + 700j, 325 + 700j, 400 + 800j)
193 | self.assertAlmostEqual(path5.point(0), (100 + 800j))
194 | self.assertAlmostEqual(path5.point(0.3), (183.7 + 737j))
195 | self.assertAlmostEqual(path5.point(0.5), (250 + 725j))
196 | self.assertAlmostEqual(path5.point(0.9), (375.4 + 773j))
197 | self.assertAlmostEqual(path5.point(1), (400 + 800j))
198 |
199 | # M600,200 C675,100 975,100 900,200
200 | path6 = CubicBezier(600 + 200j, 675 + 100j, 975 + 100j, 900 + 200j)
201 | self.assertAlmostEqual(path6.point(0), (600 + 200j))
202 | self.assertAlmostEqual(path6.point(0.3), (712.05 + 137j))
203 | self.assertAlmostEqual(path6.point(0.5), (806.25 + 125j))
204 | self.assertAlmostEqual(path6.point(0.9), (911.85 + 173j))
205 | self.assertAlmostEqual(path6.point(1), (900 + 200j))
206 |
207 | # M600,500 C600,350 900,650 900,500
208 | path7 = CubicBezier(600 + 500j, 600 + 350j, 900 + 650j, 900 + 500j)
209 | self.assertAlmostEqual(path7.point(0), (600 + 500j))
210 | self.assertAlmostEqual(path7.point(0.3), (664.8 + 462.2j))
211 | self.assertAlmostEqual(path7.point(0.5), (750 + 500j))
212 | self.assertAlmostEqual(path7.point(0.9), (891.6 + 532.4j))
213 | self.assertAlmostEqual(path7.point(1), (900 + 500j))
214 |
215 | # M600,800 C625,700 725,700 750,800
216 | path8 = CubicBezier(600 + 800j, 625 + 700j, 725 + 700j, 750 + 800j)
217 | self.assertAlmostEqual(path8.point(0), (600 + 800j))
218 | self.assertAlmostEqual(path8.point(0.3), (638.7 + 737j))
219 | self.assertAlmostEqual(path8.point(0.5), (675 + 725j))
220 | self.assertAlmostEqual(path8.point(0.9), (740.4 + 773j))
221 | self.assertAlmostEqual(path8.point(1), (750 + 800j))
222 |
223 | # S875,900 900,800
224 | inversion = (750 + 800j) + (750 + 800j) - (725 + 700j)
225 | path9 = CubicBezier(750 + 800j, inversion, 875 + 900j, 900 + 800j)
226 | self.assertAlmostEqual(path9.point(0), (750 + 800j))
227 | self.assertAlmostEqual(path9.point(0.3), (788.7 + 863j))
228 | self.assertAlmostEqual(path9.point(0.5), (825 + 875j))
229 | self.assertAlmostEqual(path9.point(0.9), (890.4 + 827j))
230 | self.assertAlmostEqual(path9.point(1), (900 + 800j))
231 |
232 | def test_length(self):
233 |
234 | # A straight line:
235 | arc = CubicBezier(
236 | complex(0, 0),
237 | complex(0, 0),
238 | complex(0, 100),
239 | complex(0, 100)
240 | )
241 |
242 | self.assertAlmostEqual(arc.length(), 100)
243 |
244 | # A diagonal line:
245 | arc = CubicBezier(
246 | complex(0, 0),
247 | complex(0, 0),
248 | complex(100, 100),
249 | complex(100, 100)
250 | )
251 |
252 | self.assertAlmostEqual(arc.length(), sqrt(2 * 100 * 100))
253 |
254 | # A quarter circle arc with radius 100:
255 | kappa = 4 * (sqrt(2) - 1) / 3 # http://www.whizkidtech.redprince.net/bezier/circle/
256 |
257 | arc = CubicBezier(
258 | complex(0, 0),
259 | complex(0, kappa * 100),
260 | complex(100 - kappa * 100, 100),
261 | complex(100, 100)
262 | )
263 |
264 | # We can't compare with pi*50 here, because this is just an
265 | # approximation of a circle arc. pi*50 is 157.079632679
266 | # So this is just yet another "warn if this changes" test.
267 | # This value is not verified to be correct.
268 | self.assertAlmostEqual(arc.length(), 157.1016698)
269 |
270 | # A recursive solution has also been suggested, but for CubicBezier
271 | # curves it could get a false solution on curves where the midpoint is on a
272 | # straight line between the start and end. For example, the following
273 | # curve would get solved as a straight line and get the length 300.
274 | # Make sure this is not the case.
275 | arc = CubicBezier(
276 | complex(600, 500),
277 | complex(600, 350),
278 | complex(900, 650),
279 | complex(900, 500)
280 | )
281 | self.assertTrue(arc.length() > 300.0)
282 |
283 | def test_equality(self):
284 | # This is to test the __eq__ and __ne__ methods, so we can't use
285 | # assertEqual and assertNotEqual
286 | segment = CubicBezier(complex(600, 500), complex(600, 350),
287 | complex(900, 650), complex(900, 500))
288 |
289 | self.assertTrue(segment ==
290 | CubicBezier(600 + 500j, 600 + 350j, 900 + 650j, 900 + 500j))
291 | self.assertTrue(segment !=
292 | CubicBezier(600 + 501j, 600 + 350j, 900 + 650j, 900 + 500j))
293 | self.assertTrue(segment != Line(0, 400))
294 |
295 |
296 | class QuadraticBezierTest(unittest.TestCase):
297 |
298 | def test_svg_examples(self):
299 | """These is the path in the SVG specs"""
300 | # M200,300 Q400,50 600,300 T1000,300
301 | path1 = QuadraticBezier(200 + 300j, 400 + 50j, 600 + 300j)
302 | self.assertAlmostEqual(path1.point(0), (200 + 300j))
303 | self.assertAlmostEqual(path1.point(0.3), (320 + 195j))
304 | self.assertAlmostEqual(path1.point(0.5), (400 + 175j))
305 | self.assertAlmostEqual(path1.point(0.9), (560 + 255j))
306 | self.assertAlmostEqual(path1.point(1), (600 + 300j))
307 |
308 | # T1000, 300
309 | inversion = (600 + 300j) + (600 + 300j) - (400 + 50j)
310 | path2 = QuadraticBezier(600 + 300j, inversion, 1000 + 300j)
311 | self.assertAlmostEqual(path2.point(0), (600 + 300j))
312 | self.assertAlmostEqual(path2.point(0.3), (720 + 405j))
313 | self.assertAlmostEqual(path2.point(0.5), (800 + 425j))
314 | self.assertAlmostEqual(path2.point(0.9), (960 + 345j))
315 | self.assertAlmostEqual(path2.point(1), (1000 + 300j))
316 |
317 | def test_length(self):
318 | # expected results calculated with
319 | # svg.path.segment_length(q, 0, 1, q.start, q.end, 1e-14, 20, 0)
320 | q1 = QuadraticBezier(200 + 300j, 400 + 50j, 600 + 300j)
321 | q2 = QuadraticBezier(200 + 300j, 400 + 50j, 500 + 200j)
322 | closedq = QuadraticBezier(6+2j, 5-1j, 6+2j)
323 | linq1 = QuadraticBezier(1, 2, 3)
324 | linq2 = QuadraticBezier(1+3j, 2+5j, -9 - 17j)
325 | nodalq = QuadraticBezier(1, 1, 1)
326 | tests = [(q1, 487.77109389525975),
327 | (q2, 379.90458193489155),
328 | (closedq, 3.1622776601683795),
329 | (linq1, 2),
330 | (linq2, 22.73335777124786),
331 | (nodalq, 0)]
332 | for q, exp_res in tests:
333 | self.assertAlmostEqual(q.length(), exp_res)
334 |
335 | def test_equality(self):
336 | # This is to test the __eq__ and __ne__ methods, so we can't use
337 | # assertEqual and assertNotEqual
338 | segment = QuadraticBezier(200 + 300j, 400 + 50j, 600 + 300j)
339 | self.assertTrue(segment == QuadraticBezier(200 + 300j, 400 + 50j, 600 + 300j))
340 | self.assertTrue(segment != QuadraticBezier(200 + 301j, 400 + 50j, 600 + 300j))
341 | self.assertFalse(segment == Arc(0j, 100 + 50j, 0, 0, 0, 100 + 50j))
342 | self.assertTrue(Arc(0j, 100 + 50j, 0, 0, 0, 100 + 50j) != segment)
343 |
344 |
345 | class ArcTest(unittest.TestCase):
346 |
347 | def test_points(self):
348 | arc1 = Arc(0j, 100 + 50j, 0, 0, 0, 100 + 50j)
349 | self.assertAlmostEqual(arc1.center, 100 + 0j)
350 | self.assertAlmostEqual(arc1.theta, 180.0)
351 | self.assertAlmostEqual(arc1.delta, -90.0)
352 |
353 | self.assertAlmostEqual(arc1.point(0.0), (0j))
354 | self.assertAlmostEqual(arc1.point(0.1), (1.23116594049 + 7.82172325201j))
355 | self.assertAlmostEqual(arc1.point(0.2), (4.89434837048 + 15.4508497187j))
356 | self.assertAlmostEqual(arc1.point(0.3), (10.8993475812 + 22.699524987j))
357 | self.assertAlmostEqual(arc1.point(0.4), (19.0983005625 + 29.3892626146j))
358 | self.assertAlmostEqual(arc1.point(0.5), (29.2893218813 + 35.3553390593j))
359 | self.assertAlmostEqual(arc1.point(0.6), (41.2214747708 + 40.4508497187j))
360 | self.assertAlmostEqual(arc1.point(0.7), (54.6009500260 + 44.5503262094j))
361 | self.assertAlmostEqual(arc1.point(0.8), (69.0983005625 + 47.5528258148j))
362 | self.assertAlmostEqual(arc1.point(0.9), (84.3565534960 + 49.3844170298j))
363 | self.assertAlmostEqual(arc1.point(1.0), (100 + 50j))
364 |
365 | arc2 = Arc(0j, 100 + 50j, 0, 1, 0, 100 + 50j)
366 | self.assertAlmostEqual(arc2.center, 50j)
367 | self.assertAlmostEqual(arc2.theta, 270.0)
368 | self.assertAlmostEqual(arc2.delta, -270.0)
369 |
370 | self.assertAlmostEqual(arc2.point(0.0), (0j))
371 | self.assertAlmostEqual(arc2.point(0.1), (-45.399049974 + 5.44967379058j))
372 | self.assertAlmostEqual(arc2.point(0.2), (-80.9016994375 + 20.6107373854j))
373 | self.assertAlmostEqual(arc2.point(0.3), (-98.7688340595 + 42.178276748j))
374 | self.assertAlmostEqual(arc2.point(0.4), (-95.1056516295 + 65.4508497187j))
375 | self.assertAlmostEqual(arc2.point(0.5), (-70.7106781187 + 85.3553390593j))
376 | self.assertAlmostEqual(arc2.point(0.6), (-30.9016994375 + 97.5528258148j))
377 | self.assertAlmostEqual(arc2.point(0.7), (15.643446504 + 99.3844170298j))
378 | self.assertAlmostEqual(arc2.point(0.8), (58.7785252292 + 90.4508497187j))
379 | self.assertAlmostEqual(arc2.point(0.9), (89.1006524188 + 72.699524987j))
380 | self.assertAlmostEqual(arc2.point(1.0), (100 + 50j))
381 |
382 | arc3 = Arc(0j, 100 + 50j, 0, 0, 1, 100 + 50j)
383 | self.assertAlmostEqual(arc3.center, 50j)
384 | self.assertAlmostEqual(arc3.theta, 270.0)
385 | self.assertAlmostEqual(arc3.delta, 90.0)
386 |
387 | self.assertAlmostEqual(arc3.point(0.0), (0j))
388 | self.assertAlmostEqual(arc3.point(0.1), (15.643446504 + 0.615582970243j))
389 | self.assertAlmostEqual(arc3.point(0.2), (30.9016994375 + 2.44717418524j))
390 | self.assertAlmostEqual(arc3.point(0.3), (45.399049974 + 5.44967379058j))
391 | self.assertAlmostEqual(arc3.point(0.4), (58.7785252292 + 9.54915028125j))
392 | self.assertAlmostEqual(arc3.point(0.5), (70.7106781187 + 14.6446609407j))
393 | self.assertAlmostEqual(arc3.point(0.6), (80.9016994375 + 20.6107373854j))
394 | self.assertAlmostEqual(arc3.point(0.7), (89.1006524188 + 27.300475013j))
395 | self.assertAlmostEqual(arc3.point(0.8), (95.1056516295 + 34.5491502813j))
396 | self.assertAlmostEqual(arc3.point(0.9), (98.7688340595 + 42.178276748j))
397 | self.assertAlmostEqual(arc3.point(1.0), (100 + 50j))
398 |
399 | arc4 = Arc(0j, 100 + 50j, 0, 1, 1, 100 + 50j)
400 | self.assertAlmostEqual(arc4.center, 100 + 0j)
401 | self.assertAlmostEqual(arc4.theta, 180.0)
402 | self.assertAlmostEqual(arc4.delta, 270.0)
403 |
404 | self.assertAlmostEqual(arc4.point(0.0), (0j))
405 | self.assertAlmostEqual(arc4.point(0.1), (10.8993475812 - 22.699524987j))
406 | self.assertAlmostEqual(arc4.point(0.2), (41.2214747708 - 40.4508497187j))
407 | self.assertAlmostEqual(arc4.point(0.3), (84.3565534960 - 49.3844170298j))
408 | self.assertAlmostEqual(arc4.point(0.4), (130.901699437 - 47.5528258148j))
409 | self.assertAlmostEqual(arc4.point(0.5), (170.710678119 - 35.3553390593j))
410 | self.assertAlmostEqual(arc4.point(0.6), (195.105651630 - 15.4508497187j))
411 | self.assertAlmostEqual(arc4.point(0.7), (198.768834060 + 7.82172325201j))
412 | self.assertAlmostEqual(arc4.point(0.8), (180.901699437 + 29.3892626146j))
413 | self.assertAlmostEqual(arc4.point(0.9), (145.399049974 + 44.5503262094j))
414 | self.assertAlmostEqual(arc4.point(1.0), (100 + 50j))
415 |
416 | def test_length(self):
417 | # I'll test the length calculations by making a circle, in two parts.
418 | arc1 = Arc(0j, 100 + 100j, 0, 0, 0, 200 + 0j)
419 | arc2 = Arc(200 + 0j, 100 + 100j, 0, 0, 0, 0j)
420 | self.assertAlmostEqual(arc1.length(), pi * 100)
421 | self.assertAlmostEqual(arc2.length(), pi * 100)
422 |
423 | def test_equality(self):
424 | # This is to test the __eq__ and __ne__ methods, so we can't use
425 | # assertEqual and assertNotEqual
426 | segment = Arc(0j, 100 + 50j, 0, 0, 0, 100 + 50j)
427 | self.assertTrue(segment == Arc(0j, 100 + 50j, 0, 0, 0, 100 + 50j))
428 | self.assertTrue(segment != Arc(0j, 100 + 50j, 0, 1, 0, 100 + 50j))
429 |
430 |
431 | class TestPath(unittest.TestCase):
432 |
433 | def test_circle(self):
434 | arc1 = Arc(0j, 100 + 100j, 0, 0, 0, 200 + 0j)
435 | arc2 = Arc(200 + 0j, 100 + 100j, 0, 0, 0, 0j)
436 | path = Path(arc1, arc2)
437 | self.assertAlmostEqual(path.point(0.0), (0j))
438 | self.assertAlmostEqual(path.point(0.25), (100 + 100j))
439 | self.assertAlmostEqual(path.point(0.5), (200 + 0j))
440 | self.assertAlmostEqual(path.point(0.75), (100 - 100j))
441 | self.assertAlmostEqual(path.point(1.0), (0j))
442 | self.assertAlmostEqual(path.length(), pi * 200)
443 |
444 | def test_svg_specs(self):
445 | """The paths that are in the SVG specs"""
446 |
447 | # Big pie: M300,200 h-150 a150,150 0 1,0 150,-150 z
448 | path = Path(Line(300 + 200j, 150 + 200j),
449 | Arc(150 + 200j, 150 + 150j, 0, 1, 0, 300 + 50j),
450 | Line(300 + 50j, 300 + 200j))
451 | # The points and length for this path are calculated and not regression tests.
452 | self.assertAlmostEqual(path.point(0.0), (300 + 200j))
453 | self.assertAlmostEqual(path.point(0.14897825542), (150 + 200j))
454 | self.assertAlmostEqual(path.point(0.5), (406.066017177 + 306.066017177j))
455 | self.assertAlmostEqual(path.point(1 - 0.14897825542), (300 + 50j))
456 | self.assertAlmostEqual(path.point(1.0), (300 + 200j))
457 | # The errors seem to accumulate. Still 6 decimal places is more than good enough.
458 | self.assertAlmostEqual(path.length(), pi * 225 + 300, places=6)
459 |
460 | # Little pie: M275,175 v-150 a150,150 0 0,0 -150,150 z
461 | path = Path(Line(275 + 175j, 275 + 25j),
462 | Arc(275 + 25j, 150 + 150j, 0, 0, 0, 125 + 175j),
463 | Line(125 + 175j, 275 + 175j))
464 | # The points and length for this path are calculated and not regression tests.
465 | self.assertAlmostEqual(path.point(0.0), (275 + 175j))
466 | self.assertAlmostEqual(path.point(0.2800495767557787), (275 + 25j))
467 | self.assertAlmostEqual(path.point(0.5), (168.93398282201787 + 68.93398282201787j))
468 | self.assertAlmostEqual(path.point(1 - 0.2800495767557787), (125 + 175j))
469 | self.assertAlmostEqual(path.point(1.0), (275 + 175j))
470 | # The errors seem to accumulate. Still 6 decimal places is more than good enough.
471 | self.assertAlmostEqual(path.length(), pi * 75 + 300, places=6)
472 |
473 | # Bumpy path: M600,350 l 50,-25
474 | # a25,25 -30 0,1 50,-25 l 50,-25
475 | # a25,50 -30 0,1 50,-25 l 50,-25
476 | # a25,75 -30 0,1 50,-25 l 50,-25
477 | # a25,100 -30 0,1 50,-25 l 50,-25
478 | path = Path(Line(600 + 350j, 650 + 325j),
479 | Arc(650 + 325j, 25 + 25j, -30, 0, 1, 700 + 300j),
480 | Line(700 + 300j, 750 + 275j),
481 | Arc(750 + 275j, 25 + 50j, -30, 0, 1, 800 + 250j),
482 | Line(800 + 250j, 850 + 225j),
483 | Arc(850 + 225j, 25 + 75j, -30, 0, 1, 900 + 200j),
484 | Line(900 + 200j, 950 + 175j),
485 | Arc(950 + 175j, 25 + 100j, -30, 0, 1, 1000 + 150j),
486 | Line(1000 + 150j, 1050 + 125j),
487 | )
488 | # These are *not* calculated, but just regression tests. Be skeptical.
489 | self.assertAlmostEqual(path.point(0.0), (600 + 350j))
490 | self.assertAlmostEqual(path.point(0.3), (755.31526434 + 217.51578768j))
491 | self.assertAlmostEqual(path.point(0.5), (832.23324151 + 156.33454892j))
492 | self.assertAlmostEqual(path.point(0.9), (974.00559321 + 115.26473532j))
493 | self.assertAlmostEqual(path.point(1.0), (1050 + 125j))
494 | # The errors seem to accumulate. Still 6 decimal places is more than good enough.
495 | self.assertAlmostEqual(path.length(), 860.6756221710)
496 |
497 | def test_repr(self):
498 | path = Path(
499 | Line(start=600 + 350j, end=650 + 325j),
500 | Arc(start=650 + 325j, radius=25 + 25j, rotation=-30, arc=0, sweep=1, end=700 + 300j),
501 | CubicBezier(start=700 + 300j, control1=800 + 400j, control2=750 + 200j, end=600 + 100j),
502 | QuadraticBezier(start=600 + 100j, control=600, end=600 + 300j))
503 | self.assertEqual(eval(repr(path)), path)
504 |
505 | def test_reverse(self):
506 | # Currently you can't reverse paths.
507 | self.assertRaises(NotImplementedError, Path().reverse)
508 |
509 | def test_equality(self):
510 | # This is to test the __eq__ and __ne__ methods, so we can't use
511 | # assertEqual and assertNotEqual
512 | path1 = Path(
513 | Line(start=600 + 350j, end=650 + 325j),
514 | Arc(start=650 + 325j, radius=25 + 25j, rotation=-30, arc=0, sweep=1, end=700 + 300j),
515 | CubicBezier(start=700 + 300j, control1=800 + 400j, control2=750 + 200j, end=600 + 100j),
516 | QuadraticBezier(start=600 + 100j, control=600, end=600 + 300j))
517 | path2 = Path(
518 | Line(start=600 + 350j, end=650 + 325j),
519 | Arc(start=650 + 325j, radius=25 + 25j, rotation=-30, arc=0, sweep=1, end=700 + 300j),
520 | CubicBezier(start=700 + 300j, control1=800 + 400j, control2=750 + 200j, end=600 + 100j),
521 | QuadraticBezier(start=600 + 100j, control=600, end=600 + 300j))
522 |
523 | self.assertTrue(path1 == path2)
524 | # Modify path2:
525 | path2[0].start = 601 + 350j
526 | self.assertTrue(path1 != path2)
527 |
528 | # Modify back:
529 | path2[0].start = 600 + 350j
530 | self.assertFalse(path1 != path2)
531 |
532 | # Get rid of the last segment:
533 | del path2[-1]
534 | self.assertFalse(path1 == path2)
535 |
536 | # It's not equal to a list of it's segments
537 | self.assertTrue(path1 != path1[:])
538 | self.assertFalse(path1 == path1[:])
539 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Sketch-RNN Model."""
15 |
16 | from __future__ import absolute_import
17 | from __future__ import division
18 | from __future__ import print_function
19 |
20 | import os
21 | import random
22 |
23 | # internal imports
24 |
25 | import numpy as np
26 | import tensorflow as tf
27 | import tensorflow.contrib.slim as slim
28 | try:
29 | from magenta.models.sketch_rnn import rnn
30 | except:
31 | import magenta_rnn as rnn
32 |
33 | from build_subnet import *
34 | from tf_data_work import *
35 |
36 |
37 | def copy_hparams(hparams):
38 | """Return a copy of an HParams instance."""
39 | return tf.contrib.training.HParams(**hparams.values())
40 |
41 |
42 | class Model(object):
43 | """Define a SketchRNN model."""
44 |
45 | def __init__(self, hps, reuse=False):
46 | """Initializer for the SketchRNN model.
47 |
48 | Args:
49 | hps: a HParams object containing model hyperparameters
50 | gpu_mode: a boolean that when True, uses GPU mode.
51 | reuse: a boolean that when true, attemps to reuse variables.
52 | """
53 | self.hps = hps
54 | with tf.variable_scope('vector_rnn', reuse=reuse):
55 | self.build_model(hps)
56 |
57 | def encoder(self, input_batch, sequence_lengths, reuse):
58 | if self.hps.enc_type == 'rnn': # vae mode:
59 | image_embeddings = self.rnn_encoder(input_batch, sequence_lengths)
60 | elif self.hps.enc_type == 'cnn':
61 | image_embeddings = self.cnn_encoder(input_batch, reuse)
62 | elif self.hps.enc_type == 'feat':
63 | image_embeddings = input_batch
64 | else:
65 | raise Exception('Please choose a valid encoder type')
66 | return image_embeddings
67 |
68 | def rnn_encoder(self, batch, sequence_lengths):
69 |
70 | if self.hps.rnn_model == 'lstm':
71 | enc_cell_fn = rnn.LSTMCell
72 | elif self.hps.rnn_model == 'layer_norm':
73 | enc_cell_fn = rnn.LayerNormLSTMCell
74 | elif self.hps.rnn_model == 'hyper':
75 | enc_cell_fn = rnn.HyperLSTMCell
76 | else:
77 | assert False, 'please choose a respectable cell'
78 |
79 | if self.hps.rnn_model == 'hyper':
80 | self.enc_cell_fw = enc_cell_fn(
81 | self.hps.enc_rnn_size,
82 | use_recurrent_dropout=self.hps.use_recurrent_dropout,
83 | dropout_keep_prob=self.hps.recurrent_dropout_prob)
84 | self.enc_cell_bw = enc_cell_fn(
85 | self.hps.enc_rnn_size,
86 | use_recurrent_dropout=self.hps.use_recurrent_dropout,
87 | dropout_keep_prob=self.hps.recurrent_dropout_prob)
88 | else:
89 | self.enc_cell_fw = enc_cell_fn(
90 | self.hps.enc_rnn_size,
91 | use_recurrent_dropout=self.hps.use_recurrent_dropout,
92 | dropout_keep_prob=self.hps.recurrent_dropout_prob)
93 | self.enc_cell_bw = enc_cell_fn(
94 | self.hps.enc_rnn_size,
95 | use_recurrent_dropout=self.hps.use_recurrent_dropout,
96 | dropout_keep_prob=self.hps.recurrent_dropout_prob)
97 |
98 | """Define the bi-directional encoder module of sketch-rnn."""
99 | unused_outputs, last_states = tf.nn.bidirectional_dynamic_rnn(
100 | self.enc_cell_fw,
101 | self.enc_cell_bw,
102 | batch,
103 | sequence_length=sequence_lengths,
104 | time_major=False,
105 | swap_memory=True,
106 | dtype=tf.float32,
107 | scope='ENC_RNN')
108 |
109 | last_state_fw, last_state_bw = last_states
110 | last_h_fw = self.enc_cell_fw.get_output(last_state_fw)
111 | last_h_bw = self.enc_cell_bw.get_output(last_state_bw)
112 | last_h = tf.concat([last_h_fw, last_h_bw], 1)
113 | return last_h
114 |
115 | def cnn_encoder(self, batch_input, reuse):
116 | if self.hps.is_train:
117 | is_train = True
118 | dropout_keep_prob = self.hps.drop_kp
119 | else:
120 | is_train = False
121 | dropout_keep_prob = 1.0
122 | tf_batch_input = tf_image_processing(batch_input, self.hps.basenet, self.hps.crop_size, self.hps.dist_aug, self.hps.hp_filter)
123 | self.tf_images = tf_batch_input
124 | with tf.variable_scope(tf.get_variable_scope(), reuse=reuse):
125 | if self.hps.basenet == 'sketchanet':
126 | feature = sketch_a_net_slim(tf_batch_input)
127 |
128 | elif self.hps.basenet == 'gen_cnn':
129 | # feature = generative_cnn_encoder(tf_batch_input, is_train, dropout_keep_prob, reuse=reuse)
130 | feature = generative_cnn_encoder(tf_batch_input, True, dropout_keep_prob, reuse=reuse)
131 |
132 | elif FLAGS.basenet == 'alexnet':
133 | feature, end_points = tf_alexnet_single(tf_batch_input, dropout_keep_prob)
134 |
135 | elif FLAGS.basenet == 'vgg':
136 | _, feature = build_single_vggnet(tf_batch_input, is_train, dropout_keep_prob)
137 |
138 | elif FLAGS.basenet == 'resnet':
139 | print('Warning, resnet scope is not set')
140 | _, feature = build_single_resnet(tf_batch_input, is_train, name_scope='resnet_v1_50')
141 |
142 | elif FLAGS.basenet == 'inceptionv1':
143 | # _, feature = build_single_inceptionv1(tf_batch_input, is_train, dropout_keep_prob)
144 | _, feature = build_single_inceptionv1(tf_batch_input, True, dropout_keep_prob)
145 | # _, feature = build_single_inceptionv1(tf_batch_input, False, dropout_keep_prob)
146 |
147 | elif FLAGS.basenet == 'inceptionv3':
148 | # _, feature = build_single_inceptionv3(batch_input, is_train, dropout_keep_prob, reduce_dim=False)
149 | _, feature = build_single_inceptionv3(tf_batch_input, True, dropout_keep_prob, reduce_dim=False)
150 | # _, feature = build_single_inceptionv3(tf_batch_input, False, dropout_keep_prob, reduce_dim=False)
151 |
152 | else:
153 | raise Exception('basenet error')
154 | return feature
155 |
156 | def decoder(self, actual_input_x, initial_state, reuse):
157 |
158 | # decoder module of sketch-rnn is below
159 | with tf.variable_scope("RNN", reuse=reuse) as rnn_scope:
160 | output, last_state = tf.nn.dynamic_rnn(
161 | self.cell,
162 | actual_input_x,
163 | initial_state=initial_state,
164 | time_major=False,
165 | swap_memory=True,
166 | dtype=tf.float32,
167 | scope=rnn_scope)
168 | return output, last_state
169 |
170 | def cnn_decoder(self, z_input, reuse):
171 |
172 | if self.hps.is_train:
173 | is_train = True
174 | dropout_keep_prob = self.hps.drop_kp
175 | else:
176 | is_train = False
177 | dropout_keep_prob = 1.0
178 |
179 | output = generative_cnn_decoder(z_input, is_train, dropout_keep_prob, reuse)
180 |
181 | return output
182 |
183 | def get_mu_sig(self, image_embedding):
184 | enc_size = int(image_embedding.shape[-1])
185 | mu = rnn.super_linear(
186 | image_embedding,
187 | self.hps.z_size,
188 | input_size=enc_size,
189 | scope='ENC_RNN_mu',
190 | init_w='gaussian',
191 | weight_start=0.001)
192 | presig = rnn.super_linear(
193 | image_embedding,
194 | self.hps.z_size,
195 | input_size=enc_size,
196 | scope='ENC_RNN_sigma',
197 | init_w='gaussian',
198 | weight_start=0.001)
199 | return mu, presig
200 |
201 | def build_kl_for_vae(self, image_embedding, scope_name, with_state=True, reuse=False):
202 | with tf.variable_scope(scope_name, reuse=reuse):
203 | if with_state:
204 | return self.get_init_state(image_embedding)
205 | else:
206 | return self.get_kl_cost(image_embedding)
207 |
208 | def get_init_state(self, image_embedding):
209 | self.mean, self.presig = self.get_mu_sig(image_embedding)
210 | self.sigma = tf.exp(self.presig / 2.0) # sigma > 0. div 2.0 -> sqrt.
211 | eps = tf.random_normal(
212 | (self.hps.batch_size, self.hps.z_size), 0.0, 1.0, dtype=tf.float32)
213 | # batch_z = self.mean + tf.multiply(self.sigma, eps)
214 | if self.hps.is_train:
215 | batch_z = self.mean + tf.multiply(self.sigma, eps)
216 | else:
217 | batch_z = self.mean
218 | if self.hps.inter_z:
219 | batch_z = self.mean + tf.multiply(self.sigma, self.sample_gussian)
220 | # KL cost
221 | kl_cost = -0.5 * tf.reduce_mean(
222 | (1 + self.presig - tf.square(self.mean) - tf.exp(self.presig)))
223 | kl_cost = tf.maximum(kl_cost, self.hps.kl_tolerance)
224 |
225 | # get initial state based on batch_z
226 | initial_state = tf.nn.tanh(
227 | rnn.super_linear(
228 | batch_z,
229 | self.cell.state_size,
230 | init_w='gaussian',
231 | weight_start=0.001,
232 | input_size=self.hps.z_size))
233 | pre_tile_y = tf.reshape(batch_z, [self.hps.batch_size, 1, self.hps.z_size])
234 | overlay_x = tf.tile(pre_tile_y, [1, self.hps.max_seq_len, 1])
235 | actual_input_x = tf.concat([self.input_x, overlay_x], 2)
236 |
237 | return initial_state, actual_input_x, batch_z, kl_cost
238 |
239 | def get_kl_cost(self, image_embedding):
240 | self.mean, self.presig = self.get_mu_sig(image_embedding)
241 | self.sigma = tf.exp(self.presig / 2.0) # sigma > 0. div 2.0 -> sqrt.
242 | eps = tf.random_normal(
243 | (self.hps.batch_size, self.hps.z_size), 0.0, 1.0, dtype=tf.float32)
244 | batch_z = self.mean + tf.multiply(self.sigma, eps)
245 | # KL cost
246 | kl_cost = -0.5 * tf.reduce_mean(
247 | (1 + self.presig - tf.square(self.mean) - tf.exp(self.presig)))
248 | kl_cost = tf.maximum(kl_cost, self.hps.kl_tolerance)
249 |
250 | return batch_z, kl_cost
251 |
252 | def config_model(self, hps):
253 | """Define model architecture."""
254 | if hps.is_train:
255 | self.global_step = tf.Variable(0, name='global_step', trainable=False)
256 | # self.global_step = tf.get_variable('global_step', trainable=False)
257 |
258 | if hps.rnn_model == 'lstm':
259 | cell_fn = rnn.LSTMCell
260 | elif hps.rnn_model == 'layer_norm':
261 | cell_fn = rnn.LayerNormLSTMCell
262 | elif hps.rnn_model == 'hyper':
263 | cell_fn = rnn.HyperLSTMCell
264 | else:
265 | assert False, 'please choose a respectable cell'
266 |
267 | self.hps.crop_size, self.hps.chn_size = get_input_size()
268 |
269 | use_recurrent_dropout = self.hps.use_recurrent_dropout
270 | rnn_input_dropout = self.hps.rnn_input_dropout
271 | rnn_output_dropout = self.hps.rnn_output_dropout
272 |
273 | if hps.rnn_model == 'hyper':
274 | cell = cell_fn(
275 | hps.dec_rnn_size,
276 | use_recurrent_dropout=use_recurrent_dropout,
277 | dropout_keep_prob=self.hps.recurrent_dropout_prob)
278 | else:
279 | cell = cell_fn(
280 | hps.dec_rnn_size,
281 | use_recurrent_dropout=use_recurrent_dropout,
282 | dropout_keep_prob=self.hps.recurrent_dropout_prob)
283 |
284 | # dropout:
285 | if rnn_input_dropout:
286 | cell = tf.contrib.rnn.DropoutWrapper(cell, input_keep_prob=self.hps.input_dropout_prob)
287 | if rnn_output_dropout:
288 | cell = tf.contrib.rnn.DropoutWrapper(cell, output_keep_prob=self.hps.output_dropout_prob)
289 | self.cell = cell
290 |
291 | batch_size = self.hps.batch_size
292 | image_size = self.hps.image_size
293 |
294 | self.sequence_lengths = tf.placeholder(
295 | dtype=tf.int32, shape=[batch_size], name='seq_len')
296 | self.input_sketch = tf.placeholder(
297 | dtype=tf.float32,
298 | shape=[batch_size, self.hps.max_seq_len + 1, 5], name='input_sketch')
299 | self.target_sketch = tf.placeholder(
300 | dtype=tf.float32,
301 | shape=[batch_size, self.hps.max_seq_len + 1, 5], name='target_sketch')
302 |
303 | if self.hps.chn_size == 1:
304 | image_shape = [batch_size, image_size, image_size]
305 | else:
306 | image_shape = [batch_size, image_size, image_size, self.hps.chn_size]
307 |
308 | sketch_shape = [batch_size, image_size, image_size]
309 |
310 | if self.hps.vae_type == 's2s':
311 | self.input_photo = tf.placeholder(dtype=tf.float32, shape=sketch_shape, name='input_photo')
312 | else:
313 | self.input_photo = tf.placeholder(dtype=tf.float32, shape=image_shape, name='input_photo')
314 | if self.hps.vae_type in ['ps2s', 'sp2s']:
315 | self.input_sketch_photo = tf.placeholder(dtype=tf.float32, shape=sketch_shape, name='input_sketch_photo')
316 |
317 | self.input_label = tf.placeholder(dtype=tf.int32, shape=batch_size, name='label')
318 |
319 | if self.hps.inter_z:
320 | self.sample_gussian = tf.placeholder(dtype=tf.float32, shape=batch_size, name='sample_gussian')
321 |
322 | # The target/expected vectors of strokes
323 | self.output_x = self.input_sketch[:, 1:self.hps.max_seq_len + 1, :]
324 | # vectors of strokes to be fed to decoder (same as above, but lagged behind
325 | # one step to include initial dummy value of (0, 0, 1, 0, 0))
326 | self.input_x = self.input_sketch[:, :self.hps.max_seq_len, :]
327 |
328 | def build_pix_encoder(self, input_image, reuse=False):
329 |
330 | with tf.variable_scope(tf.get_variable_scope(), reuse=reuse):
331 | image_embedding = self.cnn_encoder(input_image, reuse)
332 |
333 | return image_embedding
334 |
335 | def build_seq_encoder(self, input_strokes, reuse=False):
336 |
337 | with tf.variable_scope(tf.get_variable_scope(), reuse=reuse):
338 | strokes_embedding = self.rnn_encoder(input_strokes, self.sequence_lengths)
339 |
340 | return strokes_embedding
341 |
342 | # ######################
343 |
344 | def build_seq_decoder(self, feat_embedding, kl_name_scope, reuse = False):
345 |
346 | initial_state, actual_input_x, batch_z, kl_cost = self.build_kl_for_vae(feat_embedding, kl_name_scope, reuse=False)
347 | with tf.variable_scope(tf.get_variable_scope(), reuse=reuse):
348 | output, last_state = self.decoder(actual_input_x, initial_state, reuse)
349 |
350 | return output, initial_state, last_state, actual_input_x, batch_z, kl_cost
351 |
352 | def build_pix_decoder(self, feat_embedding, kl_name_scope, reuse = False):
353 |
354 | batch_z, kl_cost = self.build_kl_for_vae(feat_embedding, kl_name_scope, with_state=False, reuse=False)
355 | with tf.variable_scope(tf.get_variable_scope(), reuse=reuse):
356 | output = self.cnn_decoder(batch_z, reuse)
357 |
358 | return output, batch_z, kl_cost
359 |
360 | def build_pix2seq_embedding(self, input_image, encode_pix=True, reuse = False):
361 |
362 | if encode_pix:
363 | self.pix_embedding = self.build_pix_encoder(input_image)
364 |
365 | return self.build_seq_decoder(self.pix_embedding, 'p2s', reuse=reuse)
366 |
367 | def build_seq2pix_embedding(self, input_strokes, encode_seq=True, reuse = False):
368 |
369 | if encode_seq:
370 | self.seq_embedding = self.build_seq_encoder(input_strokes)
371 |
372 | return self.build_pix_decoder(self.seq_embedding, 's2p', reuse=reuse)
373 |
374 | def build_pix2pix_embedding(self, input_image, encode_pix=True, reuse = False):
375 |
376 | if encode_pix:
377 | self.pix_embedding = self.build_pix_encoder(input_image)
378 |
379 | return self.build_pix_decoder(self.pix_embedding, 'p2p', reuse=reuse)
380 |
381 | def build_seq2seq_embedding(self, input_strokes, encode_seq=True, reuse=False):
382 |
383 | if encode_seq:
384 | self.seq_embedding = self.build_seq_encoder(input_strokes)
385 |
386 | return self.build_seq_decoder(self.seq_embedding, 's2s', reuse=reuse)
387 |
388 | def build_seq_loss(self, output, initial_state, final_state, batch_z, kl_cost, vae_type, reuse=False):
389 | # code for output, pi miu
390 | pi, mu1, mu2, sigma1, sigma2, corr, pen_logits, pen, y1_data, y2_data, r_cost, r_score, gen_strokes = self.build_strokes_rcons(output, reuse=reuse)
391 | # self.pen: pen state probabilities (result of applying softmax to self.pen_logits)
392 |
393 | r_cost = self.hps.seq_lw * r_cost
394 |
395 | cost_dict = {'rcons': r_cost, 'kl': kl_cost}
396 |
397 | end_points = {'init_s': initial_state, 'fin_s': final_state, 'pi': pi, 'mu1': mu1, 'mu2': mu2, 'sigma1': sigma1,
398 | 'sigma2': sigma2, 'corr': corr, 'pen': pen, 'batch_z': batch_z}
399 |
400 | cost_dict_vae_type = {'%s_%s' % (vae_type, key): cost_dict[key] for key in cost_dict.keys()}
401 | end_points_vae_type = {'%s_%s' % (vae_type, key): end_points[key] for key in end_points.keys()}
402 |
403 | return gen_strokes, cost_dict_vae_type, end_points_vae_type
404 |
405 | def build_pix_loss(self, gen_photo, batch_z, kl_cost, vae_type, reuse=False):
406 | r_cost = self.build_photo_rcons(self.target_photo, gen_photo)
407 | # self.pen: pen state probabilities (result of applying softmax to self.pen_logits)
408 |
409 | r_cost = self.hps.pix_lw * r_cost
410 |
411 | cost_dict = {'rcons': r_cost, 'kl': kl_cost}
412 |
413 | gen_photo_rgb = tf.cast((gen_photo + 1) * 127.5, tf.int16)
414 |
415 | end_points = {'gen_photo': gen_photo, 'gen_photo_rgb': gen_photo_rgb, 'batch_z': batch_z}
416 |
417 | cost_dict_vae_type = {'%s_%s' % (vae_type, key): cost_dict[key] for key in cost_dict.keys()}
418 | end_points_vae_type = {'%s_%s' % (vae_type, key): end_points[key] for key in end_points.keys()}
419 |
420 | return gen_photo_rgb, cost_dict_vae_type, end_points_vae_type
421 |
422 | def build_strokes_rcons(self, output, reuse=False):
423 |
424 | # target data
425 | x1_data, x2_data = self.x1_data, self.x2_data
426 | eos_data, eoc_data, cont_data = self.eos_data, self.eoc_data, self.cont_data
427 |
428 | # TODO(deck): Better understand this comment.
429 | # Number of outputs is 3 (one logit per pen state) plus 6 per mixture
430 | # component: mean_x, stdev_x, mean_y, stdev_y, correlation_xy, and the
431 | # mixture weight/probability (Pi_k)
432 | n_out = (3 + self.hps.num_mixture * 6)
433 |
434 | with tf.variable_scope(tf.get_variable_scope(), reuse=reuse):
435 |
436 | with tf.variable_scope('RNN'):
437 | output_w = tf.get_variable('output_w', [self.hps.dec_rnn_size, n_out])
438 | output_b = tf.get_variable('output_b', [n_out])
439 |
440 | output_reshape = tf.reshape(output, [-1, self.hps.dec_rnn_size])
441 |
442 | output_mdn = tf.nn.xw_plus_b(output_reshape, output_w, output_b)
443 |
444 | o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, o_corr, o_pen, o_pen_logits = get_mixture_coef(output_mdn)
445 |
446 | o_id = tf.stack([tf.range(0, tf.shape(o_mu1)[0]), tf.cast(tf.argmax(o_pi, 1), tf.int32)], axis=1)
447 | y1_data = tf.gather_nd(o_mu1, o_id)
448 | y2_data = tf.gather_nd(o_mu2, o_id)
449 |
450 | y3_data = tf.cast(tf.greater(tf.argmax(o_pen_logits, 1), 0), tf.float32)
451 | gen_strokes = tf.stack([y1_data, y2_data, y3_data], 1)
452 | gen_strokes = tf.reshape(gen_strokes, [self.hps.batch_size, -1, 3])
453 | start_points_np = np.zeros((self.hps.batch_size, 1, 3))
454 | start_points_tf = tf.constant(start_points_np, dtype=tf.float32)
455 | gen_strokes = tf.concat([start_points_tf, gen_strokes], 1)
456 |
457 | pen_data = tf.concat([eos_data, eoc_data, cont_data], 1)
458 |
459 | rcons_loss = get_rcons_loss_pen_state(o_pen_logits, pen_data, self.hps.is_train)
460 |
461 | rcons_loss += get_rcons_loss_mdn(o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, o_corr, x1_data, x2_data, pen_data)
462 |
463 | r_cost = tf.reduce_mean(rcons_loss)
464 | r_score = -tf.reduce_sum(tf.reshape(rcons_loss, [self.hps.batch_size, -1]), axis=1)
465 | return o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, o_corr, o_pen_logits, o_pen, y1_data, y2_data, r_cost, r_score, gen_strokes
466 |
467 | def build_photo_rcons(self, real_images, output_images):
468 |
469 | # target data
470 | image_shape = real_images.get_shape()
471 | output_images_reshaped = tf.reshape(output_images, image_shape)
472 | pixel_losses = tf.reduce_mean(tf.square(real_images - output_images_reshaped))
473 | return pixel_losses
474 |
475 | def build_l2_loss(self):
476 |
477 | self.rnn_l2 = tf.reduce_mean(tf.square(self.end_points['p2s_batch_z'] - self.end_points['s2s_batch_z']))
478 | self.cnn_l2 = tf.reduce_mean(tf.square(self.end_points['s2p_batch_z'] - self.end_points['p2p_batch_z']))
479 |
480 | self.l2_cost = self.rnn_l2 + self.cnn_l2
481 |
482 | def build_seq_discriminator(self, x, y, l, reuse):
483 |
484 | # set label for orig and gen
485 | label_r = tf.ones([self.hps.batch_size, 1], tf.int32)
486 | label_f = tf.zeros([self.hps.batch_size, 1], tf.int32)
487 |
488 | # build domain classifier
489 | cell_type, n_hidden, num_layers = self.hps.dis_model, self.hps.dis_num_hidden, self.hps.dis_num_layers
490 | in_dp, out_dp, batch_size = self.hps.dis_input_dropout, self.hps.dis_output_dropout, self.hps.batch_size
491 | pred_r, logits_r = rnn_discriminator(x, l, cell_type, n_hidden, num_layers, in_dp, out_dp, batch_size, reuse=reuse)
492 | pred_f, logits_f = rnn_discriminator(y, l, cell_type, n_hidden, num_layers, in_dp, out_dp, batch_size, reuse=True)
493 |
494 | # if self.hps.w_gan:
495 | # dis_loss, gen_loss = wgan_gp_loss(logits_r, logits_f, None, use_gradients=False)
496 | # dis_acc = tf.constant(-1.0)
497 | # else:
498 | dis_loss, gen_loss, dis_acc = get_adv_loss(logits_r, logits_f, label_r, label_f)
499 |
500 | dis_loss *= self.hps.rnn_dis_lw
501 | gen_loss *= self.hps.rnn_gen_lw
502 |
503 | return dis_loss, gen_loss, dis_acc
504 |
505 | def build_pix_discriminator(self, x, y, reuse):
506 | # set label for orig and gen
507 | label_r = tf.ones([self.hps.batch_size, 1], tf.int32)
508 | label_f = tf.zeros([self.hps.batch_size, 1], tf.int32)
509 |
510 | # build domain classifier
511 | batch_size = self.hps.batch_size
512 | pred_r, logits_r = cnn_discriminator(x, batch_size, reuse=reuse)
513 | pred_f, logits_f = cnn_discriminator(y, batch_size, reuse=True)
514 |
515 | if self.hps.gp_gan:
516 | alpha = tf.random_uniform(shape=[self.hps.batch_size, 1, 1, 1], minval=0., maxval=1.)
517 | differences = y - x
518 | interpolates = x + (alpha*differences)
519 | gradients = tf.gradients(cnn_discriminator(interpolates, batch_size, reuse=True)[1], [interpolates])[0]
520 | dis_loss, gen_loss, dis_acc = get_adv_gp_loss(logits_r, logits_f, label_r, label_f, gradients)
521 | else:
522 | dis_loss, gen_loss, dis_acc = get_adv_loss(logits_r, logits_f, label_r, label_f)
523 |
524 | dis_loss *= self.hps.cnn_dis_lw
525 | gen_loss *= self.hps.cnn_gen_lw
526 |
527 | return dis_loss, gen_loss, dis_acc
528 |
529 | def build_wgan_seq_discriminator(self, x, y, l, reuse):
530 |
531 | print("Build wgan seq discriminator")
532 |
533 | # build domain classifier
534 | logits_r = wgan_gp_rnn_discriminator(x, reuse=reuse)
535 | logits_f = wgan_gp_rnn_discriminator(y, reuse=True)
536 |
537 | alpha = tf.random_uniform(shape=[self.hps.batch_size, 1, 1], minval=0., maxval=1.)
538 | differences = y - x
539 | interpolates = x + (alpha*differences)
540 | gradients = tf.gradients(wgan_gp_rnn_discriminator(interpolates, reuse=True), [interpolates])[0]
541 | dis_loss, gen_loss = wgan_gp_loss(logits_f, logits_r, gradients)
542 | dis_acc = tf.constant(-1.0)
543 |
544 | dis_loss *= self.hps.rnn_dis_lw
545 | gen_loss *= self.hps.rnn_gen_lw
546 |
547 | return dis_loss, gen_loss, dis_acc
548 |
549 | def build_wgan_pix_discriminator(self, x, y, reuse):
550 |
551 | print("Build wgan pix discriminator")
552 |
553 | # build domain classifier
554 | logits_r = wgan_gp_cnn_discriminator(x, reuse=reuse)
555 | logits_f = wgan_gp_cnn_discriminator(y, reuse=True)
556 |
557 | alpha = tf.random_uniform(shape=[self.hps.batch_size, 1, 1, 1], minval=0., maxval=1.)
558 | differences = y - x
559 | interpolates = x + (alpha*differences)
560 | gradients = tf.gradients(wgan_gp_cnn_discriminator(interpolates, reuse=True), [interpolates])[0]
561 | dis_loss, gen_loss = wgan_gp_loss(logits_f, logits_r, gradients)
562 | dis_acc = tf.constant(-1.0)
563 |
564 | dis_loss *= self.hps.cnn_dis_lw
565 | gen_loss *= self.hps.cnn_gen_lw
566 |
567 | return dis_loss, gen_loss, dis_acc
568 |
569 | def get_train_vars(self):
570 |
571 | self.t_vars = tf.trainable_variables()
572 | self.d_vars = [var for var in self.t_vars if 'DIS' in var.name]
573 | self.g_vars = [var for var in self.t_vars if 'DIS' not in var.name]
574 |
575 | def get_train_op(self):
576 |
577 | self.apply_decay()
578 |
579 | # get total loss
580 | self.get_total_loss()
581 |
582 | # get train vars
583 | self.get_train_vars()
584 |
585 | optimizer = tf.train.AdamOptimizer(self.lr)
586 | gvs = optimizer.compute_gradients(self.cost)
587 |
588 | capped_gvs = clip_gradients(gvs, self.hps.grad_clip)
589 | self.train_op = optimizer.apply_gradients(
590 | capped_gvs, global_step=self.global_step, name='train_step')
591 |
592 | def apply_decay(self):
593 | if self.hps.lr_decay:
594 | self.lr = tf.train.exponential_decay(self.hps.lr, self.global_step, self.hps.decay_step, self.hps.decay_rate, staircase=True)
595 | else:
596 | self.lr = self.hps.lr
597 |
598 | # self.kl_weight = tf.Variable(self.hps.kl_weight_start, trainable=False)
599 | if self.hps.kl_weight_decay:
600 | self.kl_weight = tf.train.exponential_decay(self.hps.kl_weight_start, self.global_step, self.hps.kl_decay_step, self.hps.kl_decay_rate, staircase=True)
601 | else:
602 | self.kl_weight = self.hps.kl_weight_start
603 |
604 | if self.hps.l2_weight_decay:
605 | self.l2_weight = tf.train.exponential_decay(self.hps.l2_weight_start, self.global_step, self.hps.l2_decay_step, self.hps.l2_decay_rate, staircase=True)
606 | else:
607 | self.l2_weight = self.hps.l2_weight_start
608 |
609 | def get_total_loss(self):
610 | self.p2s_kl, self.s2p_kl = self.cost_dict['p2s_kl'], self.cost_dict['s2p_kl']
611 | self.p2p_kl, self.s2s_kl = self.cost_dict['p2p_kl'], self.cost_dict['s2s_kl']
612 | self.kl_cost = self.p2s_kl + self.s2p_kl + self.p2p_kl + self.s2s_kl
613 | self.cost = self.kl_cost * self.kl_weight
614 |
615 | # get reconstruction loss
616 | self.p2s_r, self.s2p_r = self.cost_dict['p2s_rcons'], self.cost_dict['s2p_rcons']
617 | self.p2p_r, self.s2s_r = self.cost_dict['p2p_rcons'], self.cost_dict['s2s_rcons']
618 | self.r_cost = self.p2s_r + self.s2p_r + self.p2p_r + self.s2s_r
619 | self.cost += self.r_cost
620 |
621 | def get_target_strokes(self):
622 | target = tf.reshape(self.output_x, [-1, 5])
623 | # reshape target data so that it is compatible with prediction shape
624 | [self.x1_data, self.x2_data, self.eos_data, self.eoc_data, self.cont_data] = tf.split(target, 5, 1)
625 | start_points_np = np.zeros((self.hps.batch_size, 1, 3))
626 | start_points_tf = tf.constant(start_points_np, dtype=tf.float32)
627 | self.target_strokes = tf.concat([self.x1_data, self.x2_data, 1 - self.eos_data], 1)
628 | self.target_strokes = tf.reshape(self.target_strokes, [self.hps.batch_size, -1, 3])
629 | self.target_strokes = tf.concat([start_points_tf, self.target_strokes], 1)
630 |
631 | def get_target_photo(self):
632 | self.target_photo = \
633 | tf_image_processing(self.input_photo, self.hps.basenet, self.hps.crop_size, self.hps.dist_aug, self.hps.hp_filter)
634 |
635 | def build_model(self, hps):
636 | self.config_model(hps)
637 |
638 | # get target data
639 | self.get_target_strokes()
640 | self.get_target_photo()
641 |
642 | # build photo to stroke-level synthesis part
643 | self.gen_strokes, cost_dict_p2s, end_points_p2s = self.build_pix2seq_branch(self.input_photo)
644 |
645 | # build stroke-level to photo synthesis part
646 | self.gen_photo, cost_dict_s2p, end_points_s2p = self.build_seq2pix_branch(self.input_sketch)
647 |
648 | # build photo to photo reconstruction part
649 | self.recon_photo, cost_dict_p2p, end_points_p2p = self.build_pix2pix_branch(self.input_photo, encode_pix=False, reuse=True)
650 |
651 | # build sketch to sketch reconstruction part
652 | self.recon_sketch, cost_dict_s2s, end_points_s2s = self.build_seq2seq_branch(self.input_sketch, encode_seq=False, reuse=True)
653 |
654 | self.cost_dict = dict(cost_dict_p2s.items() + cost_dict_s2p.items() + cost_dict_p2p.items() + cost_dict_s2s.items())
655 | self.end_points = dict(end_points_p2s.items() + end_points_s2p.items() + end_points_p2p.items() + end_points_s2s.items())
656 |
657 | self.initial_state, self.final_state = self.end_points['p2s_init_s'], self.end_points['p2s_fin_s']
658 | self.pi, self.corr = self.end_points['p2s_pi'], self.end_points['p2s_corr']
659 | self.mu1, self.mu2 = self.end_points['p2s_mu1'], self.end_points['p2s_mu2']
660 | self.sigma1, self.sigma2 = self.end_points['p2s_sigma1'], self.end_points['p2s_sigma2']
661 | self.pen = self.end_points['p2s_pen']
662 | self.batch_z = self.end_points['p2s_batch_z']
663 |
664 | self.recon_initial_state, self.recon_final_state = self.end_points['s2s_init_s'], self.end_points['s2s_fin_s']
665 | self.recon_pi, self.recon_corr = self.end_points['s2s_pi'], self.end_points['s2s_corr']
666 | self.recon_mu1, self.recon_mu2 = self.end_points['s2s_mu1'], self.end_points['s2s_mu2']
667 | self.recon_sigma1, self.recon_sigma2 = self.end_points['s2s_sigma1'], self.end_points['s2s_sigma2']
668 | self.recon_pen = self.end_points['s2s_pen']
669 | self.recon_batch_z = self.end_points['s2s_batch_z']
670 |
671 | if self.hps.is_train:
672 | # self.get_train_op_with_bn() # dosen't work
673 | self.get_train_op()
674 |
675 | def build_pix2seq_branch(self, input_photo, encode_pix=True, reuse=False):
676 | # pixel to sequence
677 | output, initial_state, final_state, actual_input_x, batch_z, kl_cost = \
678 | self.build_pix2seq_embedding(input_photo, encode_pix=encode_pix, reuse=reuse)
679 |
680 | return self.build_seq_loss(output, initial_state, final_state, batch_z, kl_cost, 'p2s', reuse=reuse)
681 |
682 | def build_seq2pix_branch(self, input_strokes, encode_seq=True, reuse=False):
683 | # sequence to pixel
684 | gen_photo, batch_z, kl_cost = self.build_seq2pix_embedding(input_strokes, encode_seq=encode_seq, reuse=reuse)
685 |
686 | return self.build_pix_loss(gen_photo, batch_z, kl_cost, 's2p', reuse=reuse)
687 |
688 | def build_pix2pix_branch(self, input_photo, encode_pix=False, reuse=False):
689 | # pixel to pixel
690 | gen_photo, batch_z, kl_cost = self.build_pix2pix_embedding(input_photo, encode_pix=encode_pix, reuse=reuse)
691 |
692 | return self.build_pix_loss(gen_photo, batch_z, kl_cost, 'p2p', reuse=reuse)
693 |
694 | def build_seq2seq_branch(self, input_strokes, encode_seq=False, reuse=False):
695 | output, initial_state, final_state, actual_input_x, batch_z, kl_cost = \
696 | self.build_seq2seq_embedding(input_strokes, encode_seq=encode_seq, reuse=reuse)
697 |
698 | return self.build_seq_loss(output, initial_state, final_state, batch_z, kl_cost, 's2s', reuse=reuse)
699 |
700 |
701 | def get_pi_idx(pdf):
702 | """Samples from a pdf."""
703 | return np.argmax(pdf)
704 |
705 |
706 | def sample(sess, model, input_image, sketch=None, seq_len=250, temperature=0.5, with_sketch=False, rnn_enc_seq_len = None, cond_sketch=False, inter_z=False, inter_z_sample=0):
707 | """Samples a sequence from a pre-trained model."""
708 |
709 | prev_x = np.zeros((1, 1, 5), dtype=np.float32)
710 | prev_x[0, 0, 2] = 1 # initially, we want to see beginning of new stroke
711 | # print("enter the function of sample")
712 |
713 | if cond_sketch:
714 | if int(model.input_photo.get_shape()[-1]) == 3:
715 | input_image = input_image[:,:,:,np.newaxis]
716 | input_image = np.concatenate([input_image, input_image, input_image], -1)
717 |
718 | if rnn_enc_seq_len is None:
719 | if inter_z:
720 | prev_state = sess.run(model.initial_state, feed_dict={model.input_photo: input_image, model.sample_gussian: inter_z_sample})
721 | else:
722 | prev_state = sess.run(model.initial_state, feed_dict={model.input_photo: input_image})
723 | # image_embedding = sess.run(model.image_embedding, feed_dict={model.input_photo: input_image})
724 | # batch_z = sess.run(model.batch_z, feed_dict={model.input_photo: input_image})
725 | else:
726 | if inter_z:
727 | prev_state = sess.run(model.initial_state, feed_dict={model.input_photo: input_image, model.sequence_lengths: rnn_enc_seq_len, model.sample_gussian: inter_z_sample})
728 | else:
729 | prev_state = sess.run(model.initial_state, feed_dict={model.input_photo: input_image, model.sequence_lengths: rnn_enc_seq_len})
730 | # image_embedding = sess.run(model.image_embedding, feed_dict={model.input_photo: input_image, model.sequence_lengths: rnn_enc_seq_len})
731 | # batch_z = sess.run(model.batch_z, feed_dict={model.input_photo: input_image, model.sequence_lengths: rnn_enc_seq_len})
732 |
733 | strokes = np.zeros((seq_len, 5), dtype=np.float32)
734 | mixture_params = []
735 |
736 | for i in range(seq_len):
737 | if inter_z:
738 | feed = {
739 | model.input_x: prev_x,
740 | model.sequence_lengths: [1],
741 | model.initial_state: prev_state,
742 | model.input_photo: input_image,
743 | model.sample_gussian: inter_z_sample
744 | }
745 | else:
746 | feed = {
747 | model.input_x: prev_x,
748 | model.sequence_lengths: [1],
749 | model.initial_state: prev_state,
750 | model.input_photo: input_image
751 | }
752 |
753 | params = sess.run([
754 | model.pi, model.mu1, model.mu2, model.sigma1, model.sigma2, model.corr,
755 | model.pen, model.final_state
756 | ], feed)
757 |
758 | [o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, o_corr, o_pen, next_state] = params
759 |
760 | idx = get_pi_idx(o_pi[0])
761 |
762 | idx_eos = get_pi_idx(o_pen[0])
763 | eos = [0, 0, 0]
764 | eos[idx_eos] = 1
765 |
766 | next_x1, next_x2 = o_mu1[0][idx], o_mu2[0][idx]
767 |
768 | strokes[i, :] = [next_x1, next_x2, eos[0], eos[1], eos[2]]
769 |
770 | params = [o_pi[0], o_mu1[0], o_mu2[0], o_sigma1[0], o_sigma2[0], o_corr[0], o_pen[0]]
771 |
772 | mixture_params.append(params)
773 |
774 | prev_x = np.zeros((1, 1, 5), dtype=np.float32)
775 | if with_sketch:
776 | prev_x[0][0] = sketch[0][i+1]
777 | else:
778 | prev_x[0][0] = np.array(
779 | [next_x1, next_x2, eos[0], eos[1], eos[2]], dtype=np.float32)
780 | prev_state = next_state
781 |
782 | return strokes, mixture_params
783 |
784 |
785 | def sample_recons(sess, model, gen_model, input_sketch, sketch=None, seq_len=250, temperature=0.5, with_sketch=False, cond_sketch=False, inter_z=False, inter_z_sample=0):
786 | """Samples a sequence from a pre-trained model."""
787 |
788 | prev_x = np.zeros((1, 1, 5), dtype=np.float32)
789 | prev_x[0, 0, 2] = 1 # initially, we want to see beginning of new stroke
790 | # print("enter the function of sample")
791 |
792 | if inter_z:
793 | feed_dict = {
794 | gen_model.input_sketch: input_sketch,
795 | gen_model.sequence_lengths: [seq_len],
796 | gen_model.sample_gussian: inter_z_sample
797 | }
798 | prev_state, batch_z = sess.run([gen_model.recon_initial_state, gen_model.recon_batch_z], feed_dict=feed_dict)
799 | else:
800 | feed_dict = {
801 | gen_model.input_sketch: input_sketch,
802 | gen_model.sequence_lengths: [seq_len]
803 | }
804 | prev_state, batch_z = sess.run([gen_model.recon_initial_state, gen_model.recon_batch_z], feed_dict=feed_dict)
805 |
806 | strokes = np.zeros((seq_len, 5), dtype=np.float32)
807 | mixture_params = []
808 |
809 | for i in range(seq_len):
810 | if not model.hps.concat_z:
811 | feed = {
812 | model.input_x: prev_x,
813 | model.sequence_lengths: [1],
814 | model.recon_initial_state: prev_state
815 | }
816 | elif inter_z:
817 | feed = {
818 | model.input_x: prev_x,
819 | model.sequence_lengths: [1],
820 | model.recon_initial_state: prev_state,
821 | model.sample_gussian: inter_z_sample,
822 | model.recon_batch_z: batch_z
823 | }
824 | else:
825 | feed = {
826 | model.input_x: prev_x,
827 | model.sequence_lengths: [1],
828 | model.recon_initial_state: prev_state,
829 | model.recon_batch_z: batch_z
830 | }
831 |
832 | params = sess.run([
833 | model.recon_pi, model.recon_mu1, model.recon_mu2, model.recon_sigma1, model.recon_sigma2, model.recon_corr,
834 | model.recon_pen, model.recon_final_state
835 | ], feed)
836 |
837 | [o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, o_corr, o_pen, next_state] = params
838 |
839 | idx = get_pi_idx(o_pi[0])
840 |
841 | idx_eos = get_pi_idx(o_pen[0])
842 | eos = [0, 0, 0]
843 | eos[idx_eos] = 1
844 |
845 | next_x1, next_x2 = o_mu1[0][idx], o_mu2[0][idx]
846 |
847 | strokes[i, :] = [next_x1, next_x2, eos[0], eos[1], eos[2]]
848 |
849 | params = [o_pi[0], o_mu1[0], o_mu2[0], o_sigma1[0], o_sigma2[0], o_corr[0], o_pen[0]]
850 |
851 | mixture_params.append(params)
852 |
853 | prev_x = np.zeros((1, 1, 5), dtype=np.float32)
854 | if with_sketch:
855 | prev_x[0][0] = sketch[0][i+1]
856 | else:
857 | prev_x[0][0] = np.array(
858 | [next_x1, next_x2, eos[0], eos[1], eos[2]], dtype=np.float32)
859 | prev_state = next_state
860 |
861 | return strokes, mixture_params
862 |
863 |
864 | def get_init_fn(pretrain_model, checkpoint_exclude_scopes):
865 | """Returns a function run by the chief worker to warm-start the training."""
866 | print("load pretrained model from %s" % pretrain_model)
867 | exclusions = [scope.strip() for scope in checkpoint_exclude_scopes]
868 |
869 | variables_to_restore = []
870 | # for var in slim.get_model_variables():
871 | for var in tf.trainable_variables():
872 | excluded = False
873 | for exclusion in exclusions:
874 | if var.op.name.startswith(exclusion):
875 | excluded = True
876 | break
877 | if not excluded:
878 | print(var.name)
879 | variables_to_restore.append(var)
880 |
881 | return slim.assign_from_checkpoint_fn(pretrain_model, variables_to_restore)
882 |
883 |
884 | def get_input_size():
885 | if FLAGS.basenet == 'alexnet':
886 | crop_size = 227
887 | channel_size = 3
888 | elif FLAGS.basenet in ['resnet', 'inceptionv3']:
889 | crop_size = 299
890 | channel_size = 3
891 | elif FLAGS.basenet in ['sketchynet', 'inceptionv1', 'resnet', 'vgg', 'mobilenet', 'gen_cnn']:
892 | crop_size = 224
893 | channel_size = 3
894 | else:
895 | crop_size = 225
896 | channel_size = 1
897 | return crop_size, channel_size
898 |
--------------------------------------------------------------------------------