├── LICENSE
├── README.md
├── build_engine.py
├── fold_constants.py
├── image_processing.py
├── infer.py
├── labels
└── class_labels.txt
├── postprocess_onnx.py
└── requirements.txt
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | Copyright 2020 NVIDIA Corporation
180 |
181 | Licensed under the Apache License, Version 2.0 (the "License");
182 | you may not use this file except in compliance with the License.
183 | You may obtain a copy of the License at
184 |
185 | http://www.apache.org/licenses/LICENSE-2.0
186 |
187 | Unless required by applicable law or agreed to in writing, software
188 | distributed under the License is distributed on an "AS IS" BASIS,
189 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
190 | See the License for the specific language governing permissions and
191 | limitations under the License.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # TensorRT inference of Resnet-50 trained with QAT.
2 |
3 | **Table Of Contents**
4 | - [Description](#description)
5 | - [How does this sample work?](#how-does-this-sample-work)
6 | - [Prerequisites](#prerequisites)
7 | - [Running the sample](#running-the-sample)
8 | * [Step 1: Quantization Aware Training](#step-1-quantization-aware-training)
9 | * [Step 2: Export frozen graph of RN50 QAT](#step-2-export-frozen-graph-of-rn50-qat)
10 | * [Step 3: Constant folding](#step-3-constant-folding)
11 | * [Step 4: TF2ONNX conversion](#step-4-tf2onnx-conversion)
12 | * [Step 5: Post processing ONNX](#step-5-post-processing-onnx)
13 | * [Step 6: Build TensorRT engine from ONNX graph](#step-6-build-tensorrt-engine-from-onnx-graph)
14 | * [Step 7: TensorRT Inference](#step-7-tensorrt-inference)
15 | - [Additional resources](#additional-resources)
16 | - [Changelog](#changelog)
17 | - [Known issues](#known-issues)
18 | - [License](#license)
19 |
20 | ## Description
21 |
22 | This sample demonstrates workflow for training and inference of Resnet-50 model trained using Quantization Aware Training.
23 | The inference implementation is experimental prototype and is provided with no guarantee of support.
24 |
25 | ## How does this sample work?
26 |
27 | This sample demonstrates
28 |
29 | * Training a Resnet-50 model using quantization aware training.
30 | * Post processing and conversion to ONNX graph to ensure it is successfully parsed by TensorRT.
31 | * Inference of Resnet-50 QAT graph with TensorRT.
32 |
33 | ## Prerequisites
34 |
35 | Dependencies required for this sample
36 |
37 | 1. TensorFlow NGC containers (20.01-tf1-py3 NGC container or above for Steps 1-4. Please use `tf1` variants which have TF 1.15.2 version installed.
38 | This sample does not work with public version of Tensorflow 1.15.2 library)
39 |
40 | 2. Install the dependencies for Python3 inside the NGC container.
41 | - For Python 3 users, from the root directory, run:
42 | `python3 -m pip install -r requirements.txt`
43 |
44 | 3. TensorRT-7.1
45 |
46 | 4. ONNX-Graphsurgeon 0.2.1
47 |
48 | ## Running the sample
49 |
50 | ***NOTE: Steps 1-4 require NGC containers (TensorFlow 20.01-tf1-py3 NGC container or above). Steps 5-7 can be executed within or outside the NGC container***
51 |
52 | ### Step 1: Quantization Aware Training
53 |
54 | Please follow detailed instructions on how to finetune a RN50 model using QAT.
55 |
56 | This stage involoves
57 |
58 | * Finetune a RN50 model with quantization nodes and save the final checkpoint.
59 | * Post process the above RN50 QAT checkpoint by reshaping the weights of final FC layer into a 1x1 conv layer.
60 |
61 | ### Step 2: Export frozen graph of RN50 QAT
62 |
63 | Export the RN50 QAT graph replacing the final FC layer with a 1x1 conv layer.
64 | Please follow these instructions to generate a frozen graph in desired data formats.
65 |
66 | ### Step 3: Constant folding
67 |
68 | Once we have the frozen graph from Step 2, run the following command to perform constant folding on TF graph
69 | ```
70 | python fold_constants.py --input --output
71 | ```
72 |
73 | Arguments:
74 | * `--input` : Input Tensorflow graph
75 | * `--output_node` : Output node name of the RN50 graph (Default: `resnet50_v1.5/output/softmax_1`)
76 | * `--output` : Output name of constant folded TF graph.
77 |
78 | ### Step 4: TF2ONNX conversion
79 |
80 | TF2ONNX converter is used to convert the constant folded tensorflow frozen graph into ONNX graph. For RN50 QAT, `tf.quantization.quantize_and_dequantize` operation (QDQ) is converted into `QuantizeLinear` and `DequantizeLinear` operations.
81 | Support for converting QDQ operations has been added in `1.6.1` version of TF2ONNX.
82 |
83 | Command to convert RN50 QAT TF graph to ONNX
84 | ```
85 | python3 -m tf2onnx.convert --input --output --inputs input:0 --outputs resnet50/output/softmax_1:0 --opset 11
86 | ```
87 |
88 | Arguments:
89 | * `--input` : Name of TF input graph
90 | * `--output` : Name of ONNX output graph
91 | * `--inputs` : Name of input tensors
92 | * `--outputs` : Name of output tensors
93 | * `--opset` : ONNX opset version
94 |
95 | ### Step 5: Post processing ONNX
96 |
97 | Run the following command to postprocess the ONNX graph using ONNX-Graphsurgeon API. This step removes the `transpose` nodes after `Dequantize` nodes.
98 | ```
99 | python postprocess_onnx.py --input --output
100 | ```
101 |
102 | Arguments:
103 | * `--input` : Input ONNX graph
104 | * `--output` : Output name of postprocessed ONNX graph.
105 |
106 | ### Step 6: Build TensorRT engine from ONNX graph
107 | ```
108 | python build_engine.py --onnx
109 | ```
110 |
111 | Arguments:
112 | * `--onnx` : Path to RN50 QAT onnx graph
113 | * `--engine` : Output file name of TensorRT engine.
114 | * `--verbose` : Flag to enable verbose logging
115 |
116 | ### Step 7: TensorRT Inference
117 |
118 | Command to run inference on a sample image
119 |
120 | ```
121 | python infer.py --engine
122 | ```
123 |
124 | Arguments:
125 | * `--engine` : Path to input RN50 TensorRT engine.
126 | * `--labels` : Path to imagenet 1k labels text file provided.
127 | * `--image` : Path to the sample image
128 | * `--verbose` : Flag to enable verbose logging
129 |
130 | ### Sample --help options
131 |
132 | To see the full list of available options and their descriptions, use the `-h` or `--help` command line option. For example:
133 | ```
134 | usage: .py> [-h]
135 | ```
136 |
137 | # Additional resources
138 |
139 | The following resources provide a deeper understanding about Quantization aware training, TF2ONNX and importing a model into TensorRT using Python:
140 |
141 | **Quantization Aware Training**
142 | - [Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference](https://arxiv.org/pdf/1712.05877.pdf)
143 | - [Quantization Aware Training guide](https://www.tensorflow.org/model_optimization/guide/quantization/training)
144 | - [Resnet-50 Deep Learning Example](https://github.com/NVIDIA/DeepLearningExamples/blob/master/TensorFlow/Classification/ConvNets/resnet50v1.5/README.md)
145 | - [Deep Residual Learning for Image Recognition](https://arxiv.org/pdf/1512.03385.pdf)
146 |
147 | **Parsers**
148 | - [TF2ONNX Converter](https://github.com/onnx/tensorflow-onnx)
149 | - [ONNX Parser](https://docs.nvidia.com/deeplearning/sdk/tensorrt-api/python_api/parsers/Onnx/pyOnnx.html)
150 |
151 | **Documentation**
152 | - [Introduction To NVIDIA’s TensorRT Samples](https://docs.nvidia.com/deeplearning/sdk/tensorrt-sample-support-guide/index.html#samples)
153 | - [Working With TensorRT Using The Python API](https://docs.nvidia.com/deeplearning/sdk/tensorrt-developer-guide/index.html#python_topics)
154 | - [Importing A Model Using A Parser In Python](https://docs.nvidia.com/deeplearning/sdk/tensorrt-developer-guide/index.html#import_model_python)
155 | - [NVIDIA’s TensorRT Documentation Library](https://docs.nvidia.com/deeplearning/sdk/tensorrt-archived/index.html)
156 |
157 | # Changelog
158 |
159 | June 2020: Initial release of this sample
160 |
161 | # Known issues
162 |
163 | Tensorflow operation `tf.quantization.quantize_and_dequantize` is used for quantization during training. The gradient of this operation is not clipped based on input range.
164 |
165 | # License
166 |
167 | The sampleQAT license can be found in the LICENSE file.
168 |
169 |
--------------------------------------------------------------------------------
/build_engine.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright 2020 NVIDIA Corporation
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import tensorrt as trt
17 | import pycuda.driver as cuda
18 | import pycuda.autoinit
19 | import numpy as np
20 | import argparse
21 |
22 | def build_profile(builder, network, profile_shapes, default_shape_value=1):
23 | """
24 | Build optimization profile for the builder and configure the min, opt, max shapes appropriately.
25 | """
26 | def is_dimension_dynamic(dim):
27 | return dim is None or dim <= 0
28 |
29 | def override_shape(shape):
30 | return tuple([1 if is_dimension_dynamic(dim) else dim for dim in shape])
31 |
32 | profile = builder.create_optimization_profile()
33 | for idx in range(network.num_inputs):
34 | inp = network.get_input(idx)
35 |
36 | def get_profile_shape(name):
37 | if name not in profile_shapes:
38 | return None
39 | shapes = profile_shapes[name]
40 | if not isinstance(shapes, list) or len(shapes) != 3:
41 | G_LOGGER.critical("Profile values must be a list containing exactly 3 shapes (tuples or Dims), but received shapes: {:} for input: {:}.\nNote: profile was: {:}.\nNote: Network inputs were: {:}".format(shapes, name, profile_shapes, get_network_inputs(network)))
42 | return shapes
43 |
44 | if inp.is_shape_tensor:
45 | shapes = get_profile_shape(inp.name)
46 | if not shapes:
47 | rank = inp.shape[0]
48 | shapes = [(default_shape_value, ) * rank] * 3
49 | print("Setting shape input to {:}. If this is incorrect, for shape input: {:}, please provide tuples for min, opt, and max shapes containing {:} elements".format(shapes[0], inp.name, rank))
50 | min, opt, max = shapes
51 | profile.set_shape_input(inp.name, min, opt, max)
52 | print("Setting shape input: {:} values to min: {:}, opt: {:}, max: {:}".format(inp.name, min, opt, max))
53 | elif -1 in inp.shape:
54 | shapes = get_profile_shape(inp.name)
55 | if not shapes:
56 | shapes = [override_shape(inp.shape)] * 3
57 | print("Overriding dynamic input shape {:} to {:}. If this is incorrect, for input tensor: {:}, please provide tuples for min, opt, and max shapes containing values: {:} with dynamic dimensions replaced,".format(inp.shape, shapes[0], inp.name, inp.shape))
58 | min, opt, max = shapes
59 | profile.set_shape(inp.name, min, opt, max)
60 | print("Setting input: {:} shape to min: {:}, opt: {:}, max: {:}".format(inp.name, min, opt, max))
61 | if not profile:
62 | print("Profile is not valid, please provide profile data. Note: profile was: {:}".format(profile_shapes))
63 | return profile
64 |
65 | def preprocess_network(network):
66 | """
67 | Add quantize and dequantize nodes after the input placeholder.
68 | The scale values are currently picked on emperical basis. Ideally,
69 | you need to add these nodes during quantization aware training and
70 | learn the dynamic ranges of input node.
71 | """
72 | quant_scale = np.array([1.0/127.0], dtype=np.float32)
73 | dequant_scale = np.array([127.0/1.0], dtype=np.float32)
74 | # Zero point is always zero for quantization in TensorRT.
75 | zeros = np.zeros(shape=(1, ), dtype=np.float32)
76 |
77 | for i in range(network.num_inputs):
78 | inp = network.get_input(i)
79 | # Find layer consuming input tensor
80 | found = False
81 | for layer in network:
82 | if found:
83 | break;
84 |
85 | for k in range(layer.num_inputs):
86 | if (inp == layer.get_input(k)):
87 | mode = trt.ScaleMode.UNIFORM
88 | quantize = network.add_scale(inp, mode, scale=quant_scale, shift=zeros)
89 | quantize.set_output_type(0, trt.int8)
90 | quantize.name = "InputQuantizeNode"
91 | quantize.get_output(0).name = "QuantizedInput"
92 | dequantize = network.add_scale(quantize.get_output(0), mode, scale=dequant_scale, shift=zeros)
93 | dequantize.set_output_type(0, trt.float32)
94 | dequantize.name = "InputDequantizeNode"
95 | dequantize.get_output(0).name = "DequantizedInput"
96 | layer.set_input(k, dequantize.get_output(0))
97 | found = True
98 | break
99 |
100 | def build_engine_onnx(model_file, verbose=False):
101 | """
102 | Parse the model file through TensorRT, build TRT engine and run inference
103 | """
104 | # Create builder and network
105 | if verbose:
106 | TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE)
107 | else:
108 | TRT_LOGGER = trt.Logger(trt.Logger.INFO)
109 |
110 | network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
111 | network_flags = network_flags | (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_PRECISION))
112 |
113 | with trt.Builder(TRT_LOGGER) as builder, builder.create_network(flags=network_flags) as network, trt.OnnxParser(network, TRT_LOGGER) as parser:
114 | with open(model_file, 'rb') as model:
115 | if not parser.parse(model.read()):
116 | print ('ERROR: Failed to parse the ONNX file.')
117 | for error in range(parser.num_errors):
118 | print (parser.get_error(error))
119 | return None
120 |
121 | # Add quantize and dequantize nodes for input of the network
122 | preprocess_network(network)
123 | config = builder.create_builder_config()
124 | config.max_workspace_size = 1 << 30
125 | config.flags = config.flags | 1 << int(trt.BuilderFlag.INT8)
126 | # Setting the (min, opt, max) batch sizes to be 1. Users need to configure this according to their requirements.
127 | config.add_optimization_profile(build_profile(builder, network, profile_shapes={'input' : [(1, 3, 224, 224),(1, 3, 224, 224),(1, 3, 224, 224)]}))
128 |
129 | return builder.build_engine(network, config)
130 |
131 | def main(args):
132 |
133 | model_file = args.onnx
134 | # Parse the ONNX graph through TensorRT and build the engine
135 | trt_engine = build_engine_onnx(model_file, args.verbose)
136 | # Serialize the engine and save to file
137 | with open(args.engine, "wb") as file:
138 | file.write(trt_engine.serialize())
139 |
140 |
141 | if __name__ == '__main__':
142 | parser = argparse.ArgumentParser()
143 | parser.add_argument("--onnx", type=str, default='rn50.onnx', help="Path to RN50 ONNX graph")
144 | parser.add_argument("--engine", type=str, default='rn50_trt.engine', help="output path to TensorRT engine")
145 | parser.add_argument('-v', '--verbose', action='store_true', help="Flag to enable verbose logging")
146 | args = parser.parse_args()
147 | main(args)
148 |
--------------------------------------------------------------------------------
/fold_constants.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright 2020 NVIDIA Corporation
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import argparse
17 | from tensorflow.core.protobuf import config_pb2, rewriter_config_pb2, meta_graph_pb2
18 | from tensorflow.core.framework import graph_pb2
19 | from tensorflow.python.framework import importer, ops
20 | from tensorflow.python.grappler import tf_optimizer
21 | from tensorflow.python.training import saver
22 |
23 |
24 | def constfold(graphdef, output_name):
25 | graph = ops.Graph()
26 | with graph.as_default():
27 | outputs = output_name.split(',')
28 | output_collection = meta_graph_pb2.CollectionDef()
29 | output_list = output_collection.node_list.value
30 | for output in outputs:
31 | output_list.append(output)
32 | importer.import_graph_def(graphdef, name="")
33 | metagraph = saver.export_meta_graph(graph_def=graph.as_graph_def(add_shapes=True), graph=graph)
34 | metagraph.collection_def["train_op"].CopyFrom(output_collection)
35 |
36 | rewriter_config = rewriter_config_pb2.RewriterConfig()
37 | rewriter_config.optimizers.extend(["constfold"])
38 | rewriter_config.meta_optimizer_iterations = (rewriter_config_pb2.RewriterConfig.ONE)
39 | session_config = config_pb2.ConfigProto()
40 | session_config.graph_options.rewrite_options.CopyFrom(rewriter_config)
41 |
42 | return tf_optimizer.OptimizeGraph(session_config, metagraph)
43 |
44 | if __name__ == '__main__':
45 | parser = argparse.ArgumentParser("Folds constants in the provided frozen model")
46 | parser.add_argument("-i", "--input", help="The input frozen model to be constant folded.")
47 | parser.add_argument("--output_node", default="resnet50/output/softmax_1", help="Output node names separated by commas")
48 | parser.add_argument("-o", "--output", default="folded_rn50.pb", help="Path to constant folded output graph")
49 | args, _ = parser.parse_known_args()
50 |
51 | with open(args.input, 'rb') as f:
52 | graphdef = graph_pb2.GraphDef()
53 | graphdef.ParseFromString(f.read())
54 |
55 | folded_graph = constfold(graphdef, args.output_node)
56 | print("Writing output to {:}".format(args.output))
57 | with open(args.output, "wb") as f:
58 | f.write(folded_graph.SerializeToString())
59 |
--------------------------------------------------------------------------------
/image_processing.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 NVIDIA Corporation
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import logging
16 |
17 | import numpy as np
18 | from PIL import Image
19 |
20 |
21 | logging.basicConfig(level=logging.DEBUG,
22 | format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
23 | datefmt="%Y-%m-%d %H:%M:%S")
24 | logger = logging.getLogger(__name__)
25 |
26 | _RESIZE_MIN = 256
27 | _R_MEAN = 123.68
28 | _G_MEAN = 116.78
29 | _B_MEAN = 103.94
30 | _CHANNEL_MEANS = [_R_MEAN, _G_MEAN, _B_MEAN]
31 |
32 |
33 | def preprocess_imagenet(image, channels=3, height=224, width=224):
34 | """Pre-processing for Imagenet-based Image Classification Models:
35 | resnet50, vgg16, mobilenet, etc. (Doesn't seem to work for Inception)
36 |
37 | Parameters
38 | ----------
39 | image: PIL.Image
40 | The image resulting from PIL.Image.open(filename) to preprocess
41 | channels: int
42 | The number of channels the image has (Usually 1 or 3)
43 | height: int
44 | The desired height of the image (usually 224 for Imagenet data)
45 | width: int
46 | The desired width of the image (usually 224 for Imagenet data)
47 |
48 | Returns
49 | -------
50 | img_data: numpy array
51 | The preprocessed image data in the form of a numpy array
52 |
53 | """
54 | # Get the image in CHW format
55 | resized_image = image.resize((width, height), Image.ANTIALIAS)
56 | img_data = np.asarray(resized_image).astype(np.float32)
57 |
58 | if len(img_data.shape) == 2:
59 | # For images without a channel dimension, we stack
60 | img_data = np.stack([img_data] * 3)
61 | logger.debug("Received grayscale image. Reshaped to {:}".format(img_data.shape))
62 | else:
63 | img_data = img_data.transpose([2, 0, 1])
64 |
65 | mean_vec = np.array([0.485, 0.456, 0.406])
66 | stddev_vec = np.array([0.229, 0.224, 0.225])
67 | assert img_data.shape[0] == channels
68 |
69 | for i in range(img_data.shape[0]):
70 | # Scale each pixel to [0, 1] and normalize per channel.
71 | img_data[i, :, :] = (img_data[i, :, :] / 255 - mean_vec[i]) / stddev_vec[i]
72 |
73 | return img_data
74 |
75 | def _smallest_size_at_least(height, width, resize_min):
76 |
77 | smaller_dim = np.minimum(float(height), float(width))
78 | scale_ratio = resize_min / smaller_dim
79 |
80 | # Convert back to ints to make heights and widths that TF ops will accept.
81 | new_height = int(height * scale_ratio)
82 | new_width = int(width * scale_ratio)
83 |
84 | return new_height, new_width
85 |
86 | def _central_crop(image, crop_height, crop_width):
87 | shape = image.shape
88 | height, width = shape[0], shape[1]
89 |
90 | amount_to_be_cropped_h = (height - crop_height)
91 | crop_top = amount_to_be_cropped_h // 2
92 | amount_to_be_cropped_w = (width - crop_width)
93 | crop_left = amount_to_be_cropped_w // 2
94 | cropped_image = image[crop_top:crop_height+crop_top, crop_left:crop_width+crop_left]
95 | return cropped_image
96 |
97 | def normalize_inputs(inputs):
98 |
99 | num_channels = inputs.shape[-1]
100 |
101 | if len(_CHANNEL_MEANS) != num_channels:
102 | raise ValueError('len(means) must match the number of channels')
103 |
104 | # We have a 1-D tensor of means; convert to 3-D.
105 | means_per_channel = np.reshape(_CHANNEL_MEANS, [1, 1, num_channels])
106 | # means_per_channel = tf.cast(means_per_channel, dtype=inputs.dtype)
107 |
108 | inputs = np.subtract(inputs, means_per_channel)/255.0
109 |
110 | return inputs
111 |
112 |
113 | def preprocess_resnet50(image, channels=3, height=224, width=224):
114 | """Pre-processing for Imagenet-based Image Classification Models:
115 | resnet50 (resnet_v1_1.5 designed by Nvidia
116 | Parameters
117 | ----------
118 | image: PIL.Image
119 | The image resulting from PIL.Image.open(filename) to preprocess
120 | channels: int
121 | The number of channels the image has (Usually 1 or 3)
122 | height: int
123 | The desired height of the image (usually 224 for Imagenet data)
124 | width: int
125 | The desired width of the image (usually 224 for Imagenet data)
126 |
127 | Returns
128 | -------
129 | img_data: numpy array
130 | The preprocessed image data in the form of a numpy array
131 |
132 | """
133 | # Get the shape of the image.
134 | w, h= image.size
135 |
136 | new_height, new_width = _smallest_size_at_least(h, w, _RESIZE_MIN)
137 |
138 | # Image is still in WH format in PIL
139 | resized_image = image.resize((new_width, new_height), Image.BILINEAR)
140 | # Changes to HWC due to numpy
141 | img_data = np.asarray(resized_image).astype(np.float32)
142 | # Do a central crop
143 | cropped_image = _central_crop(img_data, height, width)
144 | assert cropped_image.shape[0] == height
145 | assert cropped_image.shape[1] == width
146 | if len(cropped_image.shape) == 2:
147 | # For images without a channel dimension, we stack
148 | cropped_image = np.stack([cropped_image] * 3)
149 | return cropped_image
150 | # logger.debug("Received grayscale image. Reshaped to {:}".format(cropped_image.shape))
151 |
152 | normalized_inputs = normalize_inputs(cropped_image)
153 | cropped_image = np.transpose(normalized_inputs, [2, 0, 1])
154 |
155 | return cropped_image
156 |
157 | def preprocess_inception(image, channels=3, height=224, width=224):
158 | """Pre-processing for InceptionV1. Inception expects different pre-processing
159 | than {resnet50, vgg16, mobilenet}. This may not be totally correct,
160 | but it worked for some simple test images.
161 |
162 | Parameters
163 | ----------
164 | image: PIL.Image
165 | The image resulting from PIL.Image.open(filename) to preprocess
166 | channels: int
167 | The number of channels the image has (Usually 1 or 3)
168 | height: int
169 | The desired height of the image (usually 224 for Imagenet data)
170 | width: int
171 | The desired width of the image (usually 224 for Imagenet data)
172 |
173 | Returns
174 | -------
175 | img_data: numpy array
176 | The preprocessed image data in the form of a numpy array
177 |
178 | """
179 | # Get the image in CHW format
180 | resized_image = image.resize((width, height), Image.BILINEAR)
181 | img_data = np.asarray(resized_image).astype(np.float32)
182 |
183 | if len(img_data.shape) == 2:
184 | # For images without a channel dimension, we stack
185 | img_data = np.stack([img_data] * 3)
186 | logger.debug("Received grayscale image. Reshaped to {:}".format(img_data.shape))
187 | else:
188 | img_data = img_data.transpose([2, 0, 1])
189 |
190 | return img_data
191 |
--------------------------------------------------------------------------------
/infer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright 2020 NVIDIA Corporation
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import os
17 | import argparse
18 | import PIL.Image
19 | import numpy as np
20 | import tensorrt as trt
21 | import pycuda.driver as cuda
22 | import pycuda.autoinit
23 | import image_processing
24 |
25 | TRT_DYNAMIC_DIM = -1
26 |
27 | def load_normalized_test_case(test_image, pagelocked_buffer, preprocess_func):
28 | # Expected input dimensions
29 | C, H, W = (3, 224, 224)
30 | # Normalize the images, concatenate them and copy to pagelocked memory.
31 | data = np.asarray([preprocess_func(PIL.Image.open(test_image).convert('RGB'), C, H, W)]).flatten()
32 | np.copyto(pagelocked_buffer, data)
33 |
34 | class HostDeviceMem(object):
35 | r""" Simple helper data class that's a little nicer to use than a 2-tuple.
36 | """
37 | def __init__(self, host_mem, device_mem):
38 | self.host = host_mem
39 | self.device = device_mem
40 |
41 | def __str__(self):
42 | return "Host:\n" + str(self.host) + "\nDevice:\n" + str(self.device)
43 |
44 | def __repr__(self):
45 | return self.__str__()
46 |
47 |
48 | def allocate_buffers(engine: trt.ICudaEngine, batch_size: int):
49 | print('Allocating buffers ...')
50 |
51 | inputs = []
52 | outputs = []
53 | dbindings = []
54 |
55 | stream = cuda.Stream()
56 |
57 | for binding in engine:
58 | size = batch_size * abs(trt.volume(engine.get_binding_shape(binding)))
59 | dtype = trt.nptype(engine.get_binding_dtype(binding))
60 | # Allocate host and device buffers
61 | host_mem = cuda.pagelocked_empty(size, dtype)
62 | device_mem = cuda.mem_alloc(host_mem.nbytes)
63 | # Append the device buffer to device bindings.
64 | dbindings.append(int(device_mem))
65 |
66 | # Append to the appropriate list.
67 | if engine.binding_is_input(binding):
68 | inputs.append(HostDeviceMem(host_mem, device_mem))
69 | else:
70 | outputs.append(HostDeviceMem(host_mem, device_mem))
71 |
72 | return inputs, outputs, dbindings, stream
73 |
74 | def infer(engine_path, preprocess_func, batch_size, input_image, labels=[], verbose=False):
75 |
76 | if verbose:
77 | logger = trt.Logger(trt.Logger.VERBOSE)
78 | else:
79 | logger = trt.Logger(trt.Logger.INFO)
80 |
81 | with open(engine_path, 'rb') as f, trt.Runtime(logger) as runtime:
82 | engine = runtime.deserialize_cuda_engine(f.read())
83 |
84 | def override_shape(shape, batch_size):
85 | return tuple([batch_size if dim==TRT_DYNAMIC_DIM else dim for dim in shape])
86 |
87 | # Allocate buffers and create a CUDA stream.
88 | inputs, outputs, dbindings, stream = allocate_buffers(engine, batch_size)
89 |
90 | # Contexts are used to perform inference.
91 | with engine.create_execution_context() as context:
92 |
93 | # Resolve dynamic shapes in the context
94 | for binding in engine:
95 | binding_idx = engine.get_binding_index(binding)
96 | shape = engine.get_binding_shape(binding_idx)
97 | if engine.binding_is_input(binding_idx):
98 | if TRT_DYNAMIC_DIM in shape:
99 | shape = override_shape(shape, batch_size)
100 | context.set_binding_shape(binding_idx, shape)
101 |
102 | # Load the test images and preprocess them
103 | load_normalized_test_case(input_image, inputs[0].host, preprocess_func)
104 |
105 | # Transfer input data to the GPU.
106 | cuda.memcpy_htod(inputs[0].device, inputs[0].host)
107 | # Run inference.
108 | context.execute(batch_size, dbindings)
109 | # Transfer predictions back to host from GPU
110 | out = outputs[0]
111 | cuda.memcpy_dtoh(out.host, out.device)
112 |
113 | softmax_output = np.array(out.host)
114 | top1_idx = np.argmax(softmax_output)
115 | output_class = labels[top1_idx+1]
116 | output_confidence = softmax_output[top1_idx]
117 |
118 | print ("Output class of the image: {} Confidence: {}".format(output_class, output_confidence))
119 |
120 | if __name__ == '__main__':
121 | parser = argparse.ArgumentParser(description='Run inference on TensorRT engines for Imagenet-based Classification models.')
122 | parser.add_argument('-e', '--engine', type=str, required=True,
123 | help='Path to RN50 TensorRT engine')
124 | parser.add_argument('-i', '--image', required=True, type=str,
125 | help="Path to input image.")
126 | parser.add_argument("-l", "--labels", type=str, default=os.path.join("labels", "class_labels.txt"),
127 | help="Path to file which has imagenet 1k labels.")
128 | parser.add_argument('-b', '--batch_size', default=1, type=int,
129 | help="Batch size of inputs")
130 | parser.add_argument('-v', '--verbose', action='store_true',
131 | help="Flag to enable verbose loggin")
132 | args = parser.parse_args()
133 |
134 | # Class 0 is not used and is treated as background class. Renaming it to "background"
135 | with open(args.labels, "r") as f:
136 | background_class = ["background"]
137 | imagenet_synsets = f.read().splitlines()
138 | imagenet_classes=[]
139 | for synset in imagenet_synsets:
140 | class_name = synset.strip()
141 | imagenet_classes.append(class_name)
142 | all_classes = background_class + imagenet_classes
143 | labels = np.array(all_classes)
144 |
145 | # Preprocessing for input images
146 | preprocess_func = image_processing.preprocess_resnet50
147 |
148 | # Run inference on the test image
149 | infer(args.engine, preprocess_func, args.batch_size, args.image, labels, args.verbose)
150 |
151 |
--------------------------------------------------------------------------------
/labels/class_labels.txt:
--------------------------------------------------------------------------------
1 | tench
2 | goldfish
3 | great white shark
4 | tiger shark
5 | hammerhead
6 | electric ray
7 | stingray
8 | cock
9 | hen
10 | ostrich
11 | brambling
12 | goldfinch
13 | house finch
14 | junco
15 | indigo bunting
16 | robin
17 | bulbul
18 | jay
19 | magpie
20 | chickadee
21 | water ouzel
22 | kite
23 | bald eagle
24 | vulture
25 | great grey owl
26 | European fire salamander
27 | common newt
28 | eft
29 | spotted salamander
30 | axolotl
31 | bullfrog
32 | tree frog
33 | tailed frog
34 | loggerhead
35 | leatherback turtle
36 | mud turtle
37 | terrapin
38 | box turtle
39 | banded gecko
40 | common iguana
41 | American chameleon
42 | whiptail
43 | agama
44 | frilled lizard
45 | alligator lizard
46 | Gila monster
47 | green lizard
48 | African chameleon
49 | Komodo dragon
50 | African crocodile
51 | American alligator
52 | triceratops
53 | thunder snake
54 | ringneck snake
55 | hognose snake
56 | green snake
57 | king snake
58 | garter snake
59 | water snake
60 | vine snake
61 | night snake
62 | boa constrictor
63 | rock python
64 | Indian cobra
65 | green mamba
66 | sea snake
67 | horned viper
68 | diamondback
69 | sidewinder
70 | trilobite
71 | harvestman
72 | scorpion
73 | black and gold garden spider
74 | barn spider
75 | garden spider
76 | black widow
77 | tarantula
78 | wolf spider
79 | tick
80 | centipede
81 | black grouse
82 | ptarmigan
83 | ruffed grouse
84 | prairie chicken
85 | peacock
86 | quail
87 | partridge
88 | African grey
89 | macaw
90 | sulphur-crested cockatoo
91 | lorikeet
92 | coucal
93 | bee eater
94 | hornbill
95 | hummingbird
96 | jacamar
97 | toucan
98 | drake
99 | red-breasted merganser
100 | goose
101 | black swan
102 | tusker
103 | echidna
104 | platypus
105 | wallaby
106 | koala
107 | wombat
108 | jellyfish
109 | sea anemone
110 | brain coral
111 | flatworm
112 | nematode
113 | conch
114 | snail
115 | slug
116 | sea slug
117 | chiton
118 | chambered nautilus
119 | Dungeness crab
120 | rock crab
121 | fiddler crab
122 | king crab
123 | American lobster
124 | spiny lobster
125 | crayfish
126 | hermit crab
127 | isopod
128 | white stork
129 | black stork
130 | spoonbill
131 | flamingo
132 | little blue heron
133 | American egret
134 | bittern
135 | crane
136 | limpkin
137 | European gallinule
138 | American coot
139 | bustard
140 | ruddy turnstone
141 | red-backed sandpiper
142 | redshank
143 | dowitcher
144 | oystercatcher
145 | pelican
146 | king penguin
147 | albatross
148 | grey whale
149 | killer whale
150 | dugong
151 | sea lion
152 | Chihuahua
153 | Japanese spaniel
154 | Maltese dog
155 | Pekinese
156 | Shih-Tzu
157 | Blenheim spaniel
158 | papillon
159 | toy terrier
160 | Rhodesian ridgeback
161 | Afghan hound
162 | basset
163 | beagle
164 | bloodhound
165 | bluetick
166 | black-and-tan coonhound
167 | Walker hound
168 | English foxhound
169 | redbone
170 | borzoi
171 | Irish wolfhound
172 | Italian greyhound
173 | whippet
174 | Ibizan hound
175 | Norwegian elkhound
176 | otterhound
177 | Saluki
178 | Scottish deerhound
179 | Weimaraner
180 | Staffordshire bullterrier
181 | American Staffordshire terrier
182 | Bedlington terrier
183 | Border terrier
184 | Kerry blue terrier
185 | Irish terrier
186 | Norfolk terrier
187 | Norwich terrier
188 | Yorkshire terrier
189 | wire-haired fox terrier
190 | Lakeland terrier
191 | Sealyham terrier
192 | Airedale
193 | cairn
194 | Australian terrier
195 | Dandie Dinmont
196 | Boston bull
197 | miniature schnauzer
198 | giant schnauzer
199 | standard schnauzer
200 | Scotch terrier
201 | Tibetan terrier
202 | silky terrier
203 | soft-coated wheaten terrier
204 | West Highland white terrier
205 | Lhasa
206 | flat-coated retriever
207 | curly-coated retriever
208 | golden retriever
209 | Labrador retriever
210 | Chesapeake Bay retriever
211 | German short-haired pointer
212 | vizsla
213 | English setter
214 | Irish setter
215 | Gordon setter
216 | Brittany spaniel
217 | clumber
218 | English springer
219 | Welsh springer spaniel
220 | cocker spaniel
221 | Sussex spaniel
222 | Irish water spaniel
223 | kuvasz
224 | schipperke
225 | groenendael
226 | malinois
227 | briard
228 | kelpie
229 | komondor
230 | Old English sheepdog
231 | Shetland sheepdog
232 | collie
233 | Border collie
234 | Bouvier des Flandres
235 | Rottweiler
236 | German shepherd
237 | Doberman
238 | miniature pinscher
239 | Greater Swiss Mountain dog
240 | Bernese mountain dog
241 | Appenzeller
242 | EntleBucher
243 | boxer
244 | bull mastiff
245 | Tibetan mastiff
246 | French bulldog
247 | Great Dane
248 | Saint Bernard
249 | Eskimo dog
250 | malamute
251 | Siberian husky
252 | dalmatian
253 | affenpinscher
254 | basenji
255 | pug
256 | Leonberg
257 | Newfoundland
258 | Great Pyrenees
259 | Samoyed
260 | Pomeranian
261 | chow
262 | keeshond
263 | Brabancon griffon
264 | Pembroke
265 | Cardigan
266 | toy poodle
267 | miniature poodle
268 | standard poodle
269 | Mexican hairless
270 | timber wolf
271 | white wolf
272 | red wolf
273 | coyote
274 | dingo
275 | dhole
276 | African hunting dog
277 | hyena
278 | red fox
279 | kit fox
280 | Arctic fox
281 | grey fox
282 | tabby
283 | tiger cat
284 | Persian cat
285 | Siamese cat
286 | Egyptian cat
287 | cougar
288 | lynx
289 | leopard
290 | snow leopard
291 | jaguar
292 | lion
293 | tiger
294 | cheetah
295 | brown bear
296 | American black bear
297 | ice bear
298 | sloth bear
299 | mongoose
300 | meerkat
301 | tiger beetle
302 | ladybug
303 | ground beetle
304 | long-horned beetle
305 | leaf beetle
306 | dung beetle
307 | rhinoceros beetle
308 | weevil
309 | fly
310 | bee
311 | ant
312 | grasshopper
313 | cricket
314 | walking stick
315 | cockroach
316 | mantis
317 | cicada
318 | leafhopper
319 | lacewing
320 | dragonfly
321 | damselfly
322 | admiral
323 | ringlet
324 | monarch
325 | cabbage butterfly
326 | sulphur butterfly
327 | lycaenid
328 | starfish
329 | sea urchin
330 | sea cucumber
331 | wood rabbit
332 | hare
333 | Angora
334 | hamster
335 | porcupine
336 | fox squirrel
337 | marmot
338 | beaver
339 | guinea pig
340 | sorrel
341 | zebra
342 | hog
343 | wild boar
344 | warthog
345 | hippopotamus
346 | ox
347 | water buffalo
348 | bison
349 | ram
350 | bighorn
351 | ibex
352 | hartebeest
353 | impala
354 | gazelle
355 | Arabian camel
356 | llama
357 | weasel
358 | mink
359 | polecat
360 | black-footed ferret
361 | otter
362 | skunk
363 | badger
364 | armadillo
365 | three-toed sloth
366 | orangutan
367 | gorilla
368 | chimpanzee
369 | gibbon
370 | siamang
371 | guenon
372 | patas
373 | baboon
374 | macaque
375 | langur
376 | colobus
377 | proboscis monkey
378 | marmoset
379 | capuchin
380 | howler monkey
381 | titi
382 | spider monkey
383 | squirrel monkey
384 | Madagascar cat
385 | indri
386 | Indian elephant
387 | African elephant
388 | lesser panda
389 | giant panda
390 | barracouta
391 | eel
392 | coho
393 | rock beauty
394 | anemone fish
395 | sturgeon
396 | gar
397 | lionfish
398 | puffer
399 | abacus
400 | abaya
401 | academic gown
402 | accordion
403 | acoustic guitar
404 | aircraft carrier
405 | airliner
406 | airship
407 | altar
408 | ambulance
409 | amphibian
410 | analog clock
411 | apiary
412 | apron
413 | ashcan
414 | assault rifle
415 | backpack
416 | bakery
417 | balance beam
418 | balloon
419 | ballpoint
420 | Band Aid
421 | banjo
422 | bannister
423 | barbell
424 | barber chair
425 | barbershop
426 | barn
427 | barometer
428 | barrel
429 | barrow
430 | baseball
431 | basketball
432 | bassinet
433 | bassoon
434 | bathing cap
435 | bath towel
436 | bathtub
437 | beach wagon
438 | beacon
439 | beaker
440 | bearskin
441 | beer bottle
442 | beer glass
443 | bell cote
444 | bib
445 | bicycle-built-for-two
446 | bikini
447 | binder
448 | binoculars
449 | birdhouse
450 | boathouse
451 | bobsled
452 | bolo tie
453 | bonnet
454 | bookcase
455 | bookshop
456 | bottlecap
457 | bow
458 | bow tie
459 | brass
460 | brassiere
461 | breakwater
462 | breastplate
463 | broom
464 | bucket
465 | buckle
466 | bulletproof vest
467 | bullet train
468 | butcher shop
469 | cab
470 | caldron
471 | candle
472 | cannon
473 | canoe
474 | can opener
475 | cardigan
476 | car mirror
477 | carousel
478 | carpenter's kit
479 | carton
480 | car wheel
481 | cash machine
482 | cassette
483 | cassette player
484 | castle
485 | catamaran
486 | CD player
487 | cello
488 | cellular telephone
489 | chain
490 | chainlink fence
491 | chain mail
492 | chain saw
493 | chest
494 | chiffonier
495 | chime
496 | china cabinet
497 | Christmas stocking
498 | church
499 | cinema
500 | cleaver
501 | cliff dwelling
502 | cloak
503 | clog
504 | cocktail shaker
505 | coffee mug
506 | coffeepot
507 | coil
508 | combination lock
509 | computer keyboard
510 | confectionery
511 | container ship
512 | convertible
513 | corkscrew
514 | cornet
515 | cowboy boot
516 | cowboy hat
517 | cradle
518 | crane
519 | crash helmet
520 | crate
521 | crib
522 | Crock Pot
523 | croquet ball
524 | crutch
525 | cuirass
526 | dam
527 | desk
528 | desktop computer
529 | dial telephone
530 | diaper
531 | digital clock
532 | digital watch
533 | dining table
534 | dishrag
535 | dishwasher
536 | disk brake
537 | dock
538 | dogsled
539 | dome
540 | doormat
541 | drilling platform
542 | drum
543 | drumstick
544 | dumbbell
545 | Dutch oven
546 | electric fan
547 | electric guitar
548 | electric locomotive
549 | entertainment center
550 | envelope
551 | espresso maker
552 | face powder
553 | feather boa
554 | file
555 | fireboat
556 | fire engine
557 | fire screen
558 | flagpole
559 | flute
560 | folding chair
561 | football helmet
562 | forklift
563 | fountain
564 | fountain pen
565 | four-poster
566 | freight car
567 | French horn
568 | frying pan
569 | fur coat
570 | garbage truck
571 | gasmask
572 | gas pump
573 | goblet
574 | go-kart
575 | golf ball
576 | golfcart
577 | gondola
578 | gong
579 | gown
580 | grand piano
581 | greenhouse
582 | grille
583 | grocery store
584 | guillotine
585 | hair slide
586 | hair spray
587 | half track
588 | hammer
589 | hamper
590 | hand blower
591 | hand-held computer
592 | handkerchief
593 | hard disc
594 | harmonica
595 | harp
596 | harvester
597 | hatchet
598 | holster
599 | home theater
600 | honeycomb
601 | hook
602 | hoopskirt
603 | horizontal bar
604 | horse cart
605 | hourglass
606 | iPod
607 | iron
608 | jack-o'-lantern
609 | jean
610 | jeep
611 | jersey
612 | jigsaw puzzle
613 | jinrikisha
614 | joystick
615 | kimono
616 | knee pad
617 | knot
618 | lab coat
619 | ladle
620 | lampshade
621 | laptop
622 | lawn mower
623 | lens cap
624 | letter opener
625 | library
626 | lifeboat
627 | lighter
628 | limousine
629 | liner
630 | lipstick
631 | Loafer
632 | lotion
633 | loudspeaker
634 | loupe
635 | lumbermill
636 | magnetic compass
637 | mailbag
638 | mailbox
639 | maillot
640 | maillot
641 | manhole cover
642 | maraca
643 | marimba
644 | mask
645 | matchstick
646 | maypole
647 | maze
648 | measuring cup
649 | medicine chest
650 | megalith
651 | microphone
652 | microwave
653 | military uniform
654 | milk can
655 | minibus
656 | miniskirt
657 | minivan
658 | missile
659 | mitten
660 | mixing bowl
661 | mobile home
662 | Model T
663 | modem
664 | monastery
665 | monitor
666 | moped
667 | mortar
668 | mortarboard
669 | mosque
670 | mosquito net
671 | motor scooter
672 | mountain bike
673 | mountain tent
674 | mouse
675 | mousetrap
676 | moving van
677 | muzzle
678 | nail
679 | neck brace
680 | necklace
681 | nipple
682 | notebook
683 | obelisk
684 | oboe
685 | ocarina
686 | odometer
687 | oil filter
688 | organ
689 | oscilloscope
690 | overskirt
691 | oxcart
692 | oxygen mask
693 | packet
694 | paddle
695 | paddlewheel
696 | padlock
697 | paintbrush
698 | pajama
699 | palace
700 | panpipe
701 | paper towel
702 | parachute
703 | parallel bars
704 | park bench
705 | parking meter
706 | passenger car
707 | patio
708 | pay-phone
709 | pedestal
710 | pencil box
711 | pencil sharpener
712 | perfume
713 | Petri dish
714 | photocopier
715 | pick
716 | pickelhaube
717 | picket fence
718 | pickup
719 | pier
720 | piggy bank
721 | pill bottle
722 | pillow
723 | ping-pong ball
724 | pinwheel
725 | pirate
726 | pitcher
727 | plane
728 | planetarium
729 | plastic bag
730 | plate rack
731 | plow
732 | plunger
733 | Polaroid camera
734 | pole
735 | police van
736 | poncho
737 | pool table
738 | pop bottle
739 | pot
740 | potter's wheel
741 | power drill
742 | prayer rug
743 | printer
744 | prison
745 | projectile
746 | projector
747 | puck
748 | punching bag
749 | purse
750 | quill
751 | quilt
752 | racer
753 | racket
754 | radiator
755 | radio
756 | radio telescope
757 | rain barrel
758 | recreational vehicle
759 | reel
760 | reflex camera
761 | refrigerator
762 | remote control
763 | restaurant
764 | revolver
765 | rifle
766 | rocking chair
767 | rotisserie
768 | rubber eraser
769 | rugby ball
770 | rule
771 | running shoe
772 | safe
773 | safety pin
774 | saltshaker
775 | sandal
776 | sarong
777 | sax
778 | scabbard
779 | scale
780 | school bus
781 | schooner
782 | scoreboard
783 | screen
784 | screw
785 | screwdriver
786 | seat belt
787 | sewing machine
788 | shield
789 | shoe shop
790 | shoji
791 | shopping basket
792 | shopping cart
793 | shovel
794 | shower cap
795 | shower curtain
796 | ski
797 | ski mask
798 | sleeping bag
799 | slide rule
800 | sliding door
801 | slot
802 | snorkel
803 | snowmobile
804 | snowplow
805 | soap dispenser
806 | soccer ball
807 | sock
808 | solar dish
809 | sombrero
810 | soup bowl
811 | space bar
812 | space heater
813 | space shuttle
814 | spatula
815 | speedboat
816 | spider web
817 | spindle
818 | sports car
819 | spotlight
820 | stage
821 | steam locomotive
822 | steel arch bridge
823 | steel drum
824 | stethoscope
825 | stole
826 | stone wall
827 | stopwatch
828 | stove
829 | strainer
830 | streetcar
831 | stretcher
832 | studio couch
833 | stupa
834 | submarine
835 | suit
836 | sundial
837 | sunglass
838 | sunglasses
839 | sunscreen
840 | suspension bridge
841 | swab
842 | sweatshirt
843 | swimming trunks
844 | swing
845 | switch
846 | syringe
847 | table lamp
848 | tank
849 | tape player
850 | teapot
851 | teddy
852 | television
853 | tennis ball
854 | thatch
855 | theater curtain
856 | thimble
857 | thresher
858 | throne
859 | tile roof
860 | toaster
861 | tobacco shop
862 | toilet seat
863 | torch
864 | totem pole
865 | tow truck
866 | toyshop
867 | tractor
868 | trailer truck
869 | tray
870 | trench coat
871 | tricycle
872 | trimaran
873 | tripod
874 | triumphal arch
875 | trolleybus
876 | trombone
877 | tub
878 | turnstile
879 | typewriter keyboard
880 | umbrella
881 | unicycle
882 | upright
883 | vacuum
884 | vase
885 | vault
886 | velvet
887 | vending machine
888 | vestment
889 | viaduct
890 | violin
891 | volleyball
892 | waffle iron
893 | wall clock
894 | wallet
895 | wardrobe
896 | warplane
897 | washbasin
898 | washer
899 | water bottle
900 | water jug
901 | water tower
902 | whiskey jug
903 | whistle
904 | wig
905 | window screen
906 | window shade
907 | Windsor tie
908 | wine bottle
909 | wing
910 | wok
911 | wooden spoon
912 | wool
913 | worm fence
914 | wreck
915 | yawl
916 | yurt
917 | web site
918 | comic book
919 | crossword puzzle
920 | street sign
921 | traffic light
922 | book jacket
923 | menu
924 | plate
925 | guacamole
926 | consomme
927 | hot pot
928 | trifle
929 | ice cream
930 | ice lolly
931 | French loaf
932 | bagel
933 | pretzel
934 | cheeseburger
935 | hotdog
936 | mashed potato
937 | head cabbage
938 | broccoli
939 | cauliflower
940 | zucchini
941 | spaghetti squash
942 | acorn squash
943 | butternut squash
944 | cucumber
945 | artichoke
946 | bell pepper
947 | cardoon
948 | mushroom
949 | Granny Smith
950 | strawberry
951 | orange
952 | lemon
953 | fig
954 | pineapple
955 | banana
956 | jackfruit
957 | custard apple
958 | pomegranate
959 | hay
960 | carbonara
961 | chocolate sauce
962 | dough
963 | meat loaf
964 | pizza
965 | potpie
966 | burrito
967 | red wine
968 | espresso
969 | cup
970 | eggnog
971 | alp
972 | bubble
973 | cliff
974 | coral reef
975 | geyser
976 | lakeside
977 | promontory
978 | sandbar
979 | seashore
980 | valley
981 | volcano
982 | ballplayer
983 | groom
984 | scuba diver
985 | rapeseed
986 | daisy
987 | yellow lady's slipper
988 | corn
989 | acorn
990 | hip
991 | buckeye
992 | coral fungus
993 | agaric
994 | gyromitra
995 | stinkhorn
996 | earthstar
997 | hen-of-the-woods
998 | bolete
999 | ear
1000 | toilet tissue
1001 |
--------------------------------------------------------------------------------
/postprocess_onnx.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright 2020 NVIDIA Corporation
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import onnx_graphsurgeon as gs
17 | import argparse
18 | import onnx
19 | import numpy as np
20 |
21 | def process_transpose_nodes(graph):
22 | """
23 | This is a workaround to manually transpose the conv weights and remove
24 | the existing transpose nodes. Currently TRT has a limitation when there is
25 | a transpose node as an input to the weights of the conv layer. This utility
26 | would be removed in future releases.
27 | """
28 | # Find all the transposes before the convolutional nodes
29 | conv_nodes = [node for node in graph.nodes if node.op == "Conv"]
30 | for node in conv_nodes:
31 | # Transpose the convolutional weights and reset them to the weights
32 | conv_weights_tensor = node.i(1).i().i().inputs[0]
33 | conv_weights_transposed = np.transpose(conv_weights_tensor.values, [3, 2, 0, 1])
34 | conv_weights_tensor.values = conv_weights_transposed
35 |
36 | # Remove the transpose nodes after the dequant node. TensorRT does not support transpose nodes after QDQ nodes.
37 | dequant_node_output = node.i(1).i(0).outputs[0]
38 | node.inputs[1] = dequant_node_output
39 |
40 | # Remove unused nodes, and topologically sort the graph.
41 | return graph.cleanup().toposort()
42 |
43 | if __name__=='__main__':
44 | parser = argparse.ArgumentParser("Post process ONNX graph by removing transpose nodes")
45 | parser.add_argument("--input", required=True, help="Input onnx graph")
46 | parser.add_argument("--output", default='postprocessed_rn50.onnx', help="Name of post processed onnx graph")
47 | args = parser.parse_args()
48 |
49 | # Load the rn50 graph
50 | graph = gs.import_onnx(onnx.load(args.input))
51 |
52 | # Remove the transpose nodes and reshape the convolution weights
53 | graph = process_transpose_nodes(graph)
54 |
55 | # Export the onnx graph from graphsurgeon
56 | onnx_model = gs.export_onnx(graph)
57 | print("Output ONNX graph generated: ", args.output)
58 | onnx.save_model(onnx_model, args.output)
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | onnx==1.6
2 | tf2onnx==1.6.1
3 | numpy
4 | pycuda
5 |
--------------------------------------------------------------------------------