├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── calibrator.py ├── engine_function.py ├── model_function.py ├── train.py ├── val_tensorrt.py └── val_tensorrt_sim.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .nox/ 42 | .coverage 43 | .coverage.* 44 | .cache 45 | nosetests.xml 46 | coverage.xml 47 | *.cover 48 | .hypothesis/ 49 | .pytest_cache/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | local_settings.py 58 | db.sqlite3 59 | 60 | # Flask stuff: 61 | instance/ 62 | .webassets-cache 63 | 64 | # Scrapy stuff: 65 | .scrapy 66 | 67 | # Sphinx documentation 68 | docs/_build/ 69 | 70 | # PyBuilder 71 | target/ 72 | 73 | # Jupyter Notebook 74 | .ipynb_checkpoints 75 | 76 | # IPython 77 | profile_default/ 78 | ipython_config.py 79 | 80 | # pyenv 81 | .python-version 82 | 83 | # celery beat schedule file 84 | celerybeat-schedule 85 | 86 | # SageMath parsed files 87 | *.sage.py 88 | 89 | # Environments 90 | .env 91 | .venv 92 | env/ 93 | venv/ 94 | ENV/ 95 | env.bak/ 96 | venv.bak/ 97 | 98 | # Spyder project settings 99 | .spyderproject 100 | .spyproject 101 | 102 | # Rope project settings 103 | .ropeproject 104 | 105 | # mkdocs documentation 106 | /site 107 | 108 | # mypy 109 | .mypy_cache/ 110 | .dmypy.json 111 | dmypy.json 112 | 113 | # Pyre type checker 114 | .pyre/ 115 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 darkknightzh 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TensorRT_pytorch 2 | 3 | A simple demo to train mnist in pytorch and speed up inference by TensorRT. 4 | The training code comes from [here](https://github.com/pytorch/examples/tree/master/mnist) . 5 | The code to use TensorRT comes from samples in installation package of TensorRT. 6 | 7 | The TensorRT version we use is 5.1.5. The code may not compatible with other versions of TensorRT. 8 | 9 | You can go to [(原)pytorch中使用TensorRT](https://www.cnblogs.com/darkknightzh/p/11332155.html) for more information. 10 | -------------------------------------------------------------------------------- /calibrator.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 1993-2019 NVIDIA Corporation. All rights reserved. 3 | # 4 | # NOTICE TO LICENSEE: 5 | # 6 | # This source code and/or documentation ("Licensed Deliverables") are 7 | # subject to NVIDIA intellectual property rights under U.S. and 8 | # international Copyright laws. 9 | # 10 | # These Licensed Deliverables contained herein is PROPRIETARY and 11 | # CONFIDENTIAL to NVIDIA and is being provided under the terms and 12 | # conditions of a form of NVIDIA software license agreement by and 13 | # between NVIDIA and Licensee ("License Agreement") or electronically 14 | # accepted by Licensee. Notwithstanding any terms or conditions to 15 | # the contrary in the License Agreement, reproduction or disclosure 16 | # of the Licensed Deliverables to any third party without the express 17 | # written consent of NVIDIA is prohibited. 18 | # 19 | # NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE 20 | # LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE 21 | # SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS 22 | # PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND. 23 | # NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED 24 | # DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY, 25 | # NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE. 26 | # NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE 27 | # LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY 28 | # SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY 29 | # DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, 30 | # WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS 31 | # ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE 32 | # OF THESE LICENSED DELIVERABLES. 33 | # 34 | # U.S. Government End Users. These Licensed Deliverables are a 35 | # "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT 36 | # 1995), consisting of "commercial computer software" and "commercial 37 | # computer software documentation" as such terms are used in 48 38 | # C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government 39 | # only as a commercial end item. Consistent with 48 C.F.R.12.212 and 40 | # 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all 41 | # U.S. Government End Users acquire the Licensed Deliverables with 42 | # only those rights set forth herein. 43 | # 44 | # Any use of the Licensed Deliverables in individual and commercial 45 | # software must include, in the user documentation and internal 46 | # comments to the code, the above Disclaimer and U.S. Government End 47 | # Users Notice. 48 | 49 | import tensorrt as trt 50 | import os 51 | import json 52 | 53 | import pycuda.driver as cuda 54 | import pycuda.autoinit 55 | from PIL import Image 56 | import numpy as np 57 | 58 | # For reading size information from batches 59 | import struct 60 | import torch 61 | import cv2 62 | import math 63 | 64 | class MNISTEntropyCalibrator(trt.IInt8EntropyCalibrator2): 65 | def __init__(self, batch_data_dir, cache_file): 66 | # Whenever you specify a custom constructor for a TensorRT class, 67 | # you MUST call the constructor of the parent explicitly. 68 | trt.IInt8EntropyCalibrator2.__init__(self) 69 | 70 | self.cache_file = cache_file 71 | # Get a list of all the batch files in the batch folder. 72 | self.batch_files = [os.path.join(batch_data_dir, f) for f in os.listdir(batch_data_dir)] 73 | 74 | # Find out the shape of a batch and then allocate a device buffer of that size. 75 | self.shape, _, _ = self.read_batch_file(self.batch_files[0]) 76 | # Each element of the calibration data is a float32. 77 | self.device_input = cuda.mem_alloc(trt.volume(self.shape) * trt.float32.itemsize) 78 | 79 | # Create a generator that will give us batches. We can use next() to iterate over the result. 80 | def load_batches(): 81 | for f in self.batch_files: 82 | shape, data, labels = self.read_batch_file(f) 83 | yield shape, data, labels 84 | self.batches = load_batches() 85 | 86 | # This function is used to load calibration data from the calibration batch files. 87 | # In this implementation, one file corresponds to one batch, but it is also possible to use 88 | # aggregate data from multiple files, or use only data from portions of a file. 89 | def read_batch_file(self, filename): 90 | with open(filename, "rb") as f: 91 | # Read the first 4 integers. These will be the NCHW dimensions of the data. 92 | shape = tuple(struct.unpack("