├── LICENSE
├── README.md
├── convert_onnx_to_tensorrt
└── convert_onnx_to_tensorrt.py
└── convert_pytorch_to_onnx
└── convert_pytorch_to_onnx.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Yong
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | Description
4 | =============
5 |
6 | #### - TensorRT engine converter of various TensorRT versions (refer to each branch or tag)
7 |
8 | #### - ONNX (Open Neural Network Exchange)
9 | - Standard format for expressing machine learning algorithms and models
10 | - More details about ONNX: https://en.wikipedia.org/wiki/Open_Neural_Network_Exchange
11 |
12 | #### - TensorRT
13 | - NVIDIA SDK for high-performance deep learning inference
14 | - Deep learning inference optimizer and runtime that delivers low latency and high throughput for deep learning inference applications
15 | - Explicit batch is required when you are dealing with Dynamic shapes, otherwise network will be created using implicit batch dimension.
16 | - More details about TensorRT: https://blog.naver.com/qbxlvnf11/222403199156
17 |
18 | Contents
19 | =============
20 | #### - [Converting Pytorch to onnx](https://github.com/qbxlvnf11/convert-pytorch-onnx-tensorrt/blob/TensorRT-21.08/convert_pytorch_to_onnx/convert_pytorch_to_onnx.py)
21 | - Details: https://blog.naver.com/qbxlvnf11/222342675767
22 | - Export & load onnx
23 | - Inference onnx
24 | - Compare output and time efficiency between onnx and pytorch
25 | - Setting batch size of input data: explicit batch or implicit batch
26 |
27 | #### - [Converting onnx to TensorRT and test time efficiency](https://github.com/qbxlvnf11/convert-pytorch-onnx-tensorrt/blob/TensorRT-21.08/convert_onnx_to_tensorrt/convert_onnx_to_tensorrt.py)
28 | - Build & load TensorRT engine
29 | - Setting batch size of input data: explicit batch or implicit batch
30 | - Key trtexec options
31 | - Precision of engine: FP32, FP16
32 | - optShapes: set the most used input data size of model for inference
33 | - minShapes: set the max input data size of model for inference
34 | - maxShapes: set the min input data size of model for inference
35 | - Inference TensorRT engine
36 | - Compare output and time efficiency among tensorrt and onnx and pytorch
37 |
38 | TensorRT Docker Environment
39 | =============
40 |
41 | #### - Download TensorRT Docker environment
42 | ```
43 | docker pull qbxlvnf11docker/tensorrt_21.08
44 | ```
45 |
46 | #### - Run TensorRT Docker environment
47 | ```
48 | nvidia-docker run -it -p 9000:9000 -e GRANT_SUDO=yes --user root --name tensorrt_21.08_env -v {code_folder_path}:/workspace -w /workspace qbxlvnf11docker/tensorrt_21.08:latest bash
49 | ```
50 |
51 | Examples of inferencing ResNet18 with TensorRT
52 | =============
53 |
54 | #### - Explicit batch
55 | - Converting Pytorch model to onnx
56 | ```
57 | python convert_pytorch_to_onnx/convert_pytorch_to_onnx.py --dynamic_axes True --output_path onnx_output_explicit.onnx --batch_size {batch_size}
58 | ```
59 |
60 | - Converting onnx to TensorRT and test time efficiency (FP32)
61 | - Setting three parameters (minShapes, optShapes, maxShapes) according to the inference environment
62 | ```
63 | python convert_onnx_to_tensorrt/convert_onnx_to_tensorrt.py --dynamic_axes True --onnx_model_path onnx_output_explicit.onnx --batch_size {batch_size} --tensorrt_engine_path FP32_explicit.engine --engine_precision FP32
64 | ```
65 |
66 | - Converting onnx to TensorRT and test time efficiency (FP16)
67 | - Setting three parameters (minShapes, optShapes, maxShapes) according to the inference environment
68 | ```
69 | python convert_onnx_to_tensorrt/convert_onnx_to_tensorrt.py --dynamic_axes True --onnx_model_path onnx_output_explicit.onnx --batch_size {batch_size} --tensorrt_engine_path FP16_explicit.engine --engine_precision FP16
70 | ```
71 |
72 | #### - Implicit batch
73 | - Converting Pytorch model to onnx
74 | ```
75 | python convert_pytorch_to_onnx/convert_pytorch_to_onnx.py --dynamic_axes False --output_path onnx_output_implicit.onnx --batch_size {batch_size}
76 | ```
77 |
78 | - Converting onnx to TensorRT and test time efficiency (FP32)
79 | ```
80 | python convert_onnx_to_tensorrt/convert_onnx_to_tensorrt.py --dynamic_axes False --onnx_model_path onnx_output_implicit.onnx --batch_size {batch_size_of_implicit_batch_onnx_model} --tensorrt_engine_path FP32_implicit.engine --engine_precision FP32
81 | ```
82 |
83 | - Converting onnx to TensorRT and test time efficiency (FP16)
84 | ```
85 | python convert_onnx_to_tensorrt/convert_onnx_to_tensorrt.py --dynamic_axes False --onnx_model_path onnx_output_implicit.onnx --batch_size {batch_size_of_implicit_batch_onnx_model} --tensorrt_engine_path FP16_implicit.engine --engine_precision FP16
86 | ```
87 |
88 | #### - Comparision of time efficiency and output
89 | - Explicit batch test of FP32 TensorRT engine
90 | - Batch size of inf data = 1
91 | - Batch size of optShapes = 1
92 |
93 |
94 |
95 | - Explicit batch test of FP16 TensorRT engine
96 | - Batch size of inf data = 1
97 | - Batch size of optShapes = 1
98 |
99 |
100 |
101 |
102 | - Explicit batch test of FP16 TensorRT engine
103 | - Batch size of inf data = 8
104 | - Batch size of optShapes = 1
105 |
106 |
107 |
108 | - Implicit batch test of FP32 TensorRT engine
109 | - Batch size of inf data = 1
110 |
111 |
112 |
113 | - Implicit batch test of FP16 TensorRT engine
114 | - Batch size of inf data = 1
115 |
116 |
117 |
118 | References
119 | =============
120 |
121 | #### - Converting Pytorch models to onnx
122 |
123 | https://pytorch.org/docs/stable/onnx.html
124 |
125 | #### - TensorRT
126 |
127 | https://developer.nvidia.com/tensorrt
128 |
129 | #### - TensorRT Release 21.08
130 |
131 | https://docs.nvidia.com/deeplearning/tensorrt/container-release-notes/rel_21-08.html#rel_21-08
132 |
133 | #### - TensorRT8 code
134 |
135 | https://github.com/NVIDIA/trt-samples-for-hackathon-cn/blob/master/cookbook/04-Parser/pyTorch-ONNX-TensorRT/main.py
136 |
137 | #### - ImageNet 1000 samples
138 |
139 | https://www.kaggle.com/ifigotin/imagenetmini-1000
140 |
141 | Author
142 | =============
143 |
144 | #### - LinkedIn: https://www.linkedin.com/in/taeyong-kong-016bb2154
145 |
146 | #### - Blog URL: https://blog.naver.com/qbxlvnf11
147 |
148 | #### - Email: qbxlvnf11@google.com, qbxlvnf11@naver.com
149 |
150 |
--------------------------------------------------------------------------------
/convert_onnx_to_tensorrt/convert_onnx_to_tensorrt.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import numpy as np
4 | from skimage import io
5 | from skimage.transform import resize
6 | from collections import OrderedDict
7 | from PIL import Image
8 | import cv2
9 | import time
10 |
11 | # Torch
12 | import torch
13 | from torch import nn
14 | import torchvision.models as models
15 | from torch.utils.model_zoo import load_url as load_state_dict_from_url
16 | import torch.optim as optim
17 | import torchvision.datasets as datasets
18 | import torchvision.transforms as transforms
19 | from torchvision.utils import save_image
20 |
21 | # ONNX: pip install onnx, onnxruntime
22 | try:
23 | import onnx
24 | import onnxruntime as rt
25 | except ImportError as e:
26 | raise ImportError(f'Please install onnx and onnxruntime first. {e}')
27 |
28 | # CUDA & TensorRT
29 | #import pycuda.driver as cuda
30 | from cuda import cuda
31 | import pycuda.autoinit
32 | import tensorrt as trt
33 |
34 | TRT_LOGGER = trt.Logger()
35 |
36 | def parse_args():
37 | parser = argparse.ArgumentParser(description='Convert ONNX models to TensorRT')
38 |
39 | parser.add_argument('--device', help='cuda or not',
40 | default='cuda:0')
41 |
42 | # Sample image
43 | parser.add_argument('--batch_size', type=int, help='data batch size',
44 | default=1)
45 | parser.add_argument('--img_size', help='input size',
46 | default=[3, 224, 224])
47 | parser.add_argument('--sample_folder_path', help='sample image folder path',
48 | default='./imagenet-mini/train')
49 | #parser.add_argument('--sample_image_path', help='sample image path',
50 | #default='./sample.jpg')
51 |
52 | # Model path
53 | parser.add_argument('--onnx_model_path', help='onnx model path',
54 | default='./onnx_model.onnx')
55 | parser.add_argument('--tensorrt_engine_path', help='tensorrt engine path',
56 | default='./tensorrt_engine.engine')
57 |
58 | # TensorRT engine params
59 | parser.add_argument('--dynamic_axes', help='dynamic batch input or output',
60 | default='True')
61 | parser.add_argument('--engine_precision', help='precision of TensorRT engine', choices=['FP32', 'FP16'],
62 | default='FP32')
63 | parser.add_argument('--min_engine_batch_size', type=int, help='set the min input data size of model for inference',
64 | default=1)
65 | parser.add_argument('--opt_engine_batch_size', type=int, help='set the most used input data size of model for inference',
66 | default=1)
67 | parser.add_argument('--max_engine_batch_size', type=int, help='set the max input data size of model for inference',
68 | default=8)
69 | parser.add_argument('--engine_workspace', type=int, help='workspace of engine',
70 | default=1024)
71 |
72 | args = string_to_bool(parser.parse_args())
73 |
74 | return args
75 |
76 | def string_to_bool(args):
77 |
78 | if args.dynamic_axes.lower() in ('true'): args.dynamic_axes = True
79 | else: args.dynamic_axes = False
80 |
81 | return args
82 |
83 | def get_transform(img_size):
84 | options = []
85 | options.append(transforms.Resize((img_size[1], img_size[2])))
86 | options.append(transforms.ToTensor())
87 | #options.append(transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5]))
88 | transform = transforms.Compose(options)
89 | return transform
90 |
91 | '''
92 | def load_image(img_path, size):
93 | img_raw = io.imread(img_path)
94 | img_raw = np.rollaxis(img_raw, 2, 0)
95 | img_resize = resize(img_raw / 255, size, anti_aliasing=True)
96 | img_resize = img_resize.astype(np.float32)
97 | return img_resize, img_raw
98 | '''
99 |
100 | def load_image_folder(folder_path, img_size, batch_size):
101 | transforming = get_transform(img_size)
102 | dataset = datasets.ImageFolder(folder_path, transform=transforming)
103 | data_loader = torch.utils.data.DataLoader(dataset,
104 | batch_size=batch_size,
105 | shuffle=True,
106 | num_workers=1)
107 | data_iter = iter(data_loader)
108 | torch_images, class_list = next(data_iter)
109 | print('class:', class_list)
110 | print('torch images size:', torch_images.size())
111 | save_image(torch_images[0], 'sample.png')
112 |
113 | return torch_images.cpu().numpy()
114 |
115 | def build_engine(onnx_model_path, tensorrt_engine_path, engine_precision, dynamic_axes, \
116 | img_size, batch_size, min_engine_batch_size, opt_engine_batch_size, max_engine_batch_size):
117 |
118 | # Builder
119 | logger = trt.Logger(trt.Logger.ERROR)
120 | builder = trt.Builder(logger)
121 | network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
122 | profile = builder.create_optimization_profile()
123 | config = builder.create_builder_config()
124 | #config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 3 << 30)
125 | # Set FP16
126 | if engine_precision == 'FP16':
127 | config.set_flag(trt.BuilderFlag.FP16)
128 |
129 | # Onnx parser
130 | parser = trt.OnnxParser(network, logger)
131 | if not os.path.exists(onnx_model_path):
132 | print("Failed finding ONNX file!")
133 | exit()
134 | print("Succeeded finding ONNX file!")
135 | with open(onnx_model_path, "rb") as model:
136 | if not parser.parse(model.read()):
137 | print("Failed parsing .onnx file!")
138 | for error in range(parser.num_errors):
139 | print(parser.get_error(error))
140 | exit()
141 | print("Succeeded parsing .onnx file!")
142 |
143 | # Input
144 | inputTensor = network.get_input(0)
145 | # Dynamic batch (min, opt, max)
146 | print('inputTensor.name:', inputTensor.name)
147 | if dynamic_axes:
148 | profile.set_shape(inputTensor.name, (min_engine_batch_size, img_size[0], img_size[1], img_size[2]), \
149 | (opt_engine_batch_size, img_size[0], img_size[1], img_size[2]), \
150 | (max_engine_batch_size, img_size[0], img_size[1], img_size[2]))
151 | print('Set dynamic')
152 | else:
153 | profile.set_shape(inputTensor.name, (batch_size, img_size[0], img_size[1], img_size[2]), \
154 | (batch_size, img_size[0], img_size[1], img_size[2]), \
155 | (batch_size, img_size[0], img_size[1], img_size[2]))
156 | config.add_optimization_profile(profile)
157 | #network.unmark_output(network.get_output(0))
158 |
159 | # Write engine
160 | engineString = builder.build_serialized_network(network, config)
161 | if engineString == None:
162 | print("Failed building engine!")
163 | exit()
164 | print("Succeeded building engine!")
165 | with open(tensorrt_engine_path, "wb") as f:
166 | f.write(engineString)
167 |
168 | def trt_inference(engine, context, data):
169 |
170 | nInput = np.sum([engine.binding_is_input(i) for i in range(engine.num_bindings)])
171 | nOutput = engine.num_bindings - nInput
172 | print('nInput:', nInput)
173 | print('nOutput:', nOutput)
174 |
175 | for i in range(nInput):
176 | print("Bind[%2d]:i[%2d]->" % (i, i), engine.get_binding_dtype(i), engine.get_binding_shape(i), context.get_binding_shape(i), engine.get_binding_name(i))
177 | for i in range(nInput,nInput+nOutput):
178 | print("Bind[%2d]:o[%2d]->" % (i, i - nInput), engine.get_binding_dtype(i), engine.get_binding_shape(i), context.get_binding_shape(i), engine.get_binding_name(i))
179 |
180 | bufferH = []
181 | bufferH.append(np.ascontiguousarray(data.reshape(-1)))
182 |
183 | for i in range(nInput, nInput + nOutput):
184 | bufferH.append(np.empty(context.get_binding_shape(i), dtype=trt.nptype(engine.get_binding_dtype(i))))
185 |
186 | bufferD = []
187 | for i in range(nInput + nOutput):
188 | bufferD.append(cuda.cuMemAlloc(bufferH[i].nbytes)[1])
189 |
190 | for i in range(nInput):
191 | cuda.cuMemcpyHtoD(bufferD[i], bufferH[i].ctypes.data, bufferH[i].nbytes)
192 |
193 | context.execute_v2(bufferD)
194 |
195 | for i in range(nInput, nInput + nOutput):
196 | cuda.cuMemcpyDtoH(bufferH[i].ctypes.data, bufferD[i], bufferH[i].nbytes)
197 |
198 | for b in bufferD:
199 | cuda.cuMemFree(b)
200 |
201 | return bufferH
202 |
203 | def main():
204 | args = parse_args()
205 |
206 | # Sample images (folder)
207 | print(args.sample_folder_path)
208 | img_resize = load_image_folder(args.sample_folder_path, args.img_size, args.batch_size).astype(np.float32)
209 | '''
210 | # Sample (one image)
211 | print(args.sample_image_path)
212 | img_resize, img_raw = load_image(args.sample_image_path, args.img_size)
213 | '''
214 |
215 | print('inference image size:', img_resize.shape)
216 |
217 | # Build TensorRT engine
218 | build_engine(args.onnx_model_path, args.tensorrt_engine_path, args.engine_precision, args.dynamic_axes, \
219 | args.img_size, args.batch_size, args.min_engine_batch_size, args.opt_engine_batch_size, args.max_engine_batch_size)
220 |
221 | # Read the engine from the file and deserialize
222 | with open(args.tensorrt_engine_path, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime:
223 | engine = runtime.deserialize_cuda_engine(f.read())
224 | context = engine.create_execution_context()
225 |
226 | # TensorRT inference
227 | context.set_binding_shape(0, (args.batch_size, args.img_size[0], args.img_size[1], args.img_size[2]))
228 |
229 | trt_start_time = time.time()
230 | trt_outputs = trt_inference(engine, context, img_resize)
231 | trt_outputs = np.array(trt_outputs[1]).reshape(args.batch_size, -1)
232 | trt_end_time = time.time()
233 |
234 | # ONNX inference
235 | onnx_model = onnx.load(args.onnx_model_path)
236 | sess = rt.InferenceSession(args.onnx_model_path)
237 |
238 | input_all = [node.name for node in onnx_model.graph.input]
239 | input_initializer = [
240 | node.name for node in onnx_model.graph.initializer
241 | ]
242 | net_feed_input = list(set(input_all) - set(input_initializer))
243 | assert len(net_feed_input) == 1
244 |
245 | sess_input = sess.get_inputs()[0].name
246 | sess_output = sess.get_outputs()[0].name
247 |
248 | onnx_start_time = time.time()
249 | onnx_result = sess.run([sess_output], {sess_input: img_resize})[0]
250 | onnx_end_time = time.time()
251 |
252 | # Pytorch inference
253 | resnet18 = models.resnet18(pretrained=True).to(args.device)
254 | resnet18.eval()
255 |
256 | img_resize_torch = torch.Tensor(img_resize).to(args.device)
257 | torch_start_time = time.time()
258 | pytorch_result = resnet18(img_resize_torch)
259 | torch_end_time = time.time()
260 | pytorch_result = pytorch_result.detach().cpu().numpy()
261 |
262 | ## Comparision output of TensorRT and output of onnx model
263 |
264 | # Time Efficiency & output
265 | print('--pytorch--')
266 | print(pytorch_result.shape) # (batch_size, 1000)
267 | print(pytorch_result[0][:10])
268 | print(np.argmax(pytorch_result, axis=1))
269 | print('Time:', torch_end_time - torch_start_time)
270 |
271 | print('--onnx--')
272 | print(onnx_result.shape)
273 | print(onnx_result[0][:10])
274 | print(np.argmax(onnx_result, axis=1))
275 | print('Time: ', onnx_end_time - onnx_start_time)
276 |
277 | print('--tensorrt--')
278 | print(trt_outputs.shape)
279 | print(trt_outputs[0][:10])
280 | print(np.argmax(trt_outputs, axis=1))
281 | print('Time: ', trt_end_time - trt_start_time)
282 |
283 | if __name__ == '__main__':
284 | main()
285 |
--------------------------------------------------------------------------------
/convert_pytorch_to_onnx/convert_pytorch_to_onnx.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import numpy as np
4 | from skimage import io
5 | from skimage.transform import resize
6 | from collections import OrderedDict
7 | from PIL import Image
8 | import cv2
9 | import time
10 |
11 | # Torch
12 | import torch
13 | from torch import nn
14 | import torchvision.models as models
15 | from torch.utils.model_zoo import load_url as load_state_dict_from_url
16 | import torch.optim as optim
17 | import torchvision.datasets as datasets
18 | import torchvision.transforms as transforms
19 | from torchvision.utils import save_image
20 |
21 | # ONNX: pip install onnx, onnxruntime
22 | try:
23 | import onnx
24 | import onnxruntime as rt
25 | except ImportError as e:
26 | raise ImportError(f'Please install onnx and onnxruntime first. {e}')
27 |
28 | def parse_args():
29 | parser = argparse.ArgumentParser(description='Convert Pytorch models to ONNX')
30 |
31 | parser.add_argument('--device', help='cuda or not',
32 | default='cuda:0')
33 |
34 | # Sample image
35 | parser.add_argument('--batch_size', type=int, help='onnx sample batch size',
36 | default=1)
37 | parser.add_argument('--img_size', help='image size',
38 | default=[3, 224, 224])
39 | parser.add_argument('--sample_folder_path', help='sample image folder path',
40 | default='./imagenet-mini/train/')
41 | #parser.add_argument('--sample_image_path', help='sample image path',
42 | #default='./sample.jpg')
43 |
44 | parser.add_argument('--output_path', help='onnx model path',
45 | default='./onnx_output.onnx')
46 |
47 | # ONNX params
48 | parser.add_argument('--dynamic_axes', help='dynamic batch input or output',
49 | default='True')
50 | parser.add_argument('--keep_initializers_as_inputs', help='If True, all the initializers (typically corresponding to parameters) in the exported graph will also be added as inputs to the graph. If False, then initializers are not added as inputs to the graph, and only the non-parameter inputs are added as inputs.',
51 | default='True')
52 | parser.add_argument('--export_params', help='If specified, all parameters will be exported. Set this to False if you want to export an untrained model.',
53 | default='True')
54 | parser.add_argument('--opset_version', type=int, help='opset version',
55 | default=11)
56 |
57 | args = string_to_bool(parser.parse_args())
58 |
59 | return args
60 |
61 | def string_to_bool(args):
62 |
63 | if args.dynamic_axes.lower() in ('true'): args.dynamic_axes = True
64 | else: args.dynamic_axes = False
65 |
66 | if args.keep_initializers_as_inputs.lower() in ('true'): args.keep_initializers_as_inputs = True
67 | else: args.keep_initializers_as_inputs = False
68 |
69 | if args.export_params.lower() in ('true'): args.export_params = True
70 | else: args.export_params = False
71 |
72 | return args
73 |
74 | def get_transform(img_size):
75 | options = []
76 | options.append(transforms.Resize((img_size[1], img_size[2])))
77 | options.append(transforms.ToTensor())
78 | #options.append(transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5]))
79 | transform = transforms.Compose(options)
80 | return transform
81 |
82 | '''def load_image(img_path, size):
83 | img_raw = io.imread(img_path)
84 | img_raw = np.rollaxis(img_raw, 2, 0)
85 | img_resize = resize(img_raw / 255, size, anti_aliasing=True)
86 | img_resize = img_resize.astype(np.float32)
87 | return img_resize, img_raw'''
88 |
89 | def load_image_folder(folder_path, img_size, batch_size):
90 | transforming = get_transform(img_size)
91 | dataset = datasets.ImageFolder(folder_path, transform=transforming)
92 | data_loader = torch.utils.data.DataLoader(dataset,
93 | batch_size=batch_size,
94 | shuffle=True,
95 | num_workers=1)
96 | data_iter = iter(data_loader)
97 | torch_images, class_list = next(data_iter)
98 | save_image(torch_images[0], 'sample.png')
99 |
100 | return torch_images.cpu().numpy()
101 |
102 | if __name__ == '__main__':
103 | args = parse_args()
104 |
105 | # Load pretrained model
106 | resnet18 = models.resnet18(pretrained=True).to(args.device)
107 |
108 | '''
109 | fc = nn.Sequential(OrderedDict([
110 | ('fc1', nn.Linear(512,1000)),
111 | ('output',nn.Softmax(dim=1))
112 | ]))
113 | resnet18.fc = fc
114 | '''
115 |
116 | print(resnet18)
117 |
118 | resnet18.eval()
119 |
120 | # Sample images (folder)
121 | print(args.sample_folder_path)
122 | img_resize = load_image_folder(args.sample_folder_path, args.img_size, args.batch_size).astype(np.float32)
123 | '''
124 | # Sample (one image)
125 | print(args.sample_image_path)
126 | img_resize, img_raw = load_image(args.sample_image_path, args.img_size)
127 | '''
128 |
129 | sample_input = torch.randn(args.batch_size, args.img_size[0], args.img_size[1], args.img_size[2]).to(args.device)
130 | print('inference image size:', img_resize.shape, 'sample input size:', sample_input.shape)
131 |
132 | if args.dynamic_axes:
133 | # Dynamic input
134 | dynamic_axes = {'input' : {0 : 'batch_size'}, 'output' : {0 : 'batch_size'}}
135 |
136 | # Export onnx
137 | torch.onnx.export(
138 | resnet18,
139 | sample_input,
140 | args.output_path,
141 | export_params=args.export_params,
142 | keep_initializers_as_inputs=args.keep_initializers_as_inputs,
143 | opset_version=args.opset_version,
144 | input_names=['input'], # input vect name
145 | output_names=['output'], # output vect name
146 | dynamic_axes=dynamic_axes, # dynamic input
147 | verbose=False)
148 | else:
149 | # Export onnx
150 | torch.onnx.export(
151 | resnet18,
152 | sample_input,
153 | args.output_path,
154 | export_params=args.export_params,
155 | keep_initializers_as_inputs=args.keep_initializers_as_inputs,
156 | opset_version=args.opset_version,
157 | input_names=['input'], # input vect name
158 | output_names=['output'], # output vect name
159 | verbose=False)
160 |
161 | # Load the ONNX model
162 | onnx_model = onnx.load(args.output_path)
163 | sess = rt.InferenceSession(args.output_path)
164 |
165 | # Check that the IR is well formed
166 | onnx.checker.check_model(onnx_model)
167 |
168 | # Print a human readable representation of the graph
169 | with open('OnnxShape.txt','w') as f:
170 | f.write(f"{onnx.helper.printable_graph(onnx_model.graph)}")
171 |
172 | ## Comparision output of onnx and output of Pytorch model
173 | # Pytorch results
174 | img_resize_torch = torch.Tensor(img_resize).to(args.device)
175 | torch_start_time = time.time()
176 | pytorch_result = resnet18(img_resize_torch)
177 | torch_end_time = time.time()
178 | pytorch_result = pytorch_result.detach().cpu().numpy()
179 |
180 | # ONNX results
181 | input_all = [node.name for node in onnx_model.graph.input]
182 | input_initializer = [
183 | node.name for node in onnx_model.graph.initializer
184 | ]
185 | net_feed_input = list(set(input_all) - set(input_initializer))
186 | assert len(net_feed_input) == 1
187 |
188 | sess_input = sess.get_inputs()[0].name
189 | sess_output = sess.get_outputs()[0].name
190 |
191 | onnx_start_time = time.time()
192 | onnx_result = sess.run([sess_output], {sess_input: img_resize})[0]
193 | onnx_end_time = time.time()
194 |
195 | print('--pytorch--')
196 | print(pytorch_result.shape) # (batch_size, 1000)
197 | print(pytorch_result[0][:10])
198 | print(np.argmax(pytorch_result, axis=1))
199 | print('Time:', torch_end_time - torch_start_time)
200 |
201 | print('--onnx--')
202 | print(onnx_result.shape)
203 | print(onnx_result[0][:10])
204 | print(np.argmax(onnx_result, axis=1))
205 | print('Time:', onnx_end_time - onnx_start_time)
206 |
207 | # Comparision
208 | assert np.allclose(
209 | pytorch_result, onnx_result,
210 | atol=1.e-2), 'The outputs are different (Pytorch and ONNX)'
211 | print('The numerical values are same (Pytorch and ONNX)')
212 |
--------------------------------------------------------------------------------