├── .gitignore ├── README.md └── keras_to_tensorflow.py /.gitignore: -------------------------------------------------------------------------------- 1 | keras_model/* 2 | tensorflow_model/* 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # keras_to_tensorflow 2 | kerasで出力したmodelファイル(.h5)をtensorflowのmodel形式(.pb)に変換するスクリプト 3 | 4 | [amir-abdi](https://github.com/amir-abdi)さんが公開されている.ipyenvコードを.pyに変換してから改良。 5 | 6 | 7 | * keras modelをコマンドライン引数から選択 8 | * 出力ファイル名は入力ファイルを元に定義 9 | 10 | 11 | ## 使い方 12 | **kerasとtensorflowが動く環境を用意** 13 | 14 | ```bash 15 | pip install keras 16 | pip install tensorflow 17 | ``` 18 | 19 | **以下のようなディレクトリ構造に** 20 | 21 | ``` 22 | . 23 | ├── README.md 24 | ├── keras_model 25 | │   └── keras_model.h5 26 | ├── keras_to_tensorflow.py 27 | └── tensorflow_model (存在しない場合は自動生成) 28 | ``` 29 | 30 | **実行** 31 | ```bash 32 | # sample 33 | python keras_to_tensorflow.py keras_model/keras_model.h5 34 | ``` 35 | 36 | 37 | ## Original README 38 | ### keras_to_tensorflow 39 | General code to convert a trained keras model into an inference tensorflow model 40 | 41 | The notebook ```keras_to_tensorflow```, is a sample code which loads a trained keras model, freezes the nodes (converts all tensorflow variables to tensorflow constants), and saves the inference graph and weights into a protobuf file (.pb). This file can then be used to deploy the trained model for inference. During freezing, other nodes of the network, which do not contribute the tensor that would contain the output predictions, are prunned. This results in a smaller, optimized network. 42 | 43 | The code on how to freeze and save keras models in previous versions of tensorflow is also available. Back then, the freeze_graph tool (```/tensorflow/python/tools/freeze_graph.py```) was used to convert the variables into constants. This functionality is now handled by ```graph_util.convert_variables_to_constants``` 44 | -------------------------------------------------------------------------------- /keras_to_tensorflow.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | """ 5 | __doc__ 6 | General code to convert a trained keras model into an inference tensorflow model. 7 | """ 8 | 9 | __author__ = "Haruyuki Ichino" 10 | __version__ = "1.0" 11 | __date__ = "2017/08/20" 12 | 13 | print(__doc__) 14 | 15 | 16 | import sys 17 | 18 | from keras.models import load_model 19 | import tensorflow as tf 20 | import os 21 | import os.path as osp 22 | from keras import backend as K 23 | 24 | from tensorflow.python.framework import graph_util 25 | from tensorflow.python.framework import graph_io 26 | 27 | 28 | # Set parameters 29 | if len(sys.argv) != 2: 30 | print("Usage: python keras_to_tensorflow.py [keras model file path]") 31 | sys.exit(1) 32 | keras_model_path = sys.argv[1] 33 | num_output = 1 34 | write_graph_def_ascii_flag = True 35 | prefix_output_node_names_of_final_network = 'output_node' 36 | keras_model_name = keras_model_path.split("/")[-1].split(".")[0] 37 | tensorflow_graph_name = keras_model_name + '.pb' 38 | 39 | # 出力ディレクトリの準備 40 | output_dir = './tensorflow_model/' 41 | if not os.path.isdir(output_dir): 42 | os.mkdir(output_dir) 43 | 44 | 45 | # Load keras model and rename output 46 | K.set_learning_phase(0) 47 | keras_model = load_model(keras_model_path) 48 | 49 | pred = [None]*num_output 50 | pred_node_names = [None]*num_output 51 | for i in range(num_output): 52 | pred_node_names[i] = prefix_output_node_names_of_final_network+str(i) 53 | pred[i] = tf.identity(keras_model.output[i], name=pred_node_names[i]) 54 | print('Output nodes names: ', pred_node_names) 55 | 56 | 57 | # [optional] write graph definition in ascii 58 | sess = K.get_session() 59 | if write_graph_def_ascii_flag: 60 | f = tensorflow_graph_name + '.ascii' 61 | tf.train.write_graph(sess.graph.as_graph_def(), output_dir, f, as_text=True) 62 | print('Saved the graph definition: ', osp.join(output_dir, f)) 63 | 64 | 65 | # convert variables to constants and save 66 | constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(), pred_node_names) 67 | graph_io.write_graph(constant_graph, output_dir, tensorflow_graph_name, as_text=False) 68 | 69 | print('Saved the TensorFlow graph: ', osp.join(output_dir, tensorflow_graph_name)) 70 | --------------------------------------------------------------------------------