├── LICENSE ├── README.md ├── convert_onnx_to_tensorrt └── convert_onnx_to_tensorrt.py └── convert_pytorch_to_onnx └── convert_pytorch_to_onnx.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Yong 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | Description 4 | ============= 5 | 6 | #### - TensorRT engine converter of various TensorRT versions (refer to each branch or tag) 7 | 8 | #### - ONNX (Open Neural Network Exchange) 9 | - Standard format for expressing machine learning algorithms and models 10 | - More details about ONNX: https://en.wikipedia.org/wiki/Open_Neural_Network_Exchange 11 | 12 | #### - TensorRT 13 | - NVIDIA SDK for high-performance deep learning inference 14 | - Deep learning inference optimizer and runtime that delivers low latency and high throughput for deep learning inference applications 15 | - Explicit batch is required when you are dealing with Dynamic shapes, otherwise network will be created using implicit batch dimension. 16 | - More details about TensorRT: https://blog.naver.com/qbxlvnf11/222403199156 17 | 18 | Contents 19 | ============= 20 | #### - [Converting Pytorch to onnx](https://github.com/qbxlvnf11/convert-pytorch-onnx-tensorrt/blob/TensorRT-21.08/convert_pytorch_to_onnx/convert_pytorch_to_onnx.py) 21 | - Details: https://blog.naver.com/qbxlvnf11/222342675767 22 | - Export & load onnx 23 | - Inference onnx 24 | - Compare output and time efficiency between onnx and pytorch 25 | - Setting batch size of input data: explicit batch or implicit batch 26 | 27 | #### - [Converting onnx to TensorRT and test time efficiency](https://github.com/qbxlvnf11/convert-pytorch-onnx-tensorrt/blob/TensorRT-21.08/convert_onnx_to_tensorrt/convert_onnx_to_tensorrt.py) 28 | - Build & load TensorRT engine 29 | - Setting batch size of input data: explicit batch or implicit batch 30 | - Key trtexec options 31 | - Precision of engine: FP32, FP16 32 | - optShapes: set the most used input data size of model for inference 33 | - minShapes: set the max input data size of model for inference 34 | - maxShapes: set the min input data size of model for inference 35 | - Inference TensorRT engine 36 | - Compare output and time efficiency among tensorrt and onnx and pytorch 37 | 38 | TensorRT Docker Environment 39 | ============= 40 | 41 | #### - Download TensorRT Docker environment 42 | ``` 43 | docker pull qbxlvnf11docker/tensorrt_21.08 44 | ``` 45 | 46 | #### - Run TensorRT Docker environment 47 | ``` 48 | nvidia-docker run -it -p 9000:9000 -e GRANT_SUDO=yes --user root --name tensorrt_21.08_env -v {code_folder_path}:/workspace -w /workspace qbxlvnf11docker/tensorrt_21.08:latest bash 49 | ``` 50 | 51 | Examples of inferencing ResNet18 with TensorRT 52 | ============= 53 | 54 | #### - Explicit batch 55 | - Converting Pytorch model to onnx 56 | ``` 57 | python convert_pytorch_to_onnx/convert_pytorch_to_onnx.py --dynamic_axes True --output_path onnx_output_explicit.onnx --batch_size {batch_size} 58 | ``` 59 | 60 | - Converting onnx to TensorRT and test time efficiency (FP32) 61 | - Setting three parameters (minShapes, optShapes, maxShapes) according to the inference environment 62 | ``` 63 | python convert_onnx_to_tensorrt/convert_onnx_to_tensorrt.py --dynamic_axes True --onnx_model_path onnx_output_explicit.onnx --batch_size {batch_size} --tensorrt_engine_path FP32_explicit.engine --engine_precision FP32 64 | ``` 65 | 66 | - Converting onnx to TensorRT and test time efficiency (FP16) 67 | - Setting three parameters (minShapes, optShapes, maxShapes) according to the inference environment 68 | ``` 69 | python convert_onnx_to_tensorrt/convert_onnx_to_tensorrt.py --dynamic_axes True --onnx_model_path onnx_output_explicit.onnx --batch_size {batch_size} --tensorrt_engine_path FP16_explicit.engine --engine_precision FP16 70 | ``` 71 | 72 | #### - Implicit batch 73 | - Converting Pytorch model to onnx 74 | ``` 75 | python convert_pytorch_to_onnx/convert_pytorch_to_onnx.py --dynamic_axes False --output_path onnx_output_implicit.onnx --batch_size {batch_size} 76 | ``` 77 | 78 | - Converting onnx to TensorRT and test time efficiency (FP32) 79 | ``` 80 | python convert_onnx_to_tensorrt/convert_onnx_to_tensorrt.py --dynamic_axes False --onnx_model_path onnx_output_implicit.onnx --batch_size {batch_size_of_implicit_batch_onnx_model} --tensorrt_engine_path FP32_implicit.engine --engine_precision FP32 81 | ``` 82 | 83 | - Converting onnx to TensorRT and test time efficiency (FP16) 84 | ``` 85 | python convert_onnx_to_tensorrt/convert_onnx_to_tensorrt.py --dynamic_axes False --onnx_model_path onnx_output_implicit.onnx --batch_size {batch_size_of_implicit_batch_onnx_model} --tensorrt_engine_path FP16_implicit.engine --engine_precision FP16 86 | ``` 87 | 88 | #### - Comparision of time efficiency and output 89 | - Explicit batch test of FP32 TensorRT engine 90 | - Batch size of inf data = 1 91 | - Batch size of optShapes = 1 92 | 93 | 94 | 95 | - Explicit batch test of FP16 TensorRT engine 96 | - Batch size of inf data = 1 97 | - Batch size of optShapes = 1 98 | 99 | 100 | 101 | 102 | - Explicit batch test of FP16 TensorRT engine 103 | - Batch size of inf data = 8 104 | - Batch size of optShapes = 1 105 | 106 | 107 | 108 | - Implicit batch test of FP32 TensorRT engine 109 | - Batch size of inf data = 1 110 | 111 | 112 | 113 | - Implicit batch test of FP16 TensorRT engine 114 | - Batch size of inf data = 1 115 | 116 | 117 | 118 | References 119 | ============= 120 | 121 | #### - Converting Pytorch models to onnx 122 | 123 | https://pytorch.org/docs/stable/onnx.html 124 | 125 | #### - TensorRT 126 | 127 | https://developer.nvidia.com/tensorrt 128 | 129 | #### - TensorRT Release 21.08 130 | 131 | https://docs.nvidia.com/deeplearning/tensorrt/container-release-notes/rel_21-08.html#rel_21-08 132 | 133 | #### - TensorRT8 code 134 | 135 | https://github.com/NVIDIA/trt-samples-for-hackathon-cn/blob/master/cookbook/04-Parser/pyTorch-ONNX-TensorRT/main.py 136 | 137 | #### - ImageNet 1000 samples 138 | 139 | https://www.kaggle.com/ifigotin/imagenetmini-1000 140 | 141 | Author 142 | ============= 143 | 144 | #### - LinkedIn: https://www.linkedin.com/in/taeyong-kong-016bb2154 145 | 146 | #### - Blog URL: https://blog.naver.com/qbxlvnf11 147 | 148 | #### - Email: qbxlvnf11@google.com, qbxlvnf11@naver.com 149 | 150 | -------------------------------------------------------------------------------- /convert_onnx_to_tensorrt/convert_onnx_to_tensorrt.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | from skimage import io 5 | from skimage.transform import resize 6 | from collections import OrderedDict 7 | from PIL import Image 8 | import cv2 9 | import time 10 | 11 | # Torch 12 | import torch 13 | from torch import nn 14 | import torchvision.models as models 15 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 16 | import torch.optim as optim 17 | import torchvision.datasets as datasets 18 | import torchvision.transforms as transforms 19 | from torchvision.utils import save_image 20 | 21 | # ONNX: pip install onnx, onnxruntime 22 | try: 23 | import onnx 24 | import onnxruntime as rt 25 | except ImportError as e: 26 | raise ImportError(f'Please install onnx and onnxruntime first. {e}') 27 | 28 | # CUDA & TensorRT 29 | #import pycuda.driver as cuda 30 | from cuda import cuda 31 | import pycuda.autoinit 32 | import tensorrt as trt 33 | 34 | TRT_LOGGER = trt.Logger() 35 | 36 | def parse_args(): 37 | parser = argparse.ArgumentParser(description='Convert ONNX models to TensorRT') 38 | 39 | parser.add_argument('--device', help='cuda or not', 40 | default='cuda:0') 41 | 42 | # Sample image 43 | parser.add_argument('--batch_size', type=int, help='data batch size', 44 | default=1) 45 | parser.add_argument('--img_size', help='input size', 46 | default=[3, 224, 224]) 47 | parser.add_argument('--sample_folder_path', help='sample image folder path', 48 | default='./imagenet-mini/train') 49 | #parser.add_argument('--sample_image_path', help='sample image path', 50 | #default='./sample.jpg') 51 | 52 | # Model path 53 | parser.add_argument('--onnx_model_path', help='onnx model path', 54 | default='./onnx_model.onnx') 55 | parser.add_argument('--tensorrt_engine_path', help='tensorrt engine path', 56 | default='./tensorrt_engine.engine') 57 | 58 | # TensorRT engine params 59 | parser.add_argument('--dynamic_axes', help='dynamic batch input or output', 60 | default='True') 61 | parser.add_argument('--engine_precision', help='precision of TensorRT engine', choices=['FP32', 'FP16'], 62 | default='FP32') 63 | parser.add_argument('--min_engine_batch_size', type=int, help='set the min input data size of model for inference', 64 | default=1) 65 | parser.add_argument('--opt_engine_batch_size', type=int, help='set the most used input data size of model for inference', 66 | default=1) 67 | parser.add_argument('--max_engine_batch_size', type=int, help='set the max input data size of model for inference', 68 | default=8) 69 | parser.add_argument('--engine_workspace', type=int, help='workspace of engine', 70 | default=1024) 71 | 72 | args = string_to_bool(parser.parse_args()) 73 | 74 | return args 75 | 76 | def string_to_bool(args): 77 | 78 | if args.dynamic_axes.lower() in ('true'): args.dynamic_axes = True 79 | else: args.dynamic_axes = False 80 | 81 | return args 82 | 83 | def get_transform(img_size): 84 | options = [] 85 | options.append(transforms.Resize((img_size[1], img_size[2]))) 86 | options.append(transforms.ToTensor()) 87 | #options.append(transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])) 88 | transform = transforms.Compose(options) 89 | return transform 90 | 91 | ''' 92 | def load_image(img_path, size): 93 | img_raw = io.imread(img_path) 94 | img_raw = np.rollaxis(img_raw, 2, 0) 95 | img_resize = resize(img_raw / 255, size, anti_aliasing=True) 96 | img_resize = img_resize.astype(np.float32) 97 | return img_resize, img_raw 98 | ''' 99 | 100 | def load_image_folder(folder_path, img_size, batch_size): 101 | transforming = get_transform(img_size) 102 | dataset = datasets.ImageFolder(folder_path, transform=transforming) 103 | data_loader = torch.utils.data.DataLoader(dataset, 104 | batch_size=batch_size, 105 | shuffle=True, 106 | num_workers=1) 107 | data_iter = iter(data_loader) 108 | torch_images, class_list = next(data_iter) 109 | print('class:', class_list) 110 | print('torch images size:', torch_images.size()) 111 | save_image(torch_images[0], 'sample.png') 112 | 113 | return torch_images.cpu().numpy() 114 | 115 | def build_engine(onnx_model_path, tensorrt_engine_path, engine_precision, dynamic_axes, \ 116 | img_size, batch_size, min_engine_batch_size, opt_engine_batch_size, max_engine_batch_size): 117 | 118 | # Builder 119 | logger = trt.Logger(trt.Logger.ERROR) 120 | builder = trt.Builder(logger) 121 | network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) 122 | profile = builder.create_optimization_profile() 123 | config = builder.create_builder_config() 124 | #config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 3 << 30) 125 | # Set FP16 126 | if engine_precision == 'FP16': 127 | config.set_flag(trt.BuilderFlag.FP16) 128 | 129 | # Onnx parser 130 | parser = trt.OnnxParser(network, logger) 131 | if not os.path.exists(onnx_model_path): 132 | print("Failed finding ONNX file!") 133 | exit() 134 | print("Succeeded finding ONNX file!") 135 | with open(onnx_model_path, "rb") as model: 136 | if not parser.parse(model.read()): 137 | print("Failed parsing .onnx file!") 138 | for error in range(parser.num_errors): 139 | print(parser.get_error(error)) 140 | exit() 141 | print("Succeeded parsing .onnx file!") 142 | 143 | # Input 144 | inputTensor = network.get_input(0) 145 | # Dynamic batch (min, opt, max) 146 | print('inputTensor.name:', inputTensor.name) 147 | if dynamic_axes: 148 | profile.set_shape(inputTensor.name, (min_engine_batch_size, img_size[0], img_size[1], img_size[2]), \ 149 | (opt_engine_batch_size, img_size[0], img_size[1], img_size[2]), \ 150 | (max_engine_batch_size, img_size[0], img_size[1], img_size[2])) 151 | print('Set dynamic') 152 | else: 153 | profile.set_shape(inputTensor.name, (batch_size, img_size[0], img_size[1], img_size[2]), \ 154 | (batch_size, img_size[0], img_size[1], img_size[2]), \ 155 | (batch_size, img_size[0], img_size[1], img_size[2])) 156 | config.add_optimization_profile(profile) 157 | #network.unmark_output(network.get_output(0)) 158 | 159 | # Write engine 160 | engineString = builder.build_serialized_network(network, config) 161 | if engineString == None: 162 | print("Failed building engine!") 163 | exit() 164 | print("Succeeded building engine!") 165 | with open(tensorrt_engine_path, "wb") as f: 166 | f.write(engineString) 167 | 168 | def trt_inference(engine, context, data): 169 | 170 | nInput = np.sum([engine.binding_is_input(i) for i in range(engine.num_bindings)]) 171 | nOutput = engine.num_bindings - nInput 172 | print('nInput:', nInput) 173 | print('nOutput:', nOutput) 174 | 175 | for i in range(nInput): 176 | print("Bind[%2d]:i[%2d]->" % (i, i), engine.get_binding_dtype(i), engine.get_binding_shape(i), context.get_binding_shape(i), engine.get_binding_name(i)) 177 | for i in range(nInput,nInput+nOutput): 178 | print("Bind[%2d]:o[%2d]->" % (i, i - nInput), engine.get_binding_dtype(i), engine.get_binding_shape(i), context.get_binding_shape(i), engine.get_binding_name(i)) 179 | 180 | bufferH = [] 181 | bufferH.append(np.ascontiguousarray(data.reshape(-1))) 182 | 183 | for i in range(nInput, nInput + nOutput): 184 | bufferH.append(np.empty(context.get_binding_shape(i), dtype=trt.nptype(engine.get_binding_dtype(i)))) 185 | 186 | bufferD = [] 187 | for i in range(nInput + nOutput): 188 | bufferD.append(cuda.cuMemAlloc(bufferH[i].nbytes)[1]) 189 | 190 | for i in range(nInput): 191 | cuda.cuMemcpyHtoD(bufferD[i], bufferH[i].ctypes.data, bufferH[i].nbytes) 192 | 193 | context.execute_v2(bufferD) 194 | 195 | for i in range(nInput, nInput + nOutput): 196 | cuda.cuMemcpyDtoH(bufferH[i].ctypes.data, bufferD[i], bufferH[i].nbytes) 197 | 198 | for b in bufferD: 199 | cuda.cuMemFree(b) 200 | 201 | return bufferH 202 | 203 | def main(): 204 | args = parse_args() 205 | 206 | # Sample images (folder) 207 | print(args.sample_folder_path) 208 | img_resize = load_image_folder(args.sample_folder_path, args.img_size, args.batch_size).astype(np.float32) 209 | ''' 210 | # Sample (one image) 211 | print(args.sample_image_path) 212 | img_resize, img_raw = load_image(args.sample_image_path, args.img_size) 213 | ''' 214 | 215 | print('inference image size:', img_resize.shape) 216 | 217 | # Build TensorRT engine 218 | build_engine(args.onnx_model_path, args.tensorrt_engine_path, args.engine_precision, args.dynamic_axes, \ 219 | args.img_size, args.batch_size, args.min_engine_batch_size, args.opt_engine_batch_size, args.max_engine_batch_size) 220 | 221 | # Read the engine from the file and deserialize 222 | with open(args.tensorrt_engine_path, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime: 223 | engine = runtime.deserialize_cuda_engine(f.read()) 224 | context = engine.create_execution_context() 225 | 226 | # TensorRT inference 227 | context.set_binding_shape(0, (args.batch_size, args.img_size[0], args.img_size[1], args.img_size[2])) 228 | 229 | trt_start_time = time.time() 230 | trt_outputs = trt_inference(engine, context, img_resize) 231 | trt_outputs = np.array(trt_outputs[1]).reshape(args.batch_size, -1) 232 | trt_end_time = time.time() 233 | 234 | # ONNX inference 235 | onnx_model = onnx.load(args.onnx_model_path) 236 | sess = rt.InferenceSession(args.onnx_model_path) 237 | 238 | input_all = [node.name for node in onnx_model.graph.input] 239 | input_initializer = [ 240 | node.name for node in onnx_model.graph.initializer 241 | ] 242 | net_feed_input = list(set(input_all) - set(input_initializer)) 243 | assert len(net_feed_input) == 1 244 | 245 | sess_input = sess.get_inputs()[0].name 246 | sess_output = sess.get_outputs()[0].name 247 | 248 | onnx_start_time = time.time() 249 | onnx_result = sess.run([sess_output], {sess_input: img_resize})[0] 250 | onnx_end_time = time.time() 251 | 252 | # Pytorch inference 253 | resnet18 = models.resnet18(pretrained=True).to(args.device) 254 | resnet18.eval() 255 | 256 | img_resize_torch = torch.Tensor(img_resize).to(args.device) 257 | torch_start_time = time.time() 258 | pytorch_result = resnet18(img_resize_torch) 259 | torch_end_time = time.time() 260 | pytorch_result = pytorch_result.detach().cpu().numpy() 261 | 262 | ## Comparision output of TensorRT and output of onnx model 263 | 264 | # Time Efficiency & output 265 | print('--pytorch--') 266 | print(pytorch_result.shape) # (batch_size, 1000) 267 | print(pytorch_result[0][:10]) 268 | print(np.argmax(pytorch_result, axis=1)) 269 | print('Time:', torch_end_time - torch_start_time) 270 | 271 | print('--onnx--') 272 | print(onnx_result.shape) 273 | print(onnx_result[0][:10]) 274 | print(np.argmax(onnx_result, axis=1)) 275 | print('Time: ', onnx_end_time - onnx_start_time) 276 | 277 | print('--tensorrt--') 278 | print(trt_outputs.shape) 279 | print(trt_outputs[0][:10]) 280 | print(np.argmax(trt_outputs, axis=1)) 281 | print('Time: ', trt_end_time - trt_start_time) 282 | 283 | if __name__ == '__main__': 284 | main() 285 | -------------------------------------------------------------------------------- /convert_pytorch_to_onnx/convert_pytorch_to_onnx.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | from skimage import io 5 | from skimage.transform import resize 6 | from collections import OrderedDict 7 | from PIL import Image 8 | import cv2 9 | import time 10 | 11 | # Torch 12 | import torch 13 | from torch import nn 14 | import torchvision.models as models 15 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 16 | import torch.optim as optim 17 | import torchvision.datasets as datasets 18 | import torchvision.transforms as transforms 19 | from torchvision.utils import save_image 20 | 21 | # ONNX: pip install onnx, onnxruntime 22 | try: 23 | import onnx 24 | import onnxruntime as rt 25 | except ImportError as e: 26 | raise ImportError(f'Please install onnx and onnxruntime first. {e}') 27 | 28 | def parse_args(): 29 | parser = argparse.ArgumentParser(description='Convert Pytorch models to ONNX') 30 | 31 | parser.add_argument('--device', help='cuda or not', 32 | default='cuda:0') 33 | 34 | # Sample image 35 | parser.add_argument('--batch_size', type=int, help='onnx sample batch size', 36 | default=1) 37 | parser.add_argument('--img_size', help='image size', 38 | default=[3, 224, 224]) 39 | parser.add_argument('--sample_folder_path', help='sample image folder path', 40 | default='./imagenet-mini/train/') 41 | #parser.add_argument('--sample_image_path', help='sample image path', 42 | #default='./sample.jpg') 43 | 44 | parser.add_argument('--output_path', help='onnx model path', 45 | default='./onnx_output.onnx') 46 | 47 | # ONNX params 48 | parser.add_argument('--dynamic_axes', help='dynamic batch input or output', 49 | default='True') 50 | parser.add_argument('--keep_initializers_as_inputs', help='If True, all the initializers (typically corresponding to parameters) in the exported graph will also be added as inputs to the graph. If False, then initializers are not added as inputs to the graph, and only the non-parameter inputs are added as inputs.', 51 | default='True') 52 | parser.add_argument('--export_params', help='If specified, all parameters will be exported. Set this to False if you want to export an untrained model.', 53 | default='True') 54 | parser.add_argument('--opset_version', type=int, help='opset version', 55 | default=11) 56 | 57 | args = string_to_bool(parser.parse_args()) 58 | 59 | return args 60 | 61 | def string_to_bool(args): 62 | 63 | if args.dynamic_axes.lower() in ('true'): args.dynamic_axes = True 64 | else: args.dynamic_axes = False 65 | 66 | if args.keep_initializers_as_inputs.lower() in ('true'): args.keep_initializers_as_inputs = True 67 | else: args.keep_initializers_as_inputs = False 68 | 69 | if args.export_params.lower() in ('true'): args.export_params = True 70 | else: args.export_params = False 71 | 72 | return args 73 | 74 | def get_transform(img_size): 75 | options = [] 76 | options.append(transforms.Resize((img_size[1], img_size[2]))) 77 | options.append(transforms.ToTensor()) 78 | #options.append(transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])) 79 | transform = transforms.Compose(options) 80 | return transform 81 | 82 | '''def load_image(img_path, size): 83 | img_raw = io.imread(img_path) 84 | img_raw = np.rollaxis(img_raw, 2, 0) 85 | img_resize = resize(img_raw / 255, size, anti_aliasing=True) 86 | img_resize = img_resize.astype(np.float32) 87 | return img_resize, img_raw''' 88 | 89 | def load_image_folder(folder_path, img_size, batch_size): 90 | transforming = get_transform(img_size) 91 | dataset = datasets.ImageFolder(folder_path, transform=transforming) 92 | data_loader = torch.utils.data.DataLoader(dataset, 93 | batch_size=batch_size, 94 | shuffle=True, 95 | num_workers=1) 96 | data_iter = iter(data_loader) 97 | torch_images, class_list = next(data_iter) 98 | save_image(torch_images[0], 'sample.png') 99 | 100 | return torch_images.cpu().numpy() 101 | 102 | if __name__ == '__main__': 103 | args = parse_args() 104 | 105 | # Load pretrained model 106 | resnet18 = models.resnet18(pretrained=True).to(args.device) 107 | 108 | ''' 109 | fc = nn.Sequential(OrderedDict([ 110 | ('fc1', nn.Linear(512,1000)), 111 | ('output',nn.Softmax(dim=1)) 112 | ])) 113 | resnet18.fc = fc 114 | ''' 115 | 116 | print(resnet18) 117 | 118 | resnet18.eval() 119 | 120 | # Sample images (folder) 121 | print(args.sample_folder_path) 122 | img_resize = load_image_folder(args.sample_folder_path, args.img_size, args.batch_size).astype(np.float32) 123 | ''' 124 | # Sample (one image) 125 | print(args.sample_image_path) 126 | img_resize, img_raw = load_image(args.sample_image_path, args.img_size) 127 | ''' 128 | 129 | sample_input = torch.randn(args.batch_size, args.img_size[0], args.img_size[1], args.img_size[2]).to(args.device) 130 | print('inference image size:', img_resize.shape, 'sample input size:', sample_input.shape) 131 | 132 | if args.dynamic_axes: 133 | # Dynamic input 134 | dynamic_axes = {'input' : {0 : 'batch_size'}, 'output' : {0 : 'batch_size'}} 135 | 136 | # Export onnx 137 | torch.onnx.export( 138 | resnet18, 139 | sample_input, 140 | args.output_path, 141 | export_params=args.export_params, 142 | keep_initializers_as_inputs=args.keep_initializers_as_inputs, 143 | opset_version=args.opset_version, 144 | input_names=['input'], # input vect name 145 | output_names=['output'], # output vect name 146 | dynamic_axes=dynamic_axes, # dynamic input 147 | verbose=False) 148 | else: 149 | # Export onnx 150 | torch.onnx.export( 151 | resnet18, 152 | sample_input, 153 | args.output_path, 154 | export_params=args.export_params, 155 | keep_initializers_as_inputs=args.keep_initializers_as_inputs, 156 | opset_version=args.opset_version, 157 | input_names=['input'], # input vect name 158 | output_names=['output'], # output vect name 159 | verbose=False) 160 | 161 | # Load the ONNX model 162 | onnx_model = onnx.load(args.output_path) 163 | sess = rt.InferenceSession(args.output_path) 164 | 165 | # Check that the IR is well formed 166 | onnx.checker.check_model(onnx_model) 167 | 168 | # Print a human readable representation of the graph 169 | with open('OnnxShape.txt','w') as f: 170 | f.write(f"{onnx.helper.printable_graph(onnx_model.graph)}") 171 | 172 | ## Comparision output of onnx and output of Pytorch model 173 | # Pytorch results 174 | img_resize_torch = torch.Tensor(img_resize).to(args.device) 175 | torch_start_time = time.time() 176 | pytorch_result = resnet18(img_resize_torch) 177 | torch_end_time = time.time() 178 | pytorch_result = pytorch_result.detach().cpu().numpy() 179 | 180 | # ONNX results 181 | input_all = [node.name for node in onnx_model.graph.input] 182 | input_initializer = [ 183 | node.name for node in onnx_model.graph.initializer 184 | ] 185 | net_feed_input = list(set(input_all) - set(input_initializer)) 186 | assert len(net_feed_input) == 1 187 | 188 | sess_input = sess.get_inputs()[0].name 189 | sess_output = sess.get_outputs()[0].name 190 | 191 | onnx_start_time = time.time() 192 | onnx_result = sess.run([sess_output], {sess_input: img_resize})[0] 193 | onnx_end_time = time.time() 194 | 195 | print('--pytorch--') 196 | print(pytorch_result.shape) # (batch_size, 1000) 197 | print(pytorch_result[0][:10]) 198 | print(np.argmax(pytorch_result, axis=1)) 199 | print('Time:', torch_end_time - torch_start_time) 200 | 201 | print('--onnx--') 202 | print(onnx_result.shape) 203 | print(onnx_result[0][:10]) 204 | print(np.argmax(onnx_result, axis=1)) 205 | print('Time:', onnx_end_time - onnx_start_time) 206 | 207 | # Comparision 208 | assert np.allclose( 209 | pytorch_result, onnx_result, 210 | atol=1.e-2), 'The outputs are different (Pytorch and ONNX)' 211 | print('The numerical values are same (Pytorch and ONNX)') 212 | --------------------------------------------------------------------------------