├── .github └── workflows │ ├── main.yml │ └── python-publish.yml ├── .gitignore ├── .style.yapf ├── LICENSE ├── README.md ├── onnx_model_maker ├── __init__.py ├── code_gen.py └── ops │ ├── __init__.py │ ├── op_helper.py │ ├── op_ver_1.py │ ├── op_ver_10.py │ ├── op_ver_11.py │ ├── op_ver_12.py │ ├── op_ver_13.py │ ├── op_ver_14.py │ ├── op_ver_15.py │ ├── op_ver_16.py │ ├── op_ver_17.py │ ├── op_ver_2.py │ ├── op_ver_3.py │ ├── op_ver_4.py │ ├── op_ver_5.py │ ├── op_ver_6.py │ ├── op_ver_7.py │ ├── op_ver_8.py │ └── op_ver_9.py ├── onnx_pytorch ├── __init__.py ├── _version.py ├── code_gen.py ├── code_gen_template.py ├── op_code_generators │ ├── Abs.py │ ├── Acos.py │ ├── Acosh.py │ ├── Add.py │ ├── And.py │ ├── ArgMax.py │ ├── ArgMin.py │ ├── Asin.py │ ├── Asinh.py │ ├── Atan.py │ ├── Atanh.py │ ├── AveragePool.py │ ├── BatchNormalization.py │ ├── BitShift.py │ ├── Cast.py │ ├── Ceil.py │ ├── Clip.py │ ├── Concat.py │ ├── Constant.py │ ├── ConstantOfShape.py │ ├── Conv.py │ ├── ConvTranspose.py │ ├── Cos.py │ ├── Cosh.py │ ├── Div.py │ ├── Dropout.py │ ├── Elu.py │ ├── Equal.py │ ├── Exp.py │ ├── Expand.py │ ├── Flatten.py │ ├── Floor.py │ ├── Gather.py │ ├── GatherND.py │ ├── Gemm.py │ ├── GlobalAveragePool.py │ ├── Greater.py │ ├── Identity.py │ ├── InstanceNormalization.py │ ├── LRN.py │ ├── LayerNormalization.py │ ├── LeakyRelu.py │ ├── Less.py │ ├── Log.py │ ├── MatMul.py │ ├── Max.py │ ├── MaxPool.py │ ├── Mul.py │ ├── NonMaxSuppression.py │ ├── NonZero.py │ ├── Not.py │ ├── PRelu.py │ ├── Pad.py │ ├── Reciprocal.py │ ├── ReduceMean.py │ ├── ReduceMin.py │ ├── ReduceProd.py │ ├── ReduceSum.py │ ├── Relu.py │ ├── Reshape.py │ ├── Resize.py │ ├── RoiAlign.py │ ├── Round.py │ ├── Scatter.py │ ├── ScatterElements.py │ ├── Shape.py │ ├── Sigmoid.py │ ├── Slice.py │ ├── Softmax.py │ ├── Split.py │ ├── Sqrt.py │ ├── Squeeze.py │ ├── Sub.py │ ├── Tanh.py │ ├── TopK.py │ ├── Transpose.py │ ├── Unsqueeze.py │ ├── Upsample.py │ └── __init__.py ├── tests │ ├── __init__.py │ ├── test_base.py │ └── test_onnx_model_zoo.py └── utils │ ├── __init__.py │ └── embedding_config_helper.py ├── requirements.txt ├── setup.py └── tutorial.py /.github/workflows/main.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: [ main, develop, ci ] 6 | pull_request: 7 | branches: [ main ] 8 | 9 | workflow_dispatch: 10 | 11 | jobs: 12 | build: 13 | strategy: 14 | matrix: 15 | os: [ ubuntu-18.04, macos-12 ] 16 | python-version: [ 3.9 ] 17 | runs-on: ${{ matrix.os }} 18 | steps: 19 | - name: Cache ONNX model zoo 20 | uses: actions/cache@v2 21 | env: 22 | cache-name: cache-onnx-model-zoo 23 | with: 24 | path: ~/onnx_model_zoo 25 | key: ${{ runner.os }}-build-${{ env.cache-name }}-${{ hashFiles('**/version.txt') }} 26 | - uses: actions/checkout@v2 27 | - name: Set up Python ${{ matrix.python-version }} 28 | uses: actions/setup-python@v2 29 | with: 30 | python-version: ${{ matrix.python-version }} 31 | - name: Install dependencies 32 | run: | 33 | python -m pip install --upgrade pip 34 | pip install sympy yapf pytest tqdm 35 | pip install -r requirements.txt 36 | pip install -e . 37 | - name: Test format 38 | run: | 39 | yapf -ri onnx_pytorch --exclude onnx_model_maker/ops 40 | if [ $(git diff --cached --exit-code HEAD -- >/dev/null && (git ls-files --other --exclude-standard --directory | grep -c -v '/$')) != 0 ]; then 41 | echo "yapf formatter check failed." 42 | exit 1 43 | else 44 | echo "yapf formatter check passed." 45 | fi 46 | exit 0 47 | - name: Test with pytest 48 | run: | 49 | pytest onnx_pytorch/tests -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | jobs: 16 | deploy: 17 | 18 | runs-on: ubuntu-latest 19 | 20 | steps: 21 | - uses: actions/checkout@v2 22 | - name: Set up Python 23 | uses: actions/setup-python@v2 24 | with: 25 | python-version: '3.x' 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | pip install build 30 | - name: Build package 31 | run: python -m build 32 | - name: Publish package 33 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 34 | with: 35 | user: __token__ 36 | password: ${{ secrets.PYPI_API_TOKEN }} 37 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit tests / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /.style.yapf: -------------------------------------------------------------------------------- 1 | [style] 2 | based_on_style = google 3 | 4 | indent_width = 2 -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # onnx-pytorch 2 | 3 | 4 | ![Build Status](https://github.com/fumihwh/onnx-pytorch/actions/workflows/main.yml/badge.svg?branch=main) 5 | 6 | 7 | Generates PyTorch code from ONNX. 8 | 9 | ## Installation 10 | 11 | - From PyPI 12 | ```bash 13 | pip install onnx-pytorch 14 | ``` 15 | 16 | - From source 17 | ```bash 18 | git clone https://github.com/fumihwh/onnx-pytorch.git 19 | cd onnx-pytorch 20 | pip install -r requirements.txt 21 | pip install -e . 22 | ``` 23 | 24 | 25 | ## Usage 26 | ### By Command Line 27 | ```bash 28 | python -m onnx_pytorch.code_gen -h 29 | 30 | usage: code_gen.py [-h] [--onnx_model_path ONNX_MODEL_PATH] [--output_dir OUTPUT_DIR] [--overwrite OVERWRITE] [--tensor_inplace TENSOR_INPLACE] [--continue_on_error CONTINUE_ON_ERROR] [--simplify_names SIMPLIFY_NAMES] 31 | 32 | optional arguments: 33 | -h, --help show this help message and exit 34 | --onnx_model_path ONNX_MODEL_PATH 35 | The onnx model path. 36 | --output_dir OUTPUT_DIR 37 | The output dir 38 | --overwrite OVERWRITE 39 | Should overwrite the output dir. 40 | --tensor_inplace TENSOR_INPLACE 41 | Try best to inplace tensor. 42 | --continue_on_error CONTINUE_ON_ERROR 43 | Continue on error. 44 | --simplify_names SIMPLIFY_NAMES 45 | Use indexing shorten name instead of original name. 46 | ``` 47 | 48 | ### By Python 49 | ```python 50 | from onnx_pytorch import code_gen 51 | code_gen.gen("/path/to/onnx_model", "/path/to/output_dir") 52 | ``` 53 | 54 | A `model.py` file and `variables/` folder will be created under `output_dir/`. 55 | 56 | ## Tutorial 57 | 1. Download resnet18 ONNX model. 58 | 59 | ```bash 60 | wget https://github.com/onnx/models/raw/master/vision/classification/resnet/model/resnet18-v2-7.onnx 61 | ``` 62 | 63 | 2. Use `onnx-pytorch` to generate PyTorch code and variables. 64 | ```python 65 | from onnx_pytorch import code_gen 66 | code_gen.gen("resnet18-v2-7.onnx", "./") 67 | ``` 68 | 69 | 3. Test result. 70 | ```python 71 | import numpy as np 72 | import onnx 73 | import onnxruntime 74 | import torch 75 | torch.set_printoptions(8) 76 | 77 | from model import Model 78 | 79 | model = Model() 80 | model.eval() 81 | inp = np.random.randn(1, 3, 224, 224).astype(np.float32) 82 | with torch.no_grad(): 83 | torch_outputs = model(torch.from_numpy(inp)) 84 | 85 | onnx_model = onnx.load("resnet18-v2-7.onnx") 86 | sess_options = onnxruntime.SessionOptions() 87 | session = onnxruntime.InferenceSession(onnx_model.SerializeToString(), 88 | sess_options) 89 | inputs = {session.get_inputs()[0].name: inp} 90 | ort_outputs = session.run(None, inputs) 91 | 92 | print( 93 | "Comparison result:", 94 | np.allclose(torch_outputs.detach().numpy(), 95 | ort_outputs[0], 96 | atol=1e-5, 97 | rtol=1e-5)) 98 | ``` 99 | 100 | ## Test 101 | ```bash 102 | pytest onnx_pytorch/tests 103 | ``` 104 | -------------------------------------------------------------------------------- /onnx_model_maker/__init__.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import sys 3 | 4 | import onnx 5 | import onnx.onnx_cpp2py_export.checker as C 6 | from onnx.helper import make_opsetid, make_model_gen_version 7 | 8 | __all__ = ["omm", "mod_name", "reset_model", "set_model", "onnx_mm_export"] 9 | 10 | # OPSET_VER = onnx.defs.onnx_opset_version() 11 | OPSET_VER = 13 12 | 13 | 14 | class OnnxModelMaker: 15 | 16 | def __init__(self, opset_ver=OPSET_VER): 17 | self.opset_import = make_opsetid("", opset_ver) 18 | self.model = make_model_gen_version(onnx.GraphProto(), 19 | opset_imports=[self.opset_import]) 20 | self.op_counter = collections.Counter() 21 | self.ctx = C.CheckerContext() 22 | self.ctx.ir_version = self.model.ir_version 23 | self.ctx.opset_imports = {'': opset_ver} 24 | 25 | def reset_model(self, opset_ver=None): 26 | if opset_ver is not None: 27 | opset_imports = [make_opsetid("", opset_ver)] 28 | global OPSET_VER 29 | OPSET_VER = opset_ver 30 | self.ctx.opset_imports = {'': opset_ver} 31 | else: 32 | opset_imports = [self.opset_import] 33 | self.model = make_model_gen_version(onnx.GraphProto(), 34 | opset_imports=opset_imports) 35 | self.op_counter = collections.Counter() 36 | 37 | def set_model(self, model): 38 | self.model = model 39 | 40 | 41 | omm = OnnxModelMaker() 42 | mod_name = __name__ 43 | reset_model = omm.reset_model 44 | set_model = omm.set_model 45 | ctx = omm.ctx 46 | 47 | 48 | class onnx_mm_export(object): 49 | 50 | def __init__(self, *args, **kwargs): 51 | self._names = args 52 | 53 | def __call__(self, func): 54 | for a in self._names: 55 | mod = sys.modules[f"{mod_name}.ops"] 56 | setattr(mod, a, func) 57 | pass 58 | 59 | return func 60 | -------------------------------------------------------------------------------- /onnx_model_maker/code_gen.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import os 3 | import shutil 4 | 5 | import onnx 6 | 7 | TENSOR_PREFIX = "_t_" 8 | AUTO_GEN_HEAD = "# Autogenerated by onnx-model-maker. Don't modify it manually." 9 | HEADER = f'''{AUTO_GEN_HEAD} 10 | 11 | import onnx 12 | import onnx.helper 13 | import onnx.numpy_helper 14 | from onnx_model_maker import omm 15 | from onnx_model_maker import onnx_mm_export 16 | from onnx_model_maker.ops.op_helper import _add_input 17 | ''' 18 | 19 | OP_HELPER_PY = f'''{AUTO_GEN_HEAD} 20 | 21 | from uuid import uuid4 22 | 23 | import numpy 24 | import onnx 25 | 26 | from onnx_model_maker import omm 27 | 28 | 29 | def _add_input(target, inputs): 30 | if target is None: 31 | return 32 | if type(target) == numpy.ndarray: 33 | t = onnx.numpy_helper.from_array(target, f"{TENSOR_PREFIX}{{uuid4().hex[:4]}}") 34 | omm.model.graph.initializer.append(t) 35 | inputs.append(t.name) 36 | elif type(target) == str: 37 | inputs.append(target) 38 | elif type(target) == list: 39 | _add_list(target, inputs) 40 | elif type(target) == onnx.NodeProto: 41 | inputs.append(target.output[0]) 42 | 43 | 44 | def _add_list(target, inputs): 45 | for t in target: 46 | _add_input(t, inputs) 47 | ''' 48 | 49 | INIT_PY = f'''{AUTO_GEN_HEAD} 50 | 51 | import glob 52 | import importlib 53 | import os 54 | import sys 55 | 56 | import onnx 57 | import numpy 58 | 59 | import onnx_model_maker 60 | from onnx_model_maker import mod_name 61 | from onnx_model_maker import omm 62 | 63 | 64 | modules = glob.glob(os.path.join(os.path.dirname(__file__), "op_ver_*.py")) 65 | for m in modules: 66 | spec = importlib.util.spec_from_file_location(os.path.basename(m)[:-3], m) 67 | spec.loader.exec_module(importlib.util.module_from_spec(spec)) 68 | 69 | 70 | def Input(*args): 71 | inputs = [] 72 | for i, a in enumerate(args): 73 | t = onnx.numpy_helper.from_array(a) 74 | vi = onnx.helper.make_tensor_value_info(f"{TENSOR_PREFIX}Input_{{i}}", 75 | t.data_type, t.dims) 76 | omm.model.graph.input.append(vi) 77 | inputs.append(vi.name) 78 | return inputs 79 | 80 | 81 | def Output(*args, output_num=None): 82 | for i, a in enumerate(args): 83 | if type(a) == numpy.ndarray: 84 | t = onnx.numpy_helper.from_array(a) 85 | vi = onnx.helper.make_tensor_value_info(f"{TENSOR_PREFIX}Output_{{i}}", t.data_type, 86 | t.dims) 87 | omm.model.graph.output.append(vi) 88 | elif type(a) == str: 89 | vi = onnx.helper.make_empty_tensor_value_info(a) 90 | omm.model.graph.output.append(vi) 91 | elif type(a) == onnx.NodeProto: 92 | for j, o in enumerate(a.output): 93 | if output_num is not None and j == output_num: 94 | break 95 | vi = onnx.helper.make_empty_tensor_value_info(o) 96 | omm.model.graph.output.append(vi) 97 | else: 98 | raise Exception 99 | 100 | 101 | ''' 102 | 103 | NEW_LINE = ''' 104 | 105 | ''' 106 | 107 | 108 | def _gen_op_maker(schema): 109 | onnx_op = schema.name 110 | inputs_args = [ 111 | i.name if idx < schema.min_input else f"{i.name}=None" 112 | for idx, i in enumerate(schema.inputs) 113 | ] 114 | inputs_forloop = [i.name for i in schema.inputs] 115 | if len(schema.inputs) == 1: 116 | inputs_forloop.append("") 117 | if len(schema.inputs) != 0: 118 | inputs_args.append("") 119 | 120 | outputs_str = [ 121 | f"f'{TENSOR_PREFIX}{onnx_op}_{{idx}}_{i.name}'" for i in schema.outputs 122 | ] 123 | # outputs_str = f'[f"{TENSOR_PREFIX}{onnx_op}_{{idx}}"]' 124 | if schema.name == "Split": 125 | if schema.since_version == 13: 126 | outputs_str = f'[f"{TENSOR_PREFIX}{onnx_op}_{{idx}}_{{i}}" for i in range(len(split))]' 127 | else: 128 | outputs_str = f'[f"{TENSOR_PREFIX}{onnx_op}_{{idx}}_{{i}}" for i in range(len(kwargs["split"]))]' 129 | if schema.name == "BatchNormalization": 130 | outputs_str = [outputs_str[0]] 131 | if type(outputs_str) in (list,): 132 | outputs_str = f"[{', '.join(outputs_str)}]" 133 | 134 | return f'''@onnx_mm_export("v{schema.since_version}.{onnx_op}") 135 | def {onnx_op}({', '.join(inputs_args)}**kwargs): 136 | _inputs = [] 137 | for i in ({', '.join(inputs_forloop)}): 138 | _add_input(i, _inputs) 139 | 140 | idx = omm.op_counter[\"{onnx_op}\"] 141 | omm.op_counter[\"{onnx_op}\"] += 1 142 | node = onnx.helper.make_node(\"{onnx_op}\", 143 | _inputs, {outputs_str}, 144 | name=f"{onnx_op}_{{idx}}", 145 | **kwargs) 146 | onnx.checker.check_node(node, omm.ctx) 147 | omm.model.graph.node.append(node) 148 | return node 149 | ''' 150 | 151 | 152 | def _gen_abs_op_maker(schema): 153 | onnx_op = schema.name 154 | return f'''def {onnx_op}(*args, **kwargs): 155 | schema = onnx.defs.get_schema("{onnx_op}", 156 | max_inclusive_version=onnx_model_maker.OPSET_VER, 157 | domain="") 158 | return getattr(sys.modules[f"{{mod_name}}.ops"], 159 | f"v{{schema.since_version}}.{onnx_op}")(*args, **kwargs) 160 | ''' 161 | 162 | 163 | def gen(output_dir=None, overwrite=False): 164 | if overwrite: 165 | shutil.rmtree(output_dir) 166 | os.makedirs(output_dir) 167 | if not os.path.exists(output_dir): 168 | os.makedirs(output_dir) 169 | abs_op_contents = {} 170 | file_contents = collections.defaultdict(list) 171 | all_schemas = onnx.defs.get_all_schemas_with_history() 172 | for schema in all_schemas: 173 | since_version = schema.since_version 174 | if str(since_version) not in file_contents: 175 | file_contents[str(since_version)].append(HEADER) 176 | if schema.name not in abs_op_contents: 177 | abs_op_contents[schema.name] = _gen_abs_op_maker(schema) 178 | file_contents[str(since_version)].append(_gen_op_maker(schema)) 179 | for v, c in file_contents.items(): 180 | with open(os.path.join(output_dir, f"op_ver_{v}.py"), "w") as f: 181 | f.write(NEW_LINE.join(c)) 182 | with open(os.path.join(output_dir, "__init__.py"), "w") as f: 183 | f.write(INIT_PY) 184 | f.write( 185 | NEW_LINE.join( 186 | [abs_op_contents[key] for key in sorted(abs_op_contents.keys())])) 187 | all_str = ', '.join([f'"{key}"' for key in sorted(abs_op_contents.keys())]) 188 | f.write(f'''{NEW_LINE}__all__ = [\"Input\", \"Output\", {all_str}]''') 189 | with open(os.path.join(output_dir, "op_helper.py"), "w") as f: 190 | f.write(OP_HELPER_PY) 191 | 192 | 193 | gen("./ops") 194 | -------------------------------------------------------------------------------- /onnx_model_maker/ops/op_helper.py: -------------------------------------------------------------------------------- 1 | # Autogenerated by onnx-model-maker. Don't modify it manually. 2 | 3 | from uuid import uuid4 4 | 5 | import numpy 6 | import onnx 7 | 8 | from onnx_model_maker import omm 9 | 10 | 11 | def _add_input(target, inputs): 12 | if target is None: 13 | return 14 | if type(target) == numpy.ndarray: 15 | t = onnx.numpy_helper.from_array(target, f"_t_{uuid4().hex[:4]}") 16 | omm.model.graph.initializer.append(t) 17 | inputs.append(t.name) 18 | elif type(target) == str: 19 | inputs.append(target) 20 | elif type(target) == list: 21 | _add_list(target, inputs) 22 | elif type(target) == onnx.NodeProto: 23 | inputs.append(target.output[0]) 24 | 25 | 26 | def _add_list(target, inputs): 27 | for t in target: 28 | _add_input(t, inputs) 29 | -------------------------------------------------------------------------------- /onnx_model_maker/ops/op_ver_10.py: -------------------------------------------------------------------------------- 1 | # Autogenerated by onnx-model-maker. Don't modify it manually. 2 | 3 | import onnx 4 | import onnx.helper 5 | import onnx.numpy_helper 6 | from onnx_model_maker import omm 7 | from onnx_model_maker import onnx_mm_export 8 | from onnx_model_maker.ops.op_helper import _add_input 9 | 10 | 11 | @onnx_mm_export("v10.RoiAlign") 12 | def RoiAlign(X, rois, batch_indices, **kwargs): 13 | _inputs = [] 14 | for i in (X, rois, batch_indices): 15 | _add_input(i, _inputs) 16 | 17 | idx = omm.op_counter["RoiAlign"] 18 | omm.op_counter["RoiAlign"] += 1 19 | node = onnx.helper.make_node("RoiAlign", 20 | _inputs, [f'_t_RoiAlign_{idx}_Y'], 21 | name=f"RoiAlign_{idx}", 22 | **kwargs) 23 | onnx.checker.check_node(node, omm.ctx) 24 | omm.model.graph.node.append(node) 25 | return node 26 | 27 | 28 | @onnx_mm_export("v10.ReverseSequence") 29 | def ReverseSequence(input, sequence_lens, **kwargs): 30 | _inputs = [] 31 | for i in (input, sequence_lens): 32 | _add_input(i, _inputs) 33 | 34 | idx = omm.op_counter["ReverseSequence"] 35 | omm.op_counter["ReverseSequence"] += 1 36 | node = onnx.helper.make_node("ReverseSequence", 37 | _inputs, [f'_t_ReverseSequence_{idx}_Y'], 38 | name=f"ReverseSequence_{idx}", 39 | **kwargs) 40 | onnx.checker.check_node(node, omm.ctx) 41 | omm.model.graph.node.append(node) 42 | return node 43 | 44 | 45 | @onnx_mm_export("v10.NonMaxSuppression") 46 | def NonMaxSuppression(boxes, scores, max_output_boxes_per_class=None, iou_threshold=None, score_threshold=None, **kwargs): 47 | _inputs = [] 48 | for i in (boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold): 49 | _add_input(i, _inputs) 50 | 51 | idx = omm.op_counter["NonMaxSuppression"] 52 | omm.op_counter["NonMaxSuppression"] += 1 53 | node = onnx.helper.make_node("NonMaxSuppression", 54 | _inputs, [f'_t_NonMaxSuppression_{idx}_selected_indices'], 55 | name=f"NonMaxSuppression_{idx}", 56 | **kwargs) 57 | onnx.checker.check_node(node, omm.ctx) 58 | omm.model.graph.node.append(node) 59 | return node 60 | 61 | 62 | @onnx_mm_export("v10.IsInf") 63 | def IsInf(X, **kwargs): 64 | _inputs = [] 65 | for i in (X, ): 66 | _add_input(i, _inputs) 67 | 68 | idx = omm.op_counter["IsInf"] 69 | omm.op_counter["IsInf"] += 1 70 | node = onnx.helper.make_node("IsInf", 71 | _inputs, [f'_t_IsInf_{idx}_Y'], 72 | name=f"IsInf_{idx}", 73 | **kwargs) 74 | onnx.checker.check_node(node, omm.ctx) 75 | omm.model.graph.node.append(node) 76 | return node 77 | 78 | 79 | @onnx_mm_export("v10.QuantizeLinear") 80 | def QuantizeLinear(x, y_scale, y_zero_point=None, **kwargs): 81 | _inputs = [] 82 | for i in (x, y_scale, y_zero_point): 83 | _add_input(i, _inputs) 84 | 85 | idx = omm.op_counter["QuantizeLinear"] 86 | omm.op_counter["QuantizeLinear"] += 1 87 | node = onnx.helper.make_node("QuantizeLinear", 88 | _inputs, [f'_t_QuantizeLinear_{idx}_y'], 89 | name=f"QuantizeLinear_{idx}", 90 | **kwargs) 91 | onnx.checker.check_node(node, omm.ctx) 92 | omm.model.graph.node.append(node) 93 | return node 94 | 95 | 96 | @onnx_mm_export("v10.QLinearConv") 97 | def QLinearConv(x, x_scale, x_zero_point, w, w_scale, w_zero_point, y_scale, y_zero_point, B=None, **kwargs): 98 | _inputs = [] 99 | for i in (x, x_scale, x_zero_point, w, w_scale, w_zero_point, y_scale, y_zero_point, B): 100 | _add_input(i, _inputs) 101 | 102 | idx = omm.op_counter["QLinearConv"] 103 | omm.op_counter["QLinearConv"] += 1 104 | node = onnx.helper.make_node("QLinearConv", 105 | _inputs, [f'_t_QLinearConv_{idx}_y'], 106 | name=f"QLinearConv_{idx}", 107 | **kwargs) 108 | onnx.checker.check_node(node, omm.ctx) 109 | omm.model.graph.node.append(node) 110 | return node 111 | 112 | 113 | @onnx_mm_export("v10.ConvInteger") 114 | def ConvInteger(x, w, x_zero_point=None, w_zero_point=None, **kwargs): 115 | _inputs = [] 116 | for i in (x, w, x_zero_point, w_zero_point): 117 | _add_input(i, _inputs) 118 | 119 | idx = omm.op_counter["ConvInteger"] 120 | omm.op_counter["ConvInteger"] += 1 121 | node = onnx.helper.make_node("ConvInteger", 122 | _inputs, [f'_t_ConvInteger_{idx}_y'], 123 | name=f"ConvInteger_{idx}", 124 | **kwargs) 125 | onnx.checker.check_node(node, omm.ctx) 126 | omm.model.graph.node.append(node) 127 | return node 128 | 129 | 130 | @onnx_mm_export("v10.QLinearMatMul") 131 | def QLinearMatMul(a, a_scale, a_zero_point, b, b_scale, b_zero_point, y_scale, y_zero_point, **kwargs): 132 | _inputs = [] 133 | for i in (a, a_scale, a_zero_point, b, b_scale, b_zero_point, y_scale, y_zero_point): 134 | _add_input(i, _inputs) 135 | 136 | idx = omm.op_counter["QLinearMatMul"] 137 | omm.op_counter["QLinearMatMul"] += 1 138 | node = onnx.helper.make_node("QLinearMatMul", 139 | _inputs, [f'_t_QLinearMatMul_{idx}_y'], 140 | name=f"QLinearMatMul_{idx}", 141 | **kwargs) 142 | onnx.checker.check_node(node, omm.ctx) 143 | omm.model.graph.node.append(node) 144 | return node 145 | 146 | 147 | @onnx_mm_export("v10.MatMulInteger") 148 | def MatMulInteger(A, B, a_zero_point=None, b_zero_point=None, **kwargs): 149 | _inputs = [] 150 | for i in (A, B, a_zero_point, b_zero_point): 151 | _add_input(i, _inputs) 152 | 153 | idx = omm.op_counter["MatMulInteger"] 154 | omm.op_counter["MatMulInteger"] += 1 155 | node = onnx.helper.make_node("MatMulInteger", 156 | _inputs, [f'_t_MatMulInteger_{idx}_Y'], 157 | name=f"MatMulInteger_{idx}", 158 | **kwargs) 159 | onnx.checker.check_node(node, omm.ctx) 160 | omm.model.graph.node.append(node) 161 | return node 162 | 163 | 164 | @onnx_mm_export("v10.StringNormalizer") 165 | def StringNormalizer(X, **kwargs): 166 | _inputs = [] 167 | for i in (X, ): 168 | _add_input(i, _inputs) 169 | 170 | idx = omm.op_counter["StringNormalizer"] 171 | omm.op_counter["StringNormalizer"] += 1 172 | node = onnx.helper.make_node("StringNormalizer", 173 | _inputs, [f'_t_StringNormalizer_{idx}_Y'], 174 | name=f"StringNormalizer_{idx}", 175 | **kwargs) 176 | onnx.checker.check_node(node, omm.ctx) 177 | omm.model.graph.node.append(node) 178 | return node 179 | 180 | 181 | @onnx_mm_export("v10.Mod") 182 | def Mod(A, B, **kwargs): 183 | _inputs = [] 184 | for i in (A, B): 185 | _add_input(i, _inputs) 186 | 187 | idx = omm.op_counter["Mod"] 188 | omm.op_counter["Mod"] += 1 189 | node = onnx.helper.make_node("Mod", 190 | _inputs, [f'_t_Mod_{idx}_C'], 191 | name=f"Mod_{idx}", 192 | **kwargs) 193 | onnx.checker.check_node(node, omm.ctx) 194 | omm.model.graph.node.append(node) 195 | return node 196 | 197 | 198 | @onnx_mm_export("v10.DequantizeLinear") 199 | def DequantizeLinear(x, x_scale, x_zero_point=None, **kwargs): 200 | _inputs = [] 201 | for i in (x, x_scale, x_zero_point): 202 | _add_input(i, _inputs) 203 | 204 | idx = omm.op_counter["DequantizeLinear"] 205 | omm.op_counter["DequantizeLinear"] += 1 206 | node = onnx.helper.make_node("DequantizeLinear", 207 | _inputs, [f'_t_DequantizeLinear_{idx}_y'], 208 | name=f"DequantizeLinear_{idx}", 209 | **kwargs) 210 | onnx.checker.check_node(node, omm.ctx) 211 | omm.model.graph.node.append(node) 212 | return node 213 | 214 | 215 | @onnx_mm_export("v10.ThresholdedRelu") 216 | def ThresholdedRelu(X, **kwargs): 217 | _inputs = [] 218 | for i in (X, ): 219 | _add_input(i, _inputs) 220 | 221 | idx = omm.op_counter["ThresholdedRelu"] 222 | omm.op_counter["ThresholdedRelu"] += 1 223 | node = onnx.helper.make_node("ThresholdedRelu", 224 | _inputs, [f'_t_ThresholdedRelu_{idx}_Y'], 225 | name=f"ThresholdedRelu_{idx}", 226 | **kwargs) 227 | onnx.checker.check_node(node, omm.ctx) 228 | omm.model.graph.node.append(node) 229 | return node 230 | 231 | 232 | @onnx_mm_export("v10.Upsample") 233 | def Upsample(X, scales, **kwargs): 234 | _inputs = [] 235 | for i in (X, scales): 236 | _add_input(i, _inputs) 237 | 238 | idx = omm.op_counter["Upsample"] 239 | omm.op_counter["Upsample"] += 1 240 | node = onnx.helper.make_node("Upsample", 241 | _inputs, [f'_t_Upsample_{idx}_Y'], 242 | name=f"Upsample_{idx}", 243 | **kwargs) 244 | onnx.checker.check_node(node, omm.ctx) 245 | omm.model.graph.node.append(node) 246 | return node 247 | 248 | 249 | @onnx_mm_export("v10.AveragePool") 250 | def AveragePool(X, **kwargs): 251 | _inputs = [] 252 | for i in (X, ): 253 | _add_input(i, _inputs) 254 | 255 | idx = omm.op_counter["AveragePool"] 256 | omm.op_counter["AveragePool"] += 1 257 | node = onnx.helper.make_node("AveragePool", 258 | _inputs, [f'_t_AveragePool_{idx}_Y'], 259 | name=f"AveragePool_{idx}", 260 | **kwargs) 261 | onnx.checker.check_node(node, omm.ctx) 262 | omm.model.graph.node.append(node) 263 | return node 264 | 265 | 266 | @onnx_mm_export("v10.TopK") 267 | def TopK(X, K, **kwargs): 268 | _inputs = [] 269 | for i in (X, K): 270 | _add_input(i, _inputs) 271 | 272 | idx = omm.op_counter["TopK"] 273 | omm.op_counter["TopK"] += 1 274 | node = onnx.helper.make_node("TopK", 275 | _inputs, [f'_t_TopK_{idx}_Values', f'_t_TopK_{idx}_Indices'], 276 | name=f"TopK_{idx}", 277 | **kwargs) 278 | onnx.checker.check_node(node, omm.ctx) 279 | omm.model.graph.node.append(node) 280 | return node 281 | 282 | 283 | @onnx_mm_export("v10.Slice") 284 | def Slice(data, starts, ends, axes=None, steps=None, **kwargs): 285 | _inputs = [] 286 | for i in (data, starts, ends, axes, steps): 287 | _add_input(i, _inputs) 288 | 289 | idx = omm.op_counter["Slice"] 290 | omm.op_counter["Slice"] += 1 291 | node = onnx.helper.make_node("Slice", 292 | _inputs, [f'_t_Slice_{idx}_output'], 293 | name=f"Slice_{idx}", 294 | **kwargs) 295 | onnx.checker.check_node(node, omm.ctx) 296 | omm.model.graph.node.append(node) 297 | return node 298 | 299 | 300 | @onnx_mm_export("v10.Resize") 301 | def Resize(X, scales, **kwargs): 302 | _inputs = [] 303 | for i in (X, scales): 304 | _add_input(i, _inputs) 305 | 306 | idx = omm.op_counter["Resize"] 307 | omm.op_counter["Resize"] += 1 308 | node = onnx.helper.make_node("Resize", 309 | _inputs, [f'_t_Resize_{idx}_Y'], 310 | name=f"Resize_{idx}", 311 | **kwargs) 312 | onnx.checker.check_node(node, omm.ctx) 313 | omm.model.graph.node.append(node) 314 | return node 315 | 316 | 317 | @onnx_mm_export("v10.MaxPool") 318 | def MaxPool(X, **kwargs): 319 | _inputs = [] 320 | for i in (X, ): 321 | _add_input(i, _inputs) 322 | 323 | idx = omm.op_counter["MaxPool"] 324 | omm.op_counter["MaxPool"] += 1 325 | node = onnx.helper.make_node("MaxPool", 326 | _inputs, [f'_t_MaxPool_{idx}_Y', f'_t_MaxPool_{idx}_Indices'], 327 | name=f"MaxPool_{idx}", 328 | **kwargs) 329 | onnx.checker.check_node(node, omm.ctx) 330 | omm.model.graph.node.append(node) 331 | return node 332 | 333 | 334 | @onnx_mm_export("v10.Dropout") 335 | def Dropout(data, **kwargs): 336 | _inputs = [] 337 | for i in (data, ): 338 | _add_input(i, _inputs) 339 | 340 | idx = omm.op_counter["Dropout"] 341 | omm.op_counter["Dropout"] += 1 342 | node = onnx.helper.make_node("Dropout", 343 | _inputs, [f'_t_Dropout_{idx}_output', f'_t_Dropout_{idx}_mask'], 344 | name=f"Dropout_{idx}", 345 | **kwargs) 346 | onnx.checker.check_node(node, omm.ctx) 347 | omm.model.graph.node.append(node) 348 | return node 349 | -------------------------------------------------------------------------------- /onnx_model_maker/ops/op_ver_12.py: -------------------------------------------------------------------------------- 1 | # Autogenerated by onnx-model-maker. Don't modify it manually. 2 | 3 | import onnx 4 | import onnx.helper 5 | import onnx.numpy_helper 6 | from onnx_model_maker import omm 7 | from onnx_model_maker import onnx_mm_export 8 | from onnx_model_maker.ops.op_helper import _add_input 9 | 10 | 11 | @onnx_mm_export("v12.LessOrEqual") 12 | def LessOrEqual(A, B, **kwargs): 13 | _inputs = [] 14 | for i in (A, B): 15 | _add_input(i, _inputs) 16 | 17 | idx = omm.op_counter["LessOrEqual"] 18 | omm.op_counter["LessOrEqual"] += 1 19 | node = onnx.helper.make_node("LessOrEqual", 20 | _inputs, [f'_t_LessOrEqual_{idx}_C'], 21 | name=f"LessOrEqual_{idx}", 22 | **kwargs) 23 | onnx.checker.check_node(node, omm.ctx) 24 | omm.model.graph.node.append(node) 25 | return node 26 | 27 | 28 | @onnx_mm_export("v12.Celu") 29 | def Celu(X, **kwargs): 30 | _inputs = [] 31 | for i in (X, ): 32 | _add_input(i, _inputs) 33 | 34 | idx = omm.op_counter["Celu"] 35 | omm.op_counter["Celu"] += 1 36 | node = onnx.helper.make_node("Celu", 37 | _inputs, [f'_t_Celu_{idx}_Y'], 38 | name=f"Celu_{idx}", 39 | **kwargs) 40 | onnx.checker.check_node(node, omm.ctx) 41 | omm.model.graph.node.append(node) 42 | return node 43 | 44 | 45 | @onnx_mm_export("v12.GatherND") 46 | def GatherND(data, indices, **kwargs): 47 | _inputs = [] 48 | for i in (data, indices): 49 | _add_input(i, _inputs) 50 | 51 | idx = omm.op_counter["GatherND"] 52 | omm.op_counter["GatherND"] += 1 53 | node = onnx.helper.make_node("GatherND", 54 | _inputs, [f'_t_GatherND_{idx}_output'], 55 | name=f"GatherND_{idx}", 56 | **kwargs) 57 | onnx.checker.check_node(node, omm.ctx) 58 | omm.model.graph.node.append(node) 59 | return node 60 | 61 | 62 | @onnx_mm_export("v12.Max") 63 | def Max(data_0, **kwargs): 64 | _inputs = [] 65 | for i in (data_0, ): 66 | _add_input(i, _inputs) 67 | 68 | idx = omm.op_counter["Max"] 69 | omm.op_counter["Max"] += 1 70 | node = onnx.helper.make_node("Max", 71 | _inputs, [f'_t_Max_{idx}_max'], 72 | name=f"Max_{idx}", 73 | **kwargs) 74 | onnx.checker.check_node(node, omm.ctx) 75 | omm.model.graph.node.append(node) 76 | return node 77 | 78 | 79 | @onnx_mm_export("v12.Einsum") 80 | def Einsum(Inputs, **kwargs): 81 | _inputs = [] 82 | for i in (Inputs, ): 83 | _add_input(i, _inputs) 84 | 85 | idx = omm.op_counter["Einsum"] 86 | omm.op_counter["Einsum"] += 1 87 | node = onnx.helper.make_node("Einsum", 88 | _inputs, [f'_t_Einsum_{idx}_Output'], 89 | name=f"Einsum_{idx}", 90 | **kwargs) 91 | onnx.checker.check_node(node, omm.ctx) 92 | omm.model.graph.node.append(node) 93 | return node 94 | 95 | 96 | @onnx_mm_export("v12.GreaterOrEqual") 97 | def GreaterOrEqual(A, B, **kwargs): 98 | _inputs = [] 99 | for i in (A, B): 100 | _add_input(i, _inputs) 101 | 102 | idx = omm.op_counter["GreaterOrEqual"] 103 | omm.op_counter["GreaterOrEqual"] += 1 104 | node = onnx.helper.make_node("GreaterOrEqual", 105 | _inputs, [f'_t_GreaterOrEqual_{idx}_C'], 106 | name=f"GreaterOrEqual_{idx}", 107 | **kwargs) 108 | onnx.checker.check_node(node, omm.ctx) 109 | omm.model.graph.node.append(node) 110 | return node 111 | 112 | 113 | @onnx_mm_export("v12.NegativeLogLikelihoodLoss") 114 | def NegativeLogLikelihoodLoss(input, target, weight=None, **kwargs): 115 | _inputs = [] 116 | for i in (input, target, weight): 117 | _add_input(i, _inputs) 118 | 119 | idx = omm.op_counter["NegativeLogLikelihoodLoss"] 120 | omm.op_counter["NegativeLogLikelihoodLoss"] += 1 121 | node = onnx.helper.make_node("NegativeLogLikelihoodLoss", 122 | _inputs, [f'_t_NegativeLogLikelihoodLoss_{idx}_loss'], 123 | name=f"NegativeLogLikelihoodLoss_{idx}", 124 | **kwargs) 125 | onnx.checker.check_node(node, omm.ctx) 126 | omm.model.graph.node.append(node) 127 | return node 128 | 129 | 130 | @onnx_mm_export("v12.ReduceMin") 131 | def ReduceMin(data, **kwargs): 132 | _inputs = [] 133 | for i in (data, ): 134 | _add_input(i, _inputs) 135 | 136 | idx = omm.op_counter["ReduceMin"] 137 | omm.op_counter["ReduceMin"] += 1 138 | node = onnx.helper.make_node("ReduceMin", 139 | _inputs, [f'_t_ReduceMin_{idx}_reduced'], 140 | name=f"ReduceMin_{idx}", 141 | **kwargs) 142 | onnx.checker.check_node(node, omm.ctx) 143 | omm.model.graph.node.append(node) 144 | return node 145 | 146 | 147 | @onnx_mm_export("v12.Min") 148 | def Min(data_0, **kwargs): 149 | _inputs = [] 150 | for i in (data_0, ): 151 | _add_input(i, _inputs) 152 | 153 | idx = omm.op_counter["Min"] 154 | omm.op_counter["Min"] += 1 155 | node = onnx.helper.make_node("Min", 156 | _inputs, [f'_t_Min_{idx}_min'], 157 | name=f"Min_{idx}", 158 | **kwargs) 159 | onnx.checker.check_node(node, omm.ctx) 160 | omm.model.graph.node.append(node) 161 | return node 162 | 163 | 164 | @onnx_mm_export("v12.ReduceMax") 165 | def ReduceMax(data, **kwargs): 166 | _inputs = [] 167 | for i in (data, ): 168 | _add_input(i, _inputs) 169 | 170 | idx = omm.op_counter["ReduceMax"] 171 | omm.op_counter["ReduceMax"] += 1 172 | node = onnx.helper.make_node("ReduceMax", 173 | _inputs, [f'_t_ReduceMax_{idx}_reduced'], 174 | name=f"ReduceMax_{idx}", 175 | **kwargs) 176 | onnx.checker.check_node(node, omm.ctx) 177 | omm.model.graph.node.append(node) 178 | return node 179 | 180 | 181 | @onnx_mm_export("v12.ArgMax") 182 | def ArgMax(data, **kwargs): 183 | _inputs = [] 184 | for i in (data, ): 185 | _add_input(i, _inputs) 186 | 187 | idx = omm.op_counter["ArgMax"] 188 | omm.op_counter["ArgMax"] += 1 189 | node = onnx.helper.make_node("ArgMax", 190 | _inputs, [f'_t_ArgMax_{idx}_reduced'], 191 | name=f"ArgMax_{idx}", 192 | **kwargs) 193 | onnx.checker.check_node(node, omm.ctx) 194 | omm.model.graph.node.append(node) 195 | return node 196 | 197 | 198 | @onnx_mm_export("v12.SoftmaxCrossEntropyLoss") 199 | def SoftmaxCrossEntropyLoss(scores, labels, weights=None, **kwargs): 200 | _inputs = [] 201 | for i in (scores, labels, weights): 202 | _add_input(i, _inputs) 203 | 204 | idx = omm.op_counter["SoftmaxCrossEntropyLoss"] 205 | omm.op_counter["SoftmaxCrossEntropyLoss"] += 1 206 | node = onnx.helper.make_node("SoftmaxCrossEntropyLoss", 207 | _inputs, [f'_t_SoftmaxCrossEntropyLoss_{idx}_output', f'_t_SoftmaxCrossEntropyLoss_{idx}_log_prob'], 208 | name=f"SoftmaxCrossEntropyLoss_{idx}", 209 | **kwargs) 210 | onnx.checker.check_node(node, omm.ctx) 211 | omm.model.graph.node.append(node) 212 | return node 213 | 214 | 215 | @onnx_mm_export("v12.Clip") 216 | def Clip(input, min=None, max=None, **kwargs): 217 | _inputs = [] 218 | for i in (input, min, max): 219 | _add_input(i, _inputs) 220 | 221 | idx = omm.op_counter["Clip"] 222 | omm.op_counter["Clip"] += 1 223 | node = onnx.helper.make_node("Clip", 224 | _inputs, [f'_t_Clip_{idx}_output'], 225 | name=f"Clip_{idx}", 226 | **kwargs) 227 | onnx.checker.check_node(node, omm.ctx) 228 | omm.model.graph.node.append(node) 229 | return node 230 | 231 | 232 | @onnx_mm_export("v12.ArgMin") 233 | def ArgMin(data, **kwargs): 234 | _inputs = [] 235 | for i in (data, ): 236 | _add_input(i, _inputs) 237 | 238 | idx = omm.op_counter["ArgMin"] 239 | omm.op_counter["ArgMin"] += 1 240 | node = onnx.helper.make_node("ArgMin", 241 | _inputs, [f'_t_ArgMin_{idx}_reduced'], 242 | name=f"ArgMin_{idx}", 243 | **kwargs) 244 | onnx.checker.check_node(node, omm.ctx) 245 | omm.model.graph.node.append(node) 246 | return node 247 | 248 | 249 | @onnx_mm_export("v12.Constant") 250 | def Constant(**kwargs): 251 | _inputs = [] 252 | for i in (): 253 | _add_input(i, _inputs) 254 | 255 | idx = omm.op_counter["Constant"] 256 | omm.op_counter["Constant"] += 1 257 | node = onnx.helper.make_node("Constant", 258 | _inputs, [f'_t_Constant_{idx}_output'], 259 | name=f"Constant_{idx}", 260 | **kwargs) 261 | onnx.checker.check_node(node, omm.ctx) 262 | omm.model.graph.node.append(node) 263 | return node 264 | 265 | 266 | @onnx_mm_export("v12.Pow") 267 | def Pow(X, Y, **kwargs): 268 | _inputs = [] 269 | for i in (X, Y): 270 | _add_input(i, _inputs) 271 | 272 | idx = omm.op_counter["Pow"] 273 | omm.op_counter["Pow"] += 1 274 | node = onnx.helper.make_node("Pow", 275 | _inputs, [f'_t_Pow_{idx}_Z'], 276 | name=f"Pow_{idx}", 277 | **kwargs) 278 | onnx.checker.check_node(node, omm.ctx) 279 | omm.model.graph.node.append(node) 280 | return node 281 | 282 | 283 | @onnx_mm_export("v12.MaxPool") 284 | def MaxPool(X, **kwargs): 285 | _inputs = [] 286 | for i in (X, ): 287 | _add_input(i, _inputs) 288 | 289 | idx = omm.op_counter["MaxPool"] 290 | omm.op_counter["MaxPool"] += 1 291 | node = onnx.helper.make_node("MaxPool", 292 | _inputs, [f'_t_MaxPool_{idx}_Y', f'_t_MaxPool_{idx}_Indices'], 293 | name=f"MaxPool_{idx}", 294 | **kwargs) 295 | onnx.checker.check_node(node, omm.ctx) 296 | omm.model.graph.node.append(node) 297 | return node 298 | 299 | 300 | @onnx_mm_export("v12.Dropout") 301 | def Dropout(data, ratio=None, training_mode=None, **kwargs): 302 | _inputs = [] 303 | for i in (data, ratio, training_mode): 304 | _add_input(i, _inputs) 305 | 306 | idx = omm.op_counter["Dropout"] 307 | omm.op_counter["Dropout"] += 1 308 | node = onnx.helper.make_node("Dropout", 309 | _inputs, [f'_t_Dropout_{idx}_output', f'_t_Dropout_{idx}_mask'], 310 | name=f"Dropout_{idx}", 311 | **kwargs) 312 | onnx.checker.check_node(node, omm.ctx) 313 | omm.model.graph.node.append(node) 314 | return node 315 | -------------------------------------------------------------------------------- /onnx_model_maker/ops/op_ver_14.py: -------------------------------------------------------------------------------- 1 | # Autogenerated by onnx-model-maker. Don't modify it manually. 2 | 3 | import onnx 4 | import onnx.helper 5 | import onnx.numpy_helper 6 | from onnx_model_maker import omm 7 | from onnx_model_maker import onnx_mm_export 8 | from onnx_model_maker.ops.op_helper import _add_input 9 | 10 | 11 | @onnx_mm_export("v14.HardSwish") 12 | def HardSwish(X, **kwargs): 13 | _inputs = [] 14 | for i in (X, ): 15 | _add_input(i, _inputs) 16 | 17 | idx = omm.op_counter["HardSwish"] 18 | omm.op_counter["HardSwish"] += 1 19 | node = onnx.helper.make_node("HardSwish", 20 | _inputs, [f'_t_HardSwish_{idx}_Y'], 21 | name=f"HardSwish_{idx}", 22 | **kwargs) 23 | onnx.checker.check_node(node, omm.ctx) 24 | omm.model.graph.node.append(node) 25 | return node 26 | 27 | 28 | @onnx_mm_export("v14.CumSum") 29 | def CumSum(x, axis, **kwargs): 30 | _inputs = [] 31 | for i in (x, axis): 32 | _add_input(i, _inputs) 33 | 34 | idx = omm.op_counter["CumSum"] 35 | omm.op_counter["CumSum"] += 1 36 | node = onnx.helper.make_node("CumSum", 37 | _inputs, [f'_t_CumSum_{idx}_y'], 38 | name=f"CumSum_{idx}", 39 | **kwargs) 40 | onnx.checker.check_node(node, omm.ctx) 41 | omm.model.graph.node.append(node) 42 | return node 43 | 44 | 45 | @onnx_mm_export("v14.Trilu") 46 | def Trilu(input, k=None, **kwargs): 47 | _inputs = [] 48 | for i in (input, k): 49 | _add_input(i, _inputs) 50 | 51 | idx = omm.op_counter["Trilu"] 52 | omm.op_counter["Trilu"] += 1 53 | node = onnx.helper.make_node("Trilu", 54 | _inputs, [f'_t_Trilu_{idx}_output'], 55 | name=f"Trilu_{idx}", 56 | **kwargs) 57 | onnx.checker.check_node(node, omm.ctx) 58 | omm.model.graph.node.append(node) 59 | return node 60 | 61 | 62 | @onnx_mm_export("v14.Sub") 63 | def Sub(A, B, **kwargs): 64 | _inputs = [] 65 | for i in (A, B): 66 | _add_input(i, _inputs) 67 | 68 | idx = omm.op_counter["Sub"] 69 | omm.op_counter["Sub"] += 1 70 | node = onnx.helper.make_node("Sub", 71 | _inputs, [f'_t_Sub_{idx}_C'], 72 | name=f"Sub_{idx}", 73 | **kwargs) 74 | onnx.checker.check_node(node, omm.ctx) 75 | omm.model.graph.node.append(node) 76 | return node 77 | 78 | 79 | @onnx_mm_export("v14.Relu") 80 | def Relu(X, **kwargs): 81 | _inputs = [] 82 | for i in (X, ): 83 | _add_input(i, _inputs) 84 | 85 | idx = omm.op_counter["Relu"] 86 | omm.op_counter["Relu"] += 1 87 | node = onnx.helper.make_node("Relu", 88 | _inputs, [f'_t_Relu_{idx}_Y'], 89 | name=f"Relu_{idx}", 90 | **kwargs) 91 | onnx.checker.check_node(node, omm.ctx) 92 | omm.model.graph.node.append(node) 93 | return node 94 | 95 | 96 | @onnx_mm_export("v14.Mul") 97 | def Mul(A, B, **kwargs): 98 | _inputs = [] 99 | for i in (A, B): 100 | _add_input(i, _inputs) 101 | 102 | idx = omm.op_counter["Mul"] 103 | omm.op_counter["Mul"] += 1 104 | node = onnx.helper.make_node("Mul", 105 | _inputs, [f'_t_Mul_{idx}_C'], 106 | name=f"Mul_{idx}", 107 | **kwargs) 108 | onnx.checker.check_node(node, omm.ctx) 109 | omm.model.graph.node.append(node) 110 | return node 111 | 112 | 113 | @onnx_mm_export("v14.RNN") 114 | def RNN(X, W, R, B=None, sequence_lens=None, initial_h=None, **kwargs): 115 | _inputs = [] 116 | for i in (X, W, R, B, sequence_lens, initial_h): 117 | _add_input(i, _inputs) 118 | 119 | idx = omm.op_counter["RNN"] 120 | omm.op_counter["RNN"] += 1 121 | node = onnx.helper.make_node("RNN", 122 | _inputs, [f'_t_RNN_{idx}_Y', f'_t_RNN_{idx}_Y_h'], 123 | name=f"RNN_{idx}", 124 | **kwargs) 125 | onnx.checker.check_node(node, omm.ctx) 126 | omm.model.graph.node.append(node) 127 | return node 128 | 129 | 130 | @onnx_mm_export("v14.Reshape") 131 | def Reshape(data, shape, **kwargs): 132 | _inputs = [] 133 | for i in (data, shape): 134 | _add_input(i, _inputs) 135 | 136 | idx = omm.op_counter["Reshape"] 137 | omm.op_counter["Reshape"] += 1 138 | node = onnx.helper.make_node("Reshape", 139 | _inputs, [f'_t_Reshape_{idx}_reshaped'], 140 | name=f"Reshape_{idx}", 141 | **kwargs) 142 | onnx.checker.check_node(node, omm.ctx) 143 | omm.model.graph.node.append(node) 144 | return node 145 | 146 | 147 | @onnx_mm_export("v14.BatchNormalization") 148 | def BatchNormalization(X, scale, B, input_mean, input_var, **kwargs): 149 | _inputs = [] 150 | for i in (X, scale, B, input_mean, input_var): 151 | _add_input(i, _inputs) 152 | 153 | idx = omm.op_counter["BatchNormalization"] 154 | omm.op_counter["BatchNormalization"] += 1 155 | node = onnx.helper.make_node("BatchNormalization", 156 | _inputs, [f'_t_BatchNormalization_{idx}_Y'], 157 | name=f"BatchNormalization_{idx}", 158 | **kwargs) 159 | onnx.checker.check_node(node, omm.ctx) 160 | omm.model.graph.node.append(node) 161 | return node 162 | 163 | 164 | @onnx_mm_export("v14.LSTM") 165 | def LSTM(X, W, R, B=None, sequence_lens=None, initial_h=None, initial_c=None, P=None, **kwargs): 166 | _inputs = [] 167 | for i in (X, W, R, B, sequence_lens, initial_h, initial_c, P): 168 | _add_input(i, _inputs) 169 | 170 | idx = omm.op_counter["LSTM"] 171 | omm.op_counter["LSTM"] += 1 172 | node = onnx.helper.make_node("LSTM", 173 | _inputs, [f'_t_LSTM_{idx}_Y', f'_t_LSTM_{idx}_Y_h', f'_t_LSTM_{idx}_Y_c'], 174 | name=f"LSTM_{idx}", 175 | **kwargs) 176 | onnx.checker.check_node(node, omm.ctx) 177 | omm.model.graph.node.append(node) 178 | return node 179 | 180 | 181 | @onnx_mm_export("v14.GRU") 182 | def GRU(X, W, R, B=None, sequence_lens=None, initial_h=None, **kwargs): 183 | _inputs = [] 184 | for i in (X, W, R, B, sequence_lens, initial_h): 185 | _add_input(i, _inputs) 186 | 187 | idx = omm.op_counter["GRU"] 188 | omm.op_counter["GRU"] += 1 189 | node = onnx.helper.make_node("GRU", 190 | _inputs, [f'_t_GRU_{idx}_Y', f'_t_GRU_{idx}_Y_h'], 191 | name=f"GRU_{idx}", 192 | **kwargs) 193 | onnx.checker.check_node(node, omm.ctx) 194 | omm.model.graph.node.append(node) 195 | return node 196 | 197 | 198 | @onnx_mm_export("v14.Identity") 199 | def Identity(input, **kwargs): 200 | _inputs = [] 201 | for i in (input, ): 202 | _add_input(i, _inputs) 203 | 204 | idx = omm.op_counter["Identity"] 205 | omm.op_counter["Identity"] += 1 206 | node = onnx.helper.make_node("Identity", 207 | _inputs, [f'_t_Identity_{idx}_output'], 208 | name=f"Identity_{idx}", 209 | **kwargs) 210 | onnx.checker.check_node(node, omm.ctx) 211 | omm.model.graph.node.append(node) 212 | return node 213 | 214 | 215 | @onnx_mm_export("v14.Add") 216 | def Add(A, B, **kwargs): 217 | _inputs = [] 218 | for i in (A, B): 219 | _add_input(i, _inputs) 220 | 221 | idx = omm.op_counter["Add"] 222 | omm.op_counter["Add"] += 1 223 | node = onnx.helper.make_node("Add", 224 | _inputs, [f'_t_Add_{idx}_C'], 225 | name=f"Add_{idx}", 226 | **kwargs) 227 | onnx.checker.check_node(node, omm.ctx) 228 | omm.model.graph.node.append(node) 229 | return node 230 | 231 | 232 | @onnx_mm_export("v14.Div") 233 | def Div(A, B, **kwargs): 234 | _inputs = [] 235 | for i in (A, B): 236 | _add_input(i, _inputs) 237 | 238 | idx = omm.op_counter["Div"] 239 | omm.op_counter["Div"] += 1 240 | node = onnx.helper.make_node("Div", 241 | _inputs, [f'_t_Div_{idx}_C'], 242 | name=f"Div_{idx}", 243 | **kwargs) 244 | onnx.checker.check_node(node, omm.ctx) 245 | omm.model.graph.node.append(node) 246 | return node 247 | -------------------------------------------------------------------------------- /onnx_model_maker/ops/op_ver_15.py: -------------------------------------------------------------------------------- 1 | # Autogenerated by onnx-model-maker. Don't modify it manually. 2 | 3 | import onnx 4 | import onnx.helper 5 | import onnx.numpy_helper 6 | from onnx_model_maker import omm 7 | from onnx_model_maker import onnx_mm_export 8 | from onnx_model_maker.ops.op_helper import _add_input 9 | 10 | 11 | @onnx_mm_export("v15.CastLike") 12 | def CastLike(input, target_type, **kwargs): 13 | _inputs = [] 14 | for i in (input, target_type): 15 | _add_input(i, _inputs) 16 | 17 | idx = omm.op_counter["CastLike"] 18 | omm.op_counter["CastLike"] += 1 19 | node = onnx.helper.make_node("CastLike", 20 | _inputs, [f'_t_CastLike_{idx}_output'], 21 | name=f"CastLike_{idx}", 22 | **kwargs) 23 | onnx.checker.check_node(node, omm.ctx) 24 | omm.model.graph.node.append(node) 25 | return node 26 | 27 | 28 | @onnx_mm_export("v15.OptionalGetElement") 29 | def OptionalGetElement(input, **kwargs): 30 | _inputs = [] 31 | for i in (input, ): 32 | _add_input(i, _inputs) 33 | 34 | idx = omm.op_counter["OptionalGetElement"] 35 | omm.op_counter["OptionalGetElement"] += 1 36 | node = onnx.helper.make_node("OptionalGetElement", 37 | _inputs, [f'_t_OptionalGetElement_{idx}_output'], 38 | name=f"OptionalGetElement_{idx}", 39 | **kwargs) 40 | onnx.checker.check_node(node, omm.ctx) 41 | omm.model.graph.node.append(node) 42 | return node 43 | 44 | 45 | @onnx_mm_export("v15.Optional") 46 | def Optional(input=None, **kwargs): 47 | _inputs = [] 48 | for i in (input, ): 49 | _add_input(i, _inputs) 50 | 51 | idx = omm.op_counter["Optional"] 52 | omm.op_counter["Optional"] += 1 53 | node = onnx.helper.make_node("Optional", 54 | _inputs, [f'_t_Optional_{idx}_output'], 55 | name=f"Optional_{idx}", 56 | **kwargs) 57 | onnx.checker.check_node(node, omm.ctx) 58 | omm.model.graph.node.append(node) 59 | return node 60 | 61 | 62 | @onnx_mm_export("v15.Shape") 63 | def Shape(data, **kwargs): 64 | _inputs = [] 65 | for i in (data, ): 66 | _add_input(i, _inputs) 67 | 68 | idx = omm.op_counter["Shape"] 69 | omm.op_counter["Shape"] += 1 70 | node = onnx.helper.make_node("Shape", 71 | _inputs, [f'_t_Shape_{idx}_shape'], 72 | name=f"Shape_{idx}", 73 | **kwargs) 74 | onnx.checker.check_node(node, omm.ctx) 75 | omm.model.graph.node.append(node) 76 | return node 77 | 78 | 79 | @onnx_mm_export("v15.OptionalHasElement") 80 | def OptionalHasElement(input, **kwargs): 81 | _inputs = [] 82 | for i in (input, ): 83 | _add_input(i, _inputs) 84 | 85 | idx = omm.op_counter["OptionalHasElement"] 86 | omm.op_counter["OptionalHasElement"] += 1 87 | node = onnx.helper.make_node("OptionalHasElement", 88 | _inputs, [f'_t_OptionalHasElement_{idx}_output'], 89 | name=f"OptionalHasElement_{idx}", 90 | **kwargs) 91 | onnx.checker.check_node(node, omm.ctx) 92 | omm.model.graph.node.append(node) 93 | return node 94 | 95 | 96 | @onnx_mm_export("v15.Bernoulli") 97 | def Bernoulli(input, **kwargs): 98 | _inputs = [] 99 | for i in (input, ): 100 | _add_input(i, _inputs) 101 | 102 | idx = omm.op_counter["Bernoulli"] 103 | omm.op_counter["Bernoulli"] += 1 104 | node = onnx.helper.make_node("Bernoulli", 105 | _inputs, [f'_t_Bernoulli_{idx}_output'], 106 | name=f"Bernoulli_{idx}", 107 | **kwargs) 108 | onnx.checker.check_node(node, omm.ctx) 109 | omm.model.graph.node.append(node) 110 | return node 111 | 112 | 113 | @onnx_mm_export("v15.BatchNormalization") 114 | def BatchNormalization(X, scale, B, input_mean, input_var, **kwargs): 115 | _inputs = [] 116 | for i in (X, scale, B, input_mean, input_var): 117 | _add_input(i, _inputs) 118 | 119 | idx = omm.op_counter["BatchNormalization"] 120 | omm.op_counter["BatchNormalization"] += 1 121 | node = onnx.helper.make_node("BatchNormalization", 122 | _inputs, [f'_t_BatchNormalization_{idx}_Y'], 123 | name=f"BatchNormalization_{idx}", 124 | **kwargs) 125 | onnx.checker.check_node(node, omm.ctx) 126 | omm.model.graph.node.append(node) 127 | return node 128 | 129 | 130 | @onnx_mm_export("v15.Pow") 131 | def Pow(X, Y, **kwargs): 132 | _inputs = [] 133 | for i in (X, Y): 134 | _add_input(i, _inputs) 135 | 136 | idx = omm.op_counter["Pow"] 137 | omm.op_counter["Pow"] += 1 138 | node = onnx.helper.make_node("Pow", 139 | _inputs, [f'_t_Pow_{idx}_Z'], 140 | name=f"Pow_{idx}", 141 | **kwargs) 142 | onnx.checker.check_node(node, omm.ctx) 143 | omm.model.graph.node.append(node) 144 | return node 145 | -------------------------------------------------------------------------------- /onnx_model_maker/ops/op_ver_16.py: -------------------------------------------------------------------------------- 1 | # Autogenerated by onnx-model-maker. Don't modify it manually. 2 | 3 | import onnx 4 | import onnx.helper 5 | import onnx.numpy_helper 6 | from onnx_model_maker import omm 7 | from onnx_model_maker import onnx_mm_export 8 | from onnx_model_maker.ops.op_helper import _add_input 9 | 10 | 11 | @onnx_mm_export("v16.LessOrEqual") 12 | def LessOrEqual(A, B, **kwargs): 13 | _inputs = [] 14 | for i in (A, B): 15 | _add_input(i, _inputs) 16 | 17 | idx = omm.op_counter["LessOrEqual"] 18 | omm.op_counter["LessOrEqual"] += 1 19 | node = onnx.helper.make_node("LessOrEqual", 20 | _inputs, [f'_t_LessOrEqual_{idx}_C'], 21 | name=f"LessOrEqual_{idx}", 22 | **kwargs) 23 | onnx.checker.check_node(node, omm.ctx) 24 | omm.model.graph.node.append(node) 25 | return node 26 | 27 | 28 | @onnx_mm_export("v16.ScatterND") 29 | def ScatterND(data, indices, updates, **kwargs): 30 | _inputs = [] 31 | for i in (data, indices, updates): 32 | _add_input(i, _inputs) 33 | 34 | idx = omm.op_counter["ScatterND"] 35 | omm.op_counter["ScatterND"] += 1 36 | node = onnx.helper.make_node("ScatterND", 37 | _inputs, [f'_t_ScatterND_{idx}_output'], 38 | name=f"ScatterND_{idx}", 39 | **kwargs) 40 | onnx.checker.check_node(node, omm.ctx) 41 | omm.model.graph.node.append(node) 42 | return node 43 | 44 | 45 | @onnx_mm_export("v16.ScatterElements") 46 | def ScatterElements(data, indices, updates, **kwargs): 47 | _inputs = [] 48 | for i in (data, indices, updates): 49 | _add_input(i, _inputs) 50 | 51 | idx = omm.op_counter["ScatterElements"] 52 | omm.op_counter["ScatterElements"] += 1 53 | node = onnx.helper.make_node("ScatterElements", 54 | _inputs, [f'_t_ScatterElements_{idx}_output'], 55 | name=f"ScatterElements_{idx}", 56 | **kwargs) 57 | onnx.checker.check_node(node, omm.ctx) 58 | omm.model.graph.node.append(node) 59 | return node 60 | 61 | 62 | @onnx_mm_export("v16.RoiAlign") 63 | def RoiAlign(X, rois, batch_indices, **kwargs): 64 | _inputs = [] 65 | for i in (X, rois, batch_indices): 66 | _add_input(i, _inputs) 67 | 68 | idx = omm.op_counter["RoiAlign"] 69 | omm.op_counter["RoiAlign"] += 1 70 | node = onnx.helper.make_node("RoiAlign", 71 | _inputs, [f'_t_RoiAlign_{idx}_Y'], 72 | name=f"RoiAlign_{idx}", 73 | **kwargs) 74 | onnx.checker.check_node(node, omm.ctx) 75 | omm.model.graph.node.append(node) 76 | return node 77 | 78 | 79 | @onnx_mm_export("v16.Scan") 80 | def Scan(initial_state_and_scan_inputs, **kwargs): 81 | _inputs = [] 82 | for i in (initial_state_and_scan_inputs, ): 83 | _add_input(i, _inputs) 84 | 85 | idx = omm.op_counter["Scan"] 86 | omm.op_counter["Scan"] += 1 87 | node = onnx.helper.make_node("Scan", 88 | _inputs, [f'_t_Scan_{idx}_final_state_and_scan_outputs'], 89 | name=f"Scan_{idx}", 90 | **kwargs) 91 | onnx.checker.check_node(node, omm.ctx) 92 | omm.model.graph.node.append(node) 93 | return node 94 | 95 | 96 | @onnx_mm_export("v16.Where") 97 | def Where(condition, X, Y, **kwargs): 98 | _inputs = [] 99 | for i in (condition, X, Y): 100 | _add_input(i, _inputs) 101 | 102 | idx = omm.op_counter["Where"] 103 | omm.op_counter["Where"] += 1 104 | node = onnx.helper.make_node("Where", 105 | _inputs, [f'_t_Where_{idx}_output'], 106 | name=f"Where_{idx}", 107 | **kwargs) 108 | onnx.checker.check_node(node, omm.ctx) 109 | omm.model.graph.node.append(node) 110 | return node 111 | 112 | 113 | @onnx_mm_export("v16.GreaterOrEqual") 114 | def GreaterOrEqual(A, B, **kwargs): 115 | _inputs = [] 116 | for i in (A, B): 117 | _add_input(i, _inputs) 118 | 119 | idx = omm.op_counter["GreaterOrEqual"] 120 | omm.op_counter["GreaterOrEqual"] += 1 121 | node = onnx.helper.make_node("GreaterOrEqual", 122 | _inputs, [f'_t_GreaterOrEqual_{idx}_C'], 123 | name=f"GreaterOrEqual_{idx}", 124 | **kwargs) 125 | onnx.checker.check_node(node, omm.ctx) 126 | omm.model.graph.node.append(node) 127 | return node 128 | 129 | 130 | @onnx_mm_export("v16.Loop") 131 | def Loop(M, cond, v_initial=None, **kwargs): 132 | _inputs = [] 133 | for i in (M, cond, v_initial): 134 | _add_input(i, _inputs) 135 | 136 | idx = omm.op_counter["Loop"] 137 | omm.op_counter["Loop"] += 1 138 | node = onnx.helper.make_node("Loop", 139 | _inputs, [f'_t_Loop_{idx}_v_final_and_scan_outputs'], 140 | name=f"Loop_{idx}", 141 | **kwargs) 142 | onnx.checker.check_node(node, omm.ctx) 143 | omm.model.graph.node.append(node) 144 | return node 145 | 146 | 147 | @onnx_mm_export("v16.LeakyRelu") 148 | def LeakyRelu(X, **kwargs): 149 | _inputs = [] 150 | for i in (X, ): 151 | _add_input(i, _inputs) 152 | 153 | idx = omm.op_counter["LeakyRelu"] 154 | omm.op_counter["LeakyRelu"] += 1 155 | node = onnx.helper.make_node("LeakyRelu", 156 | _inputs, [f'_t_LeakyRelu_{idx}_Y'], 157 | name=f"LeakyRelu_{idx}", 158 | **kwargs) 159 | onnx.checker.check_node(node, omm.ctx) 160 | omm.model.graph.node.append(node) 161 | return node 162 | 163 | 164 | @onnx_mm_export("v16.If") 165 | def If(cond, **kwargs): 166 | _inputs = [] 167 | for i in (cond, ): 168 | _add_input(i, _inputs) 169 | 170 | idx = omm.op_counter["If"] 171 | omm.op_counter["If"] += 1 172 | node = onnx.helper.make_node("If", 173 | _inputs, [f'_t_If_{idx}_outputs'], 174 | name=f"If_{idx}", 175 | **kwargs) 176 | onnx.checker.check_node(node, omm.ctx) 177 | omm.model.graph.node.append(node) 178 | return node 179 | 180 | 181 | @onnx_mm_export("v16.GridSample") 182 | def GridSample(X, grid, **kwargs): 183 | _inputs = [] 184 | for i in (X, grid): 185 | _add_input(i, _inputs) 186 | 187 | idx = omm.op_counter["GridSample"] 188 | omm.op_counter["GridSample"] += 1 189 | node = onnx.helper.make_node("GridSample", 190 | _inputs, [f'_t_GridSample_{idx}_Y'], 191 | name=f"GridSample_{idx}", 192 | **kwargs) 193 | onnx.checker.check_node(node, omm.ctx) 194 | omm.model.graph.node.append(node) 195 | return node 196 | 197 | 198 | @onnx_mm_export("v16.Identity") 199 | def Identity(input, **kwargs): 200 | _inputs = [] 201 | for i in (input, ): 202 | _add_input(i, _inputs) 203 | 204 | idx = omm.op_counter["Identity"] 205 | omm.op_counter["Identity"] += 1 206 | node = onnx.helper.make_node("Identity", 207 | _inputs, [f'_t_Identity_{idx}_output'], 208 | name=f"Identity_{idx}", 209 | **kwargs) 210 | onnx.checker.check_node(node, omm.ctx) 211 | omm.model.graph.node.append(node) 212 | return node 213 | 214 | 215 | @onnx_mm_export("v16.PRelu") 216 | def PRelu(X, slope, **kwargs): 217 | _inputs = [] 218 | for i in (X, slope): 219 | _add_input(i, _inputs) 220 | 221 | idx = omm.op_counter["PRelu"] 222 | omm.op_counter["PRelu"] += 1 223 | node = onnx.helper.make_node("PRelu", 224 | _inputs, [f'_t_PRelu_{idx}_Y'], 225 | name=f"PRelu_{idx}", 226 | **kwargs) 227 | onnx.checker.check_node(node, omm.ctx) 228 | omm.model.graph.node.append(node) 229 | return node 230 | -------------------------------------------------------------------------------- /onnx_model_maker/ops/op_ver_17.py: -------------------------------------------------------------------------------- 1 | # Autogenerated by onnx-model-maker. Don't modify it manually. 2 | 3 | import onnx 4 | import onnx.helper 5 | import onnx.numpy_helper 6 | from onnx_model_maker import omm 7 | from onnx_model_maker import onnx_mm_export 8 | from onnx_model_maker.ops.op_helper import _add_input 9 | 10 | 11 | @onnx_mm_export("v17.STFT") 12 | def STFT(signal, frame_step, window=None, frame_length=None, **kwargs): 13 | _inputs = [] 14 | for i in (signal, frame_step, window, frame_length): 15 | _add_input(i, _inputs) 16 | 17 | idx = omm.op_counter["STFT"] 18 | omm.op_counter["STFT"] += 1 19 | node = onnx.helper.make_node("STFT", 20 | _inputs, [f'_t_STFT_{idx}_output'], 21 | name=f"STFT_{idx}", 22 | **kwargs) 23 | onnx.checker.check_node(node, omm.ctx) 24 | omm.model.graph.node.append(node) 25 | return node 26 | 27 | 28 | @onnx_mm_export("v17.MelWeightMatrix") 29 | def MelWeightMatrix(num_mel_bins, dft_length, sample_rate, lower_edge_hertz, upper_edge_hertz, **kwargs): 30 | _inputs = [] 31 | for i in (num_mel_bins, dft_length, sample_rate, lower_edge_hertz, upper_edge_hertz): 32 | _add_input(i, _inputs) 33 | 34 | idx = omm.op_counter["MelWeightMatrix"] 35 | omm.op_counter["MelWeightMatrix"] += 1 36 | node = onnx.helper.make_node("MelWeightMatrix", 37 | _inputs, [f'_t_MelWeightMatrix_{idx}_output'], 38 | name=f"MelWeightMatrix_{idx}", 39 | **kwargs) 40 | onnx.checker.check_node(node, omm.ctx) 41 | omm.model.graph.node.append(node) 42 | return node 43 | 44 | 45 | @onnx_mm_export("v17.SequenceMap") 46 | def SequenceMap(input_sequence, additional_inputs=None, **kwargs): 47 | _inputs = [] 48 | for i in (input_sequence, additional_inputs): 49 | _add_input(i, _inputs) 50 | 51 | idx = omm.op_counter["SequenceMap"] 52 | omm.op_counter["SequenceMap"] += 1 53 | node = onnx.helper.make_node("SequenceMap", 54 | _inputs, [f'_t_SequenceMap_{idx}_out_sequence'], 55 | name=f"SequenceMap_{idx}", 56 | **kwargs) 57 | onnx.checker.check_node(node, omm.ctx) 58 | omm.model.graph.node.append(node) 59 | return node 60 | 61 | 62 | @onnx_mm_export("v17.HannWindow") 63 | def HannWindow(size, **kwargs): 64 | _inputs = [] 65 | for i in (size, ): 66 | _add_input(i, _inputs) 67 | 68 | idx = omm.op_counter["HannWindow"] 69 | omm.op_counter["HannWindow"] += 1 70 | node = onnx.helper.make_node("HannWindow", 71 | _inputs, [f'_t_HannWindow_{idx}_output'], 72 | name=f"HannWindow_{idx}", 73 | **kwargs) 74 | onnx.checker.check_node(node, omm.ctx) 75 | omm.model.graph.node.append(node) 76 | return node 77 | 78 | 79 | @onnx_mm_export("v17.BlackmanWindow") 80 | def BlackmanWindow(size, **kwargs): 81 | _inputs = [] 82 | for i in (size, ): 83 | _add_input(i, _inputs) 84 | 85 | idx = omm.op_counter["BlackmanWindow"] 86 | omm.op_counter["BlackmanWindow"] += 1 87 | node = onnx.helper.make_node("BlackmanWindow", 88 | _inputs, [f'_t_BlackmanWindow_{idx}_output'], 89 | name=f"BlackmanWindow_{idx}", 90 | **kwargs) 91 | onnx.checker.check_node(node, omm.ctx) 92 | omm.model.graph.node.append(node) 93 | return node 94 | 95 | 96 | @onnx_mm_export("v17.LayerNormalization") 97 | def LayerNormalization(X, Scale, B=None, **kwargs): 98 | _inputs = [] 99 | for i in (X, Scale, B): 100 | _add_input(i, _inputs) 101 | 102 | idx = omm.op_counter["LayerNormalization"] 103 | omm.op_counter["LayerNormalization"] += 1 104 | node = onnx.helper.make_node("LayerNormalization", 105 | _inputs, [f'_t_LayerNormalization_{idx}_Y', f'_t_LayerNormalization_{idx}_Mean', f'_t_LayerNormalization_{idx}_InvStdDev'], 106 | name=f"LayerNormalization_{idx}", 107 | **kwargs) 108 | onnx.checker.check_node(node, omm.ctx) 109 | omm.model.graph.node.append(node) 110 | return node 111 | 112 | 113 | @onnx_mm_export("v17.HammingWindow") 114 | def HammingWindow(size, **kwargs): 115 | _inputs = [] 116 | for i in (size, ): 117 | _add_input(i, _inputs) 118 | 119 | idx = omm.op_counter["HammingWindow"] 120 | omm.op_counter["HammingWindow"] += 1 121 | node = onnx.helper.make_node("HammingWindow", 122 | _inputs, [f'_t_HammingWindow_{idx}_output'], 123 | name=f"HammingWindow_{idx}", 124 | **kwargs) 125 | onnx.checker.check_node(node, omm.ctx) 126 | omm.model.graph.node.append(node) 127 | return node 128 | 129 | 130 | @onnx_mm_export("v17.DFT") 131 | def DFT(input, dft_length=None, **kwargs): 132 | _inputs = [] 133 | for i in (input, dft_length): 134 | _add_input(i, _inputs) 135 | 136 | idx = omm.op_counter["DFT"] 137 | omm.op_counter["DFT"] += 1 138 | node = onnx.helper.make_node("DFT", 139 | _inputs, [f'_t_DFT_{idx}_output'], 140 | name=f"DFT_{idx}", 141 | **kwargs) 142 | onnx.checker.check_node(node, omm.ctx) 143 | omm.model.graph.node.append(node) 144 | return node 145 | -------------------------------------------------------------------------------- /onnx_model_maker/ops/op_ver_2.py: -------------------------------------------------------------------------------- 1 | # Autogenerated by onnx-model-maker. Don't modify it manually. 2 | 3 | import onnx 4 | import onnx.helper 5 | import onnx.numpy_helper 6 | from onnx_model_maker import omm 7 | from onnx_model_maker import onnx_mm_export 8 | from onnx_model_maker.ops.op_helper import _add_input 9 | 10 | 11 | @onnx_mm_export("v2.LabelEncoder") 12 | def LabelEncoder(X, **kwargs): 13 | _inputs = [] 14 | for i in (X, ): 15 | _add_input(i, _inputs) 16 | 17 | idx = omm.op_counter["LabelEncoder"] 18 | omm.op_counter["LabelEncoder"] += 1 19 | node = onnx.helper.make_node("LabelEncoder", 20 | _inputs, [f'_t_LabelEncoder_{idx}_Y'], 21 | name=f"LabelEncoder_{idx}", 22 | **kwargs) 23 | onnx.checker.check_node(node, omm.ctx) 24 | omm.model.graph.node.append(node) 25 | return node 26 | 27 | 28 | @onnx_mm_export("v2.LpPool") 29 | def LpPool(X, **kwargs): 30 | _inputs = [] 31 | for i in (X, ): 32 | _add_input(i, _inputs) 33 | 34 | idx = omm.op_counter["LpPool"] 35 | omm.op_counter["LpPool"] += 1 36 | node = onnx.helper.make_node("LpPool", 37 | _inputs, [f'_t_LpPool_{idx}_Y'], 38 | name=f"LpPool_{idx}", 39 | **kwargs) 40 | onnx.checker.check_node(node, omm.ctx) 41 | omm.model.graph.node.append(node) 42 | return node 43 | 44 | 45 | @onnx_mm_export("v2.Split") 46 | def Split(input, **kwargs): 47 | _inputs = [] 48 | for i in (input, ): 49 | _add_input(i, _inputs) 50 | 51 | idx = omm.op_counter["Split"] 52 | omm.op_counter["Split"] += 1 53 | node = onnx.helper.make_node("Split", 54 | _inputs, [f"_t_Split_{idx}_{i}" for i in range(len(kwargs["split"]))], 55 | name=f"Split_{idx}", 56 | **kwargs) 57 | onnx.checker.check_node(node, omm.ctx) 58 | omm.model.graph.node.append(node) 59 | return node 60 | 61 | 62 | @onnx_mm_export("v2.Pad") 63 | def Pad(data, **kwargs): 64 | _inputs = [] 65 | for i in (data, ): 66 | _add_input(i, _inputs) 67 | 68 | idx = omm.op_counter["Pad"] 69 | omm.op_counter["Pad"] += 1 70 | node = onnx.helper.make_node("Pad", 71 | _inputs, [f'_t_Pad_{idx}_output'], 72 | name=f"Pad_{idx}", 73 | **kwargs) 74 | onnx.checker.check_node(node, omm.ctx) 75 | omm.model.graph.node.append(node) 76 | return node 77 | 78 | 79 | @onnx_mm_export("v2.GlobalLpPool") 80 | def GlobalLpPool(X, **kwargs): 81 | _inputs = [] 82 | for i in (X, ): 83 | _add_input(i, _inputs) 84 | 85 | idx = omm.op_counter["GlobalLpPool"] 86 | omm.op_counter["GlobalLpPool"] += 1 87 | node = onnx.helper.make_node("GlobalLpPool", 88 | _inputs, [f'_t_GlobalLpPool_{idx}_Y'], 89 | name=f"GlobalLpPool_{idx}", 90 | **kwargs) 91 | onnx.checker.check_node(node, omm.ctx) 92 | omm.model.graph.node.append(node) 93 | return node 94 | -------------------------------------------------------------------------------- /onnx_model_maker/ops/op_ver_3.py: -------------------------------------------------------------------------------- 1 | # Autogenerated by onnx-model-maker. Don't modify it manually. 2 | 3 | import onnx 4 | import onnx.helper 5 | import onnx.numpy_helper 6 | from onnx_model_maker import omm 7 | from onnx_model_maker import onnx_mm_export 8 | from onnx_model_maker.ops.op_helper import _add_input 9 | 10 | 11 | @onnx_mm_export("v3.TreeEnsembleClassifier") 12 | def TreeEnsembleClassifier(X, **kwargs): 13 | _inputs = [] 14 | for i in (X, ): 15 | _add_input(i, _inputs) 16 | 17 | idx = omm.op_counter["TreeEnsembleClassifier"] 18 | omm.op_counter["TreeEnsembleClassifier"] += 1 19 | node = onnx.helper.make_node("TreeEnsembleClassifier", 20 | _inputs, [f'_t_TreeEnsembleClassifier_{idx}_Y', f'_t_TreeEnsembleClassifier_{idx}_Z'], 21 | name=f"TreeEnsembleClassifier_{idx}", 22 | **kwargs) 23 | onnx.checker.check_node(node, omm.ctx) 24 | omm.model.graph.node.append(node) 25 | return node 26 | 27 | 28 | @onnx_mm_export("v3.TreeEnsembleRegressor") 29 | def TreeEnsembleRegressor(X, **kwargs): 30 | _inputs = [] 31 | for i in (X, ): 32 | _add_input(i, _inputs) 33 | 34 | idx = omm.op_counter["TreeEnsembleRegressor"] 35 | omm.op_counter["TreeEnsembleRegressor"] += 1 36 | node = onnx.helper.make_node("TreeEnsembleRegressor", 37 | _inputs, [f'_t_TreeEnsembleRegressor_{idx}_Y'], 38 | name=f"TreeEnsembleRegressor_{idx}", 39 | **kwargs) 40 | onnx.checker.check_node(node, omm.ctx) 41 | omm.model.graph.node.append(node) 42 | return node 43 | 44 | 45 | @onnx_mm_export("v3.GRU") 46 | def GRU(X, W, R, B=None, sequence_lens=None, initial_h=None, **kwargs): 47 | _inputs = [] 48 | for i in (X, W, R, B, sequence_lens, initial_h): 49 | _add_input(i, _inputs) 50 | 51 | idx = omm.op_counter["GRU"] 52 | omm.op_counter["GRU"] += 1 53 | node = onnx.helper.make_node("GRU", 54 | _inputs, [f'_t_GRU_{idx}_Y', f'_t_GRU_{idx}_Y_h'], 55 | name=f"GRU_{idx}", 56 | **kwargs) 57 | onnx.checker.check_node(node, omm.ctx) 58 | omm.model.graph.node.append(node) 59 | return node 60 | -------------------------------------------------------------------------------- /onnx_model_maker/ops/op_ver_4.py: -------------------------------------------------------------------------------- 1 | # Autogenerated by onnx-model-maker. Don't modify it manually. 2 | 3 | import onnx 4 | import onnx.helper 5 | import onnx.numpy_helper 6 | from onnx_model_maker import omm 7 | from onnx_model_maker import onnx_mm_export 8 | from onnx_model_maker.ops.op_helper import _add_input 9 | 10 | 11 | @onnx_mm_export("v4.Concat") 12 | def Concat(inputs, **kwargs): 13 | _inputs = [] 14 | for i in (inputs, ): 15 | _add_input(i, _inputs) 16 | 17 | idx = omm.op_counter["Concat"] 18 | omm.op_counter["Concat"] += 1 19 | node = onnx.helper.make_node("Concat", 20 | _inputs, [f'_t_Concat_{idx}_concat_result'], 21 | name=f"Concat_{idx}", 22 | **kwargs) 23 | onnx.checker.check_node(node, omm.ctx) 24 | omm.model.graph.node.append(node) 25 | return node 26 | -------------------------------------------------------------------------------- /onnx_model_maker/ops/op_ver_5.py: -------------------------------------------------------------------------------- 1 | # Autogenerated by onnx-model-maker. Don't modify it manually. 2 | 3 | import onnx 4 | import onnx.helper 5 | import onnx.numpy_helper 6 | from onnx_model_maker import omm 7 | from onnx_model_maker import onnx_mm_export 8 | from onnx_model_maker.ops.op_helper import _add_input 9 | 10 | 11 | @onnx_mm_export("v5.Reshape") 12 | def Reshape(data, shape, **kwargs): 13 | _inputs = [] 14 | for i in (data, shape): 15 | _add_input(i, _inputs) 16 | 17 | idx = omm.op_counter["Reshape"] 18 | omm.op_counter["Reshape"] += 1 19 | node = onnx.helper.make_node("Reshape", 20 | _inputs, [f'_t_Reshape_{idx}_reshaped'], 21 | name=f"Reshape_{idx}", 22 | **kwargs) 23 | onnx.checker.check_node(node, omm.ctx) 24 | omm.model.graph.node.append(node) 25 | return node 26 | -------------------------------------------------------------------------------- /onnx_model_maker/ops/op_ver_8.py: -------------------------------------------------------------------------------- 1 | # Autogenerated by onnx-model-maker. Don't modify it manually. 2 | 3 | import onnx 4 | import onnx.helper 5 | import onnx.numpy_helper 6 | from onnx_model_maker import omm 7 | from onnx_model_maker import onnx_mm_export 8 | from onnx_model_maker.ops.op_helper import _add_input 9 | 10 | 11 | @onnx_mm_export("v8.Mean") 12 | def Mean(data_0, **kwargs): 13 | _inputs = [] 14 | for i in (data_0, ): 15 | _add_input(i, _inputs) 16 | 17 | idx = omm.op_counter["Mean"] 18 | omm.op_counter["Mean"] += 1 19 | node = onnx.helper.make_node("Mean", 20 | _inputs, [f'_t_Mean_{idx}_mean'], 21 | name=f"Mean_{idx}", 22 | **kwargs) 23 | onnx.checker.check_node(node, omm.ctx) 24 | omm.model.graph.node.append(node) 25 | return node 26 | 27 | 28 | @onnx_mm_export("v8.Max") 29 | def Max(data_0, **kwargs): 30 | _inputs = [] 31 | for i in (data_0, ): 32 | _add_input(i, _inputs) 33 | 34 | idx = omm.op_counter["Max"] 35 | omm.op_counter["Max"] += 1 36 | node = onnx.helper.make_node("Max", 37 | _inputs, [f'_t_Max_{idx}_max'], 38 | name=f"Max_{idx}", 39 | **kwargs) 40 | onnx.checker.check_node(node, omm.ctx) 41 | omm.model.graph.node.append(node) 42 | return node 43 | 44 | 45 | @onnx_mm_export("v8.Scan") 46 | def Scan(sequence_lens, initial_state_and_scan_inputs, **kwargs): 47 | _inputs = [] 48 | for i in (sequence_lens, initial_state_and_scan_inputs): 49 | _add_input(i, _inputs) 50 | 51 | idx = omm.op_counter["Scan"] 52 | omm.op_counter["Scan"] += 1 53 | node = onnx.helper.make_node("Scan", 54 | _inputs, [f'_t_Scan_{idx}_final_state_and_scan_outputs'], 55 | name=f"Scan_{idx}", 56 | **kwargs) 57 | onnx.checker.check_node(node, omm.ctx) 58 | omm.model.graph.node.append(node) 59 | return node 60 | 61 | 62 | @onnx_mm_export("v8.Expand") 63 | def Expand(input, shape, **kwargs): 64 | _inputs = [] 65 | for i in (input, shape): 66 | _add_input(i, _inputs) 67 | 68 | idx = omm.op_counter["Expand"] 69 | omm.op_counter["Expand"] += 1 70 | node = onnx.helper.make_node("Expand", 71 | _inputs, [f'_t_Expand_{idx}_output'], 72 | name=f"Expand_{idx}", 73 | **kwargs) 74 | onnx.checker.check_node(node, omm.ctx) 75 | omm.model.graph.node.append(node) 76 | return node 77 | 78 | 79 | @onnx_mm_export("v8.Sum") 80 | def Sum(data_0, **kwargs): 81 | _inputs = [] 82 | for i in (data_0, ): 83 | _add_input(i, _inputs) 84 | 85 | idx = omm.op_counter["Sum"] 86 | omm.op_counter["Sum"] += 1 87 | node = onnx.helper.make_node("Sum", 88 | _inputs, [f'_t_Sum_{idx}_sum'], 89 | name=f"Sum_{idx}", 90 | **kwargs) 91 | onnx.checker.check_node(node, omm.ctx) 92 | omm.model.graph.node.append(node) 93 | return node 94 | 95 | 96 | @onnx_mm_export("v8.Min") 97 | def Min(data_0, **kwargs): 98 | _inputs = [] 99 | for i in (data_0, ): 100 | _add_input(i, _inputs) 101 | 102 | idx = omm.op_counter["Min"] 103 | omm.op_counter["Min"] += 1 104 | node = onnx.helper.make_node("Min", 105 | _inputs, [f'_t_Min_{idx}_min'], 106 | name=f"Min_{idx}", 107 | **kwargs) 108 | onnx.checker.check_node(node, omm.ctx) 109 | omm.model.graph.node.append(node) 110 | return node 111 | 112 | 113 | @onnx_mm_export("v8.MaxPool") 114 | def MaxPool(X, **kwargs): 115 | _inputs = [] 116 | for i in (X, ): 117 | _add_input(i, _inputs) 118 | 119 | idx = omm.op_counter["MaxPool"] 120 | omm.op_counter["MaxPool"] += 1 121 | node = onnx.helper.make_node("MaxPool", 122 | _inputs, [f'_t_MaxPool_{idx}_Y', f'_t_MaxPool_{idx}_Indices'], 123 | name=f"MaxPool_{idx}", 124 | **kwargs) 125 | onnx.checker.check_node(node, omm.ctx) 126 | omm.model.graph.node.append(node) 127 | return node 128 | -------------------------------------------------------------------------------- /onnx_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from ._version import __version__ 2 | -------------------------------------------------------------------------------- /onnx_pytorch/_version.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.1.5" 2 | -------------------------------------------------------------------------------- /onnx_pytorch/code_gen.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import re 5 | import shutil 6 | from collections import Counter 7 | 8 | import numpy as np 9 | import onnx 10 | from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE 11 | from onnx.numpy_helper import to_array 12 | from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference 13 | 14 | from onnx_pytorch.code_gen_template import CodeGenTemplate 15 | from onnx_pytorch.op_code_generators import * 16 | from onnx_pytorch.utils.embedding_config_helper import load_embedding_config 17 | 18 | 19 | class RenameHelper: 20 | 21 | def __init__(self, simplify_names=False): 22 | self.simplify_names = simplify_names 23 | 24 | self.tensor_name_mapping = {} 25 | self.tensor_name_counter = Counter() 26 | self.node_name_mapping = {} 27 | self.node_name_counter = Counter() 28 | 29 | self.tensor_counter = 0 30 | self.node_counter = Counter() 31 | 32 | self.init_name_set = set() 33 | self.sim_tensor_name_set = set() 34 | 35 | def get_tensor_name(self, tensor_name): 36 | if self.simplify_names: 37 | return self.get_simplify_tensor_name(tensor_name) 38 | if tensor_name.isnumeric(): 39 | self.tensor_name_mapping[tensor_name] = f"t_{tensor_name}" 40 | return f"t_{tensor_name}" 41 | return tensor_name 42 | 43 | def get_node_name(self, node_name, op_type): 44 | if self.simplify_names or not node_name: 45 | return self.get_simplify_node_name(node_name, op_type) 46 | return f"n_{node_name}" 47 | 48 | def get_simplify_node_name(self, node_name, op_type): 49 | idx = self.node_counter[op_type] 50 | self.node_counter[op_type] += 1 51 | self.node_name_mapping[node_name] = f"n_{op_type}_{idx}" 52 | return self.node_name_mapping[node_name] 53 | 54 | def get_simplify_tensor_name(self, tensor_name): 55 | if tensor_name in self.tensor_name_mapping: 56 | return self.tensor_name_mapping[tensor_name] 57 | suffix = self.tensor_counter 58 | self.tensor_counter += 1 59 | sim_tensor_name = f"t_{suffix}" 60 | self.sim_tensor_name_set.add(sim_tensor_name) 61 | self.tensor_name_mapping[tensor_name] = sim_tensor_name 62 | return self.tensor_name_mapping[tensor_name] 63 | 64 | 65 | class ModelCodeGenerator: 66 | 67 | def __init__(self, 68 | onnx_model=None, 69 | output_dir=None, 70 | simplify_names=False, 71 | tensor_inplace=False, 72 | continue_on_error=False, 73 | embedding_conf=None, 74 | shape_infer=True): 75 | self.onnx_model = onnx_model 76 | self.output_dir = output_dir 77 | self.rename_helper = RenameHelper(simplify_names) 78 | self.tensor_inplace = tensor_inplace 79 | self.continue_on_error = continue_on_error 80 | self.embedding_conf = embedding_conf 81 | self.shape_infer = shape_infer 82 | self.init_parts = [] 83 | self.forward_parts = [] 84 | self.method_parts = {} 85 | 86 | def add_init_part(self, m): 87 | if type(m) in (list, tuple, set): 88 | self.init_parts.extend(m) 89 | else: 90 | self.init_parts.append(m) 91 | 92 | def add_forward_part(self, m): 93 | if type(m) in (list, tuple, set): 94 | self.forward_parts.extend(m) 95 | else: 96 | self.forward_parts.append(m) 97 | 98 | def add_forward_return(self, outputs_value_infos): 99 | return_list = [ 100 | f"{self.rename_helper.get_tensor_name(o.name)}" 101 | for o in outputs_value_infos 102 | ] 103 | self.forward_parts.append(f"return {', '.join(return_list)}") 104 | 105 | def add_forward_input(self, inputs_value_infos): 106 | initializer_names = {i.name for i in self.onnx_model.graph.initializer} 107 | return_list = [ 108 | f"{self.rename_helper.get_tensor_name(i.name)}" 109 | for i in inputs_value_infos 110 | if i.name not in initializer_names 111 | ] 112 | if len(return_list) == 1: 113 | self.forward_parts.append(f"{return_list[0]}, = inputs") 114 | else: 115 | self.forward_parts.append(f"{', '.join(return_list)} = inputs") 116 | 117 | def gen_model_code(self): 118 | return CodeGenTemplate.model(model_init=''' 119 | '''.join(self.init_parts), 120 | model_forward=''' 121 | '''.join(self.forward_parts), 122 | model_method=''' 123 | '''.join(self.method_parts.values()), 124 | test_run_model=self.gen_test_run_model_code()) 125 | 126 | def gen_test_run_model_code(self): 127 | numpy_input_str = [] 128 | initializer_names = {i.name for i in self.onnx_model.graph.initializer} 129 | for i in self.onnx_model.graph.input: 130 | if i.name in initializer_names: 131 | continue 132 | dtype = TENSOR_TYPE_TO_NP_TYPE[i.type.tensor_type.elem_type] 133 | shape = [] 134 | for d in i.type.tensor_type.shape.dim: 135 | if d.dim_param != "": 136 | shape.append(1) 137 | else: 138 | shape.append(d.dim_value) 139 | if shape: 140 | numpy_input_str.append( 141 | f"torch.from_numpy(np.random.randn(*{[s if s > 1 else 1 for s in shape].__repr__()}).astype(np.{dtype.name}))" 142 | ) 143 | else: 144 | numpy_input_str.append( 145 | f"torch.from_numpy(np.random.randn(1).astype(np.{dtype.name}))") 146 | test_run_model = [ 147 | f'''@torch.no_grad() 148 | def test_run_model(inputs=[{', '.join(numpy_input_str)}]):''', 149 | "model = Model()", "model.eval()" 150 | ] 151 | test_run_model.extend(["rs = model(*inputs)", "print(rs)", "return rs"]) 152 | return ''' 153 | '''.join(test_run_model) 154 | 155 | def preprocess_onnx_model(self): 156 | for n in self.onnx_model.graph.node: 157 | inputs, outputs = [], [] 158 | for ls, f in ((inputs, n.input), (outputs, n.output)): 159 | for i in f: 160 | new_i = re.sub("[:/.]", "_", i) 161 | ls.append(new_i) 162 | if i != ls[-1] and not self.rename_helper.simplify_names: 163 | logging.info(f"Tensor name {i} is changed to {ls[-1]}.") 164 | self.rename_helper.tensor_name_counter[ls[-1]] += 1 165 | 166 | n.ClearField("input") 167 | n.input.extend(inputs) 168 | n.ClearField("output") 169 | n.output.extend(outputs) 170 | 171 | old_name = n.name 172 | n.name = re.sub("[:/.]", "_", n.name) 173 | if old_name != n.name and not self.rename_helper.simplify_names: 174 | logging.info(f"Node name {old_name} is changed to {n.name}.") 175 | self.rename_helper.node_name_counter[n.name] += 1 176 | 177 | for f in (self.onnx_model.graph.input, self.onnx_model.graph.output, 178 | self.onnx_model.graph.initializer): 179 | for i in f: 180 | old_name = i.name 181 | i.name = re.sub("[:/.]", "_", i.name) 182 | if old_name != i.name and not self.rename_helper.simplify_names: 183 | logging.info(f"Tensor name {i.name} is changed to {i.name}.") 184 | self.rename_helper.tensor_name_counter[i.name] += 1 185 | 186 | model = self.onnx_model 187 | for f in (model.graph.input, model.graph.output): 188 | for i in f: 189 | for d in i.type.tensor_type.shape.dim: 190 | if d.dim_param != "": 191 | d.dim_param = "" 192 | d.dim_value = -1 193 | elif d.dim_value == 0: 194 | d.dim_value = -1 195 | # TODO how to deal with custom op? 196 | if self.shape_infer: 197 | try: 198 | model.graph.ClearField("value_info") 199 | model = SymbolicShapeInference.infer_shapes(model, 2**31 - 1, True, 200 | True, 1) 201 | except: 202 | logging.warning("Shape infer by onnxruntime failed.") 203 | else: 204 | for f in (self.onnx_model.graph.value_info,): 205 | for i in f: 206 | old_name = i.name 207 | i.name = re.sub("[:/.]", "_", i.name) 208 | if old_name != i.name and not self.rename_helper.simplify_names: 209 | logging.info(f"Tensor name {i.name} is changed to {i.name}.") 210 | self.rename_helper.tensor_name_counter[i.name] += 1 211 | onnx.save(model, os.path.join(self.output_dir, "tmp_processed.onnx")) 212 | self.onnx_model = model 213 | 214 | def add_attr_to_op_code_generator(self, op_code_gen): 215 | for k, v in { 216 | "rename_helper": self.rename_helper, 217 | "tensor_inplace": self.tensor_inplace, 218 | "embedding_conf": self.embedding_conf 219 | }.items(): 220 | if hasattr(op_code_gen, k): 221 | setattr(op_code_gen, k, v) 222 | 223 | def run(self): 224 | self.preprocess_onnx_model() 225 | initializers = {i.name: i for i in self.onnx_model.graph.initializer} 226 | input_value_infos = {i.name: i for i in self.onnx_model.graph.input} 227 | output_value_infos = {i.name: i for i in self.onnx_model.graph.output} 228 | value_infos = {} 229 | value_infos.update(input_value_infos) 230 | value_infos.update(output_value_infos) 231 | value_infos.update({i.name: i for i in self.onnx_model.graph.value_info}) 232 | 233 | for i in self.onnx_model.graph.initializer: 234 | self.rename_helper.get_tensor_name(i.name) 235 | 236 | self.add_forward_input(self.onnx_model.graph.input) 237 | for n in self.onnx_model.graph.node: 238 | op_code_gen = get_op_code_generator(n.op_type) 239 | self.add_attr_to_op_code_generator(op_code_gen) 240 | if op_code_gen is None: 241 | if self.continue_on_error: 242 | self.add_forward_part(n.__repr__()) 243 | logging.warning(f"OpCodeGenerator is unimplemented for {n.op_type}. " 244 | "Please modify this part by manual later.") 245 | else: 246 | raise NotImplementedError( 247 | f"OpCodeGenerator is unimplemented for {n.op_type}.") 248 | else: 249 | try: 250 | if hasattr(op_code_gen, 251 | "gen_method") and n.op_type not in self.method_parts: 252 | self.method_parts[n.op_type] = op_code_gen.gen_method() 253 | gened = op_code_gen.gen(n, value_infos, initializers) 254 | self.add_init_part(gened["init"]) 255 | self.add_forward_part(gened["forward"]) 256 | except BaseException as e: 257 | if self.continue_on_error: 258 | logging.warning(e) 259 | self.add_forward_part(n.__repr__()) 260 | else: 261 | raise e 262 | self.add_forward_return(self.onnx_model.graph.output) 263 | 264 | gened_code = self.gen_model_code() 265 | print(gened_code) 266 | with open(os.path.join(self.output_dir, "model.py"), "w") as f: 267 | f.write(gened_code) 268 | shutil.rmtree(os.path.join(self.output_dir, "variables"), 269 | ignore_errors=True) 270 | os.makedirs(os.path.join(self.output_dir, "variables")) 271 | for k, v in initializers.items(): 272 | np.save( 273 | os.path.join(self.output_dir, "variables", 274 | f"{self.rename_helper.get_tensor_name(k)}.npy"), 275 | to_array(v)) 276 | 277 | 278 | def gen( 279 | onnx_model, 280 | output_dir, 281 | overwrite=False, 282 | tensor_inplace=False, 283 | simplify_names=False, 284 | continue_on_error=False, 285 | embedding_conf_file=None, 286 | shape_infer=True, 287 | ): 288 | model_code_generator = get_model_code_generator( 289 | onnx_model, output_dir, overwrite, tensor_inplace, simplify_names, 290 | continue_on_error, embedding_conf_file, shape_infer) 291 | model_code_generator.run() 292 | 293 | 294 | def get_model_code_generator( 295 | onnx_model, 296 | output_dir, 297 | overwrite=False, 298 | tensor_inplace=False, 299 | simplify_names=False, 300 | continue_on_error=False, 301 | embedding_conf_file=None, 302 | shape_infer=False, 303 | ): 304 | kwargs = { 305 | "output_dir": output_dir, 306 | "simplify_names": simplify_names, 307 | "tensor_inplace": tensor_inplace, 308 | "continue_on_error": continue_on_error, 309 | "shape_infer": shape_infer 310 | } 311 | if type(onnx_model) == onnx.ModelProto: 312 | kwargs["onnx_model"] = onnx_model 313 | else: 314 | assert os.path.exists( 315 | onnx_model), f"ONNX model {onnx_model} does not exist." 316 | assert os.path.isfile(onnx_model), f"{onnx_model} is not a file." 317 | assert os.path.exists( 318 | output_dir 319 | ) and overwrite is not True, f"{output_dir} is not empty and overwrite is not True." 320 | assert os.path.isdir(output_dir), f"{output_dir} is not directory." 321 | kwargs["onnx_model"] = onnx.load(onnx_model) 322 | if overwrite: 323 | shutil.rmtree(output_dir, ignore_errors=True) 324 | os.makedirs(output_dir) 325 | if embedding_conf_file is not None: 326 | assert os.path.exists( 327 | embedding_conf_file 328 | ), f"Embedding config file {embedding_conf_file} does not exist." 329 | kwargs["embedding_conf"] = load_embedding_config(embedding_conf_file) 330 | return ModelCodeGenerator(**kwargs) 331 | 332 | 333 | def main(): 334 | debug = True 335 | parser = argparse.ArgumentParser() 336 | parser.add_argument("--onnx_model_path", 337 | default=None, 338 | type=str, 339 | required=not debug, 340 | help="The ONNX model path.") 341 | parser.add_argument("--output_dir", 342 | default=None, 343 | type=str, 344 | required=not debug, 345 | help="The output dir") 346 | parser.add_argument("--overwrite", 347 | default=False, 348 | type=bool, 349 | help="Should overwrite the output dir.") 350 | parser.add_argument("--tensor_inplace", 351 | default=False, 352 | type=bool, 353 | help="Try best to inplace tensor.") 354 | parser.add_argument("--continue_on_error", 355 | default=False, 356 | type=bool, 357 | help="Continue on error.") 358 | parser.add_argument("--embedding_conf_file", 359 | type=str, 360 | help="Embedding config file path.") 361 | parser.add_argument( 362 | "--simplify_names", 363 | default=False, 364 | type=int, 365 | help="Use indexing shorten name instead of original name.") 366 | args = parser.parse_args() 367 | 368 | gen(onnx_model=args.onnx_model_path, 369 | output_dir=args.output_dir, 370 | overwrite=args.overwrite, 371 | tensor_inplace=args.tensor_inplace, 372 | simplify_names=args.simplify_names, 373 | continue_on_error=args.continue_on_error, 374 | embedding_conf_file=args.embedding_conf_file) 375 | 376 | 377 | if __name__ == '__main__': 378 | main() 379 | -------------------------------------------------------------------------------- /onnx_pytorch/code_gen_template.py: -------------------------------------------------------------------------------- 1 | class CodeGenTemplate: 2 | 3 | @classmethod 4 | def autogen_head(cls): 5 | return '''# Autogenerated by onnx-pytorch. 6 | ''' 7 | 8 | @classmethod 9 | def imports(cls): 10 | return '''import glob 11 | import os 12 | import math 13 | 14 | import numpy as np 15 | import torch 16 | import torch.nn as nn 17 | import torch.nn.functional as F 18 | import torchvision 19 | ''' 20 | 21 | @classmethod 22 | def model(cls, model_init, model_forward, model_method, test_run_model): 23 | return f'''{cls.autogen_head()} 24 | {cls.imports()} 25 | 26 | class Model(nn.Module): 27 | def __init__(self): 28 | super(Model, self).__init__() 29 | self._vars = nn.ParameterDict() 30 | self._regularizer_params = [] 31 | for b in glob.glob( 32 | os.path.join(os.path.dirname(__file__), "variables", "*.npy")): 33 | v = torch.from_numpy(np.load(b)) 34 | requires_grad = v.dtype.is_floating_point or v.dtype.is_complex 35 | self._vars[os.path.basename(b)[:-4]] = nn.Parameter(v, requires_grad=requires_grad) 36 | {model_init} 37 | 38 | def forward(self, *inputs): 39 | {model_forward} 40 | 41 | {model_method} 42 | {test_run_model} 43 | ''' 44 | -------------------------------------------------------------------------------- /onnx_pytorch/op_code_generators/Abs.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | import torch 3 | 4 | from onnx_pytorch.op_code_generators import OpCodeGenerator 5 | 6 | 7 | class AbsOpCodeGenerator(OpCodeGenerator): 8 | 9 | def __init__(self, 10 | onnx_ver=onnx.defs.onnx_opset_version(), 11 | torch_ver=torch.__version__): 12 | super(AbsOpCodeGenerator, self).__init__(onnx_ver, torch_ver) 13 | 14 | def gen(self, node, value_infos, initializers): 15 | inputs_str, outputs_str = self.gen_input_output_string( 16 | node, initializers, self.rename_helper, self.tensor_inplace) 17 | init_str, forward_str = [], [] 18 | forward_str.append(f"{outputs_str[0]} = torch.abs({', '.join(inputs_str)})") 19 | return {"init": init_str, "forward": forward_str} 20 | -------------------------------------------------------------------------------- /onnx_pytorch/op_code_generators/Acos.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | import torch 3 | 4 | from onnx_pytorch.op_code_generators import OpCodeGenerator 5 | 6 | 7 | class AcosOpCodeGenerator(OpCodeGenerator): 8 | 9 | def __init__(self, 10 | onnx_ver=onnx.defs.onnx_opset_version(), 11 | torch_ver=torch.__version__): 12 | super(AcosOpCodeGenerator, self).__init__(onnx_ver, torch_ver) 13 | 14 | def gen(self, node, value_infos, initializers): 15 | inputs_str, outputs_str = self.gen_input_output_string( 16 | node, initializers, self.rename_helper, self.tensor_inplace) 17 | init_str, forward_str = [], [] 18 | forward_str.append( 19 | f"{outputs_str[0]} = torch.acos({', '.join(inputs_str)})") 20 | return {"init": init_str, "forward": forward_str} 21 | -------------------------------------------------------------------------------- /onnx_pytorch/op_code_generators/Acosh.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | import torch 3 | 4 | from onnx_pytorch.op_code_generators import OpCodeGenerator 5 | 6 | 7 | class AcoshOpCodeGenerator(OpCodeGenerator): 8 | 9 | def __init__(self, 10 | onnx_ver=onnx.defs.onnx_opset_version(), 11 | torch_ver=torch.__version__): 12 | super(AcoshOpCodeGenerator, self).__init__(onnx_ver, torch_ver) 13 | 14 | def gen(self, node, value_infos, initializers): 15 | inputs_str, outputs_str = self.gen_input_output_string( 16 | node, initializers, self.rename_helper, self.tensor_inplace) 17 | init_str, forward_str = [], [] 18 | forward_str.append( 19 | f"{outputs_str[0]} = torch.acosh({', '.join(inputs_str)})") 20 | return {"init": init_str, "forward": forward_str} 21 | -------------------------------------------------------------------------------- /onnx_pytorch/op_code_generators/Add.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | import torch 3 | 4 | from onnx_pytorch.op_code_generators import OpCodeGenerator 5 | 6 | 7 | class AddOpCodeGenerator(OpCodeGenerator): 8 | 9 | def __init__(self, 10 | onnx_ver=onnx.defs.onnx_opset_version(), 11 | torch_ver=torch.__version__): 12 | super(AddOpCodeGenerator, self).__init__(onnx_ver, torch_ver) 13 | 14 | def gen(self, node, value_infos, initializers): 15 | inputs_str, outputs_str = self.gen_input_output_string( 16 | node, initializers, self.rename_helper, self.tensor_inplace) 17 | init_str, forward_str = [], [] 18 | forward_str.append(f"{outputs_str[0]} = torch.add({', '.join(inputs_str)})") 19 | return {"init": init_str, "forward": forward_str} 20 | -------------------------------------------------------------------------------- /onnx_pytorch/op_code_generators/And.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | import torch 3 | 4 | from onnx_pytorch.op_code_generators import OpCodeGenerator 5 | 6 | 7 | class AndOpCodeGenerator(OpCodeGenerator): 8 | 9 | def __init__(self, 10 | onnx_ver=onnx.defs.onnx_opset_version(), 11 | torch_ver=torch.__version__): 12 | super(AndOpCodeGenerator, self).__init__(onnx_ver, torch_ver) 13 | 14 | def gen(self, node, value_infos, initializers): 15 | inputs_str, outputs_str = self.gen_input_output_string( 16 | node, initializers, self.rename_helper, self.tensor_inplace) 17 | init_str, forward_str = [], [] 18 | forward_str.append( 19 | f"{outputs_str[0]} = torch.logical_and({', '.join(inputs_str)})") 20 | return {"init": init_str, "forward": forward_str} 21 | -------------------------------------------------------------------------------- /onnx_pytorch/op_code_generators/ArgMax.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | import torch 3 | 4 | from onnx_pytorch.op_code_generators import OpCodeGenerator 5 | 6 | 7 | class ArgMaxOpCodeGenerator(OpCodeGenerator): 8 | 9 | def __init__(self, 10 | onnx_ver=onnx.defs.onnx_opset_version(), 11 | torch_ver=torch.__version__): 12 | super(ArgMaxOpCodeGenerator, self).__init__(onnx_ver, torch_ver) 13 | 14 | def gen(self, node, value_infos, initializers): 15 | attr_value_dict = self.get_attr_value_dict(node) 16 | inputs_str, outputs_str = self.gen_input_output_string( 17 | node, initializers, self.rename_helper, self.tensor_inplace) 18 | init_str, forward_str = [], [] 19 | select_last_index = attr_value_dict.get("select_last_index", 0) 20 | assert select_last_index == 0, NotImplementedError 21 | params_str = self.gen_params_str(keepdim=bool( 22 | attr_value_dict.get("keepdims", 1)), 23 | axis=attr_value_dict.get("axis", 0)) 24 | forward_str.append( 25 | f"{outputs_str[0]} = torch.argmax({', '.join(inputs_str)}, **{{{params_str}}})" 26 | ) 27 | return {"init": init_str, "forward": forward_str} 28 | -------------------------------------------------------------------------------- /onnx_pytorch/op_code_generators/ArgMin.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | import torch 3 | 4 | from onnx_pytorch.op_code_generators import OpCodeGenerator 5 | 6 | 7 | class ArgMinOpCodeGenerator(OpCodeGenerator): 8 | 9 | def __init__(self, 10 | onnx_ver=onnx.defs.onnx_opset_version(), 11 | torch_ver=torch.__version__): 12 | super(ArgMinOpCodeGenerator, self).__init__(onnx_ver, torch_ver) 13 | 14 | def gen(self, node, value_infos, initializers): 15 | attr_value_dict = self.get_attr_value_dict(node) 16 | inputs_str, outputs_str = self.gen_input_output_string( 17 | node, initializers, self.rename_helper, self.tensor_inplace) 18 | init_str, forward_str = [], [] 19 | select_last_index = attr_value_dict.get("select_last_index", 0) 20 | assert select_last_index == 0, NotImplementedError 21 | params_str = self.gen_params_str(keepdim=bool( 22 | attr_value_dict.get("keepdims", 1)), 23 | axis=attr_value_dict.get("axis", 0)) 24 | forward_str.append( 25 | f"{outputs_str[0]} = torch.argmin({', '.join(inputs_str)}, **{{{params_str}}})" 26 | ) 27 | return {"init": init_str, "forward": forward_str} 28 | -------------------------------------------------------------------------------- /onnx_pytorch/op_code_generators/Asin.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | import torch 3 | 4 | from onnx_pytorch.op_code_generators import OpCodeGenerator 5 | 6 | 7 | class AsinOpCodeGenerator(OpCodeGenerator): 8 | 9 | def __init__(self, 10 | onnx_ver=onnx.defs.onnx_opset_version(), 11 | torch_ver=torch.__version__): 12 | super(AsinOpCodeGenerator, self).__init__(onnx_ver, torch_ver) 13 | 14 | def gen(self, node, value_infos, initializers): 15 | inputs_str, outputs_str = self.gen_input_output_string( 16 | node, initializers, self.rename_helper, self.tensor_inplace) 17 | init_str, forward_str = [], [] 18 | forward_str.append( 19 | f"{outputs_str[0]} = torch.asin({', '.join(inputs_str)})") 20 | return {"init": init_str, "forward": forward_str} 21 | -------------------------------------------------------------------------------- /onnx_pytorch/op_code_generators/Asinh.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | import torch 3 | 4 | from onnx_pytorch.op_code_generators import OpCodeGenerator 5 | 6 | 7 | class AsinhOpCodeGenerator(OpCodeGenerator): 8 | 9 | def __init__(self, 10 | onnx_ver=onnx.defs.onnx_opset_version(), 11 | torch_ver=torch.__version__): 12 | super(AsinhOpCodeGenerator, self).__init__(onnx_ver, torch_ver) 13 | 14 | def gen(self, node, value_infos, initializers): 15 | inputs_str, outputs_str = self.gen_input_output_string( 16 | node, initializers, self.rename_helper, self.tensor_inplace) 17 | init_str, forward_str = [], [] 18 | forward_str.append( 19 | f"{outputs_str[0]} = torch.asinh({', '.join(inputs_str)})") 20 | return {"init": init_str, "forward": forward_str} 21 | -------------------------------------------------------------------------------- /onnx_pytorch/op_code_generators/Atan.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | import torch 3 | 4 | from onnx_pytorch.op_code_generators import OpCodeGenerator 5 | 6 | 7 | class AtanOpCodeGenerator(OpCodeGenerator): 8 | 9 | def __init__(self, 10 | onnx_ver=onnx.defs.onnx_opset_version(), 11 | torch_ver=torch.__version__): 12 | super(AtanOpCodeGenerator, self).__init__(onnx_ver, torch_ver) 13 | 14 | def gen(self, node, value_infos, initializers): 15 | inputs_str, outputs_str = self.gen_input_output_string( 16 | node, initializers, self.rename_helper, self.tensor_inplace) 17 | init_str, forward_str = [], [] 18 | forward_str.append( 19 | f"{outputs_str[0]} = torch.atan({', '.join(inputs_str)})") 20 | return {"init": init_str, "forward": forward_str} 21 | -------------------------------------------------------------------------------- /onnx_pytorch/op_code_generators/Atanh.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | import torch 3 | 4 | from onnx_pytorch.op_code_generators import OpCodeGenerator 5 | 6 | 7 | class AtanhOpCodeGenerator(OpCodeGenerator): 8 | 9 | def __init__(self, 10 | onnx_ver=onnx.defs.onnx_opset_version(), 11 | torch_ver=torch.__version__): 12 | super(AtanhOpCodeGenerator, self).__init__(onnx_ver, torch_ver) 13 | 14 | def gen(self, node, value_infos, initializers): 15 | inputs_str, outputs_str = self.gen_input_output_string( 16 | node, initializers, self.rename_helper, self.tensor_inplace) 17 | init_str, forward_str = [], [] 18 | forward_str.append( 19 | f"{outputs_str[0]} = torch.atanh({', '.join(inputs_str)})") 20 | return {"init": init_str, "forward": forward_str} 21 | -------------------------------------------------------------------------------- /onnx_pytorch/op_code_generators/AveragePool.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | import torch 3 | 4 | from onnx_pytorch.op_code_generators import OpCodeGenerator 5 | 6 | 7 | class AveragePoolOpCodeGenerator(OpCodeGenerator): 8 | 9 | def __init__(self, 10 | onnx_ver=onnx.defs.onnx_opset_version(), 11 | torch_ver=torch.__version__): 12 | super(AveragePoolOpCodeGenerator, self).__init__(onnx_ver, torch_ver) 13 | 14 | def gen(self, node, value_infos, initializers): 15 | attr_value_dict = self.get_attr_value_dict(node) 16 | inputs_str, outputs_str = self.gen_input_output_string( 17 | node, initializers, self.rename_helper, self.tensor_inplace) 18 | 19 | d = len(value_infos[node.input[0]].type.tensor_type.shape.dim) - 2 20 | assert (d in (1, 2, 3)) 21 | 22 | nn_name = f"AvgPool{d}d" 23 | node_name = self.rename_helper.get_node_name(node.name, node.op_type) 24 | init_str, forward_str = [], [] 25 | slice_str = [":", ":"] 26 | 27 | param = { 28 | "kernel_size": attr_value_dict["kernel_shape"][:].__repr__(), 29 | "ceil_mode": bool(attr_value_dict["ceil_mode"]), 30 | "stride": attr_value_dict.get("strides", 1), 31 | "count_include_pad": bool(attr_value_dict.get("count_include_pad", 0)) 32 | } 33 | if "pads" in attr_value_dict: 34 | padding = [] 35 | for i in range(d): 36 | padding_size = max(attr_value_dict['pads'][i], 37 | attr_value_dict['pads'][i + d]) 38 | padding.append(padding_size) 39 | slice_begin = "" if padding_size == attr_value_dict['pads'][i] else str( 40 | padding_size - attr_value_dict['pads'][i]) 41 | slice_end = "" if padding_size == attr_value_dict['pads'][ 42 | i + d] else str(attr_value_dict['pads'][i + d] - padding_size) 43 | slice_str.append(":".join([slice_begin, slice_end])) 44 | param["padding"] = padding.__repr__() 45 | params_str = self.gen_params_str(**param,) 46 | init_str.append(f"self.{node_name} = nn.{nn_name}(**{{{params_str}}})") 47 | forward_str.append( 48 | f"{outputs_str[0]} = self.{node_name}({inputs_str[0]})[{', '.join(slice_str)}]" 49 | ) 50 | 51 | return {"init": init_str, "forward": forward_str} 52 | -------------------------------------------------------------------------------- /onnx_pytorch/op_code_generators/BatchNormalization.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | import onnx.numpy_helper 3 | import torch 4 | 5 | from onnx_pytorch.op_code_generators import OpCodeGenerator 6 | 7 | 8 | class BatchNormalizationOpCodeGenerator(OpCodeGenerator): 9 | 10 | def __init__(self, 11 | onnx_ver=onnx.defs.onnx_opset_version(), 12 | torch_ver=torch.__version__): 13 | super(BatchNormalizationOpCodeGenerator, self).__init__(onnx_ver, torch_ver) 14 | 15 | def gen(self, node, value_infos, initializers): 16 | attr_value_dict = self.get_attr_value_dict(node) 17 | inputs_str, outputs_str = self.gen_input_output_string( 18 | node, initializers, self.rename_helper, self.tensor_inplace) 19 | 20 | d = len(value_infos[node.input[0]].type.tensor_type.shape.dim) - 2 21 | 22 | view = False 23 | if d == 0: 24 | d = 1 25 | view = True 26 | 27 | nn_name = f"BatchNorm{d}d" 28 | node_name = self.rename_helper.get_node_name(node.name, node.op_type) 29 | 30 | params_str = self.gen_params_str(num_features=onnx.numpy_helper.to_array( 31 | initializers[node.input[1]]).shape[0], 32 | eps=attr_value_dict["epsilon"], 33 | momentum=attr_value_dict["momentum"]) 34 | 35 | init_str, forward_str = [], [] 36 | init_str.append(f"self.{node_name} = nn.{nn_name}(**{{{params_str}}})") 37 | init_str.append(f"self.{node_name}.weight.data = {inputs_str[1]}") 38 | init_str.append(f"self.{node_name}.bias.data = {inputs_str[2]}") 39 | init_str.append(f"self.{node_name}.running_mean.data = {inputs_str[3]}") 40 | init_str.append(f"self.{node_name}.running_var.data = {inputs_str[4]}") 41 | curr_input = inputs_str[0] 42 | if view: 43 | forward_str.append(f"{curr_input} = torch.unsqueeze({curr_input}, -1)") 44 | forward_str.append(f"{outputs_str[0]} = self.{node_name}({curr_input})") 45 | if view: 46 | forward_str.append( 47 | f"{outputs_str[0]} = torch.squeeze({outputs_str[0]}, -1)") 48 | return {"init": init_str, "forward": forward_str} 49 | -------------------------------------------------------------------------------- /onnx_pytorch/op_code_generators/BitShift.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | import torch 3 | 4 | from onnx_pytorch.op_code_generators import OpCodeGenerator 5 | 6 | 7 | class BitShiftOpCodeGenerator(OpCodeGenerator): 8 | 9 | def __init__(self, 10 | onnx_ver=onnx.defs.onnx_opset_version(), 11 | torch_ver=torch.__version__): 12 | super(BitShiftOpCodeGenerator, self).__init__(onnx_ver, torch_ver) 13 | 14 | def gen(self, node, value_infos, initializers): 15 | attr_value_dict = self.get_attr_value_dict(node) 16 | inputs_str, outputs_str = self.gen_input_output_string( 17 | node, initializers, self.rename_helper, self.tensor_inplace) 18 | init_str, forward_str = [], [] 19 | direction = "<<" if attr_value_dict["direction"] == b"LEFT" else ">>" 20 | forward_str.append( 21 | f"{outputs_str[0]} = {f' {direction} '.join(inputs_str)}") 22 | return {"init": init_str, "forward": forward_str} 23 | -------------------------------------------------------------------------------- /onnx_pytorch/op_code_generators/Cast.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE 3 | import torch 4 | 5 | from onnx_pytorch.op_code_generators import OpCodeGenerator 6 | 7 | 8 | class CastOpCodeGenerator(OpCodeGenerator): 9 | 10 | def __init__(self, 11 | onnx_ver=onnx.defs.onnx_opset_version(), 12 | torch_ver=torch.__version__): 13 | super(CastOpCodeGenerator, self).__init__(onnx_ver, torch_ver) 14 | 15 | def gen(self, node, value_infos, initializers): 16 | attr_value_dict = self.get_attr_value_dict(node) 17 | inputs_str, outputs_str = self.gen_input_output_string( 18 | node, initializers, self.rename_helper, self.tensor_inplace) 19 | init_str, forward_str = [], [] 20 | forward_str.append( 21 | f"{outputs_str[0]} = {inputs_str[0]}.to(device={inputs_str[0]}.device, dtype=torch.{str(TENSOR_TYPE_TO_NP_TYPE[attr_value_dict['to']])}, copy=True)" 22 | ) 23 | return {"init": init_str, "forward": forward_str} 24 | -------------------------------------------------------------------------------- /onnx_pytorch/op_code_generators/Ceil.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | import torch 3 | 4 | from onnx_pytorch.op_code_generators import OpCodeGenerator 5 | 6 | 7 | class CeilOpCodeGenerator(OpCodeGenerator): 8 | 9 | def __init__(self, 10 | onnx_ver=onnx.defs.onnx_opset_version(), 11 | torch_ver=torch.__version__): 12 | super(CeilOpCodeGenerator, self).__init__(onnx_ver, torch_ver) 13 | 14 | def gen(self, node, value_infos, initializers): 15 | inputs_str, outputs_str = self.gen_input_output_string( 16 | node, initializers, self.rename_helper, self.tensor_inplace) 17 | init_str, forward_str = [], [] 18 | forward_str.append( 19 | f"{outputs_str[0]} = torch.ceil({', '.join(inputs_str)})") 20 | return {"init": init_str, "forward": forward_str} 21 | -------------------------------------------------------------------------------- /onnx_pytorch/op_code_generators/Clip.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | import torch 3 | 4 | from onnx_pytorch.op_code_generators import OpCodeGenerator 5 | 6 | 7 | class ClipOpCodeGenerator(OpCodeGenerator): 8 | 9 | def __init__(self, 10 | onnx_ver=onnx.defs.onnx_opset_version(), 11 | torch_ver=torch.__version__): 12 | super(ClipOpCodeGenerator, self).__init__(onnx_ver, torch_ver) 13 | 14 | def gen(self, node, value_infos, initializers): 15 | attr_value_dict = self.get_attr_value_dict(node) 16 | inputs_str, outputs_str = self.gen_input_output_string( 17 | node, initializers, self.rename_helper, self.tensor_inplace) 18 | min = attr_value_dict.get("min", "float(\"-inf\")") 19 | max = attr_value_dict.get("max", "float(\"inf\")") 20 | if len(inputs_str) == 1: 21 | inputs_str.append(str(min)) 22 | if len(inputs_str) < 3: 23 | inputs_str.append(str(max)) 24 | init_str, forward_str = [], [] 25 | forward_str.append( 26 | f"{outputs_str[0]} = torch.clip({', '.join(inputs_str)})") 27 | return {"init": init_str, "forward": forward_str} 28 | -------------------------------------------------------------------------------- /onnx_pytorch/op_code_generators/Concat.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | import torch 3 | from onnx.numpy_helper import to_array 4 | 5 | from onnx_pytorch.op_code_generators import OpCodeGenerator 6 | 7 | 8 | class ConcatOpCodeGenerator(OpCodeGenerator): 9 | 10 | def __init__(self, 11 | onnx_ver=onnx.defs.onnx_opset_version(), 12 | torch_ver=torch.__version__): 13 | super(ConcatOpCodeGenerator, self).__init__(onnx_ver, torch_ver) 14 | 15 | def gen(self, node, value_infos, initializers): 16 | attr_value_dict = self.get_attr_value_dict(node) 17 | inputs_str, outputs_str = self.gen_input_output_string( 18 | node, initializers, self.rename_helper, self.tensor_inplace) 19 | init_str, forward_str = [], [] 20 | axis = attr_value_dict["axis"] 21 | params_str = self.gen_params_str(dim=axis) 22 | if len(inputs_str) == 1: 23 | forward_str.append(f"{', '.join(outputs_str)} = {inputs_str[0]}") 24 | else: 25 | forward_str.append( 26 | f"{', '.join(outputs_str)} = torch.cat(({', '.join(inputs_str)}), **{{{params_str}}})" 27 | ) 28 | return {"init": init_str, "forward": forward_str} 29 | -------------------------------------------------------------------------------- /onnx_pytorch/op_code_generators/Constant.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | import torch 3 | from onnx.numpy_helper import to_array 4 | 5 | from onnx_pytorch.op_code_generators import OpCodeGenerator 6 | 7 | 8 | class ConstantOpCodeGenerator(OpCodeGenerator): 9 | 10 | def __init__(self, 11 | onnx_ver=onnx.defs.onnx_opset_version(), 12 | torch_ver=torch.__version__): 13 | super(ConstantOpCodeGenerator, self).__init__(onnx_ver, torch_ver) 14 | 15 | def gen(self, node, value_infos, initializers): 16 | attr_value_dict = self.get_attr_value_dict(node) 17 | inputs_str, outputs_str = self.gen_input_output_string( 18 | node, initializers, self.rename_helper, self.tensor_inplace) 19 | init_str, forward_str = [], [] 20 | if "value" in attr_value_dict: 21 | initializers[node.output[0]] = attr_value_dict["value"] 22 | else: 23 | raise NotImplementedError 24 | forward_str.append(f"{outputs_str[0]} = self._vars['{outputs_str[0]}']") 25 | return {"init": init_str, "forward": forward_str} 26 | -------------------------------------------------------------------------------- /onnx_pytorch/op_code_generators/ConstantOfShape.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | import torch 3 | from onnx.numpy_helper import to_array 4 | 5 | from onnx_pytorch.op_code_generators import OpCodeGenerator 6 | 7 | 8 | class ConstantOfShapeOpCodeGenerator(OpCodeGenerator): 9 | 10 | def __init__(self, 11 | onnx_ver=onnx.defs.onnx_opset_version(), 12 | torch_ver=torch.__version__): 13 | super(ConstantOfShapeOpCodeGenerator, self).__init__(onnx_ver, torch_ver) 14 | 15 | def gen(self, node, value_infos, initializers): 16 | attr_value_dict = self.get_attr_value_dict(node) 17 | inputs_str, outputs_str = self.gen_input_output_string( 18 | node, initializers, self.rename_helper, self.tensor_inplace) 19 | init_str, forward_str = [], [] 20 | dtype = "float32" 21 | val = 0. 22 | if "value" in attr_value_dict: 23 | array = to_array(attr_value_dict["value"]) 24 | dtype = array.dtype 25 | val = array[0] 26 | forward_str.append( 27 | f"{outputs_str[0]} = torch.Tensor().new_full(size={inputs_str[0]}.tolist(), fill_value={val}, dtype=torch.{dtype})" 28 | ) 29 | return {"init": init_str, "forward": forward_str} 30 | -------------------------------------------------------------------------------- /onnx_pytorch/op_code_generators/Conv.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import onnx 4 | import onnx.numpy_helper 5 | import torch 6 | 7 | from onnx_pytorch.op_code_generators import OpCodeGenerator 8 | 9 | 10 | class ConvOpCodeGenerator(OpCodeGenerator): 11 | 12 | def __init__(self, 13 | onnx_ver=onnx.defs.onnx_opset_version(), 14 | torch_ver=torch.__version__): 15 | super(ConvOpCodeGenerator, self).__init__(onnx_ver, torch_ver) 16 | 17 | def gen(self, node, value_infos, initializers): 18 | attr_value_dict = self.get_attr_value_dict(node) 19 | inputs_str, outputs_str = self.gen_input_output_string( 20 | node, initializers, self.rename_helper, self.tensor_inplace) 21 | 22 | d = len(value_infos[node.input[0]].type.tensor_type.shape.dim) - 2 23 | assert (d in (1, 2, 3)) 24 | 25 | nn_name = f"Conv{d}d" 26 | node_name = self.rename_helper.get_node_name(node.name, node.op_type) 27 | init_str, forward_str = [], [] 28 | padding = 0 29 | if "pads" in attr_value_dict: 30 | padding = [attr_value_dict["pads"][i] for i in range(d)] 31 | elif attr_value_dict["auto_pad"] not in (b"NOTSET", b""): 32 | logging.warning( 33 | "auto_pad is a DEPRECATED attribute, will not guarantee the result.") 34 | forward_str.append( 35 | f"{inputs_str[0]} = self.compatible_auto_pad({inputs_str[0]}, self.{node_name}.weight.data.shape[2:], self.{node_name}, '{attr_value_dict['auto_pad'].decode('utf-8')}')" 36 | ) 37 | weights = onnx.numpy_helper.to_array(initializers[node.input[1]]) 38 | params_str = self.gen_params_str( 39 | groups=attr_value_dict["group"], 40 | dilation=attr_value_dict.get("dilations", 1), 41 | out_channels=weights.shape[0], 42 | padding=padding, 43 | kernel_size=weights.shape[2:].__repr__(), 44 | stride=attr_value_dict.get("strides", 1), 45 | in_channels=weights.shape[1] * attr_value_dict["group"], 46 | bias=len(node.input) > 2) 47 | 48 | init_str.append(f"self.{node_name} = nn.{nn_name}(**{{{params_str}}})") 49 | init_str.append(f"self.{node_name}.weight.data = {inputs_str[1]}") 50 | if len(node.input) > 2: 51 | init_str.append(f"self.{node_name}.bias.data = {inputs_str[2]}") 52 | 53 | forward_str.append(f"{outputs_str[0]} = self.{node_name}({inputs_str[0]})") 54 | 55 | return {"init": init_str, "forward": forward_str} 56 | 57 | @staticmethod 58 | def gen_method(): 59 | return '''def compatible_auto_pad(self, input, kernel_spatial_shape, nn_mod, auto_pad=None, **kwargs): 60 | input_spatial_shape = input.shape[2:] 61 | d = len(input_spatial_shape) 62 | strides = nn_mod.stride 63 | dilations = nn_mod.dilation 64 | output_spatial_shape = [math.ceil(float(l) / float(r)) for l, r in zip(input.shape[2:], strides)] 65 | pt_padding = [0] * 2 * d 66 | pad_shape = [0] * d 67 | for i in range(d): 68 | pad_shape[i] = (output_spatial_shape[i] - 1) * strides[i] + ((kernel_spatial_shape[i] - 1) * dilations[i] + 1) - input_spatial_shape[i] 69 | mean = pad_shape[i] // 2 70 | if auto_pad == b"SAME_UPPER": 71 | l, r = pad_shape[i] - mean, mean 72 | else: 73 | l, r = mean, pad_shape[i] - mean 74 | pt_padding.insert(0, r) 75 | pt_padding.insert(0, l) 76 | return F.pad(input, pt_padding) 77 | ''' 78 | -------------------------------------------------------------------------------- /onnx_pytorch/op_code_generators/ConvTranspose.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | import onnx.numpy_helper 3 | import torch 4 | 5 | from onnx_pytorch.op_code_generators import OpCodeGenerator 6 | 7 | 8 | class ConvTransposeOpCodeGenerator(OpCodeGenerator): 9 | 10 | def __init__(self, 11 | onnx_ver=onnx.defs.onnx_opset_version(), 12 | torch_ver=torch.__version__): 13 | super(ConvTransposeOpCodeGenerator, self).__init__(onnx_ver, torch_ver) 14 | 15 | def gen(self, node, value_infos, initializers): 16 | attr_value_dict = self.get_attr_value_dict(node) 17 | inputs_str, outputs_str = self.gen_input_output_string( 18 | node, initializers, self.rename_helper, self.tensor_inplace) 19 | 20 | d = len(value_infos[node.input[0]].type.tensor_type.shape.dim) - 2 21 | input_size = [ 22 | d.dim_value 23 | for d in value_infos[node.input[0]].type.tensor_type.shape.dim 24 | ][2:] 25 | assert (d in (1, 2, 3)) 26 | 27 | weights = onnx.numpy_helper.to_array(initializers[node.input[1]]) 28 | padding = [0] * d 29 | output_padding = [0] * d 30 | stride = attr_value_dict.get("strides", [1] * d) 31 | kernel_shape = weights.shape[2:] 32 | dilation = attr_value_dict.get("dilations", [1] * d) 33 | if "pads" in attr_value_dict: 34 | padding = [attr_value_dict["pads"][i] for i in range(d)] 35 | if "output_padding" in attr_value_dict: 36 | output_padding = [attr_value_dict["output_padding"][i] for i in range(d)] 37 | if "output_shape" in attr_value_dict: 38 | output_shape = attr_value_dict["output_shape"] 39 | total_padding = [0] * d 40 | 41 | # total_padding[i] = stride[i] * (input_size[i] - 1) + output_padding[i] + ((kernel_shape[i] - 1) * dilations[i] + 1) - output_shape[i] 42 | # If (auto_pads == SAME_UPPER): pads[start_i] = total_padding[i]/2; pads[end_i] = total_padding[i] - (total_padding[i]/2) 43 | # Else: pads[start_i] = total_padding[i] - (total_padding[i]/2); pads[end_i] = (total_padding[i]/2). 44 | 45 | for i in range(d): 46 | total_padding[i] = stride[i] * ( 47 | input_size[i] - 1) + output_padding[i] + ( 48 | (kernel_shape[i] - 1) * dilation[i] + 1) - output_shape[i] 49 | assert total_padding[ 50 | i] % 2 == 0, "Padding for ConvTranspose should be even." 51 | padding[i] = total_padding[i] // 2 52 | params_str = self.gen_params_str(groups=attr_value_dict["group"], 53 | dilation=dilation, 54 | out_channels=weights.shape[1], 55 | padding=padding, 56 | output_padding=output_padding, 57 | kernel_size=weights.shape[2:], 58 | stride=stride, 59 | in_channels=weights.shape[0], 60 | bias=len(node.input) > 2) 61 | 62 | nn_name = f"ConvTranspose{d}d" 63 | node_name = self.rename_helper.get_node_name(node.name, node.op_type) 64 | init_str, forward_str = [], [] 65 | init_str.append(f"self.{node_name} = nn.{nn_name}(**{{{params_str}}})") 66 | init_str.append(f"self.{node_name}.weight.data = {inputs_str[1]}") 67 | if len(node.input) > 2: 68 | init_str.append(f"self.{node_name}.bias.data = {inputs_str[2]}") 69 | 70 | forward_str.append(f"{outputs_str[0]} = self.{node_name}({inputs_str[0]})") 71 | 72 | return {"init": init_str, "forward": forward_str} 73 | -------------------------------------------------------------------------------- /onnx_pytorch/op_code_generators/Cos.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | import torch 3 | 4 | from onnx_pytorch.op_code_generators import OpCodeGenerator 5 | 6 | 7 | class CosOpCodeGenerator(OpCodeGenerator): 8 | 9 | def __init__(self, 10 | onnx_ver=onnx.defs.onnx_opset_version(), 11 | torch_ver=torch.__version__): 12 | super(CosOpCodeGenerator, self).__init__(onnx_ver, torch_ver) 13 | 14 | def gen(self, node, value_infos, initializers): 15 | inputs_str, outputs_str = self.gen_input_output_string( 16 | node, initializers, self.rename_helper, self.tensor_inplace) 17 | init_str, forward_str = [], [] 18 | forward_str.append(f"{outputs_str[0]} = torch.cos({', '.join(inputs_str)})") 19 | return {"init": init_str, "forward": forward_str} 20 | -------------------------------------------------------------------------------- /onnx_pytorch/op_code_generators/Cosh.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | import torch 3 | 4 | from onnx_pytorch.op_code_generators import OpCodeGenerator 5 | 6 | 7 | class CoshOpCodeGenerator(OpCodeGenerator): 8 | 9 | def __init__(self, 10 | onnx_ver=onnx.defs.onnx_opset_version(), 11 | torch_ver=torch.__version__): 12 | super(CoshOpCodeGenerator, self).__init__(onnx_ver, torch_ver) 13 | 14 | def gen(self, node, value_infos, initializers): 15 | inputs_str, outputs_str = self.gen_input_output_string( 16 | node, initializers, self.rename_helper, self.tensor_inplace) 17 | init_str, forward_str = [], [] 18 | forward_str.append( 19 | f"{outputs_str[0]} = torch.cosh({', '.join(inputs_str)})") 20 | return {"init": init_str, "forward": forward_str} 21 | -------------------------------------------------------------------------------- /onnx_pytorch/op_code_generators/Div.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | import torch 3 | 4 | from onnx_pytorch.op_code_generators import OpCodeGenerator 5 | 6 | 7 | class DivOpCodeGenerator(OpCodeGenerator): 8 | 9 | def __init__(self, 10 | onnx_ver=onnx.defs.onnx_opset_version(), 11 | torch_ver=torch.__version__): 12 | super(DivOpCodeGenerator, self).__init__(onnx_ver, torch_ver) 13 | 14 | def gen(self, node, value_infos, initializers): 15 | inputs_str, outputs_str = self.gen_input_output_string( 16 | node, initializers, self.rename_helper, self.tensor_inplace) 17 | init_str, forward_str = [], [] 18 | forward_str.append(f"{outputs_str[0]} = torch.div({', '.join(inputs_str)})") 19 | return {"init": init_str, "forward": forward_str} 20 | -------------------------------------------------------------------------------- /onnx_pytorch/op_code_generators/Dropout.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | import torch 3 | from onnx.numpy_helper import to_array 4 | 5 | from onnx_pytorch.op_code_generators import OpCodeGenerator 6 | 7 | 8 | class DropoutOpCodeGenerator(OpCodeGenerator): 9 | 10 | def __init__(self, 11 | onnx_ver=onnx.defs.onnx_opset_version(), 12 | torch_ver=torch.__version__): 13 | super(DropoutOpCodeGenerator, self).__init__(onnx_ver, torch_ver) 14 | 15 | def gen(self, node, value_infos, initializers): 16 | attr_value_dict = self.get_attr_value_dict(node) 17 | inputs_str, outputs_str = self.gen_input_output_string( 18 | node, initializers, self.rename_helper, self.tensor_inplace) 19 | init_str, forward_str = [], [] 20 | check_list = [] 21 | ratio = attr_value_dict.get("ratio", 0.5) 22 | training_mode = attr_value_dict.get("training_mode", False) 23 | if len(node.input) > 1: 24 | check_list.append((node.input[1], "ratio")) 25 | if len(node.input) > 2: 26 | check_list.append((node.input[2], "training_mode")) 27 | inits = self.check_in_init(check_list, initializers) 28 | if len(node.input) > 1: 29 | ratio = to_array(inits[0])[0] 30 | if len(node.input) > 2: 31 | training_mode = bool(to_array(inits[1])[0]) 32 | params_str = self.gen_params_str(p=ratio) 33 | nn_name = "Dropout" 34 | node_name = self.rename_helper.get_node_name(node.name, node.op_type) 35 | init_str.append(f"self.{node_name} = nn.{nn_name}(**{{{params_str}}})") 36 | forward_str.append(f"{outputs_str[0]} = self.{node_name}({inputs_str[0]})") 37 | return {"init": init_str, "forward": forward_str} 38 | -------------------------------------------------------------------------------- /onnx_pytorch/op_code_generators/Elu.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | import torch 3 | 4 | from onnx_pytorch.op_code_generators import OpCodeGenerator 5 | 6 | 7 | class EluOpCodeGenerator(OpCodeGenerator): 8 | 9 | def __init__(self, 10 | onnx_ver=onnx.defs.onnx_opset_version(), 11 | torch_ver=torch.__version__): 12 | super(EluOpCodeGenerator, self).__init__(onnx_ver, torch_ver) 13 | 14 | def gen(self, node, value_infos, initializers): 15 | inputs_str, outputs_str = self.gen_input_output_string( 16 | node, initializers, self.rename_helper, self.tensor_inplace) 17 | init_str, forward_str = [], [] 18 | forward_str.append(f"{outputs_str[0]} = F.elu({inputs_str[0]})") 19 | return {"init": init_str, "forward": forward_str} 20 | -------------------------------------------------------------------------------- /onnx_pytorch/op_code_generators/Equal.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | import torch 3 | 4 | from onnx_pytorch.op_code_generators import OpCodeGenerator 5 | 6 | 7 | class EqualOpCodeGenerator(OpCodeGenerator): 8 | 9 | def __init__(self, 10 | onnx_ver=onnx.defs.onnx_opset_version(), 11 | torch_ver=torch.__version__): 12 | super(EqualOpCodeGenerator, self).__init__(onnx_ver, torch_ver) 13 | 14 | def gen(self, node, value_infos, initializers): 15 | inputs_str, outputs_str = self.gen_input_output_string( 16 | node, initializers, self.rename_helper, self.tensor_inplace) 17 | init_str, forward_str = [], [] 18 | forward_str.append(f"{outputs_str[0]} = torch.eq({', '.join(inputs_str)})") 19 | return {"init": init_str, "forward": forward_str} 20 | -------------------------------------------------------------------------------- /onnx_pytorch/op_code_generators/Exp.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | import torch 3 | 4 | from onnx_pytorch.op_code_generators import OpCodeGenerator 5 | 6 | 7 | class ExpOpCodeGenerator(OpCodeGenerator): 8 | 9 | def __init__(self, 10 | onnx_ver=onnx.defs.onnx_opset_version(), 11 | torch_ver=torch.__version__): 12 | super(ExpOpCodeGenerator, self).__init__(onnx_ver, torch_ver) 13 | 14 | def gen(self, node, value_infos, initializers): 15 | inputs_str, outputs_str = self.gen_input_output_string( 16 | node, initializers, self.rename_helper, self.tensor_inplace) 17 | init_str, forward_str = [], [] 18 | forward_str.append(f"{outputs_str[0]} = torch.exp({', '.join(inputs_str)})") 19 | return {"init": init_str, "forward": forward_str} 20 | -------------------------------------------------------------------------------- /onnx_pytorch/op_code_generators/Expand.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | import torch 3 | 4 | from onnx_pytorch.op_code_generators import OpCodeGenerator 5 | 6 | 7 | class ExpandOpCodeGenerator(OpCodeGenerator): 8 | 9 | def __init__(self, 10 | onnx_ver=onnx.defs.onnx_opset_version(), 11 | torch_ver=torch.__version__): 12 | super(ExpandOpCodeGenerator, self).__init__(onnx_ver, torch_ver) 13 | 14 | def gen(self, node, value_infos, initializers): 15 | inputs_str, outputs_str = self.gen_input_output_string( 16 | node, initializers, self.rename_helper, self.tensor_inplace) 17 | init_str, forward_str = [], [] 18 | forward_str.append( 19 | f"{outputs_str[0]} = {inputs_str[0]}.expand(*list({inputs_str[1]}))") 20 | return {"init": init_str, "forward": forward_str} 21 | -------------------------------------------------------------------------------- /onnx_pytorch/op_code_generators/Flatten.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | import torch 3 | 4 | from onnx_pytorch.op_code_generators import OpCodeGenerator 5 | 6 | 7 | class FlattenOpCodeGenerator(OpCodeGenerator): 8 | 9 | def __init__(self, 10 | onnx_ver=onnx.defs.onnx_opset_version(), 11 | torch_ver=torch.__version__): 12 | super(FlattenOpCodeGenerator, self).__init__(onnx_ver, torch_ver) 13 | 14 | def gen(self, node, value_infos, initializers): 15 | attr_value_dict = self.get_attr_value_dict(node) 16 | inputs_str, outputs_str = self.gen_input_output_string( 17 | node, initializers, self.rename_helper, self.tensor_inplace) 18 | params_str = self.gen_params_str(start_dim=attr_value_dict["axis"]) 19 | nn_name = self.onnx_op 20 | node_name = self.rename_helper.get_node_name(node.name, node.op_type) 21 | init_str, forward_str = [], [] 22 | init_str.append(f"self.{node_name} = nn.{nn_name}(**{{{params_str}}})") 23 | forward_str.append( 24 | f"{', '.join(outputs_str)} = self.{node_name}({', '.join(inputs_str)})") 25 | 26 | return {"init": init_str, "forward": forward_str} 27 | -------------------------------------------------------------------------------- /onnx_pytorch/op_code_generators/Floor.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | import torch 3 | 4 | from onnx_pytorch.op_code_generators import OpCodeGenerator 5 | 6 | 7 | class FloorOpCodeGenerator(OpCodeGenerator): 8 | 9 | def __init__(self, 10 | onnx_ver=onnx.defs.onnx_opset_version(), 11 | torch_ver=torch.__version__): 12 | super(FloorOpCodeGenerator, self).__init__(onnx_ver, torch_ver) 13 | 14 | def gen(self, node, value_infos, initializers): 15 | inputs_str, outputs_str = self.gen_input_output_string( 16 | node, initializers, self.rename_helper, self.tensor_inplace) 17 | init_str, forward_str = [], [] 18 | forward_str.append( 19 | f"{outputs_str[0]} = torch.floor({', '.join(inputs_str)})") 20 | return {"init": init_str, "forward": forward_str} 21 | -------------------------------------------------------------------------------- /onnx_pytorch/op_code_generators/Gather.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | import torch 3 | from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE 4 | 5 | from onnx_pytorch.op_code_generators import OpCodeGenerator 6 | 7 | 8 | class GatherOpCodeGenerator(OpCodeGenerator): 9 | 10 | def __init__(self, 11 | onnx_ver=onnx.defs.onnx_opset_version(), 12 | torch_ver=torch.__version__): 13 | super(GatherOpCodeGenerator, self).__init__(onnx_ver, torch_ver) 14 | self.embedding_conf = None 15 | 16 | def gen(self, node, value_infos, initializers): 17 | attr_value_dict = self.get_attr_value_dict(node) 18 | inputs_str, outputs_str = self.gen_input_output_string( 19 | node, initializers, self.rename_helper, self.tensor_inplace) 20 | init_str, forward_str = [], [] 21 | if self.embedding_conf is not None and node.name in self.embedding_conf: 22 | conf = self.embedding_conf[node.name] 23 | node_name = self.rename_helper.get_node_name(node.name, node.op_type) 24 | params_str = self.gen_params_str(num_embeddings=conf.num_embeddings, 25 | embedding_dim=conf.embedding_dim) 26 | init_str.append(f"self.{node_name} = nn.Embedding(**{{{params_str}}})") 27 | dtype = "int" 28 | if node.input[1] in value_infos: 29 | np_type = TENSOR_TYPE_TO_NP_TYPE[value_infos[ 30 | node.input[1]].type.tensor_type.elem_type] 31 | if np_type.name == "int32": 32 | pass 33 | elif np_type.name == "int64": 34 | dtype = "long" 35 | forward_str.append( 36 | f"{outputs_str[0]} = self.{node_name}({inputs_str[1]}.{dtype}())") 37 | if conf.initializer is not None: 38 | class_name = conf.initializer["class_name"] 39 | init_conf = conf.initializer["config"] 40 | if class_name == "RandomNormal": 41 | init_str.append( 42 | f"nn.init.normal_(self.{node_name}.weight, mean={init_conf['mean']}, std=math.sqrt({init_conf['stddev']}))" 43 | ) 44 | elif class_name == "Zeros": 45 | init_str.append(f"nn.init.constant_(self.{node_name}.weight, 0.0)") 46 | if conf.regularizer is not None: 47 | reg_conf = conf.regularizer["config"] 48 | init_str.append( 49 | f"self._regularizer_params.append((self.{node_name}.weight, {reg_conf.get('l1', 0.0)}, {reg_conf.get('l2', 0.0)}))" 50 | ) 51 | else: 52 | axis = attr_value_dict.get("axis", 0) 53 | 54 | # Simple solution 55 | # forward_str.append( 56 | # f"{outputs_str[0]} = {inputs_str[0]}.__getitem__([slice(None) for _ in range({axis})] + [{inputs_str[1]}.to(device={inputs_str[0]}.device, dtype=torch.int64)])" 57 | # ) 58 | forward_str.append( 59 | f'''{outputs_str[0]} = self.gather({inputs_str[0]}, {axis}, {inputs_str[1]})''' 60 | ) 61 | return {"init": init_str, "forward": forward_str} 62 | 63 | @staticmethod 64 | def gen_method(): 65 | return '''def gather(self, input, dim, indices, **kwargs): 66 | shape_l, shape_r = list(input.shape), list(indices.shape) 67 | indices = indices.flatten().to(device=indices.device, dtype=torch.int64) 68 | for r in range(0, dim): 69 | indices = indices.unsqueeze(0) 70 | for r in range(dim, len(shape_l) - 1): 71 | indices = indices.unsqueeze(-1) 72 | indices = indices.expand(*(shape_l[:dim] + [int(np.prod(shape_r))] + shape_l[dim + 1:])) 73 | indices = torch.where(indices >= 0, indices, indices + shape_l[dim]) 74 | output = torch.gather(input, dim, indices) 75 | output = torch.reshape(output, shape_l[:dim] + shape_r + shape_l[dim + 1:]) 76 | return output 77 | ''' 78 | -------------------------------------------------------------------------------- /onnx_pytorch/op_code_generators/GatherND.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | import torch 3 | 4 | from onnx_pytorch.op_code_generators import OpCodeGenerator 5 | 6 | 7 | class GatherNDOpCodeGenerator(OpCodeGenerator): 8 | 9 | def __init__(self, 10 | onnx_ver=onnx.defs.onnx_opset_version(), 11 | torch_ver=torch.__version__): 12 | super(GatherNDOpCodeGenerator, self).__init__(onnx_ver, torch_ver) 13 | 14 | def gen(self, node, value_infos, initializers): 15 | attr_value_dict = self.get_attr_value_dict(node) 16 | inputs_str, outputs_str = self.gen_input_output_string( 17 | node, initializers, self.rename_helper, self.tensor_inplace) 18 | init_str, forward_str = [], [] 19 | assert attr_value_dict.get("batch_dims", 0) == 0, NotImplementedError 20 | forward_str.append( 21 | f"{outputs_str[0]} = {inputs_str[0]}[list(torch.LongTensor({inputs_str[1]}).T)]" 22 | ) 23 | return {"init": init_str, "forward": forward_str} 24 | -------------------------------------------------------------------------------- /onnx_pytorch/op_code_generators/Gemm.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | import torch 3 | 4 | from onnx_pytorch.op_code_generators import OpCodeGenerator 5 | 6 | 7 | class GemmOpCodeGenerator(OpCodeGenerator): 8 | 9 | def __init__(self, 10 | onnx_ver=onnx.defs.onnx_opset_version(), 11 | torch_ver=torch.__version__): 12 | super(GemmOpCodeGenerator, self).__init__(onnx_ver, torch_ver) 13 | 14 | def gen(self, node, value_infos, initializers): 15 | attr_value_dict = self.get_attr_value_dict(node) 16 | inputs_str, outputs_str = self.gen_input_output_string( 17 | node, initializers, self.rename_helper, self.tensor_inplace) 18 | 19 | if attr_value_dict["transA"] == 1: 20 | inputs_str[0] = f"torch.transpose({inputs_str[0]}, 0, 1)" 21 | if attr_value_dict["transB"] == 1: 22 | inputs_str[1] = f"torch.transpose({inputs_str[1]}, 0, 1)" 23 | 24 | init_str, forward_str = [], [] 25 | forward_str.append( 26 | f"{outputs_str[0]} = {attr_value_dict['alpha']} * torch.matmul({', '.join(inputs_str[:2])}) + {attr_value_dict['beta']} * {inputs_str[2]}" 27 | ) 28 | 29 | return {"init": init_str, "forward": forward_str} 30 | -------------------------------------------------------------------------------- /onnx_pytorch/op_code_generators/GlobalAveragePool.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | import torch 3 | 4 | from onnx_pytorch.op_code_generators import OpCodeGenerator 5 | 6 | 7 | class GlobalAveragePoolOpCodeGenerator(OpCodeGenerator): 8 | 9 | def __init__(self, 10 | onnx_ver=onnx.defs.onnx_opset_version(), 11 | torch_ver=torch.__version__): 12 | super(GlobalAveragePoolOpCodeGenerator, self).__init__(onnx_ver, torch_ver) 13 | 14 | def gen(self, node, value_infos, initializers): 15 | inputs_str, outputs_str = self.gen_input_output_string( 16 | node, initializers, self.rename_helper, self.tensor_inplace) 17 | d = len(value_infos[node.input[0]].type.tensor_type.shape.dim) - 2 18 | params_str = self.gen_params_str( 19 | kernel_size=f"{inputs_str[0]}.shape[-{d}:]") 20 | 21 | init_str, forward_str = [], [] 22 | forward_str.append( 23 | f"{outputs_str[0]} = F.avg_pool{d}d({inputs_str[0]}, **{{{params_str}}})" 24 | ) 25 | 26 | return {"init": init_str, "forward": forward_str} 27 | -------------------------------------------------------------------------------- /onnx_pytorch/op_code_generators/Greater.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | import torch 3 | 4 | from onnx_pytorch.op_code_generators import OpCodeGenerator 5 | 6 | 7 | class GreaterOpCodeGenerator(OpCodeGenerator): 8 | 9 | def __init__(self, 10 | onnx_ver=onnx.defs.onnx_opset_version(), 11 | torch_ver=torch.__version__): 12 | super(GreaterOpCodeGenerator, self).__init__(onnx_ver, torch_ver) 13 | 14 | def gen(self, node, value_infos, initializers): 15 | inputs_str, outputs_str = self.gen_input_output_string( 16 | node, initializers, self.rename_helper, self.tensor_inplace) 17 | init_str, forward_str = [], [] 18 | forward_str.append(f"{outputs_str[0]} = torch.gt({', '.join(inputs_str)})") 19 | return {"init": init_str, "forward": forward_str} 20 | -------------------------------------------------------------------------------- /onnx_pytorch/op_code_generators/Identity.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | import torch 3 | 4 | from onnx_pytorch.op_code_generators import OpCodeGenerator 5 | 6 | 7 | class IdentityOpCodeGenerator(OpCodeGenerator): 8 | 9 | def __init__(self, 10 | onnx_ver=onnx.defs.onnx_opset_version(), 11 | torch_ver=torch.__version__): 12 | super(IdentityOpCodeGenerator, self).__init__(onnx_ver, torch_ver) 13 | 14 | def gen(self, node, value_infos, initializers): 15 | inputs_str, outputs_str = self.gen_input_output_string( 16 | node, initializers, self.rename_helper, self.tensor_inplace) 17 | init_str, forward_str = [], [] 18 | nn_name = self.onnx_op 19 | node_name = self.rename_helper.get_node_name(node.name, node.op_type) 20 | init_str.append(f"self.{node_name} = nn.{nn_name}()") 21 | forward_str.append(f"{outputs_str[0]} = self.{node_name}({inputs_str[0]})") 22 | return {"init": init_str, "forward": forward_str} 23 | -------------------------------------------------------------------------------- /onnx_pytorch/op_code_generators/InstanceNormalization.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | import onnx.numpy_helper 3 | import torch 4 | 5 | from onnx_pytorch.op_code_generators import OpCodeGenerator 6 | 7 | 8 | class InstanceNormalizationOpCodeGenerator(OpCodeGenerator): 9 | 10 | def __init__(self, 11 | onnx_ver=onnx.defs.onnx_opset_version(), 12 | torch_ver=torch.__version__): 13 | super(InstanceNormalizationOpCodeGenerator, 14 | self).__init__(onnx_ver, torch_ver) 15 | 16 | def gen(self, node, value_infos, initializers): 17 | attr_value_dict = self.get_attr_value_dict(node) 18 | inputs_str, outputs_str = self.gen_input_output_string( 19 | node, initializers, self.rename_helper, self.tensor_inplace) 20 | 21 | d = len(value_infos[node.input[0]].type.tensor_type.shape.dim) - 2 22 | 23 | view = False 24 | if d == 0: 25 | d = 1 26 | view = True 27 | 28 | nn_name = f"InstanceNorm{d}d" 29 | node_name = self.rename_helper.get_node_name(node.name, node.op_type) 30 | 31 | params_str = self.gen_params_str(num_features=onnx.numpy_helper.to_array( 32 | initializers[node.input[1]]).shape[0], 33 | eps=attr_value_dict["epsilon"]) 34 | 35 | init_str, forward_str = [], [] 36 | init_str.append(f"self.{node_name} = nn.{nn_name}(**{{{params_str}}})") 37 | init_str.append(f"self.{node_name}.weight = {inputs_str[1]}") 38 | init_str.append(f"self.{node_name}.bias = {inputs_str[2]}") 39 | curr_input = inputs_str[0] 40 | if view: 41 | forward_str.append(f"{curr_input} = torch.unsqueeze({curr_input}, -1)") 42 | forward_str.append(f"{outputs_str[0]} = self.{node_name}({curr_input})") 43 | if view: 44 | forward_str.append( 45 | f"{outputs_str[0]} = torch.squeeze({outputs_str[0]}, -1)") 46 | return {"init": init_str, "forward": forward_str} 47 | -------------------------------------------------------------------------------- /onnx_pytorch/op_code_generators/LRN.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | import torch 3 | 4 | from onnx_pytorch.op_code_generators import OpCodeGenerator 5 | 6 | 7 | class LRNOpCodeGenerator(OpCodeGenerator): 8 | 9 | def __init__(self, 10 | onnx_ver=onnx.defs.onnx_opset_version(), 11 | torch_ver=torch.__version__): 12 | super(LRNOpCodeGenerator, self).__init__(onnx_ver, torch_ver) 13 | 14 | def gen(self, node, value_infos, initializers): 15 | attr_value_dict = self.get_attr_value_dict(node) 16 | inputs_str, outputs_str = self.gen_input_output_string( 17 | node, initializers, self.rename_helper, self.tensor_inplace) 18 | init_str, forward_str = [], [] 19 | params_str = self.gen_params_str(alpha=attr_value_dict["alpha"], 20 | beta=attr_value_dict["beta"], 21 | k=attr_value_dict["bias"], 22 | size=attr_value_dict["size"]) 23 | forward_str.append( 24 | f"{outputs_str[0]} = F.local_response_norm({inputs_str[0]}, **{{{params_str}}})" 25 | ) 26 | return {"init": init_str, "forward": forward_str} 27 | -------------------------------------------------------------------------------- /onnx_pytorch/op_code_generators/LayerNormalization.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | import onnx.numpy_helper 3 | import torch 4 | 5 | from onnx_pytorch.op_code_generators import OpCodeGenerator 6 | 7 | 8 | class LayerNormalizationOpCodeGenerator(OpCodeGenerator): 9 | 10 | def __init__(self, 11 | onnx_ver=onnx.defs.onnx_opset_version(), 12 | torch_ver=torch.__version__): 13 | super(LayerNormalizationOpCodeGenerator, self).__init__(onnx_ver, torch_ver) 14 | 15 | def gen(self, node, value_infos, initializers): 16 | attr_value_dict = self.get_attr_value_dict(node) 17 | inputs_str, outputs_str = self.gen_input_output_string( 18 | node, initializers, self.rename_helper, self.tensor_inplace) 19 | 20 | axis = attr_value_dict["axis"] 21 | 22 | nn_name = f"LayerNorm" 23 | node_name = self.rename_helper.get_node_name(node.name, node.op_type) 24 | 25 | params_str = self.gen_params_str( 26 | normalized_shape=onnx.numpy_helper.to_array( 27 | initializers[node.input[1]]).shape, 28 | eps=attr_value_dict["epsilon"]) 29 | 30 | init_str, forward_str = [], [] 31 | init_str.append(f"self.{node_name} = nn.{nn_name}(**{{{params_str}}})") 32 | init_str.append(f"self.{node_name}.weight.data = {inputs_str[1]}") 33 | init_str.append(f"self.{node_name}.bias.data = {inputs_str[2]}") 34 | curr_input = inputs_str[0] 35 | 36 | forward_str.append(f"{outputs_str[0]} = self.{node_name}({curr_input})") 37 | 38 | return {"init": init_str, "forward": forward_str} 39 | -------------------------------------------------------------------------------- /onnx_pytorch/op_code_generators/LeakyRelu.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | import torch 3 | 4 | from onnx_pytorch.op_code_generators import OpCodeGenerator 5 | 6 | 7 | class LeakyReluOpCodeGenerator(OpCodeGenerator): 8 | 9 | def __init__(self, 10 | onnx_ver=onnx.defs.onnx_opset_version(), 11 | torch_ver=torch.__version__): 12 | super(LeakyReluOpCodeGenerator, self).__init__(onnx_ver, torch_ver) 13 | 14 | def gen(self, node, value_infos, initializers): 15 | attr_value_dict = self.get_attr_value_dict(node) 16 | inputs_str, outputs_str = self.gen_input_output_string( 17 | node, initializers, self.rename_helper, self.tensor_inplace) 18 | init_str, forward_str = [], [] 19 | forward_str.append( 20 | f"{outputs_str[0]} = F.leaky_relu({inputs_str[0]}, {attr_value_dict['alpha']})" 21 | ) 22 | return {"init": init_str, "forward": forward_str} 23 | -------------------------------------------------------------------------------- /onnx_pytorch/op_code_generators/Less.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | import torch 3 | 4 | from onnx_pytorch.op_code_generators import OpCodeGenerator 5 | 6 | 7 | class LessOpCodeGenerator(OpCodeGenerator): 8 | 9 | def __init__(self, 10 | onnx_ver=onnx.defs.onnx_opset_version(), 11 | torch_ver=torch.__version__): 12 | super(LessOpCodeGenerator, self).__init__(onnx_ver, torch_ver) 13 | 14 | def gen(self, node, value_infos, initializers): 15 | inputs_str, outputs_str = self.gen_input_output_string( 16 | node, initializers, self.rename_helper, self.tensor_inplace) 17 | init_str, forward_str = [], [] 18 | forward_str.append(f"{outputs_str[0]} = torch.le({', '.join(inputs_str)})") 19 | return {"init": init_str, "forward": forward_str} 20 | -------------------------------------------------------------------------------- /onnx_pytorch/op_code_generators/Log.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | import torch 3 | 4 | from onnx_pytorch.op_code_generators import OpCodeGenerator 5 | 6 | 7 | class LogOpCodeGenerator(OpCodeGenerator): 8 | 9 | def __init__(self, 10 | onnx_ver=onnx.defs.onnx_opset_version(), 11 | torch_ver=torch.__version__): 12 | super(LogOpCodeGenerator, self).__init__(onnx_ver, torch_ver) 13 | 14 | def gen(self, node, value_infos, initializers): 15 | inputs_str, outputs_str = self.gen_input_output_string( 16 | node, initializers, self.rename_helper, self.tensor_inplace) 17 | init_str, forward_str = [], [] 18 | forward_str.append(f"{outputs_str[0]} = torch.log({', '.join(inputs_str)})") 19 | return {"init": init_str, "forward": forward_str} 20 | -------------------------------------------------------------------------------- /onnx_pytorch/op_code_generators/MatMul.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | import torch 3 | 4 | from onnx_pytorch.op_code_generators import OpCodeGenerator 5 | 6 | 7 | class MatMulOpCodeGenerator(OpCodeGenerator): 8 | 9 | def __init__(self, 10 | onnx_ver=onnx.defs.onnx_opset_version(), 11 | torch_ver=torch.__version__): 12 | super(MatMulOpCodeGenerator, self).__init__(onnx_ver, torch_ver) 13 | 14 | def gen(self, node, value_infos, initializers): 15 | inputs_str, outputs_str = self.gen_input_output_string( 16 | node, initializers, self.rename_helper, self.tensor_inplace) 17 | init_str, forward_str = [], [] 18 | forward_str.append( 19 | f"{outputs_str[0]} = torch.matmul({', '.join(inputs_str)})") 20 | return {"init": init_str, "forward": forward_str} 21 | -------------------------------------------------------------------------------- /onnx_pytorch/op_code_generators/Max.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | import torch 3 | 4 | from onnx_pytorch.op_code_generators import OpCodeGenerator 5 | 6 | 7 | class MaxOpCodeGenerator(OpCodeGenerator): 8 | 9 | def __init__(self, 10 | onnx_ver=onnx.defs.onnx_opset_version(), 11 | torch_ver=torch.__version__): 12 | super(MaxOpCodeGenerator, self).__init__(onnx_ver, torch_ver) 13 | 14 | def gen(self, node, value_infos, initializers): 15 | inputs_str, outputs_str = self.gen_input_output_string( 16 | node, initializers, self.rename_helper, self.tensor_inplace) 17 | init_str, forward_str = [], [] 18 | forward_str.append(f"{outputs_str[0]} = torch.max({', '.join(inputs_str)})") 19 | return {"init": init_str, "forward": forward_str} 20 | -------------------------------------------------------------------------------- /onnx_pytorch/op_code_generators/MaxPool.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import onnx 4 | import torch 5 | 6 | from onnx_pytorch.op_code_generators import OpCodeGenerator 7 | 8 | 9 | class MaxPoolOpCodeGenerator(OpCodeGenerator): 10 | 11 | def __init__(self, 12 | onnx_ver=onnx.defs.onnx_opset_version(), 13 | torch_ver=torch.__version__): 14 | super(MaxPoolOpCodeGenerator, self).__init__(onnx_ver, torch_ver) 15 | 16 | def gen(self, node, value_infos, initializers): 17 | attr_value_dict = self.get_attr_value_dict(node) 18 | inputs_str, outputs_str = self.gen_input_output_string( 19 | node, initializers, self.rename_helper, self.tensor_inplace) 20 | 21 | d = len(value_infos[node.input[0]].type.tensor_type.shape.dim) - 2 22 | assert (d in (1, 2, 3)) 23 | 24 | params = { 25 | "dilation": attr_value_dict.get("dilations", 1), 26 | "kernel_size": attr_value_dict["kernel_shape"][:].__repr__(), 27 | "ceil_mode": bool(attr_value_dict["ceil_mode"]), 28 | "stride": attr_value_dict.get("strides", 1), 29 | "return_indices": len(node.output) == 2 30 | } 31 | 32 | nn_name = f"MaxPool{d}d" 33 | node_name = self.rename_helper.get_node_name(node.name, node.op_type) 34 | init_str, forward_str = [], [] 35 | if "pads" in attr_value_dict: 36 | padding = [] 37 | pt_padding = [] 38 | for i in range(d): 39 | if attr_value_dict['pads'][i] == attr_value_dict['pads'][ 40 | i + d] and pt_padding is not None: 41 | pt_padding.append(attr_value_dict['pads'][i]) 42 | else: 43 | pt_padding = None 44 | padding.insert(0, attr_value_dict['pads'][i + d]) 45 | padding.insert(0, attr_value_dict['pads'][i]) 46 | if pt_padding is None: 47 | logging.warning( 48 | "MaxPool with asymmetric padding will get incorrect indices.") 49 | forward_str.append( 50 | f"{inputs_str[0]} = F.pad({inputs_str[0]}, {padding.__repr__()}, value=float('-inf'))" 51 | ) 52 | else: 53 | params["padding"] = pt_padding.__repr__() 54 | params_str = self.gen_params_str(**params) 55 | init_str.append(f"self.{node_name} = nn.{nn_name}(**{{{params_str}}})") 56 | forward_str.append( 57 | f"{', '.join(outputs_str)} = self.{node_name}({inputs_str[0]})") 58 | return {"init": init_str, "forward": forward_str} 59 | -------------------------------------------------------------------------------- /onnx_pytorch/op_code_generators/Mul.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | import torch 3 | 4 | from onnx_pytorch.op_code_generators import OpCodeGenerator 5 | 6 | 7 | class MulOpCodeGenerator(OpCodeGenerator): 8 | 9 | def __init__(self, 10 | onnx_ver=onnx.defs.onnx_opset_version(), 11 | torch_ver=torch.__version__): 12 | super(MulOpCodeGenerator, self).__init__(onnx_ver, torch_ver) 13 | 14 | def gen(self, node, value_infos, initializers): 15 | inputs_str, outputs_str = self.gen_input_output_string( 16 | node, initializers, self.rename_helper, self.tensor_inplace) 17 | init_str, forward_str = [], [] 18 | forward_str.append(f"{outputs_str[0]} = torch.mul({', '.join(inputs_str)})") 19 | return {"init": init_str, "forward": forward_str} 20 | -------------------------------------------------------------------------------- /onnx_pytorch/op_code_generators/NonMaxSuppression.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | import torch 3 | 4 | from onnx_pytorch.op_code_generators import ReduceOpCodeGenerator 5 | 6 | 7 | class NonMaxSuppressionOpCodeGenerator(ReduceOpCodeGenerator): 8 | 9 | def __init__(self, 10 | onnx_ver=onnx.defs.onnx_opset_version(), 11 | torch_ver=torch.__version__): 12 | super(NonMaxSuppressionOpCodeGenerator, self).__init__(onnx_ver, torch_ver) 13 | 14 | def gen(self, node, value_infos, initializers): 15 | attr_value_dict = self.get_attr_value_dict(node) 16 | inputs_str, outputs_str = self.gen_input_output_string( 17 | node, initializers, self.rename_helper, self.tensor_inplace) 18 | init_str, forward_str = [], [] 19 | iou_threshold = inputs_str[3] if len(inputs_str) > 3 else "0.0" 20 | score_threshold = inputs_str[4] if len(inputs_str) > 4 else "0.0" 21 | max_output_boxes_per_class = inputs_str[2] if len(inputs_str) > 2 else "0" 22 | forward_str.append( 23 | f"{outputs_str[0]} = self.nms({inputs_str[0]}, {inputs_str[1]}, {max_output_boxes_per_class}, {iou_threshold}, {score_threshold}, center_point_box={attr_value_dict.get('center_point_box', 0)})" 24 | ) 25 | return {"init": init_str, "forward": forward_str} 26 | 27 | @staticmethod 28 | def gen_method(): 29 | return f'''def nms(self, boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, center_point_box=0, **kwargs): 30 | nms_rs_list = [] 31 | for i in range(boxes.shape[0]): 32 | for j in range(scores.shape[1]): 33 | for k in range(boxes.shape[1]): 34 | if center_point_box == 1: 35 | boxes[i][k] = torchvision.ops.box_convert(boxes[i][k], "cxcywh", "xyxy") 36 | else: 37 | x1, y1, x2, y2 = boxes[i][k] 38 | if x1 < x2 and y1 < y2: 39 | continue 40 | indices = [0, 1, 2, 3] 41 | if x1 > x2: 42 | indices = [indices[l] for l in (2, 1, 0, 3)] 43 | if y1 > y2: 44 | indices = [indices[l] for l in (0, 3, 2, 1)] 45 | boxes[i][k] = boxes[i][k].gather(0, torch.tensor(indices)) 46 | mask = scores[i][j] >= score_threshold 47 | nms_rs = torchvision.ops.nms(boxes[i], scores[i][j], float(iou_threshold))[:max_output_boxes_per_class] 48 | nms_rs_masked = nms_rs[:mask[nms_rs].nonzero(as_tuple=False).flatten().shape[0]] 49 | batch_index = torch.full((nms_rs_masked.shape[0], 1), i) 50 | class_index = torch.full((nms_rs_masked.shape[0], 1), j) 51 | nms_rs_list.append(torch.cat((batch_index, class_index, nms_rs_masked.unsqueeze(1)), dim=1)) 52 | return torch.cat(nms_rs_list, dim=0) 53 | ''' 54 | -------------------------------------------------------------------------------- /onnx_pytorch/op_code_generators/NonZero.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | import torch 3 | 4 | from onnx_pytorch.op_code_generators import OpCodeGenerator 5 | 6 | 7 | class NonZeroOpCodeGenerator(OpCodeGenerator): 8 | 9 | def __init__(self, 10 | onnx_ver=onnx.defs.onnx_opset_version(), 11 | torch_ver=torch.__version__): 12 | super(NonZeroOpCodeGenerator, self).__init__(onnx_ver, torch_ver) 13 | 14 | def gen(self, node, value_infos, initializers): 15 | inputs_str, outputs_str = self.gen_input_output_string( 16 | node, initializers, self.rename_helper, self.tensor_inplace) 17 | init_str, forward_str = [], [] 18 | forward_str.append( 19 | f"{outputs_str[0]} = torch.nonzero({', '.join(inputs_str)}).T") 20 | return {"init": init_str, "forward": forward_str} 21 | -------------------------------------------------------------------------------- /onnx_pytorch/op_code_generators/Not.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | import torch 3 | 4 | from onnx_pytorch.op_code_generators import OpCodeGenerator 5 | 6 | 7 | class NotOpCodeGenerator(OpCodeGenerator): 8 | 9 | def __init__(self, 10 | onnx_ver=onnx.defs.onnx_opset_version(), 11 | torch_ver=torch.__version__): 12 | super(NotOpCodeGenerator, self).__init__(onnx_ver, torch_ver) 13 | 14 | def gen(self, node, value_infos, initializers): 15 | inputs_str, outputs_str = self.gen_input_output_string( 16 | node, initializers, self.rename_helper, self.tensor_inplace) 17 | init_str, forward_str = [], [] 18 | forward_str.append( 19 | f"{outputs_str[0]} = torch.logical_not({', '.join(inputs_str)}).T") 20 | return {"init": init_str, "forward": forward_str} 21 | -------------------------------------------------------------------------------- /onnx_pytorch/op_code_generators/PRelu.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | import torch 3 | 4 | from onnx_pytorch.op_code_generators import OpCodeGenerator 5 | 6 | 7 | class PReluOpCodeGenerator(OpCodeGenerator): 8 | 9 | def __init__(self, 10 | onnx_ver=onnx.defs.onnx_opset_version(), 11 | torch_ver=torch.__version__): 12 | super(PReluOpCodeGenerator, self).__init__(onnx_ver, torch_ver) 13 | 14 | def gen(self, node, value_infos, initializers): 15 | inputs_str, outputs_str = self.gen_input_output_string( 16 | node, initializers, self.rename_helper, self.tensor_inplace) 17 | init_str, forward_str = [], [] 18 | forward_str.append(f"{outputs_str[0]} = F.prelu({', '.join(inputs_str)})") 19 | return {"init": init_str, "forward": forward_str} 20 | -------------------------------------------------------------------------------- /onnx_pytorch/op_code_generators/Pad.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | import torch 3 | 4 | from onnx_pytorch.op_code_generators import OpCodeGenerator 5 | 6 | 7 | class PadOpCodeGenerator(OpCodeGenerator): 8 | 9 | def __init__(self, 10 | onnx_ver=onnx.defs.onnx_opset_version(), 11 | torch_ver=torch.__version__): 12 | super(PadOpCodeGenerator, self).__init__(onnx_ver, torch_ver) 13 | 14 | def gen(self, node, value_infos, initializers): 15 | attr_value_dict = self.get_attr_value_dict(node) 16 | inputs_str, outputs_str = self.gen_input_output_string( 17 | node, initializers, self.rename_helper, self.tensor_inplace) 18 | init_str, forward_str = [], [] 19 | mode = attr_value_dict.get("mode", b"constant") 20 | value = 0. 21 | if mode == b"constant": 22 | if len(node.input) == 3: 23 | value = onnx.numpy_helper.to_array(initializers[node.input[2]])[0] 24 | d = len(value_infos[node.input[0]].type.tensor_type.shape.dim) - 2 25 | if len(node.input) > 1: 26 | pads = initializers.get(node.input[1], None) 27 | assert pads is not None, "Currently PadOpCodeGenerator only support all of [pads] is in initializers." 28 | pads = onnx.numpy_helper.to_array(pads) 29 | else: 30 | pads = attr_value_dict["pads"] 31 | pt_pads = [0, 0] * d 32 | for i in range(d): 33 | pt_pads[2 * (d - i - 1)] = pads[2 + i] 34 | pt_pads[2 * (d - i - 1) + 1] = pads[d + 2 + 2 + i] 35 | forward_str.append( 36 | f"{outputs_str[0]} = F.pad({inputs_str[0]}, {pt_pads.__repr__()}, \"{mode.decode()}\", {value})" 37 | ) 38 | return {"init": init_str, "forward": forward_str} 39 | -------------------------------------------------------------------------------- /onnx_pytorch/op_code_generators/Reciprocal.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | import torch 3 | 4 | from onnx_pytorch.op_code_generators import OpCodeGenerator 5 | 6 | 7 | class ReciprocalOpCodeGenerator(OpCodeGenerator): 8 | 9 | def __init__(self, 10 | onnx_ver=onnx.defs.onnx_opset_version(), 11 | torch_ver=torch.__version__): 12 | super(ReciprocalOpCodeGenerator, self).__init__(onnx_ver, torch_ver) 13 | 14 | def gen(self, node, value_infos, initializers): 15 | inputs_str, outputs_str = self.gen_input_output_string( 16 | node, initializers, self.rename_helper, self.tensor_inplace) 17 | init_str, forward_str = [], [] 18 | forward_str.append(f"{outputs_str[0]} = 1 / {inputs_str[0]}") 19 | return {"init": init_str, "forward": forward_str} 20 | -------------------------------------------------------------------------------- /onnx_pytorch/op_code_generators/ReduceMean.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | import torch 3 | 4 | from onnx_pytorch.op_code_generators import ReduceOpCodeGenerator 5 | 6 | 7 | class ReduceMeanOpCodeGenerator(ReduceOpCodeGenerator): 8 | 9 | def __init__(self, 10 | onnx_ver=onnx.defs.onnx_opset_version(), 11 | torch_ver=torch.__version__): 12 | super(ReduceMeanOpCodeGenerator, self).__init__(onnx_ver, torch_ver) 13 | 14 | def gen(self, node, value_infos, initializers): 15 | attr_value_dict = self.get_attr_value_dict(node) 16 | inputs_str, outputs_str = self.gen_input_output_string( 17 | node, initializers, self.rename_helper, self.tensor_inplace) 18 | init_str, forward_str = [], [] 19 | d = len(value_infos[node.input[0]].type.tensor_type.shape.dim) 20 | dim = self._get_dim(attr_value_dict, d, node, initializers) 21 | params_str = self.gen_params_str(keepdim=bool(attr_value_dict["keepdims"])) 22 | forward_str.append( 23 | f"{outputs_str[0]} = torch.mean({inputs_str[0]}, {dim.__repr__()}, **{{{params_str}}})" 24 | ) 25 | return {"init": init_str, "forward": forward_str} 26 | -------------------------------------------------------------------------------- /onnx_pytorch/op_code_generators/ReduceMin.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | import torch 3 | 4 | from onnx_pytorch.op_code_generators import ReduceOpCodeGenerator 5 | 6 | 7 | class ReduceMinOpCodeGenerator(ReduceOpCodeGenerator): 8 | 9 | def __init__(self, 10 | onnx_ver=onnx.defs.onnx_opset_version(), 11 | torch_ver=torch.__version__): 12 | super(ReduceMinOpCodeGenerator, self).__init__(onnx_ver, torch_ver) 13 | 14 | def gen(self, node, value_infos, initializers): 15 | attr_value_dict = self.get_attr_value_dict(node) 16 | inputs_str, outputs_str = self.gen_input_output_string( 17 | node, initializers, self.rename_helper, self.tensor_inplace) 18 | init_str, forward_str = [], [] 19 | d = len(value_infos[node.input[0]].type.tensor_type.shape.dim) 20 | dim = self._get_dim(attr_value_dict, d, node, initializers) 21 | params_str = self.gen_params_str(keepdim=bool(attr_value_dict["keepdims"])) 22 | curr_input = inputs_str[0] 23 | for d in reversed(dim): 24 | forward_str.append( 25 | f"{outputs_str[0]}, _ = torch.min({curr_input}, {d}, **{{{params_str}}})" 26 | ) 27 | curr_input = outputs_str[0] 28 | return {"init": init_str, "forward": forward_str} 29 | -------------------------------------------------------------------------------- /onnx_pytorch/op_code_generators/ReduceProd.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | import torch 3 | 4 | from onnx_pytorch.op_code_generators import ReduceOpCodeGenerator 5 | 6 | 7 | class ReduceProdOpCodeGenerator(ReduceOpCodeGenerator): 8 | 9 | def __init__(self, 10 | onnx_ver=onnx.defs.onnx_opset_version(), 11 | torch_ver=torch.__version__): 12 | super(ReduceProdOpCodeGenerator, self).__init__(onnx_ver, torch_ver) 13 | 14 | def gen(self, node, value_infos, initializers): 15 | attr_value_dict = self.get_attr_value_dict(node) 16 | inputs_str, outputs_str = self.gen_input_output_string( 17 | node, initializers, self.rename_helper, self.tensor_inplace) 18 | init_str, forward_str = [], [] 19 | d = len(value_infos[node.input[0]].type.tensor_type.shape.dim) 20 | dim = self._get_dim(attr_value_dict, d, node, initializers) 21 | params_str = self.gen_params_str(keepdim=bool(attr_value_dict["keepdims"])) 22 | curr_input = inputs_str[0] 23 | for d in reversed(dim): 24 | forward_str.append( 25 | f"{outputs_str[0]} = torch.prod({curr_input}, {d}, **{{{params_str}}})" 26 | ) 27 | curr_input = outputs_str[0] 28 | return {"init": init_str, "forward": forward_str} 29 | -------------------------------------------------------------------------------- /onnx_pytorch/op_code_generators/ReduceSum.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | import torch 3 | 4 | from onnx_pytorch.op_code_generators import ReduceOpCodeGenerator 5 | 6 | 7 | class ReduceSumOpCodeGenerator(ReduceOpCodeGenerator): 8 | 9 | def __init__(self, 10 | onnx_ver=onnx.defs.onnx_opset_version(), 11 | torch_ver=torch.__version__): 12 | super(ReduceSumOpCodeGenerator, self).__init__(onnx_ver, torch_ver) 13 | 14 | def gen(self, node, value_infos, initializers): 15 | attr_value_dict = self.get_attr_value_dict(node) 16 | inputs_str, outputs_str = self.gen_input_output_string( 17 | node, initializers, self.rename_helper, self.tensor_inplace) 18 | init_str, forward_str = [], [] 19 | d = len(value_infos[node.input[0]].type.tensor_type.shape.dim) 20 | dim = self._get_dim(attr_value_dict, d, node, initializers) 21 | params_str = self.gen_params_str(keepdim=bool(attr_value_dict["keepdims"])) 22 | forward_str.append( 23 | f"{outputs_str[0]} = torch.sum({inputs_str[0]}, {dim.__repr__()}, **{{{params_str}}})" 24 | ) 25 | return {"init": init_str, "forward": forward_str} 26 | -------------------------------------------------------------------------------- /onnx_pytorch/op_code_generators/Relu.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | import torch 3 | 4 | from onnx_pytorch.op_code_generators import OpCodeGenerator 5 | 6 | 7 | class ReluOpCodeGenerator(OpCodeGenerator): 8 | 9 | def __init__(self, 10 | onnx_ver=onnx.defs.onnx_opset_version(), 11 | torch_ver=torch.__version__): 12 | super(ReluOpCodeGenerator, self).__init__(onnx_ver, torch_ver) 13 | 14 | def gen(self, node, value_infos, initializers): 15 | inputs_str, outputs_str = self.gen_input_output_string( 16 | node, initializers, self.rename_helper, self.tensor_inplace) 17 | init_str, forward_str = [], [] 18 | forward_str.append(f"{outputs_str[0]} = F.relu({inputs_str[0]})") 19 | return {"init": init_str, "forward": forward_str} 20 | -------------------------------------------------------------------------------- /onnx_pytorch/op_code_generators/Reshape.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | import torch 3 | 4 | from onnx_pytorch.op_code_generators import OpCodeGenerator 5 | 6 | 7 | class ReshapeOpCodeGenerator(OpCodeGenerator): 8 | 9 | def __init__(self, 10 | onnx_ver=onnx.defs.onnx_opset_version(), 11 | torch_ver=torch.__version__): 12 | super(ReshapeOpCodeGenerator, self).__init__(onnx_ver, torch_ver) 13 | 14 | def gen(self, node, value_infos, initializers): 15 | inputs_str, outputs_str = self.gen_input_output_string( 16 | node, initializers, self.rename_helper, self.tensor_inplace) 17 | init_str, forward_str = [], [] 18 | forward_str.append( 19 | f"{outputs_str[0]} = torch.reshape({inputs_str[0]}, [s if s != 0 else {inputs_str[0]}.shape[i] for i, s in enumerate({inputs_str[1]})])" 20 | ) 21 | return {"init": init_str, "forward": forward_str} 22 | -------------------------------------------------------------------------------- /onnx_pytorch/op_code_generators/Resize.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | import torch 3 | 4 | from onnx_pytorch.op_code_generators import OpCodeGenerator 5 | 6 | 7 | class ResizeOpCodeGenerator(OpCodeGenerator): 8 | 9 | def __init__(self, 10 | onnx_ver=onnx.defs.onnx_opset_version(), 11 | torch_ver=torch.__version__): 12 | super(ResizeOpCodeGenerator, self).__init__(onnx_ver, torch_ver) 13 | 14 | def gen(self, node, value_infos, initializers): 15 | attr_value_dict = self.get_attr_value_dict(node) 16 | inputs_str, outputs_str = self.gen_input_output_string( 17 | node, initializers, self.rename_helper, self.tensor_inplace) 18 | scales, sizes = None, None 19 | if len(node.input) == 4: 20 | sizes = tuple(onnx.numpy_helper.to_array(initializers[node.input[3]])[2:]) 21 | elif len(node.input) == 3: 22 | scales = tuple( 23 | onnx.numpy_helper.to_array(initializers[node.input[2]])[2:]) 24 | # Resize opset version 10 25 | elif len(node.input) == 2: 26 | if node.input[1] in initializers: 27 | scales = tuple( 28 | onnx.numpy_helper.to_array(initializers[node.input[1]])[2:]) 29 | else: 30 | scales = f"list({self.rename_helper.tensor_name_mapping.get(node.input[1], node.input[1])})[2:]" 31 | 32 | align_corners = None 33 | if attr_value_dict["coordinate_transformation_mode"].decode( 34 | ) == "align_corners": 35 | align_corners = True 36 | mode = attr_value_dict['mode'].decode() 37 | d = len(value_infos[node.input[0]].type.tensor_type.shape.dim) - 2 38 | assert d < 4, "Currently temporal, spatial and volumetric sampling are supported." 39 | if mode == "linear": 40 | modes = ["linear", "bilinear", "trilinear"] 41 | mode = modes[d - 1] 42 | params_str = self.gen_params_str( 43 | size=sizes, 44 | scale_factor=scales, 45 | mode=f"'{mode}'", 46 | align_corners=align_corners, 47 | recompute_scale_factor=scales is not None, 48 | ) 49 | init_str, forward_str = [], [] 50 | 51 | forward_str.append( 52 | f"{outputs_str[0]} = F.interpolate({inputs_str[0]}, **{{{params_str}}})" 53 | ) 54 | return {"init": init_str, "forward": forward_str} 55 | -------------------------------------------------------------------------------- /onnx_pytorch/op_code_generators/RoiAlign.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | import torch 3 | 4 | from onnx_pytorch.op_code_generators import OpCodeGenerator 5 | 6 | 7 | class RoiAlignOpCodeGenerator(OpCodeGenerator): 8 | 9 | def __init__(self, 10 | onnx_ver=onnx.defs.onnx_opset_version(), 11 | torch_ver=torch.__version__): 12 | super(RoiAlignOpCodeGenerator, self).__init__(onnx_ver, torch_ver) 13 | 14 | def gen(self, node, value_infos, initializers): 15 | attr_value_dict = self.get_attr_value_dict(node) 16 | inputs_str, outputs_str = self.gen_input_output_string( 17 | node, initializers, self.rename_helper, self.tensor_inplace) 18 | init_str, forward_str = [], [] 19 | params_str = self.gen_params_str( 20 | output_size=(attr_value_dict["output_height"], 21 | attr_value_dict["output_width"]), 22 | sampling_ratio=attr_value_dict["sampling_ratio"], 23 | spatial_scale=attr_value_dict["spatial_scale"], 24 | ) 25 | forward_str.append( 26 | f"boxes = torch.cat((torch.unsqueeze({inputs_str[2]}, 1), {inputs_str[1]}), axis=1)" 27 | ) 28 | forward_str.append( 29 | f"{outputs_str[0]} = torchvision.ops.roi_align({inputs_str[0]}, boxes, **{{{params_str}}})" 30 | ) 31 | return {"init": init_str, "forward": forward_str} 32 | -------------------------------------------------------------------------------- /onnx_pytorch/op_code_generators/Round.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | import torch 3 | 4 | from onnx_pytorch.op_code_generators import OpCodeGenerator 5 | 6 | 7 | class RoundOpCodeGenerator(OpCodeGenerator): 8 | 9 | def __init__(self, 10 | onnx_ver=onnx.defs.onnx_opset_version(), 11 | torch_ver=torch.__version__): 12 | super(RoundOpCodeGenerator, self).__init__(onnx_ver, torch_ver) 13 | 14 | def gen(self, node, value_infos, initializers): 15 | inputs_str, outputs_str = self.gen_input_output_string( 16 | node, initializers, self.rename_helper, self.tensor_inplace) 17 | init_str, forward_str = [], [] 18 | forward_str.append( 19 | f"{outputs_str[0]} = torch.round({', '.join(inputs_str)})") 20 | return {"init": init_str, "forward": forward_str} 21 | -------------------------------------------------------------------------------- /onnx_pytorch/op_code_generators/Scatter.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | import torch 3 | from onnx.numpy_helper import to_array 4 | 5 | from onnx_pytorch.op_code_generators import OpCodeGenerator 6 | 7 | 8 | class ScatterOpCodeGenerator(OpCodeGenerator): 9 | 10 | def __init__(self, 11 | onnx_ver=onnx.defs.onnx_opset_version(), 12 | torch_ver=torch.__version__): 13 | super(ScatterOpCodeGenerator, self).__init__(onnx_ver, torch_ver) 14 | 15 | def gen(self, node, value_infos, initializers): 16 | attr_value_dict = self.get_attr_value_dict(node) 17 | inputs_str, outputs_str = self.gen_input_output_string( 18 | node, initializers, self.rename_helper, self.tensor_inplace) 19 | init_str, forward_str = [], [] 20 | forward_str.append( 21 | f"{outputs_str[0]} = torch.scatter({inputs_str[0]}, {attr_value_dict['axis']}, {inputs_str[1]}, {inputs_str[2]})" 22 | ) 23 | return {"init": init_str, "forward": forward_str} 24 | -------------------------------------------------------------------------------- /onnx_pytorch/op_code_generators/ScatterElements.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | import torch 3 | from onnx.numpy_helper import to_array 4 | 5 | from onnx_pytorch.op_code_generators import OpCodeGenerator 6 | 7 | 8 | class ScatterElementsOpCodeGenerator(OpCodeGenerator): 9 | 10 | def __init__(self, 11 | onnx_ver=onnx.defs.onnx_opset_version(), 12 | torch_ver=torch.__version__): 13 | super(ScatterElementsOpCodeGenerator, self).__init__(onnx_ver, torch_ver) 14 | 15 | def gen(self, node, value_infos, initializers): 16 | attr_value_dict = self.get_attr_value_dict(node) 17 | inputs_str, outputs_str = self.gen_input_output_string( 18 | node, initializers, self.rename_helper, self.tensor_inplace) 19 | init_str, forward_str = [], [] 20 | forward_str.append( 21 | f"{outputs_str[0]} = torch.scatter({inputs_str[0]}, {attr_value_dict['axis']}, {inputs_str[1]}, {inputs_str[2]})" 22 | ) 23 | return {"init": init_str, "forward": forward_str} 24 | -------------------------------------------------------------------------------- /onnx_pytorch/op_code_generators/Shape.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | import torch 3 | 4 | from onnx_pytorch.op_code_generators import OpCodeGenerator 5 | 6 | 7 | class ShapeOpCodeGenerator(OpCodeGenerator): 8 | 9 | def __init__(self, 10 | onnx_ver=onnx.defs.onnx_opset_version(), 11 | torch_ver=torch.__version__): 12 | super(ShapeOpCodeGenerator, self).__init__(onnx_ver, torch_ver) 13 | 14 | def gen(self, node, value_infos, initializers): 15 | inputs_str, outputs_str = self.gen_input_output_string( 16 | node, initializers, self.rename_helper, self.tensor_inplace) 17 | init_str, forward_str = [], [] 18 | forward_str.append( 19 | f"{outputs_str[0]} = torch.tensor({inputs_str[0]}.shape, device={inputs_str[0]}.device)") 20 | return {"init": init_str, "forward": forward_str} 21 | -------------------------------------------------------------------------------- /onnx_pytorch/op_code_generators/Sigmoid.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | import torch 3 | 4 | from onnx_pytorch.op_code_generators import OpCodeGenerator 5 | 6 | 7 | class SigmoidOpCodeGenerator(OpCodeGenerator): 8 | 9 | def __init__(self, 10 | onnx_ver=onnx.defs.onnx_opset_version(), 11 | torch_ver=torch.__version__): 12 | super(SigmoidOpCodeGenerator, self).__init__(onnx_ver, torch_ver) 13 | 14 | def gen(self, node, value_infos, initializers): 15 | inputs_str, outputs_str = self.gen_input_output_string( 16 | node, initializers, self.rename_helper, self.tensor_inplace) 17 | init_str, forward_str = [], [] 18 | forward_str.append(f"{outputs_str[0]} = torch.sigmoid({inputs_str[0]})") 19 | return {"init": init_str, "forward": forward_str} 20 | -------------------------------------------------------------------------------- /onnx_pytorch/op_code_generators/Slice.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | import torch 3 | from onnx.numpy_helper import to_array 4 | 5 | from onnx_pytorch.op_code_generators import OpCodeGenerator 6 | 7 | 8 | class SliceOpCodeGenerator(OpCodeGenerator): 9 | 10 | def __init__(self, 11 | onnx_ver=onnx.defs.onnx_opset_version(), 12 | torch_ver=torch.__version__): 13 | super(SliceOpCodeGenerator, self).__init__(onnx_ver, torch_ver) 14 | 15 | def gen(self, node, value_infos, initializers): 16 | attr_value_dict = self.get_attr_value_dict(node) 17 | inputs_str, outputs_str = self.gen_input_output_string( 18 | node, initializers, self.rename_helper) 19 | init_str, forward_str = [], [] 20 | d = len(value_infos[node.input[0]].type.tensor_type.shape.dim) 21 | starts, ends, axes, steps = self._get_starts_ends_axes_steps( 22 | attr_value_dict, d, node, initializers) 23 | slice_str = [] 24 | for i in range(d): 25 | if i in axes: 26 | j = axes.index(i) 27 | s = ["", ""] 28 | if type(starts) == str and type(ends) == str: 29 | s[0] = f'{starts}[{j}] if {starts}[{j}]' 30 | s[1] = f'{ends}[{j}] if {ends}[{j}]' 31 | else: 32 | s = [ 33 | str(starts[j]) if starts[j] != 0 else "", 34 | str(ends[j]) if ends[j] < 2**31 else "" 35 | ] 36 | if steps[j] != 1: 37 | s.append(str(steps[j])) 38 | slice_str.append(":".join(s)) 39 | else: 40 | slice_str.append(":") 41 | 42 | forward_str.append( 43 | f"{outputs_str[0]} = {inputs_str[0]}[{', '.join(slice_str)}]") 44 | return {"init": init_str, "forward": forward_str} 45 | 46 | def _get_starts_ends_axes_steps(self, attr_value_dict, d, node, initializers): 47 | axes = list(range(d)) 48 | steps = [1] * len(axes) 49 | if self.onnx_ver > 1 and len(node.input) > 1: 50 | starts = initializers.get(node.input[1], None) 51 | ends = initializers.get(node.input[2], None) 52 | if starts is None: 53 | starts = node.input[1] 54 | else: 55 | starts = to_array(starts) 56 | if ends is None: 57 | ends = node.input[2] 58 | else: 59 | ends = to_array(ends) 60 | if len(node.input) > 3: 61 | axes = initializers.get(node.input[3], None) 62 | if len(node.input) > 4: 63 | steps = initializers.get(node.input[4], None) 64 | assert starts is not None or ends is not None or axes is not None or steps is not None, "Currently SliceOpCodeGenerator only support all of [starts, ends, axes, steps] is in initializers." 65 | if len(node.input) > 3: 66 | axes = to_array(axes) 67 | if len(node.input) > 4: 68 | steps = to_array(steps) 69 | else: 70 | starts = attr_value_dict["starts"] 71 | ends = attr_value_dict["ends"] 72 | axes = attr_value_dict.get("axes", axes) 73 | return starts, ends, list(axes), list(steps) 74 | -------------------------------------------------------------------------------- /onnx_pytorch/op_code_generators/Softmax.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | import torch 3 | 4 | from onnx_pytorch.op_code_generators import OpCodeGenerator 5 | 6 | 7 | class SoftmaxOpCodeGenerator(OpCodeGenerator): 8 | 9 | def __init__(self, 10 | onnx_ver=onnx.defs.onnx_opset_version(), 11 | torch_ver=torch.__version__): 12 | super(SoftmaxOpCodeGenerator, self).__init__(onnx_ver, torch_ver) 13 | 14 | def gen(self, node, value_infos, initializers): 15 | attr_value_dict = self.get_attr_value_dict(node) 16 | inputs_str, outputs_str = self.gen_input_output_string( 17 | node, initializers, self.rename_helper, self.tensor_inplace) 18 | init_str, forward_str = [], [] 19 | 20 | params_str = self.gen_params_str(dim=attr_value_dict["axis"]) 21 | forward_str.append( 22 | f"{outputs_str[0]} = F.softmax({inputs_str[0]}, **{{{params_str}}})") 23 | return {"init": init_str, "forward": forward_str} 24 | -------------------------------------------------------------------------------- /onnx_pytorch/op_code_generators/Split.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | import torch 3 | from onnx.numpy_helper import to_array 4 | 5 | from onnx_pytorch.op_code_generators import OpCodeGenerator 6 | 7 | 8 | class SplitOpCodeGenerator(OpCodeGenerator): 9 | 10 | def __init__(self, 11 | onnx_ver=onnx.defs.onnx_opset_version(), 12 | torch_ver=torch.__version__): 13 | super(SplitOpCodeGenerator, self).__init__(onnx_ver, torch_ver) 14 | 15 | def gen(self, node, value_infos, initializers): 16 | attr_value_dict = self.get_attr_value_dict(node) 17 | inputs_str, outputs_str = self.gen_input_output_string( 18 | node, initializers, self.rename_helper) 19 | init_str, forward_str = [], [] 20 | if self.onnx_ver > 11 and len(node.input) > 1: 21 | split = to_array(initializers[node.input[1]]).tolist() 22 | else: 23 | split = attr_value_dict.get("split", None) 24 | axis = attr_value_dict["axis"] 25 | 26 | params_str = self.gen_params_str(split_size_or_sections=split, dim=axis) 27 | forward_str.append( 28 | f"{', '.join(outputs_str)} = torch.split({inputs_str[0]}, **{{{params_str}}})" 29 | ) 30 | return {"init": init_str, "forward": forward_str} 31 | -------------------------------------------------------------------------------- /onnx_pytorch/op_code_generators/Sqrt.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | import torch 3 | 4 | from onnx_pytorch.op_code_generators import OpCodeGenerator 5 | 6 | 7 | class SqrtOpCodeGenerator(OpCodeGenerator): 8 | 9 | def __init__(self, 10 | onnx_ver=onnx.defs.onnx_opset_version(), 11 | torch_ver=torch.__version__): 12 | super(SqrtOpCodeGenerator, self).__init__(onnx_ver, torch_ver) 13 | 14 | def gen(self, node, value_infos, initializers): 15 | inputs_str, outputs_str = self.gen_input_output_string( 16 | node, initializers, self.rename_helper, self.tensor_inplace) 17 | init_str, forward_str = [], [] 18 | forward_str.append( 19 | f"{outputs_str[0]} = torch.sqrt({', '.join(inputs_str)})") 20 | return {"init": init_str, "forward": forward_str} 21 | -------------------------------------------------------------------------------- /onnx_pytorch/op_code_generators/Squeeze.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | import torch 3 | from onnx.numpy_helper import to_array 4 | 5 | from onnx_pytorch.op_code_generators import OpCodeGenerator 6 | 7 | 8 | class SqueezeOpCodeGenerator(OpCodeGenerator): 9 | 10 | def __init__(self, 11 | onnx_ver=onnx.defs.onnx_opset_version(), 12 | torch_ver=torch.__version__): 13 | super(SqueezeOpCodeGenerator, self).__init__(onnx_ver, torch_ver) 14 | 15 | def gen(self, node, value_infos, initializers): 16 | attr_value_dict = self.get_attr_value_dict(node) 17 | inputs_str, outputs_str = self.gen_input_output_string( 18 | node, initializers, self.rename_helper, self.tensor_inplace) 19 | axes = attr_value_dict.get("axes", []) 20 | if len(node.input) == 2: 21 | assert node.input[ 22 | 1] in initializers, "Currently SqueezeOpCodeGenerator only support all of [axes] is in initializers." 23 | axes = to_array(initializers[node.input[1]]) 24 | init_str, forward_str = [], [] 25 | curr_input = inputs_str[0] 26 | if len(axes) != 0: 27 | for a in reversed(axes): 28 | forward_str.append( 29 | f"{outputs_str[0]} = torch.squeeze({curr_input}, {a})") 30 | curr_input = outputs_str[0] 31 | else: 32 | forward_str.append(f"{outputs_str[0]} = torch.squeeze({curr_input})") 33 | 34 | return {"init": init_str, "forward": forward_str} 35 | -------------------------------------------------------------------------------- /onnx_pytorch/op_code_generators/Sub.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | import torch 3 | 4 | from onnx_pytorch.op_code_generators import OpCodeGenerator 5 | 6 | 7 | class SubOpCodeGenerator(OpCodeGenerator): 8 | 9 | def __init__(self, 10 | onnx_ver=onnx.defs.onnx_opset_version(), 11 | torch_ver=torch.__version__): 12 | super(SubOpCodeGenerator, self).__init__(onnx_ver, torch_ver) 13 | 14 | def gen(self, node, value_infos, initializers): 15 | inputs_str, outputs_str = self.gen_input_output_string( 16 | node, initializers, self.rename_helper, self.tensor_inplace) 17 | init_str, forward_str = [], [] 18 | forward_str.append(f"{outputs_str[0]} = torch.sub({', '.join(inputs_str)})") 19 | return {"init": init_str, "forward": forward_str} 20 | -------------------------------------------------------------------------------- /onnx_pytorch/op_code_generators/Tanh.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | import torch 3 | 4 | from onnx_pytorch.op_code_generators import OpCodeGenerator 5 | 6 | 7 | class TanhOpCodeGenerator(OpCodeGenerator): 8 | 9 | def __init__(self, 10 | onnx_ver=onnx.defs.onnx_opset_version(), 11 | torch_ver=torch.__version__): 12 | super(TanhOpCodeGenerator, self).__init__(onnx_ver, torch_ver) 13 | 14 | def gen(self, node, value_infos, initializers): 15 | inputs_str, outputs_str = self.gen_input_output_string( 16 | node, initializers, self.rename_helper, self.tensor_inplace) 17 | init_str, forward_str = [], [] 18 | forward_str.append(f"{outputs_str[0]} = torch.tanh({inputs_str[0]})") 19 | return {"init": init_str, "forward": forward_str} 20 | -------------------------------------------------------------------------------- /onnx_pytorch/op_code_generators/TopK.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | import torch 3 | 4 | from onnx_pytorch.op_code_generators import OpCodeGenerator 5 | 6 | 7 | class TopKOpCodeGenerator(OpCodeGenerator): 8 | 9 | def __init__(self, 10 | onnx_ver=onnx.defs.onnx_opset_version(), 11 | torch_ver=torch.__version__): 12 | super(TopKOpCodeGenerator, self).__init__(onnx_ver, torch_ver) 13 | 14 | def gen(self, node, value_infos, initializers): 15 | attr_value_dict = self.get_attr_value_dict(node) 16 | inputs_str, outputs_str = self.gen_input_output_string( 17 | node, initializers, self.rename_helper, self.tensor_inplace) 18 | init_str, forward_str = [], [] 19 | 20 | params_str = self.gen_params_str( 21 | dim=attr_value_dict.get("axis", -1), 22 | largest=bool(attr_value_dict.get("largest", 1)), 23 | sorted=bool(attr_value_dict.get("sorted", 1))) 24 | inputs_str[1] = f"int({inputs_str[1]})" 25 | forward_str.append( 26 | f"{', '.join(outputs_str)} = torch.topk({', '.join(inputs_str)}, **{{{params_str}}})" 27 | ) 28 | return {"init": init_str, "forward": forward_str} 29 | -------------------------------------------------------------------------------- /onnx_pytorch/op_code_generators/Transpose.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | import torch 3 | 4 | from onnx_pytorch.op_code_generators import OpCodeGenerator 5 | 6 | 7 | class TransposeOpCodeGenerator(OpCodeGenerator): 8 | 9 | def __init__(self, 10 | onnx_ver=onnx.defs.onnx_opset_version(), 11 | torch_ver=torch.__version__): 12 | super(TransposeOpCodeGenerator, self).__init__(onnx_ver, torch_ver) 13 | 14 | def gen(self, node, value_infos, initializers): 15 | attr_value_dict = self.get_attr_value_dict(node) 16 | inputs_str, outputs_str = self.gen_input_output_string( 17 | node, initializers, self.rename_helper, self.tensor_inplace) 18 | init_str, forward_str = [], [] 19 | if "perm" in attr_value_dict: 20 | forward_str.append( 21 | f"{outputs_str[0]} = {inputs_str[0]}.permute(*{attr_value_dict['perm'].__repr__()})" 22 | ) 23 | else: 24 | forward_str.append(f"{outputs_str[0]} = {inputs_str[0]}.T") 25 | return {"init": init_str, "forward": forward_str} 26 | -------------------------------------------------------------------------------- /onnx_pytorch/op_code_generators/Unsqueeze.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | import torch 3 | from onnx.numpy_helper import to_array 4 | 5 | from onnx_pytorch.op_code_generators import OpCodeGenerator 6 | 7 | 8 | class UnsqueezeOpCodeGenerator(OpCodeGenerator): 9 | 10 | def __init__(self, 11 | onnx_ver=onnx.defs.onnx_opset_version(), 12 | torch_ver=torch.__version__): 13 | super(UnsqueezeOpCodeGenerator, self).__init__(onnx_ver, torch_ver) 14 | 15 | def gen(self, node, value_infos, initializers): 16 | attr_value_dict = self.get_attr_value_dict(node) 17 | inputs_str, outputs_str = self.gen_input_output_string( 18 | node, initializers, self.rename_helper, self.tensor_inplace) 19 | axes = attr_value_dict.get("axes", []) 20 | if len(node.input) == 2: 21 | assert node.input[ 22 | 1] in initializers, "Currently UnsqueezeOpCodeGenerator only support all of [axes] is in initializers." 23 | axes = to_array(initializers[node.input[1]]) 24 | init_str, forward_str = [], [] 25 | curr_input = inputs_str[0] 26 | for a in axes: 27 | forward_str.append( 28 | f"{outputs_str[0]} = torch.unsqueeze({curr_input}, {a})") 29 | curr_input = outputs_str[0] 30 | 31 | return {"init": init_str, "forward": forward_str} 32 | -------------------------------------------------------------------------------- /onnx_pytorch/op_code_generators/Upsample.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | import torch 3 | 4 | from onnx_pytorch.op_code_generators import OpCodeGenerator 5 | 6 | 7 | class UpsampleOpCodeGenerator(OpCodeGenerator): 8 | 9 | def __init__(self, 10 | onnx_ver=onnx.defs.onnx_opset_version(), 11 | torch_ver=torch.__version__): 12 | super(UpsampleOpCodeGenerator, self).__init__(onnx_ver, torch_ver) 13 | 14 | def gen(self, node, value_infos, initializers): 15 | attr_value_dict = self.get_attr_value_dict(node) 16 | inputs_str, outputs_str = self.gen_input_output_string( 17 | node, initializers, self.rename_helper, self.tensor_inplace) 18 | if node.input[1] in initializers: 19 | scales = tuple( 20 | onnx.numpy_helper.to_array(initializers[node.input[1]])[2:]) 21 | else: 22 | scales = f"list({self.rename_helper.tensor_name_mapping.get(node.input[1], node.input[1])})[2:]" 23 | 24 | align_corners = None 25 | mode = attr_value_dict['mode'].decode() 26 | d = len(value_infos[node.input[0]].type.tensor_type.shape.dim) - 2 27 | assert d < 4, "Currently temporal, spatial and volumetric sampling are supported." 28 | if mode == "linear": 29 | modes = ["linear", "bilinear", "trilinear"] 30 | mode = modes[d - 1] 31 | params_str = self.gen_params_str( 32 | scale_factor=scales, 33 | mode=f"'{mode}'", 34 | align_corners=align_corners, 35 | recompute_scale_factor=scales is not None, 36 | ) 37 | init_str, forward_str = [], [] 38 | 39 | forward_str.append( 40 | f"{outputs_str[0]} = F.interpolate({inputs_str[0]}, **{{{params_str}}})" 41 | ) 42 | return {"init": init_str, "forward": forward_str} 43 | -------------------------------------------------------------------------------- /onnx_pytorch/op_code_generators/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import onnx 4 | import onnx.numpy_helper 5 | from onnx.numpy_helper import to_array 6 | import torch 7 | 8 | import glob 9 | import os 10 | 11 | modules = glob.glob(os.path.join(os.path.dirname(__file__), "*.py")) 12 | __all__ = [ 13 | os.path.basename(f)[:-3] 14 | for f in modules 15 | if os.path.isfile(f) and not f.endswith('__init__.py') 16 | ] + ["get_op_code_generator"] 17 | 18 | 19 | class OpCodeGenerator: 20 | 21 | def __init__(self, 22 | onnx_ver=onnx.defs.onnx_opset_version(), 23 | torch_ver=torch.__version__): 24 | self.onnx_ver = onnx_ver 25 | self.torch_ver = torch_ver 26 | self.onnx_op = self.__class__.__name__.replace("OpCodeGenerator", "") 27 | self.schema = onnx.defs.get_schema(self.onnx_op, 28 | max_inclusive_version=onnx_ver) 29 | 30 | # Should inherit from ModelCodeGenerator 31 | self.rename_helper = None 32 | self.tensor_inplace = None 33 | 34 | if self.schema is not None: 35 | self.attr_default = {} 36 | for a, i in self.schema.attributes.items(): 37 | try: 38 | default_value = onnx.helper.get_attribute_value(i.default_value) 39 | self.attr_default[a] = default_value 40 | except Exception as e: 41 | logging.warning( 42 | f"Cannot get default value for {a} of {self.onnx_op}.") 43 | 44 | def gen(self, node, value_infos, initializers): 45 | raise Exception 46 | 47 | def get_attr_value_dict(self, node): 48 | attr_value_dict = {} 49 | for a in node.attribute: 50 | attr_value_dict[a.name] = onnx.helper.get_attribute_value(a) 51 | attr_value_dict = dict( 52 | list(self.attr_default.items()) + list(attr_value_dict.items())) 53 | return attr_value_dict 54 | 55 | def gen_input_output_string(self, 56 | node, 57 | initializers, 58 | rename_helper, 59 | tensor_inplace=False, 60 | input_num=None, 61 | output_num=None): 62 | inputs_str, outputs_str = [], [] 63 | input_num, output_num = input_num or len(node.input), output_num or len( 64 | node.output) 65 | for idx, (num, f, ls) in enumerate( 66 | ((input_num, node.input, inputs_str), (output_num, node.output, 67 | outputs_str))): 68 | for i in range(num): 69 | # tensor_inplace condition: 70 | # idx == 1: output 71 | # i == 0: first output tensor (Currently only support first tensor inplace) 72 | # node.input[0] not in initializers: Could not inplace initializer 73 | # rename_helper.tensor_name_counter[f[i]] == 2: output tensor 0 should only be counted twice 74 | # rename_helper.tensor_name_counter[node.input[0]] == 2: input tensor 0 should only be counted twice 75 | if idx == 1 \ 76 | and i == 0 \ 77 | and tensor_inplace \ 78 | and len(node.input) > 0 \ 79 | and node.input[0] not in initializers \ 80 | and rename_helper.tensor_name_counter[f[i]] == 2 \ 81 | and rename_helper.tensor_name_counter[node.input[0]] == 2: 82 | tensor_name = node.input[0] 83 | rename_helper.tensor_name_mapping[ 84 | f[i]] = rename_helper.get_tensor_name(tensor_name) 85 | else: 86 | tensor_name = f[i] 87 | formatter = "{}" 88 | if tensor_name in initializers: 89 | formatter = "self._vars[\"{}\"]" 90 | s = formatter.format(rename_helper.get_tensor_name(tensor_name)) 91 | ls.append(s) 92 | 93 | return inputs_str, outputs_str 94 | 95 | def gen_params_str(self, **kwargs): 96 | params = [] 97 | for k, v in kwargs.items(): 98 | v_str = v if type(v) == str else v.__repr__() 99 | params.append(f"'{k}': {v_str}") 100 | return ', '.join(params).__repr__()[1:-1] 101 | 102 | def check_in_init(self, targets, initializers): 103 | lacks = [] 104 | rs = [None] * len(targets) 105 | for i, (t, n) in enumerate(targets): 106 | init = initializers.get(n, None) 107 | if init is None: 108 | lacks.append(n) 109 | rs[i] = init 110 | if lacks: 111 | raise Exception( 112 | f"Currently {self.__class__} only support all of {lacks.__repr__()} is in initializers." 113 | ) 114 | return rs 115 | 116 | def get_shape(self, value, value_infos): 117 | if value not in value_infos: 118 | return None 119 | shape = [] 120 | for d in value_infos[value].type.tensor_type.shape.dim: 121 | if d.dim_param != "": 122 | shape.append(-1) 123 | else: 124 | shape.append(d.dim_value) 125 | return shape 126 | 127 | 128 | class ReduceOpCodeGenerator(OpCodeGenerator): 129 | 130 | def __init__(self, 131 | onnx_ver=onnx.defs.onnx_opset_version(), 132 | torch_ver=torch.__version__): 133 | super(ReduceOpCodeGenerator, self).__init__(onnx_ver, torch_ver) 134 | 135 | def _get_dim(self, attr_value_dict, d, node, initializers): 136 | if "axes" in attr_value_dict: 137 | dim = attr_value_dict["axes"] 138 | else: 139 | dim = list(range(d)) 140 | if len(node.input) > 1: 141 | dim = initializers.get(node.input[1], None) 142 | assert dim is not None, "Currently ReduceOpCodeGenerator only support all of [axes] is in initializers." 143 | dim = list(to_array(dim)) 144 | return dim 145 | 146 | 147 | __op_gen_dict = {} 148 | 149 | 150 | def get_op_code_generator(op, **kwargs): 151 | op_code_gen_name = "{}OpCodeGenerator".format(op) 152 | if op_code_gen_name in __op_gen_dict: 153 | return __op_gen_dict[op_code_gen_name] 154 | mod = globals().get(op, None) 155 | if mod is None: 156 | return None 157 | __op_gen_dict[op_code_gen_name] = getattr(mod, op_code_gen_name)(**kwargs) 158 | return __op_gen_dict[op_code_gen_name] 159 | 160 | 161 | def clear_op_code_generator(): 162 | __op_gen_dict = {} 163 | -------------------------------------------------------------------------------- /onnx_pytorch/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fumihwh/onnx-pytorch/b2c10eb640aab6f51e8f1e6f1c2a2cc97a4741f1/onnx_pytorch/tests/__init__.py -------------------------------------------------------------------------------- /onnx_pytorch/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fumihwh/onnx-pytorch/b2c10eb640aab6f51e8f1e6f1c2a2cc97a4741f1/onnx_pytorch/utils/__init__.py -------------------------------------------------------------------------------- /onnx_pytorch/utils/embedding_config_helper.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import re 3 | 4 | import onnx 5 | from onnx.numpy_helper import to_array 6 | import yaml 7 | 8 | 9 | class EmbeddingParam: 10 | 11 | def __init__(self, 12 | name, 13 | num_embeddings, 14 | embedding_dim, 15 | padding_idx=None, 16 | max_norm=None, 17 | norm_type=2.0, 18 | scale_grad_by_freq=False, 19 | sparse=False, 20 | embeddings_initializer=None, 21 | embeddings_regularizer=None): 22 | self.name = name 23 | self.num_embeddings = num_embeddings 24 | self.embedding_dim = embedding_dim 25 | self.padding_idx = padding_idx 26 | self.max_norm = max_norm 27 | self.norm_type = norm_type 28 | self.scale_grad_by_freq = scale_grad_by_freq 29 | self.sparse = sparse 30 | self.initializer = embeddings_initializer 31 | self.regularizer = embeddings_regularizer 32 | 33 | 34 | def gen_embedding_config(onnx_model_path, embedding_conf_file): 35 | model = onnx.load(onnx_model_path) 36 | initializers = {i.name: i for i in model.graph.initializer} 37 | inputs = {i.name: i for i in model.graph.input if i.name not in initializers} 38 | gathers = [ 39 | n for n in model.graph.node 40 | if n.op_type == "Gather" and len(n.input) > 1 and n.input[1] in inputs 41 | ] 42 | embeddings = [ 43 | EmbeddingParam(name=n.name, 44 | num_embeddings=to_array(initializers[n.input[0]]).shape[0], 45 | embedding_dim=to_array(initializers[n.input[0]]).shape[1]) 46 | for n in gathers 47 | ] 48 | with open(embedding_conf_file, "w") as f: 49 | f.write( 50 | yaml.dump([{ 51 | "name": e.name, 52 | "num_embeddings": e.num_embeddings, 53 | "embedding_dim": e.embedding_dim 54 | } for e in embeddings], 55 | sort_keys=False)) 56 | 57 | 58 | def load_embedding_config(embedding_conf_file): 59 | with open(embedding_conf_file, "r") as f: 60 | embeddings = yaml.load(f) 61 | embeddings = { 62 | f'{re.sub("[:/.]", "_", e["name"])}': EmbeddingParam(**e) 63 | for e in embeddings 64 | } 65 | return embeddings 66 | 67 | 68 | def main(): 69 | parser = argparse.ArgumentParser() 70 | parser.add_argument("--onnx_model_path", 71 | default=None, 72 | type=str, 73 | help="The onnx model path.") 74 | parser.add_argument("--embedding_conf_file", 75 | type=str, 76 | help="Embedding config file path.") 77 | args = parser.parse_args() 78 | 79 | gen_embedding_config(onnx_model_path=args.onnx_model_path, 80 | embedding_conf_file=args.embedding_conf_file) 81 | 82 | 83 | if __name__ == '__main__': 84 | main() 85 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.19.5 2 | onnx>=1.10.2 3 | onnxruntime>=1.9.0 4 | pytest>=6.2.5 5 | PyYAML>=6.0 6 | setuptools>=59.2.0 7 | sympy>=1.9 8 | torch>=1.10.0 9 | torchvision>=0.11.1 10 | tqdm>=4.62.3 11 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List 3 | from setuptools import setup, find_packages 4 | 5 | 6 | def _version() -> str: 7 | from onnx_pytorch import _version 8 | return _version.__version__ 9 | 10 | 11 | def _parse_requirements() -> List[str]: 12 | file_path = "requirements.txt" 13 | if not os.path.exists(file_path): 14 | file_path = "onnx_pytorch.egg-info/requires.txt" 15 | with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 16 | file_path)) as f: 17 | required = f.read().splitlines() 18 | return required 19 | 20 | 21 | setup(name="onnx-pytorch", 22 | version=_version(), 23 | description="Convert ONNX to PyTorch code.", 24 | long_description=open("README.md").read(), 25 | long_description_content_type="text/markdown", 26 | author="fumihwh", 27 | author_email="fumihwh@gmail.com", 28 | url="https://github.com/fumihwh/onnx-pytorch", 29 | packages=find_packages(), 30 | license="Apache 2.0", 31 | scripts=["onnx_pytorch/code_gen.py"], 32 | install_requires=_parse_requirements(), 33 | classifiers=["Programming Language :: Python :: 3"]) 34 | -------------------------------------------------------------------------------- /tutorial.py: -------------------------------------------------------------------------------- 1 | from onnx_pytorch import code_gen 2 | code_gen.gen("resnet18-v2-7.onnx", "./") 3 | 4 | import numpy as np 5 | import onnx 6 | import onnxruntime 7 | import torch 8 | torch.set_printoptions(8) 9 | 10 | from model import Model 11 | 12 | model = Model() 13 | model.eval() 14 | inp = np.random.randn(1, 3, 224, 224).astype(np.float32) 15 | with torch.no_grad(): 16 | torch_outputs = model(torch.from_numpy(inp)) 17 | 18 | onnx_model = onnx.load("resnet18-v2-7.onnx") 19 | sess_options = onnxruntime.SessionOptions() 20 | session = onnxruntime.InferenceSession(onnx_model.SerializeToString(), 21 | sess_options) 22 | inputs = {session.get_inputs()[0].name: inp} 23 | ort_outputs = session.run(None, inputs) 24 | 25 | print( 26 | "Comparison result:", 27 | np.allclose(torch_outputs.detach().numpy(), 28 | ort_outputs[0], 29 | atol=1e-5, 30 | rtol=1e-5)) 31 | --------------------------------------------------------------------------------