├── .github └── workflows │ └── lint_and_test.yml ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── download_fixtures.sh ├── onnx2pytorch ├── __init__.py ├── constants.py ├── convert │ ├── __init__.py │ ├── attribute.py │ ├── debug.py │ ├── layer.py │ ├── model.py │ └── operations.py ├── helpers.py ├── operations │ ├── __init__.py │ ├── add.py │ ├── base.py │ ├── batchnorm.py │ ├── bitshift.py │ ├── cast.py │ ├── clip.py │ ├── constant.py │ ├── constantofshape.py │ ├── div.py │ ├── expand.py │ ├── flatten.py │ ├── gather.py │ ├── gathernd.py │ ├── globalaveragepool.py │ ├── hardsigmoid.py │ ├── instancenorm.py │ ├── loop.py │ ├── lstm.py │ ├── matmul.py │ ├── nonmaxsuppression.py │ ├── onehot.py │ ├── pad.py │ ├── prelu.py │ ├── range.py │ ├── reducel2.py │ ├── reducemax.py │ ├── reducesum.py │ ├── reshape.py │ ├── resize.py │ ├── scatter.py │ ├── scatterelements.py │ ├── scatternd.py │ ├── shape.py │ ├── slice.py │ ├── split.py │ ├── squeeze.py │ ├── thresholdedrelu.py │ ├── tile.py │ ├── topk.py │ ├── transpose.py │ ├── unsqueeze.py │ └── where.py └── utils.py ├── requirements.txt ├── setup.py ├── tests ├── __init__.py ├── onnx2pytorch │ ├── __init__.py │ ├── conftest.py │ ├── convert │ │ ├── __init__.py │ │ ├── test_attribute.py │ │ ├── test_debug.py │ │ ├── test_loop.py │ │ ├── test_lstm.py │ │ ├── test_maxpool.py │ │ └── test_train.py │ ├── operations │ │ ├── __init__.py │ │ ├── test_add.py │ │ ├── test_bitshift.py │ │ ├── test_cast.py │ │ ├── test_clip.py │ │ ├── test_constantofshape.py │ │ ├── test_div.py │ │ ├── test_expand.py │ │ ├── test_flatten.py │ │ ├── test_gather.py │ │ ├── test_gathernd.py │ │ ├── test_globalaveragepool.py │ │ ├── test_hardsigmoid.py │ │ ├── test_instancenorm.py │ │ ├── test_nonmaxsuppression.py │ │ ├── test_onehot.py │ │ ├── test_pad.py │ │ ├── test_prelu.py │ │ ├── test_range.py │ │ ├── test_reducel2.py │ │ ├── test_reducemax.py │ │ ├── test_reducesum.py │ │ ├── test_reshape.py │ │ ├── test_resize.py │ │ ├── test_scatter.py │ │ ├── test_scatterelements.py │ │ ├── test_scatternd.py │ │ ├── test_slice.py │ │ ├── test_split.py │ │ ├── test_squeeze.py │ │ ├── test_thresholdedrelu.py │ │ ├── test_tile.py │ │ ├── test_topk.py │ │ ├── test_unsqueeze.py │ │ └── test_where.py │ ├── test_convert.py │ ├── test_onnx2pytorch.py │ └── test_utils.py └── test_imports.py └── tox.ini /.github/workflows/lint_and_test.yml: -------------------------------------------------------------------------------- 1 | name: Lint and Test 2 | 3 | on: [pull_request] 4 | 5 | jobs: 6 | test: 7 | runs-on: ubuntu-latest 8 | strategy: 9 | matrix: 10 | python-version: [ '3.9', '3.10', '3.11', '3.12' ] 11 | 12 | steps: 13 | - name: Checkout code 14 | uses: actions/checkout@v4 15 | 16 | - name: Set up Python ${{ matrix.python-version }} 17 | uses: actions/setup-python@v5 18 | with: 19 | python-version: ${{ matrix.python-version }} 20 | 21 | - name: Cache dependencies 22 | uses: actions/cache@v4 23 | with: 24 | path: ~/.cache/pip 25 | key: ${{ runner.os }}-pip-${{ matrix.python-version }}-${{ hashFiles('**/requirements.txt') }} 26 | restore-keys: | 27 | ${{ runner.os }}-pip-${{ matrix.python-version }}- 28 | ${{ runner.os }}-pip- 29 | 30 | - name: Install dependencies 31 | run: | 32 | pip install tox tox-gh-actions 33 | 34 | - name: Run tests 35 | run: | 36 | bash download_fixtures.sh 37 | tox 38 | 39 | - name: Upload coverage to GitHub Artifacts 40 | uses: actions/upload-artifact@v4 41 | with: 42 | name: coverage-${{ matrix.python-version }} 43 | path: htmlcov/ -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Byte-compiled / optimized / DLL files 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # C extensions and cython compiled files 8 | *.so 9 | *.c 10 | 11 | # Distribution / packaging 12 | .Python 13 | env/ 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .coverage 43 | .coverage.* 44 | .cache 45 | nosetests.xml 46 | coverage.xml 47 | *,cover 48 | .hypothesis/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | 58 | # Sphinx documentation 59 | docs/_build/ 60 | 61 | # PyBuilder 62 | target/ 63 | 64 | #Ipython Notebook 65 | .ipynb_checkpoints 66 | 67 | # pyenv 68 | .python-version 69 | 70 | # intelij and mac 71 | .DC_Store 72 | .DS_Store 73 | .idea/ 74 | 75 | # models 76 | *.onnx 77 | *.pb 78 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/psf/black 3 | rev: 24.8.0 4 | hooks: 5 | - id: black 6 | language_version: python3.10 -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ONNX to PyTorch 2 | ![PyPI - License](https://img.shields.io/pypi/l/onnx2pytorch?color) 3 | [![Lint and Test](https://github.com/Talmaj/onnx2pytorch/actions/workflows/lint_and_test.yml/badge.svg)](https://github.com/Talmaj/onnx2pytorch/actions/workflows/lint_and_test.yml) 4 | [![Downloads](https://pepy.tech/badge/onnx2pytorch)](https://pepy.tech/project/onnx2pytorch) 5 | ![PyPI](https://img.shields.io/pypi/v/onnx2pytorch) 6 | 7 | A library to transform ONNX model to PyTorch. This library enables use of PyTorch 8 | backend and all of its great features for manipulation of neural networks. 9 | 10 | ## Installation 11 | ```pip install onnx2pytorch``` 12 | 13 | ## Usage 14 | ``` 15 | import onnx 16 | from onnx2pytorch import ConvertModel 17 | 18 | onnx_model = onnx.load(path_to_onnx_model) 19 | pytorch_model = ConvertModel(onnx_model) 20 | ``` 21 | 22 | Currently supported and tested models from [onnx_zoo](https://github.com/onnx/models): 23 | - [MobileNet](https://github.com/onnx/models/tree/master/vision/classification/mobilenet) 24 | - [ResNet](https://github.com/onnx/models/tree/master/vision/classification/resnet) 25 | - [ShuffleNet_V2](https://github.com/onnx/models/tree/master/vision/classification/shufflenet) 26 | - [BERT-Squad](https://github.com/onnx/models/tree/master/text/machine_comprehension/bert-squad) 27 | - [EfficientNet-Lite4](https://github.com/onnx/models/tree/master/vision/classification/efficientnet-lite4) 28 | - [Fast Neural Style Transfer](https://github.com/onnx/models/tree/master/vision/style_transfer/fast_neural_style) 29 | - [Super Resolution](https://github.com/onnx/models/tree/master/vision/super_resolution/sub_pixel_cnn_2016) 30 | - [YOLOv4](https://github.com/onnx/models/tree/master/vision/object_detection_segmentation/yolov4) 31 | (Not exactly the same, nearest neighbour interpolation in pytorch differs) 32 | - [U-net](https://pytorch.org/hub/mateuszbuda_brain-segmentation-pytorch_unet/) 33 | (Converted from pytorch to onnx and then back) 34 | 35 | ## Limitations 36 | Known current version limitations are: 37 | - `batch_size > 1` could deliver unexpected results due to ambiguity of onnx's BatchNorm layer. 38 | That is why in this case for now we raise an assertion error. 39 | Set `experimental=True` in `ConvertModel` to be able to use `batch_size > 1`. 40 | - Fine tuning and training of converted models was not tested yet, only inference. 41 | 42 | ## Development 43 | ### Dependency installation 44 | ```pip install -r requirements.txt``` 45 | 46 | From onnxruntime>=1.5.0 you need to add the 47 | following to your .bashrc or .zshrc if you are running OSx: 48 | ```export KMP_DUPLICATE_LIB_OK=True``` 49 | 50 | ### Code formatting 51 | The Uncompromising Code Formatter: [Black](https://github.com/psf/black) 52 | ```black {source_file_or_directory}``` 53 | 54 | Install it into pre-commit hook to always commit nicely formatted code: 55 | ```pre-commit install``` 56 | 57 | ### Testing 58 | [Pytest](https://docs.pytest.org/en/latest/) and [tox](https://tox.readthedocs.io/en/latest/). 59 | ```tox``` 60 | #### Test fixtures 61 | To test the complete conversion of an onnx model download pre-trained models: 62 | ```./download_fixtures.sh``` 63 | Use flag `--all` to download more models. 64 | Add any custom models to `./fixtures` folder to test their conversion. 65 | 66 | ### Debugging 67 | Set `ConvertModel(..., debug=True)` to compare each converted 68 | activation from pytorch with the activation from onnxruntime. 69 | This helps identify where in the graph the activations start to differ. 70 | -------------------------------------------------------------------------------- /download_fixtures.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | mkdir -p fixtures 3 | cd fixtures 4 | 5 | if [[ ! -f mobilenetv2-1.0.onnx ]]; then 6 | echo Downloading mobilenetv2-1.0 7 | curl -o mobilenetv2-1.0.onnx https://s3.amazonaws.com/onnx-model-zoo/mobilenet/mobilenetv2-1.0/mobilenetv2-1.0.onnx 8 | fi 9 | 10 | if [[ ! -f shufflenet_v2.onnx ]]; then 11 | echo Downloading shufflenet_v2 12 | curl -LJo shufflenet_v2.onnx https://github.com/onnx/models/raw/main/validated/vision/classification/shufflenet/model/shufflenet-v2-10.onnx 13 | fi 14 | 15 | if [[ $1 == "--all" ]]; then 16 | if [[ ! -f resnet18v1.onnx ]]; then 17 | echo Downloading resnet18v1 18 | curl -o resnet18v1.onnx https://s3.amazonaws.com/onnx-model-zoo/resnet/resnet18v1/resnet18v1.onnx 19 | fi 20 | 21 | if [[ ! -f bertsquad-10.onnx ]]; then 22 | echo Downloading bertsquad-10 23 | curl -LJo bertsquad-10.onnx https://github.com/onnx/models/raw/main/validated/text/machine_comprehension/bert-squad/model/bertsquad-10.onnx 24 | fi 25 | 26 | if [[ ! -f yolo_v4.onnx ]]; then 27 | echo Downloading yolo_v4 28 | curl -LJo yolo_v4.onnx https://github.com/onnx/models/raw/main/validated/vision/object_detection_segmentation/yolov4/model/yolov4.onnx 29 | fi 30 | 31 | if [[ ! -f super_res.onnx ]]; then 32 | echo Downloading super_res 33 | curl -LJo super_res.onnx https://github.com/onnx/models/raw/main/validated/vision/super_resolution/sub_pixel_cnn_2016/model/super-resolution-10.onnx 34 | fi 35 | 36 | if [[ ! -f fast_neural_style.onnx ]]; then 37 | echo Downloading fast_neural_style 38 | curl -LJo fast_neural_style.onnx https://github.com/onnx/models/raw/main/validated/vision/style_transfer/fast_neural_style/model/rain-princess-9.onnx 39 | fi 40 | 41 | if [[ ! -f efficientnet-lite4.onnx ]]; then 42 | echo Downloading efficientnet-lite4 43 | curl -LJo efficientnet-lite4.onnx https://github.com/onnx/models/raw/main/validated/vision/classification/efficientnet-lite4/model/efficientnet-lite4-11.onnx 44 | fi 45 | 46 | if [[ ! -f mobilenetv2-7.onnx ]]; then 47 | echo Downloading mobilenetv2-7 48 | curl -LJo mobilenetv2-7.onnx https://github.com/onnx/models/raw/main/validated/vision/classification/mobilenet/model/mobilenetv2-7.onnx 49 | fi 50 | 51 | fi 52 | 53 | echo "All models downloaded." 54 | -------------------------------------------------------------------------------- /onnx2pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from .convert import ConvertModel 2 | 3 | __version__ = "0.5.1" 4 | -------------------------------------------------------------------------------- /onnx2pytorch/constants.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.nn.modules.conv import _ConvNd 3 | from torch.nn.modules.pooling import _MaxPoolNd 4 | from onnx2pytorch.operations import ( 5 | BatchNormWrapper, 6 | InstanceNormWrapper, 7 | Loop, 8 | LSTMWrapper, 9 | Split, 10 | TopK, 11 | ) 12 | 13 | 14 | COMPOSITE_LAYERS = (nn.Sequential,) 15 | MULTIOUTPUT_LAYERS = (_MaxPoolNd, Loop, LSTMWrapper, Split, TopK) 16 | STANDARD_LAYERS = ( 17 | _ConvNd, 18 | BatchNormWrapper, 19 | InstanceNormWrapper, 20 | LSTMWrapper, 21 | nn.Linear, 22 | ) 23 | -------------------------------------------------------------------------------- /onnx2pytorch/convert/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import ConvertModel 2 | from .layer import * 3 | -------------------------------------------------------------------------------- /onnx2pytorch/convert/attribute.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import onnx 4 | from onnx import numpy_helper 5 | 6 | from onnx2pytorch.utils import ( 7 | extract_padding_params_for_conv_layer, 8 | extract_padding_params, 9 | ) 10 | 11 | TENSOR_PROTO_MAPPING = dict([i[::-1] for i in onnx.TensorProto.DataType.items()]) 12 | 13 | AttributeType = dict( 14 | UNDEFINED=0, 15 | FLOAT=1, 16 | INT=2, 17 | STRING=3, 18 | TENSOR=4, 19 | GRAPH=5, 20 | SPARSE_TENSOR=11, 21 | FLOATS=6, 22 | INTS=7, 23 | STRINGS=8, 24 | TENSORS=9, 25 | GRAPHS=10, 26 | SPARSE_TENSORS=12, 27 | ) 28 | 29 | 30 | def extract_attr_values(attr): 31 | """Extract onnx attribute values.""" 32 | if attr.type == AttributeType["INT"]: 33 | value = attr.i 34 | elif attr.type == AttributeType["FLOAT"]: 35 | value = attr.f 36 | elif attr.type == AttributeType["INTS"]: 37 | value = tuple(attr.ints) 38 | elif attr.type == AttributeType["FLOATS"]: 39 | value = tuple(attr.floats) 40 | elif attr.type == AttributeType["TENSOR"]: 41 | value = numpy_helper.to_array(attr.t) 42 | elif attr.type == AttributeType["STRING"]: 43 | value = attr.s.decode() 44 | elif attr.type == AttributeType["GRAPH"]: 45 | value = attr.g 46 | else: 47 | raise NotImplementedError( 48 | "Extraction of attribute type {} not implemented.".format(attr.type) 49 | ) 50 | return value 51 | 52 | 53 | def extract_attributes(node): 54 | """Extract onnx attributes. Map onnx feature naming to pytorch.""" 55 | kwargs = {} 56 | for attr in node.attribute: 57 | if attr.name == "activation_alpha": 58 | kwargs["activation_alpha"] = extract_attr_values(attr) 59 | elif attr.name == "activation_beta": 60 | kwargs["activation_beta"] = extract_attr_values(attr) 61 | elif attr.name == "activations": 62 | kwargs["activations"] = extract_attr_values(attr) 63 | elif attr.name == "alpha": 64 | if node.op_type == "LeakyRelu": 65 | kwargs["negative_slope"] = extract_attr_values(attr) 66 | elif node.op_type in ("Elu", "ThresholdedRelu"): 67 | kwargs["alpha"] = extract_attr_values(attr) 68 | elif node.op_type == "HardSigmoid": 69 | kwargs["alpha"] = extract_attr_values(attr) 70 | else: 71 | kwargs["weight_multiplier"] = extract_attr_values(attr) 72 | elif attr.name == "auto_pad": 73 | value = extract_attr_values(attr) 74 | if value == "NOTSET": 75 | pass 76 | else: 77 | raise NotImplementedError( 78 | "auto_pad={} functionality not implemented.".format(value) 79 | ) 80 | elif attr.name == "axis" and node.op_type == "Flatten": 81 | kwargs["start_dim"] = extract_attr_values(attr) 82 | elif attr.name == "axis" or attr.name == "axes": 83 | v = extract_attr_values(attr) 84 | if isinstance(v, (tuple, list)) and len(v) == 1: 85 | kwargs["dim"] = v[0] 86 | else: 87 | kwargs["dim"] = v 88 | elif attr.name == "beta": 89 | if node.op_type == "HardSigmoid": 90 | kwargs["beta"] = extract_attr_values(attr) 91 | else: 92 | kwargs["bias_multiplier"] = extract_attr_values(attr) 93 | elif attr.name == "body": 94 | kwargs["body"] = extract_attr_values(attr) 95 | elif attr.name == "ceil_mode": 96 | kwargs["ceil_mode"] = bool(extract_attr_values(attr)) 97 | elif attr.name == "center_point_box": 98 | kwargs["center_point_box"] = extract_attr_values(attr) 99 | elif attr.name == "clip": 100 | kwargs["clip"] = extract_attr_values(attr) 101 | elif attr.name == "coordinate_transformation_mode": 102 | arg = extract_attr_values(attr) 103 | if arg == "align_corners": 104 | kwargs["align_corners"] = True 105 | else: 106 | warnings.warn( 107 | "Pytorch's interpolate uses no coordinate_transformation_mode={}. " 108 | "Result might differ.".format(arg) 109 | ) 110 | elif attr.name == "dilations": 111 | kwargs["dilation"] = extract_attr_values(attr) 112 | elif attr.name == "direction": 113 | kwargs["direction"] = extract_attr_values(attr) 114 | elif attr.name == "ends": 115 | kwargs["ends"] = extract_attr_values(attr) 116 | elif attr.name == "epsilon": 117 | kwargs["eps"] = extract_attr_values(attr) 118 | elif attr.name == "group": 119 | kwargs["groups"] = extract_attr_values(attr) 120 | elif attr.name == "hidden_size": 121 | kwargs["hidden_size"] = extract_attr_values(attr) 122 | elif attr.name == "input_forget": 123 | kwargs["input_forget"] = extract_attr_values(attr) 124 | elif attr.name == "keepdims": 125 | kwargs["keepdim"] = bool(extract_attr_values(attr)) 126 | elif attr.name == "kernel_shape": 127 | kwargs["kernel_size"] = extract_attr_values(attr) 128 | elif attr.name == "largest": 129 | kwargs["largest"] = extract_attr_values(attr) 130 | elif attr.name == "layout": 131 | kwargs["layout"] = extract_attr_values(attr) 132 | elif attr.name == "max": 133 | kwargs["max"] = extract_attr_values(attr) 134 | elif attr.name == "min": 135 | kwargs["min"] = extract_attr_values(attr) 136 | elif attr.name == "mode": 137 | kwargs["mode"] = extract_attr_values(attr) 138 | elif attr.name == "momentum": 139 | kwargs["momentum"] = extract_attr_values(attr) 140 | elif attr.name == "noop_with_empty_axes": 141 | kwargs["noop_with_empty_axes"] = extract_attr_values(attr) 142 | elif attr.name == "output_shape" and node.op_type == "ConvTranspose": 143 | raise NotImplementedError( 144 | "ConvTranspose with dynamic padding not implemented." 145 | ) 146 | elif attr.name == "pads": 147 | params = extract_attr_values(attr) 148 | if node.op_type == "Pad": 149 | kwargs["padding"] = extract_padding_params(params) 150 | else: 151 | # Works for Conv, MaxPooling and other layers from convert_layer func 152 | kwargs["padding"] = extract_padding_params_for_conv_layer(params) 153 | elif attr.name == "perm": 154 | kwargs["dims"] = extract_attr_values(attr) 155 | elif attr.name == "repeats": 156 | kwargs["repeats"] = extract_attr_values(attr) 157 | elif attr.name == "sorted": 158 | kwargs["sorted"] = extract_attr_values(attr) 159 | elif attr.name == "sparse_value": 160 | kwargs["constant"] = extract_attr_values(attr) 161 | elif attr.name == "spatial": 162 | kwargs["spatial"] = extract_attr_values(attr) # Batch norm parameter 163 | elif attr.name == "split": 164 | kwargs["split_size_or_sections"] = extract_attr_values(attr) 165 | elif attr.name == "strides": 166 | kwargs["stride"] = extract_attr_values(attr) 167 | elif attr.name == "starts": 168 | kwargs["starts"] = extract_attr_values(attr) 169 | elif attr.name == "to": 170 | kwargs["dtype"] = TENSOR_PROTO_MAPPING[extract_attr_values(attr)].lower() 171 | elif attr.name == "transB": 172 | kwargs["transpose_weight"] = not extract_attr_values(attr) 173 | elif attr.name == "transA": 174 | kwargs["transpose_activation"] = bool(extract_attr_values(attr)) 175 | elif attr.name == "value": 176 | kwargs["constant"] = extract_attr_values(attr) 177 | elif attr.name == "value_float": 178 | kwargs["constant"] = extract_attr_values(attr) 179 | elif attr.name == "value_floats": 180 | kwargs["constant"] = extract_attr_values(attr) 181 | elif attr.name == "value_int": 182 | kwargs["constant"] = extract_attr_values(attr) 183 | elif attr.name == "value_ints": 184 | kwargs["constant"] = extract_attr_values(attr) 185 | elif attr.name == "value_string": 186 | kwargs["constant"] = extract_attr_values(attr) 187 | elif attr.name == "value_strings": 188 | kwargs["constant"] = extract_attr_values(attr) 189 | elif node.op_type == "Resize": 190 | # These parameters are not used, warn in Resize operator 191 | kwargs[attr.name] = extract_attr_values(attr) 192 | else: 193 | raise NotImplementedError( 194 | "Extraction of attribute {} not implemented.".format(attr.name) 195 | ) 196 | return kwargs 197 | -------------------------------------------------------------------------------- /onnx2pytorch/convert/debug.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from onnx2pytorch.utils import get_activation_value 5 | 6 | 7 | def debug_model_conversion(onnx_model, inputs, pred_act, node, rtol=1e-3, atol=1e-4): 8 | """Compare if the activations of pytorch are the same as from onnxruntime.""" 9 | if not isinstance(inputs, list): 10 | raise TypeError("inputs should be in a list.") 11 | 12 | if not all(isinstance(x, np.ndarray) for x in inputs): 13 | inputs = [x.detach().cpu().numpy() for x in inputs] 14 | 15 | exp_act = get_activation_value(onnx_model, inputs, list(node.output)) 16 | if isinstance(pred_act, list): 17 | assert len(exp_act) == len(pred_act) 18 | for a, b in zip(exp_act, pred_act): 19 | exp = torch.from_numpy(a).cpu() 20 | pred = b.cpu() 21 | assert torch.equal(torch.tensor(exp.shape), torch.tensor(pred.shape)) 22 | assert torch.allclose(exp, pred, rtol=rtol, atol=atol) 23 | else: 24 | exp = torch.from_numpy(exp_act[0]).cpu() 25 | pred = pred_act.cpu() 26 | assert torch.equal(torch.tensor(exp.shape), torch.tensor(pred.shape)) 27 | assert torch.allclose(exp, pred, rtol=rtol, atol=atol) 28 | -------------------------------------------------------------------------------- /onnx2pytorch/convert/model.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from copy import deepcopy 3 | from functools import partial 4 | import warnings 5 | 6 | import numpy as np 7 | import onnx 8 | import torch 9 | from onnx import numpy_helper 10 | from torch import nn 11 | from torch.jit import TracerWarning 12 | from torch.nn.modules.linear import Identity 13 | 14 | from onnx2pytorch.constants import ( 15 | COMPOSITE_LAYERS, 16 | MULTIOUTPUT_LAYERS, 17 | STANDARD_LAYERS, 18 | ) 19 | from onnx2pytorch.convert.debug import debug_model_conversion 20 | from onnx2pytorch.convert.operations import ( 21 | convert_operations, 22 | get_buffer_name, 23 | get_init_parameter, 24 | Loop, 25 | ) 26 | from onnx2pytorch.utils import ( 27 | get_inputs_names, 28 | get_outputs_names, 29 | ) 30 | 31 | 32 | def compute_activation_dependencies(onnx_graph, model, mapping): 33 | """ 34 | Compute activation dependencies, mapping each node to its dependents. 35 | 36 | Parameters 37 | ---------- 38 | onnx_graph: onnx.GraphProto 39 | ONNX graph. 40 | model: onnx2pytorch.ConvertModel 41 | Module which contains converted submodules. 42 | mapping: dict 43 | Dictionary mapping from node name to name of submodule. 44 | 45 | Returns 46 | ------- 47 | needed_by: dict 48 | Dictionary mapping from node name to names of its dependents. 49 | """ 50 | needed_by = defaultdict(set) 51 | for node in onnx_graph.node: 52 | out_op_id = node.output[0] 53 | for in_op_id in node.input: 54 | needed_by[in_op_id].add(out_op_id) 55 | if node.op_type == "Loop": 56 | # Look at nodes in the loop body 57 | l1 = getattr(model, mapping[out_op_id]) # Loop object 58 | loop_body_l1 = l1.body 59 | for node_l1 in loop_body_l1.node: 60 | for in_op_id in node_l1.input: 61 | # Treating node (outer loop) as dependent, not node_l1 62 | needed_by[in_op_id].add(out_op_id) 63 | if node_l1.op_type == "Loop": 64 | # Look at nodes in the loop body 65 | l2 = getattr(model, l1.mapping[node_l1.output[0]]) # Loop object 66 | loop_body_l2 = l2.body 67 | for node_l2 in loop_body_l2.node: 68 | for in_op_id in node_l2.input: 69 | # Treating node (outer loop) as dependent, not node_l2 70 | needed_by[in_op_id].add(out_op_id) 71 | if node_l2.op_type == "Loop": 72 | # TODO: make this recursive for nested loops 73 | raise NotImplementedError( 74 | "Activation garbage collection not implemented for >2 nested loops." 75 | ) 76 | needed_by.default_factory = None 77 | return needed_by 78 | 79 | 80 | class ConvertModel(nn.Module): 81 | def __init__( 82 | self, 83 | onnx_model: onnx.ModelProto, 84 | batch_dim=0, 85 | experimental=False, 86 | debug=False, 87 | enable_pruning=False, 88 | ): 89 | """ 90 | Convert onnx model to pytorch. 91 | 92 | Parameters 93 | ---------- 94 | onnx_model: onnx.ModelProto 95 | Loaded onnx model. 96 | batch_dim: int 97 | Dimension of the batch. 98 | experimental: bool 99 | Experimental implementation allows batch_size > 1. However, 100 | batchnorm layers could potentially produce false outputs. 101 | enable_pruning: bool 102 | Track kept/pruned indices between different calls to forward pass. 103 | 104 | Returns 105 | ------- 106 | model: torch.nn.Module 107 | A converted pytorch model. 108 | """ 109 | super().__init__() 110 | self.onnx_model = onnx_model 111 | self.batch_dim = batch_dim 112 | self.experimental = experimental 113 | self.debug = debug 114 | self.enable_pruning = enable_pruning 115 | 116 | self.input_names = get_inputs_names(onnx_model.graph) 117 | self.output_names = get_outputs_names(onnx_model.graph) 118 | opset_version = onnx_model.opset_import[0].version 119 | 120 | # Create mapping from node (identified by first output) to submodule 121 | self.mapping = {} 122 | for op_id, op_name, op in convert_operations( 123 | onnx_model.graph, 124 | opset_version, 125 | batch_dim, 126 | enable_pruning, 127 | ): 128 | setattr(self, op_name, op) 129 | if isinstance(op, Loop) and debug: 130 | raise NotImplementedError("debug-mode with Loop node not implemented.") 131 | self.mapping[op_id] = op_name 132 | 133 | # Store initializers as buffers 134 | for tensor in self.onnx_model.graph.initializer: 135 | buffer_name = get_buffer_name(tensor.name) 136 | self.register_buffer( 137 | buffer_name, 138 | torch.from_numpy(numpy_helper.to_array(tensor)), 139 | ) 140 | 141 | # Compute activation dependencies, mapping each node to its dependents 142 | self.needed_by = compute_activation_dependencies( 143 | self.onnx_model.graph, self, self.mapping 144 | ) 145 | 146 | if experimental: 147 | warnings.warn( 148 | "Using experimental implementation that allows 'batch_size > 1'." 149 | "Batchnorm layers could potentially produce false outputs." 150 | ) 151 | 152 | def forward(self, *input_list, **input_dict): 153 | if len(input_list) > 0 and len(input_dict) > 0: 154 | raise ValueError( 155 | "forward-pass accepts either input_list (positional args) or " 156 | "input_dict (keyword args) but not both" 157 | ) 158 | if len(input_list) > 0: 159 | inputs = input_list 160 | if len(input_dict) > 0: 161 | inputs = [input_dict[key] for key in self.input_names] 162 | 163 | if not self.experimental and inputs[0].shape[self.batch_dim] > 1: 164 | raise NotImplementedError( 165 | "Input with larger batch size than 1 not supported yet." 166 | ) 167 | activations = dict(zip(self.input_names, inputs)) 168 | still_needed_by = deepcopy(self.needed_by) 169 | 170 | for node in self.onnx_model.graph.node: 171 | # Identifying the layer ids and names 172 | out_op_id = node.output[0] 173 | out_op_name = self.mapping[out_op_id] 174 | in_op_names = [ 175 | self.mapping.get(in_op_id, in_op_id) 176 | for in_op_id in node.input 177 | if in_op_id in activations 178 | ] 179 | 180 | # getting correct layer 181 | op = getattr(self, out_op_name) 182 | 183 | # if first layer choose input as in_activations 184 | # if not in_op_names and len(node.input) == 1: 185 | # in_activations = input 186 | if isinstance(op, STANDARD_LAYERS) or ( 187 | isinstance(op, COMPOSITE_LAYERS) 188 | and any(isinstance(x, STANDARD_LAYERS) for x in op.modules()) 189 | ): 190 | in_activations = [ 191 | activations[in_op_id] 192 | for in_op_id in node.input 193 | if in_op_id in activations 194 | ] 195 | else: 196 | in_activations = [ 197 | activations[in_op_id] if in_op_id in activations 198 | # if in_op_id not in activations neither in parameters then 199 | # it must be the initial input 200 | else get_init_parameter([self], in_op_id, inputs[0]) 201 | for in_op_id in node.input 202 | ] 203 | 204 | in_activations = [in_act for in_act in in_activations if in_act is not None] 205 | 206 | # store activations for next layer 207 | if isinstance(op, Loop): 208 | outputs = op((self,), activations, *in_activations) 209 | for out_op_id, output in zip(node.output, outputs): 210 | activations[out_op_id] = output 211 | elif isinstance(op, partial) and op.func == torch.cat: 212 | activations[out_op_id] = op(in_activations) 213 | elif isinstance(op, Identity): 214 | # After batch norm fusion the batch norm parameters 215 | # were all passed to identity instead of first one only 216 | activations[out_op_id] = op(in_activations[0]) 217 | elif isinstance(op, MULTIOUTPUT_LAYERS) or ( 218 | isinstance(op, COMPOSITE_LAYERS) 219 | and any(isinstance(x, MULTIOUTPUT_LAYERS) for x in op.modules()) 220 | ): 221 | for out_op_id, output in zip(node.output, op(*in_activations)): 222 | activations[out_op_id] = output 223 | else: 224 | activations[out_op_id] = op(*in_activations) 225 | 226 | # Remove activations that are no longer needed 227 | for in_op_id in node.input: 228 | if in_op_id in still_needed_by: 229 | still_needed_by[in_op_id].discard(out_op_id) 230 | if len(still_needed_by[in_op_id]) == 0: 231 | if in_op_id in activations: 232 | del activations[in_op_id] 233 | 234 | if self.debug: 235 | # compare if the activations of pytorch are the same as from onnxruntime 236 | debug_model_conversion( 237 | self.onnx_model, 238 | [activations[x] for x in self.input_names], 239 | [activations[out_op_id] for out_op_id in node.output], 240 | node, 241 | ) 242 | 243 | # collect all outputs 244 | outputs = [activations[x] for x in self.output_names] 245 | if len(outputs) == 1: 246 | outputs = outputs[0] 247 | return outputs 248 | -------------------------------------------------------------------------------- /onnx2pytorch/helpers.py: -------------------------------------------------------------------------------- 1 | import io 2 | import torch 3 | import onnx 4 | import torch.onnx 5 | 6 | from onnx2pytorch import ConvertModel 7 | 8 | 9 | def to_onnx(model, inp_size, device=torch.device("cpu"), do_constant_folding=False): 10 | if isinstance(inp_size, (tuple, list)) and not isinstance(inp_size[0], int): 11 | input_image = tuple([torch.rand(i, device=device) for i in inp_size]) 12 | else: 13 | input_image = torch.rand(inp_size, device=device) 14 | 15 | model.to(device) 16 | bitstream = io.BytesIO() 17 | torch.onnx.export( 18 | model, 19 | input_image, 20 | bitstream, 21 | export_params=True, 22 | opset_version=11, 23 | do_constant_folding=do_constant_folding, 24 | input_names=["input"], 25 | output_names=["output"], 26 | ) 27 | return onnx.ModelProto.FromString(bitstream.getvalue()) 28 | 29 | 30 | def to_converted(model, inp_size): 31 | onnx_model = to_onnx(model, inp_size) 32 | model = ConvertModel(onnx_model) 33 | return model 34 | -------------------------------------------------------------------------------- /onnx2pytorch/operations/__init__.py: -------------------------------------------------------------------------------- 1 | from .add import Add 2 | from .batchnorm import BatchNormWrapper 3 | from .bitshift import BitShift 4 | from .cast import Cast 5 | from .clip import Clip 6 | from .constant import Constant 7 | from .constantofshape import ConstantOfShape 8 | from .div import Div 9 | from .expand import Expand 10 | from .flatten import Flatten 11 | from .gather import Gather 12 | from .gathernd import GatherND 13 | from .globalaveragepool import GlobalAveragePool 14 | from .hardsigmoid import Hardsigmoid 15 | from .instancenorm import InstanceNormWrapper 16 | from .loop import Loop 17 | from .lstm import LSTMWrapper 18 | from .matmul import MatMul 19 | from .nonmaxsuppression import NonMaxSuppression 20 | from .onehot import OneHot 21 | from .pad import Pad 22 | from .prelu import PRelu 23 | from .range import Range 24 | from .reducemax import ReduceMax 25 | from .reducesum import ReduceSum 26 | from .reducel2 import ReduceL2 27 | from .reshape import Reshape 28 | from .resize import Resize, Upsample 29 | from .scatter import Scatter 30 | from .scatterelements import ScatterElements 31 | from .scatternd import ScatterND 32 | from .shape import Shape 33 | from .slice import Slice 34 | from .split import Split 35 | from .squeeze import Squeeze 36 | from .thresholdedrelu import ThresholdedRelu 37 | from .tile import Tile 38 | from .topk import TopK 39 | from .transpose import Transpose 40 | from .unsqueeze import Unsqueeze 41 | from .where import Where 42 | 43 | __all__ = [ 44 | "Add", 45 | "BatchNormWrapper", 46 | "BitShift", 47 | "Cast", 48 | "Clip", 49 | "Constant", 50 | "ConstantOfShape", 51 | "Div", 52 | "Expand", 53 | "Flatten", 54 | "Gather", 55 | "GatherND", 56 | "GlobalAveragePool", 57 | "InstanceNormWrapper", 58 | "Loop", 59 | "LSTMWrapper", 60 | "MatMul", 61 | "NonMaxSuppression", 62 | "OneHot", 63 | "Pad", 64 | "PRelu", 65 | "Range", 66 | "ReduceMax", 67 | "ReduceSum", 68 | "ReduceL2", 69 | "Reshape", 70 | "Resize", 71 | "Scatter", 72 | "ScatterElements", 73 | "ScatterND", 74 | "Shape", 75 | "Slice", 76 | "Split", 77 | "Squeeze", 78 | "ThresholdedRelu", 79 | "Tile", 80 | "TopK", 81 | "Transpose", 82 | "Unsqueeze", 83 | "Upsample", 84 | "Where", 85 | ] 86 | -------------------------------------------------------------------------------- /onnx2pytorch/operations/add.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import torch 4 | from torch import nn 5 | 6 | from onnx2pytorch.utils import is_constant, get_selection 7 | from onnx2pytorch.operations.base import Operator 8 | 9 | 10 | class Add(Operator): 11 | def __init__(self, input_shape=None, input_indices=None, feature_dim=1): 12 | self.input_shape = input_shape 13 | self.input_indices = input_indices 14 | self.feature_dim = feature_dim # 2 for transformers else 1 15 | self.out = None 16 | 17 | if input_shape and input_indices: 18 | self.out = torch.zeros(input_shape) 19 | 20 | super().__init__() 21 | 22 | def forward(self, *input): 23 | if self.input_indices: 24 | out = self.out * 0 25 | for inp, idx in zip(input, self.input_indices): 26 | selection = get_selection(idx, self.feature_dim) 27 | out[selection] += inp 28 | return out 29 | 30 | # Reorder input so that the matrix is first 31 | if is_constant(input[0]): 32 | input = sorted(input, key=lambda x: -len(x.shape)) 33 | # Reorder input so that the broadcasted matrix is last 34 | elif all(x == 1 for x in input[0].shape): 35 | input = sorted(input, key=lambda x: -sum(x.shape)) 36 | out = input[0].clone() 37 | for inp in input[1:]: 38 | out += inp 39 | return out 40 | 41 | def set_input_indices(self, input): 42 | assert isinstance(input, (list, tuple)) 43 | 44 | # If all but one of the inputs are constants do nothing 45 | # One tensor can easily add together with any number of constants 46 | if sum(is_constant(inp) for inp in input) >= len(input) - 1: 47 | return 48 | 49 | input_shape = input[0].shape 50 | if not all(input_shape == inp.shape for inp in input[1:]): 51 | warnings.warn("Addition might be corrupted.", RuntimeWarning) 52 | assert all( 53 | is_constant(inp) or input_shape[-1] == inp.shape[-1] for inp in input 54 | ) 55 | 56 | # HACK 57 | while self.feature_dim >= len(input_shape): 58 | self.feature_dim -= 1 59 | axis = self.get_axis(input_shape, self.feature_dim) 60 | 61 | input_indices = [] 62 | for inp in input: 63 | mask = inp != 0 64 | if len(inp.shape) > 1: 65 | # Where mask is == 0, the complete input channel can be removed 66 | s = mask.sum(axis=tuple(axis)) 67 | # If inp is triangular matrix do not remove zero rows. 68 | # Immediately return. 69 | seq = torch.arange(len(s)) 70 | if torch.equal(s, seq) or torch.equal(s.flip(0), seq): 71 | return 72 | mask = s != 0 73 | (non_zeros,) = torch.where(mask) 74 | input_indices.append(non_zeros) 75 | 76 | # if all elements are non zero, no indices necessary 77 | if all(len(i) == len(mask) for i in input_indices): 78 | return 79 | 80 | unique_indices = torch.cat(input_indices).unique() 81 | input_shape = list(input[0].shape) 82 | input_shape[self.feature_dim] = len(unique_indices) 83 | 84 | _, input_indices[0] = torch.where( 85 | input_indices[0][:, None] == unique_indices[None] 86 | ) 87 | _, input_indices[1] = torch.where( 88 | input_indices[1][:, None] == unique_indices[None] 89 | ) 90 | 91 | self.input_indices = input_indices 92 | self.input_shape = tuple(input_shape) 93 | self.out = nn.Parameter( 94 | torch.zeros(self.input_shape, device=input[0].device, dtype=input[0].dtype), 95 | requires_grad=False, 96 | ) 97 | 98 | def __str__(self): 99 | if self.input_indices: 100 | return "Add({}, {}, {})".format( 101 | tuple(self.input_shape), 102 | len(self.input_indices[0]), 103 | len(self.input_indices[1]), 104 | ) 105 | else: 106 | return "Add(None, None)" 107 | -------------------------------------------------------------------------------- /onnx2pytorch/operations/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | 3 | from torch import nn 4 | 5 | 6 | class Operator(nn.Module, ABC): 7 | @staticmethod 8 | def get_axis(input_shape, input_feature_axis): 9 | """ 10 | Parameters 11 | ---------- 12 | input_shape: torch.Size 13 | input_feature_axis: int 14 | 15 | Returns 16 | ------- 17 | axis: tuple 18 | Axis to aggregate over. 19 | """ 20 | if input_feature_axis < 0: 21 | input_feature_axis += len(input_shape) 22 | # select and sum all axes except the feature one 23 | axis = set(range(len(input_shape))) - {input_feature_axis} 24 | return tuple(axis) 25 | 26 | 27 | class OperatorWrapper(Operator, ABC): 28 | def __init__(self, op): 29 | """ 30 | This class enables any function to become a subclass of nn.Module 31 | The class name is equal to the op.__name__ 32 | 33 | Parameters 34 | ---------- 35 | op: function or builtin_function_or_method 36 | Any torch function. It is used in-place of forward method. 37 | """ 38 | self.forward = op 39 | self.__class__.__name__ = op.__name__ 40 | super().__init__() 41 | -------------------------------------------------------------------------------- /onnx2pytorch/operations/batchnorm.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | from torch import nn 4 | from torch.nn.modules.batchnorm import _BatchNorm 5 | 6 | try: 7 | from torch.nn.modules.batchnorm import _LazyNormBase 8 | 9 | class _LazyBatchNorm(_LazyNormBase, _BatchNorm): 10 | 11 | cls_to_become = _BatchNorm 12 | 13 | 14 | except ImportError: 15 | # for torch < 1.10.0 16 | from torch.nn.modules.batchnorm import _LazyBatchNorm 17 | 18 | 19 | class LazyBatchNormUnsafe(_LazyBatchNorm): 20 | def __init__(self, *args, spatial=True, **kwargs): 21 | if not spatial: 22 | warnings.warn("Non-spatial BatchNorm not implemented.", RuntimeWarning) 23 | super().__init__(*args, **kwargs) 24 | 25 | def _check_input_dim(self, input): 26 | return 27 | 28 | 29 | class BatchNormUnsafe(_BatchNorm): 30 | def __init__(self, *args, spatial=True, **kwargs): 31 | if not spatial: 32 | warnings.warn("Non-spatial BatchNorm not implemented.", RuntimeWarning) 33 | super().__init__(*args, **kwargs) 34 | 35 | def _check_input_dim(self, input): 36 | return 37 | 38 | 39 | class BatchNormWrapper(nn.Module): 40 | def __init__(self, torch_params, *args, **kwargs): 41 | super().__init__() 42 | self.has_lazy = len(torch_params) == 0 43 | if self.has_lazy: 44 | self.bnu = LazyBatchNormUnsafe(*args, **kwargs) 45 | else: 46 | kwargs["num_features"] = torch_params[0].shape[0] 47 | self.bnu = BatchNormUnsafe(*args, **kwargs) 48 | keys = ["weight", "bias", "running_mean", "running_var"] 49 | for key, value in zip(keys, torch_params): 50 | getattr(self.bnu, key).data = value 51 | 52 | def forward(self, X, scale=None, B=None, input_mean=None, input_var=None): 53 | if self.has_lazy: 54 | self.bnu.initialize_parameters(X) 55 | 56 | if scale is not None: 57 | getattr(self.bnu, "weight").data = scale 58 | if B is not None: 59 | getattr(self.bnu, "bias").data = scale 60 | if input_mean is not None: 61 | getattr(self.bnu, "running_mean").data = input_mean 62 | if input_var is not None: 63 | getattr(self.bnu, "running_var").data = input_var 64 | 65 | return self.bnu.forward(X) 66 | -------------------------------------------------------------------------------- /onnx2pytorch/operations/bitshift.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class BitShift(nn.Module): 6 | def __init__(self, direction): 7 | if direction not in ("LEFT", "RIGHT"): 8 | raise ValueError("invalid BitShift direction {}".format(direction)) 9 | 10 | self.direction = direction 11 | super().__init__() 12 | 13 | def forward(self, X, Y): 14 | if self.direction == "LEFT": 15 | return X << Y 16 | elif self.direction == "RIGHT": 17 | return X >> Y 18 | -------------------------------------------------------------------------------- /onnx2pytorch/operations/cast.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class Cast(nn.Module): 6 | def __init__(self, dtype): 7 | if isinstance(dtype, str): 8 | dtype = getattr(torch, dtype.lower()) 9 | self.dtype = dtype 10 | super().__init__() 11 | 12 | def forward(self, input: torch.Tensor): 13 | return input.to(self.dtype) 14 | 15 | def extra_repr(self) -> str: 16 | return "dtype={}".format(self.dtype) 17 | -------------------------------------------------------------------------------- /onnx2pytorch/operations/clip.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class Clip(nn.Module): 6 | def __init__(self, min=None, max=None): 7 | super().__init__() 8 | self.min = min 9 | self.max = max 10 | 11 | def forward(self, input, min=None, max=None): 12 | if min is None: 13 | min = self.min 14 | if max is None: 15 | max = self.max 16 | if min is None and max is None: 17 | return input 18 | else: 19 | return torch.clamp(input, min=min, max=max) 20 | -------------------------------------------------------------------------------- /onnx2pytorch/operations/constant.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | 5 | 6 | class Constant(nn.Module): 7 | def __init__(self, constant): 8 | super().__init__() 9 | self.register_buffer("constant", torch.from_numpy(np.copy(constant))) 10 | 11 | def forward(self): 12 | return self.constant 13 | 14 | def extra_repr(self) -> str: 15 | return "constant={}".format(self.constant) 16 | -------------------------------------------------------------------------------- /onnx2pytorch/operations/constantofshape.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | 5 | 6 | class ConstantOfShape(nn.Module): 7 | def __init__(self, constant=None): 8 | super().__init__() 9 | if constant is None: 10 | const = torch.tensor(1.0, dtype=torch.float32) 11 | else: 12 | const = torch.from_numpy(np.copy(constant)) 13 | self.register_buffer("constant", const) 14 | 15 | def forward(self, shape: torch.Tensor): 16 | return self.constant.expand(*shape).to(shape.device) 17 | 18 | def extra_repr(self) -> str: 19 | return "constant={}".format(self.constant) 20 | -------------------------------------------------------------------------------- /onnx2pytorch/operations/div.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class Div(nn.Module): 6 | def forward(self, input, other): 7 | res_type = torch.result_type(input, other) 8 | true_quotient = torch.true_divide(input, other) 9 | if res_type.is_floating_point: 10 | res = true_quotient 11 | else: 12 | res = torch.floor(true_quotient).to(res_type) 13 | return res 14 | -------------------------------------------------------------------------------- /onnx2pytorch/operations/expand.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class Expand(nn.Module): 6 | def forward(self, input: torch.Tensor, shape: torch.Tensor): 7 | if isinstance(shape, torch.Tensor): 8 | shape = shape.to(torch.int64) 9 | try: 10 | out = input.expand(torch.Size(shape)) 11 | except RuntimeError: 12 | out = input * torch.ones( 13 | torch.Size(shape), dtype=input.dtype, device=input.device 14 | ) 15 | return out 16 | -------------------------------------------------------------------------------- /onnx2pytorch/operations/flatten.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class Flatten(nn.Module): 6 | def __init__(self, start_dim=1, end_dim=-1): 7 | super().__init__() 8 | self.start_dim = start_dim 9 | self.end_dim = end_dim 10 | 11 | def forward(self, input: torch.Tensor): 12 | return torch.flatten(input, start_dim=self.start_dim, end_dim=self.end_dim) 13 | -------------------------------------------------------------------------------- /onnx2pytorch/operations/gather.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class Gather(nn.Module): 6 | def __init__(self, dim=0): 7 | self.dim = dim 8 | self.selection = [slice(None) for _ in range(dim)] 9 | super().__init__() 10 | 11 | def forward(self, data: torch.Tensor, indices: torch.Tensor): 12 | selection = self.selection + [indices.to(torch.int64)] 13 | return data.__getitem__(selection) 14 | -------------------------------------------------------------------------------- /onnx2pytorch/operations/gathernd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class GatherND(nn.Module): 6 | def __init__(self, batch_dims=0): 7 | if batch_dims != 0: 8 | raise NotImplementedError( 9 | f"GatherND for batch_dims={batch_dims} not implemented." 10 | ) 11 | self.batch_dims = batch_dims 12 | super().__init__() 13 | 14 | def forward(self, data: torch.Tensor, indices: torch.Tensor): 15 | orig_shape = list(indices.shape) 16 | num_samples = torch.prod(torch.tensor(orig_shape[:-1])) 17 | m = orig_shape[-1] 18 | n = len(data.shape) 19 | 20 | if m > n: 21 | raise ValueError( 22 | f"The last dimension of indices must be <= the rank of data." 23 | f"Got indices:{indices.shape}, data:{data.shape}. {m} > {n}" 24 | ) 25 | out_shape = orig_shape[:-1] + list(data.shape)[m:] 26 | 27 | indices = indices.reshape((num_samples, m)).transpose(0, 1) 28 | indices = torch.split(indices, 1, 0) 29 | output = data[indices] # (num_samples, ...) 30 | return output.reshape(out_shape).contiguous() 31 | -------------------------------------------------------------------------------- /onnx2pytorch/operations/globalaveragepool.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class GlobalAveragePool(nn.Module): 6 | def forward(self, input: torch.Tensor): 7 | spatial_shape = input.ndimension() - 2 8 | dim = tuple(range(2, spatial_shape + 2)) 9 | return torch.mean(input, dim=dim, keepdim=True) 10 | -------------------------------------------------------------------------------- /onnx2pytorch/operations/hardsigmoid.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | class Hardsigmoid(nn.Module): 8 | def __new__(cls, alpha=0.2, beta=0.5): 9 | """ 10 | If alpha and beta same as default values for torch's Hardsigmoid, 11 | return torch's Hardsigmoid. Else, return custom Hardsigmoid. 12 | """ 13 | if math.isclose(alpha, 1 / 6, abs_tol=1e-2) and beta == 0.5: 14 | return nn.Hardsigmoid() 15 | else: 16 | return super().__new__(cls) 17 | 18 | def __init__(self, alpha=0.2, beta=0.5): 19 | super().__init__() 20 | self.alpha = alpha 21 | self.beta = beta 22 | 23 | def forward(self, input): 24 | return torch.clip(input * self.alpha + self.beta, 0, 1) 25 | -------------------------------------------------------------------------------- /onnx2pytorch/operations/instancenorm.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import torch 3 | 4 | from torch.nn.modules.instancenorm import _InstanceNorm 5 | 6 | try: 7 | from torch.nn.modules.batchnorm import _LazyNormBase 8 | 9 | class _LazyInstanceNorm(_LazyNormBase, _InstanceNorm): 10 | cls_to_become = _InstanceNorm 11 | 12 | except ImportError: 13 | from torch.nn.modules.lazy import LazyModuleMixin 14 | from torch.nn.parameter import UninitializedBuffer, UninitializedParameter 15 | 16 | class _LazyInstanceNorm(LazyModuleMixin, _InstanceNorm): 17 | weight: UninitializedParameter # type: ignore[assignment] 18 | bias: UninitializedParameter # type: ignore[assignment] 19 | 20 | cls_to_become = _InstanceNorm 21 | 22 | def __init__( 23 | self, 24 | eps=1e-5, 25 | momentum=0.1, 26 | affine=True, 27 | track_running_stats=True, 28 | device=None, 29 | dtype=None, 30 | ) -> None: 31 | factory_kwargs = {"device": device, "dtype": dtype} 32 | super(_LazyInstanceNorm, self).__init__( 33 | # affine and track_running_stats are hardcoded to False to 34 | # avoid creating tensors that will soon be overwritten. 35 | 0, 36 | eps, 37 | momentum, 38 | False, 39 | False, 40 | **factory_kwargs, 41 | ) 42 | self.affine = affine 43 | self.track_running_stats = track_running_stats 44 | if self.affine: 45 | self.weight = UninitializedParameter(**factory_kwargs) 46 | self.bias = UninitializedParameter(**factory_kwargs) 47 | if self.track_running_stats: 48 | self.running_mean = UninitializedBuffer(**factory_kwargs) 49 | self.running_var = UninitializedBuffer(**factory_kwargs) 50 | self.num_batches_tracked = torch.tensor( 51 | 0, 52 | dtype=torch.long, 53 | **{k: v for k, v in factory_kwargs.items() if k != "dtype"}, 54 | ) 55 | 56 | def reset_parameters(self) -> None: 57 | if not self.has_uninitialized_params() and self.num_features != 0: 58 | super().reset_parameters() 59 | 60 | def initialize_parameters(self, input) -> None: # type: ignore[override] 61 | if self.has_uninitialized_params(): 62 | self.num_features = input.shape[1] 63 | if self.affine: 64 | assert isinstance(self.weight, UninitializedParameter) 65 | assert isinstance(self.bias, UninitializedParameter) 66 | self.weight.materialize((self.num_features,)) 67 | self.bias.materialize((self.num_features,)) 68 | if self.track_running_stats: 69 | self.running_mean.materialize( 70 | (self.num_features,) 71 | ) # type:ignore[union-attr] 72 | self.running_var.materialize( 73 | (self.num_features,) 74 | ) # type:ignore[union-attr] 75 | self.reset_parameters() 76 | 77 | 78 | class InstanceNormMixin: 79 | """Skips dimension check.""" 80 | 81 | def __init__(self, *args, affine=True, **kwargs): 82 | self.no_batch_dim = None # no_batch_dim has to be set at runtime 83 | super().__init__(*args, affine=affine, **kwargs) 84 | 85 | def set_no_dim_batch_dim(self, no_batch_dim): 86 | self.no_batch_dim = no_batch_dim 87 | 88 | def _check_input_dim(self, input): 89 | return 90 | 91 | def _get_no_batch_dim(self): 92 | return self.no_batch_dim 93 | 94 | 95 | class LazyInstanceNormUnsafe(InstanceNormMixin, _LazyInstanceNorm): 96 | pass 97 | 98 | 99 | class InstanceNormUnsafe(InstanceNormMixin, _InstanceNorm): 100 | pass 101 | 102 | 103 | class InstanceNormWrapper(torch.nn.Module): 104 | def __init__(self, torch_params, *args, affine=True, **kwargs): 105 | super().__init__() 106 | self.has_lazy = len(torch_params) == 0 107 | if self.has_lazy: 108 | self.inu = LazyInstanceNormUnsafe(*args, affine=affine, **kwargs) 109 | else: 110 | kwargs["num_features"] = torch_params[0].shape[0] 111 | self.inu = InstanceNormUnsafe(*args, affine=affine, **kwargs) 112 | keys = ["weight", "bias"] 113 | for key, value in zip(keys, torch_params): 114 | getattr(self.inu, key).data = value 115 | 116 | def forward(self, input, scale=None, B=None): 117 | if self.has_lazy: 118 | self.inu.initialize_parameters(input) 119 | 120 | if scale is not None: 121 | getattr(self.inu, "weight").data = scale 122 | if B is not None: 123 | getattr(self.inu, "bias").data = B 124 | 125 | if self.inu.no_batch_dim is None: 126 | self.inu.set_no_dim_batch_dim(input.dim() - 1) 127 | 128 | return self.inu.forward(input) 129 | -------------------------------------------------------------------------------- /onnx2pytorch/operations/loop.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from copy import deepcopy 3 | from functools import partial 4 | from importlib import import_module 5 | import warnings 6 | 7 | import numpy as np 8 | import onnx 9 | import torch 10 | from onnx import numpy_helper 11 | from torch import nn 12 | from torch.nn.modules.linear import Identity 13 | 14 | from onnx2pytorch.utils import ( 15 | get_inputs_names, 16 | get_outputs_names, 17 | ) 18 | 19 | 20 | class Loop(nn.Module): 21 | def __init__( 22 | self, 23 | opset_version, 24 | batch_dim, 25 | body: onnx.GraphProto, 26 | ): 27 | super().__init__() 28 | self.ops = import_module("onnx2pytorch.convert.operations") 29 | self.c = import_module("onnx2pytorch.constants") 30 | 31 | self.body = body 32 | self.batch_dim = batch_dim 33 | 34 | self.input_names = get_inputs_names(body) 35 | self.output_names = get_outputs_names(body) 36 | 37 | # Creates mapping from node (identified by first output) to submodule 38 | self.mapping = {} 39 | for op_id, op_name, op in self.ops.convert_operations( 40 | body, opset_version, batch_dim 41 | ): 42 | setattr(self, op_name, op) 43 | self.mapping[op_id] = op_name 44 | 45 | # Store initializers as buffers 46 | for tensor in self.body.initializer: 47 | buffer_name = self.ops.get_buffer_name(tensor.name) 48 | self.register_buffer( 49 | buffer_name, 50 | torch.from_numpy(numpy_helper.to_array(tensor)), 51 | ) 52 | 53 | # We do not track dependencies (for memory reduction) within loops. 54 | # This would be complicated due to loop-carried dependencies. 55 | 56 | def forward(self, enclosing_modules, enclosing_activations, *inputs): 57 | """ 58 | Parameters 59 | ---------- 60 | enclosing_modules: tuple of nn.Modules 61 | Module(s) from enclosing scope(s), containing initializers as buffers. 62 | enclosing_activations: dict 63 | All activations from the enclosing scope. 64 | inputs: list 65 | Inputs to Loop node (length >= 2), comprising M, cond, and v_initial. 66 | 67 | Returns 68 | ------- 69 | v_final_and_scan_outputs: list 70 | Final N loop carried dependency values, then K scan_outputs. 71 | """ 72 | 73 | N = len(self.input_names) - 2 74 | K = len(self.output_names) - (1 + N) 75 | 76 | M = inputs[0] 77 | cond = inputs[1] 78 | v_initial = inputs[2:] 79 | 80 | iteration_num_name = self.input_names[0] 81 | cond_in_name = self.input_names[1] 82 | loop_carried_in_names = self.input_names[2:] 83 | cond_out_name = self.output_names[0] 84 | loop_carried_out_names = self.output_names[1 : N + 1] 85 | scan_outputs_names = self.output_names[1 + N :] 86 | 87 | buffer_modules = enclosing_modules + (self,) 88 | 89 | activations = {} 90 | activations.update(zip(loop_carried_in_names, v_initial)) 91 | activations.update(enclosing_activations) 92 | 93 | scan_outputs = defaultdict(list) 94 | i = torch.tensor(0) 95 | while i < M and cond: 96 | activations[iteration_num_name] = i 97 | activations[cond_in_name] = cond 98 | for node in self.body.node: 99 | # Identifying the layer ids and names 100 | out_op_id = node.output[0] 101 | out_op_name = self.mapping[out_op_id] 102 | in_op_names = [ 103 | self.mapping.get(in_op_id, in_op_id) 104 | for in_op_id in node.input 105 | if in_op_id in activations 106 | ] 107 | 108 | # getting correct layer 109 | op = getattr(self, out_op_name) 110 | 111 | # if first layer choose input as in_activations 112 | # if not in_op_names and len(node.input) == 1: 113 | # in_activations = input 114 | if isinstance(op, self.c.STANDARD_LAYERS) or ( 115 | isinstance(op, self.c.COMPOSITE_LAYERS) 116 | and any(isinstance(x, self.c.STANDARD_LAYERS) for x in op.modules()) 117 | ): 118 | in_activations = [ 119 | activations[in_op_id] 120 | for in_op_id in node.input 121 | if in_op_id in activations 122 | ] 123 | else: 124 | in_activations = [ 125 | activations[in_op_id] if in_op_id in activations 126 | # if in_op_id not in activations neither in parameters then 127 | # it must be the initial input 128 | else self.ops.get_init_parameter( 129 | buffer_modules, in_op_id, inputs[0] 130 | ) 131 | for in_op_id in node.input 132 | ] 133 | 134 | in_activations = [ 135 | in_act for in_act in in_activations if in_act is not None 136 | ] 137 | 138 | # store activations for next layer 139 | if isinstance(op, Loop): 140 | outputs = op(buffer_modules, activations, *in_activations) 141 | for out_act_name, output in zip(node.output, outputs): 142 | activations[out_op_id] = output 143 | if out_act_name in scan_outputs_names: 144 | scan_outputs[out_act_name].append(output) 145 | elif isinstance(op, partial) and op.func == torch.cat: 146 | output = op(in_activations) 147 | activations[out_op_id] = output 148 | if out_op_id in scan_outputs_names: 149 | scan_outputs[out_op_id].append(output) 150 | elif isinstance(op, Identity): 151 | # After batch norm fusion the batch norm parameters 152 | # were all passed to identity instead of first one only 153 | output = op(in_activations[0]) 154 | activations[out_op_id] = op(output) 155 | if out_op_id in scan_outputs_names: 156 | scan_outputs[out_op_id].append(output) 157 | elif isinstance(op, self.c.MULTIOUTPUT_LAYERS) or ( 158 | isinstance(op, self.c.COMPOSITE_LAYERS) 159 | and any( 160 | isinstance(x, self.c.MULTIOUTPUT_LAYERS) for x in op.modules() 161 | ) 162 | ): 163 | outputs = op(*in_activations) 164 | for out_act_name, output in zip(node.output, outputs): 165 | activations[out_act_name] = output 166 | if out_act_name in scan_outputs_names: 167 | scan_outputs[out_act_name].append(output) 168 | else: 169 | output = op(*in_activations) 170 | activations[out_op_id] = output 171 | if out_op_id in scan_outputs_names: 172 | scan_outputs[out_op_id].append(output) 173 | 174 | # Prepare for next iteration 175 | cond = activations[cond_out_name] 176 | i += 1 177 | loop_carried_outputs = [ 178 | activations[aname] for aname in loop_carried_out_names 179 | ] 180 | activations.update(zip(loop_carried_in_names, loop_carried_outputs)) 181 | 182 | # Set outputs to N loop carried final values, followed by K scan outputs 183 | outputs = [activations[lcn] for lcn in loop_carried_out_names] 184 | for son in scan_outputs_names: 185 | outputs.append(torch.cat([so.unsqueeze(dim=0) for so in scan_outputs[son]])) 186 | return outputs 187 | -------------------------------------------------------------------------------- /onnx2pytorch/operations/lstm.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class LSTMWrapper(nn.Module): 5 | """Wraps a 1-layer nn.LSTM to match the API of an ONNX LSTM. 6 | 7 | It expects h_0 and c_0 as separate inputs rather than as a tuple, 8 | and returns h_n and c_n as separate outputs rather than as a tuple. 9 | """ 10 | 11 | def __init__(self, lstm_module: nn.LSTM): 12 | super().__init__() 13 | self.lstm = lstm_module 14 | 15 | def forward(self, input, h_0=None, c_0=None): 16 | (seq_len, batch, input_size) = input.shape 17 | num_layers = 1 18 | num_directions = self.lstm.bidirectional + 1 19 | hidden_size = self.lstm.hidden_size 20 | if h_0 is None or c_0 is None or h_0.numel() == 0 or c_0.numel() == 0: 21 | tuple_0 = None 22 | else: 23 | tuple_0 = (h_0, c_0) 24 | output, (h_n, c_n) = self.lstm(input, tuple_0) 25 | 26 | # Y has shape (seq_length, num_directions, batch_size, hidden_size) 27 | Y = output.view(seq_len, batch, num_directions, hidden_size).transpose(1, 2) 28 | # Y_h has shape (num_directions, batch_size, hidden_size) 29 | Y_h = h_n.view(num_layers, num_directions, batch, hidden_size).squeeze(0) 30 | # Y_c has shape (num_directions, batch_size, hidden_size) 31 | Y_c = c_n.view(num_layers, num_directions, batch, hidden_size).squeeze(0) 32 | return Y, Y_h, Y_c 33 | -------------------------------------------------------------------------------- /onnx2pytorch/operations/matmul.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class MatMul(nn.Module): 6 | def forward(self, A, V): 7 | return torch.matmul(A, V) 8 | -------------------------------------------------------------------------------- /onnx2pytorch/operations/nonmaxsuppression.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | from torch import nn 4 | 5 | 6 | class NonMaxSuppression(nn.Module): 7 | def __init__(self, center_point_box=0): 8 | self.center_point_box = center_point_box 9 | super().__init__() 10 | 11 | def forward( 12 | self, 13 | boxes, 14 | scores, 15 | max_output_boxes_per_class=0, 16 | iou_threshold=0.0, 17 | score_threshold=0.0, 18 | ): 19 | nms_rs_list = [] 20 | for i in range(boxes.shape[0]): 21 | for j in range(scores.shape[1]): 22 | for k in range(boxes.shape[1]): 23 | if self.center_point_box == 1: 24 | boxes[i][k] = torchvision.ops.box_convert( 25 | boxes[i][k], "cxcywh", "xyxy" 26 | ) 27 | else: 28 | x1, y1, x2, y2 = boxes[i][k] 29 | if x1 < x2 and y1 < y2: 30 | continue 31 | indices = [0, 1, 2, 3] 32 | if x1 > x2: 33 | indices = [indices[l] for l in (2, 1, 0, 3)] 34 | if y1 > y2: 35 | indices = [indices[l] for l in (0, 3, 2, 1)] 36 | boxes[i][k] = boxes[i][k].gather(0, torch.tensor(indices)) 37 | mask = scores[i][j] >= score_threshold 38 | nms_rs = torchvision.ops.nms( 39 | boxes[i], scores[i][j], float(iou_threshold) 40 | )[:max_output_boxes_per_class] 41 | nms_rs_masked = nms_rs[ 42 | : mask[nms_rs].nonzero(as_tuple=False).flatten().shape[0] 43 | ] 44 | batch_index = torch.full((nms_rs_masked.shape[0], 1), i) 45 | class_index = torch.full((nms_rs_masked.shape[0], 1), j) 46 | nms_rs_list.append( 47 | torch.cat( 48 | (batch_index, class_index, nms_rs_masked.unsqueeze(1)), dim=1 49 | ) 50 | ) 51 | return torch.cat(nms_rs_list, dim=0) 52 | -------------------------------------------------------------------------------- /onnx2pytorch/operations/onehot.py: -------------------------------------------------------------------------------- 1 | from torch.nn.functional import one_hot 2 | from onnx2pytorch.operations.base import Operator 3 | 4 | 5 | class OneHot(Operator): 6 | def __init__(self, dim=-1, non_zero_values_only=False): 7 | self.dim = dim 8 | self.non_zero_values_only = non_zero_values_only 9 | super().__init__() 10 | 11 | def forward(self, indices, depth, values): 12 | if self.non_zero_values_only: 13 | off_value, on_value = -1, 1 14 | else: 15 | off_value, on_value = values 16 | out = one_hot(indices.to(int), depth.to(int).item()) 17 | out = out * (on_value - off_value) + off_value 18 | 19 | rank = len(indices.shape) 20 | if self.dim < 0: 21 | self.dim += rank + 1 22 | if not rank == self.dim: # permute only if dim not last dimension 23 | order = list(range(len(indices.shape))) 24 | order.insert(self.dim, -1) 25 | out = out.permute(order) 26 | return out 27 | -------------------------------------------------------------------------------- /onnx2pytorch/operations/pad.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | 3 | from onnx2pytorch.operations.base import Operator 4 | 5 | 6 | class Pad(Operator): 7 | def __init__(self, mode="constant", padding=None): 8 | self.mode = mode 9 | self.padding = padding 10 | super().__init__() 11 | 12 | def forward(self, input, pads=None, value=0): 13 | if self.padding is not None: 14 | pads = self.padding 15 | elif pads is None: 16 | raise TypeError("forward() missing 1 required positional argument: 'pads'") 17 | out = F.pad(input, list(pads), mode=self.mode, value=value) 18 | return out 19 | 20 | def extra_repr(self) -> str: 21 | return "mode={}, padding={}".format(self.mode, self.padding) 22 | -------------------------------------------------------------------------------- /onnx2pytorch/operations/prelu.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class PRelu(nn.Module): 6 | def forward(self, X: torch.Tensor, slope: torch.Tensor): 7 | return torch.clamp(X, min=0) + torch.clamp(X, max=0) * slope 8 | -------------------------------------------------------------------------------- /onnx2pytorch/operations/range.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class Range(nn.Module): 6 | def forward(self, start: torch.Tensor, limit: torch.Tensor, delta: torch.Tensor): 7 | return torch.arange(start=start, end=limit, step=delta) 8 | -------------------------------------------------------------------------------- /onnx2pytorch/operations/reducel2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class ReduceL2(nn.Module): 6 | def __init__( 7 | self, opset_version, dim=None, keepdim=True, noop_with_empty_axes=False 8 | ): 9 | self.opset_version = opset_version 10 | self.dim = dim 11 | self.keepdim = bool(keepdim) 12 | self.noop_with_empty_axes = noop_with_empty_axes 13 | super().__init__() 14 | 15 | def forward(self, data: torch.Tensor, axes: torch.Tensor = None): 16 | if self.opset_version < 13: 17 | dims = self.dim 18 | else: 19 | dims = axes 20 | if dims is None: 21 | if self.noop_with_empty_axes: 22 | return data 23 | else: 24 | dims = tuple(range(data.ndim)) 25 | 26 | if isinstance(dims, int): 27 | dim = dims 28 | else: 29 | dim = tuple(list(dims)) 30 | 31 | ret = torch.sqrt(torch.sum(torch.square(data), dim=dim, keepdim=self.keepdim)) 32 | return ret 33 | -------------------------------------------------------------------------------- /onnx2pytorch/operations/reducemax.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class ReduceMax(nn.Module): 6 | def __init__(self, dim=None, keepdim=True): 7 | self.dim = dim 8 | self.keepdim = keepdim 9 | super().__init__() 10 | 11 | def forward(self, data: torch.Tensor): 12 | dim = self.dim 13 | if dim is None: 14 | dim = tuple(range(data.ndim)) 15 | return torch.amax(data, dim=dim, keepdim=self.keepdim) 16 | -------------------------------------------------------------------------------- /onnx2pytorch/operations/reducesum.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class ReduceSum(nn.Module): 6 | def __init__( 7 | self, opset_version, dim=None, keepdim=True, noop_with_empty_axes=False 8 | ): 9 | self.opset_version = opset_version 10 | self.dim = dim 11 | self.keepdim = keepdim 12 | self.noop_with_empty_axes = noop_with_empty_axes 13 | super().__init__() 14 | 15 | def forward(self, data: torch.Tensor, axes: torch.Tensor = None): 16 | if self.opset_version < 13: 17 | dims = self.dim 18 | else: 19 | dims = axes 20 | if dims is None: 21 | if self.noop_with_empty_axes: 22 | return data 23 | else: 24 | dims = tuple(range(data.ndim)) 25 | if isinstance(dims, int): 26 | return torch.sum(data, dim=dims, keepdim=self.keepdim) 27 | else: 28 | return torch.sum(data, dim=tuple(list(dims)), keepdim=self.keepdim) 29 | -------------------------------------------------------------------------------- /onnx2pytorch/operations/reshape.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from onnx2pytorch.operations.base import Operator 5 | from onnx2pytorch.utils import assign_values_to_dim, get_selection 6 | 7 | 8 | class Reshape(Operator): 9 | """ 10 | In the initial pass it stores the initial_input_shape. 11 | It uses it to infer the new reshape value from a 12 | smaller pruned input in the following passes. 13 | """ 14 | 15 | def __init__(self, enable_pruning, shape=None, keep_size=True): 16 | super().__init__() 17 | self.enable_pruning = enable_pruning 18 | self.shape = shape 19 | self.initial_input_shape = None 20 | self.feature_dim = -1 21 | self.input_indices = None 22 | self.placeholder = None 23 | self.keep_size = keep_size 24 | 25 | def forward(self, input: torch.Tensor, shape=None): 26 | shape = shape if shape is not None else self.shape 27 | # This raises RuntimeWarning: iterating over a tensor. 28 | shape = [x if x != 0 else input.size(i) for i, x in enumerate(shape)] 29 | 30 | if not self.enable_pruning: 31 | return torch.reshape(input, tuple(shape)) 32 | 33 | inp_shape = torch.tensor(input.shape) 34 | if self.initial_input_shape is None: 35 | self.initial_input_shape = inp_shape 36 | elif len(shape) == 2 and shape[-1] == -1: 37 | pass 38 | elif torch.equal(self.initial_input_shape, inp_shape): 39 | # input's shape did not change 40 | pass 41 | elif self.input_indices is not None: 42 | self.placeholder *= 0 43 | selection = get_selection(self.input_indices, self.feature_dim) 44 | self.placeholder[selection] += input 45 | input = self.placeholder 46 | elif torch.prod(inp_shape) == torch.prod(torch.tensor(shape)): 47 | # If input's shape changed but shape changed to account for this, 48 | # no additional work is needed. 49 | # This happens when shape is dynamically computed by the network. 50 | pass 51 | else: 52 | # If input's shape changed but shape has not accounted for this, 53 | # the reshaped shape must change as well. 54 | c = torch.true_divide(inp_shape, self.initial_input_shape) 55 | if len(c) < len(shape) and shape[0] == 1: 56 | c = torch.cat((torch.tensor([1]), c)) 57 | shape = (c * torch.tensor(shape)).to(int) 58 | return torch.reshape(input, tuple(shape)) 59 | 60 | def set_input_indices(self, input): 61 | input_shape = input[0].shape 62 | if self.feature_dim < 0: 63 | self.feature_dim += len(input_shape) 64 | axis = self.get_axis(input_shape, self.feature_dim) 65 | mask = input[0] != 0 66 | s = mask.sum(axis=tuple(axis)) 67 | mask = s != 0 68 | (non_zeros,) = torch.where(mask) 69 | self.input_indices = non_zeros 70 | self.placeholder = nn.Parameter( 71 | torch.zeros( 72 | *self.initial_input_shape, device=input[0].device, dtype=input[0].dtype 73 | ), 74 | requires_grad=False, 75 | ) 76 | 77 | def extra_repr(self) -> str: 78 | return "shape={}".format(self.shape) 79 | -------------------------------------------------------------------------------- /onnx2pytorch/operations/resize.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import torch 4 | from torch.nn import functional as F 5 | 6 | from onnx2pytorch.operations.base import Operator 7 | 8 | empty_tensor = torch.Tensor([]) 9 | 10 | 11 | class Resize(Operator): 12 | def __init__(self, mode="nearest", align_corners=None, **kwargs): 13 | self.mode = mode 14 | self.align_corners = align_corners 15 | for key in kwargs.keys(): 16 | warnings.warn( 17 | "Pytorch's interpolate uses no {}. " "Result might differ.".format(key) 18 | ) 19 | super().__init__() 20 | 21 | def forward(self, inp, roi=empty_tensor, scales=empty_tensor, sizes=empty_tensor): 22 | if roi.nelement() > 0: 23 | warnings.warn("Pytorch's interpolate uses no roi. Result might differ.") 24 | 25 | scales = list(scales) 26 | sizes = list(sizes) 27 | shape = list(inp.shape) 28 | if shape[:2] == sizes[:2]: 29 | sizes = sizes[2:] # Pytorch's interpolate takes only H and W params 30 | elif scales[:2] == [1, 1]: 31 | scales = scales[2:] 32 | elif len(scales) == 0 and len(sizes) == 0: 33 | raise ValueError("One of the two, scales or sizes, needs to be defined.") 34 | else: 35 | raise NotImplementedError( 36 | "Pytorch's interpolate does not scale batch and channel dimensions." 37 | ) 38 | 39 | if len(scales) == 0: 40 | scales = None 41 | elif len(sizes) == 0: 42 | sizes = None 43 | else: 44 | raise ValueError( 45 | "Only one of the two, scales or sizes, needs to be defined." 46 | ) 47 | 48 | return F.interpolate( 49 | inp, 50 | scale_factor=scales, 51 | size=sizes, 52 | mode=self.mode, 53 | align_corners=self.align_corners, 54 | ) 55 | 56 | 57 | class Upsample(Resize): 58 | """Deprecated onnx operator.""" 59 | 60 | def forward(self, inp, scales): 61 | return super().forward(inp, torch.tensor([]), scales, torch.tensor([])) 62 | -------------------------------------------------------------------------------- /onnx2pytorch/operations/scatter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class Scatter(nn.Module): 6 | def __init__(self, dim=0): 7 | self.dim = dim 8 | super().__init__() 9 | 10 | def forward(self, data: torch.Tensor, indices: torch.Tensor, updates: torch.Tensor): 11 | return torch.scatter(data, self.dim, indices, updates) 12 | -------------------------------------------------------------------------------- /onnx2pytorch/operations/scatterelements.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class ScatterElements(nn.Module): 6 | def __init__(self, dim=0): 7 | self.dim = dim 8 | super().__init__() 9 | 10 | def forward(self, data: torch.Tensor, indices: torch.Tensor, updates: torch.Tensor): 11 | indices[indices < 0] = indices[indices < 0] + data.size(self.dim) 12 | return torch.scatter(data, self.dim, indices, updates) 13 | -------------------------------------------------------------------------------- /onnx2pytorch/operations/scatternd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class ScatterND(nn.Module): 6 | def forward(self, data: torch.Tensor, indices: torch.Tensor, updates: torch.Tensor): 7 | output = data.clone() 8 | k = indices.shape[-1] 9 | indices_list = [] 10 | for i in range(k): 11 | indices_list.append(indices[:, i]) 12 | output[indices_list] = updates 13 | return output 14 | -------------------------------------------------------------------------------- /onnx2pytorch/operations/shape.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class Shape(nn.Module): 6 | def forward(self, input: torch.Tensor): 7 | return torch.tensor(input.shape, device=input.device) 8 | -------------------------------------------------------------------------------- /onnx2pytorch/operations/slice.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | def _to_positive_step(orig_slice, N): 6 | """ 7 | Convert a slice object with a negative step to one with a positive step. 8 | Accessing an iterable with the positive-stepped slice, followed by flipping 9 | the result, should be equivalent to accessing the tensor with the original 10 | slice. Computing positive-step slice requires using N, the length of the 11 | iterable being sliced. This is because PyTorch currently does not support 12 | slicing a tensor with a negative step. 13 | """ 14 | # Get rid of backward slices 15 | start, stop, step = orig_slice.indices(N) 16 | 17 | # Get number of steps and remainder 18 | n, r = divmod(stop - start, step) 19 | if n < 0 or (n == 0 and r == 0): 20 | return slice(0, 0, 1) 21 | if r != 0: # a "stop" index, not a last index 22 | n += 1 23 | 24 | if step < 0: 25 | start, stop, step = start + (n - 1) * step, start - step, -step 26 | else: # step > 0, step == 0 is not allowed 27 | stop = start + n * step 28 | stop = min(stop, N) 29 | 30 | return slice(start, stop, step) 31 | 32 | 33 | class Slice(nn.Module): 34 | def __init__(self, dim=None, starts=None, ends=None, steps=None): 35 | self.dim = [dim] if isinstance(dim, int) else dim 36 | self.starts = starts 37 | self.ends = ends 38 | self.steps = steps 39 | super().__init__() 40 | 41 | def forward( 42 | self, data: torch.Tensor, starts=None, ends=None, axes=None, steps=None 43 | ): 44 | if axes is None: 45 | axes = self.dim 46 | if starts is None: 47 | starts = self.starts 48 | if ends is None: 49 | ends = self.ends 50 | if steps is None: 51 | steps = self.steps 52 | 53 | if isinstance(starts, (tuple, list)): 54 | starts = torch.tensor(starts, device=data.device) 55 | if isinstance(ends, (tuple, list)): 56 | ends = torch.tensor(ends, device=data.device) 57 | if isinstance(steps, (tuple, list)): 58 | steps = torch.tensor(steps, device=data.device) 59 | 60 | # If axes=None set them to (0, 1, 2, ...) 61 | if axes is None: 62 | axes = tuple(torch.arange(len(starts))) 63 | if steps is None: 64 | steps = tuple(torch.tensor(1) for _ in axes) 65 | 66 | axes = [data.ndim + x if x < 0 else x for x in axes] 67 | 68 | selection = [slice(None) for _ in range(max(axes) + 1)] 69 | 70 | flip_dims = [] 71 | for i, axis in enumerate(axes): 72 | raw_slice = slice( 73 | starts[i].to(dtype=torch.long, device=data.device), 74 | ends[i].to(dtype=torch.long, device=data.device), 75 | steps[i].to(dtype=torch.long, device=data.device), 76 | ) 77 | if steps[i] < 0: 78 | selection[axis] = _to_positive_step(raw_slice, data.shape[axis]) 79 | flip_dims.append(axis) 80 | else: 81 | selection[axis] = raw_slice 82 | if len(flip_dims) > 0: 83 | return torch.flip(data.__getitem__(selection), flip_dims) 84 | else: 85 | # For torch < 1.8.1, torch.flip cannot handle empty dims 86 | return data.__getitem__(selection) 87 | -------------------------------------------------------------------------------- /onnx2pytorch/operations/split.py: -------------------------------------------------------------------------------- 1 | from itertools import accumulate 2 | 3 | import torch 4 | 5 | from onnx2pytorch.operations.base import Operator 6 | from onnx2pytorch.utils import assign_values_to_dim 7 | 8 | 9 | class Split(Operator): 10 | def __init__( 11 | self, 12 | enable_pruning=False, 13 | split_size_or_sections=None, 14 | number_of_splits=None, 15 | dim=0, 16 | keep_size=True, 17 | ): 18 | """ 19 | Parameters 20 | ---------- 21 | enable_pruning: bool 22 | split_size_or_sections: tuple[int] 23 | number_of_splits: int 24 | The number of equal splits along dim. 25 | dim: int 26 | Split dimension. Tensor is split over this axis. 27 | keep_size: bool 28 | If True it keeps the size of the split the same as in initial pass. 29 | Else it splits it accordingly to the pruned input. 30 | """ 31 | if enable_pruning: 32 | assert ( 33 | split_size_or_sections is not None or number_of_splits is not None 34 | ), "One of the parameters needs to be set." 35 | self.enable_pruning = enable_pruning 36 | self.dim = dim 37 | self.split_size_or_sections = split_size_or_sections 38 | self.number_of_splits = number_of_splits 39 | self.keep_size = keep_size 40 | self.input_indices = None 41 | self.placeholder = None 42 | super().__init__() 43 | 44 | def _get_sections(self, input): 45 | """Calculate sections from number of splits.""" 46 | dim_size = input[0].shape[self.dim] 47 | assert ( 48 | dim_size % self.number_of_splits == 0 49 | ), "Dimension size {} not equally divisible by {}.".format( 50 | dim_size, self.number_of_splits 51 | ) 52 | s = dim_size // self.number_of_splits 53 | sections = tuple(s for _ in range(self.number_of_splits)) 54 | return sections 55 | 56 | def forward(self, *input): 57 | if not self.enable_pruning and len(input) == 2: 58 | return torch.split(input[0], list(input[1]), dim=self.dim) 59 | if self.split_size_or_sections is None: 60 | self.split_size_or_sections = self._get_sections(input) 61 | 62 | if self.input_indices is not None: 63 | self.placeholder *= 0 64 | assign_values_to_dim( 65 | self.placeholder, input[0], self.input_indices, self.dim 66 | ) 67 | split = torch.split(self.placeholder, self.split_size_or_sections, self.dim) 68 | else: 69 | split = torch.split(*input, self.split_size_or_sections, dim=self.dim) 70 | return split 71 | 72 | def set_input_indices(self, input: tuple): 73 | assert isinstance(input, (tuple, list)) 74 | 75 | inp = input[0] 76 | # We assume that aggregation dimensions correspond to split dimension 77 | axis = self.get_axis(inp.shape, self.dim) 78 | 79 | # Mask shows where features are non zero in the whole axis. 80 | mask = inp != 0 81 | if len(inp.shape) > 1: 82 | mask = mask.sum(axis=tuple(axis)) != 0 83 | 84 | if not self.keep_size: 85 | # Read docstrings 86 | if isinstance(self.split_size_or_sections, tuple): 87 | indices = list(accumulate(self.split_size_or_sections)) 88 | indices = torch.tensor(indices) - 1 89 | else: 90 | raise NotImplementedError("Not implemented for split size.") 91 | cs = torch.cumsum(mask, 0) 92 | ind = [0] + cs[indices].tolist() 93 | sec = [ind[i + 1] - ind[i] for i in range(len(ind) - 1)] 94 | self.split_size_or_sections = sec 95 | else: 96 | (self.input_indices,) = torch.where(mask) 97 | self.placeholder = torch.zeros(inp.shape) 98 | 99 | def __str__(self): 100 | return "Split(dim={})".format(self.dim) 101 | -------------------------------------------------------------------------------- /onnx2pytorch/operations/squeeze.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from onnx2pytorch.operations.base import Operator 4 | from onnx2pytorch.utils import get_selection 5 | 6 | 7 | class Squeeze(Operator): 8 | def __init__(self, opset_version, dim=None): 9 | self.opset_version = opset_version 10 | self.dim = dim 11 | super().__init__() 12 | 13 | def forward(self, input: torch.Tensor, axes: torch.Tensor = None): 14 | if self.opset_version < 13: 15 | dims = self.dim 16 | else: 17 | dims = axes 18 | 19 | if dims is None: 20 | return torch.squeeze(input) 21 | elif isinstance(dims, int): 22 | return torch.squeeze(input, dim=dims) 23 | else: 24 | for dim in sorted(dims, reverse=True): 25 | input = torch.squeeze(input, dim=dim) 26 | return input 27 | -------------------------------------------------------------------------------- /onnx2pytorch/operations/thresholdedrelu.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class ThresholdedRelu(nn.Module): 6 | def __init__(self, alpha=1.0): 7 | self.alpha = alpha 8 | super().__init__() 9 | 10 | def forward(self, X: torch.Tensor): 11 | Y = torch.clamp(X, min=self.alpha) 12 | Y[Y == self.alpha] = 0.0 13 | return Y 14 | -------------------------------------------------------------------------------- /onnx2pytorch/operations/tile.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class Tile(nn.Module): 6 | def forward(self, input: torch.Tensor, repeats: torch.Tensor): 7 | return torch.tile(input, tuple(repeats)) 8 | -------------------------------------------------------------------------------- /onnx2pytorch/operations/topk.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class TopK(nn.Module): 6 | def __init__(self, axis=-1, largest=1, sorted=1): 7 | self.axis = axis 8 | self.largest = bool(largest) 9 | self.sorted = bool(sorted) 10 | super().__init__() 11 | 12 | def forward(self, X: torch.Tensor, K: torch.Tensor): 13 | return torch.topk( 14 | X, int(K), dim=self.axis, largest=self.largest, sorted=self.sorted 15 | ) 16 | -------------------------------------------------------------------------------- /onnx2pytorch/operations/transpose.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class Transpose(nn.Module): 6 | def __init__(self, dims=None): 7 | self.dims = dims 8 | super().__init__() 9 | 10 | def forward(self, data: torch.Tensor): 11 | if not self.dims: 12 | dims = tuple(reversed(range(data.dim()))) 13 | else: 14 | dims = self.dims 15 | transposed = data.permute(dims) 16 | return transposed 17 | -------------------------------------------------------------------------------- /onnx2pytorch/operations/unsqueeze.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from onnx2pytorch.operations.base import Operator 5 | 6 | 7 | class Unsqueeze(Operator): 8 | def __init__(self, opset_version, dim=None): 9 | self.opset_version = opset_version 10 | self.dim = dim 11 | super().__init__() 12 | 13 | def forward(self, data: torch.Tensor, axes: torch.Tensor = None): 14 | if self.opset_version < 13: 15 | dims = self.dim 16 | else: 17 | dims = torch.Size(axes) 18 | if dims is None: 19 | raise ValueError("Unsqueeze expects axes") 20 | elif isinstance(dims, int): 21 | return torch.unsqueeze(data, dim=dims) 22 | else: 23 | for dim in sorted(dims): 24 | data = torch.unsqueeze(data, dim=dim) 25 | return data 26 | -------------------------------------------------------------------------------- /onnx2pytorch/operations/where.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class Where(nn.Module): 6 | def forward(self, condition: torch.Tensor, X: torch.Tensor, Y=torch.Tensor): 7 | res_type = torch.result_type(X, Y) 8 | output = torch.where(condition, X.to(res_type), Y.to(res_type)) 9 | return output 10 | -------------------------------------------------------------------------------- /onnx2pytorch/utils.py: -------------------------------------------------------------------------------- 1 | import io 2 | 3 | import torch 4 | import numpy as np 5 | import onnx 6 | 7 | try: 8 | import onnxruntime as ort 9 | except ImportError: 10 | ort = None 11 | 12 | 13 | def value_wrapper(value): 14 | def callback(*args, **kwargs): 15 | return value 16 | 17 | return callback 18 | 19 | 20 | def is_constant(value): 21 | return value.ndim == 0 or value.shape == torch.Size([1]) 22 | 23 | 24 | def is_symmetric(params): 25 | """ 26 | Check if parameters are symmetric, all values [2,2,2,2]. 27 | Then we can use only [2,2]. 28 | """ 29 | assert len(params) // 2 == len(params) / 2, "Non even number of parameters." 30 | idx = len(params) // 2 31 | for i in range(0, idx): 32 | if params[i] != params[idx + i]: 33 | return False 34 | return True 35 | 36 | 37 | def extract_padding_params(params): 38 | """Extract padding parameters for Pad layers.""" 39 | pad_dim = len(params) // 2 40 | if pad_dim == 0: 41 | return [] 42 | pads = np.array(params).reshape(-1, pad_dim).T.flatten() # .tolist() 43 | 44 | # Some padding modes do not support padding in batch and channel dimension. 45 | # If batch and channel dimension have no padding, discard. 46 | if (pads[:4] == 0).all(): 47 | pads = pads[4:] 48 | pads = pads.tolist() 49 | # Reverse, because for pytorch first two numbers correspond to last dimension, etc. 50 | pads.reverse() 51 | return pads 52 | 53 | 54 | def extract_padding_params_for_conv_layer(params): 55 | """ 56 | Padding params in onnx are different than in pytorch. That is why we need to 57 | check if they are symmetric and cut half or return a padding layer. 58 | """ 59 | if is_symmetric(params): 60 | return params[: len(params) // 2] 61 | else: 62 | pad_dim = len(params) // 2 63 | pad_layer = getattr(torch.nn, "ConstantPad{}d".format(pad_dim)) 64 | pads = extract_padding_params(params)[::-1] 65 | return pad_layer(pads, value=0) 66 | 67 | 68 | def get_selection(indices, dim): 69 | """ 70 | Give selection to assign values to specific indices at given dimension. 71 | Enables dimension to be dynamic: 72 | tensor[get_selection(indices, dim=2)] = values 73 | Alternatively the dimension is fixed in code syntax: 74 | tensor[:, :, indices] = values 75 | """ 76 | assert dim >= 0, "Negative dimension not supported." 77 | # Behaviour with python lists is unfortunately not working the same. 78 | if isinstance(indices, list): 79 | indices = torch.tensor(indices) 80 | assert isinstance(indices, (torch.Tensor, np.ndarray)) 81 | selection = [slice(None) for _ in range(dim + 1)] 82 | selection[dim] = indices 83 | return selection 84 | 85 | 86 | def assign_values_to_dim(tensor, values, indices, dim, inplace=True): 87 | """ 88 | Inplace tensor operation that assigns values to corresponding indices 89 | at given dimension. 90 | """ 91 | if dim < 0: 92 | dim = dim + len(tensor.shape) 93 | selection = get_selection(indices, dim) 94 | if not inplace: 95 | tensor = tensor.clone() 96 | tensor[selection] = values 97 | return tensor 98 | 99 | 100 | def get_type(x): 101 | """ 102 | Extract type from onnxruntime input. 103 | 104 | Parameters 105 | ---------- 106 | x: onnxruntime.capi.onnxruntime_pybind11_state.NodeArg 107 | """ 108 | if x.type.startswith("tensor"): 109 | typ = x.type[7:-1] 110 | else: 111 | raise NotImplementedError("For type: {}".format(x.type)) 112 | 113 | if typ == "float": 114 | typ = "float32" 115 | elif typ == "double": 116 | typ = "float64" 117 | return typ 118 | 119 | 120 | def get_shape(x, unknown_dim_size=1): 121 | """ 122 | Extract shape from onnxruntime input. 123 | Replace unknown dimension by default with 1. 124 | 125 | Parameters 126 | ---------- 127 | x: onnxruntime.capi.onnxruntime_pybind11_state.NodeArg 128 | unknown_dim_size: int 129 | Default: 1 130 | """ 131 | shape = x.shape 132 | # replace unknown dimensions by default with 1 133 | shape = [i if isinstance(i, int) else unknown_dim_size for i in shape] 134 | return shape 135 | 136 | 137 | def get_activation_value(onnx_model, inputs, activation_names): 138 | """ 139 | Get activation value from an onnx model. 140 | 141 | Parameters 142 | ---------- 143 | onnx_model: onnx.ModelProto 144 | inputs: list[np.ndarray] 145 | activation_names: list[str] 146 | Can be retrieved from onnx node: list(node.output) 147 | 148 | Returns 149 | ------- 150 | value: list[np.ndarray] 151 | Value of the activation with activation_name. 152 | """ 153 | assert ort is not None, "onnxruntime needed. pip install onnxruntime" 154 | assert all(isinstance(x, np.ndarray) for x in inputs) 155 | 156 | if not isinstance(activation_names, (list, tuple)): 157 | activation_names = [activation_names] 158 | 159 | # clear output 160 | while len(onnx_model.graph.output): 161 | onnx_model.graph.output.pop() 162 | 163 | for activation_name in activation_names: 164 | activation_value = onnx.helper.ValueInfoProto() 165 | activation_value.name = activation_name 166 | onnx_model.graph.output.append(activation_value) 167 | 168 | buffer = io.BytesIO() 169 | onnx.save(onnx_model, buffer) 170 | buffer.seek(0) 171 | onnx_model_new = onnx.load(buffer) 172 | sess = ort.InferenceSession(onnx_model_new.SerializeToString()) 173 | 174 | input_names = [x.name for x in sess.get_inputs()] 175 | if not isinstance(inputs, list): 176 | inputs = [inputs] 177 | inputs = dict(zip(input_names, inputs)) 178 | 179 | return sess.run(None, inputs) 180 | 181 | 182 | def get_inputs_names(onnx_graph): 183 | param_names = set([x.name for x in onnx_graph.initializer]) 184 | input_names = [x.name for x in onnx_graph.input] 185 | input_names = [x for x in input_names if x not in param_names] 186 | return input_names 187 | 188 | 189 | def get_inputs_sample(onnx_model, to_torch=False): 190 | """Get inputs sample from onnx model.""" 191 | assert ort is not None, "onnxruntime needed. pip install onnxruntime" 192 | 193 | sess = ort.InferenceSession(onnx_model.SerializeToString()) 194 | inputs = sess.get_inputs() 195 | input_names = get_inputs_names(onnx_model.graph) 196 | input_tensors = [ 197 | np.abs(np.random.rand(*get_shape(x)).astype(get_type(x))) for x in inputs 198 | ] 199 | if to_torch: 200 | input_tensors = [torch.from_numpy(x) for x in input_tensors] 201 | return dict(zip(input_names, input_tensors)) 202 | 203 | 204 | def get_outputs_names(onnx_graph): 205 | output_names = [x.name for x in onnx_graph.output] 206 | return output_names 207 | 208 | 209 | def get_ops_names(onnx_graph): 210 | ops_used = set(node.op_type for node in onnx_graph.node) 211 | for node in onnx_graph.node: 212 | if node.op_type == "Loop": 213 | for attr in node.attribute: 214 | if attr.name == "body": 215 | ops_used |= get_ops_names(attr.g) 216 | elif node.op_type == "If": 217 | for attr in node.attribute: 218 | if attr.name == "then_branch": 219 | ops_used |= get_ops_names(attr.g) 220 | elif attr.name == "else_branch": 221 | ops_used |= get_ops_names(attr.g) 222 | return ops_used 223 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pytest>=5.4.1 2 | pre-commit>=2.2.0 3 | torch>=1.4.0 4 | torchvision>=0.9.0 5 | onnx>=1.6.0 6 | onnxruntime>=1.5.0 7 | numpy>=1.18.1 8 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | from setuptools import setup, find_packages 4 | 5 | try: 6 | with open("README.md", "r", encoding="utf-8") as f: 7 | long_description = f.read() 8 | except IOError: 9 | long_description = "" 10 | 11 | # Extract version. Cannot import directly because of import error. 12 | root_dir = os.path.dirname(__file__) 13 | with open(os.path.join(root_dir, "onnx2pytorch/__init__.py"), "r") as f: 14 | for line in f.readlines(): 15 | if line.startswith("__version__"): 16 | version = line.split("=")[-1].strip().strip('"') 17 | break 18 | 19 | PACKAGES = find_packages(exclude=("tests.*", "tests")) 20 | 21 | setup( 22 | name="onnx2pytorch", 23 | version=version, 24 | description="Library to transform onnx model to pytorch.", 25 | license="apache-2.0", 26 | author="Talmaj Marinc", 27 | packages=PACKAGES, 28 | install_requires=["torch>=1.4.0", "onnx>=1.6.0", "torchvision>=0.9.0"], 29 | long_description=long_description, 30 | long_description_content_type="text/markdown", 31 | url="https://github.com/ToriML/onnx2pytorch", 32 | classifiers=[ 33 | "Programming Language :: Python :: 3", 34 | "License :: OSI Approved :: Apache Software License", 35 | "Operating System :: OS Independent", 36 | ], 37 | python_requires=">=3.6", 38 | ) 39 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Talmaj/onnx2pytorch/0b72a11d0772ec0fc014690019bfd302f8dc5f5f/tests/__init__.py -------------------------------------------------------------------------------- /tests/onnx2pytorch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Talmaj/onnx2pytorch/0b72a11d0772ec0fc014690019bfd302f8dc5f5f/tests/onnx2pytorch/__init__.py -------------------------------------------------------------------------------- /tests/onnx2pytorch/conftest.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | 4 | import pytest 5 | import onnx 6 | import numpy as np 7 | import onnxruntime as ort 8 | 9 | from onnx2pytorch.utils import get_inputs_sample 10 | 11 | RANDOM_SEED = 100 12 | FIXTURES_DIR = os.path.join( 13 | os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "fixtures" 14 | ) 15 | 16 | 17 | @pytest.fixture(params=glob.glob(os.path.join(FIXTURES_DIR, "*.onnx"))) 18 | def onnx_model_path(request): 19 | return request.param 20 | 21 | 22 | @pytest.fixture 23 | def onnx_model(onnx_model_path): 24 | onnx_model = onnx.load(onnx_model_path) 25 | return onnx_model 26 | 27 | 28 | @pytest.fixture 29 | def onnx_inputs(onnx_model): 30 | np.random.seed(RANDOM_SEED) 31 | return get_inputs_sample(onnx_model) 32 | 33 | 34 | @pytest.fixture 35 | def onnx_model_outputs(onnx_model_path, onnx_model, onnx_inputs): 36 | ort_session = ort.InferenceSession(onnx_model_path) 37 | onnx_output = ort_session.run(None, onnx_inputs) 38 | return onnx_output 39 | -------------------------------------------------------------------------------- /tests/onnx2pytorch/convert/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Talmaj/onnx2pytorch/0b72a11d0772ec0fc014690019bfd302f8dc5f5f/tests/onnx2pytorch/convert/__init__.py -------------------------------------------------------------------------------- /tests/onnx2pytorch/convert/test_attribute.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | import pytest 4 | import onnx 5 | import numpy as np 6 | 7 | from onnx2pytorch.convert import extract_attr_values, extract_attributes 8 | 9 | 10 | @pytest.mark.parametrize( 11 | "kwargs, value", 12 | [ 13 | [dict(type="INT", i=1), 1], 14 | [dict(type="FLOAT", f=np.float64(2.5)), 2.5], 15 | [dict(type="INTS", ints=(1, 2)), (1, 2)], 16 | [dict(type="FLOATS", floats=np.array((1.5, 2.5))), (1.5, 2.5)], 17 | [dict(type="STRING", s="nearest".encode()), "nearest"], 18 | ], 19 | ) 20 | def test_extract_attr_values(kwargs, value): 21 | attr = onnx.AttributeProto(**kwargs) 22 | assert extract_attr_values(attr) == value 23 | 24 | 25 | @pytest.mark.parametrize( 26 | "node, exp_kwargs", 27 | [ 28 | [ 29 | onnx.helper.make_node( 30 | "Conv", 31 | inputs=["x", "W"], 32 | outputs=["y"], 33 | kernel_shape=[3, 3], 34 | strides=[1, 1], 35 | dilations=[1, 1], 36 | group=1, 37 | pads=[1, 1, 1, 1], 38 | ), 39 | dict( 40 | kernel_size=(3, 3), 41 | stride=(1, 1), 42 | dilation=(1, 1), 43 | groups=1, 44 | padding=(1, 1), 45 | ), 46 | ], 47 | [ 48 | onnx.helper.make_node( 49 | "Pad", 50 | inputs=["x", "pads", "value"], 51 | outputs=["y"], 52 | mode="constant", 53 | pads=[1, 0, 0, 1, 0, 0], 54 | ), 55 | dict( 56 | mode="constant", 57 | padding=[0, 0, 0, 0, 1, 1], 58 | ), 59 | ], 60 | [ 61 | onnx.helper.make_node( 62 | "Flatten", 63 | inputs=["a"], 64 | outputs=["b"], 65 | axis=1, 66 | ), 67 | dict( 68 | start_dim=1, 69 | ), 70 | ], 71 | [ 72 | onnx.helper.make_node( 73 | "Slice", 74 | inputs=["x", "starts", "ends", "axes", "steps"], 75 | outputs=["y"], 76 | starts=[0, 0, 3], 77 | ends=[20, 10, 4], 78 | ), 79 | dict(starts=(0, 0, 3), ends=(20, 10, 4)), 80 | ], 81 | [ 82 | onnx.helper.make_node( 83 | "Resize", 84 | inputs=["X", "", "scales"], 85 | outputs=["Y"], 86 | mode="nearest", 87 | coordinate_transformation_mode="align_corners", 88 | extrapolation_value=1, 89 | ), 90 | dict(mode="nearest", align_corners=True, extrapolation_value=1), 91 | ], 92 | [ 93 | onnx.helper.make_node( 94 | "AveragePool", 95 | inputs=["x"], 96 | outputs=["y"], 97 | kernel_shape=[3, 3], 98 | strides=[2, 2], 99 | ceil_mode=True, 100 | auto_pad="NOTSET", 101 | ), 102 | dict(kernel_size=(3, 3), stride=(2, 2), ceil_mode=True), 103 | ], 104 | [ 105 | onnx.helper.make_node("LeakyRelu", inputs=["x"], outputs=["y"], alpha=0.5), 106 | dict(negative_slope=0.5), 107 | ], 108 | ], 109 | ) 110 | def test_extract_attributes(node, exp_kwargs): 111 | extracted_kwargs = extract_attributes(node) 112 | TestCase().assertDictEqual(exp_kwargs, extracted_kwargs) 113 | -------------------------------------------------------------------------------- /tests/onnx2pytorch/convert/test_debug.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytest 3 | 4 | from onnx2pytorch.convert.debug import debug_model_conversion 5 | from onnx2pytorch.helpers import to_onnx 6 | 7 | 8 | @pytest.fixture 9 | def inp_size(): 10 | return (1, 3, 10, 10) 11 | 12 | 13 | @pytest.fixture 14 | def inp(inp_size): 15 | return torch.rand(*inp_size) 16 | 17 | 18 | @pytest.fixture 19 | def model(): 20 | return torch.nn.Sequential( 21 | torch.nn.Conv2d(3, 10, 3, 1, 1), torch.nn.Conv2d(10, 3, 3, 1, 1) 22 | ) 23 | 24 | 25 | @pytest.fixture 26 | def onnx_model(model, inp_size): 27 | return to_onnx(model, inp_size) 28 | 29 | 30 | def test_debug_model_conversion(model, onnx_model, inp): 31 | pred_act = model[0](inp) 32 | debug_model_conversion(onnx_model, [inp], pred_act, onnx_model.graph.node[0]) 33 | 34 | 35 | def test_debug_model_conversion_raise_error(model, onnx_model, inp): 36 | model.eval() 37 | pred_act = torch.rand(1, 10, 10, 10) 38 | 39 | with pytest.raises(AssertionError): 40 | debug_model_conversion(onnx_model, [inp], pred_act, onnx_model.graph.node[0]) 41 | 42 | with pytest.raises(TypeError): 43 | debug_model_conversion(onnx_model, inp, pred_act, onnx_model.graph.node[0]) 44 | -------------------------------------------------------------------------------- /tests/onnx2pytorch/convert/test_loop.py: -------------------------------------------------------------------------------- 1 | import io 2 | import numpy as np 3 | import onnx 4 | import onnxruntime as ort 5 | import pytest 6 | import torch 7 | 8 | from onnx2pytorch.convert import ConvertModel 9 | 10 | 11 | def test_loop_sum(): 12 | y_in = onnx.helper.make_tensor_value_info("y_in", onnx.TensorProto.FLOAT, [1]) 13 | y_out = onnx.helper.make_tensor_value_info("y_out", onnx.TensorProto.FLOAT, [1]) 14 | scan_out = onnx.helper.make_tensor_value_info( 15 | "scan_out", onnx.TensorProto.FLOAT, [] 16 | ) 17 | cond_in = onnx.helper.make_tensor_value_info("cond_in", onnx.TensorProto.BOOL, []) 18 | cond_out = onnx.helper.make_tensor_value_info("cond_out", onnx.TensorProto.BOOL, []) 19 | iter_count = onnx.helper.make_tensor_value_info( 20 | "iter_count", onnx.TensorProto.INT64, [] 21 | ) 22 | 23 | x = np.array([1, 2, 3, 4, 5]).astype(np.float32) 24 | 25 | x_const_node = onnx.helper.make_node( 26 | "Constant", 27 | inputs=[], 28 | outputs=["x"], 29 | value=onnx.helper.make_tensor( 30 | name="const_tensor_x", 31 | data_type=onnx.TensorProto.FLOAT, 32 | dims=x.shape, 33 | vals=x.flatten().astype(float), 34 | ), 35 | ) 36 | 37 | one_const_node = onnx.helper.make_node( 38 | "Constant", 39 | inputs=[], 40 | outputs=["one"], 41 | value=onnx.helper.make_tensor( 42 | name="const_tensor_one", data_type=onnx.TensorProto.INT64, dims=(), vals=[1] 43 | ), 44 | ) 45 | 46 | i_add_node = onnx.helper.make_node( 47 | "Add", inputs=["iter_count", "one"], outputs=["end"] 48 | ) 49 | 50 | start_unsqueeze_node = onnx.helper.make_node( 51 | "Unsqueeze", inputs=["iter_count"], outputs=["slice_start"], axes=[0] 52 | ) 53 | 54 | end_unsqueeze_node = onnx.helper.make_node( 55 | "Unsqueeze", inputs=["end"], outputs=["slice_end"], axes=[0] 56 | ) 57 | 58 | slice_node = onnx.helper.make_node( 59 | "Slice", inputs=["x", "slice_start", "slice_end"], outputs=["slice_out"] 60 | ) 61 | 62 | y_add_node = onnx.helper.make_node( 63 | "Add", inputs=["y_in", "slice_out"], outputs=["y_out"] 64 | ) 65 | 66 | identity_node = onnx.helper.make_node( 67 | "Identity", inputs=["cond_in"], outputs=["cond_out"] 68 | ) 69 | 70 | scan_identity_node = onnx.helper.make_node( 71 | "Identity", inputs=["y_out"], outputs=["scan_out"] 72 | ) 73 | 74 | loop_body = onnx.helper.make_graph( 75 | [ 76 | identity_node, 77 | x_const_node, 78 | one_const_node, 79 | i_add_node, 80 | start_unsqueeze_node, 81 | end_unsqueeze_node, 82 | slice_node, 83 | y_add_node, 84 | scan_identity_node, 85 | ], 86 | "loop_body", 87 | [iter_count, cond_in, y_in], 88 | [cond_out, y_out, scan_out], 89 | ) 90 | 91 | node = onnx.helper.make_node( 92 | "Loop", 93 | inputs=["trip_count", "cond", "y"], 94 | outputs=["res_y", "res_scan"], 95 | body=loop_body, 96 | ) 97 | 98 | trip_count = onnx.helper.make_tensor_value_info( 99 | "trip_count", onnx.TensorProto.INT64, [] 100 | ) 101 | cond = onnx.helper.make_tensor_value_info("cond", onnx.TensorProto.BOOL, []) 102 | y = onnx.helper.make_tensor_value_info("y", onnx.TensorProto.FLOAT, [1]) 103 | res_y = onnx.helper.make_tensor_value_info("res_y", onnx.TensorProto.FLOAT, [1]) 104 | res_scan = onnx.helper.make_tensor_value_info( 105 | "res_scan", onnx.TensorProto.FLOAT, [] 106 | ) 107 | 108 | graph_def = onnx.helper.make_graph( 109 | nodes=[node], 110 | name="test-model", 111 | inputs=[trip_count, cond, y], 112 | outputs=[res_y, res_scan], 113 | ) 114 | 115 | model_def = onnx.helper.make_model_gen_version( 116 | graph_def, 117 | producer_name="loop-example", 118 | opset_imports=[onnx.helper.make_opsetid("", 11)], 119 | ) 120 | onnx.checker.check_model(model_def) 121 | bitstream = io.BytesIO() 122 | onnx.save(model_def, bitstream) 123 | bitstream_data = bitstream.getvalue() 124 | 125 | trip_count_input = np.array(5).astype(np.int64) 126 | cond_input = np.array(1).astype(np.bool) 127 | y_input = np.array([-2]).astype(np.float32) 128 | exp_res_y = np.array([13]).astype(np.float32) 129 | exp_res_scan = np.array([-1, 1, 4, 8, 13]).astype(np.float32).reshape((5, 1)) 130 | 131 | ort_session = ort.InferenceSession(bitstream_data) 132 | ort_inputs = {"trip_count": trip_count_input, "cond": cond_input, "y": y_input} 133 | ort_outputs = ort_session.run(None, ort_inputs) 134 | ort_res_y, ort_res_scan = ort_outputs 135 | np.testing.assert_allclose(ort_res_y, exp_res_y, rtol=1e-5, atol=1e-5) 136 | np.testing.assert_allclose(ort_res_scan, exp_res_scan, rtol=1e-5, atol=1e-5) 137 | 138 | o2p_model = ConvertModel(model_def, experimental=True) 139 | o2p_inputs = { 140 | "trip_count": torch.from_numpy(trip_count_input), 141 | "cond": torch.from_numpy(cond_input), 142 | "y": torch.from_numpy(y_input), 143 | } 144 | o2p_outputs = o2p_model(**o2p_inputs) 145 | o2p_res_y, o2p_res_scan = o2p_outputs 146 | np.testing.assert_allclose( 147 | o2p_res_y.detach().numpy(), exp_res_y, rtol=1e-5, atol=1e-5 148 | ) 149 | np.testing.assert_allclose( 150 | o2p_res_scan.detach().numpy(), exp_res_scan, rtol=1e-5, atol=1e-5 151 | ) 152 | 153 | 154 | if __name__ == "__main__": 155 | test_loop_sum() 156 | -------------------------------------------------------------------------------- /tests/onnx2pytorch/convert/test_lstm.py: -------------------------------------------------------------------------------- 1 | import io 2 | import onnx 3 | import pytest 4 | import torch 5 | 6 | from onnx2pytorch.convert import ConvertModel 7 | 8 | 9 | @pytest.mark.parametrize( 10 | "bidirectional, input_size, hidden_size, seq_len, batch, test_seq_len, test_batch", 11 | [ 12 | (False, 3, 5, 23, 4, 23, 4), 13 | (False, 3, 5, 23, 4, 37, 4), 14 | (False, 3, 5, 23, 4, 23, 7), 15 | (True, 3, 5, 23, 4, 23, 4), 16 | (True, 3, 5, 23, 4, 37, 4), 17 | (True, 3, 5, 23, 4, 23, 7), 18 | ], 19 | ) 20 | def test_single_layer_lstm( 21 | bidirectional, input_size, hidden_size, seq_len, batch, test_seq_len, test_batch 22 | ): 23 | torch.manual_seed(42) 24 | num_layers = 1 25 | num_directions = bidirectional + 1 26 | lstm = torch.nn.LSTM( 27 | input_size=input_size, 28 | hidden_size=hidden_size, 29 | num_layers=num_layers, 30 | bidirectional=bidirectional, 31 | ) 32 | input = torch.randn(seq_len, batch, input_size) 33 | h_0 = torch.randn(num_layers * num_directions, batch, hidden_size) 34 | c_0 = torch.randn(num_layers * num_directions, batch, hidden_size) 35 | output, (h_n, c_n) = lstm(input, (h_0, c_0)) 36 | bitstream = io.BytesIO() 37 | torch.onnx.export( 38 | model=lstm, 39 | args=(input, (h_0, c_0)), 40 | f=bitstream, 41 | input_names=["input", "h_0", "c_0"], 42 | opset_version=11, 43 | dynamic_axes={ 44 | "input": {0: "seq_len", 1: "batch"}, 45 | "h_0": {1: "batch"}, 46 | "c_0": {1: "batch"}, 47 | }, 48 | ) 49 | bitstream_data = bitstream.getvalue() 50 | 51 | onnx_lstm = onnx.ModelProto.FromString(bitstream_data) 52 | o2p_lstm = ConvertModel(onnx_lstm, experimental=True) 53 | with torch.no_grad(): 54 | o2p_output, o2p_h_n, o2p_c_n = o2p_lstm(input, h_0, c_0) 55 | torch.testing.assert_allclose(o2p_output, output, rtol=1e-6, atol=1e-6) 56 | torch.testing.assert_allclose(o2p_h_n, h_n, rtol=1e-6, atol=1e-6) 57 | torch.testing.assert_allclose(o2p_c_n, c_n, rtol=1e-6, atol=1e-6) 58 | 59 | onnx_lstm = onnx.ModelProto.FromString(bitstream_data) 60 | o2p_lstm = ConvertModel(onnx_lstm, experimental=True) 61 | with torch.no_grad(): 62 | o2p_output, o2p_h_n, o2p_c_n = o2p_lstm(h_0=h_0, input=input, c_0=c_0) 63 | torch.testing.assert_allclose(o2p_output, output, rtol=1e-6, atol=1e-6) 64 | torch.testing.assert_allclose(o2p_h_n, h_n, rtol=1e-6, atol=1e-6) 65 | torch.testing.assert_allclose(o2p_c_n, c_n, rtol=1e-6, atol=1e-6) 66 | with pytest.raises(KeyError): 67 | o2p_output, o2p_h_n, o2p_c_n = o2p_lstm(h_0=h_0, input=input) 68 | with pytest.raises(Exception): 69 | # Even though initial states are optional for nn.LSTM(), 70 | # we adhere to onnxruntime convention that inputs are provided 71 | # as either all positional or all keyword arguments. 72 | o2p_output, o2p_h_n, o2p_c_n = o2p_lstm(input, h_0=h_0, c_0=c_0) 73 | -------------------------------------------------------------------------------- /tests/onnx2pytorch/convert/test_maxpool.py: -------------------------------------------------------------------------------- 1 | import io 2 | import onnx 3 | import pytest 4 | import torch 5 | 6 | from onnx2pytorch.convert import ConvertModel 7 | from torch import nn 8 | 9 | 10 | class UsedIndices(nn.Module): 11 | def __init__(self): 12 | super().__init__() 13 | self.mp = nn.MaxPool2d( 14 | kernel_size=[3, 3], 15 | stride=[2, 2], 16 | ceil_mode=True, 17 | return_indices=True, 18 | ) 19 | 20 | def forward(self, x): 21 | y, indices = self.mp(x) 22 | return y - 42, indices + 42 23 | 24 | 25 | class UnusedIndices(nn.Module): 26 | def __init__(self): 27 | super().__init__() 28 | self.mp = nn.MaxPool2d( 29 | kernel_size=[3, 3], 30 | stride=[2, 2], 31 | ceil_mode=True, 32 | ) 33 | 34 | def forward(self, x): 35 | return self.mp(x) - 42 36 | 37 | 38 | def test_maxpool_2d_ceil(): 39 | x = torch.tensor( 40 | [ 41 | [ 42 | [ 43 | [1, 2, 3, 4], 44 | [5, 6, 7, 8], 45 | [9, 10, 11, 12], 46 | [13, 14, 15, 16], 47 | ] 48 | ] 49 | ], 50 | dtype=torch.float32, 51 | ) 52 | exp_y = ( 53 | torch.tensor( 54 | [ 55 | [ 56 | [ 57 | [11, 12], 58 | [15, 16], 59 | ] 60 | ] 61 | ], 62 | dtype=torch.float32, 63 | ) 64 | - 42 65 | ) 66 | exp_indices = ( 67 | torch.tensor( 68 | [ 69 | [ 70 | [ 71 | [10, 11], 72 | [14, 15], 73 | ] 74 | ] 75 | ] 76 | ) 77 | + 42 78 | ) 79 | 80 | model = UsedIndices() 81 | bitstream = io.BytesIO() 82 | torch.onnx.export( 83 | model=model, 84 | args=(x,), 85 | f=bitstream, 86 | input_names=["x"], 87 | opset_version=11, 88 | ) 89 | onnx_model = onnx.ModelProto.FromString(bitstream.getvalue()) 90 | o2p_model = ConvertModel(onnx_model) 91 | y, indices = o2p_model(x) 92 | assert torch.equal(exp_y, y) 93 | assert torch.equal(exp_indices, indices) 94 | 95 | model = UnusedIndices() 96 | bitstream = io.BytesIO() 97 | torch.onnx.export( 98 | model=model, 99 | args=(x,), 100 | f=bitstream, 101 | input_names=["x"], 102 | opset_version=11, 103 | ) 104 | onnx_model = onnx.ModelProto.FromString(bitstream.getvalue()) 105 | o2p_model = ConvertModel(onnx_model) 106 | y = o2p_model(x) 107 | assert torch.equal(exp_y, y) 108 | -------------------------------------------------------------------------------- /tests/onnx2pytorch/convert/test_train.py: -------------------------------------------------------------------------------- 1 | import io 2 | import numpy as np 3 | import onnx 4 | import onnxruntime as ort 5 | import pytest 6 | import torch 7 | 8 | from onnx2pytorch.convert import ConvertModel 9 | 10 | 11 | def test_train_multiple_models(): 12 | original_model = torch.nn.Linear(in_features=3, out_features=5) 13 | dummy_input = torch.randn(4, 3) 14 | original_output = original_model(dummy_input) 15 | 16 | bitstream = io.BytesIO() 17 | torch.onnx.export(original_model, dummy_input, bitstream, opset_version=11) 18 | 19 | bitstream.seek(0) 20 | onnx_model = onnx.load(bitstream) 21 | o2p_model = ConvertModel(onnx_model, experimental=True) 22 | o2p_model2 = ConvertModel(onnx_model, experimental=True) 23 | o2p_output = o2p_model(dummy_input) 24 | assert torch.equal(original_output, o2p_output) 25 | 26 | onnx_model_serial = onnx_model.SerializeToString() 27 | with torch.no_grad(): 28 | for name, param in o2p_model.named_parameters(): 29 | param.copy_(torch.zeros(*param.shape)) 30 | assert onnx_model.SerializeToString() == onnx_model_serial 31 | 32 | o2p_output_after = o2p_model(dummy_input) 33 | assert not torch.equal(original_output, o2p_output_after) 34 | o2p_output2_after = o2p_model2(dummy_input) 35 | assert torch.equal(original_output, o2p_output2_after) 36 | -------------------------------------------------------------------------------- /tests/onnx2pytorch/operations/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Talmaj/onnx2pytorch/0b72a11d0772ec0fc014690019bfd302f8dc5f5f/tests/onnx2pytorch/operations/__init__.py -------------------------------------------------------------------------------- /tests/onnx2pytorch/operations/test_add.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from onnx2pytorch.operations.add import Add 5 | 6 | 7 | @pytest.fixture( 8 | params=[[[10, 20, 30], list(range(32))], [[10, 20, 30], [8, 9, 10, 11, 12]]] 9 | ) 10 | def input_indices(request): 11 | return request.param 12 | 13 | 14 | @pytest.fixture 15 | def updated_input_indices(input_indices): 16 | """Re-index indices after removal of channels that contain only zeros.""" 17 | s = list(set(input_indices[0]).union(input_indices[1])) 18 | new = [[s.index(x) for x in indices] for indices in input_indices] 19 | return new 20 | 21 | 22 | @pytest.fixture 23 | def input_shape(input_indices): 24 | n = len(set(input_indices[0]).union(input_indices[1])) 25 | return (32, n, 8, 8) 26 | 27 | 28 | @pytest.fixture 29 | def original_shape(): 30 | return (32, 32, 8, 8) 31 | 32 | 33 | @pytest.fixture 34 | def inputs(input_indices): 35 | """Pruned smaller inputs.""" 36 | return [ 37 | torch.ones(32, len(input_indices[0]), 8, 8), 38 | 2 * torch.ones(32, len(input_indices[1]), 8, 8), 39 | ] 40 | 41 | 42 | def test_add(input_shape, inputs, updated_input_indices): 43 | """Test addition of differently sized inputs.""" 44 | op = Add(input_shape, updated_input_indices) 45 | out = op(*inputs) 46 | 47 | for i in range(32): 48 | if i in updated_input_indices[0] and i in updated_input_indices[1]: 49 | assert out[0, i, 0, 0] == 3 50 | elif i in updated_input_indices[0]: 51 | assert out[0, i, 0, 0] == 1 52 | elif i in updated_input_indices[1]: 53 | assert out[0, i, 0, 0] == 2 54 | 55 | 56 | def test_simple_add(): 57 | a = torch.zeros(2, 2) 58 | b = torch.zeros(2, 2) 59 | b[0] = 1 60 | op = Add() 61 | out = op(a, b) 62 | assert torch.allclose(out, torch.tensor([[1.0, 1], [0, 0]])) 63 | 64 | 65 | def test_add_2(input_shape, inputs, updated_input_indices): 66 | """Test addition of differently sized inputs.""" 67 | inputs_2 = [torch.zeros(*input_shape), torch.zeros(*input_shape)] 68 | for i in range(2): 69 | inp = inputs_2[i] 70 | idx = updated_input_indices[i] 71 | inp[:, idx] = i + 1 72 | 73 | op = Add() 74 | out_true = op(*inputs_2) 75 | 76 | op = Add(input_shape, updated_input_indices) 77 | out = op(*inputs) 78 | 79 | assert torch.allclose(out_true, out, atol=1e-7), "Outputs differ." 80 | 81 | 82 | @pytest.mark.parametrize( 83 | "inputs", 84 | [ 85 | [torch.tensor(1), torch.ones(10, 10)], 86 | [torch.ones(10, 10), torch.tensor(1)], 87 | [torch.tensor(1), torch.ones(10, 10), torch.tensor(0)], 88 | ], 89 | ) 90 | def test_add_constant(inputs): 91 | op = Add() 92 | out = op(*inputs) 93 | assert torch.equal(out, 2 * torch.ones(10, 10)) 94 | 95 | 96 | def test_set_input_indices( 97 | input_indices, updated_input_indices, input_shape, original_shape 98 | ): 99 | inputs = [torch.zeros(*original_shape), torch.zeros(*original_shape)] 100 | for inp, idx in zip(inputs, input_indices): 101 | inp[:, idx] = 1 102 | 103 | op = Add() 104 | op.set_input_indices(inputs) 105 | assert op.input_shape == input_shape 106 | for true_idx, calc_idx in zip(updated_input_indices, op.input_indices): 107 | assert true_idx == calc_idx.tolist() 108 | 109 | 110 | @pytest.mark.parametrize("inp", [torch.triu(torch.ones(10, 10), 1), torch.tensor(1)]) 111 | def test_set_input_indices_triu(inp): 112 | act = torch.ones(10, 10) 113 | inputs = [act, inp] 114 | op = Add() 115 | op.set_input_indices(inputs) 116 | assert op.input_indices is None 117 | -------------------------------------------------------------------------------- /tests/onnx2pytorch/operations/test_bitshift.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from onnx2pytorch.operations.bitshift import BitShift 5 | 6 | 7 | def test_bitshift_left_uint8(): 8 | op = BitShift(direction="LEFT") 9 | 10 | x = torch.tensor([16, 4, 1], dtype=torch.uint8) 11 | y = torch.tensor([1, 2, 3], dtype=torch.uint8) 12 | exp_z = torch.tensor([32, 16, 8], dtype=torch.uint8) 13 | assert torch.equal(op(x, y), exp_z) 14 | 15 | 16 | def test_bitshift_left_int64(): 17 | op = BitShift(direction="LEFT") 18 | 19 | x = torch.tensor([16, 4, 1], dtype=torch.int64) 20 | y = torch.tensor([1, 2, 3], dtype=torch.int64) 21 | exp_z = torch.tensor([32, 16, 8], dtype=torch.int64) 22 | assert torch.equal(op(x, y), exp_z) 23 | 24 | 25 | def test_bitshift_right_uint8(): 26 | op = BitShift(direction="RIGHT") 27 | 28 | x = torch.tensor([16, 4, 1], dtype=torch.uint8) 29 | y = torch.tensor([1, 2, 3], dtype=torch.uint8) 30 | exp_z = torch.tensor([8, 1, 0], dtype=torch.uint8) 31 | assert torch.equal(op(x, y), exp_z) 32 | 33 | 34 | def test_bitshift_right_int64(): 35 | op = BitShift(direction="RIGHT") 36 | 37 | x = torch.tensor([16, 4, 1], dtype=torch.int64) 38 | y = torch.tensor([1, 2, 3], dtype=torch.int64) 39 | exp_z = torch.tensor([8, 1, 0], dtype=torch.int64) 40 | assert torch.equal(op(x, y), exp_z) 41 | -------------------------------------------------------------------------------- /tests/onnx2pytorch/operations/test_cast.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import pytest 4 | from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE 5 | 6 | from onnx2pytorch.operations import Cast 7 | 8 | 9 | @pytest.mark.parametrize("dtype", ["double", "float32", "float16"]) 10 | def test_cast(dtype): 11 | shape = (3, 4) 12 | x_np = np.array( 13 | [ 14 | u"0.47892547", 15 | u"0.48033667", 16 | u"0.49968487", 17 | u"0.81910545", 18 | u"0.47031248", 19 | u"0.816468", 20 | u"0.21087195", 21 | u"0.7229038", 22 | u"NaN", 23 | u"INF", 24 | u"+INF", 25 | u"-INF", 26 | ], 27 | dtype=np.dtype(object), 28 | ).reshape(shape) 29 | x = torch.from_numpy(x_np.astype(dtype)) 30 | op = Cast(dtype) 31 | y = x_np.astype(getattr(np, dtype.lower())) 32 | assert np.allclose(op(x).numpy(), y, rtol=0, atol=0, equal_nan=True) 33 | -------------------------------------------------------------------------------- /tests/onnx2pytorch/operations/test_clip.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import torch 4 | 5 | from onnx2pytorch.operations.clip import Clip 6 | 7 | 8 | def test_clip(): 9 | x_np = np.random.randn(3, 4, 5).astype(np.float32) 10 | x = torch.from_numpy(x_np) 11 | 12 | op = Clip(min=-1, max=1) 13 | exp_y_np = np.clip(x_np, -1, 1) 14 | exp_y = torch.from_numpy(exp_y_np) 15 | assert torch.equal(op(x), exp_y) 16 | 17 | op = Clip(min=0) 18 | exp_y_np = np.clip(x_np, 0, np.inf) 19 | exp_y = torch.from_numpy(exp_y_np) 20 | assert torch.equal(op(x), exp_y) 21 | 22 | op = Clip(max=0) 23 | exp_y_np = np.clip(x_np, -np.inf, 0) 24 | exp_y = torch.from_numpy(exp_y_np) 25 | assert torch.equal(op(x), exp_y) 26 | 27 | op = Clip() 28 | exp_y_np = np.clip(x_np, -np.inf, np.inf) 29 | exp_y = torch.from_numpy(exp_y_np) 30 | assert torch.equal(op(x), exp_y) 31 | -------------------------------------------------------------------------------- /tests/onnx2pytorch/operations/test_constantofshape.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import pytest 4 | 5 | from onnx2pytorch.operations.constantofshape import ConstantOfShape 6 | 7 | 8 | def test_constantofshape_float_ones(): 9 | op = ConstantOfShape() 10 | x = torch.tensor([4, 3, 2], dtype=torch.int64) 11 | y = torch.ones(*x, dtype=torch.float32) 12 | assert torch.equal(op(x), y) 13 | 14 | 15 | def test_constantofshape_int32_shape_zero(): 16 | constant = torch.tensor([0], dtype=torch.int32) 17 | op = ConstantOfShape(constant=constant) 18 | x = torch.tensor([0], dtype=torch.int64) 19 | y = torch.zeros(*x, dtype=torch.int32) 20 | assert torch.equal(op(x), y) 21 | 22 | 23 | def test_constantofshape_int32_zeros(): 24 | constant = torch.tensor([0], dtype=torch.int32) 25 | op = ConstantOfShape(constant=constant) 26 | x = torch.tensor([10, 6], dtype=torch.int64) 27 | y = torch.zeros(*x, dtype=torch.int32) 28 | assert torch.equal(op(x), y) 29 | -------------------------------------------------------------------------------- /tests/onnx2pytorch/operations/test_div.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytest 3 | 4 | from onnx2pytorch.operations.div import Div 5 | 6 | 7 | def test_div(): 8 | op = Div() 9 | x = torch.tensor([3, 4], dtype=torch.float32) 10 | y = torch.tensor([1, 2], dtype=torch.float32) 11 | z = x / y 12 | assert torch.equal(op(x, y), z) 13 | 14 | x = torch.randn(3, 4, 5) 15 | y = torch.rand(3, 4, 5) + 1.0 16 | z = x / y 17 | assert torch.equal(op(x, y), z) 18 | 19 | x = torch.randint(24, size=(3, 4, 5), dtype=torch.uint8) 20 | y = torch.randint(24, size=(3, 4, 5), dtype=torch.uint8) + 1 21 | z = x // y 22 | assert torch.equal(op(x, y), z) 23 | 24 | 25 | def test_div_broadcast(): 26 | op = Div() 27 | x = torch.randn(3, 4, 5, dtype=torch.float32) 28 | y = torch.rand(5, dtype=torch.float32) + 1.0 29 | z = x / y 30 | assert torch.equal(op(x, y), z) 31 | -------------------------------------------------------------------------------- /tests/onnx2pytorch/operations/test_expand.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytest 3 | 4 | from onnx2pytorch.operations.expand import Expand 5 | 6 | 7 | def test_expand_dim_changed(): 8 | op = Expand() 9 | inp = torch.reshape(torch.arange(0, 3, dtype=torch.float32), [3, 1]) 10 | new_shape = [2, 1, 6] 11 | exp = inp * torch.ones(new_shape) 12 | exp_shape = (2, 3, 6) 13 | assert tuple(op(inp, new_shape).shape) == exp_shape 14 | assert torch.equal(op(inp, new_shape), exp) 15 | 16 | 17 | def test_expand_dim_unchanged(): 18 | op = Expand() 19 | inp = torch.reshape(torch.arange(0, 3, dtype=torch.int32), [3, 1]) 20 | new_shape = [3, 4] 21 | exp = torch.cat([inp] * 4, dim=1) 22 | exp_shape = (3, 4) 23 | ret = op(inp, new_shape) 24 | assert tuple(ret.shape) == exp_shape 25 | assert torch.equal(ret, exp) 26 | -------------------------------------------------------------------------------- /tests/onnx2pytorch/operations/test_flatten.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytest 3 | 4 | from onnx2pytorch.operations import Flatten 5 | 6 | 7 | @pytest.fixture 8 | def inp(): 9 | return torch.rand(1, 3, 10, 10) 10 | 11 | 12 | def test_flatten(inp): 13 | """Pass padding in initialization and in forward pass.""" 14 | op = Flatten(1, 2) 15 | out = op(inp) 16 | assert list(out.shape) == [1, 30, 10] 17 | -------------------------------------------------------------------------------- /tests/onnx2pytorch/operations/test_gather.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import torch 4 | 5 | from onnx2pytorch.operations.gather import Gather 6 | 7 | 8 | def test_gather_0(): 9 | op = Gather(dim=0) 10 | data = torch.randn(5, 4, 3, 2, dtype=torch.float32) 11 | indices = torch.tensor([0, 1, 3], dtype=torch.int64) 12 | output_np = np.take(data.detach().numpy(), indices.detach().numpy(), axis=0) 13 | exp_output = torch.from_numpy(output_np).to(dtype=torch.float32) 14 | assert torch.equal(op(data, indices), exp_output) 15 | 16 | 17 | def test_gather_1(): 18 | op = Gather(dim=1) 19 | data = torch.randn(5, 4, 3, 2, dtype=torch.float32) 20 | indices = torch.tensor([0, 1, 3], dtype=torch.int64) 21 | output_np = np.take(data.detach().numpy(), indices.detach().numpy(), axis=1) 22 | exp_output = torch.from_numpy(output_np).to(dtype=torch.float32) 23 | assert torch.equal(op(data, indices), exp_output) 24 | 25 | 26 | def test_gather_2d_indices(): 27 | op = Gather(dim=1) 28 | data = torch.randn(3, 3, dtype=torch.float32) 29 | indices = torch.tensor([[0, 2]]) 30 | output_np = np.take(data.detach().numpy(), indices.detach().numpy(), axis=1) 31 | exp_output = torch.from_numpy(output_np).to(dtype=torch.float32) 32 | assert torch.equal(op(data, indices), exp_output) 33 | -------------------------------------------------------------------------------- /tests/onnx2pytorch/operations/test_gathernd.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from onnx2pytorch.operations.gathernd import GatherND 5 | 6 | 7 | def test_gathernd_float32(): 8 | op = GatherND(batch_dims=0) 9 | data = torch.tensor([[[0, 1], [2, 3]], [[4, 5], [6, 7]]], dtype=torch.float32) 10 | indices = torch.tensor([[[0, 1]], [[1, 0]]], dtype=torch.int64) 11 | exp_output = torch.tensor([[[2, 3]], [[4, 5]]], dtype=torch.float32) 12 | assert torch.equal(op(data, indices), exp_output) 13 | 14 | 15 | def test_gathernd_int32(): 16 | op = GatherND(batch_dims=0) 17 | data = torch.tensor([[0, 1], [2, 3]], dtype=torch.int32) 18 | indices = torch.tensor([[0, 0], [1, 1]], dtype=torch.int64) 19 | exp_output = torch.tensor([0, 3], dtype=torch.int32) 20 | assert torch.equal(op(data, indices), exp_output) 21 | 22 | 23 | @pytest.mark.skip("GatherND batch_dims > 0 not implemented yet") 24 | def test_gathernd_int32_batch_dim1(): 25 | op = GatherND(batch_dims=1) 26 | data = torch.tensor([[[0, 1], [2, 3]], [[4, 5], [6, 7]]], dtype=torch.int32) 27 | indices = torch.tensor([[1], [0]], dtype=torch.int64) 28 | exp_output = torch.tensor([[2, 3], [4, 5]], dtype=torch.int32) 29 | assert torch.equal(op(data, indices), exp_output) 30 | -------------------------------------------------------------------------------- /tests/onnx2pytorch/operations/test_globalaveragepool.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytest 3 | 4 | from onnx2pytorch.operations.globalaveragepool import GlobalAveragePool 5 | 6 | 7 | def test_globalaveragepool_2d(): 8 | op = GlobalAveragePool() 9 | x = torch.tensor([[[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]], dtype=torch.float32) 10 | y = torch.tensor([[[[5]]]], dtype=torch.float32) 11 | assert torch.equal(op(x), y) 12 | 13 | 14 | def test_globalaveragepool_3d(): 15 | op = GlobalAveragePool() 16 | x = torch.tensor( 17 | [ 18 | [ 19 | [ 20 | [[1, 1], [2, 2], [3, 3]], 21 | [[4, 4], [5, 5], [6, 6]], 22 | [[7, 7], [8, 8], [9, 9]], 23 | ] 24 | ] 25 | ], 26 | dtype=torch.float32, 27 | ) 28 | y = torch.tensor([[[[[5]]]]], dtype=torch.float32) 29 | assert torch.equal(op(x), y) 30 | 31 | 32 | def test_globalaveragepool_channels(): 33 | op = GlobalAveragePool() 34 | x = torch.tensor([[[[0, 0]], [[1, 1]]]], dtype=torch.float32) 35 | y = torch.tensor([[[[0]], [[1]]]], dtype=torch.float32) 36 | assert torch.equal(op(x), y) 37 | -------------------------------------------------------------------------------- /tests/onnx2pytorch/operations/test_hardsigmoid.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import MagicMock 2 | 3 | import numpy as np 4 | import onnx 5 | import torch 6 | import pytest 7 | 8 | from onnx2pytorch.convert.operations import convert_operations 9 | from onnx2pytorch.operations import Hardsigmoid 10 | 11 | 12 | @pytest.fixture 13 | def x(): 14 | return np.random.randn(3, 4, 5).astype(np.float32) 15 | 16 | 17 | def test_hardsigmoid(x): 18 | alpha = 1 / 6 19 | beta = 1 / 2 20 | op = Hardsigmoid(alpha=alpha, beta=beta) 21 | # For pytorch's default values it should use torch's Hardsigmoid 22 | assert isinstance(op, torch.nn.Hardsigmoid) 23 | x = np.random.randn(3, 4, 5).astype(np.float32) 24 | y = np.clip(x * alpha + beta, 0, 1) 25 | out = op(torch.from_numpy(x)) 26 | np.testing.assert_allclose(out, torch.from_numpy(y), rtol=1e-6, atol=1e-6) 27 | 28 | 29 | def test_hardsigmoid_with_custom_alpha_and_beta(x): 30 | alpha = 0.2 31 | beta = 0.5 32 | op = Hardsigmoid(alpha=alpha, beta=beta) 33 | assert not isinstance(op, torch.nn.Hardsigmoid) 34 | y = np.clip(x * alpha + beta, 0, 1) 35 | out = op(torch.from_numpy(x)) 36 | np.testing.assert_allclose(out, torch.from_numpy(y), rtol=1e-6, atol=1e-6) 37 | 38 | 39 | def test_hardsigmoid_conversion(): 40 | alpha = np.float32(0.2) 41 | beta = np.float32(0.5) 42 | node = onnx.helper.make_node( 43 | "HardSigmoid", 44 | inputs=["x"], 45 | outputs=["y"], 46 | alpha=alpha, 47 | beta=beta, 48 | ) 49 | 50 | graph = MagicMock() 51 | graph.initializers = [] 52 | graph.node = [node] 53 | converted_ops = list(convert_operations(graph, 10)) 54 | op_id, op_name, op = converted_ops[0] 55 | assert isinstance(op, Hardsigmoid) 56 | assert op.alpha == alpha 57 | assert op.beta == beta 58 | -------------------------------------------------------------------------------- /tests/onnx2pytorch/operations/test_instancenorm.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import pytest 4 | 5 | from onnx2pytorch.operations import InstanceNormWrapper 6 | 7 | 8 | def instancenorm_reference(x, s, bias, eps): 9 | dims_x = len(x.shape) 10 | axis = tuple(range(2, dims_x)) 11 | mean = np.mean(x, axis=axis, keepdims=True) 12 | var = np.var(x, axis=axis, keepdims=True) 13 | dim_ones = (1,) * (dims_x - 2) 14 | s = s.reshape(-1, *dim_ones) 15 | bias = bias.reshape(-1, *dim_ones) 16 | return s * (x - mean) / np.sqrt(var + eps) + bias 17 | 18 | 19 | @pytest.fixture 20 | def x_np(): 21 | # input size: (1, 2, 1, 3) 22 | return np.array([[[[-1, 0, 1]], [[2, 3, 4]]]]).astype(np.float32) 23 | 24 | 25 | @pytest.fixture 26 | def s_np(): 27 | return np.array([1.0, 1.5]).astype(np.float32) 28 | 29 | 30 | @pytest.fixture 31 | def b_np(): 32 | return np.array([0, 1]).astype(np.float32) 33 | 34 | 35 | def test_instancenorm(x_np, s_np, b_np): 36 | eps = 1e-5 37 | x = torch.from_numpy(x_np) 38 | s = torch.from_numpy(s_np) 39 | b = torch.from_numpy(b_np) 40 | 41 | exp_y = instancenorm_reference(x_np, s_np, b_np, eps).astype(np.float32) 42 | exp_y_shape = (1, 2, 1, 3) 43 | op = InstanceNormWrapper([s, b], eps=eps) 44 | y = op(x) 45 | 46 | assert y.shape == exp_y_shape 47 | assert np.allclose(y.detach().numpy(), exp_y, rtol=1e-5, atol=1e-5) 48 | 49 | 50 | def test_instancenorm_lazy(x_np, s_np, b_np): 51 | eps = 1e-5 52 | x = torch.from_numpy(x_np) 53 | s = torch.from_numpy(s_np) 54 | b = torch.from_numpy(b_np) 55 | 56 | exp_y = instancenorm_reference(x_np, s_np, b_np, eps).astype(np.float32) 57 | exp_y_shape = (1, 2, 1, 3) 58 | op = InstanceNormWrapper([], eps=eps) 59 | y = op(x, s, b) 60 | 61 | assert y.shape == exp_y_shape 62 | assert np.allclose(y.detach().numpy(), exp_y, rtol=1e-5, atol=1e-5) 63 | -------------------------------------------------------------------------------- /tests/onnx2pytorch/operations/test_nonmaxsuppression.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from onnx2pytorch.operations.nonmaxsuppression import NonMaxSuppression 5 | 6 | 7 | def test_nonmaxsuppression_center_point_box_format(): 8 | op = NonMaxSuppression(center_point_box=1) 9 | boxes = torch.tensor( 10 | [ 11 | [ 12 | [0.5, 0.5, 1.0, 1.0], 13 | [0.5, 0.6, 1.0, 1.0], 14 | [0.5, 0.4, 1.0, 1.0], 15 | [0.5, 10.5, 1.0, 1.0], 16 | [0.5, 10.6, 1.0, 1.0], 17 | [0.5, 100.5, 1.0, 1.0], 18 | ] 19 | ], 20 | dtype=torch.float32, 21 | ) 22 | scores = torch.tensor([[[0.9, 0.75, 0.6, 0.95, 0.5, 0.3]]], dtype=torch.float32) 23 | max_output_boxes_per_class = torch.tensor([3], dtype=torch.int64) 24 | iou_threshold = torch.tensor([0.5], dtype=torch.float32) 25 | score_threshold = torch.tensor([0.0], dtype=torch.float32) 26 | exp_selected_indices = torch.tensor( 27 | [[0, 0, 3], [0, 0, 0], [0, 0, 5]], dtype=torch.int64 28 | ) 29 | selected_indices = op( 30 | boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold 31 | ) 32 | assert torch.equal(selected_indices, exp_selected_indices) 33 | 34 | 35 | def test_nonmaxsuppression_flipped_coordinates(): 36 | op = NonMaxSuppression() 37 | boxes = torch.tensor( 38 | [ 39 | [ 40 | [1.0, 1.0, 0.0, 0.0], 41 | [0.0, 0.1, 1.0, 1.1], 42 | [0.0, 0.9, 1.0, -0.1], 43 | [0.0, 10.0, 1.0, 11.0], 44 | [1.0, 10.1, 0.0, 11.1], 45 | [1.0, 101.0, 0.0, 100.0], 46 | ] 47 | ], 48 | dtype=torch.float32, 49 | ) 50 | scores = torch.tensor([[[0.9, 0.75, 0.6, 0.95, 0.5, 0.3]]], dtype=torch.float32) 51 | max_output_boxes_per_class = torch.tensor([3], dtype=torch.int64) 52 | iou_threshold = torch.tensor([0.5], dtype=torch.float32) 53 | score_threshold = torch.tensor([0.0], dtype=torch.float32) 54 | exp_selected_indices = torch.tensor( 55 | [[0, 0, 3], [0, 0, 0], [0, 0, 5]], dtype=torch.int64 56 | ) 57 | selected_indices = op( 58 | boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold 59 | ) 60 | assert torch.equal(selected_indices, exp_selected_indices) 61 | 62 | 63 | def test_nonmaxsuppression_identical_boxes(): 64 | op = NonMaxSuppression() 65 | boxes = torch.tensor( 66 | [ 67 | [ 68 | [0.0, 0.0, 1.0, 1.0], 69 | [0.0, 0.0, 1.0, 1.0], 70 | [0.0, 0.0, 1.0, 1.0], 71 | [0.0, 0.0, 1.0, 1.0], 72 | [0.0, 0.0, 1.0, 1.0], 73 | [0.0, 0.0, 1.0, 1.0], 74 | [0.0, 0.0, 1.0, 1.0], 75 | [0.0, 0.0, 1.0, 1.0], 76 | [0.0, 0.0, 1.0, 1.0], 77 | [0.0, 0.0, 1.0, 1.0], 78 | ] 79 | ], 80 | dtype=torch.float32, 81 | ) 82 | scores = torch.tensor( 83 | [[[0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9]]], dtype=torch.float32 84 | ) 85 | max_output_boxes_per_class = torch.tensor([3], dtype=torch.int64) 86 | iou_threshold = torch.tensor([0.5], dtype=torch.float32) 87 | score_threshold = torch.tensor([0.0], dtype=torch.float32) 88 | exp_selected_indices = torch.tensor([[0, 0, 0]], dtype=torch.int64) 89 | selected_indices = op( 90 | boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold 91 | ) 92 | assert torch.equal(selected_indices, exp_selected_indices) 93 | 94 | 95 | def test_nonmaxsuppression_limit_output_size(): 96 | op = NonMaxSuppression() 97 | boxes = torch.tensor( 98 | [ 99 | [ 100 | [0.0, 0.0, 1.0, 1.0], 101 | [0.0, 0.1, 1.0, 1.1], 102 | [0.0, -0.1, 1.0, 0.9], 103 | [0.0, 10.0, 1.0, 11.0], 104 | [0.0, 10.1, 1.0, 11.1], 105 | [0.0, 100.0, 1.0, 101.0], 106 | ] 107 | ], 108 | dtype=torch.float32, 109 | ) 110 | scores = torch.tensor([[[0.9, 0.75, 0.6, 0.95, 0.5, 0.3]]], dtype=torch.float32) 111 | max_output_boxes_per_class = torch.tensor([2], dtype=torch.int64) 112 | iou_threshold = torch.tensor([0.5], dtype=torch.float32) 113 | score_threshold = torch.tensor([0.0], dtype=torch.float32) 114 | exp_selected_indices = torch.tensor([[0, 0, 3], [0, 0, 0]], dtype=torch.int64) 115 | selected_indices = op( 116 | boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold 117 | ) 118 | assert torch.equal(selected_indices, exp_selected_indices) 119 | 120 | 121 | def test_nonmaxsuppression_single_box(): 122 | op = NonMaxSuppression() 123 | boxes = torch.tensor([[[0.0, 0.0, 1.0, 1.0]]], dtype=torch.float32) 124 | scores = torch.tensor([[[0.9]]], dtype=torch.float32) 125 | max_output_boxes_per_class = torch.tensor([3], dtype=torch.int64) 126 | iou_threshold = torch.tensor([0.5], dtype=torch.float32) 127 | score_threshold = torch.tensor([0.0], dtype=torch.float32) 128 | exp_selected_indices = torch.tensor([[0, 0, 0]], dtype=torch.int64) 129 | selected_indices = op( 130 | boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold 131 | ) 132 | assert torch.equal(selected_indices, exp_selected_indices) 133 | 134 | 135 | def test_nonmaxsuppression_suppress_by_IOU(): 136 | op = NonMaxSuppression() 137 | boxes = torch.tensor( 138 | [ 139 | [ 140 | [0.0, 0.0, 1.0, 1.0], 141 | [0.0, 0.1, 1.0, 1.1], 142 | [0.0, -0.1, 1.0, 0.9], 143 | [0.0, 10.0, 1.0, 11.0], 144 | [0.0, 10.1, 1.0, 11.1], 145 | [0.0, 100.0, 1.0, 101.0], 146 | ] 147 | ], 148 | dtype=torch.float32, 149 | ) 150 | scores = torch.tensor([[[0.9, 0.75, 0.6, 0.95, 0.5, 0.3]]], dtype=torch.float32) 151 | max_output_boxes_per_class = torch.tensor([3], dtype=torch.int64) 152 | iou_threshold = torch.tensor([0.5], dtype=torch.float32) 153 | score_threshold = torch.tensor([0.0], dtype=torch.float32) 154 | exp_selected_indices = torch.tensor( 155 | [[0, 0, 3], [0, 0, 0], [0, 0, 5]], dtype=torch.int64 156 | ) 157 | selected_indices = op( 158 | boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold 159 | ) 160 | assert torch.equal(selected_indices, exp_selected_indices) 161 | 162 | 163 | def test_nonmaxsuppression_suppress_by_IOU_and_scores(): 164 | op = NonMaxSuppression() 165 | boxes = torch.tensor( 166 | [ 167 | [ 168 | [0.0, 0.0, 1.0, 1.0], 169 | [0.0, 0.1, 1.0, 1.1], 170 | [0.0, -0.1, 1.0, 0.9], 171 | [0.0, 10.0, 1.0, 11.0], 172 | [0.0, 10.1, 1.0, 11.1], 173 | [0.0, 100.0, 1.0, 101.0], 174 | ] 175 | ], 176 | dtype=torch.float32, 177 | ) 178 | scores = torch.tensor([[[0.9, 0.75, 0.6, 0.95, 0.5, 0.3]]], dtype=torch.float32) 179 | max_output_boxes_per_class = torch.tensor([3], dtype=torch.int64) 180 | iou_threshold = torch.tensor([0.5], dtype=torch.float32) 181 | score_threshold = torch.tensor([0.4], dtype=torch.float32) 182 | exp_selected_indices = torch.tensor([[0, 0, 3], [0, 0, 0]], dtype=torch.int64) 183 | selected_indices = op( 184 | boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold 185 | ) 186 | assert torch.equal(selected_indices, exp_selected_indices) 187 | 188 | 189 | def test_nonmaxsuppression_two_batches(): 190 | op = NonMaxSuppression() 191 | boxes = torch.tensor( 192 | [ 193 | [ 194 | [0.0, 0.0, 1.0, 1.0], 195 | [0.0, 0.1, 1.0, 1.1], 196 | [0.0, -0.1, 1.0, 0.9], 197 | [0.0, 10.0, 1.0, 11.0], 198 | [0.0, 10.1, 1.0, 11.1], 199 | [0.0, 100.0, 1.0, 101.0], 200 | ], 201 | [ 202 | [0.0, 0.0, 1.0, 1.0], 203 | [0.0, 0.1, 1.0, 1.1], 204 | [0.0, -0.1, 1.0, 0.9], 205 | [0.0, 10.0, 1.0, 11.0], 206 | [0.0, 10.1, 1.0, 11.1], 207 | [0.0, 100.0, 1.0, 101.0], 208 | ], 209 | ], 210 | dtype=torch.float32, 211 | ) 212 | scores = torch.tensor( 213 | [[[0.9, 0.75, 0.6, 0.95, 0.5, 0.3]], [[0.9, 0.75, 0.6, 0.95, 0.5, 0.3]]], 214 | dtype=torch.float32, 215 | ) 216 | max_output_boxes_per_class = torch.tensor([2], dtype=torch.int64) 217 | iou_threshold = torch.tensor([0.5], dtype=torch.float32) 218 | score_threshold = torch.tensor([0.0], dtype=torch.float32) 219 | exp_selected_indices = torch.tensor( 220 | [[0, 0, 3], [0, 0, 0], [1, 0, 3], [1, 0, 0]], dtype=torch.int64 221 | ) 222 | selected_indices = op( 223 | boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold 224 | ) 225 | assert torch.equal(selected_indices, exp_selected_indices) 226 | 227 | 228 | def test_nonmaxsuppression_two_classes(): 229 | op = NonMaxSuppression() 230 | boxes = torch.tensor( 231 | [ 232 | [ 233 | [0.0, 0.0, 1.0, 1.0], 234 | [0.0, 0.1, 1.0, 1.1], 235 | [0.0, -0.1, 1.0, 0.9], 236 | [0.0, 10.0, 1.0, 11.0], 237 | [0.0, 10.1, 1.0, 11.1], 238 | [0.0, 100.0, 1.0, 101.0], 239 | ] 240 | ], 241 | dtype=torch.float32, 242 | ) 243 | scores = torch.tensor( 244 | [[[0.9, 0.75, 0.6, 0.95, 0.5, 0.3], [0.9, 0.75, 0.6, 0.95, 0.5, 0.3]]], 245 | dtype=torch.float32, 246 | ) 247 | max_output_boxes_per_class = torch.tensor([2], dtype=torch.int64) 248 | iou_threshold = torch.tensor([0.5], dtype=torch.float32) 249 | score_threshold = torch.tensor([0.0], dtype=torch.float32) 250 | exp_selected_indices = torch.tensor( 251 | [[0, 0, 3], [0, 0, 0], [0, 1, 3], [0, 1, 0]], dtype=torch.int64 252 | ) 253 | selected_indices = op( 254 | boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold 255 | ) 256 | assert torch.equal(selected_indices, exp_selected_indices) 257 | -------------------------------------------------------------------------------- /tests/onnx2pytorch/operations/test_onehot.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytest 3 | import numpy as np 4 | from onnx.backend.test.case.node.onehot import one_hot 5 | 6 | from onnx2pytorch.operations import OneHot 7 | 8 | 9 | @pytest.mark.parametrize("axis", [1, -2]) 10 | @pytest.mark.parametrize( 11 | "indices", 12 | [ 13 | torch.tensor([[1, 9], [2, 4]], dtype=torch.float32), 14 | torch.tensor([0, 7, 8], dtype=torch.int64), 15 | ], 16 | ) 17 | def test_onehot(indices, axis): 18 | on_value = 3 19 | off_value = 1 20 | output_type = torch.float32 21 | depth = torch.tensor([10], dtype=torch.float32) 22 | values = torch.tensor([off_value, on_value], dtype=output_type) 23 | y = one_hot(indices.numpy(), depth.numpy(), axis=axis, dtype=np.float32) 24 | y = y * (on_value - off_value) + off_value 25 | y = torch.from_numpy(y) 26 | 27 | op = OneHot(axis) 28 | out = op(indices, depth, values) 29 | assert torch.equal(y, out) 30 | -------------------------------------------------------------------------------- /tests/onnx2pytorch/operations/test_pad.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytest 3 | 4 | from onnx2pytorch.operations import Pad 5 | 6 | 7 | @pytest.fixture 8 | def inp(): 9 | return torch.rand(1, 3, 10, 10) 10 | 11 | 12 | @pytest.mark.parametrize("init", [True, False]) 13 | @pytest.mark.parametrize( 14 | "pads, new_shape", 15 | [ 16 | ([1, 1], [1, 3, 10, 12]), 17 | ([1, 1, 2, 2], [1, 3, 14, 12]), 18 | ([1, 1, 2, 2, 3, 3, 4, 4], [9, 9, 14, 12]), 19 | ], 20 | ) 21 | def test_pad(inp, pads, new_shape, init): 22 | """Pass padding in initialization and in forward pass.""" 23 | if init: 24 | op = Pad(padding=pads) 25 | out = op(inp) 26 | else: 27 | op = Pad() 28 | out = op(inp, pads) 29 | assert list(out.shape) == new_shape 30 | 31 | 32 | def test_pad_raise_error(inp): 33 | op = Pad() 34 | 35 | # padding should be passed either in init or forward 36 | with pytest.raises(TypeError): 37 | op(inp) 38 | -------------------------------------------------------------------------------- /tests/onnx2pytorch/operations/test_prelu.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from onnx2pytorch.operations.prelu import PRelu 5 | 6 | 7 | def test_prelu(): 8 | op = PRelu() 9 | x = torch.randn(3, 4, 5, dtype=torch.float32) 10 | slope = torch.randn(3, 4, 5, dtype=torch.float32) 11 | exp_y = torch.maximum(torch.zeros_like(x), x) + slope * torch.minimum( 12 | torch.zeros_like(x), x 13 | ) 14 | assert torch.equal(op(x, slope), exp_y) 15 | 16 | 17 | def test_prelu_broadcast(): 18 | op = PRelu() 19 | x = torch.randn(3, 4, 5, dtype=torch.float32) 20 | slope = torch.randn(5, dtype=torch.float32) 21 | exp_y = torch.clamp(x, min=0) + torch.clamp(x, max=0) * slope 22 | assert torch.equal(op(x, slope), exp_y) 23 | -------------------------------------------------------------------------------- /tests/onnx2pytorch/operations/test_range.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytest 3 | 4 | from onnx2pytorch.operations.range import Range 5 | 6 | 7 | @pytest.mark.parametrize( 8 | "start, limit, delta, expected", 9 | [ 10 | (3, 9, 3, torch.tensor([3, 6])), 11 | (10, 4, -2, torch.tensor([10, 8, 6])), 12 | (10, 6, -3, torch.tensor([10, 7])), 13 | ], 14 | ) 15 | def test_range(start, limit, delta, expected): 16 | op = Range() 17 | assert torch.equal(op(start, limit, delta), expected) 18 | -------------------------------------------------------------------------------- /tests/onnx2pytorch/operations/test_reducel2.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import onnx 3 | import pytest 4 | import torch 5 | 6 | from onnx2pytorch.convert.operations import convert_operations 7 | from onnx2pytorch.operations import ReduceL2 8 | 9 | 10 | @pytest.fixture 11 | def tensor(): 12 | return torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) 13 | 14 | 15 | def test_reduce_l2_older_opset_version(tensor): 16 | shape = [3, 2, 2] 17 | axes = np.array([2], dtype=np.int64) 18 | keepdims = 0 19 | 20 | data = np.reshape(np.arange(1, np.prod(shape) + 1, dtype=np.float32), shape) 21 | op = ReduceL2(opset_version=10, keepdim=keepdims, dim=axes) 22 | 23 | reduced = np.sqrt( 24 | np.sum(a=np.square(data), axis=tuple(axes), keepdims=keepdims == 1) 25 | ) 26 | 27 | out = op(torch.from_numpy(data), axes=axes) 28 | np.testing.assert_array_equal(out, reduced) 29 | 30 | 31 | def test_do_not_keepdims_older_opset_version() -> None: 32 | opset_version = 10 33 | shape = [3, 2, 2] 34 | axes = np.array([2], dtype=np.int64) 35 | keepdims = 0 36 | 37 | node = onnx.helper.make_node( 38 | "ReduceL2", 39 | inputs=["data"], 40 | outputs=["reduced"], 41 | keepdims=keepdims, 42 | axes=axes, 43 | ) 44 | graph = onnx.helper.make_graph([node], "test_reduce_l2_do_not_keepdims", [], []) 45 | 46 | ops = list(convert_operations(graph, opset_version)) 47 | op = ops[0][2] 48 | 49 | assert isinstance(op, ReduceL2) 50 | 51 | data = np.reshape(np.arange(1, np.prod(shape) + 1, dtype=np.float32), shape) 52 | # print(data) 53 | # [[[1., 2.], [3., 4.]], [[5., 6.], [7., 8.]], [[9., 10.], [11., 12.]]] 54 | 55 | reduced = np.sqrt( 56 | np.sum(a=np.square(data), axis=tuple(axes), keepdims=keepdims == 1) 57 | ) 58 | # print(reduced) 59 | # [[2.23606798, 5.], 60 | # [7.81024968, 10.63014581], 61 | # [13.45362405, 16.2788206]] 62 | 63 | out = op(torch.from_numpy(data)) 64 | np.testing.assert_array_equal(out, reduced) 65 | 66 | np.random.seed(0) 67 | data = np.random.uniform(-10, 10, shape).astype(np.float32) 68 | reduced = np.sqrt( 69 | np.sum(a=np.square(data), axis=tuple(axes), keepdims=keepdims == 1) 70 | ) 71 | 72 | out = op(torch.from_numpy(data)) 73 | np.testing.assert_array_equal(out, reduced) 74 | 75 | 76 | def test_do_not_keepdims() -> None: 77 | shape = [3, 2, 2] 78 | axes = np.array([2], dtype=np.int64) 79 | keepdims = 0 80 | 81 | node = onnx.helper.make_node( 82 | "ReduceL2", 83 | inputs=["data", "axes"], 84 | outputs=["reduced"], 85 | keepdims=keepdims, 86 | ) 87 | graph = onnx.helper.make_graph([node], "test_reduce_l2_do_not_keepdims", [], []) 88 | ops = list(convert_operations(graph, 18)) 89 | op = ops[0][2] 90 | 91 | assert isinstance(op, ReduceL2) 92 | 93 | data = np.reshape(np.arange(1, np.prod(shape) + 1, dtype=np.float32), shape) 94 | # print(data) 95 | # [[[1., 2.], [3., 4.]], [[5., 6.], [7., 8.]], [[9., 10.], [11., 12.]]] 96 | 97 | reduced = np.sqrt( 98 | np.sum(a=np.square(data), axis=tuple(axes), keepdims=keepdims == 1) 99 | ) 100 | # print(reduced) 101 | # [[2.23606798, 5.], 102 | # [7.81024968, 10.63014581], 103 | # [13.45362405, 16.2788206]] 104 | 105 | out = op(torch.from_numpy(data), axes=axes) 106 | np.testing.assert_array_equal(out, reduced) 107 | 108 | np.random.seed(0) 109 | data = np.random.uniform(-10, 10, shape).astype(np.float32) 110 | reduced = np.sqrt( 111 | np.sum(a=np.square(data), axis=tuple(axes), keepdims=keepdims == 1) 112 | ) 113 | 114 | out = op(torch.from_numpy(data), axes=axes) 115 | np.testing.assert_array_equal(out, reduced) 116 | 117 | 118 | def test_export_keepdims() -> None: 119 | shape = [3, 2, 2] 120 | axes = np.array([2], dtype=np.int64) 121 | keepdims = 1 122 | 123 | node = onnx.helper.make_node( 124 | "ReduceL2", 125 | inputs=["data", "axes"], 126 | outputs=["reduced"], 127 | keepdims=keepdims, 128 | ) 129 | graph = onnx.helper.make_graph([node], "test_reduce_l2_do_not_keepdims", [], []) 130 | ops = list(convert_operations(graph, 18)) 131 | op = ops[0][2] 132 | 133 | data = np.reshape(np.arange(1, np.prod(shape) + 1, dtype=np.float32), shape) 134 | # print(data) 135 | # [[[1., 2.], [3., 4.]], [[5., 6.], [7., 8.]], [[9., 10.], [11., 12.]]] 136 | 137 | reduced = np.sqrt( 138 | np.sum(a=np.square(data), axis=tuple(axes), keepdims=keepdims == 1) 139 | ) 140 | # print(reduced) 141 | # [[[2.23606798], [5.]] 142 | # [[7.81024968], [10.63014581]] 143 | # [[13.45362405], [16.2788206 ]]] 144 | 145 | out = op(torch.from_numpy(data), axes=axes) 146 | np.testing.assert_array_equal(out, reduced) 147 | 148 | np.random.seed(0) 149 | data = np.random.uniform(-10, 10, shape).astype(np.float32) 150 | reduced = np.sqrt( 151 | np.sum(a=np.square(data), axis=tuple(axes), keepdims=keepdims == 1) 152 | ) 153 | 154 | out = op(torch.from_numpy(data), axes=axes) 155 | np.testing.assert_array_equal(out, reduced) 156 | 157 | 158 | def test_export_default_axes_keepdims() -> None: 159 | shape = [3, 2, 2] 160 | axes = np.array([], dtype=np.int64) 161 | keepdims = 1 162 | 163 | node = onnx.helper.make_node( 164 | "ReduceL2", inputs=["data", "axes"], outputs=["reduced"], keepdims=keepdims 165 | ) 166 | graph = onnx.helper.make_graph([node], "test_reduce_l2_do_not_keepdims", [], []) 167 | ops = list(convert_operations(graph, 18)) 168 | op = ops[0][2] 169 | 170 | data = np.reshape(np.arange(1, np.prod(shape) + 1, dtype=np.float32), shape) 171 | # print(data) 172 | # [[[1., 2.], [3., 4.]], [[5., 6.], [7., 8.]], [[9., 10.], [11., 12.]]] 173 | 174 | reduced = np.sqrt(np.sum(a=np.square(data), axis=None, keepdims=keepdims == 1)) 175 | # print(reduced) 176 | # [[[25.49509757]]] 177 | 178 | out = op(torch.from_numpy(data), axes=axes) 179 | np.testing.assert_array_equal(out, reduced) 180 | 181 | np.random.seed(0) 182 | data = np.random.uniform(-10, 10, shape).astype(np.float32) 183 | reduced = np.sqrt(np.sum(a=np.square(data), axis=None, keepdims=keepdims == 1)) 184 | 185 | out = op(torch.from_numpy(data), axes=axes) 186 | np.testing.assert_array_equal(out, reduced) 187 | 188 | 189 | def test_export_negative_axes_keepdims() -> None: 190 | shape = [3, 2, 2] 191 | axes = np.array([-1], dtype=np.int64) 192 | keepdims = 1 193 | 194 | node = onnx.helper.make_node( 195 | "ReduceL2", 196 | inputs=["data", "axes"], 197 | outputs=["reduced"], 198 | keepdims=keepdims, 199 | ) 200 | graph = onnx.helper.make_graph([node], "test_reduce_l2_do_not_keepdims", [], []) 201 | ops = list(convert_operations(graph, 18)) 202 | op = ops[0][2] 203 | 204 | data = np.reshape(np.arange(1, np.prod(shape) + 1, dtype=np.float32), shape) 205 | # print(data) 206 | # [[[1., 2.], [3., 4.]], [[5., 6.], [7., 8.]], [[9., 10.], [11., 12.]]] 207 | 208 | reduced = np.sqrt( 209 | np.sum(a=np.square(data), axis=tuple(axes), keepdims=keepdims == 1) 210 | ) 211 | # print(reduced) 212 | # [[[2.23606798], [5.]] 213 | # [[7.81024968], [10.63014581]] 214 | # [[13.45362405], [16.2788206 ]]] 215 | 216 | out = op(torch.from_numpy(data), axes=axes) 217 | np.testing.assert_array_equal(out, reduced) 218 | 219 | np.random.seed(0) 220 | data = np.random.uniform(-10, 10, shape).astype(np.float32) 221 | reduced = np.sqrt( 222 | np.sum(a=np.square(data), axis=tuple(axes), keepdims=keepdims == 1) 223 | ) 224 | 225 | out = op(torch.from_numpy(data), axes=axes) 226 | np.testing.assert_array_equal(out, reduced) 227 | 228 | 229 | def test_export_empty_set() -> None: 230 | shape = [2, 0, 4] 231 | keepdims = 1 232 | reduced_shape = [2, 1, 4] 233 | 234 | node = onnx.helper.make_node( 235 | "ReduceL2", 236 | inputs=["data", "axes"], 237 | outputs=["reduced"], 238 | keepdims=keepdims, 239 | ) 240 | graph = onnx.helper.make_graph([node], "test_reduce_l2_do_not_keepdims", [], []) 241 | ops = list(convert_operations(graph, 18)) 242 | op = ops[0][2] 243 | 244 | data = np.array([], dtype=np.float32).reshape(shape) 245 | axes = np.array([1], dtype=np.int64) 246 | reduced = np.array(np.zeros(reduced_shape, dtype=np.float32)) 247 | 248 | out = op(torch.from_numpy(data), axes=axes) 249 | np.testing.assert_array_equal(out, reduced) 250 | -------------------------------------------------------------------------------- /tests/onnx2pytorch/operations/test_reducemax.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from onnx2pytorch.operations import ReduceMax 5 | 6 | 7 | @pytest.fixture 8 | def tensor(): 9 | return torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) 10 | 11 | 12 | def test_reduce_max_with_dim(tensor): 13 | reduce_max = ReduceMax(dim=0, keepdim=True) 14 | output = reduce_max(tensor) 15 | expected_output = torch.tensor([[7, 8, 9]]) 16 | 17 | assert output.ndim == tensor.ndim 18 | assert torch.equal(output, expected_output) 19 | 20 | 21 | def test_reduce_max(tensor): 22 | reduce_max = ReduceMax(keepdim=False) 23 | output = reduce_max(tensor) 24 | expected_output = torch.tensor(9) 25 | 26 | assert output.ndim == 0 27 | assert torch.equal(output, expected_output) 28 | -------------------------------------------------------------------------------- /tests/onnx2pytorch/operations/test_reducesum.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytest 3 | 4 | from onnx2pytorch.operations.reducesum import ReduceSum 5 | 6 | 7 | @pytest.fixture 8 | def inp(): 9 | return torch.tensor( 10 | [[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[9, 10], [11, 12]]], 11 | dtype=torch.float32, 12 | ) 13 | 14 | 15 | def test_reducesum_default_axes_keepdims(inp): 16 | op = ReduceSum(opset_version=13) 17 | exp = torch.tensor([[[78]]], dtype=torch.float32) 18 | assert torch.equal(op(inp), exp) 19 | 20 | 21 | def test_reducesum_do_not_keepdims(inp): 22 | op = ReduceSum(opset_version=13, keepdim=False) 23 | axes = torch.tensor([1]) 24 | exp = torch.tensor( 25 | [[4, 6], [12, 14], [20, 22]], 26 | dtype=torch.float32, 27 | ) 28 | assert torch.equal(op(inp, axes), exp) 29 | 30 | 31 | @pytest.mark.parametrize("keepdim", [True, False]) 32 | @pytest.mark.parametrize("axes", [None, torch.tensor([])]) 33 | def test_reducesum_empty_axes_input_noop(inp, keepdim, axes): 34 | op = ReduceSum(opset_version=13, keepdim=keepdim, noop_with_empty_axes=True) 35 | exp = inp 36 | assert torch.equal(op(inp), exp) 37 | -------------------------------------------------------------------------------- /tests/onnx2pytorch/operations/test_reshape.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytest 3 | 4 | from onnx2pytorch.operations import Reshape 5 | 6 | 7 | @pytest.fixture 8 | def inp(): 9 | return torch.rand(35, 1, 200) 10 | 11 | 12 | @pytest.fixture 13 | def pruned_inp(): 14 | return torch.rand(35, 1, 160) 15 | 16 | 17 | @pytest.mark.parametrize("enable_pruning", [True, False]) 18 | def test_reshape(inp, pruned_inp, enable_pruning): 19 | """Pass shape in forward.""" 20 | op = Reshape(enable_pruning=True) 21 | shape = torch.Size((35, 2, 100)) 22 | out = op(inp, shape) 23 | assert out.shape == shape 24 | 25 | # with the same input, the output shape should not change 26 | out = op(inp, shape) 27 | assert out.shape == shape 28 | 29 | # if input changes due to pruning, reshape should work 30 | # and output shape should change accordingly 31 | expected_shape = torch.Size((35, 2, 80)) 32 | out = op(pruned_inp, shape) 33 | assert out.shape == expected_shape 34 | 35 | 36 | @pytest.mark.parametrize("enable_pruning", [True, False]) 37 | def test_reshape_2(inp, pruned_inp, enable_pruning): 38 | """Pass shape in init.""" 39 | shape = torch.Size((35, 2, 100)) 40 | op = Reshape(enable_pruning=True, shape=shape) 41 | out = op(inp) 42 | assert out.shape == shape 43 | 44 | # input changes due to pruning, reshape should work 45 | expected_shape = torch.Size((35, 2, 80)) 46 | out = op(pruned_inp) 47 | assert out.shape == expected_shape 48 | -------------------------------------------------------------------------------- /tests/onnx2pytorch/operations/test_resize.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytest 3 | 4 | from onnx2pytorch.operations import Resize, Upsample 5 | 6 | 7 | @pytest.fixture 8 | def inp(): 9 | return torch.rand(1, 3, 10, 10) 10 | 11 | 12 | @pytest.mark.parametrize( 13 | "scales, new_shape", 14 | [ 15 | ([1, 1, 2, 2], [1, 3, 20, 20]), 16 | ([1, 1, 0.5, 0.5], [1, 3, 5, 5]), 17 | ], 18 | ) 19 | def test_resize_scales(inp, scales, new_shape): 20 | op = Resize() 21 | out = op(inp, scales=scales) 22 | assert list(out.shape) == new_shape 23 | 24 | 25 | @pytest.mark.parametrize("sizes", [[1, 3, 20, 20], [1, 3, 5, 5]]) 26 | def test_resize_sizes(inp, sizes): 27 | op = Resize() 28 | out = op(inp, sizes=sizes) 29 | assert list(out.shape) == sizes 30 | 31 | 32 | def test_resize_raise_error(inp): 33 | op = Resize() 34 | 35 | # cannot scale batch and channel dimension 36 | with pytest.raises(NotImplementedError): 37 | op(inp, scales=[2, 2, 1, 1]) 38 | with pytest.raises(NotImplementedError): 39 | op(inp, sizes=[2, 6, 10, 10]) 40 | 41 | # need to define scales or sizes 42 | with pytest.raises(ValueError): 43 | op(inp) 44 | 45 | # need to define only scales or sizes 46 | with pytest.raises(ValueError): 47 | op(inp, scales=[1, 1, 2, 2], sizes=[1, 3, 20, 20]) 48 | 49 | 50 | @pytest.mark.parametrize( 51 | "scales, new_shape", 52 | [ 53 | ([1, 1, 2, 2], [1, 3, 20, 20]), 54 | ([1, 1, 0.5, 0.5], [1, 3, 5, 5]), 55 | ], 56 | ) 57 | def test_upsample(inp, scales, new_shape): 58 | op = Upsample() 59 | out = op(inp, scales) 60 | assert list(out.shape) == new_shape 61 | -------------------------------------------------------------------------------- /tests/onnx2pytorch/operations/test_scatter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytest 3 | 4 | from onnx2pytorch.operations.scatter import Scatter 5 | 6 | 7 | def test_scatter_with_axis(): 8 | op = Scatter(dim=1) 9 | data = torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]], dtype=torch.float32) 10 | indices = torch.tensor([[1, 3]], dtype=torch.int64) 11 | updates = torch.tensor([[1.1, 2.1]], dtype=torch.float32) 12 | exp_output = torch.tensor([[1.0, 1.1, 3.0, 2.1, 5.0]], dtype=torch.float32) 13 | assert torch.equal(op(data, indices, updates), exp_output) 14 | 15 | 16 | def test_scatter_without_axis(): 17 | op = Scatter() 18 | data = torch.zeros((3, 3), dtype=torch.float32) 19 | indices = torch.tensor([[1, 0, 2], [0, 2, 1]], dtype=torch.int64) 20 | updates = torch.tensor([[1.0, 1.1, 1.2], [2.0, 2.1, 2.2]], dtype=torch.float32) 21 | exp_output = torch.tensor( 22 | [[2.0, 1.1, 0.0], [1.0, 0.0, 2.2], [0.0, 2.1, 1.2]], dtype=torch.float32 23 | ) 24 | assert torch.equal(op(data, indices, updates), exp_output) 25 | -------------------------------------------------------------------------------- /tests/onnx2pytorch/operations/test_scatterelements.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytest 3 | 4 | from onnx2pytorch.operations.scatterelements import ScatterElements 5 | 6 | 7 | def test_scatter_elements_with_axis(): 8 | op = ScatterElements(dim=1) 9 | data = torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]], dtype=torch.float32) 10 | indices = torch.tensor([[1, 3]], dtype=torch.int64) 11 | updates = torch.tensor([[1.1, 2.1]], dtype=torch.float32) 12 | exp_output = torch.tensor([[1.0, 1.1, 3.0, 2.1, 5.0]], dtype=torch.float32) 13 | output = op(data, indices, updates) 14 | assert torch.equal(output, exp_output) 15 | 16 | 17 | def test_scatter_elements_with_negative_indices(): 18 | op = ScatterElements(dim=1) 19 | data = torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]], dtype=torch.float32) 20 | indices = torch.tensor([[1, -3]], dtype=torch.int64) 21 | updates = torch.tensor([[1.1, 2.1]], dtype=torch.float32) 22 | exp_output = torch.tensor([[1.0, 1.1, 2.1, 4.0, 5.0]], dtype=torch.float32) 23 | output = op(data, indices, updates) 24 | assert torch.equal(output, exp_output) 25 | 26 | 27 | def test_scatter_elements_without_axis(): 28 | op = ScatterElements() 29 | data = torch.zeros((3, 3), dtype=torch.float32) 30 | indices = torch.tensor([[1, 0, 2], [0, 2, 1]], dtype=torch.int64) 31 | updates = torch.tensor([[1.0, 1.1, 1.2], [2.0, 2.1, 2.2]], dtype=torch.float32) 32 | exp_output = torch.tensor( 33 | [[2.0, 1.1, 0.0], [1.0, 0.0, 2.2], [0.0, 2.1, 1.2]], dtype=torch.float32 34 | ) 35 | output = op(data, indices, updates) 36 | assert torch.equal(output, exp_output) 37 | -------------------------------------------------------------------------------- /tests/onnx2pytorch/operations/test_scatternd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytest 3 | 4 | from onnx2pytorch.operations.scatternd import ScatterND 5 | 6 | 7 | @pytest.mark.parametrize( 8 | "data, indices, updates, exp_output", 9 | [ 10 | ( 11 | torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]), 12 | torch.tensor([[4], [3], [1], [7]]), 13 | torch.tensor([9, 10, 11, 12]), 14 | torch.tensor([1, 11, 3, 10, 9, 6, 7, 12]), 15 | ), 16 | ( 17 | torch.zeros((4, 4, 4), dtype=torch.int64), 18 | torch.tensor([[0, 1], [2, 3]]), 19 | torch.tensor([[5, 5, 5, 5], [6, 6, 6, 6]]), 20 | torch.tensor( 21 | [ 22 | [[0, 0, 0, 0], [5, 5, 5, 5], [0, 0, 0, 0], [0, 0, 0, 0]], 23 | [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], 24 | [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [6, 6, 6, 6]], 25 | [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], 26 | ] 27 | ), 28 | ), 29 | ( 30 | torch.tensor( 31 | [ 32 | [[1, 2, 3, 4], [5, 6, 7, 8], [8, 7, 6, 5], [4, 3, 2, 1]], 33 | [[1, 2, 3, 4], [5, 6, 7, 8], [8, 7, 6, 5], [4, 3, 2, 1]], 34 | [[8, 7, 6, 5], [4, 3, 2, 1], [1, 2, 3, 4], [5, 6, 7, 8]], 35 | [[8, 7, 6, 5], [4, 3, 2, 1], [1, 2, 3, 4], [5, 6, 7, 8]], 36 | ] 37 | ), 38 | torch.tensor([[0], [2]]), 39 | torch.tensor( 40 | [ 41 | [[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]], 42 | [[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3], [4, 4, 4, 4]], 43 | ] 44 | ), 45 | torch.tensor( 46 | [ 47 | [[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]], 48 | [[1, 2, 3, 4], [5, 6, 7, 8], [8, 7, 6, 5], [4, 3, 2, 1]], 49 | [[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3], [4, 4, 4, 4]], 50 | [[8, 7, 6, 5], [4, 3, 2, 1], [1, 2, 3, 4], [5, 6, 7, 8]], 51 | ] 52 | ), 53 | ), 54 | ], 55 | ) 56 | def test_scatternd(data, indices, updates, exp_output): 57 | op = ScatterND() 58 | assert torch.equal(op(data, indices, updates), exp_output) 59 | -------------------------------------------------------------------------------- /tests/onnx2pytorch/operations/test_slice.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import pytest 4 | 5 | from onnx2pytorch.operations import Slice 6 | from onnx2pytorch.operations.slice import _to_positive_step 7 | 8 | 9 | @pytest.fixture 10 | def x(): 11 | return torch.randn(20, 10, 5).to(torch.float32) 12 | 13 | 14 | @pytest.mark.parametrize("init", [True, False]) 15 | def test_slice_1(x, init): 16 | starts = torch.tensor([0, 0], dtype=torch.int64) 17 | ends = torch.tensor([3, 10], dtype=torch.int64) 18 | axes = torch.tensor([0, 1], dtype=torch.int64) 19 | steps = torch.tensor([1, 1], dtype=torch.int64) 20 | y = x[0:3, 0:10] 21 | 22 | if init: 23 | op = Slice(axes, starts, ends, steps) 24 | assert torch.equal(op(x), y) 25 | else: 26 | op = Slice() 27 | assert torch.equal(op(x, starts, ends, axes, steps), y) 28 | 29 | 30 | @pytest.mark.parametrize("init", [True, False]) 31 | def test_slice_2(x, init): 32 | starts = torch.tensor([1], dtype=torch.int64) 33 | ends = torch.tensor([1000], dtype=torch.int64) 34 | axes = torch.tensor([1], dtype=torch.int64) 35 | steps = torch.tensor([2], dtype=torch.int64) 36 | y = x[:, 1:1000:2] 37 | 38 | if init: 39 | op = Slice(axes, starts, ends, steps) 40 | assert torch.equal(op(x), y) 41 | else: 42 | op = Slice() 43 | assert torch.equal(op(x, starts, ends, axes, steps), y) 44 | 45 | 46 | @pytest.mark.parametrize("init", [True, False]) 47 | def test_slice_neg_axes(x, init): 48 | starts = torch.tensor([1], dtype=torch.int64) 49 | ends = torch.tensor([4], dtype=torch.int64) 50 | axes = torch.tensor([-1], dtype=torch.int64) 51 | steps = torch.tensor([2], dtype=torch.int64) 52 | y = x[:, :, 1:4:2] 53 | 54 | if init: 55 | op = Slice(axes, starts, ends, steps) 56 | assert torch.equal(op(x), y) 57 | else: 58 | op = Slice() 59 | assert torch.equal(op(x, starts, ends, axes, steps), y) 60 | 61 | 62 | @pytest.mark.parametrize("init", [True, False]) 63 | def test_slice_neg_axes_2(x, init): 64 | print(x.shape) 65 | starts = torch.tensor([1], dtype=torch.int64) 66 | ends = torch.tensor([4], dtype=torch.int64) 67 | axes = torch.tensor([-2], dtype=torch.int64) 68 | steps = torch.tensor([2], dtype=torch.int64) 69 | y = x[:, 1:4:2] 70 | 71 | if init: 72 | op = Slice(axes, starts, ends, steps) 73 | assert torch.equal(op(x), y) 74 | else: 75 | op = Slice() 76 | assert torch.equal(op(x, starts, ends, axes, steps), y) 77 | 78 | 79 | @pytest.mark.parametrize("init", [True, False]) 80 | def test_slice_default_axes(x, init): 81 | starts = torch.tensor([1, 2], dtype=torch.int64) 82 | ends = torch.tensor([9, 5], dtype=torch.int64) 83 | steps = torch.tensor([1, 2], dtype=torch.int64) 84 | y = x[1:9, 2:5:2] 85 | 86 | if init: 87 | op = Slice(starts=starts, ends=ends, steps=steps) 88 | assert torch.equal(op(x), y) 89 | else: 90 | op = Slice() 91 | assert torch.equal(op(x, starts, ends, steps=steps), y) 92 | 93 | 94 | @pytest.mark.parametrize("init", [True, False]) 95 | def test_slice_default_steps(x, init): 96 | starts = torch.tensor([1], dtype=torch.int64) 97 | ends = torch.tensor([9], dtype=torch.int64) 98 | axes = torch.tensor([1], dtype=torch.int64) 99 | y = x[:, 1:9] 100 | 101 | if init: 102 | op = Slice(axes, starts, ends) 103 | assert torch.equal(op(x), y) 104 | else: 105 | op = Slice() 106 | assert torch.equal(op(x, starts, ends, axes), y) 107 | 108 | 109 | @pytest.mark.parametrize("init", [True, False]) 110 | def test_slice_neg_steps(x, init): 111 | starts = torch.tensor([20, 10, 4], dtype=torch.int64) 112 | ends = torch.tensor([0, 0, 1], dtype=torch.int64) 113 | axes = torch.tensor([0, 1, 2], dtype=torch.int64) 114 | steps = torch.tensor([-1, -3, -2], dtype=torch.int64) 115 | y = torch.tensor(np.copy(x.numpy()[20:0:-1, 10:0:-3, 4:1:-2])) 116 | 117 | if init: 118 | op = Slice(axes, starts=starts, ends=ends, steps=steps) 119 | print(op, flush=True) 120 | assert torch.equal(op(x), y) 121 | else: 122 | op = Slice() 123 | assert torch.equal(op(x, starts, ends, axes, steps), y) 124 | 125 | 126 | def test_to_positive_step(): 127 | assert _to_positive_step(slice(-1, None, -1), 8) == slice(0, 8, 1) 128 | assert _to_positive_step(slice(-2, None, -1), 8) == slice(0, 7, 1) 129 | assert _to_positive_step(slice(None, -1, -1), 8) == slice(0, 0, 1) 130 | assert _to_positive_step(slice(None, -2, -1), 8) == slice(7, 8, 1) 131 | assert _to_positive_step(slice(None, None, -1), 8) == slice(0, 8, 1) 132 | assert _to_positive_step(slice(8, 1, -2), 8) == slice(3, 8, 2) 133 | assert _to_positive_step(slice(8, 0, -2), 8) == slice(1, 8, 2) 134 | -------------------------------------------------------------------------------- /tests/onnx2pytorch/operations/test_split.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytest 3 | 4 | from onnx2pytorch.operations import Split 5 | 6 | 7 | @pytest.fixture 8 | def weight(): 9 | a = torch.rand(15) 10 | a[[4, 7, 12]] = 0 11 | return a 12 | 13 | 14 | @pytest.mark.parametrize("enable_pruning", [True, False]) 15 | @pytest.mark.parametrize( 16 | "split_size_or_sections, number_of_splits", [((5, 5, 5), None), (None, 3)] 17 | ) 18 | def test_split(weight, enable_pruning, split_size_or_sections, number_of_splits): 19 | """keep_size=False""" 20 | op = Split( 21 | enable_pruning, split_size_or_sections, number_of_splits, keep_size=False 22 | ) 23 | s = op(weight) 24 | assert all(len(x) == 5 for x in s) 25 | 26 | op.set_input_indices((weight,)) 27 | s = op(torch.rand(12)) 28 | assert all(len(x) == 4 for x in s) 29 | 30 | 31 | @pytest.mark.parametrize("enable_pruning", [True, False]) 32 | @pytest.mark.parametrize( 33 | "split_size_or_sections, number_of_splits", [((5, 5, 5), None), (None, 3)] 34 | ) 35 | def test_split_2(weight, enable_pruning, split_size_or_sections, number_of_splits): 36 | """keep_size=True""" 37 | op = Split(enable_pruning, split_size_or_sections, number_of_splits, keep_size=True) 38 | s = op(weight) 39 | assert all(len(x) == 5 for x in s) 40 | 41 | op.set_input_indices((weight,)) 42 | s = op(torch.rand(12)) 43 | assert all(len(x) == 5 for x in s) 44 | 45 | # keep_size=True expands the input with zeros 46 | location_of_zeros_in_splits = [4, 2, 2] 47 | for x, i in zip(s, location_of_zeros_in_splits): 48 | (idx,) = torch.where(x == 0) 49 | assert idx == torch.tensor([i]) 50 | 51 | 52 | def test_split_parameter_check(weight): 53 | with pytest.raises(AssertionError): 54 | Split(enable_pruning=True, split_size_or_sections=None, number_of_splits=None) 55 | 56 | 57 | @pytest.mark.parametrize("split_size_or_sections", [(5, 5, 5)]) 58 | def test_split_no_enable_pruning(weight, split_size_or_sections): 59 | op = Split(enable_pruning=False, keep_size=False) 60 | s = op(weight, split_size_or_sections) 61 | assert all(len(x) == 5 for x in s) 62 | -------------------------------------------------------------------------------- /tests/onnx2pytorch/operations/test_squeeze.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytest 3 | 4 | from onnx2pytorch.operations.squeeze import Squeeze 5 | 6 | 7 | @pytest.fixture 8 | def inp(): 9 | return torch.ones(1, 2, 1, 2) 10 | 11 | 12 | @pytest.mark.parametrize( 13 | "dim, exp_shape", 14 | [ 15 | (None, (2, 2)), 16 | (0, (2, 1, 2)), 17 | (2, (1, 2, 2)), 18 | (-2, (1, 2, 2)), 19 | (torch.tensor([0, 2]), (2, 2)), 20 | ], 21 | ) 22 | def test_squeeze_v11(inp, dim, exp_shape): 23 | op = Squeeze(opset_version=11, dim=dim) 24 | assert tuple(op(inp).shape) == exp_shape 25 | 26 | 27 | @pytest.mark.parametrize( 28 | "dim, exp_shape", 29 | [ 30 | (None, (2, 2)), 31 | (0, (2, 1, 2)), 32 | (2, (1, 2, 2)), 33 | (-2, (1, 2, 2)), 34 | (torch.tensor([0, 2]), (2, 2)), 35 | ], 36 | ) 37 | def test_squeeze_v13(inp, dim, exp_shape): 38 | op = Squeeze(opset_version=13) 39 | assert tuple(op(inp, dim).shape) == exp_shape 40 | -------------------------------------------------------------------------------- /tests/onnx2pytorch/operations/test_thresholdedrelu.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import pytest 4 | 5 | from onnx2pytorch.operations.thresholdedrelu import ThresholdedRelu 6 | 7 | 8 | def test_thresholdedrelu(): 9 | x = torch.tensor([-1.5, 0.0, 1.2, 2.0, 2.2]) 10 | 11 | op = ThresholdedRelu() 12 | exp_y = torch.tensor([0.0, 0.0, 1.2, 2.0, 2.2]) 13 | assert torch.equal(op(x), exp_y) 14 | 15 | op = ThresholdedRelu(alpha=2.0) 16 | exp_y = torch.tensor([0.0, 0.0, 0.0, 0.0, 2.2]) 17 | assert torch.equal(op(x), exp_y) 18 | -------------------------------------------------------------------------------- /tests/onnx2pytorch/operations/test_tile.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import pytest 4 | 5 | from onnx2pytorch.operations.tile import Tile 6 | 7 | 8 | def test_tile(): 9 | op = Tile() 10 | x = torch.rand(2, 3, 4, 5) 11 | repeats = torch.randint(low=1, high=10, size=(x.ndim,)) 12 | z = torch.from_numpy(np.tile(x.numpy(), repeats.numpy())) 13 | assert torch.equal(op(x, repeats), z) 14 | 15 | 16 | def test_tile_precomputed(): 17 | op = Tile() 18 | x = torch.tensor( 19 | [ 20 | [0, 1], 21 | [2, 3], 22 | ], 23 | dtype=torch.float32, 24 | ) 25 | repeats = torch.tensor([2, 2], dtype=torch.int64) 26 | 27 | z = torch.tensor( 28 | [[0, 1, 0, 1], [2, 3, 2, 3], [0, 1, 0, 1], [2, 3, 2, 3]], dtype=torch.float32 29 | ) 30 | 31 | assert torch.equal(op(x, repeats), z) 32 | -------------------------------------------------------------------------------- /tests/onnx2pytorch/operations/test_topk.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytest 3 | 4 | from onnx2pytorch.operations.topk import TopK 5 | 6 | 7 | def test_topk(): 8 | axis = 1 9 | largest = 1 10 | op = TopK(axis=axis, largest=largest) 11 | 12 | X = torch.tensor( 13 | [ 14 | [0, 1, 2, 3], 15 | [4, 5, 6, 7], 16 | [8, 9, 10, 11], 17 | ], 18 | dtype=torch.float32, 19 | ) 20 | k = 3 21 | K = torch.tensor([k], dtype=torch.int64) 22 | values_exp = torch.tensor( 23 | [ 24 | [3, 2, 1], 25 | [7, 6, 5], 26 | [11, 10, 9], 27 | ], 28 | dtype=torch.float32, 29 | ) 30 | indices_exp = torch.tensor( 31 | [ 32 | [3, 2, 1], 33 | [3, 2, 1], 34 | [3, 2, 1], 35 | ] 36 | ) 37 | values, indices = op(X, K) 38 | assert torch.equal(values_exp, values) 39 | assert torch.equal(indices_exp, indices) 40 | 41 | 42 | def test_topk_negative_axis(): 43 | op = TopK() 44 | 45 | X = torch.tensor( 46 | [ 47 | [0, 1, 2, 3], 48 | [4, 5, 6, 7], 49 | [8, 9, 10, 11], 50 | ], 51 | dtype=torch.float32, 52 | ) 53 | k = 3 54 | K = torch.tensor([k], dtype=torch.int64) 55 | values_exp = torch.tensor( 56 | [ 57 | [3, 2, 1], 58 | [7, 6, 5], 59 | [11, 10, 9], 60 | ], 61 | dtype=torch.float32, 62 | ) 63 | indices_exp = torch.tensor( 64 | [ 65 | [3, 2, 1], 66 | [3, 2, 1], 67 | [3, 2, 1], 68 | ] 69 | ) 70 | values, indices = op(X, K) 71 | assert torch.equal(values_exp, values) 72 | assert torch.equal(indices_exp, indices) 73 | 74 | 75 | def test_topk_smallest(): 76 | axis = 1 77 | largest = 0 78 | op = TopK(axis=axis, largest=largest) 79 | 80 | X = torch.tensor( 81 | [ 82 | [0, 1, 2, 3], 83 | [4, 5, 6, 7], 84 | [11, 10, 9, 8], 85 | ], 86 | dtype=torch.float32, 87 | ) 88 | k = 3 89 | K = torch.tensor([k], dtype=torch.int64) 90 | values_exp = torch.tensor( 91 | [ 92 | [0, 1, 2], 93 | [4, 5, 6], 94 | [8, 9, 10], 95 | ], 96 | dtype=torch.float32, 97 | ) 98 | indices_exp = torch.tensor( 99 | [ 100 | [0, 1, 2], 101 | [0, 1, 2], 102 | [3, 2, 1], 103 | ] 104 | ) 105 | values, indices = op(X, K) 106 | assert torch.equal(values_exp, values) 107 | assert torch.equal(indices_exp, indices) 108 | -------------------------------------------------------------------------------- /tests/onnx2pytorch/operations/test_unsqueeze.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import torch 4 | 5 | from onnx2pytorch.operations.unsqueeze import Unsqueeze 6 | 7 | 8 | def test_unsqueeze_negative_axes(): 9 | op = Unsqueeze(opset_version=13) 10 | x = torch.randn(1, 3, 1, 5) 11 | axes = torch.tensor([-2], dtype=torch.int64) 12 | y = torch.from_numpy(np.expand_dims(x.detach().numpy(), axis=-2)) 13 | assert torch.equal(op(x, axes), y) 14 | 15 | 16 | def test_unsqueeze_unsorted_axes(): 17 | op = Unsqueeze(opset_version=13) 18 | x = torch.randn(3, 4, 5) 19 | axes = torch.tensor([5, 4, 2], dtype=torch.int64) 20 | x_np = x.detach().numpy() 21 | y_np = np.expand_dims(x_np, axis=2) 22 | y_np = np.expand_dims(y_np, axis=4) 23 | y_np = np.expand_dims(y_np, axis=5) 24 | y = torch.from_numpy(y_np) 25 | assert torch.equal(op(x, axes), y) 26 | -------------------------------------------------------------------------------- /tests/onnx2pytorch/operations/test_where.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytest 3 | 4 | from onnx2pytorch.operations.where import Where 5 | 6 | 7 | @pytest.mark.parametrize("dtype", [torch.int64, torch.float32]) 8 | def test_where(dtype): 9 | op = Where() 10 | condition = torch.tensor([[1, 0], [1, 1]], dtype=torch.bool) 11 | x = torch.tensor([[1, 2], [3, 4]]) 12 | y = torch.tensor([[9, 8], [7, 6]]) 13 | z = torch.tensor([[1, 8], [3, 4]]) 14 | assert torch.equal(op(condition, x.to(dtype), y.to(dtype)), z.to(dtype)) 15 | -------------------------------------------------------------------------------- /tests/onnx2pytorch/test_convert.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | import onnx 4 | import numpy as np 5 | from onnx import numpy_helper 6 | from onnx.backend.test.case.node.gemm import gemm_reference_implementation 7 | from torch import nn 8 | 9 | from onnx2pytorch.convert import convert_linear_layer 10 | from onnx2pytorch.helpers import to_converted 11 | 12 | 13 | @pytest.fixture 14 | def embedding(): 15 | # inp (35, 1) 16 | return nn.Embedding(28785, 200) 17 | 18 | 19 | @pytest.fixture 20 | def encoder(): 21 | # inp (35, 1, 200) 22 | encoder_layers = nn.TransformerEncoderLayer(200, 2, 100, 0.2) 23 | model = nn.TransformerEncoder(encoder_layers, 2) 24 | model.eval() 25 | return model 26 | 27 | 28 | @pytest.mark.skip("Not passing tox test") 29 | def test_convert(encoder): 30 | inp = torch.rand(35, 1, 200).to(torch.float32) 31 | mask = torch.ones(35, 35).to(torch.float32) 32 | with torch.no_grad(): 33 | out_true = encoder(inp, mask) 34 | 35 | converted_model = to_converted(encoder, ((35, 1, 200), (35, 35))) 36 | converted_model.batch_dim = 1 37 | 38 | out = converted_model(inp, mask) 39 | assert torch.allclose(out_true, out, atol=1e-6) 40 | 41 | 42 | def test_convert_linear_layer_trasB1(): 43 | node = onnx.helper.make_node( 44 | "Gemm", inputs=["a", "b", "c"], outputs=["y"], transB=1 45 | ) 46 | a = np.random.ranf([3, 6]).astype(np.float32) 47 | b = np.random.ranf([4, 6]).astype(np.float32) 48 | c = np.zeros([1, 4]).astype(np.float32) 49 | y = gemm_reference_implementation(a, b, c, transB=1) 50 | 51 | params = [numpy_helper.from_array(b), numpy_helper.from_array(c)] 52 | op = convert_linear_layer(node, params) 53 | op.eval() 54 | out = op(torch.from_numpy(a)) 55 | torch.allclose(torch.from_numpy(y), out) 56 | 57 | 58 | def test_convert_linear_layer_default(): 59 | node = onnx.helper.make_node("Gemm", inputs=["a", "b", "c"], outputs=["y"]) 60 | a = np.random.ranf([3, 6]).astype(np.float32) 61 | b = np.random.ranf([6, 4]).astype(np.float32) 62 | c = np.random.ranf([3, 4]).astype(np.float32) 63 | y = gemm_reference_implementation(a, b, c) 64 | 65 | params = [numpy_helper.from_array(b), numpy_helper.from_array(c)] 66 | op = convert_linear_layer(node, params) 67 | op.eval() 68 | out = op(torch.from_numpy(a)) 69 | torch.allclose(torch.from_numpy(y), out) 70 | 71 | 72 | def test_convert_linear_layer_transB0(): 73 | node = onnx.helper.make_node( 74 | "Gemm", inputs=["a", "b", "c"], outputs=["y"], transB=0 75 | ) 76 | a = np.random.ranf([3, 6]).astype(np.float32) 77 | b = np.random.ranf([6, 4]).astype(np.float32) 78 | c = np.random.ranf([3, 4]).astype(np.float32) 79 | y = gemm_reference_implementation(a, b, c, transB=0) 80 | 81 | params = [numpy_helper.from_array(b), numpy_helper.from_array(c)] 82 | op = convert_linear_layer(node, params) 83 | op.eval() 84 | out = op(torch.from_numpy(a)) 85 | torch.allclose(torch.from_numpy(y), out) 86 | 87 | 88 | def test_convert_linear_layer_alpha(): 89 | node = onnx.helper.make_node( 90 | "Gemm", inputs=["a", "b", "c"], outputs=["y"], alpha=0.5 91 | ) 92 | a = np.random.ranf([3, 5]).astype(np.float32) 93 | b = np.random.ranf([5, 4]).astype(np.float32) 94 | c = np.zeros([1, 4]).astype(np.float32) 95 | y = gemm_reference_implementation(a, b, c, alpha=0.5) 96 | 97 | params = [numpy_helper.from_array(b), numpy_helper.from_array(c)] 98 | op = convert_linear_layer(node, params) 99 | op.eval() 100 | out = op(torch.from_numpy(a)) 101 | torch.allclose(torch.from_numpy(y), out) 102 | 103 | 104 | def test_convert_linear_layer_all(): 105 | node = onnx.helper.make_node( 106 | "Gemm", 107 | inputs=["a", "b", "c"], 108 | outputs=["y"], 109 | alpha=0.25, 110 | beta=0.35, 111 | transA=0, 112 | transB=1, 113 | ) 114 | a = np.random.ranf([4, 3]).astype(np.float32).transpose() 115 | b = np.random.ranf([5, 4]).astype(np.float32) 116 | c = np.random.ranf([1, 5]).astype(np.float32) 117 | y = gemm_reference_implementation( 118 | a, b, c, transA=0, transB=1, alpha=0.25, beta=0.35 119 | ) 120 | 121 | params = [numpy_helper.from_array(b), numpy_helper.from_array(c)] 122 | op = convert_linear_layer(node, params) 123 | op.eval() 124 | out = op(torch.from_numpy(a)) 125 | torch.allclose(torch.from_numpy(y), out) 126 | -------------------------------------------------------------------------------- /tests/onnx2pytorch/test_onnx2pytorch.py: -------------------------------------------------------------------------------- 1 | import os 2 | import io 3 | 4 | import onnx 5 | import numpy as np 6 | import torch 7 | import onnxruntime as ort 8 | 9 | from onnx2pytorch import convert 10 | 11 | 12 | def test_onnx2pytorch(onnx_model, onnx_model_outputs, onnx_inputs): 13 | model = convert.ConvertModel(onnx_model) 14 | model.eval() 15 | model.cpu() 16 | with torch.no_grad(): 17 | outputs = model(*(torch.from_numpy(i) for i in onnx_inputs.values())) 18 | 19 | if not isinstance(outputs, list): 20 | outputs = [outputs] 21 | 22 | outputs = [x.cpu().numpy() for x in outputs] 23 | 24 | for output, onnx_model_output in zip(outputs, onnx_model_outputs): 25 | print("mse", ((onnx_model_output - output) ** 2).sum() / onnx_model_output.size) 26 | np.testing.assert_allclose(onnx_model_output, output, atol=1e-5, rtol=1e-3) 27 | 28 | 29 | def test_onnx2pytorch2onnx(onnx_model, onnx_model_outputs, onnx_inputs): 30 | """Test that conversion works both ways.""" 31 | torch_inputs = [torch.from_numpy(x) for x in onnx_inputs.values()] 32 | 33 | model = convert.ConvertModel(onnx_model) 34 | model.eval() 35 | model.cpu() 36 | 37 | bitstream = io.BytesIO() 38 | torch.onnx.export( 39 | model, 40 | tuple(torch_inputs), 41 | bitstream, 42 | export_params=True, 43 | opset_version=11, 44 | do_constant_folding=True, 45 | input_names=list(onnx_inputs.keys()), 46 | ) 47 | 48 | # for some reason the following check fails the circleci with segmentation fault 49 | if not os.environ.get("CIRCLECI"): 50 | onnx_model = onnx.ModelProto.FromString(bitstream.getvalue()) 51 | onnx.checker.check_model(onnx_model) 52 | 53 | ort_session = ort.InferenceSession(bitstream.getvalue()) 54 | outputs = ort_session.run(None, onnx_inputs) 55 | 56 | for output, onnx_model_output in zip(outputs, onnx_model_outputs): 57 | print("mse", ((onnx_model_output - output) ** 2).sum() / onnx_model_output.size) 58 | np.testing.assert_allclose(onnx_model_output, output, atol=1e-5, rtol=1e-3) 59 | -------------------------------------------------------------------------------- /tests/onnx2pytorch/test_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import onnx 3 | import pytest 4 | import torch 5 | from torch import nn 6 | from onnx.backend.test.case.node.pad import pad_impl 7 | 8 | from onnx2pytorch.helpers import to_onnx 9 | from onnx2pytorch.utils import ( 10 | is_constant, 11 | get_ops_names, 12 | get_selection, 13 | assign_values_to_dim, 14 | get_activation_value, 15 | extract_padding_params_for_conv_layer, 16 | extract_padding_params, 17 | ) 18 | 19 | 20 | @pytest.fixture 21 | def inp(): 22 | return torch.rand(10, 10) 23 | 24 | 25 | def test_is_constant(): 26 | a = torch.tensor([1]) 27 | assert is_constant(a) 28 | 29 | a = torch.tensor(1) 30 | assert is_constant(a) 31 | 32 | a = torch.tensor([1, 2]) 33 | assert not is_constant(a) 34 | 35 | 36 | def test_get_selection(): 37 | indices = torch.tensor([1, 2, 5]) 38 | with pytest.raises(AssertionError): 39 | get_selection(indices, -1) 40 | 41 | assert [indices] == get_selection(indices, 0) 42 | assert [slice(None), indices] == get_selection(indices, 1) 43 | 44 | 45 | def test_get_selection_2(): 46 | """Behaviour with python lists is unfortunately not working the same.""" 47 | inp = torch.rand(3, 3, 3) 48 | indices = torch.tensor(0) 49 | 50 | selection = get_selection(indices, 0) 51 | assert torch.equal(inp[selection], inp[0]) 52 | 53 | selection = get_selection(indices, 1) 54 | assert torch.equal(inp[selection], inp[:, 0]) 55 | 56 | 57 | @pytest.mark.parametrize( 58 | "val, dim, inplace", [[torch.zeros(4, 10), 0, False], [torch.zeros(10, 4), 1, True]] 59 | ) 60 | def test_assign_values_to_dim(inp, val, dim, inplace): 61 | indices = torch.tensor([2, 4, 6, 8]) 62 | 63 | out = inp.clone() 64 | if dim == 0: 65 | out[indices] = val 66 | elif dim == 1: 67 | out[:, indices] = val 68 | 69 | res = assign_values_to_dim(inp, val, indices, dim, inplace) 70 | if inplace: 71 | assert torch.equal(inp, out) 72 | assert torch.equal(res, out) 73 | else: 74 | # input should not be changed when inplace=False 75 | assert not torch.equal(inp, out) 76 | assert torch.equal(res, out) 77 | 78 | 79 | def test_get_activation_value(): 80 | inp = torch.ones(1, 1, 10, 10).numpy() 81 | model = nn.Sequential(nn.Conv2d(1, 3, 3), nn.Conv2d(3, 1, 3)) 82 | model[0].weight.data *= 0 83 | model[0].weight.data += 1 84 | model.eval() 85 | 86 | onnx_model = to_onnx(model, inp.shape) 87 | 88 | activation_name = onnx_model.graph.node[0].output[0] 89 | value = get_activation_value(onnx_model, inp, activation_name) 90 | assert value[0].shape == (1, 3, 8, 8) 91 | a = value[0].round() 92 | b = 9 * np.ones((1, 3, 8, 8), dtype=np.float32) 93 | assert (a == b).all() 94 | 95 | 96 | def test_get_activation_value_2(): 97 | """Get multiple outputs from onnx model.""" 98 | inp = torch.ones(1, 1, 10, 10).numpy() 99 | model = nn.Sequential(nn.Conv2d(1, 3, 3), nn.Conv2d(3, 1, 3)) 100 | onnx_model = to_onnx(model, inp.shape) 101 | 102 | activation_names = [x.output[0] for x in onnx_model.graph.node] 103 | values = get_activation_value(onnx_model, inp, activation_names) 104 | assert values[0].shape == (1, 3, 8, 8) 105 | assert values[1].shape == (1, 1, 6, 6) 106 | 107 | 108 | @pytest.mark.parametrize( 109 | "pads, output", 110 | [ 111 | ([1, 1, 1, 1], [1, 1]), 112 | ([0, 0, 0, 0], [0, 0]), 113 | ([1, 0], nn.ConstantPad1d([1, 0], 0)), 114 | ([1, 2], nn.ConstantPad1d([1, 2], 0)), 115 | ([1, 1, 0, 0], nn.ConstantPad2d([1, 0, 1, 0], 0)), 116 | ([1, 1, 1, 0, 0, 0], nn.ConstantPad3d([1, 0, 1, 0, 1, 0], 0)), 117 | ], 118 | ) 119 | def test_extract_padding_params_for_conv_layer(pads, output): 120 | out = extract_padding_params_for_conv_layer(pads) 121 | if isinstance(output, nn.Module): 122 | s = len(pads) // 2 123 | inp = np.random.rand(*s * [3]) 124 | expected_out = pad_impl(inp, np.array(pads), "constant", 0) 125 | infered_out = out(torch.from_numpy(inp)).numpy() 126 | assert (expected_out == infered_out).all() 127 | assert output._get_name() == out._get_name() 128 | assert output.padding == out.padding 129 | assert output.value == out.value 130 | else: 131 | assert out == output 132 | 133 | 134 | @pytest.fixture 135 | def weight(): 136 | return torch.rand(1, 3, 10, 10) 137 | 138 | 139 | @pytest.mark.parametrize( 140 | "onnx_pads, torch_pads", 141 | [ 142 | ([2, 2], [2, 2]), 143 | ([1, 2, 1, 2], [2, 2, 1, 1]), 144 | ([1, 2, 3, 4, 1, 2, 3, 4], [4, 4, 3, 3, 2, 2, 1, 1]), 145 | ([0, 0, 1, 2, 0, 0, 1, 2], [2, 2, 1, 1]), 146 | ], 147 | ) 148 | def test_extract_padding_params(weight, onnx_pads, torch_pads): 149 | out_pads = extract_padding_params(onnx_pads) 150 | assert out_pads == torch_pads 151 | 152 | 153 | def test_get_ops_names(): 154 | y_in = onnx.helper.make_tensor_value_info("y_in", onnx.TensorProto.FLOAT, [1]) 155 | y_out = onnx.helper.make_tensor_value_info("y_out", onnx.TensorProto.FLOAT, [1]) 156 | scan_out = onnx.helper.make_tensor_value_info( 157 | "scan_out", onnx.TensorProto.FLOAT, [] 158 | ) 159 | cond_in = onnx.helper.make_tensor_value_info("cond_in", onnx.TensorProto.BOOL, []) 160 | cond_out = onnx.helper.make_tensor_value_info("cond_out", onnx.TensorProto.BOOL, []) 161 | iter_count = onnx.helper.make_tensor_value_info( 162 | "iter_count", onnx.TensorProto.INT64, [] 163 | ) 164 | 165 | x = np.array([1, 2, 3, 4, 5]).astype(np.float32) 166 | 167 | x_const_node = onnx.helper.make_node( 168 | "Constant", 169 | inputs=[], 170 | outputs=["x"], 171 | value=onnx.helper.make_tensor( 172 | name="const_tensor_x", 173 | data_type=onnx.TensorProto.FLOAT, 174 | dims=x.shape, 175 | vals=x.flatten().astype(float), 176 | ), 177 | ) 178 | 179 | one_const_node = onnx.helper.make_node( 180 | "Constant", 181 | inputs=[], 182 | outputs=["one"], 183 | value=onnx.helper.make_tensor( 184 | name="const_tensor_one", data_type=onnx.TensorProto.INT64, dims=(), vals=[1] 185 | ), 186 | ) 187 | 188 | i_add_node = onnx.helper.make_node( 189 | "Add", inputs=["iter_count", "one"], outputs=["end"] 190 | ) 191 | 192 | start_unsqueeze_node = onnx.helper.make_node( 193 | "Unsqueeze", inputs=["iter_count"], outputs=["slice_start"], axes=[0] 194 | ) 195 | 196 | end_unsqueeze_node = onnx.helper.make_node( 197 | "Unsqueeze", inputs=["end"], outputs=["slice_end"], axes=[0] 198 | ) 199 | 200 | slice_node = onnx.helper.make_node( 201 | "Slice", inputs=["x", "slice_start", "slice_end"], outputs=["slice_out"] 202 | ) 203 | 204 | y_add_node = onnx.helper.make_node( 205 | "Add", inputs=["y_in", "slice_out"], outputs=["y_out"] 206 | ) 207 | 208 | identity_node = onnx.helper.make_node( 209 | "Identity", inputs=["cond_in"], outputs=["cond_out"] 210 | ) 211 | 212 | scan_identity_node = onnx.helper.make_node( 213 | "Identity", inputs=["y_out"], outputs=["scan_out"] 214 | ) 215 | 216 | loop_body = onnx.helper.make_graph( 217 | [ 218 | identity_node, 219 | x_const_node, 220 | one_const_node, 221 | i_add_node, 222 | start_unsqueeze_node, 223 | end_unsqueeze_node, 224 | slice_node, 225 | y_add_node, 226 | scan_identity_node, 227 | ], 228 | "loop_body", 229 | [iter_count, cond_in, y_in], 230 | [cond_out, y_out, scan_out], 231 | ) 232 | 233 | node = onnx.helper.make_node( 234 | "Loop", 235 | inputs=["trip_count", "cond", "y"], 236 | outputs=["res_y", "res_scan"], 237 | body=loop_body, 238 | ) 239 | 240 | trip_count = onnx.helper.make_tensor_value_info( 241 | "trip_count", onnx.TensorProto.INT64, [] 242 | ) 243 | cond = onnx.helper.make_tensor_value_info("cond", onnx.TensorProto.BOOL, []) 244 | y = onnx.helper.make_tensor_value_info("y", onnx.TensorProto.FLOAT, [1]) 245 | res_y = onnx.helper.make_tensor_value_info("res_y", onnx.TensorProto.FLOAT, [1]) 246 | res_scan = onnx.helper.make_tensor_value_info( 247 | "res_scan", onnx.TensorProto.FLOAT, [] 248 | ) 249 | 250 | graph_def = onnx.helper.make_graph( 251 | nodes=[node], 252 | name="test-model", 253 | inputs=[trip_count, cond, y], 254 | outputs=[res_y, res_scan], 255 | ) 256 | 257 | ops_names = set(["Add", "Constant", "Identity", "Loop", "Slice", "Unsqueeze"]) 258 | assert get_ops_names(graph_def) == ops_names 259 | -------------------------------------------------------------------------------- /tests/test_imports.py: -------------------------------------------------------------------------------- 1 | def test_import_packages(): 2 | """Test that importing our package works.""" 3 | import onnx2pytorch 4 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | # tox (https://tox.readthedocs.io/) is a tool for running tests 2 | # in multiple virtualenvs. This configuration file will run the 3 | # test suite on all supported python versions. To use it, "pip install tox" 4 | # and then run "tox" from this directory. 5 | 6 | [tox] 7 | envlist = clean,py39,py310,py311,py312 8 | 9 | [gh-actions] 10 | python = 11 | 3.9: py39 12 | 3.10: py310 13 | 3.11: py311 14 | 3.12: py312 15 | 16 | [testenv] 17 | passenv = 18 | CIRCLE* 19 | KMP_DUPLICATE_LIB_OK 20 | deps = 21 | -rrequirements.txt 22 | torch19: torch <= 1.9.0. 23 | pytest-cov 24 | commands = 25 | pytest --cov --cov-append --cov-report term --cov-report html tests/ 26 | 27 | # https://pytest-cov.readthedocs.io/en/latest/tox.html 28 | [testenv:clean] 29 | deps = coverage 30 | skip_install = true 31 | commands = coverage erase 32 | 33 | [coverage:report] 34 | omit = 35 | .tox/* 36 | tests/* 37 | 38 | # Ignore some checks due to python black 39 | [flake8] 40 | ignore = E203, E266, E501, W503, F403, F401 41 | max-line-length = 88 42 | max-complexity = 10 43 | select = B,C,E,F,W,T4,B9 44 | --------------------------------------------------------------------------------