├── .gitignore
├── LICENSE.md
├── QUICKSTART.md
├── README.md
├── assets
├── model_bn_timeline.png
└── model_gn_timeline.png
├── build.py
├── calibrator.py
├── eval.py
├── export.py
├── models.py
├── modify.py
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .ipynb_checkpoints
2 | data
3 | **/__pycache__
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | SPDX-License-Identifier: MIT
3 |
4 | Permission is hereby granted, free of charge, to any person obtaining a
5 | copy of this software and associated documentation files (the "Software"),
6 | to deal in the Software without restriction, including without limitation
7 | the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 | and/or sell copies of the Software, and to permit persons to whom the
9 | Software is furnished to do so, subject to the following conditions:
10 |
11 | The above copyright notice and this permission notice shall be included in
12 | all copies or substantial portions of the Software.
13 |
14 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
17 | THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
19 | FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
20 | DEALINGS IN THE SOFTWARE.
--------------------------------------------------------------------------------
/QUICKSTART.md:
--------------------------------------------------------------------------------
1 |
5 |
6 | # Quickstart
7 |
8 | You can follow these steps to quickly recreate the models covered in the tutorial.
9 |
10 | ## Step 1 - Train a PyTorch model on the CIFAR10 dataset
11 |
12 | Execute the following command on a machine with an NVIDIA GPU.
13 |
14 | ```bash
15 | python3 train.py model_bn --checkpoint_path=data/model_bn.pth
16 | ```
17 |
18 | ## Step 2 - Export the trained model to ONNX
19 |
20 | Execute the following command on a machine with an NVIDIA GPU.
21 |
22 | ```bash
23 | python3 export.py model_bn data/model_bn.onnx --checkpoint_path=data/model_bn.pth
24 | ```
25 |
26 | > Tip: Once exported to ONNX, the models can be profiled using the ``trtexec`` tool as described in [TUTORIAL.md](TUTORIAL.md)
27 |
28 | ## Step 3 - Build the TensorRT engine
29 |
30 | Execute the following command on a machine with an NVIDIA GPU. To use the DLA, you must call this on a machine with a DLA, like Jetson Orin.
31 |
32 | ```bash
33 | python3 build.py data/model_bn.onnx --output=data/model_bn.engine --int8 --dla_core=0 --gpu_fallback --batch_size=32
34 | ```
35 |
36 | >
37 |
38 | ## Step 4 - Evaluate the model on the CIFAR10 test dataset
39 |
40 | Execute the following command on a machine with an NVIDIA GPU. You must call this on
41 | the same machine that you called ``build.py``.
42 |
43 | ```bash
44 | python3 eval.py data/model_bn.engine --batch_size=32
45 | ```
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
5 |
6 | # Getting started with the Deep Learning Accelerator on NVIDIA Jetson Orin
7 |
8 | In this tutorial, we’ll develop a neural network that utilizes the Deep Learning Accelerator (DLA) on Jetson Orin. In case you’re unfamiliar, the DLA is an application specific integrated circuit on Jetson Xavier and Orin that is capable of running common deep learning inference operations, such as convolutions. This dedicated hardware is power efficient, and allows you to offload work from the GPU, freeing it for other tasks. If you’re interested in better utilizing the full capabilities of your Jetson platform, this tutorial is for you.
9 |
10 | To demonstrate how to use the DLA, we’ll run through the entire modeling process; from designing and training an architecture in PyTorch, to calibrating the model for INT8 precision, to profiling and validating the model. We hope this will give you a better understanding of how to use the DLA, as well as its capabilities and limitations. Understanding this information early on is important, as it will help you make design decisions to ensure your model can be easily accelerated by the DLA, enabling you to further unlock the capabilities of Jetson Orin.
11 |
12 | Let’s get started!
13 |
14 | > Tip: You can easily recreate the models covered in this tutorial yourself! Check out [``QUICKSTART.md``](QUICKSTART.md) for instructions on how to use the code contained in this repository.
15 |
16 | ## Overview
17 |
18 | - [Step 1 - Define a model in PyTorch](#step-1)
19 | - [Step 2 - Export the model to ONNX](#step-2)
20 | - [Step 3 - Optimize the untrained model using TensorRT](#step-3)
21 | - [Step 4 - Profile with Nvidia Visual Profiler](#step-4)
22 | - [Step 5 - Modify model for better DLA utilization](#step-5)
23 | - [Step 6 - Train the model](#step-6)
24 | - [Step 7 - Optimize the model (using real weights and calibration data)](#step-7)
25 | - [Step 8 - Evaluate the accuracy of the optimized model](#step-8)
26 | - [Step 9 - Modify trained model with ONNX graph surgeon](#step-9)
27 | - [Next Steps](#next-steps)
28 |
29 |
30 |
31 |
32 | ## Step 1 - Define a model in PyTorch
33 |
34 | First, let’s define our PyTorch model. We’ll create a simple CNN image classification model targeting the CIFAR10 dataset. To demonstrate the process of identifying and replacing unsupported layers, we will intentionally use the GroupNorm normalization layer, which is not currently supported by the DLA.
35 |
36 | ```python
37 | class ModelGN(nn.Module):
38 | def __init__(self, num_classes):
39 | super().__init__()
40 | self.cnn = nn.Sequential(
41 | nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1),
42 | nn.GroupNorm(8, 64),
43 | nn.ReLU(),
44 | nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
45 | nn.GroupNorm(8, 128),
46 | nn.ReLU(),
47 | nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
48 | nn.GroupNorm(8, 256),
49 | nn.ReLU(),
50 | nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
51 | nn.GroupNorm(8, 512),
52 | nn.ReLU()
53 | )
54 | self.pool = nn.AdaptiveAvgPool2d((1, 1))
55 | self.linear = nn.Linear(512, num_classes)
56 |
57 | def forward(self, x):
58 | x = self.cnn(x)
59 | x = self.pool(x)
60 | x = x.view(x.shape[0], -1)
61 | x = self.linear(x)
62 | return x
63 |
64 | model_gn = ModelGN(num_classes=10).cuda().eval()
65 | ```
66 |
67 | That’s it, our initial PyTorch model is now defined. We could proceed with training this model, but often training is a time consuming task. Instead, let’s try running our model on the DLA with TensorRT with dummy weights, just to see how it performs.
68 |
69 |
70 | ## Step 2 - Export the model to ONNX
71 |
72 | To run our model on the DLA, we need to use the TensorRT neural network inference library by NVIDIA. TensorRT is a framework that ingests a neural network graph description, and performs a variety of platform specific optimizations, including but not limited to running layers on the DLA. The optimizations result in a TensorRT engine, which may be executed using the TensorRT runtime.
73 |
74 | There are a few different utilities for converting and building TensorRT engines from PyTorch models, each with their own benefits. Here we will use the PyTorch -> ONNX -> TensorRT workflow. To do this, we first export our model to ONNX as follows
75 |
76 | ```python
77 | data = torch.zeros(1, 3, 32, 32).cuda()
78 |
79 | torch.onnx.export(model_gn, data, 'model_gn.onnx',
80 | input_names=['input'],
81 | output_names=['output'],
82 | dynamic_axes={
83 | 'input': {0: 'batch_size'},
84 | 'output': {0: 'batch_size'}
85 | }
86 | )
87 | ```
88 |
89 | Note that we specify dynamic axes for the input and output batch dimensions. By enabling dynamic batch axes, we can then generate a TensorRT engine which is capable of using batch sizes larger than the size of the example data used when exporting to ONNX.
90 |
91 |
92 | ## Step 3 - Optimize the untrained model using TensorRT
93 |
94 | To build our TensorRT engine, we can use the trtexec tool, which is installed on Jetson at the location /usr/src/tensorrt/bin/trtexec. Let’s create an alias for this
95 |
96 | ```bash
97 | alias trtexec=/usr/src/tensorrt/bin/trtexec
98 | ```
99 |
100 | Now, we can call trtexec and specify the parameters we wish to use for optimization. The command we use is as follows
101 |
102 | ```bash
103 | trtexec --onnx=model_gn.onnx --shapes=input:32x3x32x32 --saveEngine=model_gn.engine --exportProfile=model_gn.json --int8 --useDLACore=0 --allowGPUFallback --useSpinWait --separateProfileRun > model_gn.log
104 | ```
105 |
106 | - ``--onnx`` - The input ONNX file path.
107 | - ``--shapes`` - The shapes for input bindings, we specify a batch size of 32.
108 | - ``--saveEngine`` - The path to save the optimized TensorRT engine.
109 | - ``--exportProfile`` - The path to output a JSON file containing layer granularity timings.
110 | - ``--int8`` - Enable INT8 precision. This is required for best performance on Orin DLA.
111 | - ``--useDLACore=0`` - The DLA core to use for all compatible layers.
112 | - ``--allowGPUFallback`` - Allow TensorRT to run layers on GPU that aren't supported on DLA.
113 | - ``--useSpinWait`` - Synchronize GPU events, for improved profiling stability.
114 | - ``--separateProfileRun`` - Perform a separate run for layer profiling.
115 | - ``> model_gn.log`` - Capture the output into a file named ``model_gn.log``
116 |
117 | The ``trtexec`` program will log information related to the optimization and profiling processes. One notable output is the collection of layers running on the DLA. After calling ``trtexec`` to build and profile our model on GPU, we see the following output
118 |
119 | ```bash
120 | [03/31/2022-14:34:54] [I] [TRT] ---------- Layers Running on DLA ----------
121 | [03/31/2022-14:34:54] [I] [TRT] [DlaLayer] {ForeignNode[Conv_0]}
122 | [03/31/2022-14:34:54] [I] [TRT] [DlaLayer] {ForeignNode[Relu_10...Conv_11]}
123 | [03/31/2022-14:34:54] [I] [TRT] [DlaLayer] {ForeignNode[Relu_21...Conv_22]}
124 | [03/31/2022-14:34:54] [I] [TRT] [DlaLayer] {ForeignNode[Relu_32...Conv_33]}
125 | [03/31/2022-14:34:54] [I] [TRT] [DlaLayer] {ForeignNode[Relu_43]}
126 | [03/31/2022-14:34:54] [I] [TRT] ---------- Layers Running on GPU ----------
127 | [03/31/2022-14:34:54] [I] [TRT] [GpuLayer] SHUFFLE: Reshape_2 + (Unnamed Layer* 7) [Shuffle]
128 | [03/31/2022-14:34:54] [I] [TRT] [GpuLayer] PLUGIN_V2: InstanceNormalization_5
129 | [03/31/2022-14:34:54] [I] [TRT] [GpuLayer] SHUFFLE: (Unnamed Layer* 12) [Shuffle] + Reshape_7
130 | [03/31/2022-14:34:54] [I] [TRT] [GpuLayer] SCALE: 71 + (Unnamed Layer* 16) [Shuffle] + Mul_8
131 | [03/31/2022-14:34:54] [I] [TRT] [GpuLayer] SCALE: 72 + (Unnamed Layer* 19) [Shuffle] + Add_9
132 | [03/31/2022-14:34:54] [I] [TRT] [GpuLayer] SHUFFLE: Reshape_13 + (Unnamed Layer* 29) [Shuffle]
133 | ```
134 |
135 | As we can see, several layers are running on the DLA as well as the GPU. This indicates that our many layers were not compatible for execution on the DLA, and fell back to execution on the GPU. Contiguous subgraphs that execute on the DLA will appear as one "ForeignNode" block. Ideally, our entire model would appear as a single "ForeignNode" entry running on the DLA.
136 |
137 | In addition to this information, ``trtexec`` will report the number of batches executed per second.
138 |
139 | ```bash
140 | [03/31/2022-15:12:38] [I] === Performance summary ===
141 | [03/31/2022-15:12:38] [I] Throughput: 305.926 qps
142 | ```
143 |
144 | Here we see that the engine executes ``305.926`` batches per second. Multiplying this number by the batch size results in the number of images per second.
145 |
146 | This coarse information is useful for determining how fast a model will run in practice, but provides little insight into what we can do to improve the model performance.
147 |
148 | Let's take a look at how we can use the Nvidia Visual Profiler, to provide a timeline trace of
149 | what's happening during our model execution.
150 |
151 |
152 | ## Step 4 - Profile with Nvidia Visual Profiler
153 |
154 | To collect data for visualization, we first need to profile our model. To do this, we will use the ``nsys`` command line tool installed by on the NVIDIA Jetson Orin. This tool is capable of collecting system profiling information during the execution of a provided program.
155 |
156 | We will call ``nsys profile`` using ``/usr/src/tensorrt/bin/trtexec`` as the executable to be profiled. We will provide the engine we've already created in the previous step to ``trtexec`` to avoid profiling the entire optimization process.
157 |
158 | ```bash
159 | nsys profile --trace=cuda,nvtx,cublas,cudla,cusparse,cudnn,nvmedia --output=model_gn.nvvp /usr/src/tensorrt/bin/trtexec --loadEngine=model_gn.engine --iterations=10 --idleTime=500 --duration=0 --useSpinWait
160 | ```
161 |
162 | The notable ``nsys profile`` parameters are
163 |
164 | - ``--trace`` - Determines which events we want to capture during our profiling session. Notably, ``nvmedia`` will capture calls to the DLA. ``nvtx`` will capture application specific calls, which in this instance includes TensorRT layer executions.
165 | - ``--output=model_gn`` - Sets the output file that profiling information will be stored in. This can be directly loaded with the NVIDIA Visual Profiler.
166 |
167 | The ``trtexec`` parameters are
168 |
169 | - ``--loadEngine=model_gn.engine`` - The path of the TensorRT engine we previously generated by calling ``trtexec`` with ``--saveEngine``
170 | - ``--iterations=10`` - Limit the number of iterations to 10, since we are interested in the fine-grained execution timeline, not larger statistical information.
171 | - ``--duration=0`` - Removes the minimum profiling duration, we only wish to profile 10 executions.
172 | - ``--idleTime=500`` - Adds a ``500ms`` wait time between executions, so we can more easily distinguish separate execution calls in the timeline view.
173 | - ``--useSpinWait`` - Adds explicit synchronization to improve profiling stability
174 |
175 | After calling the profiling command, we will generate the profiling data which we load into NVIDIA Visual Profiler.
176 |
177 | Focusing on a single model execution call we see the following.
178 |
179 | 
180 |
181 | The orange block, which is recorded when specifying ``--trace=nvtx``, shows the entire TensorRT engine execution call. The yellow blocks, which are also captured by specifying ``--trace=nvtx``, show the individual TensorRT layer executions. The blocks in the row titled *Other Accelerators API* shows the DLA events. As we can see, the DLA events are rather spaced out, with a variety of other events happening in between. These other events take a significant portion of our model execution time.
182 |
183 | In this instance, the events in-between our DLA events are caused by the GraphNorm layer, which isn't supported by the DLA. The data is formatted for execution on the GPU, and the formatted again for execution on the DLA between each DLA block. This overhead, since it recurs several times in our model, increases our overall runtime significantly.
184 |
185 | Determining which layer isn't supported by the DLA can be done by referencing warnings output from ``trtexec`` or even inspecting the kernel names in the profiling view for hints. Sometimes the layer names don't map exactly to the original PyTorch layer, which can make this process more challenging. In this instance, substituting or removing layers and testing the conversion can provide a definitive indication of what is supported or problematic.
186 |
187 | > Tip: Though not covered here, the TensorRT Python API allows you to inspect which TensorRT layers in the TensorRT network definition are able to run on the DLA. We'll cover how to parse the ONNX model into a TensorRT network later in this tutorial. You could use this representation to programatically check which layers are supported by the DLA in Python. However, associating this information with the original PyTorch model may still present some challenge, because layers may be mapped from PyTorch to ONNX / TensorRT under different representation. In these instances, sometimes iterativelly testing the conversion of submodules in your model is the easiest way to hone-in on an incompatible layer.
188 |
189 | By this analysis, we see that the GroupNorm layer, which has relatively low computational complexity, is introducing significant runtime do to the formatting overhead between GPU and DLA.
190 |
191 | Let's try a different layer, that *is* supported by the DLA, to see how it improves our model performance.
192 |
193 |
194 | ## Step 5 - Modify model for better DLA utilization
195 |
196 | First, we modify our model definition, replacing all ``GroupNorm`` norm layers with ``BatchNorm2d`` layers.
197 |
198 | ```python
199 | class ModelBN(nn.Module):
200 | def __init__(self, num_classes):
201 | super().__init__()
202 | self.cnn = nn.Sequential(
203 | nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1),
204 | nn.BatchNorm2d(64),
205 | nn.ReLU(),
206 | nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
207 | nn.BatchNorm2d(128),
208 | nn.ReLU(),
209 | nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
210 | nn.BatchNorm2d(256),
211 | nn.ReLU(),
212 | nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
213 | nn.BatchNorm2d(512),
214 | nn.ReLU()
215 | )
216 | self.pool = nn.AdaptiveAvgPool2d((1, 1))
217 | self.linear = nn.Linear(512, num_classes)
218 |
219 | def forward(self, x):
220 | x = self.cnn(x)
221 | x = self.pool(x)
222 | x = x.view(x.shape[0], -1)
223 | x = self.linear(x)
224 | return x
225 |
226 | model_bn = ModelBN(num_classes=10).cuda().eval()
227 | ```
228 |
229 | Second, we export our model to ONNX as before.
230 |
231 | ```python
232 | data = torch.zeros(1, 3, 32, 32).cuda()
233 |
234 | torch.onnx.export(model_bn, data, 'model_bn.onnx',
235 | input_names=['input'],
236 | output_names=['output'],
237 | dynamic_axes={
238 | 'input': {0: 'batch_size'},
239 | 'output': {0: 'batch_size'}
240 | }
241 | )
242 | ```
243 |
244 | Third, we build and profile our TensorRT engine as before
245 |
246 | ```bash
247 | trtexec --onnx=model_bn.onnx --shapes=input:32x3x32x32 --saveEngine=model_bn.engine --exportProfile=model_bn.json --int8 --useDLACore=0 --allowGPUFallback --useSpinWait --separateProfileRun > model_bn.log
248 | ```
249 |
250 | Which outputs
251 |
252 | ```bash
253 | [03/31/2022-15:12:48] [I] [TRT] ---------- Layers Running on DLA ----------
254 | [03/31/2022-15:12:48] [I] [TRT] [DlaLayer] {ForeignNode[Conv_0...Relu_7]}
255 | [03/31/2022-15:12:48] [I] [TRT] [DlaLayer] {ForeignNode[Gemm_15 + linear.bias + (Unnamed Layer* 19) [Shuffle] + unsqueeze_node_after_linear.bias + (Unnamed Layer* 19) [Shuffle]_(Unnamed Layer* 19) [Shuffle]_output + (Unnamed Layer* 20) [ElementWise]]}
256 | [03/31/2022-15:12:48] [I] [TRT] ---------- Layers Running on GPU ----------
257 | [03/31/2022-15:12:48] [I] [TRT] [GpuLayer] REDUCE: GlobalAveragePool_8
258 | [03/31/2022-15:12:48] [I] [TRT] [GpuLayer] SHUFFLE: copied_squeeze_after_(Unnamed Layer* 20) [ElementWise]
259 | ```
260 |
261 | As we can see, there are significantly fewer layers running on the GPU as well as the DLA. This is an indication that more layers (Conv_0...Relu_7) were merged into a single "ForeignNode" block for execution on the DLA.
262 |
263 | This is a good sign! And as seen below, the performance of our model has approximately doubled!
264 |
265 | ```bash
266 | [03/31/2022-15:12:55] [I] === Performance summary ===
267 | [03/31/2022-15:12:55] [I] Throughput: 649.723 qps
268 | ```
269 |
270 | We profile this new engine as we did before with the following command
271 |
272 | ```bash
273 | nsys profile --trace=cuda,nvtx,cublas,cudla,cusparse,cudnn,nvmedia --output=model_bn.nvvp /usr/src/tensorrt/bin/trtexec --loadEngine=model_bn.engine --iterations=10 --idleTime=500 --duration=0 --useSpinWait
274 | ```
275 |
276 | As we see below, a majority of the model execution now occurs in one large block under the "Other accelerators API" row. This indicates our model is utilizing the DLA much more effectively!
277 |
278 | 
279 |
280 | Until this point, we've been working with dummy model weights. However, to obtain the best performance on the DLA, we need to use INT8 precision for our model. Reducing the precision of a neural network can impact the accuracy of the model. To mitigate this effect, it is important to perform calibration.
281 |
282 | Let's go over how to train the model, and perform calibration with real data.
283 |
284 |
285 | ## Step 6 - Train the model
286 |
287 | First, we train the model on the CIFAR10 dataset. This is a toy image classification dataset, which has examples in PyTorch available from a variety of sources online.
288 |
289 | Since this step isn't unique to execution on the DLA, we won't past code inline. However,
290 | we have provided a training script [``train.py``](train.py) that may be used to generate the trained neural network weights. The usage of ``train.py`` is covered in [``QUICKSTART.md``](QUICKSTART.md). To train our new batch norm model, we simply call
291 |
292 | ```bash
293 | python3 train.py model_bn --checkpoint_path=data/model_bn.pth
294 | ```
295 |
296 | The trained model weights are stored at ``model_bn.pth``. This model reaches an accuracy of roughly ``85%`` on the CIFAR10 test dataset. While other architectures can obtain better accuracy on this dataset, we've opted for a simple model since our focus only on validating that the DLA workflow and verifying that the calibrated DLA engine obtains similar accuracy to it's PyTorch counterpart.
297 |
298 | Once we have the trained weights, ``model_bn.pth`` we can export the model to ONNX as before.
299 |
300 | ```python
301 | model = ModelBN()
302 | model.load_state_dict(torch.load('model_bn.pth'))
303 |
304 | torch.onnx.export(
305 | model,
306 | data,
307 | 'model_bn.onnx',
308 | input_names=['input'],
309 | output_names=['output'],
310 | dynamic_axes={
311 | 'input': {0: 'batch_size'},
312 | 'output': {0: 'batch_size'}
313 | }
314 | )
315 | ```
316 |
317 | Now, our trained model, in ONNX representation, is stored at ``model_bn.onnx``.
318 |
319 | Let's proceed with optimizing our trained model.
320 |
321 |
322 | ## Step 7 - Optimize the model (using real weights and calibration data)
323 |
324 | To optimize our trained model, we'll use the TensorRT Python API, rather than trtexec. The reason for this decision is, that we'll need to provide a Python interface for generating calibration data.
325 |
326 | To start, we import the TensorRT Python package, and make the required API calls to parse our ONNX model into a TensorRT network.
327 |
328 | ```python
329 | import tensorrt as trt
330 |
331 | # create logger
332 | logger = trt.Logger(trt.Logger.INFO)
333 |
334 | # create builder
335 | builder = trt.Builder(logger)
336 |
337 | # create network, enabling explicit batch
338 | network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
339 |
340 | # parse the ONNX model to generate the TensorRT network
341 | parser = trt.OnnxParser(network, logger)
342 |
343 | with open(args.onnx, 'rb') as f:
344 | parser.parse(f.read())
345 | ```
346 |
347 | Now, our ``network`` class holds a full representation of our trained model. This representation however, is not executable. TensorRT still needs to optimize the model to produce the TensorRT *engine*, which stores the output of our optimization process. The engine describes which tactics are used to execute the layers in our network, including which sub-graphs are offloaded to the DLA for execution.
348 |
349 | There are a variety of configuration options we may provide to control the optimization process. These parameters are largerly controled through a configuration class which may be created as follows.
350 |
351 | ```python
352 | config = builder.create_builder_config()
353 | ```
354 |
355 | Now that our ``config`` instance is defined, we'll first set the optimization profile.
356 |
357 | ```python
358 | batch_size = 32
359 |
360 | # define the optimization configuration
361 | profile = builder.create_optimization_profile()
362 | profile.set_shape(
363 | 'input',
364 | (batch_size, 3, 32, 32), # min shape
365 | (batch_size, 3, 32, 32), # optimal shape
366 | (batch_size, 3, 32, 32) # max shape
367 | )
368 |
369 | config.add_optimization_profile(profile)
370 | ```
371 |
372 | This profile determines controls the shapes of input tensors that our final engine will be capable of handling. These shapes include a minimum, optimal, and maximum shape. The minimum and maximum shape control which shapes are allowed by the model, while the optimal shape controls which shape is used during the profiling and tactic selection process. In order for the DLA to be selected as a tactic, we must set all these shapes to be the same.
373 |
374 | Next, we specify which DLA core we want our model to run on by default, and allow GPU fallback in case DLA layers are not supported.
375 | ```python
376 | config.dla_core = 0
377 | config.set_flag(trt.BuilderFlag.GPU_FALLBACK)
378 | ```
379 |
380 | Finally, since DLA is supported only for ``INT8`` and ``FP16`` precisions, we enable ``INT8`` precision.
381 |
382 | ```python
383 | config.set_flag(trt.BuilderFlag.INT8)
384 | ```
385 |
386 | Next, we create a class ``DatasetCalibrator`` that will be used to provide calibration data to the TensorRT builder. This class wraps the CIFAR10 dataset we used for training, and implements the IINT8Calibrator interface to provide batches of data for calibration. By implementing the ``get_batch``, ``get_algorithm``, and ``get_batch_size`` methods of the ``IInt8Calibrator`` interface, our class may now be used by the TensorRT builder.
387 |
388 | ```python
389 | import torch
390 | import tensorrt as trt
391 | import torchvision.transforms as transforms
392 |
393 |
394 | class DatasetCalibrator(trt.IInt8Calibrator):
395 |
396 | def __init__(self,
397 | input, dataset,
398 | algorithm=trt.CalibrationAlgoType.ENTROPY_CALIBRATION_2):
399 | super(DatasetCalibrator, self).__init__()
400 | self.algorithm = algorithm
401 | self.dataset = dataset
402 | self.buffer = torch.zeros_like(input).contiguous()
403 | self.count = 0
404 |
405 | def get_batch(self, *args, **kwargs):
406 | if self.count < len(self.dataset):
407 | for buffer_idx in range(self.get_batch_size()):
408 |
409 | # get image from dataset
410 | dataset_idx = self.count % len(self.dataset)
411 | image, _ = self.dataset[dataset_idx]
412 | image = image.to(self.buffer.device)
413 |
414 | # copy image to buffer
415 | self.buffer[buffer_idx].copy_(image)
416 |
417 | # increment total number of images used for calibration
418 | self.count += 1
419 |
420 | return [int(self.buffer.data_ptr())]
421 | else:
422 | return [] # returning None or [] signals to TensorRT that calibration is finished
423 |
424 | def get_algorithm(self):
425 | return self.algorithm
426 |
427 | def get_batch_size(self):
428 | return int(self.buffer.shape[0])
429 |
430 | def read_calibration_cache(self, *args, **kwargs):
431 | return None
432 |
433 | def write_calibration_cache(self, cache, *args, **kwargs):
434 | pass
435 | ```
436 |
437 | Next, we instantiate our ``DatasetCalibrator`` using the CIFAR10 training dataset (which is the same that we used for training) and assign it to ``config.int8_calibrator``
438 |
439 | ```python
440 | transform = transforms.Compose([
441 | transforms.ToTensor(),
442 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
443 | ])
444 |
445 | train_dataset = torchvision.datasets.CIFAR10(
446 | root=os.path.join(args.data_dir, 'cifar10'),
447 | train=True,
448 | download=True,
449 | transform=transform
450 | )
451 |
452 | batch_size = 32
453 |
454 | data = torch.zeros(batch_size, 3, 32, 32).cuda()
455 |
456 | config.int8_calibrator = DatasetCalibrator(data, train_dataset)
457 | ```
458 |
459 | Finally, build our TensorRT engine using the network we generated from our ONNX file, and the configuration we defined above.
460 |
461 | ```python
462 | engine_bytes = builder.build_serialized_network(network, config)
463 | ```
464 |
465 | After the optimization process is finished, we save our engine to disk.
466 |
467 | ```python
468 | with open('model_bn.engine', 'wb') as f:
469 | f.write(engine_bytes)
470 | ```
471 |
472 | Our optimized, calibrated, INT8 / DLA compatbile TensorRT engine is now stored at
473 | ``model_bn.engine``.
474 |
475 |
476 | You could now use this optimized engine with a variety of inferences APIs, such as deepstream, or Triton inference server.
477 |
478 | Here we'll use the TensorRT Python API to run inference using our model, and evaluate the accuracy on the CIFAR10 test dataset.
479 |
480 |
481 | ## Step 8 - Evaluate the accuracy of the optimized model
482 |
483 | First, let's define the CIFAR10 test dataset classes.
484 |
485 | ### Define the test dataset
486 |
487 | ```python
488 | test_dataset = torchvision.datasets.CIFAR10(
489 | root=os.path.join(args.data_dir, 'cifar10'),
490 | train=False,
491 | download=True,
492 | transform=transform
493 | )
494 |
495 | test_loader = torch.utils.data.DataLoader(
496 | test_dataset,
497 | batch_size=batch_size,
498 | shuffle=False
499 | )
500 | ```
501 |
502 | Next, let's instantiate the TensorRT runtime, load our saved TensorRT engine,
503 | and create the context we'll use for execution.
504 |
505 | ```python
506 | import tensorrt as trt
507 |
508 | logger = trt.Logger(trt.Logger.INFO)
509 | runtime = trt.Runtime(logger)
510 |
511 | with open('model_bn.engine', 'rb') as f:
512 | engine_bytes = f.read()
513 | engine = runtime.deserialize_cuda_engine(engine_bytes)
514 |
515 | context = engine.create_execution_context()
516 | ```
517 |
518 | Next, we will create buffers to hold our neural network inputs and outputs.
519 | TensorRT engine execution requires that we provide data pointers to GPU memory
520 | for the input and output tensors. We will define these tensors in PyTorch,
521 | and then use the ``data_ptr()`` method that PyTorch Tensors provide to get access
522 | to the underlying GPU memory pointer.
523 |
524 | ```python
525 | input_binding_idx = engine.get_binding_index('input')
526 | output_binding_idx = engine.get_binding_index('output')
527 |
528 | input_shape = (args.batch_size, 3, 32, 32)
529 | output_shape = (args.batch_size, 10)
530 |
531 | context.set_binding_shape(
532 | input_binding_idx,
533 | input_shape
534 | )
535 |
536 | input_buffer = torch.zeros(input_shape, dtype=torch.float32, device=torch.device('cuda'))
537 | output_buffer = torch.zeros(output_shape, dtype=torch.float32, device=torch.device('cuda'))
538 |
539 | bindings = [None, None]
540 | bindings[input_binding_idx] = input_buffer.data_ptr()
541 | bindings[output_binding_idx] = output_buffer.data_ptr()
542 | ```
543 |
544 | Finally, we iterate through the data and evaluate the accuracy of our model.
545 |
546 | ```python
547 | test_accuracy = 0
548 |
549 | # run through test dataset
550 | for image, label in iter(test_loader):
551 |
552 | actual_batch_size = int(image.shape[0])
553 |
554 | input_buffer[0:actual_batch_size].copy_(image)
555 |
556 | context.execute_async_v2(
557 | bindings,
558 | torch.cuda.current_stream().cuda_stream
559 | )
560 |
561 | torch.cuda.current_stream().synchronize()
562 |
563 | output = output_buffer[0:actual_batch_size]
564 | label = label.cuda()
565 |
566 | test_accuracy += int(torch.sum(output.argmax(dim=-1) == label))
567 |
568 | test_accuracy /= len(test_dataset)
569 |
570 | print(f'TEST ACCURACY: {test_accuracy}')
571 | ```
572 |
573 | The accuracy, as before should be approximately still ``85%`` on the CIFAR10 dataset!
574 |
575 |
576 | ## Step 9 - Modify trained model with ONNX graph surgeon
577 |
578 | If you've followed the tutorial to this point, you now have a trained ONNX graph, ``model_bn.onnx``, that is largely DLA compatible. However, one of the final layers ``AdaptiveAvgPool``, which maps to the ONNX layer ``GlobalAveragePool``,
579 | is still not supported by DLA. To demonstrate another method for modifying models for DLA compatibility is to modify
580 | the ONNX graph directly.
581 |
582 | Here, we can replace the ``GlobalAveragePool`` layer, which isn't supported on the DLA, with an ``AveragePool`` layer which is
583 | supported on DLA for kernel sizes up to 8.
584 |
585 | We do so with the following code.
586 |
587 | ```python
588 | graph = gs.import_onnx(onnx.load('model_bn.onnx'))
589 |
590 | for node in graph.nodes:
591 | if node.op == 'GlobalAveragePool':
592 | node.op = 'AveragePool'
593 | node.attrs['kernel_shape'] = [2, 2]
594 |
595 | onnx.save(gs.export_onnx('model_bn_modified.onnx'), args.output)
596 | ```
597 |
598 | After optimizing our modified ONNX model with trtexec, we see the following output.
599 |
600 | ```bash
601 | [04/28/2022-21:30:53] [I] [TRT] ---------- Layers Running on DLA ----------
602 | [04/28/2022-21:30:53] [I] [TRT] [DlaLayer] {ForeignNode[Conv_0...Gemm_15]}
603 | [04/28/2022-21:30:53] [I] [TRT] ---------- Layers Running on GPU ----------
604 | [04/28/2022-21:30:53] [I] [TRT] [GpuLayer] SHUFFLE: reshape_after_Gemm_15
605 | ```
606 |
607 | Now, one larger sugraph is running on the DLA, rather than two sugraphs split apart by the unsupported pooling layer!
608 |
609 | There are often many ways you can creatively replace unsupported layers with other layers that are supported. It largely depends on the exact model and usage. We hope this helps give you ideas for how to better optimize your models for DLA.
610 |
611 |
612 | ## Next steps
613 |
614 | That concludes this tutorial. We hope this helps you get started using the DLA on Jetson.
615 |
616 | For more information related to topics discussed in this tutorial, check out
617 |
618 | - [``QUICKSTART.md``](QUICKSTART.md) - Instructions for quickly reproducing the results in this tutorial
619 | - [TensorRT User Guide](https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html)
620 | - [TensorRT Python API](https://docs.nvidia.com/deeplearning/tensorrt/api/python_api/)
621 | - [NSight Systems Profiler](https://docs.nvidia.com/nsight-systems/UserGuide/index.html#cli-profiling)
622 |
--------------------------------------------------------------------------------
/assets/model_bn_timeline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVIDIA-AI-IOT/jetson_dla_tutorial/ee4e0ed3c8446ef0d5ef443d017310f6207ce864/assets/model_bn_timeline.png
--------------------------------------------------------------------------------
/assets/model_gn_timeline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVIDIA-AI-IOT/jetson_dla_tutorial/ee4e0ed3c8446ef0d5ef443d017310f6207ce864/assets/model_gn_timeline.png
--------------------------------------------------------------------------------
/build.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: MIT
3 |
4 |
5 | import argparse
6 | import os
7 | import subprocess
8 | from calibrator import DatasetCalibrator
9 | import tensorrt as trt
10 | import torch
11 | import torchvision
12 | import torchvision.transforms as transforms
13 |
14 |
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument('onnx', type=str, help='Path to the ONNX model.')
17 | parser.add_argument('--output', type=str, default=None, help='Path to output the optimized TensorRT engine')
18 | parser.add_argument('--max_workspace_size', type=int, default=1<<25, help='Max workspace size for TensorRT engine.')
19 | parser.add_argument('--int8', action='store_true')
20 | parser.add_argument('--fp16', action='store_true')
21 | parser.add_argument('--batch_size', type=int, default=1)
22 | parser.add_argument('--dla_core', type=int, default=None)
23 | parser.add_argument('--gpu_fallback', action='store_true')
24 | parser.add_argument('--dataset_path', type=str, default='data/cifar10')
25 | args = parser.parse_args()
26 |
27 | transform = transforms.Compose([
28 | transforms.ToTensor(),
29 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
30 | ])
31 |
32 | train_dataset = torchvision.datasets.CIFAR10(
33 | root=args.dataset_path,
34 | train=True,
35 | download=True,
36 | transform=transform
37 | )
38 |
39 | test_dataset = torchvision.datasets.CIFAR10(
40 | root=args.dataset_path,
41 | train=False,
42 | download=True,
43 | transform=transform
44 | )
45 |
46 | data = torch.zeros(args.batch_size, 3, 32, 32).cuda()
47 |
48 | logger = trt.Logger(trt.Logger.INFO)
49 | builder = trt.Builder(logger)
50 | builder.max_batch_size = args.batch_size
51 | network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
52 | parser = trt.OnnxParser(network, logger)
53 |
54 | with open(args.onnx, 'rb') as f:
55 | parser.parse(f.read())
56 |
57 | profile = builder.create_optimization_profile()
58 | profile.set_shape(
59 | 'input',
60 | (args.batch_size, 3, 32, 32),
61 | (args.batch_size, 3, 32, 32),
62 | (args.batch_size, 3, 32, 32)
63 | )
64 |
65 | config = builder.create_builder_config()
66 |
67 | config.max_workspace_size = args.max_workspace_size
68 |
69 | if args.fp16:
70 | config.set_flag(trt.BuilderFlag.FP16)
71 |
72 | if args.int8:
73 | config.set_flag(trt.BuilderFlag.INT8)
74 | config.int8_calibrator = DatasetCalibrator(data, train_dataset)
75 |
76 | if args.dla_core is not None:
77 | config.default_device_type = trt.DeviceType.DLA
78 | config.DLA_core = args.dla_core
79 |
80 | if args.gpu_fallback:
81 | config.set_flag(trt.BuilderFlag.GPU_FALLBACK)
82 |
83 | config.add_optimization_profile(profile)
84 | config.set_calibration_profile(profile)
85 |
86 | engine = builder.build_serialized_network(network, config)
87 |
88 | if args.output is not None:
89 | with open(args.output, 'wb') as f:
90 | f.write(engine)
91 |
92 |
--------------------------------------------------------------------------------
/calibrator.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: MIT
3 |
4 |
5 | import torch
6 | import tensorrt as trt
7 |
8 |
9 | __all__ = [
10 | 'DatasetCalibrator'
11 | ]
12 |
13 |
14 | class DatasetCalibrator(trt.IInt8Calibrator):
15 |
16 | def __init__(self,
17 | input, dataset,
18 | algorithm=trt.CalibrationAlgoType.ENTROPY_CALIBRATION_2):
19 | super(DatasetCalibrator, self).__init__()
20 | self.dataset = dataset
21 | self.algorithm = algorithm
22 | self.buffer = torch.zeros_like(input).contiguous()
23 | self.count = 0
24 |
25 | def get_batch(self, *args, **kwargs):
26 | if self.count < len(self.dataset):
27 | for buffer_idx in range(self.get_batch_size()):
28 | dataset_idx = self.count % len(self.dataset) # roll around if not multiple of dataset
29 | image, _ = self.dataset[dataset_idx]
30 | image = image.to(self.buffer.device)
31 | self.buffer[buffer_idx].copy_(image)
32 |
33 | self.count += 1
34 | return [int(self.buffer.data_ptr())]
35 | else:
36 | return []
37 |
38 | def get_algorithm(self):
39 | return self.algorithm
40 |
41 | def get_batch_size(self):
42 | return int(self.buffer.shape[0])
43 |
44 | def read_calibration_cache(self, *args, **kwargs):
45 | return None
46 |
47 | def write_calibration_cache(self, cache, *args, **kwargs):
48 | pass
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: MIT
3 |
4 |
5 | import argparse
6 | import os
7 | import subprocess
8 | import tensorrt as trt
9 | import torch
10 | import torchvision
11 | import torchvision.transforms as transforms
12 |
13 |
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument('engine', type=str, default=None, help='Path to the optimized TensorRT engine')
16 | parser.add_argument('--batch_size', type=int, default=1)
17 | parser.add_argument('--dataset_path', type=str, default='data/cifar10')
18 | args = parser.parse_args()
19 |
20 | transform = transforms.Compose([
21 | transforms.ToTensor(),
22 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
23 | ])
24 |
25 | test_dataset = torchvision.datasets.CIFAR10(
26 | root=args.dataset_path,
27 | train=False,
28 | download=True,
29 | transform=transform
30 | )
31 |
32 | test_loader = torch.utils.data.DataLoader(
33 | test_dataset,
34 | batch_size=args.batch_size,
35 | shuffle=False
36 | )
37 |
38 | logger = trt.Logger()
39 | runtime = trt.Runtime(logger)
40 |
41 | with open(args.engine, 'rb') as f:
42 | engine = runtime.deserialize_cuda_engine(f.read())
43 |
44 |
45 | context = engine.create_execution_context()
46 |
47 | input_binding_idx = engine.get_binding_index('input')
48 | output_binding_idx = engine.get_binding_index('output')
49 |
50 | input_shape = (args.batch_size, 3, 32, 32)
51 | output_shape = (args.batch_size, 10)
52 |
53 | context.set_binding_shape(
54 | input_binding_idx,
55 | input_shape
56 | )
57 |
58 | input_buffer = torch.zeros(input_shape, dtype=torch.float32, device=torch.device('cuda'))
59 | output_buffer = torch.zeros(output_shape, dtype=torch.float32, device=torch.device('cuda'))
60 |
61 | bindings = [None, None]
62 | bindings[input_binding_idx] = input_buffer.data_ptr()
63 | bindings[output_binding_idx] = output_buffer.data_ptr()
64 |
65 | test_accuracy = 0
66 |
67 | # run through test dataset
68 | for image, label in iter(test_loader):
69 |
70 | actual_batch_size = int(image.shape[0])
71 |
72 | input_buffer[0:actual_batch_size].copy_(image)
73 |
74 | context.execute_async_v2(
75 | bindings,
76 | torch.cuda.current_stream().cuda_stream
77 | )
78 |
79 | torch.cuda.current_stream().synchronize()
80 |
81 | output = output_buffer[0:actual_batch_size]
82 | label = label.cuda()
83 |
84 | test_accuracy += int(torch.sum(output.argmax(dim=-1) == label))
85 |
86 | test_accuracy /= len(test_dataset)
87 |
88 | print(f'TEST ACCURACY: {test_accuracy}')
--------------------------------------------------------------------------------
/export.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: MIT
3 |
4 |
5 | import argparse
6 | import torch
7 | import os
8 | from models import MODELS
9 |
10 |
11 | parser = argparse.ArgumentParser()
12 | parser.add_argument('model_name', type=str)
13 | parser.add_argument('output', type=str)
14 | parser.add_argument('--checkpoint_path', type=str, default=None)
15 | args = parser.parse_args()
16 |
17 | data = torch.zeros(1, 3, 32, 32).cuda()
18 |
19 | model = MODELS[args.model_name]().cuda().eval()
20 |
21 | if args.checkpoint_path is not None:
22 | model.load_state_dict(torch.load(args.checkpoint_path))
23 |
24 | torch.onnx.export(
25 | model,
26 | data,
27 | args.output,
28 | input_names=['input'],
29 | output_names=['output'],
30 | dynamic_axes={
31 | 'input': {0: 'batch_size'},
32 | 'output': {0: 'batch_size'}
33 | }
34 | )
35 |
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: MIT
3 |
4 |
5 | import torch
6 | import torch.nn as nn
7 |
8 |
9 | __all__ = [
10 | 'ModelGN',
11 | 'ModelBN',
12 | 'MODELS'
13 | ]
14 |
15 |
16 | class ModelGN(nn.Module):
17 | def __init__(self):
18 | super().__init__()
19 | self.cnn = nn.Sequential(
20 | nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1),
21 | nn.GroupNorm(8, 64),
22 | nn.ReLU(),
23 | nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
24 | nn.GroupNorm(8, 128),
25 | nn.ReLU(),
26 | nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
27 | nn.GroupNorm(8, 256),
28 | nn.ReLU(),
29 | nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
30 | nn.GroupNorm(8, 512),
31 | nn.ReLU()
32 | )
33 | self.pool = nn.AdaptiveAvgPool2d((1, 1))
34 | self.linear = nn.Linear(512, 10)
35 |
36 | def forward(self, x):
37 | x = self.cnn(x)
38 | x = self.pool(x)
39 | x = x.view(x.shape[0], -1)
40 | x = self.linear(x)
41 | return x
42 |
43 |
44 | class ModelBN(nn.Module):
45 | def __init__(self):
46 | super().__init__()
47 | self.cnn = nn.Sequential(
48 | nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1),
49 | nn.BatchNorm2d(64),
50 | nn.ReLU(),
51 | nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
52 | nn.BatchNorm2d(128),
53 | nn.ReLU(),
54 | nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
55 | nn.BatchNorm2d(256),
56 | nn.ReLU(),
57 | nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
58 | nn.BatchNorm2d(512),
59 | nn.ReLU()
60 | )
61 | self.pool = nn.AdaptiveAvgPool2d((1, 1))
62 | self.linear = nn.Linear(512, 10)
63 |
64 | def forward(self, x):
65 | x = self.cnn(x)
66 | x = self.pool(x)
67 | x = x.view(x.shape[0], -1)
68 | x = self.linear(x)
69 | return x
70 |
71 |
72 | MODELS = {
73 | 'model_gn': ModelGN,
74 | 'model_bn': ModelBN
75 | }
76 |
--------------------------------------------------------------------------------
/modify.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: MIT
3 |
4 |
5 | import argparse
6 | import onnx_graphsurgeon as gs
7 | import onnx
8 |
9 |
10 | parser = argparse.ArgumentParser()
11 | parser.add_argument('input', type=str, help='Path to the ONNX model.')
12 | parser.add_argument('output', type=str, help='Path to output the modified ONNX model.')
13 | args = parser.parse_args()
14 |
15 | graph = gs.import_onnx(onnx.load(args.input))
16 |
17 | for node in graph.nodes:
18 | if node.op == 'GlobalAveragePool':
19 | node.op = 'AveragePool'
20 | node.attrs['kernel_shape'] = [2, 2]
21 |
22 | onnx.save(gs.export_onnx(graph), args.output)
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: MIT
3 |
4 |
5 | import os
6 | import argparse
7 | import torch
8 | import torch.nn.functional as F
9 | import torchvision
10 | import torchvision.transforms as transforms
11 | from models import MODELS
12 |
13 |
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument('model_name', type=str, help='The name of the model. See the MODELS dictionary in models.py for options.')
16 | parser.add_argument('--batch_size', type=int, default=64, help='The data loader batch size.')
17 | parser.add_argument('--lr', type=float, default=1e-3, help='The optimizer learning rate.')
18 | parser.add_argument('--optimizer', type=str, default='adam', help='The optimizer type. Must be one of the keys in the OPTIMIZERS variable in train.py.')
19 | parser.add_argument('--momentum', type=float, default=0.9, help='The optimizier momentum. Only applies when optimizer=sgd.')
20 | parser.add_argument('--epochs', type=int, default=50, help='The number of training epochs.')
21 | parser.add_argument('--dataset_path', type=str, default='data/cifar10', help='The directory to store generated models and logs.')
22 | parser.add_argument('--checkpoint_path', type=str, default=None, help='The path to store the model weights.')
23 | args = parser.parse_args()
24 |
25 | OPTIMIZERS = {
26 | 'sgd': lambda params, args: torch.optim.SGD(params, lr=args.lr, momentum=args.momentum),
27 | 'adam': lambda params, args: torch.optim.Adam(params, lr=args.lr)
28 | }
29 |
30 | model = MODELS[args.model_name]().cuda()
31 | optimizer = OPTIMIZERS[args.optimizer](model.parameters(), args)
32 |
33 |
34 | transform_train = transforms.Compose([
35 | transforms.RandomCrop(32, padding=4),
36 | transforms.RandomHorizontalFlip(),
37 | transforms.ToTensor(),
38 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
39 | ])
40 |
41 | transform_test = transforms.Compose([
42 | transforms.ToTensor(),
43 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
44 | ])
45 |
46 | train_dataset = torchvision.datasets.CIFAR10(
47 | root=args.dataset_path,
48 | train=True,
49 | download=True,
50 | transform=transform_train
51 | )
52 |
53 | train_loader = torch.utils.data.DataLoader(
54 | train_dataset,
55 | batch_size=args.batch_size,
56 | shuffle=True
57 | )
58 |
59 | test_dataset = torchvision.datasets.CIFAR10(
60 | root=args.dataset_path,
61 | train=False,
62 | download=True,
63 | transform=transform_test
64 | )
65 |
66 | test_loader = torch.utils.data.DataLoader(
67 | test_dataset,
68 | batch_size=args.batch_size,
69 | shuffle=False
70 | )
71 |
72 | best_accuracy = 0.0
73 |
74 | for epoch in range(args.epochs):
75 |
76 | train_loss = 0.0
77 | test_loss = 0.0
78 | train_accuracy = 0
79 | test_accuracy = 0
80 |
81 | # train loop
82 |
83 | model = model.train()
84 |
85 | for image, label in iter(train_loader):
86 |
87 | optimizer.zero_grad()
88 |
89 | image = image.cuda()
90 | label = label.cuda()
91 |
92 | output = model(image)
93 |
94 | loss = F.cross_entropy(output, label)
95 |
96 | loss.backward()
97 | optimizer.step()
98 |
99 | train_loss += float(loss)
100 | train_accuracy += int(torch.sum(output.argmax(dim=-1) == label))
101 |
102 | train_accuracy /= len(train_dataset)
103 | train_loss /= len(train_loader)
104 |
105 | model = model.eval()
106 |
107 | for image, label in iter(test_loader):
108 |
109 | image = image.cuda()
110 | label = label.cuda()
111 |
112 | output = model(image)
113 |
114 | loss = F.cross_entropy(output, label)
115 |
116 | test_loss += float(loss)
117 | test_accuracy += int(torch.sum(output.argmax(dim=-1) == label))
118 |
119 | test_accuracy /= len(test_dataset)
120 | test_loss /= len(test_loader)
121 |
122 | print(f'{epoch}, {train_loss}, {test_loss}, {train_accuracy}, {test_accuracy}')
123 |
124 | if test_accuracy > best_accuracy and args.checkpoint_path is not None:
125 | print(f'Saving checkpoint to {args.checkpoint_path} for model with test accuracy {test_accuracy}.')
126 | torch.save(model.state_dict(), args.checkpoint_path)
--------------------------------------------------------------------------------