├── .github └── workflows │ └── python-package.yml ├── .gitignore ├── .isort.cfg ├── .pre-commit-config.yaml ├── LICENSE.md ├── README.md ├── build.py ├── setup.cfg ├── setup.py ├── tests ├── __init__.py ├── test_convert.py └── test_converters │ ├── test_embedding.py │ ├── test_gather.py │ ├── test_grid_sample.py │ ├── test_group_norm.py │ └── test_topk.py └── torch2trt_dynamic ├── __init__.py ├── calibration.py ├── converters ├── AdaptiveAvgPool2d.py ├── AdaptiveMaxPool2d.py ├── BatchNorm1d.py ├── BatchNorm2d.py ├── Conv1d.py ├── Conv2d.py ├── ConvTranspose1d.py ├── ConvTranspose2d.py ├── Embedding.py ├── GRU.py ├── GroupNorm.py ├── Identity.py ├── LayerNorm.py ├── Linear.py ├── LogSoftmax.py ├── ReLU.py ├── ReLU6.py ├── __init__.py ├── activation.py ├── adaptive_avg_pool1d.py ├── adaptive_avg_pool2d.py ├── adaptive_max_pool1d.py ├── adaptive_max_pool2d.py ├── add.py ├── addcmul.py ├── arange.py ├── argmax.py ├── argmin.py ├── avg_pool2d.py ├── bmm.py ├── cast_type.py ├── cat.py ├── chunk.py ├── clamp.py ├── conv2d.py ├── cummax.py ├── cummin.py ├── cumprod.py ├── cumsum.py ├── deform_conv2d.py ├── div.py ├── dummy_converters.py ├── exview.py ├── flatten.py ├── flip.py ├── floor_divide.py ├── full.py ├── full_like.py ├── gather.py ├── gelu.py ├── getitem.py ├── grid_sample.py ├── identity.py ├── index_select.py ├── instance_norm.py ├── interpolate_custom.py ├── linear.py ├── linspace.py ├── logical.py ├── masked_fill.py ├── matmul.py ├── max.py ├── max_pool1d.py ├── max_pool2d.py ├── mean.py ├── meshgrid.py ├── min.py ├── mod.py ├── mul.py ├── narrow.py ├── new_ones.py ├── new_zeros.py ├── nms.py ├── normalize.py ├── numel.py ├── ones.py ├── ones_like.py ├── pad.py ├── permute.py ├── pixel_shuffle.py ├── pow.py ├── prelu.py ├── prod.py ├── relu.py ├── relu6.py ├── repeat.py ├── roi_align.py ├── roi_pool.py ├── roll.py ├── sigmoid.py ├── size.py ├── softmax.py ├── split.py ├── squeeze.py ├── stack.py ├── std.py ├── sub.py ├── sum.py ├── t.py ├── take.py ├── tanh.py ├── to.py ├── topk.py ├── transpose.py ├── unary.py ├── unfold.py ├── unsqueeze.py ├── view.py ├── view_as.py ├── where.py ├── zeros.py └── zeros_like.py ├── module_test.py ├── plugins ├── __init__.py ├── create_adaptivepool_plugin.py ├── create_dcn_plugin.py ├── create_groupnorm_plugin.py ├── create_nms_plugin.py ├── create_roiextractor_plugin.py ├── create_roipool_plugin.py ├── create_torchbmm_plugin.py ├── create_torchcum_plugin.py ├── create_torchcummaxmin_plugin.py ├── create_torchunfold_plugin.py └── globals.py ├── shape_converter.py ├── torch2trt_dynamic.py ├── torch_allocator.py ├── trt_module.py └── utils.py /.github/workflows/python-package.yml: -------------------------------------------------------------------------------- 1 | name: build 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | lint: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - uses: actions/checkout@v2 10 | - name: Set up Python 3.9 11 | uses: actions/setup-python@v2 12 | with: 13 | python-version: 3.9 14 | - name: Install pre-commit hook 15 | run: | 16 | pip install pre-commit 17 | pre-commit install 18 | - name: Linting 19 | run: pre-commit run --all-files 20 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .ninja_deps 2 | .ninja_log 3 | build.ninja 4 | tags 5 | *.o 6 | *.pb.o 7 | torch2trt.egg-info 8 | torch2trt_dynamic.egg-info 9 | build/ 10 | dist/ 11 | __pycache__/ 12 | *.so 13 | *.pb.h 14 | *.pb.cc 15 | *_pb2.py 16 | *.pyc 17 | *.ipynb_checkpoints 18 | *.pth 19 | -------------------------------------------------------------------------------- /.isort.cfg: -------------------------------------------------------------------------------- 1 | [settings] 2 | known_third_party = distutils,graphviz,numpy,packaging,pytest,setuptools,tensorrt,torch,torchvision 3 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/PyCQA/flake8 3 | rev: 4.0.1 4 | hooks: 5 | - id: flake8 6 | - repo: https://github.com/asottile/seed-isort-config 7 | rev: v2.2.0 8 | hooks: 9 | - id: seed-isort-config 10 | - repo: https://github.com/timothycrosley/isort 11 | rev: 4.3.21 12 | hooks: 13 | - id: isort 14 | - repo: https://github.com/pre-commit/mirrors-yapf 15 | rev: v0.32.0 16 | hooks: 17 | - id: yapf 18 | - repo: https://github.com/pre-commit/pre-commit-hooks 19 | rev: v4.2.0 20 | hooks: 21 | - id: trailing-whitespace 22 | - id: check-yaml 23 | - id: end-of-file-fixer 24 | - id: requirements-txt-fixer 25 | - id: double-quote-string-fixer 26 | - id: check-merge-conflict 27 | - id: fix-encoding-pragma 28 | args: ["--remove"] 29 | - id: mixed-line-ending 30 | args: ["--fix=lf"] 31 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # torch2trt dynamic 2 | 3 | This is a branch of [torch2trt](https://github.com/NVIDIA-AI-IOT/torch2trt) with dynamic input support. 4 | 5 | ## Usage 6 | 7 | Here are some examples 8 | 9 | ### Convert 10 | 11 | ```python 12 | from torch2trt_dynamic import module2trt, BuildEngineConfig 13 | import torch 14 | from torchvision.models import resnet18 15 | 16 | # create some regular pytorch model... 17 | model = resnet18().cuda().eval() 18 | 19 | # create example data 20 | x = torch.ones((1, 3, 224, 224)).cuda() 21 | 22 | # convert to TensorRT feeding sample data as input 23 | config = BuildEngineConfig( 24 | shape_ranges=dict( 25 | x=dict( 26 | min=(1, 3, 224, 224), 27 | opt=(2, 3, 224, 224), 28 | max=(4, 3, 224, 224), 29 | ) 30 | )) 31 | trt_model = module2trt( 32 | model, 33 | args=[x], 34 | config=config) 35 | ``` 36 | 37 | ### Execute 38 | 39 | We can execute the returned `TRTModule` just like the original PyTorch model 40 | 41 | ```python 42 | x = torch.rand(1, 3, 224, 224).cuda() 43 | with torch.no_grad(): 44 | y = model(x) 45 | y_trt = trt_model(x) 46 | 47 | # check the output against PyTorch 48 | torch.testing.assert_close(y, y_trt) 49 | ``` 50 | 51 | ### Save and load 52 | 53 | We can save the model as a ``state_dict``. 54 | 55 | ```python 56 | torch.save(trt_model.state_dict(), 'my_engine.pth') 57 | ``` 58 | 59 | We can load the saved model into a ``TRTModule`` 60 | 61 | ```python 62 | from torch2trt_dynamic import TRTModule 63 | 64 | trt_model = TRTModule() 65 | trt_model.load_state_dict(torch.load('my_engine.pth')) 66 | ``` 67 | 68 | ## Setup 69 | 70 | To install without compiling plugins, call the following 71 | 72 | ```bash 73 | git clone https://github.com/grimoire/torch2trt_dynamic.git torch2trt_dynamic 74 | cd torch2trt_dynamic 75 | pip install . 76 | ``` 77 | 78 | ### Set plugins(optional) 79 | 80 | Some layers such as `GN` need c++ plugins. Install the plugin project below 81 | 82 | [amirstan_plugin](https://github.com/grimoire/amirstan_plugin) 83 | 84 | **DO NOT FORGET** to export the environment variable `AMIRSTAN_LIBRARY_PATH` 85 | 86 | ## How to add (or override) a converter 87 | 88 | Here we show how to add a converter for the ``ReLU`` module using the TensorRT Python API. 89 | 90 | ```python 91 | import tensorrt as trt 92 | from torch2trt_dynamic import tensorrt_converter 93 | 94 | @tensorrt_converter('torch.nn.ReLU.forward') 95 | def convert_ReLU(ctx): 96 | input = ctx.method_args[1] 97 | output = ctx.method_return 98 | layer = ctx.network.add_activation(input=input._trt, type=trt.ActivationType.RELU) 99 | output._trt = layer.get_output(0) 100 | ``` 101 | 102 | The converter takes one argument, a ``ConversionContext``, which will contain 103 | the following 104 | 105 | * ``ctx.network`` - The TensorRT network that is being constructed. 106 | 107 | * ``ctx.method_args`` - Positional arguments that were passed to the specified PyTorch function. The ``_trt`` attribute is set for relevant input tensors. 108 | * ``ctx.method_kwargs`` - Keyword arguments that were passed to the specified PyTorch function. 109 | * ``ctx.method_return`` - The value returned by the specified PyTorch function. The converter must set the ``_trt`` attribute where relevant. 110 | 111 | Please see [this folder](torch2trt_dynamic/converters) for more examples. 112 | -------------------------------------------------------------------------------- /build.py: -------------------------------------------------------------------------------- 1 | import imp 2 | import os 3 | import subprocess 4 | from string import Template 5 | 6 | PLUGINS = [ 7 | 'interpolate', 8 | ] 9 | 10 | BASE_FOLDER = 'torch2trt_dynamic/converters' 11 | 12 | NINJA_TEMPLATE = Template(( 13 | 'rule link\n' 14 | ' command = g++ -shared -o $$out $$in -L$torch_dir/lib -L$cuda_dir/lib64 -L$trt_lib_dir -lc10 -lc10_cuda -ltorch -lcudart -lprotobuf -lprotobuf-lite -pthread -lpthread -lnvinfer\n' # noqa: E501 15 | 'rule protoc\n' 16 | ' command = protoc $$in --cpp_out=. --python_out=.\n' 17 | 'rule cxx\n' 18 | ' command = g++ -c -fPIC $$in -I$cuda_dir/include -I$torch_dir/include -I$torch_dir/include/torch/csrc/api/include -I. -std=c++11 -I$trt_inc_dir\n' # noqa: E501 19 | )) 20 | 21 | PLUGIN_TEMPLATE = Template(( 22 | 'build $plugin_dir/$plugin.pb.h $plugin_dir/$plugin.pb.cc $plugin_dir/${plugin}_pb2.py: protoc $plugin_dir/$plugin.proto\n' # noqa: E501 23 | 'build $plugin.pb.o: cxx $plugin_dir/$plugin.pb.cc\n' 24 | 'build $plugin.o: cxx $plugin_dir/$plugin.cpp\n')) 25 | 26 | 27 | def build(cuda_dir='/usr/local/cuda', 28 | torch_dir=imp.find_module('torch')[1], 29 | trt_inc_dir='/usr/include/aarch64-linux-gnu', 30 | trt_lib_dir='/usr/lib/aarch64-linux-gnu'): 31 | 32 | global PLUGINS, BASE_FOLDER, NINJA_TEMPLATE, PLUGIN_TEMPLATE 33 | 34 | NINJA_STR = NINJA_TEMPLATE.substitute({ 35 | 'torch_dir': torch_dir, 36 | 'cuda_dir': cuda_dir, 37 | 'trt_inc_dir': trt_inc_dir, 38 | 'trt_lib_dir': trt_lib_dir, 39 | }) 40 | 41 | plugin_o_files = [] 42 | for plugin in PLUGINS: 43 | NINJA_STR += \ 44 | PLUGIN_TEMPLATE.substitute({ 45 | 'plugin': plugin, 46 | 'plugin_dir': os.path.join(BASE_FOLDER, plugin), 47 | }) 48 | plugin_o_files += [plugin + '.pb.o', plugin + '.o'] 49 | 50 | NINJA_STR += Template( 51 | ('build torch2trt_dynamic/libtorch2trt_dynamic.so: link $o_files\n' 52 | )).substitute({'o_files': ' '.join(plugin_o_files)}) 53 | 54 | with open('build.ninja', 'w') as f: 55 | f.write(NINJA_STR) 56 | 57 | subprocess.call(['ninja']) 58 | 59 | 60 | if __name__ == '__main__': 61 | build() 62 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [bdist_wheel] 2 | universal=1 3 | 4 | [aliases] 5 | test=pytest 6 | 7 | [yapf] 8 | based_on_style = pep8 9 | blank_line_before_nested_class_or_def = true 10 | split_before_expression_after_opening_paren = true 11 | 12 | [isort] 13 | line_length = 79 14 | multi_line_output = 0 15 | known_standard_library = pkg_resources,setuptools,logging,os,warnings,abc 16 | known_first_party = mmcv 17 | known_third_party = addict,cv2,m2r,numpy,onnx,onnxruntime,packaging,pytest,recommonmark,resnet_cifar,scipy,tensorrt,torch,torchvision,yaml,yapf 18 | no_lines_before = STDLIB,LOCALFOLDER 19 | default_section = THIRDPARTY 20 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/grimoire/torch2trt_dynamic/05c5fdce8db9a8ff74ebbecc5ae23c74a07b7016/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_convert.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch2trt_dynamic import TRTModule, module2trt 3 | from torchvision.models import resnet18 4 | 5 | 6 | def test_convert(tmp_path): 7 | model = resnet18().cuda().eval() 8 | 9 | trt_model = module2trt( 10 | model, 11 | args=[torch.rand(1, 3, 32, 32).cuda()], 12 | ) 13 | 14 | model_path = tmp_path / 'tmp.pth' 15 | torch.save(trt_model.state_dict(), model_path) 16 | assert model_path.exists() 17 | 18 | trt_model = TRTModule() 19 | trt_model.load_state_dict(torch.load(model_path)) 20 | 21 | x = torch.rand(1, 3, 32, 32).cuda() 22 | with torch.no_grad(): 23 | y = model(x) 24 | y_trt = trt_model(x) 25 | 26 | print(y) 27 | print(y_trt) 28 | torch.testing.assert_close(y, y_trt) 29 | -------------------------------------------------------------------------------- /tests/test_converters/test_embedding.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from torch import nn 4 | from torch2trt_dynamic import module2trt 5 | 6 | 7 | class _TestModel(nn.Module): 8 | 9 | def __init__(self, num, dim) -> None: 10 | super().__init__() 11 | self.embeding = nn.Embedding(num, dim) 12 | 13 | def forward(self, input): 14 | return self.embeding(input) 15 | 16 | 17 | class TestGather: 18 | 19 | @pytest.fixture 20 | def dim(self): 21 | yield 4 22 | 23 | @pytest.fixture 24 | def num(self): 25 | yield 10 26 | 27 | @pytest.fixture 28 | def batch(self): 29 | yield 2 30 | 31 | @pytest.fixture 32 | def input(self, batch, num): 33 | yield torch.randint(num, (batch, 6)).cuda() 34 | 35 | def test_gather(self, input, dim, num): 36 | model = _TestModel(num, dim).eval().cuda() 37 | dummy_input = torch.zeros_like(input) 38 | trt_model = module2trt(model, args=[dummy_input]) 39 | 40 | with torch.inference_mode(): 41 | gt = model(input) 42 | out = trt_model(input) 43 | torch.testing.assert_close(out, gt) 44 | -------------------------------------------------------------------------------- /tests/test_converters/test_gather.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from torch import nn 4 | from torch2trt_dynamic import module2trt 5 | 6 | 7 | class _TestModel(nn.Module): 8 | 9 | def __init__(self, dim: int) -> None: 10 | super().__init__() 11 | self.dim = dim 12 | 13 | def forward(self, input, index): 14 | return torch.gather(input, self.dim, index) 15 | 16 | 17 | class TestGather: 18 | 19 | @pytest.fixture 20 | def input(self): 21 | yield torch.rand(3, 4, 5).cuda() 22 | 23 | @pytest.fixture 24 | def dim(self, request): 25 | yield request.param 26 | 27 | @pytest.fixture 28 | def index(self, input, dim): 29 | max_val = input.size(dim) 30 | yield torch.randint(max_val, (3, 4, 5)).cuda() 31 | 32 | @pytest.mark.parametrize('dim', [0, 1, 2]) 33 | def test_gather(self, input, dim, index): 34 | model = _TestModel(dim) 35 | dummy_input = torch.zeros_like(input) 36 | dummy_index = torch.zeros_like(index) 37 | trt_model = module2trt(model, args=[dummy_input, dummy_index]) 38 | 39 | with torch.inference_mode(): 40 | gt = model(input, index) 41 | out = trt_model(input, index) 42 | torch.testing.assert_close(out, gt) 43 | -------------------------------------------------------------------------------- /tests/test_converters/test_group_norm.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from torch import nn 4 | from torch2trt_dynamic import module2trt 5 | 6 | 7 | class _TestModel(nn.Module): 8 | 9 | def __init__(self, *args, **kwargs) -> None: 10 | super().__init__() 11 | self.gn = nn.GroupNorm(*args, **kwargs) 12 | 13 | def forward(self, input): 14 | return self.gn(input) 15 | 16 | 17 | class TestGroupNorm: 18 | 19 | @pytest.fixture 20 | def num_channels(self): 21 | yield 4 22 | 23 | @pytest.fixture 24 | def input(self, num_channels): 25 | yield torch.rand(2, num_channels, 8, 16).cuda() 26 | 27 | @pytest.fixture 28 | def num_groups(self): 29 | yield 2 30 | 31 | def test_group_norm(self, input, num_groups): 32 | num_channels = input.size(1) 33 | model = _TestModel(num_groups, num_channels) 34 | model = model.eval().cuda() 35 | dummy_input = torch.zeros_like(input) 36 | trt_model = module2trt(model, args=[dummy_input]) 37 | 38 | with torch.inference_mode(): 39 | gt = model(input) 40 | out = trt_model(input) 41 | torch.testing.assert_close(out, gt) 42 | -------------------------------------------------------------------------------- /tests/test_converters/test_topk.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from torch import nn 4 | from torch2trt_dynamic import module2trt 5 | 6 | 7 | class _TestStaticKModel(nn.Module): 8 | 9 | def __init__(self, k, dim, largest) -> None: 10 | super().__init__() 11 | self.k = k 12 | self.dim = dim 13 | self.largest = largest 14 | 15 | def forward(self, input): 16 | val, index = input.topk(k=self.k, dim=self.dim, largest=self.largest) 17 | return val, index 18 | 19 | 20 | class _TestDynamicModel(nn.Module): 21 | 22 | def __init__(self, k, dim, largest) -> None: 23 | super().__init__() 24 | self.k = k 25 | self.dim = dim 26 | self.largest = largest 27 | 28 | def forward(self, input): 29 | new_k = input.size(self.dim) 30 | k = min(self.k, new_k) 31 | val, index = input.topk(k=k, dim=self.dim, largest=self.largest) 32 | return val, index 33 | 34 | 35 | class TestTopk: 36 | 37 | @pytest.fixture 38 | def shape(self, request): 39 | yield request.param 40 | 41 | @pytest.fixture 42 | def dim(self, request): 43 | yield request.param 44 | 45 | @pytest.fixture 46 | def k(self, request): 47 | yield request.param 48 | 49 | @pytest.fixture 50 | def largest(self, request): 51 | yield request.param 52 | 53 | @pytest.fixture 54 | def input(self, shape): 55 | yield torch.rand(shape).cuda() 56 | 57 | @pytest.mark.parametrize('shape,dim', [ 58 | ((5, 10), 0), 59 | ((5, 10), 1), 60 | ((5, ), 0), 61 | ]) 62 | @pytest.mark.parametrize('k', [3]) 63 | @pytest.mark.parametrize('largest', [True, False]) 64 | def test_static(self, input, k, dim, largest): 65 | model = _TestStaticKModel(k, dim, largest) 66 | 67 | dummy_input = torch.zeros_like(input) 68 | trt_model = module2trt(model, args=[dummy_input]) 69 | 70 | with torch.inference_mode(): 71 | gt = model(input) 72 | out = trt_model(input) 73 | torch.testing.assert_close(out[0], gt[0]) 74 | torch.testing.assert_close(out[1].to(torch.int64), gt[1]) 75 | 76 | @pytest.mark.parametrize('shape,dim', [ 77 | ((5, 10), 0), 78 | ((5, 10), 1), 79 | ((5, ), 0), 80 | ]) 81 | @pytest.mark.parametrize('k', [6]) 82 | @pytest.mark.parametrize('largest', [True, False]) 83 | def test_dynamic(self, input, k, dim, largest): 84 | model = _TestDynamicModel(k, dim, largest) 85 | 86 | dummy_input = torch.zeros_like(input) 87 | trt_model = module2trt(model, args=[dummy_input]) 88 | 89 | with torch.inference_mode(): 90 | gt = model(input) 91 | out = trt_model(input) 92 | torch.testing.assert_close(out[0], gt[0]) 93 | torch.testing.assert_close(out[1].to(torch.int64), gt[1]) 94 | -------------------------------------------------------------------------------- /torch2trt_dynamic/__init__.py: -------------------------------------------------------------------------------- 1 | import tensorrt as trt 2 | 3 | from .converters import * # noqa: F401,F403 4 | from .torch2trt_dynamic import * # noqa: F401,F403 5 | from .trt_module import TRTModule, TRTModuleMeta # noqa: F401, F403 6 | 7 | 8 | def load_plugins(): 9 | import ctypes 10 | import os 11 | ctypes.CDLL( 12 | os.path.join(os.path.dirname(__file__), 'libtorch2trt_dynamic.so')) 13 | 14 | registry = trt.get_plugin_registry() 15 | torch2trt_creators = [ 16 | c for c in registry.plugin_creator_list 17 | if c.plugin_namespace == 'torch2trt_dynamic' 18 | ] 19 | for c in torch2trt_creators: 20 | registry.register_creator(c, 'torch2trt_dynamic') 21 | 22 | 23 | try: 24 | load_plugins() 25 | PLUGINS_LOADED = True 26 | except OSError: 27 | PLUGINS_LOADED = False 28 | -------------------------------------------------------------------------------- /torch2trt_dynamic/calibration.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Tuple 3 | 4 | import tensorrt as trt 5 | 6 | if trt.__version__ >= '5.1': 7 | DEFAULT_CALIBRATION_ALGORITHM = \ 8 | trt.CalibrationAlgoType.ENTROPY_CALIBRATION_2 9 | else: 10 | DEFAULT_CALIBRATION_ALGORITHM = trt.CalibrationAlgoType.ENTROPY_CALIBRATION 11 | 12 | 13 | class TensorBatchDataset(): 14 | 15 | def __init__(self, tensors): 16 | self.tensors = tensors 17 | 18 | def __len__(self): 19 | return len(self.tensors[0]) 20 | 21 | def __getitem__(self, idx): 22 | return [t[idx] for t in self.tensors] 23 | 24 | 25 | class SequenceDataset(): 26 | 27 | def __init__(self, sequences): 28 | self.sequences = sequences 29 | 30 | def __len__(self): 31 | return len(self.sequences) 32 | 33 | def __getitem__(self, idx): 34 | return self.sequences[idx] 35 | 36 | 37 | ShapeType = Tuple[int, ...] 38 | 39 | 40 | class DatasetCalibrator(trt.IInt8Calibrator): 41 | 42 | def __init__(self, 43 | dataset, 44 | batch_size=1, 45 | cache_file: str = None, 46 | algorithm=DEFAULT_CALIBRATION_ALGORITHM): 47 | super(DatasetCalibrator, self).__init__() 48 | 49 | self.dataset = dataset 50 | self.batch_size = batch_size 51 | self.algorithm = algorithm 52 | self.cache_file = cache_file 53 | 54 | # create buffers that will hold data batches 55 | self.buffers = dict() 56 | self.dataset_iter = iter(dataset) 57 | 58 | def get_batch(self, names): 59 | try: 60 | inputs = next(self.dataset_iter) 61 | for name in names: 62 | tensor = inputs[name] 63 | if name not in self.buffers: 64 | self.buffers[name] = tensor.clone().cuda() 65 | else: 66 | buf = self.buffers[name] 67 | assert buf.shape == tensor.shape 68 | buf.copy_(tensor) 69 | return [int(self.buffers[name].data_ptr()) for name in names] 70 | except StopIteration: 71 | return list() 72 | 73 | def get_algorithm(self): 74 | return self.algorithm 75 | 76 | def get_batch_size(self): 77 | return self.batch_size 78 | 79 | def read_calibration_cache(self): 80 | if self.cache_file is None: 81 | return 82 | if os.path.exists(self.cache_file): 83 | with open(self.cache_file, 'rb') as f: 84 | return f.read() 85 | 86 | def write_calibration_cache(self, cache): 87 | if self.cache_file is None: 88 | return 89 | with open(self.cache_file, 'wb') as f: 90 | f.write(cache) 91 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/AdaptiveAvgPool2d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch2trt_dynamic.module_test import add_module_test 3 | from torch2trt_dynamic.torch2trt_dynamic import tensorrt_converter 4 | 5 | from .adaptive_avg_pool2d import convert_adaptive_avg_pool2d 6 | 7 | 8 | @tensorrt_converter('torch.nn.AdaptiveAvgPool2d.forward') 9 | def convert_AdaptiveAvgPool2d(ctx): 10 | ctx.method_args = (ctx.method_args[1], ctx.method_args[0].output_size) 11 | convert_adaptive_avg_pool2d(ctx) 12 | 13 | 14 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224)]) 15 | def test_AdaptiveAvgPool2d_1x1(): 16 | return torch.nn.AdaptiveAvgPool2d((1, 1)) 17 | 18 | 19 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224)]) 20 | def test_AdaptiveAvgPool2d_2x2(): 21 | return torch.nn.AdaptiveAvgPool2d((2, 2)) 22 | 23 | 24 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224)]) 25 | def test_AdaptiveAvgPool2d_3x3(): 26 | return torch.nn.AdaptiveAvgPool2d((3, 3)) 27 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/AdaptiveMaxPool2d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..module_test import add_module_test 4 | from ..torch2trt_dynamic import tensorrt_converter 5 | from .adaptive_max_pool2d import convert_adaptive_max_pool2d 6 | 7 | 8 | @tensorrt_converter('torch.nn.AdaptiveMaxPool2d.forward') 9 | def convert_AdaptiveMaxPool2d(ctx): 10 | ctx.method_args = (ctx.method_args[1], ctx.method_args[0].output_size) 11 | convert_adaptive_max_pool2d(ctx) 12 | 13 | 14 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224)]) 15 | def test_AdaptiveMaxPool2d_1x1(): 16 | return torch.nn.AdaptiveMaxPool2d((1, 1)) 17 | 18 | 19 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224)]) 20 | def test_AdaptiveMaxPool2d_2x2(): 21 | return torch.nn.AdaptiveMaxPool2d((2, 2)) 22 | 23 | 24 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224)]) 25 | def test_AdaptiveMaxPool2d_3x3(): 26 | return torch.nn.AdaptiveMaxPool2d((3, 3)) 27 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/BatchNorm1d.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorrt as trt 3 | import torch 4 | from torch2trt_dynamic.module_test import add_module_test 5 | from torch2trt_dynamic.torch2trt_dynamic import tensorrt_converter, trt_ 6 | 7 | 8 | @tensorrt_converter('torch.nn.BatchNorm1d.forward') 9 | def convert_BatchNorm1d(ctx): 10 | 11 | module = ctx.method_args[0] 12 | input = ctx.method_args[1] 13 | input_trt = trt_(ctx.network, input) 14 | output = ctx.method_return 15 | 16 | scale = module.weight.detach().cpu().numpy() / np.sqrt( 17 | module.running_var.detach().cpu().numpy() + module.eps) 18 | bias = module.bias.detach().cpu().numpy( 19 | ) - module.running_mean.detach().cpu().numpy() * scale 20 | power = np.ones_like(scale) 21 | 22 | # reshape to 2D 23 | input_shape_trt = ctx.network.add_shape(input_trt).get_output(0) 24 | one_trt = trt_(ctx.network, 25 | torch.tensor([1], dtype=torch.int32).to(input.device)) 26 | if len(input.shape) == 2: 27 | new_input_shape_trt = ctx.network.add_concatenation( 28 | [input_shape_trt, one_trt, one_trt]).get_output(0) 29 | else: 30 | new_input_shape_trt = ctx.network.add_concatenation( 31 | [input_shape_trt, one_trt]).get_output(0) 32 | layer = ctx.network.add_shuffle(input_trt) 33 | layer.set_input(1, new_input_shape_trt) 34 | 35 | layer = ctx.network.add_scale( 36 | layer.get_output(0), trt.ScaleMode.CHANNEL, bias, scale, power) 37 | 38 | # reshape back to 1D 39 | conv_out_trt = layer.get_output(0) 40 | layer = ctx.network.add_shuffle(conv_out_trt) 41 | layer.set_input(1, input_shape_trt) 42 | 43 | output._trt = layer.get_output(0) 44 | 45 | 46 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 10)]) 47 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 10, 3)]) 48 | def test_BatchNorm1d_basic(): 49 | return torch.nn.BatchNorm1d(10) 50 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/BatchNorm2d.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorrt as trt 3 | from torch2trt_dynamic.torch2trt_dynamic import tensorrt_converter, trt_ 4 | 5 | 6 | @tensorrt_converter('torch.nn.BatchNorm2d.forward') 7 | def convert_BatchNorm2d(ctx): 8 | module = ctx.method_args[0] 9 | input = ctx.method_args[1] 10 | input_trt = trt_(ctx.network, input) 11 | output = ctx.method_return 12 | 13 | scale = module.weight.detach().cpu().numpy() / np.sqrt( 14 | module.running_var.detach().cpu().numpy() + module.eps) 15 | bias = module.bias.detach().cpu().numpy( 16 | ) - module.running_mean.detach().cpu().numpy() * scale 17 | power = np.ones_like(scale) 18 | 19 | layer = ctx.network.add_scale(input_trt, trt.ScaleMode.CHANNEL, bias, 20 | scale, power) 21 | 22 | output._trt = layer.get_output(0) 23 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/Conv1d.py: -------------------------------------------------------------------------------- 1 | import tensorrt as trt 2 | import torch 3 | from torch2trt_dynamic.module_test import add_module_test 4 | from torch2trt_dynamic.torch2trt_dynamic import (tensorrt_converter, 5 | torch_dtype_to_trt, trt_) 6 | 7 | 8 | @tensorrt_converter('torch.nn.Conv1d.forward') 9 | def convert_Conv1d(ctx): 10 | 11 | module = ctx.method_args[0] 12 | input = ctx.method_args[1] 13 | input_trt = trt_(ctx.network, input) 14 | output = ctx.method_return 15 | 16 | kernel_size = (module.kernel_size[0], 1) 17 | stride = (module.stride[0], 1) 18 | padding = (module.padding[0], 0) 19 | dilation = (module.dilation[0], 1) 20 | 21 | kernel = module.weight.detach().cpu().numpy()[..., None] 22 | 23 | bias = trt.Weights(torch_dtype_to_trt(module.weight.dtype)) 24 | if module.bias is not None: 25 | bias = module.bias.detach().cpu().numpy() 26 | 27 | # reshape to 2D 28 | input_shape_trt = ctx.network.add_shape(input_trt).get_output(0) 29 | one_trt = trt_(ctx.network, 30 | torch.tensor([1], dtype=torch.int32).to(input.device)) 31 | new_input_shape_trt = ctx.network.add_concatenation( 32 | [input_shape_trt, one_trt]).get_output(0) 33 | layer = ctx.network.add_shuffle(input_trt) 34 | layer.set_input(1, new_input_shape_trt) 35 | 36 | layer = ctx.network.add_convolution( 37 | input=layer.get_output(0), 38 | num_output_maps=module.out_channels, 39 | kernel_shape=kernel_size, 40 | kernel=kernel, 41 | bias=bias) 42 | layer.stride = stride 43 | layer.padding = padding 44 | layer.dilation = dilation 45 | 46 | if module.groups is not None: 47 | layer.num_groups = module.groups 48 | 49 | # reshape back to 1D 50 | conv_out_trt = layer.get_output(0) 51 | out_shape_trt = ctx.network.add_shape(conv_out_trt).get_output(0) 52 | new_out_shape_trt = ctx.network.add_slice(out_shape_trt, [0], [3], 53 | [1]).get_output(0) 54 | layer = ctx.network.add_shuffle(conv_out_trt) 55 | layer.set_input(1, new_out_shape_trt) 56 | 57 | output._trt = layer.get_output(0) 58 | 59 | 60 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 10, 224)]) 61 | def test_Conv1d_basic(): 62 | return torch.nn.Conv1d(10, 5, kernel_size=1, stride=1, padding=0) 63 | 64 | 65 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 10, 224)]) 66 | def test_Conv1d_stride2(): 67 | return torch.nn.Conv1d(10, 5, kernel_size=1, stride=2, padding=0) 68 | 69 | 70 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 10, 224)]) 71 | def test_Conv1d_kernel3(): 72 | return torch.nn.Conv1d(10, 5, kernel_size=3, stride=2, padding=1) 73 | 74 | 75 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 10, 224)]) 76 | def test_Conv1d_dilation2(): 77 | return torch.nn.Conv1d( 78 | 10, 5, kernel_size=3, stride=1, padding=1, dilation=2) 79 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/Conv2d.py: -------------------------------------------------------------------------------- 1 | import tensorrt as trt 2 | import torch 3 | from torch2trt_dynamic.module_test import add_module_test 4 | from torch2trt_dynamic.torch2trt_dynamic import (tensorrt_converter, 5 | torch_dtype_to_trt, trt_) 6 | 7 | 8 | @tensorrt_converter('torch.nn.Conv2d.forward') 9 | def convert_Conv2d(ctx): 10 | module = ctx.method_args[0] 11 | input = ctx.method_args[1] 12 | input_trt = trt_(ctx.network, input) 13 | output = ctx.method_return 14 | 15 | kernel_size = module.kernel_size 16 | if not isinstance(kernel_size, tuple): 17 | kernel_size = (kernel_size, ) * 2 18 | 19 | stride = module.stride 20 | if not isinstance(stride, tuple): 21 | stride = (stride, ) * 2 22 | 23 | padding = module.padding 24 | if not isinstance(padding, tuple): 25 | padding = (padding, ) * 2 26 | 27 | dilation = module.dilation 28 | if not isinstance(dilation, tuple): 29 | dilation = (dilation, ) * 2 30 | 31 | kernel = module.weight.detach().cpu().numpy() 32 | 33 | bias = trt.Weights(torch_dtype_to_trt(module.weight.dtype)) 34 | if module.bias is not None: 35 | bias = module.bias.detach().cpu().numpy() 36 | 37 | layer = ctx.network.add_convolution_nd( 38 | input=input_trt, 39 | num_output_maps=module.out_channels, 40 | kernel_shape=kernel_size, 41 | kernel=kernel, 42 | bias=bias) 43 | layer.stride_nd = stride 44 | layer.padding_nd = padding 45 | layer.dilation_nd = dilation 46 | 47 | if module.groups is not None: 48 | layer.num_groups = module.groups 49 | 50 | output._trt = layer.get_output(0) 51 | 52 | 53 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 10, 224, 224)]) 54 | def test_Conv2d_basic(): 55 | return torch.nn.Conv2d(10, 5, kernel_size=1, stride=1, padding=0) 56 | 57 | 58 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 10, 224, 224)]) 59 | def test_Conv2d_stride2(): 60 | return torch.nn.Conv2d(10, 5, kernel_size=1, stride=2, padding=0) 61 | 62 | 63 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 10, 224, 224)]) 64 | def test_Conv2d_kernel3(): 65 | return torch.nn.Conv2d(10, 5, kernel_size=3, stride=2, padding=1) 66 | 67 | 68 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 10, 224, 224)]) 69 | def test_Conv2d_dilation2(): 70 | return torch.nn.Conv2d( 71 | 10, 5, kernel_size=3, stride=1, padding=1, dilation=2) 72 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/ConvTranspose1d.py: -------------------------------------------------------------------------------- 1 | import tensorrt as trt 2 | from torch2trt_dynamic.torch2trt_dynamic import (tensorrt_converter, 3 | torch_dtype_to_trt, trt_) 4 | 5 | 6 | @tensorrt_converter('torch.nn.ConvTranspose1d.forward') 7 | def convert_ConvTranspose1d(ctx): 8 | module = ctx.method_args[0] 9 | input = ctx.method_args[1] 10 | input_trt = trt_(ctx.network, input) 11 | output = ctx.method_return 12 | 13 | kernel_size = module.kernel_size 14 | if not isinstance(kernel_size, tuple): 15 | kernel_size = (kernel_size, 1) 16 | else: 17 | kernel_size = kernel_size + (1, ) 18 | 19 | stride = module.stride 20 | if not isinstance(stride, tuple): 21 | stride = (stride, 1) 22 | else: 23 | stride = stride + (1, ) 24 | 25 | padding = module.padding 26 | if not isinstance(padding, tuple): 27 | padding = (padding, 0) 28 | else: 29 | padding = padding + (0, ) 30 | 31 | kernel = module.weight.detach().cpu().numpy()[..., None] 32 | 33 | bias = trt.Weights(torch_dtype_to_trt(module.weight.dtype)) 34 | if module.bias is not None: 35 | bias = module.bias.detach().cpu().numpy()[..., None] 36 | 37 | # unsqueeze(3) 38 | layer = ctx.network.add_shuffle(input_trt) 39 | layer.reshape_dims = (0, 0, 0, 1) 40 | input_trt = layer.get_output(0) 41 | 42 | # deconv 43 | layer = ctx.network.add_deconvolution( 44 | input=input_trt, 45 | num_output_maps=module.out_channels, 46 | kernel_shape=kernel_size, 47 | kernel=kernel, 48 | bias=bias) 49 | layer.stride = stride 50 | layer.padding = padding 51 | 52 | if module.groups is not None: 53 | layer.num_groups = module.groups 54 | 55 | output_trt = layer.get_output(0) 56 | 57 | # squeeze(3) 58 | layer = ctx.network.add_shuffle(output_trt) 59 | layer.reshape_dims = (0, 0, 0) 60 | output_trt = layer.get_output(0) 61 | 62 | output._trt = output_trt 63 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/ConvTranspose2d.py: -------------------------------------------------------------------------------- 1 | import tensorrt as trt 2 | from torch2trt_dynamic.torch2trt_dynamic import (tensorrt_converter, 3 | torch_dtype_to_trt, trt_) 4 | 5 | 6 | @tensorrt_converter('torch.nn.ConvTranspose2d.forward') 7 | def convert_ConvTranspose2d(ctx): 8 | module = ctx.method_args[0] 9 | input = ctx.method_args[1] 10 | input_trt = trt_(ctx.network, input) 11 | output = ctx.method_return 12 | 13 | kernel_size = module.kernel_size 14 | if not isinstance(kernel_size, tuple): 15 | kernel_size = (kernel_size, ) * 2 16 | 17 | stride = module.stride 18 | if not isinstance(stride, tuple): 19 | stride = (stride, ) * 2 20 | 21 | padding = module.padding 22 | if not isinstance(padding, tuple): 23 | padding = (padding, ) * 2 24 | 25 | kernel = module.weight.detach().cpu().numpy() 26 | 27 | bias = trt.Weights(torch_dtype_to_trt(module.weight.dtype)) 28 | if module.bias is not None: 29 | bias = module.bias.detach().cpu().numpy() 30 | 31 | layer = ctx.network.add_deconvolution( 32 | input=input_trt, 33 | num_output_maps=module.out_channels, 34 | kernel_shape=kernel_size, 35 | kernel=kernel, 36 | bias=bias) 37 | layer.stride = stride 38 | layer.padding = padding 39 | 40 | if module.groups is not None: 41 | layer.num_groups = module.groups 42 | 43 | output._trt = layer.get_output(0) 44 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/Embedding.py: -------------------------------------------------------------------------------- 1 | import tensorrt as trt 2 | 3 | from ..torch2trt_dynamic import get_arg, tensorrt_converter, trt_ 4 | 5 | 6 | def _update_weight(weight, max_norm, norm_type): 7 | if max_norm is None: 8 | return weight 9 | num_embeddings = weight.shape[0] 10 | for emb_id in range(num_embeddings): 11 | norm = weight[emb_id].norm(norm_type) 12 | if norm > max_norm: 13 | scale = max_norm / (norm + 1e-7) 14 | weight[emb_id] = weight[emb_id] * scale 15 | return weight 16 | 17 | 18 | @tensorrt_converter('torch.nn.Embedding.forward') 19 | def convert_embedding_forward(ctx): 20 | module = ctx.method_args[0] 21 | inputs = ctx.method_args[1] 22 | weight = module.weight 23 | 24 | ctx.method_args = [inputs, weight] 25 | ctx.method_kwargs = {} 26 | convert_embedding(ctx) 27 | 28 | 29 | @tensorrt_converter('torch.nn.functional.embedding') 30 | def convert_embedding(ctx): 31 | input = get_arg(ctx, 'input', pos=0, default=None) 32 | weight = get_arg(ctx, 'weight', pos=1, default=None) 33 | padding_idx = get_arg(ctx, 'padding_idx', pos=2, default=None) 34 | max_norm = get_arg(ctx, 'max_norm', pos=3, default=None) 35 | norm_type = get_arg(ctx, 'norm_type', pos=4, default=2) 36 | output = ctx.method_return 37 | 38 | weight = _update_weight(weight, max_norm, norm_type) 39 | if padding_idx is not None: 40 | weight[padding_idx, :] = 0 41 | 42 | input_trt = trt_(ctx.network, input) 43 | weight_trt = trt_(ctx.network, weight) 44 | layer = ctx.network.add_gather_v2(weight_trt, input_trt, 45 | trt.GatherMode.DEFAULT) 46 | layer.axis = 0 47 | 48 | output._trt = layer.get_output(0) 49 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/GroupNorm.py: -------------------------------------------------------------------------------- 1 | from ..plugins import create_groupnorm_plugin 2 | from ..torch2trt_dynamic import tensorrt_converter, trt_ 3 | 4 | 5 | @tensorrt_converter('torch.nn.GroupNorm.forward') 6 | def convert_GroupNorm(ctx): 7 | module = ctx.method_args[0] 8 | input = ctx.method_args[1] 9 | 10 | input_trt = trt_(ctx.network, input) 11 | weight_trt = trt_(ctx.network, module.weight) 12 | bias_trt = trt_(ctx.network, module.bias) 13 | output = ctx.method_return 14 | 15 | num_groups = module.num_groups 16 | eps = module.eps 17 | 18 | plugin = create_groupnorm_plugin( 19 | 'groupnorm_' + str(id(module)), num_groups=num_groups, eps=eps) 20 | 21 | custom_layer = ctx.network.add_plugin_v2( 22 | inputs=[input_trt, weight_trt, bias_trt], plugin=plugin) 23 | 24 | output._trt = custom_layer.get_output(0) 25 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/Identity.py: -------------------------------------------------------------------------------- 1 | from ..torch2trt_dynamic import tensorrt_converter, trt_ 2 | 3 | 4 | @tensorrt_converter('torch.nn.Dropout.forward') 5 | @tensorrt_converter('torch.nn.Dropout2d.forward') 6 | @tensorrt_converter('torch.nn.Dropout3d.forward') 7 | def convert_Identity(ctx): 8 | input = ctx.method_args[1] 9 | input_trt = trt_(ctx.network, input) 10 | output = ctx.method_return 11 | output._trt = input_trt 12 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/LayerNorm.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorrt as trt 3 | 4 | from ..torch2trt_dynamic import (tensor_trt_get_shape_trt, tensorrt_converter, 5 | torch_dim_to_trt_axes, trt_) 6 | 7 | 8 | @tensorrt_converter('torch.nn.LayerNorm.forward') 9 | def convert_LayerNorm(ctx): 10 | module = ctx.method_args[0] 11 | input = ctx.method_args[1] 12 | 13 | normalized_shape = module.normalized_shape 14 | weight = module.weight 15 | bias = module.bias 16 | eps = module.eps 17 | 18 | output = ctx.method_return 19 | 20 | eps_np = np.array([eps], dtype=np.float32) 21 | keep_dims = True 22 | 23 | input_trt = trt_(ctx.network, input) 24 | 25 | if len(input.shape) == 3: 26 | input_shape_trt = tensor_trt_get_shape_trt(ctx.network, input_trt) 27 | new_input_shape_trt = ctx.network.add_concatenation( 28 | [trt_(ctx.network, 1), input_shape_trt]).get_output(0) 29 | layer = ctx.network.add_shuffle(input_trt) 30 | layer.set_input(1, new_input_shape_trt) 31 | input_trt = layer.get_output(0) 32 | 33 | reduce_axes = torch_dim_to_trt_axes( 34 | tuple( 35 | range( 36 | len(input_trt.shape) - len(normalized_shape), 37 | len(input_trt.shape)))) 38 | 39 | mean_trt = ctx.network.add_reduce(input_trt, trt.ReduceOperation.AVG, 40 | reduce_axes, keep_dims).get_output(0) 41 | 42 | # compute variance over spatial (include eps, to reduce layer count) 43 | delta_trt = ctx.network.add_elementwise( 44 | input_trt, mean_trt, trt.ElementWiseOperation.SUB).get_output(0) 45 | 46 | var_trt = ctx.network.add_scale(delta_trt, trt.ScaleMode.UNIFORM, 47 | np.zeros_like(eps_np), 48 | np.ones_like(eps_np), 49 | 2 * np.ones_like(eps_np)).get_output(0) 50 | var_trt = ctx.network.add_reduce(var_trt, trt.ReduceOperation.AVG, 51 | reduce_axes, keep_dims).get_output(0) 52 | 53 | # compute sqrt(var + eps) 54 | var_trt = ctx.network.add_scale(var_trt, trt.ScaleMode.UNIFORM, eps_np, 55 | np.ones_like(eps_np), 56 | 0.5 * np.ones_like(eps_np)).get_output(0) 57 | 58 | # compute final result 59 | result_trt = ctx.network.add_elementwise( 60 | delta_trt, var_trt, trt.ElementWiseOperation.DIV).get_output(0) 61 | 62 | if len(input.shape) == 3: 63 | layer = ctx.network.add_shuffle(result_trt) 64 | layer.set_input(1, input_shape_trt) 65 | result_trt = layer.get_output(0) 66 | 67 | if weight is not None: 68 | assert weight.ndim <= input.ndim 69 | while weight.ndim < input.ndim: 70 | weight = weight.unsqueeze(0) 71 | weight_trt = trt_(ctx.network, weight) 72 | layer = ctx.network.add_elementwise(result_trt, weight_trt, 73 | trt.ElementWiseOperation.PROD) 74 | result_trt = layer.get_output(0) 75 | 76 | if bias is not None: 77 | assert bias.ndim <= input.ndim 78 | while bias.ndim < input.ndim: 79 | bias = bias.unsqueeze(0) 80 | bias_trt = trt_(ctx.network, bias) 81 | layer = ctx.network.add_elementwise(result_trt, bias_trt, 82 | trt.ElementWiseOperation.SUM) 83 | result_trt = layer.get_output(0) 84 | 85 | output._trt = result_trt 86 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/Linear.py: -------------------------------------------------------------------------------- 1 | import tensorrt as trt 2 | import torch 3 | 4 | from ..module_test import add_module_test 5 | from ..torch2trt_dynamic import tensorrt_converter, torch_dtype_to_trt, trt_ 6 | 7 | 8 | @tensorrt_converter('torch.nn.Linear.forward') 9 | def convert_Linear(ctx): 10 | module = ctx.method_args[0] 11 | input = ctx.method_args[1] 12 | input_trt = trt_(ctx.network, input) 13 | output = ctx.method_return 14 | 15 | # reshape to ...xNx1x1 16 | layer = ctx.network.add_shuffle(input_trt) 17 | layer.reshape_dims = (0, ) * len(input_trt.shape) + (1, 1) 18 | 19 | # add fully connected 20 | bias = trt.Weights(torch_dtype_to_trt(module.weight.dtype)) 21 | if module.bias is not None: 22 | bias = module.bias.detach().cpu().numpy() 23 | 24 | layer = ctx.network.add_convolution_nd( 25 | input=layer.get_output(0), 26 | num_output_maps=module.out_features, 27 | kernel_shape=(1, 1), 28 | kernel=module.weight.detach().cpu().numpy(), 29 | bias=bias) 30 | 31 | # reshape back to N 32 | layer = ctx.network.add_shuffle(layer.get_output(0)) 33 | # layer.reshape_dims = tuple(output.shape[1:]) 34 | layer.reshape_dims = (0, ) * len(input_trt.shape) 35 | 36 | output._trt = layer.get_output(0) 37 | 38 | 39 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 10)]) 40 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 10)]) 41 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 4, 10)]) 42 | def test_Linear_basic(): 43 | return torch.nn.Linear(10, 5) 44 | 45 | 46 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 10)]) 47 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 10)]) 48 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 4, 10)]) 49 | def test_Linear_no_bias(): 50 | return torch.nn.Linear(10, 5, bias=False) 51 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/LogSoftmax.py: -------------------------------------------------------------------------------- 1 | import tensorrt as trt 2 | 3 | from ..torch2trt_dynamic import tensorrt_converter, trt_ 4 | 5 | 6 | @tensorrt_converter('torch.nn.LogSoftmax.forward') 7 | def convert_LogSoftmax(ctx): 8 | input = ctx.method_args[1] 9 | input_trt = trt_(ctx.network, input) 10 | output = ctx.method_return 11 | layer = ctx.network.add_softmax(input=input_trt) 12 | layer = ctx.network.add_unary( 13 | input=layer.get_output(0), op=trt.UnaryOperation.LOG) 14 | output._trt = layer.get_output(0) 15 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/ReLU.py: -------------------------------------------------------------------------------- 1 | import tensorrt as trt 2 | from torch2trt_dynamic.torch2trt_dynamic import tensorrt_converter, trt_ 3 | 4 | 5 | @tensorrt_converter('torch.nn.ReLU.forward') 6 | def convert_ReLU(ctx): 7 | input = ctx.method_args[1] 8 | input_trt = trt_(ctx.network, input) 9 | output = ctx.method_return 10 | layer = ctx.network.add_activation( 11 | input=input_trt, type=trt.ActivationType.RELU) 12 | output._trt = layer.get_output(0) 13 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/ReLU6.py: -------------------------------------------------------------------------------- 1 | import tensorrt as trt 2 | import torch 3 | from torch2trt_dynamic.module_test import add_module_test 4 | from torch2trt_dynamic.torch2trt_dynamic import tensorrt_converter, trt_ 5 | 6 | 7 | @tensorrt_converter('torch.nn.ReLU6.forward') 8 | def convert_ReLU6(ctx): 9 | input = ctx.method_args[1] 10 | output = ctx.method_return 11 | 12 | input_trt, trt_6 = trt_(ctx.network, input, 6.) 13 | 14 | layer = ctx.network.add_activation( 15 | input=input_trt, type=trt.ActivationType.RELU) 16 | layer = ctx.network.add_elementwise( 17 | layer.get_output(0), trt_6, trt.ElementWiseOperation.MIN) 18 | 19 | output._trt = layer.get_output(0) 20 | 21 | 22 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 4, 5)]) 23 | def test_relu6_basic(): 24 | return torch.nn.ReLU6() 25 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/activation.py: -------------------------------------------------------------------------------- 1 | import tensorrt as trt 2 | import torch 3 | from torch2trt_dynamic.module_test import add_module_test 4 | from torch2trt_dynamic.torch2trt_dynamic import (get_arg, tensorrt_converter, 5 | trt_) 6 | 7 | from .unary import UnaryModule 8 | 9 | # | RELU : Rectified Linear activation (impl in relu.py) 10 | # | SIGMOID : Sigmoid activation (impl in sigmoid.py) 11 | # | TANH : Hyperbolic Tangent activation (impl in tanh.py) 12 | 13 | # | LEAKY_RELU : Leaky Relu activation: 14 | # f(x) = x if x >= 0, f(x) = alpha * x if x < 0 15 | 16 | 17 | @tensorrt_converter('torch.nn.functional.leaky_relu') 18 | @tensorrt_converter('torch.nn.functional.leaky_relu_') 19 | def convert_leaky_relu(ctx): 20 | input = get_arg(ctx, 'input', pos=0, default=None) 21 | negative_slope = get_arg(ctx, 'negative_slope', pos=1, default=0.01) 22 | output = ctx.method_return 23 | 24 | input_trt = trt_(ctx.network, input) 25 | layer = ctx.network.add_activation(input_trt, 26 | trt.ActivationType.LEAKY_RELU) 27 | layer.alpha = negative_slope 28 | 29 | output._trt = layer.get_output(0) 30 | 31 | 32 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 5, 3)]) 33 | def test_leaky_relu(): 34 | return UnaryModule(lambda x: torch.nn.functional.leaky_relu(x)) 35 | 36 | 37 | # | ELU : Elu activation: 38 | # f(x) = x if x >= 0, f(x) = alpha * (exp(x) - 1) if x < 0 39 | 40 | 41 | @tensorrt_converter('torch.nn.functional.elu') 42 | @tensorrt_converter('torch.nn.functional.elu_') 43 | def convert_elu(ctx): 44 | input = get_arg(ctx, 'input', pos=0, default=None) 45 | alpha = get_arg(ctx, 'alpha', pos=1, default=1.0) 46 | output = ctx.method_return 47 | 48 | input_trt = trt_(ctx.network, input) 49 | layer = ctx.network.add_activation(input_trt, trt.ActivationType.ELU) 50 | layer.alpha = alpha 51 | 52 | output._trt = layer.get_output(0) 53 | 54 | 55 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 5, 3)]) 56 | def test_elu(): 57 | return UnaryModule(lambda x: torch.nn.functional.elu(x)) 58 | 59 | 60 | # | SELU : Selu activation: 61 | # f(x) = beta * x if x > 0, f(x) = beta * (alpha * exp(x) - alpha) if x <= 0 62 | 63 | 64 | @tensorrt_converter('torch.selu') 65 | @tensorrt_converter('torch.selu_') 66 | @tensorrt_converter('torch.nn.functional.selu') 67 | @tensorrt_converter('torch.nn.functional.selu_') 68 | def convert_selu(ctx): 69 | input = get_arg(ctx, 'input', pos=0, default=None) 70 | # alpha = get_arg(ctx, 'alpha', pos=1, default=1.0) 71 | output = ctx.method_return 72 | 73 | input_trt = trt_(ctx.network, input) 74 | layer = ctx.network.add_activation(input_trt, trt.ActivationType.SELU) 75 | layer.alpha = 1.6732632423543772848170429916717 76 | layer.beta = 1.0507009873554804934193349852946 77 | 78 | output._trt = layer.get_output(0) 79 | 80 | 81 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 5, 3)]) 82 | def test_selu(): 83 | return UnaryModule(lambda x: torch.nn.functional.selu(x)) 84 | 85 | 86 | # | SOFTSIGN : Softsign activation: f(x) = x / (1 + \|x\|) 87 | 88 | 89 | @tensorrt_converter('torch.nn.functional.softsign') 90 | def convert_softsign(ctx): 91 | input = get_arg(ctx, 'input', pos=0, default=None) 92 | output = ctx.method_return 93 | 94 | input_trt = trt_(ctx.network, input) 95 | layer = ctx.network.add_activation(input_trt, trt.ActivationType.SOFTSIGN) 96 | 97 | output._trt = layer.get_output(0) 98 | 99 | 100 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 5, 3)]) 101 | def test_softsign(): 102 | return UnaryModule(lambda x: torch.nn.functional.softsign(x)) 103 | 104 | 105 | # | SOFTPLUS : Softplus activation: f(x) = alpha * log(exp(beta * x) + 1) 106 | 107 | 108 | @tensorrt_converter('torch.nn.functional.softplus') 109 | def convert_softplus(ctx): 110 | input = get_arg(ctx, 'input', pos=0, default=None) 111 | output = ctx.method_return 112 | 113 | input_trt = trt_(ctx.network, input) 114 | layer = ctx.network.add_activation(input_trt, trt.ActivationType.SOFTPLUS) 115 | 116 | output._trt = layer.get_output(0) 117 | 118 | 119 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 5, 3)]) 120 | def test_softplus(): 121 | return UnaryModule(lambda x: torch.nn.functional.softplus(x)) 122 | 123 | 124 | # | CLIP : Clip activation: 125 | # f(x) = max(alpha, min(beta, x)) (impl in clamp.py) 126 | 127 | # | HARD_SIGMOID : Hard sigmoid activation: 128 | # f(x) = max(0, min(1, alpha * x + beta)) 129 | # (not sure if there is this in Pytorch?) 130 | # | SCALED_TANH : Scaled Tanh activation: 131 | # f(x) = alpha * tanh(beta * x) (not sure if there is this in Pytorch?) 132 | # | THRESHOLDED_RELU : Thresholded Relu activation: 133 | # f(x) = x if x > alpha, f(x) = 0 if x <= alpha 134 | # (not sure if there is this in Pytorch?) 135 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/adaptive_avg_pool1d.py: -------------------------------------------------------------------------------- 1 | import tensorrt as trt 2 | from torch2trt_dynamic.torch2trt_dynamic import (get_arg, tensorrt_converter, 3 | trt_) 4 | 5 | 6 | @tensorrt_converter('torch.nn.functional.adaptive_avg_pool1d') 7 | def convert_adaptive_avg_pool1d(ctx): 8 | input = ctx.method_args[0] 9 | output_size = get_arg(ctx, 'output_size', pos=1, default=0) 10 | output = ctx.method_return 11 | input_trt = trt_(ctx.network, input) 12 | 13 | if output_size == 1: 14 | # use reduce as max pool2d 15 | shape_length = len(input.shape) 16 | axes = (1 << (shape_length - 1)) 17 | keepdim = True 18 | layer = ctx.network.add_reduce(input_trt, trt.ReduceOperation.AVG, 19 | axes, keepdim) 20 | output._trt = layer.get_output(0) 21 | else: 22 | from torch2trt_dynamic.plugins import create_adaptivepool_plugin 23 | output_size = (output_size, 1) 24 | 25 | # input.unsqueeze(-1) 26 | layer = ctx.network.add_shuffle(input_trt) 27 | layer.reshape_dims = (0, 0, 0, 1) 28 | input_trt = layer.get_output(0) 29 | 30 | # adaptive pool 2d 31 | plugin = create_adaptivepool_plugin( 32 | 'adaptive_avg_pool2d_' + str(id(input)), 33 | output_size=output_size, 34 | pooling_type=trt.PoolingType.AVERAGE) 35 | 36 | layer = ctx.network.add_plugin_v2(inputs=[input_trt], plugin=plugin) 37 | 38 | output_trt = layer.get_output(0) 39 | 40 | layer = ctx.network.add_shuffle(output_trt) 41 | layer.reshape_dims = (0, 0, 0) 42 | output_trt = layer.get_output(0) 43 | 44 | output._trt = output_trt 45 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/adaptive_avg_pool2d.py: -------------------------------------------------------------------------------- 1 | import tensorrt as trt 2 | from torch2trt_dynamic.torch2trt_dynamic import (get_arg, tensorrt_converter, 3 | trt_) 4 | 5 | 6 | @tensorrt_converter('torch.nn.functional.adaptive_avg_pool2d') 7 | def convert_adaptive_avg_pool2d(ctx): 8 | input = ctx.method_args[0] 9 | output_size = get_arg(ctx, 'output_size', pos=1, default=0) 10 | output = ctx.method_return 11 | input_trt = trt_(ctx.network, input) 12 | 13 | if isinstance(output_size, int): 14 | output_size = (output_size, output_size) 15 | 16 | output_size = tuple([-1 if not o else o for o in output_size]) 17 | 18 | if output_size[0] == 1 and output_size[1] == 1: 19 | # use reduce as max pool2d 20 | shape_length = len(input.shape) 21 | axes = (1 << (shape_length - 1)) + (1 << (shape_length - 2)) 22 | keepdim = True 23 | layer = ctx.network.add_reduce(input_trt, trt.ReduceOperation.AVG, 24 | axes, keepdim) 25 | output._trt = layer.get_output(0) 26 | else: 27 | from torch2trt_dynamic.plugins import create_adaptivepool_plugin 28 | plugin = create_adaptivepool_plugin( 29 | 'adaptive_avg_pool2d_' + str(id(input)), 30 | output_size=output_size, 31 | pooling_type=trt.PoolingType.AVERAGE) 32 | 33 | layer = ctx.network.add_plugin_v2(inputs=[input_trt], plugin=plugin) 34 | 35 | output._trt = layer.get_output(0) 36 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/adaptive_max_pool1d.py: -------------------------------------------------------------------------------- 1 | import tensorrt as trt 2 | from torch2trt_dynamic.torch2trt_dynamic import (get_arg, tensorrt_converter, 3 | trt_) 4 | 5 | 6 | @tensorrt_converter('torch.nn.functional.adaptive_max_pool1d') 7 | def convert_adaptive_max_pool1d(ctx): 8 | input = ctx.method_args[0] 9 | output_size = get_arg(ctx, 'output_size', pos=1, default=0) 10 | output = ctx.method_return 11 | input_trt = trt_(ctx.network, input) 12 | 13 | if output_size == 1: 14 | # use reduce as max pool2d 15 | shape_length = len(input.shape) 16 | axes = (1 << (shape_length - 1)) 17 | keepdim = True 18 | layer = ctx.network.add_reduce(input_trt, trt.ReduceOperation.MAX, 19 | axes, keepdim) 20 | output._trt = layer.get_output(0) 21 | else: 22 | from torch2trt_dynamic.plugins import create_adaptivepool_plugin 23 | output_size = (output_size, 1) 24 | 25 | # input.unsqueeze(-1) 26 | layer = ctx.network.add_shuffle(input_trt) 27 | layer.reshape_dims = (0, 0, 0, 1) 28 | input_trt = layer.get_output(0) 29 | 30 | # adaptive pool 2d 31 | plugin = create_adaptivepool_plugin( 32 | 'adaptive_avg_pool2d_' + str(id(input)), 33 | output_size=output_size, 34 | pooling_type=trt.PoolingType.MAX) 35 | 36 | layer = ctx.network.add_plugin_v2(inputs=[input_trt], plugin=plugin) 37 | 38 | output_trt = layer.get_output(0) 39 | 40 | layer = ctx.network.add_shuffle(output_trt) 41 | layer.reshape_dims = (0, 0, 0) 42 | output_trt = layer.get_output(0) 43 | 44 | output._trt = output_trt 45 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/adaptive_max_pool2d.py: -------------------------------------------------------------------------------- 1 | import tensorrt as trt 2 | import torch 3 | from torch2trt_dynamic.module_test import add_module_test 4 | from torch2trt_dynamic.plugins import create_adaptivepool_plugin 5 | from torch2trt_dynamic.torch2trt_dynamic import (get_arg, tensorrt_converter, 6 | trt_) 7 | 8 | 9 | @tensorrt_converter('torch.nn.functional.adaptive_max_pool2d') 10 | def convert_adaptive_max_pool2d(ctx): 11 | input = ctx.method_args[0] 12 | output_size = get_arg(ctx, 'output_size', pos=1, default=0) 13 | output = ctx.method_return 14 | input_trt = trt_(ctx.network, input) 15 | 16 | if isinstance(output_size, int): 17 | output_size = (output_size, output_size) 18 | 19 | output_size = tuple([-1 if not o else o for o in output_size]) 20 | 21 | if output_size[0] == 1 and output_size[1] == 1: 22 | # use reduce as max pool2d 23 | shape_length = len(input.shape) 24 | axes = (1 << (shape_length - 1)) + (1 << (shape_length - 2)) 25 | keepdim = True 26 | layer = ctx.network.add_reduce(input_trt, trt.ReduceOperation.MAX, 27 | axes, keepdim) 28 | output._trt = layer.get_output(0) 29 | else: 30 | plugin = create_adaptivepool_plugin( 31 | 'adaptive_max_pool2d_' + str(id(input)), 32 | output_size=output_size, 33 | pooling_type=trt.PoolingType.MAX) 34 | 35 | layer = ctx.network.add_plugin_v2(inputs=[input_trt], plugin=plugin) 36 | 37 | output._trt = layer.get_output(0) 38 | 39 | 40 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224)]) 41 | def test_adaptive_max_pool2d_1x1(): 42 | return torch.nn.AdaptiveMaxPool2d((1, 1)) 43 | 44 | 45 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224)]) 46 | def test_adaptive_max_pool2d_2x2(): 47 | return torch.nn.AdaptiveMaxPool2d((2, 2)) 48 | 49 | 50 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224)]) 51 | def test_adaptive_max_pool2d_3x3(): 52 | return torch.nn.AdaptiveMaxPool2d((3, 3)) 53 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/add.py: -------------------------------------------------------------------------------- 1 | import tensorrt as trt 2 | import torch 3 | from torch2trt_dynamic.module_test import add_module_test 4 | from torch2trt_dynamic.torch2trt_dynamic import tensorrt_converter, trt_ 5 | 6 | 7 | @tensorrt_converter('torch.add') 8 | @tensorrt_converter('torch.Tensor.__iadd__') 9 | @tensorrt_converter('torch.Tensor.__add__') 10 | @tensorrt_converter('torch.Tensor.__radd__') 11 | def convert_add(ctx): 12 | input_a = ctx.method_args[0] 13 | input_b = ctx.method_args[1] 14 | input_a_trt, input_b_trt = trt_(ctx.network, input_a, input_b) 15 | output = ctx.method_return 16 | layer = ctx.network.add_elementwise(input_a_trt, input_b_trt, 17 | trt.ElementWiseOperation.SUM) 18 | output._trt = layer.get_output(0) 19 | 20 | 21 | class Add(torch.nn.Module): 22 | 23 | def __init__(self): 24 | super(Add, self).__init__() 25 | 26 | def forward(self, x, y): 27 | return x + y 28 | 29 | 30 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224), 31 | (1, 3, 224, 224)]) 32 | def test_add_basic(): 33 | return Add() 34 | 35 | 36 | class IAdd(torch.nn.Module): 37 | 38 | def __init__(self): 39 | super(IAdd, self).__init__() 40 | 41 | def forward(self, x, y): 42 | x += y 43 | return x 44 | 45 | 46 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224), 47 | (1, 3, 224, 224)]) 48 | def test_add_iadd(): 49 | return IAdd() 50 | 51 | 52 | class TorchAdd(torch.nn.Module): 53 | 54 | def __init__(self): 55 | super(TorchAdd, self).__init__() 56 | 57 | def forward(self, x, y): 58 | return torch.add(x, y) 59 | 60 | 61 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224), 62 | (1, 3, 224, 224)]) 63 | def test_add_torchadd(): 64 | return TorchAdd() 65 | 66 | 67 | class RAddInt(torch.nn.Module): 68 | 69 | def __init__(self): 70 | super(RAddInt, self).__init__() 71 | 72 | def forward(self, x): 73 | return 1 + x 74 | 75 | 76 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224)]) 77 | def test_add_radd_int(): 78 | return RAddInt() 79 | 80 | 81 | class RAddFloat(torch.nn.Module): 82 | 83 | def __init__(self): 84 | super(RAddFloat, self).__init__() 85 | 86 | def forward(self, x): 87 | return 1.0 + x 88 | 89 | 90 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224)]) 91 | def test_add_radd_float(): 92 | return RAddFloat() 93 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/addcmul.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorrt as trt 3 | import torch 4 | from torch2trt_dynamic.module_test import add_module_test 5 | from torch2trt_dynamic.torch2trt_dynamic import (tensor_trt_get_shape_trt, 6 | tensorrt_converter, trt_) 7 | 8 | 9 | @tensorrt_converter('torch.addcmul') 10 | @tensorrt_converter('torch.Tensor.addcmul') 11 | def convert_addcmul(ctx): 12 | tensor0 = ctx.method_args[0] 13 | 14 | value = 1 15 | next_tensor_offset = 0 16 | if len(ctx.method_args) == 4: 17 | value = ctx.method_args[1] 18 | next_tensor_offset = 1 19 | if 'value' in ctx.method_kwargs: 20 | value = ctx.method_kwargs['value'] 21 | 22 | tensor1 = ctx.method_args[1 + next_tensor_offset] 23 | tensor2 = ctx.method_args[2 + next_tensor_offset] 24 | 25 | input0_trt, input1_trt, input2_trt = trt_(ctx.network, tensor0, tensor1, 26 | tensor2) 27 | output = ctx.method_return 28 | 29 | output_mul_trt = ctx.network.add_elementwise( 30 | input1_trt, input2_trt, trt.ElementWiseOperation.PROD).get_output(0) 31 | if value != 1 or value != 1.: 32 | shift = np.zeros([1], np.float32) 33 | scale = np.array([value], np.float32) 34 | if len(tensor0.shape) < 4: 35 | input_shape_trt = tensor_trt_get_shape_trt(ctx.network, input0_trt) 36 | add_dim = 4 - len(tensor0.shape) 37 | add_trt = trt_(ctx.network, torch.ones([add_dim], 38 | dtype=torch.int32)) 39 | new_input_shape_trt = ctx.network.add_concatenation( 40 | [add_trt, input_shape_trt]).get_output(0) 41 | layer = ctx.network.add_shuffle(output_mul_trt) 42 | layer.set_input(1, new_input_shape_trt) 43 | output_mul_trt = layer.get_output(0) 44 | output_mul_trt = ctx.network.add_scale(output_mul_trt, 45 | trt.ScaleMode.UNIFORM, shift, 46 | scale).get_output(0) 47 | 48 | if len(tensor0.shape) < 4: 49 | layer = ctx.network.add_shuffle(output_mul_trt) 50 | layer.set_input(1, input_shape_trt) 51 | output_mul_trt = layer.get_output(0) 52 | 53 | output_trt = ctx.network.add_elementwise( 54 | input0_trt, output_mul_trt, trt.ElementWiseOperation.SUM).get_output(0) 55 | 56 | output._trt = output_trt 57 | 58 | 59 | class AddcmulTestModule(torch.nn.Module): 60 | 61 | def __init__(self, value): 62 | super(AddcmulTestModule, self).__init__() 63 | self.value = value 64 | 65 | def forward(self, x, y, z): 66 | return torch.addcmul(x, self.value, y, z) 67 | 68 | 69 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 4, 5), (1, 4, 5), 70 | (1, 4, 5)]) 71 | def test_addcmul(): 72 | return AddcmulTestModule(2) 73 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/arange.py: -------------------------------------------------------------------------------- 1 | import tensorrt as trt 2 | import torch 3 | from torch2trt_dynamic.torch2trt_dynamic import (get_arg, tensorrt_converter, 4 | torch_dtype_to_trt, trt_, 5 | trt_cast) 6 | 7 | 8 | @tensorrt_converter('torch.arange') 9 | def convert_arange(ctx): 10 | if len(ctx.method_args) == 1: 11 | start = 0 12 | end = ctx.method_args[0] 13 | kwargs = ctx.method_kwargs 14 | step = 1 if 'step' not in kwargs else kwargs['step'] 15 | dtype = None if 'dtype' not in kwargs else kwargs['dtype'] 16 | else: 17 | start = get_arg(ctx, 'start', pos=0, default=0) 18 | end = get_arg(ctx, 'end', pos=1, default=1) 19 | step = get_arg(ctx, 'step', pos=2, default=1) 20 | dtype = get_arg(ctx, 'dtype', pos=4, default=None) 21 | 22 | output = ctx.method_return 23 | dtype = output.dtype 24 | if dtype == torch.int64: 25 | dtype = torch.int32 26 | 27 | # cast float to int if necessory 28 | if not hasattr(start, '_trt') and start % 1 == 0: 29 | start = int(start) 30 | 31 | if not hasattr(end, '_trt') and end % 1 == 0: 32 | end = int(end) 33 | 34 | if not hasattr(step, '_trt') and step % 1 == 0: 35 | step = int(step) 36 | 37 | # check const 38 | is_const = True 39 | is_const = False if hasattr(start, '_trt') or hasattr( 40 | end, '_trt') or hasattr(step, '_trt') else is_const 41 | if not isinstance(start, int) or not isinstance( 42 | end, int) or not isinstance(step, int): 43 | is_const = True 44 | print('warning: dynamic arange with start:{} end:{} step:{}'.format( 45 | type(start), type(end), type(step)) + ', use constant instead.') 46 | if is_const: 47 | # create const value 48 | output_trt = trt_(ctx.network, output) 49 | 50 | else: 51 | # create fill 52 | 53 | # compute shape 54 | start_trt = trt_(ctx.network, start) 55 | end_trt = trt_(ctx.network, end) 56 | step_trt = trt_(ctx.network, step) 57 | one_trt = trt_(ctx.network, torch.tensor([1], dtype=torch.int32)) 58 | 59 | # length = (end - start + step - 1) // step 60 | length_trt = ctx.network.add_elementwise( 61 | end_trt, start_trt, trt.ElementWiseOperation.SUB).get_output(0) 62 | length_trt = ctx.network.add_elementwise( 63 | length_trt, step_trt, trt.ElementWiseOperation.SUM).get_output(0) 64 | length_trt = ctx.network.add_elementwise( 65 | length_trt, one_trt, trt.ElementWiseOperation.SUB).get_output(0) 66 | length_trt = ctx.network.add_elementwise( 67 | length_trt, step_trt, 68 | trt.ElementWiseOperation.FLOOR_DIV).get_output(0) 69 | 70 | # length to int 71 | length_trt = trt_cast(ctx.network, length_trt, trt.DataType.INT32) 72 | 73 | # start rank 0 74 | layer = ctx.network.add_shuffle(start_trt) 75 | layer.reshape_dims = tuple() 76 | start_trt = layer.get_output(0) 77 | 78 | layer = ctx.network.add_fill(output.shape, trt.FillOperation.LINSPACE) 79 | layer.set_input(0, length_trt) 80 | layer.set_input(1, start_trt) 81 | layer.set_input(2, step_trt) 82 | output_trt = layer.get_output(0) 83 | 84 | # cast data type 85 | data_type = torch_dtype_to_trt(dtype) 86 | 87 | if data_type is not None: 88 | layer = ctx.network.add_identity(output_trt) 89 | layer.set_output_type(0, data_type) 90 | output_trt = layer.get_output(0) 91 | 92 | output._trt = output_trt 93 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/argmax.py: -------------------------------------------------------------------------------- 1 | from torch2trt_dynamic.torch2trt_dynamic import get_arg, tensorrt_converter 2 | 3 | from .flatten import convert_flatten 4 | from .squeeze import convert_squeeze 5 | from .topk import convert_topk 6 | 7 | 8 | @tensorrt_converter('torch.Tensor.argmax') 9 | @tensorrt_converter('torch.argmax') 10 | def convert_argmax(ctx): 11 | 12 | old_args = ctx.method_args 13 | input = ctx.method_args[0] 14 | dim = get_arg(ctx, 'dim', pos=1, default=None) 15 | keepdim = get_arg(ctx, 'keepdim', pos=2, default=False) 16 | 17 | output = ctx.method_return 18 | 19 | # dim is None 20 | if dim is None: 21 | input_flatten = input.flatten() 22 | ctx.method_args = [input] 23 | ctx.method_return = input_flatten 24 | convert_flatten(ctx) 25 | input = ctx.method_return 26 | dim = 0 27 | 28 | # topk 29 | topk_output = input.topk(1, dim) 30 | topk_input = [input, 1, dim] 31 | ctx.method_args = topk_input 32 | ctx.method_return = topk_output 33 | convert_topk(ctx) 34 | topk_index = ctx.method_return[1] 35 | 36 | output._trt = topk_index._trt 37 | ctx.method_return = output 38 | 39 | # keepdim 40 | if not keepdim and topk_index.shape[dim] == 1 and len( 41 | topk_index.shape) > 1: 42 | ctx.method_args = [topk_index, dim] 43 | ctx.method_return = output 44 | convert_squeeze(ctx) 45 | ctx.method_args = old_args 46 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/argmin.py: -------------------------------------------------------------------------------- 1 | from torch2trt_dynamic.torch2trt_dynamic import get_arg, tensorrt_converter 2 | 3 | from .flatten import convert_flatten 4 | from .squeeze import convert_squeeze 5 | from .topk import convert_topk 6 | 7 | 8 | @tensorrt_converter('torch.Tensor.argmin') 9 | @tensorrt_converter('torch.argmin') 10 | def convert_argmin(ctx): 11 | 12 | old_args = ctx.method_args 13 | input = ctx.method_args[0] 14 | dim = get_arg(ctx, 'dim', pos=1, default=None) 15 | keepdim = get_arg(ctx, 'keepdim', pos=2, default=False) 16 | 17 | output = ctx.method_return 18 | 19 | # dim is None 20 | if dim is None: 21 | input_flatten = input.flatten() 22 | ctx.method_args = [input] 23 | ctx.method_return = input_flatten 24 | convert_flatten(ctx) 25 | input = ctx.method_return 26 | dim = 0 27 | 28 | # topk 29 | topk_output = input.topk(1, dim, largest=False) 30 | topk_input = [input, 1, dim, False] 31 | ctx.method_args = topk_input 32 | ctx.method_return = topk_output 33 | convert_topk(ctx) 34 | topk_index = ctx.method_return[1] 35 | 36 | output._trt = topk_index._trt 37 | ctx.method_return = output 38 | 39 | # keepdim 40 | if not keepdim and topk_index.shape[dim] == 1 and len( 41 | topk_index.shape) > 1: 42 | ctx.method_args = [topk_index, dim] 43 | ctx.method_return = output 44 | convert_squeeze(ctx) 45 | ctx.method_args = old_args 46 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/avg_pool2d.py: -------------------------------------------------------------------------------- 1 | import tensorrt as trt 2 | import torch 3 | from torch2trt_dynamic.module_test import add_module_test 4 | from torch2trt_dynamic.torch2trt_dynamic import (get_arg, tensorrt_converter, 5 | trt_) 6 | 7 | 8 | @tensorrt_converter('torch.nn.functional.avg_pool2d') 9 | def convert_avg_pool2d(ctx): 10 | # parse args 11 | input = get_arg(ctx, 'input', pos=0, default=None) 12 | kernel_size = get_arg(ctx, 'kernel_size', pos=1, default=None) 13 | stride = get_arg(ctx, 'stride', pos=2, default=None) 14 | padding = get_arg(ctx, 'padding', pos=3, default=0) 15 | ceil_mode = get_arg(ctx, 'ceil_mode', pos=4, default=False) 16 | count_include_pad = get_arg(ctx, 'count_include_pad', pos=5, default=True) 17 | 18 | # get input trt tensor (or create constant if it doesn't exist) 19 | input_trt = trt_(ctx.network, input) 20 | 21 | output = ctx.method_return 22 | 23 | # get kernel size 24 | if not isinstance(kernel_size, tuple): 25 | kernel_size = (kernel_size, ) * 2 26 | 27 | # get stride 28 | if not isinstance(stride, tuple): 29 | stride = (stride, ) * 2 30 | 31 | # get padding 32 | if not isinstance(padding, tuple): 33 | padding = (padding, ) * 2 34 | 35 | layer = ctx.network.add_pooling( 36 | input=input_trt, type=trt.PoolingType.AVERAGE, window_size=kernel_size) 37 | 38 | layer.stride = stride 39 | layer.padding = padding 40 | layer.average_count_excludes_padding = not count_include_pad 41 | 42 | if ceil_mode: 43 | layer.padding_mode = trt.PaddingMode.EXPLICIT_ROUND_UP 44 | 45 | output._trt = layer.get_output(0) 46 | 47 | 48 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 4, 6)]) 49 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 5, 7)]) 50 | def test_avg_pool2d_without_ceil_mode(): 51 | return torch.nn.AvgPool2d( 52 | kernel_size=3, stride=2, padding=1, ceil_mode=False) 53 | 54 | 55 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 4, 6)]) 56 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 5, 7)]) 57 | def test_avg_pool2d_with_ceil_mode(): 58 | return torch.nn.AvgPool2d( 59 | kernel_size=3, 60 | stride=2, 61 | padding=1, 62 | ceil_mode=True, 63 | count_include_pad=False 64 | ) # TRT does not support ceil_mode=True && count_include_pad=True 65 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/bmm.py: -------------------------------------------------------------------------------- 1 | from ..plugins import create_torchbmm_plugin 2 | from ..torch2trt_dynamic import tensorrt_converter, trt_ 3 | 4 | 5 | @tensorrt_converter('torch.Tensor.bmm') 6 | @tensorrt_converter('torch.bmm') 7 | def convert_bmm(ctx): 8 | mat0 = ctx.method_args[0] 9 | mat1 = ctx.method_args[1] 10 | output = ctx.method_return 11 | 12 | mat0_trt = trt_(ctx.network, mat0) 13 | mat1_trt = trt_(ctx.network, mat1) 14 | 15 | plugin = create_torchbmm_plugin('torch_bmm_' + str(id(mat0))) 16 | 17 | layer = ctx.network.add_plugin_v2( 18 | inputs=[mat0_trt, mat1_trt], plugin=plugin) 19 | 20 | output._trt = layer.get_output(0) 21 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/cast_type.py: -------------------------------------------------------------------------------- 1 | import tensorrt as trt 2 | from torch2trt_dynamic.torch2trt_dynamic import tensorrt_converter, trt_ 3 | 4 | 5 | def convert_type(ctx, data_type): 6 | input = ctx.method_args[0] 7 | output = ctx.method_return 8 | 9 | input_trt = trt_(ctx.network, input) 10 | 11 | layer = ctx.network.add_identity(input_trt) 12 | layer.set_output_type(0, data_type) 13 | output._trt = layer.get_output(0) 14 | output._trt.shape # trick to enable type cast 15 | 16 | 17 | @tensorrt_converter('torch.Tensor.long') 18 | @tensorrt_converter('torch.Tensor.int') 19 | def convert_int(ctx): 20 | convert_type(ctx, trt.DataType.INT32) 21 | 22 | 23 | @tensorrt_converter('torch.Tensor.float') 24 | def convert_float(ctx): 25 | convert_type(ctx, trt.DataType.FLOAT) 26 | 27 | 28 | @tensorrt_converter('torch.Tensor.bool') 29 | def convert_bool(ctx): 30 | convert_type(ctx, trt.DataType.BOOL) 31 | 32 | 33 | @tensorrt_converter('torch.Tensor.type_as') 34 | def convert_type_as(ctx): 35 | input = ctx.method_args[0] 36 | other = ctx.method_args[1] 37 | output = ctx.method_return 38 | 39 | input_trt = trt_(ctx.network, input) 40 | other_trt = trt_(ctx.network, other) 41 | 42 | layer = ctx.network.add_identity(input_trt) 43 | layer.set_output_type(0, other_trt.dtype) 44 | output._trt = layer.get_output(0) 45 | output._trt.shape # trick to enable type cast 46 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/cat.py: -------------------------------------------------------------------------------- 1 | from torch2trt_dynamic.torch2trt_dynamic import (get_arg, tensorrt_converter, 2 | trt_) 3 | 4 | 5 | @tensorrt_converter('torch.cat') 6 | def convert_cat(ctx): 7 | inputs = ctx.method_args[0] 8 | 9 | dim = get_arg(ctx, 'dim', pos=1, default=0) 10 | if dim < 0: 11 | dim = len(inputs[0].shape) + dim 12 | 13 | output = ctx.method_return 14 | trt_inputs = [trt_(ctx.network, i) for i in inputs] 15 | 16 | layer = ctx.network.add_concatenation(inputs=trt_inputs) 17 | 18 | layer.axis = dim 19 | output._trt = layer.get_output(0) 20 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/conv2d.py: -------------------------------------------------------------------------------- 1 | # copy from 2 | # https://github.com/yuzhiyiliu/torch2trt/blob/origin/torch.nn.functional.conv2d_support/torch2trt/converters/conv2d.py 3 | import torch 4 | from torch2trt_dynamic.torch2trt_dynamic import get_arg, tensorrt_converter 5 | 6 | from .Conv2d import convert_Conv2d 7 | 8 | 9 | @tensorrt_converter('torch.nn.functional.conv2d') 10 | def convert_conv2d(ctx): 11 | weight = get_arg(ctx, 'weight', pos=1, default=None) 12 | bias = get_arg(ctx, 'bias', pos=2, default=None) 13 | in_channels = weight.size()[1] 14 | out_channels = weight.size()[0] 15 | kernel_size = tuple(weight.size()[2:4]) 16 | stride = get_arg(ctx, 'stride', pos=3, default=None) 17 | padding = get_arg(ctx, 'padding', pos=4, default=None) 18 | dilation = get_arg(ctx, 'dilation', pos=5, default=None) 19 | groups = get_arg(ctx, 'groups', pos=6, default=None) 20 | need_bias = False if bias is None else True 21 | 22 | module = torch.nn.Conv2d( 23 | in_channels=in_channels, 24 | out_channels=out_channels, 25 | kernel_size=kernel_size, 26 | stride=stride, 27 | padding=padding, 28 | dilation=dilation, 29 | groups=groups, 30 | bias=need_bias) 31 | module.weight = torch.nn.parameter.Parameter(weight) 32 | if bias is not None: 33 | bias = torch.nn.parameter.Parameter(bias) 34 | module.bias = bias 35 | 36 | ctx.method_args = (module, ctx.method_args[0]) 37 | convert_Conv2d(ctx) 38 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/cummax.py: -------------------------------------------------------------------------------- 1 | from torch2trt_dynamic.plugins import create_torchcummaxmin_plugin 2 | from torch2trt_dynamic.torch2trt_dynamic import (get_arg, tensorrt_converter, 3 | trt_) 4 | 5 | 6 | @tensorrt_converter('torch.cummax') 7 | @tensorrt_converter('torch.Tensor.cummax') 8 | def convert_cummax(ctx): 9 | input = ctx.method_args[0] 10 | dim = get_arg(ctx, 'dim', pos=1, default=0) 11 | cum_type = 0 12 | 13 | if dim < 0: 14 | dim = len(input.shape) + dim 15 | 16 | input_trt = trt_(ctx.network, input) 17 | output = ctx.method_return 18 | 19 | plugin = create_torchcummaxmin_plugin( 20 | 'cummax_' + str(id(input)), dim=dim, cum_type=cum_type) 21 | 22 | custom_layer = ctx.network.add_plugin_v2(inputs=[input_trt], plugin=plugin) 23 | 24 | output[0]._trt = custom_layer.get_output(0) 25 | output[1]._trt = custom_layer.get_output(1) 26 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/cummin.py: -------------------------------------------------------------------------------- 1 | from torch2trt_dynamic.torch2trt_dynamic import (get_arg, tensorrt_converter, 2 | trt_) 3 | 4 | from ..plugins import create_torchcummaxmin_plugin 5 | 6 | 7 | @tensorrt_converter('torch.cummin') 8 | @tensorrt_converter('torch.Tensor.cummin') 9 | def convert_cummin(ctx): 10 | input = ctx.method_args[0] 11 | dim = get_arg(ctx, 'dim', pos=1, default=0) 12 | cum_type = 1 13 | 14 | if dim < 0: 15 | dim = len(input.shape) + dim 16 | input_trt = trt_(ctx.network, input) 17 | output = ctx.method_return 18 | 19 | plugin = create_torchcummaxmin_plugin( 20 | 'cummin_' + str(id(input)), dim=dim, cum_type=cum_type) 21 | 22 | custom_layer = ctx.network.add_plugin_v2(inputs=[input_trt], plugin=plugin) 23 | 24 | output[0]._trt = custom_layer.get_output(0) 25 | output[1]._trt = custom_layer.get_output(1) 26 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/cumprod.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch2trt_dynamic.torch2trt_dynamic import (get_arg, tensorrt_converter, 3 | torch_dtype_to_trt, trt_) 4 | 5 | from ..plugins import create_torchcum_plugin 6 | from .cast_type import convert_type 7 | 8 | 9 | @tensorrt_converter('torch.cumprod') 10 | @tensorrt_converter('torch.Tensor.cumprod') 11 | def convert_cumprod(ctx): 12 | old_args = ctx.method_args 13 | old_kwargs = ctx.method_kwargs 14 | input = ctx.method_args[0] 15 | dim = get_arg(ctx, 'dim', pos=1, default=0) 16 | cum_type = 1 17 | 18 | if dim < 0: 19 | dim = len(input.shape) + dim 20 | output = ctx.method_return 21 | 22 | if input.dtype == torch.bool or input.dtype == bool: 23 | cast_input = input.type_as(output) 24 | ctx.method_args = [input] 25 | ctx.method_kwargs = {} 26 | ctx.method_return = cast_input 27 | convert_type(ctx, torch_dtype_to_trt(output.dtype)) 28 | input_trt = trt_(ctx.network, cast_input) 29 | else: 30 | input_trt = trt_(ctx.network, input) 31 | 32 | plugin = create_torchcum_plugin( 33 | 'cumprod_' + str(id(input)), dim=dim, cum_type=cum_type) 34 | 35 | custom_layer = ctx.network.add_plugin_v2(inputs=[input_trt], plugin=plugin) 36 | 37 | output_trt = custom_layer.get_output(0) 38 | 39 | if input.dtype != output.dtype: 40 | tmp_output = output.clone() 41 | tmp_output._trt = output_trt 42 | ctx.method_args = [tmp_output] 43 | ctx.method_kwargs = {} 44 | ctx.method_return = output 45 | convert_type(ctx, torch_dtype_to_trt(output.dtype)) 46 | output_trt = ctx.method_return._trt 47 | 48 | output._trt = output_trt 49 | 50 | ctx.method_args = old_args 51 | ctx.method_kwargs = old_kwargs 52 | ctx.method_return = output 53 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/cumsum.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..plugins import create_torchcum_plugin 4 | from ..torch2trt_dynamic import (get_arg, tensorrt_converter, 5 | torch_dtype_to_trt, trt_) 6 | from .cast_type import convert_type 7 | 8 | 9 | @tensorrt_converter('torch.cumsum') 10 | @tensorrt_converter('torch.Tensor.cumsum') 11 | def convert_cumsum(ctx): 12 | old_args = ctx.method_args 13 | old_kwargs = ctx.method_kwargs 14 | input = ctx.method_args[0] 15 | dim = get_arg(ctx, 'dim', pos=1, default=0) 16 | cum_type = 0 17 | 18 | if dim < 0: 19 | dim = len(input.shape) + dim 20 | output = ctx.method_return 21 | 22 | if input.dtype == torch.bool or input.dtype == bool: 23 | cast_input = input.type_as(output) 24 | ctx.method_args = [input] 25 | ctx.method_kwargs = {} 26 | ctx.method_return = cast_input 27 | convert_type(ctx, torch_dtype_to_trt(output.dtype)) 28 | input_trt = trt_(ctx.network, cast_input) 29 | else: 30 | input_trt = trt_(ctx.network, input) 31 | 32 | plugin = create_torchcum_plugin( 33 | 'cumsum_' + str(id(input)), dim=dim, cum_type=cum_type) 34 | 35 | custom_layer = ctx.network.add_plugin_v2(inputs=[input_trt], plugin=plugin) 36 | 37 | output_trt = custom_layer.get_output(0) 38 | 39 | if input.dtype != output.dtype: 40 | tmp_output = output.clone() 41 | tmp_output._trt = output_trt 42 | ctx.method_args = [tmp_output] 43 | ctx.method_kwargs = {} 44 | ctx.method_return = output 45 | convert_type(ctx, torch_dtype_to_trt(output.dtype)) 46 | output_trt = ctx.method_return._trt 47 | 48 | output._trt = output_trt 49 | 50 | ctx.method_args = old_args 51 | ctx.method_kwargs = old_kwargs 52 | ctx.method_return = output 53 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/deform_conv2d.py: -------------------------------------------------------------------------------- 1 | import torchvision.ops # noqa: F401 2 | 3 | from ..plugins import create_dcn_plugin 4 | from ..torch2trt_dynamic import get_arg, tensorrt_converter, trt_ 5 | 6 | 7 | @tensorrt_converter('torchvision.ops.deform_conv.deform_conv2d') 8 | def convert_deform_conv2d(ctx): 9 | 10 | input = get_arg(ctx, 'input', pos=0, default=None) 11 | offset = get_arg(ctx, 'offset', pos=1, default=None) 12 | weight = get_arg(ctx, 'weight', pos=2, default=None) 13 | bias = get_arg(ctx, 'bias', pos=3, default=None) 14 | stride = get_arg(ctx, 'stride', pos=4, default=1) 15 | padding = get_arg(ctx, 'padding', pos=5, default=0) 16 | dilation = get_arg(ctx, 'dilation', pos=6, default=1) 17 | groups = 1 18 | 19 | output = ctx.method_return 20 | 21 | input_trt = trt_(ctx.network, input) 22 | offset_trt = trt_(ctx.network, offset) 23 | 24 | kernel_size = weight.shape[2] 25 | if not isinstance(kernel_size, tuple): 26 | kernel_size = (kernel_size, ) * 2 27 | 28 | if not isinstance(stride, tuple): 29 | stride = (stride, ) * 2 30 | 31 | if not isinstance(padding, tuple): 32 | padding = (padding, ) * 2 33 | 34 | if not isinstance(dilation, tuple): 35 | dilation = (dilation, ) * 2 36 | 37 | deform_groups = int(offset.shape[1] // 38 | (2 * kernel_size[0] * kernel_size[1])) 39 | 40 | kernel = weight.detach().cpu().numpy() 41 | out_channels = output.shape[1] 42 | 43 | bias = bias.detach().cpu().numpy() 44 | 45 | plugin = create_dcn_plugin( 46 | 'dcn_' + str(id(input)), 47 | out_channels=out_channels, 48 | kernel_size=kernel_size, 49 | W=kernel, 50 | B=bias, 51 | padding=padding, 52 | stride=stride, 53 | dilation=dilation, 54 | deformable_group=deform_groups, 55 | group=groups) 56 | 57 | custom_layer = ctx.network.add_plugin_v2( 58 | inputs=[input_trt, offset_trt], plugin=plugin) 59 | 60 | output._trt = custom_layer.get_output(0) 61 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/div.py: -------------------------------------------------------------------------------- 1 | import tensorrt as trt 2 | import torch 3 | 4 | from ..module_test import add_module_test 5 | from ..torch2trt_dynamic import tensorrt_converter, trt_ 6 | 7 | 8 | @tensorrt_converter('torch.div') 9 | @tensorrt_converter('torch.Tensor.div') 10 | @tensorrt_converter('torch.Tensor.__div__') # py2 11 | @tensorrt_converter('torch.Tensor.__idiv__') # py2 12 | @tensorrt_converter('torch.Tensor.__truediv__') # py3 13 | @tensorrt_converter('torch.Tensor.__itruediv__') # py3 14 | def convert_div(ctx): 15 | input_a = ctx.method_args[0] 16 | input_b = ctx.method_args[1] 17 | input_a_trt, input_b_trt = trt_(ctx.network, input_a, input_b) 18 | output = ctx.method_return 19 | layer = ctx.network.add_elementwise(input_a_trt, input_b_trt, 20 | trt.ElementWiseOperation.DIV) 21 | output._trt = layer.get_output(0) 22 | 23 | 24 | @tensorrt_converter('torch.Tensor.__rdiv__') # py2 25 | @tensorrt_converter('torch.Tensor.__rtruediv__') # py3 26 | def convert_rdiv(ctx): 27 | input_a = ctx.method_args[1] # inputs switched for rdiv 28 | input_b = ctx.method_args[0] 29 | input_a_trt, input_b_trt = trt_(ctx.network, input_a, input_b) 30 | output = ctx.method_return 31 | layer = ctx.network.add_elementwise(input_a_trt, input_b_trt, 32 | trt.ElementWiseOperation.DIV) 33 | output._trt = layer.get_output(0) 34 | 35 | 36 | class Div(torch.nn.Module): 37 | 38 | def __init__(self): 39 | super(Div, self).__init__() 40 | 41 | def forward(self, x, y): 42 | return x / y 43 | 44 | 45 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224), 46 | (1, 3, 224, 224)]) 47 | def test_div_basic(): 48 | return Div() 49 | 50 | 51 | class IDiv(torch.nn.Module): 52 | 53 | def __init__(self): 54 | super(IDiv, self).__init__() 55 | 56 | def forward(self, x, y): 57 | x /= y 58 | return x 59 | 60 | 61 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224), 62 | (1, 3, 224, 224)]) 63 | def test_div_idiv(): 64 | return IDiv() 65 | 66 | 67 | class TorchDiv(torch.nn.Module): 68 | 69 | def __init__(self): 70 | super(TorchDiv, self).__init__() 71 | 72 | def forward(self, x, y): 73 | return torch.div(x, y) 74 | 75 | 76 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224), 77 | (1, 3, 224, 224)]) 78 | def test_div_torchdiv(): 79 | return TorchDiv() 80 | 81 | 82 | class RDivInt(torch.nn.Module): 83 | 84 | def __init__(self): 85 | super(RDivInt, self).__init__() 86 | 87 | def forward(self, x): 88 | return 100 / x 89 | 90 | 91 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3, 3)]) 92 | def test_rdiv_int(): 93 | return RDivInt() 94 | 95 | 96 | class RDivFloat(torch.nn.Module): 97 | 98 | def __init__(self): 99 | super(RDivFloat, self).__init__() 100 | 101 | def forward(self, x): 102 | return 100.0 / x 103 | 104 | 105 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3, 3)]) 106 | def test_rdiv_float(): 107 | return RDivFloat() 108 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/dummy_converters.py: -------------------------------------------------------------------------------- 1 | import torch # noqa: F401,F403 2 | 3 | from ..torch2trt_dynamic import tensorrt_converter 4 | 5 | 6 | def is_private(method): 7 | method = method.split('.')[-1] # remove prefix 8 | return method[0] == '_' and method[1] != '_' 9 | 10 | 11 | def is_function_type(method): 12 | fntype = eval(method + '.__class__.__name__') 13 | return fntype == 'function' or fntype == 'builtin_function_or_method' or \ 14 | fntype == 'method_descriptor' 15 | 16 | 17 | def get_methods(namespace): 18 | methods = [] 19 | for method in dir(eval(namespace)): 20 | full_method = namespace + '.' + method 21 | if not is_private(full_method) and is_function_type(full_method): 22 | methods.append(full_method) 23 | return methods 24 | 25 | 26 | TORCH_METHODS = [] 27 | TORCH_METHODS += get_methods('torch') 28 | TORCH_METHODS += get_methods('torch.Tensor') 29 | TORCH_METHODS += get_methods('torch.nn.functional') 30 | 31 | for method in TORCH_METHODS: 32 | 33 | @tensorrt_converter(method, is_real=False) 34 | def warn_method(ctx): 35 | print('Warning: Encountered known unsupported method %s' % 36 | ctx.method_str) 37 | 38 | 39 | @tensorrt_converter('torch.Tensor.dim', is_real=False) 40 | def dont_warn(ctx): 41 | pass 42 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/exview.py: -------------------------------------------------------------------------------- 1 | import tensorrt as trt 2 | import torch 3 | 4 | from ..torch2trt_dynamic import trt_ 5 | 6 | 7 | def next_symbol_exview(exp, start_pos): 8 | if start_pos >= len(exp): 9 | print('next_symbol_exview out of range', exp, start_pos) 10 | next_pos = start_pos 11 | symbol = -1 12 | isnumber = False 13 | return isnumber, symbol, next_pos 14 | 15 | next_pos = start_pos + 1 16 | if exp[start_pos] < '0' or exp[start_pos] > '9': 17 | symbol = exp[start_pos] 18 | isnumber = False 19 | return isnumber, symbol, next_pos 20 | 21 | symbol = int(exp[start_pos]) 22 | while next_pos < len( 23 | exp) and exp[next_pos] >= '0' and exp[next_pos] <= '9': 24 | symbol = symbol * 10 + int(exp[next_pos]) 25 | next_pos += 1 26 | isnumber = True 27 | return isnumber, int(symbol), next_pos 28 | 29 | 30 | def get_value_exview_impl(ctx, exp, inputs, start_pos): 31 | if start_pos >= len(exp): 32 | print('get_value_exview_impl out of range', exp, start_pos) 33 | return None, None 34 | 35 | isnumber, symbol, next_pos = next_symbol_exview(exp, start_pos) 36 | 37 | if isnumber: 38 | return trt_(ctx.network, 39 | torch.tensor([symbol], 40 | dtype=torch.int32).cuda(0)), next_pos 41 | if symbol.isalpha(): 42 | desc_id = ord(symbol.lower()) - ord('a') 43 | isnumber, symbol, next_pos = next_symbol_exview(exp, next_pos) 44 | if not isnumber: 45 | print('wrong expression1:', exp, 'with symbol:', symbol) 46 | return None, next_pos 47 | return ctx.network.add_slice(inputs[desc_id], [symbol], [1], 48 | [1]).get_output(0), next_pos 49 | elif symbol == '(': 50 | result = parse_exview_string_impl(ctx, exp, inputs, start_pos + 1) 51 | if next_pos >= len(exp) or exp[next_pos] != ')': 52 | print('wrong expression2:', exp, 'with symbol:', symbol) 53 | return None, next_pos 54 | return result, next_pos + 1 55 | 56 | else: 57 | print('wrong expression3:', exp, 'with symbol:', symbol) 58 | return None, next_pos 59 | 60 | 61 | def parse_exview_string_impl(ctx, exp, inputs, start_pos): 62 | if start_pos >= len(exp): 63 | print('parse_exview_string_impl out of range', exp, start_pos) 64 | return None, None 65 | 66 | return_value, next_pos = get_value_exview_impl(ctx, exp, inputs, start_pos) 67 | if return_value is None: 68 | return None, next_pos 69 | 70 | for _ in range(next_pos, len(exp)): 71 | isnumber, symbol, next_pos = next_symbol_exview(exp, next_pos) 72 | 73 | chr_sym = str(symbol) 74 | if not isnumber and chr_sym == ')': 75 | next_pos -= 1 76 | break 77 | 78 | result, next_pos = get_value_exview_impl(ctx, exp, inputs, next_pos) 79 | 80 | elementwise_op = None 81 | if chr_sym == '+': 82 | elementwise_op = trt.ElementWiseOperation.SUM 83 | if chr_sym == '-': 84 | elementwise_op = trt.ElementWiseOperation.SUB 85 | if chr_sym == '*': 86 | elementwise_op = trt.ElementWiseOperation.PROD 87 | if chr_sym == '/': 88 | elementwise_op = trt.ElementWiseOperation.FLOOR_DIV 89 | 90 | if elementwise_op is not None: 91 | return_value = ctx.network.add_elementwise( 92 | return_value, result, elementwise_op).get_output(0) 93 | if next_pos >= len(exp): 94 | break 95 | 96 | return return_value, next_pos 97 | 98 | 99 | def parse_exview_string(ctx, exp, tensors_shape_trt): 100 | result, _ = parse_exview_string_impl(ctx, exp, tensors_shape_trt, 0) 101 | return result 102 | 103 | 104 | def convert_exview(ctx): 105 | input = ctx.method_args[0] 106 | tensors = ctx.method_args[1] 107 | exps = ctx.method_args[2] 108 | input_trt = trt_(ctx.network, input) 109 | tensors_trt = [trt_(ctx.network, t) for t in tensors] 110 | output = ctx.method_return 111 | 112 | tensors_shape_trt = [ 113 | ctx.network.add_shape(t).get_output(0) for t in tensors_trt 114 | ] 115 | 116 | shape_trt = [ 117 | parse_exview_string(ctx, exp, tensors_shape_trt) for exp in exps 118 | ] 119 | shape_trt = ctx.network.add_concatenation(shape_trt).get_output(0) 120 | layer = ctx.network.add_shuffle(input_trt) 121 | layer.set_input(1, shape_trt) 122 | 123 | output._trt = layer.get_output(0) 124 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/flatten.py: -------------------------------------------------------------------------------- 1 | import tensorrt as trt 2 | import torch 3 | 4 | from ..torch2trt_dynamic import get_arg, tensorrt_converter, trt_ 5 | from .identity import convert_identity 6 | 7 | 8 | @tensorrt_converter('torch.Tensor.flatten') 9 | @tensorrt_converter('torch.flatten') 10 | def convert_flatten(ctx): 11 | 12 | input = ctx.method_args[0] 13 | start_dim = get_arg(ctx, 'start_dim', pos=1, default=0) 14 | end_dim = get_arg(ctx, 'end_dim', pos=2, default=-1) 15 | 16 | if start_dim == -1: 17 | start_dim = len(input.shape) - 1 18 | if end_dim == -1: 19 | end_dim = len(input.shape) - 1 20 | if start_dim == end_dim: 21 | ctx.method_args = [input] 22 | convert_identity(ctx) 23 | return 24 | 25 | input_trt = trt_(ctx.network, input) 26 | 27 | # shuffle of bool is not allowed in cudnn 28 | if input.dtype == torch.bool: 29 | layer = ctx.network.add_identity(input_trt) 30 | layer.set_output_type(0, trt.DataType.INT32) 31 | input_trt = layer.get_output(0) 32 | 33 | shape_trt = ctx.network.add_shape(input_trt).get_output(0) 34 | output = ctx.method_return 35 | 36 | shape1_trt = None 37 | shape2_trt = None 38 | if start_dim != 0: 39 | slice1_start = [0] 40 | slice1_size = [start_dim] 41 | slice1_stride = [1] 42 | shape1_trt = ctx.network.add_slice(shape_trt, slice1_start, 43 | slice1_size, 44 | slice1_stride).get_output(0) 45 | if end_dim != len(input.shape) - 1: 46 | slice2_start = [end_dim + 1] 47 | slice2_size = [len(input.shape) - end_dim - 1] 48 | slice2_stride = [1] 49 | shape2_trt = ctx.network.add_slice(shape_trt, slice2_start, 50 | slice2_size, 51 | slice2_stride).get_output(0) 52 | 53 | slice_mid_start = [start_dim] 54 | slice_mid_size = [end_dim - start_dim + 1] 55 | slice_mid_stride = [1] 56 | shape_mid_trt = ctx.network.add_slice(shape_trt, slice_mid_start, 57 | slice_mid_size, 58 | slice_mid_stride).get_output(0) 59 | 60 | # reduce mid 61 | mid_trt = ctx.network.add_slice(shape_mid_trt, [0], [1], [1]).get_output(0) 62 | for i in range(end_dim - start_dim): 63 | other_trt = ctx.network.add_slice(shape_mid_trt, [i + 1], [1], 64 | [1]).get_output(0) 65 | mid_trt = ctx.network.add_elementwise( 66 | mid_trt, other_trt, trt.ElementWiseOperation.PROD).get_output(0) 67 | 68 | shape_mid_trt = mid_trt 69 | 70 | if shape1_trt is None and shape2_trt is None: 71 | new_shape_trt = shape_mid_trt 72 | elif shape1_trt is None: 73 | new_shape_trt = ctx.network.add_concatenation( 74 | [shape_mid_trt, shape2_trt]).get_output(0) 75 | elif shape2_trt is None: 76 | new_shape_trt = ctx.network.add_concatenation( 77 | [shape1_trt, shape_mid_trt]).get_output(0) 78 | else: 79 | new_shape_trt = ctx.network.add_concatenation( 80 | [shape1_trt, shape_mid_trt, shape2_trt]).get_output(0) 81 | 82 | layer = ctx.network.add_shuffle(input_trt) 83 | layer.set_input(1, new_shape_trt) 84 | output_trt = layer.get_output(0) 85 | 86 | if input.dtype == torch.bool: 87 | layer = ctx.network.add_identity(output_trt) 88 | layer.set_output_type(0, trt.DataType.BOOL) 89 | output_trt = layer.get_output(0) 90 | 91 | output._trt = output_trt 92 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/flip.py: -------------------------------------------------------------------------------- 1 | import tensorrt as trt 2 | import torch 3 | 4 | from ..torch2trt_dynamic import (get_arg, slice_shape_trt, 5 | tensor_trt_get_shape_trt, tensorrt_converter, 6 | trt_) 7 | 8 | 9 | @tensorrt_converter('torch.flip') 10 | @tensorrt_converter('torch.Tensor.flip') 11 | def convert_flip(ctx): 12 | input = ctx.method_args[0] 13 | dims = get_arg(ctx, 'dims', pos=1, default=0) 14 | if isinstance(dims, int): 15 | dims = ctx.method_args[1:] 16 | 17 | input_dim = len(input.shape) 18 | dims = [input_dim + dim if dim < 0 else dim for dim in dims] 19 | 20 | input_trt = trt_(ctx.network, input) 21 | output = ctx.method_return 22 | 23 | input_shape_trt = tensor_trt_get_shape_trt(ctx.network, input_trt) 24 | 25 | zero_trt = trt_(ctx.network, input.new_zeros(1, dtype=torch.int32)) 26 | one_trt = trt_(ctx.network, input.new_ones(1, dtype=torch.int32)) 27 | minus_one_trt = trt_(ctx.network, 28 | -1 * input.new_ones(1, dtype=torch.int32)) 29 | starts_trt = [zero_trt for _ in range(input_dim)] 30 | steps_trt = [one_trt for _ in range(input_dim)] 31 | 32 | for d in dims: 33 | tmp_slice_trt = slice_shape_trt(ctx.network, input_shape_trt, d, 1) 34 | starts_trt[d] = ctx.network.add_elementwise( 35 | tmp_slice_trt, one_trt, trt.ElementWiseOperation.SUB).get_output(0) 36 | steps_trt[d] = minus_one_trt 37 | 38 | starts_trt = ctx.network.add_concatenation(starts_trt).get_output(0) 39 | steps_trt = ctx.network.add_concatenation(steps_trt).get_output(0) 40 | 41 | layer = ctx.network.add_slice(input_trt, [0] * input_dim, [1] * input_dim, 42 | [0] * input_dim) 43 | layer.set_input(1, starts_trt) 44 | layer.set_input(2, input_shape_trt) 45 | layer.set_input(3, steps_trt) 46 | 47 | output._trt = layer.get_output(0) 48 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/floor_divide.py: -------------------------------------------------------------------------------- 1 | import tensorrt as trt 2 | 3 | from ..torch2trt_dynamic import tensorrt_converter, trt_ 4 | 5 | 6 | @tensorrt_converter('torch.floor_divide') 7 | @tensorrt_converter('torch.Tensor.floor_divide') 8 | @tensorrt_converter('torch.Tensor.floor_divide_') 9 | @tensorrt_converter('torch.Tensor.__floordiv__') 10 | @tensorrt_converter('torch.Tensor.__ifloordiv__') 11 | def convert_floor_div(ctx): 12 | input_a = ctx.method_args[0] 13 | input_b = ctx.method_args[1] 14 | input_a_trt, input_b_trt = trt_(ctx.network, input_a, input_b) 15 | output = ctx.method_return 16 | layer = ctx.network.add_elementwise(input_a_trt, input_b_trt, 17 | trt.ElementWiseOperation.FLOOR_DIV) 18 | output._trt = layer.get_output(0) 19 | 20 | 21 | @tensorrt_converter('torch.Tensor.__rfloordiv__') 22 | def convert_rfloor_div(ctx): 23 | input_a = ctx.method_args[1] # inputs switched for rdiv 24 | input_b = ctx.method_args[0] 25 | input_a_trt, input_b_trt = trt_(ctx.network, input_a, input_b) 26 | output = ctx.method_return 27 | layer = ctx.network.add_elementwise(input_a_trt, input_b_trt, 28 | trt.ElementWiseOperation.FLOOR_DIV) 29 | output._trt = layer.get_output(0) 30 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/full.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Iterable 2 | 3 | import tensorrt as trt 4 | import torch 5 | from torch2trt_dynamic.torch2trt_dynamic import tensorrt_converter, trt_ 6 | 7 | 8 | @tensorrt_converter('torch.full') 9 | def convert_full(ctx): 10 | size = ctx.method_args[0] 11 | if not isinstance(size, Iterable): 12 | size = ctx.method_args 13 | fill_value = ctx.method_args[1] 14 | dtype = torch.float32 15 | if 'dtype' in ctx.method_kwargs: 16 | dtype = ctx.method_kwargs['dtype'] 17 | output = ctx.method_return 18 | 19 | if isinstance(size, int): 20 | size = (size, ) 21 | 22 | # check const 23 | is_const = True 24 | for s in size: 25 | if hasattr(s, '_trt'): 26 | is_const = False 27 | break 28 | 29 | if is_const: 30 | # create const value 31 | output_trt = trt_(ctx.network, output) 32 | 33 | else: 34 | # create fill 35 | trt_size = [] 36 | for s in size: 37 | if hasattr(s, '_trt'): 38 | trt_size.append(s._trt) 39 | else: 40 | trt_size.append(trt_(ctx.network, s)) 41 | 42 | trt_size = ctx.network.add_concatenation(trt_size).get_output(0) 43 | 44 | layer = ctx.network.add_fill(size, trt.FillOperation.RANDOM_UNIFORM) 45 | layer.set_input(0, trt_size) 46 | layer.set_input( 47 | 1, trt_(ctx.network, 48 | torch.tensor(fill_value, dtype=dtype).cuda())) 49 | layer.set_input( 50 | 2, trt_(ctx.network, 51 | torch.tensor(fill_value, dtype=dtype).cuda())) 52 | 53 | output_trt = layer.get_output(0) 54 | 55 | data_type = None 56 | if dtype == torch.float32: 57 | data_type = trt.DataType.FLOAT 58 | elif dtype == torch.int32 or dtype == torch.long: 59 | data_type = trt.DataType.INT32 60 | elif dtype == torch.bool: 61 | data_type = trt.DataType.BOOL 62 | else: 63 | print('unsupported convert type:{}'.format(dtype)) 64 | 65 | if data_type is not None: 66 | layer = ctx.network.add_identity(output_trt) 67 | layer.set_output_type(0, data_type) 68 | output_trt = layer.get_output(0) 69 | 70 | output._trt = output_trt 71 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/full_like.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..torch2trt_dynamic import get_arg, tensorrt_converter 4 | from .add import convert_add 5 | from .cast_type import convert_bool, convert_float, convert_int 6 | from .mul import convert_mul 7 | 8 | 9 | @tensorrt_converter('torch.full_like') 10 | def convert_full_like(ctx): 11 | input = ctx.method_args[0] 12 | fill_value = get_arg(ctx, 'fill_value', pos=1, default=0) 13 | dtype = get_arg(ctx, 'dtype', pos=3, default=torch.float32) 14 | output = ctx.method_return 15 | 16 | old_method_args = ctx.method_args 17 | old_method_kwargs = ctx.method_kwargs 18 | 19 | # mul zero 20 | input_mul_zero = input * 0 21 | ctx.method_args = [input, 0] 22 | ctx.method_kwargs = {} 23 | ctx.method_return = input_mul_zero 24 | convert_mul(ctx) 25 | 26 | # add fill_value 27 | input_add_one = input_mul_zero + fill_value 28 | ctx.method_args = [input_mul_zero, fill_value] 29 | ctx.method_kwargs = {} 30 | ctx.method_return = input_add_one 31 | convert_add(ctx) 32 | 33 | convert_type_func = None 34 | if dtype == torch.float32: 35 | convert_type_func = convert_float 36 | elif dtype == torch.int32 or dtype == torch.long: 37 | convert_type_func = convert_int 38 | elif dtype == torch.bool: 39 | convert_type_func = convert_bool 40 | else: 41 | print('unsupported convert type:{}'.format(dtype)) 42 | 43 | if convert_type_func is not None: 44 | input_as_type = input_add_one.to(dtype) 45 | ctx.method_args = [input_add_one, dtype] 46 | ctx.method_return = input_as_type 47 | convert_type_func(ctx) 48 | ctx.method_args = [input_as_type, 0] 49 | ctx.method_kwargs = {} 50 | ctx.method_return = output 51 | convert_add(ctx) 52 | 53 | ctx.method_args = old_method_args 54 | ctx.method_kwargs = old_method_kwargs 55 | ctx.method_return = output 56 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/gather.py: -------------------------------------------------------------------------------- 1 | import tensorrt as trt 2 | 3 | from ..torch2trt_dynamic import get_arg, tensorrt_converter, trt_ 4 | 5 | 6 | @tensorrt_converter('torch.Tensor.gather') 7 | @tensorrt_converter('torch.gather') 8 | def convert_gather(ctx): 9 | inputs = ctx.method_args[0] 10 | dim = get_arg(ctx, 'dim', pos=1, default=0) 11 | index = get_arg(ctx, 'index', pos=2, default=None) 12 | output = ctx.method_return 13 | 14 | inputs_trt = trt_(ctx.network, inputs) 15 | index_trt = trt_(ctx.network, index) 16 | 17 | layer = ctx.network.add_gather_v2(inputs_trt, index_trt, 18 | trt.GatherMode.ELEMENT) 19 | layer.num_elementwise_dims = 0 20 | layer.axis = dim 21 | 22 | output._trt = layer.get_output(0) 23 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/gelu.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import tensorrt as trt 4 | from torch2trt_dynamic.torch2trt_dynamic import tensorrt_converter, trt_ 5 | 6 | 7 | @tensorrt_converter('torch.nn.functional.gelu') 8 | def convert_gelu(ctx): 9 | input = ctx.method_args[0] 10 | output = ctx.method_return 11 | 12 | input_trt, b_trt, v1_trt, v0_5_trt, v3_trt = trt_(ctx.network, input, 13 | 0.044715, 1, 0.5, 3) 14 | 15 | layer = ctx.network.add_elementwise(input_trt, v3_trt, 16 | trt.ElementWiseOperation.POW) 17 | input_p3_trt = layer.get_output(0) 18 | 19 | # b*x**3 20 | layer = ctx.network.add_elementwise(input_p3_trt, b_trt, 21 | trt.ElementWiseOperation.PROD) 22 | bx3_trt = layer.get_output(0) 23 | 24 | # x + b*x**3 25 | layer = ctx.network.add_elementwise(bx3_trt, input_trt, 26 | trt.ElementWiseOperation.SUM) 27 | xabx3_trt = layer.get_output(0) 28 | 29 | # tanh() 30 | layer = ctx.network.add_activation(xabx3_trt, 31 | trt.ActivationType.SCALED_TANH) 32 | layer.alpha = 1 33 | layer.beta = math.sqrt(2 / math.pi) 34 | tanh_trt = layer.get_output(0) 35 | 36 | # 1+tanh() 37 | layer = ctx.network.add_elementwise(tanh_trt, v1_trt, 38 | trt.ElementWiseOperation.SUM) 39 | oneatanh_trt = layer.get_output(0) 40 | 41 | # x*() 42 | layer = ctx.network.add_elementwise(input_trt, oneatanh_trt, 43 | trt.ElementWiseOperation.PROD) 44 | xtanh_trt = layer.get_output(0) 45 | 46 | # output 47 | layer = ctx.network.add_elementwise(xtanh_trt, v0_5_trt, 48 | trt.ElementWiseOperation.PROD) 49 | output._trt = layer.get_output(0) 50 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/grid_sample.py: -------------------------------------------------------------------------------- 1 | import tensorrt as trt 2 | 3 | from ..torch2trt_dynamic import get_arg, tensorrt_converter, trt_ 4 | 5 | _MODE_MAP = dict( 6 | bilinear=trt.ResizeMode.LINEAR, 7 | nearest=trt.ResizeMode.NEAREST, 8 | bicubic=trt.ResizeMode.CUBIC) 9 | 10 | _PAD_MODE_MAP = dict( 11 | zeros=trt.SampleMode.FILL, 12 | border=trt.SampleMode.CLAMP, 13 | reflection=trt.SampleMode.REFLECT) 14 | 15 | 16 | @tensorrt_converter('torch.nn.functional.grid_sample') 17 | def convert_grid_sample(ctx): 18 | input = ctx.method_args[0] 19 | grid = get_arg(ctx, 'grid', pos=1, default=None) 20 | mode = get_arg(ctx, 'mode', pos=2, default='bilinear') 21 | padding_mode = get_arg(ctx, 'padding_mode', pos=3, default='zeros') 22 | align_corners = get_arg(ctx, 'align_corners', pos=4, default=False) 23 | 24 | output = ctx.method_return 25 | 26 | input_trt = trt_(ctx.network, input) 27 | grid_trt = trt_(ctx.network, grid) 28 | 29 | mode = _MODE_MAP[mode] 30 | padding_mode = _PAD_MODE_MAP[padding_mode] 31 | 32 | layer = ctx.network.add_grid_sample(input_trt, grid_trt) 33 | layer.interpolation_mode = mode 34 | layer.sample_mode = padding_mode 35 | layer.align_corners = align_corners 36 | 37 | output._trt = layer.get_output(0) 38 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/identity.py: -------------------------------------------------------------------------------- 1 | from ..torch2trt_dynamic import tensorrt_converter, trt_ 2 | 3 | 4 | @tensorrt_converter('torch.Tensor.cuda') 5 | @tensorrt_converter('torch.Tensor.detach') 6 | @tensorrt_converter('torch.Tensor.contiguous') 7 | @tensorrt_converter('torch.nn.functional.dropout') 8 | @tensorrt_converter('torch.nn.functional.dropout2d') 9 | @tensorrt_converter('torch.nn.functional.dropout3d') 10 | def convert_identity(ctx): 11 | input = ctx.method_args[0] 12 | input_trt = trt_(ctx.network, input) 13 | output = ctx.method_return 14 | output._trt = input_trt 15 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/index_select.py: -------------------------------------------------------------------------------- 1 | from ..torch2trt_dynamic import get_arg, tensorrt_converter, trt_ 2 | 3 | 4 | @tensorrt_converter('torch.index_select') 5 | @tensorrt_converter('torch.Tensor.index_select') 6 | def convert_index_select(ctx): 7 | input = ctx.method_args[0] 8 | dim = get_arg(ctx, 'dim', pos=1, default=None) 9 | index = get_arg(ctx, 'index', pos=2, default=None) 10 | 11 | input_trt = trt_(ctx.network, input) 12 | index_trt = trt_(ctx.network, index) 13 | output = ctx.method_return 14 | 15 | layer = ctx.network.add_gather(input_trt, index_trt, dim) 16 | output._trt = layer.get_output(0) 17 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/linear.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..torch2trt_dynamic import get_arg, tensorrt_converter 4 | from .Linear import convert_Linear 5 | 6 | 7 | @tensorrt_converter('torch.nn.functional.linear') 8 | def convert_linear(ctx): 9 | old_method_args = ctx.method_args 10 | old_method_kwargs = ctx.method_kwargs 11 | 12 | input = ctx.method_args[0] 13 | weight = get_arg(ctx, 'weight', pos=1, default=None) 14 | bias = get_arg(ctx, 'bias', pos=2, default=None) 15 | output = ctx.method_return 16 | 17 | in_channels = weight.shape[1] 18 | out_channels = weight.shape[0] 19 | module = torch.nn.Linear(in_channels, out_channels, bias is not None) 20 | module.weight = torch.nn.Parameter(weight) 21 | if bias is not None: 22 | module.bias = torch.nn.Parameter(bias) 23 | 24 | ctx.method_args = [module, input] 25 | ctx.method_kwargs = {} 26 | convert_Linear(ctx) 27 | 28 | ctx.method_args = old_method_args 29 | ctx.method_kwargs = old_method_kwargs 30 | ctx.method_return = output 31 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/linspace.py: -------------------------------------------------------------------------------- 1 | import tensorrt as trt 2 | import torch 3 | 4 | from ..torch2trt_dynamic import (get_arg, tensorrt_converter, 5 | torch_dtype_to_trt, trt_, trt_cast) 6 | 7 | 8 | @tensorrt_converter('torch.linspace') 9 | def convert_linspace(ctx): 10 | start = get_arg(ctx, 'start', pos=0, default=0) 11 | end = get_arg(ctx, 'end', pos=1, default=1) 12 | steps = get_arg(ctx, 'steps', pos=2, default=2) 13 | dtype = get_arg(ctx, 'dtype', pos=4, default=None) 14 | 15 | output = ctx.method_return 16 | dtype = output.dtype 17 | if dtype == torch.int64: 18 | dtype = torch.int32 19 | 20 | # check const 21 | is_const = True 22 | is_const = False if hasattr(start, '_trt') or hasattr( 23 | end, '_trt') or hasattr(steps, '_trt') else is_const 24 | 25 | if is_const: 26 | # create const value 27 | output_trt = trt_(ctx.network, output) 28 | 29 | else: 30 | # create fill 31 | 32 | # compute shape 33 | start_trt = trt_(ctx.network, start) 34 | end_trt = trt_(ctx.network, end) 35 | steps_trt = trt_(ctx.network, steps) 36 | 37 | length_trt = steps_trt 38 | 39 | # to float 40 | one_trt = trt_(ctx.network, torch.tensor([1], dtype=torch.float32)) 41 | start_trt = trt_cast(ctx.network, start_trt, trt.DataType.FLOAT) 42 | end_trt = trt_cast(ctx.network, end_trt, trt.DataType.FLOAT) 43 | steps_trt = trt_cast(ctx.network, steps_trt, trt.DataType.FLOAT) 44 | 45 | # length = (end - start + step - 1) // step 46 | step_trt = ctx.network.add_elementwise( 47 | end_trt, start_trt, trt.ElementWiseOperation.SUB).get_output(0) 48 | step_div_trt = ctx.network.add_elementwise( 49 | steps_trt, one_trt, trt.ElementWiseOperation.SUB).get_output(0) 50 | step_trt = ctx.network.add_elementwise( 51 | step_trt, step_div_trt, trt.ElementWiseOperation.DIV).get_output(0) 52 | 53 | # start rank 0 54 | layer = ctx.network.add_shuffle(start_trt) 55 | layer.reshape_dims = tuple() 56 | start_trt = layer.get_output(0) 57 | 58 | layer = ctx.network.add_fill(output.shape, trt.FillOperation.LINSPACE) 59 | layer.set_input(0, length_trt) 60 | layer.set_input(1, start_trt) 61 | layer.set_input(2, step_trt) 62 | output_trt = layer.get_output(0) 63 | 64 | # cast data type 65 | data_type = torch_dtype_to_trt(dtype) 66 | 67 | if data_type is not None: 68 | layer = ctx.network.add_identity(output_trt) 69 | layer.set_output_type(0, data_type) 70 | output_trt = layer.get_output(0) 71 | 72 | output._trt = output_trt 73 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/logical.py: -------------------------------------------------------------------------------- 1 | import tensorrt as trt 2 | 3 | from ..torch2trt_dynamic import tensorrt_converter, trt_ 4 | from .unary import __convert_unary 5 | 6 | 7 | def convert_compare(ctx, compare_op): 8 | input_a = ctx.method_args[0] 9 | input_b = ctx.method_args[1] 10 | input_a_trt, input_b_trt = trt_(ctx.network, input_a, input_b) 11 | output = ctx.method_return 12 | layer = ctx.network.add_elementwise(input_a_trt, input_b_trt, compare_op) 13 | layer.set_output_type(0, trt.bool) 14 | output._trt = layer.get_output(0) 15 | 16 | 17 | @tensorrt_converter('torch.gt') 18 | @tensorrt_converter('torch.Tensor.gt') 19 | @tensorrt_converter('torch.Tensor.__gt__') 20 | def convert_greater(ctx): 21 | convert_compare(ctx, trt.ElementWiseOperation.GREATER) 22 | 23 | 24 | @tensorrt_converter('torch.lt') 25 | @tensorrt_converter('torch.Tensor.lt') 26 | @tensorrt_converter('torch.Tensor.__lt__') 27 | def convert_less(ctx): 28 | convert_compare(ctx, trt.ElementWiseOperation.LESS) 29 | 30 | 31 | @tensorrt_converter('torch.Tensor.__and__') 32 | def convert_and(ctx): 33 | convert_compare(ctx, trt.ElementWiseOperation.AND) 34 | 35 | 36 | @tensorrt_converter('torch.Tensor.__or__') 37 | def convert_or(ctx): 38 | convert_compare(ctx, trt.ElementWiseOperation.OR) 39 | 40 | 41 | @tensorrt_converter('torch.eq') 42 | @tensorrt_converter('torch.Tensor.eq') 43 | @tensorrt_converter('torch.Tensor.__eq__') 44 | def convert_equal(ctx): 45 | convert_compare(ctx, trt.ElementWiseOperation.EQUAL) 46 | 47 | 48 | @tensorrt_converter('torch.ge') 49 | @tensorrt_converter('torch.Tensor.ge') 50 | @tensorrt_converter('torch.Tensor.__ge__') 51 | def convert_greaterequal(ctx): 52 | input_a = ctx.method_args[0] 53 | input_b = ctx.method_args[1] 54 | output = ctx.method_return 55 | 56 | greater = input_a > input_b 57 | equal = input_a == input_b 58 | 59 | ctx.method_return = greater 60 | convert_greater(ctx) 61 | 62 | ctx.method_return = equal 63 | convert_equal(ctx) 64 | 65 | ctx.method_args = [greater, equal] 66 | ctx.method_return = output 67 | convert_or(ctx) 68 | 69 | 70 | @tensorrt_converter('torch.le') 71 | @tensorrt_converter('torch.Tensor.le') 72 | @tensorrt_converter('torch.Tensor.__le__') 73 | def convert_lessequal(ctx): 74 | input_a = ctx.method_args[0] 75 | input_b = ctx.method_args[1] 76 | output = ctx.method_return 77 | 78 | less = input_a < input_b 79 | equal = input_a == input_b 80 | 81 | ctx.method_return = less 82 | convert_less(ctx) 83 | 84 | ctx.method_return = equal 85 | convert_equal(ctx) 86 | 87 | ctx.method_args = [less, equal] 88 | ctx.method_return = output 89 | convert_or(ctx) 90 | 91 | 92 | @tensorrt_converter('torch.ne') 93 | @tensorrt_converter('torch.Tensor.ne') 94 | @tensorrt_converter('torch.Tensor.__ne__') 95 | def convert_ne(ctx): 96 | input_a = ctx.method_args[0] 97 | input_b = ctx.method_args[1] 98 | output = ctx.method_return 99 | 100 | equal = input_a == input_b 101 | 102 | ctx.method_return = equal 103 | convert_equal(ctx) 104 | 105 | ctx.method_args = [equal] 106 | ctx.method_return = output 107 | __convert_unary(ctx, trt.UnaryOperation.NOT) 108 | 109 | 110 | @tensorrt_converter('torch.logical_xor') 111 | @tensorrt_converter('torch.Tensor.logical_xor') 112 | @tensorrt_converter('torch.Tensor.__xor__') 113 | def convert_xor(ctx): 114 | convert_compare(ctx, trt.ElementWiseOperation.XOR) 115 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/masked_fill.py: -------------------------------------------------------------------------------- 1 | from ..torch2trt_dynamic import get_arg, tensorrt_converter 2 | 3 | 4 | @tensorrt_converter('torch.masked_fill', is_real=False) 5 | @tensorrt_converter('torch.Tensor.masked_fill', is_real=False) 6 | @tensorrt_converter('torch.Tensor.masked_fill_', is_real=False) 7 | def convert_masked_fill(ctx): 8 | input = ctx.method_args[0] 9 | mask = get_arg(ctx, 'mask', pos=1, default=None) 10 | value = get_arg(ctx, 'value', pos=2, default=0) 11 | output = ctx.method_return 12 | 13 | if value == float('-inf'): 14 | import sys 15 | float_info = sys.float_info 16 | value = -(float_info.min * float_info.epsilon) 17 | 18 | float_mask = mask.type_as(input) 19 | result = input * (1 - float_mask) + value * float_mask 20 | 21 | output._trt = result._trt 22 | ctx.method_return = output 23 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/max_pool1d.py: -------------------------------------------------------------------------------- 1 | import tensorrt as trt 2 | import torch 3 | from torch2trt_dynamic.torch2trt_dynamic import (get_arg, tensorrt_converter, 4 | trt_) 5 | 6 | from .squeeze import convert_squeeze 7 | from .unsqueeze import convert_unsqueeze 8 | 9 | 10 | @tensorrt_converter('torch.nn.functional.max_pool1d') 11 | def convert_max_pool1d(ctx): 12 | # parse args 13 | old_args = ctx.method_args 14 | old_kwargs = ctx.method_kwargs 15 | input = get_arg(ctx, 'input', pos=0, default=None) 16 | kernel_size = get_arg(ctx, 'kernel_size', pos=1, default=None) 17 | stride = get_arg(ctx, 'stride', pos=2, default=None) 18 | padding = get_arg(ctx, 'padding', pos=3, default=0) 19 | # dilation = get_arg(ctx, 'dilation', pos=4, default=1) 20 | ceil_mode = get_arg(ctx, 'ceil_mode', pos=5, default=False) 21 | 22 | kernel_size = (kernel_size, 1) 23 | stride = (stride, 1) 24 | padding = (padding, 0) 25 | 26 | output = ctx.method_return 27 | 28 | # unsqueeze -1 29 | unsqueeze_input = input.unsqueeze(-1) 30 | ctx.method_args = [input, -1] 31 | ctx.method_kwargs = {} 32 | ctx.method_return = unsqueeze_input 33 | convert_unsqueeze(ctx) 34 | 35 | # pool2d 36 | input_trt = trt_(ctx.network, unsqueeze_input) 37 | 38 | layer = ctx.network.add_pooling( 39 | input=input_trt, type=trt.PoolingType.MAX, window_size=kernel_size) 40 | 41 | layer.stride = stride 42 | layer.padding = padding 43 | 44 | if ceil_mode: 45 | layer.padding_mode = trt.PaddingMode.EXPLICIT_ROUND_UP 46 | 47 | pool2d_output = torch.nn.functional.max_pool2d( 48 | unsqueeze_input, 49 | kernel_size=kernel_size, 50 | stride=stride, 51 | padding=padding, 52 | ceil_mode=ceil_mode) 53 | pool2d_output._trt = layer.get_output(0) 54 | 55 | # squeeze -1 56 | ctx.method_args = [pool2d_output, -1] 57 | ctx.method_kwargs = {} 58 | ctx.method_return = output 59 | convert_squeeze(ctx) 60 | 61 | ctx.method_args = old_args 62 | ctx.method_kwargs = old_kwargs 63 | ctx.method_return = output 64 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/max_pool2d.py: -------------------------------------------------------------------------------- 1 | import tensorrt as trt 2 | import torch 3 | from torch2trt_dynamic.module_test import add_module_test 4 | from torch2trt_dynamic.torch2trt_dynamic import (get_arg, tensorrt_converter, 5 | trt_) 6 | 7 | 8 | @tensorrt_converter('torch.nn.functional.max_pool2d') 9 | def convert_max_pool2d(ctx): 10 | # parse args 11 | input = get_arg(ctx, 'input', pos=0, default=None) 12 | kernel_size = get_arg(ctx, 'kernel_size', pos=1, default=None) 13 | stride = get_arg(ctx, 'stride', pos=2, default=None) 14 | padding = get_arg(ctx, 'padding', pos=3, default=0) 15 | # dilation = get_arg(ctx, 'dilation', pos=4, default=1) 16 | ceil_mode = get_arg(ctx, 'ceil_mode', pos=5, default=False) 17 | 18 | # get input trt tensor (or create constant if it doesn't exist) 19 | input_trt = trt_(ctx.network, input) 20 | 21 | output = ctx.method_return 22 | 23 | # get kernel size 24 | if not isinstance(kernel_size, tuple): 25 | kernel_size = (kernel_size, ) * 2 26 | 27 | # get stride 28 | if not isinstance(stride, tuple): 29 | stride = (stride, ) * 2 30 | 31 | # get padding 32 | if not isinstance(padding, tuple): 33 | padding = (padding, ) * 2 34 | 35 | layer = ctx.network.add_pooling_nd( 36 | input=input_trt, type=trt.PoolingType.MAX, window_size=kernel_size) 37 | 38 | layer.stride_nd = stride 39 | layer.padding_nd = padding 40 | 41 | if ceil_mode: 42 | layer.padding_mode = trt.PaddingMode.EXPLICIT_ROUND_UP 43 | 44 | output._trt = layer.get_output(0) 45 | 46 | 47 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 4, 6)]) 48 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 5, 7)]) 49 | def test_MaxPool2d_without_ceil_mode(): 50 | return torch.nn.MaxPool2d( 51 | kernel_size=3, stride=2, padding=1, ceil_mode=False) 52 | 53 | 54 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 4, 6)]) 55 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 5, 7)]) 56 | def test_MaxPool2d_with_ceil_mode(): 57 | return torch.nn.MaxPool2d( 58 | kernel_size=3, stride=2, padding=1, ceil_mode=True) 59 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/mean.py: -------------------------------------------------------------------------------- 1 | import tensorrt as trt 2 | import torch 3 | from torch2trt_dynamic.module_test import add_module_test 4 | from torch2trt_dynamic.torch2trt_dynamic import (get_arg, tensorrt_converter, 5 | trt_) 6 | 7 | 8 | @tensorrt_converter('torch.mean') 9 | @tensorrt_converter('torch.Tensor.mean') 10 | def convert_mean(ctx): 11 | input = ctx.method_args[0] 12 | input_trt = trt_(ctx.network, input) 13 | output = ctx.method_return 14 | dim = get_arg(ctx, 'dim', pos=1, default=None) 15 | keep_dims = get_arg(ctx, 'keepdim', pos=2, default=False) 16 | 17 | # get dims from args or kwargs 18 | if dim is None: 19 | dim = tuple(range(len(input.shape))) 20 | 21 | # convert list to tuple 22 | if isinstance(dim, list): 23 | dim = tuple(dim) 24 | 25 | if not isinstance(dim, tuple): 26 | dim = (dim, ) 27 | 28 | dim = tuple([d if d >= 0 else len(input.shape) + d for d in dim]) 29 | 30 | # create axes bitmask for reduce layer 31 | axes = 0 32 | for d in dim: 33 | axes |= 1 << d 34 | 35 | layer = ctx.network.add_reduce(input_trt, trt.ReduceOperation.AVG, axes, 36 | keep_dims) 37 | output._trt = layer.get_output(0) 38 | 39 | 40 | class Mean(torch.nn.Module): 41 | 42 | def __init__(self, dim, keepdim): 43 | super(Mean, self).__init__() 44 | self.dim = dim 45 | self.keepdim = keepdim 46 | 47 | def forward(self, x): 48 | return x.mean(self.dim, self.keepdim) 49 | 50 | 51 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3)]) 52 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3)]) 53 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3, 3)]) 54 | def test_mean_channel(): 55 | return Mean(1, False) 56 | 57 | 58 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3)]) 59 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3, 3)]) 60 | def test_mean_tuple(): 61 | return Mean((1, 2), False) 62 | 63 | 64 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3)]) 65 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3)]) 66 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3, 3)]) 67 | def test_mean_keepdim(): 68 | return Mean(1, True) 69 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/meshgrid.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch2trt_dynamic.torch2trt_dynamic import (tensor_trt_get_shape_trt, 3 | tensorrt_converter, trt_) 4 | 5 | from .repeat import _convert_repeat_impl 6 | 7 | 8 | @tensorrt_converter('torch.meshgrid') 9 | def convert_meshgrid(ctx): 10 | input_list = ctx.method_args 11 | output = ctx.method_return 12 | 13 | num_inputs = len(input_list) 14 | input_trt_list = [ 15 | trt_(ctx.network, input_tensor) for input_tensor in input_list 16 | ] 17 | input_shape_trt_list = [ 18 | tensor_trt_get_shape_trt(ctx.network, input_trt) 19 | for input_trt in input_trt_list 20 | ] 21 | 22 | output_shape_trt = ctx.network.add_concatenation( 23 | input_shape_trt_list).get_output(0) 24 | 25 | one_trt = trt_(ctx.network, torch.ones(1, dtype=torch.int32)) 26 | for index, input_trt in enumerate(input_trt_list): 27 | shuffle_shape_trt = [one_trt] * index 28 | shuffle_shape_trt += [input_shape_trt_list[index]] 29 | shuffle_shape_trt += [one_trt] * (num_inputs - 1 - index) 30 | shuffle_shape_trt = \ 31 | ctx.network.add_concatenation(shuffle_shape_trt).get_output(0) 32 | layer = ctx.network.add_shuffle(input_trt) 33 | layer.set_input(1, shuffle_shape_trt) 34 | input_trt_list[index] = layer.get_output(0) 35 | 36 | for input_trt, out in zip(input_trt_list, output): 37 | out._trt = _convert_repeat_impl(ctx, input_trt, output_shape_trt) 38 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/mod.py: -------------------------------------------------------------------------------- 1 | import tensorrt as trt 2 | from torch2trt_dynamic.torch2trt_dynamic import tensorrt_converter, trt_ 3 | 4 | 5 | @tensorrt_converter('torch.Tensor.__mod__') 6 | def convert_mod(ctx): 7 | input_a = ctx.method_args[0] 8 | input_b = ctx.method_args[1] 9 | input_a_trt, input_b_trt = trt_(ctx.network, input_a, input_b) 10 | output = ctx.method_return 11 | 12 | layer = ctx.network.add_elementwise(input_a_trt, input_b_trt, 13 | trt.ElementWiseOperation.FLOOR_DIV) 14 | floor_div_trt = layer.get_output(0) 15 | 16 | layer = ctx.network.add_elementwise(input_b_trt, floor_div_trt, 17 | trt.ElementWiseOperation.PROD) 18 | prod_trt = layer.get_output(0) 19 | 20 | layer = ctx.network.add_elementwise(input_a_trt, prod_trt, 21 | trt.ElementWiseOperation.SUB) 22 | output._trt = layer.get_output(0) 23 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/mul.py: -------------------------------------------------------------------------------- 1 | import tensorrt as trt 2 | import torch 3 | from torch2trt_dynamic.module_test import add_module_test 4 | from torch2trt_dynamic.torch2trt_dynamic import tensorrt_converter, trt_ 5 | 6 | 7 | @tensorrt_converter('torch.mul') 8 | @tensorrt_converter('torch.Tensor.__imul__') 9 | @tensorrt_converter('torch.Tensor.__mul__') 10 | @tensorrt_converter('torch.Tensor.__rmul__') 11 | def convert_mul(ctx): 12 | input_a = ctx.method_args[0] 13 | input_b = ctx.method_args[1] 14 | input_a_trt, input_b_trt = trt_(ctx.network, input_a, input_b) 15 | output = ctx.method_return 16 | layer = ctx.network.add_elementwise(input_a_trt, input_b_trt, 17 | trt.ElementWiseOperation.PROD) 18 | output._trt = layer.get_output(0) 19 | 20 | 21 | class Mul(torch.nn.Module): 22 | 23 | def __init__(self): 24 | super(Mul, self).__init__() 25 | 26 | def forward(self, x, y): 27 | return x * y 28 | 29 | 30 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224), 31 | (1, 3, 224, 224)]) 32 | def test_mul_basic(): 33 | return Mul() 34 | 35 | 36 | class IMul(torch.nn.Module): 37 | 38 | def __init__(self): 39 | super(IMul, self).__init__() 40 | 41 | def forward(self, x, y): 42 | x *= y 43 | return x 44 | 45 | 46 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224), 47 | (1, 3, 224, 224)]) 48 | def test_mul_imul(): 49 | return IMul() 50 | 51 | 52 | class TorchMul(torch.nn.Module): 53 | 54 | def __init__(self): 55 | super(TorchMul, self).__init__() 56 | 57 | def forward(self, x, y): 58 | return torch.mul(x, y) 59 | 60 | 61 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224), 62 | (1, 3, 224, 224)]) 63 | def test_mul_torchmul(): 64 | return TorchMul() 65 | 66 | 67 | class RMulInt(torch.nn.Module): 68 | 69 | def __init__(self): 70 | super(RMulInt, self).__init__() 71 | 72 | def forward(self, x): 73 | return 10 * x 74 | 75 | 76 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3, 3)]) 77 | def test_rmul_int(): 78 | return RMulInt() 79 | 80 | 81 | class RMulFloat(torch.nn.Module): 82 | 83 | def __init__(self): 84 | super(RMulFloat, self).__init__() 85 | 86 | def forward(self, x): 87 | return 10.0 * x 88 | 89 | 90 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3, 3)]) 91 | def test_rmul_float(): 92 | return RMulFloat() 93 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/narrow.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch2trt_dynamic.torch2trt_dynamic import (get_arg, slice_shape_trt, 3 | tensor_trt_get_shape_trt, 4 | tensorrt_converter, trt_) 5 | 6 | from .size import get_intwarper_trt 7 | 8 | 9 | @tensorrt_converter('torch.Tensor.narrow') 10 | @tensorrt_converter('torch.narrow') 11 | def convert_narrow(ctx): 12 | input = ctx.method_args[0] 13 | if 'dim' in ctx.method_kwargs: 14 | dim = ctx.method_kwargs['dim'] 15 | elif 'dimension' in ctx.method_kwargs: 16 | dim = ctx.method_kwargs['dimension'] 17 | else: 18 | dim = ctx.method_args[1] 19 | input_dim = input.dim() 20 | if dim < 0: 21 | dim = dim + input_dim 22 | 23 | start = get_arg(ctx, 'start', pos=2, default=None) 24 | length = get_arg(ctx, 'length', pos=3, default=None) 25 | 26 | output = ctx.method_return 27 | 28 | input_trt = trt_(ctx.network, input) 29 | 30 | input_shape_trt = tensor_trt_get_shape_trt(ctx.network, input_trt) 31 | start_trt = get_intwarper_trt(start, ctx) 32 | length_trt = get_intwarper_trt(length, ctx) 33 | stride_trt = trt_(ctx.network, torch.ones([input_dim]).int()) 34 | if dim != 0: 35 | start_pre_trt = trt_(ctx.network, 36 | torch.zeros([ 37 | dim, 38 | ]).int()) 39 | start_trt = ctx.network.add_concatenation([start_pre_trt, 40 | start_trt]).get_output(0) 41 | length_pre_trt = slice_shape_trt(ctx.network, input_shape_trt, 0, dim) 42 | length_trt = ctx.network.add_concatenation( 43 | [length_pre_trt, length_trt]).get_output(0) 44 | if dim < input_dim - 1: 45 | start_post_trt = trt_(ctx.network, 46 | torch.zeros([input_dim - dim - 1]).int()) 47 | 48 | start_trt = ctx.network.add_concatenation([start_trt, start_post_trt 49 | ]).get_output(0) 50 | length_post_trt = slice_shape_trt(ctx.network, input_shape_trt, 51 | dim + 1) 52 | length_trt = ctx.network.add_concatenation( 53 | [length_trt, length_post_trt]).get_output(0) 54 | 55 | layer = ctx.network.add_slice(input_trt, [0] * input_dim, [1] * input_dim, 56 | [1] * input_dim) 57 | layer.set_input(1, start_trt) 58 | layer.set_input(2, length_trt) 59 | layer.set_input(3, stride_trt) 60 | output._trt = layer.get_output(0) 61 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/new_ones.py: -------------------------------------------------------------------------------- 1 | import tensorrt as trt 2 | import torch 3 | from torch2trt_dynamic.torch2trt_dynamic import tensorrt_converter, trt_ 4 | 5 | 6 | @tensorrt_converter('torch.Tensor.new_ones') 7 | def convert_new_ones(ctx): 8 | input = ctx.method_args[0] 9 | size = ctx.method_args[1] 10 | if isinstance(size, int): 11 | size = ctx.method_args[1:] 12 | 13 | dtype = input.dtype 14 | if 'dtype' in ctx.method_kwargs: 15 | dtype = ctx.method_kwargs['dtype'] 16 | 17 | output = ctx.method_return 18 | 19 | if isinstance(size, int): 20 | size = (size, ) 21 | 22 | # check const 23 | is_const = True 24 | for s in size: 25 | if hasattr(s, '_trt'): 26 | is_const = False 27 | break 28 | 29 | if is_const: 30 | # create const value 31 | output_trt = trt_(ctx.network, output) 32 | 33 | else: 34 | # create fill 35 | trt_size = [] 36 | for s in size: 37 | if hasattr(s, '_trt'): 38 | trt_size.append(s._trt) 39 | else: 40 | trt_size.append(trt_(ctx.network, s)) 41 | 42 | trt_size = ctx.network.add_concatenation(trt_size).get_output(0) 43 | 44 | layer = ctx.network.add_fill(size, trt.FillOperation.RANDOM_UNIFORM) 45 | layer.set_input(0, trt_size) 46 | layer.set_input(1, trt_(ctx.network, input.new_tensor(1))) 47 | layer.set_input(2, trt_(ctx.network, input.new_tensor(1))) 48 | 49 | output_trt = layer.get_output(0) 50 | 51 | data_type = None 52 | if dtype == torch.float32: 53 | data_type = trt.DataType.FLOAT 54 | elif dtype == torch.int32 or dtype == torch.long: 55 | data_type = trt.DataType.INT32 56 | elif dtype == torch.bool: 57 | data_type = trt.DataType.BOOL 58 | else: 59 | print('unsupported convert type:{}'.format(dtype)) 60 | 61 | if data_type is not None: 62 | layer = ctx.network.add_identity(output_trt) 63 | layer.set_output_type(0, data_type) 64 | output_trt = layer.get_output(0) 65 | 66 | output._trt = output_trt 67 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/new_zeros.py: -------------------------------------------------------------------------------- 1 | import tensorrt as trt 2 | import torch 3 | from torch2trt_dynamic.torch2trt_dynamic import tensorrt_converter, trt_ 4 | 5 | 6 | @tensorrt_converter('torch.Tensor.new_zeros') 7 | def convert_new_zeros(ctx): 8 | input = ctx.method_args[0] 9 | size = ctx.method_args[1] 10 | if isinstance(size, int): 11 | size = ctx.method_args[1:] 12 | 13 | dtype = input.dtype 14 | if 'dtype' in ctx.method_kwargs: 15 | dtype = ctx.method_kwargs['dtype'] 16 | 17 | output = ctx.method_return 18 | 19 | if isinstance(size, int): 20 | size = (size, ) 21 | 22 | # check const 23 | is_const = True 24 | for s in size: 25 | if hasattr(s, '_trt'): 26 | is_const = False 27 | break 28 | 29 | if is_const: 30 | # create const value 31 | output_trt = trt_(ctx.network, output) 32 | 33 | else: 34 | # create fill 35 | trt_size = [] 36 | for s in size: 37 | if hasattr(s, '_trt'): 38 | trt_size.append(s._trt) 39 | else: 40 | trt_size.append(trt_(ctx.network, s)) 41 | 42 | trt_size = ctx.network.add_concatenation(trt_size).get_output(0) 43 | 44 | layer = ctx.network.add_fill(size, trt.FillOperation.RANDOM_UNIFORM) 45 | layer.set_input(0, trt_size) 46 | layer.set_input(1, trt_(ctx.network, input.new_tensor(0))) 47 | layer.set_input(2, trt_(ctx.network, input.new_tensor(0))) 48 | 49 | output_trt = layer.get_output(0) 50 | 51 | data_type = None 52 | if dtype == torch.float32: 53 | data_type = trt.DataType.FLOAT 54 | elif dtype == torch.int32 or dtype == torch.long: 55 | data_type = trt.DataType.INT32 56 | elif dtype == torch.bool: 57 | data_type = trt.DataType.BOOL 58 | else: 59 | print('unsupported convert type:{}'.format(dtype)) 60 | 61 | if data_type is not None: 62 | layer = ctx.network.add_identity(output_trt) 63 | layer.set_output_type(0, data_type) 64 | output_trt = layer.get_output(0) 65 | 66 | output._trt = output_trt 67 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/nms.py: -------------------------------------------------------------------------------- 1 | import torchvision.ops # noqa: F401 2 | from torch2trt_dynamic.plugins import create_nms_plugin 3 | from torch2trt_dynamic.torch2trt_dynamic import (get_arg, tensorrt_converter, 4 | trt_) 5 | 6 | 7 | @tensorrt_converter('torchvision.ops.nms') 8 | def convert_nms(ctx): 9 | 10 | boxes = get_arg(ctx, 'boxes', pos=0, default=None) 11 | scores = get_arg(ctx, 'scores', pos=1, default=None) 12 | iou_threshold = get_arg(ctx, 'iou_threshold', pos=2, default=0.7) 13 | 14 | output = ctx.method_return 15 | 16 | boxes_trt = trt_(ctx.network, boxes) 17 | scores_trt = trt_(ctx.network, scores) 18 | 19 | plugin = create_nms_plugin( 20 | 'nms_' + str(id(boxes)), iou_threshold=iou_threshold) 21 | 22 | custom_layer = ctx.network.add_plugin_v2( 23 | inputs=[boxes_trt, scores_trt], plugin=plugin) 24 | 25 | output._trt = custom_layer.get_output(0) 26 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/normalize.py: -------------------------------------------------------------------------------- 1 | import tensorrt as trt 2 | import torch 3 | from torch2trt_dynamic.module_test import add_module_test 4 | from torch2trt_dynamic.torch2trt_dynamic import (get_arg, tensorrt_converter, 5 | torch_dim_to_trt_axes, trt_) 6 | 7 | 8 | @tensorrt_converter('torch.nn.functional.normalize') 9 | def convert_normalize(ctx): 10 | # get args 11 | input = get_arg(ctx, 'input', pos=0, default=None) 12 | p = get_arg(ctx, 'p', pos=1, default=2) 13 | dim = get_arg(ctx, 'dim', pos=2, default=1) 14 | eps = get_arg(ctx, 'eps', pos=3, default=1e-12) 15 | if dim < 0: 16 | dim = len(input.shape) + dim 17 | 18 | 19 | # input_trt = input._trt 20 | output = ctx.method_return 21 | 22 | # add broadcastable scalar constants to network 23 | input_trt, eps_trt, p_trt, p_inv_trt = trt_(ctx.network, input, eps, p, 24 | 1.0 / p) 25 | 26 | # compute norm = sum(abs(x)**p, dim=dim)**(1./p) 27 | norm = ctx.network.add_unary(input_trt, 28 | trt.UnaryOperation.ABS).get_output(0) 29 | norm = ctx.network.add_elementwise( 30 | norm, p_trt, trt.ElementWiseOperation.POW).get_output(0) 31 | norm = ctx.network.add_reduce( 32 | norm, 33 | trt.ReduceOperation.SUM, 34 | torch_dim_to_trt_axes(dim), 35 | keep_dims=True).get_output(0) 36 | norm = ctx.network.add_elementwise( 37 | norm, p_inv_trt, trt.ElementWiseOperation.POW).get_output(0) 38 | 39 | # clamp norm = max(norm, eps) 40 | norm = ctx.network.add_elementwise( 41 | norm, eps_trt, trt.ElementWiseOperation.MAX).get_output(0) 42 | 43 | # divide input by norm 44 | output._trt = ctx.network.add_elementwise( 45 | input_trt, norm, trt.ElementWiseOperation.DIV).get_output(0) 46 | 47 | 48 | class Normalize(torch.nn.Module): 49 | 50 | def __init__(self, *args, **kwargs): 51 | super(Normalize, self).__init__() 52 | self.args = args 53 | self.kwargs = kwargs 54 | 55 | def forward(self, x): 56 | return torch.nn.functional.normalize(x, *self.args, **self.kwargs) 57 | 58 | 59 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3)]) 60 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3)]) 61 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3, 3)]) 62 | def test_normalize_basic(): 63 | return Normalize() 64 | 65 | 66 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3)]) 67 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3)]) 68 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3, 3)]) 69 | def test_normalize_l1_basic(): 70 | return Normalize(p=1.0) 71 | 72 | 73 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3)]) 74 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3)]) 75 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3, 3)]) 76 | def test_normalize_l1p5_basic(): 77 | return Normalize(p=1.5) 78 | 79 | 80 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3)]) 81 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3, 3)]) 82 | def test_normalize_l2_height(): 83 | return Normalize(p=2.0, dim=2) 84 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/numel.py: -------------------------------------------------------------------------------- 1 | import tensorrt as trt 2 | from torch2trt_dynamic.torch2trt_dynamic import (slice_shape_trt, 3 | tensor_trt_get_shape_trt, 4 | tensorrt_converter, trt_) 5 | 6 | from .size import IntWarper 7 | 8 | 9 | @tensorrt_converter('torch.Tensor.numel') 10 | def convert_numel(ctx): 11 | input = ctx.method_args[0] 12 | 13 | input_trt = trt_(ctx.network, input) 14 | shape_trt = tensor_trt_get_shape_trt(ctx.network, input_trt) 15 | num = ctx.method_return 16 | 17 | num_trt = slice_shape_trt(ctx.network, shape_trt, 0, 1) 18 | for i in range(1, len(input.shape)): 19 | other_trt = slice_shape_trt(ctx.network, shape_trt, i, 1) 20 | num_trt = ctx.network.add_elementwise( 21 | num_trt, other_trt, trt.ElementWiseOperation.PROD).get_output(0) 22 | intwarper = IntWarper(num) 23 | intwarper._trt = num_trt 24 | 25 | ctx.method_return = intwarper 26 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/ones.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Iterable 2 | 3 | import tensorrt as trt 4 | import torch 5 | from torch2trt_dynamic.torch2trt_dynamic import tensorrt_converter, trt_ 6 | 7 | 8 | @tensorrt_converter('torch.ones') 9 | def convert_ones(ctx): 10 | size = ctx.method_args[0] 11 | if not isinstance(size, Iterable): 12 | size = ctx.method_args 13 | dtype = torch.float32 14 | if 'dtype' in ctx.method_kwargs: 15 | dtype = ctx.method_kwargs['dtype'] 16 | output = ctx.method_return 17 | 18 | if isinstance(size, int): 19 | size = (size, ) 20 | 21 | # check const 22 | is_const = True 23 | for s in size: 24 | if hasattr(s, '_trt'): 25 | is_const = False 26 | break 27 | 28 | if is_const: 29 | # create const value 30 | output_trt = trt_(ctx.network, output) 31 | 32 | else: 33 | # create fill 34 | trt_size = [] 35 | for s in size: 36 | if hasattr(s, '_trt'): 37 | trt_size.append(s._trt) 38 | else: 39 | trt_size.append(trt_(ctx.network, s)) 40 | 41 | trt_size = ctx.network.add_concatenation(trt_size).get_output(0) 42 | 43 | layer = ctx.network.add_fill(size, trt.FillOperation.RANDOM_UNIFORM) 44 | layer.set_input(0, trt_size) 45 | layer.set_input( 46 | 1, trt_(ctx.network, 47 | torch.tensor(1., dtype=dtype).cuda())) 48 | layer.set_input( 49 | 2, trt_(ctx.network, 50 | torch.tensor(1., dtype=dtype).cuda())) 51 | 52 | output_trt = layer.get_output(0) 53 | 54 | data_type = None 55 | if dtype == torch.float32: 56 | data_type = trt.DataType.FLOAT 57 | elif dtype == torch.int32 or dtype == torch.long: 58 | data_type = trt.DataType.INT32 59 | elif dtype == torch.bool: 60 | data_type = trt.DataType.BOOL 61 | else: 62 | print('unsupported convert type:{}'.format(dtype)) 63 | 64 | if data_type is not None: 65 | layer = ctx.network.add_identity(output_trt) 66 | layer.set_output_type(0, data_type) 67 | output_trt = layer.get_output(0) 68 | 69 | output._trt = output_trt 70 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/ones_like.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch2trt_dynamic.torch2trt_dynamic import get_arg, tensorrt_converter 3 | 4 | from .add import convert_add 5 | from .cast_type import convert_bool, convert_float, convert_int 6 | from .mul import convert_mul 7 | 8 | 9 | @tensorrt_converter('torch.ones_like') 10 | def convert_ones_like(ctx): 11 | input = ctx.method_args[0] 12 | dtype = get_arg(ctx, 'dtype', pos=1, default=torch.float32) 13 | output = ctx.method_return 14 | 15 | old_method_args = ctx.method_args 16 | old_method_kwargs = ctx.method_kwargs 17 | 18 | # mul zero 19 | input_mul_zero = input * 0 20 | ctx.method_args = [input, 0] 21 | ctx.method_kwargs = {} 22 | ctx.method_return = input_mul_zero 23 | convert_mul(ctx) 24 | 25 | # add one 26 | input_add_one = input_mul_zero + 1 27 | ctx.method_args = [input_mul_zero, 1] 28 | ctx.method_kwargs = {} 29 | ctx.method_return = input_add_one 30 | convert_add(ctx) 31 | 32 | convert_type_func = None 33 | if dtype == torch.float32: 34 | convert_type_func = convert_float 35 | elif dtype == torch.int32 or dtype == torch.long: 36 | convert_type_func = convert_int 37 | elif dtype == torch.bool: 38 | convert_type_func = convert_bool 39 | else: 40 | print('unsupported convert type:{}'.format(dtype)) 41 | 42 | if convert_type_func is not None: 43 | input_as_type = input_add_one.to(dtype) 44 | ctx.method_args = [input_add_one, dtype] 45 | ctx.method_return = input_as_type 46 | convert_type_func(ctx) 47 | ctx.method_args = [input_as_type, 0] 48 | ctx.method_kwargs = {} 49 | ctx.method_return = output 50 | convert_add(ctx) 51 | 52 | ctx.method_args = old_method_args 53 | ctx.method_kwargs = old_method_kwargs 54 | ctx.method_return = output 55 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/pad.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch2trt_dynamic.module_test import add_module_test 3 | from torch2trt_dynamic.torch2trt_dynamic import (get_arg, tensorrt_converter, 4 | trt_) 5 | 6 | 7 | @tensorrt_converter('torch.nn.functional.pad') 8 | def convert_pad(ctx): 9 | input = ctx.method_args[0] 10 | input_trt = trt_(ctx.network, input) 11 | output = ctx.method_return 12 | 13 | pad = get_arg(ctx, 'pad', pos=1, default=[0, 0, 0, 0]) 14 | pre_padding = (pad[2], pad[0]) 15 | post_padding = (pad[3], pad[1]) 16 | 17 | # mode / value are ignored since not supported by TensorRT 18 | 19 | layer = ctx.network.add_padding(input_trt, pre_padding, post_padding) 20 | output._trt = layer.get_output(0) 21 | 22 | 23 | class Pad(torch.nn.Module): 24 | 25 | def __init__(self, pad): 26 | super(Pad, self).__init__() 27 | self.pad = pad 28 | 29 | def forward(self, x): 30 | return torch.nn.functional.pad(x, self.pad) 31 | 32 | 33 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224)]) 34 | def test_pad_basic(): 35 | return Pad((1, 2, 3, 4)) 36 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/permute.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch2trt_dynamic.module_test import add_module_test 3 | from torch2trt_dynamic.torch2trt_dynamic import tensorrt_converter, trt_ 4 | 5 | 6 | @tensorrt_converter('torch.Tensor.permute') 7 | def convert_permute(ctx): 8 | input = ctx.method_args[0] 9 | input_trt = trt_(ctx.network, input) 10 | output = ctx.method_return 11 | 12 | # permutation -1 because TRT does not include batch dim 13 | if isinstance(ctx.method_args[1], int): 14 | permutation = tuple(ctx.method_args[1:]) # handle permute(a, b, c) 15 | else: 16 | permutation = tuple(ctx.method_args[1]) # handle permute([a, b, c]) 17 | 18 | trt_permutation = permutation 19 | 20 | layer = ctx.network.add_shuffle(input_trt) 21 | layer.second_transpose = tuple(trt_permutation) 22 | 23 | output._trt = layer.get_output(0) 24 | 25 | 26 | class Permute(torch.nn.Module): 27 | 28 | def __init__(self, *args): 29 | super(Permute, self).__init__() 30 | self.args = args 31 | 32 | def forward(self, x): 33 | return x.permute(*self.args).contiguous() 34 | 35 | 36 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 4, 5)]) 37 | def test_permute_2d_0123(): 38 | return Permute(0, 1, 2, 3) 39 | 40 | 41 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 4, 5)]) 42 | def test_permute_2d_0312(): 43 | return Permute(0, 3, 1, 2) 44 | 45 | 46 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 4, 5, 6)]) 47 | def test_permute_3d_01234(): 48 | return Permute(0, 1, 2, 3, 4) 49 | 50 | 51 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 4, 5, 6)]) 52 | def test_permute_3d_04132(): 53 | return Permute(0, 4, 1, 3, 2) 54 | 55 | 56 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 4, 5, 6)]) 57 | def test_permute_list(): 58 | return Permute([0, 4, 1, 3, 2]) 59 | 60 | 61 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 4, 5, 6)]) 62 | def test_permute_tuple(): 63 | return Permute((0, 4, 1, 3, 2)) 64 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/pixel_shuffle.py: -------------------------------------------------------------------------------- 1 | import tensorrt as trt 2 | import torch 3 | from torch2trt_dynamic.torch2trt_dynamic import (get_arg, tensorrt_converter, 4 | trt_) 5 | 6 | 7 | @tensorrt_converter('torch.nn.functional.pixel_shuffle') 8 | def convert_pixel_shuffle(ctx): 9 | 10 | input = ctx.method_args[0] 11 | upscale_factor = get_arg(ctx, 'upscale_factor', pos=1, default=None) 12 | 13 | input_trt = trt_(ctx.network, input) 14 | input_shape_trt = ctx.network.add_shape(input_trt).get_output(0) 15 | output = ctx.method_return 16 | 17 | batch_shape_trt = ctx.network.add_slice(input_shape_trt, [0], [1], 18 | [1]).get_output(0) 19 | channel_shape_trt = ctx.network.add_slice(input_shape_trt, [1], [1], 20 | [1]).get_output(0) 21 | height_shape_trt = ctx.network.add_slice(input_shape_trt, [2], [1], 22 | [1]).get_output(0) 23 | width_shape_trt = ctx.network.add_slice(input_shape_trt, [3], [1], 24 | [1]).get_output(0) 25 | 26 | upscale_shape_trt = trt_( 27 | ctx.network, 28 | torch.tensor([upscale_factor], dtype=torch.int32).to(input.device)) 29 | upscale_p2_trt = ctx.network.add_elementwise( 30 | upscale_shape_trt, upscale_shape_trt, 31 | trt.ElementWiseOperation.PROD).get_output(0) 32 | new_channel_shape_trt = ctx.network.add_elementwise( 33 | channel_shape_trt, upscale_p2_trt, 34 | trt.ElementWiseOperation.FLOOR_DIV).get_output(0) 35 | 36 | # (b, c0, s, s, h, w) 37 | pre_shape_trt = ctx.network.add_concatenation([ 38 | batch_shape_trt, new_channel_shape_trt, upscale_shape_trt, 39 | upscale_shape_trt, height_shape_trt, width_shape_trt 40 | ]).get_output(0) 41 | 42 | layer = ctx.network.add_shuffle(input_trt) 43 | layer.set_input(1, pre_shape_trt) 44 | layer.second_transpose = (0, 1, 4, 2, 5, 3) 45 | 46 | permute_trt = layer.get_output(0) 47 | 48 | new_height_shape_trt = ctx.network.add_elementwise( 49 | height_shape_trt, upscale_shape_trt, 50 | trt.ElementWiseOperation.PROD).get_output(0) 51 | new_width_shape_trt = ctx.network.add_elementwise( 52 | width_shape_trt, upscale_shape_trt, 53 | trt.ElementWiseOperation.PROD).get_output(0) 54 | 55 | post_shape_trt = ctx.network.add_concatenation([ 56 | batch_shape_trt, new_channel_shape_trt, new_height_shape_trt, 57 | new_width_shape_trt 58 | ]).get_output(0) 59 | 60 | layer = ctx.network.add_shuffle(permute_trt) 61 | layer.set_input(1, post_shape_trt) 62 | output._trt = layer.get_output(0) 63 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/pow.py: -------------------------------------------------------------------------------- 1 | import tensorrt as trt 2 | import torch 3 | from torch2trt_dynamic.module_test import add_module_test 4 | from torch2trt_dynamic.torch2trt_dynamic import tensorrt_converter, trt_ 5 | 6 | 7 | @tensorrt_converter('torch.pow') 8 | @tensorrt_converter('torch.Tensor.pow') 9 | @tensorrt_converter('torch.Tensor.__ipow__') 10 | @tensorrt_converter('torch.Tensor.__pow__') 11 | def convert_pow(ctx): 12 | input_a = ctx.method_args[0] 13 | input_b = ctx.method_args[1] 14 | input_a_trt, input_b_trt = trt_(ctx.network, input_a, input_b) 15 | output = ctx.method_return 16 | layer = ctx.network.add_elementwise(input_a_trt, input_b_trt, 17 | trt.ElementWiseOperation.POW) 18 | output._trt = layer.get_output(0) 19 | 20 | 21 | @tensorrt_converter('torch.Tensor.__rpow__') 22 | def convert_rpow(ctx): 23 | input_a = ctx.method_args[1] 24 | input_b = ctx.method_args[0] # flipped for rpow 25 | input_a_trt, input_b_trt = trt_(ctx.network, input_a, input_b) 26 | output = ctx.method_return 27 | layer = ctx.network.add_elementwise(input_a_trt, input_b_trt, 28 | trt.ElementWiseOperation.POW) 29 | output._trt = layer.get_output(0) 30 | 31 | 32 | class Pow(torch.nn.Module): 33 | 34 | def __init__(self): 35 | super(Pow, self).__init__() 36 | 37 | def forward(self, x, y): 38 | return x**y 39 | 40 | 41 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224), 42 | (1, 3, 224, 224)]) 43 | def test_pow_basic(): 44 | return Pow() 45 | 46 | 47 | # __ipow__ not yet impl in torch 48 | # class IPow(torch.nn.Module): 49 | # def __init__(self): 50 | # super(IPow, self).__init__() 51 | 52 | # def forward(self, x, y): 53 | # x **= y 54 | # return x 55 | 56 | # @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224), (1, 3, 224, 224)]) # noqa: E501 57 | # def test_pow_ipow(): 58 | # return IPow() 59 | 60 | 61 | class TorchPow(torch.nn.Module): 62 | 63 | def __init__(self): 64 | super(TorchPow, self).__init__() 65 | 66 | def forward(self, x, y): 67 | return torch.pow(x, y) 68 | 69 | 70 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224), 71 | (1, 3, 224, 224)]) 72 | def test_torch_pow(): 73 | return TorchPow() 74 | 75 | 76 | class RpowInt(torch.nn.Module): 77 | 78 | def __init__(self): 79 | super(RpowInt, self).__init__() 80 | 81 | def forward(self, x): 82 | return 2**x 83 | 84 | 85 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224)]) 86 | def test_rpow_int(): 87 | return RpowInt() 88 | 89 | 90 | class RpowFloat(torch.nn.Module): 91 | 92 | def __init__(self): 93 | super(RpowFloat, self).__init__() 94 | 95 | def forward(self, x): 96 | return 2.0**x 97 | 98 | 99 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224)]) 100 | def test_rpow_float(): 101 | return RpowFloat() 102 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/prelu.py: -------------------------------------------------------------------------------- 1 | import tensorrt as trt 2 | import torch 3 | from torch2trt_dynamic.module_test import add_module_test 4 | from torch2trt_dynamic.torch2trt_dynamic import (get_arg, tensorrt_converter, 5 | trt_) 6 | 7 | 8 | @tensorrt_converter('torch.nn.functional.prelu') 9 | def convert_prelu(ctx): 10 | input = get_arg(ctx, 'input', pos=0, default=None) 11 | weight = get_arg(ctx, 'weight', pos=1, default=None) 12 | output = ctx.method_return 13 | 14 | weight_shape = [1] * len(input.shape) 15 | weight_shape[1] = weight.numel() 16 | 17 | input_trt = trt_(ctx.network, input) 18 | 19 | # y = prelu(x) = relu(x) - alpha * relu(-x) 20 | weight_trt = ctx.network.add_constant( 21 | weight_shape, 22 | -weight.detach().view(weight_shape).cpu().numpy()).get_output( 23 | 0) # detach so considered leaf 24 | 25 | # x >= 0 26 | a = ctx.network.add_activation(input_trt, 27 | trt.ActivationType.RELU).get_output(0) 28 | 29 | # x <= 0 30 | b = ctx.network.add_unary(input_trt, trt.UnaryOperation.NEG).get_output(0) 31 | b = ctx.network.add_activation(b, trt.ActivationType.RELU).get_output(0) 32 | b = ctx.network.add_elementwise( 33 | b, weight_trt, trt.ElementWiseOperation.PROD).get_output(0) 34 | 35 | # y = a + b 36 | y = ctx.network.add_elementwise(a, b, trt.ElementWiseOperation.SUM) 37 | 38 | output._trt = y.get_output(0) 39 | 40 | 41 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 5)]) 42 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 5, 3)]) 43 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 5, 3, 3)]) 44 | def test_prelu_scalar(): 45 | return torch.nn.PReLU() 46 | 47 | 48 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 5)]) 49 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 5, 3)]) 50 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 5, 3, 3)]) 51 | def test_prelu_vector(): 52 | m = torch.nn.PReLU(5) 53 | m.weight = torch.nn.Parameter( 54 | torch.randn(5)) # randn so each channel different 55 | return m 56 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/prod.py: -------------------------------------------------------------------------------- 1 | import tensorrt as trt 2 | import torch 3 | from torch2trt_dynamic.module_test import add_module_test 4 | from torch2trt_dynamic.torch2trt_dynamic import (get_arg, tensorrt_converter, 5 | torch_dim_to_trt_axes, trt_) 6 | 7 | from .unary import UnaryModule 8 | 9 | 10 | @tensorrt_converter('torch.prod') 11 | @tensorrt_converter('torch.Tensor.prod') 12 | def convert_prod(ctx): 13 | input = ctx.method_args[0] 14 | dim = get_arg(ctx, 'dim', pos=1, default=tuple(range(1, input.ndim))) 15 | keepdim = get_arg(ctx, 'keepdim', pos=2, default=False) 16 | input_trt = trt_(ctx.network, input) 17 | output = ctx.method_return 18 | layer = ctx.network.add_reduce(input_trt, trt.ReduceOperation.PROD, 19 | torch_dim_to_trt_axes(dim), keepdim) 20 | output._trt = layer.get_output(0) 21 | 22 | 23 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3)]) 24 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3)]) 25 | def test_prod_reduce_all(): 26 | return UnaryModule(lambda x: torch.prod(x)) 27 | 28 | 29 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3)]) 30 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3)]) 31 | def test_prod_reduce_dim1(): 32 | return UnaryModule(lambda x: torch.prod(x, 1)) 33 | 34 | 35 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3)]) 36 | def test_prod_reduce_dim22(): 37 | return UnaryModule(lambda x: torch.prod(x, 2)) 38 | 39 | 40 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3)]) 41 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3)]) 42 | def test_prod_reduce_dim1_keepdim(): 43 | return UnaryModule(lambda x: torch.prod(x, 1, keepdim=True)) 44 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/relu.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch2trt_dynamic.torch2trt_dynamic import tensorrt_converter 3 | 4 | from .ReLU import convert_ReLU 5 | 6 | 7 | @tensorrt_converter('torch.relu') 8 | @tensorrt_converter('torch.relu_') 9 | @tensorrt_converter('torch.nn.functional.relu') 10 | @tensorrt_converter('torch.nn.functional.relu_') 11 | def convert_relu(ctx): 12 | ctx.method_args = (torch.nn.ReLU(), ) + ctx.method_args 13 | convert_ReLU(ctx) 14 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/relu6.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch2trt_dynamic.torch2trt_dynamic import tensorrt_converter 3 | 4 | from .ReLU6 import convert_ReLU6 5 | 6 | 7 | @tensorrt_converter('torch.nn.functional.relu6') 8 | def convert_relu6(ctx): 9 | ctx.method_args = (torch.nn.ReLU6(), ) + ctx.method_args 10 | convert_ReLU6(ctx) 11 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/repeat.py: -------------------------------------------------------------------------------- 1 | import tensorrt as trt 2 | import torch 3 | from torch2trt_dynamic.torch2trt_dynamic import (get_arg, slice_shape_trt, 4 | tensor_trt_get_shape_trt, 5 | tensorrt_converter, trt_) 6 | 7 | 8 | def _unsqueeze_input(ctx, input_trt, dim): 9 | if dim == len(input_trt.shape): 10 | return input_trt 11 | ones_trt = trt_(ctx.network, 12 | torch.ones(dim - len(input_trt.shape), dtype=torch.int32)) 13 | input_shape_trt = tensor_trt_get_shape_trt(ctx.network, input_trt) 14 | input_shape_trt = ctx.network.add_concatenation( 15 | [ones_trt, input_shape_trt]).get_output(0) 16 | layer = ctx.network.add_shuffle(input_trt) 17 | layer.set_input(1, input_shape_trt) 18 | input_trt = layer.get_output(0) 19 | return input_trt 20 | 21 | 22 | def _convert_repeat_impl(ctx, input_trt, output_shape_trt): 23 | dim = output_shape_trt.shape[0] 24 | 25 | if len(input_trt.shape) < dim: 26 | input_trt = _unsqueeze_input(ctx, input_trt, dim) 27 | 28 | zeros_trt = trt_(ctx.network, torch.zeros(dim, dtype=torch.int32)) 29 | ones_trt = trt_(ctx.network, torch.ones(dim, dtype=torch.int32)) 30 | 31 | layer = ctx.network.add_slice(input_trt, [0] * dim, [1] * dim, [1] * dim) 32 | layer.set_input(1, zeros_trt) 33 | layer.set_input(2, output_shape_trt) 34 | layer.set_input(3, ones_trt) 35 | layer.mode = trt.SliceMode.WRAP 36 | 37 | output_trt = layer.get_output(0) 38 | 39 | return output_trt 40 | 41 | 42 | @tensorrt_converter('torch.Tensor.repeat') 43 | def convert_repeat(ctx): 44 | input = ctx.method_args[0] 45 | repeats = ctx.method_args[1] 46 | if isinstance(repeats, int): 47 | repeats = ctx.method_args[1:] 48 | 49 | output = ctx.method_return 50 | 51 | input_trt = trt_(ctx.network, input) 52 | input_trt = _unsqueeze_input(ctx, input_trt, len(repeats)) 53 | # compute output shape 54 | input_shape_trt = tensor_trt_get_shape_trt(ctx.network, input_trt) 55 | repeat_times_trt = [trt_(ctx.network, rep) for rep in repeats] 56 | repeat_times_trt = ctx.network.add_concatenation( 57 | repeat_times_trt).get_output(0) 58 | 59 | output_shape_trt = ctx.network.add_elementwise( 60 | input_shape_trt, repeat_times_trt, 61 | trt.ElementWiseOperation.PROD).get_output(0) 62 | 63 | # convert repeat 64 | output_trt = _convert_repeat_impl(ctx, input_trt, output_shape_trt) 65 | 66 | output._trt = output_trt 67 | 68 | 69 | @tensorrt_converter('torch.Tensor.expand') 70 | def convert_expand(ctx): 71 | input = ctx.method_args[0] 72 | if isinstance(ctx.method_args[1], int): 73 | sizes = ctx.method_args[1:] 74 | else: 75 | sizes = ctx.method_args[1] 76 | 77 | output = ctx.method_return 78 | 79 | input_trt = trt_(ctx.network, input) 80 | 81 | dim = len(sizes) 82 | 83 | # unsqueeze if necessary 84 | if len(input_trt.shape) < dim: 85 | input_trt = _unsqueeze_input(ctx, input_trt, dim) 86 | input_shape_trt = tensor_trt_get_shape_trt(ctx.network, input_trt) 87 | 88 | # compute output shape 89 | output_shape_trt = [] 90 | for i, s in enumerate(sizes): 91 | if s > 0: 92 | output_shape_trt.append(trt_(ctx.network, s)) 93 | else: 94 | output_shape_trt.append( 95 | slice_shape_trt(ctx.network, input_shape_trt, i, 1)) 96 | 97 | output_shape_trt = ctx.network.add_concatenation( 98 | output_shape_trt).get_output(0) 99 | 100 | # convert repeat 101 | output_trt = _convert_repeat_impl(ctx, input_trt, output_shape_trt) 102 | 103 | output._trt = output_trt 104 | 105 | 106 | @tensorrt_converter('torch.Tensor.expand_as') 107 | def convert_expand_as(ctx): 108 | input = ctx.method_args[0] 109 | other = get_arg(ctx, 'other', pos=1, default=None) 110 | 111 | input_trt = trt_(ctx.network, input) 112 | other_trt = trt_(ctx.network, other) 113 | output = ctx.method_return 114 | 115 | # compute output shape 116 | output_shape_trt = tensor_trt_get_shape_trt(ctx.network, other_trt) 117 | 118 | # convert repeat 119 | output_trt = _convert_repeat_impl(ctx, input_trt, output_shape_trt) 120 | 121 | output._trt = output_trt 122 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/roi_align.py: -------------------------------------------------------------------------------- 1 | import torchvision.ops # noqa: F401 2 | from torch2trt_dynamic.plugins import create_roiextractor_plugin 3 | from torch2trt_dynamic.torch2trt_dynamic import (get_arg, tensorrt_converter, 4 | trt_) 5 | 6 | 7 | @tensorrt_converter('torchvision.ops.roi_align') 8 | def convert_roi_align(ctx): 9 | 10 | input = get_arg(ctx, 'input', pos=0, default=None) 11 | boxes = get_arg(ctx, 'boxes', pos=1, default=None) 12 | output_size = get_arg(ctx, 'output_size', pos=2, default=7) 13 | spatial_scale = get_arg(ctx, 'spatial_scale', pos=3, default=1.) 14 | sampling_ratio = get_arg(ctx, 'sampling_ratio', pos=4, default=-1) 15 | aligned = get_arg(ctx, 'aligned', pos=5, default=False) 16 | 17 | output = ctx.method_return 18 | 19 | input_trt = trt_(ctx.network, input) 20 | boxes_offset_trt, boxes_trt = trt_(ctx.network, 0.5 / spatial_scale, boxes) 21 | 22 | plugin = create_roiextractor_plugin( 23 | 'roi_align_' + str(id(boxes)), 24 | out_size=output_size, 25 | sample_num=sampling_ratio, 26 | featmap_strides=[1. / spatial_scale], 27 | roi_scale_factor=1., 28 | finest_scale=56, 29 | aligned=1 if aligned else 0) 30 | 31 | custom_layer = ctx.network.add_plugin_v2( 32 | inputs=[boxes_trt, input_trt], plugin=plugin) 33 | 34 | output._trt = custom_layer.get_output(0) 35 | 36 | 37 | @tensorrt_converter('torchvision.ops.RoIAlign.forward') 38 | def convert_RoiAlign(ctx): 39 | module = ctx.method_args[0] 40 | input = get_arg(ctx, 'input', pos=1, default=None) 41 | boxes = get_arg(ctx, 'boxes', pos=2, default=None) 42 | 43 | output_size = module.output_size 44 | spatial_scale = module.spatial_scale 45 | sampling_ratio = module.sampling_ratio 46 | aligned = module.aligned 47 | 48 | old_method_args = ctx.method_args 49 | old_method_kwargs = ctx.method_kwargs 50 | new_method_args = [ 51 | input, boxes, output_size, spatial_scale, sampling_ratio, aligned 52 | ] 53 | new_method_kwargs = {} 54 | ctx.method_args = new_method_args 55 | ctx.method_kwargs = new_method_kwargs 56 | convert_roi_align(ctx) 57 | ctx.method_args = old_method_args 58 | ctx.method_kwargs = old_method_kwargs 59 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/roi_pool.py: -------------------------------------------------------------------------------- 1 | import torchvision.ops # noqa: F401 2 | from torch2trt_dynamic.plugins import create_roipool_plugin 3 | from torch2trt_dynamic.torch2trt_dynamic import (get_arg, tensorrt_converter, 4 | trt_) 5 | 6 | 7 | @tensorrt_converter('torchvision.ops.roi_pool') 8 | def convert_roi_pool(ctx): 9 | input = get_arg(ctx, 'input', pos=0, default=None) 10 | boxes = get_arg(ctx, 'boxes', pos=1, default=None) 11 | output_size = get_arg(ctx, 'output_size', pos=2, default=7) 12 | spatial_scale = get_arg(ctx, 'spatial_scale', pos=3, default=1.) 13 | 14 | output = ctx.method_return 15 | 16 | input_trt = trt_(ctx.network, input) 17 | boxes_trt = trt_(ctx.network, boxes) 18 | 19 | plugin = create_roipool_plugin( 20 | 'roi_pool_' + str(id(boxes)), 21 | out_size=output_size, 22 | featmap_strides=[1. / spatial_scale], 23 | roi_scale_factor=-1, 24 | finest_scale=56) 25 | 26 | custom_layer = ctx.network.add_plugin_v2( 27 | inputs=[boxes_trt, input_trt], plugin=plugin) 28 | 29 | output._trt = custom_layer.get_output(0) 30 | 31 | 32 | @tensorrt_converter('torchvision.ops.RoIPool.forward') 33 | def convert_RoIPool(ctx): 34 | module = ctx.method_args[0] 35 | input = get_arg(ctx, 'input', pos=1, default=None) 36 | boxes = get_arg(ctx, 'boxes', pos=2, default=None) 37 | 38 | output_size = module.output_size 39 | spatial_scale = module.spatial_scale 40 | 41 | old_method_args = ctx.method_args 42 | old_method_kwargs = ctx.method_kwargs 43 | new_method_args = [input, boxes, output_size, spatial_scale] 44 | new_method_kwargs = {} 45 | ctx.method_args = new_method_args 46 | ctx.method_kwargs = new_method_kwargs 47 | convert_roi_pool(ctx) 48 | ctx.method_args = old_method_args 49 | ctx.method_kwargs = old_method_kwargs 50 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/sigmoid.py: -------------------------------------------------------------------------------- 1 | import tensorrt as trt 2 | import torch 3 | from torch2trt_dynamic.module_test import add_module_test 4 | from torch2trt_dynamic.torch2trt_dynamic import tensorrt_converter, trt_ 5 | 6 | 7 | @tensorrt_converter('torch.nn.functional.sigmoid') 8 | @tensorrt_converter('torch.sigmoid') 9 | @tensorrt_converter('torch.Tensor.sigmoid') 10 | def convert_sigmoid(ctx): 11 | input = ctx.method_args[0] 12 | input_trt = trt_(ctx.network, input) 13 | output = ctx.method_return 14 | 15 | layer = ctx.network.add_activation(input_trt, trt.ActivationType.SIGMOID) 16 | output._trt = layer.get_output(0) 17 | 18 | 19 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3, 3)]) 20 | def test_sigmoid_basic(): 21 | return torch.nn.Sigmoid() 22 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/softmax.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch2trt_dynamic.module_test import add_module_test 3 | from torch2trt_dynamic.torch2trt_dynamic import (get_arg, tensorrt_converter, 4 | trt_) 5 | 6 | 7 | @tensorrt_converter('torch.Tensor.softmax') 8 | @tensorrt_converter('torch.softmax') 9 | @tensorrt_converter('torch.nn.functional.softmax') 10 | def convert_softmax(ctx): 11 | 12 | input = ctx.method_args[0] 13 | input_trt = trt_(ctx.network, input) 14 | output = ctx.method_return 15 | 16 | # get dims from args or kwargs 17 | dim = get_arg(ctx, 'dim', pos=1, default=None) 18 | if dim is None: 19 | dim = -1 20 | if dim < 0: 21 | dim = len(input.shape) + dim 22 | 23 | # axes = 1 << (dim - 1) 24 | axes = 1 << dim 25 | 26 | layer = ctx.network.add_softmax(input=input_trt) 27 | layer.axes = axes 28 | 29 | output._trt = layer.get_output(0) 30 | 31 | 32 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3)]) 33 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3, 3)]) 34 | def test_softmax_module(): 35 | return torch.nn.Softmax(1) 36 | 37 | 38 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3, 3)]) 39 | def test_softmax_module_dim2(): 40 | return torch.nn.Softmax(2) 41 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/split.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch2trt_dynamic.module_test import add_module_test 3 | from torch2trt_dynamic.torch2trt_dynamic import (get_arg, tensorrt_converter, 4 | trt_) 5 | 6 | 7 | @tensorrt_converter('torch.split') 8 | @tensorrt_converter('torch.Tensor.split') 9 | def convert_split(ctx): 10 | input = get_arg(ctx, 'input', 0, None) 11 | input_trt = trt_(ctx.network, input) 12 | # we don't need to parse split/chunk (arg 1) 13 | # since we infer size from output tensors 14 | dim = get_arg(ctx, 'dim', 2, 0) 15 | 16 | outputs = ctx.method_return 17 | 18 | assert (dim >= 1) 19 | 20 | start = [0] * len(input.shape) # exclude batch 21 | stride = [1] * len(start) 22 | offset = 0 23 | trt_dim = dim 24 | 25 | # add slice layers 26 | for i, output in enumerate(outputs): 27 | shape = list(output.shape) 28 | start[trt_dim] = offset 29 | layer = ctx.network.add_slice( 30 | input_trt, start=start, shape=shape, stride=stride) 31 | output._trt = layer.get_output(0) 32 | offset = offset + shape[trt_dim] 33 | 34 | 35 | class TorchSplit(torch.nn.Module): 36 | 37 | def __init__(self, *args, **kwargs): 38 | super(TorchSplit, self).__init__() 39 | self.args = args 40 | self.kwargs = kwargs 41 | 42 | def forward(self, x): 43 | return torch.split(x, *self.args, **self.kwargs) 44 | 45 | 46 | class TensorSplit(torch.nn.Module): 47 | 48 | def __init__(self, *args, **kwargs): 49 | super(TensorSplit, self).__init__() 50 | self.args = args 51 | self.kwargs = kwargs 52 | 53 | def forward(self, x): 54 | return x.split(*self.args, **self.kwargs) 55 | 56 | 57 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3)]) 58 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3, 3)]) 59 | def test_torch_split_1_1(): 60 | return TorchSplit(1, 1) 61 | 62 | 63 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3)]) 64 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3, 3)]) 65 | def test_torch_split_2_1(): 66 | return TorchSplit(2, 1) 67 | 68 | 69 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3)]) 70 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3, 3)]) 71 | def test_torch_split_3_1(): 72 | return TorchSplit(3, 1) 73 | 74 | 75 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3)]) 76 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3, 3)]) 77 | def test_torch_split_3_2(): 78 | return TorchSplit(3, 2) 79 | 80 | 81 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3, 3)]) 82 | def test_tensor_split_3_2(): 83 | return TensorSplit(3, 2) 84 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/squeeze.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch2trt_dynamic.torch2trt_dynamic import (get_arg, tensorrt_converter, 3 | trt_) 4 | 5 | from .identity import convert_identity 6 | 7 | 8 | @tensorrt_converter('torch.Tensor.squeeze') 9 | @tensorrt_converter('torch.squeeze') 10 | def convert_squeeze(ctx): 11 | 12 | input = ctx.method_args[0] 13 | dim = get_arg(ctx, 'dim', pos=1, default=None) 14 | if dim is None: 15 | dim = list( 16 | filter(lambda x: input.shape[x] == 1, range(len(input.shape)))) 17 | else: 18 | if input.shape[dim] != 1: 19 | ctx.method_args = [input] 20 | convert_identity(ctx) 21 | return 22 | if dim < 0: 23 | dim = len(input.shape) + dim 24 | dim = [dim] 25 | input_trt = trt_(ctx.network, input) 26 | shape_trt = ctx.network.add_shape(input_trt).get_output(0) 27 | output = ctx.method_return 28 | 29 | reverse_dim = list(filter(lambda x: x not in dim, range(len(input.shape)))) 30 | reverse_dim_trt = trt_( 31 | ctx.network, 32 | torch.tensor(reverse_dim, dtype=torch.int32).to(input.device)) 33 | 34 | new_shape_trt = ctx.network.add_gather(shape_trt, reverse_dim_trt, 35 | 0).get_output(0) 36 | 37 | layer = ctx.network.add_shuffle(input_trt) 38 | layer.set_input(1, new_shape_trt) 39 | output._trt = layer.get_output(0) 40 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/stack.py: -------------------------------------------------------------------------------- 1 | from torch2trt_dynamic.torch2trt_dynamic import get_arg, tensorrt_converter 2 | 3 | from .cat import convert_cat 4 | from .unsqueeze import convert_unsqueeze 5 | 6 | 7 | @tensorrt_converter('torch.stack') 8 | def convert_stack(ctx): 9 | inputs = ctx.method_args[0] 10 | dim = get_arg(ctx, 'dim', pos=1, default=0) 11 | output = ctx.method_return 12 | 13 | unsqueeze_inputs = [] 14 | for input in inputs: 15 | unsqueeze_input = input.unsqueeze(dim=dim) 16 | ctx.method_args = (input, dim) 17 | ctx.method_return = unsqueeze_input 18 | convert_unsqueeze(ctx) 19 | unsqueeze_inputs.append(unsqueeze_input) 20 | 21 | ctx.method_args = (unsqueeze_inputs, dim) 22 | ctx.method_return = output 23 | 24 | convert_cat(ctx) 25 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/sub.py: -------------------------------------------------------------------------------- 1 | import tensorrt as trt 2 | import torch 3 | from torch2trt_dynamic.module_test import add_module_test 4 | from torch2trt_dynamic.torch2trt_dynamic import tensorrt_converter, trt_ 5 | 6 | 7 | @tensorrt_converter('torch.sub') 8 | @tensorrt_converter('torch.Tensor.__isub__') 9 | @tensorrt_converter('torch.Tensor.__sub__') 10 | def convert_sub(ctx): 11 | input_a = ctx.method_args[0] 12 | input_b = ctx.method_args[1] 13 | input_a_trt, input_b_trt = trt_(ctx.network, input_a, input_b) 14 | output = ctx.method_return 15 | layer = ctx.network.add_elementwise(input_a_trt, input_b_trt, 16 | trt.ElementWiseOperation.SUB) 17 | output._trt = layer.get_output(0) 18 | 19 | 20 | @tensorrt_converter('torch.Tensor.__rsub__') 21 | def convert_rsub(ctx): 22 | input_a = ctx.method_args[1] 23 | input_b = ctx.method_args[0] # flipped for rsub 24 | input_a_trt, input_b_trt = trt_(ctx.network, input_a, input_b) 25 | output = ctx.method_return 26 | layer = ctx.network.add_elementwise(input_a_trt, input_b_trt, 27 | trt.ElementWiseOperation.SUB) 28 | output._trt = layer.get_output(0) 29 | 30 | 31 | class Sub(torch.nn.Module): 32 | 33 | def __init__(self): 34 | super(Sub, self).__init__() 35 | 36 | def forward(self, x, y): 37 | return x - y 38 | 39 | 40 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224), 41 | (1, 3, 224, 224)]) 42 | def test_sub_basic(): 43 | return Sub() 44 | 45 | 46 | class ISub(torch.nn.Module): 47 | 48 | def __init__(self): 49 | super(ISub, self).__init__() 50 | 51 | def forward(self, x, y): 52 | x -= y 53 | return x 54 | 55 | 56 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224), 57 | (1, 3, 224, 224)]) 58 | def test_sub_isub(): 59 | return ISub() 60 | 61 | 62 | class TorchSub(torch.nn.Module): 63 | 64 | def __init__(self): 65 | super(TorchSub, self).__init__() 66 | 67 | def forward(self, x, y): 68 | return torch.sub(x, y) 69 | 70 | 71 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224), 72 | (1, 3, 224, 224)]) 73 | def test_torch_sub(): 74 | return TorchSub() 75 | 76 | 77 | class RSubInt(torch.nn.Module): 78 | 79 | def __init__(self): 80 | super(RSubInt, self).__init__() 81 | 82 | def forward(self, x): 83 | return 1 - x 84 | 85 | 86 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224)]) 87 | def test_rsub_int(): 88 | return RSubInt() 89 | 90 | 91 | class RSubFloat(torch.nn.Module): 92 | 93 | def __init__(self): 94 | super(RSubFloat, self).__init__() 95 | 96 | def forward(self, x): 97 | return 1.0 - x 98 | 99 | 100 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 224, 224)]) 101 | def test_rsub_float(): 102 | return RSubFloat() 103 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/sum.py: -------------------------------------------------------------------------------- 1 | import tensorrt as trt 2 | import torch 3 | from torch2trt_dynamic.module_test import add_module_test 4 | from torch2trt_dynamic.torch2trt_dynamic import (get_arg, tensorrt_converter, 5 | torch_dim_to_trt_axes, trt_) 6 | 7 | from .unary import UnaryModule 8 | 9 | 10 | @tensorrt_converter('torch.sum') 11 | @tensorrt_converter('torch.Tensor.sum') 12 | def convert_sum(ctx): 13 | input = ctx.method_args[0] 14 | dim = get_arg(ctx, 'dim', pos=1, default=tuple(range(1, input.ndim))) 15 | keepdim = get_arg(ctx, 'keepdim', pos=2, default=False) 16 | if dim < 0: 17 | dim = input.dim() + dim 18 | input_trt = trt_(ctx.network, input) 19 | output = ctx.method_return 20 | layer = ctx.network.add_reduce(input_trt, trt.ReduceOperation.SUM, 21 | torch_dim_to_trt_axes(dim), keepdim) 22 | output._trt = layer.get_output(0) 23 | 24 | 25 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3)]) 26 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3)]) 27 | def test_sum_reduce_all(): 28 | return UnaryModule(lambda x: torch.sum(x)) 29 | 30 | 31 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3)]) 32 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3)]) 33 | def test_sum_reduce_dim1(): 34 | return UnaryModule(lambda x: torch.sum(x, 1)) 35 | 36 | 37 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3)]) 38 | def test_sum_reduce_dim22(): 39 | return UnaryModule(lambda x: torch.sum(x, 2)) 40 | 41 | 42 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3)]) 43 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3)]) 44 | def test_sum_reduce_dim1_keepdim(): 45 | return UnaryModule(lambda x: torch.sum(x, 1, keepdim=True)) 46 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/t.py: -------------------------------------------------------------------------------- 1 | from torch2trt_dynamic.torch2trt_dynamic import tensorrt_converter, trt_ 2 | 3 | from .transpose import convert_transpose 4 | 5 | 6 | @tensorrt_converter('torch.Tensor.t') 7 | def convert_t(ctx): 8 | input = ctx.method_args[0] 9 | input_trt = trt_(ctx.network, input) 10 | output = ctx.method_return 11 | # permutation -1 because TRT does not include batch dim 12 | 13 | if len(input.shape) == 1: 14 | layer = ctx.network.add_identity(input_trt) 15 | output._trt = layer.get_output(0) 16 | else: 17 | ctx.method_args = [input, 1, 0] 18 | ctx.method_kwargs = {} 19 | convert_transpose(ctx) 20 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/take.py: -------------------------------------------------------------------------------- 1 | from torch2trt_dynamic.torch2trt_dynamic import (get_arg, tensorrt_converter, 2 | trt_) 3 | 4 | 5 | @tensorrt_converter('torch.take') 6 | def convert_take(ctx): 7 | input = ctx.method_args[0] 8 | index = get_arg(ctx, 'index', pos=1, default=None) 9 | 10 | input_trt = trt_(ctx.network, input) 11 | index_trt = trt_(ctx.network, index) 12 | output = ctx.method_return 13 | 14 | # flatten input 15 | layer = ctx.network.add_shuffle(input_trt) 16 | layer.reshape_dims = (-1, ) 17 | flatten_input_trt = layer.get_output(0) 18 | 19 | # flatten index 20 | output_trt = ctx.network.add_gather(flatten_input_trt, index_trt, 21 | 0).get_output(0) 22 | 23 | output._trt = output_trt 24 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/tanh.py: -------------------------------------------------------------------------------- 1 | import tensorrt as trt 2 | import torch 3 | from torch2trt_dynamic.module_test import add_module_test 4 | from torch2trt_dynamic.torch2trt_dynamic import tensorrt_converter, trt_ 5 | 6 | 7 | @tensorrt_converter('torch.nn.functional.tanh') 8 | @tensorrt_converter('torch.tanh') 9 | def convert_tanh(ctx): 10 | input = ctx.method_args[0] 11 | input_trt = trt_(ctx.network, input) 12 | output = ctx.method_return 13 | 14 | layer = ctx.network.add_activation(input_trt, trt.ActivationType.TANH) 15 | output._trt = layer.get_output(0) 16 | 17 | 18 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3, 3)]) 19 | def test_tanh_basic(): 20 | return torch.nn.Tanh() 21 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/to.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch2trt_dynamic.torch2trt_dynamic import (tensorrt_converter, trt_, 3 | trt_cast) 4 | 5 | 6 | @tensorrt_converter('torch.Tensor.to') 7 | def convert_Tensor_to(ctx): 8 | input = ctx.method_args[0] 9 | output = ctx.method_return 10 | 11 | input_trt = trt_(ctx.network, input) 12 | if output.dtype == input.dtype: 13 | output._trt = input_trt 14 | else: 15 | data_type = output.dtype 16 | if data_type == torch.int64: 17 | data_type = torch.int32 18 | 19 | output_trt = trt_cast(ctx.network, input_trt, data_type) 20 | output._trt = output_trt 21 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/topk.py: -------------------------------------------------------------------------------- 1 | import tensorrt as trt 2 | import torch 3 | from torch2trt_dynamic.module_test import add_module_test 4 | from torch2trt_dynamic.torch2trt_dynamic import (bind_arguments, 5 | tensorrt_converter, trt_) 6 | 7 | from .size import IntWarper 8 | 9 | 10 | def _dummy_topk(input, k, dim=None, largest=True, sorted=True, *, out=None): 11 | pass 12 | 13 | 14 | @tensorrt_converter('torch.topk') 15 | @tensorrt_converter('torch.Tensor.topk') 16 | def convert_topk(ctx): 17 | arguments = bind_arguments(_dummy_topk, ctx) 18 | input = arguments['input'] 19 | k = arguments['k'] 20 | dim = arguments['dim'] 21 | largest = arguments['largest'] 22 | 23 | if dim is None: 24 | dim = len(input.shape) - 1 25 | if dim < 0: 26 | dim = len(input.shape) + dim 27 | 28 | def __add_unsqueeze_layer(input_trt, dim): 29 | layer = ctx.network.add_shuffle(input_trt) 30 | layer.reshape_dims = (1, ) + tuple(input_trt.shape) 31 | input_trt = layer.get_output(0) 32 | dim += 1 33 | return input_trt, dim 34 | 35 | def __add_topk_layer(k, dim): 36 | topkOp = trt.TopKOperation.MAX if largest else trt.TopKOperation.MIN 37 | 38 | k_trt = None 39 | if isinstance(k, IntWarper): 40 | k_trt = trt_(ctx.network, k) 41 | layer = ctx.network.add_shuffle(k_trt) 42 | layer.reshape_dims = tuple() 43 | k_trt = layer.get_output(0) 44 | 45 | if isinstance(k, int) and k > 3840: 46 | print('Clamp k to 3840.') 47 | k = 3840 48 | 49 | layer = ctx.network.add_topk(input_trt, topkOp, k, 1 << dim) 50 | 51 | if k_trt is not None: 52 | layer.set_input(1, k_trt) 53 | 54 | output0_trt = layer.get_output(0) 55 | output1_trt = layer.get_output(1) 56 | return output0_trt, output1_trt 57 | 58 | def __add_squeeze_layer(output_trt): 59 | layer = ctx.network.add_shuffle(output_trt) 60 | layer.reshape_dims = tuple(output_trt.shape)[1:] 61 | return layer.get_output(0) 62 | 63 | input_trt = trt_(ctx.network, input) 64 | output = ctx.method_return 65 | 66 | # can only use topk on dim>=2 67 | need_unsqueeze = len(input_trt.shape) == 1 68 | if need_unsqueeze: 69 | input_trt, dim = __add_unsqueeze_layer(input_trt, dim) 70 | 71 | output0_trt, output1_trt = __add_topk_layer(k, dim) 72 | 73 | # recovery 74 | if need_unsqueeze: 75 | output0_trt = __add_squeeze_layer(output0_trt) 76 | output1_trt = __add_squeeze_layer(output1_trt) 77 | 78 | output[0]._trt = output0_trt 79 | output[1]._trt = output1_trt 80 | 81 | 82 | class TopkTestModule(torch.nn.Module): 83 | 84 | def __init__(self, k, dim, largest): 85 | super(TopkTestModule, self).__init__() 86 | self.k = k 87 | self.dim = dim 88 | self.largest = largest 89 | 90 | def forward(self, x): 91 | return x.topk(k=self.k, dim=self.dim, largest=self.largest) 92 | 93 | 94 | @add_module_test( 95 | torch.float32, 96 | torch.device('cuda'), [(1, 20, 4, 6)], 97 | max_workspace_size=1 << 20) 98 | @add_module_test( 99 | torch.float32, 100 | torch.device('cuda'), [(1, 20, 6)], 101 | max_workspace_size=1 << 20) 102 | @add_module_test( 103 | torch.float32, torch.device('cuda'), [(1, 20)], max_workspace_size=1 << 20) 104 | def test_topk_dim1(): 105 | return TopkTestModule(10, 1, True) 106 | 107 | 108 | @add_module_test( 109 | torch.float32, 110 | torch.device('cuda'), [(1, 4, 20, 6)], 111 | max_workspace_size=1 << 20) 112 | @add_module_test( 113 | torch.float32, 114 | torch.device('cuda'), [(1, 6, 20)], 115 | max_workspace_size=1 << 20) 116 | def test_topk_dim2(): 117 | return TopkTestModule(10, 2, True) 118 | 119 | 120 | @add_module_test( 121 | torch.float32, 122 | torch.device('cuda'), [(1, 20, 4, 6)], 123 | max_workspace_size=1 << 20) 124 | @add_module_test( 125 | torch.float32, 126 | torch.device('cuda'), [(1, 20, 6)], 127 | max_workspace_size=1 << 20) 128 | @add_module_test( 129 | torch.float32, torch.device('cuda'), [(1, 20)], max_workspace_size=1 << 20) 130 | def test_topk_largest_false(): 131 | return TopkTestModule(10, 1, False) 132 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/transpose.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch2trt_dynamic.module_test import add_module_test 3 | from torch2trt_dynamic.torch2trt_dynamic import tensorrt_converter, trt_ 4 | 5 | 6 | @tensorrt_converter('torch.transpose') 7 | @tensorrt_converter('torch.Tensor.transpose') 8 | def convert_transpose(ctx): 9 | input = ctx.method_args[0] 10 | input_trt = trt_(ctx.network, input) 11 | output = ctx.method_return 12 | # permutation -1 because TRT does not include batch dim 13 | 14 | dim = input.dim() 15 | permutation = list(range(dim)) 16 | dim0 = ctx.method_args[1] 17 | dim1 = ctx.method_args[2] 18 | dim0 = dim0 if dim0 >= 0 else dim + dim0 19 | dim1 = dim1 if dim1 >= 0 else dim + dim1 20 | permutation[dim0] = dim1 21 | permutation[dim1] = dim0 22 | layer = ctx.network.add_shuffle(input_trt) 23 | layer.second_transpose = tuple(permutation) 24 | output._trt = layer.get_output(0) 25 | 26 | 27 | class Transpose(torch.nn.Module): 28 | 29 | def __init__(self, dim0, dim1): 30 | super(Transpose, self).__init__() 31 | self.dim0 = dim0 32 | self.dim1 = dim1 33 | 34 | def forward(self, x): 35 | return torch.transpose(x, self.dim0, self.dim1).contiguous() 36 | 37 | 38 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3)]) 39 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3, 3)]) 40 | def test_transpose_12(): 41 | return Transpose(1, 2) 42 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/unfold.py: -------------------------------------------------------------------------------- 1 | from torch2trt_dynamic.plugins import create_torchunfold_plugin 2 | from torch2trt_dynamic.torch2trt_dynamic import (get_arg, tensorrt_converter, 3 | trt_) 4 | 5 | 6 | @tensorrt_converter('torch.nn.functional.unfold') 7 | def convert_unfold(ctx): 8 | input = ctx.method_args[0] 9 | kernel_size = get_arg(ctx, 'kernel_size', pos=1, default=0) 10 | dilation = get_arg(ctx, 'dilation', pos=2, default=1) 11 | padding = get_arg(ctx, 'padding', pos=3, default=0) 12 | stride = get_arg(ctx, 'stride', pos=4, default=1) 13 | output = ctx.method_return 14 | input_trt = trt_(ctx.network, input) 15 | 16 | plugin = create_torchunfold_plugin( 17 | 'unfold_' + str(id(input)), 18 | kernel_size=kernel_size, 19 | dilation=dilation, 20 | padding=padding, 21 | stride=stride) 22 | 23 | layer = ctx.network.add_plugin_v2(inputs=[input_trt], plugin=plugin) 24 | 25 | output._trt = layer.get_output(0) 26 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/unsqueeze.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch2trt_dynamic.torch2trt_dynamic import (get_arg, tensorrt_converter, 3 | trt_) 4 | 5 | 6 | @tensorrt_converter('torch.Tensor.unsqueeze') 7 | @tensorrt_converter('torch.unsqueeze') 8 | def convert_unsqueeze(ctx): 9 | 10 | input = ctx.method_args[0] 11 | dim = get_arg(ctx, 'dim', pos=1, default=None) 12 | if dim < 0: 13 | dim = len(input.shape) + dim + 1 14 | input_trt = trt_(ctx.network, input) 15 | shape_trt = ctx.network.add_shape(input_trt).get_output(0) 16 | unsqueeze_trt = trt_(ctx.network, input.new_ones((1), dtype=torch.int32)) 17 | output = ctx.method_return 18 | 19 | shape1_trt = None 20 | shape2_trt = None 21 | if dim == 0: 22 | shape2_trt = shape_trt 23 | elif dim == len(input.shape): 24 | shape1_trt = shape_trt 25 | else: 26 | slice1_start = [0] 27 | slice1_size = [dim] 28 | slice1_stride = [1] 29 | shape1_trt = ctx.network.add_slice(shape_trt, slice1_start, 30 | slice1_size, 31 | slice1_stride).get_output(0) 32 | slice2_start = [dim] 33 | slice2_size = [len(input.shape) - dim] 34 | slice2_stride = [1] 35 | shape2_trt = ctx.network.add_slice(shape_trt, slice2_start, 36 | slice2_size, 37 | slice2_stride).get_output(0) 38 | 39 | if shape1_trt is None: 40 | new_shape_trt = ctx.network.add_concatenation( 41 | [unsqueeze_trt, shape2_trt]).get_output(0) 42 | elif shape2_trt is None: 43 | new_shape_trt = ctx.network.add_concatenation( 44 | [shape1_trt, unsqueeze_trt]).get_output(0) 45 | else: 46 | new_shape_trt = ctx.network.add_concatenation( 47 | [shape1_trt, unsqueeze_trt, shape2_trt]).get_output(0) 48 | 49 | layer = ctx.network.add_shuffle(input_trt) 50 | layer.set_input(1, new_shape_trt) 51 | output._trt = layer.get_output(0) 52 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/view.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch2trt_dynamic.module_test import add_module_test 3 | from torch2trt_dynamic.torch2trt_dynamic import (get_arg, tensorrt_converter, 4 | trt_, trt_cast) 5 | 6 | from .size import IntWarper 7 | 8 | 9 | @tensorrt_converter('torch.Tensor.reshape') 10 | @tensorrt_converter('torch.Tensor.view') 11 | def convert_view(ctx): 12 | 13 | input = ctx.method_args[0] 14 | size = get_arg(ctx, 'shape', pos=1, default=[]) 15 | if isinstance(size, int): 16 | size = tuple(ctx.method_args[1:]) 17 | input_trt = trt_(ctx.network, input) 18 | output = ctx.method_return 19 | 20 | if input.dtype == torch.bool: 21 | input_trt = trt_cast(ctx.network, input_trt, torch.int32) 22 | 23 | # check if there are shape tensor 24 | is_shape_tensor = False 25 | for s in size: 26 | if isinstance(s, IntWarper): 27 | is_shape_tensor = True 28 | break 29 | 30 | # negative shape might cause overflow, forbid for now 31 | for s in size: 32 | if s < 0: 33 | is_shape_tensor = True 34 | break 35 | 36 | # compute shape tensor 37 | if is_shape_tensor: 38 | shape_trt = [] 39 | for idx, s in enumerate(size): 40 | if isinstance(s, IntWarper): 41 | shape_trt.append(s._trt) 42 | else: 43 | const_shape_trt = trt_( 44 | ctx.network, input.new_tensor([s], dtype=torch.int32)) 45 | shape_trt.append(const_shape_trt) 46 | 47 | shape_trt = ctx.network.add_concatenation(shape_trt).get_output(0) 48 | 49 | layer = ctx.network.add_shuffle(input_trt) 50 | if is_shape_tensor: 51 | layer.set_input(1, shape_trt) 52 | else: 53 | layer.reshape_dims = output.shape 54 | 55 | output_trt = layer.get_output(0) 56 | 57 | if input.dtype == torch.bool: 58 | output_trt = trt_cast(ctx.network, output_trt, torch.bool) 59 | 60 | output._trt = output_trt 61 | 62 | 63 | class View(torch.nn.Module): 64 | 65 | def __init__(self, *dims): 66 | super(View, self).__init__() 67 | self.dims = dims 68 | 69 | def forward(self, x): 70 | return x.view(*self.dims) 71 | 72 | 73 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3)]) 74 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3)]) 75 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3, 3)]) 76 | def test_view_1d(): 77 | return View(1, -1) 78 | 79 | 80 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3)]) 81 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3)]) 82 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3, 3)]) 83 | def test_view_2d(): 84 | return View(1, 1, -1) 85 | 86 | 87 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3)]) 88 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3)]) 89 | @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3, 3)]) 90 | def test_view_3d(): 91 | return View(1, 1, 1, -1) 92 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/view_as.py: -------------------------------------------------------------------------------- 1 | from torch2trt_dynamic.torch2trt_dynamic import (get_arg, tensorrt_converter, 2 | trt_) 3 | 4 | 5 | @tensorrt_converter('torch.Tensor.view_as') 6 | def convert_view_as(ctx): 7 | 8 | input = ctx.method_args[0] 9 | other = get_arg(ctx, 'other', pos=1, default=None) 10 | input_trt = trt_(ctx.network, input) 11 | other_trt = trt_(ctx.network, other) 12 | output = ctx.method_return 13 | 14 | shape_trt = ctx.network.add_shape(other_trt).get_output(0) 15 | 16 | layer = ctx.network.add_shuffle(input_trt) 17 | layer.set_input(1, shape_trt) 18 | output._trt = layer.get_output(0) 19 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/where.py: -------------------------------------------------------------------------------- 1 | from torch2trt_dynamic.torch2trt_dynamic import (get_arg, tensorrt_converter, 2 | trt_) 3 | 4 | 5 | @tensorrt_converter('torch.where') 6 | def convert_where(ctx): 7 | condition = get_arg(ctx, 'condition', pos=0, default=None) 8 | x = get_arg(ctx, 'x', pos=1, default=None) 9 | y = get_arg(ctx, 'y', pos=2, default=None) 10 | 11 | condition_trt = trt_(ctx.network, condition) 12 | x_trt = trt_(ctx.network, x) 13 | y_trt = trt_(ctx.network, y) 14 | output = ctx.method_return 15 | 16 | layer = ctx.network.add_select(condition_trt, x_trt, y_trt) 17 | output_trt = layer.get_output(0) 18 | 19 | output._trt = output_trt 20 | 21 | 22 | @tensorrt_converter('torch.Tensor.where') 23 | def convert_Tensor_where(ctx): 24 | x = ctx.method_args[0] 25 | condition = get_arg(ctx, 'condition', pos=1, default=None) 26 | y = get_arg(ctx, 'y', pos=2, default=None) 27 | 28 | ctx.method_args = [condition, x, y] 29 | ctx.method_kwargs = {} 30 | convert_where(ctx) 31 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/zeros.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Iterable 2 | 3 | import tensorrt as trt 4 | import torch 5 | from torch2trt_dynamic.torch2trt_dynamic import tensorrt_converter, trt_ 6 | 7 | 8 | @tensorrt_converter('torch.zeros') 9 | def convert_zeros(ctx): 10 | size = ctx.method_args[0] 11 | if not isinstance(size, Iterable): 12 | size = ctx.method_args 13 | dtype = torch.float32 14 | if 'dtype' in ctx.method_kwargs: 15 | dtype = ctx.method_kwargs['dtype'] 16 | output = ctx.method_return 17 | 18 | if isinstance(size, int): 19 | size = (size, ) 20 | 21 | # check const 22 | is_const = True 23 | for s in size: 24 | if hasattr(s, '_trt'): 25 | is_const = False 26 | break 27 | 28 | if is_const: 29 | # create const value 30 | output_trt = trt_(ctx.network, output) 31 | 32 | else: 33 | # create fill 34 | trt_size = [] 35 | for s in size: 36 | if hasattr(s, '_trt'): 37 | trt_size.append(s._trt) 38 | else: 39 | trt_size.append(trt_(ctx.network, s)) 40 | 41 | trt_size = ctx.network.add_concatenation(trt_size).get_output(0) 42 | 43 | layer = ctx.network.add_fill(size, trt.FillOperation.RANDOM_UNIFORM) 44 | layer.set_input(0, trt_size) 45 | layer.set_input( 46 | 1, trt_(ctx.network, 47 | torch.tensor(0., dtype=dtype).cuda())) 48 | layer.set_input( 49 | 2, trt_(ctx.network, 50 | torch.tensor(0., dtype=dtype).cuda())) 51 | 52 | output_trt = layer.get_output(0) 53 | 54 | data_type = None 55 | if dtype == torch.float32: 56 | data_type = trt.DataType.FLOAT 57 | elif dtype == torch.int32 or dtype == torch.long: 58 | data_type = trt.DataType.INT32 59 | elif dtype == torch.bool: 60 | data_type = trt.DataType.BOOL 61 | else: 62 | print('unsupported convert type:{}'.format(dtype)) 63 | 64 | if data_type is not None: 65 | layer = ctx.network.add_identity(output_trt) 66 | layer.set_output_type(0, data_type) 67 | output_trt = layer.get_output(0) 68 | 69 | output._trt = output_trt 70 | -------------------------------------------------------------------------------- /torch2trt_dynamic/converters/zeros_like.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch2trt_dynamic.torch2trt_dynamic import get_arg, tensorrt_converter 3 | 4 | from .add import convert_add 5 | from .cast_type import convert_bool, convert_float, convert_int 6 | from .mul import convert_mul 7 | 8 | 9 | @tensorrt_converter('torch.zeros_like') 10 | def convert_zeros_like(ctx): 11 | input = ctx.method_args[0] 12 | dtype = get_arg(ctx, 'dtype', pos=1, default=torch.float32) 13 | output = ctx.method_return 14 | 15 | old_method_args = ctx.method_args 16 | old_method_kwargs = ctx.method_kwargs 17 | 18 | # mul zero 19 | input_mul_zero = input * 0 20 | ctx.method_args = [input, 0] 21 | ctx.method_kwargs = {} 22 | ctx.method_return = input_mul_zero 23 | convert_mul(ctx) 24 | 25 | convert_type_func = None 26 | if dtype == torch.float32: 27 | convert_type_func = convert_float 28 | elif dtype == torch.int32 or dtype == torch.long: 29 | convert_type_func = convert_int 30 | elif dtype == torch.bool: 31 | convert_type_func = convert_bool 32 | else: 33 | print('unsupported convert type:{}'.format(dtype)) 34 | 35 | if convert_type_func is not None: 36 | input_as_type = input_mul_zero.to(dtype) 37 | ctx.method_args = [input_mul_zero, dtype] 38 | ctx.method_return = input_as_type 39 | convert_type_func(ctx) 40 | ctx.method_args = [input_as_type, 0] 41 | ctx.method_kwargs = {} 42 | ctx.method_return = output 43 | convert_add(ctx) 44 | 45 | ctx.method_args = old_method_args 46 | ctx.method_kwargs = old_method_kwargs 47 | ctx.method_return = output 48 | -------------------------------------------------------------------------------- /torch2trt_dynamic/module_test.py: -------------------------------------------------------------------------------- 1 | class ModuleTest(object): 2 | 3 | def __init__(self, module_fn, dtype, device, input_shapes, 4 | **torch2trt_kwargs): 5 | self.module_fn = module_fn 6 | self.dtype = dtype 7 | self.device = device 8 | self.input_shapes = input_shapes 9 | self.torch2trt_kwargs = torch2trt_kwargs 10 | 11 | def module_name(self): 12 | return self.module_fn.__module__ + '.' + self.module_fn.__name__ 13 | 14 | 15 | MODULE_TESTS = [] 16 | 17 | 18 | def add_module_test(dtype, device, input_shapes, **torch2trt_kwargs): 19 | 20 | def register_module_test(module): 21 | global MODULE_TESTS 22 | MODULE_TESTS += [ 23 | ModuleTest(module, dtype, device, input_shapes, **torch2trt_kwargs) 24 | ] 25 | return module 26 | 27 | return register_module_test 28 | -------------------------------------------------------------------------------- /torch2trt_dynamic/plugins/__init__.py: -------------------------------------------------------------------------------- 1 | from .create_adaptivepool_plugin import create_adaptivepool_plugin 2 | from .create_dcn_plugin import create_dcn_plugin 3 | from .create_groupnorm_plugin import create_groupnorm_plugin 4 | from .create_nms_plugin import create_nms_plugin 5 | from .create_roiextractor_plugin import create_roiextractor_plugin 6 | from .create_roipool_plugin import create_roipool_plugin 7 | from .create_torchbmm_plugin import create_torchbmm_plugin 8 | from .create_torchcum_plugin import create_torchcum_plugin 9 | from .create_torchcummaxmin_plugin import create_torchcummaxmin_plugin 10 | from .create_torchunfold_plugin import create_torchunfold_plugin 11 | from .globals import load_plugin_library 12 | 13 | __all__ = [ 14 | 'create_groupnorm_plugin', 'create_adaptivepool_plugin', 15 | 'create_torchcummaxmin_plugin', 'create_torchcum_plugin', 16 | 'create_dcn_plugin', 'create_nms_plugin', 'create_roiextractor_plugin', 17 | 'create_roipool_plugin', 'create_torchbmm_plugin', 18 | 'create_torchunfold_plugin' 19 | ] 20 | 21 | load_plugin_library() 22 | -------------------------------------------------------------------------------- /torch2trt_dynamic/plugins/create_adaptivepool_plugin.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorrt as trt 3 | 4 | 5 | def create_adaptivepool_plugin(layer_name, output_size, pooling_type): 6 | 7 | creator = trt.get_plugin_registry().get_plugin_creator( 8 | 'AdaptivePoolPluginDynamic', '1', '') 9 | 10 | pfc = trt.PluginFieldCollection() 11 | 12 | pf_output_size = trt.PluginField('output_size', 13 | np.array(output_size, dtype=np.int32), 14 | trt.PluginFieldType.INT32) 15 | pfc.append(pf_output_size) 16 | 17 | pf_pooling_type = trt.PluginField( 18 | 'pooling_type', np.array([int(pooling_type)], dtype=np.int32), 19 | trt.PluginFieldType.INT32) 20 | pfc.append(pf_pooling_type) 21 | 22 | return creator.create_plugin(layer_name, pfc) 23 | -------------------------------------------------------------------------------- /torch2trt_dynamic/plugins/create_groupnorm_plugin.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorrt as trt 3 | 4 | 5 | def create_groupnorm_plugin(layer_name, num_groups, eps=1e-5): 6 | 7 | creator = trt.get_plugin_registry().get_plugin_creator( 8 | 'GroupNormPluginDynamic', '1', '') 9 | 10 | pfc = trt.PluginFieldCollection() 11 | 12 | pf_num_groups = trt.PluginField('num_groups', 13 | np.array([num_groups], dtype=np.int32), 14 | trt.PluginFieldType.INT32) 15 | pfc.append(pf_num_groups) 16 | 17 | pf_eps = trt.PluginField('eps', np.array([eps], dtype=np.float32), 18 | trt.PluginFieldType.FLOAT32) 19 | pfc.append(pf_eps) 20 | 21 | return creator.create_plugin(layer_name, pfc) 22 | -------------------------------------------------------------------------------- /torch2trt_dynamic/plugins/create_nms_plugin.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorrt as trt 3 | 4 | 5 | def create_nms_plugin(layer_name, iou_threshold): 6 | 7 | creator = trt.get_plugin_registry().get_plugin_creator( 8 | 'TorchNMSPluginDynamic', '1', '') 9 | 10 | pfc = trt.PluginFieldCollection() 11 | 12 | pf_iou_threshold = trt.PluginField( 13 | 'iou_threshold', np.array([iou_threshold], dtype=np.float32), 14 | trt.PluginFieldType.FLOAT32) 15 | pfc.append(pf_iou_threshold) 16 | 17 | return creator.create_plugin(layer_name, pfc) 18 | -------------------------------------------------------------------------------- /torch2trt_dynamic/plugins/create_roiextractor_plugin.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorrt as trt 3 | 4 | 5 | def create_roiextractor_plugin(layer_name, out_size, sample_num, 6 | featmap_strides, roi_scale_factor, finest_scale, 7 | aligned): 8 | 9 | creator = trt.get_plugin_registry().get_plugin_creator( 10 | 'RoiExtractorPluginDynamic', '1', '') 11 | 12 | pfc = trt.PluginFieldCollection() 13 | 14 | pf_out_size = trt.PluginField('out_size', 15 | np.array([out_size], dtype=np.int32), 16 | trt.PluginFieldType.INT32) 17 | pfc.append(pf_out_size) 18 | 19 | pf_sample_num = trt.PluginField('sample_num', 20 | np.array([sample_num], dtype=np.int32), 21 | trt.PluginFieldType.INT32) 22 | pfc.append(pf_sample_num) 23 | 24 | pf_featmap_strides = trt.PluginField( 25 | 'featmap_strides', 26 | np.array(featmap_strides).astype(np.float32), 27 | trt.PluginFieldType.FLOAT32) 28 | pfc.append(pf_featmap_strides) 29 | 30 | pf_roi_scale_factor = trt.PluginField( 31 | 'roi_scale_factor', np.array([roi_scale_factor], dtype=np.float32), 32 | trt.PluginFieldType.FLOAT32) 33 | pfc.append(pf_roi_scale_factor) 34 | 35 | pf_finest_scale = trt.PluginField('finest_scale', 36 | np.array([finest_scale], dtype=np.int32), 37 | trt.PluginFieldType.INT32) 38 | pfc.append(pf_finest_scale) 39 | 40 | pf_aligned = trt.PluginField('aligned', np.array([aligned], 41 | dtype=np.int32), 42 | trt.PluginFieldType.INT32) 43 | pfc.append(pf_aligned) 44 | 45 | return creator.create_plugin(layer_name, pfc) 46 | -------------------------------------------------------------------------------- /torch2trt_dynamic/plugins/create_roipool_plugin.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorrt as trt 3 | 4 | 5 | def create_roipool_plugin(layer_name, out_size, featmap_strides, 6 | roi_scale_factor, finest_scale): 7 | 8 | creator = trt.get_plugin_registry().get_plugin_creator( 9 | 'RoiPoolPluginDynamic', '1', '') 10 | 11 | pfc = trt.PluginFieldCollection() 12 | 13 | pf_out_size = trt.PluginField('out_size', 14 | np.array([out_size], dtype=np.int32), 15 | trt.PluginFieldType.INT32) 16 | pfc.append(pf_out_size) 17 | 18 | pf_featmap_strides = trt.PluginField( 19 | 'featmap_strides', 20 | np.array(featmap_strides).astype(np.float32), 21 | trt.PluginFieldType.FLOAT32) 22 | pfc.append(pf_featmap_strides) 23 | 24 | pf_roi_scale_factor = trt.PluginField( 25 | 'roi_scale_factor', np.array([roi_scale_factor], dtype=np.float32), 26 | trt.PluginFieldType.FLOAT32) 27 | pfc.append(pf_roi_scale_factor) 28 | 29 | pf_finest_scale = trt.PluginField('finest_scale', 30 | np.array([finest_scale], dtype=np.int32), 31 | trt.PluginFieldType.INT32) 32 | pfc.append(pf_finest_scale) 33 | 34 | return creator.create_plugin(layer_name, pfc) 35 | -------------------------------------------------------------------------------- /torch2trt_dynamic/plugins/create_torchbmm_plugin.py: -------------------------------------------------------------------------------- 1 | import tensorrt as trt 2 | 3 | 4 | def create_torchbmm_plugin(layer_name): 5 | 6 | creator = trt.get_plugin_registry().get_plugin_creator( 7 | 'TorchBmmPluginDynamic', '1', '') 8 | 9 | pfc = trt.PluginFieldCollection() 10 | 11 | return creator.create_plugin(layer_name, pfc) 12 | -------------------------------------------------------------------------------- /torch2trt_dynamic/plugins/create_torchcum_plugin.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorrt as trt 3 | 4 | 5 | def create_torchcum_plugin(layer_name, dim, cum_type): 6 | 7 | creator = trt.get_plugin_registry().get_plugin_creator( 8 | 'TorchCumPluginDynamic', '1', '') 9 | 10 | pfc = trt.PluginFieldCollection() 11 | 12 | pf_dim = trt.PluginField('dim', np.array([dim], dtype=np.int32), 13 | trt.PluginFieldType.INT32) 14 | pfc.append(pf_dim) 15 | 16 | pf_cum_type = trt.PluginField('cum_type', 17 | np.array([cum_type], dtype=np.int32), 18 | trt.PluginFieldType.INT32) 19 | pfc.append(pf_cum_type) 20 | 21 | return creator.create_plugin(layer_name, pfc) 22 | -------------------------------------------------------------------------------- /torch2trt_dynamic/plugins/create_torchcummaxmin_plugin.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorrt as trt 3 | 4 | 5 | def create_torchcummaxmin_plugin(layer_name, dim, cum_type): 6 | 7 | creator = trt.get_plugin_registry().get_plugin_creator( 8 | 'TorchCumMaxMinPluginDynamic', '1', '') 9 | 10 | pfc = trt.PluginFieldCollection() 11 | 12 | pf_dim = trt.PluginField('dim', np.array([dim], dtype=np.int32), 13 | trt.PluginFieldType.INT32) 14 | pfc.append(pf_dim) 15 | 16 | pf_cum_type = trt.PluginField('cum_type', 17 | np.array([cum_type], dtype=np.int32), 18 | trt.PluginFieldType.INT32) 19 | pfc.append(pf_cum_type) 20 | 21 | return creator.create_plugin(layer_name, pfc) 22 | -------------------------------------------------------------------------------- /torch2trt_dynamic/plugins/create_torchunfold_plugin.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorrt as trt 3 | 4 | 5 | def create_torchunfold_plugin(layer_name, kernel_size, dilation, padding, 6 | stride): 7 | 8 | creator = trt.get_plugin_registry().get_plugin_creator( 9 | 'TorchUnfoldPluginDynamic', '1', '') 10 | 11 | pfc = trt.PluginFieldCollection() 12 | 13 | if isinstance(kernel_size, int): 14 | kernel_size = (kernel_size, kernel_size) 15 | pf_kernel_size = trt.PluginField('kernel_size', 16 | np.array(kernel_size, dtype=np.int32), 17 | trt.PluginFieldType.INT32) 18 | pfc.append(pf_kernel_size) 19 | 20 | if isinstance(dilation, int): 21 | dilation = (dilation, dilation) 22 | pf_dilation = trt.PluginField('dilation', 23 | np.array(dilation, dtype=np.int32), 24 | trt.PluginFieldType.INT32) 25 | pfc.append(pf_dilation) 26 | 27 | if isinstance(padding, int): 28 | padding = (padding, padding) 29 | pf_padding = trt.PluginField('padding', np.array(padding, dtype=np.int32), 30 | trt.PluginFieldType.INT32) 31 | pfc.append(pf_padding) 32 | 33 | if isinstance(stride, int): 34 | stride = (stride, stride) 35 | pf_stride = trt.PluginField('stride', np.array(stride, dtype=np.int32), 36 | trt.PluginFieldType.INT32) 37 | pfc.append(pf_stride) 38 | 39 | return creator.create_plugin(layer_name, pfc) 40 | -------------------------------------------------------------------------------- /torch2trt_dynamic/plugins/globals.py: -------------------------------------------------------------------------------- 1 | import ctypes 2 | import os 3 | import os.path as osp 4 | 5 | dir_path = osp.join(os.path.expanduser('~'), 'space/trt_plugin/build/lib/') 6 | 7 | if not osp.exists(dir_path): 8 | if 'AMIRSTAN_LIBRARY_PATH' in os.environ: 9 | dir_path = os.environ['AMIRSTAN_LIBRARY_PATH'] 10 | else: 11 | dir_path = os.path.dirname(os.path.realpath(__file__)) 12 | 13 | 14 | def load_plugin_library(): 15 | ctypes.CDLL(osp.join(dir_path, 'libamirstan_plugin.so')) 16 | -------------------------------------------------------------------------------- /torch2trt_dynamic/shape_converter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def get_tensor_shape(self): 5 | return self.size() 6 | 7 | 8 | old_get_attribute = torch.Tensor.__getattribute__ 9 | 10 | 11 | def new_getattribute__(self, name): 12 | if name == 'shape': 13 | return get_tensor_shape(self) 14 | else: 15 | return old_get_attribute(self, name) 16 | 17 | 18 | class ShapeConverter: 19 | 20 | def __init__(self): 21 | pass 22 | 23 | def __enter__(self): 24 | torch.Tensor.__getattribute__ = new_getattribute__ 25 | 26 | def __exit__(self, type, val, tb): 27 | torch.Tensor.__getattribute__ = old_get_attribute 28 | -------------------------------------------------------------------------------- /torch2trt_dynamic/torch_allocator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import tensorrt as trt 3 | import torch 4 | 5 | 6 | class TorchAllocator(trt.IGpuAllocator): 7 | """PyTorch Cuda Allocator Wrapper.""" 8 | 9 | def __init__(self, device_id: int = None) -> None: 10 | super().__init__() 11 | 12 | self.device_id = device_id 13 | self.mems = set() 14 | self.caching_delete = torch._C._cuda_cudaCachingAllocator_raw_delete 15 | 16 | def __del__(self): 17 | """destructor.""" 18 | mems = self.mems.copy() 19 | (self.deallocate(mem) for mem in mems) 20 | 21 | def allocate(self: trt.IGpuAllocator, size: int, alignment: int, 22 | flags: int) -> int: 23 | """allocate gpu memory. 24 | 25 | Args: 26 | self (trt.IGpuAllocator): gpu allocator 27 | size (int): memory size. 28 | alignment (int): memory alignment. 29 | flags (int): flags. 30 | 31 | Returns: 32 | int: memory address. 33 | """ 34 | torch_stream = torch.cuda.current_stream(self.device_id) 35 | assert alignment >= 0 36 | if alignment > 0: 37 | size = size | (alignment - 1) + 1 38 | mem = torch.cuda.caching_allocator_alloc( 39 | size, device=self.device_id, stream=torch_stream) 40 | self.mems.add(mem) 41 | return mem 42 | 43 | def deallocate(self: trt.IGpuAllocator, memory: int) -> bool: 44 | """deallocate memory. 45 | 46 | Args: 47 | self (trt.IGpuAllocator): gpu allocator 48 | memory (int): memory address. 49 | 50 | Returns: 51 | bool: deallocate success. 52 | """ 53 | if memory not in self.mems: 54 | return False 55 | 56 | self.caching_delete(memory) 57 | self.mems.discard(memory) 58 | return True 59 | -------------------------------------------------------------------------------- /torch2trt_dynamic/utils.py: -------------------------------------------------------------------------------- 1 | import graphviz 2 | 3 | 4 | def trt_network_to_dot_graph(network): 5 | dot = graphviz.Digraph(comment='Network') 6 | 7 | # add nodes (layers) 8 | for i in range(network.num_layers): 9 | layer = network.get_layer(i) 10 | dot.node(layer.name) 11 | 12 | # add nodes (inputs) 13 | for i in range(network.num_inputs): 14 | dot.node(network.get_input(i).name) 15 | 16 | # add nodes (outputs) 17 | for i in range(network.num_outputs): 18 | dot.node(network.get_output(i).name) 19 | 20 | # add layer->layer edges 21 | for a in range(network.num_layers): 22 | layer_a = network.get_layer(a) 23 | 24 | for b in range(network.num_layers): 25 | layer_b = network.get_layer(b) 26 | 27 | for i in range(layer_a.num_outputs): 28 | output_i = layer_a.get_output(i) 29 | 30 | for j in range(layer_b.num_inputs): 31 | input_j = layer_b.get_input(j) 32 | 33 | if output_i == input_j: 34 | dot.edge( 35 | layer_a.name, 36 | layer_b.name, 37 | label=str(input_j.shape)) 38 | 39 | # add input->layer edges 40 | for i in range(network.num_inputs): 41 | input_i = network.get_input(i) 42 | 43 | for b in range(network.num_layers): 44 | layer_b = network.get_layer(b) 45 | 46 | for j in range(layer_b.num_inputs): 47 | input_j = layer_b.get_input(j) 48 | 49 | if input_i == input_j: 50 | dot.edge( 51 | input_i.name, layer_b.name, label=str(input_j.shape)) 52 | 53 | # add layer->output edges 54 | for i in range(network.num_outputs): 55 | input_i = network.get_output(i) 56 | 57 | for b in range(network.num_layers): 58 | layer_b = network.get_layer(b) 59 | 60 | for j in range(layer_b.num_outputs): 61 | input_j = layer_b.get_output(j) 62 | 63 | if input_i == input_j: 64 | dot.edge( 65 | layer_b.name, input_i.name, label=str(input_j.shape)) 66 | 67 | return dot 68 | --------------------------------------------------------------------------------