├── .github ├── CODEOWNERS └── workflows │ └── lint.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CITATION.cff ├── LICENSE ├── MANIFEST.in ├── README.md ├── assets └── logo │ ├── onnx2torch_dark.png │ └── onnx2torch_light.png ├── onnx2torch ├── __init__.py ├── converter.py ├── node_converters │ ├── __init__.py │ ├── activations.py │ ├── arg_extrema.py │ ├── average_pool.py │ ├── base_element_wise.py │ ├── batch_norm.py │ ├── binary_math_operations.py │ ├── cast.py │ ├── clip.py │ ├── comparisons.py │ ├── concat.py │ ├── constant.py │ ├── constant_of_shape.py │ ├── conv.py │ ├── cumsum.py │ ├── depth_to_space.py │ ├── dropout.py │ ├── einsum.py │ ├── expand.py │ ├── eye_like.py │ ├── flatten.py │ ├── functions.py │ ├── gather.py │ ├── gemm.py │ ├── global_average_pool.py │ ├── identity.py │ ├── instance_norm.py │ ├── isinf.py │ ├── isnan.py │ ├── layer_norm.py │ ├── logical.py │ ├── lrn.py │ ├── matmul.py │ ├── max_pool.py │ ├── mean.py │ ├── min_max.py │ ├── mod.py │ ├── neg.py │ ├── nms.py │ ├── nonzero.py │ ├── pad.py │ ├── pow.py │ ├── range.py │ ├── reciprocal.py │ ├── reduce.py │ ├── registry.py │ ├── reshape.py │ ├── resize.py │ ├── roialign.py │ ├── roundings.py │ ├── scatter_nd.py │ ├── shape.py │ ├── slice.py │ ├── split.py │ ├── squeeze.py │ ├── sum.py │ ├── tile.py │ ├── topk.py │ ├── transpose.py │ ├── unsqueeze.py │ └── where.py ├── onnx_graph.py ├── onnx_node.py ├── onnx_tensor.py └── utils │ ├── __init__.py │ ├── common.py │ ├── custom_export_to_onnx.py │ ├── dtype.py │ ├── indices.py │ ├── padding.py │ └── safe_shape_inference.py ├── operators.md ├── pyproject.toml └── tests ├── __init__.py ├── models ├── __init__.py └── models_test.py ├── node_converters ├── __init__.py ├── activations_test.py ├── arg_extrema_test.py ├── average_pool_max_pool_test.py ├── batch_norm_test.py ├── binary_operations_test.py ├── clip_test.py ├── comparisons_test.py ├── concat_test.py ├── constant_of_shape_test.py ├── constant_test.py ├── conv_test.py ├── cumsum_test.py ├── depth_to_space_test.py ├── dropout_test.py ├── einsum_test.py ├── expand_test.py ├── eye_like_test.py ├── flatten_test.py ├── gather_test.py ├── gemm_test.py ├── global_avg_pool_test.py ├── instance_norm_test.py ├── layer_norm_test.py ├── logical_test.py ├── lrn_test.py ├── matmul_test.py ├── mean_test.py ├── min_max_test.py ├── mod_test.py ├── neg_test.py ├── nms_test.py ├── pad_test.py ├── pow_test.py ├── range_test.py ├── reciprocal_test.py ├── reduce_test.py ├── reshape_test.py ├── resize_test.py ├── roialign_test.py ├── scatter_nd_test.py ├── shape_test.py ├── slice_test.py ├── split_test.py ├── squeeze_test.py ├── sum_test.py ├── test_functions.py ├── tile_test.py ├── topk_test.py ├── transpose_test.py ├── unsqueeze_test.py └── where_test.py ├── pytest.ini └── utils ├── __init__.py ├── common.py └── resources.py /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | # Users referenced in this file will automatically be requested as reviewers for PRs that modify the given paths. 2 | # See https://help.github.com/articles/about-code-owners/ 3 | 4 | # Code 5 | /onnx2torch @ivkalgin @senysenyseny16 6 | /tests @ivkalgin @senysenyseny16 7 | 8 | # Actions 9 | /.github @senysenyseny16 10 | -------------------------------------------------------------------------------- /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: Lint 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | 9 | jobs: 10 | lint-python: 11 | name: Pylint 12 | runs-on: ubuntu-latest 13 | steps: 14 | - uses: actions/checkout@v3 15 | - uses: actions/setup-python@v3 16 | with: 17 | python-version: '3.9' 18 | - name: Install dependencies 19 | run: | 20 | python -m pip install --upgrade pip 21 | pip install pylint 22 | python -m pip install -e .[dev] 23 | - name: Analysing the code with pylint 24 | run: | 25 | pylint --output-format=colorized $(git ls-files '*.py') 26 | 27 | lint-python-format: 28 | name: Python format 29 | runs-on: ubuntu-latest 30 | steps: 31 | - uses: actions/checkout@v3 32 | - uses: actions/setup-python@v3 33 | with: 34 | python-version: '3.9' 35 | - uses: psf/black@stable 36 | with: 37 | options: --check --diff 38 | - uses: isort/isort-action@master 39 | with: 40 | configuration: --check --diff 41 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by .ignore support plugin (hsz.mobi) 2 | ### Python template 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | pip-wheel-metadata/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # celery beat schedule file 95 | celerybeat-schedule 96 | 97 | # SageMath parsed files 98 | *.sage.py 99 | 100 | # Environments 101 | .env 102 | .venv 103 | env/ 104 | venv/ 105 | ENV/ 106 | env.bak/ 107 | venv.bak/ 108 | 109 | # Spyder project settings 110 | .spyderproject 111 | .spyproject 112 | 113 | # Rope project settings 114 | .ropeproject 115 | 116 | # VS Code project settings 117 | .vscode 118 | 119 | # mkdocs documentation 120 | /site 121 | 122 | # mypy 123 | .mypy_cache/ 124 | .dmypy.json 125 | dmypy.json 126 | 127 | # Pyre type checker 128 | .pyre/ 129 | 130 | # Idea 131 | .idea/ 132 | trash* 133 | *.onnx 134 | 135 | tests/.tmp 136 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_install_hook_types: [commit-msg, pre-commit, pre-push] 2 | 3 | repos: 4 | - repo: https://github.com/pre-commit/pre-commit-hooks 5 | rev: v4.5.0 6 | hooks: 7 | - id: check-yaml 8 | - id: check-toml 9 | - id: check-json 10 | - id: end-of-file-fixer 11 | - id: trailing-whitespace 12 | - id: check-added-large-files 13 | - id: check-case-conflict 14 | - id: check-merge-conflict 15 | - id: detect-private-key 16 | - id: end-of-file-fixer 17 | - id: debug-statements 18 | - id: detect-private-key 19 | - id: detect-aws-credentials 20 | args: [--allow-missing-credentials] 21 | - id: no-commit-to-branch 22 | args: [-b=main] 23 | 24 | - repo: https://github.com/commitizen-tools/commitizen 25 | rev: v3.20.0 26 | hooks: 27 | - id: commitizen 28 | 29 | - repo: https://github.com/gitleaks/gitleaks 30 | rev: v8.18.2 31 | hooks: 32 | - id: gitleaks 33 | 34 | - repo: https://github.com/executablebooks/mdformat 35 | rev: 0.7.17 36 | hooks: 37 | - id: mdformat 38 | additional_dependencies: 39 | - mdformat-gfm 40 | - mdformat-black 41 | - mdformat-shfmt 42 | 43 | - repo: https://github.com/lyz-code/yamlfix 44 | rev: 1.16.0 45 | hooks: 46 | - id: yamlfix 47 | 48 | - repo: https://github.com/adrienverge/yamllint.git 49 | rev: v1.35.1 50 | hooks: 51 | - id: yamllint 52 | args: 53 | - --format 54 | - parsable 55 | - --strict 56 | - -d 57 | - '{extends: relaxed, rules: {line-length: {max: 120}}}' 58 | 59 | - repo: https://github.com/psf/black 60 | rev: 24.3.0 61 | hooks: 62 | - id: black 63 | 64 | - repo: https://github.com/PyCQA/isort 65 | rev: 5.13.2 66 | hooks: 67 | - id: isort 68 | 69 | - repo: https://github.com/PyCQA/pylint 70 | rev: v3.1.0 71 | hooks: 72 | - id: pylint 73 | language: system 74 | args: [-rn, -sn] 75 | 76 | - repo: https://github.com/RobertCraigie/pyright-python 77 | rev: v1.1.356 78 | hooks: 79 | - id: pyright 80 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | title: onnx2torch 3 | message: "Please use this information to cite onnx2torch in research or other publications." 4 | authors: 5 | - affiliation: ENOT LLC 6 | given-names: ENOT developers 7 | - family-names: Kalgin 8 | given-names: Igor 9 | - family-names: Yanchenko 10 | given-names: Arseny 11 | - family-names: Ivanov 12 | given-names: Pyoter 13 | - family-names: Goncharenko 14 | given-names: Alexander 15 | date-released: 2021-12-14 16 | url: "https://enot.ai" 17 | repository-code: "https://github.com/ENOT-AutoDL/onnx2torch" 18 | license: "Apache-2.0" 19 | keywords: 20 | - onnx 21 | - pytorch 22 | - convert 23 | - deep learning 24 | - machine learning 25 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | recursive-exclude tests * 2 | -------------------------------------------------------------------------------- /assets/logo/onnx2torch_dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ENOT-AutoDL/onnx2torch/369412ad62c81ca5b360554572820755e31b9b7a/assets/logo/onnx2torch_dark.png -------------------------------------------------------------------------------- /assets/logo/onnx2torch_light.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ENOT-AutoDL/onnx2torch/369412ad62c81ca5b360554572820755e31b9b7a/assets/logo/onnx2torch_light.png -------------------------------------------------------------------------------- /onnx2torch/__init__.py: -------------------------------------------------------------------------------- 1 | from onnx2torch.converter import convert 2 | -------------------------------------------------------------------------------- /onnx2torch/node_converters/__init__.py: -------------------------------------------------------------------------------- 1 | from onnx2torch.node_converters.activations import * 2 | from onnx2torch.node_converters.arg_extrema import * 3 | from onnx2torch.node_converters.average_pool import * 4 | from onnx2torch.node_converters.batch_norm import * 5 | from onnx2torch.node_converters.binary_math_operations import * 6 | from onnx2torch.node_converters.cast import * 7 | from onnx2torch.node_converters.clip import * 8 | from onnx2torch.node_converters.comparisons import * 9 | from onnx2torch.node_converters.concat import * 10 | from onnx2torch.node_converters.constant import * 11 | from onnx2torch.node_converters.constant_of_shape import * 12 | from onnx2torch.node_converters.conv import * 13 | from onnx2torch.node_converters.cumsum import * 14 | from onnx2torch.node_converters.depth_to_space import * 15 | from onnx2torch.node_converters.dropout import * 16 | from onnx2torch.node_converters.einsum import * 17 | from onnx2torch.node_converters.expand import * 18 | from onnx2torch.node_converters.eye_like import * 19 | from onnx2torch.node_converters.flatten import * 20 | from onnx2torch.node_converters.functions import * 21 | from onnx2torch.node_converters.gather import * 22 | from onnx2torch.node_converters.gemm import * 23 | from onnx2torch.node_converters.global_average_pool import * 24 | from onnx2torch.node_converters.identity import * 25 | from onnx2torch.node_converters.instance_norm import * 26 | from onnx2torch.node_converters.isinf import * 27 | from onnx2torch.node_converters.isnan import * 28 | from onnx2torch.node_converters.layer_norm import * 29 | from onnx2torch.node_converters.logical import * 30 | from onnx2torch.node_converters.lrn import * 31 | from onnx2torch.node_converters.matmul import * 32 | from onnx2torch.node_converters.max_pool import * 33 | from onnx2torch.node_converters.mean import * 34 | from onnx2torch.node_converters.min_max import * 35 | from onnx2torch.node_converters.mod import * 36 | from onnx2torch.node_converters.neg import * 37 | from onnx2torch.node_converters.nms import * 38 | from onnx2torch.node_converters.nonzero import * 39 | from onnx2torch.node_converters.pad import * 40 | from onnx2torch.node_converters.pow import * 41 | from onnx2torch.node_converters.range import * 42 | from onnx2torch.node_converters.reciprocal import * 43 | from onnx2torch.node_converters.reduce import * 44 | from onnx2torch.node_converters.registry import OperationDescription 45 | from onnx2torch.node_converters.registry import TConverter 46 | from onnx2torch.node_converters.registry import get_converter 47 | from onnx2torch.node_converters.reshape import * 48 | from onnx2torch.node_converters.resize import * 49 | from onnx2torch.node_converters.roialign import * 50 | from onnx2torch.node_converters.roundings import * 51 | from onnx2torch.node_converters.scatter_nd import * 52 | from onnx2torch.node_converters.shape import * 53 | from onnx2torch.node_converters.slice import * 54 | from onnx2torch.node_converters.split import * 55 | from onnx2torch.node_converters.squeeze import * 56 | from onnx2torch.node_converters.sum import * 57 | from onnx2torch.node_converters.tile import * 58 | from onnx2torch.node_converters.topk import * 59 | from onnx2torch.node_converters.transpose import * 60 | from onnx2torch.node_converters.unsqueeze import * 61 | from onnx2torch.node_converters.where import * 62 | -------------------------------------------------------------------------------- /onnx2torch/node_converters/arg_extrema.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=missing-docstring 2 | __all__ = [ 3 | 'OnnxArgExtremumOld', 4 | 'OnnxArgExtremum', 5 | ] 6 | 7 | import torch 8 | from torch import nn 9 | 10 | from onnx2torch.node_converters.registry import add_converter 11 | from onnx2torch.onnx_graph import OnnxGraph 12 | from onnx2torch.onnx_node import OnnxNode 13 | from onnx2torch.utils.common import OnnxToTorchModule 14 | from onnx2torch.utils.common import OperationConverterResult 15 | from onnx2torch.utils.common import onnx_mapping_from_node 16 | 17 | DEFAULT_AXIS = 0 18 | DEFAULT_KEEPDIMS = 1 19 | DEFAULT_SELECT_LAST_INDEX = 0 20 | 21 | _TORCH_FUNCTION_FROM_ONNX_TYPE = { 22 | 'ArgMax': torch.argmax, 23 | 'ArgMin': torch.argmin, 24 | } 25 | 26 | 27 | class OnnxArgExtremumOld(nn.Module, OnnxToTorchModule): 28 | def __init__(self, operation_type: str, axis: int, keepdims: int): 29 | super().__init__() 30 | self.axis = axis 31 | self.keepdims = bool(keepdims) 32 | self.extremum_function = _TORCH_FUNCTION_FROM_ONNX_TYPE[operation_type] 33 | 34 | def forward(self, data: torch.Tensor) -> torch.Tensor: 35 | return self.extremum_function(data, dim=self.axis, keepdim=self.keepdims) 36 | 37 | 38 | class OnnxArgExtremum(nn.Module, OnnxToTorchModule): 39 | def __init__(self, operation_type: str, axis: int, keepdims: int, select_last_index: int): 40 | super().__init__() 41 | self.axis = axis 42 | self.keepdims = bool(keepdims) 43 | self.select_last_index = bool(select_last_index) 44 | self.extremum_function = _TORCH_FUNCTION_FROM_ONNX_TYPE[operation_type] 45 | 46 | def forward(self, data: torch.Tensor) -> torch.Tensor: 47 | if self.select_last_index: 48 | # torch's argmax does not handle the select_last_index attribute from Onnx. 49 | # We flip the data, call the normal argmax, then map it back to the original 50 | flipped = torch.flip(data, dims=[self.axis]) 51 | 52 | extremum_index_flipped = self.extremum_function(flipped, dim=self.axis, keepdim=self.keepdims) 53 | extremum_index_original = data.size(dim=self.axis) - 1 - extremum_index_flipped 54 | return extremum_index_original 55 | 56 | return self.extremum_function(data, dim=self.axis, keepdim=self.keepdims) 57 | 58 | 59 | @add_converter(operation_type='ArgMax', version=12) 60 | @add_converter(operation_type='ArgMax', version=13) 61 | @add_converter(operation_type='ArgMin', version=12) 62 | @add_converter(operation_type='ArgMin', version=13) 63 | def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: 64 | del graph 65 | return OperationConverterResult( 66 | torch_module=OnnxArgExtremum( 67 | operation_type=node.operation_type, 68 | axis=node.attributes.get('axis', DEFAULT_AXIS), 69 | keepdims=node.attributes.get('keepdims', DEFAULT_KEEPDIMS), 70 | select_last_index=node.attributes.get('select_last_index', DEFAULT_SELECT_LAST_INDEX), 71 | ), 72 | onnx_mapping=onnx_mapping_from_node(node=node), 73 | ) 74 | 75 | 76 | @add_converter(operation_type='ArgMax', version=11) 77 | @add_converter(operation_type='ArgMin', version=11) 78 | def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: 79 | del graph 80 | return OperationConverterResult( 81 | torch_module=OnnxArgExtremumOld( 82 | operation_type=node.operation_type, 83 | axis=node.attributes.get('axis', DEFAULT_AXIS), 84 | keepdims=node.attributes.get('keepdims', DEFAULT_KEEPDIMS), 85 | ), 86 | onnx_mapping=onnx_mapping_from_node(node=node), 87 | ) 88 | -------------------------------------------------------------------------------- /onnx2torch/node_converters/average_pool.py: -------------------------------------------------------------------------------- 1 | __all__ = [] 2 | 3 | from torch import nn 4 | 5 | from onnx2torch.node_converters.registry import add_converter 6 | from onnx2torch.onnx_graph import OnnxGraph 7 | from onnx2torch.onnx_node import OnnxNode 8 | from onnx2torch.utils.common import OperationConverterResult 9 | from onnx2torch.utils.common import get_shape_from_value_info 10 | from onnx2torch.utils.common import onnx_mapping_from_node 11 | from onnx2torch.utils.padding import onnx_auto_pad_to_torch_padding 12 | 13 | _AVGPOOL_CLASS_FROM_SPATIAL_RANK = { 14 | 1: nn.AvgPool1d, 15 | 2: nn.AvgPool2d, 16 | 3: nn.AvgPool3d, 17 | } 18 | 19 | 20 | @add_converter(operation_type='AveragePool', version=7) 21 | @add_converter(operation_type='AveragePool', version=10) 22 | @add_converter(operation_type='AveragePool', version=11) 23 | def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: 24 | input_value_info = graph.value_info[node.input_values[0]] 25 | input_shape = get_shape_from_value_info(input_value_info) 26 | 27 | spatial_rank = len(input_shape) - 2 28 | try: 29 | avgpool_class = _AVGPOOL_CLASS_FROM_SPATIAL_RANK[spatial_rank] 30 | except KeyError as exc: 31 | raise NotImplementedError( 32 | f'Average pool operation with spatial rank == {spatial_rank} is not implemented' 33 | ) from exc 34 | 35 | node_attributes = node.attributes 36 | # required 37 | kernel_shape = node_attributes['kernel_shape'] 38 | # optional 39 | ceil_mode = node_attributes.get('ceil_mode', 0) 40 | strides = node_attributes.get('strides', 1) 41 | count_include_pad = node_attributes.get('count_include_pad', 0) 42 | 43 | padding, padding_module = onnx_auto_pad_to_torch_padding( 44 | onnx_padding=node_attributes.get('pads', [0] * spatial_rank * 2), 45 | auto_pad=node_attributes.get('auto_pad', 'NOTSET'), 46 | ) 47 | if padding_module is not None: 48 | raise NotImplementedError('AvgPool with non symmetrical padding is not implemented.') 49 | 50 | torch_module = avgpool_class( 51 | kernel_size=kernel_shape, 52 | stride=strides, 53 | padding=padding, 54 | count_include_pad=count_include_pad == 1, 55 | ceil_mode=ceil_mode == 1, 56 | ) 57 | 58 | return OperationConverterResult( 59 | torch_module=torch_module, 60 | onnx_mapping=onnx_mapping_from_node(node=node), 61 | ) 62 | -------------------------------------------------------------------------------- /onnx2torch/node_converters/base_element_wise.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=missing-docstring 2 | import torch 3 | from torch import nn 4 | 5 | from onnx2torch.utils.custom_export_to_onnx import DefaultExportToOnnx 6 | from onnx2torch.utils.custom_export_to_onnx import OnnxToTorchModuleWithCustomExport 7 | 8 | 9 | class OnnxBaseElementWise(nn.Module, OnnxToTorchModuleWithCustomExport): 10 | def __init__(self, op_type: str): 11 | super().__init__() 12 | self._op_type = op_type 13 | 14 | @staticmethod 15 | def _broadcast_shape(*tensors: torch.Tensor): 16 | shapes = [t.shape for t in tensors] 17 | broadcast_shape = torch.broadcast_shapes(*shapes) 18 | return broadcast_shape 19 | 20 | def apply_reduction(self, *tensors: torch.Tensor) -> torch.Tensor: 21 | del tensors 22 | raise NotImplementedError 23 | 24 | def forward(self, *input_tensors: torch.Tensor) -> torch.Tensor: 25 | if len(input_tensors) == 1: 26 | # If there is a single element, return it (no op). 27 | # Also, no need for manually building the ONNX node. 28 | return input_tensors[0] 29 | 30 | def _forward() -> torch.Tensor: 31 | return self.apply_reduction(*input_tensors) 32 | 33 | if torch.onnx.is_in_onnx_export(): 34 | return DefaultExportToOnnx.export(_forward, self._op_type, *input_tensors, {}) 35 | 36 | return _forward() 37 | -------------------------------------------------------------------------------- /onnx2torch/node_converters/binary_math_operations.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'OnnxBinaryMathOperation', 3 | ] 4 | 5 | from typing import Optional 6 | 7 | import torch 8 | from torch import nn 9 | 10 | from onnx2torch.node_converters.registry import add_converter 11 | from onnx2torch.onnx_graph import OnnxGraph 12 | from onnx2torch.onnx_node import OnnxNode 13 | from onnx2torch.utils.common import OnnxToTorchModule 14 | from onnx2torch.utils.common import OperationConverterResult 15 | from onnx2torch.utils.common import old_style_broadcast 16 | from onnx2torch.utils.common import onnx_mapping_from_node 17 | 18 | 19 | def _onnx_div(first: torch.Tensor, second: torch.Tensor) -> torch.Tensor: 20 | if first.is_floating_point() or second.is_floating_point(): # float division 21 | return torch.div(first, second) 22 | 23 | return torch.div(first, second, rounding_mode='trunc') # integer division 24 | 25 | 26 | _TORCH_FUNCTION_FROM_ONNX_TYPE = { 27 | 'Add': torch.add, 28 | 'Sub': torch.sub, 29 | 'Mul': torch.mul, 30 | 'Div': _onnx_div, 31 | } 32 | 33 | 34 | class OnnxBinaryMathOperation(nn.Module, OnnxToTorchModule): # pylint: disable=missing-docstring 35 | def __init__(self, operation_type: str, broadcast: Optional[int] = None, axis: Optional[int] = None): 36 | super().__init__() 37 | 38 | self.broadcast = broadcast 39 | self.axis = axis 40 | self.math_op_function = _TORCH_FUNCTION_FROM_ONNX_TYPE[operation_type] 41 | 42 | def forward( # pylint: disable=missing-function-docstring 43 | self, 44 | first: torch.Tensor, 45 | second: torch.Tensor, 46 | ) -> torch.Tensor: 47 | if self.broadcast == 1 and self.axis is not None: 48 | second = old_style_broadcast(first, second, self.axis) 49 | 50 | return self.math_op_function(first, second) 51 | 52 | 53 | @add_converter(operation_type='Add', version=1) 54 | @add_converter(operation_type='Add', version=6) 55 | @add_converter(operation_type='Add', version=7) 56 | @add_converter(operation_type='Add', version=13) 57 | @add_converter(operation_type='Add', version=14) 58 | @add_converter(operation_type='Sub', version=1) 59 | @add_converter(operation_type='Sub', version=6) 60 | @add_converter(operation_type='Sub', version=7) 61 | @add_converter(operation_type='Sub', version=13) 62 | @add_converter(operation_type='Sub', version=14) 63 | @add_converter(operation_type='Mul', version=1) 64 | @add_converter(operation_type='Mul', version=6) 65 | @add_converter(operation_type='Mul', version=7) 66 | @add_converter(operation_type='Mul', version=13) 67 | @add_converter(operation_type='Mul', version=14) 68 | @add_converter(operation_type='Div', version=1) 69 | @add_converter(operation_type='Div', version=6) 70 | @add_converter(operation_type='Div', version=7) 71 | @add_converter(operation_type='Div', version=13) 72 | @add_converter(operation_type='Div', version=14) 73 | def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument 74 | return OperationConverterResult( 75 | torch_module=OnnxBinaryMathOperation( 76 | operation_type=node.operation_type, 77 | broadcast=node.attributes.get('broadcast', None), 78 | axis=node.attributes.get('axis', None), 79 | ), 80 | onnx_mapping=onnx_mapping_from_node(node=node), 81 | ) 82 | -------------------------------------------------------------------------------- /onnx2torch/node_converters/cast.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'OnnxCast', 3 | ] 4 | 5 | import torch 6 | from onnx import TensorProto # pylint: disable=no-name-in-module 7 | from torch import nn 8 | 9 | from onnx2torch.node_converters.registry import add_converter 10 | from onnx2torch.onnx_graph import OnnxGraph 11 | from onnx2torch.onnx_node import OnnxNode 12 | from onnx2torch.utils.common import OnnxToTorchModule 13 | from onnx2torch.utils.common import OperationConverterResult 14 | from onnx2torch.utils.common import onnx_mapping_from_node 15 | 16 | # pylint: disable=no-member 17 | TENSOR_TYPE_TO_TORCH_TYPE = { 18 | int(TensorProto.FLOAT): torch.float32, 19 | int(TensorProto.UINT8): torch.uint8, 20 | int(TensorProto.INT8): torch.int8, 21 | int(TensorProto.INT16): torch.int16, 22 | int(TensorProto.INT32): torch.int32, 23 | int(TensorProto.INT64): torch.int64, 24 | int(TensorProto.BOOL): torch.bool, 25 | int(TensorProto.FLOAT16): torch.float16, 26 | int(TensorProto.DOUBLE): torch.float64, 27 | int(TensorProto.COMPLEX64): torch.complex64, 28 | int(TensorProto.COMPLEX128): torch.complex128, 29 | } 30 | # pylint: enable=no-member 31 | 32 | 33 | class OnnxCast(nn.Module, OnnxToTorchModule): # pylint: disable=missing-docstring 34 | def __init__(self, onnx_dtype: int): 35 | super().__init__() 36 | try: 37 | self.torch_dtype = TENSOR_TYPE_TO_TORCH_TYPE[onnx_dtype] 38 | except KeyError as exc: 39 | raise NotImplementedError(f'Conversion to "{onnx_dtype}" is not implemented') from exc 40 | 41 | def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: # pylint: disable=missing-function-docstring 42 | return input_tensor.to(self.torch_dtype) 43 | 44 | 45 | @add_converter(operation_type='Cast', version=9) 46 | @add_converter(operation_type='Cast', version=13) 47 | def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument 48 | node_attributes = node.attributes 49 | onnx_dtype = node_attributes.get('to', None) 50 | 51 | return OperationConverterResult( 52 | torch_module=OnnxCast(onnx_dtype), 53 | onnx_mapping=onnx_mapping_from_node(node=node), 54 | ) 55 | -------------------------------------------------------------------------------- /onnx2torch/node_converters/clip.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'OnnxClip', 3 | ] 4 | 5 | from typing import Optional 6 | 7 | import torch 8 | from torch import nn 9 | from torch.types import Number 10 | 11 | from onnx2torch.node_converters.registry import add_converter 12 | from onnx2torch.onnx_graph import OnnxGraph 13 | from onnx2torch.onnx_node import OnnxNode 14 | from onnx2torch.utils.common import OnnxMapping 15 | from onnx2torch.utils.common import OnnxToTorchModule 16 | from onnx2torch.utils.common import OperationConverterResult 17 | from onnx2torch.utils.common import get_const_value 18 | from onnx2torch.utils.common import onnx_mapping_from_node 19 | 20 | 21 | class OnnxClip(nn.Module, OnnxToTorchModule): # pylint: disable=missing-docstring 22 | def __init__( 23 | self, 24 | min_val: Optional[Number] = None, 25 | max_val: Optional[Number] = None, 26 | ): 27 | super().__init__() 28 | self.min_val = min_val 29 | self.max_val = max_val 30 | 31 | def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: # pylint: disable=missing-function-docstring 32 | return torch.clamp(input_tensor, self.min_val, self.max_val) 33 | 34 | 35 | def _create_torch_module(min_val: Optional[torch.Tensor], max_val: Optional[torch.Tensor]) -> nn.Module: 36 | if min_val is None and max_val is None: 37 | torch_module = nn.Identity() 38 | elif min_val == 0 and max_val is None: 39 | torch_module = nn.ReLU() 40 | elif min_val == 0 and max_val == 6: 41 | torch_module = nn.ReLU6() 42 | else: 43 | torch_module = OnnxClip(min_val=min_val, max_val=max_val) 44 | 45 | return torch_module 46 | 47 | 48 | @add_converter(operation_type='Clip', version=11) 49 | @add_converter(operation_type='Clip', version=12) 50 | @add_converter(operation_type='Clip', version=13) 51 | def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: 52 | # Min and Max inputs are optional 53 | min_name = node.input_values[1] if len(node.input_values) > 1 else None 54 | max_name = node.input_values[2] if len(node.input_values) > 2 else None 55 | 56 | try: 57 | min_val = float(get_const_value(min_name, graph)) if min_name is not None else None 58 | max_val = float(get_const_value(max_name, graph)) if max_name is not None else None 59 | except KeyError as exc: 60 | raise NotImplementedError('Dynamic value of min/max is not implemented') from exc 61 | 62 | torch_module = _create_torch_module(min_val=min_val, max_val=max_val) 63 | 64 | return OperationConverterResult( 65 | torch_module=torch_module, 66 | onnx_mapping=OnnxMapping( 67 | inputs=(node.input_values[0],), 68 | outputs=node.output_values, 69 | ), 70 | ) 71 | 72 | 73 | @add_converter(operation_type='Clip', version=6) 74 | def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument 75 | node_attributes = node.attributes 76 | min_val = node_attributes.get('min', None) 77 | max_val = node_attributes.get('max', None) 78 | 79 | torch_module = _create_torch_module(min_val=min_val, max_val=max_val) 80 | 81 | return OperationConverterResult( 82 | torch_module=torch_module, 83 | onnx_mapping=onnx_mapping_from_node(node), 84 | ) 85 | -------------------------------------------------------------------------------- /onnx2torch/node_converters/comparisons.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'OnnxCompare', 3 | ] 4 | 5 | import torch 6 | from torch import nn 7 | 8 | from onnx2torch.node_converters.registry import add_converter 9 | from onnx2torch.onnx_graph import OnnxGraph 10 | from onnx2torch.onnx_node import OnnxNode 11 | from onnx2torch.utils.common import OnnxToTorchModule 12 | from onnx2torch.utils.common import OperationConverterResult 13 | from onnx2torch.utils.common import onnx_mapping_from_node 14 | 15 | _TORCH_FUNCTION_FROM_ONNX_TYPE = { 16 | 'Equal': torch.eq, 17 | 'Less': torch.less, 18 | 'LessOrEqual': torch.less_equal, 19 | 'Greater': torch.greater, 20 | 'GreaterOrEqual': torch.greater_equal, 21 | } 22 | 23 | 24 | class OnnxCompare(nn.Module, OnnxToTorchModule): # pylint: disable=missing-docstring 25 | def __init__(self, operation_type: str): 26 | super().__init__() 27 | self.compare_function = _TORCH_FUNCTION_FROM_ONNX_TYPE[operation_type] 28 | 29 | def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # pylint: disable=missing-function-docstring 30 | return self.compare_function(x, y) 31 | 32 | 33 | @add_converter(operation_type='Equal', version=7) 34 | @add_converter(operation_type='Equal', version=11) 35 | @add_converter(operation_type='Equal', version=13) 36 | @add_converter(operation_type='Less', version=7) 37 | @add_converter(operation_type='Less', version=9) 38 | @add_converter(operation_type='Less', version=13) 39 | @add_converter(operation_type='Greater', version=7) 40 | @add_converter(operation_type='Greater', version=9) 41 | @add_converter(operation_type='Greater', version=13) 42 | @add_converter(operation_type='LessOrEqual', version=12) 43 | @add_converter(operation_type='GreaterOrEqual', version=12) 44 | def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument 45 | return OperationConverterResult( 46 | torch_module=OnnxCompare(operation_type=node.operation_type), 47 | onnx_mapping=onnx_mapping_from_node(node=node), 48 | ) 49 | -------------------------------------------------------------------------------- /onnx2torch/node_converters/concat.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'OnnxConcat', 3 | ] 4 | 5 | import torch 6 | from torch import nn 7 | 8 | from onnx2torch.node_converters.registry import add_converter 9 | from onnx2torch.onnx_graph import OnnxGraph 10 | from onnx2torch.onnx_node import OnnxNode 11 | from onnx2torch.utils.common import OnnxToTorchModule 12 | from onnx2torch.utils.common import OperationConverterResult 13 | from onnx2torch.utils.common import onnx_mapping_from_node 14 | 15 | 16 | class OnnxConcat(nn.Module, OnnxToTorchModule): # pylint: disable=missing-docstring 17 | def __init__(self, axis: int): 18 | super().__init__() 19 | self.axis = axis 20 | 21 | def forward(self, *input_tensors) -> torch.Tensor: # pylint: disable=missing-function-docstring 22 | return torch.cat(input_tensors, self.axis) 23 | 24 | 25 | @add_converter(operation_type='Concat', version=4) 26 | @add_converter(operation_type='Concat', version=11) 27 | @add_converter(operation_type='Concat', version=13) 28 | def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument 29 | axis = node.attributes.get('axis', 0) 30 | torch_module = OnnxConcat( 31 | axis=axis, 32 | ) 33 | 34 | return OperationConverterResult( 35 | torch_module=torch_module, 36 | onnx_mapping=onnx_mapping_from_node(node=node), 37 | ) 38 | -------------------------------------------------------------------------------- /onnx2torch/node_converters/constant.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'OnnxConstant', 3 | ] 4 | 5 | from typing import Any 6 | 7 | import torch 8 | from torch import nn 9 | 10 | from onnx2torch.node_converters.registry import add_converter 11 | from onnx2torch.onnx_graph import OnnxGraph 12 | from onnx2torch.onnx_node import OnnxNode 13 | from onnx2torch.utils.common import OnnxToTorchModule 14 | from onnx2torch.utils.common import OperationConverterResult 15 | from onnx2torch.utils.common import onnx_mapping_from_node 16 | 17 | _CONSTANT_PARSING_MAPPING = { 18 | 'value': lambda x: x.to_torch(), 19 | 'value_float': torch.tensor, 20 | 'value_floats': torch.tensor, 21 | 'value_int': torch.tensor, 22 | 'value_ints': torch.tensor, 23 | 'value_string': lambda x: x, 24 | 'value_strings': lambda x: x, 25 | } 26 | 27 | 28 | class OnnxConstant(nn.Module, OnnxToTorchModule): # pylint: disable=missing-docstring 29 | def __init__(self, value: Any): 30 | super().__init__() 31 | # We need it for placing constant to cuda. 32 | if isinstance(value, torch.Tensor): 33 | self.register_buffer('value', value) 34 | else: 35 | self.value = value 36 | 37 | def forward(self) -> Any: # pylint: disable=missing-function-docstring 38 | return self.value 39 | 40 | 41 | def _prepare_output_value(value: Any, attr_name: str) -> Any: 42 | if attr_name in _CONSTANT_PARSING_MAPPING: 43 | return _CONSTANT_PARSING_MAPPING[attr_name](value) 44 | 45 | raise NotImplementedError(f'value type "{attr_name}" not supported yet.') 46 | 47 | 48 | @add_converter(operation_type='Constant', version=9) 49 | @add_converter(operation_type='Constant', version=11) 50 | @add_converter(operation_type='Constant', version=12) 51 | @add_converter(operation_type='Constant', version=13) 52 | def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument 53 | attr_name, value = list(node.attributes.items())[0] 54 | prepared_value = _prepare_output_value(value, attr_name) 55 | 56 | torch_module = OnnxConstant( 57 | value=prepared_value, 58 | ) 59 | 60 | return OperationConverterResult( 61 | torch_module=torch_module, 62 | onnx_mapping=onnx_mapping_from_node(node=node), 63 | ) 64 | -------------------------------------------------------------------------------- /onnx2torch/node_converters/constant_of_shape.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'OnnxConstantOfShape', 3 | ] 4 | 5 | from typing import Optional 6 | 7 | import torch 8 | from torch import nn 9 | 10 | from onnx2torch.node_converters.registry import add_converter 11 | from onnx2torch.onnx_graph import OnnxGraph 12 | from onnx2torch.onnx_node import OnnxNode 13 | from onnx2torch.utils.common import OnnxToTorchModule 14 | from onnx2torch.utils.common import OperationConverterResult 15 | from onnx2torch.utils.common import onnx_mapping_from_node 16 | 17 | 18 | class OnnxConstantOfShape(nn.Module, OnnxToTorchModule): # pylint: disable=missing-docstring 19 | def __init__(self, value: Optional[torch.Tensor] = None): 20 | super().__init__() 21 | 22 | if value is None: 23 | value = torch.tensor(0.0, dtype=torch.float32) 24 | 25 | if value.numel() != 1: 26 | raise ValueError('parameter "value" must be scalar') 27 | 28 | self.value: torch.Tensor 29 | self.register_buffer('value', value) 30 | 31 | def forward(self, shape: torch.Tensor) -> torch.Tensor: # pylint: disable=missing-function-docstring 32 | fill_value = self.value.item() 33 | 34 | return torch.full( 35 | size=torch.Size(shape), 36 | fill_value=int(fill_value) if isinstance(fill_value, bool) else fill_value, 37 | dtype=self.value.dtype, 38 | device=self.value.device, 39 | ) 40 | 41 | 42 | @add_converter(operation_type='ConstantOfShape', version=9) 43 | def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument 44 | node_attributes = node.attributes 45 | 46 | if 'value' in node_attributes: 47 | value = node_attributes['value'].to_torch() 48 | else: 49 | value = None 50 | 51 | return OperationConverterResult( 52 | torch_module=OnnxConstantOfShape(value=value), 53 | onnx_mapping=onnx_mapping_from_node(node=node), 54 | ) 55 | -------------------------------------------------------------------------------- /onnx2torch/node_converters/conv.py: -------------------------------------------------------------------------------- 1 | __all__ = [] 2 | 3 | import torch 4 | from torch import nn 5 | 6 | from onnx2torch.node_converters.registry import add_converter 7 | from onnx2torch.onnx_graph import OnnxGraph 8 | from onnx2torch.onnx_node import OnnxNode 9 | from onnx2torch.utils.common import OnnxMapping 10 | from onnx2torch.utils.common import OperationConverterResult 11 | from onnx2torch.utils.padding import onnx_auto_pad_to_torch_padding 12 | 13 | _CONV_CLASS_FROM_SPATIAL_RANK = { 14 | ('Conv', 1): nn.Conv1d, 15 | ('Conv', 2): nn.Conv2d, 16 | ('Conv', 3): nn.Conv3d, 17 | ('ConvTranspose', 1): nn.ConvTranspose1d, 18 | ('ConvTranspose', 2): nn.ConvTranspose2d, 19 | ('ConvTranspose', 3): nn.ConvTranspose3d, 20 | } 21 | 22 | 23 | @add_converter(operation_type='Conv', version=1) 24 | @add_converter(operation_type='Conv', version=11) 25 | @add_converter(operation_type='ConvTranspose', version=1) 26 | @add_converter(operation_type='ConvTranspose', version=11) 27 | def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: 28 | weights_value_name = node.input_values[1] 29 | weights = graph.initializers[weights_value_name] 30 | weights = weights.to_torch() 31 | if len(node.input_values) == 3: 32 | bias_value_name = node.input_values[2] 33 | bias = graph.initializers[bias_value_name] 34 | bias = bias.to_torch() 35 | else: 36 | bias = None 37 | 38 | op_type = node.operation_type 39 | spatial_rank = len(weights.shape) - 2 40 | try: 41 | conv_class = _CONV_CLASS_FROM_SPATIAL_RANK[op_type, spatial_rank] 42 | except KeyError as exc: 43 | raise NotImplementedError( 44 | f'Convolution operation with spatial rank == {spatial_rank} is not implemented' 45 | ) from exc 46 | 47 | node_attributes = node.attributes 48 | padding, input_padding_module = onnx_auto_pad_to_torch_padding( 49 | onnx_padding=node_attributes.get('pads', [0] * spatial_rank * 2), 50 | auto_pad=node_attributes.get('auto_pad', 'NOTSET'), 51 | ) 52 | common_kwargs = { 53 | 'kernel_size': node_attributes.get('kernel_shape', weights.shape[2:]), 54 | 'stride': node_attributes.get('strides', 1), 55 | 'dilation': node_attributes.get('dilations', 1), 56 | 'groups': node_attributes.get('group', 1), 57 | 'padding': padding, 58 | 'bias': bias is not None, 59 | } 60 | 61 | if op_type == 'Conv': 62 | special_kwargs = { 63 | 'out_channels': weights.shape[0], 64 | 'in_channels': weights.shape[1] * common_kwargs['groups'], 65 | } 66 | elif op_type == 'ConvTranspose': 67 | if input_padding_module is not None: 68 | raise NotImplementedError('ConvTranspose with non symmetrical padding is not implemented.') 69 | 70 | output_padding = node_attributes.get('output_padding', [0] * spatial_rank) 71 | special_kwargs = { 72 | 'out_channels': weights.shape[1] * common_kwargs['groups'], 73 | 'in_channels': weights.shape[0], 74 | 'output_padding': output_padding, 75 | } 76 | else: 77 | raise ValueError(f'Got unknown op_type "{op_type}"') 78 | 79 | torch_module = conv_class( 80 | **common_kwargs, 81 | **special_kwargs, 82 | ) 83 | with torch.no_grad(): 84 | torch_module.weight.data = weights 85 | if bias is not None: 86 | torch_module.bias.data = bias 87 | 88 | if input_padding_module is not None: 89 | torch_module = nn.Sequential(input_padding_module, torch_module) 90 | 91 | return OperationConverterResult( 92 | torch_module=torch_module, 93 | onnx_mapping=OnnxMapping( 94 | inputs=(node.input_values[0],), 95 | outputs=node.output_values, 96 | ), 97 | ) 98 | -------------------------------------------------------------------------------- /onnx2torch/node_converters/cumsum.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'OnnxCumSum', 3 | ] 4 | import torch 5 | from torch import nn 6 | 7 | from onnx2torch.node_converters.registry import add_converter 8 | from onnx2torch.onnx_graph import OnnxGraph 9 | from onnx2torch.onnx_node import OnnxNode 10 | from onnx2torch.utils.common import OnnxToTorchModule 11 | from onnx2torch.utils.common import OperationConverterResult 12 | from onnx2torch.utils.common import onnx_mapping_from_node 13 | 14 | 15 | def _arbitrary_dim_shift_and_insert_zero( 16 | input_tensor: torch.Tensor, 17 | insert_dim: int, 18 | ) -> torch.Tensor: 19 | # single item shift 20 | slice_index, insertion = [[slice(None)] * len(input_tensor.shape)] * 2 21 | insert_dim_size = input_tensor.shape[insert_dim] 22 | 23 | slice_index[insert_dim] = slice(0, -1) 24 | slice_index = tuple(slice_index) 25 | tensor_slice = input_tensor[slice_index] 26 | 27 | insert_index = torch.arange(start=1, end=insert_dim_size, dtype=torch.int64, device=input_tensor.device) 28 | index_shape = [1] * len(input_tensor.shape) 29 | index_shape[insert_dim] = insert_dim_size - 1 30 | 31 | insert_index = torch.reshape(insert_index, index_shape) 32 | insert_index = insert_index + torch.zeros_like(tensor_slice, dtype=torch.int64, device=input_tensor.device) 33 | 34 | input_tensor = torch.scatter( 35 | input=input_tensor, 36 | dim=insert_dim, 37 | index=insert_index, 38 | src=tensor_slice, 39 | ) 40 | 41 | insertion[insert_dim] = slice(0, 1) 42 | insertion = tuple(insertion) 43 | input_tensor[insertion] = 0 44 | 45 | return input_tensor 46 | 47 | 48 | class OnnxCumSum(nn.Module, OnnxToTorchModule): # pylint: disable=missing-docstring 49 | def __init__( 50 | self, 51 | exclusive: bool = False, 52 | reverse: bool = False, 53 | ): 54 | super().__init__() 55 | self.exclusive = exclusive 56 | self.reverse = reverse 57 | 58 | def forward( # pylint: disable=missing-function-docstring 59 | self, 60 | input_tensor: torch.Tensor, 61 | axis: torch.Tensor, 62 | ) -> torch.Tensor: 63 | axis = axis.item() 64 | if self.reverse: 65 | input_tensor = torch.flip(input_tensor, dims=(axis,)) 66 | 67 | if self.exclusive: 68 | input_tensor = _arbitrary_dim_shift_and_insert_zero(input_tensor, insert_dim=axis) 69 | 70 | input_tensor = torch.cumsum(input_tensor, dim=axis) 71 | 72 | if self.reverse: 73 | input_tensor = torch.flip(input_tensor, dims=(axis,)) 74 | 75 | return input_tensor 76 | 77 | 78 | @add_converter(operation_type='CumSum', version=11) 79 | @add_converter(operation_type='CumSum', version=14) 80 | def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument 81 | node_attributes = node.attributes 82 | exclusive = bool(node_attributes.get('exclusive', 0)) 83 | reverse = bool(node_attributes.get('reverse', 1)) 84 | 85 | return OperationConverterResult( 86 | torch_module=OnnxCumSum(exclusive, reverse), 87 | onnx_mapping=onnx_mapping_from_node(node), 88 | ) 89 | -------------------------------------------------------------------------------- /onnx2torch/node_converters/depth_to_space.py: -------------------------------------------------------------------------------- 1 | __all__ = ['OnnxDepthToSpace'] 2 | 3 | import torch 4 | from torch import nn 5 | 6 | from onnx2torch.node_converters.registry import add_converter 7 | from onnx2torch.onnx_graph import OnnxGraph 8 | from onnx2torch.onnx_node import OnnxNode 9 | from onnx2torch.utils.common import OnnxToTorchModule 10 | from onnx2torch.utils.common import OperationConverterResult 11 | from onnx2torch.utils.common import onnx_mapping_from_node 12 | 13 | 14 | class OnnxDepthToSpace(nn.Module, OnnxToTorchModule): # pylint: disable=missing-docstring 15 | def __init__(self, blocksize: int): 16 | super().__init__() 17 | self._upscale_factor = blocksize 18 | 19 | def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: # pylint: disable=missing-function-docstring 20 | return torch.pixel_shuffle(input_tensor, upscale_factor=self._upscale_factor) 21 | 22 | 23 | @add_converter(operation_type='DepthToSpace', version=11) 24 | @add_converter(operation_type='DepthToSpace', version=13) 25 | def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: 26 | del graph 27 | 28 | blocksize: int = node.attributes['blocksize'] # required 29 | mode: str = node.attributes.get('mode', 'DCR') 30 | 31 | if mode != 'CRD': 32 | raise NotImplementedError('DepthToSpace for mode other than CRD is not implemented') 33 | 34 | return OperationConverterResult( 35 | torch_module=OnnxDepthToSpace(blocksize=blocksize), 36 | onnx_mapping=onnx_mapping_from_node(node=node), 37 | ) 38 | -------------------------------------------------------------------------------- /onnx2torch/node_converters/dropout.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'OnnxDropoutDynamic', 3 | ] 4 | 5 | from typing import Optional 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | from torch import nn 10 | 11 | from onnx2torch.node_converters.registry import add_converter 12 | from onnx2torch.onnx_graph import OnnxGraph 13 | from onnx2torch.onnx_node import OnnxNode 14 | from onnx2torch.utils.common import OnnxToTorchModule 15 | from onnx2torch.utils.common import OperationConverterResult 16 | from onnx2torch.utils.common import onnx_mapping_from_node 17 | 18 | 19 | class OnnxDropoutDynamic(nn.Module, OnnxToTorchModule): # pylint: disable=missing-class-docstring 20 | def forward( # pylint: disable=missing-function-docstring, unused-argument 21 | self, 22 | input_tensor: torch.Tensor, 23 | ratio: float = 0.5, 24 | training_mode: Optional[torch.Tensor] = None, 25 | ) -> torch.Tensor: 26 | # Ignoring training_mode from ONNX and use the one from PyTorch 27 | return F.dropout(input_tensor, p=ratio, training=self.training) 28 | 29 | 30 | @add_converter(operation_type='Dropout', version=10) 31 | def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument 32 | node_attributes = node.attributes 33 | ratio = node_attributes.get('ratio', 0.5) 34 | 35 | torch_module = nn.Dropout(p=ratio) 36 | 37 | return OperationConverterResult( 38 | torch_module=torch_module, 39 | onnx_mapping=onnx_mapping_from_node(node=node), 40 | ) 41 | 42 | 43 | @add_converter(operation_type='Dropout', version=12) 44 | @add_converter(operation_type='Dropout', version=13) 45 | def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument 46 | node_attributes = node.attributes 47 | seed = node_attributes.get('seed') 48 | if seed is not None: 49 | raise NotImplementedError('Dropout nodes with seeds are not supported.') 50 | 51 | return OperationConverterResult( 52 | torch_module=OnnxDropoutDynamic(), 53 | onnx_mapping=onnx_mapping_from_node(node=node), 54 | ) 55 | -------------------------------------------------------------------------------- /onnx2torch/node_converters/einsum.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'OnnxEinsum', 3 | ] 4 | 5 | import torch 6 | from torch import nn 7 | 8 | from onnx2torch.node_converters.registry import add_converter 9 | from onnx2torch.onnx_graph import OnnxGraph 10 | from onnx2torch.onnx_node import OnnxNode 11 | from onnx2torch.utils.common import OnnxToTorchModule 12 | from onnx2torch.utils.common import OperationConverterResult 13 | from onnx2torch.utils.common import onnx_mapping_from_node 14 | 15 | 16 | class OnnxEinsum(nn.Module, OnnxToTorchModule): # pylint: disable=missing-docstring 17 | def __init__(self, equation: str): 18 | super().__init__() 19 | self.equation = equation 20 | 21 | def forward(self, *args): # pylint: disable=missing-function-docstring 22 | return torch.einsum(self.equation, *args) 23 | 24 | 25 | @add_converter(operation_type='Einsum', version=12) 26 | def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument 27 | return OperationConverterResult( 28 | torch_module=OnnxEinsum( 29 | equation=node.attributes['equation'], 30 | ), 31 | onnx_mapping=onnx_mapping_from_node(node=node), 32 | ) 33 | -------------------------------------------------------------------------------- /onnx2torch/node_converters/expand.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'OnnxExpand', 3 | ] 4 | 5 | import torch 6 | from torch import nn 7 | 8 | from onnx2torch.node_converters.registry import add_converter 9 | from onnx2torch.onnx_graph import OnnxGraph 10 | from onnx2torch.onnx_node import OnnxNode 11 | from onnx2torch.utils.common import OperationConverterResult 12 | from onnx2torch.utils.common import onnx_mapping_from_node 13 | from onnx2torch.utils.custom_export_to_onnx import DefaultExportToOnnx 14 | from onnx2torch.utils.custom_export_to_onnx import OnnxToTorchModuleWithCustomExport 15 | 16 | 17 | class OnnxExpand(nn.Module, OnnxToTorchModuleWithCustomExport): # pylint: disable=missing-docstring 18 | def forward( # pylint: disable=missing-function-docstring 19 | self, 20 | input_tensor: torch.Tensor, 21 | shape: torch.Tensor, 22 | ) -> torch.Tensor: 23 | def _forward(): 24 | return input_tensor * torch.ones(torch.Size(shape), dtype=input_tensor.dtype, device=input_tensor.device) 25 | 26 | if torch.onnx.is_in_onnx_export(): 27 | return DefaultExportToOnnx.export(_forward, 'Expand', input_tensor, shape, {}) 28 | 29 | return _forward() 30 | 31 | 32 | @add_converter(operation_type='Expand', version=8) 33 | @add_converter(operation_type='Expand', version=13) 34 | def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument 35 | return OperationConverterResult( 36 | torch_module=OnnxExpand(), 37 | onnx_mapping=onnx_mapping_from_node(node=node), 38 | ) 39 | -------------------------------------------------------------------------------- /onnx2torch/node_converters/eye_like.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'OnnxEyeLike', 3 | ] 4 | 5 | from typing import Optional 6 | 7 | import torch 8 | from torch import nn 9 | 10 | from onnx2torch.node_converters.registry import add_converter 11 | from onnx2torch.onnx_graph import OnnxGraph 12 | from onnx2torch.onnx_node import OnnxNode 13 | from onnx2torch.utils.common import OnnxToTorchModule 14 | from onnx2torch.utils.common import OperationConverterResult 15 | from onnx2torch.utils.common import onnx_mapping_from_node 16 | from onnx2torch.utils.dtype import onnx_dtype_to_torch_dtype 17 | 18 | 19 | class OnnxEyeLike(nn.Module, OnnxToTorchModule): # pylint: disable=missing-docstring 20 | def __init__(self, dtype: Optional[int] = None, k: int = 0): # pylint: disable=invalid-name 21 | super().__init__() 22 | self.dtype = dtype 23 | self.k = k # pylint: disable=invalid-name 24 | 25 | def forward(self, x: torch.Tensor) -> torch.Tensor: # pylint: disable=missing-function-docstring 26 | if len(x.shape) != 2: 27 | raise ValueError(f'EyeLike only supports 2D tensors, got {len(x.shape)}') 28 | 29 | dtype = x.dtype if self.dtype is None else onnx_dtype_to_torch_dtype(self.dtype) 30 | if not isinstance(dtype, torch.dtype): 31 | raise ValueError(f'Expected type of dtype is torch.dtype, got {type(dtype)}') 32 | 33 | rows, cols = x.size() 34 | if self.k > rows: 35 | raise ValueError( 36 | f'EyeLike attribute k should be less or equal than the zero dimension of input tensor,' 37 | f'got {self.k} and {rows}' 38 | ) 39 | 40 | if self.k == 0: 41 | return torch.eye(n=rows, m=cols, dtype=dtype) 42 | if self.k > 0: 43 | return torch.concat( 44 | [ 45 | torch.zeros(rows, self.k, dtype=dtype), 46 | torch.eye(n=rows, m=(cols - self.k), dtype=dtype), 47 | ], 48 | dim=1, 49 | ) 50 | return torch.concat( # k < 0: 51 | [ 52 | torch.zeros(-self.k, cols, dtype=dtype), 53 | torch.eye(n=(rows + self.k), m=cols, dtype=dtype), 54 | ], 55 | dim=0, 56 | ) 57 | 58 | 59 | @add_converter(operation_type='EyeLike', version=9) 60 | def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument 61 | node_attributes = node.attributes 62 | k = node_attributes.get('k', 0) # pylint: disable=invalid-name 63 | dtype = node_attributes.get('dtype', None) 64 | return OperationConverterResult( 65 | torch_module=OnnxEyeLike(dtype=dtype, k=k), 66 | onnx_mapping=onnx_mapping_from_node(node=node), 67 | ) 68 | -------------------------------------------------------------------------------- /onnx2torch/node_converters/flatten.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from onnx2torch.node_converters.registry import add_converter 5 | from onnx2torch.onnx_graph import OnnxGraph 6 | from onnx2torch.onnx_node import OnnxNode 7 | from onnx2torch.utils.common import OnnxToTorchModule 8 | from onnx2torch.utils.common import OperationConverterResult 9 | from onnx2torch.utils.common import onnx_mapping_from_node 10 | 11 | 12 | class OnnxFlatten(nn.Module, OnnxToTorchModule): # pylint: disable=missing-docstring 13 | def __init__(self, axis: int = 1): 14 | super().__init__() 15 | self.axis = axis 16 | 17 | def forward(self, x: torch.Tensor) -> torch.Tensor: # pylint: disable=missing-function-docstring 18 | x = torch.flatten(x, end_dim=self.axis - 1) 19 | return torch.flatten(x, start_dim=1) 20 | 21 | @classmethod 22 | def maybe_create_simple_flatten(cls, axis: int = 1) -> nn.Module: # pylint: disable=missing-docstring 23 | if axis == 1: 24 | return nn.Flatten(start_dim=axis) 25 | 26 | return cls(axis=axis) 27 | 28 | 29 | @add_converter(operation_type='Flatten', version=13) 30 | @add_converter(operation_type='Flatten', version=11) 31 | @add_converter(operation_type='Flatten', version=9) 32 | def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument 33 | axis = node.attributes.get('axis', 1) 34 | torch_module = OnnxFlatten.maybe_create_simple_flatten(axis=axis) 35 | 36 | return OperationConverterResult( 37 | torch_module=torch_module, 38 | onnx_mapping=onnx_mapping_from_node(node=node), 39 | ) 40 | -------------------------------------------------------------------------------- /onnx2torch/node_converters/functions.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'OnnxFunction', 3 | ] 4 | 5 | import torch 6 | from torch import nn 7 | 8 | from onnx2torch.node_converters.registry import add_converter 9 | from onnx2torch.onnx_graph import OnnxGraph 10 | from onnx2torch.onnx_node import OnnxNode 11 | from onnx2torch.utils.common import OnnxToTorchModule 12 | from onnx2torch.utils.common import OperationConverterResult 13 | from onnx2torch.utils.common import onnx_mapping_from_node 14 | 15 | # Exporting from pytorch to onnx operators atanh, asinh, acosh, cosh, sinh are not supported 16 | _TORCH_FUNCTION_FROM_ONNX_TYPE = { 17 | 'Abs': torch.abs, 18 | 'Acos': torch.acos, 19 | 'Asin': torch.asin, 20 | 'Atan': torch.atan, 21 | 'Cos': torch.cos, 22 | 'Exp': torch.exp, 23 | 'Log': torch.log, 24 | 'Sign': torch.sign, 25 | 'Sin': torch.sin, 26 | 'Tan': torch.tan, 27 | 'Tanh': torch.tanh, 28 | } 29 | 30 | 31 | class OnnxFunction(nn.Module, OnnxToTorchModule): # pylint: disable=missing-docstring 32 | def __init__(self, function_type: str): 33 | super().__init__() 34 | self.function = _TORCH_FUNCTION_FROM_ONNX_TYPE[function_type] 35 | 36 | def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: # pylint: disable=missing-function-docstring 37 | return self.function(input_tensor) 38 | 39 | 40 | @add_converter(operation_type='Abs', version=13) 41 | @add_converter(operation_type='Abs', version=6) 42 | @add_converter(operation_type='Acos', version=7) 43 | @add_converter(operation_type='Asin', version=7) 44 | @add_converter(operation_type='Atan', version=7) 45 | @add_converter(operation_type='Cos', version=7) 46 | @add_converter(operation_type='Exp', version=6) 47 | @add_converter(operation_type='Exp', version=13) 48 | @add_converter(operation_type='Log', version=13) 49 | @add_converter(operation_type='Log', version=6) 50 | @add_converter(operation_type='Sign', version=13) 51 | @add_converter(operation_type='Sign', version=9) 52 | @add_converter(operation_type='Sin', version=7) 53 | @add_converter(operation_type='Tan', version=7) 54 | @add_converter(operation_type='Tanh', version=13) 55 | @add_converter(operation_type='Tanh', version=6) 56 | def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument 57 | return OperationConverterResult( 58 | torch_module=OnnxFunction(node.operation_type), 59 | onnx_mapping=onnx_mapping_from_node(node=node), 60 | ) 61 | -------------------------------------------------------------------------------- /onnx2torch/node_converters/gemm.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'OnnxGemm', 3 | ] 4 | 5 | from typing import Optional 6 | 7 | import torch 8 | from torch import nn 9 | 10 | from onnx2torch.node_converters.registry import add_converter 11 | from onnx2torch.onnx_graph import OnnxGraph 12 | from onnx2torch.onnx_node import OnnxNode 13 | from onnx2torch.utils.common import OnnxMapping 14 | from onnx2torch.utils.common import OnnxToTorchModule 15 | from onnx2torch.utils.common import OperationConverterResult 16 | from onnx2torch.utils.common import onnx_mapping_from_node 17 | 18 | 19 | class OnnxGemm(nn.Module, OnnxToTorchModule): # pylint: disable=missing-class-docstring 20 | def __init__(self, alpha: float, beta: float, trans_a: bool, trans_b: bool): 21 | super().__init__() 22 | 23 | self.alpha = alpha 24 | self.beta = beta 25 | self.trans_a = trans_a 26 | self.trans_b = trans_b 27 | 28 | def forward( # pylint: disable=missing-function-docstring 29 | self, 30 | input_a: torch.Tensor, 31 | input_b: torch.Tensor, 32 | input_c: Optional[torch.Tensor] = None, 33 | ): 34 | if self.trans_a: 35 | input_a = torch.transpose(input_a, dim0=0, dim1=1) 36 | if self.trans_b: 37 | input_b = torch.transpose(input_b, dim0=0, dim1=1) 38 | 39 | output = input_a @ input_b * self.alpha 40 | if input_c is not None: 41 | output += input_c * self.beta 42 | 43 | return output 44 | 45 | 46 | @add_converter(operation_type='Gemm', version=9) 47 | @add_converter(operation_type='Gemm', version=11) 48 | @add_converter(operation_type='Gemm', version=13) 49 | def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: 50 | a_name = node.input_values[0] 51 | b_name = node.input_values[1] 52 | c_name = node.input_values[2] if len(node.input_values) > 2 else None 53 | 54 | node_attributes = node.attributes 55 | alpha = node_attributes.get('alpha', 1.0) 56 | beta = node_attributes.get('beta', 1.0) 57 | trans_a = node_attributes.get('transA', 0) != 0 58 | trans_b = node_attributes.get('transB', 0) != 0 59 | 60 | if not trans_a and b_name in graph.initializers and (c_name is None or c_name in graph.initializers): 61 | if c_name is None: 62 | bias = None 63 | else: 64 | bias = graph.initializers[c_name] 65 | bias = bias.to_torch() 66 | 67 | if bias is None or bias.dim() == 1: 68 | weights = graph.initializers[b_name] 69 | weights = weights.to_torch() 70 | if not trans_b: 71 | weights = weights.T 72 | 73 | in_features, out_features = weights.shape[1], weights.shape[0] 74 | torch_module = nn.Linear( 75 | in_features=in_features, 76 | out_features=out_features, 77 | bias=bias is not None, 78 | ) 79 | 80 | with torch.no_grad(): 81 | weights = weights * alpha 82 | torch_module.weight.data = weights 83 | if bias is not None: 84 | bias = bias * beta 85 | torch_module.bias.data = bias 86 | 87 | return OperationConverterResult( 88 | torch_module=torch_module, 89 | onnx_mapping=OnnxMapping( 90 | inputs=(a_name,), 91 | outputs=node.output_values, 92 | ), 93 | ) 94 | 95 | return OperationConverterResult( 96 | torch_module=OnnxGemm(alpha=alpha, beta=beta, trans_a=trans_a, trans_b=trans_b), 97 | onnx_mapping=onnx_mapping_from_node(node), 98 | ) 99 | -------------------------------------------------------------------------------- /onnx2torch/node_converters/global_average_pool.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=missing-docstring 2 | __all__ = [ 3 | 'OnnxGlobalAveragePool', 4 | 'OnnxGlobalAveragePoolWithKnownInputShape', 5 | ] 6 | 7 | from typing import List 8 | 9 | import torch 10 | from torch import nn 11 | 12 | from onnx2torch.node_converters.registry import add_converter 13 | from onnx2torch.onnx_graph import OnnxGraph 14 | from onnx2torch.onnx_node import OnnxNode 15 | from onnx2torch.utils.common import OperationConverterResult 16 | from onnx2torch.utils.common import get_shape_from_value_info 17 | from onnx2torch.utils.common import onnx_mapping_from_node 18 | from onnx2torch.utils.custom_export_to_onnx import DefaultExportToOnnx 19 | from onnx2torch.utils.custom_export_to_onnx import OnnxToTorchModuleWithCustomExport 20 | 21 | 22 | class OnnxGlobalAveragePool(nn.Module, OnnxToTorchModuleWithCustomExport): 23 | def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: 24 | def _forward(): 25 | x_dims = list(range(2, len(input_tensor.shape))) 26 | return torch.mean(input_tensor, dim=x_dims, keepdim=True) 27 | 28 | if torch.onnx.is_in_onnx_export(): 29 | return DefaultExportToOnnx.export(_forward, 'GlobalAveragePool', input_tensor, {}) 30 | 31 | return _forward() 32 | 33 | 34 | class OnnxGlobalAveragePoolWithKnownInputShape(nn.Module, OnnxToTorchModuleWithCustomExport): 35 | def __init__(self, input_shape: List[int]): 36 | super().__init__() 37 | self._x_dims = list(range(2, len(input_shape))) 38 | 39 | def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: 40 | def _forward() -> torch.Tensor: 41 | return torch.mean(input_tensor, dim=self._x_dims, keepdim=True) 42 | 43 | if torch.onnx.is_in_onnx_export(): 44 | return DefaultExportToOnnx.export(_forward, 'GlobalAveragePool', input_tensor, {}) 45 | 46 | return _forward() 47 | 48 | 49 | @add_converter(operation_type='GlobalAveragePool', version=1) 50 | def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: 51 | input_value_info = graph.value_info[node.input_values[0]] 52 | input_shape = get_shape_from_value_info(input_value_info) 53 | 54 | if input_shape is not None: 55 | torch_module = OnnxGlobalAveragePoolWithKnownInputShape(input_shape=input_shape) 56 | else: 57 | torch_module = OnnxGlobalAveragePool() 58 | 59 | return OperationConverterResult( 60 | torch_module=torch_module, 61 | onnx_mapping=onnx_mapping_from_node(node=node), 62 | ) 63 | -------------------------------------------------------------------------------- /onnx2torch/node_converters/identity.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from onnx2torch.node_converters.registry import add_converter 5 | from onnx2torch.onnx_graph import OnnxGraph 6 | from onnx2torch.onnx_node import OnnxNode 7 | from onnx2torch.utils.common import OnnxToTorchModule 8 | from onnx2torch.utils.common import OperationConverterResult 9 | from onnx2torch.utils.common import onnx_mapping_from_node 10 | 11 | 12 | class OnnxCopyIdentity(nn.Module, OnnxToTorchModule): # pylint: disable=missing-class-docstring 13 | def forward(self, x: torch.Tensor): # pylint: disable=missing-function-docstring 14 | return x.clone() 15 | 16 | 17 | @add_converter(operation_type='Identity', version=16) 18 | @add_converter(operation_type='Identity', version=14) 19 | @add_converter(operation_type='Identity', version=13) 20 | @add_converter(operation_type='Identity', version=1) 21 | def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument 22 | # We need copy identity because in onnx identity create new tensor. 23 | # Pytorch identity simply returns the same tensor. 24 | # Which ruin quantization logic, because we should mark quantized tensors. 25 | # For example, input quantization node will be supressed if input tensor is already quantized. 26 | return OperationConverterResult( 27 | torch_module=OnnxCopyIdentity(), 28 | onnx_mapping=onnx_mapping_from_node(node=node), 29 | ) 30 | -------------------------------------------------------------------------------- /onnx2torch/node_converters/instance_norm.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'OnnxInstanceNorm', 3 | ] 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn 8 | 9 | from onnx2torch.node_converters.registry import add_converter 10 | from onnx2torch.onnx_graph import OnnxGraph 11 | from onnx2torch.onnx_node import OnnxNode 12 | from onnx2torch.utils.common import OnnxMapping 13 | from onnx2torch.utils.common import OnnxToTorchModule 14 | from onnx2torch.utils.common import OperationConverterResult 15 | from onnx2torch.utils.common import get_shape_from_value_info 16 | from onnx2torch.utils.common import onnx_mapping_from_node 17 | 18 | _IN_CLASS_FROM_SPATIAL_RANK = { 19 | 0: nn.InstanceNorm1d, 20 | 1: nn.InstanceNorm1d, 21 | 2: nn.InstanceNorm2d, 22 | 3: nn.InstanceNorm3d, 23 | } 24 | 25 | 26 | class OnnxInstanceNorm(nn.Module, OnnxToTorchModule): # pylint: disable=missing-docstring 27 | def __init__(self, momentum: float, epsilon: float): 28 | super().__init__() 29 | self.momentum = momentum 30 | self.epsilon = epsilon 31 | 32 | def forward( # pylint: disable=missing-function-docstring 33 | self, 34 | input_data: torch.Tensor, 35 | weight: torch.Tensor, 36 | bias: torch.Tensor, 37 | ) -> torch.Tensor: 38 | return F.instance_norm( 39 | input=input_data, 40 | running_mean=None, 41 | running_var=None, 42 | weight=weight, 43 | bias=bias, 44 | use_input_stats=True, 45 | momentum=self.momentum, 46 | eps=self.epsilon, 47 | ) 48 | 49 | 50 | @add_converter(operation_type='InstanceNormalization', version=1) 51 | @add_converter(operation_type='InstanceNormalization', version=6) 52 | def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: 53 | node_attributes = node.attributes 54 | epsilon = node_attributes.get('epsilon', 1e-5) 55 | momentum = 0.1 56 | 57 | if all(value_name in graph.initializers for value_name in node.input_values[1:]): 58 | input_value_info = graph.value_info[node.input_values[0]] 59 | input_shape = get_shape_from_value_info(input_value_info) 60 | spatial_rank = len(input_shape) - 2 61 | try: 62 | in_class = _IN_CLASS_FROM_SPATIAL_RANK[spatial_rank] 63 | except KeyError as exc: 64 | raise NotImplementedError( 65 | f'InstanceNorm operation with spatial rank == {spatial_rank} is not implemented' 66 | ) from exc 67 | 68 | scale_value_name = node.input_values[1] 69 | bias_value_name = node.input_values[2] 70 | 71 | scale = graph.initializers[scale_value_name].to_torch() 72 | torch_module = in_class( 73 | num_features=scale.size()[0], 74 | eps=epsilon, 75 | momentum=momentum, 76 | affine=True, 77 | track_running_stats=False, 78 | ) 79 | with torch.no_grad(): 80 | torch_module.weight.data = graph.initializers[scale_value_name].to_torch() 81 | torch_module.bias.data = graph.initializers[bias_value_name].to_torch() 82 | 83 | onnx_mapping = OnnxMapping(inputs=(node.input_values[0],), outputs=node.output_values) 84 | else: 85 | torch_module = OnnxInstanceNorm(momentum=momentum, epsilon=epsilon) 86 | onnx_mapping = onnx_mapping_from_node(node) 87 | 88 | return OperationConverterResult(torch_module=torch_module, onnx_mapping=onnx_mapping) 89 | -------------------------------------------------------------------------------- /onnx2torch/node_converters/isinf.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=missing-docstring 2 | __all__ = [ 3 | 'OnnxIsInf', 4 | ] 5 | 6 | import torch 7 | from torch import nn 8 | 9 | from onnx2torch.node_converters.registry import add_converter 10 | from onnx2torch.onnx_graph import OnnxGraph 11 | from onnx2torch.onnx_node import OnnxNode 12 | from onnx2torch.utils.common import OnnxMapping 13 | from onnx2torch.utils.common import OnnxToTorchModule 14 | from onnx2torch.utils.common import OperationConverterResult 15 | 16 | 17 | class OnnxIsInf(nn.Module, OnnxToTorchModule): 18 | def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: 19 | return torch.isinf(input_tensor) 20 | 21 | 22 | @add_converter(operation_type='IsInf', version=10) 23 | def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: 24 | del graph 25 | torch_module = OnnxIsInf() 26 | 27 | return OperationConverterResult( 28 | torch_module=torch_module, 29 | onnx_mapping=OnnxMapping( 30 | inputs=(node.input_values[0],), 31 | outputs=node.output_values, 32 | ), 33 | ) 34 | -------------------------------------------------------------------------------- /onnx2torch/node_converters/isnan.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=missing-docstring 2 | __all__ = [ 3 | 'OnnxIsNaN', 4 | ] 5 | 6 | import torch 7 | from torch import nn 8 | 9 | from onnx2torch.node_converters.registry import add_converter 10 | from onnx2torch.onnx_graph import OnnxGraph 11 | from onnx2torch.onnx_node import OnnxNode 12 | from onnx2torch.utils.common import OnnxMapping 13 | from onnx2torch.utils.common import OnnxToTorchModule 14 | from onnx2torch.utils.common import OperationConverterResult 15 | 16 | 17 | class OnnxIsNaN(nn.Module, OnnxToTorchModule): 18 | def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: 19 | return torch.isnan(input_tensor) 20 | 21 | 22 | @add_converter(operation_type='IsNaN', version=13) 23 | def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: 24 | del graph 25 | torch_module = OnnxIsNaN() 26 | 27 | return OperationConverterResult( 28 | torch_module=torch_module, 29 | onnx_mapping=OnnxMapping( 30 | inputs=(node.input_values[0],), 31 | outputs=node.output_values, 32 | ), 33 | ) 34 | -------------------------------------------------------------------------------- /onnx2torch/node_converters/layer_norm.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'OnnxLayerNorm', 3 | ] 4 | 5 | from typing import Optional 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | from torch import nn 10 | 11 | from onnx2torch.node_converters.registry import add_converter 12 | from onnx2torch.onnx_graph import OnnxGraph 13 | from onnx2torch.onnx_node import OnnxNode 14 | from onnx2torch.utils.common import OnnxMapping 15 | from onnx2torch.utils.common import OnnxToTorchModule 16 | from onnx2torch.utils.common import OperationConverterResult 17 | from onnx2torch.utils.common import get_shape_from_value_info 18 | from onnx2torch.utils.common import onnx_mapping_from_node 19 | 20 | AXIS_DEFAULT_VALUE = -1 21 | EPSILON_DEFAULT_VALUE = 1e-5 22 | 23 | 24 | class OnnxLayerNorm(nn.Module, OnnxToTorchModule): # pylint: disable=missing-docstring 25 | def __init__(self, axis: int, epsilon: float): 26 | super().__init__() 27 | self.axis = axis 28 | self.epsilon = epsilon 29 | 30 | def forward( # pylint: disable=missing-function-docstring 31 | self, 32 | inputs: torch.Tensor, 33 | scale: torch.Tensor, 34 | bias: Optional[torch.Tensor] = None, 35 | ) -> torch.Tensor: 36 | normalized_shape = inputs.shape[self.axis :] 37 | return F.layer_norm( 38 | input=inputs, 39 | normalized_shape=normalized_shape, 40 | weight=scale, 41 | bias=bias, 42 | eps=self.epsilon, 43 | ) 44 | 45 | 46 | @add_converter(operation_type='LayerNormalization', version=17) 47 | def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: 48 | node_attributes = node.attributes 49 | 50 | axis = node_attributes.get('axis', AXIS_DEFAULT_VALUE) 51 | epsilon = node_attributes.get('epsilon', EPSILON_DEFAULT_VALUE) 52 | 53 | if all(value_name in graph.initializers for value_name in node.input_values[1:]): 54 | input_value_info = graph.value_info[node.input_values[0]] 55 | input_shape = get_shape_from_value_info(input_value_info) 56 | 57 | torch_module = nn.LayerNorm( 58 | normalized_shape=input_shape[axis:], 59 | eps=epsilon, 60 | elementwise_affine=True, 61 | ) 62 | 63 | scale_value_name = node.input_values[1] 64 | bias_value_name = node.input_values[2] if len(node.input_values) > 2 else None 65 | 66 | with torch.no_grad(): 67 | torch_module.weight.data = graph.initializers[scale_value_name].to_torch() 68 | if bias_value_name is not None: 69 | torch_module.bias.data = graph.initializers[bias_value_name].to_torch() 70 | 71 | onnx_mapping = OnnxMapping(inputs=(node.input_values[0],), outputs=node.output_values) 72 | else: 73 | input_value_info = graph.value_info[node.input_values[0]] 74 | input_shape = get_shape_from_value_info(input_value_info) 75 | torch_module = OnnxLayerNorm(axis=axis, epsilon=epsilon) 76 | onnx_mapping = onnx_mapping_from_node(node) 77 | 78 | return OperationConverterResult(torch_module=torch_module, onnx_mapping=onnx_mapping) 79 | -------------------------------------------------------------------------------- /onnx2torch/node_converters/logical.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=missing-docstring 2 | __all__ = [ 3 | 'OnnxNot', 4 | 'OnnxLogical', 5 | ] 6 | 7 | from typing import Optional 8 | 9 | import torch 10 | from torch import nn 11 | 12 | from onnx2torch.node_converters.registry import add_converter 13 | from onnx2torch.onnx_graph import OnnxGraph 14 | from onnx2torch.onnx_node import OnnxNode 15 | from onnx2torch.utils.common import OnnxToTorchModule 16 | from onnx2torch.utils.common import OperationConverterResult 17 | from onnx2torch.utils.common import old_style_broadcast 18 | from onnx2torch.utils.common import onnx_mapping_from_node 19 | from onnx2torch.utils.custom_export_to_onnx import DefaultExportToOnnx 20 | from onnx2torch.utils.custom_export_to_onnx import OnnxToTorchModuleWithCustomExport 21 | 22 | _TORCH_FUNCTION_FROM_ONNX_TYPE = { 23 | 'Or': torch.logical_or, 24 | 'And': torch.logical_and, 25 | 'Xor': torch.logical_xor, 26 | } 27 | 28 | 29 | class OnnxNot(nn.Module, OnnxToTorchModuleWithCustomExport): 30 | def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: 31 | def _forward() -> torch.Tensor: 32 | return torch.logical_not(input_tensor) 33 | 34 | if torch.onnx.is_in_onnx_export(): 35 | return DefaultExportToOnnx.export(_forward, 'Not', input_tensor, {}) 36 | 37 | return _forward() 38 | 39 | 40 | class OnnxLogical(nn.Module, OnnxToTorchModule): 41 | def __init__(self, operation_type: str, broadcast: Optional[int] = None, axis: Optional[int] = None): 42 | super().__init__() 43 | self.broadcast = broadcast 44 | self.axis = axis 45 | 46 | self.logic_op_function = _TORCH_FUNCTION_FROM_ONNX_TYPE[operation_type] 47 | 48 | def forward(self, first_tensor: torch.Tensor, second_tensor: torch.Tensor): 49 | if self.broadcast == 1 and self.axis is not None: 50 | second_tensor = old_style_broadcast(first_tensor, second_tensor, self.axis) 51 | 52 | return self.logic_op_function(first_tensor, second_tensor) 53 | 54 | 55 | @add_converter(operation_type='Xor', version=1) 56 | @add_converter(operation_type='Xor', version=7) 57 | @add_converter(operation_type='And', version=1) 58 | @add_converter(operation_type='And', version=7) 59 | @add_converter(operation_type='Or', version=1) 60 | @add_converter(operation_type='Or', version=7) 61 | def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: 62 | del graph 63 | return OperationConverterResult( 64 | torch_module=OnnxLogical( 65 | operation_type=node.operation_type, 66 | broadcast=node.attributes.get('broadcast', None), 67 | axis=node.attributes.get('axis', None), 68 | ), 69 | onnx_mapping=onnx_mapping_from_node(node=node), 70 | ) 71 | 72 | 73 | @add_converter(operation_type='Not', version=1) 74 | def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: 75 | del graph 76 | return OperationConverterResult( 77 | torch_module=OnnxNot(), 78 | onnx_mapping=onnx_mapping_from_node(node=node), 79 | ) 80 | -------------------------------------------------------------------------------- /onnx2torch/node_converters/lrn.py: -------------------------------------------------------------------------------- 1 | __all__ = [] 2 | 3 | from torch import nn 4 | 5 | from onnx2torch.node_converters.registry import add_converter 6 | from onnx2torch.onnx_graph import OnnxGraph 7 | from onnx2torch.onnx_node import OnnxNode 8 | from onnx2torch.utils.common import OperationConverterResult 9 | from onnx2torch.utils.common import onnx_mapping_from_node 10 | 11 | 12 | @add_converter(operation_type='LRN', version=13) 13 | @add_converter(operation_type='LRN', version=1) 14 | def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument 15 | size = node.attributes.get('size') 16 | alpha = node.attributes.get('alpha', 0.0001) 17 | beta = node.attributes.get('beta', 0.75) 18 | k = node.attributes.get('bias', 1) # pylint: disable=invalid-name 19 | 20 | return OperationConverterResult( 21 | torch_module=nn.LocalResponseNorm(size=size, alpha=alpha, beta=beta, k=k), 22 | onnx_mapping=onnx_mapping_from_node(node=node), 23 | ) 24 | -------------------------------------------------------------------------------- /onnx2torch/node_converters/matmul.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'OnnxMatMul', 3 | ] 4 | 5 | import torch 6 | from torch import nn 7 | 8 | from onnx2torch.node_converters.registry import add_converter 9 | from onnx2torch.onnx_graph import OnnxGraph 10 | from onnx2torch.onnx_node import OnnxNode 11 | from onnx2torch.utils.common import OnnxToTorchModule 12 | from onnx2torch.utils.common import OperationConverterResult 13 | from onnx2torch.utils.common import onnx_mapping_from_node 14 | 15 | 16 | class OnnxMatMul(nn.Module, OnnxToTorchModule): # pylint: disable=missing-class-docstring 17 | def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # pylint: disable=missing-function-docstring 18 | return torch.matmul(x, y) 19 | 20 | 21 | @add_converter(operation_type='MatMul', version=1) 22 | @add_converter(operation_type='MatMul', version=9) 23 | @add_converter(operation_type='MatMul', version=13) 24 | def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument 25 | return OperationConverterResult( 26 | torch_module=OnnxMatMul(), 27 | onnx_mapping=onnx_mapping_from_node(node=node), 28 | ) 29 | -------------------------------------------------------------------------------- /onnx2torch/node_converters/max_pool.py: -------------------------------------------------------------------------------- 1 | __all__ = [] 2 | 3 | from torch import nn 4 | 5 | from onnx2torch.node_converters.registry import add_converter 6 | from onnx2torch.onnx_graph import OnnxGraph 7 | from onnx2torch.onnx_node import OnnxNode 8 | from onnx2torch.utils.common import OperationConverterResult 9 | from onnx2torch.utils.common import get_shape_from_value_info 10 | from onnx2torch.utils.common import onnx_mapping_from_node 11 | from onnx2torch.utils.padding import onnx_auto_pad_to_torch_padding 12 | 13 | _MAXPOOL_CLASS_FROM_SPATIAL_RANK = { 14 | 1: nn.MaxPool1d, 15 | 2: nn.MaxPool2d, 16 | 3: nn.MaxPool3d, 17 | } 18 | 19 | 20 | @add_converter(operation_type='MaxPool', version=12) 21 | @add_converter(operation_type='MaxPool', version=11) 22 | @add_converter(operation_type='MaxPool', version=10) 23 | @add_converter(operation_type='MaxPool', version=8) 24 | def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: 25 | input_value_info = graph.value_info[node.input_values[0]] 26 | input_shape = get_shape_from_value_info(input_value_info) 27 | 28 | spatial_rank = len(input_shape) - 2 29 | try: 30 | maxpool_class = _MAXPOOL_CLASS_FROM_SPATIAL_RANK[spatial_rank] 31 | except KeyError as exc: 32 | raise NotImplementedError(f'Max pool operation with spatial rank == {spatial_rank} is not implemented') from exc 33 | 34 | node_attributes = node.attributes 35 | # required 36 | kernel_shape = node_attributes['kernel_shape'] 37 | # optional 38 | ceil_mode = node_attributes.get('ceil_mode', 0) 39 | dilation = node_attributes.get('dilations', 1) 40 | strides = node_attributes.get('strides', 1) 41 | storage_order = node_attributes.get('storage_order', 0) 42 | if storage_order != 0: 43 | raise NotImplementedError('Only row major (0) order is supported.') 44 | 45 | padding, padding_module = onnx_auto_pad_to_torch_padding( 46 | onnx_padding=node_attributes.get('pads', [0] * spatial_rank * 2), 47 | auto_pad=node_attributes.get('auto_pad', 'NOTSET'), 48 | ) 49 | 50 | torch_module = maxpool_class( 51 | kernel_size=kernel_shape, 52 | stride=strides, 53 | padding=padding, 54 | dilation=dilation, 55 | ceil_mode=ceil_mode == 1, 56 | ) 57 | if padding_module is not None: 58 | # MaxPool must ignore padded values, so we should pad by -inf 59 | padding_module.constant_value = float('-inf') 60 | torch_module = nn.Sequential(padding_module, torch_module) 61 | 62 | return OperationConverterResult( 63 | torch_module=torch_module, 64 | onnx_mapping=onnx_mapping_from_node(node=node), 65 | ) 66 | -------------------------------------------------------------------------------- /onnx2torch/node_converters/mean.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'OnnxMean', 3 | ] 4 | 5 | import torch 6 | 7 | from onnx2torch.node_converters.base_element_wise import OnnxBaseElementWise 8 | from onnx2torch.node_converters.registry import add_converter 9 | from onnx2torch.onnx_graph import OnnxGraph 10 | from onnx2torch.onnx_node import OnnxNode 11 | from onnx2torch.utils.common import OperationConverterResult 12 | from onnx2torch.utils.common import onnx_mapping_from_node 13 | 14 | 15 | class OnnxMean(OnnxBaseElementWise): # pylint: disable=missing-docstring 16 | def __init__(self): 17 | super().__init__(op_type='Mean') 18 | 19 | def apply_reduction(self, *tensors: torch.Tensor) -> torch.Tensor: # pylint: disable=missing-function-docstring 20 | broadcast_shape = self._broadcast_shape(*tensors) 21 | 22 | output = torch.zeros(broadcast_shape, dtype=tensors[0].dtype, device=tensors[0].device) 23 | for y in tensors: 24 | output.add_(y) 25 | 26 | output = output.div(len(tensors)) # Divide by the number of tensors 27 | return output 28 | 29 | 30 | @add_converter(operation_type='Mean', version=8) 31 | @add_converter(operation_type='Mean', version=13) 32 | def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument 33 | return OperationConverterResult( 34 | torch_module=OnnxMean(), 35 | onnx_mapping=onnx_mapping_from_node(node=node), 36 | ) 37 | -------------------------------------------------------------------------------- /onnx2torch/node_converters/min_max.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'OnnxMinMax', 3 | ] 4 | 5 | import torch 6 | 7 | from onnx2torch.node_converters.base_element_wise import OnnxBaseElementWise 8 | from onnx2torch.node_converters.registry import add_converter 9 | from onnx2torch.onnx_graph import OnnxGraph 10 | from onnx2torch.onnx_node import OnnxNode 11 | from onnx2torch.utils.common import OperationConverterResult 12 | from onnx2torch.utils.common import onnx_mapping_from_node 13 | 14 | 15 | class OnnxMinMax(OnnxBaseElementWise): # pylint: disable=missing-docstring 16 | _OPERATORS = { 17 | 'Min': torch.amin, 18 | 'Max': torch.amax, 19 | } 20 | 21 | def __init__(self, op_type: str): 22 | super().__init__(op_type=op_type) 23 | self._operator = self._OPERATORS[op_type] 24 | 25 | def apply_reduction(self, *tensors: torch.Tensor) -> torch.Tensor: # pylint: disable=missing-function-docstring 26 | broadcast_shape = self._broadcast_shape(*tensors) 27 | broadcast_tensors = [t.broadcast_to(broadcast_shape) for t in tensors] 28 | stacked_tensors = torch.stack(broadcast_tensors) 29 | output = self._operator(stacked_tensors, dim=0) 30 | return output 31 | 32 | 33 | @add_converter(operation_type='Min', version=8) 34 | @add_converter(operation_type='Min', version=12) 35 | @add_converter(operation_type='Min', version=13) 36 | @add_converter(operation_type='Max', version=8) 37 | @add_converter(operation_type='Max', version=12) 38 | @add_converter(operation_type='Max', version=13) 39 | def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument 40 | return OperationConverterResult( 41 | torch_module=OnnxMinMax(node.operation_type), 42 | onnx_mapping=onnx_mapping_from_node(node=node), 43 | ) 44 | -------------------------------------------------------------------------------- /onnx2torch/node_converters/mod.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'OnnxMod', 3 | ] 4 | 5 | import torch 6 | from torch import nn 7 | 8 | from onnx2torch.node_converters.registry import add_converter 9 | from onnx2torch.onnx_graph import OnnxGraph 10 | from onnx2torch.onnx_node import OnnxNode 11 | from onnx2torch.utils.common import OnnxToTorchModule 12 | from onnx2torch.utils.common import OperationConverterResult 13 | from onnx2torch.utils.common import onnx_mapping_from_node 14 | 15 | 16 | class OnnxMod(nn.Module, OnnxToTorchModule): # pylint: disable=missing-class-docstring 17 | def __init__(self, fmod: int): 18 | super().__init__() 19 | self.fmod = fmod 20 | 21 | if self.fmod not in [0, 1]: 22 | raise ValueError(f'OnnxMod fom must be 0 or 1, but get {self.fmod}') 23 | 24 | def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # pylint: disable=missing-function-docstring 25 | return torch.fmod(x, y) if self.fmod else torch.remainder(x, y) 26 | 27 | 28 | @add_converter(operation_type='Mod', version=10) 29 | @add_converter(operation_type='Mod', version=13) 30 | def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument 31 | node_attributes = node.attributes 32 | fmod = node_attributes.get('fmod', 0) 33 | return OperationConverterResult( 34 | torch_module=OnnxMod(fmod=fmod), 35 | onnx_mapping=onnx_mapping_from_node(node=node), 36 | ) 37 | -------------------------------------------------------------------------------- /onnx2torch/node_converters/neg.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'OnnxNeg', 3 | ] 4 | 5 | import torch 6 | from torch import nn 7 | 8 | from onnx2torch.node_converters.registry import add_converter 9 | from onnx2torch.onnx_graph import OnnxGraph 10 | from onnx2torch.onnx_node import OnnxNode 11 | from onnx2torch.utils.common import OnnxToTorchModule 12 | from onnx2torch.utils.common import OperationConverterResult 13 | from onnx2torch.utils.common import onnx_mapping_from_node 14 | 15 | 16 | class OnnxNeg(nn.Module, OnnxToTorchModule): # pylint: disable=missing-class-docstring 17 | def forward(self, input_tensor: torch.Tensor): # pylint: disable=missing-function-docstring 18 | return -input_tensor 19 | 20 | 21 | @add_converter(operation_type='Neg', version=1) 22 | @add_converter(operation_type='Neg', version=6) 23 | @add_converter(operation_type='Neg', version=13) 24 | def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument 25 | return OperationConverterResult( 26 | torch_module=OnnxNeg(), 27 | onnx_mapping=onnx_mapping_from_node(node=node), 28 | ) 29 | -------------------------------------------------------------------------------- /onnx2torch/node_converters/nonzero.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=missing-docstring 2 | __all__ = [ 3 | 'OnnxNonZero', 4 | ] 5 | 6 | import torch 7 | from torch import nn 8 | 9 | from onnx2torch.node_converters.registry import add_converter 10 | from onnx2torch.onnx_graph import OnnxGraph 11 | from onnx2torch.onnx_node import OnnxNode 12 | from onnx2torch.utils.common import OnnxMapping 13 | from onnx2torch.utils.common import OnnxToTorchModule 14 | from onnx2torch.utils.common import OperationConverterResult 15 | 16 | 17 | class OnnxNonZero(nn.Module, OnnxToTorchModule): 18 | def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: 19 | return torch.nonzero(input_tensor) 20 | 21 | 22 | @add_converter(operation_type='NonZero', version=13) 23 | def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: 24 | del graph 25 | torch_module = OnnxNonZero() 26 | 27 | return OperationConverterResult( 28 | torch_module=torch_module, 29 | onnx_mapping=OnnxMapping( 30 | inputs=(node.input_values[0],), 31 | outputs=node.output_values, 32 | ), 33 | ) 34 | -------------------------------------------------------------------------------- /onnx2torch/node_converters/pow.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'OnnxPow', 3 | 'OnnxSqrt', 4 | ] 5 | 6 | from typing import Optional 7 | 8 | import torch 9 | from torch import nn 10 | 11 | from onnx2torch.node_converters.registry import add_converter 12 | from onnx2torch.onnx_graph import OnnxGraph 13 | from onnx2torch.onnx_node import OnnxNode 14 | from onnx2torch.utils.common import OnnxToTorchModule 15 | from onnx2torch.utils.common import OperationConverterResult 16 | from onnx2torch.utils.common import old_style_broadcast 17 | from onnx2torch.utils.common import onnx_mapping_from_node 18 | 19 | 20 | class OnnxPow(nn.Module, OnnxToTorchModule): # pylint: disable=missing-class-docstring 21 | def __init__(self, broadcast: Optional[int] = None, axis: Optional[int] = None): 22 | super().__init__() 23 | self.axis = axis 24 | self.broadcast = broadcast 25 | 26 | def forward( # pylint: disable=missing-function-docstring 27 | self, 28 | input_tensor: torch.Tensor, 29 | exponent: torch.Tensor, 30 | ) -> torch.Tensor: 31 | if self.broadcast == 1 and self.axis is not None: 32 | exponent = old_style_broadcast(input_tensor, exponent, self.axis) 33 | 34 | return torch.pow(input_tensor, exponent) 35 | 36 | 37 | class OnnxSqrt(nn.Module, OnnxToTorchModule): # pylint: disable=missing-class-docstring 38 | def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: # pylint: disable=missing-function-docstring 39 | return torch.sqrt(input_tensor) 40 | 41 | 42 | @add_converter(operation_type='Pow', version=1) 43 | @add_converter(operation_type='Pow', version=7) 44 | @add_converter(operation_type='Pow', version=12) 45 | @add_converter(operation_type='Pow', version=13) 46 | @add_converter(operation_type='Pow', version=15) 47 | def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument 48 | return OperationConverterResult( 49 | torch_module=OnnxPow( 50 | broadcast=node.attributes.get('broadcast', None), 51 | axis=node.attributes.get('axis', None), 52 | ), 53 | onnx_mapping=onnx_mapping_from_node(node=node), 54 | ) 55 | 56 | 57 | @add_converter(operation_type='Sqrt', version=1) 58 | @add_converter(operation_type='Sqrt', version=6) 59 | @add_converter(operation_type='Sqrt', version=13) 60 | def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument 61 | return OperationConverterResult( 62 | torch_module=OnnxSqrt(), 63 | onnx_mapping=onnx_mapping_from_node(node=node), 64 | ) 65 | -------------------------------------------------------------------------------- /onnx2torch/node_converters/range.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=missing-docstring 2 | __all__ = [ 3 | 'OnnxRange', 4 | ] 5 | 6 | from typing import Union 7 | 8 | import torch 9 | from torch import nn 10 | 11 | from onnx2torch.node_converters.registry import add_converter 12 | from onnx2torch.onnx_graph import OnnxGraph 13 | from onnx2torch.onnx_node import OnnxNode 14 | from onnx2torch.utils.common import OperationConverterResult 15 | from onnx2torch.utils.common import onnx_mapping_from_node 16 | from onnx2torch.utils.custom_export_to_onnx import DefaultExportToOnnx 17 | from onnx2torch.utils.custom_export_to_onnx import OnnxToTorchModuleWithCustomExport 18 | 19 | 20 | class OnnxRange(nn.Module, OnnxToTorchModuleWithCustomExport): 21 | def __init__(self): 22 | super().__init__() 23 | self.register_buffer('dummy_buffer', torch.Tensor(), persistent=False) 24 | 25 | @staticmethod 26 | def _get_scalar(value) -> Union[float, int]: 27 | if isinstance(value, torch.Tensor): 28 | return value.item() 29 | 30 | return value 31 | 32 | def _arange( 33 | self, 34 | start: Union[torch.Tensor, float, int], 35 | limit: Union[torch.Tensor, float, int], 36 | delta: Union[torch.Tensor, float, int], 37 | ) -> torch.Tensor: 38 | return torch.arange( 39 | start=self._get_scalar(start), 40 | end=self._get_scalar(limit), 41 | step=self._get_scalar(delta), 42 | device=self.dummy_buffer.device, 43 | ) 44 | 45 | def forward( 46 | self, 47 | start: Union[torch.Tensor, float, int], 48 | limit: Union[torch.Tensor, float, int], 49 | delta: Union[torch.Tensor, float, int], 50 | ) -> torch.Tensor: 51 | def _forward() -> torch.Tensor: 52 | return self._arange(start, limit, delta) 53 | 54 | if torch.onnx.is_in_onnx_export(): 55 | return DefaultExportToOnnx.export(_forward, 'Range', start, limit, delta, {}) 56 | 57 | return _forward() 58 | 59 | 60 | @add_converter(operation_type='Range', version=11) 61 | def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: 62 | del graph 63 | return OperationConverterResult( 64 | torch_module=OnnxRange(), 65 | onnx_mapping=onnx_mapping_from_node(node), 66 | ) 67 | -------------------------------------------------------------------------------- /onnx2torch/node_converters/reciprocal.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'OnnxReciprocal', 3 | ] 4 | 5 | import torch 6 | from torch import nn 7 | 8 | from onnx2torch.node_converters.registry import add_converter 9 | from onnx2torch.onnx_graph import OnnxGraph 10 | from onnx2torch.onnx_node import OnnxNode 11 | from onnx2torch.utils.common import OnnxToTorchModule 12 | from onnx2torch.utils.common import OperationConverterResult 13 | from onnx2torch.utils.common import onnx_mapping_from_node 14 | 15 | 16 | class OnnxReciprocal(nn.Module, OnnxToTorchModule): # pylint: disable=missing-class-docstring 17 | def forward(self, x): # pylint: disable=missing-function-docstring 18 | return torch.reciprocal(x) 19 | 20 | 21 | @add_converter(operation_type='Reciprocal', version=1) 22 | @add_converter(operation_type='Reciprocal', version=6) 23 | @add_converter(operation_type='Reciprocal', version=13) 24 | def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument 25 | return OperationConverterResult( 26 | torch_module=OnnxReciprocal(), 27 | onnx_mapping=onnx_mapping_from_node(node=node), 28 | ) 29 | -------------------------------------------------------------------------------- /onnx2torch/node_converters/registry.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Callable 3 | from typing import NamedTuple 4 | 5 | from onnx import defs 6 | 7 | from onnx2torch.onnx_graph import OnnxGraph 8 | from onnx2torch.onnx_node import OnnxNode 9 | from onnx2torch.utils.common import OperationConverterResult 10 | 11 | _LOGGER = logging.getLogger(__name__) 12 | _CONVERTER_REGISTRY = {} 13 | 14 | 15 | class OperationDescription(NamedTuple): # pylint: disable=missing-class-docstring 16 | domain: str 17 | operation_type: str 18 | version: int 19 | 20 | 21 | TConverter = Callable[[OnnxNode, OnnxGraph], OperationConverterResult] 22 | 23 | 24 | def add_converter( # pylint: disable=missing-function-docstring 25 | operation_type: str, 26 | version: int, 27 | domain: str = defs.ONNX_DOMAIN, 28 | ): 29 | description = OperationDescription( 30 | domain=domain, 31 | operation_type=operation_type, 32 | version=version, 33 | ) 34 | 35 | def deco(converter: TConverter): 36 | if description in _CONVERTER_REGISTRY: 37 | raise ValueError(f'Operation "{description}" already registered') 38 | 39 | _CONVERTER_REGISTRY[description] = converter 40 | _LOGGER.debug(f'Operation converter registered {description}') 41 | 42 | return converter 43 | 44 | return deco 45 | 46 | 47 | def get_converter( # pylint: disable=missing-function-docstring 48 | operation_type: str, 49 | version: int, 50 | domain: str = defs.ONNX_DOMAIN, 51 | ) -> TConverter: 52 | try: 53 | version = defs.get_schema( 54 | operation_type, 55 | domain=domain, 56 | max_inclusive_version=version, 57 | ).since_version 58 | except (RuntimeError, defs.SchemaError): 59 | pass 60 | 61 | description = OperationDescription( 62 | domain=domain, 63 | operation_type=operation_type, 64 | version=version, 65 | ) 66 | 67 | converter = _CONVERTER_REGISTRY.get(description, None) 68 | if converter is None: 69 | raise NotImplementedError(f'Converter is not implemented ({description})') 70 | 71 | return converter 72 | -------------------------------------------------------------------------------- /onnx2torch/node_converters/reshape.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'OnnxReshape', 3 | ] 4 | 5 | import torch 6 | from torch import nn 7 | 8 | from onnx2torch.node_converters.registry import add_converter 9 | from onnx2torch.onnx_graph import OnnxGraph 10 | from onnx2torch.onnx_node import OnnxNode 11 | from onnx2torch.utils.common import OperationConverterResult 12 | from onnx2torch.utils.common import onnx_mapping_from_node 13 | from onnx2torch.utils.custom_export_to_onnx import DefaultExportToOnnx 14 | from onnx2torch.utils.custom_export_to_onnx import OnnxToTorchModuleWithCustomExport 15 | 16 | 17 | class OnnxReshape(nn.Module, OnnxToTorchModuleWithCustomExport): # pylint: disable=missing-class-docstring 18 | @staticmethod 19 | def _do_reshape(input_tensor: torch.Tensor, shape: torch.Tensor) -> torch.Tensor: 20 | if torch.any(shape == 0): 21 | shape = [input_tensor.shape[i] if dim_size == 0 else dim_size for i, dim_size in enumerate(shape)] 22 | 23 | return torch.reshape(input_tensor, torch.Size(shape)) 24 | 25 | def forward( # pylint: disable=missing-function-docstring 26 | self, 27 | input_tensor: torch.Tensor, 28 | shape: torch.Tensor, 29 | ) -> torch.Tensor: 30 | def _forward() -> torch.Tensor: 31 | return self._do_reshape(input_tensor, shape) 32 | 33 | if torch.onnx.is_in_onnx_export(): 34 | return DefaultExportToOnnx.export(_forward, 'Reshape', input_tensor, shape, {}) 35 | 36 | return _forward() 37 | 38 | 39 | @add_converter(operation_type='Reshape', version=5) 40 | @add_converter(operation_type='Reshape', version=13) 41 | @add_converter(operation_type='Reshape', version=14) 42 | def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument 43 | if node.attributes.get('allowzero', 0) == 1: 44 | raise NotImplementedError('"allowzero=1" is not implemented') 45 | 46 | return OperationConverterResult( 47 | torch_module=OnnxReshape(), 48 | onnx_mapping=onnx_mapping_from_node(node=node), 49 | ) 50 | -------------------------------------------------------------------------------- /onnx2torch/node_converters/roundings.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'OnnxRound', 3 | ] 4 | 5 | import torch 6 | from torch import nn 7 | 8 | from onnx2torch.node_converters.registry import add_converter 9 | from onnx2torch.onnx_graph import OnnxGraph 10 | from onnx2torch.onnx_node import OnnxNode 11 | from onnx2torch.utils.common import OnnxToTorchModule 12 | from onnx2torch.utils.common import OperationConverterResult 13 | from onnx2torch.utils.common import onnx_mapping_from_node 14 | 15 | _TORCH_ROUND_FROM_ONNX_TYPE = { 16 | 'Ceil': torch.ceil, 17 | 'Floor': torch.floor, 18 | 'Round': torch.round, 19 | } 20 | 21 | 22 | class OnnxRound(nn.Module, OnnxToTorchModule): # pylint: disable=missing-class-docstring 23 | def __init__(self, round_type: str): 24 | super().__init__() 25 | self.round_function = _TORCH_ROUND_FROM_ONNX_TYPE[round_type] 26 | 27 | def forward(self, input_tensor: torch.Tensor): # pylint: disable=missing-function-docstring 28 | return self.round_function(input_tensor) 29 | 30 | 31 | @add_converter(operation_type='Ceil', version=13) 32 | @add_converter(operation_type='Ceil', version=6) 33 | @add_converter(operation_type='Floor', version=13) 34 | @add_converter(operation_type='Floor', version=6) 35 | @add_converter(operation_type='Round', version=11) 36 | def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument 37 | return OperationConverterResult( 38 | torch_module=OnnxRound(node.operation_type), 39 | onnx_mapping=onnx_mapping_from_node(node=node), 40 | ) 41 | -------------------------------------------------------------------------------- /onnx2torch/node_converters/shape.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'OnnxShape', 3 | ] 4 | 5 | from typing import Any 6 | from typing import Dict 7 | from typing import Optional 8 | 9 | import torch 10 | from torch import nn 11 | 12 | from onnx2torch.node_converters.registry import add_converter 13 | from onnx2torch.onnx_graph import OnnxGraph 14 | from onnx2torch.onnx_node import OnnxNode 15 | from onnx2torch.utils.common import OperationConverterResult 16 | from onnx2torch.utils.common import get_onnx_version 17 | from onnx2torch.utils.common import onnx_mapping_from_node 18 | from onnx2torch.utils.custom_export_to_onnx import DefaultExportToOnnx 19 | from onnx2torch.utils.custom_export_to_onnx import OnnxToTorchModuleWithCustomExport 20 | 21 | 22 | class OnnxShape(nn.Module, OnnxToTorchModuleWithCustomExport): # pylint: disable=missing-class-docstring 23 | def __init__(self, start: int = 0, end: Optional[int] = None): 24 | super().__init__() 25 | self._start = start 26 | self._end = end 27 | 28 | def _onnx_attrs(self, opset_version: int) -> Dict[str, Any]: 29 | if opset_version < 15: 30 | if self._start != 0: 31 | raise ValueError(f'Shape from opset < 15 does not support start != 0, got {self._start}') 32 | if self._end is not None: 33 | raise ValueError(f'Shape from opset < 15 does not support end != None, got {self._end}') 34 | return {} 35 | 36 | onnx_attrs: Dict[str, Any] = {'start_i': self._start} 37 | if self._end: 38 | onnx_attrs['end_i'] = self._end 39 | 40 | return onnx_attrs 41 | 42 | def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: # pylint: disable=missing-function-docstring 43 | def _forward(): 44 | return torch.tensor( 45 | input_tensor.shape[self._start : self._end], 46 | device=input_tensor.device, 47 | ) 48 | 49 | if torch.onnx.is_in_onnx_export(): 50 | onnx_attrs = self._onnx_attrs(opset_version=get_onnx_version()) 51 | return DefaultExportToOnnx.export(_forward, 'Shape', input_tensor, onnx_attrs) 52 | 53 | return _forward() 54 | 55 | 56 | @add_converter(operation_type='Shape', version=1) 57 | @add_converter(operation_type='Shape', version=13) 58 | @add_converter(operation_type='Shape', version=15) 59 | def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument 60 | return OperationConverterResult( 61 | torch_module=OnnxShape( 62 | start=node.attributes.get('start', 0), 63 | end=node.attributes.get('end', None), 64 | ), 65 | onnx_mapping=onnx_mapping_from_node(node=node), 66 | ) 67 | -------------------------------------------------------------------------------- /onnx2torch/node_converters/split.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'OnnxSplit', 3 | 'OnnxSplit13', 4 | ] 5 | 6 | from typing import List 7 | from typing import Optional 8 | 9 | import torch 10 | from torch import nn 11 | 12 | from onnx2torch.node_converters.registry import add_converter 13 | from onnx2torch.onnx_graph import OnnxGraph 14 | from onnx2torch.onnx_node import OnnxNode 15 | from onnx2torch.utils.common import OnnxToTorchModule 16 | from onnx2torch.utils.common import OperationConverterResult 17 | from onnx2torch.utils.common import onnx_mapping_from_node 18 | 19 | 20 | class OnnxSplit13(nn.Module, OnnxToTorchModule): # pylint: disable=missing-class-docstring 21 | def __init__(self, num_splits: int, axis: int = 0): 22 | super().__init__() 23 | 24 | self.axis = axis 25 | self.num_splits = num_splits 26 | 27 | def forward( # pylint: disable=missing-function-docstring 28 | self, 29 | input_tensor: torch.Tensor, 30 | split: Optional[torch.Tensor] = None, 31 | ) -> torch.Tensor: 32 | if split is None: 33 | axis_len = input_tensor.shape[self.axis] 34 | split_size_or_sections = axis_len // self.num_splits 35 | else: 36 | split_size_or_sections = split.tolist() 37 | 38 | return torch.split(input_tensor, split_size_or_sections, dim=self.axis) 39 | 40 | 41 | class OnnxSplit(nn.Module, OnnxToTorchModule): # pylint: disable=missing-class-docstring 42 | def __init__(self, num_splits: int, axis: int = 0, split: Optional[List[int]] = None): 43 | super().__init__() 44 | 45 | self.axis = axis 46 | self.num_splits = num_splits 47 | self.split = split 48 | 49 | def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: # pylint: disable=missing-function-docstring 50 | if self.split is None: 51 | axis_len = input_tensor.shape[self.axis] 52 | split_size_or_sections = axis_len // self.num_splits 53 | else: 54 | split_size_or_sections = self.split 55 | 56 | return torch.split(input_tensor, split_size_or_sections, dim=self.axis) 57 | 58 | 59 | @add_converter(operation_type='Split', version=13) 60 | def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument 61 | axis = node.attributes.get('axis', 0) 62 | num_splits = len(node.output_values) 63 | return OperationConverterResult( 64 | torch_module=OnnxSplit13(axis=axis, num_splits=num_splits), 65 | onnx_mapping=onnx_mapping_from_node(node=node), 66 | ) 67 | 68 | 69 | @add_converter(operation_type='Split', version=11) 70 | @add_converter(operation_type='Split', version=2) 71 | def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument 72 | axis = node.attributes.get('axis', 0) 73 | split = node.attributes.get('split', None) 74 | num_splits = len(node.output_values) 75 | return OperationConverterResult( 76 | torch_module=OnnxSplit(axis=axis, split=split, num_splits=num_splits), 77 | onnx_mapping=onnx_mapping_from_node(node=node), 78 | ) 79 | -------------------------------------------------------------------------------- /onnx2torch/node_converters/sum.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'OnnxSum', 3 | ] 4 | 5 | import torch 6 | 7 | from onnx2torch.node_converters.base_element_wise import OnnxBaseElementWise 8 | from onnx2torch.node_converters.registry import add_converter 9 | from onnx2torch.onnx_graph import OnnxGraph 10 | from onnx2torch.onnx_node import OnnxNode 11 | from onnx2torch.utils.common import OperationConverterResult 12 | from onnx2torch.utils.common import onnx_mapping_from_node 13 | 14 | 15 | class OnnxSum(OnnxBaseElementWise): # pylint: disable=missing-docstring 16 | def __init__(self): 17 | super().__init__(op_type='Sum') 18 | 19 | def apply_reduction(self, *tensors: torch.Tensor) -> torch.Tensor: # pylint: disable=missing-function-docstring 20 | broadcast_shape = self._broadcast_shape(*tensors) 21 | 22 | output = torch.zeros(broadcast_shape, dtype=tensors[0].dtype, device=tensors[0].device) 23 | for y in tensors: 24 | output.add_(y) 25 | 26 | return output 27 | 28 | 29 | @add_converter(operation_type='Sum', version=8) 30 | @add_converter(operation_type='Sum', version=13) 31 | def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument 32 | return OperationConverterResult( 33 | torch_module=OnnxSum(), 34 | onnx_mapping=onnx_mapping_from_node(node=node), 35 | ) 36 | -------------------------------------------------------------------------------- /onnx2torch/node_converters/tile.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=missing-docstring 2 | __all__ = [ 3 | 'OnnxTile', 4 | ] 5 | 6 | import torch 7 | from torch import nn 8 | 9 | from onnx2torch.node_converters.registry import add_converter 10 | from onnx2torch.onnx_graph import OnnxGraph 11 | from onnx2torch.onnx_node import OnnxNode 12 | from onnx2torch.utils.common import OperationConverterResult 13 | from onnx2torch.utils.common import onnx_mapping_from_node 14 | from onnx2torch.utils.custom_export_to_onnx import DefaultExportToOnnx 15 | from onnx2torch.utils.custom_export_to_onnx import OnnxToTorchModuleWithCustomExport 16 | 17 | 18 | class OnnxTile(nn.Module, OnnxToTorchModuleWithCustomExport): 19 | def forward(self, input_tensor: torch.Tensor, repeats: torch.Tensor) -> torch.Tensor: 20 | def _forward() -> torch.Tensor: 21 | return input_tensor.repeat(torch.Size(repeats)) 22 | 23 | if torch.onnx.is_in_onnx_export(): 24 | return DefaultExportToOnnx.export(_forward, 'Tile', input_tensor, repeats, {}) 25 | 26 | return _forward() 27 | 28 | 29 | @add_converter(operation_type='Tile', version=6) 30 | @add_converter(operation_type='Tile', version=13) 31 | def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: 32 | del graph 33 | return OperationConverterResult( 34 | torch_module=OnnxTile(), 35 | onnx_mapping=onnx_mapping_from_node(node=node), 36 | ) 37 | -------------------------------------------------------------------------------- /onnx2torch/node_converters/topk.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'OnnxTopK', 3 | ] 4 | 5 | from typing import Tuple 6 | from typing import Union 7 | 8 | import torch 9 | from torch import nn 10 | 11 | from onnx2torch.node_converters.registry import add_converter 12 | from onnx2torch.onnx_graph import OnnxGraph 13 | from onnx2torch.onnx_node import OnnxNode 14 | from onnx2torch.utils.common import OnnxToTorchModule 15 | from onnx2torch.utils.common import OperationConverterResult 16 | from onnx2torch.utils.common import onnx_mapping_from_node 17 | 18 | 19 | class OnnxTopK(nn.Module, OnnxToTorchModule): # pylint: disable=missing-class-docstring 20 | def __init__(self, dim: int = -1, largest: int = 1, sorted_: int = 1): 21 | super().__init__() 22 | self.dim = dim 23 | self.largest = largest == 1 24 | self.sorted = sorted_ == 1 25 | 26 | def forward( # pylint: disable=missing-function-docstring, invalid-name 27 | self, 28 | input_tensor: torch.Tensor, 29 | k: Union[torch.Tensor, int], 30 | ) -> Tuple[torch.Tensor, torch.Tensor]: 31 | k = k[0] if isinstance(k, torch.Tensor) else k 32 | 33 | top_k = torch.topk( 34 | input_tensor, 35 | k=k, 36 | dim=self.dim, 37 | largest=self.largest, 38 | sorted=self.sorted, 39 | ) 40 | return top_k.values, top_k.indices 41 | 42 | 43 | @add_converter(operation_type='TopK', version=1) 44 | @add_converter(operation_type='TopK', version=10) 45 | @add_converter(operation_type='TopK', version=11) 46 | def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument 47 | node_attributes = node.attributes 48 | axis = node_attributes.get('axis', -1) 49 | largest = node_attributes.get('largest', 1) 50 | sorted_ = node_attributes.get('sorted', 1) 51 | 52 | return OperationConverterResult( 53 | torch_module=OnnxTopK(dim=axis, largest=largest, sorted_=sorted_), 54 | onnx_mapping=onnx_mapping_from_node(node=node), 55 | ) 56 | -------------------------------------------------------------------------------- /onnx2torch/node_converters/transpose.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'OnnxTranspose', 3 | ] 4 | 5 | from typing import List 6 | from typing import Optional 7 | 8 | import torch 9 | from torch import nn 10 | 11 | from onnx2torch.node_converters.registry import add_converter 12 | from onnx2torch.onnx_graph import OnnxGraph 13 | from onnx2torch.onnx_node import OnnxNode 14 | from onnx2torch.utils.common import OnnxMapping 15 | from onnx2torch.utils.common import OnnxToTorchModule 16 | from onnx2torch.utils.common import OperationConverterResult 17 | 18 | 19 | class OnnxTranspose(nn.Module, OnnxToTorchModule): # pylint: disable=missing-class-docstring 20 | def __init__(self, perm: Optional[List[int]] = None): 21 | super().__init__() 22 | self.perm = perm 23 | 24 | def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: # pylint: disable=missing-function-docstring 25 | if self.perm is None: 26 | self.perm = list(range(input_tensor.dim()))[::-1] 27 | 28 | return input_tensor.permute(self.perm) 29 | 30 | 31 | @add_converter(operation_type='Transpose', version=1) 32 | @add_converter(operation_type='Transpose', version=13) 33 | def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument 34 | input_values = [node.input_values[0]] 35 | perm_value_name = node.input_values[1] if len(node.input_values) > 1 else None 36 | 37 | if perm_value_name is not None: 38 | perm = graph.initializers[perm_value_name].to_torch().tolist() 39 | else: 40 | perm = node.attributes.get('perm', None) 41 | if perm is not None: 42 | perm = torch.tensor(perm, dtype=torch.long).tolist() 43 | 44 | return OperationConverterResult( 45 | torch_module=OnnxTranspose(perm=perm), 46 | onnx_mapping=OnnxMapping( 47 | inputs=tuple(input_values), 48 | outputs=node.output_values, 49 | ), 50 | ) 51 | -------------------------------------------------------------------------------- /onnx2torch/node_converters/unsqueeze.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'OnnxUnsqueezeStaticAxes', 3 | 'OnnxUnsqueezeDynamicAxes', 4 | ] 5 | 6 | from typing import List 7 | 8 | import torch 9 | from torch import nn 10 | 11 | from onnx2torch.node_converters.registry import add_converter 12 | from onnx2torch.onnx_graph import OnnxGraph 13 | from onnx2torch.onnx_node import OnnxNode 14 | from onnx2torch.utils.common import OnnxMapping 15 | from onnx2torch.utils.common import OnnxToTorchModule 16 | from onnx2torch.utils.common import OperationConverterResult 17 | from onnx2torch.utils.common import get_const_value 18 | from onnx2torch.utils.common import onnx_mapping_from_node 19 | from onnx2torch.utils.custom_export_to_onnx import DefaultExportToOnnx 20 | from onnx2torch.utils.custom_export_to_onnx import OnnxToTorchModuleWithCustomExport 21 | 22 | 23 | class OnnxUnsqueezeStaticAxes(nn.Module, OnnxToTorchModule): # pylint: disable=missing-class-docstring 24 | def __init__(self, axes: List[int]): 25 | super().__init__() 26 | self._axes = sorted(axes) 27 | 28 | def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: # pylint: disable=missing-function-docstring 29 | result = input_tensor 30 | for axes_id in self._axes: 31 | result = torch.unsqueeze(result, dim=axes_id) 32 | 33 | return result 34 | 35 | 36 | class OnnxUnsqueezeDynamicAxes( # pylint: disable=missing-class-docstring 37 | nn.Module, 38 | OnnxToTorchModuleWithCustomExport, 39 | ): 40 | def forward( # pylint: disable=missing-function-docstring 41 | self, 42 | input_tensor: torch.Tensor, 43 | axes: torch.Tensor, 44 | ) -> torch.Tensor: 45 | def _forward(): 46 | result = input_tensor 47 | for axes_id in torch.sort(axes).values: 48 | result = torch.unsqueeze(result, dim=axes_id) 49 | 50 | return result 51 | 52 | if torch.onnx.is_in_onnx_export(): 53 | return DefaultExportToOnnx.export(_forward, 'Unsqueeze', input_tensor, axes, {}) 54 | 55 | return _forward() 56 | 57 | 58 | @add_converter(operation_type='Unsqueeze', version=1) 59 | @add_converter(operation_type='Unsqueeze', version=11) 60 | def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument 61 | axes = node.attributes['axes'] 62 | return OperationConverterResult( 63 | torch_module=OnnxUnsqueezeStaticAxes(axes=axes), 64 | onnx_mapping=onnx_mapping_from_node(node), 65 | ) 66 | 67 | 68 | @add_converter(operation_type='Unsqueeze', version=13) 69 | def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument 70 | try: 71 | axes = get_const_value(node.input_values[1], graph) 72 | axes = axes.tolist() 73 | return OperationConverterResult( 74 | torch_module=OnnxUnsqueezeStaticAxes(axes=axes), 75 | onnx_mapping=OnnxMapping( 76 | inputs=(node.input_values[0],), 77 | outputs=node.output_values, 78 | ), 79 | ) 80 | except KeyError: 81 | pass 82 | 83 | return OperationConverterResult( 84 | torch_module=OnnxUnsqueezeDynamicAxes(), 85 | onnx_mapping=onnx_mapping_from_node(node), 86 | ) 87 | -------------------------------------------------------------------------------- /onnx2torch/node_converters/where.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'OnnxWhere', 3 | ] 4 | 5 | import torch 6 | from torch import nn 7 | 8 | from onnx2torch.node_converters.registry import add_converter 9 | from onnx2torch.onnx_graph import OnnxGraph 10 | from onnx2torch.onnx_node import OnnxNode 11 | from onnx2torch.utils.common import OnnxToTorchModule 12 | from onnx2torch.utils.common import OperationConverterResult 13 | from onnx2torch.utils.common import onnx_mapping_from_node 14 | 15 | 16 | class OnnxWhere(nn.Module, OnnxToTorchModule): # pylint: disable=missing-class-docstring 17 | def forward( # pylint: disable=missing-function-docstring 18 | self, 19 | condition: torch.Tensor, 20 | x: torch.Tensor, 21 | y: torch.Tensor, 22 | ) -> torch.Tensor: 23 | return torch.where(condition, x, y) 24 | 25 | 26 | @add_converter(operation_type='Where', version=9) 27 | @add_converter(operation_type='Where', version=16) 28 | def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument 29 | return OperationConverterResult( 30 | torch_module=OnnxWhere(), 31 | onnx_mapping=onnx_mapping_from_node(node=node), 32 | ) 33 | -------------------------------------------------------------------------------- /onnx2torch/onnx_node.py: -------------------------------------------------------------------------------- 1 | from types import MappingProxyType 2 | from typing import Any 3 | from typing import Mapping 4 | from typing import Tuple 5 | 6 | from onnx.onnx_ml_pb2 import AttributeProto 7 | from onnx.onnx_ml_pb2 import NodeProto 8 | 9 | from onnx2torch.onnx_tensor import OnnxTensor 10 | 11 | 12 | class OnnxNode: # pylint: disable=missing-class-docstring 13 | def __init__(self, onnx_node_proto: NodeProto, unique_name: str): 14 | self._proto = onnx_node_proto 15 | self._unique_name = unique_name 16 | self._input_values = tuple(onnx_node_proto.input) 17 | self._output_values = tuple(onnx_node_proto.output) 18 | self._inputs = None 19 | 20 | self._proto_attributes = { 21 | attribute.name: OnnxNode._parse_attribute_value(attribute) for attribute in self._proto.attribute 22 | } 23 | 24 | @staticmethod 25 | def _parse_attribute_value(attribute: AttributeProto) -> Any: 26 | if attribute.HasField('i'): 27 | value = attribute.i 28 | elif attribute.HasField('f'): 29 | value = attribute.f 30 | elif attribute.HasField('s'): 31 | value = str(attribute.s, 'utf-8') 32 | elif attribute.HasField('t'): 33 | value = OnnxTensor(attribute.t) 34 | elif attribute.ints: 35 | value = list(attribute.ints) 36 | elif attribute.floats: 37 | value = list(attribute.floats) 38 | elif attribute.strings: 39 | value = [str(s, 'utf-8') for s in attribute.strings] 40 | elif attribute.tensors: 41 | value = [OnnxTensor(t) for t in attribute.tensors] 42 | else: 43 | value = attribute 44 | 45 | return value 46 | 47 | @property 48 | def proto(self) -> NodeProto: # pylint: disable=missing-function-docstring 49 | return self._proto 50 | 51 | @property 52 | def name(self) -> str: # pylint: disable=missing-function-docstring 53 | return self._proto.name 54 | 55 | @property 56 | def unique_name(self) -> str: # pylint: disable=missing-function-docstring 57 | return self._unique_name 58 | 59 | @property 60 | def domain(self) -> str: # pylint: disable=missing-function-docstring 61 | return self._proto.domain 62 | 63 | @property 64 | def operation_type(self) -> str: # pylint: disable=missing-function-docstring 65 | return self._proto.op_type 66 | 67 | @property 68 | def input_values(self) -> Tuple[str, ...]: # pylint: disable=missing-function-docstring 69 | return self._input_values 70 | 71 | @property 72 | def output_values(self) -> Tuple[str, ...]: # pylint: disable=missing-function-docstring 73 | return self._output_values 74 | 75 | @property 76 | def attributes(self) -> Mapping[str, Any]: # pylint: disable=missing-function-docstring 77 | return MappingProxyType(self._proto_attributes) 78 | -------------------------------------------------------------------------------- /onnx2torch/onnx_tensor.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from onnx import numpy_helper 4 | from onnx.onnx_ml_pb2 import TensorProto 5 | 6 | 7 | class OnnxTensor: # pylint: disable=missing-class-docstring 8 | def __init__(self, onnx_tensor_proto: TensorProto): 9 | self._proto = onnx_tensor_proto 10 | 11 | @classmethod 12 | def from_numpy(cls, array: np.ndarray, name: str = None): # pylint: disable=missing-function-docstring 13 | onnx_tensor_proto = numpy_helper.from_array(array, name=name) 14 | return cls(onnx_tensor_proto) 15 | 16 | @classmethod 17 | def from_torch(cls, tensor: torch.Tensor, name: str = None): # pylint: disable=missing-function-docstring 18 | array = tensor.detach().cpu().numpy() 19 | return cls.from_numpy(array, name=name) 20 | 21 | @property 22 | def proto(self) -> TensorProto: # pylint: disable=missing-function-docstring 23 | return self._proto 24 | 25 | @property 26 | def name(self) -> str: # pylint: disable=missing-function-docstring 27 | return self._proto.name 28 | 29 | def to_numpy(self) -> np.ndarray: # pylint: disable=missing-function-docstring 30 | return numpy_helper.to_array(self._proto).copy() 31 | 32 | def to_torch(self) -> torch.Tensor: # pylint: disable=missing-function-docstring 33 | return torch.from_numpy(self.to_numpy()) 34 | -------------------------------------------------------------------------------- /onnx2torch/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ENOT-AutoDL/onnx2torch/369412ad62c81ca5b360554572820755e31b9b7a/onnx2torch/utils/__init__.py -------------------------------------------------------------------------------- /onnx2torch/utils/common.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from typing import NamedTuple 3 | from typing import Tuple 4 | from typing import Union 5 | 6 | import torch 7 | from onnx import ValueInfoProto # pylint: disable=no-name-in-module 8 | from torch import nn 9 | from torch.onnx import symbolic_helper 10 | 11 | from onnx2torch.onnx_graph import OnnxGraph 12 | from onnx2torch.onnx_node import OnnxNode 13 | 14 | 15 | class OnnxToTorchModule: 16 | """ 17 | Marker class for onnx2torch modules. 18 | """ 19 | 20 | pass # pylint: disable=unnecessary-pass 21 | 22 | 23 | class OnnxMapping(NamedTuple): # pylint: disable=missing-class-docstring 24 | inputs: Tuple[str, ...] 25 | outputs: Tuple[str, ...] 26 | 27 | 28 | class OperationConverterResult(NamedTuple): # pylint: disable=missing-class-docstring 29 | torch_module: nn.Module 30 | onnx_mapping: OnnxMapping 31 | 32 | 33 | def onnx_mapping_from_node(node: OnnxNode) -> OnnxMapping: # pylint: disable=missing-function-docstring 34 | return OnnxMapping( 35 | inputs=node.input_values, 36 | outputs=node.output_values, 37 | ) 38 | 39 | 40 | def get_onnx_version() -> int: 41 | """Returns opset version at the time of the export.""" 42 | if hasattr(symbolic_helper, 'GLOBALS'): 43 | return symbolic_helper.GLOBALS.export_onnx_opset_version 44 | 45 | return symbolic_helper._export_onnx_opset_version # pylint: disable=no-member, protected-access 46 | 47 | 48 | def get_shape_from_value_info(value_info: ValueInfoProto) -> List[int]: # pylint: disable=missing-function-docstring 49 | return [dim.dim_value for dim in value_info.type.tensor_type.shape.dim] 50 | 51 | 52 | def get_const_value( # pylint: disable=missing-function-docstring 53 | name: str, 54 | graph: OnnxGraph, 55 | ) -> Union[torch.Tensor, float, int, str, List]: 56 | if name in graph.initializers: 57 | return graph.initializers[name].to_torch() 58 | 59 | try: 60 | node, _ = graph.value_as_node_output(name) 61 | except KeyError as exc: 62 | raise KeyError(f'Tensor "{name}" is not found in constant values') from exc 63 | 64 | if node.operation_type == 'Constant': 65 | attr_name, attr_value = next(iter(node.attributes.items())) 66 | if attr_name == 'value': 67 | attr_value = attr_value.to_torch() 68 | 69 | return attr_value 70 | 71 | raise KeyError(f'Tensor "{name}" is not found in constant values') 72 | 73 | 74 | def old_style_broadcast( # pylint: disable=missing-function-docstring 75 | first: torch.Tensor, 76 | second: torch.Tensor, 77 | axis: int, 78 | ) -> torch.Tensor: 79 | rank = len(first.shape) 80 | axis = axis + rank if axis < 0 else axis 81 | 82 | second_shape = [1] * axis + list(second.shape) 83 | second_shape = second_shape + [1] * (rank - len(second_shape)) 84 | 85 | return second.view(second_shape) 86 | -------------------------------------------------------------------------------- /onnx2torch/utils/custom_export_to_onnx.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | 'CustomExportToOnnx', 3 | 'DefaultExportToOnnx', 4 | 'OnnxToTorchModuleWithCustomExport', 5 | ] 6 | 7 | from typing import Any 8 | from typing import Callable 9 | from typing import Dict 10 | from typing import Optional 11 | 12 | import torch 13 | from torch import _C as torch_C 14 | 15 | from onnx2torch.utils.common import OnnxToTorchModule 16 | 17 | 18 | class OnnxToTorchModuleWithCustomExport(OnnxToTorchModule): 19 | """ 20 | Marker class for onnx2torch modules with custom export to onnx. 21 | """ 22 | 23 | def _onnx_attrs(self, opset_version: int) -> Dict[str, Any]: # pylint: disable=unused-argument 24 | """ 25 | Returns ONNX attributes with their values as a dictionary. 26 | 27 | Parameters 28 | ---------- 29 | opset_version : int 30 | ONNX opset version. 31 | The number of attributes, their names and values depend on opset version; 32 | function should return correct set of attributes. 33 | 34 | Returns 35 | ------- 36 | Dict[str, Any] 37 | ONNX attributes. 38 | 39 | """ 40 | return {} 41 | 42 | 43 | class CustomExportToOnnx(torch.autograd.Function): 44 | """Customizes ONNX exporting from PyTorch.""" 45 | 46 | _NEXT_FORWARD_FUNCTION: Optional[Callable] = None 47 | 48 | @classmethod 49 | def export(cls, forward_function: Callable, *args) -> Any: 50 | """ 51 | Substitues custom forward function. 52 | This function is closely related to forward function, it substitues `forward_function` to real forward. 53 | 54 | Old name: `set_forward_and_apply`. 55 | """ 56 | CustomExportToOnnx._NEXT_FORWARD_FUNCTION = forward_function 57 | return cls.apply(*args) 58 | 59 | @staticmethod 60 | def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: # pylint: disable=unused-argument, arguments-differ 61 | """Applies custom forward function.""" 62 | if CustomExportToOnnx._NEXT_FORWARD_FUNCTION is None: 63 | raise RuntimeError('Forward function is not set') 64 | 65 | try: 66 | return CustomExportToOnnx._NEXT_FORWARD_FUNCTION() # pylint: disable=not-callable 67 | finally: 68 | CustomExportToOnnx._NEXT_FORWARD_FUNCTION = None 69 | 70 | @staticmethod 71 | def backward(ctx: Any, *grad_outputs: Any) -> Any: # pylint: disable=unused-argument, missing-function-docstring 72 | raise RuntimeError('Backward called while converting to ONNX') 73 | 74 | @staticmethod 75 | def symbolic(graph: torch_C.Graph, *args) -> torch_C.Value: # pylint: disable=unused-argument 76 | """Export implementation. Return ONNX operation from this function using graph.""" 77 | raise NotImplementedError 78 | 79 | 80 | class DefaultExportToOnnx(CustomExportToOnnx): # pylint: disable=abstract-method 81 | """ 82 | CustomExportToOnnx with default symbolic method implementation. 83 | 84 | Please follow our convention, args consists of: 85 | - op_type 86 | - operation inputs 87 | - operation attributes 88 | 89 | DO NOT REORDER! 90 | 91 | Note: the number of operation outputs can be added later. 92 | 93 | This class should be used in most cases: 94 | >>> return DefaultExportToOnnx.export(_forward, op_type, *inputs, onnx_attrs) 95 | """ 96 | 97 | @staticmethod 98 | def symbolic(graph: torch_C.Graph, *args) -> torch_C.Value: 99 | op_type, *inputs, onnx_attrs = args 100 | return graph.op(op_type, *inputs, **onnx_attrs, outputs=1) 101 | -------------------------------------------------------------------------------- /onnx2torch/utils/dtype.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | from typing import Dict 3 | from typing import Type 4 | from typing import Union 5 | 6 | import numpy as np 7 | import torch 8 | 9 | 10 | def onnx_dtype_to_torch_dtype(dtype: int) -> Union[torch.dtype, Type[str], Type[bool]]: 11 | """ 12 | Convert ONNX dtype to PyTorch dtype. 13 | 14 | Parameters 15 | ---------- 16 | dtype : int 17 | ONNX data type. 18 | 19 | Returns 20 | ------- 21 | Union[torch.dtype, Type[str], Type[bool]] 22 | Corresponding PyTorch dtype. 23 | 24 | """ 25 | # https://github.com/onnx/onnx/blob/main/onnx/onnx-ml.proto#L485 26 | _dtypes: Dict[int, Union[torch.dtype, Type[str], Type[bool]]] = { 27 | 1: torch.float32, 28 | 2: torch.uint8, 29 | 3: torch.int8, 30 | # 4: UINT16 is not supported: https://github.com/pytorch/pytorch/issues/58734. 31 | 5: torch.int16, 32 | 6: torch.int32, 33 | 7: torch.int64, 34 | 8: str, 35 | 9: bool, 36 | 10: torch.float16, 37 | 11: torch.float64, 38 | # 12: UINT32 is not supported: https://github.com/pytorch/pytorch/issues/58734. 39 | # 13: UINT64 is not supported: https://github.com/pytorch/pytorch/issues/58734. 40 | 14: torch.complex64, 41 | 15: torch.complex128, 42 | 16: torch.bfloat16, 43 | } 44 | try: 45 | return _dtypes[dtype] 46 | except KeyError as exc: 47 | raise ValueError(f'dtype={dtype} is not supported') from exc 48 | 49 | 50 | def onnx_dtype_to_numpy_dtype(dtype: int) -> Union[np.dtype, Type[str], Type[bool]]: 51 | """ 52 | Convert ONNX dtype to Numpy dtype. 53 | 54 | Parameters 55 | ---------- 56 | dtype : int 57 | ONNX data type. 58 | 59 | Returns 60 | ------- 61 | Union[torch.dtype, Type[str], Type[bool]] 62 | Corresponding Numpy dtype. 63 | 64 | """ 65 | # https://numpy.org/doc/stable/reference/arrays.dtypes.html 66 | _dtypes: Dict[int, Any] = { 67 | 1: np.float32, 68 | 2: np.uint8, 69 | 3: np.int8, 70 | 4: np.uint16, 71 | 5: np.int16, 72 | 6: np.int32, 73 | 7: np.int64, 74 | 8: str, 75 | 9: bool, 76 | 10: np.float16, 77 | 11: np.float64, 78 | 12: np.uint32, 79 | 13: np.uint64, 80 | 14: np.complex64, 81 | 15: np.complex128, 82 | # 16: bfloat16 is not supported. 83 | } 84 | try: 85 | return _dtypes[dtype] 86 | except KeyError as exc: 87 | raise ValueError(f'dtype={dtype} is not supported') from exc 88 | -------------------------------------------------------------------------------- /onnx2torch/utils/indices.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | __all__ = [ 4 | 'upcast_indices', 5 | ] 6 | 7 | _INT_DTYPES = ( 8 | torch.int8, 9 | torch.int16, 10 | torch.int32, 11 | torch.int64, 12 | ) 13 | 14 | 15 | def upcast_indices(indices: torch.Tensor) -> torch.Tensor: 16 | """ 17 | Upcasts indices tensor to torch.int64 (long) dtype. 18 | 19 | indices : torch.Tensor 20 | Indices for upcasting to torch.int64. 21 | 22 | Returns 23 | ------- 24 | torch.Tensor 25 | Upcasted to torch.int64 tensor. 26 | 27 | """ 28 | if not any(indices.dtype == dtype for dtype in _INT_DTYPES): 29 | raise ValueError(f'Expected types of indices: {_INT_DTYPES}, got {indices.dtype} instead') 30 | return indices.type(dtype=torch.int64) 31 | -------------------------------------------------------------------------------- /onnx2torch/utils/padding.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from typing import Tuple 3 | from typing import Union 4 | 5 | from torch import nn 6 | 7 | from onnx2torch.node_converters.pad import OnnxPadStatic 8 | 9 | 10 | def is_symmetric_onnx_padding(padding: Tuple[int, ...]) -> bool: # pylint: disable=missing-function-docstring 11 | half_len = len(padding) // 2 12 | return padding[:half_len] == padding[half_len:] 13 | 14 | 15 | def onnx_auto_pad_to_torch_padding( # pylint: disable=missing-function-docstring 16 | auto_pad: str, 17 | onnx_padding: Tuple[int, ...], 18 | ) -> Tuple[Union[int, Tuple[int, ...]], Optional[nn.Module]]: 19 | if auto_pad == 'NOTSET': 20 | if onnx_padding is None: 21 | return 0, None 22 | 23 | if is_symmetric_onnx_padding(onnx_padding): 24 | half_len = len(onnx_padding) // 2 25 | return onnx_padding[:half_len], None 26 | 27 | return 0, OnnxPadStatic.create_from_onnx_params(onnx_pads=onnx_padding) 28 | 29 | if auto_pad == 'VALID': 30 | return 0, None 31 | 32 | if auto_pad in ('SAME_UPPER', 'SAME_LOWER'): 33 | raise NotImplementedError(f'"{auto_pad}" auto_pad is not implemented') 34 | 35 | raise ValueError(f'Got unexpected auto_pad value "{auto_pad}"') 36 | -------------------------------------------------------------------------------- /onnx2torch/utils/safe_shape_inference.py: -------------------------------------------------------------------------------- 1 | import tempfile 2 | from pathlib import Path 3 | from typing import Union 4 | 5 | import onnx 6 | from onnx.onnx_ml_pb2 import ModelProto 7 | from onnx.shape_inference import infer_shapes 8 | from onnx.shape_inference import infer_shapes_path 9 | 10 | 11 | def _is_big_model(model: ModelProto) -> bool: 12 | return model.ByteSize() / (1024 * 1024 * 1024) > 2.0 13 | 14 | 15 | def _shape_inference_by_model_path( 16 | model_path: Union[Path, str], 17 | output_path: Union[Path, str], 18 | **kwargs, 19 | ) -> ModelProto: 20 | model_path = str(Path(model_path).resolve()) 21 | output_path = str(Path(output_path).resolve()) 22 | infer_shapes_path(model_path, output_path=output_path, **kwargs) 23 | 24 | return onnx.load(output_path) 25 | 26 | 27 | def safe_shape_inference( # pylint: disable=missing-function-docstring 28 | onnx_model_or_path: Union[ModelProto, Path, str], 29 | **kwargs, 30 | ) -> ModelProto: 31 | if isinstance(onnx_model_or_path, ModelProto): 32 | if not _is_big_model(onnx_model_or_path): 33 | return infer_shapes(onnx_model_or_path, **kwargs) 34 | 35 | with tempfile.TemporaryDirectory() as tmp_dir: 36 | tmp_model_path = Path(tmp_dir) / 'model.onnx' 37 | onnx.save_model( 38 | proto=onnx_model_or_path, 39 | f=str(tmp_model_path), 40 | save_as_external_data=True, 41 | all_tensors_to_one_file=True, 42 | ) 43 | return _shape_inference_by_model_path(tmp_model_path, output_path=tmp_model_path, **kwargs) 44 | 45 | with tempfile.NamedTemporaryFile(dir=Path(onnx_model_or_path).parent) as tmp_model_file: 46 | return _shape_inference_by_model_path(onnx_model_or_path, output_path=tmp_model_file.name, **kwargs) 47 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = 'onnx2torch' 3 | version = '1.5.15' 4 | license = {file = 'LICENSE'} 5 | description = 'ONNX to PyTorch converter' 6 | readme = 'README.md' 7 | keywords = ['AI', 'onnx', 'torch', 'onnx2torch', 'converters'] 8 | authors = [{name = 'ENOT LLC', email = 'enot@enot.ai'}] 9 | classifiers = [ 10 | 'Development Status :: 5 - Production/Stable', 11 | 'License :: OSI Approved :: Apache Software License', 12 | 'Programming Language :: Python', 13 | 'Programming Language :: Python :: 3 :: Only', 14 | ] 15 | requires-python = '>=3.6' 16 | dependencies = [ 17 | 'numpy>=1.16.4', 18 | 'onnx>=1.9.0', 19 | 'torch>=1.8.0', 20 | 'torchvision>=0.9.0', 21 | ] 22 | 23 | [project.optional-dependencies] 24 | dev = [ 25 | 'pytest', 26 | 'black', 27 | 'isort', 28 | 'pylint', 29 | 'pre-commit', 30 | 'onnxruntime', 31 | 'Pillow', 32 | 'requests', 33 | 'googledrivedownloader', 34 | ] 35 | 36 | [project.urls] 37 | homepage = 'https://enot.ai' 38 | repository = 'https://github.com/ENOT-AutoDL/onnx2torch' 39 | 40 | [tool.setuptools.packages.find] 41 | include = ['onnx2torch*'] 42 | 43 | [tool.commitizen] 44 | name = 'cz_conventional_commits' 45 | tag_format = '$version' 46 | version_scheme = 'pep440' 47 | version_provider = 'pep621' 48 | update_changelog_on_bump = true 49 | major_version_zero = true 50 | 51 | [tool.docformatter] 52 | recursive = true 53 | wrap-summaries = 0 54 | wrap-descriptions = 0 55 | blank = true 56 | black = true 57 | pre-summary-newline = true 58 | 59 | [tool.yamlfix] 60 | line_length = 120 61 | explicit_start = false 62 | sequence_style = 'keep_style' 63 | whitelines = 1 64 | section_whitelines = 1 65 | 66 | [tool.black] 67 | line-length = 120 68 | target-version = ['py36', 'py37', 'py38', 'py39'] 69 | include = '\.pyi?$' 70 | skip-string-normalization = true 71 | 72 | [tool.isort] 73 | profile = 'black' 74 | line_length = 120 75 | ensure_newline_before_comments = true 76 | force_single_line = true 77 | 78 | [tool.pylint.master] 79 | load-plugins = ['pylint.extensions.docparams'] 80 | 81 | [tool.pylint.format] 82 | max-line-length = 120 83 | 84 | [tool.pylint.design] 85 | max-args = 12 86 | max-locals = 30 87 | max-attributes = 20 88 | min-public-methods = 0 89 | 90 | [tool.pylint.typecheck] 91 | generated-members = ['torch.*'] 92 | 93 | [tool.pylint.messages_control] 94 | disable = [ 95 | 'logging-fstring-interpolation', 96 | 'cyclic-import', 97 | 'duplicate-code', 98 | 'missing-module-docstring', 99 | 'unnecessary-pass', 100 | 'no-name-in-module', 101 | ] 102 | 103 | [tool.pylint.BASIC] 104 | good-names = ['bs', 'bn'] 105 | 106 | [tool.pyright] 107 | reportMissingImports = false 108 | reportMissingTypeStubs = false 109 | reportWildcardImportFromLibrary = false 110 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | TMP_DIR = Path(__file__).parent / '.tmp' 4 | MODELS_DIR = TMP_DIR / 'models' 5 | DATASETS_DIR = TMP_DIR / 'datasets' 6 | 7 | TMP_DIR.mkdir(exist_ok=True) 8 | MODELS_DIR.mkdir(exist_ok=True) 9 | DATASETS_DIR.mkdir(exist_ok=True) 10 | -------------------------------------------------------------------------------- /tests/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ENOT-AutoDL/onnx2torch/369412ad62c81ca5b360554572820755e31b9b7a/tests/models/__init__.py -------------------------------------------------------------------------------- /tests/node_converters/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ENOT-AutoDL/onnx2torch/369412ad62c81ca5b360554572820755e31b9b7a/tests/node_converters/__init__.py -------------------------------------------------------------------------------- /tests/node_converters/average_pool_max_pool_test.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | from typing import List 3 | 4 | import numpy as np 5 | import onnx 6 | import pytest 7 | 8 | from tests.utils.common import check_onnx_model 9 | from tests.utils.common import make_model_from_nodes 10 | 11 | 12 | def _test_pool_op( 13 | op_type, 14 | input_shape: List[int], 15 | atol_onnx_torch: float = 0.0, 16 | **kwargs, 17 | ) -> None: 18 | x = np.random.uniform(low=-1.0, high=1.0, size=input_shape).astype(np.float32) 19 | test_inputs = {'x': x} 20 | 21 | node = onnx.helper.make_node( 22 | op_type, 23 | inputs=['x'], 24 | outputs=['y'], 25 | **kwargs, 26 | ) 27 | model = make_model_from_nodes(nodes=node, initializers={}, inputs_example=test_inputs) 28 | check_onnx_model( 29 | model, 30 | test_inputs, 31 | atol_onnx_torch=atol_onnx_torch, 32 | ) 33 | 34 | 35 | @pytest.mark.parametrize( 36 | 'op', 37 | ( 38 | 'MaxPool', 39 | 'AveragePool', 40 | ), 41 | ) 42 | @pytest.mark.parametrize( 43 | 'input_shape,kernel_shape,optional_attrs', 44 | ( 45 | # 1d 46 | ([2, 3, 16], [2], {}), 47 | ([2, 3, 16], [1], {}), 48 | ([2, 3, 16], [3], {}), 49 | ([2, 3, 16], [2], {'strides': [3]}), 50 | ([2, 3, 16], [2], {'ceil_mode': 1}), 51 | # 2d 52 | ([2, 3, 16, 16], [2, 2], {}), 53 | ([2, 3, 16, 16], [1, 2], {}), 54 | ([2, 3, 16, 16], [3, 2], {}), 55 | ([2, 3, 16, 16], [2, 2], {'strides': [2, 3]}), 56 | ([2, 3, 16, 16], [2, 2], {'ceil_mode': 1}), 57 | # 3d 58 | ([2, 3, 16, 16, 16], [2, 2, 2], {}), 59 | ([2, 3, 16, 16, 16], [1, 2, 3], {}), 60 | ([2, 3, 16, 16, 16], [3, 2, 1], {}), 61 | ([2, 3, 16, 16, 16], [2, 2, 2], {'strides': [1, 2, 3]}), 62 | ([2, 3, 16, 16, 16], [2, 2, 2], {'ceil_mode': 1}), 63 | ), 64 | ) 65 | def test_max_pool_average_pool( # pylint: disable=missing-function-docstring 66 | op: str, # pylint: disable=invalid-name 67 | input_shape: List[int], 68 | kernel_shape: List[int], 69 | optional_attrs: Dict, 70 | ) -> None: 71 | if op == 'AveragePool': 72 | optional_attrs['atol_onnx_torch'] = 10**-7 73 | 74 | _test_pool_op(op, input_shape=input_shape, kernel_shape=kernel_shape, **optional_attrs) 75 | 76 | 77 | @pytest.mark.parametrize( 78 | 'input_shape,kernel_shape,optional_attrs', 79 | ( 80 | # 1d 81 | ([2, 3, 16], [2], {'pads': [1] * 2}), 82 | ([2, 3, 16], [3], {'pads': [0, 1]}), 83 | ([2, 3, 16], [3], {'pads': [2, 0]}), 84 | # 2d 85 | ([2, 3, 16, 16], [2, 2], {'pads': [1] * 4}), 86 | ([2, 3, 16, 16], [2, 2], {'pads': [0] * 2 + [1] * 2}), 87 | ([2, 3, 16, 16], [3, 3], {'pads': [0, 1, 1, 0]}), 88 | # 3d 89 | ([2, 3, 16, 16, 16], [2, 2, 2], {'pads': [1] * 6}), 90 | ([2, 3, 16, 16, 16], [2, 2, 2], {'pads': [0] * 3 + [1] * 3}), 91 | ([2, 3, 16, 16, 16], [3, 3, 3], {'pads': [0, 1, 2, 2, 1, 0]}), 92 | ), 93 | ) 94 | def test_max_pool_padding( # pylint: disable=missing-function-docstring 95 | input_shape: List[int], 96 | kernel_shape: List[int], 97 | optional_attrs: Dict, 98 | ) -> None: 99 | _test_pool_op('MaxPool', input_shape=input_shape, kernel_shape=kernel_shape, **optional_attrs) 100 | -------------------------------------------------------------------------------- /tests/node_converters/batch_norm_test.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import numpy as np 4 | import onnx 5 | import pytest 6 | 7 | from tests.utils.common import check_onnx_model 8 | from tests.utils.common import make_model_from_nodes 9 | 10 | 11 | @pytest.mark.parametrize( 12 | 'parameters_as_inputs', 13 | (True, False), 14 | ) 15 | @pytest.mark.parametrize( 16 | 'input_shape', 17 | ( 18 | # 1d 19 | [2, 3, 16], 20 | [2, 1, 7], 21 | # 2d 22 | [2, 3, 16, 16], 23 | [2, 1, 7, 16], 24 | # 3d 25 | [2, 3, 16, 16, 16], 26 | [2, 1, 16, 7, 16], 27 | ), 28 | ) 29 | def test_batch_norm( # pylint: disable=missing-function-docstring 30 | input_shape: List[int], 31 | parameters_as_inputs: bool, 32 | ) -> None: 33 | num_features = input_shape[1] 34 | x = np.random.uniform(low=-1.0, high=1.0, size=input_shape).astype(np.float32) 35 | scale = np.random.uniform(low=0.0, high=1.0, size=num_features).astype(np.float32) 36 | bias = np.random.uniform(low=-1.0, high=1.0, size=num_features).astype(np.float32) 37 | mean = np.random.uniform(low=-1.0, high=1.0, size=num_features).astype(np.float32) 38 | var = np.random.uniform(low=0.001, high=0.5, size=num_features).astype(np.float32) 39 | 40 | test_inputs = {'x': x} 41 | initializers = {} 42 | parameters = { 43 | 'scale': scale, 44 | 'bias': bias, 45 | 'mean': mean, 46 | 'var': var, 47 | } 48 | if parameters_as_inputs: 49 | initializers.update(parameters) 50 | else: 51 | test_inputs.update(parameters) 52 | 53 | node = onnx.helper.make_node( 54 | op_type='BatchNormalization', 55 | inputs=['x', 'scale', 'bias', 'mean', 'var'], 56 | outputs=['y'], 57 | ) 58 | 59 | model = make_model_from_nodes(nodes=node, initializers=initializers, inputs_example=test_inputs) 60 | check_onnx_model( 61 | model, 62 | test_inputs, 63 | atol_onnx_torch=10**-6, 64 | atol_torch_cpu_cuda=10**-6, 65 | ) 66 | -------------------------------------------------------------------------------- /tests/node_converters/binary_operations_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import onnx 3 | import pytest 4 | 5 | from tests.utils.common import check_onnx_model 6 | from tests.utils.common import make_model_from_nodes 7 | 8 | 9 | @pytest.mark.parametrize( 10 | 'op_type', 11 | ('Add', 'Sub', 'Mul', 'Div'), 12 | ) 13 | def test_math_binary_operation(op_type: str) -> None: # pylint: disable=missing-function-docstring 14 | input_shape = [10, 3, 128, 128] 15 | x = np.random.uniform(low=-1.0, high=1.0, size=input_shape).astype(np.float32) 16 | y_variants = [ 17 | np.random.uniform(low=-1.0, high=1.0, size=1).astype(np.float32), 18 | np.random.uniform(low=-1.0, high=1.0, size=[1] * len(input_shape)).astype(np.float32), 19 | np.random.uniform(low=-1.0, high=1.0, size=input_shape).astype(np.float32), 20 | np.array([0.0], dtype=np.float32), 21 | ] 22 | for y in y_variants: 23 | test_inputs = {'x': x, 'y': y} 24 | initializers = {} 25 | node = onnx.helper.make_node( 26 | op_type=op_type, 27 | inputs=['x', 'y'], 28 | outputs=['z'], 29 | ) 30 | 31 | model = make_model_from_nodes(nodes=node, initializers=initializers, inputs_example=test_inputs) 32 | check_onnx_model(model, test_inputs) 33 | 34 | 35 | @pytest.mark.parametrize( 36 | 'x, y', 37 | [ 38 | (1, 2), 39 | (1, 5), 40 | (5, 30), 41 | (-1, 2), 42 | (-1, 5), 43 | (5, -30), 44 | (5, 2), 45 | (-5, 2), 46 | ], 47 | ) 48 | def test_div_operation(x: int, y: int) -> None: # pylint: disable=missing-function-docstring 49 | x_ = np.array(x, dtype=np.int64) # pylint: disable=invalid-name 50 | y_ = np.array(y, dtype=np.int64) # pylint: disable=invalid-name 51 | test_inputs = {'x': x_, 'y': y_} 52 | 53 | node = onnx.helper.make_node( 54 | op_type='Div', 55 | inputs=['x', 'y'], 56 | outputs=['z'], 57 | ) 58 | 59 | model = make_model_from_nodes(nodes=node, initializers={}, inputs_example=test_inputs) 60 | check_onnx_model(model, test_inputs) 61 | -------------------------------------------------------------------------------- /tests/node_converters/clip_test.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from typing import Tuple 3 | 4 | import numpy as np 5 | import onnx 6 | 7 | from tests.utils.common import check_onnx_model 8 | from tests.utils.common import make_model_from_nodes 9 | 10 | 11 | def _test_clip( 12 | input_shape: Tuple[int, int, int, int], 13 | min_value: Optional[float] = None, 14 | max_value: Optional[float] = None, 15 | **kwargs, 16 | ) -> None: 17 | x_range = 2 * max_value if max_value is not None else 5 18 | x = np.random.uniform(low=-x_range, high=x_range, size=input_shape).astype(np.float32) 19 | test_inputs = {'x': x} 20 | 21 | initializers = {} 22 | if min_value is not None: 23 | initializers['min'] = np.array(min_value, dtype=np.float32) 24 | 25 | if max_value is not None: 26 | initializers['max'] = np.array(max_value, dtype=np.float32) 27 | 28 | node = onnx.helper.make_node( 29 | op_type='Clip', 30 | inputs=list(test_inputs) + list(initializers), 31 | outputs=['y'], 32 | **kwargs, 33 | ) 34 | model = make_model_from_nodes(nodes=node, initializers=initializers, inputs_example=test_inputs) 35 | check_onnx_model(model, test_inputs) 36 | 37 | 38 | def _test_clip_opset9( 39 | input_shape: Tuple[int, int, int, int], 40 | **kwargs, 41 | ) -> None: 42 | x = np.random.uniform(low=-10.0, high=10.0, size=input_shape).astype(np.float32) 43 | test_inputs = {'x': x} 44 | 45 | node = onnx.helper.make_node( 46 | op_type='Clip', 47 | inputs=list(test_inputs), 48 | outputs=['y'], 49 | **kwargs, 50 | ) 51 | model = make_model_from_nodes(nodes=node, initializers={}, inputs_example=test_inputs, opset_version=9) 52 | check_onnx_model(model, test_inputs) 53 | 54 | 55 | def test_clip() -> None: # pylint: disable=missing-function-docstring 56 | _test_clip(input_shape=(2, 3, 16, 16), min_value=0.0, max_value=6.0) 57 | _test_clip(input_shape=(2, 3, 16, 16), min_value=0.0) 58 | _test_clip(input_shape=(2, 3, 16, 16), min_value=-1.5, max_value=2.5) 59 | _test_clip_opset9(input_shape=(2, 3, 16, 16), min=0.0, max=6.0) 60 | _test_clip_opset9(input_shape=(2, 3, 16, 16), min=0.0) 61 | _test_clip_opset9(input_shape=(2, 3, 16, 16), min=-1.7, max=2.8) 62 | -------------------------------------------------------------------------------- /tests/node_converters/comparisons_test.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import numpy as np 4 | import onnx 5 | import pytest 6 | from onnx.helper import make_tensor_value_info 7 | from onnx.mapping import NP_TYPE_TO_TENSOR_TYPE 8 | 9 | from tests.utils.common import check_onnx_model 10 | from tests.utils.common import make_model_from_nodes 11 | 12 | 13 | def _test_comparison(op_type: str, x: np.ndarray, y: np.ndarray, opset_version: int = 13) -> None: 14 | test_inputs = {'x': x, 'y': y} 15 | 16 | node = onnx.helper.make_node( 17 | op_type=op_type, 18 | inputs=list(test_inputs), 19 | outputs=['out'], 20 | ) 21 | outputs_info = [ 22 | make_tensor_value_info( 23 | name='out', 24 | elem_type=NP_TYPE_TO_TENSOR_TYPE[np.dtype('bool')], 25 | shape=x.shape, 26 | ), 27 | ] 28 | 29 | model = make_model_from_nodes( 30 | nodes=node, 31 | initializers={}, 32 | inputs_example=test_inputs, 33 | outputs_info=outputs_info, 34 | opset_version=opset_version, 35 | ) 36 | check_onnx_model(model, test_inputs) 37 | 38 | 39 | @pytest.mark.parametrize( 40 | 'op_type,x_shape,y_shape', 41 | ( 42 | ('Equal', [3, 4, 5], [5]), 43 | ('Equal', [3, 4, 5], [3, 4, 5]), 44 | ('Less', [3, 4, 5], [5]), 45 | ('Less', [3, 4, 5], [3, 4, 5]), 46 | ('Greater', [3, 4, 5], [5]), 47 | ('Greater', [3, 4, 5], [3, 4, 5]), 48 | ('LessOrEqual', [3, 4, 5], [5]), 49 | ('LessOrEqual', [3, 4, 5], [3, 4, 5]), 50 | ('GreaterOrEqual', [3, 4, 5], [5]), 51 | ('GreaterOrEqual', [3, 4, 5], [3, 4, 5]), 52 | ), 53 | ) 54 | def test_comparison( # pylint: disable=missing-function-docstring 55 | op_type: str, 56 | x_shape: List[int], 57 | y_shape: List[int], 58 | ) -> None: 59 | _test_comparison( 60 | op_type=op_type, 61 | x=np.random.randn(*x_shape), 62 | y=np.random.randn(*y_shape), 63 | ) 64 | -------------------------------------------------------------------------------- /tests/node_converters/concat_test.py: -------------------------------------------------------------------------------- 1 | from itertools import product 2 | from typing import List 3 | 4 | import numpy as np 5 | import onnx 6 | from onnx.helper import make_tensor_value_info 7 | from onnx.mapping import NP_TYPE_TO_TENSOR_TYPE 8 | 9 | from tests.utils.common import check_onnx_model 10 | from tests.utils.common import make_model_from_nodes 11 | 12 | 13 | def _test_concat( 14 | input_arrays_shapes: List[List[int]], 15 | opset_version: int, 16 | **kwargs, 17 | ) -> None: 18 | test_inputs = {} 19 | for index, input_array_shape in enumerate(input_arrays_shapes): 20 | x = np.random.uniform(low=-1.0, high=1.0, size=input_array_shape).astype(np.float32) 21 | node_name = f'x_{index}' 22 | test_inputs[node_name] = x 23 | 24 | node = onnx.helper.make_node( 25 | 'Concat', 26 | inputs=list(test_inputs), 27 | outputs=['y'], 28 | **kwargs, 29 | ) 30 | 31 | onnx_type = NP_TYPE_TO_TENSOR_TYPE[np.dtype('float32')] 32 | outputs_info = [make_tensor_value_info(name='y', elem_type=onnx_type, shape=None)] 33 | model = make_model_from_nodes( 34 | nodes=node, 35 | initializers={}, 36 | inputs_example=test_inputs, 37 | outputs_info=outputs_info, 38 | opset_version=opset_version, 39 | ) 40 | check_onnx_model(model, test_inputs) 41 | 42 | 43 | def test_concat() -> None: # pylint: disable=missing-function-docstring 44 | opset_variants = (9, 13) 45 | axis_variants = (0, 1) 46 | for opset_version, axis in product(opset_variants, axis_variants): 47 | _test_concat( 48 | input_arrays_shapes=[[1, 3, 16, 16], [1, 3, 16, 16], [1, 3, 16, 16]], 49 | axis=axis, 50 | opset_version=opset_version, 51 | ) 52 | -------------------------------------------------------------------------------- /tests/node_converters/constant_of_shape_test.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import onnx 5 | import pytest 6 | from onnx import numpy_helper 7 | from onnx.helper import make_tensor_value_info 8 | from onnx.mapping import NP_TYPE_TO_TENSOR_TYPE 9 | 10 | from tests.utils.common import check_onnx_model 11 | from tests.utils.common import make_model_from_nodes 12 | 13 | 14 | def _test_constant_of_shape(shape: np.ndarray, value: np.ndarray) -> None: 15 | test_inputs = {'shape': shape} 16 | onnx_type = NP_TYPE_TO_TENSOR_TYPE[value.dtype] 17 | 18 | node = onnx.helper.make_node( 19 | 'ConstantOfShape', 20 | inputs=list(test_inputs), 21 | outputs=['output'], 22 | value=numpy_helper.from_array(value, name='value'), 23 | ) 24 | 25 | outputs_info = [make_tensor_value_info(name='output', elem_type=onnx_type, shape=shape.tolist())] 26 | 27 | model = make_model_from_nodes( 28 | nodes=node, 29 | initializers={}, 30 | inputs_example=test_inputs, 31 | outputs_info=outputs_info, 32 | ) 33 | check_onnx_model(model, test_inputs) 34 | 35 | 36 | @pytest.mark.filterwarnings('ignore::torch.jit._trace.TracerWarning') 37 | def test_constant_of_shape() -> None: # pylint: disable=missing-function-docstring 38 | for _ in range(10): 39 | size = random.randint(1, 6) 40 | shape = np.random.randint(low=1, high=2, size=(size,)) 41 | value = np.random.uniform(low=-10000, high=10000, size=(1,)) 42 | _test_constant_of_shape(shape, value) 43 | 44 | _test_constant_of_shape(np.asarray([3, 3]), np.asarray([True])) 45 | -------------------------------------------------------------------------------- /tests/node_converters/constant_test.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import numpy as np 4 | import onnx 5 | import pytest 6 | from onnx import numpy_helper 7 | from onnx.helper import make_tensor_value_info 8 | from onnx.mapping import NP_TYPE_TO_TENSOR_TYPE 9 | 10 | from tests.utils.common import check_onnx_model 11 | from tests.utils.common import make_model_from_nodes 12 | 13 | 14 | def _test_constant_as_tensor(shape: Tuple[int, ...], dtype: np.dtype) -> None: 15 | values = np.random.randn(*shape).astype(dtype) 16 | onnx_type = NP_TYPE_TO_TENSOR_TYPE[values.dtype] 17 | node = onnx.helper.make_node( 18 | 'Constant', 19 | inputs=[], 20 | outputs=['values'], 21 | value=numpy_helper.from_array(values, name='const_tensor'), 22 | ) 23 | 24 | outputs_info = [make_tensor_value_info(name='values', elem_type=onnx_type, shape=values.shape)] 25 | model = make_model_from_nodes( 26 | nodes=node, 27 | initializers={}, 28 | inputs_example={}, 29 | outputs_info=outputs_info, 30 | ) 31 | check_onnx_model(model, onnx_inputs={}) 32 | 33 | 34 | @pytest.mark.filterwarnings('ignore:No input args') 35 | def test_constant() -> None: # pylint: disable=missing-function-docstring 36 | _test_constant_as_tensor((16, 16, 16), np.dtype('int32')) 37 | _test_constant_as_tensor((16, 16, 16), np.dtype('int32')) 38 | _test_constant_as_tensor((16, 16, 16), np.dtype('float32')) 39 | -------------------------------------------------------------------------------- /tests/node_converters/cumsum_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import onnx 3 | import pytest 4 | from onnx.helper import make_tensor_value_info 5 | from onnx.mapping import NP_TYPE_TO_TENSOR_TYPE 6 | 7 | from tests.utils.common import check_onnx_model 8 | from tests.utils.common import make_model_from_nodes 9 | 10 | 11 | def _test_cumsum( 12 | input_tensor: np.ndarray, 13 | axis: int, 14 | exclusive: int, 15 | reverse: int, 16 | ) -> None: 17 | test_inputs = {'x': input_tensor, 'axis': np.array(axis)} 18 | node = onnx.helper.make_node( 19 | op_type='CumSum', 20 | inputs=list(test_inputs.keys()), 21 | outputs=['y'], 22 | exclusive=exclusive, 23 | reverse=reverse, 24 | ) 25 | 26 | outputs_info = [ 27 | make_tensor_value_info( 28 | name='y', 29 | elem_type=NP_TYPE_TO_TENSOR_TYPE[input_tensor.dtype], 30 | shape=input_tensor.shape, 31 | ), 32 | ] 33 | model = make_model_from_nodes( 34 | nodes=node, 35 | initializers={}, 36 | inputs_example=test_inputs, 37 | outputs_info=outputs_info, 38 | ) 39 | check_onnx_model(model, test_inputs) 40 | 41 | 42 | @pytest.mark.parametrize( 43 | 'tensor_size', 44 | ( 45 | (10,), 46 | (10, 10), 47 | (10, 10, 5), 48 | (10, 10, 5, 6), 49 | ), 50 | ) 51 | @pytest.mark.parametrize( 52 | 'exclusive,reverse', 53 | ( 54 | (0, 0), 55 | (0, 1), 56 | (1, 0), 57 | (1, 1), 58 | ), 59 | ) 60 | def test_cumsum(tensor_size, exclusive, reverse) -> None: # pylint: disable=missing-function-docstring 61 | input_tensor = np.random.randint(low=-10, high=10, size=tensor_size) 62 | for axis in range(-len(tensor_size), len(tensor_size) - 1): 63 | _test_cumsum( 64 | input_tensor=input_tensor, 65 | axis=axis, 66 | exclusive=exclusive, 67 | reverse=reverse, 68 | ) 69 | -------------------------------------------------------------------------------- /tests/node_converters/depth_to_space_test.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=missing-function-docstring 2 | from typing import List 3 | 4 | import numpy as np 5 | import onnx 6 | import pytest 7 | 8 | from tests.utils.common import check_onnx_model 9 | from tests.utils.common import make_model_from_nodes 10 | 11 | 12 | def _test_depth_to_space( 13 | input_shape: List[int], 14 | blocksize: int, 15 | mode: str, 16 | opset: int, 17 | ) -> None: 18 | x = np.random.uniform(low=-1.0, high=1.0, size=input_shape).astype(np.float32) 19 | test_inputs = {'x': x} 20 | 21 | node = onnx.helper.make_node( # type: ignore 22 | op_type='DepthToSpace', 23 | inputs=['x'], 24 | outputs=['y'], 25 | blocksize=blocksize, 26 | mode=mode, 27 | ) 28 | model = make_model_from_nodes(nodes=node, initializers={}, inputs_example=test_inputs, opset_version=opset) 29 | check_onnx_model(model, test_inputs) 30 | 31 | 32 | @pytest.mark.parametrize( 33 | 'input_shape, blocksize', 34 | [ 35 | ([1, 12, 3, 3], 2), 36 | ([5, 75, 3, 3], 5), 37 | ([7, 588, 3, 4], 7), 38 | ], 39 | ) 40 | @pytest.mark.parametrize('opset', [11, 13]) 41 | def test_depth_to_space(input_shape: List[int], blocksize: int, opset: int) -> None: 42 | _test_depth_to_space(input_shape=input_shape, blocksize=blocksize, mode='CRD', opset=opset) 43 | -------------------------------------------------------------------------------- /tests/node_converters/dropout_test.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from typing import Optional 3 | 4 | import numpy as np 5 | import onnx 6 | import pytest 7 | 8 | from tests.utils.common import check_onnx_model 9 | from tests.utils.common import make_model_from_nodes 10 | 11 | 12 | def _test_dropout(data: np.ndarray, opset_version: int, **kwargs) -> None: 13 | test_inputs = {'input_tensor': data} 14 | 15 | if opset_version >= 12: 16 | if 'ratio' in kwargs: 17 | test_inputs['ratio'] = np.array(kwargs.pop('ratio'), dtype=np.float16) 18 | if 'training_mode' in kwargs: 19 | test_inputs['training_mode'] = np.array(kwargs.pop('training_mode'), dtype=bool) 20 | 21 | node = onnx.helper.make_node(op_type='Dropout', inputs=list(test_inputs), outputs=['y'], **kwargs) 22 | model = make_model_from_nodes( 23 | nodes=node, 24 | initializers={}, 25 | inputs_example=test_inputs, 26 | opset_version=opset_version, 27 | ) 28 | 29 | check_onnx_model(model, test_inputs) 30 | 31 | 32 | @pytest.mark.parametrize( 33 | 'input_shape,ratio,training_mode,opset_version', 34 | ( 35 | ([3, 32, 32], None, None, 10), 36 | ([3, 32, 32], None, None, 12), 37 | ([3, 32, 32], None, None, 13), 38 | ([3, 32, 32], 0.8, None, 10), 39 | ([3, 32, 32], 0.8, None, 12), 40 | ([3, 32, 32], 0.8, None, 13), 41 | ([3, 32, 32], 0.8, False, 13), 42 | ([3, 32, 32], 0.8, False, 13), 43 | ([8, 3, 32, 32], None, None, 10), 44 | ([8, 3, 32, 32, 32], None, None, 10), 45 | ), 46 | ) 47 | def test_dropout( # pylint: disable=missing-function-docstring 48 | input_shape: List[int], 49 | ratio: Optional[float], 50 | training_mode: Optional[bool], 51 | opset_version: int, 52 | ) -> None: 53 | data = np.random.randn(*input_shape).astype(np.float32) 54 | kwargs = {} 55 | if ratio is not None: 56 | kwargs['ratio'] = ratio 57 | if training_mode is not None: 58 | kwargs['training_mode'] = training_mode 59 | _test_dropout(data=data, opset_version=opset_version, **kwargs) 60 | -------------------------------------------------------------------------------- /tests/node_converters/einsum_test.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from typing import Tuple 3 | 4 | import numpy as np 5 | import onnx 6 | import pytest 7 | from onnx.helper import make_tensor_value_info 8 | from onnx.mapping import NP_TYPE_TO_TENSOR_TYPE 9 | 10 | from tests.utils.common import check_onnx_model 11 | from tests.utils.common import make_model_from_nodes 12 | 13 | 14 | @pytest.mark.parametrize( 15 | 'equation,input_shapes,output_shape', 16 | ( 17 | ('...ii ->...i', [(3, 5, 5)], (3, 5)), 18 | ('i,i', [(5,), (5,)], None), 19 | ('ij->i', [(3, 4)], (3,)), 20 | ('ij->ji', [(3, 4)], (4, 3)), 21 | ), 22 | ) 23 | def test_einsum( # pylint: disable=missing-function-docstring 24 | equation: str, 25 | input_shapes: List[Tuple[int, ...]], 26 | output_shape: Tuple[int, ...], 27 | ) -> None: 28 | test_inputs = {f'input_{index}': np.random.randn(*shape) for index, shape in enumerate(input_shapes)} 29 | 30 | node = onnx.helper.make_node( 31 | op_type='Einsum', 32 | inputs=list(test_inputs), 33 | outputs=['out'], 34 | equation=equation, 35 | ) 36 | outputs_info = [ 37 | make_tensor_value_info( 38 | name='out', 39 | elem_type=NP_TYPE_TO_TENSOR_TYPE[np.dtype('float')], 40 | shape=output_shape, 41 | ), 42 | ] 43 | 44 | model = make_model_from_nodes( 45 | nodes=node, 46 | initializers={}, 47 | inputs_example=test_inputs, 48 | outputs_info=outputs_info, 49 | opset_version=13, 50 | ) 51 | check_onnx_model(model, test_inputs) 52 | -------------------------------------------------------------------------------- /tests/node_converters/expand_test.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import numpy as np 4 | import onnx 5 | import pytest 6 | from onnx.helper import make_tensor_value_info 7 | from onnx.mapping import NP_TYPE_TO_TENSOR_TYPE 8 | 9 | from tests.utils.common import check_onnx_model 10 | from tests.utils.common import make_model_from_nodes 11 | 12 | 13 | def _test_expand( 14 | data: np.ndarray, 15 | shape: List[int], 16 | ) -> None: 17 | test_inputs = { 18 | 'x': data, 19 | 'shape': np.array(shape, dtype=np.int64), 20 | } 21 | 22 | node = onnx.helper.make_node(op_type='Expand', inputs=list(test_inputs), outputs=['y']) 23 | outputs_info = [ 24 | make_tensor_value_info( 25 | name='y', 26 | elem_type=NP_TYPE_TO_TENSOR_TYPE[data.dtype], 27 | shape=[None] * len(shape), 28 | ), 29 | ] 30 | 31 | model = make_model_from_nodes( 32 | nodes=node, 33 | initializers={}, 34 | inputs_example=test_inputs, 35 | outputs_info=outputs_info, 36 | ) 37 | check_onnx_model(model, test_inputs) 38 | 39 | 40 | @pytest.mark.parametrize( 41 | 'src_shape,dst_shape', 42 | ( 43 | ([3, 1], [2, 1, 6]), 44 | ([3, 1], [3, 4]), 45 | ), 46 | ) 47 | def test_expand(src_shape: List[int], dst_shape: List[int]) -> None: # pylint: disable=missing-function-docstring 48 | data = np.reshape(np.arange(1, np.prod(src_shape) + 1, dtype=np.float32), src_shape) 49 | _test_expand( 50 | data=data, 51 | shape=dst_shape, 52 | ) 53 | -------------------------------------------------------------------------------- /tests/node_converters/eye_like_test.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from typing import Tuple 3 | 4 | import numpy as np 5 | import onnx 6 | import pytest 7 | from onnx.helper import make_tensor_value_info 8 | 9 | from tests.utils.common import check_onnx_model 10 | from tests.utils.common import make_model_from_nodes 11 | 12 | 13 | @pytest.mark.parametrize('dtype', [None, 1, 6, 7, 11]) 14 | @pytest.mark.parametrize('k', [-2, -1, 0, 1, 2]) 15 | @pytest.mark.parametrize('shape', [[2, 3], [3, 4], [3, 3]]) 16 | def test_eye_like( # pylint: disable=missing-function-docstring 17 | shape: Tuple[int], 18 | dtype: Optional[int], 19 | k: int, # pylint: disable=invalid-name 20 | ) -> None: 21 | input_values = np.random.randn(*shape).astype(np.float32) 22 | test_inputs = {'x': input_values} 23 | 24 | node = onnx.helper.make_node(op_type='EyeLike', inputs=['x'], outputs=['z'], dtype=dtype, k=k) 25 | model = make_model_from_nodes( 26 | nodes=node, 27 | initializers={}, 28 | inputs_example=test_inputs, 29 | outputs_info=[make_tensor_value_info(name='z', elem_type=dtype, shape=shape)] if dtype else None, 30 | ) 31 | check_onnx_model(model, test_inputs) 32 | -------------------------------------------------------------------------------- /tests/node_converters/flatten_test.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import numpy as np 4 | import onnx 5 | 6 | from tests.utils.common import check_onnx_model 7 | from tests.utils.common import make_model_from_nodes 8 | 9 | 10 | def _test_flatten( 11 | input_shape: List[int], 12 | **kwargs, 13 | ) -> None: 14 | x = np.random.uniform(low=-1.0, high=1.0, size=input_shape).astype(np.float32) 15 | test_inputs = {'x': x} 16 | 17 | node = onnx.helper.make_node( 18 | op_type='Flatten', 19 | inputs=['x'], 20 | outputs=['y'], 21 | **kwargs, 22 | ) 23 | model = make_model_from_nodes(nodes=node, initializers={}, inputs_example=test_inputs) 24 | check_onnx_model(model, test_inputs) 25 | 26 | 27 | def test_flatten() -> None: # pylint: disable=missing-function-docstring 28 | _test_flatten(input_shape=[2, 3, 16, 16, 16]) 29 | _test_flatten(input_shape=[2, 3, 16, 16], axis=2) 30 | _test_flatten(input_shape=[2, 3, 16], axis=-1) 31 | -------------------------------------------------------------------------------- /tests/node_converters/gather_test.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from typing import cast 3 | 4 | import numpy as np 5 | import onnx 6 | import pytest 7 | 8 | from tests.utils.common import check_onnx_model 9 | from tests.utils.common import make_model_from_nodes 10 | 11 | 12 | def _test_gather( 13 | op_type: str, 14 | input_array: np.ndarray, 15 | indices: np.ndarray, 16 | opset_version: int, 17 | **kwargs, 18 | ) -> None: 19 | test_inputs = { 20 | 'x': input_array, 21 | 'indices': indices, 22 | } 23 | 24 | node = onnx.helper.make_node( 25 | op_type, 26 | inputs=list(test_inputs), 27 | outputs=['y'], 28 | **kwargs, 29 | ) 30 | 31 | model = make_model_from_nodes( 32 | nodes=node, 33 | initializers={}, 34 | inputs_example=test_inputs, 35 | opset_version=opset_version, 36 | ) 37 | check_onnx_model(model, test_inputs) 38 | 39 | 40 | @pytest.mark.parametrize( 41 | 'op_type,axis,opset_version', 42 | ( 43 | ('Gather', 0, 13), 44 | ('Gather', 0, 11), 45 | ('Gather', 0, 9), 46 | ('Gather', 1, 13), 47 | ('Gather', 1, 11), 48 | ('Gather', 1, 9), 49 | ('GatherElements', 0, 13), 50 | ('GatherElements', 0, 11), 51 | ('GatherElements', 1, 13), 52 | ('GatherElements', 1, 11), 53 | ), 54 | ) 55 | def test_gather(op_type: str, axis: int, opset_version: int) -> None: # pylint: disable=missing-function-docstring 56 | input_tensor = np.asarray( 57 | [ 58 | [1.0, 1.2, 1.9], 59 | [2.3, 3.4, 3.9], 60 | [4.5, 5.7, 5.9], 61 | ], 62 | dtype=np.float32, 63 | ) 64 | indices = np.asarray( 65 | [ 66 | [1, 0], 67 | ], 68 | dtype=np.int64, 69 | ) 70 | _test_gather(op_type=op_type, input_array=input_tensor, indices=indices, axis=axis, opset_version=opset_version) 71 | 72 | 73 | @pytest.mark.parametrize('opset_version', (11, 12, 13)) 74 | @pytest.mark.parametrize( 75 | 'data_shape, indices_shape, batch_dims', 76 | ( 77 | # Examples from ONNX opset doc: https://github.com/onnx/onnx/blob/main/docs/Changelog.md#GatherND-13. 78 | ([2, 2], [2, 2], 0), 79 | ([2, 2], [2, 1], 0), 80 | ([2, 2, 2], [2, 2], 0), 81 | ([2, 2, 2], [2, 1, 2], 0), 82 | pytest.param([2, 2, 2], [2, 1], 1, marks=pytest.mark.xfail(reason='implemented for batch_dims = 0 only')), 83 | # Our tests. 84 | ([8, 3, 16, 16], [16, 3], 0), 85 | ([16, 3, 224, 224], [32, 1, 3], 0), 86 | ), 87 | ) 88 | def test_gather_nd( # pylint: disable=missing-function-docstring 89 | data_shape: List[int], 90 | indices_shape: List[int], 91 | batch_dims: int, 92 | opset_version: int, 93 | ) -> None: 94 | input_tensor = cast(np.ndarray, np.random.rand(*data_shape)) 95 | indices_high = data_shape[: indices_shape[-1]] 96 | indices = np.random.randint(low=0, high=indices_high, size=indices_shape, dtype=np.int64) 97 | 98 | _test_gather( 99 | op_type='GatherND', 100 | input_array=input_tensor, 101 | indices=indices, 102 | batch_dims=batch_dims if opset_version > 11 else None, 103 | opset_version=opset_version, 104 | ) 105 | -------------------------------------------------------------------------------- /tests/node_converters/global_avg_pool_test.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import numpy as np 4 | import onnx 5 | import pytest 6 | 7 | from tests.utils.common import check_onnx_model 8 | from tests.utils.common import make_model_from_nodes 9 | 10 | 11 | @pytest.mark.parametrize( 12 | 'input_shape', 13 | ( 14 | [2, 3, 16, 16, 16], 15 | [2, 3, 16, 16], 16 | [2, 3, 16], 17 | ), 18 | ) 19 | def test_global_avg_pool(input_shape: List[int]) -> None: # pylint: disable=missing-function-docstring 20 | x = np.random.uniform(low=-1.0, high=1.0, size=input_shape).astype(np.float32) 21 | test_inputs = {'x': x} 22 | 23 | node = onnx.helper.make_node( 24 | op_type='GlobalAveragePool', 25 | inputs=['x'], 26 | outputs=['y'], 27 | ) 28 | model = make_model_from_nodes(nodes=node, initializers={}, inputs_example=test_inputs) 29 | check_onnx_model( 30 | model, 31 | test_inputs, 32 | atol_onnx_torch=10**-7, 33 | atol_torch_cpu_cuda=10**-7, 34 | ) 35 | -------------------------------------------------------------------------------- /tests/node_converters/instance_norm_test.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import numpy as np 4 | import onnx 5 | import pytest 6 | 7 | from tests.utils.common import check_onnx_model 8 | from tests.utils.common import make_model_from_nodes 9 | 10 | 11 | @pytest.mark.parametrize('parameters_as_inputs', (True, False)) 12 | @pytest.mark.parametrize( 13 | 'input_shape', 14 | ( 15 | # 1d 16 | [2, 3, 16], 17 | [2, 1, 7], 18 | # 2d 19 | [2, 3, 16, 16], 20 | [2, 1, 7, 16], 21 | # 3d 22 | [2, 3, 16, 16, 16], 23 | [2, 1, 16, 7, 16], 24 | ), 25 | ) 26 | def test_instance_norm( # pylint: disable=missing-function-docstring 27 | input_shape: List[int], 28 | parameters_as_inputs: bool, 29 | ) -> None: 30 | num_features = input_shape[1] 31 | x = np.random.uniform(low=-1.0, high=1.0, size=input_shape).astype(np.float32) 32 | scale = np.random.uniform(low=0.0, high=1.0, size=num_features).astype(np.float32) 33 | bias = np.random.uniform(low=-1.0, high=1.0, size=num_features).astype(np.float32) 34 | 35 | inputs = {'input': x} 36 | parameters = {'scale': scale, 'bias': bias} 37 | initializers = {} 38 | 39 | if parameters_as_inputs: 40 | inputs.update(parameters) 41 | else: 42 | initializers.update(parameters) 43 | 44 | node = onnx.helper.make_node(op_type='InstanceNormalization', inputs=['input', 'scale', 'bias'], outputs=['y']) 45 | 46 | model = make_model_from_nodes(nodes=node, initializers=initializers, inputs_example=inputs) 47 | check_onnx_model(onnx_model=model, onnx_inputs=inputs, atol_onnx_torch=1e-6, atol_torch_cpu_cuda=1e-6) 48 | -------------------------------------------------------------------------------- /tests/node_converters/layer_norm_test.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=missing-function-docstring 2 | from typing import List 3 | from typing import Optional 4 | 5 | import numpy as np 6 | import onnx 7 | import pytest 8 | 9 | from tests.utils.common import check_onnx_model 10 | from tests.utils.common import make_model_from_nodes 11 | 12 | 13 | def _test_layer_norm( 14 | x: np.ndarray, 15 | scale: np.ndarray, 16 | bias: Optional[np.ndarray], 17 | axis: int, 18 | parameters_as_inputs: bool, 19 | ) -> None: 20 | inputs = {'input': x} 21 | parameters = {'scale': scale} 22 | if bias is not None: 23 | parameters['bias'] = bias 24 | 25 | initializers = {} 26 | 27 | if parameters_as_inputs: 28 | inputs.update(parameters) 29 | else: 30 | initializers.update(parameters) 31 | 32 | node = onnx.helper.make_node( 33 | op_type='LayerNormalization', 34 | inputs=['input', 'scale', 'bias'] if bias is not None else ['input', 'scale'], 35 | outputs=['y'], 36 | axis=axis, 37 | ) 38 | model = make_model_from_nodes(nodes=node, initializers=initializers, inputs_example=inputs, opset_version=17) 39 | check_onnx_model( 40 | onnx_model=model, 41 | onnx_inputs=inputs, 42 | atol_onnx_torch=1e-5, 43 | atol_torch_cpu_cuda=1e-5, 44 | atol_onnx_torch2onnx=1e-5, 45 | ) 46 | 47 | 48 | @pytest.mark.parametrize('parameters_as_inputs', (True, False)) 49 | @pytest.mark.parametrize( 50 | 'input_shape', 51 | ( 52 | [2, 3, 16], 53 | [3, 1, 224], 54 | [4, 3, 16, 16], 55 | [5, 1, 32, 32], 56 | [6, 3, 16, 16, 8], 57 | [7, 1, 7, 7, 16], 58 | ), 59 | ) 60 | def test_layer_norm(input_shape: List[int], parameters_as_inputs: bool) -> None: 61 | x = np.random.randn(*input_shape).astype(np.float32) 62 | 63 | for axis in [*range(len(input_shape))] + [-1]: 64 | normalized_shape = input_shape[axis:] 65 | 66 | scale = np.random.randn(*normalized_shape).astype(np.float32) 67 | bias = np.random.randn(*normalized_shape).astype(np.float32) 68 | 69 | for bias_ in [bias, None]: 70 | _test_layer_norm( 71 | x=x, 72 | scale=scale, 73 | bias=bias_, 74 | axis=axis, 75 | parameters_as_inputs=parameters_as_inputs, 76 | ) 77 | -------------------------------------------------------------------------------- /tests/node_converters/logical_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import onnx 3 | import pytest 4 | 5 | from tests.utils.common import check_onnx_model 6 | from tests.utils.common import make_model_from_nodes 7 | 8 | 9 | @pytest.mark.parametrize( 10 | 'op_type', 11 | ('Or', 'And', 'Xor'), 12 | ) 13 | def test_logical_operation(op_type: str) -> None: # pylint: disable=missing-function-docstring 14 | x = np.random.randn(10, 1, 64, 128) > 0 15 | y_variants = ( 16 | (np.random.randn(128) > 0), 17 | (np.random.randn(64, 128) > 0), 18 | (np.random.randn(1, 64, 128) > 0), 19 | (np.random.randn(1, 3, 1, 128) > 0), 20 | (np.random.randn(10, 1, 64, 128) > 0), 21 | ) 22 | for y in y_variants: 23 | test_inputs = {'x': x, 'y': y} 24 | initializers = {} 25 | node = onnx.helper.make_node( 26 | op_type=op_type, 27 | inputs=['x', 'y'], 28 | outputs=['z'], 29 | ) 30 | 31 | model = make_model_from_nodes( 32 | nodes=node, 33 | initializers=initializers, 34 | inputs_example=test_inputs, 35 | ) 36 | check_onnx_model(model, test_inputs) 37 | 38 | 39 | def test_not() -> None: # pylint: disable=missing-function-docstring 40 | x_variants = ( 41 | (np.random.randn(128) > 0), 42 | (np.random.randn(64, 128) > 0), 43 | (np.random.randn(1, 64, 128) > 0), 44 | (np.random.randn(10, 1, 64, 128) > 0), 45 | ) 46 | 47 | for x in x_variants: 48 | test_inputs = {'x': x} 49 | initializers = {} 50 | node = onnx.helper.make_node( 51 | op_type='Not', 52 | inputs=['x'], 53 | outputs=['y'], 54 | ) 55 | 56 | model = make_model_from_nodes(nodes=node, initializers=initializers, inputs_example=test_inputs) 57 | check_onnx_model(model, test_inputs) 58 | -------------------------------------------------------------------------------- /tests/node_converters/lrn_test.py: -------------------------------------------------------------------------------- 1 | from random import randrange 2 | 3 | import numpy as np 4 | import onnx 5 | 6 | from tests.utils.common import check_onnx_model 7 | from tests.utils.common import make_model_from_nodes 8 | 9 | 10 | def _test_lrn(data: np.ndarray, alpha: float, beta: float, bias: float, size: int) -> None: 11 | test_inputs = {'input_tensor': data} 12 | node = onnx.helper.make_node( 13 | op_type='LRN', 14 | inputs=list(test_inputs), 15 | outputs=['y'], 16 | alpha=alpha, # ONNX attributes are passed as regular keyword arguments. 17 | beta=beta, 18 | bias=bias, 19 | size=size, 20 | ) 21 | 22 | model = make_model_from_nodes( 23 | nodes=node, 24 | initializers={}, 25 | inputs_example=test_inputs, 26 | ) 27 | check_onnx_model(model, test_inputs) 28 | 29 | 30 | def test_lrn() -> None: # pylint: disable=missing-function-docstring 31 | shape = (1, 3, 227, 227) 32 | data = np.random.random_sample(shape).astype(np.float32) 33 | alpha = np.random.uniform(low=0.0, high=1.0) 34 | beta = np.random.uniform(low=0.0, high=1.0) 35 | bias = np.random.uniform(low=1.0, high=5.0) 36 | size = randrange(start=1, stop=10, step=2) # diameter of channels, not radius, must be odd 37 | _test_lrn(data=data, alpha=alpha, beta=beta, bias=bias, size=size) 38 | -------------------------------------------------------------------------------- /tests/node_converters/matmul_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import onnx 3 | 4 | from tests.utils.common import check_onnx_model 5 | from tests.utils.common import make_model_from_nodes 6 | 7 | 8 | def test_matmul() -> None: # pylint: disable=missing-function-docstring 9 | x_variants = [ 10 | np.random.randn(3, 4).astype(np.float32), 11 | np.random.randn(2, 3, 4).astype(np.float32), 12 | np.random.randn(1, 2, 3, 4).astype(np.float32), 13 | ] 14 | 15 | y_variants = [ 16 | np.random.randn(4, 3).astype(np.float32), 17 | np.random.randn(2, 4, 3).astype(np.float32), 18 | np.random.randn(1, 2, 4, 3).astype(np.float32), 19 | ] 20 | 21 | for x, y in zip(x_variants, y_variants): 22 | test_inputs = {'x': x, 'y': y} 23 | initializers = {} 24 | node = onnx.helper.make_node( 25 | op_type='MatMul', 26 | inputs=['x', 'y'], 27 | outputs=['z'], 28 | ) 29 | 30 | model = make_model_from_nodes( 31 | nodes=node, 32 | initializers=initializers, 33 | inputs_example=test_inputs, 34 | ) 35 | check_onnx_model( 36 | model, 37 | test_inputs, 38 | atol_onnx_torch=10**-6, 39 | atol_torch_cpu_cuda=10**-6, 40 | ) 41 | -------------------------------------------------------------------------------- /tests/node_converters/mean_test.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import numpy as np 4 | import onnx 5 | import pytest 6 | from onnx.helper import make_tensor_value_info 7 | from onnx.mapping import NP_TYPE_TO_TENSOR_TYPE 8 | 9 | from tests.utils.common import check_onnx_model 10 | from tests.utils.common import make_model_from_nodes 11 | 12 | 13 | def _test_mean( 14 | data_list: List[np.ndarray], 15 | ) -> None: 16 | test_inputs = {f'data_{i}': data for i, data in enumerate(data_list)} 17 | 18 | node = onnx.helper.make_node(op_type='Mean', inputs=list(test_inputs), outputs=['y']) 19 | outputs_info = [ 20 | make_tensor_value_info( 21 | name='y', 22 | elem_type=NP_TYPE_TO_TENSOR_TYPE[data_list[0].dtype], 23 | shape=None, 24 | ), 25 | ] 26 | 27 | model = make_model_from_nodes( 28 | nodes=node, 29 | initializers={}, 30 | inputs_example=test_inputs, 31 | outputs_info=outputs_info, 32 | ) 33 | check_onnx_model(model, test_inputs) 34 | 35 | 36 | @pytest.mark.parametrize( 37 | 'input_shapes', 38 | ( 39 | ([],), 40 | ([2, 3, 4],), 41 | ([3, 1], [2, 1, 6]), 42 | ([3, 1], [3, 4]), 43 | ), 44 | ) 45 | def test_mean(input_shapes: List[List[int]]) -> None: # pylint: disable=missing-function-docstring 46 | input_tensors = [np.random.normal(size=i_shape).astype(np.float32) for i_shape in input_shapes] 47 | _test_mean(data_list=input_tensors) 48 | -------------------------------------------------------------------------------- /tests/node_converters/min_max_test.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import numpy as np 4 | import onnx 5 | import pytest 6 | from onnx.helper import make_tensor_value_info 7 | from onnx.mapping import NP_TYPE_TO_TENSOR_TYPE 8 | 9 | from tests.utils.common import check_onnx_model 10 | from tests.utils.common import make_model_from_nodes 11 | 12 | 13 | def _test_min_max( 14 | data_list: List[np.ndarray], 15 | operation_type: str, 16 | ) -> None: 17 | test_inputs = {f'data_{i}': data for i, data in enumerate(data_list)} 18 | 19 | node = onnx.helper.make_node(op_type=operation_type, inputs=list(test_inputs), outputs=['y']) 20 | outputs_info = [ 21 | make_tensor_value_info( 22 | name='y', 23 | elem_type=NP_TYPE_TO_TENSOR_TYPE[data_list[0].dtype], 24 | shape=None, 25 | ), 26 | ] 27 | 28 | model = make_model_from_nodes( 29 | nodes=node, 30 | initializers={}, 31 | inputs_example=test_inputs, 32 | outputs_info=outputs_info, 33 | ) 34 | check_onnx_model(model, test_inputs) 35 | 36 | 37 | @pytest.mark.parametrize( 38 | 'input_shapes', 39 | ( 40 | ([],), 41 | ([2, 3, 4],), 42 | ([3, 1], [2, 1, 6]), 43 | ([3, 1], [3, 4]), 44 | ), 45 | ) 46 | @pytest.mark.parametrize('operation_type', ['Min', 'Max']) 47 | def test_min_amx( # pylint: disable=missing-function-docstring 48 | input_shapes: List[List[int]], 49 | operation_type: str, 50 | ) -> None: 51 | input_tensors = [np.random.normal(size=i_shape).astype(np.float32) for i_shape in input_shapes] 52 | 53 | _test_min_max( 54 | data_list=input_tensors, 55 | operation_type=operation_type, 56 | ) 57 | -------------------------------------------------------------------------------- /tests/node_converters/mod_test.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import numpy as np 4 | import onnx 5 | import pytest 6 | 7 | from tests.utils.common import check_onnx_model 8 | from tests.utils.common import make_model_from_nodes 9 | 10 | 11 | @pytest.mark.parametrize( 12 | 'dividend', 13 | [ 14 | [-4, 7, 5, 4, -7, 8], 15 | [-4.3, 7.2, 5.0, 4.3, -7.2, 8.0], 16 | ], 17 | ) 18 | @pytest.mark.parametrize( 19 | 'divisor', 20 | [ 21 | [2, -3, 8, -2, 3, 5], 22 | [2.1, -3.4, 8.0, -2.1, 3.4, 5.0], 23 | ], 24 | ) 25 | @pytest.mark.parametrize('fmod', [0, 1]) 26 | def test_mod( # pylint: disable=missing-function-docstring 27 | dividend: List[float], 28 | divisor: List[float], 29 | fmod: int, 30 | ) -> None: 31 | x_variants = np.array(dividend).astype(np.float32 if fmod else np.int32) 32 | y_variants = np.array(divisor).astype(np.float32 if fmod else np.int32) 33 | 34 | test_inputs = {'x': x_variants, 'y': y_variants} 35 | 36 | node = onnx.helper.make_node(op_type='Mod', inputs=['x', 'y'], outputs=['z'], fmod=fmod) 37 | model = make_model_from_nodes( 38 | nodes=node, 39 | initializers={}, 40 | inputs_example=test_inputs, 41 | ) 42 | check_onnx_model(model, test_inputs) 43 | -------------------------------------------------------------------------------- /tests/node_converters/neg_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import onnx 3 | 4 | from tests.utils.common import check_onnx_model 5 | from tests.utils.common import make_model_from_nodes 6 | 7 | 8 | def test_neg() -> None: # pylint: disable=missing-function-docstring 9 | x_variants = ( 10 | np.random.randn(128), 11 | np.random.randn(64, 128), 12 | np.random.randn(1, 64, 128), 13 | np.random.randn(10, 1, 64, 128), 14 | ) 15 | 16 | for x in x_variants: 17 | test_inputs = {'x': x} 18 | initializers = {} 19 | node = onnx.helper.make_node( 20 | op_type='Neg', 21 | inputs=['x'], 22 | outputs=['y'], 23 | ) 24 | 25 | model = make_model_from_nodes(nodes=node, initializers=initializers, inputs_example=test_inputs) 26 | check_onnx_model(model, test_inputs) 27 | -------------------------------------------------------------------------------- /tests/node_converters/pad_test.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import numpy as np 4 | import onnx 5 | import pytest 6 | 7 | from tests.utils.common import check_onnx_model 8 | from tests.utils.common import make_model_from_nodes 9 | 10 | 11 | def _test_pad( 12 | input_array: np.ndarray, 13 | opset_version: int, 14 | **kwargs, 15 | ) -> None: 16 | test_inputs = { 17 | 'x': input_array, 18 | } 19 | 20 | if opset_version != 2: 21 | test_inputs['pads'] = np.array(kwargs.pop('pads'), dtype=np.int64) 22 | 23 | node = onnx.helper.make_node( 24 | 'Pad', 25 | inputs=list(test_inputs), 26 | outputs=['y'], 27 | **kwargs, 28 | ) 29 | 30 | model = make_model_from_nodes( 31 | nodes=node, 32 | initializers={}, 33 | inputs_example=test_inputs, 34 | opset_version=opset_version, 35 | ) 36 | check_onnx_model(model, test_inputs) 37 | 38 | 39 | @pytest.mark.parametrize( 40 | 'input_shape,pads,mode', 41 | ( 42 | ([1, 1, 1, 3, 3], [0, 1, 1, 1, 1, 0, 0, 0, 1, 1], 'constant'), 43 | ([1, 1, 1, 3, 3], [0, 0, 5, 3, 7, 0, 0, 2, 3, 11], 'edge'), 44 | ([1, 1, 3, 3, 3], [0, 0, 1, 2, 1, 0, 0, 1, 2, 1], 'reflect'), 45 | ([1, 1, 3, 3], [0, 0, 0, 0, 0, 0, 0, 0], 'constant'), 46 | ([1, 1, 3, 3], [0, 1, 1, 1, 1, 0, 0, 0], 'constant'), 47 | ([1, 1, 3, 3], [0, 2, 0, 2, 0, 2, 0, 2], 'constant'), 48 | ([1, 1, 3, 3], [1, 2, 4, 2, 5, 4, 4, 2], 'constant'), 49 | ([1, 1, 3, 3], [0, 0, 0, 0, 0, 0, 0, 0], 'edge'), 50 | ([1, 1, 3, 3], [0, 0, 2, 3, 0, 0, 2, 3], 'edge'), 51 | ([1, 1, 3, 3], [0, 0, 0, 0, 0, 0, 0, 0], 'reflect'), 52 | ([1, 1, 3, 3], [0, 0, 2, 1, 0, 0, 2, 1], 'reflect'), 53 | ([1, 3, 3], [0, 4, 0, 1, 0, 1], 'constant'), 54 | ([1, 3, 3], [0, 0, 3, 0, 0, 3], 'edge'), 55 | ([1, 3, 3], [0, 0, 1, 0, 0, 1], 'reflect'), 56 | # negative padding 57 | ([3, 3, 3, 3, 3], [0, -1, 1, -1, 1, 0, 0, 0, 1, 1], 'constant'), 58 | ([3, 3, 3, 3], [0, -1, -1, -1, -1, 0, 0, 0], 'constant'), 59 | ([5, 7, 6], [0, -4, 0, -1, 0, 1], 'constant'), 60 | ), 61 | ) 62 | @pytest.mark.parametrize('opset_version', (2, 11, 13)) 63 | def test_pad( # pylint: disable=missing-function-docstring 64 | input_shape: List[int], 65 | pads: List[int], 66 | mode: str, 67 | opset_version: int, 68 | ) -> None: 69 | input_array = np.random.random(size=input_shape).astype(np.float32) 70 | print(len(input_array.shape), len(pads)) 71 | _test_pad(input_array=input_array, mode=mode, opset_version=opset_version, pads=pads) 72 | -------------------------------------------------------------------------------- /tests/node_converters/pow_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import onnx 3 | 4 | from tests.utils.common import check_onnx_model 5 | from tests.utils.common import make_model_from_nodes 6 | 7 | 8 | def test_pow() -> None: # pylint: disable=missing-function-docstring 9 | input_shape = [10, 3, 128, 128] 10 | x_variants = [ 11 | np.random.uniform(low=0.0, high=4.0, size=input_shape).astype(np.float32), 12 | np.random.uniform(low=-4.0, high=4.0, size=input_shape).astype(np.float32), 13 | np.random.uniform(low=-4.0, high=0.001, size=input_shape).astype(np.float32), 14 | np.random.uniform(low=-4.0, high=4.0, size=input_shape).astype(np.float32), 15 | ] 16 | 17 | y_variants = [ 18 | np.random.uniform(low=-3.0, high=3.0, size=1).astype(np.float32), 19 | np.random.randint(low=0, high=4, size=[1] * len(input_shape)).astype(np.float32), 20 | np.random.randint(low=-4, high=0, size=input_shape).astype(np.float32), 21 | np.array([0.0], dtype=np.float32), 22 | ] 23 | 24 | for x, y in zip(x_variants, y_variants): 25 | test_inputs = {'x': x, 'y': y} 26 | initializers = {} 27 | node = onnx.helper.make_node( 28 | op_type='Pow', 29 | inputs=['x', 'y'], 30 | outputs=['z'], 31 | ) 32 | 33 | model = make_model_from_nodes( 34 | nodes=node, 35 | initializers=initializers, 36 | inputs_example=test_inputs, 37 | ) 38 | check_onnx_model(model, test_inputs) 39 | 40 | 41 | def test_sqrt() -> None: # pylint: disable=missing-function-docstring 42 | input_shape = [10, 3, 128, 128] 43 | x = np.random.uniform(low=0.0, high=10.0, size=input_shape).astype(np.float32) 44 | 45 | test_inputs = {'x': x} 46 | initializers = {} 47 | node = onnx.helper.make_node( 48 | op_type='Sqrt', 49 | inputs=['x'], 50 | outputs=['z'], 51 | ) 52 | 53 | model = make_model_from_nodes(nodes=node, initializers=initializers, inputs_example=test_inputs) 54 | check_onnx_model(model, test_inputs) 55 | -------------------------------------------------------------------------------- /tests/node_converters/range_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import onnx 3 | import pytest 4 | from onnx.helper import make_tensor_value_info 5 | from onnx.mapping import NP_TYPE_TO_TENSOR_TYPE 6 | 7 | from tests.utils.common import check_onnx_model 8 | from tests.utils.common import make_model_from_nodes 9 | 10 | 11 | def _test_range( 12 | start: np.ndarray, 13 | limit: np.ndarray, 14 | delta: np.ndarray, 15 | ) -> None: 16 | test_inputs = {'start': start, 'limit': limit, 'delta': delta} 17 | node = onnx.helper.make_node(op_type='Range', inputs=list(test_inputs), outputs=['y']) 18 | 19 | num_elements = int(max(np.ceil((limit - start) / delta), 0)) 20 | outputs_info = [ 21 | make_tensor_value_info( 22 | name='y', 23 | elem_type=NP_TYPE_TO_TENSOR_TYPE[delta.dtype], 24 | shape=[num_elements], 25 | ), 26 | ] 27 | model = make_model_from_nodes( 28 | nodes=node, 29 | initializers={}, 30 | inputs_example=test_inputs, 31 | outputs_info=outputs_info, 32 | ) 33 | check_onnx_model(model, test_inputs) 34 | 35 | 36 | @pytest.mark.filterwarnings('ignore::torch.jit._trace.TracerWarning') 37 | def test_range() -> None: # pylint: disable=missing-function-docstring 38 | _test_range( 39 | start=np.array(1, dtype=np.int32), 40 | limit=np.array(5, dtype=np.int32), 41 | delta=np.array(2, dtype=np.int32), 42 | ) 43 | _test_range( 44 | start=np.array(10.0, dtype=np.float32), 45 | limit=np.array(6.0, dtype=np.float32), 46 | delta=np.array(-2.3, dtype=np.float32), 47 | ) 48 | _test_range( 49 | start=np.array(1, dtype=np.int64), 50 | limit=np.array(60, dtype=np.int64), 51 | delta=np.array(7, dtype=np.int64), 52 | ) 53 | -------------------------------------------------------------------------------- /tests/node_converters/reciprocal_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import onnx 3 | 4 | from tests.utils.common import check_onnx_model 5 | from tests.utils.common import make_model_from_nodes 6 | 7 | 8 | def test_reciprocal() -> None: # pylint: disable=missing-function-docstring 9 | x_variants = ( 10 | np.random.randn(128), 11 | np.random.randn(64, 128), 12 | np.random.randn(3, 64, 128), 13 | np.random.randn(10, 2, 64, 128), 14 | np.zeros([3, 3, 5]), 15 | ) 16 | 17 | for x in x_variants: 18 | test_inputs = {'x': x} 19 | initializers = {} 20 | node = onnx.helper.make_node( 21 | op_type='Reciprocal', 22 | inputs=['x'], 23 | outputs=['y'], 24 | ) 25 | 26 | model = make_model_from_nodes(nodes=node, initializers=initializers, inputs_example=test_inputs) 27 | check_onnx_model(model, test_inputs) 28 | -------------------------------------------------------------------------------- /tests/node_converters/reshape_test.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import numpy as np 4 | import onnx 5 | import pytest 6 | 7 | from tests.utils.common import check_onnx_model 8 | from tests.utils.common import make_model_from_nodes 9 | 10 | 11 | def _test_reshape( 12 | input_shape: List[int], 13 | output_shape: List[int], 14 | opset_version: int, 15 | **kwargs, 16 | ) -> None: 17 | test_inputs = {'x': np.random.uniform(low=-1.0, high=1.0, size=input_shape).astype(np.float32)} 18 | initializers = {'output_shape': np.asarray(output_shape, dtype=np.int64)} 19 | 20 | node = onnx.helper.make_node( 21 | op_type='Reshape', 22 | inputs=['x', 'output_shape'], 23 | outputs=['y'], 24 | **kwargs, 25 | ) 26 | model = make_model_from_nodes( 27 | nodes=node, 28 | initializers=initializers, 29 | inputs_example=test_inputs, 30 | opset_version=opset_version, 31 | ) 32 | check_onnx_model(model, test_inputs) 33 | 34 | 35 | @pytest.mark.filterwarnings('ignore::torch.jit._trace.TracerWarning') 36 | @pytest.mark.parametrize( 37 | 'input_shape,output_shape,opset_version', 38 | ( 39 | ([2, 3, 16, 16], [2, -1, 3], 9), 40 | ([2, 3, 16, 16], [2, 0, -1], 9), 41 | ([2, 3, 16, 16], [2, 0, 1, 1, 1, 1, 1, 1, -1], 9), 42 | ([2, 3, 16, 16], [-1, 1, 1, 2, 1, 1, 1, 2, 1, 1], 14), 43 | ), 44 | ) 45 | def test_reshape( # pylint: disable=missing-function-docstring 46 | input_shape: List[int], 47 | output_shape: List[int], 48 | opset_version: int, 49 | ) -> None: 50 | _test_reshape( 51 | input_shape=input_shape, 52 | output_shape=output_shape, 53 | opset_version=opset_version, 54 | ) 55 | -------------------------------------------------------------------------------- /tests/node_converters/scatter_nd_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import onnx 3 | import pytest 4 | 5 | from tests.utils.common import check_onnx_model 6 | from tests.utils.common import make_model_from_nodes 7 | 8 | 9 | def _test_scatter_nd( 10 | data: np.ndarray, 11 | indices: np.ndarray, 12 | updates: np.ndarray, 13 | opset_version: int, 14 | **kwargs, 15 | ) -> None: 16 | test_inputs = {'data': data, 'indices': indices, 'updates': updates} 17 | 18 | node = onnx.helper.make_node( 19 | op_type='ScatterND', 20 | inputs=['data', 'indices', 'updates'], 21 | outputs=['y'], 22 | **kwargs, 23 | ) 24 | 25 | model = make_model_from_nodes(nodes=node, initializers={}, inputs_example=test_inputs, opset_version=opset_version) 26 | check_onnx_model(model, test_inputs, opset_version=opset_version) 27 | 28 | 29 | @pytest.mark.parametrize('opset_version', (11, 13, 14, 16)) 30 | @pytest.mark.parametrize('reduction', ('none',)) 31 | @pytest.mark.parametrize( 32 | 'data', 33 | ( 34 | np.array( 35 | [ 36 | [[1, 2, 3, 4], [5, 6, 7, 8], [8, 7, 6, 5], [4, 3, 2, 1]], 37 | [[1, 2, 3, 4], [5, 6, 7, 8], [8, 7, 6, 5], [4, 3, 2, 1]], 38 | [[8, 7, 6, 5], [4, 3, 2, 1], [1, 2, 3, 4], [5, 6, 7, 8]], 39 | [[8, 7, 6, 5], [4, 3, 2, 1], [1, 2, 3, 4], [5, 6, 7, 8]], 40 | ], 41 | dtype=np.float32, 42 | ), 43 | ), 44 | ) 45 | @pytest.mark.parametrize( 46 | 'indices, updates', 47 | ( 48 | ( 49 | np.array([[0, 1, 2], [1, 2, 3]], dtype=np.int64), 50 | np.array([1232, 5463], dtype=np.float32), 51 | ), 52 | ( 53 | np.array([[0, 1], [1, 2]], dtype=np.int64), 54 | np.array([[8, 7, 6, 5], [4, 3, 2, 1]], dtype=np.float32), 55 | ), 56 | ( 57 | np.array([[0], [2]], dtype=np.int64), 58 | np.array( 59 | [ 60 | [[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]], 61 | [[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3], [4, 4, 4, 4]], 62 | ], 63 | dtype=np.float32, 64 | ), 65 | ), 66 | ), 67 | ) 68 | def test_scatter_nd( # pylint: disable=missing-function-docstring 69 | data: np.ndarray, indices: np.ndarray, updates: np.ndarray, opset_version: int, reduction: str 70 | ) -> None: 71 | _test_scatter_nd( 72 | data=data, 73 | indices=indices, 74 | updates=updates, 75 | opset_version=opset_version, 76 | reduction=reduction if opset_version >= 16 else None, 77 | ) 78 | -------------------------------------------------------------------------------- /tests/node_converters/shape_test.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import numpy as np 4 | import onnx 5 | import pytest 6 | from onnx.helper import make_tensor_value_info 7 | from onnx.mapping import NP_TYPE_TO_TENSOR_TYPE 8 | 9 | from tests.utils.common import check_onnx_model 10 | from tests.utils.common import make_model_from_nodes 11 | 12 | 13 | def _test_shape( 14 | input_shape: List[int], 15 | opset_version: int, 16 | **kwargs, 17 | ) -> None: 18 | x = np.random.uniform(low=-1.0, high=1.0, size=input_shape).astype(np.float32) 19 | test_inputs = {'x': x} 20 | 21 | node = onnx.helper.make_node( 22 | op_type='Shape', 23 | inputs=list(test_inputs), 24 | outputs=['y'], 25 | **kwargs, 26 | ) 27 | onnx_type = NP_TYPE_TO_TENSOR_TYPE[np.dtype('int64')] 28 | outputs_info = [make_tensor_value_info(name='y', elem_type=onnx_type, shape=None)] 29 | model = make_model_from_nodes( 30 | nodes=node, 31 | initializers={}, 32 | inputs_example=test_inputs, 33 | outputs_info=outputs_info, 34 | opset_version=opset_version, 35 | ) 36 | check_onnx_model(model, test_inputs) 37 | 38 | 39 | @pytest.mark.filterwarnings('ignore::torch.jit._trace.TracerWarning') 40 | def test_shape() -> None: # pylint: disable=missing-function-docstring 41 | _test_shape(input_shape=[2, 3, 16, 16, 16], opset_version=9) 42 | _test_shape(input_shape=[2, 3, 16, 16], opset_version=9) 43 | _test_shape(input_shape=[2, 3, 16], opset_version=9) 44 | -------------------------------------------------------------------------------- /tests/node_converters/slice_test.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from typing import Optional 3 | from typing import Tuple 4 | 5 | import numpy as np 6 | import onnx 7 | import pytest 8 | from onnx.helper import make_tensor_value_info 9 | from onnx.mapping import NP_TYPE_TO_TENSOR_TYPE 10 | 11 | from tests.utils.common import check_onnx_model 12 | from tests.utils.common import make_model_from_nodes 13 | 14 | 15 | def _test_slice( 16 | input_tensor: np.ndarray, 17 | starts: np.ndarray, 18 | ends: np.ndarray, 19 | axes: Optional[np.ndarray] = None, 20 | steps: Optional[np.ndarray] = None, 21 | ) -> None: 22 | test_inputs = {'input_tensor': input_tensor} 23 | 24 | initializers = {'starts': starts, 'ends': ends} 25 | if axes is not None: 26 | initializers['axes'] = axes 27 | if steps is not None: 28 | initializers['steps'] = steps 29 | 30 | node = onnx.helper.make_node( 31 | op_type='Slice', 32 | inputs=list(test_inputs.keys()) + list(initializers.keys()), 33 | outputs=['y'], 34 | ) 35 | outputs_info = [ 36 | make_tensor_value_info( 37 | name='y', 38 | elem_type=NP_TYPE_TO_TENSOR_TYPE[input_tensor.dtype], 39 | shape=None, 40 | ), 41 | ] 42 | model = make_model_from_nodes( 43 | nodes=node, 44 | initializers=initializers, 45 | inputs_example=test_inputs, 46 | outputs_info=outputs_info, 47 | ) 48 | # onnx checker in torch 1.12 has problems with negative steps in Slice, so we disable it 49 | ignore_export_checker = steps is not None and np.any(steps < 0) 50 | check_onnx_model(model, test_inputs, ignore_export_checker=ignore_export_checker) 51 | 52 | 53 | @pytest.mark.filterwarnings('ignore::torch.jit._trace.TracerWarning') 54 | @pytest.mark.parametrize( 55 | 'input_shape,starts,ends,axes,steps', 56 | ( 57 | ((20, 10, 15), [0, 0], [3, 10], [0, 1], [1, 1]), 58 | ((20, 10, 15), [0, 0, 3], [20, 10, 4], None, None), 59 | ((20, 10, 15), [1], [1000], [1], [1]), 60 | ((20, 10, 15), [0], [-1], [1], [1]), 61 | ((20, 10, 15), [20, 10, 4], [0, 0, 1], [0, 1, 2], [-1, -3, -2]), 62 | ((20, 10, 15), [0, 0, 3], [20, 10, 4], [0, -2, -1], None), 63 | ), 64 | ) 65 | def test_slice( # pylint: disable=missing-function-docstring 66 | input_shape: Tuple[int, ...], 67 | starts: List[int], 68 | ends: List[int], 69 | axes: Optional[List[int]], 70 | steps: Optional[List[int]], 71 | ) -> None: 72 | x = np.random.randn(*input_shape).astype(np.float32) 73 | _test_slice( 74 | input_tensor=x, 75 | starts=np.array(starts, dtype=np.int64), 76 | ends=np.array(ends, dtype=np.int64), 77 | axes=np.array(axes, dtype=np.int64) if axes is not None else None, 78 | steps=np.array(steps, dtype=np.int64) if steps is not None else None, 79 | ) 80 | -------------------------------------------------------------------------------- /tests/node_converters/split_test.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from typing import Optional 3 | 4 | import numpy as np 5 | import onnx 6 | import pytest 7 | from onnx.helper import make_tensor_value_info 8 | from onnx.mapping import NP_TYPE_TO_TENSOR_TYPE 9 | 10 | from tests.utils.common import check_onnx_model 11 | from tests.utils.common import make_model_from_nodes 12 | 13 | 14 | def _test_split( 15 | x: np.ndarray, 16 | expected_output: List[np.ndarray], 17 | opset_version: int, 18 | **kwargs, 19 | ) -> None: 20 | inputs = [ 21 | 'x', 22 | ] 23 | test_inputs = {'x': x} 24 | 25 | if opset_version >= 13 and kwargs.get('split') is not None: 26 | split = kwargs.pop('split') 27 | test_inputs['split'] = split 28 | inputs.append('split') 29 | 30 | node = onnx.helper.make_node( 31 | op_type='Split', 32 | inputs=inputs, 33 | outputs=[f'output_{i}' for i, _ in enumerate(expected_output)], 34 | **kwargs, 35 | ) 36 | 37 | outputs_info = [ 38 | make_tensor_value_info( 39 | name=f'output_{i}', 40 | elem_type=NP_TYPE_TO_TENSOR_TYPE[out.dtype], 41 | shape=out.shape, 42 | ) 43 | for i, out in enumerate(expected_output) 44 | ] 45 | 46 | model = make_model_from_nodes( 47 | nodes=node, 48 | initializers={}, 49 | inputs_example=test_inputs, 50 | outputs_info=outputs_info, 51 | opset_version=opset_version, 52 | ) 53 | check_onnx_model(model, test_inputs) 54 | 55 | 56 | INPUT_1D = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).astype(np.float32) 57 | INPUT_2D = np.array([[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], [7.0, 8.0, 9.0, 10.0, 11.0, 12.0]]).astype(np.float32) 58 | 59 | EMPTY_INPUT = np.array([]).astype(np.float32) 60 | EXPECTED_EMPTY_OUT = [np.array([]).astype(np.float32), np.array([]).astype(np.float32), np.array([]).astype(np.float32)] 61 | 62 | 63 | @pytest.mark.parametrize( 64 | 'input_array,expected_out,axis,split', 65 | ( 66 | (INPUT_1D, np.split(INPUT_1D, 3), None, None), 67 | (INPUT_1D, np.split(INPUT_1D, 3), 0, None), 68 | (INPUT_1D, np.split(INPUT_1D, [2]), None, np.array([2, 4]).astype(np.int64)), 69 | (INPUT_1D, np.split(INPUT_1D, [2]), 0, np.array([2, 4]).astype(np.int64)), 70 | (INPUT_2D, np.split(INPUT_2D, 2, axis=1), 1, None), 71 | (INPUT_2D, np.split(INPUT_2D, [2], axis=1), 1, np.array([2, 4]).astype(np.int64)), 72 | (EMPTY_INPUT, EXPECTED_EMPTY_OUT, None, np.array([0, 0, 0]).astype(np.int64)), 73 | ), 74 | ) 75 | @pytest.mark.parametrize('opset_version', (13, 11, 2)) 76 | def test_split( # pylint: disable=missing-function-docstring 77 | input_array: np.ndarray, 78 | expected_out: List[np.ndarray], 79 | axis: Optional[int], 80 | split: Optional[np.ndarray], 81 | opset_version: int, 82 | ) -> None: 83 | kwargs = {} 84 | if axis is not None: 85 | kwargs['axis'] = axis 86 | if split is not None: 87 | kwargs['split'] = split 88 | 89 | _test_split(input_array, expected_out, opset_version=opset_version, **kwargs) 90 | -------------------------------------------------------------------------------- /tests/node_converters/squeeze_test.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | from typing import Dict 3 | from typing import List 4 | from typing import Optional 5 | 6 | import numpy as np 7 | import onnx 8 | import pytest 9 | from onnx.helper import make_tensor_value_info 10 | from onnx.mapping import NP_TYPE_TO_TENSOR_TYPE 11 | 12 | from tests.utils.common import check_onnx_model 13 | from tests.utils.common import make_model_from_nodes 14 | 15 | 16 | def _test_squeeze( 17 | input_tensor: np.ndarray, 18 | axes: Optional[List[int]], 19 | opset_version: int, 20 | **kwargs, 21 | ) -> None: 22 | test_inputs: Dict[str, Any] = {'input_tensor': input_tensor} 23 | 24 | if axes is not None and len(axes) > 0: 25 | if opset_version >= 13: 26 | test_inputs['axes'] = np.array(axes, dtype=np.int64) 27 | else: 28 | kwargs['axes'] = axes 29 | 30 | output_shape = np.squeeze(input_tensor, axis=tuple(a for a in axes if input_tensor.shape[a] == 1)).shape 31 | else: 32 | output_shape = np.squeeze(input_tensor).shape 33 | 34 | node = onnx.helper.make_node( 35 | op_type='Squeeze', 36 | inputs=list(test_inputs), 37 | outputs=['y'], 38 | **kwargs, 39 | ) 40 | 41 | model = make_model_from_nodes( 42 | nodes=node, 43 | initializers={}, 44 | inputs_example=test_inputs, 45 | opset_version=opset_version, 46 | outputs_info=( 47 | make_tensor_value_info( 48 | name='y', 49 | elem_type=NP_TYPE_TO_TENSOR_TYPE[input_tensor.dtype], 50 | shape=output_shape, 51 | ), 52 | ), 53 | ) 54 | check_onnx_model(model, test_inputs) 55 | 56 | 57 | @pytest.mark.filterwarnings('ignore::torch.jit._trace.TracerWarning') 58 | @pytest.mark.parametrize('opset_version', [11, 13, 21]) 59 | @pytest.mark.parametrize( 60 | 'shape, axes', 61 | ( 62 | ([1, 3, 4, 5], [0]), 63 | ([1, 3, 1, 5], [-2]), 64 | ([1, 3, 1, 5], [0, 2]), 65 | ([1, 3, 1, 5], [2, 0]), 66 | ([1, 3, 1, 1, 1, 5, 1], [2, 0, 6]), 67 | ([1, 3, 1, 5], [0, -2]), 68 | ([1, 3, 1, 5], [-2, 0]), 69 | ([1, 3, 1, 5], None), 70 | ([1, 1, 1, 1], None), 71 | ([1], None), 72 | ([3, 3, 3], None), 73 | ), 74 | ) 75 | def test_squeeze( # pylint: disable=missing-function-docstring 76 | shape: List[int], 77 | axes: List[int], 78 | opset_version: int, 79 | ) -> None: 80 | x = np.random.randn(*shape).astype(np.float32) 81 | _test_squeeze(input_tensor=x, axes=axes, opset_version=opset_version) 82 | -------------------------------------------------------------------------------- /tests/node_converters/sum_test.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import numpy as np 4 | import onnx 5 | import pytest 6 | from onnx.helper import make_tensor_value_info 7 | from onnx.mapping import NP_TYPE_TO_TENSOR_TYPE 8 | 9 | from tests.utils.common import check_onnx_model 10 | from tests.utils.common import make_model_from_nodes 11 | 12 | 13 | def _test_sum( 14 | data_list: List[np.ndarray], 15 | ) -> None: 16 | test_inputs = {f'data_{i}': data for i, data in enumerate(data_list)} 17 | 18 | node = onnx.helper.make_node(op_type='Sum', inputs=list(test_inputs), outputs=['y']) 19 | outputs_info = [ 20 | make_tensor_value_info( 21 | name='y', 22 | elem_type=NP_TYPE_TO_TENSOR_TYPE[data_list[0].dtype], 23 | shape=None, 24 | ), 25 | ] 26 | 27 | model = make_model_from_nodes( 28 | nodes=node, 29 | initializers={}, 30 | inputs_example=test_inputs, 31 | outputs_info=outputs_info, 32 | ) 33 | check_onnx_model(model, test_inputs) 34 | 35 | 36 | @pytest.mark.parametrize( 37 | 'input_shapes', 38 | ( 39 | ([],), 40 | ([2, 3, 4],), 41 | ([3, 1], [2, 1, 6]), 42 | ([3, 1], [3, 4]), 43 | ), 44 | ) 45 | def test_sum(input_shapes: List[List[int]]) -> None: # pylint: disable=missing-function-docstring 46 | input_tensors = [np.random.normal(size=i_shape).astype(np.float32) for i_shape in input_shapes] 47 | _test_sum(data_list=input_tensors) 48 | -------------------------------------------------------------------------------- /tests/node_converters/test_functions.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import numpy as np 4 | import onnx 5 | import pytest 6 | 7 | from tests.utils.common import check_onnx_model 8 | from tests.utils.common import make_model_from_nodes 9 | 10 | 11 | def _test_functions(function: str, data: np.ndarray, opset_version, **kwargs) -> None: 12 | test_inputs = {'input_tensor': data} 13 | 14 | node = onnx.helper.make_node(op_type=function, inputs=['input_tensor'], outputs=['y'], **kwargs) 15 | model = make_model_from_nodes( 16 | nodes=node, 17 | initializers={}, 18 | inputs_example=test_inputs, 19 | opset_version=opset_version, 20 | ) 21 | 22 | check_onnx_model(model, test_inputs) 23 | 24 | 25 | @pytest.mark.parametrize( 26 | 'function,input_shape', 27 | ( 28 | ('Ceil', [8, 3, 32, 32]), 29 | ('Floor', [8, 3, 32, 32]), 30 | ('Round', [8, 3, 32, 32]), 31 | ), 32 | ) 33 | def test_roundings(function: str, input_shape: List[int]) -> None: # pylint: disable=missing-function-docstring 34 | data = np.random.randn(*input_shape).astype(np.float32) 35 | _test_functions(function, data=data, opset_version=11) 36 | 37 | 38 | @pytest.mark.parametrize( 39 | 'function,input_shape', 40 | ( 41 | ('Abs', [8, 3, 32, 32]), 42 | ('Cos', [8, 3, 32, 32]), 43 | ('Exp', [8, 3, 32, 32]), 44 | ('Log', [8, 3, 32, 32]), 45 | ('Sign', [8, 3, 32, 32]), 46 | ('Sin', [8, 3, 32, 32]), 47 | ('Tan', [8, 3, 32, 32]), 48 | ), 49 | ) 50 | def test_common_functions(function: str, input_shape: List[int]) -> None: # pylint: disable=missing-function-docstring 51 | data = np.random.randn(*input_shape).astype(np.float32) 52 | if function == 'Log': 53 | data[data <= 0] = 10**-4 54 | _test_functions(function, data=data, opset_version=11) 55 | 56 | 57 | @pytest.mark.parametrize( 58 | 'function,input_shape', 59 | ( 60 | ('Acos', [8, 3, 32, 32]), 61 | ('Asin', [8, 3, 32, 32]), 62 | ('Atan', [8, 3, 32, 32]), 63 | ), 64 | ) 65 | def test_arc_functions(function: str, input_shape: List[int]) -> None: # pylint: disable=missing-function-docstring 66 | if function in ['Acos', 'Asin']: 67 | data = np.random.uniform(-1, 1, input_shape).astype(np.float32) 68 | else: 69 | data = np.random.randn(*input_shape).astype(np.float32) 70 | 71 | _test_functions(function, data=data, opset_version=11) 72 | 73 | 74 | @pytest.mark.parametrize( 75 | 'function,input_shape', 76 | (('Tanh', [8, 3, 32, 32]),), 77 | ) 78 | def test_hyperbolic_functions( # pylint: disable=missing-function-docstring 79 | function: str, 80 | input_shape: List[int], 81 | ) -> None: 82 | data = np.random.randn(*input_shape).astype(np.float32) 83 | _test_functions(function, data=data, opset_version=11) 84 | -------------------------------------------------------------------------------- /tests/node_converters/tile_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import onnx 3 | import pytest 4 | from onnx.helper import make_tensor_value_info 5 | from onnx.mapping import NP_TYPE_TO_TENSOR_TYPE 6 | 7 | from tests.utils.common import check_onnx_model 8 | from tests.utils.common import make_model_from_nodes 9 | 10 | 11 | def _test_tile( 12 | data: np.ndarray, 13 | repeats: np.ndarray, 14 | desire_out: np.ndarray, 15 | ) -> None: 16 | test_inputs = {'input_tensor': data, 'repeats': repeats} 17 | node = onnx.helper.make_node( 18 | op_type='Tile', 19 | inputs=list(test_inputs), 20 | outputs=['y'], 21 | ) 22 | outputs_info = [ 23 | make_tensor_value_info(name='y', elem_type=NP_TYPE_TO_TENSOR_TYPE[data.dtype], shape=desire_out.shape), 24 | ] 25 | model = make_model_from_nodes( 26 | nodes=node, 27 | initializers={}, 28 | inputs_example=test_inputs, 29 | outputs_info=outputs_info, 30 | ) 31 | check_onnx_model(model, test_inputs) 32 | 33 | 34 | @pytest.mark.filterwarnings('ignore::torch.jit._trace.TracerWarning') 35 | def test_tile() -> None: # pylint: disable=missing-function-docstring 36 | data = np.random.rand(2, 3, 4, 5).astype(np.float32) 37 | repeats = np.random.randint(low=1, high=10, size=(np.ndim(data),)).astype(np.int64) 38 | _test_tile( 39 | data=data, 40 | repeats=repeats, 41 | desire_out=np.tile(data, repeats), 42 | ) 43 | 44 | data = np.array([[0, 1], [2, 3]], dtype=np.float32) 45 | 46 | repeats = np.array([2, 2], dtype=np.int64) 47 | _test_tile( 48 | data=data, 49 | repeats=repeats, 50 | desire_out=np.array([[0, 1, 0, 1], [2, 3, 2, 3], [0, 1, 0, 1], [2, 3, 2, 3]], dtype=np.float32), 51 | ) 52 | -------------------------------------------------------------------------------- /tests/node_converters/topk_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import onnx 3 | from onnx.helper import make_tensor_value_info 4 | from onnx.mapping import NP_TYPE_TO_TENSOR_TYPE 5 | 6 | from tests.utils.common import check_onnx_model 7 | from tests.utils.common import make_model_from_nodes 8 | 9 | 10 | def _test_topk(data: np.ndarray, k: np.ndarray, **kwargs) -> None: # pylint: disable=invalid-name 11 | test_inputs = {'input_tensor': data, 'k': k} 12 | 13 | node = onnx.helper.make_node( 14 | op_type='TopK', 15 | inputs=list(test_inputs), 16 | outputs=['y_0', 'y_1'], 17 | **kwargs, 18 | ) 19 | outputs_info = [ 20 | make_tensor_value_info(name='y_0', elem_type=NP_TYPE_TO_TENSOR_TYPE[data.dtype], shape=None), 21 | make_tensor_value_info(name='y_1', elem_type=NP_TYPE_TO_TENSOR_TYPE[np.dtype('int64')], shape=None), 22 | ] 23 | model = make_model_from_nodes( 24 | nodes=node, 25 | initializers={}, 26 | inputs_example=test_inputs, 27 | outputs_info=outputs_info, 28 | ) 29 | check_onnx_model(model, test_inputs) 30 | 31 | 32 | def test_topk() -> None: # pylint: disable=missing-function-docstring 33 | x = np.array( 34 | [ 35 | [0, 1, 2, 3], 36 | [4, 5, 6, 7], 37 | [8, 9, 10, 11], 38 | ], 39 | dtype=np.float32, 40 | ) 41 | 42 | _test_topk(data=x, k=np.array([3], dtype=np.int64), axis=1, largest=1) 43 | _test_topk(data=x, k=np.array([3], dtype=np.int64), axis=-1, largest=1) 44 | _test_topk(data=x, k=np.array([3], dtype=np.int64), axis=1, largest=1, sorted=1) 45 | -------------------------------------------------------------------------------- /tests/node_converters/transpose_test.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | 3 | import numpy as np 4 | import onnx 5 | 6 | from tests.utils.common import check_onnx_model 7 | from tests.utils.common import make_model_from_nodes 8 | 9 | 10 | def _test_transpose(data: np.ndarray, **kwargs) -> None: 11 | test_inputs = {'input_tensor': data} 12 | node = onnx.helper.make_node( 13 | op_type='Transpose', 14 | inputs=list(test_inputs), 15 | outputs=['y'], 16 | **kwargs, 17 | ) 18 | model = make_model_from_nodes( 19 | nodes=node, 20 | initializers={}, 21 | inputs_example=test_inputs, 22 | ) 23 | check_onnx_model(model, test_inputs) 24 | 25 | 26 | def test_transpose() -> None: # pylint: disable=missing-function-docstring 27 | shape = (2, 3, 4) 28 | data = np.random.random_sample(shape).astype(np.float32) 29 | permutations = list(itertools.permutations(np.arange(len(shape)))) 30 | for permutation in permutations: 31 | _test_transpose( 32 | data=data, 33 | perm=np.array(permutation, dtype=np.int64), 34 | ) 35 | 36 | _test_transpose(data=data) 37 | -------------------------------------------------------------------------------- /tests/node_converters/unsqueeze_test.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | from typing import Dict 3 | from typing import List 4 | 5 | import numpy as np 6 | import onnx 7 | import pytest 8 | from onnx.helper import make_tensor_value_info 9 | from onnx.mapping import NP_TYPE_TO_TENSOR_TYPE 10 | 11 | from tests.utils.common import check_onnx_model 12 | from tests.utils.common import make_model_from_nodes 13 | 14 | 15 | def _test_unsqueeze( 16 | input_tensor: np.ndarray, 17 | axes: List[int], 18 | opset_version: int, 19 | **kwargs, 20 | ) -> None: 21 | test_inputs: Dict[str, Any] = {'input_tensor': input_tensor} 22 | 23 | if opset_version >= 13: 24 | test_inputs['axes'] = np.array(axes, dtype=np.int64) 25 | else: 26 | kwargs['axes'] = axes 27 | 28 | node = onnx.helper.make_node( 29 | op_type='Unsqueeze', 30 | inputs=list(test_inputs), 31 | outputs=['y'], 32 | **kwargs, 33 | ) 34 | 35 | model = make_model_from_nodes( 36 | nodes=node, 37 | initializers={}, 38 | inputs_example=test_inputs, 39 | opset_version=opset_version, 40 | outputs_info=( 41 | make_tensor_value_info( 42 | name='y', 43 | elem_type=NP_TYPE_TO_TENSOR_TYPE[input_tensor.dtype], 44 | shape=np.expand_dims(input_tensor, axis=axes).shape, 45 | ), 46 | ), 47 | ) 48 | check_onnx_model(model, test_inputs) 49 | 50 | 51 | # Known warning. Shape Inference do not work properly in opset_version=9 and negative indices. 52 | # [W:onnxruntime:, execution_frame.cc:721 VerifyOutputSizes] 53 | # Expected shape from model of {2,3,16,16} does not match actual shape of {2,1,3,16,1,16} for output y 54 | @pytest.mark.filterwarnings('ignore::torch.jit._trace.TracerWarning') 55 | @pytest.mark.parametrize( 56 | 'shape,axes,opset_version', 57 | ( 58 | ([2, 3, 16, 16], [0], 11), 59 | ([2, 3, 16, 16], [2], 11), 60 | ([2, 3, 16, 16], [-1], 11), 61 | ([2, 3, 16, 16], [-3], 11), 62 | ([2, 3, 16, 16], [0, 1], 11), 63 | ([2, 3, 16, 16], [1, 2, 3, 4, 5], 11), 64 | ([2, 3, 16, 16], [1, -2], 11), 65 | ([2, 3, 16, 16], [-2, 1], 11), 66 | ([2, 3, 16, 16], [0], 13), 67 | ([2, 3, 16, 16], [2], 13), 68 | ([2, 3, 16, 16], [-1], 13), 69 | ([2, 3, 16, 16], [-3], 13), 70 | ([2, 3, 16, 16], [0, 1], 13), 71 | ([2, 3, 16, 16], [1, 2, 3, 4, 5], 13), 72 | ([2, 3, 16, 16], [1, -2], 13), 73 | ([2, 3, 16, 16], [-2, 1], 13), 74 | ), 75 | ) 76 | def test_unsqueeze( # pylint: disable=missing-function-docstring 77 | shape: List[int], 78 | axes: List[int], 79 | opset_version: int, 80 | ) -> None: 81 | x = np.random.randn(*shape).astype(np.float32) 82 | _test_unsqueeze(input_tensor=x, axes=axes, opset_version=opset_version) 83 | -------------------------------------------------------------------------------- /tests/node_converters/where_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import onnx 3 | from onnx.helper import make_tensor_value_info 4 | from onnx.mapping import NP_TYPE_TO_TENSOR_TYPE 5 | 6 | from tests.utils.common import check_onnx_model 7 | from tests.utils.common import make_model_from_nodes 8 | 9 | 10 | def where_test( # pylint: disable=missing-function-docstring 11 | condition: np.ndarray, 12 | x: np.ndarray, 13 | y: np.ndarray, 14 | ) -> None: 15 | test_inputs = {'condition': condition, 'x': x, 'y': y} 16 | node = onnx.helper.make_node( 17 | op_type='Where', 18 | inputs=list(test_inputs), 19 | outputs=['z'], 20 | ) 21 | outputs_info = [ 22 | make_tensor_value_info( 23 | name='z', 24 | elem_type=NP_TYPE_TO_TENSOR_TYPE[x.dtype], 25 | shape=None, 26 | ) 27 | ] 28 | model = make_model_from_nodes(nodes=node, initializers={}, inputs_example=test_inputs, outputs_info=outputs_info) 29 | check_onnx_model(model, test_inputs) 30 | 31 | 32 | def test_where() -> None: # pylint: disable=missing-function-docstring 33 | where_test( 34 | condition=np.array([[1, 0], [1, 1]], dtype=bool), 35 | x=np.array([[1, 2], [3, 4]], dtype=np.int64), 36 | y=np.array([[9, 8], [7, 6]], dtype=np.int64), 37 | ) 38 | 39 | where_test( 40 | condition=np.array([[1, 0], [1, 1]], dtype=bool), 41 | x=np.array([[1, 2], [3, 4]], dtype=np.float32), 42 | y=np.array([[9, 8], [7, 6]], dtype=np.float32), 43 | ) 44 | 45 | where_test( 46 | condition=np.array([[1, 0], [1, 1]], dtype=bool), 47 | x=np.array( 48 | [ 49 | [ 50 | 1, 51 | ], 52 | [ 53 | 3, 54 | ], 55 | ], 56 | dtype=np.float32, 57 | ), 58 | y=np.array([[9, 8], [7, 6]], dtype=np.float32), 59 | ) 60 | -------------------------------------------------------------------------------- /tests/pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | 3 | log_level=ERROR 4 | log_cli=True 5 | log_cli_level=INFO 6 | 7 | filterwarnings = 8 | ignore::DeprecationWarning 9 | -------------------------------------------------------------------------------- /tests/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ENOT-AutoDL/onnx2torch/369412ad62c81ca5b360554572820755e31b9b7a/tests/utils/__init__.py -------------------------------------------------------------------------------- /tests/utils/resources.py: -------------------------------------------------------------------------------- 1 | import tarfile 2 | import urllib.request 3 | from pathlib import Path 4 | 5 | import onnx 6 | from google_drive_downloader import GoogleDriveDownloader 7 | from onnx import ModelProto # pylint: disable=no-name-in-module 8 | 9 | from tests import DATASETS_DIR 10 | from tests import MODELS_DIR 11 | 12 | _BASE_URL = 'https://gitlab.expasoft.com/p.ivanov/onnx2torch_data/-/raw/main/models_for_tests' 13 | 14 | _CHKP_DETECTION_URL = f'{_BASE_URL}/detection' 15 | _CHKP_SEGMENTATION_URL = f'{_BASE_URL}/segmentation' 16 | _CHKP_TRANSFORMERS_URL = f'{_BASE_URL}/transformers' 17 | _CHKP_KEYPOINTS_URL = f'{_BASE_URL}/keypoints' 18 | _CHKP_OTHER_URL = f'{_BASE_URL}/other' 19 | 20 | _ONNX_MODELS_IDS = { 21 | 'deeplabv3_mnv3_large': f'{_CHKP_SEGMENTATION_URL}/deeplabv3_mobilenet_v3_large.onnx', 22 | 'deeplabv3_plus_resnet101': f'{_CHKP_SEGMENTATION_URL}/deeplabv3_resnet101_dimans.onnx', 23 | 'hrnet': f'{_CHKP_SEGMENTATION_URL}/hrnet.onnx', 24 | 'unet': f'{_CHKP_SEGMENTATION_URL}/unet_resnet34.onnx', 25 | 'retinanet': f'{_CHKP_DETECTION_URL}/retinanet_r50_fpn.onnx', 26 | 'ssd300_vgg': f'{_CHKP_DETECTION_URL}/ssd300.onnx', 27 | 'ssdlite': f'{_CHKP_DETECTION_URL}/ssdlite.onnx', 28 | 'yolov3_d53': f'{_CHKP_DETECTION_URL}/yolov3_d53_tuned_shape.onnx', 29 | 'yolov5_ultralitics': f'{_CHKP_DETECTION_URL}/yolov5_ultralitics.onnx', 30 | 'swin': f'{_CHKP_TRANSFORMERS_URL}/swin.onnx', 31 | 'vit': f'{_CHKP_TRANSFORMERS_URL}/vit.onnx', 32 | 'gptj_2_random_blocks': f'{_CHKP_TRANSFORMERS_URL}/gptj_2_random_blocks.onnx', 33 | 'resnet50': 'https://github.com/onnx/models/raw/main/vision/classification/resnet/model/resnet50-v2-7.onnx', 34 | '3d_gan': f'{_CHKP_OTHER_URL}/3d_gan.onnx', 35 | 'shelfnet': f'{_CHKP_KEYPOINTS_URL}/shelfnet.onnx', 36 | 'point_arch': f'{_CHKP_OTHER_URL}/point_arch_nq.onnx', 37 | } 38 | 39 | _MINIMAL_DATASETS_ID = '1Vd7qfQotrRADPLFxViA2tRpz7tBymR31' 40 | 41 | 42 | def get_model_path(name: str) -> Path: # pylint: disable=missing-function-docstring 43 | model_path = MODELS_DIR / f'{name}.onnx' 44 | if not model_path.exists(): 45 | if name in _ONNX_MODELS_IDS: 46 | url = _ONNX_MODELS_IDS[name] 47 | urllib.request.urlretrieve(url=url, filename=model_path) 48 | else: 49 | raise RuntimeError('Cannot find model path.') 50 | 51 | return model_path 52 | 53 | 54 | def get_model(name: str) -> ModelProto: # pylint: disable=missing-function-docstring 55 | model_path = get_model_path(name) 56 | return onnx.load_model(str(model_path)) 57 | 58 | 59 | def get_minimal_dataset_path(): # pylint: disable=missing-function-docstring 60 | dataset_path = DATASETS_DIR / 'minimal_dataset' 61 | if not dataset_path.exists(): 62 | arch_path = dataset_path.with_suffix('.tar.gz') 63 | GoogleDriveDownloader.download_file_from_google_drive( 64 | file_id=_MINIMAL_DATASETS_ID, 65 | dest_path=arch_path, 66 | overwrite=True, 67 | ) 68 | with tarfile.open(arch_path, 'r:gz') as arch_file: 69 | arch_file.extractall(path=dataset_path) 70 | 71 | return dataset_path 72 | --------------------------------------------------------------------------------