├── .babelrc
├── .gitignore
├── LICENSE
├── README.md
├── bin
└── convert.js
├── examples
├── Log In Page.svg
└── login_page.sketch
├── images
├── export_instructions.png
├── export_instructions_2.png
├── sketch_conversion.png
└── sketch_to_react_native.png
├── package-lock.json
├── package.json
├── scripts
├── __init__.py
├── count_ops.py
├── evaluate.py
├── graph_pb2tb.py
├── label_image.py
├── quantize_graph.py
└── show_image.py
├── src
├── attributes.js
├── components.js
├── flex.js
├── index.js
├── input.js
├── lib
│ ├── files.js
│ └── utils.js
├── neural_net.js
├── output.js
├── process.js
└── screenshot.js
└── tf_files
├── retrained_graph.pb
└── retrained_labels.txt
/.babelrc:
--------------------------------------------------------------------------------
1 | {
2 | "presets": [
3 | [
4 | "env",
5 | {
6 | "node": "6.10.0"
7 | }
8 | ]
9 | ],
10 | "plugins": [
11 | "transform-regenerator",
12 | "transform-async-to-generator",
13 | [
14 | "transform-object-rest-spread",
15 | {
16 | "useBuiltIns": true
17 | }
18 | ],
19 | "transform-es2015-destructuring"
20 | ]
21 | }
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | node_modules
2 | *.svg
3 | !examples/*.svg
4 | output/*
5 | temp/*
6 | .DS_Store
7 | scripts/__pycache__/*
8 | build
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2017 NanoHop
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Convert Sketch files to React Native components
2 |
3 | * Sketch App: https://www.sketchapp.com/
4 | * React Native: https://facebook.github.io/react-native/
5 |
6 | Do you have designs in Sketch, and need to turn those into a mobile app? This will take those designs, and automatically create React Native components.
7 |
8 | 
9 |
10 | ## Want to try it without installing everything?
11 |
12 | Send me your Sketch file and instructions, and I'll run it and email the output back to you: chris@nanohop.com
13 |
14 | ## Want to collaborate?
15 |
16 | [Join us on Slack!](https://join.slack.com/t/design-to-code/shared_invite/enQtMjU5MzQ2OTAzNDEzLWE1MjUzMWVhNzVlNTFlZTFmYmRjODNkNDZmMDI1M2NhMTcwZjgwM2Q5M2Q3OTk5YmNhNTI5MDRmZDk5NmY1MWY)
17 |
18 | ## Want this as a service?
19 |
20 | We offer that! We also offer a **human in the loop**, where we'll clean up the output before it goes back to you. Send me an email to learn more: chris@nanohop.com
21 |
22 | *****
23 |
24 | # Getting started
25 |
26 | ### Prerequisites:
27 |
28 | * Node 8.5.0+ https://nodejs.org/en/
29 | * Python 3.6.1+ https://www.python.org/downloads/
30 | * Install TensorFlow https://www.tensorflow.org/install/
31 |
32 | ### Steps to run:
33 |
34 | ```bash
35 | > git clone https://github.com/nanohop/sketch-to-react-native.git
36 | > cd sketch-to-react-native
37 | > npm install && npm link
38 | > sketch-to-react-native ~/Desktop/myfile.svg
39 | ```
40 |
41 | ### Extract the component from Sketch as an SVG:
42 |
43 | 
44 |
45 | 
46 |
47 | ### Use that SVG file as the argument for convert.js
48 |
49 | ```bash
50 | > sketch-to-react-native input-file.svg
51 | ```
52 |
53 | It will run and save the output to the ./output folder. Make sure to grab both the .js file, and the associated _images_ folder! Drop that into your React Native application, and see the magic!
54 |
55 | ## What if it doesn't work?
56 |
57 | Please let me know! This is early software, and I'm trying to solve as many edge cases as I can find. Please file an issue or send me an email.
58 |
59 | # Conversion process
60 |
61 | Sketch exports a fairly clean SVG, which makes the process easier, but there is still a lot of processing involved. Here's the basic steps:
62 |
63 | 1. Prep the SVG to make processing easier
64 | 2. Use a deep neural net to filter out unwanted elements
65 | 3. Get component bounding boxes using headless Chrome
66 | 4. Figure out all the child/parent relationships
67 | 5. Convert from absolute (pixel) positioning to Flex Box
68 | 6. Extract images from SVG paths and polygons
69 | 7. Generate styles for every component
70 | 8. Generate components
71 | 9. Export to an output file
72 |
73 |
74 | ### A note about the most difficult part
75 |
76 | Sketch (and the SVG export) is a pixel based (absolute positioned) system - but that's no good for React Native. One of the primary difficulties was figuring out what the proper parent / sibling relationships were between all the components, and then converting from absolute positioning to flexBox. There is still some work to clean this part up, but in general it's working fairly well for most inputs.
77 |
78 | This also means that components in Sketch can _partially_ overlap - which doesn't do well during the conversion process. I'm working on a fix for that; but until then - you'll get best results if there are no partially overlapping components in your Sketch file. (Fully overlapping is fine - those will get properly converted to a parent => child relationship)
79 |
80 | # FAQ
81 |
82 | ## Didn't Airbnb already release something that does this?
83 |
84 | Nope! You're probably thinking of [react-sketchapp](https://github.com/airbnb/react-sketchapp) that takes react components, and generates Sketch files. This goes the opposite direction: it takes Sketch files, and creates React Native components.
85 |
86 |
87 | ## Why React Native (mobile)? Why not React (web)?
88 |
89 | I started with mobile because the app designs for mobile are generally more straightforward, with less variation - so it was easier to make the first version. I'm planning to do React for web as well though! Send me an email if you'd like updates when it's available: chris@nanohop.com
90 |
91 |
92 | ## Is the generated code any good?
93 |
94 | I'm a mobile developer myself - and we're rightfully fearful of any generated code. It is really high on my priority list to keep the output as clean and readable as possible. It's not perfect all the time, but it is one of my top priorities.
95 |
96 |
97 | ## How much time does this save?
98 |
99 | I've found that screens that would normally take me about an hour to create, take as little as 10 minutes - so that's as much as 80% time savings! The output does have to be cleaned up a little bit generally, but I find it usually provides good starting point.
100 |
101 |
102 | ## Why use headless Chrome?
103 |
104 | It seems like overkill, but headless Chrome has a great SVG rendering engine, and it was the easiest way to get bounding boxes to work, and to export the SVG assets as pngs. That will probably change in the future.
105 |
106 |
107 | ## Is there a way to try this out without installing everything?
108 |
109 | I plan to get a hosted version up at some point, but until then you can email me your Sketch file and instructions, and I'll run it and email the component back to you: chris@nanohop.com
110 |
111 |
112 | ## What can't it do yet?
113 |
114 | This is a work in progress, so there are a few things it doesn't do well: overlapping components, reusing component styles, reusing common components, collapsing unnecessary wrapping Views, (and more). I have a long roadmap of features to add.
115 |
116 |
117 | ## Is there a Sketch plugin to do this automatically?
118 |
119 | Not yet, but it's on the roadmap!
120 |
121 |
122 | ## How can I help?
123 |
124 | If you'd like to help, I'd love to have you involved! Feel free to file issues, or send me an email with any Sketch file that doesn't work quite right, and I'll also review and merge pull requests as well.
125 |
126 |
127 |
128 | # Example
129 |
130 |
131 | 
132 |
133 |
134 | You can see that it's not perfect - but it provides a really good starting point, and it's getting better all the time!
135 |
136 |
137 | ## Here's the generated code:
138 |
139 |
140 |
141 | ```javascript
142 |
143 | import React, { Component } from 'react';
144 |
145 | import {
146 | StyleSheet,
147 | Text,
148 | View,
149 | TouchableOpacity,
150 | TextInput,
151 | ScrollView,
152 | Image
153 | } from 'react-native';
154 |
155 | import BackArrow from './Log_In_Page_images/Back-Arrow.png'
156 | import Logo from './Log_In_Page_images/Logo.png'
157 |
158 | export default class Main extends Component {
159 |
160 | render() {
161 | return (
162 |
166 |
167 |
168 |
169 |
170 | Email Address
171 |
172 |
173 | Password
174 |
175 |
176 | Log In
177 |
178 |
179 |
180 | )
181 | }
182 |
183 | }
184 |
185 | const styles = StyleSheet.create({
186 | Base: {
187 | height: 680,
188 | backgroundColor: '#5CC5F8',
189 | borderRadius: 6,
190 | paddingTop: 20,
191 | paddingBottom: 122
192 | },
193 | BackArrow: {
194 | alignSelf: 'flex-start',
195 | marginLeft: 18
196 | },
197 | Logo: {
198 | alignSelf: 'center',
199 | marginTop: 25
200 | },
201 | EmailInput: {
202 | height: 64,
203 | backgroundColor: '#FAFAFA',
204 | borderRadius: 6,
205 | alignSelf: 'center',
206 | marginTop: 49,
207 | width: 292,
208 | alignItems: 'flex-start',
209 | marginLeft: 30,
210 | justifyContent: 'center'
211 | },
212 | EmailAddress: {
213 | backgroundColor: 'transparent',
214 | fontSize: 18,
215 | fontWeight: '300',
216 | color: '#444444',
217 | textAlign: 'left',
218 | marginLeft: 30
219 | },
220 | PasswordInput: {
221 | height: 64,
222 | backgroundColor: '#FAFAFA',
223 | borderRadius: 6,
224 | alignSelf: 'center',
225 | marginTop: 14,
226 | width: 292,
227 | alignItems: 'flex-start',
228 | marginLeft: 30,
229 | justifyContent: 'center'
230 | },
231 | Passsord: {
232 | backgroundColor: 'transparent',
233 | fontSize: 18,
234 | fontWeight: '300',
235 | color: '#444444',
236 | textAlign: 'left',
237 | marginLeft: 30
238 | },
239 | LoginButton: {
240 | height: 64,
241 | backgroundColor: '#332AC6',
242 | borderRadius: 6,
243 | alignSelf: 'center',
244 | marginTop: 121,
245 | width: 292,
246 | alignItems: 'center',
247 | justifyContent: 'center'
248 | },
249 | LogIn: {
250 | backgroundColor: 'transparent',
251 | fontSize: 24,
252 | fontWeight: '300',
253 | color: '#FFFFFF',
254 | textAlign: 'center'
255 | }
256 | })
257 |
258 |
259 | ```
260 |
261 |
262 |
263 |
--------------------------------------------------------------------------------
/bin/convert.js:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env node
2 | require('../build');
--------------------------------------------------------------------------------
/examples/Log In Page.svg:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/examples/login_page.sketch:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nanohop/sketch-to-react-native/6075a766749f6ed419ff0f34abdb81f8cc4cf0f3/examples/login_page.sketch
--------------------------------------------------------------------------------
/images/export_instructions.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nanohop/sketch-to-react-native/6075a766749f6ed419ff0f34abdb81f8cc4cf0f3/images/export_instructions.png
--------------------------------------------------------------------------------
/images/export_instructions_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nanohop/sketch-to-react-native/6075a766749f6ed419ff0f34abdb81f8cc4cf0f3/images/export_instructions_2.png
--------------------------------------------------------------------------------
/images/sketch_conversion.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nanohop/sketch-to-react-native/6075a766749f6ed419ff0f34abdb81f8cc4cf0f3/images/sketch_conversion.png
--------------------------------------------------------------------------------
/images/sketch_to_react_native.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nanohop/sketch-to-react-native/6075a766749f6ed419ff0f34abdb81f8cc4cf0f3/images/sketch_to_react_native.png
--------------------------------------------------------------------------------
/package.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "svg_to_react",
3 | "version": "1.0.0",
4 | "description": "",
5 | "main": "src/index.js",
6 | "dependencies": {
7 | "puppeteer": "^0.10.2",
8 | "svgson": "^2.1.0"
9 | },
10 | "devDependencies": {
11 | "babel-cli": "^6.26.0",
12 | "babel-core": "^6.26.0",
13 | "babel-plugin-transform-async-to-generator": "^6.24.1",
14 | "babel-plugin-transform-es2015-destructuring": "^6.23.0",
15 | "babel-plugin-transform-object-rest-spread": "^6.26.0",
16 | "babel-plugin-transform-regenerator": "^6.26.0",
17 | "babel-preset-env": "^1.6.1",
18 | "regenerator-runtime": "^0.11.0"
19 | },
20 | "scripts": {
21 | "convert": "babel-node ./src",
22 | "build": "babel ./src -d ./build",
23 | "prepare": "npm run build",
24 | "test": "echo \"Error: no test specified\" && exit 1"
25 | },
26 | "bin": {
27 | "sketch-to-react-native": "./bin/convert.js"
28 | },
29 | "engines": {
30 | "node": ">=6.10.0"
31 | },
32 | "author": "",
33 | "license": "ISC"
34 | }
35 |
--------------------------------------------------------------------------------
/scripts/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | #
3 | # Copyright 2017 Google Inc.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | from __future__ import absolute_import
17 | from __future__ import division
18 | from __future__ import print_function
19 |
--------------------------------------------------------------------------------
/scripts/count_ops.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | #
3 | # Copyright 2017 Google Inc.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | from __future__ import absolute_import
17 | from __future__ import division
18 | from __future__ import print_function
19 |
20 | import os
21 |
22 | import sys
23 | import tensorflow as tf
24 |
25 | def load_graph(file_name):
26 | with open(file_name,'rb') as f:
27 | content = f.read()
28 | graph_def = tf.GraphDef()
29 | graph_def.ParseFromString(content)
30 | with tf.Graph().as_default() as graph:
31 | tf.import_graph_def(graph_def, name='')
32 | return graph
33 |
34 | def count_ops(file_name, op_name = None):
35 | graph = load_graph(file_name)
36 |
37 | if op_name is None:
38 | return len(graph.get_operations())
39 | else:
40 | return sum(1 for op in graph.get_operations()
41 | if op.name == op_name)
42 |
43 | if __name__ == "__main__":
44 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
45 | print(count_ops(*sys.argv[1:]))
46 |
47 |
--------------------------------------------------------------------------------
/scripts/evaluate.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | #
3 | # Copyright 2017 Google Inc.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | from __future__ import absolute_import
17 | from __future__ import division
18 | from __future__ import print_function
19 |
20 | import os
21 |
22 | import sys
23 | import argparse
24 |
25 | import numpy as np
26 | import PIL.Image as Image
27 | import tensorflow as tf
28 |
29 | import scripts.retrain as retrain
30 | from scripts.count_ops import load_graph
31 |
32 | def evaluate_graph(graph_file_name):
33 | with load_graph(graph_file_name).as_default() as graph:
34 | ground_truth_input = tf.placeholder(
35 | tf.float32, [None, 5], name='GroundTruthInput')
36 |
37 | image_buffer_input = graph.get_tensor_by_name('input:0')
38 | final_tensor = graph.get_tensor_by_name('final_result:0')
39 | accuracy, _ = retrain.add_evaluation_step(final_tensor, ground_truth_input)
40 |
41 | logits = graph.get_tensor_by_name("final_training_ops/Wx_plus_b/add:0")
42 | xent = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
43 | labels = ground_truth_input,
44 | logits = logits))
45 |
46 | image_dir = 'tf_files/flower_photos'
47 | testing_percentage = 10
48 | validation_percentage = 10
49 | validation_batch_size = 100
50 | category='testing'
51 |
52 | image_lists = retrain.create_image_lists(
53 | image_dir, testing_percentage,
54 | validation_percentage)
55 | class_count = len(image_lists.keys())
56 |
57 | ground_truths = []
58 | file_names = []
59 |
60 | for label_index, label_name in enumerate(image_lists.keys()):
61 | for image_index, image_name in enumerate(image_lists[label_name][category]):
62 | image_name = retrain.get_image_path(
63 | image_lists, label_name, image_index, image_dir, category)
64 | ground_truth = np.zeros([1, class_count], dtype=np.float32)
65 | ground_truth[0, label_index] = 1.0
66 | ground_truths.append(ground_truth)
67 | file_names.append(image_name)
68 |
69 | accuracies = []
70 | xents = []
71 | with tf.Session(graph=graph) as sess:
72 | for filename, ground_truth in zip(file_names, ground_truths):
73 | image = Image.open(filename).resize((224,224),Image.ANTIALIAS)
74 | image = np.array(image, dtype=np.float32)[None,...]
75 | image = (image-128)/128.0
76 |
77 | feed_dict={
78 | image_buffer_input: image,
79 | ground_truth_input: ground_truth}
80 |
81 | eval_accuracy, eval_xent = sess.run([accuracy, xent], feed_dict)
82 |
83 | accuracies.append(eval_accuracy)
84 | xents.append(eval_xent)
85 |
86 |
87 | return np.mean(accuracies), np.mean(xents)
88 |
89 | if __name__ == "__main__":
90 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
91 | accuracy,xent = evaluate_graph(*sys.argv[1:])
92 | print('Accuracy: %g' % accuracy)
93 | print('Cross Entropy: %g' % xent)
94 |
95 |
96 |
97 |
--------------------------------------------------------------------------------
/scripts/graph_pb2tb.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | #
3 | # Copyright 2017 Google Inc.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import os
18 | import sys
19 |
20 | import tensorflow as tf
21 |
22 | def load_graph(graph_pb_path):
23 | with open(graph_pb_path,'rb') as f:
24 | content = f.read()
25 | graph_def = tf.GraphDef()
26 | graph_def.ParseFromString(content)
27 | with tf.Graph().as_default() as graph:
28 | tf.import_graph_def(graph_def, name='')
29 | return graph
30 |
31 |
32 | def graph_to_tensorboard(graph, out_dir):
33 | with tf.Session():
34 | train_writer = tf.summary.FileWriter(out_dir)
35 | train_writer.add_graph(graph)
36 |
37 |
38 | def main(out_dir, graph_pb_path):
39 | graph = load_graph(graph_pb_path)
40 | graph_to_tensorboard(graph, out_dir)
41 |
42 | if __name__ == "__main__":
43 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
44 | main(*sys.argv[1:])
45 |
--------------------------------------------------------------------------------
/scripts/label_image.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import argparse
20 | import sys
21 |
22 | import numpy as np
23 | import tensorflow as tf
24 | import os
25 | os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
26 |
27 | def load_graph(model_file):
28 | graph = tf.Graph()
29 | graph_def = tf.GraphDef()
30 |
31 | with open(model_file, "rb") as f:
32 | graph_def.ParseFromString(f.read())
33 | with graph.as_default():
34 | tf.import_graph_def(graph_def)
35 |
36 | return graph
37 |
38 | def read_tensor_from_image_file(file_name, input_height=299, input_width=299,
39 | input_mean=0, input_std=255):
40 | input_name = "file_reader"
41 | output_name = "normalized"
42 | file_reader = tf.read_file(file_name, input_name)
43 | if file_name.endswith(".png"):
44 | image_reader = tf.image.decode_png(file_reader, channels = 3,
45 | name='png_reader')
46 | elif file_name.endswith(".gif"):
47 | image_reader = tf.squeeze(tf.image.decode_gif(file_reader,
48 | name='gif_reader'))
49 | elif file_name.endswith(".bmp"):
50 | image_reader = tf.image.decode_bmp(file_reader, name='bmp_reader')
51 | else:
52 | image_reader = tf.image.decode_jpeg(file_reader, channels = 3,
53 | name='jpeg_reader')
54 | float_caster = tf.cast(image_reader, tf.float32)
55 | dims_expander = tf.expand_dims(float_caster, 0);
56 | resized = tf.image.resize_bilinear(dims_expander, [input_height, input_width])
57 | normalized = tf.divide(tf.subtract(resized, [input_mean]), [input_std])
58 | sess = tf.Session()
59 | result = sess.run(normalized)
60 |
61 | return result
62 |
63 | def load_labels(label_file):
64 | label = []
65 | proto_as_ascii_lines = tf.gfile.GFile(label_file).readlines()
66 | for l in proto_as_ascii_lines:
67 | label.append(l.rstrip())
68 | return label
69 |
70 | if __name__ == "__main__":
71 | file_name = "tf_files/flower_photos/daisy/3475870145_685a19116d.jpg"
72 | model_file = "tf_files/retrained_graph.pb"
73 | label_file = "tf_files/retrained_labels.txt"
74 | input_height = 224
75 | input_width = 224
76 | input_mean = 128
77 | input_std = 128
78 | input_layer = "input"
79 | output_layer = "final_result"
80 |
81 | parser = argparse.ArgumentParser()
82 | parser.add_argument("--image", help="image to be processed")
83 | parser.add_argument("--graph", help="graph/model to be executed")
84 | parser.add_argument("--labels", help="name of file containing labels")
85 | parser.add_argument("--input_height", type=int, help="input height")
86 | parser.add_argument("--input_width", type=int, help="input width")
87 | parser.add_argument("--input_mean", type=int, help="input mean")
88 | parser.add_argument("--input_std", type=int, help="input std")
89 | parser.add_argument("--input_layer", help="name of input layer")
90 | parser.add_argument("--output_layer", help="name of output layer")
91 | args = parser.parse_args()
92 |
93 | if args.graph:
94 | model_file = args.graph
95 | if args.image:
96 | file_name = args.image
97 | if args.labels:
98 | label_file = args.labels
99 | if args.input_height:
100 | input_height = args.input_height
101 | if args.input_width:
102 | input_width = args.input_width
103 | if args.input_mean:
104 | input_mean = args.input_mean
105 | if args.input_std:
106 | input_std = args.input_std
107 | if args.input_layer:
108 | input_layer = args.input_layer
109 | if args.output_layer:
110 | output_layer = args.output_layer
111 |
112 | graph = load_graph(model_file)
113 | t = read_tensor_from_image_file(file_name,
114 | input_height=input_height,
115 | input_width=input_width,
116 | input_mean=input_mean,
117 | input_std=input_std)
118 |
119 | input_name = "import/" + input_layer
120 | output_name = "import/" + output_layer
121 | input_operation = graph.get_operation_by_name(input_name);
122 | output_operation = graph.get_operation_by_name(output_name);
123 |
124 | with tf.Session(graph=graph) as sess:
125 | results = sess.run(output_operation.outputs[0],
126 | {input_operation.outputs[0]: t})
127 | results = np.squeeze(results)
128 |
129 | top_k = results.argsort()[-5:][::-1]
130 | labels = load_labels(label_file)
131 | for i in top_k:
132 | print(labels[i], results[i])
133 |
--------------------------------------------------------------------------------
/scripts/quantize_graph.py:
--------------------------------------------------------------------------------
1 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | r"""Transforms a float-trained graph into an equivalent quantized version.
16 |
17 | An example of command-line usage is:
18 | bazel build tensorflow/tools/quantization:quantize_graph \
19 | && bazel-bin/tensorflow/tools/quantization/quantize_graph \
20 | --input=tensorflow_inception_graph.pb
21 | --output_node_names="softmax2" --print_nodes --output=/tmp/quantized_graph.pb \
22 | --mode=eightbit --logtostderr
23 |
24 | """
25 |
26 | from __future__ import absolute_import
27 | from __future__ import division
28 | from __future__ import print_function
29 |
30 | import os
31 | import collections
32 | import re
33 | import numpy as np
34 |
35 | from tensorflow.core.framework import attr_value_pb2
36 | from tensorflow.core.framework import graph_pb2
37 | from tensorflow.core.framework import node_def_pb2
38 | from tensorflow.python.client import session
39 | from tensorflow.python.framework import constant_op
40 | from tensorflow.python.framework import dtypes
41 | from tensorflow.python.framework import graph_util
42 | from tensorflow.python.framework import importer
43 | from tensorflow.python.framework import ops
44 | from tensorflow.python.framework import tensor_shape
45 | from tensorflow.python.framework import tensor_util
46 | from tensorflow.python.ops import array_ops
47 | from tensorflow.python.platform import app
48 | from tensorflow.python.platform import flags as flags_lib
49 | from tensorflow.python.platform import gfile
50 |
51 | flags = flags_lib
52 | FLAGS = flags.FLAGS
53 |
54 | flags.DEFINE_boolean("print_nodes", False, """Lists all nodes in the model.""")
55 | flags.DEFINE_string("input", "", """TensorFlow 'GraphDef' file to load.""")
56 | flags.DEFINE_string("output_node_names", "",
57 | """Output node names, comma separated.""")
58 | flags.DEFINE_string("output", "", """File to save the output graph to.""")
59 | flags.DEFINE_integer("bitdepth", 8,
60 | """How many bits to quantize the graph to.""")
61 | flags.DEFINE_string("mode", "round",
62 | """What transformation to apply (round, quantize,"""
63 | """ eightbit, weights, or weights_rounded).""")
64 | flags.DEFINE_string("test_input_dims", "1,224,224,3",
65 | """The size of the input tensor to use when testing a"""
66 | """ graph loaded from a file.""")
67 | flags.DEFINE_boolean("strip_redundant_quantization", True,
68 | """Removes redundant dequantize/quantize pairs.""")
69 | flags.DEFINE_boolean("quantized_input", False,
70 | "If true, assume Placeholders are quantized with values "
71 | "covering [--quantized_input_min,--quantized_input_max]. "
72 | "Only supported when --mode=eightbit")
73 | flags.DEFINE_float("quantized_input_min", 0,
74 | "The minimum of the actual input range when "
75 | "--quantized_input")
76 | flags.DEFINE_float("quantized_input_max", 1,
77 | "The maximum of the actual input range when "
78 | "--quantized_input")
79 | flags.DEFINE_float(
80 | "quantized_fallback_min", None,
81 | "The fallback 'min' value to use for layers which lack min-max "
82 | "information. Note: this should be considered a coarse tool just good "
83 | "enough for experimentation purposes, since graphs quantized in this way "
84 | "would be very inaccurate.")
85 | flags.DEFINE_float(
86 | "quantized_fallback_max", None,
87 | "The fallback 'max' value to use for layers which lack min-max "
88 | "information. Note: this should be considered a coarse tool just good "
89 | "enough for experimentation purposes, since graphs quantized in this way "
90 | "would be very inaccurate.")
91 |
92 |
93 | def print_input_nodes(current_node, nodes_map, indent, already_visited):
94 | print(" " * indent + current_node.op + ":" + current_node.name)
95 | already_visited[current_node.name] = True
96 | for input_node_name in current_node.input:
97 | if input_node_name in already_visited:
98 | continue
99 | input_node = nodes_map[input_node_name]
100 | print_input_nodes(input_node, nodes_map, indent + 1, already_visited)
101 |
102 |
103 | def create_node(op, name, inputs):
104 | new_node = node_def_pb2.NodeDef()
105 | new_node.op = op
106 | new_node.name = name
107 | for input_name in inputs:
108 | new_node.input.extend([input_name])
109 | return new_node
110 |
111 |
112 | def create_constant_node(name, value, dtype, shape=None):
113 | node = create_node("Const", name, [])
114 | set_attr_dtype(node, "dtype", dtype)
115 | set_attr_tensor(node, "value", value, dtype, shape)
116 | return node
117 |
118 |
119 | def copy_attr(node, key, attr_value):
120 | try:
121 | node.attr[key].CopyFrom(attr_value)
122 | except KeyError:
123 | pass
124 |
125 |
126 | def set_attr_dtype(node, key, value):
127 | try:
128 | node.attr[key].CopyFrom(
129 | attr_value_pb2.AttrValue(type=value.as_datatype_enum))
130 | except KeyError:
131 | pass
132 |
133 |
134 | def set_attr_shape(node, key, value):
135 | try:
136 | node.attr[key].CopyFrom(
137 | attr_value_pb2.AttrValue(shape=tensor_shape.as_shape(value).as_proto()))
138 | except KeyError:
139 | pass
140 |
141 |
142 | def set_attr_tensor(node, key, value, dtype, shape=None):
143 | try:
144 | node.attr[key].CopyFrom(
145 | attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto(
146 | value, dtype=dtype, shape=shape)))
147 | except KeyError:
148 | pass
149 |
150 |
151 | def set_attr_string(node, key, value):
152 | try:
153 | node.attr[key].CopyFrom(attr_value_pb2.AttrValue(s=value))
154 | except KeyError:
155 | pass
156 |
157 |
158 | def set_attr_int_list(node, key, value):
159 | list_value = attr_value_pb2.AttrValue.ListValue(i=value)
160 | try:
161 | node.attr[key].CopyFrom(attr_value_pb2.AttrValue(list=list_value))
162 | except KeyError:
163 | pass
164 |
165 |
166 | def set_attr_bool(node, key, value):
167 | try:
168 | node.attr[key].CopyFrom(attr_value_pb2.AttrValue(b=value))
169 | except KeyError:
170 | pass
171 |
172 |
173 | def set_attr_int(node, key, value):
174 | try:
175 | node.attr[key].CopyFrom(attr_value_pb2.AttrValue(i=value))
176 | except KeyError:
177 | pass
178 |
179 |
180 | def set_attr_float(node, key, value):
181 | try:
182 | node.attr[key].CopyFrom(attr_value_pb2.AttrValue(f=value))
183 | except KeyError:
184 | pass
185 |
186 |
187 | def node_name_from_input(node_name):
188 | """Strips off ports and other decorations to get the underlying node name."""
189 | if node_name.startswith("^"):
190 | node_name = node_name[1:]
191 | m = re.search(r"(.*):\d+$", node_name)
192 | if m:
193 | node_name = m.group(1)
194 | return node_name
195 |
196 |
197 | def ensure_tensor_name_has_port(node_name):
198 | """Makes sure that a tensor name has :0 if no explicit port exists."""
199 | m = re.search(r"(.*):\d+$", node_name)
200 | if m:
201 | name_with_port = node_name
202 | else:
203 | name_with_port = node_name + ":0"
204 | return name_with_port
205 |
206 |
207 | def unique_node_name_from_input(node_name):
208 | """Replaces invalid characters in input names to get a unique node name."""
209 | return node_name.replace(":", "__port__").replace("^", "__hat__")
210 |
211 |
212 | def quantize_array(arr, num_buckets):
213 | """Quantizes a numpy array.
214 |
215 | This function maps each scalar in arr to the center of one of num_buckets
216 | buckets. For instance,
217 | quantize_array([0, 0.3, 0.6, 1], 2) => [0.25, 0.25, 0.75, 0.75]
218 |
219 | Args:
220 | arr: The numpy array to quantize.
221 | num_buckets: The number of buckets to map "var" to.
222 | Returns:
223 | The quantized numpy array.
224 | Raises:
225 | ValueError: when num_buckets < 1.
226 | """
227 | if num_buckets < 1:
228 | raise ValueError("num_buckets must be >= 1")
229 | arr_max = arr.max()
230 | arr_min = arr.min()
231 | if arr_max == arr_min:
232 | return arr
233 | bucket_width = (arr_max - arr_min) / num_buckets
234 | # Map scalars to bucket indices. Take special care of max(arr).
235 | bucket_indices = np.floor((arr - arr_min) / bucket_width)
236 | bucket_indices[bucket_indices == num_buckets] = num_buckets - 1
237 | # Map each scalar to the center of a bucket.
238 | arr = arr_min + bucket_width * (bucket_indices + 0.5)
239 | return arr
240 |
241 |
242 | def quantize_weight_rounded(input_node):
243 | """Returns a replacement node for input_node containing bucketed floats."""
244 | input_tensor = input_node.attr["value"].tensor
245 | tensor_value = tensor_util.MakeNdarray(input_tensor)
246 | shape = input_tensor.tensor_shape
247 | # Currently, the parameter FLAGS.bitdepth is used to compute the
248 | # number of buckets as 1 << FLAGS.bitdepth, meaning the number of
249 | # buckets can only be a power of 2.
250 | # This could be fixed by introducing a new parameter, num_buckets,
251 | # which would allow for more flexibility in chosing the right model
252 | # size/accuracy tradeoff. But I didn't want to add more parameters
253 | # to this script than absolutely necessary.
254 | num_buckets = 1 << FLAGS.bitdepth
255 | tensor_value_rounded = quantize_array(tensor_value, num_buckets)
256 | tensor_shape_list = tensor_util.TensorShapeProtoToList(shape)
257 | return [
258 | create_constant_node(
259 | input_node.name,
260 | tensor_value_rounded,
261 | dtypes.float32,
262 | shape=tensor_shape_list)
263 | ]
264 |
265 |
266 | def quantize_weight_eightbit(input_node, quantization_mode):
267 | """Returns replacement nodes for input_node using the Dequantize op."""
268 | base_name = input_node.name + "_"
269 | quint8_const_name = base_name + "quint8_const"
270 | min_name = base_name + "min"
271 | max_name = base_name + "max"
272 | float_tensor = tensor_util.MakeNdarray(input_node.attr["value"].tensor)
273 | min_value = np.min(float_tensor.flatten())
274 | max_value = np.max(float_tensor.flatten())
275 | # Make sure that the range includes zero.
276 | if min_value > 0.0:
277 | min_value = 0.0
278 | # min_value == max_value is a tricky case. It can occur for general
279 | # tensors, and of course for scalars. The quantized ops cannot deal
280 | # with this case, so we set max_value to something else.
281 | # It's a tricky question what is the numerically best solution to
282 | # deal with this degeneracy.
283 | # TODO(petewarden): Better use a tolerance than a hard comparison?
284 | if min_value == max_value:
285 | if abs(min_value) < 0.000001:
286 | max_value = min_value + 1.0
287 | elif min_value > 0:
288 | max_value = 2 * min_value
289 | else:
290 | max_value = min_value / 2.0
291 |
292 | sess = session.Session()
293 | with sess.as_default():
294 | quantize_op = array_ops.quantize_v2(
295 | float_tensor,
296 | min_value,
297 | max_value,
298 | dtypes.quint8,
299 | mode=quantization_mode)
300 | quint8_tensor = quantize_op[0].eval()
301 | shape = tensor_util.TensorShapeProtoToList(input_node.attr["value"]
302 | .tensor.tensor_shape)
303 | quint8_const_node = create_constant_node(
304 | quint8_const_name, quint8_tensor, dtypes.quint8, shape=shape)
305 | min_node = create_constant_node(min_name, min_value, dtypes.float32)
306 | max_node = create_constant_node(max_name, max_value, dtypes.float32)
307 | dequantize_node = create_node("Dequantize", input_node.name,
308 | [quint8_const_name, min_name, max_name])
309 | set_attr_dtype(dequantize_node, "T", dtypes.quint8)
310 | set_attr_string(dequantize_node, "mode", quantization_mode)
311 | return [quint8_const_node, min_node, max_node, dequantize_node]
312 |
313 |
314 | EightbitizeRecursionState = collections.namedtuple(
315 | "EightbitizeRecursionState",
316 | ["already_visited", "output_node_stack", "merged_with_fake_quant"])
317 |
318 |
319 | class GraphRewriter(object):
320 | """Takes a float graph, and rewrites it in quantized form."""
321 |
322 | def __init__(self,
323 | input_graph,
324 | mode,
325 | quantized_input_range,
326 | fallback_quantization_range=None):
327 | """Sets up the class to rewrite a float graph.
328 |
329 | Args:
330 | input_graph: A float graph to transform.
331 | mode: A string controlling how quantization is performed -
332 | round, quantize, eightbit, or weights.
333 | quantized_input_range: if set, assume the input is
334 | quantized and represents the range
335 | [quantized_input_range[0], quantized_input_range[1]]
336 | fallback_quantization_range: if set, then for nodes where the quantization
337 | range can't be inferred from the graph, use the range
338 | [fallback_quantization_range[0], fallback_quantization_range[1]) instead
339 | of using a RequantizationRange node in the graph.
340 |
341 | Raises:
342 | ValueError: Two nodes with the same name were found in the graph.
343 | """
344 | self.input_graph = input_graph
345 | self.nodes_map = self.create_nodes_map(input_graph)
346 | self.output_graph = None
347 | self.mode = mode
348 | self.final_node_renames = {}
349 | if quantized_input_range:
350 | self.input_range = (quantized_input_range[0], quantized_input_range[1])
351 | if self.input_range[0] >= self.input_range[1]:
352 | raise ValueError("Invalid quantized_input_range: [%s,%s]" %
353 | self.input_range)
354 | if self.mode != "eightbit":
355 | raise ValueError(
356 | "quantized_input_range can only be specified in eightbit mode")
357 | else:
358 | self.input_range = None
359 |
360 | if fallback_quantization_range:
361 | self.fallback_quantization_range = [
362 | fallback_quantization_range[0], fallback_quantization_range[1]
363 | ]
364 | if (self.fallback_quantization_range[0] >=
365 | self.fallback_quantization_range[1]):
366 | raise ValueError("Invalid fallback_quantization_range: [%s,%s]" %
367 | self.fallback_quantization_range)
368 | if self.mode != "eightbit":
369 | raise ValueError("fallback_quantization_range can only be "
370 | "specified in eightbit mode")
371 | else:
372 | self.fallback_quantization_range = None
373 |
374 | # Data that is valid only during the recursive call to rewrite the graph.
375 | self.state = None
376 |
377 | def create_nodes_map(self, graph):
378 | """Builds a mapping of node names to their defs from the graph."""
379 | nodes_map = {}
380 | for node in graph.node:
381 | if node.name not in nodes_map.keys():
382 | nodes_map[node.name] = node
383 | else:
384 | raise ValueError("Duplicate node names detected.")
385 | return nodes_map
386 |
387 | def rewrite(self, output_node_names):
388 | """Triggers rewriting of the float graph.
389 |
390 | Args:
391 | output_node_names: A list of names of the nodes that produce the final
392 | results.
393 |
394 | Returns:
395 | A quantized version of the float graph.
396 | """
397 | self.output_graph = graph_pb2.GraphDef()
398 | output_nodes = [
399 | self.nodes_map[output_node_name]
400 | for output_node_name in output_node_names
401 | ]
402 | if self.mode == "round":
403 | self.already_visited = {}
404 | for output_node in output_nodes:
405 | self.round_nodes_recursively(output_node)
406 | elif self.mode == "quantize":
407 | self.already_visited = {}
408 | self.already_quantized = {}
409 | for output_node in output_nodes:
410 | self.quantize_nodes_recursively(output_node)
411 | elif self.mode == "eightbit":
412 | self.set_input_graph(graph_util.remove_training_nodes(self.input_graph))
413 | output_nodes = [
414 | self.nodes_map[output_node_name]
415 | for output_node_name in output_node_names
416 | ]
417 |
418 | self.state = EightbitizeRecursionState(
419 | already_visited={}, output_node_stack=[], merged_with_fake_quant={})
420 | for output_node in output_nodes:
421 | self.eightbitize_nodes_recursively(output_node)
422 | self.state = None
423 | if self.input_range:
424 | self.add_output_graph_node(
425 | create_constant_node("quantized_input_min_value", self.input_range[
426 | 0], dtypes.float32, []))
427 | self.add_output_graph_node(
428 | create_constant_node("quantized_input_max_value", self.input_range[
429 | 1], dtypes.float32, []))
430 | if self.fallback_quantization_range:
431 | self.add_output_graph_node(
432 | create_constant_node("fallback_quantization_min_value",
433 | self.fallback_quantization_range[0],
434 | dtypes.float32, []))
435 | self.add_output_graph_node(
436 | create_constant_node("fallback_quantization_max_value",
437 | self.fallback_quantization_range[1],
438 | dtypes.float32, []))
439 | if FLAGS.strip_redundant_quantization:
440 | self.output_graph = self.remove_redundant_quantization(
441 | self.output_graph)
442 | self.remove_dead_nodes(output_node_names)
443 | self.apply_final_node_renames()
444 | elif self.mode == "weights":
445 | self.output_graph = self.quantize_weights(self.input_graph,
446 | b"MIN_COMBINED")
447 | self.remove_dead_nodes(output_node_names)
448 | elif self.mode == "weights_rounded":
449 | self.output_graph = self.quantize_weights(self.input_graph, self.mode)
450 | self.remove_dead_nodes(output_node_names)
451 | else:
452 | print("Bad mode - " + self.mode + ".")
453 | return self.output_graph
454 |
455 | def round_nodes_recursively(self, current_node):
456 | """The entry point for simple rounding quantization."""
457 | if self.already_visited[current_node.name]:
458 | return
459 | self.already_visited[current_node.name] = True
460 | for input_node_name in current_node.input:
461 | input_node_name = node_name_from_input(input_node_name)
462 | input_node = self.nodes_map[input_node_name]
463 | self.round_nodes_recursively(input_node)
464 | nodes_to_quantize = ["Conv2D", "BiasAdd", "MatMul"]
465 | if any(current_node.op in s for s in nodes_to_quantize):
466 | new_node = node_def_pb2.NodeDef()
467 | new_node.CopyFrom(current_node)
468 | new_node.name = current_node.name + "_original"
469 | self.add_output_graph_node(new_node)
470 | levels = 1 << FLAGS.bitdepth
471 | constant_name = current_node.name + "_round_depth"
472 | constant_tensor = constant_op.constant(
473 | levels, dtype=dtypes.int32, name=constant_name)
474 | constant_node = constant_tensor.op.node_def
475 | self.add_output_graph_node(constant_node)
476 | quantize_node = node_def_pb2.NodeDef()
477 | quantize_node.op = "RoundToSteps"
478 | quantize_node.name = current_node.name
479 | quantize_node.input.extend([current_node.name + "_original"])
480 | quantize_node.input.extend([constant_node.name])
481 | self.add_output_graph_node(quantize_node)
482 | else:
483 | new_node = node_def_pb2.NodeDef()
484 | new_node.CopyFrom(current_node)
485 | self.add_output_graph_node(new_node)
486 |
487 | def quantize_nodes_recursively(self, current_node):
488 | """The entry point for quantizing nodes to eight bit and back."""
489 | if self.already_visited[current_node.name]:
490 | return
491 | self.already_visited[current_node.name] = True
492 | for input_node_name in current_node.input:
493 | input_node_name = node_name_from_input(input_node_name)
494 | input_node = self.nodes_map[input_node_name]
495 | self.quantize_nodes_recursively(input_node)
496 | nodes_to_quantize = ["Conv2D", "BiasAdd", "MatMul"]
497 | if any(current_node.op in s for s in nodes_to_quantize):
498 | for input_name in current_node.input:
499 | input_name = node_name_from_input(input_name)
500 | input_node = self.nodes_map[input_name]
501 | self.quantize_node(input_node)
502 | self.quantize_node(current_node)
503 | else:
504 | new_node = node_def_pb2.NodeDef()
505 | new_node.CopyFrom(current_node)
506 | self.add_output_graph_node(new_node)
507 |
508 | def quantize_node(self, input_node):
509 | """Handles quantizing a single node."""
510 | input_name = input_node.name
511 | if input_name in self.already_quantized:
512 | return
513 | self.already_quantized[input_name] = True
514 | original_input_name = input_name + "_original"
515 | reshape_name = input_name + "_reshape"
516 | reshape_dims_name = input_name + "_reshape_dims"
517 | max_name = input_name + "_max"
518 | min_name = input_name + "_min"
519 | dims_name = input_name + "_dims"
520 | quantize_name = input_name + "_quantize"
521 | dequantize_name = input_name
522 | original_input_node = node_def_pb2.NodeDef()
523 | original_input_node.CopyFrom(input_node)
524 | original_input_node.name = original_input_name
525 | self.add_output_graph_node(original_input_node)
526 | reshape_dims_node = create_constant_node(reshape_dims_name, -1,
527 | dtypes.int32, [1])
528 | self.add_output_graph_node(reshape_dims_node)
529 | reshape_node = create_node("Reshape", reshape_name,
530 | [original_input_name, reshape_dims_name])
531 | set_attr_dtype(reshape_node, "T", dtypes.float32)
532 | self.add_output_graph_node(reshape_node)
533 | dims_node = create_constant_node(dims_name, 0, dtypes.int32, [1])
534 | self.add_output_graph_node(dims_node)
535 | max_node = create_node("Max", max_name, [reshape_name, dims_name])
536 | set_attr_dtype(max_node, "T", dtypes.float32)
537 | set_attr_bool(max_node, "keep_dims", False)
538 | self.add_output_graph_node(max_node)
539 | min_node = create_node("Min", min_name, [reshape_name, dims_name])
540 | set_attr_dtype(min_node, "T", dtypes.float32)
541 | set_attr_bool(min_node, "keep_dims", False)
542 | self.add_output_graph_node(min_node)
543 | quantize_node = create_node("Quantize", quantize_name,
544 | [original_input_name, min_name, max_name])
545 | set_attr_dtype(quantize_node, "T", dtypes.quint8)
546 | set_attr_string(quantize_node, "mode", b"MIN_FIRST")
547 | self.add_output_graph_node(quantize_node)
548 | dequantize_node = create_node("Dequantize", dequantize_name,
549 | [quantize_name, min_name, max_name])
550 | set_attr_dtype(dequantize_node, "T", dtypes.quint8)
551 | set_attr_string(dequantize_node, "mode", b"MIN_FIRST")
552 | self.add_output_graph_node(dequantize_node)
553 |
554 | def should_merge_with_fake_quant_node(self):
555 | """Should the current node merge with self.state.output_node_stack[-1]?"""
556 | if not self.state.output_node_stack:
557 | return False
558 | top = self.state.output_node_stack[-1]
559 | return top[1] == 0 and top[0].op in ["FakeQuantWithMinMaxVars"]
560 |
561 | def should_quantize_const(self, node):
562 | if not self.state.output_node_stack:
563 | return False
564 | top = self.state.output_node_stack[-1]
565 | if not top[2]:
566 | return False
567 | dtype = dtypes.as_dtype(node.attr["dtype"].type)
568 | assert dtype == dtypes.float32, (
569 | "Failed to quantized constant %s of type %s" % (node.name, dtype))
570 | return True
571 |
572 | def eightbitize_nodes_recursively(self, current_node):
573 | """The entry point for transforming a graph into full eight bit."""
574 | if current_node.name in self.state.already_visited:
575 | if (self.should_merge_with_fake_quant_node() or
576 | current_node.name in self.state.merged_with_fake_quant):
577 | raise ValueError("Unsupported graph structure: output of node %s "
578 | "is processed by a FakeQuant* node and should have "
579 | "no other outputs.", current_node.name)
580 | return
581 | self.state.already_visited[current_node.name] = True
582 |
583 | for i, input_node_name in enumerate(current_node.input):
584 | quantize_input = False
585 | if current_node.op in ("MatMul", "Conv2D", "BiasAdd", "MaxPool",
586 | "AvgPool", "Relu", "Relu6",
587 | "BatchNormWithGlobalNormalization"):
588 | quantize_input = True
589 | elif current_node.op == "Concat" and i > 0:
590 | quantize_input = (
591 | dtypes.as_dtype(current_node.attr["T"].type) == dtypes.float32)
592 | elif current_node.op == "Reshape" and i == 0:
593 | quantize_input = (
594 | dtypes.as_dtype(current_node.attr["T"].type) == dtypes.float32)
595 |
596 | self.state.output_node_stack.append((current_node, i, quantize_input))
597 |
598 | input_node_name = node_name_from_input(input_node_name)
599 | input_node = self.nodes_map[input_node_name]
600 | self.eightbitize_nodes_recursively(input_node)
601 |
602 | self.state.output_node_stack.pop()
603 |
604 | if current_node.op == "MatMul":
605 | self.eightbitize_mat_mul_node(current_node)
606 | elif current_node.op == "Conv2D":
607 | self.eightbitize_conv_node(current_node)
608 | elif current_node.op == "BiasAdd":
609 | self.eightbitize_bias_add_node(current_node)
610 | elif current_node.op == "MaxPool" or current_node.op == "AvgPool":
611 | self.eightbitize_single_input_tensor_node(current_node,
612 | self.add_pool_function)
613 | elif current_node.op == "Relu" or current_node.op == "Relu6":
614 | self.eightbitize_single_input_tensor_node(current_node,
615 | self.add_relu_function)
616 | elif (current_node.op == "Concat" and
617 | dtypes.as_dtype(current_node.attr["T"].type) == dtypes.float32):
618 | self.eightbitize_concat_node(current_node)
619 | elif current_node.op == "BatchNormWithGlobalNormalization":
620 | self.eightbitize_batch_norm_node(current_node)
621 | elif (current_node.op == "Reshape" and
622 | dtypes.as_dtype(current_node.attr["T"].type) == dtypes.float32):
623 | self.eightbitize_reshape_node(current_node)
624 | elif (self.input_range and
625 | current_node.op in ("Placeholder", "PlaceholderV2")):
626 | self.eightbitize_placeholder_node(current_node)
627 | elif current_node.op == "FakeQuantWithMinMaxVars":
628 | # It will have been merged into the underlying node.
629 | pass
630 | elif current_node.op == "Const":
631 | if self.should_quantize_const(current_node):
632 | for n in quantize_weight_eightbit(current_node, b"MIN_FIRST"):
633 | self.add_output_graph_node(n)
634 | else:
635 | new_node = node_def_pb2.NodeDef()
636 | new_node.CopyFrom(current_node)
637 | self.add_output_graph_node(new_node)
638 |
639 | ###################################################################
640 | # Note: if more cases are added here, you may need to update the op
641 | # name lists in the loop over children at the start of the function.
642 | ###################################################################
643 | else:
644 | new_node = node_def_pb2.NodeDef()
645 | new_node.CopyFrom(current_node)
646 | self.add_output_graph_node(new_node)
647 |
648 | if (self.should_merge_with_fake_quant_node() and
649 | current_node.name not in self.state.merged_with_fake_quant):
650 | raise ValueError(
651 | "FakeQuant* node %s failed to merge with node %s of type %s" %
652 | (self.state.output_node_stack[-1][0], current_node.name,
653 | current_node.op))
654 |
655 | def add_eightbit_prologue_nodes(self, original_node):
656 | """Adds input conversion nodes to handle quantizing the underlying node."""
657 | namespace_prefix = original_node.name + "_eightbit"
658 | reshape_dims_name, reduction_dims_name = self.add_common_quantization_nodes(
659 | namespace_prefix)
660 | input_names = []
661 | min_max_names = []
662 | for original_input_name in original_node.input:
663 | quantize_input_name, min_input_name, max_input_name = (
664 | self.eightbitize_input_to_node(namespace_prefix, original_input_name,
665 | reshape_dims_name,
666 | reduction_dims_name))
667 | input_names.append(quantize_input_name)
668 | min_max_names.append(min_input_name)
669 | min_max_names.append(max_input_name)
670 | all_input_names = []
671 | all_input_names.extend(input_names)
672 | all_input_names.extend(min_max_names)
673 | return all_input_names
674 |
675 | def add_common_quantization_nodes(self, namespace_prefix):
676 | """Builds constant nodes needed for quantization of inputs."""
677 | reshape_dims_name = namespace_prefix + "_reshape_dims"
678 | reduction_dims_name = namespace_prefix + "_reduction_dims"
679 |
680 | reshape_dims_node = create_constant_node(reshape_dims_name, -1,
681 | dtypes.int32, [1])
682 | self.add_output_graph_node(reshape_dims_node)
683 | reduction_dims_node = create_constant_node(reduction_dims_name, 0,
684 | dtypes.int32, [1])
685 | self.add_output_graph_node(reduction_dims_node)
686 | return reshape_dims_name, reduction_dims_name
687 |
688 | def eightbitize_input_to_node(self, namespace_prefix, original_input_name,
689 | reshape_dims_name, reduction_dims_name):
690 | """Takes one float input to an op, and converts it to quantized form."""
691 | unique_input_name = unique_node_name_from_input(original_input_name)
692 | reshape_input_name = namespace_prefix + "_reshape_" + unique_input_name
693 | min_input_name = namespace_prefix + "_min_" + unique_input_name
694 | max_input_name = namespace_prefix + "_max_" + unique_input_name
695 | quantize_input_name = namespace_prefix + "_quantize_" + unique_input_name
696 | reshape_input_node = create_node("Reshape", reshape_input_name,
697 | [original_input_name, reshape_dims_name])
698 | set_attr_dtype(reshape_input_node, "T", dtypes.float32)
699 | self.add_output_graph_node(reshape_input_node)
700 | min_input_node = create_node("Min", min_input_name,
701 | [reshape_input_name, reduction_dims_name])
702 | set_attr_dtype(min_input_node, "T", dtypes.float32)
703 | set_attr_bool(min_input_node, "keep_dims", False)
704 | self.add_output_graph_node(min_input_node)
705 | max_input_node = create_node("Max", max_input_name,
706 | [reshape_input_name, reduction_dims_name])
707 | set_attr_dtype(max_input_node, "T", dtypes.float32)
708 | set_attr_bool(max_input_node, "keep_dims", False)
709 | self.add_output_graph_node(max_input_node)
710 | quantize_input_node = create_node(
711 | "QuantizeV2", quantize_input_name,
712 | [original_input_name, min_input_name, max_input_name])
713 | set_attr_dtype(quantize_input_node, "T", dtypes.quint8)
714 | set_attr_string(quantize_input_node, "mode", b"MIN_FIRST")
715 | self.add_output_graph_node(quantize_input_node)
716 | min_output_name = quantize_input_name + ":1"
717 | max_output_name = quantize_input_name + ":2"
718 | return quantize_input_name, min_output_name, max_output_name
719 |
720 | def add_quantize_down_nodes(self, original_node, quantized_output_name):
721 | quantized_outputs = [
722 | quantized_output_name, quantized_output_name + ":1",
723 | quantized_output_name + ":2"
724 | ]
725 | min_max_inputs = None
726 | if self.should_merge_with_fake_quant_node():
727 | # Use the inputs to the FakeQuantWithMinMaxVars node as the inputs to
728 | # Requantize.
729 | fake_quant_node = self.state.output_node_stack[-1][0]
730 | min_max_inputs = [fake_quant_node.input[1], fake_quant_node.input[2]]
731 | assert original_node.name not in self.state.merged_with_fake_quant
732 | self.state.merged_with_fake_quant[original_node.name] = True
733 | elif self.fallback_quantization_range:
734 | min_max_inputs = [
735 | "fallback_quantization_min_value:0",
736 | "fallback_quantization_max_value:0"
737 | ]
738 | else:
739 | # Add a RequantizationRange node for finding the min and max values.
740 | requant_range_node = create_node(
741 | "RequantizationRange", original_node.name + "_eightbit_requant_range",
742 | quantized_outputs)
743 | set_attr_dtype(requant_range_node, "Tinput", dtypes.qint32)
744 | self.add_output_graph_node(requant_range_node)
745 | min_max_inputs = [
746 | requant_range_node.name + ":0", requant_range_node.name + ":1"
747 | ]
748 | requantize_node = create_node("Requantize",
749 | original_node.name + "_eightbit_requantize",
750 | quantized_outputs + min_max_inputs)
751 | set_attr_dtype(requantize_node, "Tinput", dtypes.qint32)
752 | set_attr_dtype(requantize_node, "out_type", dtypes.quint8)
753 | self.add_output_graph_node(requantize_node)
754 | return requantize_node.name
755 |
756 | def add_dequantize_result_node(self,
757 | quantized_output_name,
758 | original_node_name,
759 | min_tensor_index=1):
760 | min_max_inputs = [
761 | "%s:%s" % (quantized_output_name, min_tensor_index),
762 | "%s:%s" % (quantized_output_name, (min_tensor_index + 1))
763 | ]
764 | dequantize_name = original_node_name
765 | if self.should_merge_with_fake_quant_node():
766 | fake_quant_node = self.state.output_node_stack[-1][0]
767 | if original_node_name not in self.state.merged_with_fake_quant:
768 | min_max_inputs = [fake_quant_node.input[1], fake_quant_node.input[2]]
769 | self.state.merged_with_fake_quant[original_node_name] = True
770 | dequantize_name = fake_quant_node.name
771 |
772 | dequantize_node = create_node(
773 | "Dequantize", dequantize_name,
774 | [quantized_output_name, min_max_inputs[0], min_max_inputs[1]])
775 | set_attr_dtype(dequantize_node, "T", dtypes.quint8)
776 | set_attr_string(dequantize_node, "mode", b"MIN_FIRST")
777 | self.add_output_graph_node(dequantize_node)
778 |
779 | def eightbitize_mat_mul_node(self, original_node):
780 | """Replaces a MatMul node with the eight bit equivalent sub-graph."""
781 | quantized_mat_mul_name = original_node.name + "_eightbit_quantized_mat_mul"
782 | all_input_names = self.add_eightbit_prologue_nodes(original_node)
783 | quantized_mat_mul_node = create_node("QuantizedMatMul",
784 | quantized_mat_mul_name,
785 | all_input_names)
786 | set_attr_dtype(quantized_mat_mul_node, "T1", dtypes.quint8)
787 | set_attr_dtype(quantized_mat_mul_node, "T2", dtypes.quint8)
788 | set_attr_dtype(quantized_mat_mul_node, "Toutput", dtypes.qint32)
789 | copy_attr(quantized_mat_mul_node, "transpose_a",
790 | original_node.attr["transpose_a"])
791 | copy_attr(quantized_mat_mul_node, "transpose_b",
792 | original_node.attr["transpose_b"])
793 | self.add_output_graph_node(quantized_mat_mul_node)
794 | quantize_down_name = self.add_quantize_down_nodes(original_node,
795 | quantized_mat_mul_name)
796 | self.add_dequantize_result_node(quantize_down_name, original_node.name)
797 |
798 | def eightbitize_conv_node(self, original_node):
799 | """Replaces a Conv2D node with the eight bit equivalent sub-graph."""
800 | all_input_names = self.add_eightbit_prologue_nodes(original_node)
801 | quantized_conv_name = original_node.name + "_eightbit_quantized_conv"
802 | quantized_conv_node = create_node("QuantizedConv2D", quantized_conv_name,
803 | all_input_names)
804 | copy_attr(quantized_conv_node, "strides", original_node.attr["strides"])
805 | copy_attr(quantized_conv_node, "padding", original_node.attr["padding"])
806 | set_attr_dtype(quantized_conv_node, "Tinput", dtypes.quint8)
807 | set_attr_dtype(quantized_conv_node, "Tfilter", dtypes.quint8)
808 | set_attr_dtype(quantized_conv_node, "out_type", dtypes.qint32)
809 | self.add_output_graph_node(quantized_conv_node)
810 | quantize_down_name = self.add_quantize_down_nodes(original_node,
811 | quantized_conv_name)
812 | self.add_dequantize_result_node(quantize_down_name, original_node.name)
813 |
814 | def eightbitize_bias_add_node(self, original_node):
815 | """Replaces a BiasAdd node with the eight bit equivalent sub-graph."""
816 | quantized_bias_add_name = (
817 | original_node.name + "_eightbit_quantized_bias_add")
818 | all_input_names = self.add_eightbit_prologue_nodes(original_node)
819 | quantized_bias_add_node = create_node("QuantizedBiasAdd",
820 | quantized_bias_add_name,
821 | all_input_names)
822 | set_attr_dtype(quantized_bias_add_node, "T1", dtypes.quint8)
823 | set_attr_dtype(quantized_bias_add_node, "T2", dtypes.quint8)
824 | set_attr_dtype(quantized_bias_add_node, "out_type", dtypes.qint32)
825 | self.add_output_graph_node(quantized_bias_add_node)
826 | quantize_down_name = self.add_quantize_down_nodes(original_node,
827 | quantized_bias_add_name)
828 | self.add_dequantize_result_node(quantize_down_name, original_node.name)
829 |
830 | def eightbitize_single_input_tensor_node(self, original_node,
831 | add_op_function):
832 | """Replaces a single-tensor node with the eight bit equivalent sub-graph.
833 |
834 | Converts a node like this:
835 |
836 | Shape(f) Input(f)
837 | | |
838 | +--------v v
839 | Operation
840 | |
841 | v
842 | (f)
843 |
844 | Into a quantized equivalent:
845 |
846 | Input(f) ReshapeDims
847 | +------v v-------------+
848 | | Reshape
849 | | |
850 | | | ReductionDims
851 | | +-----+ |
852 | | | +---c---------+
853 | | v v v v-------+
854 | | Min Max
855 | | +----+ |
856 | v v v--------+
857 | Quantize
858 | |
859 | v
860 | QuantizedOperation
861 | | | |
862 | v v v
863 | Dequantize
864 | |
865 | v
866 | (f)
867 |
868 |
869 | Args:
870 | original_node: Float node to be converted.
871 | add_op_function: Function to create the actual node.
872 |
873 | Returns:
874 | Subgraph representing the quantized version of the original node.
875 |
876 | """
877 | quantized_op_name = original_node.name + "_eightbit_quantized"
878 | quantized_op_type = "Quantized" + original_node.op
879 | all_input_names = self.add_eightbit_prologue_nodes(original_node)
880 | quantized_op_node = create_node(quantized_op_type, quantized_op_name,
881 | all_input_names)
882 | add_op_function(original_node, quantized_op_node)
883 | self.add_output_graph_node(quantized_op_node)
884 | self.add_dequantize_result_node(quantized_op_name, original_node.name)
885 |
886 | def add_pool_function(self, original_node, quantized_op_node):
887 | set_attr_dtype(quantized_op_node, "T", dtypes.quint8)
888 | copy_attr(quantized_op_node, "ksize", original_node.attr["ksize"])
889 | copy_attr(quantized_op_node, "strides", original_node.attr["strides"])
890 | copy_attr(quantized_op_node, "padding", original_node.attr["padding"])
891 |
892 | def add_relu_function(self, unused_arg_node, quantized_op_node):
893 | set_attr_dtype(quantized_op_node, "Tinput", dtypes.quint8)
894 |
895 | def eightbitize_concat_node(self, original_node):
896 | """Replaces a Concat node with the eight bit equivalent sub-graph.
897 |
898 | Converts a node like this:
899 |
900 | Shape(f) Input0(f) Input1(f)
901 | | | |
902 | +--------v v v----------+
903 | Concat
904 | |
905 | v
906 | (f)
907 |
908 | Into a quantized equivalent:
909 |
910 | Shape(f) Input0(f) ReshapeDims Input1(f)
911 | | +------v v--------------+------------------v v------+
912 | | | Reshape Reshape |
913 | | | | | |
914 | | | | ReductionDims | |
915 | | | +------+ | +--------+ |
916 | | | | +---c---------+-----------c-----+ | |
917 | | | +v v v v-------+---------v v v v+ |
918 | | | Min Max Min Max |
919 | | | +----+ | | +-----+ |
920 | | v v v--------+ +----------v v v
921 | | Quantize Quantize
922 | | +------------------+ +----------------------+
923 | +-------------------------------+ | |
924 | v v v
925 | QuantizedConcat
926 | | | |
927 | v v v
928 | Dequantize
929 | |
930 | v
931 | (f)
932 | Args:
933 | original_node: Float node to be converted.
934 |
935 | Returns:
936 | Subgraph representing the quantized version of the original node.
937 |
938 | """
939 | namespace_prefix = original_node.name + "_eightbit"
940 | quantized_concat_name = namespace_prefix + "_quantized_concat"
941 | reshape_dims_name, reduction_dims_name = self.add_common_quantization_nodes(
942 | namespace_prefix)
943 | shape_input_name = original_node.input[0]
944 | original_inputs = original_node.input[1:]
945 | input_names = []
946 | min_names = []
947 | max_names = []
948 | for original_input_name in original_inputs:
949 | quantize_input_name, min_input_name, max_input_name = (
950 | self.eightbitize_input_to_node(namespace_prefix, original_input_name,
951 | reshape_dims_name,
952 | reduction_dims_name))
953 | input_names.append(quantize_input_name)
954 | min_names.append(min_input_name)
955 | max_names.append(max_input_name)
956 | all_input_names = [shape_input_name]
957 | all_input_names.extend(input_names)
958 | all_input_names.extend(min_names)
959 | all_input_names.extend(max_names)
960 | quantized_concat_node = create_node("QuantizedConcat",
961 | quantized_concat_name, all_input_names)
962 | set_attr_int(quantized_concat_node, "N", len(original_inputs))
963 | set_attr_dtype(quantized_concat_node, "T", dtypes.quint8)
964 | self.add_output_graph_node(quantized_concat_node)
965 | self.add_dequantize_result_node(quantized_concat_name, original_node.name)
966 |
967 | def eightbitize_placeholder_node(self, current_node):
968 | """Replaces a placeholder node with a quint8 placeholder node+dequantize."""
969 | name = current_node.name
970 |
971 | # Convert the placeholder into a quantized type.
972 | output_node = node_def_pb2.NodeDef()
973 | output_node.CopyFrom(current_node)
974 | set_attr_dtype(output_node, "dtype", dtypes.quint8)
975 | output_node.name += "_original_input"
976 | self.add_output_graph_node(output_node)
977 |
978 | # Add a dequantize to convert back to float.
979 | dequantize_node = create_node("Dequantize", name, [
980 | output_node.name, "quantized_input_min_value",
981 | "quantized_input_max_value"
982 | ])
983 | set_attr_dtype(dequantize_node, "T", dtypes.quint8)
984 | set_attr_string(dequantize_node, "mode", b"MIN_FIRST")
985 | self.add_output_graph_node(dequantize_node)
986 |
987 | # For the descent over the graph to work, the dequantize node must be named
988 | # current_node.name. However, for the feeding of the graph to work, the
989 | # placeholder must have the name current_node.name; so record a final set
990 | # of renames to apply after all processing has been done.
991 | self.final_node_renames[output_node.name] = name
992 | self.final_node_renames[dequantize_node.name] = name + "_dequantize"
993 |
994 | def eightbitize_reshape_node(self, original_node):
995 | """Replaces a Reshape node with the eight bit equivalent sub-graph.
996 |
997 | Args:
998 | original_node: Float node to be converted.
999 |
1000 | Returns:
1001 | Subgraph representing the quantized version of the original node.
1002 |
1003 | """
1004 | namespace_prefix = original_node.name + "_eightbit"
1005 | quantized_reshape_name = namespace_prefix + "_quantized_reshape"
1006 | reshape_dims_name, reduction_dims_name = self.add_common_quantization_nodes(
1007 | namespace_prefix)
1008 | shape_input_name = original_node.input[1]
1009 | quantize_input_name, min_input_name, max_input_name = (
1010 | self.eightbitize_input_to_node(namespace_prefix, original_node.input[0],
1011 | reshape_dims_name, reduction_dims_name))
1012 | quantized_reshape_node = create_node(
1013 | "QuantizedReshape", quantized_reshape_name,
1014 | [quantize_input_name, shape_input_name, min_input_name, max_input_name])
1015 | set_attr_dtype(quantized_reshape_node, "T", dtypes.quint8)
1016 | self.add_output_graph_node(quantized_reshape_node)
1017 | self.add_dequantize_result_node(quantized_reshape_name, original_node.name)
1018 |
1019 | def eightbitize_batch_norm_node(self, original_node):
1020 | """Replaces a MatMul node with the eight bit equivalent sub-graph."""
1021 | namespace_prefix = original_node.name + "_eightbit"
1022 | original_input_name = original_node.input[0]
1023 | original_mean_name = original_node.input[1]
1024 | original_variance_name = original_node.input[2]
1025 | original_beta_name = original_node.input[3]
1026 | original_gamma_name = original_node.input[4]
1027 | quantized_batch_norm_name = namespace_prefix + "_quantized_batch_norm"
1028 |
1029 | reshape_dims_name, reduction_dims_name = self.add_common_quantization_nodes(
1030 | namespace_prefix)
1031 | quantize_input_name, min_input_name, max_input_name = (
1032 | self.eightbitize_input_to_node(namespace_prefix, original_input_name,
1033 | reshape_dims_name, reduction_dims_name))
1034 | quantize_mean_name, min_mean_name, max_mean_name = (
1035 | self.eightbitize_input_to_node(namespace_prefix, original_mean_name,
1036 | reshape_dims_name, reduction_dims_name))
1037 | quantize_variance_name, min_variance_name, max_variance_name = (
1038 | self.eightbitize_input_to_node(namespace_prefix, original_variance_name,
1039 | reshape_dims_name, reduction_dims_name))
1040 | quantize_beta_name, min_beta_name, max_beta_name = (
1041 | self.eightbitize_input_to_node(namespace_prefix, original_beta_name,
1042 | reshape_dims_name, reduction_dims_name))
1043 | quantize_gamma_name, min_gamma_name, max_gamma_name = (
1044 | self.eightbitize_input_to_node(namespace_prefix, original_gamma_name,
1045 | reshape_dims_name, reduction_dims_name))
1046 | quantized_batch_norm_node = create_node(
1047 | "QuantizedBatchNormWithGlobalNormalization", quantized_batch_norm_name,
1048 | [
1049 | quantize_input_name, min_input_name, max_input_name,
1050 | quantize_mean_name, min_mean_name, max_mean_name,
1051 | quantize_variance_name, min_variance_name, max_variance_name,
1052 | quantize_beta_name, min_beta_name, max_beta_name,
1053 | quantize_gamma_name, min_gamma_name, max_gamma_name
1054 | ])
1055 | set_attr_dtype(quantized_batch_norm_node, "Tinput", dtypes.quint8)
1056 | set_attr_dtype(quantized_batch_norm_node, "out_type", dtypes.qint32)
1057 | copy_attr(quantized_batch_norm_node, "scale_after_normalization",
1058 | original_node.attr["scale_after_normalization"])
1059 | copy_attr(quantized_batch_norm_node, "variance_epsilon",
1060 | original_node.attr["variance_epsilon"])
1061 | self.add_output_graph_node(quantized_batch_norm_node)
1062 | quantize_down_name = self.add_quantize_down_nodes(original_node,
1063 | quantized_batch_norm_name)
1064 | self.add_dequantize_result_node(quantize_down_name, original_node.name)
1065 |
1066 | def add_output_graph_node(self, output_node):
1067 | """Inserts one node into the new graph."""
1068 | self.output_graph.node.extend([output_node])
1069 |
1070 | def remove_redundant_quantization(self, old_graph):
1071 | """Removes unneeded pairs of quantize/dequantize ops from the graph.
1072 |
1073 | This is a bit of a tricky function, because it's attempting to spot the
1074 | pattern of dequantizing from eight-bit up to float, and then immediately
1075 | quantizing back down to eight bits again, that's introduced by previous
1076 | passes that do 'key-hole' conversions of individual nodes but have to
1077 | convert back to float to match the previous output interface, since they
1078 | don't know that the next op can handle quantized tensors.
1079 | It works by:
1080 | - Looking for Quantize nodes.
1081 | - Checking to see if their first input is a Dequantize node.
1082 | - Seeing if their min/max inputs come from Min/Max nodes.
1083 | - Making sure those Min/Max nodes are being fed from the same Dequantize.
1084 | - Or that the Min is indirectly being fed from the same Dequantize as Max.
1085 | - Making sure the Dequantize is going through a Reshape (which we add
1086 | during the previous pass when we create the quantize sub-graph).
1087 | - Looking for the dims Const op for the Min/Max dims.
1088 | If all of these conditions are met, then it's a sub-graph pattern that
1089 | we know how to optimize out (and is likely the common one we've introduced).
1090 | We then rewire the graph to skip it entirely, and then rely on the dead node
1091 | removal pass to get rid of any nodes that are no longer needed.
1092 |
1093 | Args:
1094 | old_graph: The model we'll be stripping redundant nodes from.
1095 |
1096 | Returns:
1097 | A graph with the unnecessary nodes removed.
1098 |
1099 | Raises:
1100 | ValueError: Two nodes with the same name were found in the graph.
1101 | """
1102 | old_nodes_map = self.create_nodes_map(old_graph)
1103 | self.output_graph = graph_pb2.GraphDef()
1104 | inputs_to_rename = {}
1105 | # We go through all the nodes, looking for any that match the patterns we
1106 | # know how to optimize away.
1107 | for node in old_graph.node:
1108 | # We always start with a Quantize node, and examine its inputs to see if
1109 | # they are in a form that can be removed.
1110 | if node.op not in ["Quantize", "QuantizeV2"]:
1111 | continue
1112 | dequantize_node_name = node_name_from_input(node.input[0])
1113 | if dequantize_node_name not in old_nodes_map:
1114 | raise ValueError("Input node name '" + dequantize_node_name +
1115 | "' not found in node '" + node.name + "'")
1116 | dequantize_node = old_nodes_map[dequantize_node_name]
1117 | # Do we have a Dequantize feeding in, with the same type as the Quantize?
1118 | if dequantize_node.op != "Dequantize":
1119 | continue
1120 | if node.attr["T"] != dequantize_node.attr["T"]:
1121 | continue
1122 | # Now look at the other inputs, and ensure they're Min/Max nodes.
1123 | min_node_name = node_name_from_input(node.input[1])
1124 | max_node_name = node_name_from_input(node.input[2])
1125 | min_node = old_nodes_map[min_node_name]
1126 | max_node = old_nodes_map[max_node_name]
1127 | is_min_right_type = (min_node.op in ["Min", "Dequantize"])
1128 | is_max_right_type = (max_node.op in ["Max", "Dequantize"])
1129 | if not is_min_right_type or not is_max_right_type:
1130 | print("Didn't find expected types on inputs : %s, %s." % (min_node.op,
1131 | max_node.op))
1132 | continue
1133 | min_node_input_name = node_name_from_input(min_node.input[0])
1134 | max_node_input_name = node_name_from_input(max_node.input[0])
1135 | # There are two different patterns for Min nodes we can recognize, one
1136 | # where the input comes directly from the same one as the Max, and
1137 | # another where we run it through another Min first, so check for both.
1138 | is_same_input = False
1139 | if min_node_input_name == max_node_input_name:
1140 | is_same_input = True
1141 | else:
1142 | first_min_node_input = old_nodes_map[min_node_input_name]
1143 | if first_min_node_input.op == "Concat":
1144 | second_min_node_name = node_name_from_input(
1145 | first_min_node_input.input[1])
1146 | second_min_node = old_nodes_map[second_min_node_name]
1147 | if second_min_node.op == "Min":
1148 | second_min_node_input_name = node_name_from_input(
1149 | second_min_node.input[0])
1150 | is_same_input = (second_min_node_input_name == max_node_input_name)
1151 | if not is_same_input:
1152 | print("Different min/max inputs: " + min_node_input_name)
1153 | continue
1154 | # We recognize this pattern, so mark the graph edges to be rewired to
1155 | # route around it entirely, since we know it's a no-op.
1156 | dequantize_source_name = node_name_from_input(dequantize_node.input[0])
1157 | node_tensor_name = ensure_tensor_name_has_port(node.name)
1158 | min_tensor_name = node.name + ":1"
1159 | max_tensor_name = node.name + ":2"
1160 | inputs_to_rename[node_tensor_name] = dequantize_source_name
1161 | inputs_to_rename[min_tensor_name] = dequantize_node.input[1]
1162 | inputs_to_rename[max_tensor_name] = dequantize_node.input[2]
1163 | # Finally we apply all the rewiring we've marked to the graph.
1164 | for node in old_graph.node:
1165 | for index, input_full_name in enumerate(node.input):
1166 | input_name = ensure_tensor_name_has_port(input_full_name)
1167 | if input_name in inputs_to_rename:
1168 | node.input[index] = inputs_to_rename[input_name]
1169 | self.add_output_graph_node(node)
1170 | return self.output_graph
1171 |
1172 | def apply_final_node_renames(self):
1173 | """Applies node renames in self.final_node_renames to self.output_graph."""
1174 | old_graph = self.output_graph
1175 | self.output_graph = graph_pb2.GraphDef()
1176 | for node in old_graph.node:
1177 | node.name = self.final_node_renames.get(node.name, node.name)
1178 | for index, input_name in enumerate(node.input):
1179 | node_name = node_name_from_input(input_name)
1180 | input_full_name = ensure_tensor_name_has_port(input_name)
1181 | if node_name in self.final_node_renames:
1182 | node.input[index] = "%s%s" % (self.final_node_renames[node_name],
1183 | input_full_name[len(node_name):])
1184 | self.add_output_graph_node(node)
1185 | return self.output_graph
1186 |
1187 | def remove_dead_nodes(self, output_names):
1188 | """Removes nodes that are no longer needed for inference from the graph."""
1189 | old_output_graph = self.output_graph
1190 | self.output_graph = graph_util.extract_sub_graph(old_output_graph,
1191 | output_names)
1192 |
1193 | def quantize_weights(self, input_graph, quantization_mode):
1194 | """Quantize float Const ops.
1195 |
1196 | There are two modes of operations, both replace float Const ops with
1197 | quantized values.
1198 | 1. If quantization_mode is "weights_rounded", this function replaces float
1199 | Const ops with quantized float Const ops - same as the original op, but
1200 | float values being mapped to the center of one of 1<%s" % caption))
36 |
--------------------------------------------------------------------------------
/src/attributes.js:
--------------------------------------------------------------------------------
1 |
2 | const {
3 | aboutZero,
4 | aboutCentered,
5 | aboutEqual,
6 | allAboutEqual,
7 | smallComparedTo
8 | } = require('./lib/utils');
9 |
10 |
11 | const nativeAttrs = (js, ele, idDims, childParent) => {
12 | const dims = idDims[js.id] ? idDims[js.id] : {};
13 | let componentStyles = {};
14 |
15 | const {
16 | top,
17 | left,
18 | width,
19 | // height,
20 | position,
21 | ...rest
22 | } = js.style
23 |
24 | let jsStyles = {...rest}
25 |
26 | if(width < 360) {
27 | jsStyles.width = width;
28 | }
29 |
30 | // Note: 360 is the max width of iphone
31 | if(js.type == 'View' && !jsStyles.width && dims.width < 360) {
32 | jsStyles.width = dims.width;
33 | }
34 |
35 | if(js.type == 'Text' || js.type == 'Tspan') {
36 | const parent = childParent[js.id];
37 | const parentDims = parent && idDims[parent];
38 | if(parentDims) {
39 | const spaceBefore = dims.left - parentDims.left;
40 | const spaceAfter = parentDims.right - dims.right;
41 |
42 | if(aboutCentered(spaceBefore, spaceAfter) && !aboutZero(spaceBefore)) {
43 | jsStyles['textAlign'] = 'center';
44 | }
45 |
46 | if(!aboutCentered(spaceBefore, spaceAfter) && spaceBefore < spaceAfter) {
47 | jsStyles['textAlign'] = 'left';
48 | jsStyles['marginLeft'] = spaceBefore;
49 | }
50 |
51 | if(!aboutCentered(spaceBefore, spaceAfter) && spaceBefore > spaceAfter) {
52 | jsStyles['textAlign'] = 'right';
53 | jsStyles['marginRight'] = spaceAfter;
54 | }
55 |
56 | }
57 | }
58 |
59 |
60 | const parent = childParent[ele.id]
61 | if(parent) {
62 | const parentDims = parent == 'row' || parent == 'column' ? parent.parentDims : idDims[parent]
63 | if(js.type == 'View' && !jsStyles.height && !aboutEqual(dims.height, parentDims.height)) {
64 | jsStyles.height = dims.height;
65 | }
66 | }
67 |
68 | let styles = jsStyles ? Object.keys(jsStyles).reduce((r, key) => {
69 | if(key == 'top' ||
70 | key == 'left' ||
71 | key == 'flex' ||
72 | key == 'height' ||
73 | key == 'width' ||
74 | key == 'fontSize' ||
75 | key == 'lineHeight' ||
76 | key == 'borderRadius' ||
77 | key == 'paddingTop' ||
78 | key == 'paddingBottom' ||
79 | key == 'paddingRight' ||
80 | key == 'paddingLeft' ||
81 | key == 'marginTop' ||
82 | key == 'marginBottom' ||
83 | key == 'marginLeft' ||
84 | key == 'marginRight') {
85 | if(jsStyles[key] > 0) {
86 | r[key] = Math.round(jsStyles[key])
87 | }
88 | } else {
89 | r[key] = `'${jsStyles[key]}'`
90 | }
91 | return r;
92 | }, {}) : {};
93 |
94 |
95 | if(ele.alignJustify) {
96 | Object.keys(ele.alignJustify).forEach((key) => {
97 | if(key == 'paddingTop' ||
98 | key == 'paddingBottom' ||
99 | key == 'paddingRight' ||
100 | key == 'paddingLeft' ||
101 | key == 'marginTop' ||
102 | key == 'marginBottom' ||
103 | key == 'marginLeft' ||
104 | key == 'marginRight') {
105 | if(ele.alignJustify[key] > 0) {
106 | styles[key] = Math.round(ele.alignJustify[key]);
107 | }
108 | } else {
109 | styles[key] = `'${ele.alignJustify[key]}'`
110 | }
111 | })
112 | }
113 |
114 |
115 | let attrObjs = {}
116 |
117 | if(js.directAttrs) {
118 | Object.keys(js.directAttrs).forEach((key) => {
119 | attrObjs[key] = "'" + js.directAttrs[key] + "'"
120 | })
121 | }
122 |
123 | const attrString = Object.keys(attrObjs).map((key) => {
124 | return `${key}={${attrObjs[key]}}`
125 | }).join(" ");
126 |
127 | const attrs = attrString.length > 0 ? " " + attrString : "";
128 | let styleId = js.id.replace(/[^\_0-9a-zA-Z]/g, '')
129 | if(styleId.match(/^[0-9]/)) {
130 | styleId = "_" + styleId
131 | }
132 |
133 | componentStyles[styleId] = styles;
134 | return({attrs, attrObjs, styleId, componentStyles});
135 | }
136 |
137 | const firstBackgroundColor = (js, base, idDims) => {
138 | const rDims = idDims[base.id]
139 | const jsDims = idDims[js.id]
140 | let backgroundColor = null;
141 | if(rDims && jsDims && rDims.top == jsDims.top && rDims.bottom == jsDims.bottom && rDims.left == jsDims.left && rDims.right == jsDims.right) {
142 | if(js.style.backgroundColor) {
143 | backgroundColor = js.style.backgroundColor;
144 | }
145 | }
146 | js.childs.forEach((child) => {
147 | const childColor = firstBackgroundColor(child, base, idDims);
148 | if(childColor) {
149 | backgroundColor = childColor;
150 | }
151 | });
152 | return backgroundColor;
153 | }
154 |
155 |
156 | const determineAlignJustify = (comp, children, idDims, jsObjs) => {
157 | let alignJustify = {}
158 |
159 | if(comp.id == 'row' && children && children.length > 0 && idDims[children[0].id]) {
160 | const spaceBefore = idDims[children[0].id].left - comp.parentDims.left;
161 | const spaceAfter = comp.parentDims.right - idDims[children[children.length - 1].id].right;
162 | let spacesBetween = [];
163 |
164 | children.forEach((child, index) => {
165 | if(index > 0) {
166 | spacesBetween.push(idDims[children[index].id].left - idDims[children[index - 1].id].right)
167 | }
168 | });
169 |
170 | if(spacesBetween.length > 0 && allAboutEqual(spacesBetween)) {
171 | // space between case
172 | if(aboutEqual(spaceBefore, spaceAfter)) {
173 | if(smallComparedTo(spaceBefore, spacesBetween[0])) {
174 | alignJustify['justifyContent'] = 'space-between'
175 | }
176 | }
177 | // space around case
178 | if(aboutEqual(spaceBefore, spaceAfter) && !aboutZero(spaceBefore)) {
179 | alignJustify['justifyContent'] = 'space-around'
180 | }
181 | }
182 |
183 | } else if(comp.id == 'column') {
184 |
185 | } else {
186 | if(children.id) {
187 | // One child
188 | const childDims = children.id == 'row' || children.id == 'column' ? children.parentDims : idDims[children.id];
189 | const parentDims = comp.id == 'row' || children.id == 'column' ? comp.parentDims : idDims[comp.id];
190 |
191 | const spaceBefore = childDims.left - parentDims.left;
192 | const spaceAfter = parentDims.right - childDims.right;
193 | const spaceAbove = childDims.top - parentDims.top;
194 | const spaceBelow = parentDims.bottom - childDims.bottom;
195 |
196 | if(aboutEqual(spaceBefore, spaceAfter)) {
197 | alignJustify['alignItems'] = 'center'
198 | } else if(spaceBefore > spaceAfter) {
199 | alignJustify['alignItems'] = 'flex-end'
200 | alignJustify['marginRight'] = spaceAfter
201 | } else if(spaceBefore < spaceAfter) {
202 | alignJustify['alignItems'] = 'flex-start'
203 | alignJustify['marginLeft'] = spaceBefore
204 | }
205 |
206 | if(aboutEqual(spaceAbove, spaceBelow)) {
207 | alignJustify['justifyContent'] = 'center'
208 | } else if(spaceAbove > spaceBelow) {
209 | alignJustify['justifyContent'] = 'flex-end'
210 | alignJustify['marginBottom'] = spaceBelow
211 | } else if(spaceAbove < spaceBelow) {
212 | alignJustify['justifyContent'] = 'flex-start'
213 | alignJustify['marginTop'] = spaceAbove
214 | }
215 |
216 | } else {
217 | // Multiple Children; default to column
218 |
219 | const firstChildDims = children[0].id == 'row' || children[0].id == 'column' ? children[0].parentDims : idDims[children[0].id]
220 | const lastChildDims = children[children.length - 1].id == 'row' || children[children.length - 1].id == 'column' ? children[children.length - 1].parentDims : idDims[children[children.length - 1].id]
221 | const parentDims = comp.id == 'row' || children.id == 'column' ? comp.parentDims : idDims[comp.id];
222 |
223 | const spaceAbove = firstChildDims.top - parentDims.top;
224 | const spaceBelow = parentDims.bottom - lastChildDims.bottom;
225 |
226 | let spacesBetween = [];
227 | children.forEach((child, index) => {
228 | const childJS = jsObjs[child.id];
229 | const childDims = child.id == 'row' || child.id == 'column' ? child.parentDims : idDims[child.id]
230 |
231 | if(child.id != 'row' && child.id != 'column') {
232 | const spaceBefore = childDims.left - parentDims.left;
233 | const spaceAfter = parentDims.right - childDims.right;
234 | if(aboutEqual(spaceBefore, spaceAfter)) {
235 | childJS.style['alignSelf'] = 'center';
236 | } else if(spaceBefore > spaceAfter) {
237 | childJS.style['alignSelf'] = 'flex-end';
238 | childJS.style['marginRight'] = spaceAfter;
239 | } else if(spaceBefore < spaceAfter) {
240 | childJS.style['alignSelf'] = 'flex-start';
241 | childJS.style['marginLeft'] = spaceBefore;
242 | }
243 |
244 | }
245 |
246 | if(index > 0) {
247 | const lastChildDims = children[index - 1].id == 'row' || children[index - 1].id == 'column' ? children[index - 1].parentDims : idDims[children[index - 1].id]
248 | spacesBetween.push(childDims.top - lastChildDims.bottom)
249 | if(childJS) {
250 | childJS.style['marginTop'] = childDims.top - lastChildDims.bottom;
251 | }
252 | }
253 | });
254 |
255 | if(!aboutEqual(spaceBelow, spaceAbove)) {
256 | // apply padding
257 | alignJustify['paddingTop'] = spaceAbove;
258 | alignJustify['paddingBottom'] = spaceBelow;
259 | }
260 |
261 | }
262 | }
263 |
264 | return alignJustify;
265 | }
266 |
267 |
268 | module.exports.nativeAttrs = nativeAttrs;
269 | module.exports.determineAlignJustify = determineAlignJustify;
270 | module.exports.firstBackgroundColor = firstBackgroundColor;
271 |
--------------------------------------------------------------------------------
/src/components.js:
--------------------------------------------------------------------------------
1 |
2 | const generateChildParent = (orderedIds, idDims) => {
3 | let childParent = {};
4 |
5 | orderedIds.forEach((key) => {
6 | const dims = idDims[key];
7 | const top = Math.min(dims.top, dims.bottom);
8 | const bottom = Math.max(dims.top, dims.bottom);
9 | const left = Math.min(dims.left, dims.right);
10 | const right = Math.max(dims.left, dims.right);
11 | const height = dims.height;
12 | const width = dims.width;
13 |
14 | const parentNodes = orderedIds.filter((pKey) => {
15 | if(key == pKey) {
16 | return false;
17 | }
18 |
19 | const pDims = idDims[pKey];
20 | const ptop = Math.min(pDims.top, pDims.bottom);
21 | const pbottom = Math.max(pDims.top, pDims.bottom);
22 | const pleft = Math.min(pDims.left, pDims.right);
23 | const pright = Math.max(pDims.left, pDims.right);
24 | const pheight = pDims.height;
25 | const pwidth = pDims.width;
26 |
27 | return(top >= ptop && bottom <= pbottom && left >= pleft && right <= pright)
28 | })
29 |
30 | if(parentNodes.length > 0) {
31 | for(let i = parentNodes.length - 1; i>=0; i--) {
32 | if(orderedIds.indexOf(parentNodes[i]) < orderedIds.indexOf(key)) {
33 | childParent[key] = parentNodes[i]
34 | break;
35 | }
36 | }
37 | }
38 |
39 | });
40 |
41 | return childParent;
42 | }
43 |
44 | module.exports.generateChildParent = generateChildParent;
45 |
46 |
--------------------------------------------------------------------------------
/src/flex.js:
--------------------------------------------------------------------------------
1 |
2 | const { determineAlignJustify } = require('./attributes');
3 |
4 |
5 |
6 | const flexColumns = (nodeIds, idDims, parentChildren, childParent) => {
7 | if(nodeIds.length == 0) {
8 | return nodeIds
9 | }
10 |
11 | const orderedNodeIds = nodeIds.sort((a, b) => {
12 | const adims = idDims[a];
13 | const bdims = idDims[b];
14 |
15 | if(adims.left == bdims.left) {
16 | return adims.top - bdims.top;
17 | }
18 |
19 | return adims.left - bdims.left
20 | });
21 |
22 | let colBreaks = [0];
23 | let currentRight = null;
24 |
25 | orderedNodeIds.forEach((nodeId, i) => {
26 | const node = idDims[nodeId];
27 | if(!currentRight) {
28 | currentRight = node.right;
29 | } else {
30 | if(node.left > currentRight) {
31 | currentRight = node.right;
32 | // done with break
33 | colBreaks.push(i)
34 | } else {
35 | // same row
36 | if(node.right > currentRight) {
37 | currentRight = node.right
38 | }
39 | }
40 | }
41 |
42 | })
43 |
44 | colBreaks.push(orderedNodeIds.length)
45 | let cols = [];
46 |
47 | for(let i=0; i {
72 | const children = parentChildren[c] ? parentChildren[c] : [];
73 | return {
74 | id: c,
75 | children: flexBox(children, idDims, parentChildren, childParent)
76 | }
77 | })
78 | });
79 | } else {
80 | let parentDims = idDims[childParent[comps[0]]];
81 | const newLeft = minLeft(comps, idDims);
82 | const newRight = maxRight(comps, idDims);
83 |
84 | if(newLeft) {
85 | parentDims.left = newLeft
86 | }
87 | if(newRight) {
88 | parentDims.right = newRight
89 | }
90 |
91 | cols.push({
92 | id: 'column',
93 | parentDims,
94 | children: flexRows(comps, idDims, parentChildren, childParent)
95 | })
96 | }
97 | }
98 |
99 | return cols.length == 1 ? cols[0] : cols;
100 | }
101 |
102 |
103 |
104 | const flexRows = (nodeIds, idDims, parentChildren, childParent) => {
105 | if(nodeIds.length == 0) {
106 | return nodeIds
107 | }
108 |
109 | const orderedNodeIds = nodeIds.sort((a, b) => {
110 | const adims = idDims[a];
111 | const bdims = idDims[b];
112 |
113 | if(adims.top == bdims.top) {
114 | return adims.left - bdims.left;
115 | }
116 |
117 | return adims.top - bdims.top
118 | });
119 |
120 | let rowBreaks = [0];
121 | let currentBottom = null;
122 |
123 | orderedNodeIds.forEach((nodeId, i) => {
124 | // TODO: node is sometimes null here?
125 | const node = idDims[nodeId];
126 | if(!currentBottom) {
127 | currentBottom = node && node.bottom;
128 | } else {
129 | if(node.top > currentBottom) {
130 | currentBottom = node && node.bottom;
131 | // done with break
132 | rowBreaks.push(i)
133 | } else {
134 | // same row
135 | if(node && node.bottom > currentBottom) {
136 | currentBottom = node && node.bottom
137 | }
138 | }
139 | }
140 |
141 | })
142 |
143 | rowBreaks.push(orderedNodeIds.length)
144 | let rows = [];
145 |
146 |
147 | for(let i=0; i {
171 | const children = parentChildren[c] ? parentChildren[c] : [];
172 | return {
173 | id: c,
174 | children: flexBox(children, idDims, parentChildren, childParent)
175 | }
176 | })
177 | });
178 | } else {
179 | let parentDims = idDims[childParent[comps[0]]];
180 | const newTop = minTop(comps, idDims);
181 | const newBottom = maxBottom(comps, idDims);
182 |
183 | if(newTop) {
184 | parentDims.top = newTop
185 | }
186 | if(newBottom) {
187 | parentDims.bottom = newBottom
188 | }
189 |
190 | rows.push({
191 | id: 'row',
192 | parentDims,
193 | children: flexColumns(comps, idDims, parentChildren, childParent)
194 | })
195 | }
196 | }
197 |
198 | return rows.length == 1 ? rows[0] : rows;
199 | }
200 |
201 |
202 |
203 |
204 | const minTop = (ids, idDims) => {
205 | let min = idDims[ids[0]].top
206 | ids.forEach((id) => {
207 | if(idDims[ids[0]].top < min) {
208 | min = idDims[ids[0]].top;
209 | }
210 | })
211 | return min
212 | }
213 |
214 | const maxBottom = (ids, idDims) => {
215 | let max = ids && ids[0] && idDims[ids[0]] && idDims[ids[0]].bottom
216 | ids.forEach((id) => {
217 | if(idDims[ids[0]].bottom > max) {
218 | max = idDims[ids[0]].bottom;
219 | }
220 | })
221 | return max
222 | }
223 |
224 |
225 | const minLeft = (ids, idDims) => {
226 | let min = idDims[ids[0]].left
227 | ids.forEach((id) => {
228 | if(idDims[ids[0]].left < min) {
229 | min = idDims[ids[0]].left;
230 | }
231 | })
232 | return min
233 | }
234 |
235 | const maxRight = (ids, idDims) => {
236 | let max = idDims[ids[0]].right
237 | ids.forEach((id) => {
238 | if(idDims[ids[0]].right > max) {
239 | max = idDims[ids[0]].right;
240 | }
241 | })
242 | return max
243 | }
244 |
245 |
246 |
247 |
248 | const flexBox = (nodeIds, idDims, parentChildren, childParent) => {
249 | // NOTE! nodes here can be cols or rows; currently it's just
250 | // based on whatever they are in the design... This
251 | // will have to be addressed when there is a design that
252 | // has sibling columns instead of rows.
253 |
254 | // SO, a todo here is to detect if siblings are rows or columns
255 | // (can they be both? no. have to break up into rows first, or cols first.)
256 | return flexRows(nodeIds, idDims, parentChildren, childParent)
257 | }
258 |
259 |
260 |
261 |
262 |
263 | const flattenBoxComponents = (comps, roodId, idDims, jsObjs, indent=0) => {
264 | let flatEles = [];
265 |
266 | const flattenBoxComps = (comps, rootId, idDims, jsObjs, indent=0) => {
267 | if(comps.id && comps.children && comps.children.id) {
268 | // single element with one child
269 | const compDims = idDims[comps.id];
270 | const rootDims = idDims[rootId];
271 | if(rootDims.top == compDims.top && rootDims.bottom == compDims.bottom && rootDims.left == compDims.left && rootDims.right == compDims.right) {
272 | flattenBoxComps(comps.children, rootId, idDims, jsObjs, indent)
273 | return flatEles
274 | }
275 | }
276 |
277 | const spaces = ' ' + new Array(indent + 1).join(' ');
278 | if(comps && comps.id) {
279 | const hasChildren = comps.children && comps.children.id || comps.children.length > 0;
280 |
281 | let alignJustify = {}
282 | if(hasChildren) {
283 | alignJustify = determineAlignJustify(comps, comps.children, idDims, jsObjs);
284 | }
285 |
286 | flatEles.push({spaces, id: comps.id, single: !hasChildren, alignJustify});
287 | if(hasChildren) {
288 | flattenBoxComps(comps.children, rootId, idDims, jsObjs, indent + 1)
289 | flatEles.push({spaces, id: comps.id, end: true});
290 | }
291 | }
292 | if(comps && Array.isArray(comps)) {
293 | comps.forEach((comp) => {
294 | const hasChildren = comp.children && comp.children.id || comp.children.length > 0;
295 |
296 | let alignJustify = {}
297 | if(hasChildren) {
298 | alignJustify = determineAlignJustify(comp, comp.children, idDims, jsObjs);
299 | }
300 |
301 | flatEles.push({spaces, id: comp.id, single: !hasChildren, alignJustify});
302 | if(hasChildren) {
303 | flattenBoxComps(comp.children, rootId, idDims, jsObjs, indent + 1)
304 | flatEles.push({spaces, id: comp.id, end: true});
305 | }
306 | })
307 | }
308 | }
309 |
310 | flattenBoxComps(comps, roodId, idDims, jsObjs, indent);
311 |
312 | return flatEles;
313 | }
314 |
315 |
316 |
317 |
318 | module.exports.flexBox = flexBox;
319 | module.exports.flattenBoxComponents = flattenBoxComponents;
320 |
--------------------------------------------------------------------------------
/src/index.js:
--------------------------------------------------------------------------------
1 | require("regenerator-runtime/runtime");
2 |
3 | const svgson = require('svgson');
4 | const fs = require('fs');
5 | const path = require('path');
6 |
7 | const { emptyAndCreateDir, makeDir, copyFolderRecursive } = require('./lib/files');
8 | const { processNode, imagifyParents } = require('./process');
9 | const { getBrowserBoundingBoxes, screenshotElements } = require('./screenshot');
10 | const { flexBox, flattenBoxComponents } = require('./flex');
11 | const { generateComponent, generateComponentStrings } = require('./output');
12 | const { firstBackgroundColor, nativeAttrs } = require('./attributes');
13 | const { generateChildParent } = require('./components');
14 | const { prepData } = require('./input');
15 | const { removeStatusBarAndKeyboard } = require('./neural_net');
16 | const {
17 | aboutZero,
18 | aboutEqual,
19 | allAboutEqual,
20 | smallComparedTo
21 | } = require('./lib/utils');
22 |
23 | // catch unhandled rejections (e.g. async/await without try/catch)
24 | process.on("unhandledRejection", function(err) { console.error(err); });
25 |
26 | const INPUT_FILE = process.argv[2]
27 | if(!INPUT_FILE || INPUT_FILE == '' || !INPUT_FILE.match(/\.svg$/)) {
28 | throw "Usage: convert.js [svg_file]"
29 | }
30 |
31 | const pathArray = INPUT_FILE.split('/')
32 | const INPUT_FILENAME = pathArray[pathArray.length - 1]
33 |
34 | const INPUT_FILE_NO_SPACES = INPUT_FILENAME.replace(/\s/g, '_').split(".svg")[0]
35 | const OUTPUT_FILE = INPUT_FILE_NO_SPACES.split(".svg")[0] + ".js"
36 |
37 | const BASE_PATH = path.resolve();
38 | const OUTPUT_DIR = 'output';
39 | const TEMP_DIR = 'temp';
40 | const IMAGES_DIR = INPUT_FILE_NO_SPACES.split(".svg")[0]+'_images';
41 | const TEMP_IMAGES_DIR = path.join(BASE_PATH, TEMP_DIR, IMAGES_DIR);
42 | const TEMP_COMPONENT_DIR = path.join(BASE_PATH, TEMP_DIR, 'components');
43 |
44 | makeDir(OUTPUT_DIR); // don't delete what's in output each time!
45 | emptyAndCreateDir(TEMP_DIR);
46 | emptyAndCreateDir(TEMP_IMAGES_DIR);
47 | emptyAndCreateDir(TEMP_COMPONENT_DIR);
48 |
49 | (async () => {
50 |
51 | fs.readFile(INPUT_FILE, 'utf-8', (err, data) => {
52 |
53 | const preppedData = prepData({
54 | data,
55 | tempDir: TEMP_DIR,
56 | inputFile: INPUT_FILENAME
57 | });
58 |
59 | svgson(preppedData, {}, async function(result) {
60 |
61 | const processedJS = processNode(result);
62 |
63 | const preppedFile = 'file://'+path.join(BASE_PATH, TEMP_DIR, 'prepped_'+INPUT_FILENAME);
64 |
65 | const cleanedJS = await removeStatusBarAndKeyboard(preppedFile, TEMP_COMPONENT_DIR, processedJS);
66 |
67 | await getBrowserBoundingBoxes(cleanedJS, preppedFile);
68 |
69 | const js = imagifyParents(cleanedJS)
70 | const { idDims, orderedIds } = await getBrowserBoundingBoxes(js, preppedFile);
71 |
72 | const mainBackgroundColor = firstBackgroundColor(js.childs[0], js.childs[0], idDims);
73 |
74 | const childParent = generateChildParent(orderedIds, idDims);
75 |
76 | let parentChildren = {};
77 | orderedIds.forEach((id) => {
78 | if(childParent[id]) {
79 | if(!parentChildren[childParent[id]]) {
80 | parentChildren[childParent[id]] = [];
81 | }
82 | parentChildren[childParent[id]].push(id);
83 | }
84 | });
85 |
86 | let jsObjs = {}
87 | const unrollJs = (js) => {
88 | jsObjs[js.id] = js;
89 | js.childs.forEach((child) => {
90 | unrollJs(child)
91 | })
92 | }
93 | unrollJs(js);
94 |
95 | const boxComponents = flexBox([orderedIds[0]], idDims, parentChildren, childParent)
96 |
97 | const flatEles = flattenBoxComponents(boxComponents, boxComponents.id, idDims, jsObjs);
98 |
99 | const polygons = orderedIds.filter((id) => {
100 | const item = jsObjs[id];
101 | return(item.type == 'Polygon' || item.type == 'Path' || item.type == 'Image');
102 | })
103 |
104 | await screenshotElements(
105 | preppedFile,
106 | TEMP_IMAGES_DIR,
107 | polygons
108 | );
109 |
110 | let globalStyles = {};
111 | flatEles.forEach((ele) => {
112 | const js = jsObjs[ele.id];
113 | if(!ele.end && ele.id != 'row' && ele.id != 'column') {
114 | const { componentStyles } = nativeAttrs(js, ele, idDims, childParent);
115 | globalStyles = {...globalStyles, ...componentStyles}
116 | }
117 | });
118 |
119 | const { imports, componentStrings } = generateComponentStrings({
120 | flatEles,
121 | idDims,
122 | childParent,
123 | jsObjs,
124 | imagesDir: IMAGES_DIR
125 | });
126 |
127 | const generatedComponent = generateComponent({
128 | imports,
129 | rootStyle: processedJS.rootStyle,
130 | mainBackgroundColor,
131 | componentStrings,
132 | globalStyles
133 | })
134 | emptyAndCreateDir(OUTPUT_DIR + '/' + IMAGES_DIR)
135 | copyFolderRecursive(TEMP_IMAGES_DIR, OUTPUT_DIR + '/' + IMAGES_DIR)
136 | fs.writeFileSync(OUTPUT_DIR + '/' + OUTPUT_FILE, generatedComponent)
137 |
138 | console.log("")
139 | console.log("Images directory written: ", path.join(BASE_PATH, OUTPUT_DIR, IMAGES_DIR))
140 | console.log("React Native component generated: ", path.join(BASE_PATH, OUTPUT_DIR, OUTPUT_FILE))
141 | console.log("")
142 |
143 | });
144 | })
145 |
146 | })();
147 |
148 |
149 |
150 |
--------------------------------------------------------------------------------
/src/input.js:
--------------------------------------------------------------------------------
1 |
2 | const fs = require('fs');
3 |
4 | const prepData = ({ data, tempDir, inputFile }) => {
5 | let idList = [];
6 |
7 | const oldDefs = data.match(/\[\s\S]*\<\/defs\>/g)
8 |
9 | const newData = data.replace(/\sid=\"[^\"]+/g, function(d) {
10 | let onlyId = d.replace(/^\sid\=\"/, '').replace(/[^\_0-9a-zA-Z]/g, '');
11 |
12 | if(onlyId.match(/^[0-9]+/)) {
13 | onlyId = "_" + onlyId
14 | }
15 |
16 | d = ' id="' + onlyId;
17 |
18 | if(idList.indexOf(d) >= 0) {
19 | const newId = d + "_" + idList.length
20 | idList.push(newId)
21 | return newId
22 | } else {
23 | idList.push(d);
24 | return d;
25 | }
26 | });
27 |
28 | const preppedData = newData.replace(/\[\s\S]*\<\/defs\>/g, oldDefs)
29 |
30 | fs.writeFileSync(tempDir + '/prepped_'+inputFile, preppedData)
31 |
32 | return preppedData;
33 | }
34 |
35 | module.exports.prepData = prepData;
36 |
--------------------------------------------------------------------------------
/src/lib/files.js:
--------------------------------------------------------------------------------
1 | const fs = require('fs');
2 |
3 | // Clear the images directory
4 | const deleteFolderRecursive = function(path) {
5 | if (fs.existsSync(path)) {
6 | fs.readdirSync(path).forEach(function(file, index){
7 | const curPath = path + "/" + file;
8 | if (fs.lstatSync(curPath).isDirectory()) { // recurse
9 | deleteFolderRecursive(curPath);
10 | } else { // delete file
11 | fs.unlinkSync(curPath);
12 | }
13 | });
14 | fs.rmdirSync(path);
15 | }
16 | };
17 |
18 | module.exports.makeDir = (path) => {
19 | if (!fs.existsSync(path)) {
20 | fs.mkdirSync(path);
21 | }
22 | }
23 |
24 | module.exports.emptyAndCreateDir = (dir) => {
25 | deleteFolderRecursive(dir);
26 | fs.mkdirSync(dir);
27 | }
28 |
29 | module.exports.copyFolderRecursive = function(path, toPath) {
30 | if (fs.existsSync(path)) {
31 | fs.readdirSync(path).forEach(function(file, index){
32 | const curPath = path + "/" + file;
33 | const newPath = toPath + "/" + file;
34 | if (fs.lstatSync(curPath).isDirectory()) { // recurse
35 | copyFolderRecursive(curPath);
36 | } else { // delete file
37 | fs.createReadStream(curPath).pipe(fs.createWriteStream(newPath));
38 | }
39 | });
40 | }
41 | };
42 |
43 |
--------------------------------------------------------------------------------
/src/lib/utils.js:
--------------------------------------------------------------------------------
1 | // TODO: remove hardcoded values
2 |
3 | const aboutZero = (a) => {
4 | return Math.abs(a-0) < 6
5 | }
6 |
7 | const aboutCentered = (a, b) => {
8 | return Math.abs(a-b) < 17
9 | }
10 |
11 | const aboutEqual = (a, b) => {
12 | return Math.abs(a-b) < 6
13 | }
14 |
15 | const smallComparedTo = (a, b) => {
16 | if(aboutZero(a) && !aboutZero(b)) {
17 | return true
18 | }
19 | const ratio = b / a;
20 | return(a != 0 && ratio > 0.5)
21 | }
22 |
23 | const allAboutEqual = (arr) => {
24 | const avg = arr.reduce(function(sum, a) { return sum + a },0)/(arr.length||1);
25 | return arr.filter((ele) => {
26 | return !aboutEqual(avg, ele)
27 | }).length == 0;
28 | }
29 |
30 |
31 | module.exports.aboutZero = aboutZero;
32 | module.exports.aboutCentered = aboutCentered;
33 | module.exports.aboutEqual = aboutEqual;
34 | module.exports.allAboutEqual = allAboutEqual;
35 | module.exports.smallComparedTo = smallComparedTo;
36 |
--------------------------------------------------------------------------------
/src/neural_net.js:
--------------------------------------------------------------------------------
1 |
2 | const exec = require('child_process').exec;
3 | const fs = require('fs');
4 | const path = require('path');
5 |
6 | const { screenshotAllElements } = require('./screenshot');
7 |
8 |
9 | const childIsBadElement = (async (child, tempDir) => {
10 | if(!child.id || child.id == '') {
11 | return false
12 | }
13 |
14 | const filepath = path.join(tempDir, 'components/'+child.id+'.png');
15 | if (!fs.existsSync(filepath)) {
16 | return false;
17 | }
18 |
19 | const command = `python -m ${path.join(__dirname, '../scripts/label_image')} --graph=${path.join(__dirname, '../tf_files/retrained_graph.pb')} --image=`+filepath;
20 |
21 | return new Promise((resolve, reject) => {
22 | exec(command,
23 | function (error, stdout, stderr) {
24 | if(stdout && stdout != '') {
25 | const results = stdout.split("\n")
26 |
27 | const statusBar = results.filter((result) => {
28 | return result.match(/^status\sbar/);
29 | });
30 | const isStatusBar = statusBar && statusBar.length > 0 && parseFloat(statusBar[0].split(/\s/)[statusBar[0].split(/\s/).length - 1]);
31 |
32 | const keyboard = results.filter((result) => {
33 | return result.match(/^keyboard/);
34 | });
35 | const isKeyboard = keyboard && keyboard.length > 0 && parseFloat(keyboard[0].split(/\s/)[keyboard[0].split(/\s/).length - 1]);
36 |
37 | // Make sure you're really sure it's a status bar or keyboard
38 | // just text can trigger even a 98% status bar response
39 | resolve(isStatusBar > 0.99 || isKeyboard > 0.99)
40 | }
41 |
42 | if (error !== null) {
43 | console.log('exec error: ' + error);
44 | resolve(false)
45 | } else {
46 | resolve(false)
47 | }
48 | });
49 | });
50 |
51 | });
52 |
53 |
54 | const filterElements = (async (js, tempDir) => {
55 | if(js.childs) {
56 |
57 | const newChildren = await Promise.all(js.childs.map(async (child) => {
58 | const childShouldBeFilteredOut = await childIsBadElement(child, tempDir);
59 | if(childShouldBeFilteredOut) {
60 | return null;
61 | } else {
62 | await filterElements(child, tempDir);
63 | return child
64 | }
65 | }));
66 |
67 | const filteredChildren = newChildren.filter(child => child);
68 | js.childs = filteredChildren
69 | }
70 | return js;
71 | })
72 |
73 | const removeStatusBarAndKeyboard = (async (file, tempDir, js) => {
74 | // Start with the children
75 | await screenshotAllElements(file, tempDir, js)
76 | console.warn("filtering elements...")
77 | const newJS = await filterElements(js, tempDir)
78 |
79 | return newJS;
80 | })
81 |
82 |
83 |
84 | module.exports.removeStatusBarAndKeyboard = removeStatusBarAndKeyboard;
85 |
--------------------------------------------------------------------------------
/src/output.js:
--------------------------------------------------------------------------------
1 |
2 | const { nativeAttrs } = require('./attributes');
3 |
4 |
5 | // Note: the spacing (indentation) is important (even though it looks weird)
6 | // for everything in the ``.
7 | const generateStyleSheetString = (componentStyles) => {
8 | return Object.keys(componentStyles).map((key) => {
9 | const keyStyles = Object.keys(componentStyles[key]).map((styleKey) => {
10 | const styleString = componentStyles[key][styleKey]
11 | return ` ${styleKey}: ${styleString}`
12 | }).join(",\n");
13 | return ` ${key}: {
14 | ${keyStyles}
15 | }`
16 | }).join(",\n")
17 | }
18 |
19 |
20 | const generateComponentStrings = ({
21 | flatEles,
22 | idDims,
23 | childParent,
24 | jsObjs,
25 | imagesDir
26 | }) => {
27 | let imports = [];
28 |
29 | const componentStrings = flatEles.map((ele) => {
30 |
31 | const color = '#'+'0123456789abcddd'.split('').map(function(v,i,a){ return i>5 ? null : a[Math.floor(Math.random()*16)] }).join('');
32 |
33 | if(ele.end) {
34 | if(ele.id == 'row') return `${ele.spaces}`
35 | if(ele.id == 'column') return `${ele.spaces}`
36 | // The rest of the ends are covered below
37 | }
38 |
39 | if(ele.id == 'row') {
40 | let styleAttrs = [`flexDirection: 'row'`]
41 | if(ele.alignJustify) {
42 | const aj = ele.alignJustify
43 | if(aj.justifyContent) styleAttrs.push(`justifyContent: '${aj.justifyContent}'`)
44 | if(aj.alignItems) styleAttrs.push(`alignItems: '${aj.alignItems}'`)
45 | if(aj.marginRight) styleAttrs.push(`marginRight: '${aj.marginRight}'`)
46 | if(aj.marginLeft) styleAttrs.push(`marginLeft: '${aj.marginLeft}'`)
47 | if(aj.marginTop) styleAttrs.push(`marginTop: '${aj.marginTop}'`)
48 | if(aj.marginBottom) styleAttrs.push(`marginBottom: '${aj.marginBottom}'`)
49 | }
50 | return `${ele.spaces}`
51 | }
52 | if(ele.id == 'column') {
53 | return `${ele.spaces}`
54 | }
55 |
56 | const js = jsObjs[ele.id];
57 |
58 | if(ele.end) {
59 | if(js.type == 'Path') return `${ele.spaces}`
60 | if(js.type == 'Image') return `${ele.spaces}`
61 | if(js.type == 'Polygon') return `${ele.spaces}`
62 | return `${ele.spaces}`
63 | }
64 |
65 | const {attrs, attrObjs, styleId} = nativeAttrs(js, ele, idDims, childParent);
66 |
67 | const styles = styleId ? ` style={styles.${styleId}}` : ''
68 |
69 | if(js.type == 'Text') {
70 | if(js.childs.length == 0) {
71 | return `${ele.spaces}${js.text}${ele.single ? '' : ''}`
72 | }
73 | if(js.childs.length == 1 && js.childs[0].type == 'Tspan') {
74 | return `${ele.spaces}${js.text}${js.childs[0].text}`
75 | } else {
76 | const tspans = js.childs.map((t) => {
77 | if(t.type == 'Tspan') {
78 | // TODO: once tspans have ids, change this
79 | // const {Tattrs, TattrObjs, TstyleId} = nativeAttrs(t, ele);
80 | // const Tstyles = TstyleId ? ` style={styles.${styleId}}` : ''
81 | const Tstyles = ''
82 | return `${ele.spaces} ${t.text}{'\\n'}`
83 | }
84 | }).join("\n")
85 |
86 | return `${ele.spaces}${js.text}\n${tspans}\n${ele.spaces}`
87 | }
88 |
89 | }
90 | if(js.type == 'Tspan') {
91 | return `${js.text}`
92 | }
93 |
94 | if(js.type == 'Polygon') {
95 | const imageStyle = attrObjs.style ? ` style={${attrObjs.style}}` : '';
96 | const jsname = js.id.replace("-", "").replace("+", "")
97 | if(imports.indexOf(`import ${jsname} from './${imagesDir}/${js.id}.png'`) < 0) {
98 | imports.push(`import ${jsname} from './${imagesDir}/${js.id}.png'`)
99 | }
100 | return `${ele.spaces}`
101 | }
102 |
103 | if(js.type == 'Image') {
104 | const imageStyle = attrObjs.style ? ` style={${attrObjs.style}}` : '';
105 | const jsname = js.id.replace("-", "").replace("+", "")
106 | if(imports.indexOf(`import ${jsname} from './${imagesDir}/${js.id}.png'`) < 0) {
107 | imports.push(`import ${jsname} from './${imagesDir}/${js.id}.png'`)
108 | }
109 | return `${ele.spaces}`
110 | }
111 |
112 | if(js.type == 'Path') {
113 | const imageStyle = attrObjs.style ? ` style={${attrObjs.style}}` : '';
114 | if(ele.end) {
115 | return `${ele.spaces}`
116 | }
117 | const jsname = js.id.replace("-", "").replace("+", "")
118 | if(imports.indexOf(`import ${jsname} from './${imagesDir}/${js.id}.png'`) < 0) {
119 | imports.push(`import ${jsname} from './${imagesDir}/${js.id}.png'`)
120 | }
121 | if(ele.single) {
122 | return `${ele.spaces}`
123 | } else {
124 | return `${ele.spaces}`
125 | }
126 | }
127 |
128 | if(ele.end) {
129 | return(ele.spaces+'');
130 | }
131 | return `${ele.spaces}<${js.type}${attrs}${styles}${ele.single ? ' /' : ''}>`
132 |
133 | }).join("\n");
134 |
135 | return({
136 | componentStrings,
137 | imports
138 | })
139 | }
140 |
141 |
142 |
143 | const generateComponent = ({
144 | imports,
145 | rootStyle,
146 | mainBackgroundColor,
147 | componentStrings,
148 | globalStyles
149 | }) => {
150 | const styleSheetString = generateStyleSheetString(globalStyles);
151 |
152 | // Note: the spacing (indentation) is important (even though it looks weird)
153 | // for everything in the ``.
154 | return (
155 | `import React, { Component } from 'react';
156 |
157 | import {
158 | StyleSheet,
159 | Text,
160 | View,
161 | TouchableOpacity,
162 | TextInput,
163 | ScrollView,
164 | Image
165 | } from 'react-native';
166 |
167 | `+imports.join("\n")+`
168 |
169 | export default class Main extends Component {
170 |
171 | render() {
172 | return (
173 |
177 | ` + componentStrings + `
178 |
179 | )
180 | }
181 |
182 | }
183 |
184 | const styles = StyleSheet.create({
185 | ${styleSheetString}
186 | })
187 | `
188 | )
189 | }
190 |
191 |
192 | module.exports.generateComponent = generateComponent;
193 | module.exports.generateComponentStrings = generateComponentStrings;
194 |
--------------------------------------------------------------------------------
/src/process.js:
--------------------------------------------------------------------------------
1 |
2 |
3 | const processNode = (node, parentAttrs={}) => {
4 |
5 | if(node.name == 'svg') {
6 | return processSVG(node, parentAttrs);
7 | } else if(node.name == 'g') {
8 | return processG(node, parentAttrs);
9 | } else if(node.name == 'text') {
10 | return processText(node, parentAttrs);
11 | } else if(node.name == 'rect') {
12 | return processRect(node, parentAttrs);
13 | } else if(node.name == 'path') {
14 | return processPath(node, parentAttrs);
15 | } else if(node.name == 'polygon') {
16 | return processPolygon(node, parentAttrs);
17 | } else if(node.name == 'tspan') {
18 | return processTspan(node, parentAttrs);
19 | }
20 | }
21 |
22 | const processChildren = (parent, parentAttrs = {}) => {
23 | const children = parent.childs ? parent.childs : [];
24 |
25 | // If all the nodes are going to be images or empty,
26 | // Then just process the entire thing as one image.
27 |
28 | return children.filter((node) => {
29 | const processedNode = processNode(node, parentAttrs);
30 | return processedNode;
31 | }).map((node) => {
32 | return processNode(node, parentAttrs)
33 | });
34 | }
35 |
36 |
37 | const processSVG = (node, parentAttrs={}) => {
38 |
39 | const viewBoxArray = node.attrs.viewBox.split(" ")
40 | const viewBox = {
41 | x: viewBoxArray[0],
42 | y: viewBoxArray[1]
43 | }
44 | parentAttrs.viewBox = viewBox;
45 |
46 | let rootStyle = {}
47 | if(node.attrs.style) {
48 | rootStyle = node.attrs.style.split(/\;\s*/).reduce((r, style) => {
49 | const keyValue = style.split(/\:\s*/)
50 | if(keyValue.length == 2) {
51 | r[keyValue[0]] = keyValue[1]
52 | if(keyValue[0] == 'background' && keyValue[1].length == 4 || keyValue[1].length == 7) {
53 | r['backgroundColor'] = keyValue[1]
54 | }
55 | }
56 | return r;
57 | }, {});
58 | }
59 |
60 | const backgroundColor = rootStyle && rootStyle.background ? rootStyle.background : rootStyle && rootStyle.backgroundColor ? rootStyle.backgroundColor : '#ffffff';
61 | let styles = {
62 | flex: 1,
63 | alignSelf: 'stretch'
64 | }
65 | if(backgroundColor) {
66 | styles.backgroundColor = backgroundColor
67 | }
68 |
69 | return {
70 | id: node.attrs.id,
71 | type: 'ScrollView',
72 | rootStyle,
73 | childs: processChildren(node, parentAttrs),
74 | style: styles
75 | }
76 | }
77 |
78 | const processPolygon = (node, parentAttrs={}) => {
79 | return {
80 | id: node.attrs.id,
81 | type: 'Polygon',
82 | childs: [],
83 | style: {}
84 | }
85 | }
86 |
87 |
88 | const processG = (node, parentAttrs={}) => {
89 | const tl = topLeft(node, parentAttrs)
90 |
91 | let styles = {...tl}
92 |
93 | if(node.attrs.fontSize) {
94 | parentAttrs.fontSize = node.attrs.fontSize
95 | }
96 | if(node.attrs.fontWeight) {
97 | parentAttrs.fontWeight = node.attrs.fontWeight
98 | }
99 | // Font family removed for now
100 |
101 | // if(node.attrs.fontFamily) {
102 | // parentAttrs.fontFamily = node.attrs.fontFamily.split(", ")[0]
103 | // }
104 | if(node.attrs.lineHeight) {
105 | parentAttrs.lineHeight = node.attrs.lineSpacing
106 | }
107 | if(node.attrs.fill) {
108 | parentAttrs.fill = node.attrs.fill
109 | }
110 |
111 | let use = null;
112 | node.childs && node.childs.forEach((child) => {
113 | if(child.name == 'use') {
114 | if(child.attrs.fill) {
115 | styles.backgroundColor = child.attrs.fill
116 | }
117 | }
118 | })
119 |
120 |
121 | return {
122 | id: node.attrs.id,
123 | type: 'View',
124 | childs: processChildren(node, parentAttrs),
125 | style: styles
126 | }
127 | }
128 |
129 |
130 | const processRect = (node, parentAttrs={}) => {
131 | const tl = topLeft(node, parentAttrs)
132 |
133 | let styles = {...tl}
134 |
135 | const attrs = node.attrs ? node.attrs : {}
136 | if(attrs.fill) {
137 | styles.backgroundColor = attrs.fill
138 | }
139 |
140 | if(attrs.opacity && attrs.fill) {
141 | const op = parseFloat(attrs.opacity).toFixed(2).split(".")[1];
142 | styles.backgroundColor = attrs.fill + op
143 | }
144 |
145 | if(attrs.rx) {
146 | styles.borderRadius = attrs.rx
147 | }
148 |
149 | if(attrs.ry) {
150 | styles.borderRadius = attrs.rx
151 | }
152 |
153 | return {
154 | id: node.attrs.id,
155 | type: 'View',
156 | childs: processChildren(node, parentAttrs),
157 | style: styles
158 | }
159 | }
160 |
161 |
162 | const processText = (node, parentAttrs={}) => {
163 | let style = {}
164 | let text = ""
165 |
166 | const attrs = node.attrs ? node.attrs : {}
167 |
168 | if(attrs.x || attrs.y) {
169 | style.position = 'absolute'
170 | style.left = attrs.x ? attrs.x : 0
171 | style.top = attrs.y ? attrs.y : 0
172 | }
173 |
174 | style.backgroundColor = 'transparent'
175 | if(parentAttrs.fontSize) {
176 | style.fontSize = parentAttrs.fontSize
177 | }
178 | if(attrs.fontSize) {
179 | style.fontSize = attrs.fontSize
180 | }
181 |
182 | if(parentAttrs.lineHeight) {
183 | style.lineHeight = parentAttrs.lineHeight
184 | }
185 | if(parentAttrs.fontWeight) {
186 | style.fontWeight = parentAttrs.fontWeight
187 | }
188 | if(attrs.fontWeight) {
189 | style.fontWeight = attrs.fontWeight
190 | }
191 |
192 | // Font family removed for now; sketch files do not
193 | // contain the fonts, so this became a mess.
194 |
195 | // if(parentAttrs.fontFamily) {
196 | // style.fontFamily = parentAttrs.fontFamily
197 | // }
198 | // if(attrs.fontFamily) {
199 | // style.fontFamily = attrs.fontFamily.split(", ")[0]
200 | // }
201 |
202 | if(parentAttrs.fill) {
203 | style.color = parentAttrs.fill
204 | }
205 | if(attrs.fill) {
206 | style.color = attrs.fill;
207 | }
208 |
209 | return {
210 | id: node.attrs.id,
211 | type: 'Text',
212 | childs: processChildren(node, parentAttrs),
213 | text: text,
214 | style
215 | }
216 | }
217 |
218 |
219 | const processTspan = (node, parentAttrs={}) => {
220 | let style = {}
221 | const attrs = node.attrs ? node.attrs : {}
222 |
223 | if(attrs.fill) {
224 | style.color = attrs.fill;
225 | }
226 | if(attrs.x || attrs.y) {
227 | style.position = 'absolute'
228 | style.left = attrs.x ? attrs.x : 0
229 | style.top = attrs.y ? attrs.y : 0
230 | }
231 |
232 | const children = node.childs ? node.childs : [];
233 | const textChild = children && children.length == 1 && children[0].text ? children[0].text : '';
234 |
235 | return {
236 | id: node.attrs.id,
237 | type: 'Tspan',
238 | childs: [],
239 | text: textChild,
240 | style
241 | }
242 | }
243 |
244 |
245 |
246 | const processPath = (node, parentAttrs={}) => {
247 |
248 | let style = {}
249 | const attrs = node.attrs ? node.attrs : {}
250 | let directAttrs = {}
251 |
252 | if(attrs.d) {
253 | directAttrs.d = attrs.d;
254 | }
255 |
256 | if(parentAttrs.fill) {
257 | directAttrs.fill = parentAttrs.fill;
258 | }
259 | if(attrs.fill) {
260 | directAttrs.fill = attrs.fill;
261 | }
262 | if(attrs.opacity) {
263 | directAttrs.opacity = attrs.opacity;
264 | }
265 | if(attrs.x || attrs.y) {
266 | style.position = 'absolute'
267 | style.left = attrs.x ? attrs.x : 0
268 | style.top = attrs.y ? attrs.y : 0
269 | }
270 |
271 | const children = node.childs ? node.childs : [];
272 |
273 |
274 | return {
275 | id: node.attrs.id,
276 | type: 'Path',
277 | childs: processChildren(node, parentAttrs),
278 | style,
279 | directAttrs: directAttrs
280 | }
281 | }
282 |
283 |
284 |
285 | const topLeft = (node, parentAttrs={}) => {
286 | let styles = {position: 'absolute'}
287 | const viewBox = parentAttrs.viewBox;
288 |
289 | if(node.attrs && node.attrs.transform && node.attrs.transform.match(/^translate\([^\(]+\)$/)) {
290 | const transform = node.attrs.transform.replace(/^translate\(/, '').replace(/\)$/, '').split(", ")
291 | styles.left = transform[0] - (viewBox ? viewBox.x : 0)
292 | styles.top = transform[1] - (viewBox ? viewBox.y : 0)
293 | }
294 |
295 | if(node.attrs && node.attrs.height) {
296 | styles.height = node.attrs.height
297 | }
298 |
299 | if(node.attrs && node.attrs.width) {
300 | styles.width = node.attrs.width
301 | }
302 |
303 | if(node.attrs && node.attrs.x) {
304 | styles.left = node.attrs.x - (viewBox ? viewBox.x : 0)
305 | }
306 |
307 | if(node.attrs && node.attrs.y) {
308 | styles.top = node.attrs.y - (viewBox ? viewBox.y : 0)
309 | }
310 |
311 | return styles
312 | }
313 |
314 | // If the bounding boxes for a view are all empty views
315 | // or paths or polygons or masks?
316 | // Then call the view an image, and process it once.
317 | const imagifyParents = (js) => {
318 | // if all children are empty views or polygons or paths
319 | // return no children, and change the type to image.
320 |
321 | let newJS = {...js}
322 |
323 | let newChildren = js.childs.map((child) => {
324 | return imagifyParents(child);
325 | });
326 |
327 | let siblingsAreImages = js.childs.length > 0
328 | js.childs.forEach((child) => {
329 | const emptyView = child.type == 'View' && child.childs.length == 0;
330 | const isImageType = ['Path', 'Polygon'].indexOf(child.type) > -1;
331 | if(!emptyView && !isImageType) {
332 | siblingsAreImages = false
333 | }
334 | });
335 |
336 | if(siblingsAreImages) {
337 | newJS.type = 'Image';
338 | newJS.childs = [];
339 | } else {
340 | newJS.childs = newChildren;
341 | }
342 |
343 | return newJS;
344 | }
345 |
346 |
347 | module.exports.processNode = processNode;
348 | module.exports.imagifyParents = imagifyParents;
349 |
--------------------------------------------------------------------------------
/src/screenshot.js:
--------------------------------------------------------------------------------
1 | const puppeteer = require('puppeteer');
2 |
3 |
4 | const getAllElementIds = (js) => {
5 | let ids = js.childs.filter((child) => {
6 | return child.id;
7 | }).map((child) => {
8 | return child.id;
9 | });
10 | js.childs.forEach((child) => {
11 | const childIds = getAllElementIds(child);
12 | ids = [...ids, ...childIds]
13 | })
14 | return ids;
15 | }
16 |
17 | const screenshotAllElements = (async (file, tempDir, rootJS) => {
18 |
19 | let elements = getAllElementIds(rootJS);
20 | console.warn("gathering elements...")
21 | await screenshotElements(file, tempDir, elements);
22 |
23 | })
24 |
25 |
26 | const screenshotElements = (async (file, tempDir, elementIds) => {
27 | const browser = await puppeteer.launch();
28 | const page = await browser.newPage();
29 | await page.goto(file);
30 |
31 | for(let id of elementIds) {
32 | await screenshotDOMElement(page, id, tempDir)
33 | }
34 |
35 | browser.close();
36 | });
37 |
38 |
39 | async function screenshotDOMElement(page, selector, tempDir) {
40 |
41 | const rect = await page.evaluate(selector => {
42 | const element = document.querySelector("#" + selector);
43 | if (!element)
44 | return null;
45 | const {x, y, width, height} = element.getBoundingClientRect();
46 | return {left: x, top: y, width, height, id: element.id};
47 | }, selector);
48 |
49 | if(!rect) {
50 | throw Error(`Could not find element that matches selector: ${selector}.`);
51 | }
52 |
53 | return await page.screenshot({
54 | path: tempDir+'/'+selector+'.png',
55 | clip: {
56 | x: rect.left,
57 | y: rect.top,
58 | width: rect.width,
59 | height: rect.height
60 | }
61 | });
62 | }
63 |
64 |
65 | const getEleDimensions = (async (node, page, idDims, orderedIds) => {
66 |
67 | if(node.id) {
68 | // Get the "viewport" of the page, as reported by the page.
69 | const dimensions = await page.evaluate((node) => {
70 | const ele = document.getElementById(node.id) && document.getElementById(node.id).getBoundingClientRect()
71 | if(ele) {
72 | return {
73 | top: Math.min(ele.top, ele.bottom),
74 | left: Math.min(ele.left, ele.right),
75 | right: Math.max(ele.right, ele.left),
76 | bottom: Math.max(ele.bottom, ele.top),
77 | height: ele.height,
78 | width: ele.width
79 | }
80 | } else {
81 | return null;
82 | }
83 | }, node);
84 |
85 | if(dimensions) {
86 | idDims[node.id] = dimensions;
87 | orderedIds.push(node.id);
88 | }
89 | }
90 |
91 | for(let child of node.childs) {
92 | await getEleDimensions(child, page, idDims, orderedIds);
93 | }
94 |
95 | });
96 |
97 |
98 | const getBrowserBoundingBoxes = (async (js, file) => {
99 | const browser = await puppeteer.launch();
100 | const page = await browser.newPage();
101 | await page.goto(file);
102 |
103 | let idDims = {};
104 | let orderedIds = [];
105 | await getEleDimensions(js, page, idDims, orderedIds);
106 |
107 | browser.close();
108 |
109 | return({
110 | idDims,
111 | orderedIds
112 | });
113 |
114 | });
115 |
116 |
117 | module.exports.getBrowserBoundingBoxes = getBrowserBoundingBoxes;
118 | module.exports.screenshotElements = screenshotElements;
119 | module.exports.screenshotAllElements = screenshotAllElements;
120 |
--------------------------------------------------------------------------------
/tf_files/retrained_graph.pb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nanohop/sketch-to-react-native/6075a766749f6ed419ff0f34abdb81f8cc4cf0f3/tf_files/retrained_graph.pb
--------------------------------------------------------------------------------
/tf_files/retrained_labels.txt:
--------------------------------------------------------------------------------
1 | activity
2 | button
3 | keyboard
4 | slider
5 | status bar
6 | switch
7 | text
8 |
--------------------------------------------------------------------------------