├── .gitignore ├── README.md └── simple_mlp ├── README.md ├── main.py ├── mlp.py ├── mlp_trt.py ├── trt_utils ├── common.py └── convertor.py └── weights ├── simple_mlp.engine ├── simple_mlp.onnx └── simple_mlp.pth /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.pyc 3 | *.pyo 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SimpleMLP-TensorRT 2 | 3 | A Simple guide and tutorial on using TensorRT for accelerating a simple Multi-Layer Perceptron (MLP). This repository includes step-by-step instructions, code examples, and explanations to help you get started with TensorRT for nueral network models. 4 | 5 | **TensorRT Version**: Ensure you have TensorRT **version 8.6** or later installed, as TensorRT is compatible across versions from 8.6 onwards. 6 | 7 | ## Overview of Using TensorRT 8 | 9 | 1. **[Define Your Model](#Pytorch-model)**: Start by defining and training your model in PyTorch. 10 | 2. **[Convert to ONNX](#Convert-pytorch-model-to-ONNX)**: Convert your PyTorch model to the ONNX format. This step is necessary for both static and dynamic shape configurations. 11 | 3. **[Build TensorRT Engine](#Building-Engine)**: 12 | - **Static Shapes**: Build an engine with predefined input and output shapes for maximum optimization. 13 | - **Dynamic Shapes**: Build an engine with profile settings that support varying input and output shapes, allowing flexibility for different scenarios. 14 | 4. **[Inference from engine](#Inference)**: 15 | - **Create Execution Context**: Generate a context from the TensorRT engine to manage inference execution. 16 | - **Allocate Memory Buffers**: 17 | - Allocate memory for inputs and outputs in both host and device memory based on the shapes. 18 | - **Transfer Data and Run Inference**: 19 | - Transfer input data from the host to the device memory. 20 | - Execute inference using the TensorRT context. 21 | - Transfer the output data from the device back to the host memory. 22 | - **Post-Processing**: Reshape the 1D output array to the desired dimensions for further use. 23 | 24 | By following these steps, you can leverage TensorRT to significantly improve the performance of your neural network models on NVIDIA GPUs. 25 | 26 | ## Static vs Dynamic Shapes in TensorRT 27 | 28 | In TensorRT, the term "shapes" refers to the dimensions of the input and output tensors that a neural network processes. 29 | When working with TensorRT, understanding the differences between static and dynamic shapes is crucial for optimizing and deploying models effectively. 30 | 31 | **Note**: Working with different batch sizes required dynamic shapes. 32 | 33 | ### Choosing Between Static and Dynamic Shapes 34 | 35 | The choice between static and dynamic shapes depends on the specific requirements of your application: 36 | - Use **static shapes** if your input dimensions are fixed and known in advance, and you need maximum performance. 37 | - Use **dynamic shapes** if your application needs to handle inputs of varying sizes and you require flexibility and scalability. 38 | 39 | In the following sections, we will provide examples of how to implement both static and dynamic shapes using TensorRT for a simple MLP model. 40 | 41 | ## Pytorch model 42 | 43 | First we define a simple Multi-Layer Perceptron (MLP) model using PyTorch. This model will be used as the basis for our TensorRT conversion and inference examples. 44 | We need to train our model and save the weights,here we save them as (.pth) file. 45 | 46 | ```python 47 | import torch 48 | import torch.nn as nn 49 | import torch.optim as optim 50 | import torch.nn.functional as F 51 | from torch.utils.data import DataLoader, TensorDataset 52 | 53 | # Define the MLP model 54 | class SimpleMLP(nn.Module): 55 | def __init__(self, input_size, hidden_size, num_classes): 56 | super(SimpleMLP, self).__init__() 57 | self.fc1 = nn.Linear(input_size, hidden_size) 58 | self.fc2 = nn.Linear(hidden_size, num_classes) 59 | 60 | def forward(self, x): 61 | x = F.relu(self.fc1(x)) 62 | x = self.fc2(x) 63 | return x 64 | 65 | # Parameters 66 | input_size = 784 # Example for MNIST dataset (28x28 images) 67 | hidden_size = 32 68 | num_classes = 10 69 | num_epochs = 10 70 | batch_size = 1 71 | learning_rate = 0.001 72 | 73 | # Dummy dataset 74 | x_train = torch.randn(600, input_size) 75 | y_train = torch.randint(0, num_classes, (600,)) 76 | 77 | train_dataset = TensorDataset(x_train, y_train) 78 | train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True) 79 | 80 | # Initialize the model, loss function, and optimizer 81 | model = SimpleMLP(input_size, hidden_size, num_classes) 82 | criterion = nn.CrossEntropyLoss() 83 | optimizer = optim.Adam(model.parameters(), lr=learning_rate) 84 | 85 | # # Training loop 86 | for epoch in range(num_epochs): 87 | for i, (images, labels) in enumerate(train_loader): 88 | # Forward pass 89 | outputs = model(images) 90 | loss = criterion(outputs, labels) 91 | 92 | # Backward pass and optimization 93 | optimizer.zero_grad() 94 | loss.backward() 95 | optimizer.step() 96 | 97 | if (i+1) % 100 == 0: 98 | print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}') 99 | 100 | # # Save the trained model 101 | torch.save(model.state_dict(), 'simple_mlp.pth') 102 | 103 | print('Model training complete and saved to simple_mlp.pth') 104 | ``` 105 | 106 | ## Convert pytorch model to ONNX 107 | 108 | To leverage the power of TensorRT for both static and dynamic shapes, we first need to convert our PyTorch model to the ONNX (Open Neural Network Exchange) format. ONNX is an open format built to represent machine learning models, enabling them to be transferred between various frameworks and optimizers. 109 | 110 | ### Steps to Convert a PyTorch Model to ONNX 111 | 112 | Regardless of whether you are working with static or dynamic shapes, the process of converting a PyTorch model to ONNX involves the following steps: 113 | 114 | 1. **Define Your Model**: Ensure your model is defined and trained in PyTorch. 115 | 2. **Create Dummy Input**: Prepare a dummy input tensor with the appropriate shape. For dynamic shapes, specify the axes that can vary. 116 | 3. **Export to ONNX**: Use the `torch.onnx.export` function to convert the model. 117 | 118 | ```python 119 | import mlp 120 | import torch 121 | 122 | 123 | # Convert the model to ONNX format 124 | # Dummy input for the model 125 | 126 | # Parameters 127 | input_size = 784 # Example for MNIST dataset (28x28 images) 128 | hidden_size = 32 129 | num_classes = 10 130 | batch_size = 10 131 | 132 | 133 | dummy_input = torch.randn(1, input_size) 134 | onnx_file_path = "simple_mlp_dynamic.onnx" 135 | 136 | 137 | model = mlp.SimpleMLP(input_size, hidden_size, num_classes) 138 | # Load the model's weights (using the final epoch as an example) 139 | model.load_state_dict(torch.load('simple_mlp.pth')) 140 | model.eval() # Set the model to evaluation mode 141 | 142 | # Export the model 143 | torch.onnx.export( 144 | model, 145 | dummy_input, 146 | onnx_file_path, 147 | input_names=['input'], 148 | output_names=['output'], 149 | dynamic_axes={'input' : {0: 'input_batch_size'}, # For static shapes we dont need dynamic axes 150 | 'output': {0: 'output_batch_size'}}, 151 | opset_version=11 152 | ) 153 | 154 | print(f'Model has been converted to ONNX and saved to {onnx_file_path}') 155 | ``` 156 | 157 | ## Building Engine 158 | 159 | After converting your PyTorch model to the ONNX format, the next step is to build a TensorRT engine. The engine is a highly optimized, platform-specific model that can run inference efficiently on NVIDIA GPUs. 160 | 161 | **Note**: Engine should be build on your device because it optimize model based on your GPU architecture. 162 | 163 | **Note**: Building the TensorRT engine can take a significant amount of time because it involves searching through various algorithms to optimize the inference performance. 164 | 165 | ### Steps to Build Engine 166 | 167 | 1. **Load the ONNX Model**: Read the ONNX file into memory. 168 | 2. **Create TensorRT Builder and Network**: Initialize the TensorRT builder and network. 169 | 3. **Parse the ONNX Model**: Parse the ONNX model to populate the TensorRT network. 170 | 4. **Build the Engine**: Configure the builder settings and build the engine. 171 | 172 | For dynamic shapes, we need to set an optimization profile that includes minimum, optimal, and maximum shape values for the dynamic dimensions. During inference, the input shapes must not exceed the maximum or fall below the minimum values specified. The closer the input shapes are to the optimal values, the more performance benefits we can achieve with TensorRT. 173 | 174 | ```python 175 | import tensorrt as trt 176 | 177 | # Initialize TensorRT logger and builder 178 | TRT_LOGGER = trt.Logger(trt.Logger.INFO) 179 | builder = trt.Builder(TRT_LOGGER) 180 | config = builder.create_builder_config() 181 | 182 | 183 | # Set cache 184 | cache = config.create_timing_cache(b"") 185 | config.set_timing_cache(cache, ignore_mismatch=False) 186 | 187 | 188 | flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) 189 | builder.max_batch_size = 64 190 | network = builder.create_network(flag) 191 | parser = trt.OnnxParser(network, TRT_LOGGER) 192 | 193 | path_onnx_model = "simple_mlp_dynamic.onnx" 194 | 195 | with open(path_onnx_model, "rb") as f: 196 | if not parser.parse(f.read()): 197 | print(f"ERROR: Failed to parse the ONNX file {path_onnx_model}") 198 | for error in range(parser.num_errors): 199 | print(parser.get_error(error)) 200 | 201 | # Network has every inputs and its better to work with their names 202 | # Since we only have 1 input we do this 203 | input = network.get_input(0) 204 | 205 | # Set profile for dynamic shapes(we dont need this step for static ones) 206 | profile = builder.create_optimization_profile() 207 | min_shape = [1, 784] 208 | opt_shape = [32, 784] 209 | max_shape = [64, 784] 210 | profile.set_shape(input.name, min_shape, opt_shape, max_shape) 211 | config.add_optimization_profile(profile) 212 | 213 | # Check if fast Half is avaliable 214 | # print(builder.platform_has_fast_fp16) 215 | config.set_flag(trt.BuilderFlag.FP16) 216 | 217 | # Build engine 218 | engine_bytes = builder.build_serialized_network(network, config) 219 | 220 | engine_path = "simple_mlp_dynamic.engine" 221 | with open(engine_path, "wb") as f: 222 | f.write(engine_bytes) 223 | ``` 224 | ## Inference 225 | 226 | During inference, the following steps need to be performed: 227 | 228 | 1. **Create Context**: Generate a context from the pre-built TensorRT engine. 229 | 2. **Create CUDA Stream**: Create a `cuda.Stream()` to handle synchronization between host and device. 230 | 3. **Allocate Buffers**: Allocate memory buffers for inputs and outputs in both host and device memory exactly for how much data we need, based on the input and output shapes and size of data type we use. 231 | 5. **Transfer Input Data**: Move the input data from the host to the device memory. 232 | 6. **Run Inference**: Execute the inference on the device using the created context. 233 | 7. **Retrieve Output Data**: Transfer the output data from the device memory back to the host memory. 234 | 235 | These steps ensure that the data flows correctly through the TensorRT engine for efficient inference and that synchronization between the host and device is properly managed. 236 | 237 | ### Static shapes 238 | 239 | For static shapes, the input and output shapes are predefined and can be obtained from `engine.binding`. We simply need to allocate memory buffers based on these sizes. 240 | ```python 241 | import numpy as np 242 | import tensorrt as trt 243 | from cuda import cuda, cudart 244 | import ctypes 245 | from typing import Optional, List 246 | 247 | ### Cudart keypoint handler 248 | def check_cuda_err(err): 249 | if isinstance(err, cuda.CUresult): 250 | if err != cuda.CUresult.CUDA_SUCCESS: 251 | raise RuntimeError("Cuda Error: {}".format(err)) 252 | if isinstance(err, cudart.cudaError_t): 253 | if err != cudart.cudaError_t.cudaSuccess: 254 | raise RuntimeError("Cuda Runtime Error: {}".format(err)) 255 | else: 256 | raise RuntimeError("Unknown error type: {}".format(err)) 257 | 258 | def cuda_call(call): 259 | err, res = call[0], call[1:] 260 | check_cuda_err(err) 261 | if len(res) == 1: 262 | res = res[0] 263 | return res 264 | 265 | 266 | ### Class for transfer data between host and device memory 267 | class HostDeviceMem: 268 | """Pair of host and device memory, where the host memory is wrapped in a numpy array""" 269 | def __init__(self, size: int, dtype: np.dtype): 270 | nbytes = size * dtype.itemsize 271 | host_mem = cuda_call(cudart.cudaMallocHost(nbytes)) 272 | pointer_type = ctypes.POINTER(np.ctypeslib.as_ctypes_type(dtype)) 273 | 274 | self._host = np.ctypeslib.as_array(ctypes.cast(host_mem, pointer_type), (size,)) 275 | self._device = cuda_call(cudart.cudaMalloc(nbytes)) 276 | self._nbytes = nbytes 277 | 278 | @property 279 | def host(self) -> np.ndarray: 280 | return self._host 281 | 282 | @host.setter 283 | def host(self, arr: np.ndarray): 284 | if arr.size > self.host.size: 285 | raise ValueError( 286 | f"Tried to fit an array of size {arr.size} into host memory of size {self.host.size}" 287 | ) 288 | #np.copyto(self.host[:arr.size], arr.flat, casting='safe') 289 | np.copyto(self.host[:arr.size], arr.flat) 290 | 291 | @property 292 | def device(self) -> int: 293 | return self._device 294 | 295 | @property 296 | def nbytes(self) -> int: 297 | return self._nbytes 298 | 299 | def __str__(self): 300 | return f"Host:\n{self.host}\nDevice:\n{self.device}\nSize:\n{self.nbytes}\n" 301 | 302 | def __repr__(self): 303 | return self.__str__() 304 | 305 | def free(self): 306 | cuda_call(cudart.cudaFree(self.device)) 307 | cuda_call(cudart.cudaFreeHost(self.host.ctypes.data)) 308 | 309 | 310 | # Allocates all buffers required for an engine, i.e. host/device inputs/outputs. 311 | def allocate_buffers(engine: trt.ICudaEngine): 312 | inputs = [] 313 | outputs = [] 314 | bindings = [] 315 | stream = cuda_call(cudart.cudaStreamCreate()) 316 | for binding in engine: 317 | size = trt.volume(engine.get_binding_shape(binding)) * engine.max_batch_size 318 | dtype = trt.nptype(engine.get_binding_dtype(binding)) 319 | 320 | # Allocate host and device buffers 321 | bindingMemory = HostDeviceMem(size, dtype) 322 | 323 | # Append the device buffer to device bindings. 324 | bindings.append(int(bindingMemory.device)) 325 | 326 | # Append to the appropriate list. 327 | if engine.get_tensor_mode(binding) == trt.TensorIOMode.INPUT: 328 | inputs.append(bindingMemory) 329 | else: 330 | outputs.append(bindingMemory) 331 | 332 | 333 | return inputs, outputs, bindings, stream 334 | 335 | 336 | # Frees the resources allocated in allocate_buffers 337 | def free_buffers(inputs: List[HostDeviceMem], outputs: List[HostDeviceMem], stream: cudart.cudaStream_t): 338 | for mem in inputs + outputs: 339 | mem.free() 340 | cuda_call(cudart.cudaStreamDestroy(stream)) 341 | 342 | 343 | # Wrapper for cudaMemcpy which infers copy size and does error checking 344 | def memcpy_host_to_device(device_ptr: int, host_arr: np.ndarray): 345 | nbytes = host_arr.size * host_arr.itemsize 346 | cuda_call(cudart.cudaMemcpy(device_ptr, host_arr, nbytes, cudart.cudaMemcpyKind.cudaMemcpyHostToDevice)) 347 | 348 | 349 | # Wrapper for cudaMemcpy which infers copy size and does error checking 350 | def memcpy_device_to_host(host_arr: np.ndarray, device_ptr: int): 351 | nbytes = host_arr.size * host_arr.itemsize 352 | cuda_call(cudart.cudaMemcpy(host_arr, device_ptr, nbytes, cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost)) 353 | 354 | 355 | def _do_inference_base(inputs, outputs, stream, execute_async): 356 | # Transfer input data to the GPU. 357 | kind = cudart.cudaMemcpyKind.cudaMemcpyHostToDevice 358 | [cuda_call(cudart.cudaMemcpyAsync(inp.device, inp.host, inp.nbytes, kind, stream)) for inp in inputs] 359 | # Run inference. 360 | execute_async() 361 | # Transfer predictions back from the GPU. 362 | kind = cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost 363 | [cuda_call(cudart.cudaMemcpyAsync(out.host, out.device, out.nbytes, kind, stream)) for out in outputs] 364 | # Synchronize the stream 365 | cuda_call(cudart.cudaStreamSynchronize(stream)) 366 | # Return only the host outputs. 367 | return [out.host for out in outputs] 368 | 369 | 370 | # This function is generalized for multiple inputs/outputs for full dimension networks. 371 | # inputs and outputs are expected to be lists of HostDeviceMem objects. 372 | def do_inference(context, bindings, inputs, outputs, stream): 373 | def execute_async(): 374 | context.execute_async(bindings=bindings, stream_handle=stream) 375 | return _do_inference_base(inputs, outputs, stream, execute_async) 376 | 377 | ### Inference from tensorRT 378 | engine_file_path = 'simple_mlp_dynamic.engine' 379 | engine = load_engine(engine_file_path) 380 | 381 | # Create execution context 382 | context = engine.create_execution_context() 383 | 384 | # Allocate buffers 385 | inputs, outputs, binding, stream = allocate_buffers(engine) 386 | 387 | # Dummy input data 388 | input_data = np.random.randn(1, 784).astype(np.float32) 389 | 390 | # Transfer input data to host memory 391 | np.copyto(inputs[0].host, input_data.ravel()) 392 | 393 | # Run inference 394 | output_data = do_inference(context, bindings, inputs, outputs, stream) 395 | print("Inference output:", output_data) 396 | 397 | # Free allocated memory 398 | free_buffers(inputs, outputs, stream) 399 | ``` 400 | ### Dynamic shapes 401 | 402 | For dynamic shapes, we need to set the engine bindings based on the input shapes at inference time. Then, we allocate memory buffers accordingly based on these input shapes. 403 | 404 | **Note**: We can define multiple profile settings with different minimum, optimal, and maximum shapes within a single TensorRT engine. This allows the engine to be used in various scenarios without needing to rebuild it, which can be time-consuming. For example, if we have two profile settings and three bindings for input and output, to access the first input of the second profile, you would use `get_binding(3 (number of bindings) + 0 (first input))`. 405 | 406 | **Note**: The output from TensorRT is a 1D array. We need to reshape this array to the desired dimensions to use it as the output of our model. 407 | 408 | ```python 409 | import numpy as np 410 | import tensorrt as trt 411 | from cuda import cuda, cudart 412 | import ctypes 413 | from typing import Optional, List 414 | 415 | 416 | ### Cudart keypoint handler 417 | def check_cuda_err(err): 418 | if isinstance(err, cuda.CUresult): 419 | if err != cuda.CUresult.CUDA_SUCCESS: 420 | raise RuntimeError("Cuda Error: {}".format(err)) 421 | if isinstance(err, cudart.cudaError_t): 422 | if err != cudart.cudaError_t.cudaSuccess: 423 | raise RuntimeError("Cuda Runtime Error: {}".format(err)) 424 | else: 425 | raise RuntimeError("Unknown error type: {}".format(err)) 426 | 427 | def cuda_call(call): 428 | err, res = call[0], call[1:] 429 | check_cuda_err(err) 430 | if len(res) == 1: 431 | res = res[0] 432 | return res 433 | 434 | 435 | ### Class for transfer data between host and device memory 436 | class HostDeviceMem: 437 | """Pair of host and device memory, where the host memory is wrapped in a numpy array""" 438 | def __init__(self, size: int, dtype: np.dtype): 439 | nbytes = size * dtype.itemsize 440 | host_mem = cuda_call(cudart.cudaMallocHost(nbytes)) 441 | pointer_type = ctypes.POINTER(np.ctypeslib.as_ctypes_type(dtype)) 442 | 443 | self._host = np.ctypeslib.as_array(ctypes.cast(host_mem, pointer_type), (size,)) 444 | self._device = cuda_call(cudart.cudaMalloc(nbytes)) 445 | self._nbytes = nbytes 446 | 447 | @property 448 | def host(self) -> np.ndarray: 449 | return self._host 450 | 451 | @host.setter 452 | def host(self, arr: np.ndarray): 453 | if arr.size > self.host.size: 454 | raise ValueError( 455 | f"Tried to fit an array of size {arr.size} into host memory of size {self.host.size}" 456 | ) 457 | #np.copyto(self.host[:arr.size], arr.flat, casting='safe') 458 | np.copyto(self.host[:arr.size], arr.flat) 459 | 460 | @property 461 | def device(self) -> int: 462 | return self._device 463 | 464 | @property 465 | def nbytes(self) -> int: 466 | return self._nbytes 467 | 468 | def __str__(self): 469 | return f"Host:\n{self.host}\nDevice:\n{self.device}\nSize:\n{self.nbytes}\n" 470 | 471 | def __repr__(self): 472 | return self.__str__() 473 | 474 | def free(self): 475 | cuda_call(cudart.cudaFree(self.device)) 476 | cuda_call(cudart.cudaFreeHost(self.host.ctypes.data)) 477 | 478 | 479 | # Allocates all buffers required for an engine, i.e. host/device inputs/outputs. 480 | # If engine uses dynamic shapes, specify a profile to find the maximum input & output size. 481 | def allocate_buffers(engine: trt.ICudaEngine, inputs_shape): 482 | inputs = [] 483 | outputs = [] 484 | bindings = [] 485 | stream = cuda_call(cudart.cudaStreamCreate()) 486 | tensor_names = [engine.get_tensor_name(i) for i in range(engine.num_io_tensors)] 487 | for shape, binding in zip(inputs_shape, tensor_names): 488 | size = trt.volume(shape) 489 | dtype = np.dtype(trt.nptype(engine.get_tensor_dtype(binding))) 490 | 491 | # Allocate host and device buffers 492 | bindingMemory = HostDeviceMem(size, dtype) 493 | 494 | # Append the device buffer to device bindings. 495 | bindings.append(int(bindingMemory.device)) 496 | 497 | # Append to the appropriate list. 498 | if engine.get_tensor_mode(binding) == trt.TensorIOMode.INPUT: 499 | inputs.append(bindingMemory) 500 | else: 501 | outputs.append(bindingMemory) 502 | 503 | 504 | return inputs, outputs, bindings, stream 505 | 506 | 507 | # Frees the resources allocated in allocate_buffers 508 | def free_buffers(inputs: List[HostDeviceMem], outputs: List[HostDeviceMem], stream: cudart.cudaStream_t): 509 | for mem in inputs + outputs: 510 | mem.free() 511 | cuda_call(cudart.cudaStreamDestroy(stream)) 512 | 513 | 514 | # Wrapper for cudaMemcpy which infers copy size and does error checking 515 | def memcpy_host_to_device(device_ptr: int, host_arr: np.ndarray): 516 | nbytes = host_arr.size * host_arr.itemsize 517 | cuda_call(cudart.cudaMemcpy(device_ptr, host_arr, nbytes, cudart.cudaMemcpyKind.cudaMemcpyHostToDevice)) 518 | 519 | 520 | # Wrapper for cudaMemcpy which infers copy size and does error checking 521 | def memcpy_device_to_host(host_arr: np.ndarray, device_ptr: int): 522 | nbytes = host_arr.size * host_arr.itemsize 523 | cuda_call(cudart.cudaMemcpy(host_arr, device_ptr, nbytes, cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost)) 524 | 525 | 526 | def _do_inference_base(inputs, outputs, stream, execute_async): 527 | # Transfer input data to the GPU. 528 | kind = cudart.cudaMemcpyKind.cudaMemcpyHostToDevice 529 | [cuda_call(cudart.cudaMemcpyAsync(inp.device, inp.host, inp.nbytes, kind, stream)) for inp in inputs] 530 | # Run inference. 531 | execute_async() 532 | # Transfer predictions back from the GPU. 533 | kind = cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost 534 | [cuda_call(cudart.cudaMemcpyAsync(out.host, out.device, out.nbytes, kind, stream)) for out in outputs] 535 | # Synchronize the stream 536 | cuda_call(cudart.cudaStreamSynchronize(stream)) 537 | # Return only the host outputs. 538 | return [out.host for out in outputs] 539 | 540 | 541 | # This function is generalized for multiple inputs/outputs for full dimension networks. 542 | # inputs and outputs are expected to be lists of HostDeviceMem objects. 543 | def do_inference_v2(context, bindings, inputs, outputs, stream): 544 | def execute_async(): 545 | context.execute_async_v2(bindings=bindings, stream_handle=stream) 546 | return _do_inference_base(inputs, outputs, stream, execute_async) 547 | 548 | 549 | ### Inference from tensorRt 550 | # Function to load a TensorRT engine from a file 551 | def load_engine(engine_file_path): 552 | TRT_LOGGER = trt.Logger(trt.Logger.WARNING) 553 | with open(engine_file_path, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime: 554 | return runtime.deserialize_cuda_engine(f.read()) 555 | 556 | # Load the engine 557 | engine_file_path = "/home/jetson/Danial/a/superpoint.engine" 558 | engine = load_engine(engine_file_path) 559 | 560 | 561 | # Create context 562 | context = engine.create_execution_context() 563 | # In dynamic version we need to first load input and set binding based on input shapes then allocate buffers 564 | # Dummy input data 565 | input_data = np.random.randn(40, 784).astype(np.float32) 566 | 567 | # Set binding 568 | input_name = engine.get_tensor_name(0) 569 | context.set_input_shape(input_name, input_data.shape) 570 | 571 | # Set input shapes for memory allocation 572 | # We should know what model outputs shapes 573 | model_shapes = [input_data.shape, (input_data.shape[0], 10)] #[input_shape=(40, 784), output_shape=(40, 10)] 574 | 575 | # Allocate memory for inputs and outputs 576 | inputs, outputs, bindings, stream = allocate_buffers(engine, inputs_shape) 577 | 578 | # Transfer input data to the allocated buffer 579 | np.copyto(inputs[0].host, img.ravel()) 580 | 581 | output_data = do_inference_v2(context, bindings, inputs, outputs, stream) 582 | 583 | # Reshape the output to desire shape and convert to torch 584 | output = torch.from_numpy(output_data[0].reshape(input_data.shape[0], 10)) 585 | 586 | # Free allocated memory 587 | free_buffers(inputs, outputs, stream) 588 | ``` 589 | ## Acknowledgements 590 | [Nvidia tensorRT documentation](https://developer.nvidia.com/tensorrt), [Inferece example in nvidia repo](https://github.com/d246810g2000/tensorrt/blob/main/common.py) 591 | 592 | -------------------------------------------------------------------------------- /simple_mlp/README.md: -------------------------------------------------------------------------------- 1 | This part provides an example of a simple MLP model, including its conversion to ONNX, optimization into a TensorRT engine, and inference using the TensorRT context. 2 | 3 | To use this with your own model, you will need to modify the MLP model, `build_onnx`, `build_engine` in `convertor.py`, and the inference part in the `SimpleMLPTRT` class to accommodate the input and output dimensions of your model. After making these adjustments, simply use `main.py` as shown in the example. 4 | -------------------------------------------------------------------------------- /simple_mlp/main.py: -------------------------------------------------------------------------------- 1 | from mlp_trt import SimpleMLPTRT 2 | 3 | import torch 4 | import random 5 | 6 | # Model parameters 7 | weights_file_path = 'weights/simple_mlp.pth' 8 | onnx_file_path = 'weights/simple_mlp.onnx' 9 | engine_file_path = 'weights/simple_mlp.engine' 10 | input_size = 784 11 | hidden_size = 32 12 | num_classes = 10 13 | 14 | 15 | # Create an instance of the inference class 16 | trt_model = SimpleMLPTRT(input_size=input_size, hidden_size=hidden_size, num_classes=num_classes, weights_file_path=weights_file_path, onnx_file_path=onnx_file_path, engine_file_path=engine_file_path) 17 | 18 | # Inference with dynamic batch_size 19 | for i in range(100): 20 | # Create dummy input with dynamic batch_size 21 | batch_size = random.randint(1, 60) 22 | dummy_input = torch.randn(batch_size, 784) 23 | 24 | # Perform inference 25 | output = trt_model.infer(dummy_input) 26 | -------------------------------------------------------------------------------- /simple_mlp/mlp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import torch.nn.functional as F 5 | from torch.utils.data import DataLoader, TensorDataset 6 | 7 | # Define the MLP model 8 | class SimpleMLP(nn.Module): 9 | def __init__(self, input_size, hidden_size, num_classes): 10 | super(SimpleMLP, self).__init__() 11 | self.fc1 = nn.Linear(input_size, hidden_size) 12 | self.fc2 = nn.Linear(hidden_size, num_classes) 13 | 14 | def forward(self, x): 15 | x = F.relu(self.fc1(x)) 16 | x = self.fc2(x) 17 | return x 18 | 19 | # Parameters 20 | input_size = 784 # Example for MNIST dataset (28x28 images) 21 | hidden_size = 32 22 | num_classes = 10 23 | num_epochs = 10 24 | batch_size = 1 25 | learning_rate = 0.001 26 | 27 | # Dummy dataset (replace with actual data) 28 | x_train = torch.randn(600, input_size) 29 | y_train = torch.randint(0, num_classes, (600,)) 30 | 31 | train_dataset = TensorDataset(x_train, y_train) 32 | train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True) 33 | 34 | # Initialize the model, loss function, and optimizer 35 | model = SimpleMLP(input_size, hidden_size, num_classes) 36 | criterion = nn.CrossEntropyLoss() 37 | optimizer = optim.Adam(model.parameters(), lr=learning_rate) 38 | 39 | # # Training loop 40 | # for epoch in range(num_epochs): 41 | # for i, (images, labels) in enumerate(train_loader): 42 | # # Forward pass 43 | # outputs = model(images) 44 | # loss = criterion(outputs, labels) 45 | 46 | # # Backward pass and optimization 47 | # optimizer.zero_grad() 48 | # loss.backward() 49 | # optimizer.step() 50 | 51 | # if (i+1) % 100 == 0: 52 | # print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}') 53 | 54 | # # Save the trained model 55 | torch.save(model.state_dict(), 'simple_mlp.pth') 56 | 57 | print('Model training complete and saved to simple_mlp.pth') 58 | -------------------------------------------------------------------------------- /simple_mlp/mlp_trt.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from trt_utils.convertor import build_onnx, build_engine, load_engine 5 | import trt_utils.common as common 6 | import mlp 7 | 8 | class SimpleMLPTRT: 9 | def __init__(self, input_size=784, hidden_size=32, num_classes=10, batch_size=1, weights_file_path='simple_mlp.pth', onnx_file_path='simple_mlp.onnx', engine_file_path='simple_mlp.engine'): 10 | self.input_size = input_size 11 | self.hidden_size = hidden_size 12 | self.num_classes = num_classes 13 | self.batch_size = batch_size 14 | 15 | # Define model (not used directly for inference but can be useful for validation) 16 | self.model = mlp.SimpleMLP(input_size, hidden_size, num_classes) 17 | 18 | # File paths 19 | self.weights_file_path = weights_file_path 20 | self.onnx_file_path = onnx_file_path 21 | self.engine_file_path = engine_file_path 22 | 23 | # Build ONNX model if it doesn't exist 24 | if not os.path.exists(self.onnx_file_path): 25 | build_onnx(self.weights_file_path, self.onnx_file_path, input_size, self.model) 26 | print(f'ONNX model saved to {self.onnx_file_path}') 27 | else: 28 | print('ONNX file already exists!') 29 | 30 | # Build TensorRT engine if it doesn't exist 31 | if not os.path.exists(self.engine_file_path): 32 | build_engine(self.onnx_file_path, self.engine_file_path) 33 | print(f'TensorRT engine saved to {self.engine_file_path}') 34 | else: 35 | print('TensorRT engine file already exists!') 36 | 37 | # Load TensorRT engine 38 | self.engine = load_engine(self.engine_file_path) 39 | 40 | # Create execution context for inference 41 | self.context = self.engine.create_execution_context() 42 | 43 | def infer(self, input_tensor): 44 | """Perform inference on the input tensor.""" 45 | 46 | # Set binding for context based on the input shape 47 | input_shape = input_tensor.shape 48 | output_shape = (input_shape[0], self.num_classes) 49 | 50 | input_binding_index = self.engine.get_tensor_name(0) 51 | self.context.set_input_shape(input_binding_index, input_shape) 52 | 53 | # Allocate buffers for inputs and outputs 54 | model_shapes = [input_shape, output_shape] 55 | inputs, outputs, bindings, stream = common.allocate_buffers(self.engine, model_shapes) 56 | 57 | # Transfer data to host memory 58 | np.copyto(inputs[0].host, input_tensor.ravel()) 59 | 60 | # Do inference 61 | common.do_inference_v2(self.context, bindings, inputs, outputs, stream) 62 | 63 | # Post-process the output of the model 64 | output = torch.from_numpy(outputs[0].host.reshape(output_shape).copy()) 65 | 66 | # Free buffers 67 | common.free_buffers(inputs, outputs, stream) 68 | 69 | return output 70 | 71 | 72 | -------------------------------------------------------------------------------- /simple_mlp/trt_utils/common.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import tensorrt as trt 4 | from cuda import cuda, cudart 5 | import ctypes 6 | from typing import List, Tuple 7 | 8 | 9 | def check_cuda_err(err): 10 | if isinstance(err, cuda.CUresult): 11 | if err != cuda.CUresult.CUDA_SUCCESS: 12 | raise RuntimeError("Cuda Error: {}".format(err)) 13 | if isinstance(err, cudart.cudaError_t): 14 | if err != cudart.cudaError_t.cudaSuccess: 15 | raise RuntimeError("Cuda Runtime Error: {}".format(err)) 16 | else: 17 | raise RuntimeError("Unknown error type: {}".format(err)) 18 | 19 | 20 | def cuda_call(call): 21 | err, res = call[0], call[1:] 22 | check_cuda_err(err) 23 | if len(res) == 1: 24 | res = res[0] 25 | return res 26 | 27 | 28 | class HostDeviceMem: 29 | def __init__(self, size: int, dtype: np.dtype): 30 | nbytes = size * dtype.itemsize 31 | host_mem = cuda_call(cudart.cudaMallocHost(nbytes)) 32 | pointer_type = ctypes.POINTER(np.ctypeslib.as_ctypes_type(dtype)) 33 | 34 | self._host = np.ctypeslib.as_array(ctypes.cast(host_mem, pointer_type), (size,)) 35 | self._device = cuda_call(cudart.cudaMalloc(nbytes)) 36 | self._nbytes = nbytes 37 | 38 | @property 39 | def host(self) -> np.ndarray: 40 | return self._host 41 | 42 | @host.setter 43 | def host(self, arr: np.ndarray): 44 | if arr.size > self.host.size: 45 | raise ValueError( 46 | f"Tried to fit an array of size {arr.size} into host memory of size {self.host.size}" 47 | ) 48 | np.copyto(self.host[: arr.size], arr.flat) 49 | 50 | @property 51 | def device(self) -> int: 52 | return self._device 53 | 54 | @property 55 | def nbytes(self) -> int: 56 | return self._nbytes 57 | 58 | def __str__(self): 59 | return f"Host:\n{self.host}\nDevice:\n{self.device}\nSize:\n{self.nbytes}\n" 60 | 61 | def __repr__(self): 62 | return self.__str__() 63 | 64 | def free(self): 65 | cuda_call(cudart.cudaFree(self.device)) 66 | cuda_call(cudart.cudaFreeHost(self.host.ctypes.data)) 67 | 68 | 69 | def allocate_buffers(engine: trt.ICudaEngine, inputs_shape: List[Tuple[int]]): 70 | inputs = [] 71 | outputs = [] 72 | bindings = [] 73 | stream = cuda_call(cudart.cudaStreamCreate()) 74 | tensor_names = [engine.get_tensor_name(i) for i in range(engine.num_io_tensors)] 75 | for shape, binding in zip(inputs_shape, tensor_names): 76 | size = trt.volume(shape) 77 | dtype = np.dtype(trt.nptype(engine.get_tensor_dtype(binding))) 78 | 79 | bindingMemory = HostDeviceMem(size, dtype) 80 | bindings.append(int(bindingMemory.device)) 81 | 82 | if engine.get_tensor_mode(binding) == trt.TensorIOMode.INPUT: 83 | inputs.append(bindingMemory) 84 | else: 85 | outputs.append(bindingMemory) 86 | 87 | return inputs, outputs, bindings, stream 88 | 89 | 90 | def free_buffers( 91 | inputs: List[HostDeviceMem], 92 | outputs: List[HostDeviceMem], 93 | stream: cudart.cudaStream_t, 94 | ): 95 | for mem in inputs + outputs: 96 | mem.free() 97 | cuda_call(cudart.cudaStreamDestroy(stream)) 98 | 99 | 100 | def memcpy_host_to_device(device_ptr: int, host_arr: np.ndarray): 101 | nbytes = host_arr.size * host_arr.itemsize 102 | cuda_call( 103 | cudart.cudaMemcpy( 104 | device_ptr, host_arr, nbytes, cudart.cudaMemcpyKind.cudaMemcpyHostToDevice 105 | ) 106 | ) 107 | 108 | 109 | def memcpy_device_to_host(host_arr: np.ndarray, device_ptr: int): 110 | nbytes = host_arr.size * host_arr.itemsize 111 | cuda_call( 112 | cudart.cudaMemcpy( 113 | host_arr, device_ptr, nbytes, cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost 114 | ) 115 | ) 116 | 117 | 118 | def _do_inference_base(inputs, outputs, stream, execute_async): 119 | kind = cudart.cudaMemcpyKind.cudaMemcpyHostToDevice 120 | [ 121 | cuda_call( 122 | cudart.cudaMemcpyAsync(inp.device, inp.host, inp.nbytes, kind, stream) 123 | ) 124 | for inp in inputs 125 | ] 126 | execute_async() 127 | kind = cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost 128 | [ 129 | cuda_call( 130 | cudart.cudaMemcpyAsync(out.host, out.device, out.nbytes, kind, stream) 131 | ) 132 | for out in outputs 133 | ] 134 | cuda_call(cudart.cudaStreamSynchronize(stream)) 135 | return [out.host for out in outputs] 136 | 137 | 138 | def do_inference_v2(context, bindings, inputs, outputs, stream): 139 | def execute_async(): 140 | context.execute_async_v2(bindings=bindings, stream_handle=stream) 141 | 142 | return _do_inference_base(inputs, outputs, stream, execute_async) 143 | -------------------------------------------------------------------------------- /simple_mlp/trt_utils/convertor.py: -------------------------------------------------------------------------------- 1 | import tensorrt as trt 2 | import torch 3 | 4 | 5 | # Convert the model to ONNX format 6 | # Dummy input for the model (replace with an actual sample if needed) 7 | def build_onnx(weights_file_path, onnx_file_path, input_size, model): 8 | 9 | dummy_input = torch.randn(1, input_size) 10 | 11 | # Load the model's weights (using the final epoch as an example) 12 | model.load_state_dict(torch.load(weights_file_path)) 13 | model.eval() # Set the model to evaluation mode 14 | 15 | # Export the model 16 | torch.onnx.export( 17 | model, 18 | dummy_input, 19 | onnx_file_path, 20 | input_names=['input'], 21 | output_names=['output'], 22 | dynamic_axes={'input' : {0: 'input_batch_size'}, 23 | 'output': {0: 'output_batch_size'}}, 24 | opset_version=11 25 | ) 26 | 27 | print(f'Model has been converted to ONNX and saved to {onnx_file_path}') 28 | def build_engine(onnx_path, engine_path): 29 | # Initialize TensorRT logger and builder 30 | TRT_LOGGER = trt.Logger(trt.Logger.INFO) 31 | builder = trt.Builder(TRT_LOGGER) 32 | config = builder.create_builder_config() 33 | 34 | 35 | # Set cache 36 | cache = config.create_timing_cache(b"") 37 | config.set_timing_cache(cache, ignore_mismatch=False) 38 | 39 | 40 | flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) 41 | builder.max_batch_size = 64 42 | network = builder.create_network(flag) 43 | parser = trt.OnnxParser(network, TRT_LOGGER) 44 | 45 | 46 | with open(onnx_path, "rb") as f: 47 | if not parser.parse(f.read()): 48 | print(f"ERROR: Failed to parse the ONNX file {onnx_path}") 49 | for error in range(parser.num_errors): 50 | print(parser.get_error(error)) 51 | 52 | input = network.get_input(0) 53 | # Check if fast Half is avaliable 54 | # print(builder.platform_has_fast_fp16) 55 | 56 | profile = builder.create_optimization_profile() 57 | 58 | min_shape = [1, 784] 59 | opt_shape = [32, 784] 60 | max_shape = [64, 784] 61 | profile.set_shape(input.name, min_shape, opt_shape, max_shape) 62 | 63 | 64 | config.add_optimization_profile(profile) 65 | #config.set_flag(trt.BuilderFlag.FP16) 66 | 67 | # Build engine 68 | engine_bytes = builder.build_serialized_network(network, config) 69 | 70 | with open(engine_path, "wb") as f: 71 | f.write(engine_bytes) 72 | 73 | def load_engine(engine_file_path: str): 74 | TRT_LOGGER = trt.Logger(trt.Logger.ERROR) 75 | with open(engine_file_path, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime: 76 | return runtime.deserialize_cuda_engine(f.read()) -------------------------------------------------------------------------------- /simple_mlp/weights/simple_mlp.engine: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ImDB0oo1/SimpleMLP-TensorRT/3a46cc16268b7e0cdc6f68d5d34b44ca5ba37881/simple_mlp/weights/simple_mlp.engine -------------------------------------------------------------------------------- /simple_mlp/weights/simple_mlp.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ImDB0oo1/SimpleMLP-TensorRT/3a46cc16268b7e0cdc6f68d5d34b44ca5ba37881/simple_mlp/weights/simple_mlp.onnx -------------------------------------------------------------------------------- /simple_mlp/weights/simple_mlp.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ImDB0oo1/SimpleMLP-TensorRT/3a46cc16268b7e0cdc6f68d5d34b44ca5ba37881/simple_mlp/weights/simple_mlp.pth --------------------------------------------------------------------------------