├── .github ├── ISSUE_TEMPLATE │ ├── bug-report.md │ └── request-operator.md ├── codecov.yml └── workflows │ ├── build.yaml │ ├── codeql-analysis.yml │ ├── publish.yml │ ├── rebase.yml │ ├── sanity.yaml │ └── stale.yml ├── .gitignore ├── LICENSE ├── README.md ├── _config.yml ├── assets └── tests │ ├── abs-add-rsqrt.float32.tflite │ ├── abs-sqrt.float32.tflite │ ├── abs.float32.tflite │ ├── add-broadcast.float32.tflite │ ├── add-broadcast2.float32.tflite │ ├── add-relu.float32.tflite │ ├── add.float32.tflite │ ├── avgpool-concat.float32.tflite │ ├── avgpooling.float32.tflite │ ├── concat.float32.tflite │ ├── concat2.float32.tflite │ ├── conv-dilation.float32.tflite │ ├── conv-quant-fp16.float32.tflite │ ├── conv-relu.float32.tflite │ ├── conv-relu.uint8.tflite │ ├── conv-relu6.float32.tflite │ ├── conv-reshape-multiple-conv.float32.tflite │ ├── conv-reshape.float32.tflite │ ├── conv-stride.float32.tflite │ ├── conv-transpose_relu.float32.tflite │ ├── conv.float32.tflite │ ├── conv.uint8.tflite │ ├── depthwise-conv-stride.float32.tflite │ ├── depthwise-conv.float32.tflite │ ├── depthwise-conv.uint8.tflite │ ├── fullyconnected-relu6.float32.tflite │ ├── fullyconnected.float32.tflite │ ├── maxpooling.float32.tflite │ ├── mean-keepdims.float32.tflite │ ├── mean.float32.tflite │ ├── mirror-pad.int32.tflite │ ├── mobilenet_v1_0.25_128.tflite │ ├── mobilenet_v1_0.25_128_quant.tflite │ ├── mul.float32.tflite │ ├── padding.float32.tflite │ ├── prelu.float32.tflite │ ├── relu.float32.tflite │ ├── relu6-power.float32.tflite │ ├── relu6.float32.tflite │ ├── reshape-conv.float32.tflite │ ├── reshape.float32.tflite │ ├── resize-bilinear.float32.tflite │ ├── resize-nearest-neighbor.float32.tflite │ ├── sigmoid.float32.tflite │ ├── softmax.float32.tflite │ ├── split.float32.tflite │ ├── squared-diff.float32.tflite │ ├── stridedslice-beginmask.float32.tflite │ ├── stridedslice-endmask.float32.tflite │ ├── stridedslice-stride.float32.tflite │ ├── stridedslice.float32.tflite │ ├── sub.float32.tflite │ ├── transpose.float32.tflite │ ├── transposeconv-samepad-stride2.float32.tflite │ ├── transposeconv-samepad.float32.tflite │ ├── transposeconv-validpad-stride2.float32.tflite │ └── transposeconv-validpad.float32.tflite ├── docs ├── contribution-guide.md ├── data-layout-semantic.md ├── faq.md ├── how-to-enable-new-operator.md ├── images │ └── propagate-nasnet.jpg └── release-notes.md ├── pyproject.toml ├── requirements.txt ├── scripts ├── build-wheel.sh ├── open-github-connection.sh ├── source-me.sh └── upload-pip.sh ├── setup.cfg ├── tests ├── test_cmd.py ├── test_explicit_layout.py ├── test_layout.py ├── test_mapping.py ├── test_networks.py ├── test_ops.py ├── test_padding.py ├── test_quantize.py └── test_utils.py └── tflite2onnx ├── __init__.py ├── common.py ├── convert.py ├── graph.py ├── layout.py ├── mapping.py ├── model.py ├── op ├── __init__.py ├── activation.py ├── binary.py ├── common.py ├── concat.py ├── conv.py ├── fullyconnected.py ├── padding.py ├── pooling.py ├── quantize.py ├── reduce.py ├── reshape.py ├── resize.py ├── rsqrt.py ├── slice.py ├── softmax.py ├── split.py ├── squared_difference.py ├── transpose.py └── unary.py ├── quantize.py ├── tensor.py └── utils.py /.github/ISSUE_TEMPLATE/bug-report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: bug 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | Please check the [FAQ](https://github.com/zhenhuaw-me/tflite2onnx/blob/master/docs/faq.md) 13 | and search around before creating a real issue. 14 | 15 | 16 | **To Reproduce** 17 | 1. Please help to narrow down the TFLite model as far as you can. 18 | 2. Your Python code to convert the model. 19 | 3. Attach the TFLite model. 20 | 4. Attach debug log, please rerun with logging config below. 21 | ```python 22 | logging.basicConfig(format='%(asctime)s %(levelname).1s [%(name)s][%(filename)s:%(lineno)d] %(message)s', level=logging.DEBUG) 23 | ``` 24 | 25 | **Version** 26 | Which version are you using? Example `v0.3.0`. 27 | 28 | Have you tried with latest `master` branch? 29 | 30 | 31 | **Additional context** 32 | Add any other context about the problem here. 33 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/request-operator.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Request new operator 3 | about: Request to enable support for new TensorFlow Lite operator 4 | title: 'Operator request:' 5 | labels: Operator 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Note**: Currently, we only accept [TensorFlow Lite builtin operators](https://jackwish.net/tflite/docs/BuiltinOperator.m.html) request. 11 | 12 | **What is the TensorFlow Lite operator you need** 13 | (Please attach a tiny TFLite model `*.tflite`, which contains the operator you need ONLY.) 14 | 15 | 16 | **What kind of service you are trying to deploy your model?** 17 | 18 | 19 | 20 | **Would you like to contribute the operator?** 21 | 22 | 23 | 24 | **Additional context** 25 | Add any other context about the problem here. 26 | -------------------------------------------------------------------------------- /.github/codecov.yml: -------------------------------------------------------------------------------- 1 | coverage: 2 | status: 3 | project: 4 | default: 5 | threshold: 1% 6 | target: auto 7 | base: auto 8 | branches: 9 | - master 10 | only_pulls: false 11 | 12 | -------------------------------------------------------------------------------- /.github/workflows/build.yaml: -------------------------------------------------------------------------------- 1 | name: Build and Test 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | 7 | test: 8 | name: Build and Test 9 | runs-on: ubuntu-latest 10 | 11 | steps: 12 | - uses: actions/checkout@v2 13 | 14 | - name: Set up Python 3.8 15 | uses: actions/setup-python@v2 16 | with: 17 | python-version: 3.8 18 | 19 | - name: Update pip 20 | run: python -m pip install --upgrade pip 21 | 22 | - name: Build package 23 | run: scripts/build-wheel.sh 24 | shell: bash 25 | 26 | - name: Install package 27 | run: pip install -U assets/dist/tflite2onnx-*.whl 28 | 29 | - name: Install development dependencies 30 | run: pip install -r requirements.txt 31 | 32 | - name: Testing (collecting coverage data) 33 | run: coverage run --source=./tflite2onnx -m pytest 34 | 35 | # - name: SSH via Ngrok if fail 36 | # if: ${{ failure() }} 37 | # env: 38 | # # Find token in: https://dashboard.ngrok.com/get-started/setup 39 | # NGROK_TOKEN: ${{ secrets.NGROK_TOKEN }} 40 | # NGROK_LOCAL_PASS: ${{ secrets.NGROK_LOCAL_PASS }} 41 | # run: scripts/open-github-connection.sh 42 | # - name: Live instance if fail 43 | # if: ${{ failure() }} 44 | # run: sleep 1h 45 | 46 | - name: Upload coverage report 47 | continue-on-error: true 48 | env: 49 | CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} 50 | run: | 51 | coverage xml 52 | bash <(curl -s https://codecov.io/bash) 53 | -------------------------------------------------------------------------------- /.github/workflows/codeql-analysis.yml: -------------------------------------------------------------------------------- 1 | # For most projects, this workflow file will not need changing; you simply need 2 | # to commit it to your repository. 3 | # 4 | # You may wish to alter this file to override the set of languages analyzed, 5 | # or to provide custom queries or build logic. 6 | # 7 | # ******** NOTE ******** 8 | # We have attempted to detect the languages in your repository. Please check 9 | # the `language` matrix defined below to confirm you have the correct set of 10 | # supported CodeQL languages. 11 | # 12 | name: "CodeQL" 13 | 14 | on: 15 | push: 16 | branches: [ master ] 17 | pull_request: 18 | # The branches below must be a subset of the branches above 19 | branches: [ master ] 20 | schedule: 21 | - cron: '44 6 * * 4' 22 | 23 | jobs: 24 | analyze: 25 | name: Analyze 26 | runs-on: ubuntu-latest 27 | permissions: 28 | actions: read 29 | contents: read 30 | security-events: write 31 | 32 | strategy: 33 | fail-fast: false 34 | matrix: 35 | language: [ 'python' ] 36 | # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python', 'ruby' ] 37 | # Learn more about CodeQL language support at https://git.io/codeql-language-support 38 | 39 | steps: 40 | - name: Checkout repository 41 | uses: actions/checkout@v2 42 | 43 | # Initializes the CodeQL tools for scanning. 44 | - name: Initialize CodeQL 45 | uses: github/codeql-action/init@v1 46 | with: 47 | languages: ${{ matrix.language }} 48 | # If you wish to specify custom queries, you can do so here or in a config file. 49 | # By default, queries listed here will override any specified in a config file. 50 | # Prefix the list here with "+" to use these queries and those in the config file. 51 | # queries: ./path/to/local/query, your-org/your-repo/queries@main 52 | 53 | # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). 54 | # If this step fails, then you should remove it and run the build manually (see below) 55 | - name: Autobuild 56 | uses: github/codeql-action/autobuild@v1 57 | 58 | # ℹ️ Command-line programs to run using the OS shell. 59 | # 📚 https://git.io/JvXDl 60 | 61 | # ✏️ If the Autobuild fails above, remove it and uncomment the following three lines 62 | # and modify them (or add more) to build your code if your project 63 | # uses a compiled language 64 | 65 | #- run: | 66 | # make bootstrap 67 | # make release 68 | 69 | - name: Perform CodeQL Analysis 70 | uses: github/codeql-action/analyze@v1 71 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | # This workflows will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | name: Upload Python Package 5 | 6 | on: 7 | release: 8 | types: [created] 9 | 10 | jobs: 11 | deploy: 12 | 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - uses: actions/checkout@v2 17 | - name: Set up Python 3.8 18 | uses: actions/setup-python@v2 19 | with: 20 | python-version: 3.8 21 | 22 | - name: Update pip 23 | run: python -m pip install --upgrade pip 24 | 25 | - name: Build package 26 | run: scripts/build-wheel.sh 27 | shell: bash 28 | 29 | - name: Publish 30 | env: 31 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 32 | TWINE_PASSWORD: ${{ secrets.PYPI_RELEASE_TOKEN }} 33 | run: | 34 | pip install twine 35 | twine upload --verbose assets/dist/* 36 | -------------------------------------------------------------------------------- /.github/workflows/rebase.yml: -------------------------------------------------------------------------------- 1 | name: Automatic Rebase 2 | 3 | on: 4 | issue_comment: 5 | types: [created] 6 | 7 | jobs: 8 | rebase: 9 | name: Rebase 10 | # github.event.comment.author_association == 'MEMBER' 11 | # https://docs.github.com/en/graphql/reference/enums#commentauthorassociation 12 | if: github.event.issue.pull_request != '' && contains(github.event.comment.body, '/rebase') 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - name: Checkout the latest code 17 | uses: actions/checkout@v2 18 | with: 19 | fetch-depth: 0 20 | - name: Automatic Rebase 21 | uses: cirrus-actions/rebase@1.3.1 22 | env: 23 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 24 | -------------------------------------------------------------------------------- /.github/workflows/sanity.yaml: -------------------------------------------------------------------------------- 1 | name: Sanity 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | 7 | pychecker: 8 | name: Check Python Style 9 | runs-on: ubuntu-latest 10 | 11 | steps: 12 | - uses: actions/checkout@v2 13 | - name: Set up Python 3.8 14 | uses: actions/setup-python@v2 15 | with: 16 | python-version: 3.8 17 | - name: Install dependencies 18 | run: | 19 | python -m pip install --upgrade pip 20 | pip install flake8 21 | - name: Check Python Style 22 | run: | 23 | flake8 24 | -------------------------------------------------------------------------------- /.github/workflows/stale.yml: -------------------------------------------------------------------------------- 1 | name: "Close stale issues" 2 | on: 3 | schedule: 4 | - cron: "30 1 * * *" 5 | 6 | jobs: 7 | stale: 8 | runs-on: ubuntu-latest 9 | steps: 10 | - uses: actions/stale@v3 11 | with: 12 | repo-token: ${{ secrets.GITHUB_TOKEN }} 13 | days-before-stale: 30 14 | days-before-close: 14 15 | stale-issue-message: 'This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 14 days.' 16 | stale-issue-label: 'stale' 17 | exempt-issue-labels: 'Story,help wanted,Enhancement,bug' 18 | stale-pr-message: 'This PR is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 14 days.' 19 | stale-pr-label: 'stale' 20 | exempt-pr-labels: 'Story,help wanted,Enhancement,bug' 21 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | 132 | # tflite2onnx 133 | *.onnx 134 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright (c) 2020 王振华 (Zhenhua WANG) 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | tflite2onnx - Convert TensorFlow Lite models to ONNX 2 | ==================================================== 3 | 4 | [![Build and Test](https://github.com/zhenhuaw-me/tflite2onnx/workflows/Build%20and%20Test/badge.svg)](https://github.com/zhenhuaw-me/tflite2onnx/actions?query=workflow%3A%22Build+and+Test%22) 5 | [![Sanity](https://github.com/zhenhuaw-me/tflite2onnx/workflows/Sanity/badge.svg)](https://github.com/zhenhuaw-me/tflite2onnx/actions?query=workflow%3ASanity) 6 | [![Coverage](https://codecov.io/gh/zhenhuaw-me/tflite2onnx/branch/master/graph/badge.svg)](https://codecov.io/gh/zhenhuaw-me/tflite2onnx) 7 | [![Download](https://img.shields.io/pypi/dm/tflite2onnx)](https://img.shields.io/pypi/dm/tflite2onnx) 8 | 9 | `tflite2onnx` converts TensorFlow Lite (TFLite) models (`*.tflite`) to ONNX models (`*.onnx`), 10 | with data layout and quantization semantic properly handled (check the [introduction blog][intro] for detail). 11 | 12 | If you'd like to convert a TensorFlow model (frozen graph `*.pb`, `SavedModel` 13 | or whatever) to ONNX, try [`tf2onnx`](https://github.com/onnx/tensorflow-onnx). 14 | Or, you can firstly [convert][tf2tflite] it to a TFLite (`*.tflite`) model, 15 | and then convert the TFLite model to ONNX. 16 | 17 | 18 | ## Call for Contribution 19 | 20 | Currently we have [14 open issues](https://github.com/zhenhuaw-me/tflite2onnx/issues). 21 | I am sorry that I don't have more bandwidth to work on them. 22 | *Please help to contribute to this project!* See [Contributing](#contributing) below. 23 | 24 | Started from the beginning, I have written [docs](#documentation) to help you ramp up this project. 25 | Therefore I am not the critical path of this project. 26 | 27 | Microsoft has implemented another _TensorFlow Lite to ONNX model converter_ in `tf2onnx` 28 | [at Feb 2021](https://github.com/onnx/sigs/blob/master/converters/meetings/019-20210212.md) 29 | (we open sourced `tflite2onnx` in May 2020). `tf2onnx` seems to able to convert Quantization 30 | just like us, and it seems able to convert RNN networks which we are not supported yet. 31 | Yet, due to the architecture, I think it's non-trivial to fix a bug, which means that, 32 | `tflite2onnx` is a rather better choice if you are blocked and don't wait for a fix from the maintainer. 33 | 34 | 35 | ## Installation 36 | 37 | Install via [pip][pypi] `pip install tflite2onnx`. 38 | 39 | Or install from source to get latest features (please try out with [virtualenv](https://virtualenv.pypa.io)): 40 | 41 | 1. Download the repo: `git clone https://github.com/zhenhuaw-me/tflite2onnx.git && cd tflite2onnx` 42 | 2. Build the package: `./scripts/build-wheel.sh` 43 | 3. Install the built package: `pip install assets/dist/tflite2onnx-*.whl` 44 | 45 | Or you can just add the code tree to your `$PYTHONPATH`. 46 | (Command line tool is not avaiable in this mode.) 47 | 48 | ```sh 49 | export PYTHONPATH=$(pwd):${PYTHONPATH} 50 | ``` 51 | 52 | 53 | ## Usage 54 | 55 | ### Python Interface 56 | 57 | ```py 58 | import tflite2onnx 59 | 60 | tflite_path = '/path/to/original/tflite/model' 61 | onnx_path = '/path/to/save/converted/onnx/model' 62 | 63 | tflite2onnx.convert(tflite_path, onnx_path) 64 | ``` 65 | 66 | `tflite2onnx` now supports *explicit layout*, check the 67 | [test example](https://github.com/zhenhuaw-me/tflite2onnx/blob/master/tests/test_explicit_layout.py). 68 | 69 | 70 | ### Command Line 71 | 72 | ```sh 73 | tflite2onnx /path/to/original/tflite/model /path/to/save/converted/onnx/model 74 | ``` 75 | 76 | 77 | ## Documentation 78 | 79 | * [FAQ](docs/faq.md) 80 | * [Release note](docs/release-notes.md) 81 | * [Contribution guide](docs/contribution-guide.md) 82 | * [Introduction blog - the background, design and implementation][intro] 83 | * [How to enable a new operator](docs/how-to-enable-new-operator.md) 84 | * [Data layout semantic](docs/data-layout-semantic.md) 85 | 86 | 87 | ## Contributing 88 | 89 | * If something seems wrong to you, [report bugs](https://github.com/zhenhuaw-me/tflite2onnx/issues/new?assignees=&labels=bug&template=bug-report.md&title=). 90 | * If some operators are not supported yet, you may [request a new operator](https://github.com/zhenhuaw-me/tflite2onnx/issues/new?assignees=&labels=operator%2C+help+wanted&template=request-operator.md&title=Operator+request%3A). 91 | * It would be great if you can help to enable new operators, please join us with [How to enable a new operator](docs/how-to-enable-new-operator.md). 92 | * Feel free to open any other related discussions. 93 | 94 | Check [contribution guide](docs/contribution-guide.md) for more. 95 | 96 | 97 | ## License 98 | 99 | Apache License Version 2.0. 100 | 101 | [intro]: https://zhenhuaw.me/blog/2020/Convert-TensorFlow-Lite-models-to-ONNX.html 102 | [pypi]: https://pypi.org/project/tflite2onnx 103 | [github]: https://github.com/zhenhuaw-me/tflite2onnx 104 | [tf2tflite]: https://www.tensorflow.org/lite/convert 105 | -------------------------------------------------------------------------------- /_config.yml: -------------------------------------------------------------------------------- 1 | theme: jekyll-theme-cayman -------------------------------------------------------------------------------- /assets/tests/abs-add-rsqrt.float32.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenhuaw-me/tflite2onnx/32d4d5731be506d5797c648fa73a91b36a80222b/assets/tests/abs-add-rsqrt.float32.tflite -------------------------------------------------------------------------------- /assets/tests/abs-sqrt.float32.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenhuaw-me/tflite2onnx/32d4d5731be506d5797c648fa73a91b36a80222b/assets/tests/abs-sqrt.float32.tflite -------------------------------------------------------------------------------- /assets/tests/abs.float32.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenhuaw-me/tflite2onnx/32d4d5731be506d5797c648fa73a91b36a80222b/assets/tests/abs.float32.tflite -------------------------------------------------------------------------------- /assets/tests/add-broadcast.float32.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenhuaw-me/tflite2onnx/32d4d5731be506d5797c648fa73a91b36a80222b/assets/tests/add-broadcast.float32.tflite -------------------------------------------------------------------------------- /assets/tests/add-broadcast2.float32.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenhuaw-me/tflite2onnx/32d4d5731be506d5797c648fa73a91b36a80222b/assets/tests/add-broadcast2.float32.tflite -------------------------------------------------------------------------------- /assets/tests/add-relu.float32.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenhuaw-me/tflite2onnx/32d4d5731be506d5797c648fa73a91b36a80222b/assets/tests/add-relu.float32.tflite -------------------------------------------------------------------------------- /assets/tests/add.float32.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenhuaw-me/tflite2onnx/32d4d5731be506d5797c648fa73a91b36a80222b/assets/tests/add.float32.tflite -------------------------------------------------------------------------------- /assets/tests/avgpool-concat.float32.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenhuaw-me/tflite2onnx/32d4d5731be506d5797c648fa73a91b36a80222b/assets/tests/avgpool-concat.float32.tflite -------------------------------------------------------------------------------- /assets/tests/avgpooling.float32.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenhuaw-me/tflite2onnx/32d4d5731be506d5797c648fa73a91b36a80222b/assets/tests/avgpooling.float32.tflite -------------------------------------------------------------------------------- /assets/tests/concat.float32.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenhuaw-me/tflite2onnx/32d4d5731be506d5797c648fa73a91b36a80222b/assets/tests/concat.float32.tflite -------------------------------------------------------------------------------- /assets/tests/concat2.float32.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenhuaw-me/tflite2onnx/32d4d5731be506d5797c648fa73a91b36a80222b/assets/tests/concat2.float32.tflite -------------------------------------------------------------------------------- /assets/tests/conv-dilation.float32.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenhuaw-me/tflite2onnx/32d4d5731be506d5797c648fa73a91b36a80222b/assets/tests/conv-dilation.float32.tflite -------------------------------------------------------------------------------- /assets/tests/conv-quant-fp16.float32.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenhuaw-me/tflite2onnx/32d4d5731be506d5797c648fa73a91b36a80222b/assets/tests/conv-quant-fp16.float32.tflite -------------------------------------------------------------------------------- /assets/tests/conv-relu.float32.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenhuaw-me/tflite2onnx/32d4d5731be506d5797c648fa73a91b36a80222b/assets/tests/conv-relu.float32.tflite -------------------------------------------------------------------------------- /assets/tests/conv-relu.uint8.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenhuaw-me/tflite2onnx/32d4d5731be506d5797c648fa73a91b36a80222b/assets/tests/conv-relu.uint8.tflite -------------------------------------------------------------------------------- /assets/tests/conv-relu6.float32.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenhuaw-me/tflite2onnx/32d4d5731be506d5797c648fa73a91b36a80222b/assets/tests/conv-relu6.float32.tflite -------------------------------------------------------------------------------- /assets/tests/conv-reshape-multiple-conv.float32.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenhuaw-me/tflite2onnx/32d4d5731be506d5797c648fa73a91b36a80222b/assets/tests/conv-reshape-multiple-conv.float32.tflite -------------------------------------------------------------------------------- /assets/tests/conv-reshape.float32.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenhuaw-me/tflite2onnx/32d4d5731be506d5797c648fa73a91b36a80222b/assets/tests/conv-reshape.float32.tflite -------------------------------------------------------------------------------- /assets/tests/conv-stride.float32.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenhuaw-me/tflite2onnx/32d4d5731be506d5797c648fa73a91b36a80222b/assets/tests/conv-stride.float32.tflite -------------------------------------------------------------------------------- /assets/tests/conv-transpose_relu.float32.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenhuaw-me/tflite2onnx/32d4d5731be506d5797c648fa73a91b36a80222b/assets/tests/conv-transpose_relu.float32.tflite -------------------------------------------------------------------------------- /assets/tests/conv.float32.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenhuaw-me/tflite2onnx/32d4d5731be506d5797c648fa73a91b36a80222b/assets/tests/conv.float32.tflite -------------------------------------------------------------------------------- /assets/tests/conv.uint8.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenhuaw-me/tflite2onnx/32d4d5731be506d5797c648fa73a91b36a80222b/assets/tests/conv.uint8.tflite -------------------------------------------------------------------------------- /assets/tests/depthwise-conv-stride.float32.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenhuaw-me/tflite2onnx/32d4d5731be506d5797c648fa73a91b36a80222b/assets/tests/depthwise-conv-stride.float32.tflite -------------------------------------------------------------------------------- /assets/tests/depthwise-conv.float32.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenhuaw-me/tflite2onnx/32d4d5731be506d5797c648fa73a91b36a80222b/assets/tests/depthwise-conv.float32.tflite -------------------------------------------------------------------------------- /assets/tests/depthwise-conv.uint8.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenhuaw-me/tflite2onnx/32d4d5731be506d5797c648fa73a91b36a80222b/assets/tests/depthwise-conv.uint8.tflite -------------------------------------------------------------------------------- /assets/tests/fullyconnected-relu6.float32.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenhuaw-me/tflite2onnx/32d4d5731be506d5797c648fa73a91b36a80222b/assets/tests/fullyconnected-relu6.float32.tflite -------------------------------------------------------------------------------- /assets/tests/fullyconnected.float32.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenhuaw-me/tflite2onnx/32d4d5731be506d5797c648fa73a91b36a80222b/assets/tests/fullyconnected.float32.tflite -------------------------------------------------------------------------------- /assets/tests/maxpooling.float32.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenhuaw-me/tflite2onnx/32d4d5731be506d5797c648fa73a91b36a80222b/assets/tests/maxpooling.float32.tflite -------------------------------------------------------------------------------- /assets/tests/mean-keepdims.float32.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenhuaw-me/tflite2onnx/32d4d5731be506d5797c648fa73a91b36a80222b/assets/tests/mean-keepdims.float32.tflite -------------------------------------------------------------------------------- /assets/tests/mean.float32.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenhuaw-me/tflite2onnx/32d4d5731be506d5797c648fa73a91b36a80222b/assets/tests/mean.float32.tflite -------------------------------------------------------------------------------- /assets/tests/mirror-pad.int32.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenhuaw-me/tflite2onnx/32d4d5731be506d5797c648fa73a91b36a80222b/assets/tests/mirror-pad.int32.tflite -------------------------------------------------------------------------------- /assets/tests/mobilenet_v1_0.25_128.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenhuaw-me/tflite2onnx/32d4d5731be506d5797c648fa73a91b36a80222b/assets/tests/mobilenet_v1_0.25_128.tflite -------------------------------------------------------------------------------- /assets/tests/mobilenet_v1_0.25_128_quant.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenhuaw-me/tflite2onnx/32d4d5731be506d5797c648fa73a91b36a80222b/assets/tests/mobilenet_v1_0.25_128_quant.tflite -------------------------------------------------------------------------------- /assets/tests/mul.float32.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenhuaw-me/tflite2onnx/32d4d5731be506d5797c648fa73a91b36a80222b/assets/tests/mul.float32.tflite -------------------------------------------------------------------------------- /assets/tests/padding.float32.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenhuaw-me/tflite2onnx/32d4d5731be506d5797c648fa73a91b36a80222b/assets/tests/padding.float32.tflite -------------------------------------------------------------------------------- /assets/tests/prelu.float32.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenhuaw-me/tflite2onnx/32d4d5731be506d5797c648fa73a91b36a80222b/assets/tests/prelu.float32.tflite -------------------------------------------------------------------------------- /assets/tests/relu.float32.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenhuaw-me/tflite2onnx/32d4d5731be506d5797c648fa73a91b36a80222b/assets/tests/relu.float32.tflite -------------------------------------------------------------------------------- /assets/tests/relu6-power.float32.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenhuaw-me/tflite2onnx/32d4d5731be506d5797c648fa73a91b36a80222b/assets/tests/relu6-power.float32.tflite -------------------------------------------------------------------------------- /assets/tests/relu6.float32.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenhuaw-me/tflite2onnx/32d4d5731be506d5797c648fa73a91b36a80222b/assets/tests/relu6.float32.tflite -------------------------------------------------------------------------------- /assets/tests/reshape-conv.float32.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenhuaw-me/tflite2onnx/32d4d5731be506d5797c648fa73a91b36a80222b/assets/tests/reshape-conv.float32.tflite -------------------------------------------------------------------------------- /assets/tests/reshape.float32.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenhuaw-me/tflite2onnx/32d4d5731be506d5797c648fa73a91b36a80222b/assets/tests/reshape.float32.tflite -------------------------------------------------------------------------------- /assets/tests/resize-bilinear.float32.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenhuaw-me/tflite2onnx/32d4d5731be506d5797c648fa73a91b36a80222b/assets/tests/resize-bilinear.float32.tflite -------------------------------------------------------------------------------- /assets/tests/resize-nearest-neighbor.float32.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenhuaw-me/tflite2onnx/32d4d5731be506d5797c648fa73a91b36a80222b/assets/tests/resize-nearest-neighbor.float32.tflite -------------------------------------------------------------------------------- /assets/tests/sigmoid.float32.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenhuaw-me/tflite2onnx/32d4d5731be506d5797c648fa73a91b36a80222b/assets/tests/sigmoid.float32.tflite -------------------------------------------------------------------------------- /assets/tests/softmax.float32.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenhuaw-me/tflite2onnx/32d4d5731be506d5797c648fa73a91b36a80222b/assets/tests/softmax.float32.tflite -------------------------------------------------------------------------------- /assets/tests/split.float32.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenhuaw-me/tflite2onnx/32d4d5731be506d5797c648fa73a91b36a80222b/assets/tests/split.float32.tflite -------------------------------------------------------------------------------- /assets/tests/squared-diff.float32.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenhuaw-me/tflite2onnx/32d4d5731be506d5797c648fa73a91b36a80222b/assets/tests/squared-diff.float32.tflite -------------------------------------------------------------------------------- /assets/tests/stridedslice-beginmask.float32.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenhuaw-me/tflite2onnx/32d4d5731be506d5797c648fa73a91b36a80222b/assets/tests/stridedslice-beginmask.float32.tflite -------------------------------------------------------------------------------- /assets/tests/stridedslice-endmask.float32.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenhuaw-me/tflite2onnx/32d4d5731be506d5797c648fa73a91b36a80222b/assets/tests/stridedslice-endmask.float32.tflite -------------------------------------------------------------------------------- /assets/tests/stridedslice-stride.float32.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenhuaw-me/tflite2onnx/32d4d5731be506d5797c648fa73a91b36a80222b/assets/tests/stridedslice-stride.float32.tflite -------------------------------------------------------------------------------- /assets/tests/stridedslice.float32.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenhuaw-me/tflite2onnx/32d4d5731be506d5797c648fa73a91b36a80222b/assets/tests/stridedslice.float32.tflite -------------------------------------------------------------------------------- /assets/tests/sub.float32.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenhuaw-me/tflite2onnx/32d4d5731be506d5797c648fa73a91b36a80222b/assets/tests/sub.float32.tflite -------------------------------------------------------------------------------- /assets/tests/transpose.float32.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenhuaw-me/tflite2onnx/32d4d5731be506d5797c648fa73a91b36a80222b/assets/tests/transpose.float32.tflite -------------------------------------------------------------------------------- /assets/tests/transposeconv-samepad-stride2.float32.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenhuaw-me/tflite2onnx/32d4d5731be506d5797c648fa73a91b36a80222b/assets/tests/transposeconv-samepad-stride2.float32.tflite -------------------------------------------------------------------------------- /assets/tests/transposeconv-samepad.float32.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenhuaw-me/tflite2onnx/32d4d5731be506d5797c648fa73a91b36a80222b/assets/tests/transposeconv-samepad.float32.tflite -------------------------------------------------------------------------------- /assets/tests/transposeconv-validpad-stride2.float32.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenhuaw-me/tflite2onnx/32d4d5731be506d5797c648fa73a91b36a80222b/assets/tests/transposeconv-validpad-stride2.float32.tflite -------------------------------------------------------------------------------- /assets/tests/transposeconv-validpad.float32.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenhuaw-me/tflite2onnx/32d4d5731be506d5797c648fa73a91b36a80222b/assets/tests/transposeconv-validpad.float32.tflite -------------------------------------------------------------------------------- /docs/contribution-guide.md: -------------------------------------------------------------------------------- 1 | Contribution Guide 2 | ================== 3 | 4 | Welcome and thank you for reaching this contribution guide. 5 | Materials are split into sections, just jump to topics you are interested in. 6 | 7 | 8 | ## Reporting Issues 9 | 10 | * If something seems wrong to you, [report bugs](https://github.com/zhenhuaw-me/tflite2onnx/issues/new?assignees=&labels=bug&template=bug-report.md&title=). 11 | * If some operators are not supported yet, you may [request a new operator](https://github.com/zhenhuaw-me/tflite2onnx/issues/new?assignees=&labels=operator%2C+help+wanted&template=request-operator.md&title=Operator+request%3A). 12 | * Feel free to open any other related discussions. 13 | 14 | It's high recommended to attach a narrow down-ed TFLite model and 15 | debug logs (generate it with `tflite2onnx.enableDebugLog()`). 16 | 17 | 18 | ## Contributing Code 19 | 20 | We work on enabling operators most of the time. 21 | In this way, there is [a dedicated step-by-step guide](docs/how-to-enable-new-operator.md) 22 | to help you enable new operators. 23 | 24 | Please help to upstream your operator enabling. 25 | _Unus pro omnibus, omnes pro uno._ 26 | 27 | We have GitHub Actions based CI for pull requests. 28 | I know sometimes it's annoying but it's very important as it help us to protect the code. 29 | In this way we can keep the quality and save time of debugging. 30 | 31 | In general we need: 32 | * Dedicated test of the new operator. 33 | * It would be great if we can have models to test different attributes. 34 | * `pytest` at root directory to run all test. 35 | * Clean code. 36 | * Code style. 37 | * `flake8` at root directory to check. 38 | * No significant code coverage drop (guarded by `CodeCov`). 39 | * Automatically checked when open/update PR. 40 | 41 | Like many other python packages, you can set `PYTHONPATH` to `tflite2onnx` 42 | instead of building and installing to try our your changed. 43 | -------------------------------------------------------------------------------- /docs/data-layout-semantic.md: -------------------------------------------------------------------------------- 1 | Data Layout Semantic Conversion 2 | =============================== 3 | 4 | This document, which is originally written in the [introducing blog](https://zhenhuaw.me/blog/2020/Convert-TensorFlow-Lite-models-to-ONNX.html), covers the _layout semantic divergence_ between TensorFlow Lite (TFLite) models which is NHWC and ONNX which is NCHW. 5 | 6 | The data layout format of TFLite has not been mentioned in either the document or the [model representation](https://github.com/tensorflow/tensorflow/blob/v2.1.0/tensorflow/lite/schema/schema.fbs) but in the implicit agreement of the [TFLite converter](https://github.com/tensorflow/tensorflow/blob/v2.1.0/tensorflow/lite/toco/import_tensorflow.cc#L795) (the TensorFlow model needs to be NHWC) and the [kernels](https://github.com/tensorflow/tensorflow/blob/v2.1.0/tensorflow/lite/kernels/conv.cc#L251). On the contrary, ONNX explicitly declares that it uses NCHW in both [operator representation](https://github.com/onnx/onnx/blob/6bdac246617682f9696f0dac40362ef4f4de2cde/onnx/defs/nn/defs.cc#L713) and document (which is generated from operator representation). 7 | 8 | `tflite2onnx` introduces _propagation based approach_ to handle the layout issue in general, together with some other mechanisms for dedicated corner cases. 9 | 10 | 11 | ## Propagation based Approach 12 | 13 | The _propagation based approach_ is introduced to resolve this by propagating the *layout semantic divergence* across the graph, in which way the _transpose pattern_ is not needed. 14 | 15 | By default (for most cases), given a graph, some of the tensors have implicit layout semantic, e.g. tensors that are connected to `Conv` directly, while others are not, e.g. `Abs` and `Add`. The later ones are transparent to layout, where _transparent_ means all of the tensors that connected to the operator mush have the same layout semantic or don't hold such semantic. 16 | 17 | So when an operator that is _transparent_ to layout is connected to an operator that has implicit layout tensors, then all tensors of the _transparent_ operator have the same layout semantic as the tensor that connecting these two operators, named as _propagation_. 18 | 19 | For example, when converting the TFLite graph (omitted _kernel_ and _bias_) 20 | $$\left< A_{nhwc} \right> \rightarrow [Conv] \rightarrow \left< B_{nhwc} \right> \rightarrow [Abs] \rightarrow \left< C_{?} \right>$$ 21 | to ONNX, tensor $$\left< A_{nhwc} \right>$$ becomes $$\left< A_{nchw} \right>$$ and $$\left< B_{nhwc} \right>$$ becomes $$\left< B_{nchw} \right>$$. Hence, the output $$\left< C \right>$$ of $$[Abs]$$ should have the same format as the input $$\left< B \right>$$. _Propagation based approach_ propagates the conversion from $$\left< B \right>$$ to $$\left< C \right>$$. Therefore we have the ONNX graph 22 | $$\left< A_{nchw} \right> \rightarrow [Conv] \rightarrow \left< B_{nchw} \right> \rightarrow [Abs] \rightarrow \left< C_{nchw} \right>$$, where no additional operators nor tensors are introduced. 23 | 24 | During layout propagation, the layout transformation permutes the shape of tensors if they are _activations_, i.e. value info in ONNX, and transposes the data of _weights_ in addition, i.e. initializer in ONNX. 25 | 26 | In practice, operators are categorized into four (as marked in _Figure 5_): 27 | 28 | * *Implicit*: operators have *layout semantic divergence*, e.g. `Conv`. They are the source of *layout semantic divergence*. 29 | * *Transparent*: operators that are insensitive to layout, e.g. `Abs`. If any tensor has *layout semantic divergence*, propagate it to all tensors that are connected to such operators. 30 | * *Attribute*: operators that can propagate *layout semantic divergence* just like _Transparent_, but have layout sensitive attributes that need special handling, e.g. attribute `axis` of `Concat`. An additional pass after propagation to adjust these attributes is required. 31 | * *Terminate*: operators that don't have and cannot propagate *layout semantic divergence*, e.g. `Reshape`. The propagation across the graph terminates at such operators. 32 | 33 | ![Part the ONNX model generated by *propagation based approach*.](images/propagate-nasnet.jpg) 34 | *Figure 1: Part of the ONNX model generated by propagation based approach of TFLite2ONNX* 35 | 36 | When propagating *layout semantic divergence* across the graph, for a particular operator: if it is *Transparent* or *Attribute*, propagate *layout semantic divergence* among its tensors; if it is *Implicit* or *Terminate*, terminates the propagation in this direction. *Figure 1* is part of the ONNX model generated by *propagation based approach* from the [NASNet](https://tfhub.dev/tensorflow/lite-model/nasnet/mobile/1/metadata/1) TFLite model. 37 | 38 | 39 | ## Explicit Layout 40 | 41 | With *propagation based approach*, the converted ONNX model includes zero effort to handle _layout semantic divergence_, i.e. no additional operators or tensors are introduced. 42 | 43 | However, sometimes there could be incompatible layouts. Consider `Reshape`, which is *Terminate*, as below. If $$\left< A \right>$$ is propagated while other tensors are not, the output layout could be unexpected as the user may assume the dimensions of $$\left< B \right>$$ has something to do with $$\left< A \right>$$. (*Transpose based approach* doesn't have this issue as its layout is TFLite style at the model level, *layout semantic divergence* is handled inside the $$[Transpose] \rightarrow [OP] \rightarrow [Transpose]$$ pattern.) 44 | 45 | $$ 46 | \left. 47 | \begin{aligned} 48 | \{Graph\} \rightarrow \left< A \right> \rightarrow [Reshape] \rightarrow \left< B \right> \\ 49 | \left< C \right> \\ 50 | \end{aligned} 51 | \right\} \rightarrow [Concat] \rightarrow \left< D \right> 52 | $$ 53 | 54 | *Explicit layout* is introduced to handle such a scenario. Users can feed a mapping of $$\{Tensor\ name : tuple(TFLite\ layout, ONNX\ layout)\}$$ that describes the data layout of TFLite and ONNX to TFLite2ONNX. And, it's flexible for the user to define the layout conversion for non-_Transparent_ operators. For example, we have performed the NHWC to NCHW layout conversion for a TFLite graph that has only an `Add` operator. 55 | 56 | 57 | ## Broadcast of Propagation 58 | 59 | Another problem is the [broadcast](https://numpy.org/doc/stable/user/basics.broadcasting.html) of binary operators such as `Add` (see [this issue](https://github.com/zhenhuaw-me/tflite2onnx/issues/13) for more). Taking the example below, in which tensor $$\left< B \right>$$ needs to be broadcasted. If $$\left< A \right>$$ is converted from NHWC to NCHW, i.e. $$\left< A_{(2 \times 5 \times 3 \times 4)} \right>$$, $$\left< B \right>$$ is no longer broadcastable in ONNX. Even worse, the _layout semantic divergence_ will fail when propagated to $$\left< B \right>$$ as $$\left< A \right>$$ and $$\left< B \right>$$ have different dimensions. 60 | 61 | $$ 62 | \left. 63 | \begin{aligned} 64 | \{Graph\} \rightarrow \left< A_{(2 \times 3\times 4 \times5)} \right> \\ 65 | \left< B_{(4 \times 5)} \right> \\ 66 | \end{aligned} 67 | \right\} \rightarrow [Add] \rightarrow \left< C \right> 68 | $$ 69 | 70 | To manage broadcasting in the ONNX model, `tflite2onnx` introduces the _Reshape pattern_. Any tensors like $$\left< B \right>$$ are reshaped to extend (inserting $$1$$) their dimensions to be equal with the other, such that propagation and broadcasting can correctly do their jobs. The example of the intermediate graph before propagation is as below. 71 | 72 | $$ 73 | \left. 74 | \begin{aligned} 75 | \{Graph\} \rightarrow \left< A_{(2 \times 3\times 4 \times5)} \right> \\ 76 | \left< B_{(4 \times 5)} \right> \rightarrow [Reshape] \rightarrow \left< B^{'}_{(1 \times 1 \times 4 \times 5)} \right>\\ 77 | \end{aligned} 78 | \right\} \rightarrow [Add] \rightarrow \left< C \right> 79 | $$ 80 | 81 | -------------------------------------------------------------------------------- /docs/faq.md: -------------------------------------------------------------------------------- 1 | Frequently Asked Questions 2 | ========================== 3 | 4 | _Just jump to the sections that you are interested in. 5 | Please help to raise issues if any of the document here is wrong._ 6 | 7 | 8 | ## Unsupported TFLite OP Error 9 | 10 | As of today, `tflite2onnx` supports about 31 TFLite operators (check 11 | the `tflite2onnx.getSupportedOperators()` API). However, there are 12 | 127 builtin operators (check via `tflite.BUILTIN_OPCODE2NAME`) in TFLite. 13 | For the operators that is unsupported, an error like below will be thrown 14 | 15 | ``` 16 | NotImplementedError: Unsupported TFLite OP: 123 {OPNAME} 17 | ``` 18 | 19 | Usually, we need to enable that operator in `tflite2onnx`. 20 | Please [report and contribute](contribution-guide.md)! 21 | 22 | However, sometimes, there are operators that are not _TFLite builtin operators_ 23 | in the original model. For example, an TensorFlow operator is added 24 | when converting TensorFlow model to TFLite like below. 25 | 26 | ```py 27 | converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, 28 | tf.lite.OpsSet.SELECT_TF_OPS] 29 | converter.allow_custom_ops = True 30 | ``` 31 | 32 | This is not supported currently as it requires significant end to end effort. 33 | To workaround it, you may need to replace complicate TensorFlow operators 34 | with [TFLite builtin operators](https://zhenhuaw.me/tflite/docs/BuiltinOperator.m.html), 35 | and then try again. 36 | 37 | 38 | 39 | ## FP16 Error When Converting 40 | 41 | Related issue: [#30](https://github.com/zhenhuaw-me/tflite2onnx/issues/30). 42 | 43 | As of TensorFlow `v2.3.0`, FP16 is not natively supported by TFLite. 44 | Operators such as [`Add`](https://github.com/tensorflow/tensorflow/blob/v2.3.0/tensorflow/lite/kernels/add.cc#L196) 45 | and [`Conv`](https://github.com/tensorflow/tensorflow/blob/v2.3.0/tensorflow/lite/kernels/conv.cc#L998) 46 | don't support FP16 - no related kernels. 47 | 48 | In practice, TensorFlow inserts `tf.Cast` to converter FP16 data to FP32 49 | for further computation. However, the MLIR based TensorFlow Lite converter 50 | desn't support `tf.Cast`. An example of converting FP16 `tf.math.Add` is as below. 51 | 52 | ``` 53 | :0: error: failed while converting: 'main': Ops that can be supported by the flex runtime (enabled via setting the -emit-select-tf-ops flag): 54 | tf.Cast {Truncate = false, device = ""} 55 | :0: note: see current operation: "func"() ( { 56 | ^bb0(%arg0: tensor<1x2x3x4xf16>, %arg1: tensor<1x2x3x4xf16>): // no predecessors 57 | %0 = "tf.Cast"(%arg0) {Truncate = false, device = ""} : (tensor<1x2x3x4xf16>) -> tensor<1x2x3x4xf32> 58 | %1 = "tf.Cast"(%arg1) {Truncate = false, device = ""} : (tensor<1x2x3x4xf16>) -> tensor<1x2x3x4xf32> 59 | %2 = "tfl.add"(%0, %1) {fused_activation_function = "NONE"} : (tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> 60 | "std.return"(%2) : (tensor<1x2x3x4xf32>) -> () 61 | }) {sym_name = "main", tf.entry_function = {control_outputs = "", inputs = "A,B", outputs = "Identity"}, type = (tensor<1x2x3x4xf16>, tensor<1x2x3x4xf16>) -> t 62 | ensor<1x2x3x4xf32>} : () -> () 63 | ``` 64 | 65 | In general, FP16 in a TFLite model exists due to 66 | [FP16 quantization](https://www.tensorflow.org/lite/performance/post_training_quantization#float16_quantization). 67 | As of today, I'd recommend to use 68 | [full integer quantization](https://www.tensorflow.org/lite/performance/post_training_quantization#full_integer_quantization) 69 | and quantization-aware training. 70 | Or keep the TensorFlow/TFLite model in FP32 format. 71 | 72 | 73 | ## FP16 Quantization Model 74 | 75 | Many people are using TFLite 76 | [FP16 quantization](https://www.tensorflow.org/lite/performance/post_training_quantization#float16_quantization), 77 | and some models ([example](https://github.com/zhenhuaw-me/tflite2onnx/issues/33)) 78 | are published in such format. 79 | 80 | The FP16 weights in these models will be converted to FP32 online by a TFLite 81 | operator `Dequantize`. In general, we convert TFLite `Dequantize` to ONNX 82 | [`DequantizeLinear`](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#DequantizeLinear-10). 83 | However, `DequantizeLinear` in ONNX supports only dequantize an integer 84 | (`uint8`, `int8`, `int32`). 85 | 86 | We enabled [*FP16 Quantizatoin Pattern Folding*](https://github.com/zhenhuaw-me/tflite2onnx/issues/35) 87 | to workaround this issue. In the resulted model, the FP16 tensors are converted into FP32. 88 | Be carefull when feed or retrieve data to and from the model. 89 | 90 | Still, I'd recommend 91 | [full integer quantization](https://www.tensorflow.org/lite/performance/post_training_quantization#full_integer_quantization) 92 | if possible. 93 | 94 | 95 | ## TFLite Model Contains Custom Operators 96 | 97 | Custom operator in TFLite requires 98 | [developer provided kernels](https://www.tensorflow.org/lite/guide/ops_custom#defining_the_kernel_in_the_tensorflow_lite_runtime). 99 | 100 | `tflite2onnx` doesn't support custom operator as the TFLite model file 101 | itself doesn't know how to perform the computation - which is the knowledge 102 | of the model owner. And we don't have plan to support yet. 103 | 104 | If your model contains custom operator, you may either break the model 105 | into several sub-models which have no custom model, convert to ONNX 106 | and integrate them in ONNX backend. Or you can rewrite your TensorFlow 107 | model such that composing the custom operator with builtin operator. 108 | 109 | I believe you have met similar in other deep learning system... 110 | And I believe this can be resolved in the future. 111 | But before that, we need to workaround... 112 | 113 | 114 | ## Custom the ONNX Opset Version 115 | 116 | `tflite2onnx` is bound to ONNX Opset 11 currently. 117 | We don't plan to support a _custom_ opset version, 118 | since it requires opset semantic conversion 119 | which could be a burden to handle but I don't see the value of it. 120 | 121 | If you really need a custom opset, try 122 | [the ONNX Version Converter](https://github.com/onnx/onnx/blob/master/docs/VersionConverter.md). 123 | And ask them to fix if there is any bug. 124 | (I have not used it :)) 125 | -------------------------------------------------------------------------------- /docs/how-to-enable-new-operator.md: -------------------------------------------------------------------------------- 1 | How to Enable New Operator 2 | ========================== 3 | 4 | This document will walk you through steps to enable new operators in `tflite2onnx`. 5 | 6 | It's highly recommended to read the [blog][blog] which introduces the 7 | background of `tflite2onnx` and [general contribution guide](contribution-guide.md). 8 | I am sure it will help when enabling new operators. 9 | Also, make sure that the operator is has not been enabled, 10 | i.e. not included in [Operator Support Status](how-to-enable-new-operator.md). 11 | 12 | 13 | ## Prepare Your Development Environment 14 | 15 | This is pretty simple: 16 | 17 | ```sh 18 | pip install -r requirements.txt 19 | ``` 20 | 21 | 22 | ## Generate the TensorFlow Lite model 23 | 24 | First of all, we need a TensorFlow Lite model (`model.tflite`) to get started. 25 | Currently, to generate a TFLite model, we build a TensorFlow or Keras model, 26 | and convert it into TFLite model. 27 | 28 | Below is an example of generating a TFLite model which contains `Concat` operator 29 | only. 30 | 31 | ```py 32 | import tensorflow as tf 33 | 34 | # operator inputs 35 | a = tf.keras.Input(dtype='float32', name='a', shape=(1, 2, 3, 1)) 36 | b = tf.keras.Input(dtype='float32', name='b', shape=(1, 2, 3, 2)) 37 | c = tf.keras.Input(dtype='float32', name='c', shape=(1, 2, 3, 3)) 38 | 39 | # operator 40 | concat = tf.keras.layers.Concatenate(axis=-1, name='output')([a, b, c]) 41 | 42 | # build Keras model 43 | model = tf.keras.Model(inputs=[a, b, c], outputs=[concat]) 44 | 45 | # convert to TFLite model 46 | converter = tf.lite.TFLiteConverter.from_keras_model(model) 47 | tflite_model = converter.convert() 48 | 49 | # save it 50 | with open('model.tflite', 'wb') as f: 51 | f.write(tflite_model) 52 | ``` 53 | 54 | Usually, tensor sizes are kept small to generate small model, as some of 55 | them will be hosted in `tflite2onnx` repository. In addition, we gave 56 | each dimension different extent, such that layout errors can be easily 57 | identified. 58 | 59 | Once generated, it's recommended to use visualization tool such as 60 | [Netron](https://github.com/lutzroeder/netron) to verify if the tensors 61 | and operator are what you expect. 62 | 63 | 64 | ## Setup Test for the Operator 65 | 66 | `tflite2onnx` requires test for every operator to ensure that functionality 67 | is not broken across development. The test for the operator is also very 68 | helpful when enabling new operators. 69 | 70 | Copy the newly generated `model.tflite` to `tflite2onnx` repository, put it 71 | in `${tflite2onnx}/assets/tests`. Naming convesion is `{operator}.{data type}.tflite`, 72 | for example `concat.float32.tflite`. The pattern `{operator}.{data type}` 73 | will be used in our test. Also, `{operator}` doesn't necessarily to be 74 | operator type only, check files in `${tflite2onnx}/assets/tests` for details. 75 | 76 | Add the pattern `{operator}.{data type}` into operator test in 77 | `${tflite2onnx}/tests/test_ops.py`, depending the data layout attribution of 78 | the operator (if you don't know which sub test shall the op goes, check the [blog][blog]). 79 | It would help to comment out all other operators and tests when trying around. 80 | 81 | Invoke the test `python tests/test_ops.py`. You should be able to see errors 82 | like below, which indicates that one operator has not been supported. 83 | 84 | ``` 85 | wzh@Mac[✓]tflite2onnx.git (master*) $ python tests/test_ops.py 86 | 2020-11-09 20:51:00,439 D [tflite2onnx][convert.py:37] tflite: /Users/wzh/workspace/onnx/tflite2onnx.git/assets/tests/concat.float32.tflite 87 | 2020-11-09 20:51:00,439 D [tflite2onnx][convert.py:38] onnx: concat.float32.onnx 88 | 2020-11-09 20:51:00,439 D [tflite2onnx][model.py:21] Parsing the Model... 89 | 2020-11-09 20:51:00,439 D [tflite2onnx][graph.py:58] Parsing the Graph... 90 | 2020-11-09 20:51:00,439 D [tflite2onnx][graph.py:61] Parsing operator: 0 91 | Traceback (most recent call last): 92 | File "tests/test_ops.py", line 85, in 93 | test_ops_post_propagation() 94 | File "tests/test_ops.py", line 64, in test_ops_post_propagation 95 | end2end_test(op, 'NHWC') 96 | File "tests/test_ops.py", line 16, in end2end_test 97 | t2o.convert(tflm_path, onnx_name) 98 | File "/Users/wzh/workspace/onnx/tflite2onnx.git/tflite2onnx/convert.py", line 44, in convert 99 | model.convert(explicit_layouts) 100 | File "/Users/wzh/workspace/onnx/tflite2onnx.git/tflite2onnx/model.py", line 39, in convert 101 | self.parse() 102 | File "/Users/wzh/workspace/onnx/tflite2onnx.git/tflite2onnx/model.py", line 31, in parse 103 | g.parse() 104 | File "/Users/wzh/workspace/onnx/tflite2onnx.git/tflite2onnx/graph.py", line 62, in parse 105 | op = self.OPCFactory.create(i) 106 | File "/Users/wzh/workspace/onnx/tflite2onnx.git/tflite2onnx/op/common.py", line 151, in create 107 | raise NotImplementedError("Unsupported TFLite OP: {}".format(opcode)) 108 | NotImplementedError: Unsupported TFLite OP: 2 109 | ``` 110 | 111 | The `2` of `NotImplementedError: Unsupported TFLite OP: 2` indicates which operator 112 | has not been enabled yet. It is `CONCATENATION` in 113 | [`tflite.BuiltinOperator`](https://github.com/zhenhuaw-me/tflite/blob/master/tflite/BuiltinOperator.py). 114 | 115 | With this, we can really start to write some code. 116 | 117 | 118 | ## Get the Workflow of the Operator Ready 119 | 120 | To start with, we add a operator converter class to handle 121 | converting of this operator. 122 | In this example, we created `tflite.op.Concat` initially as: 123 | 124 | ```py 125 | class Concat(Operator): 126 | TypeMapping = { 127 | tflite.BuiltinOperator.CONCATENATION: 'Concat', 128 | } 129 | 130 | def __init__(self, TFactory, index): 131 | super().__init__(TFactory, index) 132 | 133 | self.attrs['axis'] = -1 # operator attribute 134 | 135 | self.setInited() 136 | 137 | @property 138 | def type(self): 139 | return 'Concat' 140 | 141 | def parse(self): 142 | logger.debug("Parsing %s...", self.type) 143 | op = self.tflite 144 | opcode = self.model.OperatorCodes(op.OpcodeIndex()).BuiltinCode() 145 | assert(opcode is tflite.BuiltinOperator.CONCATENATION) 146 | 147 | assert(op.InputsLength() >= 1) 148 | assert(op.OutputsLength() == 1) 149 | 150 | # TODO: parse tensors 151 | 152 | self.setParsed() 153 | 154 | def propagatableTensors(self): 155 | return list() 156 | 157 | def transform(self): 158 | # TODO: handle layout transform 159 | pass 160 | ``` 161 | 162 | This can be done by copying an existing similar operator, and make several 163 | modifications. 164 | * `Operator.TypeMapping` maps TFLite operator type to ONNX operator type. Different TFLite operator may map to same ONNX operator type. An operator converter may be able to handle many TFLite operators. 165 | * `Operator.__init__()` collects the TFLite objects and initializes attributes of the operator. You can take [ONNX Operator Schemas][onnx-op] as reference. When it's done, the object switches to status `INITIALIZED`. 166 | * `Operator.type` is the operator type of ONNX. It's a string which you can find in operator examples in [ONNX Operator Schemas][onnx-op] - usually simply the operator type name, e.g. `Concat` in our example. The type may be mapped from the TFLite operator type via `Operator.TypeMapping` or other information sometimes depending on the implementation. 167 | * `Operator.parse()` parses the tensors used by the operator, attributes of the operator. Let's left it *to be done* in next section. After finished, set object to status `PARSED`. Mostly, an object should not been parsed multiple times. 168 | * `Operator.propagatableTensors()` describes which tensors of this operator are layout propagatable. This is a bit tricky, please look into [Data layout semantic and converting procedure][layout-handling]. 169 | * `Operator.transform()` transforms operator attributes that are sensitive to layout. This is sort tricky which requires serious consideration. Leave it empty currently. 170 | 171 | Now, let's integrate the operator converter class into framework. 172 | This is simple (as we are trying to make it easy to extend :) ). 173 | Import and register the operator converter class. In `${tflite2onnx}/op/__init__.py`, add code below. 174 | 175 | ```py 176 | from tflite2onnx.op.concat import Concat 177 | # ... 178 | 179 | OpFactory.register(Concat) 180 | ``` 181 | 182 | That's it! Simple! 183 | 184 | Now let's try it. You may see errors like below - take it easy, 185 | as we have not finish our jobs. But we can see that the `Concat` 186 | class is parsing something (nothing so far). That means we have 187 | enabled basic workflow for the operator. 188 | 189 | ``` 190 | 2020-11-09 21:04:45,344 D [tflite2onnx][convert.py:37] tflite: /Users/wzh/workspace/onnx/tflite2onnx.git/assets/tests/concat.float32.tflite 191 | 2020-11-09 21:04:45,345 D [tflite2onnx][convert.py:38] onnx: concat.float32.onnx 192 | 2020-11-09 21:04:45,345 D [tflite2onnx][model.py:21] Parsing the Model... 193 | 2020-11-09 21:04:45,345 D [tflite2onnx][graph.py:58] Parsing the Graph... 194 | 2020-11-09 21:04:45,345 D [tflite2onnx][graph.py:61] Parsing operator: 0 195 | 2020-11-09 21:04:45,345 D [tflite2onnx][concat.py:27] Parsing [None](Concat)... 196 | Traceback (most recent call last): 197 | File "tests/test_ops.py", line 85, in 198 | test_ops_post_propagation() 199 | File "tests/test_ops.py", line 64, in test_ops_post_propagation 200 | end2end_test(op, 'NHWC') 201 | File "tests/test_ops.py", line 16, in end2end_test 202 | t2o.convert(tflm_path, onnx_name) 203 | File "/Users/wzh/workspace/onnx/tflite2onnx.git/tflite2onnx/convert.py", line 44, in convert 204 | model.convert(explicit_layouts) 205 | File "/Users/wzh/workspace/onnx/tflite2onnx.git/tflite2onnx/model.py", line 39, in convert 206 | self.parse() 207 | File "/Users/wzh/workspace/onnx/tflite2onnx.git/tflite2onnx/model.py", line 31, in parse 208 | g.parse() 209 | File "/Users/wzh/workspace/onnx/tflite2onnx.git/tflite2onnx/graph.py", line 63, in parse 210 | op.parse() 211 | File "/Users/wzh/workspace/onnx/tflite2onnx.git/tflite2onnx/op/concat.py", line 47, in parse 212 | self.setParsed() 213 | File "/Users/wzh/workspace/onnx/tflite2onnx.git/tflite2onnx/op/common.py", line 104, in setParsed 214 | self.name = self.outputs[0].name if self.name is None else self.name 215 | IndexError: list index out of range 216 | ``` 217 | 218 | 219 | ## Make the Operator Converter Work 220 | 221 | ### Understand Operator Semantic Divergence 222 | 223 | TFLite and ONNX operator semantic are sometimes different. Make sure to review 224 | [TFLite API documentation][tflite-api] for operator option, 225 | and [ONNX documents][onnx-op] for operator attributes. To be noted, some operator 226 | option or attribute of one, could be described by input tensor in another. 227 | 228 | For this `Concat` example, it accepts several inputs and generate one output in both 229 | TFLite and [ONNX](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Concat). 230 | Unfortuanately, TFLite doesn't provide rich documents about operators, we may check 231 | [Compatible operations of TensorFlow and TFLite](https://www.tensorflow.org/lite/guide/ops_compatibility#compatible_operations) 232 | and sometimes even the [source code](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/kernels/internal/reference/concatenation.h). 233 | 234 | For options or attributes, we can check the 235 | [*OperatorOption* of TFLite](https://zhenhuaw.me/tflite/docs/ConcatenationOptions.m.html#tflite.ConcatenationOptions.ConcatenationOptions). 236 | In our `Concat` example, it has two: 237 | * `Axis` indicates concatenating on which dimension. This attribute is sensitive to how we handle layout issue. For example, if TFLite concatenates on axis `-1` and has a `NHWC` data layout - which means it's concatenating on `C` dimension. While ONNX uses layout `NCHW`, ONNX version needs to concatenates on axis `1` for it's dimension `C` in ONNX. This is needed when `Concat` feeds to a `Conv`. The interesting part is, if the model contains no `Conv`, for example has only one `Concat` in our case, we'd better keep it unchanged. We will discuss this more in dedicated document. 238 | * `FusedActivationFunction` describes which [activation function](https://zhenhuaw.me/tflite/docs/ActivationFunctionType.m.html) has been fused into this operator. This is commen in operators like `Conv` and `FullyConnected`. 239 | 240 | 241 | ### Parse the Tensors 242 | 243 | Parsing input and output tensors is simple, and we have provided 244 | well wrapped helpers to make this easy. 245 | 246 | In our case, `Concat` may has multiple inputs and one output, so just 247 | 248 | ```py 249 | for i in range(op.InputsLength()): 250 | self.parseInput(i) 251 | 252 | self.parseOutput(0) 253 | ``` 254 | 255 | Now, if you invoke the test, you may see that it completes without erros. 256 | You may also catch the tensors and graph log (well, we have many debug 257 | log to make investigation easier): 258 | 259 | ```sh 260 | wzh@Mac[✗]tflite2onnx.git (master*) $ python tests/test_ops.py 261 | 2020-11-09 21:11:24,589 D [tflite2onnx][convert.py:37] tflite: /Users/wzh/workspace/onnx/tflite2onnx.git/assets/tests/concat.float32.tflite 262 | 2020-11-09 21:11:24,589 D [tflite2onnx][convert.py:38] onnx: concat.float32.onnx 263 | 2020-11-09 21:11:24,589 D [tflite2onnx][model.py:21] Parsing the Model... 264 | 2020-11-09 21:11:24,589 D [tflite2onnx][graph.py:58] Parsing the Graph... 265 | 2020-11-09 21:11:24,589 D [tflite2onnx][graph.py:61] Parsing operator: 0 266 | 2020-11-09 21:11:24,590 D [tflite2onnx][concat.py:27] Parsing [None](Concat)... 267 | 2020-11-09 21:11:24,590 D [tflite2onnx][tensor.py:103] Parsing a... 268 | 2020-11-09 21:11:24,590 D [tflite2onnx][tensor.py:103] Parsing b... 269 | 2020-11-09 21:11:24,591 D [tflite2onnx][tensor.py:103] Parsing c... 270 | 2020-11-09 21:11:24,591 D [tflite2onnx][tensor.py:103] Parsing Identity... 271 | 2020-11-09 21:11:24,591 D [tflite2onnx][model.py:40] Converting... 272 | 2020-11-09 21:11:24,591 D [tflite2onnx][graph.py:91] Converting... 273 | 2020-11-09 21:11:24,592 D [tflite2onnx][graph.py:93] Handling data layout... 274 | 2020-11-09 21:11:24,592 D [tflite2onnx][graph.py:130] Propragating layout across graph... 275 | 2020-11-09 21:11:24,592 D [tflite2onnx][graph.py:141] Propagation: 4 tensors in total, 0 to walk, 4 at wild 276 | 2020-11-09 21:11:24,592 D [tflite2onnx][graph.py:164] Propagation: wild tensors 4, ignored tensors 0 277 | 2020-11-09 21:11:24,592 D [tflite2onnx][graph.py:104] Translating quantization semantic... 278 | 2020-11-09 21:11:24,592 D [tflite2onnx][graph.py:112] Graph: 279 | [OP] [Identity](Concat) attr{'axis': -1}: ['a', 'b', 'c'] -> ['Identity'] 280 | [Input] (float32,[1, 1, 2, 3, 1]): {[]} -> {['[Identity](Concat)']} 281 | [Input] (float32,[1, 1, 2, 3, 2]): {[]} -> {['[Identity](Concat)']} 282 | [Input] (float32,[1, 1, 2, 3, 3]): {[]} -> {['[Identity](Concat)']} 283 | [Output] (float32,[1, 1, 2, 3, 6]): {['[Identity](Concat)']} -> {[]} 284 | [Value Info] (float32,[1, 1, 2, 3, 1]): {[]} -> {['[Identity](Concat)']} 285 | [Value Info] (float32,[1, 1, 2, 3, 6]): {['[Identity](Concat)']} -> {[]} 286 | [Value Info] (float32,[1, 1, 2, 3, 2]): {[]} -> {['[Identity](Concat)']} 287 | [Value Info] (float32,[1, 1, 2, 3, 3]): {[]} -> {['[Identity](Concat)']} 288 | 289 | 2020-11-09 21:11:24,592 D [tflite2onnx][common.py:111] Converting [Identity](Concat)... 290 | 2020-11-09 21:11:24,592 D [tflite2onnx][tensor.py:142] Converting (float32,[1, 1, 2, 3, 1])... 291 | 2020-11-09 21:11:24,593 D [tflite2onnx][tensor.py:142] Converting (float32,[1, 1, 2, 3, 2])... 292 | 2020-11-09 21:11:24,593 D [tflite2onnx][tensor.py:142] Converting (float32,[1, 1, 2, 3, 3])... 293 | 2020-11-09 21:11:24,593 D [tflite2onnx][tensor.py:142] Converting (float32,[1, 1, 2, 3, 6])... 294 | 2020-11-09 21:11:24,593 D [tflite2onnx][graph.py:118] Making ONNX... 295 | 2020-11-09 21:11:24,594 D [tflite2onnx][model.py:56] saving model as concat.float32.onnx 296 | 2020-11-09 21:11:24,614 I [tflite2onnx][convert.py:46] Converted ONNX model: concat.float32.onnx 297 | 2020-11-09 21:11:24,622 D [shrub][onnx.py:55] running concat.float32.onnx 298 | 2020-11-09 21:11:24,709 D [shrub][onnx.py:22] parsing concat.float32.onnx 299 | 2020-11-09 21:11:25,384 D [tensorflow][tpu_cluster_resolver.py:34] Falling back to TensorFlow client; we recommended you install the Cloud TPU client directly with pip install cloud-tpu-client. 300 | 2020-11-09 21:11:26,565 D [shrub][tflite.py:98] running /Users/wzh/workspace/onnx/tflite2onnx.git/assets/tests/concat.float32.tflite 301 | 2020-11-09 21:11:26,568 D [shrub][tflite.py:104] Inputs: [{'name': 'a', 'index': 0, 'shape': array([1, 1, 2, 3, 1], dtype=int32), 'shape_signature': array([1, 1, 2, 3, 1], dtype=int32), 'dtype': , 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}, {'name': 'b', 'index': 1, 'shape': array([1, 1, 2, 3, 2], dtype=int32), 'shape_signature': array([1, 1, 2, 3, 2], dtype=int32), 'dtype': , 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}, {'name': 'c', 'index': 2, 'shape': array([1, 1, 2, 3, 3], dtype=int32), 'shape_signature': array([1, 1, 2, 3, 3], dtype=int32), 'dtype': , 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}] 302 | 2020-11-09 21:11:26,569 D [shrub][tflite.py:105] Outputs: [{'name': 'Identity', 'index': 3, 'shape': array([1, 1, 2, 3, 6], dtype=int32), 'shape_signature': array([1, 1, 2, 3, 6], dtype=int32), 'dtype': , 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtyp 303 | ``` 304 | 305 | But, that's not the end necessarily, keep going! 306 | 307 | 308 | ## Parse Operator Attributes 309 | 310 | TFLite model stores *operator option* with dedicated class per operator, 311 | which needs to be handled seperately. 312 | 313 | Taking `Concat` example, the options are aviable to obtain after a option 314 | object has *init* from memory. See below. 315 | 316 | ```py 317 | op_opt = op.BuiltinOptions() 318 | option = tflite.ConcatenationOptions() 319 | option.Init(op_opt.Bytes, op_opt.Pos) 320 | self.attrs['axis'] = option.Axis() 321 | ``` 322 | 323 | Each operator option has a funtion to extract the information, please refer 324 | to the [TFLite parser API][tflite-api]. And all ONNX attributes are collected 325 | in `Operator.attrs`. 326 | 327 | A TFLite operator option doesn't necessarily have a peer ONNX operator 328 | attribute, vice verse. A TFLite operator option may become ONNX operator input, 329 | or implicit ONNX operator semantic. Please do take care consideration for these 330 | functionalities. If you are not sure, take existing operator converter as 331 | reference, or open issue to ask. 332 | 333 | Among all the options, *fused activation function* is one special, for which 334 | we need to add one more ONNX operator to the graph. But, don't worry, it can be 335 | handled by simply calling `handleFusedActivation(self, option, ot)`, if that 336 | operator has a `FusedActivationFunction()` 337 | ([`Concat` example](https://zhenhuaw.me/tflite/docs/ConcatenationOptions.m.html#tflite.ConcatenationOptions.ConcatenationOptions.FusedActivationFunction)) 338 | method of its option class. If that is the case, please don't add output tensor 339 | of the operator directly, but do something like below. 340 | 341 | ```py 342 | handleFusedActivation(self, option, ot) 343 | ``` 344 | 345 | Do remeber to initialize ONNX attributes in `Operator.__init__()`. 346 | And do NOT miss any when creating ONNX operator. 347 | 348 | 349 | ## Handling Data Layout Issue 350 | 351 | Given that the data layout of TFLite models and ONNX models are NHWC 352 | and NCHW respective, some additional efforts are needed when enabling 353 | operators. If you have not read the [blog][blog] or the [layout handling 354 | story][layout-handling], not it's the time. 355 | 356 | `Operator.propagatableTensors()` describes which tensors of this operator 357 | are layout propagatable. For most case like `Concat`, all tensors are 358 | propagatable, so we can write it like this. 359 | 360 | ```py 361 | def propagatableTensors(self): 362 | return self.inputs + self.outputs 363 | ``` 364 | 365 | But for operators like `Conv`, none of it's tensors is propagatable. 366 | Be carefull for this part as it may require significant effort to debug 367 | if it's not correctly coded at the begining. 368 | 369 | `Operator.transform()` transforms operator attributes that are sensitive 370 | to layout. For most case, this function can be left as empty. But, just like 371 | `Operator.propagatableTensors()`, we need to check what should be done. 372 | 373 | For `Concat`, which requires attribute transform, we get the `layout` 374 | description (source layout and target layout) from output, and transform 375 | attribute `axis` accordingly. 376 | 377 | ```py 378 | def transform(self): 379 | logger.debug("Transforming %s...", self.shorty) 380 | layout = self.outputs[0].layout 381 | if layout is not None: 382 | axis = self.attrs['axis'] 383 | axis = axis if axis >= 0 else (axis + len(layout.perm)) 384 | self.attrs['axis'] = layout.perm.index(axis) 385 | ``` 386 | 387 | To be noted, the layout issue is case by case, this document only shows 388 | `Concat` as an example. You may find other operator converter classes 389 | as a hint for the operator you are trying to enable. 390 | If you have any question, just open issue to discuss. 391 | 392 | 393 | ## Going Further 394 | 395 | Congratulation! You have basically finished the implementation of 396 | a new operator. If everthing looks good, open pull request. 397 | Let your work empower the community. 398 | 399 | Thank you for your contribution! 400 | 401 | Cheers! 402 | 403 | 404 | [onnx-op]: https://github.com/onnx/onnx/blob/master/docs/Operators.md 405 | [layout-handling]: https://github.com/zhenhuaw-me/tflite2onnx/issues/2 406 | [tflite-api]: https://zhenhuaw.me/tflite/docs 407 | [blog]: https://zhenhuaw.me/blog/2020/Convert-TensorFlow-Lite-models-to-ONNX.html 408 | -------------------------------------------------------------------------------- /docs/images/propagate-nasnet.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenhuaw-me/tflite2onnx/32d4d5731be506d5797c648fa73a91b36a80222b/docs/images/propagate-nasnet.jpg -------------------------------------------------------------------------------- /docs/release-notes.md: -------------------------------------------------------------------------------- 1 | Release Notes 2 | ============= 3 | 4 | 5 | ## v0.3.2 6 | 7 | 2021-02-03, [Project](https://github.com/zhenhuaw-me/tflite2onnx/projects/5) 8 | 9 | * New API `getSupportedOperators()` to know what operators have been supported. No longer need to maintain a list manually. 10 | * FP16 quantization: fold FP16 tensors to unblock MediaPipe models. See [this issue](https://github.com/zhenhuaw-me/tflite2onnx/issues/35) for details. 11 | 12 | Thanks for the contribution of @briangrifiin! 13 | 14 | 15 | ## v0.3.1 16 | 17 | 2020-12-28, [Project](https://github.com/zhenhuaw-me/tflite2onnx/projects/5) 18 | 19 | * More operators, check [the support list](https://github.com/zhenhuaw-me/tflite2onnx/blob/v0.3.1/docs/operator-support-status.md). 20 | * Relax data type check, most for FP16 and INT8. 21 | * Interface `enableDebugLog()` to dump log for debugging purpose. 22 | 23 | Thanks for the contribution of @erizmr @briangrifiin and @IkbeomJeon! 24 | 25 | 26 | ## v0.3.0 27 | 28 | 2020-09-30, [Project](https://github.com/zhenhuaw-me/tflite2onnx/projects/4) 29 | 30 | * Now open source with [annocement blog](https://zhenhuaw.me/blog/2020/Convert-TensorFlow-Lite-models-to-ONNX.html). 31 | * [Quantization support](https://github.com/zhenhuaw-me/tflite2onnx/issues/10) enabled, and tried quantized MobileNetV1 an MobileNetV2. 32 | * Drop [Transpose based layout handling](https://github.com/zhenhuaw-me/tflite2onnx/issues/2) to save effort of managing quantization. 33 | * More [operators](https://github.com/zhenhuaw-me/tflite2onnx/blob/v0.3.0/docs/operator-support-status.md) added, and [tested models](https://github.com/zhenhuaw-me/tflite2onnx/tree/more-model-test/assets/networks): 34 | * MobileNetV1 35 | * MobileNetV2 36 | * DenseNet 37 | * EfficientNet 38 | * MnasNet 39 | * SqueezeNet 40 | * NasNet 41 | 42 | ## v0.2.0 43 | 44 | 2020-07-15, [Project](https://github.com/zhenhuaw-me/tflite2onnx/projects/2) 45 | 46 | * Operator support of MobileNetV2. 47 | * Infrastructure improvements. 48 | * [Propagation based layout handling](https://github.com/zhenhuaw-me/tflite2onnx/issues/2). 49 | 50 | 51 | ## v0.1.0 52 | 53 | 2020-05-24, [Project](https://github.com/zhenhuaw-me/tflite2onnx/projects/1) 54 | 55 | * Model converting Workflow. 56 | * Basic operator support of MobileNetV1. 57 | * [Transpose based layout handling](https://github.com/zhenhuaw-me/tflite2onnx/issues/2). 58 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | build 2 | coverage 3 | flake8 4 | numpy 5 | onnx 6 | pytest 7 | shrub 8 | tflite>=2.4.0 9 | twine 10 | -------------------------------------------------------------------------------- /scripts/build-wheel.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ "$(uname -s)" == "Darwin" ]; then 4 | root_dir=$(dirname $(dirname $(greadlink -f $0}))) 5 | else 6 | root_dir=$(dirname $(dirname $(readlink -f $0}))) 7 | fi 8 | rm -f ${root_dir}/assets/dist/tflite2onnx-*.whl 9 | 10 | pip install build numpy onnx "tflite>=2.4.0" 11 | 12 | python -m build \ 13 | --outdir ${root_dir}/assets/dist 14 | rm -rf ${root_dir}/tflite2onnx.egg-info 15 | rm -rf ${root_dir}/build 16 | -------------------------------------------------------------------------------- /scripts/open-github-connection.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [[ -z "$NGROK_TOKEN" ]]; then 4 | echo "Please set 'NGROK_TOKEN'" 5 | exit 2 6 | fi 7 | 8 | if [[ -z "$NGROK_LOCAL_PASS" ]]; then 9 | echo "Please set 'NGROK_LOCAL_PASS' for user: $USER" 10 | exit 3 11 | fi 12 | 13 | echo "1. Installing Ngrok..." 14 | 15 | wget -q https://bin.equinox.io/c/4VmDzA7iaHb/ngrok-stable-linux-386.zip 16 | unzip ngrok-stable-linux-386.zip 17 | chmod +x ./ngrok 18 | 19 | echo "2. Updating password for user: $USER" 20 | echo -e "$NGROK_LOCAL_PASS\n$NGROK_LOCAL_PASS" | sudo passwd "$USER" 21 | 22 | echo "3. Starting ngrok proxy..." 23 | 24 | rm -f .ngrok.log 25 | ./ngrok authtoken "$NGROK_TOKEN" 26 | ./ngrok tcp 22 --log ".ngrok.log" & 27 | 28 | sleep 10 29 | HAS_ERRORS=$(grep "command failed" < .ngrok.log) 30 | 31 | if [[ -z "$HAS_ERRORS" ]]; then 32 | echo "" 33 | echo "==========================================" 34 | echo "To connect: $(grep -o -E "tcp://(.+)" < .ngrok.log | sed "s/tcp:\/\//ssh $USER@/" | sed "s/:/ -p /")" 35 | echo "==========================================" 36 | else 37 | echo "$HAS_ERRORS" 38 | exit 4 39 | fi 40 | -------------------------------------------------------------------------------- /scripts/source-me.sh: -------------------------------------------------------------------------------- 1 | echo "Tip: must be sourced under root directory of the repo!" 2 | export PYTHONPATH=$(pwd):${PYTHONPATH} 3 | -------------------------------------------------------------------------------- /scripts/upload-pip.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ "$(uname -s)" == "Darwin" ]; then 4 | root_dir=$(dirname $(dirname $(greadlink -f $0}))) 5 | else 6 | root_dir=$(dirname $(dirname $(readlink -f $0}))) 7 | fi 8 | 9 | ${root_dir}/scripts/build-wheel.sh 10 | 11 | read -p "Will upload to test.pypi.org, for real publishment type \"Release\": " input_str 12 | if [ -z "${input_str}" -o ${input_str} != "Release" ]; then 13 | python3 -m twine upload \ 14 | --repository-url https://test.pypi.org/legacy/ \ 15 | ${root_dir}/assets/dist/tflite2onnx-* 16 | else 17 | read -p "Will publish the package, are you sure to continue [Y|N] ? " input_str 18 | if [ -n "${input_str}" -a ${input_str} = "Y" ]; then 19 | echo "Uploading..." 20 | python3 -m twine upload ${root_dir}/assets/dist/tflite2onnx-* 21 | fi 22 | fi 23 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name=tflite2onnx 3 | version = attr: tflite2onnx.__version__ 4 | description = Convert TensorFlow Lite models to ONNX 5 | 6 | author = 王振华(Zhenhua WANG) 7 | author_email = hi@zhenhuaw.me 8 | url = https://zhenhuaw.me/tflite2onnx 9 | 10 | long_description = file: README.md 11 | long_description_content_type = text/markdown 12 | license = Apache License 2.0 13 | license_file = LICENSE 14 | keywords = tflite, onnx, deep-learning 15 | 16 | project_urls = 17 | Bug Reports = https://github.com/zhenhuaw-me/tflite2onnx/issues 18 | Source = https://github.com/zhenhuaw-me/tflite2onnx 19 | 20 | classifiers = 21 | Development Status :: 4 - Beta 22 | Programming Language :: Python :: 3 23 | Programming Language :: Python :: 3.5 24 | Programming Language :: Python :: 3.6 25 | Programming Language :: Python :: 3.7 26 | Programming Language :: Python :: 3.8 27 | Environment :: Console 28 | Intended Audience :: Developers 29 | License :: OSI Approved :: Apache Software License 30 | Natural Language :: English 31 | Operating System :: OS Independent 32 | Topic :: Scientific/Engineering :: Artificial Intelligence 33 | 34 | 35 | [options] 36 | install_requires = 37 | numpy 38 | onnx 39 | tflite>=2.4.0 40 | python_requires = >=3.5.*, <4 41 | packages = find: 42 | 43 | 44 | [options.entry_points] 45 | console_scripts = 46 | tflite2onnx = tflite2onnx.convert:cmd_convert 47 | 48 | 49 | [flake8] 50 | max-line-length = 100 51 | max-complexity = 10 52 | 53 | [tool:pytest] 54 | log_level = DEBUG 55 | -------------------------------------------------------------------------------- /tests/test_cmd.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | 4 | 5 | def cmd_convert(model_name): 6 | cur_dir = os.path.dirname(os.path.abspath(__file__)) 7 | tflm_dir = os.path.abspath(cur_dir + '/../assets/tests') 8 | tflm_name = model_name + '.tflite' 9 | onnx_name = model_name + '.onnx' 10 | tflm_path = os.path.join(tflm_dir, tflm_name) 11 | 12 | cmd = "tflite2onnx %s %s" % (tflm_path, onnx_name) 13 | cmd = cmd.split(' ') 14 | 15 | process = subprocess.run(cmd) 16 | assert(process.returncode == 0) 17 | 18 | 19 | def test_cmd_convert(): 20 | MODEL_LIST = ( 21 | 'abs.float32', 22 | ) 23 | 24 | for m in MODEL_LIST: 25 | cmd_convert(m) 26 | 27 | 28 | if __name__ == '__main__': 29 | test_cmd_convert() 30 | -------------------------------------------------------------------------------- /tests/test_explicit_layout.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | 4 | import shrub 5 | import tflite2onnx as t2o 6 | 7 | shrub.util.formatLogging(logging.DEBUG) 8 | 9 | 10 | def run_end2end_test(tflite_path, onnx_path, tflite_layout, onnx_layout, tensors): 11 | io_layouts = {t: (tflite_layout, onnx_layout) for t in tensors} 12 | t2o.convert(tflite_path, onnx_path, io_layouts) 13 | 14 | m = shrub.tflite.parse(tflite_path, tflite_layout) 15 | m.genInput() 16 | 17 | onnx_ret = shrub.onnx.run(onnx_path, m.inputs, onnx_layout) 18 | tflite_ret = shrub.tflite.run(tflite_path, m.inputs, tflite_layout) 19 | assert(shrub.network.cmpTensors(onnx_ret, tflite_ret, useLayout=tflite_layout)) 20 | 21 | 22 | def end2end_test(model_name, tflite_layout, onnx_layout, tensors): 23 | cur_dir = os.path.dirname(os.path.abspath(__file__)) 24 | tflm_dir = os.path.abspath(cur_dir + '/../assets/tests') 25 | tflm_name = model_name + '.tflite' 26 | onnx_name = model_name + '.onnx' 27 | tflm_path = os.path.join(tflm_dir, tflm_name) 28 | 29 | # Firstly same ONNX layout as TFLite 30 | run_end2end_test(tflm_path, onnx_name, tflite_layout, tflite_layout, tensors) 31 | # Secondly different layouts 32 | run_end2end_test(tflm_path, onnx_name, tflite_layout, onnx_layout, tensors) 33 | 34 | 35 | def test_explicit_layout(): 36 | end2end_test('abs.float32', 'NHWC', 'NCHW', ['input', 'output']) 37 | end2end_test('abs.float32', 'NHWC', 'NCHW', ['input', ]) 38 | 39 | end2end_test('add.float32', 'NHWC', 'NCHW', ['A', ]) 40 | end2end_test('add-broadcast.float32', 'NHWC', 'NCHW', ['A', ]) 41 | end2end_test('add-broadcast2.float32', 'NHWC', 'NCHW', ['A', ]) 42 | 43 | end2end_test('concat.float32', 'NHWDC', 'NDCHW', ['a', ]) 44 | end2end_test('concat2.float32', 'NHWC', 'NCHW', ['a', ]) 45 | 46 | end2end_test('mean.float32', 'NHWC', 'NCHW', ['input', ]) 47 | 48 | end2end_test('stridedslice.float32', 'NHWC', 'NCHW', ['input', ]) 49 | end2end_test('stridedslice-beginmask.float32', 'NHWC', 'NCHW', ['input', ]) 50 | end2end_test('stridedslice-endmask.float32', 'NHWC', 'NCHW', ['input', ]) 51 | end2end_test('stridedslice-stride.float32', 'NHWC', 'NCHW', ['input', ]) 52 | 53 | end2end_test('padding.float32', 'NHWC', 'NCHW', ['input', ]) 54 | 55 | end2end_test('abs-sqrt.float32', 'NHWC', 'NCHW', ['input', 'output']) 56 | end2end_test('abs-sqrt.float32', 'NHWC', 'NCHW', ['input', ]) 57 | 58 | 59 | if __name__ == '__main__': 60 | test_explicit_layout() 61 | -------------------------------------------------------------------------------- /tests/test_layout.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import shrub 3 | 4 | shrub.util.formatLogging(logging.DEBUG) 5 | 6 | 7 | def test_transform(): 8 | from tflite2onnx.layout import transform 9 | assert(transform([1, 2, 6, 8], 'NCHW', 'NHWC') == [1, 6, 8, 2]) 10 | assert(transform([1, 2, 6, 8], 'NHWC', 'NCHW') == [1, 8, 2, 6]) 11 | 12 | 13 | def test_getPerm(): 14 | from tflite2onnx.layout import getPerm 15 | assert(getPerm('01', '01') == [0, 1]) 16 | assert(getPerm('01', '10') == [1, 0]) 17 | assert(getPerm('0123', '0123') == [0, 1, 2, 3]) 18 | assert(getPerm('0123', '0312') == [0, 3, 1, 2]) 19 | assert(getPerm('0123', '3021') == [3, 0, 2, 1]) 20 | 21 | 22 | def test_align_dimension(): 23 | from tflite2onnx.op.binary import alignDimension 24 | # cases from: https://github.com/onnx/onnx/blob/master/docs/Broadcasting.md 25 | assert(alignDimension([2, 3, 4, 5], list()) == (False, [1, 1, 1, 1])) 26 | assert(alignDimension([2, 3, 4, 5], [5, ]) == (False, [1, 1, 1, 5])) 27 | assert(alignDimension([4, 5], [2, 3, 4, 5]) == (True, [1, 1, 4, 5])) 28 | assert(alignDimension([1, 4, 5], [2, 3, 1, 1]) == (True, [1, 1, 4, 5])) 29 | assert(alignDimension([3, 4, 5], [2, 1, 1, 1]) == (True, [1, 3, 4, 5])) 30 | 31 | 32 | if __name__ == '__main__': 33 | test_transform() 34 | test_getPerm() 35 | test_align_dimension() 36 | -------------------------------------------------------------------------------- /tests/test_mapping.py: -------------------------------------------------------------------------------- 1 | from tflite2onnx import mapping 2 | 3 | 4 | def test_mapping(): 5 | assert(len(mapping.DTYPE_NAME2ONNX) == 8) 6 | assert(len(mapping.DTYPE_NAME2TFLITE) == 8) 7 | assert(len(mapping.DTYPE_ONNX2NAME) == 8) 8 | assert(len(mapping.DTYPE_ONNX2TFLITE) == 8) 9 | assert(len(mapping.DTYPE_TFLITE2NAME) == 8) 10 | assert(len(mapping.DTYPE_TFLITE2ONNX) == 8) 11 | 12 | 13 | if __name__ == '__main__': 14 | test_mapping() 15 | -------------------------------------------------------------------------------- /tests/test_networks.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | 4 | import shrub 5 | import tflite2onnx as t2o 6 | 7 | shrub.util.formatLogging(logging.DEBUG) 8 | 9 | 10 | def end2end_test(model_name, use_layout): 11 | cur_dir = os.path.dirname(os.path.abspath(__file__)) 12 | tflm_dir = os.path.abspath(cur_dir + '/../assets/tests') 13 | tflm_name = model_name + '.tflite' 14 | onnx_name = model_name + '.onnx' 15 | tflm_path = os.path.join(tflm_dir, tflm_name) 16 | t2o.convert(tflm_path, onnx_name) 17 | 18 | m = shrub.tflite.parse(tflm_path) 19 | m.genInput() 20 | 21 | onnx_ret = shrub.onnx.run(onnx_name, m.inputs, use_layout) 22 | tflite_ret = shrub.tflite.run(tflm_path, m.inputs) 23 | assert(shrub.network.cmpTensors(onnx_ret, tflite_ret, useLayout=use_layout)) 24 | 25 | 26 | def test_networks(): 27 | NETWORK_LIST = ( 28 | 'mobilenet_v1_0.25_128', 29 | ) 30 | 31 | for net in NETWORK_LIST: 32 | end2end_test(net, 'NCHW') 33 | 34 | 35 | if __name__ == '__main__': 36 | test_networks() 37 | -------------------------------------------------------------------------------- /tests/test_ops.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | 4 | import shrub 5 | import tflite2onnx as t2o 6 | 7 | shrub.util.formatLogging(logging.DEBUG) 8 | 9 | 10 | def end2end_test(model_name, use_layout): 11 | cur_dir = os.path.dirname(os.path.abspath(__file__)) 12 | tflm_dir = os.path.abspath(cur_dir + '/../assets/tests') 13 | tflm_name = model_name + '.tflite' 14 | onnx_name = model_name + '.onnx' 15 | tflm_path = os.path.join(tflm_dir, tflm_name) 16 | t2o.convert(tflm_path, onnx_name) 17 | 18 | m = shrub.tflite.parse(tflm_path) 19 | m.genInput() 20 | 21 | onnx_ret = shrub.onnx.run(onnx_name, m.inputs, use_layout) 22 | tflite_ret = shrub.tflite.run(tflm_path, m.inputs) 23 | assert(shrub.network.cmpTensors(onnx_ret, tflite_ret, useLayout=use_layout)) 24 | 25 | 26 | def test_ops_implicit_layout(): 27 | # these ops will stop layout propagation 28 | OP_LIST = ( 29 | 'avgpooling.float32', 30 | 'avgpool-concat.float32', 31 | 'conv.float32', 32 | 'conv-dilation.float32', 33 | 'conv-quant-fp16.float32', 34 | 'conv-relu.float32', 35 | 'conv-relu6.float32', 36 | 'conv-stride.float32', 37 | 'depthwise-conv.float32', 38 | 'depthwise-conv-stride.float32', 39 | 'fullyconnected.float32', 40 | 'fullyconnected-relu6.float32', 41 | 'maxpooling.float32', 42 | 'resize-bilinear.float32', 43 | 'resize-nearest-neighbor.float32', 44 | 'conv-reshape.float32', 45 | 'reshape-conv.float32', 46 | 'conv-reshape-multiple-conv.float32', 47 | 'transposeconv-samepad-stride2.float32', 48 | 'transposeconv-samepad.float32', 49 | 'transposeconv-validpad-stride2.float32', 50 | 'transposeconv-validpad.float32', 51 | ) 52 | 53 | for op in OP_LIST: 54 | end2end_test(op, 'NCHW') 55 | 56 | 57 | def test_ops_post_propagation(): 58 | # these ops need post-propagation handling 59 | OP_LIST = ( 60 | 'concat.float32', 61 | 'mean.float32', 62 | 'padding.float32', 63 | 'reshape.float32', 64 | 'softmax.float32', 65 | 'split.float32', 66 | 'stridedslice-beginmask.float32', 67 | 'stridedslice-endmask.float32', 68 | 'stridedslice-stride.float32', 69 | 'stridedslice.float32', 70 | 'transpose.float32', 71 | 'mirror-pad.int32', 72 | ) 73 | 74 | for op in OP_LIST: 75 | end2end_test(op, 'NHWC') 76 | 77 | 78 | def test_ops_layout_transparent(): 79 | # these ops are very wild :) 80 | OP_LIST = ( 81 | 'abs.float32', 82 | 'add.float32', 83 | 'add-relu.float32', 84 | 'mul.float32', 85 | 'relu6.float32', 86 | 'relu.float32', 87 | 'prelu.float32', 88 | 'sigmoid.float32', 89 | 'sub.float32', 90 | 'abs-sqrt.float32', 91 | 'relu6-power.float32', 92 | 'squared-diff.float32', 93 | 'abs-add-rsqrt.float32', 94 | ) 95 | 96 | for op in OP_LIST: 97 | end2end_test(op, 'NHWC') 98 | 99 | 100 | if __name__ == '__main__': 101 | test_ops_implicit_layout() 102 | test_ops_post_propagation() 103 | test_ops_layout_transparent() 104 | -------------------------------------------------------------------------------- /tests/test_padding.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import shrub 4 | import tflite 5 | from tflite2onnx.op.padding import computePaddingSize 6 | 7 | shrub.util.formatLogging(logging.DEBUG) 8 | 9 | 10 | def test_same_trival(): 11 | input_size = [10, 10] 12 | kernel_size = [3, 3] 13 | stride = [1, 1] 14 | dilation = [1, 1] 15 | padding_mode = tflite.Padding.SAME 16 | computed = computePaddingSize(padding_mode, input_size, kernel_size, stride, dilation) 17 | assert((computed == [1, 1, 1, 1]).all()) 18 | 19 | 20 | if __name__ == '__main__': 21 | test_same_trival() 22 | -------------------------------------------------------------------------------- /tests/test_quantize.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | import logging 4 | 5 | import shrub 6 | import tflite2onnx as t2o 7 | 8 | shrub.util.formatLogging(logging.DEBUG) 9 | 10 | 11 | def end2end_test(model_name, use_layout, atol, rtol): 12 | cur_dir = os.path.dirname(os.path.abspath(__file__)) 13 | tflm_dir = os.path.abspath(cur_dir + '/../assets/tests') 14 | tflm_name = model_name + '.tflite' 15 | onnx_name = model_name + '.onnx' 16 | tflm_path = os.path.join(tflm_dir, tflm_name) 17 | t2o.convert(tflm_path, onnx_name) 18 | 19 | m = shrub.tflite.parse(tflm_path) 20 | m.genInput() 21 | 22 | # TFLite model is supposed to be end to end quantized 23 | tflite_ret = shrub.tflite.run(tflm_path, m.inputs) 24 | oquant = shrub.tflite.parseQuantParam(tflm_path, False)[0] 25 | foutputs = list() 26 | for f in tflite_ret: 27 | foutput = copy.deepcopy(f) 28 | foutput.quant = oquant 29 | foutput.dequantize() 30 | foutputs.append(foutput) 31 | 32 | # ONNX model is supposed to be only several operators quantized 33 | iquant = shrub.tflite.parseQuantParam(tflm_path, True)[0] 34 | finputs = list() 35 | for q in m.inputs: 36 | finput = copy.deepcopy(q) 37 | finput.quant = iquant 38 | finput.dequantize() 39 | finputs.append(finput) 40 | onnx_ret = shrub.onnx.run(onnx_name, finputs, use_layout) 41 | 42 | assert(shrub.network.cmpTensors(foutputs, onnx_ret, atol=atol, rtol=rtol, useLayout=use_layout)) 43 | 44 | 45 | def test_quantized_ops(): 46 | OP_LIST = ( 47 | 'conv.uint8', 48 | 'conv-relu.uint8', 49 | 'depthwise-conv.uint8', 50 | ) 51 | 52 | for op in OP_LIST: 53 | end2end_test(op, 'NCHW', 1e-7, 1e-5) 54 | 55 | 56 | def test_quantized_networks(): 57 | NETWORK_LIST = ( 58 | 'mobilenet_v1_0.25_128_quant', 59 | ) 60 | 61 | for net in NETWORK_LIST: 62 | end2end_test(net, 'NCHW', 1, 1e-5) 63 | 64 | 65 | if __name__ == '__main__': 66 | test_quantized_ops() 67 | test_quantized_networks() 68 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import shrub 3 | from tflite2onnx import getSupportedOperators 4 | 5 | shrub.util.formatLogging(logging.DEBUG) 6 | 7 | 8 | def test_supported_ops(): 9 | ops = getSupportedOperators() 10 | assert(len(ops) > 0) 11 | assert(ops[0] == 'ADD') 12 | 13 | 14 | if __name__ == '__main__': 15 | test_supported_ops() 16 | -------------------------------------------------------------------------------- /tflite2onnx/__init__.py: -------------------------------------------------------------------------------- 1 | """Converting TensorFlow Lite models (*.tflite) to ONNX models (*.onnx)""" 2 | 3 | from tflite2onnx.convert import convert 4 | from tflite2onnx.utils import enableDebugLog, getSupportedOperators 5 | 6 | # package metadata 7 | __name__ = 'tflite2onnx' 8 | __version__ = '0.4.1' 9 | DESCRIPTION = "Convert TensorFlow Lite models to ONNX" 10 | 11 | __all__ = [ 12 | convert, 13 | enableDebugLog, 14 | getSupportedOperators, 15 | __name__, 16 | __version__, 17 | DESCRIPTION, 18 | ] 19 | -------------------------------------------------------------------------------- /tflite2onnx/common.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | from enum import Enum 3 | 4 | 5 | class Status(Enum): 6 | # Before `__init__()` finishes. 7 | UNINITIALIZED = 0 8 | 9 | # Basic TensorFlow Lite objects registed, Class-wise member allocated. 10 | INITIALIZED = 1 11 | 12 | # Objects and any members have been parsed from TFLite model. 13 | PARSED = 2 14 | 15 | # ONNX object has been created. 16 | CONVERTED = 3 17 | 18 | # Reserved. 19 | INVALID = 10 20 | 21 | @property 22 | def uninitialized(self): 23 | return self == self.UNINITIALIZED 24 | 25 | @property 26 | def initialized(self): 27 | return self == self.INITIALIZED 28 | 29 | @property 30 | def parsed(self): 31 | return self == self.PARSED 32 | 33 | @property 34 | def converted(self): 35 | return self == self.CONVERTED 36 | 37 | 38 | class T2OBase(ABC): 39 | """Holding objects of TFLite and ONNX""" 40 | def __init__(self, model=None, graph=None, index=None): 41 | # Overall fields 42 | self.status = Status.UNINITIALIZED 43 | self.name = None 44 | 45 | # TFLite objects 46 | self.model = model 47 | self.graph = graph 48 | self.index = index # index of tensor or op 49 | self.tflite = None 50 | 51 | # ONNX object 52 | self.onnx = None 53 | 54 | def setInited(self): 55 | assert(self.status.uninitialized) 56 | self.status = Status.INITIALIZED 57 | 58 | def parse(self): 59 | raise NotImplementedError("method parse() should be overrided!") 60 | 61 | def setParsed(self): 62 | assert(self.status.initialized) 63 | self.status = Status.PARSED 64 | 65 | def validate(self): 66 | raise NotImplementedError("method validate() should be overrided!") 67 | 68 | def convert(self): 69 | raise NotImplementedError("method convert() should be overrided!") 70 | 71 | def setConverted(self): 72 | assert(self.status.parsed) 73 | self.status = Status.CONVERTED 74 | 75 | def setInvalid(self): 76 | self.status = Status.INVALID 77 | 78 | @property 79 | def shorty(self): 80 | """A short readable description for the class/object. 81 | 82 | This aims to be different from `__str__` which is exepcted to be 83 | long description on this package. 84 | """ 85 | raise NotImplementedError("method shorty() should be overrided!") 86 | 87 | def __str__(self): 88 | """A readable description for the class/object.""" 89 | raise NotImplementedError("method __str__() should be overrided!") 90 | -------------------------------------------------------------------------------- /tflite2onnx/convert.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | 5 | import tflite 6 | import tflite2onnx 7 | from tflite2onnx.model import Model 8 | 9 | logger = logging.getLogger('tflite2onnx') 10 | 11 | 12 | def convert(tflite_path: str, onnx_path: str, explicit_layouts=None): 13 | """Converting TensorFlow Lite model (*.tflite) to ONNX model. 14 | 15 | Args: 16 | tflite_path (str): the path to TFLite model. 17 | onnx_path (str): the path where to save the converted ONNX model. 18 | explicit_layouts (dict, optinal): Dict of `str -> tuple(str, str)`. 19 | For each items, its *tensor name* `->` *tflite layout* and *onnx layout*. 20 | This can be safely ignored usually - tflite2onnx can handle most 21 | layout semantic divergence automatically. 22 | """ 23 | 24 | if not os.path.exists(tflite_path): 25 | raise ValueError("Invalid TFLite model path (%s)!" % tflite_path) 26 | if os.path.exists(onnx_path): 27 | logger.warning("ONNX model path (%s) existed!", onnx_path) 28 | 29 | if explicit_layouts: 30 | for k, v in explicit_layouts.items(): 31 | if not (isinstance(k, str) and isinstance(v, tuple) and 32 | (len(v) == 2) and isinstance(v[0], str) or isinstance(v[1], str)): 33 | raise ValueError("Invalid explicit layouts!") 34 | else: 35 | explicit_layouts = dict() 36 | 37 | logger.debug("tflite: %s", tflite_path) 38 | logger.debug("onnx: %s", onnx_path) 39 | with open(tflite_path, 'rb') as f: 40 | buf = f.read() 41 | im = tflite.Model.GetRootAsModel(buf, 0) 42 | 43 | model = Model(im) 44 | model.convert(explicit_layouts) 45 | model.save(onnx_path) 46 | logger.info("Converted ONNX model: %s", onnx_path) 47 | 48 | 49 | def cmd_convert(): 50 | description = "tflite2onnx " + tflite2onnx.__version__ + ", " + tflite2onnx.DESCRIPTION 51 | parser = argparse.ArgumentParser(description=description, 52 | formatter_class=argparse.RawTextHelpFormatter) 53 | parser.add_argument('tflite_path', help="Path to the input TFLite mode") 54 | parser.add_argument('onnx_path', help="Path to save the converted ONNX mode") 55 | 56 | args = parser.parse_args() 57 | 58 | convert(args.tflite_path, args.onnx_path) 59 | -------------------------------------------------------------------------------- /tflite2onnx/graph.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import logging 3 | import tflite 4 | from onnx import helper 5 | 6 | from tflite2onnx.tensor import TensorFactory 7 | from tflite2onnx.common import T2OBase 8 | from tflite2onnx.layout import Layout 9 | from tflite2onnx.op import OpFactory 10 | from tflite2onnx.quantize import handleQuantizationTensor 11 | from tflite2onnx.quantize import foldFP16QuantPattern 12 | 13 | logger = logging.getLogger('tflite2onnx') 14 | 15 | 16 | class Graph(T2OBase): 17 | def __init__(self, model: tflite.Model, graph: tflite.SubGraph): 18 | super().__init__(model, graph) 19 | 20 | self.ops = [] # the OP that has TFLite peer 21 | self.op_all = [] # includes helper OP 22 | 23 | self.inputs = [] 24 | self.outputs = [] 25 | self.initializer = set() 26 | self.value_info = set() 27 | 28 | self.tflite = graph 29 | self.TFactory = TensorFactory(model, graph) 30 | self.OPCFactory = OpFactory(self.TFactory) 31 | 32 | self.setInited() 33 | 34 | def _collectOpAndTensor(self): 35 | self.op_all.clear() 36 | 37 | # collect operators 38 | def _recursive(op): 39 | for cur_op in op.pre: 40 | _recursive(cur_op) 41 | self.op_all.append(op) 42 | for cur_op in op.post: 43 | _recursive(cur_op) 44 | for op in self.ops: 45 | _recursive(op) 46 | 47 | # collect tensors 48 | assert(len(self.op_all) > 0) 49 | self.initializer.clear() 50 | self.value_info.clear() 51 | for op in self.op_all: 52 | for t in op.inputs + op.outputs: 53 | if t.isInitializer: 54 | self.initializer.add(t) 55 | else: 56 | self.value_info.add(t) 57 | 58 | def parse(self): 59 | logger.debug("Parsing the Graph...") 60 | # operators 61 | for i in range(self.graph.OperatorsLength()): 62 | logger.debug("Parsing operator: %d", i) 63 | op = self.OPCFactory.create(i) 64 | op.parse() 65 | self.ops.append(op) 66 | 67 | # inputs 68 | for i in range(self.graph.InputsLength()): 69 | # FIXME: assert they have been created. 70 | index = self.graph.Inputs(i) 71 | t = self.TFactory.get(index) 72 | self.inputs.append(t) 73 | 74 | # outputs 75 | for i in range(self.graph.OutputsLength()): 76 | index = self.graph.Outputs(i) 77 | t = self.TFactory.get(index) 78 | self.outputs.append(t) 79 | 80 | self._collectOpAndTensor() 81 | 82 | self.setParsed() 83 | 84 | def validate(self): 85 | self._collectOpAndTensor() 86 | for op in self.op_all: 87 | op.validate() 88 | for t in self.initializer | self.value_info: 89 | t.validate() 90 | 91 | def convert(self, explicit_layouts): 92 | logger.debug("Converting...") 93 | 94 | logger.debug("Handling data layout...") 95 | for op in self.ops: 96 | for t in op.inputs + op.outputs: 97 | if t.name in explicit_layouts: 98 | assert(t.layout is None) 99 | layouts = explicit_layouts[t.name] 100 | assert(len(layouts) == 2) 101 | t.layout = Layout(layouts[0], layouts[1]) 102 | self._propagateLayout() 103 | self._collectOpAndTensor() 104 | 105 | foldFP16QuantPattern(self.ops) 106 | self._collectOpAndTensor() 107 | 108 | logger.debug("Translating quantization semantic...") 109 | for t in self.value_info | self.initializer: 110 | deqt = handleQuantizationTensor(self.TFactory, t) 111 | for i, o in enumerate(self.outputs): 112 | if o == t: 113 | self.outputs[i] = deqt 114 | self._collectOpAndTensor() 115 | 116 | logger.debug("Graph:\n%s", str(self)) 117 | 118 | self.validate() 119 | for op in self.op_all: 120 | op.convert() 121 | 122 | logger.debug("Making ONNX...") 123 | onodes = [n.onnx for n in self.op_all] 124 | oinputs = [t.onnx for t in self.inputs] 125 | ooutputs = [t.onnx for t in self.outputs] 126 | initializer = [t.onnx for t in self.initializer] 127 | value_info = [t.onnx for t in self.value_info] 128 | 129 | self.onnx = helper.make_graph(onodes, 'pre-alpha', oinputs, ooutputs, 130 | initializer=initializer, value_info=value_info) 131 | self.setConverted() 132 | 133 | def _propagateLayout(self): # noqa: C901 134 | logger.debug("Propragating layout across graph...") 135 | 136 | # collect tensors 137 | T_toWalk = set() 138 | T_wild = set() 139 | tensor_count = len(self.value_info) + len(self.initializer) 140 | for t in self.value_info | self.initializer: 141 | if t.layout is None: 142 | T_wild.add(t) 143 | else: 144 | T_toWalk.add(t) 145 | logger.debug("Propagation: %d tensors in total, %d to walk, %d at wild", 146 | tensor_count, len(T_toWalk), len(T_wild)) 147 | 148 | # propagrate layout across graph 149 | T_ignored = set() 150 | T_walked = set() 151 | while (len(T_toWalk) != 0): 152 | T = T_toWalk.pop() 153 | logger.debug("Propagation: walking %s", T.shorty) 154 | for n in T.producers + T.consumers: 155 | for t in n.propagatableTensors(): 156 | if t is T: 157 | continue 158 | if t in T_wild: 159 | logger.debug("Propagation: propagated to %s", t.shorty) 160 | assert(t.layout is None) 161 | T_wild.remove(t) 162 | if t.isScalar: 163 | T_ignored.add(t) 164 | else: 165 | t.layout = copy.deepcopy(T.layout) 166 | T_toWalk.add(t) 167 | T_walked.add(T) 168 | logger.debug("Propagation: wild tensors %d, ignored tensors %d", 169 | len(T_wild), len(T_ignored)) 170 | 171 | # update tensor and operator 172 | for t in T_walked: 173 | t.transform() 174 | self._collectOpAndTensor() 175 | for op in self.op_all: 176 | op.transform() 177 | 178 | def _dump(self, tag, container, useShorty): 179 | dump = str() 180 | for e in container: 181 | dump += '[%s] %s\n' % (tag, e.shorty if useShorty else e) 182 | return dump 183 | 184 | @property 185 | def shorty(self): 186 | string = str() 187 | string += self._dump('OP', self.op_all, True) 188 | string += self._dump('Input', self.inputs, True) 189 | string += self._dump('Output', self.outputs, True) 190 | string += self._dump('Initializer', self.initializer, True) 191 | string += self._dump('Value Info', self.value_info, True) 192 | return string 193 | 194 | def __str__(self): 195 | string = str() 196 | string += self._dump('OP', self.op_all, False) 197 | string += self._dump('Input', self.inputs, False) 198 | string += self._dump('Output', self.outputs, False) 199 | string += self._dump('Initializer', self.initializer, False) 200 | string += self._dump('Value Info', self.value_info, False) 201 | return string 202 | -------------------------------------------------------------------------------- /tflite2onnx/layout.py: -------------------------------------------------------------------------------- 1 | def getPerm(ilayout: str, olayout: str): 2 | char2index = {} 3 | for i in range(len(ilayout)): 4 | c = ilayout[i] 5 | char2index[c] = i 6 | 7 | perm = [char2index[c] for c in olayout] 8 | return perm 9 | 10 | 11 | def transform(input, ilayout: str, olayout: str): 12 | if (ilayout == olayout): 13 | return input 14 | 15 | perm = getPerm(ilayout, olayout) 16 | transfrom_axis = [input[p] for p in perm] 17 | return transfrom_axis 18 | 19 | 20 | class Layout(object): 21 | def __init__(self, source: str, target: str): 22 | self.source = source 23 | self.target = target 24 | self.current = source 25 | 26 | def transform(self, input): 27 | output = transform(input, self.source, self.target) 28 | self.current = self.target 29 | return output 30 | 31 | @property 32 | def perm(self): 33 | return getPerm(self.source, self.target) 34 | 35 | def __str__(self): 36 | return self.current + '(' + self.source + '->' + self.target + ')' 37 | -------------------------------------------------------------------------------- /tflite2onnx/mapping.py: -------------------------------------------------------------------------------- 1 | from onnx import TensorProto 2 | from tflite import TensorType 3 | 4 | 5 | def _inverseDict(d): 6 | return {v: k for k, v in d.items()} 7 | 8 | 9 | def _buildIndirectMapping(a, b): 10 | """Given a maps x->y, b maps y->z, return map of x->z.""" 11 | assert(len(a) == len(b)) 12 | assert(isinstance(list(b.keys())[0], type(list(a.values())[0]))) 13 | c = dict() 14 | for x in a.keys(): 15 | y = a[x] 16 | z = b[y] 17 | c[x] = z 18 | return c 19 | 20 | 21 | DTYPE_ONNX2NAME = { 22 | TensorProto.BOOL: 'bool', 23 | TensorProto.FLOAT16: 'float16', 24 | TensorProto.FLOAT: 'float32', 25 | TensorProto.INT16: 'int16', 26 | TensorProto.INT32: 'int32', 27 | TensorProto.INT64: 'int64', 28 | TensorProto.INT8: 'int8', 29 | TensorProto.UINT8: 'uint8', 30 | } 31 | 32 | DTYPE_NAME2ONNX = _inverseDict(DTYPE_ONNX2NAME) 33 | 34 | DTYPE_TFLITE2NAME = { 35 | TensorType.BOOL: 'bool', 36 | TensorType.FLOAT16: 'float16', 37 | TensorType.FLOAT32: 'float32', 38 | TensorType.INT16: 'int16', 39 | TensorType.INT32: 'int32', 40 | TensorType.INT64: 'int64', 41 | TensorType.INT8: 'int8', 42 | TensorType.UINT8: 'uint8', 43 | } 44 | 45 | DTYPE_NAME2TFLITE = _inverseDict(DTYPE_TFLITE2NAME) 46 | 47 | DTYPE_TFLITE2ONNX = _buildIndirectMapping(DTYPE_TFLITE2NAME, DTYPE_NAME2ONNX) 48 | DTYPE_ONNX2TFLITE = _buildIndirectMapping(DTYPE_ONNX2NAME, DTYPE_NAME2TFLITE) 49 | -------------------------------------------------------------------------------- /tflite2onnx/model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import tflite 3 | import onnx 4 | from onnx import helper 5 | 6 | from tflite2onnx.common import T2OBase 7 | from tflite2onnx.graph import Graph 8 | 9 | logger = logging.getLogger('tflite2onnx') 10 | 11 | 12 | class Model(T2OBase): 13 | """Everything helps to convert TFLite model to ONNX model""" 14 | def __init__(self, model: tflite.Model): 15 | super().__init__(model) 16 | self.tflite = model 17 | self.graphes = [] 18 | self.setInited() 19 | 20 | def parse(self): 21 | logger.debug("Parsing the Model...") 22 | graph_count = self.model.SubgraphsLength() 23 | if (graph_count != 1): 24 | raise NotImplementedError("ONNX supports one graph per model only, while TFLite has ", 25 | graph_count) 26 | tflg = self.model.Subgraphs(0) 27 | graph = Graph(self.model, tflg) 28 | self.graphes.append(graph) 29 | 30 | for g in self.graphes: 31 | g.parse() 32 | 33 | self.setParsed() 34 | 35 | def validate(self): 36 | pass 37 | 38 | def convert(self, explicit_layouts): 39 | self.parse() 40 | logger.debug("Converting...") 41 | for g in self.graphes: 42 | g.convert(explicit_layouts) 43 | 44 | # ONNXRuntime restrictions 45 | opset = helper.make_operatorsetid(onnx.defs.ONNX_DOMAIN, 11) 46 | attrs = { 47 | 'producer_name': 'tflite2onnx', 48 | 'ir_version': 6, 49 | 'opset_imports': [opset], 50 | } 51 | 52 | self.onnx = helper.make_model(self.graphes[0].onnx, **attrs) 53 | self.setConverted() 54 | 55 | def save(self, path: str): 56 | logger.debug("saving model as %s", path) 57 | assert(self.status.converted) 58 | onnx.save(self.onnx, path) 59 | onnx.checker.check_model(path) 60 | 61 | @property 62 | def shorty(self): 63 | return "Model holder" 64 | 65 | def __str__(self): 66 | return self.shorty 67 | -------------------------------------------------------------------------------- /tflite2onnx/op/__init__.py: -------------------------------------------------------------------------------- 1 | from tflite2onnx.op.activation import Activation 2 | from tflite2onnx.op.binary import Binary 3 | from tflite2onnx.op.common import OpFactory 4 | from tflite2onnx.op.common import Operator # noqa: F401 5 | from tflite2onnx.op.concat import Concat 6 | from tflite2onnx.op.conv import Conv 7 | from tflite2onnx.op.conv import TransposeConv 8 | from tflite2onnx.op.fullyconnected import FullyConnected 9 | from tflite2onnx.op.padding import Padding 10 | from tflite2onnx.op.pooling import Pooling 11 | from tflite2onnx.op.quantize import Quantize 12 | from tflite2onnx.op.reduce import Reduce 13 | from tflite2onnx.op.reshape import Reshape 14 | from tflite2onnx.op.resize import Resize 15 | from tflite2onnx.op.rsqrt import Rsqrt 16 | from tflite2onnx.op.slice import Slice 17 | from tflite2onnx.op.softmax import Softmax 18 | from tflite2onnx.op.split import Split 19 | from tflite2onnx.op.squared_difference import SquaredDifference 20 | from tflite2onnx.op.transpose import Transpose 21 | from tflite2onnx.op.unary import Unary 22 | 23 | 24 | OpFactory.register(Activation) 25 | OpFactory.register(Binary) 26 | OpFactory.register(Concat) 27 | OpFactory.register(Conv) 28 | OpFactory.register(FullyConnected) 29 | OpFactory.register(Padding) 30 | OpFactory.register(Pooling) 31 | OpFactory.register(Quantize) 32 | OpFactory.register(Reduce) 33 | OpFactory.register(Reshape) 34 | OpFactory.register(Resize) 35 | OpFactory.register(Rsqrt) 36 | OpFactory.register(Slice) 37 | OpFactory.register(Softmax) 38 | OpFactory.register(Split) 39 | OpFactory.register(SquaredDifference) 40 | OpFactory.register(Transpose) 41 | OpFactory.register(TransposeConv) 42 | OpFactory.register(Unary) 43 | -------------------------------------------------------------------------------- /tflite2onnx/op/activation.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import tflite 3 | 4 | from tflite2onnx.op.common import Operator 5 | 6 | logger = logging.getLogger('tflite2onnx') 7 | 8 | 9 | class Activation(Operator): 10 | TypeMapping = { 11 | tflite.BuiltinOperator.LOGISTIC: 'Sigmoid', 12 | tflite.BuiltinOperator.PRELU: 'PRelu', 13 | tflite.BuiltinOperator.RELU6: 'Clip', 14 | tflite.BuiltinOperator.RELU: 'Relu', 15 | } 16 | 17 | def __init__(self, TFactory, index, preset_opcode=None): 18 | super().__init__(TFactory, index) 19 | 20 | # TFLite op code of the activation, e.g. tflite.BuiltinOperator.RELU 21 | # Used for fused activation, where we cannot parse type from tflite object. 22 | self.preset_opcode = preset_opcode 23 | 24 | self.setInited() 25 | 26 | @property 27 | def type(self): 28 | if self.status.uninitialized: 29 | return 'Activation' 30 | else: 31 | assert(self.tflite or self.preset_opcode), "One of the two must be provided" 32 | if self.preset_opcode: 33 | opcode = self.preset_opcode 34 | else: 35 | op = self.tflite 36 | opcode = self.model.OperatorCodes(op.OpcodeIndex()).BuiltinCode() 37 | assert(opcode in self.TypeMapping) 38 | return self.TypeMapping[opcode] 39 | 40 | def parse(self): 41 | logger.debug("Parsing %s...", self.type) 42 | 43 | op = self.tflite 44 | opcode = self.model.OperatorCodes(op.OpcodeIndex()).BuiltinCode() 45 | assert(opcode in self.TypeMapping) 46 | 47 | if opcode == tflite.BuiltinOperator.PRELU: 48 | assert (op.InputsLength() == 2) 49 | else: 50 | assert(op.InputsLength() == 1) 51 | assert(op.OutputsLength() == 1) 52 | 53 | self.parseInput(0) 54 | 55 | if opcode == tflite.BuiltinOperator.RELU6: 56 | tmin = self.TFactory.createScalar('float32', 0.0) 57 | tmin.addConsumer(self) 58 | self.inputs.append(tmin) 59 | tmax = self.TFactory.createScalar('float32', 6.0) 60 | tmax.addConsumer(self) 61 | self.inputs.append(tmax) 62 | 63 | if opcode == tflite.BuiltinOperator.PRELU: 64 | # `alpha` should be a learned array with the same shape as `X` 65 | # But there is no `batch_size` dimension in its shape, 66 | # which will cause `out of index` exception during axis transform 67 | # so we expand its dimension by insert 1 to its shape 68 | alpha = self.parseInput(1) 69 | alpha.shape.insert(0, 1) 70 | 71 | self.parseOutput(0) 72 | 73 | self.setParsed() 74 | 75 | def propagatableTensors(self): 76 | return self.inputs + self.outputs 77 | 78 | def transform(self): 79 | pass 80 | 81 | 82 | def handleFusedActivation(master, option, output, intermediate=None): 83 | """Handle FusedActivationFunction for master node. 84 | 85 | For master node such as Conv and FC, there could be 86 | FusedActivationFunction. If there were, create a activation node 87 | `ActOp` and corresponding tensor `actTensor`, and insert them 88 | into the original graph. E.g. for subgraph `[Conv] -> ` 89 | with `ReLU`, we generate `[Conv] -> -> [ActOp(ReLU)] -> `. 90 | 91 | Sometimes, there will be other nodes (quantization node for example) 92 | inserted between the *master* node and activation node. For such case, 93 | we cannot attach activation node to master node directly, e.g. the input graph 94 | will be like `[Conv] -> -> [Dequantize] -> `. Therefore, generating 95 | `[Conv] -> -> [Dequantize] -> -> [ActOp(ReLU)] -> `. 96 | 97 | So this util generates a pattern ` -> [ActOp(ReLU)]` and 98 | insert to the original graph. In general, we need: 99 | * `master`: the *mater* node, and usualy which activation attached to. 100 | * `option`: the option parsed from the original master node. 101 | * `output`: the tensor that act as output of the whole pattern. 102 | * `intermediate`: the node that activation attach to, usually same as `master`. 103 | """ 104 | FusedActFunc2OpType = { 105 | tflite.ActivationFunctionType.RELU6: tflite.BuiltinOperator.RELU6, 106 | tflite.ActivationFunctionType.RELU: tflite.BuiltinOperator.RELU, 107 | } 108 | 109 | logger.debug("Handling FusedActivationFunction for %s", master.shorty) 110 | faf = option.FusedActivationFunction() 111 | if faf is tflite.ActivationFunctionType.NONE: 112 | return 113 | intermediate = master if intermediate is None else intermediate 114 | 115 | assert(faf in FusedActFunc2OpType) 116 | act_type = FusedActFunc2OpType[faf] 117 | assert(output.status.parsed) 118 | 119 | # create tensor that from Conv/FC to Activation 120 | iname = 'TFLITE2ONNX_FAF_%s' % output.name 121 | input = intermediate.TFactory.getWithRef(output, iname, True) 122 | input.setParsed() 123 | 124 | intermediate.replaceOutput(output, input) 125 | input.addProducer(intermediate) 126 | 127 | # create the activation node, and let intermediate node output to be its'. 128 | if act_type in [tflite.BuiltinOperator.RELU, tflite.BuiltinOperator.RELU6]: 129 | act = Activation(intermediate.TFactory, -1, preset_opcode=act_type) 130 | 131 | input.addConsumer(act) 132 | act.inputs.append(input) 133 | 134 | if act_type == tflite.BuiltinOperator.RELU6: 135 | tmin = intermediate.TFactory.createScalar('float32', 0.0) 136 | tmin.addConsumer(act) 137 | act.inputs.append(tmin) 138 | tmax = intermediate.TFactory.createScalar('float32', 6.0) 139 | tmax.addConsumer(act) 140 | act.inputs.append(tmax) 141 | 142 | output.replaceProducer(intermediate, act) 143 | act.outputs.append(output) 144 | 145 | act.setParsed() 146 | 147 | # this is where we need *master* node, all tflite2onnx generated 148 | # node shall be added as `pre` or `post` of the node that has a TFLite op. 149 | master.post.append(act) 150 | else: 151 | raise NotImplementedError("Unsupported fused ActivationFunctionType") 152 | -------------------------------------------------------------------------------- /tflite2onnx/op/binary.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import logging 3 | import tflite 4 | import numpy as np 5 | 6 | from tflite2onnx import mapping 7 | from tflite2onnx.op.common import Operator 8 | from tflite2onnx.op.activation import handleFusedActivation 9 | from tflite2onnx.op.reshape import Reshape 10 | 11 | logger = logging.getLogger('tflite2onnx') 12 | 13 | 14 | class Binary(Operator): 15 | TypeMapping = { 16 | tflite.BuiltinOperator.ADD: 'Add', 17 | tflite.BuiltinOperator.MUL: 'Mul', 18 | tflite.BuiltinOperator.SUB: 'Sub', 19 | tflite.BuiltinOperator.POW: 'Pow', 20 | } 21 | 22 | OptionMapping = { 23 | tflite.BuiltinOperator.ADD: tflite.AddOptions, 24 | tflite.BuiltinOperator.MUL: tflite.MulOptions, 25 | tflite.BuiltinOperator.SUB: tflite.SubOptions, 26 | } 27 | 28 | def __init__(self, TFactory, index): 29 | super().__init__(TFactory, index) 30 | self.setInited() 31 | 32 | @property 33 | def type(self): 34 | if self.status.uninitialized: 35 | return 'Binary' 36 | else: 37 | op = self.tflite 38 | opcode = self.model.OperatorCodes(op.OpcodeIndex()).BuiltinCode() 39 | assert(opcode in self.TypeMapping) 40 | return self.TypeMapping[opcode] 41 | 42 | def fakeBroadcast(self): 43 | # Binary operators need to broadcast shape explicitly here since 44 | # they may not be broadcastable after layout propagration. 45 | # We don't really broadcast here, but extend shape with 1. 46 | assert(self.status.initialized) 47 | a = self.inputs[0] 48 | b = self.inputs[1] 49 | output = self.outputs[0] 50 | if (len(a.shape) == len(b.shape)): 51 | return 52 | logger.info("Inserting `Reshape` for fake broadcasting, be carefull for the layout") 53 | 54 | align_a, new_shape = alignDimension(a.shape, b.shape) 55 | todo = a if align_a else b 56 | assert(len(new_shape) == len(output.shape)) 57 | 58 | new_t_name = 'TFLITE2ONNX_Reshape_%s' % todo.name 59 | new_t = self.TFactory.getWithRef(todo, new_t_name, True) 60 | new_t.shape = new_shape 61 | new_t.setParsed() 62 | 63 | shape_t_name = 'TFLITE2ONNX_NewShape_%s' % todo.name 64 | shape_t = self.TFactory.getWithRef(todo, shape_t_name, True) 65 | shape_t.dtype = mapping.DTYPE_NAME2ONNX['int64'] 66 | shape_t.shape = (len(new_shape),) 67 | shape_t.data = np.array(new_shape) 68 | shape_t.setParsed() 69 | 70 | reshape = Reshape(self.TFactory, -1) 71 | reshape.forFakeBroadcasting = True 72 | 73 | reshape.inputs.append(todo) 74 | todo.replaceConsumer(self, reshape) 75 | self.replaceInput(todo, new_t) 76 | 77 | reshape.inputs.append(shape_t) 78 | shape_t.addConsumer(reshape) 79 | 80 | reshape.outputs.append(new_t) 81 | new_t.addProducer(reshape) 82 | new_t.addConsumer(self) 83 | reshape.setParsed() 84 | 85 | self.pre.append(reshape) 86 | 87 | def parse(self): 88 | logger.debug("Parsing %s...", self.type) 89 | 90 | op = self.tflite 91 | opcode = self.model.OperatorCodes(op.OpcodeIndex()).BuiltinCode() 92 | assert(opcode in self.TypeMapping) 93 | 94 | assert(op.InputsLength() == 2) 95 | assert(op.OutputsLength() == 1) 96 | 97 | self.parseInput(0) 98 | self.parseInput(1) 99 | ot = self.parseOutput(0) 100 | 101 | self.fakeBroadcast() 102 | 103 | # options 104 | op_opt = op.BuiltinOptions() 105 | if opcode in self.OptionMapping: 106 | option = self.OptionMapping[opcode]() 107 | option.Init(op_opt.Bytes, op_opt.Pos) 108 | 109 | handleFusedActivation(self, option, ot) 110 | 111 | self.setParsed() 112 | 113 | def propagatableTensors(self): 114 | return self.inputs + self.outputs 115 | 116 | def transform(self): 117 | pass 118 | 119 | 120 | def alignDimension(a, b): 121 | """Align the dimension of the shorter one to the longer one. 122 | 123 | We don't really broadcast tensors during converting, instead, align 124 | dimensions of the two inputs such that the tensors have same dimensions 125 | which is _layout handling compatible_. 126 | """ 127 | align_a = len(a) < len(b) 128 | to_align = a if align_a else b 129 | ref = b if align_a else a 130 | 131 | size = len(ref) - len(to_align) 132 | aligned = copy.deepcopy(to_align) 133 | for i in range(size): 134 | aligned.insert(0, 1) 135 | 136 | return (align_a, aligned) 137 | 138 | 139 | # wrapper is used here to override Binary.type property 140 | class PowerWrapper(Binary): 141 | @property 142 | def type(self): 143 | return 'Pow' 144 | -------------------------------------------------------------------------------- /tflite2onnx/op/common.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import tflite 3 | from onnx import helper 4 | 5 | from tflite2onnx.common import T2OBase 6 | 7 | logger = logging.getLogger('tflite2onnx') 8 | 9 | 10 | class Operator(T2OBase): 11 | TypeMapping = dict() 12 | 13 | def __init__(self, TFactory, index): 14 | super().__init__(TFactory.model, TFactory.graph, index) 15 | self.TFactory = TFactory 16 | self.tflite = self.graph.Operators(index) if index >= 0 else None 17 | self.inputs = [] 18 | self.outputs = [] 19 | self.pre = [] # ops that before this op which to enable TFLite op 20 | self.post = [] # ops that after this op which to enable TFLite op 21 | self.attrs = dict() # One dict to hold all ONNX operator attributes 22 | 23 | @property 24 | def type(self): 25 | raise NotImplementedError("Method Operator.type() must be overrided!") 26 | 27 | def propagatableTensors(self): 28 | """Get all layout propagable tensors of this operator. 29 | 30 | When we propagate layouts across the graph: 31 | 1. Some operators may stop the propagation 32 | a) An operator assumes layouts of its tensors, `Conv` for example. 33 | Such operator needs to define the layouts of its tensors explicitly. 34 | b) An operator breaks layout semantic, `Reshape` for example. 35 | Tensors connected to this operator should be propagated. 36 | And the operator may need special handling regarding layout. 37 | 2. Others may not - propagatable: 38 | a) An operator that is transparent to layout, such as Add. 39 | Just propagate the layouts. 40 | b) Layout can propagate across tensors of an operator, but the operator 41 | itself has attribution that is sensitive to layout. 42 | Operator needs special handling after propagation. 43 | This is defined per operator. 44 | 45 | To handle this, we firstly propagate layouts of tensors across the graph, 46 | and then update attributes of operators accordingly. 47 | """ 48 | raise NotImplementedError("Method %s.propagatableTensors() must be overrided!" % self.type) 49 | 50 | def transform(self): 51 | """Transform the operator attributions w.r.t. propagated layouts. 52 | 53 | The attributions could be a tensor that describing layout related things. 54 | Operators that defined as 1.a, 1.b and 2.b in `layoutPropagatable()` 55 | are such cases. But not all of them need special treatment. 56 | For example, `Conv` doesn't need additional processing after propagation. 57 | 58 | This must be called after the layouts have been propagated across graph. 59 | """ 60 | raise NotImplementedError("Method %s.transform() must be overrided!" % self.type) 61 | 62 | @property 63 | def str(self): 64 | return '[' + self.name + '] (' + self.type + ')' 65 | 66 | def parseInput(self, index, layout=None, is_bias=False): 67 | ii = self.tflite.Inputs(index) 68 | it = self.TFactory.get(ii, layout, is_bias) 69 | it.parse() 70 | it.addConsumer(self) 71 | self.inputs.append(it) 72 | return it 73 | 74 | def parseOutput(self, index, layout=None): 75 | oi = self.tflite.Outputs(index) 76 | ot = self.TFactory.get(oi, layout) 77 | ot.parse() 78 | ot.addProducer(self) 79 | self.outputs.append(ot) 80 | return ot 81 | 82 | def replaceInput(self, original, new): 83 | logger.debug("Replacing %s input %s with %s", self.shorty, original.shorty, new.shorty) 84 | assert(original in self.inputs) 85 | for i, item in enumerate(self.inputs): 86 | if item is original: 87 | self.inputs[i] = new 88 | return 89 | 90 | def replaceOutput(self, original, new): 91 | logger.debug("Replacing %s output %s with %s", self.shorty, original.shorty, new.shorty) 92 | assert(original in self.outputs) 93 | for i, item in enumerate(self.outputs): 94 | if item is original: 95 | self.outputs[i] = new 96 | return 97 | 98 | def setParsed(self): 99 | """Name the operator (if not yet) and change to initialized. 100 | 101 | Assume that the outputs won't change after parsed. 102 | * If the operator is a helper in TFLITE2ONNX, it should have been named already. 103 | * If the operator is original in TFLite, using name of its first output tensor. 104 | """ 105 | self.name = self.outputs[0].name if self.name is None else self.name 106 | super().setParsed() 107 | 108 | def validate(self): 109 | assert(len(self.outputs) >= 1), "Operator should produce something" 110 | 111 | def convert(self): 112 | logger.debug("Converting %s...", self.shorty) 113 | for t in self.inputs + self.outputs: 114 | t.convert() 115 | self.attrs['name'] = self.name 116 | inames = [t.name for t in self.inputs] 117 | onames = [t.name for t in self.outputs] 118 | self.onnx = helper.make_node(self.type, inames, onames, **self.attrs) 119 | self.setConverted() 120 | 121 | @property 122 | def shorty(self): 123 | return '[%s](%s)' % (self.name, self.type) 124 | 125 | def __str__(self): 126 | inames = str([t.name for t in self.inputs]) 127 | onames = str([t.name for t in self.outputs]) 128 | return '%s attr%s: %s -> %s' % (self.shorty, self.attrs, inames, onames) 129 | 130 | 131 | class OpFactory: 132 | """The factory for creating operater converter objects.""" 133 | 134 | registry = dict() 135 | 136 | @staticmethod 137 | def register(converter): 138 | opcs = converter.TypeMapping.keys() 139 | for opc in opcs: 140 | assert(opc not in OpFactory.registry) 141 | OpFactory.registry[opc] = converter 142 | 143 | def __init__(self, TFactory): 144 | self.model = TFactory.model 145 | self.graph = TFactory.graph 146 | self.TFactory = TFactory 147 | 148 | def create(self, index): 149 | op = self.graph.Operators(index) 150 | opcode = self.model.OperatorCodes(op.OpcodeIndex()).BuiltinCode() 151 | if opcode not in OpFactory.registry: 152 | if opcode in tflite.BUILTIN_OPCODE2NAME: 153 | name = tflite.opcode2name(opcode) 154 | raise NotImplementedError("Unsupported TFLite OP: {} {}!".format(opcode, name)) 155 | else: 156 | raise ValueError("Opcode {} is not a TFLite builtin operator!".format(opcode)) 157 | 158 | op_converter = OpFactory.registry[opcode] 159 | return op_converter(self.TFactory, index) 160 | 161 | @staticmethod 162 | def dump(): 163 | return "Registered OP converter: %d" % len(OpFactory.registry) 164 | 165 | def __str__(self): 166 | return OpFactory.dump() 167 | -------------------------------------------------------------------------------- /tflite2onnx/op/concat.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import tflite 3 | 4 | from tflite2onnx.op.activation import handleFusedActivation 5 | from tflite2onnx.op.common import Operator 6 | 7 | logger = logging.getLogger('tflite2onnx') 8 | 9 | 10 | class Concat(Operator): 11 | TypeMapping = { 12 | tflite.BuiltinOperator.CONCATENATION: 'Concat', 13 | } 14 | 15 | def __init__(self, TFactory, index): 16 | super().__init__(TFactory, index) 17 | 18 | self.attrs['axis'] = -1 19 | 20 | self.setInited() 21 | 22 | @property 23 | def type(self): 24 | return 'Concat' 25 | 26 | def parse(self): 27 | logger.debug("Parsing %s...", self.shorty) 28 | op = self.tflite 29 | opcode = self.model.OperatorCodes(op.OpcodeIndex()).BuiltinCode() 30 | assert(opcode in self.TypeMapping) 31 | 32 | assert(op.InputsLength() >= 1) 33 | assert(op.OutputsLength() == 1) 34 | 35 | for i in range(op.InputsLength()): 36 | self.parseInput(i) 37 | 38 | op_opt = op.BuiltinOptions() 39 | option = tflite.ConcatenationOptions() 40 | option.Init(op_opt.Bytes, op_opt.Pos) 41 | self.attrs['axis'] = option.Axis() 42 | 43 | self.parseOutput(0) 44 | 45 | handleFusedActivation(self, option, self.outputs[0]) 46 | 47 | self.setParsed() 48 | 49 | def propagatableTensors(self): 50 | return self.inputs + self.outputs 51 | 52 | def transform(self): 53 | logger.debug("Transforming %s...", self.shorty) 54 | layout = self.outputs[0].layout 55 | if layout is not None: 56 | axis = self.attrs['axis'] 57 | axis = axis if axis >= 0 else (axis + len(layout.perm)) 58 | self.attrs['axis'] = layout.perm.index(axis) 59 | -------------------------------------------------------------------------------- /tflite2onnx/op/conv.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import tflite 3 | 4 | from tflite2onnx.layout import Layout 5 | from tflite2onnx.op.activation import handleFusedActivation 6 | from tflite2onnx.op.common import Operator 7 | from tflite2onnx.op.padding import computePaddingSize 8 | 9 | logger = logging.getLogger('tflite2onnx') 10 | 11 | 12 | class Conv(Operator): 13 | TypeMapping = { 14 | tflite.BuiltinOperator.CONV_2D: 'Conv', 15 | tflite.BuiltinOperator.DEPTHWISE_CONV_2D: 'Conv', 16 | } 17 | 18 | def __init__(self, TFactory, index): 19 | super().__init__(TFactory, index) 20 | 21 | self.attrs['kernel_shape'] = [] 22 | self.attrs['strides'] = [] 23 | # ONNX: This attribute cannot be used simultaneously with `auto_pad` attribute. 24 | # re-initialize during self.parse(), as it needs the shape of input. 25 | # We prefer `auto_pad`, however ONNXRuntime doesn't support 26 | # `dilation` + `auto_pad`, such that we use `pads` to workaround it. 27 | self.attrs['pads'] = [0, 0, 0, 0] 28 | # XXX Not enabled as ONNXRuntime has limitation to infer pads for non-1 dilation 29 | # self.attrs['auto_pad'] = 'SAME_UPPER' # See ComputePaddingHeightWidth() of TFLite 30 | self.attrs['dilations'] = [] 31 | self.attrs['group'] = -1 32 | 33 | self.setInited() 34 | 35 | @property 36 | def type(self): 37 | return 'Conv' 38 | 39 | @property 40 | def isDepthwise(self): 41 | op = self.tflite 42 | opcode = self.model.OperatorCodes(op.OpcodeIndex()).BuiltinCode() 43 | return (opcode is tflite.BuiltinOperator.DEPTHWISE_CONV_2D) 44 | 45 | def parse(self): 46 | logger.debug("Parsing %s...", self.type) 47 | op = self.tflite 48 | opcode = self.model.OperatorCodes(op.OpcodeIndex()).BuiltinCode() 49 | assert(opcode in self.TypeMapping) 50 | 51 | assert(op.InputsLength() == 3), "TFLite Conv always has bias" 52 | assert(op.OutputsLength() == 1) 53 | 54 | # input 55 | ilayout = Layout('NHWC', 'NCHW') 56 | it = self.parseInput(0, ilayout) 57 | 58 | # weight 59 | wlayout = Layout('CHWM', 'MCHW') if self.isDepthwise else Layout('OHWI', 'OIHW') 60 | wt = self.parseInput(1, wlayout) 61 | 62 | # bias 63 | self.parseInput(2, is_bias=True) 64 | 65 | # output 66 | olayout = Layout('NHWC', 'NCHW') 67 | ot = self.parseOutput(0, olayout) 68 | 69 | # options 70 | op_opt = op.BuiltinOptions() 71 | option = tflite.DepthwiseConv2DOptions() if self.isDepthwise else tflite.Conv2DOptions() 72 | option.Init(op_opt.Bytes, op_opt.Pos) 73 | 74 | self.attrs['dilations'] = [option.DilationHFactor(), option.DilationWFactor()] 75 | self.attrs['group'] = wt.shape[3] if self.isDepthwise else 1 76 | self.attrs['kernel_shape'] = wt.shape[1:3] 77 | self.attrs['strides'] = [option.StrideH(), option.StrideW()] 78 | # XXX Not enabled as ONNXRuntime has limitation to infer pads for non-1 dilation 79 | # self.attrs['auto_pad'] = PaddingMapping[option.Padding()] 80 | if self.isDepthwise: 81 | assert(option.DepthMultiplier() == 1) 82 | self.attrs['pads'] = computePaddingSize(option.Padding(), it.shape[1:3], 83 | self.attrs['kernel_shape'], 84 | self.attrs['strides'], self.attrs['dilations']) 85 | 86 | handleFusedActivation(self, option, ot) 87 | 88 | self.setParsed() 89 | 90 | def propagatableTensors(self): 91 | return list() 92 | 93 | def transform(self): 94 | pass 95 | 96 | 97 | class TransposeConv(Operator): 98 | TypeMapping = { 99 | tflite.BuiltinOperator.TRANSPOSE_CONV: 'ConvTranspose', 100 | } 101 | 102 | # FIXME: cases that untested yet (we are not fully understand the semantic gap) 103 | # 1. Special output shape for VALID padding 104 | # 2. Different input/output shape for SAME padding 105 | 106 | def __init__(self, TFactory, index): 107 | super().__init__(TFactory, index) 108 | 109 | self.attrs['dilations'] = [1, 1] # TFLite TransposeConv doesn't have dilation 110 | self.attrs['group'] = 1 # TFLite TransposeConv doesn't have group 111 | self.attrs['kernel_shape'] = [] 112 | # self.attrs['output_padding'] = [] 113 | self.attrs['output_shape'] = [] 114 | # pads are overwrited by output_shape 115 | # self.attrs['auto_pad'] = 'NOTSET' 116 | # self.attrs['pads'] = [] 117 | self.attrs['strides'] = [] 118 | 119 | self.setInited() 120 | 121 | @property 122 | def type(self): 123 | return 'ConvTranspose' 124 | 125 | def parse(self): 126 | logger.debug("Parsing %s...", self.type) 127 | op = self.tflite 128 | opcode = self.model.OperatorCodes(op.OpcodeIndex()).BuiltinCode() 129 | 130 | assert(opcode in self.TypeMapping) 131 | assert(op.InputsLength() == 3) 132 | assert(op.OutputsLength() == 1) 133 | 134 | # oshape 135 | osi = op.Inputs(0) 136 | oshape = self.TFactory.getData(self.model, self.graph, osi, 'int32') 137 | 138 | # X 139 | ilayout = Layout('NHWC', 'NCHW') 140 | self.parseInput(2, ilayout) 141 | 142 | # weight 143 | wlayout = Layout('OHWI', 'IOHW') 144 | wt = self.parseInput(1, wlayout) 145 | 146 | # FIXME: we don't have a model containing bias. 147 | 148 | # output 149 | olayout = Layout('NHWC', 'NCHW') 150 | ot = self.parseOutput(0, olayout) 151 | assert((ot.shape == oshape).all()) 152 | 153 | # options 154 | op_opt = op.BuiltinOptions() 155 | option = tflite.TransposeConvOptions() 156 | option.Init(op_opt.Bytes, op_opt.Pos) 157 | 158 | self.attrs['kernel_shape'] = wt.shape[1:3] 159 | self.attrs['strides'] = [option.StrideH(), option.StrideW()] 160 | oslayout = Layout('NHWC', 'NCHW') 161 | self.attrs['output_shape'] = oslayout.transform(oshape) 162 | self.setParsed() 163 | 164 | def propagatableTensors(self): 165 | return list() 166 | 167 | def transform(self): 168 | pass 169 | -------------------------------------------------------------------------------- /tflite2onnx/op/fullyconnected.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import tflite 3 | 4 | from tflite2onnx.op.activation import handleFusedActivation 5 | from tflite2onnx.op.common import Operator 6 | 7 | logger = logging.getLogger('tflite2onnx') 8 | 9 | 10 | class FullyConnected(Operator): 11 | TypeMapping = { 12 | tflite.BuiltinOperator.FULLY_CONNECTED: 'Gemm', 13 | } 14 | 15 | def __init__(self, TFactory, index): 16 | super().__init__(TFactory, index) 17 | 18 | # raw default values 19 | self.attrs['alpha'] = 1.0 20 | self.attrs['beta'] = 1.0 21 | # TFLite Fully Connected: A (M, K), B (N, K) 22 | self.attrs['transA'] = 0 23 | self.attrs['transB'] = 1 24 | 25 | self.setInited() 26 | 27 | @property 28 | def type(self): 29 | return 'Gemm' 30 | 31 | def parse(self): 32 | logger.debug("Parsing %s...", self.type) 33 | op = self.tflite 34 | opcode = self.model.OperatorCodes(op.OpcodeIndex()).BuiltinCode() 35 | assert(opcode in self.TypeMapping) 36 | 37 | assert(op.InputsLength() == 3), "TFLite Fullly Connected always has bias" 38 | assert(op.OutputsLength() == 1) 39 | 40 | # input 41 | self.parseInput(0) 42 | 43 | # weight 44 | self.parseInput(1) 45 | 46 | # bias 47 | self.parseInput(2, is_bias=True) 48 | 49 | # output 50 | ot = self.parseOutput(0) 51 | 52 | # options 53 | op_opt = op.BuiltinOptions() 54 | option = tflite.FullyConnectedOptions() 55 | option.Init(op_opt.Bytes, op_opt.Pos) 56 | 57 | assert(not option.KeepNumDims()) 58 | assert(option.WeightsFormat() is tflite.FullyConnectedOptionsWeightsFormat.DEFAULT) 59 | 60 | handleFusedActivation(self, option, ot) 61 | 62 | self.setParsed() 63 | 64 | def propagatableTensors(self): 65 | return list() 66 | 67 | def transform(self): 68 | pass 69 | -------------------------------------------------------------------------------- /tflite2onnx/op/padding.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import numpy as np 3 | import tflite 4 | 5 | from tflite2onnx.op.common import Operator 6 | 7 | logger = logging.getLogger('tflite2onnx') 8 | 9 | PaddingMapping = { 10 | tflite.Padding.SAME: 'SAME_UPPER', 11 | tflite.Padding.VALID: 'VALID', 12 | } 13 | 14 | 15 | class Padding(Operator): 16 | TypeMapping = { 17 | tflite.BuiltinOperator.PAD: 'Pad', 18 | tflite.BuiltinOperator.MIRROR_PAD: 'Pad', 19 | } 20 | 21 | def __init__(self, TFactory, index): 22 | super().__init__(TFactory, index) 23 | 24 | self.attrs['mode'] = 'constant' 25 | 26 | self.setInited() 27 | 28 | @property 29 | def type(self): 30 | if self.status.uninitialized: 31 | return 'Pad' 32 | else: 33 | opcode = self.model.OperatorCodes(self.tflite.OpcodeIndex()).BuiltinCode() 34 | assert(opcode in self.TypeMapping) 35 | return self.TypeMapping[opcode] 36 | 37 | def parse(self): 38 | logger.debug("Parsing %s...", self.shorty) 39 | op = self.tflite 40 | opcode = self.model.OperatorCodes(op.OpcodeIndex()).BuiltinCode() 41 | assert(opcode in self.TypeMapping) 42 | 43 | if opcode is tflite.BuiltinOperator.MIRROR_PAD: 44 | self.attrs['mode'] = 'reflect' 45 | else: 46 | self.attrs['mode'] = 'constant' 47 | 48 | assert(op.InputsLength() == 2) 49 | assert(op.OutputsLength() == 1) 50 | 51 | it = self.parseInput(0) 52 | 53 | pt = self.parseInput(1) 54 | assert(len(pt.shape) == 2) 55 | assert(pt.shape[0] == len(it.shape)) 56 | assert(pt.shape[1] == 2) 57 | assert(pt.isInitializer) 58 | # bridge semantic gap 59 | pt.asDtype('int64') 60 | 61 | self.parseOutput(0) 62 | 63 | self.setParsed() 64 | 65 | def propagatableTensors(self): 66 | return [self.inputs[0], self.outputs[0]] 67 | 68 | def transform(self): 69 | # Padding.transform() handls TFLite/ONNX semantic gap in addition to layout gap 70 | # TensorFlow (Lite) pads is `[n, 2]` where `[i, 0]` is _begin_ and `[i, 1]` is _end_ 71 | # ONNX pads is `[n * 2]` sequenced as `[x1_begin, x2_begin,...,x1_end, x2_end,...]` 72 | layout = self.inputs[0].layout 73 | pt = self.inputs[1] 74 | pads = pt.data 75 | pads = np.reshape(pads, pt.shape) 76 | if layout is None: 77 | pads = np.transpose(pads) 78 | else: 79 | pads_begin = pads[:, 0] 80 | pads_end = pads[:, 1] 81 | pads_begin = layout.transform(pads_begin) 82 | pads_end = layout.transform(pads_end) 83 | pads = np.array([pads_begin, pads_end]) 84 | pt.data = pads.flatten() 85 | pt.shape = [np.prod(pt.shape), ] 86 | 87 | 88 | # https://github.com/tensorflow/tensorflow/blob/v2.2.0/tensorflow/lite/kernels/padding.h#L58 89 | def computePaddingSize(padding_mode, input_size, kernel_size, stride, dilation): 90 | assert(len(input_size) == len(kernel_size)) 91 | assert(len(input_size) == len(stride)) 92 | assert(len(input_size) == len(dilation)) 93 | 94 | # compute output shape 95 | ones = np.ones_like(input_size) 96 | effective_filter_size = np.add(np.multiply(np.subtract(kernel_size, ones), dilation), ones) 97 | if padding_mode is tflite.Padding.SAME: 98 | oshape = np.divide(np.subtract(np.add(input_size, stride), ones), stride) 99 | elif padding_mode is tflite.Padding.VALID: 100 | oshape = np.divide(np.subtract(np.add(input_size, stride), effective_filter_size), stride) 101 | else: 102 | raise ValueError("Unknown padding mode!") 103 | oshape = oshape.astype('int') 104 | 105 | # infer the padding 106 | total_padding = np.add(np.multiply(np.subtract(oshape, ones), stride), 107 | np.subtract(effective_filter_size, input_size)) 108 | total_padding = np.maximum(total_padding, np.zeros_like(input_size)) 109 | total_padding = total_padding.astype('int') 110 | 111 | # ONNX semantic 112 | pre_padding = total_padding // 2 113 | post_padding = np.subtract(total_padding, pre_padding) 114 | padding = np.concatenate((pre_padding, post_padding)) 115 | 116 | return padding.flatten() 117 | -------------------------------------------------------------------------------- /tflite2onnx/op/pooling.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import tflite 3 | 4 | from tflite2onnx.layout import Layout 5 | from tflite2onnx.op.activation import handleFusedActivation 6 | from tflite2onnx.op.common import Operator 7 | from tflite2onnx.op.padding import PaddingMapping 8 | 9 | logger = logging.getLogger('tflite2onnx') 10 | 11 | 12 | class Pooling(Operator): 13 | TypeMapping = { 14 | tflite.BuiltinOperator.AVERAGE_POOL_2D: 'AveragePool', 15 | tflite.BuiltinOperator.MAX_POOL_2D: 'MaxPool', 16 | } 17 | 18 | def __init__(self, TFactory, index): 19 | super().__init__(TFactory, index) 20 | 21 | self.attrs['kernel_shape'] = [] 22 | self.attrs['strides'] = [] 23 | self.attrs['auto_pad'] = 'SAME_UPPER' # See ComputePaddingHeightWidth() of TFLite 24 | # ceil_mod = 0 25 | 26 | self.setInited() 27 | 28 | @property 29 | def type(self): 30 | if self.status.uninitialized: 31 | return 'Pooling' 32 | else: 33 | op = self.tflite 34 | opcode = self.model.OperatorCodes(op.OpcodeIndex()).BuiltinCode() 35 | assert(opcode in self.TypeMapping) 36 | return self.TypeMapping[opcode] 37 | 38 | def parse(self): 39 | logger.debug("Parsing %s...", self.type) 40 | op = self.tflite 41 | opcode = self.model.OperatorCodes(op.OpcodeIndex()).BuiltinCode() 42 | assert(opcode in self.TypeMapping) 43 | 44 | assert(op.InputsLength() == 1) 45 | assert(op.OutputsLength() == 1) 46 | 47 | ilayout = Layout('NHWC', 'NCHW') 48 | self.parseInput(0, ilayout) 49 | 50 | op_opt = op.BuiltinOptions() 51 | option = tflite.Pool2DOptions() 52 | option.Init(op_opt.Bytes, op_opt.Pos) 53 | self.attrs['auto_pad'] = PaddingMapping[option.Padding()] 54 | self.attrs['kernel_shape'] = [option.FilterHeight(), option.FilterWidth()] 55 | self.attrs['strides'] = [option.StrideH(), option.StrideW()] 56 | 57 | olayout = Layout('NHWC', 'NCHW') 58 | ot = self.parseOutput(0, olayout) 59 | 60 | handleFusedActivation(self, option, ot) 61 | 62 | self.setParsed() 63 | 64 | def propagatableTensors(self): 65 | return list() 66 | 67 | def transform(self): 68 | pass 69 | -------------------------------------------------------------------------------- /tflite2onnx/op/quantize.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import tflite 3 | from onnx import TensorProto 4 | 5 | from tflite2onnx.op.common import Operator 6 | 7 | logger = logging.getLogger('tflite2onnx') 8 | 9 | 10 | class Quantize(Operator): 11 | TypeMapping = { 12 | tflite.BuiltinOperator.QUANTIZE: 'QuantizeLinear', 13 | tflite.BuiltinOperator.DEQUANTIZE: 'DequantizeLinear', 14 | } 15 | 16 | def __init__(self, TFactory, index): 17 | super().__init__(TFactory, index) 18 | 19 | # self.axis = 1 20 | 21 | self.setInited() 22 | 23 | @property 24 | def type(self): 25 | return 'QuantizeLinear' if self.isQuantize else 'DequantizeLinear' 26 | 27 | @property 28 | def isQuantize(self): 29 | if self.status.parsed: 30 | return self.inputs[0].dtype is TensorProto.FLOAT 31 | else: 32 | # to cover the case when isQuantize is called from logger 33 | # it may happen before the actual parsing begins 34 | opcode = self.model.OperatorCodes(self.tflite.OpcodeIndex()).BuiltinCode() 35 | return opcode is tflite.BuiltinOperator.QUANTIZE 36 | 37 | def parse(self): 38 | logger.debug("Parsing %s...", self.shorty) 39 | op = self.tflite 40 | opcode = self.model.OperatorCodes(op.OpcodeIndex()).BuiltinCode() 41 | assert(opcode in self.TypeMapping) 42 | 43 | assert(op.InputsLength() == 1) 44 | assert(op.OutputsLength() == 1) 45 | self.parseInput(0) 46 | self.parseOutput(0) 47 | 48 | self.setParsed() 49 | 50 | def propagatableTensors(self): 51 | return self.inputs + self.outputs 52 | 53 | def dequantize(self): 54 | if self.isQuantize: 55 | ft = self.inputs[0] 56 | qt = self.outputs[0] 57 | else: 58 | qt = self.inputs[0] 59 | ft = self.outputs[0] 60 | 61 | ft.dequantize() 62 | # assert(qt.quantized) 63 | 64 | st = self.TFactory.createQuantScale(qt) 65 | st.addConsumer(self) 66 | self.inputs.append(st) 67 | zpt = self.TFactory.createQuantZeroPoint(qt) 68 | zpt.addConsumer(self) 69 | self.inputs.append(zpt) 70 | 71 | def transform(self): 72 | pass 73 | 74 | def validate(self): 75 | quant_dtype = self.outputs[0].dtype if self.isQuantize else self.inputs[0].dtype 76 | if quant_dtype not in [TensorProto.UINT8, TensorProto.INT8, TensorProto.INT32]: 77 | raise ValueError("Unsupported quantization type due to ONNX operator semantic. " 78 | "See https://github.com/zhenhuaw-me/tflite2onnx/blob/master/docs/faq.md#fp16-quantization-model-doesnt-work") # noqa: E501 79 | -------------------------------------------------------------------------------- /tflite2onnx/op/reduce.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import tflite 3 | 4 | from tflite2onnx.tensor import TensorFactory 5 | from tflite2onnx.op.common import Operator 6 | 7 | logger = logging.getLogger('tflite2onnx') 8 | 9 | 10 | class Reduce(Operator): 11 | TypeMapping = { 12 | tflite.BuiltinOperator.MEAN: 'ReduceMean', 13 | } 14 | 15 | def __init__(self, TFactory, index): 16 | super().__init__(TFactory, index) 17 | 18 | self.attrs['axes'] = None 19 | self.attrs['keepdims'] = 0 20 | 21 | self.setInited() 22 | 23 | @property 24 | def type(self): 25 | if self.status.uninitialized: 26 | return 'Reduce' 27 | else: 28 | op = self.tflite 29 | opcode = self.model.OperatorCodes(op.OpcodeIndex()).BuiltinCode() 30 | assert(opcode in self.TypeMapping) 31 | return self.TypeMapping[opcode] 32 | 33 | def parse(self): 34 | logger.debug("Parsing %s...", self.type) 35 | 36 | op = self.tflite 37 | opcode = self.model.OperatorCodes(op.OpcodeIndex()).BuiltinCode() 38 | assert(opcode in self.TypeMapping) 39 | 40 | assert(op.InputsLength() == 2) 41 | assert(op.OutputsLength() == 1) 42 | it = self.parseInput(0) 43 | ot = self.parseOutput(0) 44 | 45 | # options 46 | ai = op.Inputs(1) 47 | self.attrs['axes'] = TensorFactory.getData(self.model, self.graph, ai, 'int32') 48 | self.attrs['keepdims'] = 1 if (len(ot.shape) == len(it.shape)) else 0 49 | 50 | self.setParsed() 51 | 52 | def propagatableTensors(self): 53 | return list() 54 | 55 | def transform(self): 56 | layout = self.inputs[0].layout 57 | if layout is None: 58 | return 59 | else: 60 | axes = self.attrs['axes'] 61 | axes = [axe if axe >= 0 else (axes + len(layout.perm)) for axe in axes] 62 | self.attrs['axes'] = [layout.perm.index(axe) for axe in axes] 63 | -------------------------------------------------------------------------------- /tflite2onnx/op/reshape.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import logging 3 | import tflite 4 | import numpy as np 5 | 6 | from tflite2onnx import mapping 7 | from tflite2onnx.layout import Layout 8 | from tflite2onnx.op.common import Operator 9 | from tflite2onnx.op.transpose import Transpose 10 | 11 | logger = logging.getLogger('tflite2onnx') 12 | 13 | 14 | class Reshape(Operator): 15 | TypeMapping = { 16 | tflite.BuiltinOperator.RESHAPE: 'Reshape', 17 | } 18 | 19 | def __init__(self, TFactory, index): 20 | super().__init__(TFactory, index) 21 | 22 | self.forFakeBroadcasting = False 23 | 24 | self.setInited() 25 | 26 | @property 27 | def type(self): 28 | return 'Reshape' 29 | 30 | def parse(self): 31 | logger.debug("Parsing %s...", self.type) 32 | 33 | op = self.tflite 34 | opcode = self.model.OperatorCodes(op.OpcodeIndex()).BuiltinCode() 35 | assert(opcode in self.TypeMapping) 36 | 37 | assert(op.InputsLength() >= 1) 38 | assert(op.OutputsLength() == 1) 39 | 40 | # input 41 | self.parseInput(0) 42 | 43 | if op.InputsLength() == 1: 44 | # This path has not been tested by CI as we don't have a simple model for it. 45 | # See https://github.com/tensorflow/tensorflow/issues/45150 46 | op_opt = op.BuiltinOptions() 47 | option = tflite.ReshapeOptions() 48 | option.Init(op_opt.Bytes, op_opt.Pos) 49 | sp = option.NewShapeAsNumpy() 50 | sp = self.TFactory.createVector(sp.astype('int64')) 51 | sp.addConsumer(self) 52 | self.inputs.append(sp) 53 | 54 | if op.InputsLength() == 2: 55 | # shape 56 | st = self.parseInput(1) 57 | 58 | # TFLite shape is int32 data type, ONNX is int64 59 | st.dtype = mapping.DTYPE_NAME2ONNX['int64'] 60 | if st.isInitializer: 61 | st.data = st.data.astype('int64') 62 | if len(st.shape) > 1: 63 | logger.warning("ONNXRuntime doesn't support 2+rank shape, " 64 | "flatten if the shape is initialzier, ignore otherwise." 65 | "https://github.com/zhenhuaw-me/tflite2onnx/issues/16") 66 | if st.isInitializer: 67 | st.shape = (np.prod(np.array(st.shape)),) 68 | 69 | # output 70 | self.parseOutput(0) 71 | 72 | self.setParsed() 73 | 74 | def preserveInputSpatialSemantic(self): 75 | # https://github.com/zhenhuaw-me/tflite2onnx/issues/28 76 | # An example for inserting `Transpose` before `Reshape` 77 | # ------ 78 | # | Conv | 79 | # ------ 80 | # | 81 | # | to_transpose (Original input of `Reshape`) 82 | # | (e.g. NCHW) 83 | # ------- 84 | # | Trans | e.g. perm: (0, 2, 3, 1) 85 | # ------- 86 | # | 87 | # | transposed (New created tensor) 88 | # | (e.g. NHWC) 89 | # -------- 90 | # | Reshape | 91 | # -------- 92 | 93 | assert(self.status.parsed) 94 | to_transpose = self.inputs[0] 95 | 96 | transposed_name = 'TFLITE2ONNX_Transposed_%s' % to_transpose.name 97 | transposed = self.TFactory.getWithRef(to_transpose, transposed_name, True) 98 | 99 | # Construct the layout from the original input of `Reshape` 100 | layout = Layout(to_transpose.layout.target, to_transpose.layout.source) 101 | transposed.shape = layout.transform(to_transpose.shape) 102 | transposed.setParsed() 103 | 104 | # Construct the additional transpose before `Reshape` 105 | trans = Transpose(self.TFactory, -1) 106 | trans.attrs['perm'] = layout.perm 107 | 108 | trans.inputs.append(to_transpose) 109 | to_transpose.replaceConsumer(self, trans) 110 | self.replaceInput(to_transpose, transposed) 111 | 112 | trans.outputs.append(transposed) 113 | transposed.addProducer(trans) 114 | transposed.addConsumer(self) 115 | trans.setParsed() 116 | 117 | self.pre.append(trans) 118 | 119 | def preserveOutputSpatialSemantic(self): 120 | # https://github.com/zhenhuaw-me/tflite2onnx/issues/28 121 | # An example for inserting `Transpose` after `Reshape` 122 | # ------- 123 | # | Reshape | 124 | # ------- 125 | # | 126 | # | to_transpose (New created tensor) 127 | # | (e.g. NHWC) 128 | # ------- 129 | # | Trans | e.g. perm: (0, 3, 1, 2) 130 | # ------- 131 | # | 132 | # | transposed (Original `Reshape` output) 133 | # | (e.g. NCHW) 134 | # ------ 135 | # | Conv | 136 | # ------ 137 | 138 | assert(self.status.parsed) 139 | transposed = self.outputs[0] 140 | 141 | to_transpose_name = 'TFLITE2ONNX_ToTranspose_%s' % transposed.name 142 | to_transpose = self.TFactory.getWithRef(transposed, to_transpose_name, True) 143 | 144 | # Construct a layout from the original output of `Reshape` 145 | layout = Layout(transposed.layout.target, transposed.layout.source) 146 | to_transpose.shape = layout.transform(transposed.shape) 147 | to_transpose.setParsed() 148 | 149 | # Construct the additional transpose after `Reshape` 150 | trans = Transpose(self.TFactory, -1) 151 | trans.attrs['perm'] = transposed.layout.perm 152 | 153 | trans.inputs.append(to_transpose) 154 | transposed.replaceProducer(self, trans) 155 | self.replaceOutput(transposed, to_transpose) 156 | 157 | trans.outputs.append(transposed) 158 | to_transpose.addProducer(self) 159 | to_transpose.addConsumer(trans) 160 | trans.setParsed() 161 | # Rename the new `Transpose` operator avoid the name conflict with 'Reshape' 162 | trans.name = 'TFLITE2ONNX_Transpose_%s' % transposed.name 163 | 164 | self.post.append(trans) 165 | 166 | def propagatableTensors(self): 167 | return list() 168 | 169 | def transform(self): 170 | i = self.inputs[0] 171 | o = self.outputs[0] 172 | 173 | if self.forFakeBroadcasting: 174 | assert(len(i.shape) != len(o.shape)) 175 | shape_t = self.inputs[1] 176 | layout = copy.deepcopy(o.layout) 177 | if layout is None: 178 | raise ValueError("Requires layout description for <%s>" % i.name) 179 | shape_t.data = np.array(layout.transform(shape_t.data)) 180 | else: 181 | # Insert `Transpose` before/after `Reshape` to preserve spatial semantic 182 | if i.layout: 183 | self.preserveInputSpatialSemantic() 184 | if o.layout: 185 | self.preserveOutputSpatialSemantic() 186 | -------------------------------------------------------------------------------- /tflite2onnx/op/resize.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import tflite 3 | import numpy as np 4 | 5 | from tflite2onnx.layout import Layout 6 | from tflite2onnx import mapping 7 | from tflite2onnx.op.common import Operator 8 | 9 | logger = logging.getLogger('tflite2onnx') 10 | 11 | 12 | class Resize(Operator): 13 | TypeMapping = { 14 | tflite.BuiltinOperator.RESIZE_NEAREST_NEIGHBOR: 'Resize', 15 | tflite.BuiltinOperator.RESIZE_BILINEAR: 'Resize', 16 | } 17 | 18 | def __init__(self, TFactory, index): 19 | super().__init__(TFactory, index) 20 | # Four choices: 21 | # half_pixel, pytorch_half_pixel, align_corners, asymmetric, tf_crop_and_resize 22 | self.attrs['coordinate_transformation_mode'] = 'half_pixel' 23 | # This attribute is valid only if "mode" is "cubic". 24 | # The coefficient 'a' used in cubic interpolation. 25 | # Two common choice are -0.5 (in some cases of TensorFlow) and -0.75 (in PyTorch). 26 | self.attrs['cubic_coeff_a'] = -0.75 27 | self.attrs['exclude_outside'] = 0 28 | self.attrs['extrapolation_value'] = 0.0 29 | # Three interpolation modes: nearest (default), linear and cubic. 30 | # The "linear" mode includes linear interpolation for 1D tensor 31 | # and N-linear interpolation for N-D tensor 32 | # (for example, bilinear interpolation for 2D tensor). 33 | # The "cubic" mode includes cubic interpolation for 1D tensor 34 | # and N-cubic interpolation for N-D tensor 35 | # (for example, bicubic interpolation for 2D tensor). 36 | self.attrs['mode'] = 'nearest' 37 | # Four modes: round_prefer_floor (default, as known as round half down), 38 | # round_prefer_ceil (as known as round half up), floor, ceil. 39 | # Only used by nearest interpolation. 40 | # It indicates how to get "nearest" pixel in input tensor from x_original, 41 | # so this attribute is valid only if "mode" is "nearest". 42 | self.attrs['nearest_mode'] = 'round_prefer_floor' 43 | 44 | self.setInited() 45 | 46 | @property 47 | def type(self): 48 | return 'Resize' 49 | 50 | def parse(self): 51 | logger.debug("Parsing %s...", self.type) 52 | op = self.tflite 53 | opcode = self.model.OperatorCodes(op.OpcodeIndex()).BuiltinCode() 54 | assert(opcode in self.TypeMapping) 55 | 56 | assert(op.InputsLength() == 2), "TFLite has only two inputs" 57 | assert(op.OutputsLength() == 1) 58 | 59 | # ONNX Resize doesn't have layout semantic, but TFLite requires NHWC 60 | ilayout = Layout('NHWC', 'NCHW') 61 | im = self.parseInput(0, ilayout) 62 | 63 | # ROI and Scale are not optional until Resize v13, 64 | # currently (v11) we create them as empty initializer. 65 | # After v13, we can try to not include them in graph 66 | empty_input = self.TFactory.createEmptyTensor() 67 | empty_input.addConsumer(self) 68 | self.inputs.append(empty_input) # ROI 69 | self.inputs.append(empty_input) # Scale 70 | 71 | # output size 72 | sz = self.parseInput(1) 73 | # TFLite sizes is (H_new, W_new) while ONNX needs (N, C, H_new,W_new) 74 | assert len(sz.data) == 2 75 | assert len(im.shape) == 4 76 | sz.shape = [len(im.shape)] 77 | sz.data = np.concatenate((np.array([im.shape[0], im.shape[-1]]), sz.data)) 78 | sz.dtype = mapping.DTYPE_NAME2ONNX['int64'] 79 | 80 | # output 81 | olayout = Layout('NHWC', 'NCHW') 82 | self.parseOutput(0, olayout) 83 | 84 | # options 85 | if opcode is tflite.BuiltinOperator.RESIZE_BILINEAR: 86 | self.attrs['mode'] = 'linear' 87 | option = tflite.ResizeBilinearOptions() 88 | elif opcode is tflite.BuiltinOperator.RESIZE_NEAREST_NEIGHBOR: 89 | self.attrs['mode'] = 'nearest' 90 | option = tflite.ResizeNearestNeighborOptions() 91 | else: 92 | assert False, "Unreachable path!" 93 | 94 | op_opt = op.BuiltinOptions() 95 | option.Init(op_opt.Bytes, op_opt.Pos) 96 | 97 | if option.AlignCorners(): 98 | self.attrs['coordinate_transformation_mode'] = 'align_corners' 99 | elif option.HalfPixelCenters(): 100 | self.attrs['coordinate_transformation_mode'] = 'half_pixel' 101 | else: 102 | raise NotImplementedError("This path has not been tried") 103 | 104 | self.setParsed() 105 | 106 | def propagatableTensors(self): 107 | return list() 108 | 109 | def transform(self): 110 | pass 111 | -------------------------------------------------------------------------------- /tflite2onnx/op/rsqrt.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import tflite 3 | import numpy as np 4 | 5 | from tflite2onnx import mapping 6 | from tflite2onnx.op.common import Operator 7 | from tflite2onnx.op.binary import PowerWrapper 8 | 9 | logger = logging.getLogger('tflite2onnx') 10 | 11 | 12 | class Rsqrt(Operator): 13 | # use square root as input operator and propagate output to power 14 | TypeMapping = { 15 | tflite.BuiltinOperator.RSQRT: 'Sqrt', 16 | } 17 | 18 | def __init__(self, TFactory, index): 19 | super().__init__(TFactory, index) 20 | self.setInited() 21 | 22 | @property 23 | def type(self): 24 | if self.status.uninitialized: 25 | return 'Sqrt' 26 | else: 27 | op = self.tflite 28 | opcode = self.model.OperatorCodes(op.OpcodeIndex()).BuiltinCode() 29 | assert(opcode in self.TypeMapping) 30 | return self.TypeMapping[opcode] 31 | 32 | def parse(self): 33 | logger.debug("Parsing %s...", self.type) 34 | 35 | op = self.tflite 36 | opcode = self.model.OperatorCodes(op.OpcodeIndex()).BuiltinCode() 37 | assert(opcode in self.TypeMapping) 38 | 39 | assert(op.InputsLength() == 1) 40 | assert(op.OutputsLength() == 1) 41 | 42 | self.parseInput(0) 43 | self.parseOutput(0) 44 | 45 | # invert square root result 46 | self.appendInvert() 47 | 48 | self.setParsed() 49 | 50 | def propagatableTensors(self): 51 | return self.inputs + self.outputs 52 | 53 | def transform(self): 54 | pass 55 | 56 | def appendInvert(self): 57 | invert = PowerWrapper(self.TFactory, -1) 58 | 59 | invert_name = 'TFLITE2ONNX_Invert_%s' % self.outputs[0].name 60 | invert_t = self.TFactory.getWithRef(self.outputs[0], invert_name, True) 61 | invert_t.setParsed() 62 | invert_t.addProducer(self) 63 | invert_t.addConsumer(invert) 64 | 65 | pow_t = 'TFLITE2ONNX_PowData_%s' % self.outputs[0].name 66 | pow_t = self.TFactory.getWithRef(self.outputs[0], pow_t, True) 67 | pow_dtype = mapping.DTYPE_ONNX2NAME[pow_t.dtype] 68 | pow_t.data = np.full(shape=pow_t.shape, fill_value=-1, dtype=pow_dtype) 69 | pow_t.setParsed() 70 | pow_t.addConsumer(invert) 71 | 72 | invert.inputs.append(invert_t) 73 | invert.inputs.append(pow_t) 74 | invert.outputs.append(self.outputs[0]) 75 | self.replaceOutput(self.outputs[0], invert_t) 76 | 77 | invert.setParsed() 78 | self.post.append(invert) 79 | -------------------------------------------------------------------------------- /tflite2onnx/op/slice.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import logging 3 | import tflite 4 | import numpy as np 5 | 6 | from tflite2onnx.op.common import Operator 7 | 8 | logger = logging.getLogger('tflite2onnx') 9 | 10 | 11 | class Slice(Operator): 12 | TypeMapping = { 13 | tflite.BuiltinOperator.STRIDED_SLICE: 'Slice', 14 | } 15 | 16 | def __init__(self, TFactory, index): 17 | super().__init__(TFactory, index) 18 | self.setInited() 19 | 20 | @property 21 | def type(self): 22 | return 'Slice' 23 | 24 | def parse(self): 25 | logger.debug("Parsing %s...", self.type) 26 | 27 | op = self.tflite 28 | opcode = self.model.OperatorCodes(op.OpcodeIndex()).BuiltinCode() 29 | assert(opcode in self.TypeMapping) 30 | 31 | assert(op.InputsLength() == 4) 32 | assert(op.OutputsLength() == 1) 33 | 34 | # input 35 | it = self.parseInput(0) 36 | rank = len(it.shape) 37 | 38 | # output 39 | ot = self.parseOutput(0) 40 | assert(rank == len(ot.shape)) 41 | 42 | # options 43 | op_opt = op.BuiltinOptions() 44 | option = tflite.StridedSliceOptions() 45 | option.Init(op_opt.Bytes, op_opt.Pos) 46 | m_begin = option.BeginMask() 47 | m_end = option.EndMask() 48 | m_ellipsis = option.EllipsisMask() 49 | m_new_axis = option.NewAxisMask() 50 | m_shrink_axis = option.ShrinkAxisMask() 51 | assert(m_ellipsis == 0), "EllipsisMask not supported!" 52 | assert(m_new_axis == 0), "NewAxisMask not supported!" 53 | assert(m_shrink_axis == 0), "ShrinkAxisMask not supported!" 54 | 55 | def _intToBitsList(data, size): 56 | return [int(x) for x in '{:0{size}b}'.format(data, size=size)] 57 | 58 | m_begin = _intToBitsList(m_begin, rank) 59 | m_end = _intToBitsList(m_end, rank) 60 | 61 | # begin 62 | bt = self.parseInput(1) 63 | assert(bt.isInitializer) 64 | assert(rank == bt.shape[0]) 65 | for i, (mask, begin) in enumerate(zip(m_begin, list(bt.data))): 66 | bt.data[i] = 0 if mask == 1 else begin 67 | 68 | # end 69 | et = self.parseInput(2) 70 | assert(et.isInitializer) 71 | assert(rank == et.shape[0]) 72 | for i, (extent, mask, end) in enumerate(zip(it.shape, m_end, list(et.data))): 73 | et.data[i] = extent if mask == 1 else end 74 | 75 | # axis, we create from empty 76 | axis = np.arange(rank) 77 | at = self.TFactory.createVector(axis.astype('int32')) 78 | at.addConsumer(self) 79 | self.inputs.append(at) 80 | 81 | # strides 82 | st = self.parseInput(3) 83 | assert(st.isInitializer) 84 | assert(rank == st.shape[0]) 85 | 86 | self.setParsed() 87 | 88 | def propagatableTensors(self): 89 | return [self.inputs[0], self.outputs[0]] 90 | 91 | def transform(self): 92 | logger.debug("Transforming %s...", self.shorty) 93 | layout = self.outputs[0].layout 94 | cl = copy.deepcopy(layout) 95 | if cl is None: 96 | logger.warning("layout of %s should not be None", self.shorty) 97 | return 98 | assert(len(self.inputs) == 5) 99 | tbegin = self.inputs[1] 100 | tbegin.data = cl.transform(tbegin.data) 101 | tend = self.inputs[2] 102 | tend.data = cl.transform(tend.data) 103 | tstrides = self.inputs[4] 104 | tstrides.data = cl.transform(tstrides.data) 105 | -------------------------------------------------------------------------------- /tflite2onnx/op/softmax.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import tflite 3 | 4 | from tflite2onnx.op.common import Operator 5 | 6 | logger = logging.getLogger('tflite2onnx') 7 | 8 | 9 | class Softmax(Operator): 10 | TypeMapping = { 11 | tflite.BuiltinOperator.SOFTMAX: 'Softmax', 12 | } 13 | 14 | def __init__(self, TFactory, index): 15 | super().__init__(TFactory, index) 16 | 17 | self.attrs['axis'] = -1 18 | 19 | self.setInited() 20 | 21 | @property 22 | def type(self): 23 | return 'Softmax' 24 | 25 | def parse(self): 26 | logger.debug("Parsing %s...", self.type) 27 | op = self.tflite 28 | opcode = self.model.OperatorCodes(op.OpcodeIndex()).BuiltinCode() 29 | assert(opcode in self.TypeMapping) 30 | 31 | assert(op.InputsLength() == 1) 32 | assert(op.OutputsLength() == 1) 33 | self.parseInput(0) 34 | self.parseOutput(0) 35 | 36 | # TFLite Softmax ALWAYS softmax on `-1` axis, while ONNX on `1` by default. 37 | # And, we transform NHWC to NCHW for 4D tensor. 38 | # axis = 1 if len(to.shape) == 4 else -1 39 | # if len(to.shape) == 4: 40 | # axis = 1 41 | # elif len(to.shape) == 2: 42 | # axis = -1 43 | # else: 44 | # axis = -1 45 | # logger.warning("Softmax has input shape %s.", str(to.shape)) 46 | # FIXME: axis is the dim index of 'C'. 47 | self.attrs['axis'] = -1 48 | 49 | self.setParsed() 50 | 51 | def propagatableTensors(self): 52 | return self.inputs + self.outputs 53 | 54 | def transform(self): 55 | pass 56 | -------------------------------------------------------------------------------- /tflite2onnx/op/split.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import tflite 3 | import numpy as np 4 | 5 | from tflite2onnx.tensor import TensorFactory 6 | from tflite2onnx.op.common import Operator 7 | 8 | logger = logging.getLogger('tflite2onnx') 9 | 10 | 11 | class Split(Operator): 12 | TypeMapping = { 13 | tflite.BuiltinOperator.SPLIT: 'Split', 14 | } 15 | 16 | def __init__(self, TFactory, index): 17 | super().__init__(TFactory, index) 18 | 19 | self.attrs['axis'] = -1 20 | self.attrs['split'] = None 21 | 22 | self.setInited() 23 | 24 | @property 25 | def type(self): 26 | return 'Split' 27 | 28 | def parse(self): 29 | logger.debug("Parsing %s...", self.type) 30 | 31 | op = self.tflite 32 | opcode = self.model.OperatorCodes(op.OpcodeIndex()).BuiltinCode() 33 | assert(opcode in self.TypeMapping) 34 | 35 | assert(op.InputsLength() == 2) 36 | 37 | # input 38 | it = self.parseInput(1) 39 | 40 | # options 41 | ai = op.Inputs(0) 42 | axis = TensorFactory.getData(self.model, self.graph, ai, 'int32') 43 | assert(axis.size == 1) 44 | self.attrs['axis'] = int(axis[0]) 45 | 46 | op_opt = op.BuiltinOptions() 47 | option = tflite.SplitOptions() 48 | option.Init(op_opt.Bytes, op_opt.Pos) 49 | # TFLite outputs have same shape 50 | split_size = option.NumSplits() 51 | assert(isinstance(split_size, int)) 52 | assert(op.OutputsLength() == split_size) 53 | self.attrs['split'] = np.zeros(split_size).astype('int') 54 | 55 | # XXX workaround for ONNXRuntime: doesn't support all-zero `split`. 56 | split = it.shape[self.attrs['axis']] / split_size 57 | self.attrs['split'] = np.full((split_size,), split).astype('int') 58 | 59 | # output 60 | for i in range(split_size): 61 | self.parseOutput(i) 62 | 63 | self.setParsed() 64 | 65 | def propagatableTensors(self): 66 | return self.inputs + self.outputs 67 | 68 | def transform(self): 69 | logger.debug("Transforming %s...", self.shorty) 70 | layout = self.outputs[0].layout 71 | if layout is not None: 72 | axis = self.attrs['axis'] 73 | axis = axis if axis >= 0 else (axis + len(layout.perm)) 74 | self.attrs['axis'] = layout.perm.index(axis) 75 | -------------------------------------------------------------------------------- /tflite2onnx/op/squared_difference.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import tflite 3 | import numpy as np 4 | 5 | from tflite2onnx import mapping 6 | from tflite2onnx.op.common import Operator 7 | from tflite2onnx.op.binary import PowerWrapper 8 | 9 | logger = logging.getLogger('tflite2onnx') 10 | 11 | 12 | class SquaredDifference(Operator): 13 | # use subtraction as input operator and propagate output to power 14 | TypeMapping = { 15 | tflite.BuiltinOperator.SQUARED_DIFFERENCE: 'Sub', 16 | } 17 | 18 | def __init__(self, TFactory, index): 19 | super().__init__(TFactory, index) 20 | self.setInited() 21 | 22 | @property 23 | def type(self): 24 | if self.status.uninitialized: 25 | return 'Sub' 26 | else: 27 | op = self.tflite 28 | opcode = self.model.OperatorCodes(op.OpcodeIndex()).BuiltinCode() 29 | assert(opcode in self.TypeMapping) 30 | return self.TypeMapping[opcode] 31 | 32 | def parse(self): 33 | logger.debug("Parsing %s...", self.type) 34 | 35 | op = self.tflite 36 | opcode = self.model.OperatorCodes(op.OpcodeIndex()).BuiltinCode() 37 | assert(opcode in self.TypeMapping) 38 | 39 | assert(op.InputsLength() == 2) 40 | assert(op.OutputsLength() == 1) 41 | 42 | self.parseInput(0) 43 | self.parseInput(1) 44 | self.parseOutput(0) 45 | 46 | assert(len(self.inputs[0].shape) == len(self.inputs[1].shape)) 47 | 48 | # apply square to the subtraction result 49 | self.appendSquare() 50 | 51 | self.setParsed() 52 | 53 | def propagatableTensors(self): 54 | return self.inputs + self.outputs 55 | 56 | def transform(self): 57 | pass 58 | 59 | def appendSquare(self): 60 | square = PowerWrapper(self.TFactory, -1) 61 | 62 | square_name = 'TFLITE2ONNX_Square_%s' % self.outputs[0].name 63 | square_t = self.TFactory.getWithRef(self.outputs[0], square_name, True) 64 | square_t.setParsed() 65 | square_t.addProducer(self) 66 | square_t.addConsumer(square) 67 | 68 | pow_t = 'TFLITE2ONNX_PowData_%s' % self.outputs[0].name 69 | pow_t = self.TFactory.getWithRef(self.outputs[0], pow_t, True) 70 | pow_dtype = mapping.DTYPE_ONNX2NAME[pow_t.dtype] 71 | pow_t.data = np.full(shape=pow_t.shape, fill_value=2, dtype=pow_dtype) 72 | pow_t.setParsed() 73 | pow_t.addConsumer(square) 74 | 75 | square.inputs.append(square_t) 76 | square.inputs.append(pow_t) 77 | square.outputs.append(self.outputs[0]) 78 | self.replaceOutput(self.outputs[0], square_t) 79 | 80 | square.setParsed() 81 | self.post.append(square) 82 | -------------------------------------------------------------------------------- /tflite2onnx/op/transpose.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import tflite 3 | 4 | from tflite2onnx.tensor import TensorFactory 5 | from tflite2onnx.op.common import Operator 6 | 7 | logger = logging.getLogger('tflite2onnx') 8 | 9 | 10 | class Transpose(Operator): 11 | TypeMapping = { 12 | tflite.BuiltinOperator.TRANSPOSE: 'Transpose' 13 | } 14 | 15 | def __init__(self, TFactory, index): 16 | super().__init__(TFactory, index) 17 | 18 | self.attrs['perm'] = [] 19 | 20 | self.setInited() 21 | 22 | @property 23 | def type(self): 24 | return 'Transpose' 25 | 26 | def parse(self): 27 | logger.debug("Parsing %s...", self.type) 28 | op = self.tflite 29 | opcode = self.model.OperatorCodes(op.OpcodeIndex()).BuiltinCode() 30 | assert(opcode in self.TypeMapping) 31 | 32 | assert(op.InputsLength() == 2) 33 | assert(op.OutputsLength() == 1) 34 | self.parseInput(0) 35 | self.parseOutput(0) 36 | 37 | ii = op.Inputs(1) 38 | self.attrs['perm'] = TensorFactory.getData(self.model, self.graph, ii, 'int32') 39 | 40 | self.setParsed() 41 | 42 | def propagatableTensors(self): 43 | return list() 44 | 45 | def transform(self): 46 | logger.warning("Transforming %s, doing nothing now...", self.type) 47 | -------------------------------------------------------------------------------- /tflite2onnx/op/unary.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import tflite 3 | 4 | from tflite2onnx.op.common import Operator 5 | 6 | logger = logging.getLogger('tflite2onnx') 7 | 8 | 9 | class Unary(Operator): 10 | TypeMapping = { 11 | tflite.BuiltinOperator.ABS: 'Abs', 12 | tflite.BuiltinOperator.SQRT: 'Sqrt', 13 | } 14 | 15 | def __init__(self, TFactory, index): 16 | super().__init__(TFactory, index) 17 | self.setInited() 18 | 19 | @property 20 | def type(self): 21 | if self.status.uninitialized: 22 | return 'Unary' 23 | else: 24 | op = self.tflite 25 | opcode = self.model.OperatorCodes(op.OpcodeIndex()).BuiltinCode() 26 | assert(opcode in self.TypeMapping) 27 | return self.TypeMapping[opcode] 28 | 29 | def parse(self): 30 | logger.debug("Parsing %s...", self.type) 31 | 32 | op = self.tflite 33 | opcode = self.model.OperatorCodes(op.OpcodeIndex()).BuiltinCode() 34 | assert(opcode in self.TypeMapping) 35 | 36 | assert(op.InputsLength() == 1) 37 | assert(op.OutputsLength() == 1) 38 | self.parseInput(0) 39 | self.parseOutput(0) 40 | 41 | self.setParsed() 42 | 43 | def propagatableTensors(self): 44 | return self.inputs + self.outputs 45 | 46 | def transform(self): 47 | pass 48 | -------------------------------------------------------------------------------- /tflite2onnx/quantize.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from onnx import TensorProto 3 | 4 | from tflite2onnx.op.quantize import Quantize 5 | 6 | logger = logging.getLogger('tflite2onnx') 7 | 8 | 9 | def handleQuantizationTensor(TFactory, t): 10 | """Translate a UINT8 TFLite tensor to Quantize-Dequantize pattern in ONNX. 11 | 12 | As quantization support of ONNX is limited, we currently try to preserve 13 | the quantization parameters of TFLite model in the resulted ONNX model. 14 | * All operators are FP operators still. 15 | * Translate a UINT8 TFLite tensor to Quantize-Dequantize pattern. 16 | 17 | In practice, convert TFLite pattern `[OP A] -> -> [OP B]` to be 18 | `[OP A] -> -> [Quantize] -> -> [Dequantize] -> -> [OP B]`. 19 | Where the `[OP A]` or `[OP B]` can be empty if the tensor is an input or 20 | output w.r.t. the graph. The `` can be UINT8 mostly, and INT32 for bias. 21 | To identify bias, the functionality needs to be called by operator? 22 | 23 | We need the `` only because the quantization parameters and 24 | producer and consumers of it can be easily obtained. For the inserted 25 | operators, store them in `t1.producers[0].post` or `t1.consumers[0].pre` 26 | if `` has no producer. Then modify the graph. 27 | """ 28 | if not t.quantized: 29 | return t 30 | logger.debug("Generating quantization pattern for {}".format(t.name)) 31 | 32 | t.dequantize() 33 | 34 | if t.is_bias: 35 | # Bias is INT32, which cannot be processed by Quantize/Dequantize. 36 | # Fast return here as we need it be float only. 37 | return t 38 | 39 | name_prefix = 'TFLITE2ONNX_Quant_' + t.name 40 | 41 | # create quantized tensor 42 | qtname = name_prefix + '_quantized' 43 | qtensor = TFactory.getWithRef(t, qtname, True) 44 | qtensor.dtype = TensorProto.UINT8 45 | qtensor.setParsed() 46 | 47 | # create Quantize op 48 | qop = Quantize(TFactory, -1) 49 | qop.name = name_prefix + '_Quantize' 50 | qop.inputs.append(t) 51 | qop.outputs.append(qtensor) 52 | qop.setParsed() 53 | qop.dequantize() 54 | 55 | # create dequantized tensor 56 | deqtname = name_prefix + '_dequantized' 57 | deqtensor = TFactory.getWithRef(t, deqtname, True) 58 | deqtensor.dtype = TensorProto.FLOAT 59 | deqtensor.setParsed() 60 | 61 | # create Dequantize op 62 | deqop = Quantize(TFactory, -1) 63 | deqop.name = name_prefix + '_Dequantize' 64 | deqop.inputs.append(qtensor) 65 | deqop.outputs.append(deqtensor) 66 | deqop.setParsed() 67 | deqop.dequantize() 68 | 69 | # link local pattern 70 | qtensor.addProducer(qop) 71 | qtensor.addConsumer(deqop) 72 | deqtensor.addProducer(deqop) 73 | 74 | # add Quantize/Dequantize to graph 75 | if t.producers: 76 | master_op = t.producers[0] 77 | master_op.post.insert(0, deqop) 78 | master_op.post.insert(0, qop) 79 | elif t.consumers: 80 | master_op = t.consumers[0] 81 | master_op.pre.insert(0, deqop) 82 | master_op.pre.insert(0, qop) 83 | else: 84 | assert(False), "No place to add op" 85 | 86 | # link pattern to graph 87 | for c in t.consumers: 88 | c.replaceInput(t, deqtensor) 89 | deqtensor.addConsumer(c) 90 | t.consumers.clear() 91 | t.addConsumer(qop) 92 | 93 | return deqtensor 94 | 95 | 96 | def foldFP16QuantPattern(ops): 97 | """Fold TFLite FP16 quantization pattern. 98 | 99 | * `FP16 weights - Dequantize - FP32 tensor - Conv` -> `FP32 weights - Conv` 100 | * `FP16 input - Dequantize - FP32 tensor - Conv` -> `FP32 input - Conv` 101 | """ 102 | logger.debug("FP16 Quant Folder: Folding FP16 quantization subgraph across graph...") 103 | 104 | # Using `Graph.ops` in this part as these operators should be raw TFLite ones 105 | fp16deqs = [op for op in ops if op.type == 'DequantizeLinear' and op.inputs[0].dtype is TensorProto.FLOAT16] # noqa: E501 106 | 107 | count = 0 108 | for dep in fp16deqs: 109 | logger.debug("FP16 Quant Folder: Folding FP16-Quant for op %s", dep.name) 110 | fp16i = dep.inputs[0] 111 | fp16i.dtype = TensorProto.FLOAT 112 | if fp16i.isInitializer: 113 | fp16i.data = fp16i.data.astype('float32') 114 | fp32i = fp16i 115 | 116 | # attach the casted fp32 tensor to the op that consumes the output of the Dequantize 117 | fp32o = dep.outputs[0] 118 | for op in fp32o.consumers: 119 | op.replaceInput(fp32o, fp32i) 120 | fp32i.addConsumer(op) 121 | fp32i.removeConsumer(dep) 122 | 123 | # remove Dequantize operator 124 | ops.remove(dep) 125 | # the unneeded tensos will be removed in graph automatically 126 | # graph.value_info.remove(fp32o) 127 | count += 1 128 | 129 | if count > 0: 130 | logger.info("FP16 Quant Folder: %d FP16 Quant Pattern are folded!", count) 131 | -------------------------------------------------------------------------------- /tflite2onnx/tensor.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import logging 3 | import numpy as np 4 | import onnx 5 | from onnx import helper, numpy_helper, TensorProto 6 | 7 | from tflite2onnx import mapping 8 | from tflite2onnx.common import T2OBase 9 | 10 | logger = logging.getLogger('tflite2onnx') 11 | 12 | 13 | class Tensor(T2OBase): 14 | def __init__(self, model, graph, index, layout=None, is_bias=False): 15 | super().__init__(model, graph, index) 16 | self.tflite = graph.Tensors(index) if index >= 0 else None 17 | self.shape = [] 18 | self.dtype = None 19 | self.data = None 20 | 21 | # the defaults of quantization parameter 22 | self.scale = 1.0 23 | self.zero_point = 127 24 | 25 | self.layout = layout 26 | self.producers = [] 27 | self.consumers = [] 28 | 29 | # we only accept INT32 as quantized tensor type for bias 30 | self.is_bias = is_bias 31 | 32 | self.setInited() 33 | 34 | @property 35 | def isInitializer(self): 36 | return self.data is not None 37 | 38 | def addProducer(self, op): 39 | assert(len(self.producers) == 0) 40 | self.producers.append(op) 41 | assert(len(self.producers) == 1) 42 | 43 | def removeProducer(self, op): 44 | assert(len(self.producers) == 1) 45 | assert(self.producers[0] == op) 46 | self.producers.remove(op) 47 | 48 | def replaceProducer(self, original, new): 49 | assert(len(self.producers) == 1) 50 | assert(self.producers[0] == original) 51 | self.producers[0] = new 52 | 53 | def addConsumer(self, op): 54 | assert(op not in self.consumers) 55 | self.consumers.append(op) 56 | 57 | def removeConsumer(self, op): 58 | assert(op in self.consumers) 59 | self.consumers.remove(op) 60 | 61 | def replaceConsumer(self, original, new): 62 | assert(original in self.consumers) 63 | for i, op in enumerate(self.consumers): 64 | if op is original: 65 | self.consumers[i] = new 66 | return 67 | 68 | @property 69 | def quantized(self): 70 | is_quant_dtype = ((self.dtype == TensorProto.UINT8) or 71 | ((self.dtype == TensorProto.INT32) and self.is_bias)) 72 | if self.tflite is None: 73 | return is_quant_dtype 74 | else: 75 | has_quant = self.tflite.Quantization() is not None 76 | return is_quant_dtype and has_quant 77 | 78 | def dequantize(self): 79 | if not self.quantized: 80 | return 81 | logger.debug("Dequantizing %s", self.shorty) 82 | if self.isInitializer: 83 | int32 = self.data.astype('int32') 84 | shiftted = np.subtract(int32, self.zero_point) 85 | fp32 = np.multiply(shiftted.astype('float32'), self.scale) 86 | self.data = fp32 87 | self.dtype = TensorProto.FLOAT 88 | 89 | @property 90 | def isScalar(self): 91 | return (self.layout is None) and (len(self.shape) == 0) and (len(self.data) == 1) 92 | 93 | def asDtype(self, dtype: str): 94 | self.dtype = mapping.DTYPE_NAME2ONNX[dtype] 95 | if self.isInitializer: 96 | self.data = self.data.astype(dtype) 97 | 98 | def parse(self): 99 | if self.status.parsed: 100 | return 101 | tensor = self.tflite 102 | self.name = tensor.Name().decode('utf-8') 103 | logger.debug("Parsing %s...", self.name) 104 | self.shape = [int(i) for i in tensor.ShapeAsNumpy()] 105 | 106 | assert(tensor.Type() in mapping.DTYPE_TFLITE2ONNX) 107 | self.dtype = mapping.DTYPE_TFLITE2ONNX[tensor.Type()] 108 | 109 | if self.quantized: 110 | quant = tensor.Quantization() 111 | assert(quant.ScaleAsNumpy().size == 1), "Per-tensor support only currently" 112 | assert(quant.ZeroPointAsNumpy().size == 1), "Per-tensor support only currently" 113 | self.scale = float(quant.ScaleAsNumpy()[0]) 114 | self.zero_point = int(quant.ZeroPointAsNumpy()[0]) 115 | 116 | self.data = TensorFactory.getData(self.model, self.graph, self.index, 117 | mapping.DTYPE_ONNX2NAME[self.dtype]) 118 | 119 | self.setParsed() 120 | 121 | def transform(self): 122 | assert(self.status.parsed) 123 | assert(self.layout is not None) 124 | if self.isInitializer: 125 | data = self.data.reshape(self.shape) 126 | self.shape = self.layout.transform(self.shape) 127 | self.data = data.transpose(self.layout.perm) 128 | else: 129 | self.shape = self.layout.transform(self.shape) 130 | 131 | def validate(self): 132 | if self.isInitializer: 133 | assert(len(self.producers) == 0), "Initializer should not have producer" 134 | else: 135 | assert(len(self.producers) <= 1), "Tensor should have 1 producer or no" 136 | assert(len(self.name) > 0), "Tensor must have valid name" 137 | 138 | def convert(self): 139 | if self.status.converted: 140 | return 141 | logger.debug("Converting %s...", self.shorty) 142 | if self.isInitializer: 143 | if isinstance(self.data, np.ndarray): 144 | # Need this because ONNX saves non-C-builtin data type in a special way. 145 | # https://github.com/onnx/onnx/blob/v1.8.0/onnx/onnx.proto3#L523 146 | self.onnx = numpy_helper.from_array(self.data, self.name) 147 | else: 148 | self.onnx = helper.make_tensor(self.name, self.dtype, self.shape, self.data) 149 | onnx.checker.check_tensor(self.onnx) 150 | else: 151 | self.onnx = helper.make_tensor_value_info(self.name, self.dtype, self.shape) 152 | onnx.checker.check_value_info(self.onnx) 153 | assert(self.onnx) 154 | 155 | self.setConverted() 156 | 157 | @property 158 | def shorty(self): 159 | return '<%s>(%s,%s)' % (self.name, mapping.DTYPE_ONNX2NAME[self.dtype], self.shape) 160 | 161 | def __str__(self): 162 | producer_names = [op.shorty for op in self.producers] 163 | consumer_names = [op.shorty for op in self.consumers] 164 | return '%s: {%s} -> {%s}' % (self.shorty, producer_names, consumer_names) 165 | 166 | 167 | class TensorFactory: 168 | """The Registery holds all tensors in a SubGraph of TFLite by a name->Tensor map.""" 169 | def __init__(self, model, graph): 170 | self.model = model 171 | self.graph = graph 172 | self.registery = dict() 173 | 174 | def get(self, index, layout=None, is_bias=False): 175 | tft = self.graph.Tensors(index) 176 | name = tft.Name().decode('utf-8') 177 | if name not in self.registery: 178 | t = Tensor(self.model, self.graph, index, layout, is_bias) 179 | self.registery[name] = t 180 | else: 181 | t = self.registery[name] 182 | if t.layout is None: 183 | t.layout = layout 184 | return t 185 | 186 | def getWithRef(self, ref, name, forceUnique=False): 187 | """Create a copy of the ref tensor. 188 | 189 | This is used to create helper tensors for activations, layout handling, 190 | quantization and so on. Some attributions will be removed. 191 | """ 192 | if name not in self.registery: 193 | t = Tensor(self.model, self.graph, -1) 194 | t.name = name 195 | t.dtype = ref.dtype 196 | t.layout = copy.deepcopy(ref.layout) 197 | t.shape = copy.deepcopy(ref.shape) 198 | t.scale = copy.deepcopy(ref.scale) 199 | t.zero_point = copy.deepcopy(ref.zero_point) 200 | self.registery[name] = t 201 | else: 202 | assert(not forceUnique) 203 | t = self.registery[name] 204 | return t 205 | 206 | def createScalar(self, dtype, value): 207 | name = 'TFLITE2ONNX_Scalar_' + dtype + '_' + str(value) 208 | return self._createScalarCore(name, dtype, value) 209 | 210 | def createVector(self, ndarray): 211 | array2key = str(ndarray).replace(' ', '_') 212 | dtype = str(ndarray.dtype) 213 | name = 'TFLITE2ONNX_Vector_' + dtype + '_' + array2key 214 | if name not in self.registery: 215 | t = Tensor(self.model, self.graph, -1, None) 216 | t.name = name 217 | t.dtype = mapping.DTYPE_NAME2ONNX[dtype] 218 | t.data = ndarray.copy() 219 | t.shape = t.data.shape 220 | t.setParsed() 221 | self.registery[name] = t 222 | return self.registery[name] 223 | 224 | def createEmptyTensor(self): 225 | # Used for optional inputs that we need it to be empty. 226 | logger.warning("Empty tensor used, please double confirm your code path!") 227 | name = 'TFLITE2ONNX_EmptyTensor' 228 | if name not in self.registery: 229 | t = Tensor(self.model, self.graph, -1, None) 230 | t.name = name 231 | t.dtype = mapping.DTYPE_NAME2ONNX['float32'] 232 | t.data = [] 233 | t.shape = [0] 234 | t.setParsed() 235 | self.registery[name] = t 236 | return self.registery[name] 237 | 238 | def _createScalarCore(self, name, dtype, value): 239 | if name not in self.registery: 240 | t = Tensor(self.model, self.graph, -1, None) 241 | t.name = name 242 | t.dtype = mapping.DTYPE_NAME2ONNX[dtype] 243 | t.data = [value] # cannot use NDArray for cases such as min/max of ReLU6 244 | t.setParsed() 245 | self.registery[name] = t 246 | return self.registery[name] 247 | 248 | def createQuantScale(self, tensor): 249 | value = tensor.scale 250 | assert(isinstance(value, float) or (len(value) == 1)) 251 | dtype = 'float32' 252 | name = 'TFLITE2ONNX_Scalar_' + dtype + '_' + str(value) 253 | return self._createScalarCore(name, dtype, value) 254 | 255 | def createQuantZeroPoint(self, tensor): 256 | value = tensor.zero_point 257 | assert(isinstance(value, int) or (len(value) == 1)) 258 | assert(value >= 0 and value <= 255) 259 | dtype = 'uint8' 260 | name = 'TFLITE2ONNX_Scalar_' + dtype + '_' + str(value) 261 | return self._createScalarCore(name, dtype, value) 262 | 263 | @staticmethod 264 | def getData(model, graph, index, dtype): 265 | if (dtype not in ['int32', 'float32', 'uint8']): 266 | logger.warning("Data type {} not supported/tested yet, " 267 | "the generated model may contain error".format(dtype)) 268 | assert(index < graph.TensorsLength()) 269 | t = graph.Tensors(index) 270 | bi = t.Buffer() 271 | shape = t.ShapeAsNumpy() 272 | assert(bi < model.BuffersLength()) 273 | raw = model.Buffers(bi).DataAsNumpy() 274 | if isinstance(raw, int) and raw == 0: 275 | return None 276 | data = np.frombuffer(raw, dtype=dtype) 277 | if len(shape) > 0: 278 | data = data.reshape(shape) 279 | return data.copy() 280 | -------------------------------------------------------------------------------- /tflite2onnx/utils.py: -------------------------------------------------------------------------------- 1 | import tflite 2 | from tflite2onnx.op.common import OpFactory 3 | 4 | 5 | def enableDebugLog(): 6 | """Dump the logging.DEBUG level log.""" 7 | import logging 8 | fmt = '%(asctime)s %(levelname).1s [%(name)s][%(filename)s:%(lineno)d] %(message)s' 9 | logging.basicConfig(format=fmt, level=logging.DEBUG) 10 | 11 | 12 | def getSupportedOperators(): 13 | """Get the names of the supported TensorFlow Lite operator.""" 14 | opcs = list(OpFactory.registry.keys()) 15 | opcs.sort() 16 | names = [tflite.BUILTIN_OPCODE2NAME[opc] for opc in opcs] 17 | return names 18 | --------------------------------------------------------------------------------