├── .gitignore ├── LICENSE ├── README.md └── convertkeras.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # IPython Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule 77 | 78 | # dotenv 79 | .env 80 | 81 | # virtualenv 82 | venv/ 83 | ENV/ 84 | 85 | # Spyder project settings 86 | .spyderproject 87 | 88 | # Rope project settings 89 | .ropeproject 90 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Alan Steremberg 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # keras_to_tensorflow 2 | Convert keras models to tensorflow frozen graph for use on cell phones, etc 3 | 4 | The last parameter of the script takes the path to the freeze graph tool. Build it here: 5 | ``` 6 | bazel build tensorflow/python/tools:freeze_graph 7 | ``` 8 | It usually lives here off of your tensorflow directory: 9 | ``` 10 | tensorflow/bazel-bin/tensorflow/python/tools/freeze_graph 11 | ``` 12 | 13 | The script needs to be fixed so it doesn't put ./ in front of the paths. I was having a little trouble getting things to work. Feel free to fix and submit a pull request. 14 | -------------------------------------------------------------------------------- /convertkeras.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copywrite 2017 Alan Steremberg and Arthur Conner 3 | # 4 | 5 | import argparse 6 | from tensorflow.python.keras import backend as K 7 | from keras.models import load_model 8 | #from tensorflow_serving.session_bundle import exporter 9 | from keras.models import model_from_config 10 | from keras.models import Sequential,Model 11 | import tensorflow as tf 12 | import os 13 | # Disable the eager execution mode 14 | tf.compat.v1.disable_eager_execution() 15 | 16 | def convert(prevmodel,export_path,freeze_graph_binary): 17 | 18 | # open up a Tensorflow session 19 | sess = tf.compat.v1.Session() 20 | # tell Keras to use the session 21 | K.set_session(sess) 22 | 23 | # From this document: https://blog.keras.io/keras-as-a-simplified-interface-to-tensorflow-tutorial.html 24 | 25 | # let's convert the model for inference 26 | K.set_learning_phase(0) # all new operations will be in test mode from now on 27 | # serialize the model and get its weights, for quick re-building 28 | previous_model = load_model(prevmodel) 29 | previous_model.summary() 30 | 31 | config = previous_model.get_config() 32 | weights = previous_model.get_weights() 33 | 34 | # re-build a model where the learning phase is now hard-coded to 0 35 | try: 36 | model= Sequential.from_config(config) 37 | except: 38 | model= Model.from_config(config) 39 | #model= model_from_config(config) 40 | model.set_weights(weights) 41 | 42 | print("Input name:") 43 | print(model.input.name) 44 | print("Output name:") 45 | print(model.output.name) 46 | output_name=model.output.name.split(':')[0] 47 | 48 | # not sure what this is for 49 | export_version = 1 # version number (integer) 50 | 51 | graph_file=export_path+"_graph.pb" 52 | ckpt_file=export_path+".ckpt" 53 | # create a saver 54 | saver = tf.compat.v1.train.Saver(sharded=True) 55 | tf.io.write_graph(sess.graph_def, '', graph_file) 56 | save_path = saver.save(sess, ckpt_file) 57 | #~/tensorflow/bazel-bin/tensorflow/python/tools/freeze_graph --input_graph=./graph.pb --input_checkpoint=./model.ckpt --output_node_names=add_72 --output_graph=frozen.pb 58 | command = freeze_graph_binary +" --input_graph=./"+graph_file+" --input_checkpoint=./"+ckpt_file+" --output_node_names="+output_name+" --output_graph=./"+export_path+".pb" 59 | print(command) 60 | os.system(command) 61 | 62 | if __name__ == '__main__': 63 | parser = argparse.ArgumentParser(description='Keras Tensorflow Converter') 64 | parser.add_argument( 65 | 'model', 66 | type=str, 67 | help='Path to the keras model' 68 | ) 69 | parser.add_argument( 70 | 'frozen', 71 | type=str, 72 | help='Path to the frozen output' 73 | ) 74 | parser.add_argument( 75 | 'freezegraph', 76 | type=str, 77 | help='Path to the freeze_graph binary' 78 | ) 79 | args = parser.parse_args() 80 | 81 | convert(args.model,args.frozen,args.freezegraph) 82 | --------------------------------------------------------------------------------