├── LICENSE ├── README.md ├── build_engine.py ├── fold_constants.py ├── image_processing.py ├── infer.py ├── labels └── class_labels.txt ├── postprocess_onnx.py └── requirements.txt /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | Copyright 2020 NVIDIA Corporation 180 | 181 | Licensed under the Apache License, Version 2.0 (the "License"); 182 | you may not use this file except in compliance with the License. 183 | You may obtain a copy of the License at 184 | 185 | http://www.apache.org/licenses/LICENSE-2.0 186 | 187 | Unless required by applicable law or agreed to in writing, software 188 | distributed under the License is distributed on an "AS IS" BASIS, 189 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 190 | See the License for the specific language governing permissions and 191 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TensorRT inference of Resnet-50 trained with QAT. 2 | 3 | **Table Of Contents** 4 | - [Description](#description) 5 | - [How does this sample work?](#how-does-this-sample-work) 6 | - [Prerequisites](#prerequisites) 7 | - [Running the sample](#running-the-sample) 8 | * [Step 1: Quantization Aware Training](#step-1-quantization-aware-training) 9 | * [Step 2: Export frozen graph of RN50 QAT](#step-2-export-frozen-graph-of-rn50-qat) 10 | * [Step 3: Constant folding](#step-3-constant-folding) 11 | * [Step 4: TF2ONNX conversion](#step-4-tf2onnx-conversion) 12 | * [Step 5: Post processing ONNX](#step-5-post-processing-onnx) 13 | * [Step 6: Build TensorRT engine from ONNX graph](#step-6-build-tensorrt-engine-from-onnx-graph) 14 | * [Step 7: TensorRT Inference](#step-7-tensorrt-inference) 15 | - [Additional resources](#additional-resources) 16 | - [Changelog](#changelog) 17 | - [Known issues](#known-issues) 18 | - [License](#license) 19 | 20 | ## Description 21 | 22 | This sample demonstrates workflow for training and inference of Resnet-50 model trained using Quantization Aware Training. 23 | The inference implementation is experimental prototype and is provided with no guarantee of support. 24 | 25 | ## How does this sample work? 26 | 27 | This sample demonstrates 28 | 29 | * Training a Resnet-50 model using quantization aware training. 30 | * Post processing and conversion to ONNX graph to ensure it is successfully parsed by TensorRT. 31 | * Inference of Resnet-50 QAT graph with TensorRT. 32 | 33 | ## Prerequisites 34 | 35 | Dependencies required for this sample 36 | 37 | 1. TensorFlow NGC containers (20.01-tf1-py3 NGC container or above for Steps 1-4. Please use `tf1` variants which have TF 1.15.2 version installed. 38 | This sample does not work with public version of Tensorflow 1.15.2 library) 39 | 40 | 2. Install the dependencies for Python3 inside the NGC container. 41 | - For Python 3 users, from the root directory, run: 42 | `python3 -m pip install -r requirements.txt` 43 | 44 | 3. TensorRT-7.1 45 | 46 | 4. ONNX-Graphsurgeon 0.2.1 47 | 48 | ## Running the sample 49 | 50 | ***NOTE: Steps 1-4 require NGC containers (TensorFlow 20.01-tf1-py3 NGC container or above). Steps 5-7 can be executed within or outside the NGC container*** 51 | 52 | ### Step 1: Quantization Aware Training 53 | 54 | Please follow detailed instructions on how to finetune a RN50 model using QAT. 55 | 56 | This stage involoves 57 | 58 | * Finetune a RN50 model with quantization nodes and save the final checkpoint. 59 | * Post process the above RN50 QAT checkpoint by reshaping the weights of final FC layer into a 1x1 conv layer. 60 | 61 | ### Step 2: Export frozen graph of RN50 QAT 62 | 63 | Export the RN50 QAT graph replacing the final FC layer with a 1x1 conv layer. 64 | Please follow these instructions to generate a frozen graph in desired data formats. 65 | 66 | ### Step 3: Constant folding 67 | 68 | Once we have the frozen graph from Step 2, run the following command to perform constant folding on TF graph 69 | ``` 70 | python fold_constants.py --input --output 71 | ``` 72 | 73 | Arguments: 74 | * `--input` : Input Tensorflow graph 75 | * `--output_node` : Output node name of the RN50 graph (Default: `resnet50_v1.5/output/softmax_1`) 76 | * `--output` : Output name of constant folded TF graph. 77 | 78 | ### Step 4: TF2ONNX conversion 79 | 80 | TF2ONNX converter is used to convert the constant folded tensorflow frozen graph into ONNX graph. For RN50 QAT, `tf.quantization.quantize_and_dequantize` operation (QDQ) is converted into `QuantizeLinear` and `DequantizeLinear` operations. 81 | Support for converting QDQ operations has been added in `1.6.1` version of TF2ONNX. 82 | 83 | Command to convert RN50 QAT TF graph to ONNX 84 | ``` 85 | python3 -m tf2onnx.convert --input --output --inputs input:0 --outputs resnet50/output/softmax_1:0 --opset 11 86 | ``` 87 | 88 | Arguments: 89 | * `--input` : Name of TF input graph 90 | * `--output` : Name of ONNX output graph 91 | * `--inputs` : Name of input tensors 92 | * `--outputs` : Name of output tensors 93 | * `--opset` : ONNX opset version 94 | 95 | ### Step 5: Post processing ONNX 96 | 97 | Run the following command to postprocess the ONNX graph using ONNX-Graphsurgeon API. This step removes the `transpose` nodes after `Dequantize` nodes. 98 | ``` 99 | python postprocess_onnx.py --input --output 100 | ``` 101 | 102 | Arguments: 103 | * `--input` : Input ONNX graph 104 | * `--output` : Output name of postprocessed ONNX graph. 105 | 106 | ### Step 6: Build TensorRT engine from ONNX graph 107 | ``` 108 | python build_engine.py --onnx 109 | ``` 110 | 111 | Arguments: 112 | * `--onnx` : Path to RN50 QAT onnx graph 113 | * `--engine` : Output file name of TensorRT engine. 114 | * `--verbose` : Flag to enable verbose logging 115 | 116 | ### Step 7: TensorRT Inference 117 | 118 | Command to run inference on a sample image 119 | 120 | ``` 121 | python infer.py --engine 122 | ``` 123 | 124 | Arguments: 125 | * `--engine` : Path to input RN50 TensorRT engine. 126 | * `--labels` : Path to imagenet 1k labels text file provided. 127 | * `--image` : Path to the sample image 128 | * `--verbose` : Flag to enable verbose logging 129 | 130 | ### Sample --help options 131 | 132 | To see the full list of available options and their descriptions, use the `-h` or `--help` command line option. For example: 133 | ``` 134 | usage: .py> [-h] 135 | ``` 136 | 137 | # Additional resources 138 | 139 | The following resources provide a deeper understanding about Quantization aware training, TF2ONNX and importing a model into TensorRT using Python: 140 | 141 | **Quantization Aware Training** 142 | - [Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference](https://arxiv.org/pdf/1712.05877.pdf) 143 | - [Quantization Aware Training guide](https://www.tensorflow.org/model_optimization/guide/quantization/training) 144 | - [Resnet-50 Deep Learning Example](https://github.com/NVIDIA/DeepLearningExamples/blob/master/TensorFlow/Classification/ConvNets/resnet50v1.5/README.md) 145 | - [Deep Residual Learning for Image Recognition](https://arxiv.org/pdf/1512.03385.pdf) 146 | 147 | **Parsers** 148 | - [TF2ONNX Converter](https://github.com/onnx/tensorflow-onnx) 149 | - [ONNX Parser](https://docs.nvidia.com/deeplearning/sdk/tensorrt-api/python_api/parsers/Onnx/pyOnnx.html) 150 | 151 | **Documentation** 152 | - [Introduction To NVIDIA’s TensorRT Samples](https://docs.nvidia.com/deeplearning/sdk/tensorrt-sample-support-guide/index.html#samples) 153 | - [Working With TensorRT Using The Python API](https://docs.nvidia.com/deeplearning/sdk/tensorrt-developer-guide/index.html#python_topics) 154 | - [Importing A Model Using A Parser In Python](https://docs.nvidia.com/deeplearning/sdk/tensorrt-developer-guide/index.html#import_model_python) 155 | - [NVIDIA’s TensorRT Documentation Library](https://docs.nvidia.com/deeplearning/sdk/tensorrt-archived/index.html) 156 | 157 | # Changelog 158 | 159 | June 2020: Initial release of this sample 160 | 161 | # Known issues 162 | 163 | Tensorflow operation `tf.quantization.quantize_and_dequantize` is used for quantization during training. The gradient of this operation is not clipped based on input range. 164 | 165 | # License 166 | 167 | The sampleQAT license can be found in the LICENSE file. 168 | 169 | -------------------------------------------------------------------------------- /build_engine.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2020 NVIDIA Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import tensorrt as trt 17 | import pycuda.driver as cuda 18 | import pycuda.autoinit 19 | import numpy as np 20 | import argparse 21 | 22 | def build_profile(builder, network, profile_shapes, default_shape_value=1): 23 | """ 24 | Build optimization profile for the builder and configure the min, opt, max shapes appropriately. 25 | """ 26 | def is_dimension_dynamic(dim): 27 | return dim is None or dim <= 0 28 | 29 | def override_shape(shape): 30 | return tuple([1 if is_dimension_dynamic(dim) else dim for dim in shape]) 31 | 32 | profile = builder.create_optimization_profile() 33 | for idx in range(network.num_inputs): 34 | inp = network.get_input(idx) 35 | 36 | def get_profile_shape(name): 37 | if name not in profile_shapes: 38 | return None 39 | shapes = profile_shapes[name] 40 | if not isinstance(shapes, list) or len(shapes) != 3: 41 | G_LOGGER.critical("Profile values must be a list containing exactly 3 shapes (tuples or Dims), but received shapes: {:} for input: {:}.\nNote: profile was: {:}.\nNote: Network inputs were: {:}".format(shapes, name, profile_shapes, get_network_inputs(network))) 42 | return shapes 43 | 44 | if inp.is_shape_tensor: 45 | shapes = get_profile_shape(inp.name) 46 | if not shapes: 47 | rank = inp.shape[0] 48 | shapes = [(default_shape_value, ) * rank] * 3 49 | print("Setting shape input to {:}. If this is incorrect, for shape input: {:}, please provide tuples for min, opt, and max shapes containing {:} elements".format(shapes[0], inp.name, rank)) 50 | min, opt, max = shapes 51 | profile.set_shape_input(inp.name, min, opt, max) 52 | print("Setting shape input: {:} values to min: {:}, opt: {:}, max: {:}".format(inp.name, min, opt, max)) 53 | elif -1 in inp.shape: 54 | shapes = get_profile_shape(inp.name) 55 | if not shapes: 56 | shapes = [override_shape(inp.shape)] * 3 57 | print("Overriding dynamic input shape {:} to {:}. If this is incorrect, for input tensor: {:}, please provide tuples for min, opt, and max shapes containing values: {:} with dynamic dimensions replaced,".format(inp.shape, shapes[0], inp.name, inp.shape)) 58 | min, opt, max = shapes 59 | profile.set_shape(inp.name, min, opt, max) 60 | print("Setting input: {:} shape to min: {:}, opt: {:}, max: {:}".format(inp.name, min, opt, max)) 61 | if not profile: 62 | print("Profile is not valid, please provide profile data. Note: profile was: {:}".format(profile_shapes)) 63 | return profile 64 | 65 | def preprocess_network(network): 66 | """ 67 | Add quantize and dequantize nodes after the input placeholder. 68 | The scale values are currently picked on emperical basis. Ideally, 69 | you need to add these nodes during quantization aware training and 70 | learn the dynamic ranges of input node. 71 | """ 72 | quant_scale = np.array([1.0/127.0], dtype=np.float32) 73 | dequant_scale = np.array([127.0/1.0], dtype=np.float32) 74 | # Zero point is always zero for quantization in TensorRT. 75 | zeros = np.zeros(shape=(1, ), dtype=np.float32) 76 | 77 | for i in range(network.num_inputs): 78 | inp = network.get_input(i) 79 | # Find layer consuming input tensor 80 | found = False 81 | for layer in network: 82 | if found: 83 | break; 84 | 85 | for k in range(layer.num_inputs): 86 | if (inp == layer.get_input(k)): 87 | mode = trt.ScaleMode.UNIFORM 88 | quantize = network.add_scale(inp, mode, scale=quant_scale, shift=zeros) 89 | quantize.set_output_type(0, trt.int8) 90 | quantize.name = "InputQuantizeNode" 91 | quantize.get_output(0).name = "QuantizedInput" 92 | dequantize = network.add_scale(quantize.get_output(0), mode, scale=dequant_scale, shift=zeros) 93 | dequantize.set_output_type(0, trt.float32) 94 | dequantize.name = "InputDequantizeNode" 95 | dequantize.get_output(0).name = "DequantizedInput" 96 | layer.set_input(k, dequantize.get_output(0)) 97 | found = True 98 | break 99 | 100 | def build_engine_onnx(model_file, verbose=False): 101 | """ 102 | Parse the model file through TensorRT, build TRT engine and run inference 103 | """ 104 | # Create builder and network 105 | if verbose: 106 | TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE) 107 | else: 108 | TRT_LOGGER = trt.Logger(trt.Logger.INFO) 109 | 110 | network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) 111 | network_flags = network_flags | (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_PRECISION)) 112 | 113 | with trt.Builder(TRT_LOGGER) as builder, builder.create_network(flags=network_flags) as network, trt.OnnxParser(network, TRT_LOGGER) as parser: 114 | with open(model_file, 'rb') as model: 115 | if not parser.parse(model.read()): 116 | print ('ERROR: Failed to parse the ONNX file.') 117 | for error in range(parser.num_errors): 118 | print (parser.get_error(error)) 119 | return None 120 | 121 | # Add quantize and dequantize nodes for input of the network 122 | preprocess_network(network) 123 | config = builder.create_builder_config() 124 | config.max_workspace_size = 1 << 30 125 | config.flags = config.flags | 1 << int(trt.BuilderFlag.INT8) 126 | # Setting the (min, opt, max) batch sizes to be 1. Users need to configure this according to their requirements. 127 | config.add_optimization_profile(build_profile(builder, network, profile_shapes={'input' : [(1, 3, 224, 224),(1, 3, 224, 224),(1, 3, 224, 224)]})) 128 | 129 | return builder.build_engine(network, config) 130 | 131 | def main(args): 132 | 133 | model_file = args.onnx 134 | # Parse the ONNX graph through TensorRT and build the engine 135 | trt_engine = build_engine_onnx(model_file, args.verbose) 136 | # Serialize the engine and save to file 137 | with open(args.engine, "wb") as file: 138 | file.write(trt_engine.serialize()) 139 | 140 | 141 | if __name__ == '__main__': 142 | parser = argparse.ArgumentParser() 143 | parser.add_argument("--onnx", type=str, default='rn50.onnx', help="Path to RN50 ONNX graph") 144 | parser.add_argument("--engine", type=str, default='rn50_trt.engine', help="output path to TensorRT engine") 145 | parser.add_argument('-v', '--verbose', action='store_true', help="Flag to enable verbose logging") 146 | args = parser.parse_args() 147 | main(args) 148 | -------------------------------------------------------------------------------- /fold_constants.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2020 NVIDIA Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import argparse 17 | from tensorflow.core.protobuf import config_pb2, rewriter_config_pb2, meta_graph_pb2 18 | from tensorflow.core.framework import graph_pb2 19 | from tensorflow.python.framework import importer, ops 20 | from tensorflow.python.grappler import tf_optimizer 21 | from tensorflow.python.training import saver 22 | 23 | 24 | def constfold(graphdef, output_name): 25 | graph = ops.Graph() 26 | with graph.as_default(): 27 | outputs = output_name.split(',') 28 | output_collection = meta_graph_pb2.CollectionDef() 29 | output_list = output_collection.node_list.value 30 | for output in outputs: 31 | output_list.append(output) 32 | importer.import_graph_def(graphdef, name="") 33 | metagraph = saver.export_meta_graph(graph_def=graph.as_graph_def(add_shapes=True), graph=graph) 34 | metagraph.collection_def["train_op"].CopyFrom(output_collection) 35 | 36 | rewriter_config = rewriter_config_pb2.RewriterConfig() 37 | rewriter_config.optimizers.extend(["constfold"]) 38 | rewriter_config.meta_optimizer_iterations = (rewriter_config_pb2.RewriterConfig.ONE) 39 | session_config = config_pb2.ConfigProto() 40 | session_config.graph_options.rewrite_options.CopyFrom(rewriter_config) 41 | 42 | return tf_optimizer.OptimizeGraph(session_config, metagraph) 43 | 44 | if __name__ == '__main__': 45 | parser = argparse.ArgumentParser("Folds constants in the provided frozen model") 46 | parser.add_argument("-i", "--input", help="The input frozen model to be constant folded.") 47 | parser.add_argument("--output_node", default="resnet50/output/softmax_1", help="Output node names separated by commas") 48 | parser.add_argument("-o", "--output", default="folded_rn50.pb", help="Path to constant folded output graph") 49 | args, _ = parser.parse_known_args() 50 | 51 | with open(args.input, 'rb') as f: 52 | graphdef = graph_pb2.GraphDef() 53 | graphdef.ParseFromString(f.read()) 54 | 55 | folded_graph = constfold(graphdef, args.output_node) 56 | print("Writing output to {:}".format(args.output)) 57 | with open(args.output, "wb") as f: 58 | f.write(folded_graph.SerializeToString()) 59 | -------------------------------------------------------------------------------- /image_processing.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 NVIDIA Corporation 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import logging 16 | 17 | import numpy as np 18 | from PIL import Image 19 | 20 | 21 | logging.basicConfig(level=logging.DEBUG, 22 | format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", 23 | datefmt="%Y-%m-%d %H:%M:%S") 24 | logger = logging.getLogger(__name__) 25 | 26 | _RESIZE_MIN = 256 27 | _R_MEAN = 123.68 28 | _G_MEAN = 116.78 29 | _B_MEAN = 103.94 30 | _CHANNEL_MEANS = [_R_MEAN, _G_MEAN, _B_MEAN] 31 | 32 | 33 | def preprocess_imagenet(image, channels=3, height=224, width=224): 34 | """Pre-processing for Imagenet-based Image Classification Models: 35 | resnet50, vgg16, mobilenet, etc. (Doesn't seem to work for Inception) 36 | 37 | Parameters 38 | ---------- 39 | image: PIL.Image 40 | The image resulting from PIL.Image.open(filename) to preprocess 41 | channels: int 42 | The number of channels the image has (Usually 1 or 3) 43 | height: int 44 | The desired height of the image (usually 224 for Imagenet data) 45 | width: int 46 | The desired width of the image (usually 224 for Imagenet data) 47 | 48 | Returns 49 | ------- 50 | img_data: numpy array 51 | The preprocessed image data in the form of a numpy array 52 | 53 | """ 54 | # Get the image in CHW format 55 | resized_image = image.resize((width, height), Image.ANTIALIAS) 56 | img_data = np.asarray(resized_image).astype(np.float32) 57 | 58 | if len(img_data.shape) == 2: 59 | # For images without a channel dimension, we stack 60 | img_data = np.stack([img_data] * 3) 61 | logger.debug("Received grayscale image. Reshaped to {:}".format(img_data.shape)) 62 | else: 63 | img_data = img_data.transpose([2, 0, 1]) 64 | 65 | mean_vec = np.array([0.485, 0.456, 0.406]) 66 | stddev_vec = np.array([0.229, 0.224, 0.225]) 67 | assert img_data.shape[0] == channels 68 | 69 | for i in range(img_data.shape[0]): 70 | # Scale each pixel to [0, 1] and normalize per channel. 71 | img_data[i, :, :] = (img_data[i, :, :] / 255 - mean_vec[i]) / stddev_vec[i] 72 | 73 | return img_data 74 | 75 | def _smallest_size_at_least(height, width, resize_min): 76 | 77 | smaller_dim = np.minimum(float(height), float(width)) 78 | scale_ratio = resize_min / smaller_dim 79 | 80 | # Convert back to ints to make heights and widths that TF ops will accept. 81 | new_height = int(height * scale_ratio) 82 | new_width = int(width * scale_ratio) 83 | 84 | return new_height, new_width 85 | 86 | def _central_crop(image, crop_height, crop_width): 87 | shape = image.shape 88 | height, width = shape[0], shape[1] 89 | 90 | amount_to_be_cropped_h = (height - crop_height) 91 | crop_top = amount_to_be_cropped_h // 2 92 | amount_to_be_cropped_w = (width - crop_width) 93 | crop_left = amount_to_be_cropped_w // 2 94 | cropped_image = image[crop_top:crop_height+crop_top, crop_left:crop_width+crop_left] 95 | return cropped_image 96 | 97 | def normalize_inputs(inputs): 98 | 99 | num_channels = inputs.shape[-1] 100 | 101 | if len(_CHANNEL_MEANS) != num_channels: 102 | raise ValueError('len(means) must match the number of channels') 103 | 104 | # We have a 1-D tensor of means; convert to 3-D. 105 | means_per_channel = np.reshape(_CHANNEL_MEANS, [1, 1, num_channels]) 106 | # means_per_channel = tf.cast(means_per_channel, dtype=inputs.dtype) 107 | 108 | inputs = np.subtract(inputs, means_per_channel)/255.0 109 | 110 | return inputs 111 | 112 | 113 | def preprocess_resnet50(image, channels=3, height=224, width=224): 114 | """Pre-processing for Imagenet-based Image Classification Models: 115 | resnet50 (resnet_v1_1.5 designed by Nvidia 116 | Parameters 117 | ---------- 118 | image: PIL.Image 119 | The image resulting from PIL.Image.open(filename) to preprocess 120 | channels: int 121 | The number of channels the image has (Usually 1 or 3) 122 | height: int 123 | The desired height of the image (usually 224 for Imagenet data) 124 | width: int 125 | The desired width of the image (usually 224 for Imagenet data) 126 | 127 | Returns 128 | ------- 129 | img_data: numpy array 130 | The preprocessed image data in the form of a numpy array 131 | 132 | """ 133 | # Get the shape of the image. 134 | w, h= image.size 135 | 136 | new_height, new_width = _smallest_size_at_least(h, w, _RESIZE_MIN) 137 | 138 | # Image is still in WH format in PIL 139 | resized_image = image.resize((new_width, new_height), Image.BILINEAR) 140 | # Changes to HWC due to numpy 141 | img_data = np.asarray(resized_image).astype(np.float32) 142 | # Do a central crop 143 | cropped_image = _central_crop(img_data, height, width) 144 | assert cropped_image.shape[0] == height 145 | assert cropped_image.shape[1] == width 146 | if len(cropped_image.shape) == 2: 147 | # For images without a channel dimension, we stack 148 | cropped_image = np.stack([cropped_image] * 3) 149 | return cropped_image 150 | # logger.debug("Received grayscale image. Reshaped to {:}".format(cropped_image.shape)) 151 | 152 | normalized_inputs = normalize_inputs(cropped_image) 153 | cropped_image = np.transpose(normalized_inputs, [2, 0, 1]) 154 | 155 | return cropped_image 156 | 157 | def preprocess_inception(image, channels=3, height=224, width=224): 158 | """Pre-processing for InceptionV1. Inception expects different pre-processing 159 | than {resnet50, vgg16, mobilenet}. This may not be totally correct, 160 | but it worked for some simple test images. 161 | 162 | Parameters 163 | ---------- 164 | image: PIL.Image 165 | The image resulting from PIL.Image.open(filename) to preprocess 166 | channels: int 167 | The number of channels the image has (Usually 1 or 3) 168 | height: int 169 | The desired height of the image (usually 224 for Imagenet data) 170 | width: int 171 | The desired width of the image (usually 224 for Imagenet data) 172 | 173 | Returns 174 | ------- 175 | img_data: numpy array 176 | The preprocessed image data in the form of a numpy array 177 | 178 | """ 179 | # Get the image in CHW format 180 | resized_image = image.resize((width, height), Image.BILINEAR) 181 | img_data = np.asarray(resized_image).astype(np.float32) 182 | 183 | if len(img_data.shape) == 2: 184 | # For images without a channel dimension, we stack 185 | img_data = np.stack([img_data] * 3) 186 | logger.debug("Received grayscale image. Reshaped to {:}".format(img_data.shape)) 187 | else: 188 | img_data = img_data.transpose([2, 0, 1]) 189 | 190 | return img_data 191 | -------------------------------------------------------------------------------- /infer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2020 NVIDIA Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import os 17 | import argparse 18 | import PIL.Image 19 | import numpy as np 20 | import tensorrt as trt 21 | import pycuda.driver as cuda 22 | import pycuda.autoinit 23 | import image_processing 24 | 25 | TRT_DYNAMIC_DIM = -1 26 | 27 | def load_normalized_test_case(test_image, pagelocked_buffer, preprocess_func): 28 | # Expected input dimensions 29 | C, H, W = (3, 224, 224) 30 | # Normalize the images, concatenate them and copy to pagelocked memory. 31 | data = np.asarray([preprocess_func(PIL.Image.open(test_image).convert('RGB'), C, H, W)]).flatten() 32 | np.copyto(pagelocked_buffer, data) 33 | 34 | class HostDeviceMem(object): 35 | r""" Simple helper data class that's a little nicer to use than a 2-tuple. 36 | """ 37 | def __init__(self, host_mem, device_mem): 38 | self.host = host_mem 39 | self.device = device_mem 40 | 41 | def __str__(self): 42 | return "Host:\n" + str(self.host) + "\nDevice:\n" + str(self.device) 43 | 44 | def __repr__(self): 45 | return self.__str__() 46 | 47 | 48 | def allocate_buffers(engine: trt.ICudaEngine, batch_size: int): 49 | print('Allocating buffers ...') 50 | 51 | inputs = [] 52 | outputs = [] 53 | dbindings = [] 54 | 55 | stream = cuda.Stream() 56 | 57 | for binding in engine: 58 | size = batch_size * abs(trt.volume(engine.get_binding_shape(binding))) 59 | dtype = trt.nptype(engine.get_binding_dtype(binding)) 60 | # Allocate host and device buffers 61 | host_mem = cuda.pagelocked_empty(size, dtype) 62 | device_mem = cuda.mem_alloc(host_mem.nbytes) 63 | # Append the device buffer to device bindings. 64 | dbindings.append(int(device_mem)) 65 | 66 | # Append to the appropriate list. 67 | if engine.binding_is_input(binding): 68 | inputs.append(HostDeviceMem(host_mem, device_mem)) 69 | else: 70 | outputs.append(HostDeviceMem(host_mem, device_mem)) 71 | 72 | return inputs, outputs, dbindings, stream 73 | 74 | def infer(engine_path, preprocess_func, batch_size, input_image, labels=[], verbose=False): 75 | 76 | if verbose: 77 | logger = trt.Logger(trt.Logger.VERBOSE) 78 | else: 79 | logger = trt.Logger(trt.Logger.INFO) 80 | 81 | with open(engine_path, 'rb') as f, trt.Runtime(logger) as runtime: 82 | engine = runtime.deserialize_cuda_engine(f.read()) 83 | 84 | def override_shape(shape, batch_size): 85 | return tuple([batch_size if dim==TRT_DYNAMIC_DIM else dim for dim in shape]) 86 | 87 | # Allocate buffers and create a CUDA stream. 88 | inputs, outputs, dbindings, stream = allocate_buffers(engine, batch_size) 89 | 90 | # Contexts are used to perform inference. 91 | with engine.create_execution_context() as context: 92 | 93 | # Resolve dynamic shapes in the context 94 | for binding in engine: 95 | binding_idx = engine.get_binding_index(binding) 96 | shape = engine.get_binding_shape(binding_idx) 97 | if engine.binding_is_input(binding_idx): 98 | if TRT_DYNAMIC_DIM in shape: 99 | shape = override_shape(shape, batch_size) 100 | context.set_binding_shape(binding_idx, shape) 101 | 102 | # Load the test images and preprocess them 103 | load_normalized_test_case(input_image, inputs[0].host, preprocess_func) 104 | 105 | # Transfer input data to the GPU. 106 | cuda.memcpy_htod(inputs[0].device, inputs[0].host) 107 | # Run inference. 108 | context.execute(batch_size, dbindings) 109 | # Transfer predictions back to host from GPU 110 | out = outputs[0] 111 | cuda.memcpy_dtoh(out.host, out.device) 112 | 113 | softmax_output = np.array(out.host) 114 | top1_idx = np.argmax(softmax_output) 115 | output_class = labels[top1_idx+1] 116 | output_confidence = softmax_output[top1_idx] 117 | 118 | print ("Output class of the image: {} Confidence: {}".format(output_class, output_confidence)) 119 | 120 | if __name__ == '__main__': 121 | parser = argparse.ArgumentParser(description='Run inference on TensorRT engines for Imagenet-based Classification models.') 122 | parser.add_argument('-e', '--engine', type=str, required=True, 123 | help='Path to RN50 TensorRT engine') 124 | parser.add_argument('-i', '--image', required=True, type=str, 125 | help="Path to input image.") 126 | parser.add_argument("-l", "--labels", type=str, default=os.path.join("labels", "class_labels.txt"), 127 | help="Path to file which has imagenet 1k labels.") 128 | parser.add_argument('-b', '--batch_size', default=1, type=int, 129 | help="Batch size of inputs") 130 | parser.add_argument('-v', '--verbose', action='store_true', 131 | help="Flag to enable verbose loggin") 132 | args = parser.parse_args() 133 | 134 | # Class 0 is not used and is treated as background class. Renaming it to "background" 135 | with open(args.labels, "r") as f: 136 | background_class = ["background"] 137 | imagenet_synsets = f.read().splitlines() 138 | imagenet_classes=[] 139 | for synset in imagenet_synsets: 140 | class_name = synset.strip() 141 | imagenet_classes.append(class_name) 142 | all_classes = background_class + imagenet_classes 143 | labels = np.array(all_classes) 144 | 145 | # Preprocessing for input images 146 | preprocess_func = image_processing.preprocess_resnet50 147 | 148 | # Run inference on the test image 149 | infer(args.engine, preprocess_func, args.batch_size, args.image, labels, args.verbose) 150 | 151 | -------------------------------------------------------------------------------- /labels/class_labels.txt: -------------------------------------------------------------------------------- 1 | tench 2 | goldfish 3 | great white shark 4 | tiger shark 5 | hammerhead 6 | electric ray 7 | stingray 8 | cock 9 | hen 10 | ostrich 11 | brambling 12 | goldfinch 13 | house finch 14 | junco 15 | indigo bunting 16 | robin 17 | bulbul 18 | jay 19 | magpie 20 | chickadee 21 | water ouzel 22 | kite 23 | bald eagle 24 | vulture 25 | great grey owl 26 | European fire salamander 27 | common newt 28 | eft 29 | spotted salamander 30 | axolotl 31 | bullfrog 32 | tree frog 33 | tailed frog 34 | loggerhead 35 | leatherback turtle 36 | mud turtle 37 | terrapin 38 | box turtle 39 | banded gecko 40 | common iguana 41 | American chameleon 42 | whiptail 43 | agama 44 | frilled lizard 45 | alligator lizard 46 | Gila monster 47 | green lizard 48 | African chameleon 49 | Komodo dragon 50 | African crocodile 51 | American alligator 52 | triceratops 53 | thunder snake 54 | ringneck snake 55 | hognose snake 56 | green snake 57 | king snake 58 | garter snake 59 | water snake 60 | vine snake 61 | night snake 62 | boa constrictor 63 | rock python 64 | Indian cobra 65 | green mamba 66 | sea snake 67 | horned viper 68 | diamondback 69 | sidewinder 70 | trilobite 71 | harvestman 72 | scorpion 73 | black and gold garden spider 74 | barn spider 75 | garden spider 76 | black widow 77 | tarantula 78 | wolf spider 79 | tick 80 | centipede 81 | black grouse 82 | ptarmigan 83 | ruffed grouse 84 | prairie chicken 85 | peacock 86 | quail 87 | partridge 88 | African grey 89 | macaw 90 | sulphur-crested cockatoo 91 | lorikeet 92 | coucal 93 | bee eater 94 | hornbill 95 | hummingbird 96 | jacamar 97 | toucan 98 | drake 99 | red-breasted merganser 100 | goose 101 | black swan 102 | tusker 103 | echidna 104 | platypus 105 | wallaby 106 | koala 107 | wombat 108 | jellyfish 109 | sea anemone 110 | brain coral 111 | flatworm 112 | nematode 113 | conch 114 | snail 115 | slug 116 | sea slug 117 | chiton 118 | chambered nautilus 119 | Dungeness crab 120 | rock crab 121 | fiddler crab 122 | king crab 123 | American lobster 124 | spiny lobster 125 | crayfish 126 | hermit crab 127 | isopod 128 | white stork 129 | black stork 130 | spoonbill 131 | flamingo 132 | little blue heron 133 | American egret 134 | bittern 135 | crane 136 | limpkin 137 | European gallinule 138 | American coot 139 | bustard 140 | ruddy turnstone 141 | red-backed sandpiper 142 | redshank 143 | dowitcher 144 | oystercatcher 145 | pelican 146 | king penguin 147 | albatross 148 | grey whale 149 | killer whale 150 | dugong 151 | sea lion 152 | Chihuahua 153 | Japanese spaniel 154 | Maltese dog 155 | Pekinese 156 | Shih-Tzu 157 | Blenheim spaniel 158 | papillon 159 | toy terrier 160 | Rhodesian ridgeback 161 | Afghan hound 162 | basset 163 | beagle 164 | bloodhound 165 | bluetick 166 | black-and-tan coonhound 167 | Walker hound 168 | English foxhound 169 | redbone 170 | borzoi 171 | Irish wolfhound 172 | Italian greyhound 173 | whippet 174 | Ibizan hound 175 | Norwegian elkhound 176 | otterhound 177 | Saluki 178 | Scottish deerhound 179 | Weimaraner 180 | Staffordshire bullterrier 181 | American Staffordshire terrier 182 | Bedlington terrier 183 | Border terrier 184 | Kerry blue terrier 185 | Irish terrier 186 | Norfolk terrier 187 | Norwich terrier 188 | Yorkshire terrier 189 | wire-haired fox terrier 190 | Lakeland terrier 191 | Sealyham terrier 192 | Airedale 193 | cairn 194 | Australian terrier 195 | Dandie Dinmont 196 | Boston bull 197 | miniature schnauzer 198 | giant schnauzer 199 | standard schnauzer 200 | Scotch terrier 201 | Tibetan terrier 202 | silky terrier 203 | soft-coated wheaten terrier 204 | West Highland white terrier 205 | Lhasa 206 | flat-coated retriever 207 | curly-coated retriever 208 | golden retriever 209 | Labrador retriever 210 | Chesapeake Bay retriever 211 | German short-haired pointer 212 | vizsla 213 | English setter 214 | Irish setter 215 | Gordon setter 216 | Brittany spaniel 217 | clumber 218 | English springer 219 | Welsh springer spaniel 220 | cocker spaniel 221 | Sussex spaniel 222 | Irish water spaniel 223 | kuvasz 224 | schipperke 225 | groenendael 226 | malinois 227 | briard 228 | kelpie 229 | komondor 230 | Old English sheepdog 231 | Shetland sheepdog 232 | collie 233 | Border collie 234 | Bouvier des Flandres 235 | Rottweiler 236 | German shepherd 237 | Doberman 238 | miniature pinscher 239 | Greater Swiss Mountain dog 240 | Bernese mountain dog 241 | Appenzeller 242 | EntleBucher 243 | boxer 244 | bull mastiff 245 | Tibetan mastiff 246 | French bulldog 247 | Great Dane 248 | Saint Bernard 249 | Eskimo dog 250 | malamute 251 | Siberian husky 252 | dalmatian 253 | affenpinscher 254 | basenji 255 | pug 256 | Leonberg 257 | Newfoundland 258 | Great Pyrenees 259 | Samoyed 260 | Pomeranian 261 | chow 262 | keeshond 263 | Brabancon griffon 264 | Pembroke 265 | Cardigan 266 | toy poodle 267 | miniature poodle 268 | standard poodle 269 | Mexican hairless 270 | timber wolf 271 | white wolf 272 | red wolf 273 | coyote 274 | dingo 275 | dhole 276 | African hunting dog 277 | hyena 278 | red fox 279 | kit fox 280 | Arctic fox 281 | grey fox 282 | tabby 283 | tiger cat 284 | Persian cat 285 | Siamese cat 286 | Egyptian cat 287 | cougar 288 | lynx 289 | leopard 290 | snow leopard 291 | jaguar 292 | lion 293 | tiger 294 | cheetah 295 | brown bear 296 | American black bear 297 | ice bear 298 | sloth bear 299 | mongoose 300 | meerkat 301 | tiger beetle 302 | ladybug 303 | ground beetle 304 | long-horned beetle 305 | leaf beetle 306 | dung beetle 307 | rhinoceros beetle 308 | weevil 309 | fly 310 | bee 311 | ant 312 | grasshopper 313 | cricket 314 | walking stick 315 | cockroach 316 | mantis 317 | cicada 318 | leafhopper 319 | lacewing 320 | dragonfly 321 | damselfly 322 | admiral 323 | ringlet 324 | monarch 325 | cabbage butterfly 326 | sulphur butterfly 327 | lycaenid 328 | starfish 329 | sea urchin 330 | sea cucumber 331 | wood rabbit 332 | hare 333 | Angora 334 | hamster 335 | porcupine 336 | fox squirrel 337 | marmot 338 | beaver 339 | guinea pig 340 | sorrel 341 | zebra 342 | hog 343 | wild boar 344 | warthog 345 | hippopotamus 346 | ox 347 | water buffalo 348 | bison 349 | ram 350 | bighorn 351 | ibex 352 | hartebeest 353 | impala 354 | gazelle 355 | Arabian camel 356 | llama 357 | weasel 358 | mink 359 | polecat 360 | black-footed ferret 361 | otter 362 | skunk 363 | badger 364 | armadillo 365 | three-toed sloth 366 | orangutan 367 | gorilla 368 | chimpanzee 369 | gibbon 370 | siamang 371 | guenon 372 | patas 373 | baboon 374 | macaque 375 | langur 376 | colobus 377 | proboscis monkey 378 | marmoset 379 | capuchin 380 | howler monkey 381 | titi 382 | spider monkey 383 | squirrel monkey 384 | Madagascar cat 385 | indri 386 | Indian elephant 387 | African elephant 388 | lesser panda 389 | giant panda 390 | barracouta 391 | eel 392 | coho 393 | rock beauty 394 | anemone fish 395 | sturgeon 396 | gar 397 | lionfish 398 | puffer 399 | abacus 400 | abaya 401 | academic gown 402 | accordion 403 | acoustic guitar 404 | aircraft carrier 405 | airliner 406 | airship 407 | altar 408 | ambulance 409 | amphibian 410 | analog clock 411 | apiary 412 | apron 413 | ashcan 414 | assault rifle 415 | backpack 416 | bakery 417 | balance beam 418 | balloon 419 | ballpoint 420 | Band Aid 421 | banjo 422 | bannister 423 | barbell 424 | barber chair 425 | barbershop 426 | barn 427 | barometer 428 | barrel 429 | barrow 430 | baseball 431 | basketball 432 | bassinet 433 | bassoon 434 | bathing cap 435 | bath towel 436 | bathtub 437 | beach wagon 438 | beacon 439 | beaker 440 | bearskin 441 | beer bottle 442 | beer glass 443 | bell cote 444 | bib 445 | bicycle-built-for-two 446 | bikini 447 | binder 448 | binoculars 449 | birdhouse 450 | boathouse 451 | bobsled 452 | bolo tie 453 | bonnet 454 | bookcase 455 | bookshop 456 | bottlecap 457 | bow 458 | bow tie 459 | brass 460 | brassiere 461 | breakwater 462 | breastplate 463 | broom 464 | bucket 465 | buckle 466 | bulletproof vest 467 | bullet train 468 | butcher shop 469 | cab 470 | caldron 471 | candle 472 | cannon 473 | canoe 474 | can opener 475 | cardigan 476 | car mirror 477 | carousel 478 | carpenter's kit 479 | carton 480 | car wheel 481 | cash machine 482 | cassette 483 | cassette player 484 | castle 485 | catamaran 486 | CD player 487 | cello 488 | cellular telephone 489 | chain 490 | chainlink fence 491 | chain mail 492 | chain saw 493 | chest 494 | chiffonier 495 | chime 496 | china cabinet 497 | Christmas stocking 498 | church 499 | cinema 500 | cleaver 501 | cliff dwelling 502 | cloak 503 | clog 504 | cocktail shaker 505 | coffee mug 506 | coffeepot 507 | coil 508 | combination lock 509 | computer keyboard 510 | confectionery 511 | container ship 512 | convertible 513 | corkscrew 514 | cornet 515 | cowboy boot 516 | cowboy hat 517 | cradle 518 | crane 519 | crash helmet 520 | crate 521 | crib 522 | Crock Pot 523 | croquet ball 524 | crutch 525 | cuirass 526 | dam 527 | desk 528 | desktop computer 529 | dial telephone 530 | diaper 531 | digital clock 532 | digital watch 533 | dining table 534 | dishrag 535 | dishwasher 536 | disk brake 537 | dock 538 | dogsled 539 | dome 540 | doormat 541 | drilling platform 542 | drum 543 | drumstick 544 | dumbbell 545 | Dutch oven 546 | electric fan 547 | electric guitar 548 | electric locomotive 549 | entertainment center 550 | envelope 551 | espresso maker 552 | face powder 553 | feather boa 554 | file 555 | fireboat 556 | fire engine 557 | fire screen 558 | flagpole 559 | flute 560 | folding chair 561 | football helmet 562 | forklift 563 | fountain 564 | fountain pen 565 | four-poster 566 | freight car 567 | French horn 568 | frying pan 569 | fur coat 570 | garbage truck 571 | gasmask 572 | gas pump 573 | goblet 574 | go-kart 575 | golf ball 576 | golfcart 577 | gondola 578 | gong 579 | gown 580 | grand piano 581 | greenhouse 582 | grille 583 | grocery store 584 | guillotine 585 | hair slide 586 | hair spray 587 | half track 588 | hammer 589 | hamper 590 | hand blower 591 | hand-held computer 592 | handkerchief 593 | hard disc 594 | harmonica 595 | harp 596 | harvester 597 | hatchet 598 | holster 599 | home theater 600 | honeycomb 601 | hook 602 | hoopskirt 603 | horizontal bar 604 | horse cart 605 | hourglass 606 | iPod 607 | iron 608 | jack-o'-lantern 609 | jean 610 | jeep 611 | jersey 612 | jigsaw puzzle 613 | jinrikisha 614 | joystick 615 | kimono 616 | knee pad 617 | knot 618 | lab coat 619 | ladle 620 | lampshade 621 | laptop 622 | lawn mower 623 | lens cap 624 | letter opener 625 | library 626 | lifeboat 627 | lighter 628 | limousine 629 | liner 630 | lipstick 631 | Loafer 632 | lotion 633 | loudspeaker 634 | loupe 635 | lumbermill 636 | magnetic compass 637 | mailbag 638 | mailbox 639 | maillot 640 | maillot 641 | manhole cover 642 | maraca 643 | marimba 644 | mask 645 | matchstick 646 | maypole 647 | maze 648 | measuring cup 649 | medicine chest 650 | megalith 651 | microphone 652 | microwave 653 | military uniform 654 | milk can 655 | minibus 656 | miniskirt 657 | minivan 658 | missile 659 | mitten 660 | mixing bowl 661 | mobile home 662 | Model T 663 | modem 664 | monastery 665 | monitor 666 | moped 667 | mortar 668 | mortarboard 669 | mosque 670 | mosquito net 671 | motor scooter 672 | mountain bike 673 | mountain tent 674 | mouse 675 | mousetrap 676 | moving van 677 | muzzle 678 | nail 679 | neck brace 680 | necklace 681 | nipple 682 | notebook 683 | obelisk 684 | oboe 685 | ocarina 686 | odometer 687 | oil filter 688 | organ 689 | oscilloscope 690 | overskirt 691 | oxcart 692 | oxygen mask 693 | packet 694 | paddle 695 | paddlewheel 696 | padlock 697 | paintbrush 698 | pajama 699 | palace 700 | panpipe 701 | paper towel 702 | parachute 703 | parallel bars 704 | park bench 705 | parking meter 706 | passenger car 707 | patio 708 | pay-phone 709 | pedestal 710 | pencil box 711 | pencil sharpener 712 | perfume 713 | Petri dish 714 | photocopier 715 | pick 716 | pickelhaube 717 | picket fence 718 | pickup 719 | pier 720 | piggy bank 721 | pill bottle 722 | pillow 723 | ping-pong ball 724 | pinwheel 725 | pirate 726 | pitcher 727 | plane 728 | planetarium 729 | plastic bag 730 | plate rack 731 | plow 732 | plunger 733 | Polaroid camera 734 | pole 735 | police van 736 | poncho 737 | pool table 738 | pop bottle 739 | pot 740 | potter's wheel 741 | power drill 742 | prayer rug 743 | printer 744 | prison 745 | projectile 746 | projector 747 | puck 748 | punching bag 749 | purse 750 | quill 751 | quilt 752 | racer 753 | racket 754 | radiator 755 | radio 756 | radio telescope 757 | rain barrel 758 | recreational vehicle 759 | reel 760 | reflex camera 761 | refrigerator 762 | remote control 763 | restaurant 764 | revolver 765 | rifle 766 | rocking chair 767 | rotisserie 768 | rubber eraser 769 | rugby ball 770 | rule 771 | running shoe 772 | safe 773 | safety pin 774 | saltshaker 775 | sandal 776 | sarong 777 | sax 778 | scabbard 779 | scale 780 | school bus 781 | schooner 782 | scoreboard 783 | screen 784 | screw 785 | screwdriver 786 | seat belt 787 | sewing machine 788 | shield 789 | shoe shop 790 | shoji 791 | shopping basket 792 | shopping cart 793 | shovel 794 | shower cap 795 | shower curtain 796 | ski 797 | ski mask 798 | sleeping bag 799 | slide rule 800 | sliding door 801 | slot 802 | snorkel 803 | snowmobile 804 | snowplow 805 | soap dispenser 806 | soccer ball 807 | sock 808 | solar dish 809 | sombrero 810 | soup bowl 811 | space bar 812 | space heater 813 | space shuttle 814 | spatula 815 | speedboat 816 | spider web 817 | spindle 818 | sports car 819 | spotlight 820 | stage 821 | steam locomotive 822 | steel arch bridge 823 | steel drum 824 | stethoscope 825 | stole 826 | stone wall 827 | stopwatch 828 | stove 829 | strainer 830 | streetcar 831 | stretcher 832 | studio couch 833 | stupa 834 | submarine 835 | suit 836 | sundial 837 | sunglass 838 | sunglasses 839 | sunscreen 840 | suspension bridge 841 | swab 842 | sweatshirt 843 | swimming trunks 844 | swing 845 | switch 846 | syringe 847 | table lamp 848 | tank 849 | tape player 850 | teapot 851 | teddy 852 | television 853 | tennis ball 854 | thatch 855 | theater curtain 856 | thimble 857 | thresher 858 | throne 859 | tile roof 860 | toaster 861 | tobacco shop 862 | toilet seat 863 | torch 864 | totem pole 865 | tow truck 866 | toyshop 867 | tractor 868 | trailer truck 869 | tray 870 | trench coat 871 | tricycle 872 | trimaran 873 | tripod 874 | triumphal arch 875 | trolleybus 876 | trombone 877 | tub 878 | turnstile 879 | typewriter keyboard 880 | umbrella 881 | unicycle 882 | upright 883 | vacuum 884 | vase 885 | vault 886 | velvet 887 | vending machine 888 | vestment 889 | viaduct 890 | violin 891 | volleyball 892 | waffle iron 893 | wall clock 894 | wallet 895 | wardrobe 896 | warplane 897 | washbasin 898 | washer 899 | water bottle 900 | water jug 901 | water tower 902 | whiskey jug 903 | whistle 904 | wig 905 | window screen 906 | window shade 907 | Windsor tie 908 | wine bottle 909 | wing 910 | wok 911 | wooden spoon 912 | wool 913 | worm fence 914 | wreck 915 | yawl 916 | yurt 917 | web site 918 | comic book 919 | crossword puzzle 920 | street sign 921 | traffic light 922 | book jacket 923 | menu 924 | plate 925 | guacamole 926 | consomme 927 | hot pot 928 | trifle 929 | ice cream 930 | ice lolly 931 | French loaf 932 | bagel 933 | pretzel 934 | cheeseburger 935 | hotdog 936 | mashed potato 937 | head cabbage 938 | broccoli 939 | cauliflower 940 | zucchini 941 | spaghetti squash 942 | acorn squash 943 | butternut squash 944 | cucumber 945 | artichoke 946 | bell pepper 947 | cardoon 948 | mushroom 949 | Granny Smith 950 | strawberry 951 | orange 952 | lemon 953 | fig 954 | pineapple 955 | banana 956 | jackfruit 957 | custard apple 958 | pomegranate 959 | hay 960 | carbonara 961 | chocolate sauce 962 | dough 963 | meat loaf 964 | pizza 965 | potpie 966 | burrito 967 | red wine 968 | espresso 969 | cup 970 | eggnog 971 | alp 972 | bubble 973 | cliff 974 | coral reef 975 | geyser 976 | lakeside 977 | promontory 978 | sandbar 979 | seashore 980 | valley 981 | volcano 982 | ballplayer 983 | groom 984 | scuba diver 985 | rapeseed 986 | daisy 987 | yellow lady's slipper 988 | corn 989 | acorn 990 | hip 991 | buckeye 992 | coral fungus 993 | agaric 994 | gyromitra 995 | stinkhorn 996 | earthstar 997 | hen-of-the-woods 998 | bolete 999 | ear 1000 | toilet tissue 1001 | -------------------------------------------------------------------------------- /postprocess_onnx.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2020 NVIDIA Corporation 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import onnx_graphsurgeon as gs 17 | import argparse 18 | import onnx 19 | import numpy as np 20 | 21 | def process_transpose_nodes(graph): 22 | """ 23 | This is a workaround to manually transpose the conv weights and remove 24 | the existing transpose nodes. Currently TRT has a limitation when there is 25 | a transpose node as an input to the weights of the conv layer. This utility 26 | would be removed in future releases. 27 | """ 28 | # Find all the transposes before the convolutional nodes 29 | conv_nodes = [node for node in graph.nodes if node.op == "Conv"] 30 | for node in conv_nodes: 31 | # Transpose the convolutional weights and reset them to the weights 32 | conv_weights_tensor = node.i(1).i().i().inputs[0] 33 | conv_weights_transposed = np.transpose(conv_weights_tensor.values, [3, 2, 0, 1]) 34 | conv_weights_tensor.values = conv_weights_transposed 35 | 36 | # Remove the transpose nodes after the dequant node. TensorRT does not support transpose nodes after QDQ nodes. 37 | dequant_node_output = node.i(1).i(0).outputs[0] 38 | node.inputs[1] = dequant_node_output 39 | 40 | # Remove unused nodes, and topologically sort the graph. 41 | return graph.cleanup().toposort() 42 | 43 | if __name__=='__main__': 44 | parser = argparse.ArgumentParser("Post process ONNX graph by removing transpose nodes") 45 | parser.add_argument("--input", required=True, help="Input onnx graph") 46 | parser.add_argument("--output", default='postprocessed_rn50.onnx', help="Name of post processed onnx graph") 47 | args = parser.parse_args() 48 | 49 | # Load the rn50 graph 50 | graph = gs.import_onnx(onnx.load(args.input)) 51 | 52 | # Remove the transpose nodes and reshape the convolution weights 53 | graph = process_transpose_nodes(graph) 54 | 55 | # Export the onnx graph from graphsurgeon 56 | onnx_model = gs.export_onnx(graph) 57 | print("Output ONNX graph generated: ", args.output) 58 | onnx.save_model(onnx_model, args.output) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | onnx==1.6 2 | tf2onnx==1.6.1 3 | numpy 4 | pycuda 5 | --------------------------------------------------------------------------------