├── .babelrc ├── .gitignore ├── LICENSE ├── README.md ├── bin └── convert.js ├── examples ├── Log In Page.svg └── login_page.sketch ├── images ├── export_instructions.png ├── export_instructions_2.png ├── sketch_conversion.png └── sketch_to_react_native.png ├── package-lock.json ├── package.json ├── scripts ├── __init__.py ├── count_ops.py ├── evaluate.py ├── graph_pb2tb.py ├── label_image.py ├── quantize_graph.py └── show_image.py ├── src ├── attributes.js ├── components.js ├── flex.js ├── index.js ├── input.js ├── lib │ ├── files.js │ └── utils.js ├── neural_net.js ├── output.js ├── process.js └── screenshot.js └── tf_files ├── retrained_graph.pb └── retrained_labels.txt /.babelrc: -------------------------------------------------------------------------------- 1 | { 2 | "presets": [ 3 | [ 4 | "env", 5 | { 6 | "node": "6.10.0" 7 | } 8 | ] 9 | ], 10 | "plugins": [ 11 | "transform-regenerator", 12 | "transform-async-to-generator", 13 | [ 14 | "transform-object-rest-spread", 15 | { 16 | "useBuiltIns": true 17 | } 18 | ], 19 | "transform-es2015-destructuring" 20 | ] 21 | } -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | node_modules 2 | *.svg 3 | !examples/*.svg 4 | output/* 5 | temp/* 6 | .DS_Store 7 | scripts/__pycache__/* 8 | build -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 NanoHop 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Convert Sketch files to React Native components 2 | 3 | * Sketch App: https://www.sketchapp.com/ 4 | * React Native: https://facebook.github.io/react-native/ 5 | 6 | Do you have designs in Sketch, and need to turn those into a mobile app? This will take those designs, and automatically create React Native components. 7 | 8 | ![Sketch to React Native](images/sketch_to_react_native.png?raw=true) 9 | 10 | ## Want to try it without installing everything? 11 | 12 | Send me your Sketch file and instructions, and I'll run it and email the output back to you: chris@nanohop.com 13 | 14 | ## Want to collaborate? 15 | 16 | [Join us on Slack!](https://join.slack.com/t/design-to-code/shared_invite/enQtMjU5MzQ2OTAzNDEzLWE1MjUzMWVhNzVlNTFlZTFmYmRjODNkNDZmMDI1M2NhMTcwZjgwM2Q5M2Q3OTk5YmNhNTI5MDRmZDk5NmY1MWY) 17 | 18 | ## Want this as a service? 19 | 20 | We offer that! We also offer a **human in the loop**, where we'll clean up the output before it goes back to you. Send me an email to learn more: chris@nanohop.com 21 | 22 | ***** 23 | 24 | # Getting started 25 | 26 | ### Prerequisites: 27 | 28 | * Node 8.5.0+ https://nodejs.org/en/ 29 | * Python 3.6.1+ https://www.python.org/downloads/ 30 | * Install TensorFlow https://www.tensorflow.org/install/ 31 | 32 | ### Steps to run: 33 | 34 | ```bash 35 | > git clone https://github.com/nanohop/sketch-to-react-native.git 36 | > cd sketch-to-react-native 37 | > npm install && npm link 38 | > sketch-to-react-native ~/Desktop/myfile.svg 39 | ``` 40 | 41 | ### Extract the component from Sketch as an SVG: 42 | 43 | ![Export Instructions](images/export_instructions.png?raw=true) 44 | 45 | ![Export Instructions 2](images/export_instructions_2.png?raw=true) 46 | 47 | ### Use that SVG file as the argument for convert.js 48 | 49 | ```bash 50 | > sketch-to-react-native input-file.svg 51 | ``` 52 | 53 | It will run and save the output to the ./output folder. Make sure to grab both the .js file, and the associated _images_ folder! Drop that into your React Native application, and see the magic! 54 | 55 | ## What if it doesn't work? 56 | 57 | Please let me know! This is early software, and I'm trying to solve as many edge cases as I can find. Please file an issue or send me an email. 58 | 59 | # Conversion process 60 | 61 | Sketch exports a fairly clean SVG, which makes the process easier, but there is still a lot of processing involved. Here's the basic steps: 62 | 63 | 1. Prep the SVG to make processing easier 64 | 2. Use a deep neural net to filter out unwanted elements 65 | 3. Get component bounding boxes using headless Chrome 66 | 4. Figure out all the child/parent relationships 67 | 5. Convert from absolute (pixel) positioning to Flex Box 68 | 6. Extract images from SVG paths and polygons 69 | 7. Generate styles for every component 70 | 8. Generate components 71 | 9. Export to an output file 72 | 73 | 74 | ### A note about the most difficult part 75 | 76 | Sketch (and the SVG export) is a pixel based (absolute positioned) system - but that's no good for React Native. One of the primary difficulties was figuring out what the proper parent / sibling relationships were between all the components, and then converting from absolute positioning to flexBox. There is still some work to clean this part up, but in general it's working fairly well for most inputs. 77 | 78 | This also means that components in Sketch can _partially_ overlap - which doesn't do well during the conversion process. I'm working on a fix for that; but until then - you'll get best results if there are no partially overlapping components in your Sketch file. (Fully overlapping is fine - those will get properly converted to a parent => child relationship) 79 | 80 | # FAQ 81 | 82 | ## Didn't Airbnb already release something that does this? 83 | 84 | Nope! You're probably thinking of [react-sketchapp](https://github.com/airbnb/react-sketchapp) that takes react components, and generates Sketch files. This goes the opposite direction: it takes Sketch files, and creates React Native components. 85 | 86 | 87 | ## Why React Native (mobile)? Why not React (web)? 88 | 89 | I started with mobile because the app designs for mobile are generally more straightforward, with less variation - so it was easier to make the first version. I'm planning to do React for web as well though! Send me an email if you'd like updates when it's available: chris@nanohop.com 90 | 91 | 92 | ## Is the generated code any good? 93 | 94 | I'm a mobile developer myself - and we're rightfully fearful of any generated code. It is really high on my priority list to keep the output as clean and readable as possible. It's not perfect all the time, but it is one of my top priorities. 95 | 96 | 97 | ## How much time does this save? 98 | 99 | I've found that screens that would normally take me about an hour to create, take as little as 10 minutes - so that's as much as 80% time savings! The output does have to be cleaned up a little bit generally, but I find it usually provides good starting point. 100 | 101 | 102 | ## Why use headless Chrome? 103 | 104 | It seems like overkill, but headless Chrome has a great SVG rendering engine, and it was the easiest way to get bounding boxes to work, and to export the SVG assets as pngs. That will probably change in the future. 105 | 106 | 107 | ## Is there a way to try this out without installing everything? 108 | 109 | I plan to get a hosted version up at some point, but until then you can email me your Sketch file and instructions, and I'll run it and email the component back to you: chris@nanohop.com 110 | 111 | 112 | ## What can't it do yet? 113 | 114 | This is a work in progress, so there are a few things it doesn't do well: overlapping components, reusing component styles, reusing common components, collapsing unnecessary wrapping Views, (and more). I have a long roadmap of features to add. 115 | 116 | 117 | ## Is there a Sketch plugin to do this automatically? 118 | 119 | Not yet, but it's on the roadmap! 120 | 121 | 122 | ## How can I help? 123 | 124 | If you'd like to help, I'd love to have you involved! Feel free to file issues, or send me an email with any Sketch file that doesn't work quite right, and I'll also review and merge pull requests as well. 125 | 126 | 127 | 128 | # Example 129 | 130 | 131 | ![Sketch to React Native conversion example](images/sketch_conversion.png?raw=true "Example Conversion") 132 | 133 | 134 | You can see that it's not perfect - but it provides a really good starting point, and it's getting better all the time! 135 | 136 | 137 | ## Here's the generated code: 138 | 139 | 140 | 141 | ```javascript 142 | 143 | import React, { Component } from 'react'; 144 | 145 | import { 146 | StyleSheet, 147 | Text, 148 | View, 149 | TouchableOpacity, 150 | TextInput, 151 | ScrollView, 152 | Image 153 | } from 'react-native'; 154 | 155 | import BackArrow from './Log_In_Page_images/Back-Arrow.png' 156 | import Logo from './Log_In_Page_images/Logo.png' 157 | 158 | export default class Main extends Component { 159 | 160 | render() { 161 | return ( 162 | 166 | 167 | 168 | 169 | 170 | Email Address 171 | 172 | 173 | Password 174 | 175 | 176 | Log In 177 | 178 | 179 | 180 | ) 181 | } 182 | 183 | } 184 | 185 | const styles = StyleSheet.create({ 186 | Base: { 187 | height: 680, 188 | backgroundColor: '#5CC5F8', 189 | borderRadius: 6, 190 | paddingTop: 20, 191 | paddingBottom: 122 192 | }, 193 | BackArrow: { 194 | alignSelf: 'flex-start', 195 | marginLeft: 18 196 | }, 197 | Logo: { 198 | alignSelf: 'center', 199 | marginTop: 25 200 | }, 201 | EmailInput: { 202 | height: 64, 203 | backgroundColor: '#FAFAFA', 204 | borderRadius: 6, 205 | alignSelf: 'center', 206 | marginTop: 49, 207 | width: 292, 208 | alignItems: 'flex-start', 209 | marginLeft: 30, 210 | justifyContent: 'center' 211 | }, 212 | EmailAddress: { 213 | backgroundColor: 'transparent', 214 | fontSize: 18, 215 | fontWeight: '300', 216 | color: '#444444', 217 | textAlign: 'left', 218 | marginLeft: 30 219 | }, 220 | PasswordInput: { 221 | height: 64, 222 | backgroundColor: '#FAFAFA', 223 | borderRadius: 6, 224 | alignSelf: 'center', 225 | marginTop: 14, 226 | width: 292, 227 | alignItems: 'flex-start', 228 | marginLeft: 30, 229 | justifyContent: 'center' 230 | }, 231 | Passsord: { 232 | backgroundColor: 'transparent', 233 | fontSize: 18, 234 | fontWeight: '300', 235 | color: '#444444', 236 | textAlign: 'left', 237 | marginLeft: 30 238 | }, 239 | LoginButton: { 240 | height: 64, 241 | backgroundColor: '#332AC6', 242 | borderRadius: 6, 243 | alignSelf: 'center', 244 | marginTop: 121, 245 | width: 292, 246 | alignItems: 'center', 247 | justifyContent: 'center' 248 | }, 249 | LogIn: { 250 | backgroundColor: 'transparent', 251 | fontSize: 24, 252 | fontWeight: '300', 253 | color: '#FFFFFF', 254 | textAlign: 'center' 255 | } 256 | }) 257 | 258 | 259 | ``` 260 | 261 | 262 | 263 | -------------------------------------------------------------------------------- /bin/convert.js: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env node 2 | require('../build'); -------------------------------------------------------------------------------- /examples/Log In Page.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | Log In Page 5 | Created with Sketch. 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | Log In 14 | 15 | 16 | 17 | Email Address 18 | 19 | 20 | 21 | Password 22 | 23 | 27 | 28 | 29 | -------------------------------------------------------------------------------- /examples/login_page.sketch: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nanohop/sketch-to-react-native/6075a766749f6ed419ff0f34abdb81f8cc4cf0f3/examples/login_page.sketch -------------------------------------------------------------------------------- /images/export_instructions.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nanohop/sketch-to-react-native/6075a766749f6ed419ff0f34abdb81f8cc4cf0f3/images/export_instructions.png -------------------------------------------------------------------------------- /images/export_instructions_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nanohop/sketch-to-react-native/6075a766749f6ed419ff0f34abdb81f8cc4cf0f3/images/export_instructions_2.png -------------------------------------------------------------------------------- /images/sketch_conversion.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nanohop/sketch-to-react-native/6075a766749f6ed419ff0f34abdb81f8cc4cf0f3/images/sketch_conversion.png -------------------------------------------------------------------------------- /images/sketch_to_react_native.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nanohop/sketch-to-react-native/6075a766749f6ed419ff0f34abdb81f8cc4cf0f3/images/sketch_to_react_native.png -------------------------------------------------------------------------------- /package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "svg_to_react", 3 | "version": "1.0.0", 4 | "description": "", 5 | "main": "src/index.js", 6 | "dependencies": { 7 | "puppeteer": "^0.10.2", 8 | "svgson": "^2.1.0" 9 | }, 10 | "devDependencies": { 11 | "babel-cli": "^6.26.0", 12 | "babel-core": "^6.26.0", 13 | "babel-plugin-transform-async-to-generator": "^6.24.1", 14 | "babel-plugin-transform-es2015-destructuring": "^6.23.0", 15 | "babel-plugin-transform-object-rest-spread": "^6.26.0", 16 | "babel-plugin-transform-regenerator": "^6.26.0", 17 | "babel-preset-env": "^1.6.1", 18 | "regenerator-runtime": "^0.11.0" 19 | }, 20 | "scripts": { 21 | "convert": "babel-node ./src", 22 | "build": "babel ./src -d ./build", 23 | "prepare": "npm run build", 24 | "test": "echo \"Error: no test specified\" && exit 1" 25 | }, 26 | "bin": { 27 | "sketch-to-react-native": "./bin/convert.js" 28 | }, 29 | "engines": { 30 | "node": ">=6.10.0" 31 | }, 32 | "author": "", 33 | "license": "ISC" 34 | } 35 | -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2017 Google Inc. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | -------------------------------------------------------------------------------- /scripts/count_ops.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2017 Google Inc. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import os 21 | 22 | import sys 23 | import tensorflow as tf 24 | 25 | def load_graph(file_name): 26 | with open(file_name,'rb') as f: 27 | content = f.read() 28 | graph_def = tf.GraphDef() 29 | graph_def.ParseFromString(content) 30 | with tf.Graph().as_default() as graph: 31 | tf.import_graph_def(graph_def, name='') 32 | return graph 33 | 34 | def count_ops(file_name, op_name = None): 35 | graph = load_graph(file_name) 36 | 37 | if op_name is None: 38 | return len(graph.get_operations()) 39 | else: 40 | return sum(1 for op in graph.get_operations() 41 | if op.name == op_name) 42 | 43 | if __name__ == "__main__": 44 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 45 | print(count_ops(*sys.argv[1:])) 46 | 47 | -------------------------------------------------------------------------------- /scripts/evaluate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2017 Google Inc. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import os 21 | 22 | import sys 23 | import argparse 24 | 25 | import numpy as np 26 | import PIL.Image as Image 27 | import tensorflow as tf 28 | 29 | import scripts.retrain as retrain 30 | from scripts.count_ops import load_graph 31 | 32 | def evaluate_graph(graph_file_name): 33 | with load_graph(graph_file_name).as_default() as graph: 34 | ground_truth_input = tf.placeholder( 35 | tf.float32, [None, 5], name='GroundTruthInput') 36 | 37 | image_buffer_input = graph.get_tensor_by_name('input:0') 38 | final_tensor = graph.get_tensor_by_name('final_result:0') 39 | accuracy, _ = retrain.add_evaluation_step(final_tensor, ground_truth_input) 40 | 41 | logits = graph.get_tensor_by_name("final_training_ops/Wx_plus_b/add:0") 42 | xent = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits( 43 | labels = ground_truth_input, 44 | logits = logits)) 45 | 46 | image_dir = 'tf_files/flower_photos' 47 | testing_percentage = 10 48 | validation_percentage = 10 49 | validation_batch_size = 100 50 | category='testing' 51 | 52 | image_lists = retrain.create_image_lists( 53 | image_dir, testing_percentage, 54 | validation_percentage) 55 | class_count = len(image_lists.keys()) 56 | 57 | ground_truths = [] 58 | file_names = [] 59 | 60 | for label_index, label_name in enumerate(image_lists.keys()): 61 | for image_index, image_name in enumerate(image_lists[label_name][category]): 62 | image_name = retrain.get_image_path( 63 | image_lists, label_name, image_index, image_dir, category) 64 | ground_truth = np.zeros([1, class_count], dtype=np.float32) 65 | ground_truth[0, label_index] = 1.0 66 | ground_truths.append(ground_truth) 67 | file_names.append(image_name) 68 | 69 | accuracies = [] 70 | xents = [] 71 | with tf.Session(graph=graph) as sess: 72 | for filename, ground_truth in zip(file_names, ground_truths): 73 | image = Image.open(filename).resize((224,224),Image.ANTIALIAS) 74 | image = np.array(image, dtype=np.float32)[None,...] 75 | image = (image-128)/128.0 76 | 77 | feed_dict={ 78 | image_buffer_input: image, 79 | ground_truth_input: ground_truth} 80 | 81 | eval_accuracy, eval_xent = sess.run([accuracy, xent], feed_dict) 82 | 83 | accuracies.append(eval_accuracy) 84 | xents.append(eval_xent) 85 | 86 | 87 | return np.mean(accuracies), np.mean(xents) 88 | 89 | if __name__ == "__main__": 90 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 91 | accuracy,xent = evaluate_graph(*sys.argv[1:]) 92 | print('Accuracy: %g' % accuracy) 93 | print('Cross Entropy: %g' % xent) 94 | 95 | 96 | 97 | -------------------------------------------------------------------------------- /scripts/graph_pb2tb.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2017 Google Inc. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import os 18 | import sys 19 | 20 | import tensorflow as tf 21 | 22 | def load_graph(graph_pb_path): 23 | with open(graph_pb_path,'rb') as f: 24 | content = f.read() 25 | graph_def = tf.GraphDef() 26 | graph_def.ParseFromString(content) 27 | with tf.Graph().as_default() as graph: 28 | tf.import_graph_def(graph_def, name='') 29 | return graph 30 | 31 | 32 | def graph_to_tensorboard(graph, out_dir): 33 | with tf.Session(): 34 | train_writer = tf.summary.FileWriter(out_dir) 35 | train_writer.add_graph(graph) 36 | 37 | 38 | def main(out_dir, graph_pb_path): 39 | graph = load_graph(graph_pb_path) 40 | graph_to_tensorboard(graph, out_dir) 41 | 42 | if __name__ == "__main__": 43 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 44 | main(*sys.argv[1:]) 45 | -------------------------------------------------------------------------------- /scripts/label_image.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import argparse 20 | import sys 21 | 22 | import numpy as np 23 | import tensorflow as tf 24 | import os 25 | os.environ['TF_CPP_MIN_LOG_LEVEL']='2' 26 | 27 | def load_graph(model_file): 28 | graph = tf.Graph() 29 | graph_def = tf.GraphDef() 30 | 31 | with open(model_file, "rb") as f: 32 | graph_def.ParseFromString(f.read()) 33 | with graph.as_default(): 34 | tf.import_graph_def(graph_def) 35 | 36 | return graph 37 | 38 | def read_tensor_from_image_file(file_name, input_height=299, input_width=299, 39 | input_mean=0, input_std=255): 40 | input_name = "file_reader" 41 | output_name = "normalized" 42 | file_reader = tf.read_file(file_name, input_name) 43 | if file_name.endswith(".png"): 44 | image_reader = tf.image.decode_png(file_reader, channels = 3, 45 | name='png_reader') 46 | elif file_name.endswith(".gif"): 47 | image_reader = tf.squeeze(tf.image.decode_gif(file_reader, 48 | name='gif_reader')) 49 | elif file_name.endswith(".bmp"): 50 | image_reader = tf.image.decode_bmp(file_reader, name='bmp_reader') 51 | else: 52 | image_reader = tf.image.decode_jpeg(file_reader, channels = 3, 53 | name='jpeg_reader') 54 | float_caster = tf.cast(image_reader, tf.float32) 55 | dims_expander = tf.expand_dims(float_caster, 0); 56 | resized = tf.image.resize_bilinear(dims_expander, [input_height, input_width]) 57 | normalized = tf.divide(tf.subtract(resized, [input_mean]), [input_std]) 58 | sess = tf.Session() 59 | result = sess.run(normalized) 60 | 61 | return result 62 | 63 | def load_labels(label_file): 64 | label = [] 65 | proto_as_ascii_lines = tf.gfile.GFile(label_file).readlines() 66 | for l in proto_as_ascii_lines: 67 | label.append(l.rstrip()) 68 | return label 69 | 70 | if __name__ == "__main__": 71 | file_name = "tf_files/flower_photos/daisy/3475870145_685a19116d.jpg" 72 | model_file = "tf_files/retrained_graph.pb" 73 | label_file = "tf_files/retrained_labels.txt" 74 | input_height = 224 75 | input_width = 224 76 | input_mean = 128 77 | input_std = 128 78 | input_layer = "input" 79 | output_layer = "final_result" 80 | 81 | parser = argparse.ArgumentParser() 82 | parser.add_argument("--image", help="image to be processed") 83 | parser.add_argument("--graph", help="graph/model to be executed") 84 | parser.add_argument("--labels", help="name of file containing labels") 85 | parser.add_argument("--input_height", type=int, help="input height") 86 | parser.add_argument("--input_width", type=int, help="input width") 87 | parser.add_argument("--input_mean", type=int, help="input mean") 88 | parser.add_argument("--input_std", type=int, help="input std") 89 | parser.add_argument("--input_layer", help="name of input layer") 90 | parser.add_argument("--output_layer", help="name of output layer") 91 | args = parser.parse_args() 92 | 93 | if args.graph: 94 | model_file = args.graph 95 | if args.image: 96 | file_name = args.image 97 | if args.labels: 98 | label_file = args.labels 99 | if args.input_height: 100 | input_height = args.input_height 101 | if args.input_width: 102 | input_width = args.input_width 103 | if args.input_mean: 104 | input_mean = args.input_mean 105 | if args.input_std: 106 | input_std = args.input_std 107 | if args.input_layer: 108 | input_layer = args.input_layer 109 | if args.output_layer: 110 | output_layer = args.output_layer 111 | 112 | graph = load_graph(model_file) 113 | t = read_tensor_from_image_file(file_name, 114 | input_height=input_height, 115 | input_width=input_width, 116 | input_mean=input_mean, 117 | input_std=input_std) 118 | 119 | input_name = "import/" + input_layer 120 | output_name = "import/" + output_layer 121 | input_operation = graph.get_operation_by_name(input_name); 122 | output_operation = graph.get_operation_by_name(output_name); 123 | 124 | with tf.Session(graph=graph) as sess: 125 | results = sess.run(output_operation.outputs[0], 126 | {input_operation.outputs[0]: t}) 127 | results = np.squeeze(results) 128 | 129 | top_k = results.argsort()[-5:][::-1] 130 | labels = load_labels(label_file) 131 | for i in top_k: 132 | print(labels[i], results[i]) 133 | -------------------------------------------------------------------------------- /scripts/quantize_graph.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | r"""Transforms a float-trained graph into an equivalent quantized version. 16 | 17 | An example of command-line usage is: 18 | bazel build tensorflow/tools/quantization:quantize_graph \ 19 | && bazel-bin/tensorflow/tools/quantization/quantize_graph \ 20 | --input=tensorflow_inception_graph.pb 21 | --output_node_names="softmax2" --print_nodes --output=/tmp/quantized_graph.pb \ 22 | --mode=eightbit --logtostderr 23 | 24 | """ 25 | 26 | from __future__ import absolute_import 27 | from __future__ import division 28 | from __future__ import print_function 29 | 30 | import os 31 | import collections 32 | import re 33 | import numpy as np 34 | 35 | from tensorflow.core.framework import attr_value_pb2 36 | from tensorflow.core.framework import graph_pb2 37 | from tensorflow.core.framework import node_def_pb2 38 | from tensorflow.python.client import session 39 | from tensorflow.python.framework import constant_op 40 | from tensorflow.python.framework import dtypes 41 | from tensorflow.python.framework import graph_util 42 | from tensorflow.python.framework import importer 43 | from tensorflow.python.framework import ops 44 | from tensorflow.python.framework import tensor_shape 45 | from tensorflow.python.framework import tensor_util 46 | from tensorflow.python.ops import array_ops 47 | from tensorflow.python.platform import app 48 | from tensorflow.python.platform import flags as flags_lib 49 | from tensorflow.python.platform import gfile 50 | 51 | flags = flags_lib 52 | FLAGS = flags.FLAGS 53 | 54 | flags.DEFINE_boolean("print_nodes", False, """Lists all nodes in the model.""") 55 | flags.DEFINE_string("input", "", """TensorFlow 'GraphDef' file to load.""") 56 | flags.DEFINE_string("output_node_names", "", 57 | """Output node names, comma separated.""") 58 | flags.DEFINE_string("output", "", """File to save the output graph to.""") 59 | flags.DEFINE_integer("bitdepth", 8, 60 | """How many bits to quantize the graph to.""") 61 | flags.DEFINE_string("mode", "round", 62 | """What transformation to apply (round, quantize,""" 63 | """ eightbit, weights, or weights_rounded).""") 64 | flags.DEFINE_string("test_input_dims", "1,224,224,3", 65 | """The size of the input tensor to use when testing a""" 66 | """ graph loaded from a file.""") 67 | flags.DEFINE_boolean("strip_redundant_quantization", True, 68 | """Removes redundant dequantize/quantize pairs.""") 69 | flags.DEFINE_boolean("quantized_input", False, 70 | "If true, assume Placeholders are quantized with values " 71 | "covering [--quantized_input_min,--quantized_input_max]. " 72 | "Only supported when --mode=eightbit") 73 | flags.DEFINE_float("quantized_input_min", 0, 74 | "The minimum of the actual input range when " 75 | "--quantized_input") 76 | flags.DEFINE_float("quantized_input_max", 1, 77 | "The maximum of the actual input range when " 78 | "--quantized_input") 79 | flags.DEFINE_float( 80 | "quantized_fallback_min", None, 81 | "The fallback 'min' value to use for layers which lack min-max " 82 | "information. Note: this should be considered a coarse tool just good " 83 | "enough for experimentation purposes, since graphs quantized in this way " 84 | "would be very inaccurate.") 85 | flags.DEFINE_float( 86 | "quantized_fallback_max", None, 87 | "The fallback 'max' value to use for layers which lack min-max " 88 | "information. Note: this should be considered a coarse tool just good " 89 | "enough for experimentation purposes, since graphs quantized in this way " 90 | "would be very inaccurate.") 91 | 92 | 93 | def print_input_nodes(current_node, nodes_map, indent, already_visited): 94 | print(" " * indent + current_node.op + ":" + current_node.name) 95 | already_visited[current_node.name] = True 96 | for input_node_name in current_node.input: 97 | if input_node_name in already_visited: 98 | continue 99 | input_node = nodes_map[input_node_name] 100 | print_input_nodes(input_node, nodes_map, indent + 1, already_visited) 101 | 102 | 103 | def create_node(op, name, inputs): 104 | new_node = node_def_pb2.NodeDef() 105 | new_node.op = op 106 | new_node.name = name 107 | for input_name in inputs: 108 | new_node.input.extend([input_name]) 109 | return new_node 110 | 111 | 112 | def create_constant_node(name, value, dtype, shape=None): 113 | node = create_node("Const", name, []) 114 | set_attr_dtype(node, "dtype", dtype) 115 | set_attr_tensor(node, "value", value, dtype, shape) 116 | return node 117 | 118 | 119 | def copy_attr(node, key, attr_value): 120 | try: 121 | node.attr[key].CopyFrom(attr_value) 122 | except KeyError: 123 | pass 124 | 125 | 126 | def set_attr_dtype(node, key, value): 127 | try: 128 | node.attr[key].CopyFrom( 129 | attr_value_pb2.AttrValue(type=value.as_datatype_enum)) 130 | except KeyError: 131 | pass 132 | 133 | 134 | def set_attr_shape(node, key, value): 135 | try: 136 | node.attr[key].CopyFrom( 137 | attr_value_pb2.AttrValue(shape=tensor_shape.as_shape(value).as_proto())) 138 | except KeyError: 139 | pass 140 | 141 | 142 | def set_attr_tensor(node, key, value, dtype, shape=None): 143 | try: 144 | node.attr[key].CopyFrom( 145 | attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto( 146 | value, dtype=dtype, shape=shape))) 147 | except KeyError: 148 | pass 149 | 150 | 151 | def set_attr_string(node, key, value): 152 | try: 153 | node.attr[key].CopyFrom(attr_value_pb2.AttrValue(s=value)) 154 | except KeyError: 155 | pass 156 | 157 | 158 | def set_attr_int_list(node, key, value): 159 | list_value = attr_value_pb2.AttrValue.ListValue(i=value) 160 | try: 161 | node.attr[key].CopyFrom(attr_value_pb2.AttrValue(list=list_value)) 162 | except KeyError: 163 | pass 164 | 165 | 166 | def set_attr_bool(node, key, value): 167 | try: 168 | node.attr[key].CopyFrom(attr_value_pb2.AttrValue(b=value)) 169 | except KeyError: 170 | pass 171 | 172 | 173 | def set_attr_int(node, key, value): 174 | try: 175 | node.attr[key].CopyFrom(attr_value_pb2.AttrValue(i=value)) 176 | except KeyError: 177 | pass 178 | 179 | 180 | def set_attr_float(node, key, value): 181 | try: 182 | node.attr[key].CopyFrom(attr_value_pb2.AttrValue(f=value)) 183 | except KeyError: 184 | pass 185 | 186 | 187 | def node_name_from_input(node_name): 188 | """Strips off ports and other decorations to get the underlying node name.""" 189 | if node_name.startswith("^"): 190 | node_name = node_name[1:] 191 | m = re.search(r"(.*):\d+$", node_name) 192 | if m: 193 | node_name = m.group(1) 194 | return node_name 195 | 196 | 197 | def ensure_tensor_name_has_port(node_name): 198 | """Makes sure that a tensor name has :0 if no explicit port exists.""" 199 | m = re.search(r"(.*):\d+$", node_name) 200 | if m: 201 | name_with_port = node_name 202 | else: 203 | name_with_port = node_name + ":0" 204 | return name_with_port 205 | 206 | 207 | def unique_node_name_from_input(node_name): 208 | """Replaces invalid characters in input names to get a unique node name.""" 209 | return node_name.replace(":", "__port__").replace("^", "__hat__") 210 | 211 | 212 | def quantize_array(arr, num_buckets): 213 | """Quantizes a numpy array. 214 | 215 | This function maps each scalar in arr to the center of one of num_buckets 216 | buckets. For instance, 217 | quantize_array([0, 0.3, 0.6, 1], 2) => [0.25, 0.25, 0.75, 0.75] 218 | 219 | Args: 220 | arr: The numpy array to quantize. 221 | num_buckets: The number of buckets to map "var" to. 222 | Returns: 223 | The quantized numpy array. 224 | Raises: 225 | ValueError: when num_buckets < 1. 226 | """ 227 | if num_buckets < 1: 228 | raise ValueError("num_buckets must be >= 1") 229 | arr_max = arr.max() 230 | arr_min = arr.min() 231 | if arr_max == arr_min: 232 | return arr 233 | bucket_width = (arr_max - arr_min) / num_buckets 234 | # Map scalars to bucket indices. Take special care of max(arr). 235 | bucket_indices = np.floor((arr - arr_min) / bucket_width) 236 | bucket_indices[bucket_indices == num_buckets] = num_buckets - 1 237 | # Map each scalar to the center of a bucket. 238 | arr = arr_min + bucket_width * (bucket_indices + 0.5) 239 | return arr 240 | 241 | 242 | def quantize_weight_rounded(input_node): 243 | """Returns a replacement node for input_node containing bucketed floats.""" 244 | input_tensor = input_node.attr["value"].tensor 245 | tensor_value = tensor_util.MakeNdarray(input_tensor) 246 | shape = input_tensor.tensor_shape 247 | # Currently, the parameter FLAGS.bitdepth is used to compute the 248 | # number of buckets as 1 << FLAGS.bitdepth, meaning the number of 249 | # buckets can only be a power of 2. 250 | # This could be fixed by introducing a new parameter, num_buckets, 251 | # which would allow for more flexibility in chosing the right model 252 | # size/accuracy tradeoff. But I didn't want to add more parameters 253 | # to this script than absolutely necessary. 254 | num_buckets = 1 << FLAGS.bitdepth 255 | tensor_value_rounded = quantize_array(tensor_value, num_buckets) 256 | tensor_shape_list = tensor_util.TensorShapeProtoToList(shape) 257 | return [ 258 | create_constant_node( 259 | input_node.name, 260 | tensor_value_rounded, 261 | dtypes.float32, 262 | shape=tensor_shape_list) 263 | ] 264 | 265 | 266 | def quantize_weight_eightbit(input_node, quantization_mode): 267 | """Returns replacement nodes for input_node using the Dequantize op.""" 268 | base_name = input_node.name + "_" 269 | quint8_const_name = base_name + "quint8_const" 270 | min_name = base_name + "min" 271 | max_name = base_name + "max" 272 | float_tensor = tensor_util.MakeNdarray(input_node.attr["value"].tensor) 273 | min_value = np.min(float_tensor.flatten()) 274 | max_value = np.max(float_tensor.flatten()) 275 | # Make sure that the range includes zero. 276 | if min_value > 0.0: 277 | min_value = 0.0 278 | # min_value == max_value is a tricky case. It can occur for general 279 | # tensors, and of course for scalars. The quantized ops cannot deal 280 | # with this case, so we set max_value to something else. 281 | # It's a tricky question what is the numerically best solution to 282 | # deal with this degeneracy. 283 | # TODO(petewarden): Better use a tolerance than a hard comparison? 284 | if min_value == max_value: 285 | if abs(min_value) < 0.000001: 286 | max_value = min_value + 1.0 287 | elif min_value > 0: 288 | max_value = 2 * min_value 289 | else: 290 | max_value = min_value / 2.0 291 | 292 | sess = session.Session() 293 | with sess.as_default(): 294 | quantize_op = array_ops.quantize_v2( 295 | float_tensor, 296 | min_value, 297 | max_value, 298 | dtypes.quint8, 299 | mode=quantization_mode) 300 | quint8_tensor = quantize_op[0].eval() 301 | shape = tensor_util.TensorShapeProtoToList(input_node.attr["value"] 302 | .tensor.tensor_shape) 303 | quint8_const_node = create_constant_node( 304 | quint8_const_name, quint8_tensor, dtypes.quint8, shape=shape) 305 | min_node = create_constant_node(min_name, min_value, dtypes.float32) 306 | max_node = create_constant_node(max_name, max_value, dtypes.float32) 307 | dequantize_node = create_node("Dequantize", input_node.name, 308 | [quint8_const_name, min_name, max_name]) 309 | set_attr_dtype(dequantize_node, "T", dtypes.quint8) 310 | set_attr_string(dequantize_node, "mode", quantization_mode) 311 | return [quint8_const_node, min_node, max_node, dequantize_node] 312 | 313 | 314 | EightbitizeRecursionState = collections.namedtuple( 315 | "EightbitizeRecursionState", 316 | ["already_visited", "output_node_stack", "merged_with_fake_quant"]) 317 | 318 | 319 | class GraphRewriter(object): 320 | """Takes a float graph, and rewrites it in quantized form.""" 321 | 322 | def __init__(self, 323 | input_graph, 324 | mode, 325 | quantized_input_range, 326 | fallback_quantization_range=None): 327 | """Sets up the class to rewrite a float graph. 328 | 329 | Args: 330 | input_graph: A float graph to transform. 331 | mode: A string controlling how quantization is performed - 332 | round, quantize, eightbit, or weights. 333 | quantized_input_range: if set, assume the input is 334 | quantized and represents the range 335 | [quantized_input_range[0], quantized_input_range[1]] 336 | fallback_quantization_range: if set, then for nodes where the quantization 337 | range can't be inferred from the graph, use the range 338 | [fallback_quantization_range[0], fallback_quantization_range[1]) instead 339 | of using a RequantizationRange node in the graph. 340 | 341 | Raises: 342 | ValueError: Two nodes with the same name were found in the graph. 343 | """ 344 | self.input_graph = input_graph 345 | self.nodes_map = self.create_nodes_map(input_graph) 346 | self.output_graph = None 347 | self.mode = mode 348 | self.final_node_renames = {} 349 | if quantized_input_range: 350 | self.input_range = (quantized_input_range[0], quantized_input_range[1]) 351 | if self.input_range[0] >= self.input_range[1]: 352 | raise ValueError("Invalid quantized_input_range: [%s,%s]" % 353 | self.input_range) 354 | if self.mode != "eightbit": 355 | raise ValueError( 356 | "quantized_input_range can only be specified in eightbit mode") 357 | else: 358 | self.input_range = None 359 | 360 | if fallback_quantization_range: 361 | self.fallback_quantization_range = [ 362 | fallback_quantization_range[0], fallback_quantization_range[1] 363 | ] 364 | if (self.fallback_quantization_range[0] >= 365 | self.fallback_quantization_range[1]): 366 | raise ValueError("Invalid fallback_quantization_range: [%s,%s]" % 367 | self.fallback_quantization_range) 368 | if self.mode != "eightbit": 369 | raise ValueError("fallback_quantization_range can only be " 370 | "specified in eightbit mode") 371 | else: 372 | self.fallback_quantization_range = None 373 | 374 | # Data that is valid only during the recursive call to rewrite the graph. 375 | self.state = None 376 | 377 | def create_nodes_map(self, graph): 378 | """Builds a mapping of node names to their defs from the graph.""" 379 | nodes_map = {} 380 | for node in graph.node: 381 | if node.name not in nodes_map.keys(): 382 | nodes_map[node.name] = node 383 | else: 384 | raise ValueError("Duplicate node names detected.") 385 | return nodes_map 386 | 387 | def rewrite(self, output_node_names): 388 | """Triggers rewriting of the float graph. 389 | 390 | Args: 391 | output_node_names: A list of names of the nodes that produce the final 392 | results. 393 | 394 | Returns: 395 | A quantized version of the float graph. 396 | """ 397 | self.output_graph = graph_pb2.GraphDef() 398 | output_nodes = [ 399 | self.nodes_map[output_node_name] 400 | for output_node_name in output_node_names 401 | ] 402 | if self.mode == "round": 403 | self.already_visited = {} 404 | for output_node in output_nodes: 405 | self.round_nodes_recursively(output_node) 406 | elif self.mode == "quantize": 407 | self.already_visited = {} 408 | self.already_quantized = {} 409 | for output_node in output_nodes: 410 | self.quantize_nodes_recursively(output_node) 411 | elif self.mode == "eightbit": 412 | self.set_input_graph(graph_util.remove_training_nodes(self.input_graph)) 413 | output_nodes = [ 414 | self.nodes_map[output_node_name] 415 | for output_node_name in output_node_names 416 | ] 417 | 418 | self.state = EightbitizeRecursionState( 419 | already_visited={}, output_node_stack=[], merged_with_fake_quant={}) 420 | for output_node in output_nodes: 421 | self.eightbitize_nodes_recursively(output_node) 422 | self.state = None 423 | if self.input_range: 424 | self.add_output_graph_node( 425 | create_constant_node("quantized_input_min_value", self.input_range[ 426 | 0], dtypes.float32, [])) 427 | self.add_output_graph_node( 428 | create_constant_node("quantized_input_max_value", self.input_range[ 429 | 1], dtypes.float32, [])) 430 | if self.fallback_quantization_range: 431 | self.add_output_graph_node( 432 | create_constant_node("fallback_quantization_min_value", 433 | self.fallback_quantization_range[0], 434 | dtypes.float32, [])) 435 | self.add_output_graph_node( 436 | create_constant_node("fallback_quantization_max_value", 437 | self.fallback_quantization_range[1], 438 | dtypes.float32, [])) 439 | if FLAGS.strip_redundant_quantization: 440 | self.output_graph = self.remove_redundant_quantization( 441 | self.output_graph) 442 | self.remove_dead_nodes(output_node_names) 443 | self.apply_final_node_renames() 444 | elif self.mode == "weights": 445 | self.output_graph = self.quantize_weights(self.input_graph, 446 | b"MIN_COMBINED") 447 | self.remove_dead_nodes(output_node_names) 448 | elif self.mode == "weights_rounded": 449 | self.output_graph = self.quantize_weights(self.input_graph, self.mode) 450 | self.remove_dead_nodes(output_node_names) 451 | else: 452 | print("Bad mode - " + self.mode + ".") 453 | return self.output_graph 454 | 455 | def round_nodes_recursively(self, current_node): 456 | """The entry point for simple rounding quantization.""" 457 | if self.already_visited[current_node.name]: 458 | return 459 | self.already_visited[current_node.name] = True 460 | for input_node_name in current_node.input: 461 | input_node_name = node_name_from_input(input_node_name) 462 | input_node = self.nodes_map[input_node_name] 463 | self.round_nodes_recursively(input_node) 464 | nodes_to_quantize = ["Conv2D", "BiasAdd", "MatMul"] 465 | if any(current_node.op in s for s in nodes_to_quantize): 466 | new_node = node_def_pb2.NodeDef() 467 | new_node.CopyFrom(current_node) 468 | new_node.name = current_node.name + "_original" 469 | self.add_output_graph_node(new_node) 470 | levels = 1 << FLAGS.bitdepth 471 | constant_name = current_node.name + "_round_depth" 472 | constant_tensor = constant_op.constant( 473 | levels, dtype=dtypes.int32, name=constant_name) 474 | constant_node = constant_tensor.op.node_def 475 | self.add_output_graph_node(constant_node) 476 | quantize_node = node_def_pb2.NodeDef() 477 | quantize_node.op = "RoundToSteps" 478 | quantize_node.name = current_node.name 479 | quantize_node.input.extend([current_node.name + "_original"]) 480 | quantize_node.input.extend([constant_node.name]) 481 | self.add_output_graph_node(quantize_node) 482 | else: 483 | new_node = node_def_pb2.NodeDef() 484 | new_node.CopyFrom(current_node) 485 | self.add_output_graph_node(new_node) 486 | 487 | def quantize_nodes_recursively(self, current_node): 488 | """The entry point for quantizing nodes to eight bit and back.""" 489 | if self.already_visited[current_node.name]: 490 | return 491 | self.already_visited[current_node.name] = True 492 | for input_node_name in current_node.input: 493 | input_node_name = node_name_from_input(input_node_name) 494 | input_node = self.nodes_map[input_node_name] 495 | self.quantize_nodes_recursively(input_node) 496 | nodes_to_quantize = ["Conv2D", "BiasAdd", "MatMul"] 497 | if any(current_node.op in s for s in nodes_to_quantize): 498 | for input_name in current_node.input: 499 | input_name = node_name_from_input(input_name) 500 | input_node = self.nodes_map[input_name] 501 | self.quantize_node(input_node) 502 | self.quantize_node(current_node) 503 | else: 504 | new_node = node_def_pb2.NodeDef() 505 | new_node.CopyFrom(current_node) 506 | self.add_output_graph_node(new_node) 507 | 508 | def quantize_node(self, input_node): 509 | """Handles quantizing a single node.""" 510 | input_name = input_node.name 511 | if input_name in self.already_quantized: 512 | return 513 | self.already_quantized[input_name] = True 514 | original_input_name = input_name + "_original" 515 | reshape_name = input_name + "_reshape" 516 | reshape_dims_name = input_name + "_reshape_dims" 517 | max_name = input_name + "_max" 518 | min_name = input_name + "_min" 519 | dims_name = input_name + "_dims" 520 | quantize_name = input_name + "_quantize" 521 | dequantize_name = input_name 522 | original_input_node = node_def_pb2.NodeDef() 523 | original_input_node.CopyFrom(input_node) 524 | original_input_node.name = original_input_name 525 | self.add_output_graph_node(original_input_node) 526 | reshape_dims_node = create_constant_node(reshape_dims_name, -1, 527 | dtypes.int32, [1]) 528 | self.add_output_graph_node(reshape_dims_node) 529 | reshape_node = create_node("Reshape", reshape_name, 530 | [original_input_name, reshape_dims_name]) 531 | set_attr_dtype(reshape_node, "T", dtypes.float32) 532 | self.add_output_graph_node(reshape_node) 533 | dims_node = create_constant_node(dims_name, 0, dtypes.int32, [1]) 534 | self.add_output_graph_node(dims_node) 535 | max_node = create_node("Max", max_name, [reshape_name, dims_name]) 536 | set_attr_dtype(max_node, "T", dtypes.float32) 537 | set_attr_bool(max_node, "keep_dims", False) 538 | self.add_output_graph_node(max_node) 539 | min_node = create_node("Min", min_name, [reshape_name, dims_name]) 540 | set_attr_dtype(min_node, "T", dtypes.float32) 541 | set_attr_bool(min_node, "keep_dims", False) 542 | self.add_output_graph_node(min_node) 543 | quantize_node = create_node("Quantize", quantize_name, 544 | [original_input_name, min_name, max_name]) 545 | set_attr_dtype(quantize_node, "T", dtypes.quint8) 546 | set_attr_string(quantize_node, "mode", b"MIN_FIRST") 547 | self.add_output_graph_node(quantize_node) 548 | dequantize_node = create_node("Dequantize", dequantize_name, 549 | [quantize_name, min_name, max_name]) 550 | set_attr_dtype(dequantize_node, "T", dtypes.quint8) 551 | set_attr_string(dequantize_node, "mode", b"MIN_FIRST") 552 | self.add_output_graph_node(dequantize_node) 553 | 554 | def should_merge_with_fake_quant_node(self): 555 | """Should the current node merge with self.state.output_node_stack[-1]?""" 556 | if not self.state.output_node_stack: 557 | return False 558 | top = self.state.output_node_stack[-1] 559 | return top[1] == 0 and top[0].op in ["FakeQuantWithMinMaxVars"] 560 | 561 | def should_quantize_const(self, node): 562 | if not self.state.output_node_stack: 563 | return False 564 | top = self.state.output_node_stack[-1] 565 | if not top[2]: 566 | return False 567 | dtype = dtypes.as_dtype(node.attr["dtype"].type) 568 | assert dtype == dtypes.float32, ( 569 | "Failed to quantized constant %s of type %s" % (node.name, dtype)) 570 | return True 571 | 572 | def eightbitize_nodes_recursively(self, current_node): 573 | """The entry point for transforming a graph into full eight bit.""" 574 | if current_node.name in self.state.already_visited: 575 | if (self.should_merge_with_fake_quant_node() or 576 | current_node.name in self.state.merged_with_fake_quant): 577 | raise ValueError("Unsupported graph structure: output of node %s " 578 | "is processed by a FakeQuant* node and should have " 579 | "no other outputs.", current_node.name) 580 | return 581 | self.state.already_visited[current_node.name] = True 582 | 583 | for i, input_node_name in enumerate(current_node.input): 584 | quantize_input = False 585 | if current_node.op in ("MatMul", "Conv2D", "BiasAdd", "MaxPool", 586 | "AvgPool", "Relu", "Relu6", 587 | "BatchNormWithGlobalNormalization"): 588 | quantize_input = True 589 | elif current_node.op == "Concat" and i > 0: 590 | quantize_input = ( 591 | dtypes.as_dtype(current_node.attr["T"].type) == dtypes.float32) 592 | elif current_node.op == "Reshape" and i == 0: 593 | quantize_input = ( 594 | dtypes.as_dtype(current_node.attr["T"].type) == dtypes.float32) 595 | 596 | self.state.output_node_stack.append((current_node, i, quantize_input)) 597 | 598 | input_node_name = node_name_from_input(input_node_name) 599 | input_node = self.nodes_map[input_node_name] 600 | self.eightbitize_nodes_recursively(input_node) 601 | 602 | self.state.output_node_stack.pop() 603 | 604 | if current_node.op == "MatMul": 605 | self.eightbitize_mat_mul_node(current_node) 606 | elif current_node.op == "Conv2D": 607 | self.eightbitize_conv_node(current_node) 608 | elif current_node.op == "BiasAdd": 609 | self.eightbitize_bias_add_node(current_node) 610 | elif current_node.op == "MaxPool" or current_node.op == "AvgPool": 611 | self.eightbitize_single_input_tensor_node(current_node, 612 | self.add_pool_function) 613 | elif current_node.op == "Relu" or current_node.op == "Relu6": 614 | self.eightbitize_single_input_tensor_node(current_node, 615 | self.add_relu_function) 616 | elif (current_node.op == "Concat" and 617 | dtypes.as_dtype(current_node.attr["T"].type) == dtypes.float32): 618 | self.eightbitize_concat_node(current_node) 619 | elif current_node.op == "BatchNormWithGlobalNormalization": 620 | self.eightbitize_batch_norm_node(current_node) 621 | elif (current_node.op == "Reshape" and 622 | dtypes.as_dtype(current_node.attr["T"].type) == dtypes.float32): 623 | self.eightbitize_reshape_node(current_node) 624 | elif (self.input_range and 625 | current_node.op in ("Placeholder", "PlaceholderV2")): 626 | self.eightbitize_placeholder_node(current_node) 627 | elif current_node.op == "FakeQuantWithMinMaxVars": 628 | # It will have been merged into the underlying node. 629 | pass 630 | elif current_node.op == "Const": 631 | if self.should_quantize_const(current_node): 632 | for n in quantize_weight_eightbit(current_node, b"MIN_FIRST"): 633 | self.add_output_graph_node(n) 634 | else: 635 | new_node = node_def_pb2.NodeDef() 636 | new_node.CopyFrom(current_node) 637 | self.add_output_graph_node(new_node) 638 | 639 | ################################################################### 640 | # Note: if more cases are added here, you may need to update the op 641 | # name lists in the loop over children at the start of the function. 642 | ################################################################### 643 | else: 644 | new_node = node_def_pb2.NodeDef() 645 | new_node.CopyFrom(current_node) 646 | self.add_output_graph_node(new_node) 647 | 648 | if (self.should_merge_with_fake_quant_node() and 649 | current_node.name not in self.state.merged_with_fake_quant): 650 | raise ValueError( 651 | "FakeQuant* node %s failed to merge with node %s of type %s" % 652 | (self.state.output_node_stack[-1][0], current_node.name, 653 | current_node.op)) 654 | 655 | def add_eightbit_prologue_nodes(self, original_node): 656 | """Adds input conversion nodes to handle quantizing the underlying node.""" 657 | namespace_prefix = original_node.name + "_eightbit" 658 | reshape_dims_name, reduction_dims_name = self.add_common_quantization_nodes( 659 | namespace_prefix) 660 | input_names = [] 661 | min_max_names = [] 662 | for original_input_name in original_node.input: 663 | quantize_input_name, min_input_name, max_input_name = ( 664 | self.eightbitize_input_to_node(namespace_prefix, original_input_name, 665 | reshape_dims_name, 666 | reduction_dims_name)) 667 | input_names.append(quantize_input_name) 668 | min_max_names.append(min_input_name) 669 | min_max_names.append(max_input_name) 670 | all_input_names = [] 671 | all_input_names.extend(input_names) 672 | all_input_names.extend(min_max_names) 673 | return all_input_names 674 | 675 | def add_common_quantization_nodes(self, namespace_prefix): 676 | """Builds constant nodes needed for quantization of inputs.""" 677 | reshape_dims_name = namespace_prefix + "_reshape_dims" 678 | reduction_dims_name = namespace_prefix + "_reduction_dims" 679 | 680 | reshape_dims_node = create_constant_node(reshape_dims_name, -1, 681 | dtypes.int32, [1]) 682 | self.add_output_graph_node(reshape_dims_node) 683 | reduction_dims_node = create_constant_node(reduction_dims_name, 0, 684 | dtypes.int32, [1]) 685 | self.add_output_graph_node(reduction_dims_node) 686 | return reshape_dims_name, reduction_dims_name 687 | 688 | def eightbitize_input_to_node(self, namespace_prefix, original_input_name, 689 | reshape_dims_name, reduction_dims_name): 690 | """Takes one float input to an op, and converts it to quantized form.""" 691 | unique_input_name = unique_node_name_from_input(original_input_name) 692 | reshape_input_name = namespace_prefix + "_reshape_" + unique_input_name 693 | min_input_name = namespace_prefix + "_min_" + unique_input_name 694 | max_input_name = namespace_prefix + "_max_" + unique_input_name 695 | quantize_input_name = namespace_prefix + "_quantize_" + unique_input_name 696 | reshape_input_node = create_node("Reshape", reshape_input_name, 697 | [original_input_name, reshape_dims_name]) 698 | set_attr_dtype(reshape_input_node, "T", dtypes.float32) 699 | self.add_output_graph_node(reshape_input_node) 700 | min_input_node = create_node("Min", min_input_name, 701 | [reshape_input_name, reduction_dims_name]) 702 | set_attr_dtype(min_input_node, "T", dtypes.float32) 703 | set_attr_bool(min_input_node, "keep_dims", False) 704 | self.add_output_graph_node(min_input_node) 705 | max_input_node = create_node("Max", max_input_name, 706 | [reshape_input_name, reduction_dims_name]) 707 | set_attr_dtype(max_input_node, "T", dtypes.float32) 708 | set_attr_bool(max_input_node, "keep_dims", False) 709 | self.add_output_graph_node(max_input_node) 710 | quantize_input_node = create_node( 711 | "QuantizeV2", quantize_input_name, 712 | [original_input_name, min_input_name, max_input_name]) 713 | set_attr_dtype(quantize_input_node, "T", dtypes.quint8) 714 | set_attr_string(quantize_input_node, "mode", b"MIN_FIRST") 715 | self.add_output_graph_node(quantize_input_node) 716 | min_output_name = quantize_input_name + ":1" 717 | max_output_name = quantize_input_name + ":2" 718 | return quantize_input_name, min_output_name, max_output_name 719 | 720 | def add_quantize_down_nodes(self, original_node, quantized_output_name): 721 | quantized_outputs = [ 722 | quantized_output_name, quantized_output_name + ":1", 723 | quantized_output_name + ":2" 724 | ] 725 | min_max_inputs = None 726 | if self.should_merge_with_fake_quant_node(): 727 | # Use the inputs to the FakeQuantWithMinMaxVars node as the inputs to 728 | # Requantize. 729 | fake_quant_node = self.state.output_node_stack[-1][0] 730 | min_max_inputs = [fake_quant_node.input[1], fake_quant_node.input[2]] 731 | assert original_node.name not in self.state.merged_with_fake_quant 732 | self.state.merged_with_fake_quant[original_node.name] = True 733 | elif self.fallback_quantization_range: 734 | min_max_inputs = [ 735 | "fallback_quantization_min_value:0", 736 | "fallback_quantization_max_value:0" 737 | ] 738 | else: 739 | # Add a RequantizationRange node for finding the min and max values. 740 | requant_range_node = create_node( 741 | "RequantizationRange", original_node.name + "_eightbit_requant_range", 742 | quantized_outputs) 743 | set_attr_dtype(requant_range_node, "Tinput", dtypes.qint32) 744 | self.add_output_graph_node(requant_range_node) 745 | min_max_inputs = [ 746 | requant_range_node.name + ":0", requant_range_node.name + ":1" 747 | ] 748 | requantize_node = create_node("Requantize", 749 | original_node.name + "_eightbit_requantize", 750 | quantized_outputs + min_max_inputs) 751 | set_attr_dtype(requantize_node, "Tinput", dtypes.qint32) 752 | set_attr_dtype(requantize_node, "out_type", dtypes.quint8) 753 | self.add_output_graph_node(requantize_node) 754 | return requantize_node.name 755 | 756 | def add_dequantize_result_node(self, 757 | quantized_output_name, 758 | original_node_name, 759 | min_tensor_index=1): 760 | min_max_inputs = [ 761 | "%s:%s" % (quantized_output_name, min_tensor_index), 762 | "%s:%s" % (quantized_output_name, (min_tensor_index + 1)) 763 | ] 764 | dequantize_name = original_node_name 765 | if self.should_merge_with_fake_quant_node(): 766 | fake_quant_node = self.state.output_node_stack[-1][0] 767 | if original_node_name not in self.state.merged_with_fake_quant: 768 | min_max_inputs = [fake_quant_node.input[1], fake_quant_node.input[2]] 769 | self.state.merged_with_fake_quant[original_node_name] = True 770 | dequantize_name = fake_quant_node.name 771 | 772 | dequantize_node = create_node( 773 | "Dequantize", dequantize_name, 774 | [quantized_output_name, min_max_inputs[0], min_max_inputs[1]]) 775 | set_attr_dtype(dequantize_node, "T", dtypes.quint8) 776 | set_attr_string(dequantize_node, "mode", b"MIN_FIRST") 777 | self.add_output_graph_node(dequantize_node) 778 | 779 | def eightbitize_mat_mul_node(self, original_node): 780 | """Replaces a MatMul node with the eight bit equivalent sub-graph.""" 781 | quantized_mat_mul_name = original_node.name + "_eightbit_quantized_mat_mul" 782 | all_input_names = self.add_eightbit_prologue_nodes(original_node) 783 | quantized_mat_mul_node = create_node("QuantizedMatMul", 784 | quantized_mat_mul_name, 785 | all_input_names) 786 | set_attr_dtype(quantized_mat_mul_node, "T1", dtypes.quint8) 787 | set_attr_dtype(quantized_mat_mul_node, "T2", dtypes.quint8) 788 | set_attr_dtype(quantized_mat_mul_node, "Toutput", dtypes.qint32) 789 | copy_attr(quantized_mat_mul_node, "transpose_a", 790 | original_node.attr["transpose_a"]) 791 | copy_attr(quantized_mat_mul_node, "transpose_b", 792 | original_node.attr["transpose_b"]) 793 | self.add_output_graph_node(quantized_mat_mul_node) 794 | quantize_down_name = self.add_quantize_down_nodes(original_node, 795 | quantized_mat_mul_name) 796 | self.add_dequantize_result_node(quantize_down_name, original_node.name) 797 | 798 | def eightbitize_conv_node(self, original_node): 799 | """Replaces a Conv2D node with the eight bit equivalent sub-graph.""" 800 | all_input_names = self.add_eightbit_prologue_nodes(original_node) 801 | quantized_conv_name = original_node.name + "_eightbit_quantized_conv" 802 | quantized_conv_node = create_node("QuantizedConv2D", quantized_conv_name, 803 | all_input_names) 804 | copy_attr(quantized_conv_node, "strides", original_node.attr["strides"]) 805 | copy_attr(quantized_conv_node, "padding", original_node.attr["padding"]) 806 | set_attr_dtype(quantized_conv_node, "Tinput", dtypes.quint8) 807 | set_attr_dtype(quantized_conv_node, "Tfilter", dtypes.quint8) 808 | set_attr_dtype(quantized_conv_node, "out_type", dtypes.qint32) 809 | self.add_output_graph_node(quantized_conv_node) 810 | quantize_down_name = self.add_quantize_down_nodes(original_node, 811 | quantized_conv_name) 812 | self.add_dequantize_result_node(quantize_down_name, original_node.name) 813 | 814 | def eightbitize_bias_add_node(self, original_node): 815 | """Replaces a BiasAdd node with the eight bit equivalent sub-graph.""" 816 | quantized_bias_add_name = ( 817 | original_node.name + "_eightbit_quantized_bias_add") 818 | all_input_names = self.add_eightbit_prologue_nodes(original_node) 819 | quantized_bias_add_node = create_node("QuantizedBiasAdd", 820 | quantized_bias_add_name, 821 | all_input_names) 822 | set_attr_dtype(quantized_bias_add_node, "T1", dtypes.quint8) 823 | set_attr_dtype(quantized_bias_add_node, "T2", dtypes.quint8) 824 | set_attr_dtype(quantized_bias_add_node, "out_type", dtypes.qint32) 825 | self.add_output_graph_node(quantized_bias_add_node) 826 | quantize_down_name = self.add_quantize_down_nodes(original_node, 827 | quantized_bias_add_name) 828 | self.add_dequantize_result_node(quantize_down_name, original_node.name) 829 | 830 | def eightbitize_single_input_tensor_node(self, original_node, 831 | add_op_function): 832 | """Replaces a single-tensor node with the eight bit equivalent sub-graph. 833 | 834 | Converts a node like this: 835 | 836 | Shape(f) Input(f) 837 | | | 838 | +--------v v 839 | Operation 840 | | 841 | v 842 | (f) 843 | 844 | Into a quantized equivalent: 845 | 846 | Input(f) ReshapeDims 847 | +------v v-------------+ 848 | | Reshape 849 | | | 850 | | | ReductionDims 851 | | +-----+ | 852 | | | +---c---------+ 853 | | v v v v-------+ 854 | | Min Max 855 | | +----+ | 856 | v v v--------+ 857 | Quantize 858 | | 859 | v 860 | QuantizedOperation 861 | | | | 862 | v v v 863 | Dequantize 864 | | 865 | v 866 | (f) 867 | 868 | 869 | Args: 870 | original_node: Float node to be converted. 871 | add_op_function: Function to create the actual node. 872 | 873 | Returns: 874 | Subgraph representing the quantized version of the original node. 875 | 876 | """ 877 | quantized_op_name = original_node.name + "_eightbit_quantized" 878 | quantized_op_type = "Quantized" + original_node.op 879 | all_input_names = self.add_eightbit_prologue_nodes(original_node) 880 | quantized_op_node = create_node(quantized_op_type, quantized_op_name, 881 | all_input_names) 882 | add_op_function(original_node, quantized_op_node) 883 | self.add_output_graph_node(quantized_op_node) 884 | self.add_dequantize_result_node(quantized_op_name, original_node.name) 885 | 886 | def add_pool_function(self, original_node, quantized_op_node): 887 | set_attr_dtype(quantized_op_node, "T", dtypes.quint8) 888 | copy_attr(quantized_op_node, "ksize", original_node.attr["ksize"]) 889 | copy_attr(quantized_op_node, "strides", original_node.attr["strides"]) 890 | copy_attr(quantized_op_node, "padding", original_node.attr["padding"]) 891 | 892 | def add_relu_function(self, unused_arg_node, quantized_op_node): 893 | set_attr_dtype(quantized_op_node, "Tinput", dtypes.quint8) 894 | 895 | def eightbitize_concat_node(self, original_node): 896 | """Replaces a Concat node with the eight bit equivalent sub-graph. 897 | 898 | Converts a node like this: 899 | 900 | Shape(f) Input0(f) Input1(f) 901 | | | | 902 | +--------v v v----------+ 903 | Concat 904 | | 905 | v 906 | (f) 907 | 908 | Into a quantized equivalent: 909 | 910 | Shape(f) Input0(f) ReshapeDims Input1(f) 911 | | +------v v--------------+------------------v v------+ 912 | | | Reshape Reshape | 913 | | | | | | 914 | | | | ReductionDims | | 915 | | | +------+ | +--------+ | 916 | | | | +---c---------+-----------c-----+ | | 917 | | | +v v v v-------+---------v v v v+ | 918 | | | Min Max Min Max | 919 | | | +----+ | | +-----+ | 920 | | v v v--------+ +----------v v v 921 | | Quantize Quantize 922 | | +------------------+ +----------------------+ 923 | +-------------------------------+ | | 924 | v v v 925 | QuantizedConcat 926 | | | | 927 | v v v 928 | Dequantize 929 | | 930 | v 931 | (f) 932 | Args: 933 | original_node: Float node to be converted. 934 | 935 | Returns: 936 | Subgraph representing the quantized version of the original node. 937 | 938 | """ 939 | namespace_prefix = original_node.name + "_eightbit" 940 | quantized_concat_name = namespace_prefix + "_quantized_concat" 941 | reshape_dims_name, reduction_dims_name = self.add_common_quantization_nodes( 942 | namespace_prefix) 943 | shape_input_name = original_node.input[0] 944 | original_inputs = original_node.input[1:] 945 | input_names = [] 946 | min_names = [] 947 | max_names = [] 948 | for original_input_name in original_inputs: 949 | quantize_input_name, min_input_name, max_input_name = ( 950 | self.eightbitize_input_to_node(namespace_prefix, original_input_name, 951 | reshape_dims_name, 952 | reduction_dims_name)) 953 | input_names.append(quantize_input_name) 954 | min_names.append(min_input_name) 955 | max_names.append(max_input_name) 956 | all_input_names = [shape_input_name] 957 | all_input_names.extend(input_names) 958 | all_input_names.extend(min_names) 959 | all_input_names.extend(max_names) 960 | quantized_concat_node = create_node("QuantizedConcat", 961 | quantized_concat_name, all_input_names) 962 | set_attr_int(quantized_concat_node, "N", len(original_inputs)) 963 | set_attr_dtype(quantized_concat_node, "T", dtypes.quint8) 964 | self.add_output_graph_node(quantized_concat_node) 965 | self.add_dequantize_result_node(quantized_concat_name, original_node.name) 966 | 967 | def eightbitize_placeholder_node(self, current_node): 968 | """Replaces a placeholder node with a quint8 placeholder node+dequantize.""" 969 | name = current_node.name 970 | 971 | # Convert the placeholder into a quantized type. 972 | output_node = node_def_pb2.NodeDef() 973 | output_node.CopyFrom(current_node) 974 | set_attr_dtype(output_node, "dtype", dtypes.quint8) 975 | output_node.name += "_original_input" 976 | self.add_output_graph_node(output_node) 977 | 978 | # Add a dequantize to convert back to float. 979 | dequantize_node = create_node("Dequantize", name, [ 980 | output_node.name, "quantized_input_min_value", 981 | "quantized_input_max_value" 982 | ]) 983 | set_attr_dtype(dequantize_node, "T", dtypes.quint8) 984 | set_attr_string(dequantize_node, "mode", b"MIN_FIRST") 985 | self.add_output_graph_node(dequantize_node) 986 | 987 | # For the descent over the graph to work, the dequantize node must be named 988 | # current_node.name. However, for the feeding of the graph to work, the 989 | # placeholder must have the name current_node.name; so record a final set 990 | # of renames to apply after all processing has been done. 991 | self.final_node_renames[output_node.name] = name 992 | self.final_node_renames[dequantize_node.name] = name + "_dequantize" 993 | 994 | def eightbitize_reshape_node(self, original_node): 995 | """Replaces a Reshape node with the eight bit equivalent sub-graph. 996 | 997 | Args: 998 | original_node: Float node to be converted. 999 | 1000 | Returns: 1001 | Subgraph representing the quantized version of the original node. 1002 | 1003 | """ 1004 | namespace_prefix = original_node.name + "_eightbit" 1005 | quantized_reshape_name = namespace_prefix + "_quantized_reshape" 1006 | reshape_dims_name, reduction_dims_name = self.add_common_quantization_nodes( 1007 | namespace_prefix) 1008 | shape_input_name = original_node.input[1] 1009 | quantize_input_name, min_input_name, max_input_name = ( 1010 | self.eightbitize_input_to_node(namespace_prefix, original_node.input[0], 1011 | reshape_dims_name, reduction_dims_name)) 1012 | quantized_reshape_node = create_node( 1013 | "QuantizedReshape", quantized_reshape_name, 1014 | [quantize_input_name, shape_input_name, min_input_name, max_input_name]) 1015 | set_attr_dtype(quantized_reshape_node, "T", dtypes.quint8) 1016 | self.add_output_graph_node(quantized_reshape_node) 1017 | self.add_dequantize_result_node(quantized_reshape_name, original_node.name) 1018 | 1019 | def eightbitize_batch_norm_node(self, original_node): 1020 | """Replaces a MatMul node with the eight bit equivalent sub-graph.""" 1021 | namespace_prefix = original_node.name + "_eightbit" 1022 | original_input_name = original_node.input[0] 1023 | original_mean_name = original_node.input[1] 1024 | original_variance_name = original_node.input[2] 1025 | original_beta_name = original_node.input[3] 1026 | original_gamma_name = original_node.input[4] 1027 | quantized_batch_norm_name = namespace_prefix + "_quantized_batch_norm" 1028 | 1029 | reshape_dims_name, reduction_dims_name = self.add_common_quantization_nodes( 1030 | namespace_prefix) 1031 | quantize_input_name, min_input_name, max_input_name = ( 1032 | self.eightbitize_input_to_node(namespace_prefix, original_input_name, 1033 | reshape_dims_name, reduction_dims_name)) 1034 | quantize_mean_name, min_mean_name, max_mean_name = ( 1035 | self.eightbitize_input_to_node(namespace_prefix, original_mean_name, 1036 | reshape_dims_name, reduction_dims_name)) 1037 | quantize_variance_name, min_variance_name, max_variance_name = ( 1038 | self.eightbitize_input_to_node(namespace_prefix, original_variance_name, 1039 | reshape_dims_name, reduction_dims_name)) 1040 | quantize_beta_name, min_beta_name, max_beta_name = ( 1041 | self.eightbitize_input_to_node(namespace_prefix, original_beta_name, 1042 | reshape_dims_name, reduction_dims_name)) 1043 | quantize_gamma_name, min_gamma_name, max_gamma_name = ( 1044 | self.eightbitize_input_to_node(namespace_prefix, original_gamma_name, 1045 | reshape_dims_name, reduction_dims_name)) 1046 | quantized_batch_norm_node = create_node( 1047 | "QuantizedBatchNormWithGlobalNormalization", quantized_batch_norm_name, 1048 | [ 1049 | quantize_input_name, min_input_name, max_input_name, 1050 | quantize_mean_name, min_mean_name, max_mean_name, 1051 | quantize_variance_name, min_variance_name, max_variance_name, 1052 | quantize_beta_name, min_beta_name, max_beta_name, 1053 | quantize_gamma_name, min_gamma_name, max_gamma_name 1054 | ]) 1055 | set_attr_dtype(quantized_batch_norm_node, "Tinput", dtypes.quint8) 1056 | set_attr_dtype(quantized_batch_norm_node, "out_type", dtypes.qint32) 1057 | copy_attr(quantized_batch_norm_node, "scale_after_normalization", 1058 | original_node.attr["scale_after_normalization"]) 1059 | copy_attr(quantized_batch_norm_node, "variance_epsilon", 1060 | original_node.attr["variance_epsilon"]) 1061 | self.add_output_graph_node(quantized_batch_norm_node) 1062 | quantize_down_name = self.add_quantize_down_nodes(original_node, 1063 | quantized_batch_norm_name) 1064 | self.add_dequantize_result_node(quantize_down_name, original_node.name) 1065 | 1066 | def add_output_graph_node(self, output_node): 1067 | """Inserts one node into the new graph.""" 1068 | self.output_graph.node.extend([output_node]) 1069 | 1070 | def remove_redundant_quantization(self, old_graph): 1071 | """Removes unneeded pairs of quantize/dequantize ops from the graph. 1072 | 1073 | This is a bit of a tricky function, because it's attempting to spot the 1074 | pattern of dequantizing from eight-bit up to float, and then immediately 1075 | quantizing back down to eight bits again, that's introduced by previous 1076 | passes that do 'key-hole' conversions of individual nodes but have to 1077 | convert back to float to match the previous output interface, since they 1078 | don't know that the next op can handle quantized tensors. 1079 | It works by: 1080 | - Looking for Quantize nodes. 1081 | - Checking to see if their first input is a Dequantize node. 1082 | - Seeing if their min/max inputs come from Min/Max nodes. 1083 | - Making sure those Min/Max nodes are being fed from the same Dequantize. 1084 | - Or that the Min is indirectly being fed from the same Dequantize as Max. 1085 | - Making sure the Dequantize is going through a Reshape (which we add 1086 | during the previous pass when we create the quantize sub-graph). 1087 | - Looking for the dims Const op for the Min/Max dims. 1088 | If all of these conditions are met, then it's a sub-graph pattern that 1089 | we know how to optimize out (and is likely the common one we've introduced). 1090 | We then rewire the graph to skip it entirely, and then rely on the dead node 1091 | removal pass to get rid of any nodes that are no longer needed. 1092 | 1093 | Args: 1094 | old_graph: The model we'll be stripping redundant nodes from. 1095 | 1096 | Returns: 1097 | A graph with the unnecessary nodes removed. 1098 | 1099 | Raises: 1100 | ValueError: Two nodes with the same name were found in the graph. 1101 | """ 1102 | old_nodes_map = self.create_nodes_map(old_graph) 1103 | self.output_graph = graph_pb2.GraphDef() 1104 | inputs_to_rename = {} 1105 | # We go through all the nodes, looking for any that match the patterns we 1106 | # know how to optimize away. 1107 | for node in old_graph.node: 1108 | # We always start with a Quantize node, and examine its inputs to see if 1109 | # they are in a form that can be removed. 1110 | if node.op not in ["Quantize", "QuantizeV2"]: 1111 | continue 1112 | dequantize_node_name = node_name_from_input(node.input[0]) 1113 | if dequantize_node_name not in old_nodes_map: 1114 | raise ValueError("Input node name '" + dequantize_node_name + 1115 | "' not found in node '" + node.name + "'") 1116 | dequantize_node = old_nodes_map[dequantize_node_name] 1117 | # Do we have a Dequantize feeding in, with the same type as the Quantize? 1118 | if dequantize_node.op != "Dequantize": 1119 | continue 1120 | if node.attr["T"] != dequantize_node.attr["T"]: 1121 | continue 1122 | # Now look at the other inputs, and ensure they're Min/Max nodes. 1123 | min_node_name = node_name_from_input(node.input[1]) 1124 | max_node_name = node_name_from_input(node.input[2]) 1125 | min_node = old_nodes_map[min_node_name] 1126 | max_node = old_nodes_map[max_node_name] 1127 | is_min_right_type = (min_node.op in ["Min", "Dequantize"]) 1128 | is_max_right_type = (max_node.op in ["Max", "Dequantize"]) 1129 | if not is_min_right_type or not is_max_right_type: 1130 | print("Didn't find expected types on inputs : %s, %s." % (min_node.op, 1131 | max_node.op)) 1132 | continue 1133 | min_node_input_name = node_name_from_input(min_node.input[0]) 1134 | max_node_input_name = node_name_from_input(max_node.input[0]) 1135 | # There are two different patterns for Min nodes we can recognize, one 1136 | # where the input comes directly from the same one as the Max, and 1137 | # another where we run it through another Min first, so check for both. 1138 | is_same_input = False 1139 | if min_node_input_name == max_node_input_name: 1140 | is_same_input = True 1141 | else: 1142 | first_min_node_input = old_nodes_map[min_node_input_name] 1143 | if first_min_node_input.op == "Concat": 1144 | second_min_node_name = node_name_from_input( 1145 | first_min_node_input.input[1]) 1146 | second_min_node = old_nodes_map[second_min_node_name] 1147 | if second_min_node.op == "Min": 1148 | second_min_node_input_name = node_name_from_input( 1149 | second_min_node.input[0]) 1150 | is_same_input = (second_min_node_input_name == max_node_input_name) 1151 | if not is_same_input: 1152 | print("Different min/max inputs: " + min_node_input_name) 1153 | continue 1154 | # We recognize this pattern, so mark the graph edges to be rewired to 1155 | # route around it entirely, since we know it's a no-op. 1156 | dequantize_source_name = node_name_from_input(dequantize_node.input[0]) 1157 | node_tensor_name = ensure_tensor_name_has_port(node.name) 1158 | min_tensor_name = node.name + ":1" 1159 | max_tensor_name = node.name + ":2" 1160 | inputs_to_rename[node_tensor_name] = dequantize_source_name 1161 | inputs_to_rename[min_tensor_name] = dequantize_node.input[1] 1162 | inputs_to_rename[max_tensor_name] = dequantize_node.input[2] 1163 | # Finally we apply all the rewiring we've marked to the graph. 1164 | for node in old_graph.node: 1165 | for index, input_full_name in enumerate(node.input): 1166 | input_name = ensure_tensor_name_has_port(input_full_name) 1167 | if input_name in inputs_to_rename: 1168 | node.input[index] = inputs_to_rename[input_name] 1169 | self.add_output_graph_node(node) 1170 | return self.output_graph 1171 | 1172 | def apply_final_node_renames(self): 1173 | """Applies node renames in self.final_node_renames to self.output_graph.""" 1174 | old_graph = self.output_graph 1175 | self.output_graph = graph_pb2.GraphDef() 1176 | for node in old_graph.node: 1177 | node.name = self.final_node_renames.get(node.name, node.name) 1178 | for index, input_name in enumerate(node.input): 1179 | node_name = node_name_from_input(input_name) 1180 | input_full_name = ensure_tensor_name_has_port(input_name) 1181 | if node_name in self.final_node_renames: 1182 | node.input[index] = "%s%s" % (self.final_node_renames[node_name], 1183 | input_full_name[len(node_name):]) 1184 | self.add_output_graph_node(node) 1185 | return self.output_graph 1186 | 1187 | def remove_dead_nodes(self, output_names): 1188 | """Removes nodes that are no longer needed for inference from the graph.""" 1189 | old_output_graph = self.output_graph 1190 | self.output_graph = graph_util.extract_sub_graph(old_output_graph, 1191 | output_names) 1192 | 1193 | def quantize_weights(self, input_graph, quantization_mode): 1194 | """Quantize float Const ops. 1195 | 1196 | There are two modes of operations, both replace float Const ops with 1197 | quantized values. 1198 | 1. If quantization_mode is "weights_rounded", this function replaces float 1199 | Const ops with quantized float Const ops - same as the original op, but 1200 | float values being mapped to the center of one of 1<%s" % caption)) 36 | -------------------------------------------------------------------------------- /src/attributes.js: -------------------------------------------------------------------------------- 1 | 2 | const { 3 | aboutZero, 4 | aboutCentered, 5 | aboutEqual, 6 | allAboutEqual, 7 | smallComparedTo 8 | } = require('./lib/utils'); 9 | 10 | 11 | const nativeAttrs = (js, ele, idDims, childParent) => { 12 | const dims = idDims[js.id] ? idDims[js.id] : {}; 13 | let componentStyles = {}; 14 | 15 | const { 16 | top, 17 | left, 18 | width, 19 | // height, 20 | position, 21 | ...rest 22 | } = js.style 23 | 24 | let jsStyles = {...rest} 25 | 26 | if(width < 360) { 27 | jsStyles.width = width; 28 | } 29 | 30 | // Note: 360 is the max width of iphone 31 | if(js.type == 'View' && !jsStyles.width && dims.width < 360) { 32 | jsStyles.width = dims.width; 33 | } 34 | 35 | if(js.type == 'Text' || js.type == 'Tspan') { 36 | const parent = childParent[js.id]; 37 | const parentDims = parent && idDims[parent]; 38 | if(parentDims) { 39 | const spaceBefore = dims.left - parentDims.left; 40 | const spaceAfter = parentDims.right - dims.right; 41 | 42 | if(aboutCentered(spaceBefore, spaceAfter) && !aboutZero(spaceBefore)) { 43 | jsStyles['textAlign'] = 'center'; 44 | } 45 | 46 | if(!aboutCentered(spaceBefore, spaceAfter) && spaceBefore < spaceAfter) { 47 | jsStyles['textAlign'] = 'left'; 48 | jsStyles['marginLeft'] = spaceBefore; 49 | } 50 | 51 | if(!aboutCentered(spaceBefore, spaceAfter) && spaceBefore > spaceAfter) { 52 | jsStyles['textAlign'] = 'right'; 53 | jsStyles['marginRight'] = spaceAfter; 54 | } 55 | 56 | } 57 | } 58 | 59 | 60 | const parent = childParent[ele.id] 61 | if(parent) { 62 | const parentDims = parent == 'row' || parent == 'column' ? parent.parentDims : idDims[parent] 63 | if(js.type == 'View' && !jsStyles.height && !aboutEqual(dims.height, parentDims.height)) { 64 | jsStyles.height = dims.height; 65 | } 66 | } 67 | 68 | let styles = jsStyles ? Object.keys(jsStyles).reduce((r, key) => { 69 | if(key == 'top' || 70 | key == 'left' || 71 | key == 'flex' || 72 | key == 'height' || 73 | key == 'width' || 74 | key == 'fontSize' || 75 | key == 'lineHeight' || 76 | key == 'borderRadius' || 77 | key == 'paddingTop' || 78 | key == 'paddingBottom' || 79 | key == 'paddingRight' || 80 | key == 'paddingLeft' || 81 | key == 'marginTop' || 82 | key == 'marginBottom' || 83 | key == 'marginLeft' || 84 | key == 'marginRight') { 85 | if(jsStyles[key] > 0) { 86 | r[key] = Math.round(jsStyles[key]) 87 | } 88 | } else { 89 | r[key] = `'${jsStyles[key]}'` 90 | } 91 | return r; 92 | }, {}) : {}; 93 | 94 | 95 | if(ele.alignJustify) { 96 | Object.keys(ele.alignJustify).forEach((key) => { 97 | if(key == 'paddingTop' || 98 | key == 'paddingBottom' || 99 | key == 'paddingRight' || 100 | key == 'paddingLeft' || 101 | key == 'marginTop' || 102 | key == 'marginBottom' || 103 | key == 'marginLeft' || 104 | key == 'marginRight') { 105 | if(ele.alignJustify[key] > 0) { 106 | styles[key] = Math.round(ele.alignJustify[key]); 107 | } 108 | } else { 109 | styles[key] = `'${ele.alignJustify[key]}'` 110 | } 111 | }) 112 | } 113 | 114 | 115 | let attrObjs = {} 116 | 117 | if(js.directAttrs) { 118 | Object.keys(js.directAttrs).forEach((key) => { 119 | attrObjs[key] = "'" + js.directAttrs[key] + "'" 120 | }) 121 | } 122 | 123 | const attrString = Object.keys(attrObjs).map((key) => { 124 | return `${key}={${attrObjs[key]}}` 125 | }).join(" "); 126 | 127 | const attrs = attrString.length > 0 ? " " + attrString : ""; 128 | let styleId = js.id.replace(/[^\_0-9a-zA-Z]/g, '') 129 | if(styleId.match(/^[0-9]/)) { 130 | styleId = "_" + styleId 131 | } 132 | 133 | componentStyles[styleId] = styles; 134 | return({attrs, attrObjs, styleId, componentStyles}); 135 | } 136 | 137 | const firstBackgroundColor = (js, base, idDims) => { 138 | const rDims = idDims[base.id] 139 | const jsDims = idDims[js.id] 140 | let backgroundColor = null; 141 | if(rDims && jsDims && rDims.top == jsDims.top && rDims.bottom == jsDims.bottom && rDims.left == jsDims.left && rDims.right == jsDims.right) { 142 | if(js.style.backgroundColor) { 143 | backgroundColor = js.style.backgroundColor; 144 | } 145 | } 146 | js.childs.forEach((child) => { 147 | const childColor = firstBackgroundColor(child, base, idDims); 148 | if(childColor) { 149 | backgroundColor = childColor; 150 | } 151 | }); 152 | return backgroundColor; 153 | } 154 | 155 | 156 | const determineAlignJustify = (comp, children, idDims, jsObjs) => { 157 | let alignJustify = {} 158 | 159 | if(comp.id == 'row' && children && children.length > 0 && idDims[children[0].id]) { 160 | const spaceBefore = idDims[children[0].id].left - comp.parentDims.left; 161 | const spaceAfter = comp.parentDims.right - idDims[children[children.length - 1].id].right; 162 | let spacesBetween = []; 163 | 164 | children.forEach((child, index) => { 165 | if(index > 0) { 166 | spacesBetween.push(idDims[children[index].id].left - idDims[children[index - 1].id].right) 167 | } 168 | }); 169 | 170 | if(spacesBetween.length > 0 && allAboutEqual(spacesBetween)) { 171 | // space between case 172 | if(aboutEqual(spaceBefore, spaceAfter)) { 173 | if(smallComparedTo(spaceBefore, spacesBetween[0])) { 174 | alignJustify['justifyContent'] = 'space-between' 175 | } 176 | } 177 | // space around case 178 | if(aboutEqual(spaceBefore, spaceAfter) && !aboutZero(spaceBefore)) { 179 | alignJustify['justifyContent'] = 'space-around' 180 | } 181 | } 182 | 183 | } else if(comp.id == 'column') { 184 | 185 | } else { 186 | if(children.id) { 187 | // One child 188 | const childDims = children.id == 'row' || children.id == 'column' ? children.parentDims : idDims[children.id]; 189 | const parentDims = comp.id == 'row' || children.id == 'column' ? comp.parentDims : idDims[comp.id]; 190 | 191 | const spaceBefore = childDims.left - parentDims.left; 192 | const spaceAfter = parentDims.right - childDims.right; 193 | const spaceAbove = childDims.top - parentDims.top; 194 | const spaceBelow = parentDims.bottom - childDims.bottom; 195 | 196 | if(aboutEqual(spaceBefore, spaceAfter)) { 197 | alignJustify['alignItems'] = 'center' 198 | } else if(spaceBefore > spaceAfter) { 199 | alignJustify['alignItems'] = 'flex-end' 200 | alignJustify['marginRight'] = spaceAfter 201 | } else if(spaceBefore < spaceAfter) { 202 | alignJustify['alignItems'] = 'flex-start' 203 | alignJustify['marginLeft'] = spaceBefore 204 | } 205 | 206 | if(aboutEqual(spaceAbove, spaceBelow)) { 207 | alignJustify['justifyContent'] = 'center' 208 | } else if(spaceAbove > spaceBelow) { 209 | alignJustify['justifyContent'] = 'flex-end' 210 | alignJustify['marginBottom'] = spaceBelow 211 | } else if(spaceAbove < spaceBelow) { 212 | alignJustify['justifyContent'] = 'flex-start' 213 | alignJustify['marginTop'] = spaceAbove 214 | } 215 | 216 | } else { 217 | // Multiple Children; default to column 218 | 219 | const firstChildDims = children[0].id == 'row' || children[0].id == 'column' ? children[0].parentDims : idDims[children[0].id] 220 | const lastChildDims = children[children.length - 1].id == 'row' || children[children.length - 1].id == 'column' ? children[children.length - 1].parentDims : idDims[children[children.length - 1].id] 221 | const parentDims = comp.id == 'row' || children.id == 'column' ? comp.parentDims : idDims[comp.id]; 222 | 223 | const spaceAbove = firstChildDims.top - parentDims.top; 224 | const spaceBelow = parentDims.bottom - lastChildDims.bottom; 225 | 226 | let spacesBetween = []; 227 | children.forEach((child, index) => { 228 | const childJS = jsObjs[child.id]; 229 | const childDims = child.id == 'row' || child.id == 'column' ? child.parentDims : idDims[child.id] 230 | 231 | if(child.id != 'row' && child.id != 'column') { 232 | const spaceBefore = childDims.left - parentDims.left; 233 | const spaceAfter = parentDims.right - childDims.right; 234 | if(aboutEqual(spaceBefore, spaceAfter)) { 235 | childJS.style['alignSelf'] = 'center'; 236 | } else if(spaceBefore > spaceAfter) { 237 | childJS.style['alignSelf'] = 'flex-end'; 238 | childJS.style['marginRight'] = spaceAfter; 239 | } else if(spaceBefore < spaceAfter) { 240 | childJS.style['alignSelf'] = 'flex-start'; 241 | childJS.style['marginLeft'] = spaceBefore; 242 | } 243 | 244 | } 245 | 246 | if(index > 0) { 247 | const lastChildDims = children[index - 1].id == 'row' || children[index - 1].id == 'column' ? children[index - 1].parentDims : idDims[children[index - 1].id] 248 | spacesBetween.push(childDims.top - lastChildDims.bottom) 249 | if(childJS) { 250 | childJS.style['marginTop'] = childDims.top - lastChildDims.bottom; 251 | } 252 | } 253 | }); 254 | 255 | if(!aboutEqual(spaceBelow, spaceAbove)) { 256 | // apply padding 257 | alignJustify['paddingTop'] = spaceAbove; 258 | alignJustify['paddingBottom'] = spaceBelow; 259 | } 260 | 261 | } 262 | } 263 | 264 | return alignJustify; 265 | } 266 | 267 | 268 | module.exports.nativeAttrs = nativeAttrs; 269 | module.exports.determineAlignJustify = determineAlignJustify; 270 | module.exports.firstBackgroundColor = firstBackgroundColor; 271 | -------------------------------------------------------------------------------- /src/components.js: -------------------------------------------------------------------------------- 1 | 2 | const generateChildParent = (orderedIds, idDims) => { 3 | let childParent = {}; 4 | 5 | orderedIds.forEach((key) => { 6 | const dims = idDims[key]; 7 | const top = Math.min(dims.top, dims.bottom); 8 | const bottom = Math.max(dims.top, dims.bottom); 9 | const left = Math.min(dims.left, dims.right); 10 | const right = Math.max(dims.left, dims.right); 11 | const height = dims.height; 12 | const width = dims.width; 13 | 14 | const parentNodes = orderedIds.filter((pKey) => { 15 | if(key == pKey) { 16 | return false; 17 | } 18 | 19 | const pDims = idDims[pKey]; 20 | const ptop = Math.min(pDims.top, pDims.bottom); 21 | const pbottom = Math.max(pDims.top, pDims.bottom); 22 | const pleft = Math.min(pDims.left, pDims.right); 23 | const pright = Math.max(pDims.left, pDims.right); 24 | const pheight = pDims.height; 25 | const pwidth = pDims.width; 26 | 27 | return(top >= ptop && bottom <= pbottom && left >= pleft && right <= pright) 28 | }) 29 | 30 | if(parentNodes.length > 0) { 31 | for(let i = parentNodes.length - 1; i>=0; i--) { 32 | if(orderedIds.indexOf(parentNodes[i]) < orderedIds.indexOf(key)) { 33 | childParent[key] = parentNodes[i] 34 | break; 35 | } 36 | } 37 | } 38 | 39 | }); 40 | 41 | return childParent; 42 | } 43 | 44 | module.exports.generateChildParent = generateChildParent; 45 | 46 | -------------------------------------------------------------------------------- /src/flex.js: -------------------------------------------------------------------------------- 1 | 2 | const { determineAlignJustify } = require('./attributes'); 3 | 4 | 5 | 6 | const flexColumns = (nodeIds, idDims, parentChildren, childParent) => { 7 | if(nodeIds.length == 0) { 8 | return nodeIds 9 | } 10 | 11 | const orderedNodeIds = nodeIds.sort((a, b) => { 12 | const adims = idDims[a]; 13 | const bdims = idDims[b]; 14 | 15 | if(adims.left == bdims.left) { 16 | return adims.top - bdims.top; 17 | } 18 | 19 | return adims.left - bdims.left 20 | }); 21 | 22 | let colBreaks = [0]; 23 | let currentRight = null; 24 | 25 | orderedNodeIds.forEach((nodeId, i) => { 26 | const node = idDims[nodeId]; 27 | if(!currentRight) { 28 | currentRight = node.right; 29 | } else { 30 | if(node.left > currentRight) { 31 | currentRight = node.right; 32 | // done with break 33 | colBreaks.push(i) 34 | } else { 35 | // same row 36 | if(node.right > currentRight) { 37 | currentRight = node.right 38 | } 39 | } 40 | } 41 | 42 | }) 43 | 44 | colBreaks.push(orderedNodeIds.length) 45 | let cols = []; 46 | 47 | for(let i=0; i { 72 | const children = parentChildren[c] ? parentChildren[c] : []; 73 | return { 74 | id: c, 75 | children: flexBox(children, idDims, parentChildren, childParent) 76 | } 77 | }) 78 | }); 79 | } else { 80 | let parentDims = idDims[childParent[comps[0]]]; 81 | const newLeft = minLeft(comps, idDims); 82 | const newRight = maxRight(comps, idDims); 83 | 84 | if(newLeft) { 85 | parentDims.left = newLeft 86 | } 87 | if(newRight) { 88 | parentDims.right = newRight 89 | } 90 | 91 | cols.push({ 92 | id: 'column', 93 | parentDims, 94 | children: flexRows(comps, idDims, parentChildren, childParent) 95 | }) 96 | } 97 | } 98 | 99 | return cols.length == 1 ? cols[0] : cols; 100 | } 101 | 102 | 103 | 104 | const flexRows = (nodeIds, idDims, parentChildren, childParent) => { 105 | if(nodeIds.length == 0) { 106 | return nodeIds 107 | } 108 | 109 | const orderedNodeIds = nodeIds.sort((a, b) => { 110 | const adims = idDims[a]; 111 | const bdims = idDims[b]; 112 | 113 | if(adims.top == bdims.top) { 114 | return adims.left - bdims.left; 115 | } 116 | 117 | return adims.top - bdims.top 118 | }); 119 | 120 | let rowBreaks = [0]; 121 | let currentBottom = null; 122 | 123 | orderedNodeIds.forEach((nodeId, i) => { 124 | // TODO: node is sometimes null here? 125 | const node = idDims[nodeId]; 126 | if(!currentBottom) { 127 | currentBottom = node && node.bottom; 128 | } else { 129 | if(node.top > currentBottom) { 130 | currentBottom = node && node.bottom; 131 | // done with break 132 | rowBreaks.push(i) 133 | } else { 134 | // same row 135 | if(node && node.bottom > currentBottom) { 136 | currentBottom = node && node.bottom 137 | } 138 | } 139 | } 140 | 141 | }) 142 | 143 | rowBreaks.push(orderedNodeIds.length) 144 | let rows = []; 145 | 146 | 147 | for(let i=0; i { 171 | const children = parentChildren[c] ? parentChildren[c] : []; 172 | return { 173 | id: c, 174 | children: flexBox(children, idDims, parentChildren, childParent) 175 | } 176 | }) 177 | }); 178 | } else { 179 | let parentDims = idDims[childParent[comps[0]]]; 180 | const newTop = minTop(comps, idDims); 181 | const newBottom = maxBottom(comps, idDims); 182 | 183 | if(newTop) { 184 | parentDims.top = newTop 185 | } 186 | if(newBottom) { 187 | parentDims.bottom = newBottom 188 | } 189 | 190 | rows.push({ 191 | id: 'row', 192 | parentDims, 193 | children: flexColumns(comps, idDims, parentChildren, childParent) 194 | }) 195 | } 196 | } 197 | 198 | return rows.length == 1 ? rows[0] : rows; 199 | } 200 | 201 | 202 | 203 | 204 | const minTop = (ids, idDims) => { 205 | let min = idDims[ids[0]].top 206 | ids.forEach((id) => { 207 | if(idDims[ids[0]].top < min) { 208 | min = idDims[ids[0]].top; 209 | } 210 | }) 211 | return min 212 | } 213 | 214 | const maxBottom = (ids, idDims) => { 215 | let max = ids && ids[0] && idDims[ids[0]] && idDims[ids[0]].bottom 216 | ids.forEach((id) => { 217 | if(idDims[ids[0]].bottom > max) { 218 | max = idDims[ids[0]].bottom; 219 | } 220 | }) 221 | return max 222 | } 223 | 224 | 225 | const minLeft = (ids, idDims) => { 226 | let min = idDims[ids[0]].left 227 | ids.forEach((id) => { 228 | if(idDims[ids[0]].left < min) { 229 | min = idDims[ids[0]].left; 230 | } 231 | }) 232 | return min 233 | } 234 | 235 | const maxRight = (ids, idDims) => { 236 | let max = idDims[ids[0]].right 237 | ids.forEach((id) => { 238 | if(idDims[ids[0]].right > max) { 239 | max = idDims[ids[0]].right; 240 | } 241 | }) 242 | return max 243 | } 244 | 245 | 246 | 247 | 248 | const flexBox = (nodeIds, idDims, parentChildren, childParent) => { 249 | // NOTE! nodes here can be cols or rows; currently it's just 250 | // based on whatever they are in the design... This 251 | // will have to be addressed when there is a design that 252 | // has sibling columns instead of rows. 253 | 254 | // SO, a todo here is to detect if siblings are rows or columns 255 | // (can they be both? no. have to break up into rows first, or cols first.) 256 | return flexRows(nodeIds, idDims, parentChildren, childParent) 257 | } 258 | 259 | 260 | 261 | 262 | 263 | const flattenBoxComponents = (comps, roodId, idDims, jsObjs, indent=0) => { 264 | let flatEles = []; 265 | 266 | const flattenBoxComps = (comps, rootId, idDims, jsObjs, indent=0) => { 267 | if(comps.id && comps.children && comps.children.id) { 268 | // single element with one child 269 | const compDims = idDims[comps.id]; 270 | const rootDims = idDims[rootId]; 271 | if(rootDims.top == compDims.top && rootDims.bottom == compDims.bottom && rootDims.left == compDims.left && rootDims.right == compDims.right) { 272 | flattenBoxComps(comps.children, rootId, idDims, jsObjs, indent) 273 | return flatEles 274 | } 275 | } 276 | 277 | const spaces = ' ' + new Array(indent + 1).join(' '); 278 | if(comps && comps.id) { 279 | const hasChildren = comps.children && comps.children.id || comps.children.length > 0; 280 | 281 | let alignJustify = {} 282 | if(hasChildren) { 283 | alignJustify = determineAlignJustify(comps, comps.children, idDims, jsObjs); 284 | } 285 | 286 | flatEles.push({spaces, id: comps.id, single: !hasChildren, alignJustify}); 287 | if(hasChildren) { 288 | flattenBoxComps(comps.children, rootId, idDims, jsObjs, indent + 1) 289 | flatEles.push({spaces, id: comps.id, end: true}); 290 | } 291 | } 292 | if(comps && Array.isArray(comps)) { 293 | comps.forEach((comp) => { 294 | const hasChildren = comp.children && comp.children.id || comp.children.length > 0; 295 | 296 | let alignJustify = {} 297 | if(hasChildren) { 298 | alignJustify = determineAlignJustify(comp, comp.children, idDims, jsObjs); 299 | } 300 | 301 | flatEles.push({spaces, id: comp.id, single: !hasChildren, alignJustify}); 302 | if(hasChildren) { 303 | flattenBoxComps(comp.children, rootId, idDims, jsObjs, indent + 1) 304 | flatEles.push({spaces, id: comp.id, end: true}); 305 | } 306 | }) 307 | } 308 | } 309 | 310 | flattenBoxComps(comps, roodId, idDims, jsObjs, indent); 311 | 312 | return flatEles; 313 | } 314 | 315 | 316 | 317 | 318 | module.exports.flexBox = flexBox; 319 | module.exports.flattenBoxComponents = flattenBoxComponents; 320 | -------------------------------------------------------------------------------- /src/index.js: -------------------------------------------------------------------------------- 1 | require("regenerator-runtime/runtime"); 2 | 3 | const svgson = require('svgson'); 4 | const fs = require('fs'); 5 | const path = require('path'); 6 | 7 | const { emptyAndCreateDir, makeDir, copyFolderRecursive } = require('./lib/files'); 8 | const { processNode, imagifyParents } = require('./process'); 9 | const { getBrowserBoundingBoxes, screenshotElements } = require('./screenshot'); 10 | const { flexBox, flattenBoxComponents } = require('./flex'); 11 | const { generateComponent, generateComponentStrings } = require('./output'); 12 | const { firstBackgroundColor, nativeAttrs } = require('./attributes'); 13 | const { generateChildParent } = require('./components'); 14 | const { prepData } = require('./input'); 15 | const { removeStatusBarAndKeyboard } = require('./neural_net'); 16 | const { 17 | aboutZero, 18 | aboutEqual, 19 | allAboutEqual, 20 | smallComparedTo 21 | } = require('./lib/utils'); 22 | 23 | // catch unhandled rejections (e.g. async/await without try/catch) 24 | process.on("unhandledRejection", function(err) { console.error(err); }); 25 | 26 | const INPUT_FILE = process.argv[2] 27 | if(!INPUT_FILE || INPUT_FILE == '' || !INPUT_FILE.match(/\.svg$/)) { 28 | throw "Usage: convert.js [svg_file]" 29 | } 30 | 31 | const pathArray = INPUT_FILE.split('/') 32 | const INPUT_FILENAME = pathArray[pathArray.length - 1] 33 | 34 | const INPUT_FILE_NO_SPACES = INPUT_FILENAME.replace(/\s/g, '_').split(".svg")[0] 35 | const OUTPUT_FILE = INPUT_FILE_NO_SPACES.split(".svg")[0] + ".js" 36 | 37 | const BASE_PATH = path.resolve(); 38 | const OUTPUT_DIR = 'output'; 39 | const TEMP_DIR = 'temp'; 40 | const IMAGES_DIR = INPUT_FILE_NO_SPACES.split(".svg")[0]+'_images'; 41 | const TEMP_IMAGES_DIR = path.join(BASE_PATH, TEMP_DIR, IMAGES_DIR); 42 | const TEMP_COMPONENT_DIR = path.join(BASE_PATH, TEMP_DIR, 'components'); 43 | 44 | makeDir(OUTPUT_DIR); // don't delete what's in output each time! 45 | emptyAndCreateDir(TEMP_DIR); 46 | emptyAndCreateDir(TEMP_IMAGES_DIR); 47 | emptyAndCreateDir(TEMP_COMPONENT_DIR); 48 | 49 | (async () => { 50 | 51 | fs.readFile(INPUT_FILE, 'utf-8', (err, data) => { 52 | 53 | const preppedData = prepData({ 54 | data, 55 | tempDir: TEMP_DIR, 56 | inputFile: INPUT_FILENAME 57 | }); 58 | 59 | svgson(preppedData, {}, async function(result) { 60 | 61 | const processedJS = processNode(result); 62 | 63 | const preppedFile = 'file://'+path.join(BASE_PATH, TEMP_DIR, 'prepped_'+INPUT_FILENAME); 64 | 65 | const cleanedJS = await removeStatusBarAndKeyboard(preppedFile, TEMP_COMPONENT_DIR, processedJS); 66 | 67 | await getBrowserBoundingBoxes(cleanedJS, preppedFile); 68 | 69 | const js = imagifyParents(cleanedJS) 70 | const { idDims, orderedIds } = await getBrowserBoundingBoxes(js, preppedFile); 71 | 72 | const mainBackgroundColor = firstBackgroundColor(js.childs[0], js.childs[0], idDims); 73 | 74 | const childParent = generateChildParent(orderedIds, idDims); 75 | 76 | let parentChildren = {}; 77 | orderedIds.forEach((id) => { 78 | if(childParent[id]) { 79 | if(!parentChildren[childParent[id]]) { 80 | parentChildren[childParent[id]] = []; 81 | } 82 | parentChildren[childParent[id]].push(id); 83 | } 84 | }); 85 | 86 | let jsObjs = {} 87 | const unrollJs = (js) => { 88 | jsObjs[js.id] = js; 89 | js.childs.forEach((child) => { 90 | unrollJs(child) 91 | }) 92 | } 93 | unrollJs(js); 94 | 95 | const boxComponents = flexBox([orderedIds[0]], idDims, parentChildren, childParent) 96 | 97 | const flatEles = flattenBoxComponents(boxComponents, boxComponents.id, idDims, jsObjs); 98 | 99 | const polygons = orderedIds.filter((id) => { 100 | const item = jsObjs[id]; 101 | return(item.type == 'Polygon' || item.type == 'Path' || item.type == 'Image'); 102 | }) 103 | 104 | await screenshotElements( 105 | preppedFile, 106 | TEMP_IMAGES_DIR, 107 | polygons 108 | ); 109 | 110 | let globalStyles = {}; 111 | flatEles.forEach((ele) => { 112 | const js = jsObjs[ele.id]; 113 | if(!ele.end && ele.id != 'row' && ele.id != 'column') { 114 | const { componentStyles } = nativeAttrs(js, ele, idDims, childParent); 115 | globalStyles = {...globalStyles, ...componentStyles} 116 | } 117 | }); 118 | 119 | const { imports, componentStrings } = generateComponentStrings({ 120 | flatEles, 121 | idDims, 122 | childParent, 123 | jsObjs, 124 | imagesDir: IMAGES_DIR 125 | }); 126 | 127 | const generatedComponent = generateComponent({ 128 | imports, 129 | rootStyle: processedJS.rootStyle, 130 | mainBackgroundColor, 131 | componentStrings, 132 | globalStyles 133 | }) 134 | emptyAndCreateDir(OUTPUT_DIR + '/' + IMAGES_DIR) 135 | copyFolderRecursive(TEMP_IMAGES_DIR, OUTPUT_DIR + '/' + IMAGES_DIR) 136 | fs.writeFileSync(OUTPUT_DIR + '/' + OUTPUT_FILE, generatedComponent) 137 | 138 | console.log("") 139 | console.log("Images directory written: ", path.join(BASE_PATH, OUTPUT_DIR, IMAGES_DIR)) 140 | console.log("React Native component generated: ", path.join(BASE_PATH, OUTPUT_DIR, OUTPUT_FILE)) 141 | console.log("") 142 | 143 | }); 144 | }) 145 | 146 | })(); 147 | 148 | 149 | 150 | -------------------------------------------------------------------------------- /src/input.js: -------------------------------------------------------------------------------- 1 | 2 | const fs = require('fs'); 3 | 4 | const prepData = ({ data, tempDir, inputFile }) => { 5 | let idList = []; 6 | 7 | const oldDefs = data.match(/\[\s\S]*\<\/defs\>/g) 8 | 9 | const newData = data.replace(/\sid=\"[^\"]+/g, function(d) { 10 | let onlyId = d.replace(/^\sid\=\"/, '').replace(/[^\_0-9a-zA-Z]/g, ''); 11 | 12 | if(onlyId.match(/^[0-9]+/)) { 13 | onlyId = "_" + onlyId 14 | } 15 | 16 | d = ' id="' + onlyId; 17 | 18 | if(idList.indexOf(d) >= 0) { 19 | const newId = d + "_" + idList.length 20 | idList.push(newId) 21 | return newId 22 | } else { 23 | idList.push(d); 24 | return d; 25 | } 26 | }); 27 | 28 | const preppedData = newData.replace(/\[\s\S]*\<\/defs\>/g, oldDefs) 29 | 30 | fs.writeFileSync(tempDir + '/prepped_'+inputFile, preppedData) 31 | 32 | return preppedData; 33 | } 34 | 35 | module.exports.prepData = prepData; 36 | -------------------------------------------------------------------------------- /src/lib/files.js: -------------------------------------------------------------------------------- 1 | const fs = require('fs'); 2 | 3 | // Clear the images directory 4 | const deleteFolderRecursive = function(path) { 5 | if (fs.existsSync(path)) { 6 | fs.readdirSync(path).forEach(function(file, index){ 7 | const curPath = path + "/" + file; 8 | if (fs.lstatSync(curPath).isDirectory()) { // recurse 9 | deleteFolderRecursive(curPath); 10 | } else { // delete file 11 | fs.unlinkSync(curPath); 12 | } 13 | }); 14 | fs.rmdirSync(path); 15 | } 16 | }; 17 | 18 | module.exports.makeDir = (path) => { 19 | if (!fs.existsSync(path)) { 20 | fs.mkdirSync(path); 21 | } 22 | } 23 | 24 | module.exports.emptyAndCreateDir = (dir) => { 25 | deleteFolderRecursive(dir); 26 | fs.mkdirSync(dir); 27 | } 28 | 29 | module.exports.copyFolderRecursive = function(path, toPath) { 30 | if (fs.existsSync(path)) { 31 | fs.readdirSync(path).forEach(function(file, index){ 32 | const curPath = path + "/" + file; 33 | const newPath = toPath + "/" + file; 34 | if (fs.lstatSync(curPath).isDirectory()) { // recurse 35 | copyFolderRecursive(curPath); 36 | } else { // delete file 37 | fs.createReadStream(curPath).pipe(fs.createWriteStream(newPath)); 38 | } 39 | }); 40 | } 41 | }; 42 | 43 | -------------------------------------------------------------------------------- /src/lib/utils.js: -------------------------------------------------------------------------------- 1 | // TODO: remove hardcoded values 2 | 3 | const aboutZero = (a) => { 4 | return Math.abs(a-0) < 6 5 | } 6 | 7 | const aboutCentered = (a, b) => { 8 | return Math.abs(a-b) < 17 9 | } 10 | 11 | const aboutEqual = (a, b) => { 12 | return Math.abs(a-b) < 6 13 | } 14 | 15 | const smallComparedTo = (a, b) => { 16 | if(aboutZero(a) && !aboutZero(b)) { 17 | return true 18 | } 19 | const ratio = b / a; 20 | return(a != 0 && ratio > 0.5) 21 | } 22 | 23 | const allAboutEqual = (arr) => { 24 | const avg = arr.reduce(function(sum, a) { return sum + a },0)/(arr.length||1); 25 | return arr.filter((ele) => { 26 | return !aboutEqual(avg, ele) 27 | }).length == 0; 28 | } 29 | 30 | 31 | module.exports.aboutZero = aboutZero; 32 | module.exports.aboutCentered = aboutCentered; 33 | module.exports.aboutEqual = aboutEqual; 34 | module.exports.allAboutEqual = allAboutEqual; 35 | module.exports.smallComparedTo = smallComparedTo; 36 | -------------------------------------------------------------------------------- /src/neural_net.js: -------------------------------------------------------------------------------- 1 | 2 | const exec = require('child_process').exec; 3 | const fs = require('fs'); 4 | const path = require('path'); 5 | 6 | const { screenshotAllElements } = require('./screenshot'); 7 | 8 | 9 | const childIsBadElement = (async (child, tempDir) => { 10 | if(!child.id || child.id == '') { 11 | return false 12 | } 13 | 14 | const filepath = path.join(tempDir, 'components/'+child.id+'.png'); 15 | if (!fs.existsSync(filepath)) { 16 | return false; 17 | } 18 | 19 | const command = `python -m ${path.join(__dirname, '../scripts/label_image')} --graph=${path.join(__dirname, '../tf_files/retrained_graph.pb')} --image=`+filepath; 20 | 21 | return new Promise((resolve, reject) => { 22 | exec(command, 23 | function (error, stdout, stderr) { 24 | if(stdout && stdout != '') { 25 | const results = stdout.split("\n") 26 | 27 | const statusBar = results.filter((result) => { 28 | return result.match(/^status\sbar/); 29 | }); 30 | const isStatusBar = statusBar && statusBar.length > 0 && parseFloat(statusBar[0].split(/\s/)[statusBar[0].split(/\s/).length - 1]); 31 | 32 | const keyboard = results.filter((result) => { 33 | return result.match(/^keyboard/); 34 | }); 35 | const isKeyboard = keyboard && keyboard.length > 0 && parseFloat(keyboard[0].split(/\s/)[keyboard[0].split(/\s/).length - 1]); 36 | 37 | // Make sure you're really sure it's a status bar or keyboard 38 | // just text can trigger even a 98% status bar response 39 | resolve(isStatusBar > 0.99 || isKeyboard > 0.99) 40 | } 41 | 42 | if (error !== null) { 43 | console.log('exec error: ' + error); 44 | resolve(false) 45 | } else { 46 | resolve(false) 47 | } 48 | }); 49 | }); 50 | 51 | }); 52 | 53 | 54 | const filterElements = (async (js, tempDir) => { 55 | if(js.childs) { 56 | 57 | const newChildren = await Promise.all(js.childs.map(async (child) => { 58 | const childShouldBeFilteredOut = await childIsBadElement(child, tempDir); 59 | if(childShouldBeFilteredOut) { 60 | return null; 61 | } else { 62 | await filterElements(child, tempDir); 63 | return child 64 | } 65 | })); 66 | 67 | const filteredChildren = newChildren.filter(child => child); 68 | js.childs = filteredChildren 69 | } 70 | return js; 71 | }) 72 | 73 | const removeStatusBarAndKeyboard = (async (file, tempDir, js) => { 74 | // Start with the children 75 | await screenshotAllElements(file, tempDir, js) 76 | console.warn("filtering elements...") 77 | const newJS = await filterElements(js, tempDir) 78 | 79 | return newJS; 80 | }) 81 | 82 | 83 | 84 | module.exports.removeStatusBarAndKeyboard = removeStatusBarAndKeyboard; 85 | -------------------------------------------------------------------------------- /src/output.js: -------------------------------------------------------------------------------- 1 | 2 | const { nativeAttrs } = require('./attributes'); 3 | 4 | 5 | // Note: the spacing (indentation) is important (even though it looks weird) 6 | // for everything in the ``. 7 | const generateStyleSheetString = (componentStyles) => { 8 | return Object.keys(componentStyles).map((key) => { 9 | const keyStyles = Object.keys(componentStyles[key]).map((styleKey) => { 10 | const styleString = componentStyles[key][styleKey] 11 | return ` ${styleKey}: ${styleString}` 12 | }).join(",\n"); 13 | return ` ${key}: { 14 | ${keyStyles} 15 | }` 16 | }).join(",\n") 17 | } 18 | 19 | 20 | const generateComponentStrings = ({ 21 | flatEles, 22 | idDims, 23 | childParent, 24 | jsObjs, 25 | imagesDir 26 | }) => { 27 | let imports = []; 28 | 29 | const componentStrings = flatEles.map((ele) => { 30 | 31 | const color = '#'+'0123456789abcddd'.split('').map(function(v,i,a){ return i>5 ? null : a[Math.floor(Math.random()*16)] }).join(''); 32 | 33 | if(ele.end) { 34 | if(ele.id == 'row') return `${ele.spaces}` 35 | if(ele.id == 'column') return `${ele.spaces}` 36 | // The rest of the ends are covered below 37 | } 38 | 39 | if(ele.id == 'row') { 40 | let styleAttrs = [`flexDirection: 'row'`] 41 | if(ele.alignJustify) { 42 | const aj = ele.alignJustify 43 | if(aj.justifyContent) styleAttrs.push(`justifyContent: '${aj.justifyContent}'`) 44 | if(aj.alignItems) styleAttrs.push(`alignItems: '${aj.alignItems}'`) 45 | if(aj.marginRight) styleAttrs.push(`marginRight: '${aj.marginRight}'`) 46 | if(aj.marginLeft) styleAttrs.push(`marginLeft: '${aj.marginLeft}'`) 47 | if(aj.marginTop) styleAttrs.push(`marginTop: '${aj.marginTop}'`) 48 | if(aj.marginBottom) styleAttrs.push(`marginBottom: '${aj.marginBottom}'`) 49 | } 50 | return `${ele.spaces}` 51 | } 52 | if(ele.id == 'column') { 53 | return `${ele.spaces}` 54 | } 55 | 56 | const js = jsObjs[ele.id]; 57 | 58 | if(ele.end) { 59 | if(js.type == 'Path') return `${ele.spaces}` 60 | if(js.type == 'Image') return `${ele.spaces}` 61 | if(js.type == 'Polygon') return `${ele.spaces}` 62 | return `${ele.spaces}` 63 | } 64 | 65 | const {attrs, attrObjs, styleId} = nativeAttrs(js, ele, idDims, childParent); 66 | 67 | const styles = styleId ? ` style={styles.${styleId}}` : '' 68 | 69 | if(js.type == 'Text') { 70 | if(js.childs.length == 0) { 71 | return `${ele.spaces}${js.text}${ele.single ? '' : ''}` 72 | } 73 | if(js.childs.length == 1 && js.childs[0].type == 'Tspan') { 74 | return `${ele.spaces}${js.text}${js.childs[0].text}` 75 | } else { 76 | const tspans = js.childs.map((t) => { 77 | if(t.type == 'Tspan') { 78 | // TODO: once tspans have ids, change this 79 | // const {Tattrs, TattrObjs, TstyleId} = nativeAttrs(t, ele); 80 | // const Tstyles = TstyleId ? ` style={styles.${styleId}}` : '' 81 | const Tstyles = '' 82 | return `${ele.spaces} ${t.text}{'\\n'}` 83 | } 84 | }).join("\n") 85 | 86 | return `${ele.spaces}${js.text}\n${tspans}\n${ele.spaces}` 87 | } 88 | 89 | } 90 | if(js.type == 'Tspan') { 91 | return `${js.text}` 92 | } 93 | 94 | if(js.type == 'Polygon') { 95 | const imageStyle = attrObjs.style ? ` style={${attrObjs.style}}` : ''; 96 | const jsname = js.id.replace("-", "").replace("+", "") 97 | if(imports.indexOf(`import ${jsname} from './${imagesDir}/${js.id}.png'`) < 0) { 98 | imports.push(`import ${jsname} from './${imagesDir}/${js.id}.png'`) 99 | } 100 | return `${ele.spaces}` 101 | } 102 | 103 | if(js.type == 'Image') { 104 | const imageStyle = attrObjs.style ? ` style={${attrObjs.style}}` : ''; 105 | const jsname = js.id.replace("-", "").replace("+", "") 106 | if(imports.indexOf(`import ${jsname} from './${imagesDir}/${js.id}.png'`) < 0) { 107 | imports.push(`import ${jsname} from './${imagesDir}/${js.id}.png'`) 108 | } 109 | return `${ele.spaces}` 110 | } 111 | 112 | if(js.type == 'Path') { 113 | const imageStyle = attrObjs.style ? ` style={${attrObjs.style}}` : ''; 114 | if(ele.end) { 115 | return `${ele.spaces}` 116 | } 117 | const jsname = js.id.replace("-", "").replace("+", "") 118 | if(imports.indexOf(`import ${jsname} from './${imagesDir}/${js.id}.png'`) < 0) { 119 | imports.push(`import ${jsname} from './${imagesDir}/${js.id}.png'`) 120 | } 121 | if(ele.single) { 122 | return `${ele.spaces}` 123 | } else { 124 | return `${ele.spaces}` 125 | } 126 | } 127 | 128 | if(ele.end) { 129 | return(ele.spaces+''); 130 | } 131 | return `${ele.spaces}<${js.type}${attrs}${styles}${ele.single ? ' /' : ''}>` 132 | 133 | }).join("\n"); 134 | 135 | return({ 136 | componentStrings, 137 | imports 138 | }) 139 | } 140 | 141 | 142 | 143 | const generateComponent = ({ 144 | imports, 145 | rootStyle, 146 | mainBackgroundColor, 147 | componentStrings, 148 | globalStyles 149 | }) => { 150 | const styleSheetString = generateStyleSheetString(globalStyles); 151 | 152 | // Note: the spacing (indentation) is important (even though it looks weird) 153 | // for everything in the ``. 154 | return ( 155 | `import React, { Component } from 'react'; 156 | 157 | import { 158 | StyleSheet, 159 | Text, 160 | View, 161 | TouchableOpacity, 162 | TextInput, 163 | ScrollView, 164 | Image 165 | } from 'react-native'; 166 | 167 | `+imports.join("\n")+` 168 | 169 | export default class Main extends Component { 170 | 171 | render() { 172 | return ( 173 | 177 | ` + componentStrings + ` 178 | 179 | ) 180 | } 181 | 182 | } 183 | 184 | const styles = StyleSheet.create({ 185 | ${styleSheetString} 186 | }) 187 | ` 188 | ) 189 | } 190 | 191 | 192 | module.exports.generateComponent = generateComponent; 193 | module.exports.generateComponentStrings = generateComponentStrings; 194 | -------------------------------------------------------------------------------- /src/process.js: -------------------------------------------------------------------------------- 1 | 2 | 3 | const processNode = (node, parentAttrs={}) => { 4 | 5 | if(node.name == 'svg') { 6 | return processSVG(node, parentAttrs); 7 | } else if(node.name == 'g') { 8 | return processG(node, parentAttrs); 9 | } else if(node.name == 'text') { 10 | return processText(node, parentAttrs); 11 | } else if(node.name == 'rect') { 12 | return processRect(node, parentAttrs); 13 | } else if(node.name == 'path') { 14 | return processPath(node, parentAttrs); 15 | } else if(node.name == 'polygon') { 16 | return processPolygon(node, parentAttrs); 17 | } else if(node.name == 'tspan') { 18 | return processTspan(node, parentAttrs); 19 | } 20 | } 21 | 22 | const processChildren = (parent, parentAttrs = {}) => { 23 | const children = parent.childs ? parent.childs : []; 24 | 25 | // If all the nodes are going to be images or empty, 26 | // Then just process the entire thing as one image. 27 | 28 | return children.filter((node) => { 29 | const processedNode = processNode(node, parentAttrs); 30 | return processedNode; 31 | }).map((node) => { 32 | return processNode(node, parentAttrs) 33 | }); 34 | } 35 | 36 | 37 | const processSVG = (node, parentAttrs={}) => { 38 | 39 | const viewBoxArray = node.attrs.viewBox.split(" ") 40 | const viewBox = { 41 | x: viewBoxArray[0], 42 | y: viewBoxArray[1] 43 | } 44 | parentAttrs.viewBox = viewBox; 45 | 46 | let rootStyle = {} 47 | if(node.attrs.style) { 48 | rootStyle = node.attrs.style.split(/\;\s*/).reduce((r, style) => { 49 | const keyValue = style.split(/\:\s*/) 50 | if(keyValue.length == 2) { 51 | r[keyValue[0]] = keyValue[1] 52 | if(keyValue[0] == 'background' && keyValue[1].length == 4 || keyValue[1].length == 7) { 53 | r['backgroundColor'] = keyValue[1] 54 | } 55 | } 56 | return r; 57 | }, {}); 58 | } 59 | 60 | const backgroundColor = rootStyle && rootStyle.background ? rootStyle.background : rootStyle && rootStyle.backgroundColor ? rootStyle.backgroundColor : '#ffffff'; 61 | let styles = { 62 | flex: 1, 63 | alignSelf: 'stretch' 64 | } 65 | if(backgroundColor) { 66 | styles.backgroundColor = backgroundColor 67 | } 68 | 69 | return { 70 | id: node.attrs.id, 71 | type: 'ScrollView', 72 | rootStyle, 73 | childs: processChildren(node, parentAttrs), 74 | style: styles 75 | } 76 | } 77 | 78 | const processPolygon = (node, parentAttrs={}) => { 79 | return { 80 | id: node.attrs.id, 81 | type: 'Polygon', 82 | childs: [], 83 | style: {} 84 | } 85 | } 86 | 87 | 88 | const processG = (node, parentAttrs={}) => { 89 | const tl = topLeft(node, parentAttrs) 90 | 91 | let styles = {...tl} 92 | 93 | if(node.attrs.fontSize) { 94 | parentAttrs.fontSize = node.attrs.fontSize 95 | } 96 | if(node.attrs.fontWeight) { 97 | parentAttrs.fontWeight = node.attrs.fontWeight 98 | } 99 | // Font family removed for now 100 | 101 | // if(node.attrs.fontFamily) { 102 | // parentAttrs.fontFamily = node.attrs.fontFamily.split(", ")[0] 103 | // } 104 | if(node.attrs.lineHeight) { 105 | parentAttrs.lineHeight = node.attrs.lineSpacing 106 | } 107 | if(node.attrs.fill) { 108 | parentAttrs.fill = node.attrs.fill 109 | } 110 | 111 | let use = null; 112 | node.childs && node.childs.forEach((child) => { 113 | if(child.name == 'use') { 114 | if(child.attrs.fill) { 115 | styles.backgroundColor = child.attrs.fill 116 | } 117 | } 118 | }) 119 | 120 | 121 | return { 122 | id: node.attrs.id, 123 | type: 'View', 124 | childs: processChildren(node, parentAttrs), 125 | style: styles 126 | } 127 | } 128 | 129 | 130 | const processRect = (node, parentAttrs={}) => { 131 | const tl = topLeft(node, parentAttrs) 132 | 133 | let styles = {...tl} 134 | 135 | const attrs = node.attrs ? node.attrs : {} 136 | if(attrs.fill) { 137 | styles.backgroundColor = attrs.fill 138 | } 139 | 140 | if(attrs.opacity && attrs.fill) { 141 | const op = parseFloat(attrs.opacity).toFixed(2).split(".")[1]; 142 | styles.backgroundColor = attrs.fill + op 143 | } 144 | 145 | if(attrs.rx) { 146 | styles.borderRadius = attrs.rx 147 | } 148 | 149 | if(attrs.ry) { 150 | styles.borderRadius = attrs.rx 151 | } 152 | 153 | return { 154 | id: node.attrs.id, 155 | type: 'View', 156 | childs: processChildren(node, parentAttrs), 157 | style: styles 158 | } 159 | } 160 | 161 | 162 | const processText = (node, parentAttrs={}) => { 163 | let style = {} 164 | let text = "" 165 | 166 | const attrs = node.attrs ? node.attrs : {} 167 | 168 | if(attrs.x || attrs.y) { 169 | style.position = 'absolute' 170 | style.left = attrs.x ? attrs.x : 0 171 | style.top = attrs.y ? attrs.y : 0 172 | } 173 | 174 | style.backgroundColor = 'transparent' 175 | if(parentAttrs.fontSize) { 176 | style.fontSize = parentAttrs.fontSize 177 | } 178 | if(attrs.fontSize) { 179 | style.fontSize = attrs.fontSize 180 | } 181 | 182 | if(parentAttrs.lineHeight) { 183 | style.lineHeight = parentAttrs.lineHeight 184 | } 185 | if(parentAttrs.fontWeight) { 186 | style.fontWeight = parentAttrs.fontWeight 187 | } 188 | if(attrs.fontWeight) { 189 | style.fontWeight = attrs.fontWeight 190 | } 191 | 192 | // Font family removed for now; sketch files do not 193 | // contain the fonts, so this became a mess. 194 | 195 | // if(parentAttrs.fontFamily) { 196 | // style.fontFamily = parentAttrs.fontFamily 197 | // } 198 | // if(attrs.fontFamily) { 199 | // style.fontFamily = attrs.fontFamily.split(", ")[0] 200 | // } 201 | 202 | if(parentAttrs.fill) { 203 | style.color = parentAttrs.fill 204 | } 205 | if(attrs.fill) { 206 | style.color = attrs.fill; 207 | } 208 | 209 | return { 210 | id: node.attrs.id, 211 | type: 'Text', 212 | childs: processChildren(node, parentAttrs), 213 | text: text, 214 | style 215 | } 216 | } 217 | 218 | 219 | const processTspan = (node, parentAttrs={}) => { 220 | let style = {} 221 | const attrs = node.attrs ? node.attrs : {} 222 | 223 | if(attrs.fill) { 224 | style.color = attrs.fill; 225 | } 226 | if(attrs.x || attrs.y) { 227 | style.position = 'absolute' 228 | style.left = attrs.x ? attrs.x : 0 229 | style.top = attrs.y ? attrs.y : 0 230 | } 231 | 232 | const children = node.childs ? node.childs : []; 233 | const textChild = children && children.length == 1 && children[0].text ? children[0].text : ''; 234 | 235 | return { 236 | id: node.attrs.id, 237 | type: 'Tspan', 238 | childs: [], 239 | text: textChild, 240 | style 241 | } 242 | } 243 | 244 | 245 | 246 | const processPath = (node, parentAttrs={}) => { 247 | 248 | let style = {} 249 | const attrs = node.attrs ? node.attrs : {} 250 | let directAttrs = {} 251 | 252 | if(attrs.d) { 253 | directAttrs.d = attrs.d; 254 | } 255 | 256 | if(parentAttrs.fill) { 257 | directAttrs.fill = parentAttrs.fill; 258 | } 259 | if(attrs.fill) { 260 | directAttrs.fill = attrs.fill; 261 | } 262 | if(attrs.opacity) { 263 | directAttrs.opacity = attrs.opacity; 264 | } 265 | if(attrs.x || attrs.y) { 266 | style.position = 'absolute' 267 | style.left = attrs.x ? attrs.x : 0 268 | style.top = attrs.y ? attrs.y : 0 269 | } 270 | 271 | const children = node.childs ? node.childs : []; 272 | 273 | 274 | return { 275 | id: node.attrs.id, 276 | type: 'Path', 277 | childs: processChildren(node, parentAttrs), 278 | style, 279 | directAttrs: directAttrs 280 | } 281 | } 282 | 283 | 284 | 285 | const topLeft = (node, parentAttrs={}) => { 286 | let styles = {position: 'absolute'} 287 | const viewBox = parentAttrs.viewBox; 288 | 289 | if(node.attrs && node.attrs.transform && node.attrs.transform.match(/^translate\([^\(]+\)$/)) { 290 | const transform = node.attrs.transform.replace(/^translate\(/, '').replace(/\)$/, '').split(", ") 291 | styles.left = transform[0] - (viewBox ? viewBox.x : 0) 292 | styles.top = transform[1] - (viewBox ? viewBox.y : 0) 293 | } 294 | 295 | if(node.attrs && node.attrs.height) { 296 | styles.height = node.attrs.height 297 | } 298 | 299 | if(node.attrs && node.attrs.width) { 300 | styles.width = node.attrs.width 301 | } 302 | 303 | if(node.attrs && node.attrs.x) { 304 | styles.left = node.attrs.x - (viewBox ? viewBox.x : 0) 305 | } 306 | 307 | if(node.attrs && node.attrs.y) { 308 | styles.top = node.attrs.y - (viewBox ? viewBox.y : 0) 309 | } 310 | 311 | return styles 312 | } 313 | 314 | // If the bounding boxes for a view are all empty views 315 | // or paths or polygons or masks? 316 | // Then call the view an image, and process it once. 317 | const imagifyParents = (js) => { 318 | // if all children are empty views or polygons or paths 319 | // return no children, and change the type to image. 320 | 321 | let newJS = {...js} 322 | 323 | let newChildren = js.childs.map((child) => { 324 | return imagifyParents(child); 325 | }); 326 | 327 | let siblingsAreImages = js.childs.length > 0 328 | js.childs.forEach((child) => { 329 | const emptyView = child.type == 'View' && child.childs.length == 0; 330 | const isImageType = ['Path', 'Polygon'].indexOf(child.type) > -1; 331 | if(!emptyView && !isImageType) { 332 | siblingsAreImages = false 333 | } 334 | }); 335 | 336 | if(siblingsAreImages) { 337 | newJS.type = 'Image'; 338 | newJS.childs = []; 339 | } else { 340 | newJS.childs = newChildren; 341 | } 342 | 343 | return newJS; 344 | } 345 | 346 | 347 | module.exports.processNode = processNode; 348 | module.exports.imagifyParents = imagifyParents; 349 | -------------------------------------------------------------------------------- /src/screenshot.js: -------------------------------------------------------------------------------- 1 | const puppeteer = require('puppeteer'); 2 | 3 | 4 | const getAllElementIds = (js) => { 5 | let ids = js.childs.filter((child) => { 6 | return child.id; 7 | }).map((child) => { 8 | return child.id; 9 | }); 10 | js.childs.forEach((child) => { 11 | const childIds = getAllElementIds(child); 12 | ids = [...ids, ...childIds] 13 | }) 14 | return ids; 15 | } 16 | 17 | const screenshotAllElements = (async (file, tempDir, rootJS) => { 18 | 19 | let elements = getAllElementIds(rootJS); 20 | console.warn("gathering elements...") 21 | await screenshotElements(file, tempDir, elements); 22 | 23 | }) 24 | 25 | 26 | const screenshotElements = (async (file, tempDir, elementIds) => { 27 | const browser = await puppeteer.launch(); 28 | const page = await browser.newPage(); 29 | await page.goto(file); 30 | 31 | for(let id of elementIds) { 32 | await screenshotDOMElement(page, id, tempDir) 33 | } 34 | 35 | browser.close(); 36 | }); 37 | 38 | 39 | async function screenshotDOMElement(page, selector, tempDir) { 40 | 41 | const rect = await page.evaluate(selector => { 42 | const element = document.querySelector("#" + selector); 43 | if (!element) 44 | return null; 45 | const {x, y, width, height} = element.getBoundingClientRect(); 46 | return {left: x, top: y, width, height, id: element.id}; 47 | }, selector); 48 | 49 | if(!rect) { 50 | throw Error(`Could not find element that matches selector: ${selector}.`); 51 | } 52 | 53 | return await page.screenshot({ 54 | path: tempDir+'/'+selector+'.png', 55 | clip: { 56 | x: rect.left, 57 | y: rect.top, 58 | width: rect.width, 59 | height: rect.height 60 | } 61 | }); 62 | } 63 | 64 | 65 | const getEleDimensions = (async (node, page, idDims, orderedIds) => { 66 | 67 | if(node.id) { 68 | // Get the "viewport" of the page, as reported by the page. 69 | const dimensions = await page.evaluate((node) => { 70 | const ele = document.getElementById(node.id) && document.getElementById(node.id).getBoundingClientRect() 71 | if(ele) { 72 | return { 73 | top: Math.min(ele.top, ele.bottom), 74 | left: Math.min(ele.left, ele.right), 75 | right: Math.max(ele.right, ele.left), 76 | bottom: Math.max(ele.bottom, ele.top), 77 | height: ele.height, 78 | width: ele.width 79 | } 80 | } else { 81 | return null; 82 | } 83 | }, node); 84 | 85 | if(dimensions) { 86 | idDims[node.id] = dimensions; 87 | orderedIds.push(node.id); 88 | } 89 | } 90 | 91 | for(let child of node.childs) { 92 | await getEleDimensions(child, page, idDims, orderedIds); 93 | } 94 | 95 | }); 96 | 97 | 98 | const getBrowserBoundingBoxes = (async (js, file) => { 99 | const browser = await puppeteer.launch(); 100 | const page = await browser.newPage(); 101 | await page.goto(file); 102 | 103 | let idDims = {}; 104 | let orderedIds = []; 105 | await getEleDimensions(js, page, idDims, orderedIds); 106 | 107 | browser.close(); 108 | 109 | return({ 110 | idDims, 111 | orderedIds 112 | }); 113 | 114 | }); 115 | 116 | 117 | module.exports.getBrowserBoundingBoxes = getBrowserBoundingBoxes; 118 | module.exports.screenshotElements = screenshotElements; 119 | module.exports.screenshotAllElements = screenshotAllElements; 120 | -------------------------------------------------------------------------------- /tf_files/retrained_graph.pb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nanohop/sketch-to-react-native/6075a766749f6ed419ff0f34abdb81f8cc4cf0f3/tf_files/retrained_graph.pb -------------------------------------------------------------------------------- /tf_files/retrained_labels.txt: -------------------------------------------------------------------------------- 1 | activity 2 | button 3 | keyboard 4 | slider 5 | status bar 6 | switch 7 | text 8 | --------------------------------------------------------------------------------