├── .gitignore ├── LICENSE.txt ├── README.md ├── benderthon ├── __init__.py ├── caffe_freeze.py ├── cmdline.py ├── tf_freeze.py └── util.py ├── requirements-dev.txt ├── requirements.txt ├── sample.py ├── setup.cfg ├── setup.py ├── test └── testdata └── g.pb /.gitignore: -------------------------------------------------------------------------------- 1 | *.egg-info 2 | build 3 | checkpoints 4 | !checkpoints/.keep 5 | data 6 | !data/.keep 7 | dist 8 | output 9 | !output/.keep 10 | weights 11 | !weights/.keep 12 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # benderthon 2 | 3 | Set of utilities to work easier with [Bender](https://github.com/xmartlabs/Bender). 4 | 5 | Currently there's support for TensorFlow and Caffe, but we are working on more stuff! 6 | 7 | Works on Python 2.7.+ and 3.+, with Tensorflow 1.2+. 8 | 9 | To install: 10 | 11 | ```bash 12 | pip install benderthon 13 | ``` 14 | 15 | TensorFlow is required too. The simplest way to install it is: 16 | 17 | ```bash 18 | pip install tensorflow 19 | ``` 20 | 21 | There are other ways, see [Installing Tensorflow](https://www.tensorflow.org/install/). Benderthon does not install it 22 | by default to let the usage of a custom installation. 23 | 24 | ## tf-freeze 25 | 26 | Utility to convert **TensorFlow** checkpoints into minimal frozen **graphs**. 27 | 28 | ### Usage 29 | 30 | #### From a checkpoint 31 | 32 | To take the checkpoint `checkpoint_path.ckpt`, whose output is yielded by the node named `Tanh`, and save it to `graph_with_weights.pb`: 33 | 34 | ```bash 35 | benderthon tf-freeze checkpoint_path.ckpt graph_with_weights.pb Tanh 36 | ``` 37 | 38 | #### From code 39 | 40 | If you don't have a checkpoint or prefer to run it from code, this is the way to go. This is the same example as above but from code: 41 | 42 | ```python 43 | from benderthon import tf_freeze 44 | 45 | // … 46 | 47 | with tf.Session() as sess: 48 | // … 49 | 50 | tf_freeze.freeze(sess, 'graph_with_weights.pb', ['Tanh']) 51 | ``` 52 | 53 | ### Sample 54 | 55 | The file `sample.py` contains a network example for MNIST dataset with 2 convolutional layers and 2 dense layers. If you run it, it will generate a minimal protobuf for with the weights frozen to run in Bender in `output/mnist.pb`: 56 | 57 | ```bash 58 | ./sample.py 59 | ``` 60 | 61 | The generated file occupies **half** the original checkpoints (26MB to 13MB). 62 | 63 | The script will also generate checkpoints files with prefix `checkpoints/mnist.ckpt`. So you could have generated the protobuf from it: 64 | 65 | ```bash 66 | benderthon tf-freeze checkpoints/mnist.ckpt output/mnist.pb Prediction 67 | ``` 68 | 69 | You can also get only the graph, which occupies just **13kB**: 70 | 71 | ```bash 72 | benderthon tf-freeze --no-weights checkpoints/mnist.ckpt output/mnist_only_graph.pb Prediction 73 | ``` 74 | 75 | To save the weights in a separate path for later processing: 76 | 77 | ```bash 78 | benderthon tf-freeze --only-weights checkpoints/mnist.ckpt weights/ Prediction 79 | ``` 80 | 81 | ## caffe-freeze 82 | 83 | This module cannot be accessed from the command line utility, it should be used from Python code, importing `benderthon.caffe_freeze`. 84 | 85 | You need `caffeflow` package installed first: 86 | 87 | ```bash 88 | pip install -e git://github.com/xmartlabs/caffeflow.git@4618f89#egg=caffeflow 89 | ``` 90 | 91 | ## Development 92 | 93 | This utility is under development and the API **is not stable**. So, do not heavily rely on it. 94 | 95 | To install locally you should do ```./setup.py install```, but first have [pandoc](http://pandoc.org/) and [pypandoc](https://github.com/bebraw/pypandoc) installed. 96 | 97 | ## License 98 | 99 | ``` 100 | Copyright 2017 Xmartlabs SRL. 101 | 102 | Licensed under the Apache License, Version 2.0 (the "License"); 103 | you may not use this file except in compliance with the License. 104 | You may obtain a copy of the License at 105 | 106 | http://www.apache.org/licenses/LICENSE-2.0 107 | 108 | Unless required by applicable law or agreed to in writing, software 109 | distributed under the License is distributed on an "AS IS" BASIS, 110 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 111 | See the License for the specific language governing permissions and 112 | limitations under the License. 113 | ``` 114 | -------------------------------------------------------------------------------- /benderthon/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmartlabs/benderthon/810b6fb90f56136257e7ed12e5a30d17ad7ce6ba/benderthon/__init__.py -------------------------------------------------------------------------------- /benderthon/caffe_freeze.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """Utility to freeze Caffe models.""" 4 | 5 | from __future__ import absolute_import, division, print_function, unicode_literals 6 | 7 | import contextlib 8 | import imp 9 | import os 10 | 11 | import tensorflow as tf 12 | 13 | from benderthon import tf_freeze, util 14 | 15 | 16 | @contextlib.contextmanager 17 | def dummy_context_mgr(obj): 18 | yield obj 19 | 20 | 21 | def caffe_to_tensorflow_session(caffe_def_path, caffemodel_path, inputs, graph_name='Graph', 22 | conversion_out_dir_path=None, use_padding_same=False): 23 | """Create a TensorFlow Session from a Caffe model.""" 24 | try: 25 | # noinspection PyUnresolvedReferences 26 | from caffeflow import convert 27 | except ImportError: 28 | raise Exception("caffeflow package needs to be installed to freeze Caffe models. Check out the README file.") 29 | 30 | with (dummy_context_mgr(conversion_out_dir_path) or util.TemporaryDirectory()) as dir_path: 31 | params_values_output_path = os.path.join(dir_path, 'params_values.npy') 32 | network_output_path = os.path.join(dir_path, 'network.py') 33 | 34 | convert.convert(caffe_def_path, caffemodel_path, params_values_output_path, network_output_path, False, 35 | use_padding_same=use_padding_same) 36 | 37 | network_module = imp.load_source('module.name', network_output_path) 38 | network_class = getattr(network_module, graph_name) 39 | network = network_class(inputs) 40 | 41 | sess = tf.Session() 42 | 43 | network.load(params_values_output_path, sess) 44 | 45 | return sess 46 | 47 | 48 | def freeze(caffe_def_path, caffemodel_path, inputs, output_file_path, output_node_names, graph_name='Graph', 49 | conversion_out_dir_path=None, checkpoint_out_path=None, use_padding_same=False): 50 | """Freeze and shrink the graph based on a Caffe model, the input tensors and the output node names.""" 51 | with caffe_to_tensorflow_session(caffe_def_path, caffemodel_path, inputs, graph_name=graph_name, 52 | conversion_out_dir_path=conversion_out_dir_path, 53 | use_padding_same=use_padding_same) as sess: 54 | saver = tf.train.Saver() 55 | 56 | with (dummy_context_mgr(checkpoint_out_path) or util.TemporaryDirectory()) as temp_dir_path: 57 | checkpoint_path = checkpoint_out_path or os.path.join(temp_dir_path, 'pose.ckpt') 58 | saver.save(sess, checkpoint_path) 59 | 60 | output_node_names = util.output_node_names_string_as_list(output_node_names) 61 | 62 | tf_freeze.freeze_from_checkpoint(checkpoint_path, output_file_path, output_node_names) 63 | 64 | 65 | def save_graph_only(caffe_def_path, caffemodel_path, inputs, output_file_path, output_node_names, graph_name='Graph', 66 | use_padding_same=False): 67 | """Save a small version of the graph based on a Caffe model, the input tensors and the output node names.""" 68 | with caffe_to_tensorflow_session(caffe_def_path, caffemodel_path, inputs, graph_name=graph_name, 69 | use_padding_same=use_padding_same) as sess: 70 | tf_freeze.save_graph_only(sess, output_file_path, output_node_names) 71 | 72 | 73 | def save_weights(caffe_def_path, caffemodel_path, inputs, output_path, graph_name='Graph', conv_var_names=None, 74 | conv_transpose_var_names=None, use_padding_same=False): 75 | """Save the weights of the trainable variables, each one in a different file in output_path.""" 76 | with caffe_to_tensorflow_session(caffe_def_path, caffemodel_path, inputs, graph_name=graph_name, 77 | use_padding_same=use_padding_same) as sess: 78 | tf_freeze.save_weights(sess, output_path, conv_var_names=conv_var_names, 79 | conv_transpose_var_names=conv_transpose_var_names) 80 | -------------------------------------------------------------------------------- /benderthon/cmdline.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """Set of utilities to work easier with Bender.""" 4 | 5 | import argparse 6 | 7 | from benderthon import tf_freeze 8 | 9 | 10 | def main(): 11 | parser = argparse.ArgumentParser(description="Set of utilities for Bender.") 12 | subparsers = parser.add_subparsers() 13 | 14 | parser_freeze = subparsers.add_parser('tf-freeze', 15 | help="Utility to easily convert TensorFlow checkpoints into minimal frozen " 16 | "graphs in binary protobuf format.") 17 | parser_freeze.add_argument('input_checkpoint', help="checkpoint path to load") 18 | parser_freeze.add_argument('output_path', help="path to save the binary protobuf graph or the weights") 19 | parser_freeze.add_argument('output_node_names', help="the name of the output nodes, comma separated") 20 | parser_freeze.add_argument('--no-weights', action='store_true', 21 | help="indicate that the variables are not converted to consts") 22 | parser_freeze.add_argument('--only-weights', action='store_true', 23 | help="indicate that only the weights should be saved. Each one is saved in a different " 24 | "file in the path specified by output_path.") 25 | parser_freeze.add_argument('--conv-vars', help="the name of variables used for convolutions, comma separated. " 26 | "Used for --only-weights option.", default='') 27 | parser_freeze.add_argument('--conv-transpose-vars', help="the name of variables used for transposed convolutions, " 28 | "comma separated. Used for --only-weights option.", 29 | default='') 30 | args = parser.parse_args() 31 | if args.no_weights: 32 | tf_freeze.save_graph_only_from_checkpoint(args.input_checkpoint, args.output_path, 33 | args.output_node_names.split(',')) 34 | elif args.only_weights: 35 | tf_freeze.save_weights_from_checkpoint(args.input_checkpoint, args.output_path, 36 | conv_var_names=args.conv_vars.split(','), 37 | conv_transpose_var_names=args.conv_transpose_vars.split(',')) 38 | else: 39 | tf_freeze.freeze_from_checkpoint(args.input_checkpoint, args.output_path, args.output_node_names.split(',')) 40 | 41 | 42 | if __name__ == '__main__': 43 | main() 44 | -------------------------------------------------------------------------------- /benderthon/tf_freeze.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """Utility to freeze TensorFlow graphs.""" 4 | 5 | from __future__ import absolute_import, division, print_function, unicode_literals 6 | 7 | import os 8 | 9 | import tensorflow as tf 10 | from tensorflow.python.framework import graph_io, graph_util 11 | from tensorflow.python.tools import freeze_graph 12 | 13 | from benderthon.util import check_input_checkpoint, output_node_names_string_as_list, restore_from_checkpoint,\ 14 | TemporaryDirectory 15 | 16 | 17 | def freeze_from_checkpoint(input_checkpoint, output_file_path, output_node_names): 18 | """Freeze and shrink the graph based on a checkpoint and the output node names.""" 19 | check_input_checkpoint(input_checkpoint) 20 | 21 | output_node_names = output_node_names_string_as_list(output_node_names) 22 | 23 | with tf.Session() as sess: 24 | restore_from_checkpoint(sess, input_checkpoint) 25 | freeze_graph.freeze_graph_with_def_protos(input_graph_def=sess.graph_def, input_saver_def=None, 26 | input_checkpoint=input_checkpoint, 27 | output_node_names=','.join(output_node_names), 28 | restore_op_name='save/restore_all', 29 | filename_tensor_name='save/Const:0', output_graph=output_file_path, 30 | clear_devices=True, initializer_nodes='') 31 | 32 | 33 | def freeze(sess, output_file_path, output_node_names): 34 | """Freeze and shrink the graph based on a session and the output node names.""" 35 | with TemporaryDirectory() as temp_dir_name: 36 | checkpoint_path = os.path.join(temp_dir_name, 'model.ckpt') 37 | tf.train.Saver().save(sess, checkpoint_path) 38 | 39 | freeze_from_checkpoint(checkpoint_path, output_file_path, output_node_names) 40 | 41 | 42 | def save_graph_only(sess, output_file_path, output_node_names, as_text=False): 43 | """Save a small version of the graph based on a session and the output node names.""" 44 | for node in sess.graph_def.node: 45 | node.device = '' 46 | graph_def = graph_util.extract_sub_graph(sess.graph_def, output_node_names) 47 | output_dir, output_filename = os.path.split(output_file_path) 48 | graph_io.write_graph(graph_def, output_dir, output_filename, as_text=as_text) 49 | 50 | 51 | def save_graph_only_from_checkpoint(input_checkpoint, output_file_path, output_node_names, as_text=False): 52 | """Save a small version of the graph based on a checkpoint and the output node names.""" 53 | check_input_checkpoint(input_checkpoint) 54 | 55 | output_node_names = output_node_names_string_as_list(output_node_names) 56 | 57 | with tf.Session() as sess: 58 | restore_from_checkpoint(sess, input_checkpoint) 59 | save_graph_only(sess, output_file_path, output_node_names, as_text=as_text) 60 | 61 | 62 | def save_weights(sess, output_path, conv_var_names=None, conv_transpose_var_names=None): 63 | """Save the weights of the trainable variables, each one in a different file in output_path.""" 64 | if not conv_var_names: 65 | conv_var_names = [] 66 | 67 | if not conv_transpose_var_names: 68 | conv_transpose_var_names = [] 69 | 70 | for var in tf.trainable_variables(): 71 | filename = '{}-{}'.format(output_path, var.name.replace(':', '-').replace('/', '-')) 72 | 73 | if var.name in conv_var_names: 74 | var = tf.transpose(var, perm=[3, 0, 1, 2]) 75 | elif var.name in conv_transpose_var_names: 76 | var = tf.transpose(var, perm=[3, 1, 0, 2]) 77 | 78 | value = sess.run(var) 79 | 80 | # noinspection PyTypeChecker 81 | with open(filename, 'w') as file_: 82 | value.tofile(file_) 83 | 84 | 85 | def save_weights_from_checkpoint(input_checkpoint, output_path, conv_var_names=None, conv_transpose_var_names=None): 86 | """Save the weights of the trainable variables given a checkpoint, each one in a different file in output_path.""" 87 | check_input_checkpoint(input_checkpoint) 88 | 89 | with tf.Session() as sess: 90 | restore_from_checkpoint(sess, input_checkpoint) 91 | save_weights(sess, output_path, conv_var_names=conv_var_names, 92 | conv_transpose_var_names=conv_transpose_var_names) 93 | -------------------------------------------------------------------------------- /benderthon/util.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Util functions used by Bender subcommands.""" 3 | 4 | from __future__ import absolute_import, division, print_function, unicode_literals 5 | 6 | import shutil 7 | import tempfile 8 | import warnings 9 | 10 | from six import string_types 11 | import tensorflow as tf 12 | from tensorflow.python.training import saver as saver_lib 13 | 14 | 15 | def check_input_checkpoint(input_checkpoint): 16 | """Check if input_checkpoint is a valid path or path prefix.""" 17 | if not saver_lib.checkpoint_exists(input_checkpoint): 18 | print("Input checkpoint '{}' doesn't exist!".format(input_checkpoint)) 19 | exit(-1) 20 | 21 | 22 | def restore_from_checkpoint(sess, input_checkpoint): 23 | """Return a TensorFlow saver from a checkpoint containing the metagraph.""" 24 | saver = tf.train.import_meta_graph('{}.meta'.format(input_checkpoint)) 25 | saver.restore(sess, input_checkpoint) 26 | return saver 27 | 28 | 29 | def output_node_names_string_as_list(output_node_names): 30 | """Return a list of containing output_node_names if it's a string, otherwise return just output_node_names.""" 31 | if isinstance(output_node_names, string_types): 32 | return [output_node_names] 33 | else: 34 | return output_node_names 35 | 36 | 37 | class TemporaryDirectory(object): 38 | """Create and return a temporary directory. This has the same 39 | behavior as mkdtemp but can be used as a context manager. For 40 | example: 41 | 42 | with TemporaryDirectory() as tmpdir: 43 | ... 44 | 45 | Upon exiting the context, the directory and everything contained 46 | in it are removed. 47 | 48 | Inspired from https://hg.python.org/cpython/file/3.6/Lib/tempfile.py 49 | """ 50 | 51 | def __init__(self, suffix='', prefix='', dir_=None): 52 | self.name = tempfile.mkdtemp(suffix, prefix, dir_) 53 | 54 | @classmethod 55 | def _cleanup(cls, name, warn_message): 56 | shutil.rmtree(name) 57 | warnings.warn(warn_message, ResourceWarning) 58 | 59 | def __repr__(self): 60 | return "<{} {!r}>".format(self.__class__.__name__, self.name) 61 | 62 | def __enter__(self): 63 | return self.name 64 | 65 | def __exit__(self, exc, value, tb): 66 | shutil.rmtree(self.name) 67 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | pypandoc 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow>=1.2.0 2 | -------------------------------------------------------------------------------- /sample.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | from __future__ import division, print_function, unicode_literals 4 | 5 | from benderthon import tf_freeze 6 | import tensorflow as tf 7 | from tensorflow.examples.tutorials.mnist import input_data 8 | 9 | SEED = 24 10 | 11 | 12 | def simple_network(x): 13 | W_fc1 = tf.Variable(tf.truncated_normal([784, 1000], stddev=0.01, seed=SEED)) 14 | b_fc1 = tf.Variable(tf.zeros([1000])) 15 | h_fc1 = tf.nn.relu(tf.matmul(x, W_fc1) + b_fc1) 16 | 17 | keep_prob1 = tf.placeholder(tf.float32) 18 | h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob=keep_prob1, seed=SEED) 19 | 20 | W_fc2 = tf.Variable(tf.truncated_normal([1000, 1000], stddev=0.01, seed=SEED)) 21 | b_fc2 = tf.Variable(tf.zeros([1000])) 22 | h_fc2 = tf.nn.relu(tf.matmul(h_fc1_drop, W_fc2) + b_fc2) 23 | 24 | keep_prob2 = tf.placeholder(tf.float32) 25 | h_fc2_drop = tf.nn.dropout(h_fc2, keep_prob=keep_prob2, seed=SEED) 26 | 27 | W_fc3 = tf.Variable(tf.truncated_normal([1000, 10], stddev=0.01, seed=SEED)) 28 | b_fc3 = tf.Variable(tf.zeros([10])) 29 | y = tf.add(tf.matmul(h_fc2_drop, W_fc3), b_fc3, name="Prediction") 30 | 31 | return y, keep_prob1, keep_prob2 32 | 33 | 34 | def deep_network(x): 35 | x_image = tf.reshape(x, [-1, 28, 28, 1]) 36 | 37 | W_conv1 = tf.Variable(tf.truncated_normal([5, 5, 1, 32], stddev=0.1)) 38 | b_conv1 = tf.Variable(tf.constant(0.1, shape=[32])) 39 | h_conv1 = tf.nn.relu(tf.nn.conv2d(x_image, W_conv1, strides=[1, 1, 1, 1], padding='SAME') + b_conv1) 40 | 41 | h_pool1 = tf.nn.max_pool(h_conv1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME') 42 | 43 | W_conv2 = tf.Variable(tf.truncated_normal([5, 5, 32, 64], stddev=0.1)) 44 | b_conv2 = tf.Variable(tf.constant(0.1, shape=[64])) 45 | h_conv2 = tf.nn.relu(tf.nn.conv2d(h_pool1, W_conv2, strides=[1, 1, 1, 1], padding='SAME') + b_conv2) 46 | 47 | h_pool2 = tf.nn.max_pool(h_conv2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME') 48 | 49 | W_fc1 = tf.Variable(tf.truncated_normal([7 * 7 * 64, 1024], stddev=0.1)) 50 | b_fc1 = tf.Variable(tf.constant(0.1, shape=[1024])) 51 | 52 | h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64]) 53 | h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1) 54 | 55 | keep_prob = tf.placeholder(tf.float32) 56 | h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob) 57 | 58 | W_fc2 = tf.Variable(tf.truncated_normal([1024, 10], stddev=0.1)) 59 | b_fc2 = tf.Variable(tf.constant(0.1, shape=[10])) 60 | 61 | y = tf.add(tf.matmul(h_fc1_drop, W_fc2), b_fc2, name="Prediction") 62 | 63 | return y, keep_prob 64 | 65 | 66 | def main(): 67 | mnist = input_data.read_data_sets('data', one_hot=True, seed=SEED) 68 | 69 | x = tf.placeholder(tf.float32, [None, 784]) 70 | y_ = tf.placeholder(tf.float32, [None, 10]) 71 | 72 | y, keep_prob = deep_network(x) 73 | 74 | cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y)) 75 | 76 | train_step = tf.train.AdagradOptimizer(learning_rate=0.1).minimize(cross_entropy) 77 | correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) 78 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 79 | 80 | saver = tf.train.Saver() 81 | 82 | with tf.Session() as sess: 83 | sess.run(tf.global_variables_initializer()) 84 | 85 | iterations = 1000 86 | for i in range(iterations): 87 | if i % 100 == 0: 88 | print("{}/{}".format(i, iterations)) 89 | batch = mnist.train.next_batch(128) 90 | sess.run(train_step, feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5}) 91 | 92 | acc = sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}) 93 | print("Test accuracy: {}".format(acc)) 94 | 95 | saver.save(sess, 'checkpoints/mnist.ckpt') # Just in case 96 | 97 | tf_freeze.freeze(sess, 'output/mnist.pb', ['Prediction']) 98 | 99 | 100 | if __name__ == '__main__': 101 | main() 102 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [bdist_wheel] 2 | universal=1 3 | 4 | [metadata] 5 | license_file = LICENSE.txt 6 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | from setuptools import find_packages, setup 4 | 5 | import pypandoc 6 | 7 | setup( 8 | name="benderthon", 9 | version="0.3.0", 10 | description="Set of utilities to work easier with Bender.", 11 | long_description=pypandoc.convert('README.md', 'rst'), 12 | url="https://github.com/xmartlabs/benderthon", 13 | keywords=["Bender", "machine learning", "artificial intelligence", "freeze", "model", "utility", "utilities", 14 | "TensorFlow"], 15 | classifiers=[ 16 | 'Development Status :: 3 - Alpha', 17 | 'Environment :: Console', 18 | 'Intended Audience :: Developers', 19 | 'Intended Audience :: Science/Research', 20 | 'License :: OSI Approved :: Apache Software License', 21 | 'Operating System :: OS Independent', 22 | 'Programming Language :: Python', 23 | 'Programming Language :: Python :: 2', 24 | 'Programming Language :: Python :: 2.7', 25 | 'Programming Language :: Python :: 3', 26 | 'Programming Language :: Python :: 3.3', 27 | 'Programming Language :: Python :: 3.4', 28 | 'Programming Language :: Python :: 3.5', 29 | 'Programming Language :: Python :: 3.6', 30 | 'Topic :: Scientific/Engineering', 31 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 32 | ], 33 | packages=find_packages(), 34 | entry_points={ 35 | 'console_scripts': [ 36 | 'benderthon = benderthon.cmdline:main', 37 | ], 38 | }, 39 | author="Xmartlabs", 40 | author_email="hi@xmartlabs.com", 41 | maintainer="Santiago Castro", 42 | maintainer_email="santiago@xmartlabs.com", 43 | license="Apache 2.0", 44 | ) 45 | -------------------------------------------------------------------------------- /test: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | ./sample.py 6 | benderthon/cmdline.py tf-freeze --no-weights checkpoints/mnist.ckpt output/g2.pb Prediction 7 | diff testdata/g.pb output/g2.pb 8 | -------------------------------------------------------------------------------- /testdata/g.pb: -------------------------------------------------------------------------------- 1 | node { 2 | name: "Placeholder" 3 | op: "Placeholder" 4 | attr { 5 | key: "dtype" 6 | value { 7 | type: DT_FLOAT 8 | } 9 | } 10 | attr { 11 | key: "shape" 12 | value { 13 | shape { 14 | dim { 15 | size: -1 16 | } 17 | dim { 18 | size: 784 19 | } 20 | } 21 | } 22 | } 23 | } 24 | node { 25 | name: "Reshape/shape" 26 | op: "Const" 27 | attr { 28 | key: "dtype" 29 | value { 30 | type: DT_INT32 31 | } 32 | } 33 | attr { 34 | key: "value" 35 | value { 36 | tensor { 37 | dtype: DT_INT32 38 | tensor_shape { 39 | dim { 40 | size: 4 41 | } 42 | } 43 | tensor_content: "\377\377\377\377\034\000\000\000\034\000\000\000\001\000\000\000" 44 | } 45 | } 46 | } 47 | } 48 | node { 49 | name: "Reshape" 50 | op: "Reshape" 51 | input: "Placeholder" 52 | input: "Reshape/shape" 53 | attr { 54 | key: "T" 55 | value { 56 | type: DT_FLOAT 57 | } 58 | } 59 | attr { 60 | key: "Tshape" 61 | value { 62 | type: DT_INT32 63 | } 64 | } 65 | } 66 | node { 67 | name: "Variable" 68 | op: "VariableV2" 69 | attr { 70 | key: "container" 71 | value { 72 | s: "" 73 | } 74 | } 75 | attr { 76 | key: "dtype" 77 | value { 78 | type: DT_FLOAT 79 | } 80 | } 81 | attr { 82 | key: "shape" 83 | value { 84 | shape { 85 | dim { 86 | size: 5 87 | } 88 | dim { 89 | size: 5 90 | } 91 | dim { 92 | size: 1 93 | } 94 | dim { 95 | size: 32 96 | } 97 | } 98 | } 99 | } 100 | attr { 101 | key: "shared_name" 102 | value { 103 | s: "" 104 | } 105 | } 106 | } 107 | node { 108 | name: "Variable/read" 109 | op: "Identity" 110 | input: "Variable" 111 | attr { 112 | key: "T" 113 | value { 114 | type: DT_FLOAT 115 | } 116 | } 117 | attr { 118 | key: "_class" 119 | value { 120 | list { 121 | s: "loc:@Variable" 122 | } 123 | } 124 | } 125 | } 126 | node { 127 | name: "Const" 128 | op: "Const" 129 | attr { 130 | key: "dtype" 131 | value { 132 | type: DT_FLOAT 133 | } 134 | } 135 | attr { 136 | key: "value" 137 | value { 138 | tensor { 139 | dtype: DT_FLOAT 140 | tensor_shape { 141 | dim { 142 | size: 32 143 | } 144 | } 145 | float_val: 0.10000000149 146 | } 147 | } 148 | } 149 | } 150 | node { 151 | name: "Conv2D" 152 | op: "Conv2D" 153 | input: "Reshape" 154 | input: "Variable/read" 155 | attr { 156 | key: "T" 157 | value { 158 | type: DT_FLOAT 159 | } 160 | } 161 | attr { 162 | key: "data_format" 163 | value { 164 | s: "NHWC" 165 | } 166 | } 167 | attr { 168 | key: "padding" 169 | value { 170 | s: "SAME" 171 | } 172 | } 173 | attr { 174 | key: "strides" 175 | value { 176 | list { 177 | i: 1 178 | i: 1 179 | i: 1 180 | i: 1 181 | } 182 | } 183 | } 184 | attr { 185 | key: "use_cudnn_on_gpu" 186 | value { 187 | b: true 188 | } 189 | } 190 | } 191 | node { 192 | name: "add" 193 | op: "Add" 194 | input: "Conv2D" 195 | input: "Const" 196 | attr { 197 | key: "T" 198 | value { 199 | type: DT_FLOAT 200 | } 201 | } 202 | } 203 | node { 204 | name: "Relu" 205 | op: "Relu" 206 | input: "add" 207 | attr { 208 | key: "T" 209 | value { 210 | type: DT_FLOAT 211 | } 212 | } 213 | } 214 | node { 215 | name: "MaxPool" 216 | op: "MaxPool" 217 | input: "Relu" 218 | attr { 219 | key: "T" 220 | value { 221 | type: DT_FLOAT 222 | } 223 | } 224 | attr { 225 | key: "data_format" 226 | value { 227 | s: "NHWC" 228 | } 229 | } 230 | attr { 231 | key: "ksize" 232 | value { 233 | list { 234 | i: 1 235 | i: 2 236 | i: 2 237 | i: 1 238 | } 239 | } 240 | } 241 | attr { 242 | key: "padding" 243 | value { 244 | s: "SAME" 245 | } 246 | } 247 | attr { 248 | key: "strides" 249 | value { 250 | list { 251 | i: 1 252 | i: 2 253 | i: 2 254 | i: 1 255 | } 256 | } 257 | } 258 | } 259 | node { 260 | name: "Variable_1" 261 | op: "VariableV2" 262 | attr { 263 | key: "container" 264 | value { 265 | s: "" 266 | } 267 | } 268 | attr { 269 | key: "dtype" 270 | value { 271 | type: DT_FLOAT 272 | } 273 | } 274 | attr { 275 | key: "shape" 276 | value { 277 | shape { 278 | dim { 279 | size: 5 280 | } 281 | dim { 282 | size: 5 283 | } 284 | dim { 285 | size: 32 286 | } 287 | dim { 288 | size: 64 289 | } 290 | } 291 | } 292 | } 293 | attr { 294 | key: "shared_name" 295 | value { 296 | s: "" 297 | } 298 | } 299 | } 300 | node { 301 | name: "Variable_1/read" 302 | op: "Identity" 303 | input: "Variable_1" 304 | attr { 305 | key: "T" 306 | value { 307 | type: DT_FLOAT 308 | } 309 | } 310 | attr { 311 | key: "_class" 312 | value { 313 | list { 314 | s: "loc:@Variable_1" 315 | } 316 | } 317 | } 318 | } 319 | node { 320 | name: "Const_1" 321 | op: "Const" 322 | attr { 323 | key: "dtype" 324 | value { 325 | type: DT_FLOAT 326 | } 327 | } 328 | attr { 329 | key: "value" 330 | value { 331 | tensor { 332 | dtype: DT_FLOAT 333 | tensor_shape { 334 | dim { 335 | size: 64 336 | } 337 | } 338 | float_val: 0.10000000149 339 | } 340 | } 341 | } 342 | } 343 | node { 344 | name: "Conv2D_1" 345 | op: "Conv2D" 346 | input: "MaxPool" 347 | input: "Variable_1/read" 348 | attr { 349 | key: "T" 350 | value { 351 | type: DT_FLOAT 352 | } 353 | } 354 | attr { 355 | key: "data_format" 356 | value { 357 | s: "NHWC" 358 | } 359 | } 360 | attr { 361 | key: "padding" 362 | value { 363 | s: "SAME" 364 | } 365 | } 366 | attr { 367 | key: "strides" 368 | value { 369 | list { 370 | i: 1 371 | i: 1 372 | i: 1 373 | i: 1 374 | } 375 | } 376 | } 377 | attr { 378 | key: "use_cudnn_on_gpu" 379 | value { 380 | b: true 381 | } 382 | } 383 | } 384 | node { 385 | name: "add_1" 386 | op: "Add" 387 | input: "Conv2D_1" 388 | input: "Const_1" 389 | attr { 390 | key: "T" 391 | value { 392 | type: DT_FLOAT 393 | } 394 | } 395 | } 396 | node { 397 | name: "Relu_1" 398 | op: "Relu" 399 | input: "add_1" 400 | attr { 401 | key: "T" 402 | value { 403 | type: DT_FLOAT 404 | } 405 | } 406 | } 407 | node { 408 | name: "MaxPool_1" 409 | op: "MaxPool" 410 | input: "Relu_1" 411 | attr { 412 | key: "T" 413 | value { 414 | type: DT_FLOAT 415 | } 416 | } 417 | attr { 418 | key: "data_format" 419 | value { 420 | s: "NHWC" 421 | } 422 | } 423 | attr { 424 | key: "ksize" 425 | value { 426 | list { 427 | i: 1 428 | i: 2 429 | i: 2 430 | i: 1 431 | } 432 | } 433 | } 434 | attr { 435 | key: "padding" 436 | value { 437 | s: "SAME" 438 | } 439 | } 440 | attr { 441 | key: "strides" 442 | value { 443 | list { 444 | i: 1 445 | i: 2 446 | i: 2 447 | i: 1 448 | } 449 | } 450 | } 451 | } 452 | node { 453 | name: "Variable_2" 454 | op: "VariableV2" 455 | attr { 456 | key: "container" 457 | value { 458 | s: "" 459 | } 460 | } 461 | attr { 462 | key: "dtype" 463 | value { 464 | type: DT_FLOAT 465 | } 466 | } 467 | attr { 468 | key: "shape" 469 | value { 470 | shape { 471 | dim { 472 | size: 3136 473 | } 474 | dim { 475 | size: 1024 476 | } 477 | } 478 | } 479 | } 480 | attr { 481 | key: "shared_name" 482 | value { 483 | s: "" 484 | } 485 | } 486 | } 487 | node { 488 | name: "Variable_2/read" 489 | op: "Identity" 490 | input: "Variable_2" 491 | attr { 492 | key: "T" 493 | value { 494 | type: DT_FLOAT 495 | } 496 | } 497 | attr { 498 | key: "_class" 499 | value { 500 | list { 501 | s: "loc:@Variable_2" 502 | } 503 | } 504 | } 505 | } 506 | node { 507 | name: "Const_2" 508 | op: "Const" 509 | attr { 510 | key: "dtype" 511 | value { 512 | type: DT_FLOAT 513 | } 514 | } 515 | attr { 516 | key: "value" 517 | value { 518 | tensor { 519 | dtype: DT_FLOAT 520 | tensor_shape { 521 | dim { 522 | size: 1024 523 | } 524 | } 525 | float_val: 0.10000000149 526 | } 527 | } 528 | } 529 | } 530 | node { 531 | name: "Reshape_1/shape" 532 | op: "Const" 533 | attr { 534 | key: "dtype" 535 | value { 536 | type: DT_INT32 537 | } 538 | } 539 | attr { 540 | key: "value" 541 | value { 542 | tensor { 543 | dtype: DT_INT32 544 | tensor_shape { 545 | dim { 546 | size: 2 547 | } 548 | } 549 | tensor_content: "\377\377\377\377@\014\000\000" 550 | } 551 | } 552 | } 553 | } 554 | node { 555 | name: "Reshape_1" 556 | op: "Reshape" 557 | input: "MaxPool_1" 558 | input: "Reshape_1/shape" 559 | attr { 560 | key: "T" 561 | value { 562 | type: DT_FLOAT 563 | } 564 | } 565 | attr { 566 | key: "Tshape" 567 | value { 568 | type: DT_INT32 569 | } 570 | } 571 | } 572 | node { 573 | name: "MatMul" 574 | op: "MatMul" 575 | input: "Reshape_1" 576 | input: "Variable_2/read" 577 | attr { 578 | key: "T" 579 | value { 580 | type: DT_FLOAT 581 | } 582 | } 583 | attr { 584 | key: "transpose_a" 585 | value { 586 | b: false 587 | } 588 | } 589 | attr { 590 | key: "transpose_b" 591 | value { 592 | b: false 593 | } 594 | } 595 | } 596 | node { 597 | name: "add_2" 598 | op: "Add" 599 | input: "MatMul" 600 | input: "Const_2" 601 | attr { 602 | key: "T" 603 | value { 604 | type: DT_FLOAT 605 | } 606 | } 607 | } 608 | node { 609 | name: "Relu_2" 610 | op: "Relu" 611 | input: "add_2" 612 | attr { 613 | key: "T" 614 | value { 615 | type: DT_FLOAT 616 | } 617 | } 618 | } 619 | node { 620 | name: "Placeholder_2" 621 | op: "Placeholder" 622 | attr { 623 | key: "dtype" 624 | value { 625 | type: DT_FLOAT 626 | } 627 | } 628 | attr { 629 | key: "shape" 630 | value { 631 | shape { 632 | unknown_rank: true 633 | } 634 | } 635 | } 636 | } 637 | node { 638 | name: "dropout/Shape" 639 | op: "Shape" 640 | input: "Relu_2" 641 | attr { 642 | key: "T" 643 | value { 644 | type: DT_FLOAT 645 | } 646 | } 647 | attr { 648 | key: "out_type" 649 | value { 650 | type: DT_INT32 651 | } 652 | } 653 | } 654 | node { 655 | name: "dropout/random_uniform/min" 656 | op: "Const" 657 | attr { 658 | key: "dtype" 659 | value { 660 | type: DT_FLOAT 661 | } 662 | } 663 | attr { 664 | key: "value" 665 | value { 666 | tensor { 667 | dtype: DT_FLOAT 668 | tensor_shape { 669 | } 670 | float_val: 0.0 671 | } 672 | } 673 | } 674 | } 675 | node { 676 | name: "dropout/random_uniform/max" 677 | op: "Const" 678 | attr { 679 | key: "dtype" 680 | value { 681 | type: DT_FLOAT 682 | } 683 | } 684 | attr { 685 | key: "value" 686 | value { 687 | tensor { 688 | dtype: DT_FLOAT 689 | tensor_shape { 690 | } 691 | float_val: 1.0 692 | } 693 | } 694 | } 695 | } 696 | node { 697 | name: "dropout/random_uniform/RandomUniform" 698 | op: "RandomUniform" 699 | input: "dropout/Shape" 700 | attr { 701 | key: "T" 702 | value { 703 | type: DT_INT32 704 | } 705 | } 706 | attr { 707 | key: "dtype" 708 | value { 709 | type: DT_FLOAT 710 | } 711 | } 712 | attr { 713 | key: "seed" 714 | value { 715 | i: 0 716 | } 717 | } 718 | attr { 719 | key: "seed2" 720 | value { 721 | i: 0 722 | } 723 | } 724 | } 725 | node { 726 | name: "dropout/random_uniform/sub" 727 | op: "Sub" 728 | input: "dropout/random_uniform/max" 729 | input: "dropout/random_uniform/min" 730 | attr { 731 | key: "T" 732 | value { 733 | type: DT_FLOAT 734 | } 735 | } 736 | } 737 | node { 738 | name: "dropout/random_uniform/mul" 739 | op: "Mul" 740 | input: "dropout/random_uniform/RandomUniform" 741 | input: "dropout/random_uniform/sub" 742 | attr { 743 | key: "T" 744 | value { 745 | type: DT_FLOAT 746 | } 747 | } 748 | } 749 | node { 750 | name: "dropout/random_uniform" 751 | op: "Add" 752 | input: "dropout/random_uniform/mul" 753 | input: "dropout/random_uniform/min" 754 | attr { 755 | key: "T" 756 | value { 757 | type: DT_FLOAT 758 | } 759 | } 760 | } 761 | node { 762 | name: "dropout/add" 763 | op: "Add" 764 | input: "Placeholder_2" 765 | input: "dropout/random_uniform" 766 | attr { 767 | key: "T" 768 | value { 769 | type: DT_FLOAT 770 | } 771 | } 772 | } 773 | node { 774 | name: "dropout/Floor" 775 | op: "Floor" 776 | input: "dropout/add" 777 | attr { 778 | key: "T" 779 | value { 780 | type: DT_FLOAT 781 | } 782 | } 783 | } 784 | node { 785 | name: "dropout/div" 786 | op: "RealDiv" 787 | input: "Relu_2" 788 | input: "Placeholder_2" 789 | attr { 790 | key: "T" 791 | value { 792 | type: DT_FLOAT 793 | } 794 | } 795 | } 796 | node { 797 | name: "dropout/mul" 798 | op: "Mul" 799 | input: "dropout/div" 800 | input: "dropout/Floor" 801 | attr { 802 | key: "T" 803 | value { 804 | type: DT_FLOAT 805 | } 806 | } 807 | } 808 | node { 809 | name: "Variable_3" 810 | op: "VariableV2" 811 | attr { 812 | key: "container" 813 | value { 814 | s: "" 815 | } 816 | } 817 | attr { 818 | key: "dtype" 819 | value { 820 | type: DT_FLOAT 821 | } 822 | } 823 | attr { 824 | key: "shape" 825 | value { 826 | shape { 827 | dim { 828 | size: 1024 829 | } 830 | dim { 831 | size: 10 832 | } 833 | } 834 | } 835 | } 836 | attr { 837 | key: "shared_name" 838 | value { 839 | s: "" 840 | } 841 | } 842 | } 843 | node { 844 | name: "Variable_3/read" 845 | op: "Identity" 846 | input: "Variable_3" 847 | attr { 848 | key: "T" 849 | value { 850 | type: DT_FLOAT 851 | } 852 | } 853 | attr { 854 | key: "_class" 855 | value { 856 | list { 857 | s: "loc:@Variable_3" 858 | } 859 | } 860 | } 861 | } 862 | node { 863 | name: "Const_3" 864 | op: "Const" 865 | attr { 866 | key: "dtype" 867 | value { 868 | type: DT_FLOAT 869 | } 870 | } 871 | attr { 872 | key: "value" 873 | value { 874 | tensor { 875 | dtype: DT_FLOAT 876 | tensor_shape { 877 | dim { 878 | size: 10 879 | } 880 | } 881 | float_val: 0.10000000149 882 | } 883 | } 884 | } 885 | } 886 | node { 887 | name: "MatMul_1" 888 | op: "MatMul" 889 | input: "dropout/mul" 890 | input: "Variable_3/read" 891 | attr { 892 | key: "T" 893 | value { 894 | type: DT_FLOAT 895 | } 896 | } 897 | attr { 898 | key: "transpose_a" 899 | value { 900 | b: false 901 | } 902 | } 903 | attr { 904 | key: "transpose_b" 905 | value { 906 | b: false 907 | } 908 | } 909 | } 910 | node { 911 | name: "Prediction" 912 | op: "Add" 913 | input: "MatMul_1" 914 | input: "Const_3" 915 | attr { 916 | key: "T" 917 | value { 918 | type: DT_FLOAT 919 | } 920 | } 921 | } 922 | library { 923 | } 924 | versions { 925 | producer: 22 926 | } 927 | --------------------------------------------------------------------------------