├── requirements.txt ├── dataset.txt ├── data ├── dogs.jpg ├── beaver.jpg ├── elephant.jpg ├── cat_299x299.jpg └── dog_299x299.jpg ├── .gitignore ├── LICENSE ├── README.md ├── benchmark_incption_v3.py ├── convert_rknn.py ├── freeze_graph.py └── freeze_graph.ipynb /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow 2 | scipy 3 | onnx 4 | numpy==1.16.1 -------------------------------------------------------------------------------- /dataset.txt: -------------------------------------------------------------------------------- 1 | ./data/dog_299x299.jpg 2 | ./data/cat_299x299.jpg 3 | -------------------------------------------------------------------------------- /data/dogs.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tony607/Keras_RK3399pro/HEAD/data/dogs.jpg -------------------------------------------------------------------------------- /data/beaver.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tony607/Keras_RK3399pro/HEAD/data/beaver.jpg -------------------------------------------------------------------------------- /data/elephant.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tony607/Keras_RK3399pro/HEAD/data/elephant.jpg -------------------------------------------------------------------------------- /data/cat_299x299.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tony607/Keras_RK3399pro/HEAD/data/cat_299x299.jpg -------------------------------------------------------------------------------- /data/dog_299x299.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tony607/Keras_RK3399pro/HEAD/data/dog_299x299.jpg -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.hdf5 3 | *.ipynb_checkpoints 4 | *.p 5 | *.h5 6 | *.HDF5 7 | __pycache__ 8 | # Ignore model directory. 9 | model/ 10 | .vscode/ 11 | 12 | *.rknn 13 | *.pb 14 | *.h5 -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | COPYRIGHT 2 | 3 | All contributions by François Chollet: 4 | Copyright (c) 2015 - 2018, François Chollet. 5 | All rights reserved. 6 | 7 | All contributions by Google: 8 | Copyright (c) 2015 - 2018, Google, Inc. 9 | All rights reserved. 10 | 11 | All contributions by Microsoft: 12 | Copyright (c) 2017 - 2018, Microsoft, Inc. 13 | All rights reserved. 14 | 15 | All other contributions: 16 | Copyright (c) 2015 - 2018, the respective contributors. 17 | All rights reserved. 18 | 19 | Each contributor holds copyright over their respective contributions. 20 | The project versioning (Git) records all such contribution source information. 21 | 22 | LICENSE 23 | 24 | The MIT License (MIT) 25 | 26 | Permission is hereby granted, free of charge, to any person obtaining a copy 27 | of this software and associated documentation files (the "Software"), to deal 28 | in the Software without restriction, including without limitation the rights 29 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 30 | copies of the Software, and to permit persons to whom the Software is 31 | furnished to do so, subject to the following conditions: 32 | 33 | The above copyright notice and this permission notice shall be included in all 34 | copies or substantial portions of the Software. 35 | 36 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 37 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 38 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 39 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 40 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 41 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 42 | SOFTWARE. 43 | 44 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [Run Keras/Tensorflow model on RK3399Pro](https://www.dlology.com/blog/how-to-run-keras-model-on-rk3399pro/) 2 | 3 | ### Clone or download this repo 4 | ``` 5 | git clone https://github.com/Tony607/Keras_RK3399pro 6 | ``` 7 | 8 | **Download pre-compiled Python wheel files from my [aarch64_python_packages](https://coding.net/u/zcw607/p/aarch64_python_packages/git) repo and [rknn_toolkit](https://github.com/rockchip-toybrick/RKNPUTool/tree/master/rknn-toolkit/package) wheels from their official GitHub.** 9 | ### Step1: Freeze Keras model and convert to RKNN model (On Linux development machine) 10 | Require [Python 3.5+](https://www.python.org/ftp/python/3.6.7/python-3.6.7.exe). 11 | 12 | ### Install required libraries for your development machine 13 | `pip3 install -r requirements.txt` 14 | 15 | The install rknn toolkit with the following command. 16 | ``` 17 | pip3 install rknn_toolkit-0.9.9-cp36-cp36m-linux_x86_64.whl 18 | ``` 19 | 20 | To freeze a Keras InceptionV3 ImageNet model to a single `.pb` file. 21 | The frozen graph will accept inputs with shape `(N, 299, 299, 3)`. 22 | ``` 23 | freeze_graph.py 24 | ``` 25 | 26 | To convert the `.pb` file to `.rknn` file, run 27 | ``` 28 | python3 convert_rknn.py 29 | ``` 30 | 31 | ### Step2: Make prediction (On RK3399Pro board) 32 | Setup for the first time. 33 | ```bash 34 | sudo dnf update -y 35 | sudo dnf install -y cmake gcc gcc-c++ protobuf-devel protobuf-compiler lapack-devel 36 | sudo dnf install -y python3-devel python3-opencv python3-numpy-f2py python3-h5py python3-lmdb 37 | sudo dnf install -y python3-grpcio 38 | 39 | sudo pip3 install scipy-1.2.0-cp36-cp36m-linux_aarch64.whl 40 | sudo pip3 install onnx-1.4.1-cp36-cp36m-linux_aarch64.whl 41 | sudo pip3 install tensorflow-1.10.1-cp36-cp36m-linux_aarch64.whl 42 | sudo pip3 install rknn_toolkit-0.9.9-cp36-cp36m-linux_aarch64.whl 43 | ``` 44 | 45 | 46 | To run inference benchmark on RK3399Pro board, in its terminal run, 47 | ``` 48 | python3 benchmark_incption_v3.py 49 | ``` -------------------------------------------------------------------------------- /benchmark_incption_v3.py: -------------------------------------------------------------------------------- 1 | import time 2 | import platform 3 | import numpy as np 4 | import cv2 5 | from rknn.api import RKNN 6 | 7 | 8 | def show_outputs(outputs): 9 | output = outputs[0][0] 10 | output_sorted = sorted(output, reverse=True) 11 | top5_str = "Inception_v3\n-----TOP 5-----\n" 12 | for i in range(5): 13 | value = output_sorted[i] 14 | index = np.where(output == value) 15 | for j in range(len(index)): 16 | if (i + j) >= 5: 17 | break 18 | if value > 0: 19 | topi = "{}: {}\n".format(index[j], value) 20 | else: 21 | topi = "-1: 0.0\n" 22 | top5_str += topi 23 | print("top5_str:`{}`".format(top5_str)) 24 | 25 | 26 | if __name__ == "__main__": 27 | 28 | # Create RKNN object 29 | rknn = RKNN() 30 | img_height = 299 31 | # Direct Load RKNN Model 32 | print("--> Loading RKNN model") 33 | ret = rknn.load_rknn("./inception_v3.rknn") 34 | if ret != 0: 35 | print("Load inception_v3.rknn failed!") 36 | exit(ret) 37 | 38 | # Set inputs 39 | img = cv2.imread("./data/beaver.jpg") 40 | img = cv2.resize(img, dsize=(img_height, img_height), interpolation=cv2.INTER_CUBIC) 41 | 42 | # This can opt out if "reorder_channel" is set to "2 1 0" 43 | # rknn.config() in `convert_rknn.py` 44 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 45 | 46 | # init runtime environment 47 | print("--> Init runtime environment") 48 | 49 | if "aarch64" in platform.platform(): 50 | target = "rk3399pro" 51 | else: 52 | target = None 53 | 54 | ret = rknn.init_runtime(target=target) 55 | if ret != 0: 56 | print("Init runtime environment failed") 57 | exit(ret) 58 | 59 | # Inference 60 | print("--> Running model") 61 | outputs = rknn.inference(inputs=[img]) 62 | 63 | outputs = np.array(outputs) 64 | print("Outputs shape: {}".format(outputs.shape)) 65 | show_outputs(outputs) 66 | 67 | # Benchmark model 68 | print("--> Benchmark model") 69 | 70 | times = [] 71 | 72 | # Run inference 20 times and do the average. 73 | for i in range(20): 74 | start_time = time.time() 75 | # Use the API internal call directly. 76 | results = rknn.rknn_base.inference( 77 | inputs=[img], data_type="uint8", data_format="nhwc", outputs=None 78 | ) 79 | # Alternatively, use the external API call. 80 | # outputs = rknn.inference(inputs=[img]) 81 | delta = time.time() - start_time 82 | times.append(delta) 83 | 84 | # Calculate the average time for inference. 85 | mean_delta = np.array(times).mean() 86 | 87 | fps = 1 / mean_delta 88 | print("average(sec):{:.3f},fps:{:.2f}".format(mean_delta, fps)) 89 | 90 | # perf 91 | print("--> Begin evaluate model performance") 92 | perf_results = rknn.eval_perf(inputs=[img]) 93 | 94 | rknn.release() 95 | -------------------------------------------------------------------------------- /convert_rknn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | from rknn.api import RKNN 4 | 5 | 6 | def show_outputs(outputs): 7 | output = outputs[0][0] 8 | output_sorted = sorted(output, reverse=True) 9 | top5_str = "inception_v3\n-----TOP 5-----\n" 10 | for i in range(5): 11 | value = output_sorted[i] 12 | index = np.where(output == value) 13 | for j in range(len(index)): 14 | if (i + j) >= 5: 15 | break 16 | if value > 0: 17 | topi = "{}: {}\n".format(index[j], value) 18 | else: 19 | topi = "-1: 0.0\n" 20 | top5_str += topi 21 | print("top5_str: {}".format(top5_str)) 22 | 23 | 24 | if __name__ == "__main__": 25 | INPUT_NODE = ["input_1"] 26 | OUTPUT_NODE = ["predictions/Softmax"] 27 | 28 | img_height = 299 29 | 30 | # Create RKNN object 31 | rknn = RKNN() 32 | 33 | # pre-process config 34 | print("--> config model") 35 | # channel_mean_value "0 0 0 255" while normalize the image data to range [0, 1] 36 | # channel_mean_value "128 128 128 128" while normalize the image data to range [-1, 1] 37 | # reorder_channel "0 1 2" will keep the color channel, "2 1 0" will swap the R and B channel, 38 | # i.e. if the input is BGR loaded by cv2.imread, it will convert it to RGB for the model input. 39 | # need_horizontal_merge is suggested for inception models (v1/v3/v4). 40 | rknn.config( 41 | channel_mean_value="128 128 128 128", 42 | reorder_channel="0 1 2", 43 | need_horizontal_merge=True, 44 | quantized_dtype="asymmetric_quantized-u8", 45 | ) 46 | 47 | # Load tensorflow model 48 | print("--> Loading model") 49 | ret = rknn.load_tensorflow( 50 | tf_pb="./model/frozen_model.pb", 51 | inputs=INPUT_NODE, 52 | outputs=OUTPUT_NODE, 53 | input_size_list=[[img_height, img_height, 3]], 54 | ) 55 | if ret != 0: 56 | print("Load inception_v3 failed!") 57 | exit(ret) 58 | 59 | # Build model 60 | print("--> Building model") 61 | # dataset: A input data set for rectifying quantization parameters. 62 | ret = rknn.build(do_quantization=True, dataset="./dataset.txt") 63 | if ret != 0: 64 | print("Build inception_v3 failed!") 65 | exit(ret) 66 | 67 | # Export rknn model 68 | print("--> Export RKNN model") 69 | ret = rknn.export_rknn("./inception_v3.rknn") 70 | if ret != 0: 71 | print("Export inception_v3.rknn failed!") 72 | exit(ret) 73 | 74 | # Set inputs 75 | img = cv2.imread("./data/elephant.jpg") 76 | img = cv2.resize(img, dsize=(img_height, img_height), interpolation=cv2.INTER_CUBIC) 77 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 78 | 79 | print("--> Init runtime environment") 80 | ret = rknn.init_runtime() 81 | if ret != 0: 82 | print("Init runtime environment failed") 83 | exit(ret) 84 | 85 | # Inference 86 | print("--> Running model") 87 | outputs = rknn.inference(inputs=[img]) 88 | show_outputs(outputs) 89 | # print('inference result: ', outputs) 90 | 91 | # perf 92 | print("--> Begin evaluate model performance") 93 | perf_results = rknn.eval_perf(inputs=[img]) 94 | 95 | rknn.release() 96 | -------------------------------------------------------------------------------- /freeze_graph.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # ## Save the Keras model as a single .h5 file. 5 | 6 | # In[1]: 7 | 8 | 9 | # Force use CPU only. 10 | import os 11 | 12 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 13 | os.environ["CUDA_VISIBLE_DEVICES"] = "-1" 14 | 15 | import tensorflow as tf 16 | from tensorflow.keras.applications.inception_v3 import InceptionV3 as Net 17 | from tensorflow.keras.preprocessing import image 18 | from tensorflow.keras.applications.inception_v3 import ( 19 | preprocess_input, 20 | decode_predictions, 21 | ) 22 | import numpy as np 23 | 24 | print("TensorFlow version: {}".format(tf.__version__)) 25 | 26 | # Optional image to test model prediction. 27 | img_path = "./data/elephant.jpg" 28 | model_path = "./model" 29 | 30 | # Path to save the model h5 file. 31 | model_fname = os.path.join(model_path, "model.h5") 32 | 33 | os.makedirs(model_path, exist_ok=True) 34 | 35 | img_height = 299 36 | 37 | model = Net(weights="imagenet", input_shape=(img_height, img_height, 3)) 38 | 39 | 40 | # Load the image for prediction. 41 | img = image.load_img(img_path, target_size=(img_height, img_height)) 42 | x = image.img_to_array(img) 43 | x = np.expand_dims(x, axis=0) 44 | x = preprocess_input(x) 45 | 46 | preds = model.predict(x) 47 | # decode the results into a list of tuples (class, description, probability) 48 | # (one such list for each sample in the batch) 49 | print("Predicted:", decode_predictions(preds, top=3)[0]) 50 | # Predicted: [(u'n02504013', u'Indian_elephant', 0.82658225), (u'n01871265', u'tusker', 0.1122357), (u'n02504458', u'African_elephant', 0.061040461)] 51 | 52 | # Save the h5 file to path specified. 53 | model.save(model_fname) 54 | 55 | 56 | # ## Benchmark Keras prediction speed. 57 | 58 | # In[2]: 59 | 60 | 61 | import time 62 | 63 | times = [] 64 | for i in range(20): 65 | start_time = time.time() 66 | preds = model.predict(x) 67 | delta = time.time() - start_time 68 | times.append(delta) 69 | mean_delta = np.array(times).mean() 70 | fps = 1 / mean_delta 71 | print("average(sec):{},fps:{}".format(mean_delta, fps)) 72 | 73 | # Clear any previous session. 74 | tf.keras.backend.clear_session() 75 | 76 | 77 | # ## Freeze graph 78 | # Generate `.pb` file. 79 | 80 | # In[3]: 81 | 82 | 83 | import tensorflow as tf 84 | from tensorflow.python.framework import graph_io 85 | from tensorflow.keras.models import load_model 86 | 87 | 88 | # Clear any previous session. 89 | tf.keras.backend.clear_session() 90 | 91 | save_pb_dir = "./model" 92 | model_fname = "./model/model.h5" 93 | 94 | 95 | def freeze_graph( 96 | graph, 97 | session, 98 | output, 99 | save_pb_dir=".", 100 | save_pb_name="frozen_model.pb", 101 | save_pb_as_text=False, 102 | ): 103 | with graph.as_default(): 104 | graphdef_inf = tf.graph_util.remove_training_nodes(graph.as_graph_def()) 105 | graphdef_frozen = tf.graph_util.convert_variables_to_constants( 106 | session, graphdef_inf, output 107 | ) 108 | graph_io.write_graph( 109 | graphdef_frozen, save_pb_dir, save_pb_name, as_text=save_pb_as_text 110 | ) 111 | return graphdef_frozen 112 | 113 | 114 | # This line must be executed before loading Keras model. 115 | tf.keras.backend.set_learning_phase(0) 116 | 117 | model = load_model(model_fname) 118 | 119 | session = tf.keras.backend.get_session() 120 | 121 | INPUT_NODE = [t.op.name for t in model.inputs] 122 | OUTPUT_NODE = [t.op.name for t in model.outputs] 123 | print("\nINPUT_NODE: {}\nOUTPUT_NODE: {}".format(INPUT_NODE, OUTPUT_NODE)) 124 | frozen_graph = freeze_graph( 125 | session.graph, 126 | session, 127 | [out.op.name for out in model.outputs], 128 | save_pb_dir=save_pb_dir, 129 | ) 130 | 131 | 132 | # ## Convert `.pb` file to RKNN model 133 | # 134 | # Run `convert_rknn.py` 135 | -------------------------------------------------------------------------------- /freeze_graph.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Save the Keras model as a single .h5 file." 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [ 15 | { 16 | "name": "stdout", 17 | "output_type": "stream", 18 | "text": [ 19 | "TensorFlow version: 1.13.1\n", 20 | "WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/resource_variable_ops.py:435: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.\n", 21 | "Instructions for updating:\n", 22 | "Colocations handled automatically by placer.\n", 23 | "Predicted: [('n02504458', 'African_elephant', 0.9739248), ('n01871265', 'tusker', 0.008072746), ('n02504013', 'Indian_elephant', 0.003881079)]\n" 24 | ] 25 | } 26 | ], 27 | "source": [ 28 | "# Force use CPU only.\n", 29 | "import os\n", 30 | "os.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\"\n", 31 | "os.environ['CUDA_VISIBLE_DEVICES'] = '-1'\n", 32 | "\n", 33 | "import tensorflow as tf\n", 34 | "from tensorflow.keras.applications.inception_v3 import InceptionV3 as Net\n", 35 | "from tensorflow.keras.preprocessing import image\n", 36 | "from tensorflow.keras.applications.inception_v3 import preprocess_input, decode_predictions\n", 37 | "import numpy as np\n", 38 | "\n", 39 | "print(\"TensorFlow version: {}\".format(tf.__version__))\n", 40 | "\n", 41 | "# Optional image to test model prediction.\n", 42 | "img_path = './data/elephant.jpg'\n", 43 | "model_path = './model'\n", 44 | "\n", 45 | "# Path to save the model h5 file.\n", 46 | "model_fname = os.path.join(model_path, 'model.h5')\n", 47 | "\n", 48 | "os.makedirs(model_path, exist_ok=True)\n", 49 | "\n", 50 | "img_height = 299\n", 51 | "\n", 52 | "model = Net(weights='imagenet', input_shape=(img_height, img_height, 3))\n", 53 | "\n", 54 | "\n", 55 | "# Load the image for prediction.\n", 56 | "img = image.load_img(img_path, target_size=(img_height, img_height))\n", 57 | "x = image.img_to_array(img)\n", 58 | "x = np.expand_dims(x, axis=0)\n", 59 | "x = preprocess_input(x)\n", 60 | "\n", 61 | "preds = model.predict(x)\n", 62 | "# decode the results into a list of tuples (class, description, probability)\n", 63 | "# (one such list for each sample in the batch)\n", 64 | "print('Predicted:', decode_predictions(preds, top=3)[0])\n", 65 | "# Predicted: [(u'n02504013', u'Indian_elephant', 0.82658225), (u'n01871265', u'tusker', 0.1122357), (u'n02504458', u'African_elephant', 0.061040461)]\n", 66 | "\n", 67 | "# Save the h5 file to path specified.\n", 68 | "model.save(model_fname)" 69 | ] 70 | }, 71 | { 72 | "cell_type": "markdown", 73 | "metadata": {}, 74 | "source": [ 75 | "## Benchmark Keras prediction speed." 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": 2, 81 | "metadata": {}, 82 | "outputs": [ 83 | { 84 | "name": "stdout", 85 | "output_type": "stream", 86 | "text": [ 87 | "average(sec):0.13158477544784547,fps:7.599663385042268\n" 88 | ] 89 | } 90 | ], 91 | "source": [ 92 | "import time\n", 93 | "times = []\n", 94 | "for i in range(20):\n", 95 | " start_time = time.time()\n", 96 | " preds = model.predict(x)\n", 97 | " delta = (time.time() - start_time)\n", 98 | " times.append(delta)\n", 99 | "mean_delta = np.array(times).mean()\n", 100 | "fps = 1/mean_delta\n", 101 | "print('average(sec):{},fps:{}'.format(mean_delta,fps))\n", 102 | "\n", 103 | "# Clear any previous session.\n", 104 | "tf.keras.backend.clear_session()" 105 | ] 106 | }, 107 | { 108 | "cell_type": "markdown", 109 | "metadata": {}, 110 | "source": [ 111 | "## Freeze graph\n", 112 | "Generate `.pb` file." 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": 3, 118 | "metadata": {}, 119 | "outputs": [ 120 | { 121 | "name": "stdout", 122 | "output_type": "stream", 123 | "text": [ 124 | "WARNING:tensorflow:No training configuration found in save file: the model was *not* compiled. Compile it manually.\n", 125 | "['input_1'] ['predictions/Softmax']\n", 126 | "WARNING:tensorflow:From :16: remove_training_nodes (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.\n", 127 | "Instructions for updating:\n", 128 | "Use tf.compat.v1.graph_util.remove_training_nodes\n", 129 | "WARNING:tensorflow:From :17: convert_variables_to_constants (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.\n", 130 | "Instructions for updating:\n", 131 | "Use tf.compat.v1.graph_util.convert_variables_to_constants\n", 132 | "WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/graph_util_impl.py:245: extract_sub_graph (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.\n", 133 | "Instructions for updating:\n", 134 | "Use tf.compat.v1.graph_util.extract_sub_graph\n", 135 | "INFO:tensorflow:Froze 378 variables.\n", 136 | "INFO:tensorflow:Converted 378 variables to const ops.\n" 137 | ] 138 | } 139 | ], 140 | "source": [ 141 | "# force reset ipython namespaces\n", 142 | "%reset -f\n", 143 | "\n", 144 | "import tensorflow as tf\n", 145 | "from tensorflow.python.framework import graph_io\n", 146 | "from tensorflow.keras.models import load_model\n", 147 | "\n", 148 | "\n", 149 | "# Clear any previous session.\n", 150 | "tf.keras.backend.clear_session()\n", 151 | "\n", 152 | "save_pb_dir = './model'\n", 153 | "model_fname = './model/model.h5'\n", 154 | "def freeze_graph(graph, session, output, save_pb_dir='.', save_pb_name='frozen_model.pb', save_pb_as_text=False):\n", 155 | " with graph.as_default():\n", 156 | " graphdef_inf = tf.graph_util.remove_training_nodes(graph.as_graph_def())\n", 157 | " graphdef_frozen = tf.graph_util.convert_variables_to_constants(session, graphdef_inf, output)\n", 158 | " graph_io.write_graph(graphdef_frozen, save_pb_dir, save_pb_name, as_text=save_pb_as_text)\n", 159 | " return graphdef_frozen\n", 160 | "\n", 161 | "# This line must be executed before loading Keras model.\n", 162 | "tf.keras.backend.set_learning_phase(0) \n", 163 | "\n", 164 | "model = load_model(model_fname)\n", 165 | "\n", 166 | "session = tf.keras.backend.get_session()\n", 167 | "\n", 168 | "INPUT_NODE = [t.op.name for t in model.inputs]\n", 169 | "OUTPUT_NODE = [t.op.name for t in model.outputs]\n", 170 | "print(INPUT_NODE, OUTPUT_NODE)\n", 171 | "frozen_graph = freeze_graph(session.graph, session, [out.op.name for out in model.outputs], save_pb_dir=save_pb_dir)" 172 | ] 173 | }, 174 | { 175 | "cell_type": "markdown", 176 | "metadata": {}, 177 | "source": [ 178 | "## Convert `.pb` file to RKNN model\n", 179 | "\n", 180 | "Run `convert_rknn.py`" 181 | ] 182 | }, 183 | { 184 | "cell_type": "code", 185 | "execution_count": null, 186 | "metadata": {}, 187 | "outputs": [], 188 | "source": [] 189 | } 190 | ], 191 | "metadata": { 192 | "kernelspec": { 193 | "display_name": "Python 3", 194 | "language": "python", 195 | "name": "python3" 196 | }, 197 | "language_info": { 198 | "codemirror_mode": { 199 | "name": "ipython", 200 | "version": 3 201 | }, 202 | "file_extension": ".py", 203 | "mimetype": "text/x-python", 204 | "name": "python", 205 | "nbconvert_exporter": "python", 206 | "pygments_lexer": "ipython3", 207 | "version": "3.5.2" 208 | }, 209 | "toc": { 210 | "base_numbering": 1, 211 | "nav_menu": {}, 212 | "number_sections": true, 213 | "sideBar": true, 214 | "skip_h1_title": false, 215 | "title_cell": "Table of Contents", 216 | "title_sidebar": "Contents", 217 | "toc_cell": false, 218 | "toc_position": {}, 219 | "toc_section_display": true, 220 | "toc_window_display": false 221 | } 222 | }, 223 | "nbformat": 4, 224 | "nbformat_minor": 2 225 | } 226 | --------------------------------------------------------------------------------