├── .github ├── FUNDING.yml └── workflows │ ├── ci.yml │ ├── format.yml │ ├── nightly-test.yml │ └── python-publish.yml ├── .gitignore ├── LICENSE ├── MANIFEST.in ├── README.md ├── VERSION ├── bin └── onnxslim ├── docs ├── _static │ └── style.css ├── _templates │ └── layout.html ├── conf.py ├── index.rst ├── main │ └── toc.rst └── requirements.txt ├── examples ├── common_subexpression_elimination │ ├── README.md │ └── cse_demo.py ├── input_shape_modification │ └── README.md ├── model_inspect │ └── README.md └── output_modification │ └── README.md ├── format.sh ├── images ├── after_cse.png ├── before_cse.png ├── cse.png ├── input_shape_modification.jpg ├── model_inspect.jpg ├── onnxslim.gif └── output_modification.jpg ├── onnxslim ├── __init__.py ├── __main__.py ├── argparser.py ├── cli │ ├── __init__.py │ └── _main.py ├── core │ ├── __init__.py │ ├── optimization │ │ ├── __init__.py │ │ ├── dead_node_elimination.py │ │ ├── subexpression_elimination.py │ │ └── weight_tying.py │ └── pattern │ │ ├── __init__.py │ │ ├── elimination │ │ ├── __init__.py │ │ ├── concat.py │ │ ├── reshape.py │ │ ├── reshape_as.py │ │ ├── slice.py │ │ └── unsqueeze.py │ │ ├── fusion │ │ ├── __init__.py │ │ ├── convadd.py │ │ ├── convbn.py │ │ ├── gelu.py │ │ ├── gemm.py │ │ ├── padconv.py │ │ └── reduce.py │ │ └── registry.py ├── misc │ ├── __init__.py │ ├── font.py │ └── tabulate.py ├── third_party │ ├── __init__.py │ ├── onnx_graphsurgeon │ │ ├── __init__.py │ │ ├── exporters │ │ │ ├── __init__.py │ │ │ ├── base_exporter.py │ │ │ └── onnx_exporter.py │ │ ├── graph_pattern │ │ │ ├── __init__.py │ │ │ └── graph_pattern.py │ │ ├── importers │ │ │ ├── __init__.py │ │ │ ├── base_importer.py │ │ │ └── onnx_importer.py │ │ ├── ir │ │ │ ├── __init__.py │ │ │ ├── function.py │ │ │ ├── graph.py │ │ │ ├── node.py │ │ │ └── tensor.py │ │ ├── logger │ │ │ ├── __init__.py │ │ │ └── logger.py │ │ └── util │ │ │ ├── __init__.py │ │ │ ├── exception.py │ │ │ └── misc.py │ └── symbolic_shape_infer.py └── utils.py ├── setup.py └── tests ├── test_benchmark.py ├── test_folder.py ├── test_modelzoo.py ├── test_onnx_nets.py ├── test_onnxslim.py ├── test_pattern_generator.py ├── test_pattern_matcher.py ├── test_yolo.py └── utils.py /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: inisis 4 | patreon: # Replace with a single Patreon username 5 | open_collective: # Replace with a single Open Collective username 6 | ko_fi: # Replace with a single Ko-fi username 7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel 8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry 9 | liberapay: # Replace with a single Liberapay username 10 | issuehunt: # Replace with a single IssueHunt username 11 | lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry 12 | polar: # Replace with a single Polar username 13 | buy_me_a_coffee: # Replace with a single Buy Me a Coffee username 14 | thanks_dev: # Replace with a single thanks.dev username 15 | custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] 16 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: ["main"] 6 | pull_request: 7 | branches: ["main"] 8 | 9 | jobs: 10 | test: 11 | runs-on: self-hosted 12 | strategy: 13 | matrix: 14 | python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] 15 | 16 | steps: 17 | - uses: actions/checkout@v3 18 | 19 | - uses: actions/setup-python@v4 20 | with: 21 | python-version: ${{ matrix.python-version }} 22 | 23 | - name: install dependency 24 | run: | 25 | python -m pip install --upgrade pip wheel setuptools 26 | pip install . 27 | pip install pytest onnxruntime 28 | pip install pytest pytest-xdist onnxruntime timm torchvision --no-cache-dir --extra-index-url https://download.pytorch.org/whl/cpu 29 | pip install coverage 30 | 31 | - name: yolo test 32 | if: matrix.python-version <= '3.11' 33 | run: | 34 | pip install ultralytics 35 | coverage run -m pytest tests/test_yolo.py -sv 36 | 37 | - name: onnxslim api and binary test 38 | run: | 39 | pip install onnxconverter_common 40 | coverage run -m pytest tests/test_onnxslim.py 41 | 42 | - name: model zoo test 43 | run: | 44 | coverage run -m pytest tests/test_modelzoo.py 45 | 46 | - name: pattern matcher test 47 | run: | 48 | coverage run -m pytest tests/test_pattern_matcher.py 49 | 50 | - name: pattern generator test 51 | run: | 52 | coverage run -m pytest tests/test_pattern_generator.py 53 | 54 | - name: Merge Coverage Reports 55 | run: | 56 | coverage xml -o coverage-ci.xml 57 | 58 | - name: Upload coverage reports to Codecov 59 | uses: codecov/codecov-action@v5 60 | with: 61 | token: ${{ secrets.CODECOV_TOKEN }} 62 | -------------------------------------------------------------------------------- /.github/workflows/format.yml: -------------------------------------------------------------------------------- 1 | # Ultralytics 🚀 - AGPL-3.0 license 2 | # Ultralytics Actions https://github.com/ultralytics/actions 3 | # This workflow automatically formats code and documentation in PRs to official Ultralytics standards 4 | 5 | name: Ultralytics Actions 6 | 7 | on: 8 | push: 9 | branches: [main] 10 | pull_request_target: 11 | branches: [main] 12 | types: [opened, closed, synchronize] 13 | 14 | jobs: 15 | format: 16 | runs-on: ubuntu-latest 17 | steps: 18 | - name: Run Ultralytics Formatting 19 | uses: ultralytics/actions@main 20 | with: 21 | token: ${{ secrets.GITHUB_TOKEN }} # automatically generated, do not modify 22 | python: true # format Python code and docstrings 23 | markdown: true # format Markdown 24 | prettier: true # format YAML 25 | spelling: true # check spelling 26 | links: false # check broken links 27 | -------------------------------------------------------------------------------- /.github/workflows/nightly-test.yml: -------------------------------------------------------------------------------- 1 | name: nightly-test 2 | 3 | on: 4 | schedule: 5 | - cron: "0 18 * * *" # Runs at 6:00 PM UTC every day, which is 2:00 AM Beijing Time the next day 6 | 7 | jobs: 8 | build: 9 | runs-on: self-hosted 10 | 11 | steps: 12 | - uses: actions/checkout@v3 13 | 14 | - uses: actions/setup-python@v4 15 | with: 16 | python-version: "3.10" 17 | 18 | - name: install dependency 19 | run: | 20 | python -m pip install --upgrade pip wheel setuptools 21 | pip install . 22 | pip install pytest onnxruntime 23 | 24 | - name: benchmark test 25 | run: | 26 | python tests/test_benchmark.py 27 | 28 | - name: model test 29 | run: | 30 | pip install . 31 | pip install pytest pytest-xdist onnxruntime timm torchvision --no-cache-dir --extra-index-url https://download.pytorch.org/whl/cpu 32 | python tests/test_onnx_nets.py 33 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | permissions: 16 | contents: read 17 | 18 | jobs: 19 | deploy: 20 | runs-on: ubuntu-latest 21 | 22 | steps: 23 | - uses: actions/checkout@v3 24 | - name: Set up Python 25 | uses: actions/setup-python@v3 26 | with: 27 | python-version: "3.x" 28 | - name: Install dependencies 29 | run: | 30 | python -m pip install --upgrade pip wheel setuptools 31 | pip install build 32 | - name: Build package 33 | run: python -m build 34 | - name: Publish package 35 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 36 | with: 37 | user: __token__ 38 | password: ${{ secrets.PYPI_API_TOKEN }} 39 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | tmp/ 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # IPython Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule 77 | 78 | # dotenv 79 | .env 80 | 81 | # virtualenv 82 | venv/ 83 | ENV/ 84 | 85 | # Spyder project settings 86 | .spyderproject 87 | 88 | # Rope project settings 89 | .ropeproject 90 | 91 | # atom remote-sync package 92 | .remote-sync.json 93 | 94 | # weights 95 | weights/ 96 | 97 | #DS_Store 98 | .DS_Store 99 | 100 | # dev stuff 101 | eval/ 102 | eval.ipynb 103 | dev.ipynb 104 | .vscode/ 105 | 106 | # not ready 107 | videos/ 108 | templates/ 109 | data/ssd_dataloader.py 110 | data/datasets/ 111 | doc/visualize.py 112 | read_results.py 113 | ssd300_120000/ 114 | demos/live 115 | webdemo.py 116 | test_data_aug.py 117 | 118 | # attributes 119 | 120 | # pycharm 121 | .idea/ 122 | 123 | # temp checkout soln 124 | data/datasets/ 125 | data/ssd_dataloader.py 126 | tests/model 127 | 128 | # pylint 129 | .pylintrc 130 | 131 | # onnx 132 | *.onnx 133 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 inisis 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include VERSION 2 | recursive-exclude tests * -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # OnnxSlim 2 | 3 |

4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 |

20 | 21 | OnnxSlim can help you slim your onnx model, with less operators, but same accuracy, better inference speed. 22 | 23 | - 🚀 2025/05/17: OnnxSlim is merged into [optimum](https://github.com/huggingface/optimum) 🤗🤗🤗 24 | - 🚀 2025/04/30: Rank 1st in the [AICAS 2025 LLM inference optimization challenge](https://tianchi.aliyun.com/competition/entrance/532289/customize588) 25 | - 🚀 2025/01/28: Achieved 1M downloads 26 | - 🚀 2024/06/23: OnnxSlim is merged into [transformers.js](https://github.com/huggingface/transformers.js) 🤗🤗🤗 27 | - 🚀 2024/06/02: OnnxSlim is merged into [ultralytics](https://github.com/ultralytics/ultralytics) ❤️❤️❤️ 28 | - 🚀 2024/04/30: Rank 1st in the [AICAS 2024 LLM inference optimization challenge](https://tianchi.aliyun.com/competition/entrance/532170/customize440) held by Arm and T-head 29 | - 🚀 2024/01/25: OnnxSlim is merged to [mnn-llm](https://github.com/wangzhaode/mnn-llm), performance increased by 5% 30 | 31 | # Installation 32 | 33 | ## Using Prebuilt 34 | 35 | ```bash 36 | pip install onnxslim 37 | ``` 38 | 39 | ## Install From Source 40 | 41 | ```bash 42 | pip install git+https://github.com/inisis/OnnxSlim@main 43 | ``` 44 | 45 | ## Install From Local 46 | 47 | ```bash 48 | git clone https://github.com/inisis/OnnxSlim && cd OnnxSlim/ 49 | pip install . 50 | ``` 51 | 52 | # How to use 53 | 54 | ``` 55 | onnxslim your_onnx_model slimmed_onnx_model 56 | ``` 57 | 58 |
59 | 60 | For more usage, see onnxslim -h or refer to our [examples](./examples) 61 | 62 | # Projects using OnnxSlim 63 | 64 | - [Mozilla/smart_autofill](https://github.com/mozilla/smart_autofill) 65 | - [alibaba/MNN](https://github.com/alibaba/MNN) 66 | - [PaddlePaddle/PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR) 67 | - [huggingface/transformers.js](https://github.com/huggingface/transformers.js) 68 | - [huggingface/optimum](https://github.com/huggingface/optimum) 69 | - [THU-MIG/yolov10](https://github.com/THU-MIG/yolov10) 70 | - [ultralytics/ultralytics](https://github.com/ultralytics/ultralytics) 71 | - [ModelScope/FunASR](https://github.com/modelscope/FunASR) 72 | - [alibaba/MNN-LLM](https://github.com/wangzhaode/mnn-llm) 73 | - [deepghs/imgutils](https://github.com/deepghs/imgutils) 74 | - [sunsmarterjie/yolov12](https://github.com/sunsmarterjie/yolov12) 75 | - [nndeploy/nndeploy](https://github.com/nndeploy/nndeploy) 76 | 77 | # References 78 | 79 | > - [onnx-graphsurgeon](https://github.com/NVIDIA/TensorRT/tree/main/tools/onnx-graphsurgeon) 80 | > - [Polygraphy](https://github.com/NVIDIA/TensorRT/tree/main/tools/Polygraphy/polygraphy) 81 | > - [onnx-simplifier](https://github.com/daquexian/onnx-simplifier) 82 | > - [tabulate](https://github.com/astanin/python-tabulate) 83 | > - [onnxruntime](https://github.com/microsoft/onnxruntime) 84 | 85 | # Contact 86 | 87 | Discord: https://discord.gg/nRw2Fd3VUS QQ Group: `873569894` 88 | -------------------------------------------------------------------------------- /VERSION: -------------------------------------------------------------------------------- 1 | 0.1.56 2 | -------------------------------------------------------------------------------- /bin/onnxslim: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import os 4 | import sys 5 | 6 | G_SCRIPT_FILE = os.path.realpath(__file__) 7 | G_ROOT_DIR = os.path.join(os.path.dirname(G_SCRIPT_FILE), os.pardir) 8 | sys.path.insert(0, G_ROOT_DIR) 9 | 10 | 11 | from onnxslim import cli 12 | 13 | if __name__ == "__main__": 14 | sys.exit(cli.main()) -------------------------------------------------------------------------------- /docs/_static/style.css: -------------------------------------------------------------------------------- 1 | .wy-nav-content { 2 | max-width: 1100px !important; 3 | } 4 | -------------------------------------------------------------------------------- /docs/_templates/layout.html: -------------------------------------------------------------------------------- 1 | {% extends '!layout.html' %} {% block extrahead %} 2 | 3 | {% endblock %} {% block footer %} 4 | 7 | {% endblock %} 8 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | ROOT_DIR = os.path.join(os.path.abspath(os.path.dirname(__file__)), os.path.pardir) 5 | sys.path.insert(0, ROOT_DIR) 6 | 7 | 8 | project = "OnnxSlim" 9 | copyright = "2022, WeLoveAI" 10 | author = "inisis" 11 | 12 | import onnxslim 13 | 14 | version = onnxslim.__version__ 15 | 16 | extensions = [ 17 | "sphinx.ext.autodoc", 18 | "sphinx.ext.intersphinx", 19 | "sphinx.ext.autosummary", 20 | "sphinx.ext.napoleon", 21 | "sphinx.ext.mathjax", 22 | ] 23 | 24 | intersphinx_mapping = { 25 | "rtd": ("https://docs.readthedocs.io/en/stable/", None), 26 | "python": ("https://docs.python.org/3/", None), 27 | "sphinx": ("https://www.sphinx-doc.org/en/master/", None), 28 | } 29 | intersphinx_disabled_domains = ["std"] 30 | 31 | templates_path = ["_templates"] 32 | epub_show_urls = "footnote" 33 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] 34 | html_theme = "sphinx_rtd_theme" 35 | napoleon_preprocess_types = True 36 | 37 | html_static_path = ["_static"] 38 | 39 | 40 | def setup(app): 41 | """Configure the Sphinx app by adding a custom CSS file ('style.css').""" 42 | app.add_css_file("style.css") 43 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | OnnxSlim 2 | =================================== 3 | 4 | Onnxslim is a toolkit to help optimize large onnx models 5 | 6 | 7 | .. toctree:: 8 | :hidden: 9 | 10 | self 11 | 12 | .. toctree:: 13 | :caption: API Reference 14 | :maxdepth: 2 15 | 16 | main/toc 17 | -------------------------------------------------------------------------------- /docs/main/toc.rst: -------------------------------------------------------------------------------- 1 | main 2 | ============ 3 | 4 | .. autoclass:: onnxslim.slim 5 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx 2 | sphinx_rtd_theme 3 | -------------------------------------------------------------------------------- /examples/common_subexpression_elimination/README.md: -------------------------------------------------------------------------------- 1 | # Common SubExpression Elimination 2 | 3 | ## Introduction 4 | 5 | Common Subexpression Elimination (CSE) is a powerful optimization technique commonly employed in compilers to improve the efficiency of code execution. It targets redundant computations within a program by identifying and removing duplicate expressions, thus reducing both computational overhead and memory usage. By eliminating redundant computations, CSE enhances the overall performance of slimmed onnx model. 6 | 7 | ## How CSE Works 8 | 9 | In many programs, certain expressions are computed multiple times within a given scope, even though their results remain constant across these computations. Common subexpressions refer to these redundant expressions. CSE identifies such common subexpressions and replaces subsequent occurrences with references to the original computation result. This process effectively reduces the number of computations required during program execution. 10 | 11 | For example, consider the following code snippet: 12 | 13 | ``` 14 | int a = b + c; 15 | int x = b + c; 16 | ``` 17 | 18 | In this code, b + c is a common subexpression computed twice. With CSE, the redundant computation of b + c would be eliminated, and both occurrences of x would directly reference the computation result of a. 19 | 20 | ## Running the example 21 | 22 | ```bash 23 | python cse_demo.py # to generate onnx model for demo 24 | ``` 25 | 26 | ![../../image/before_cse.png](../../images/before_cse.png) 27 | 28 | There are two identical blocks that are doing the same things. 29 | 30 | ```bash 31 | onnxslim ln_cse.onnx slim.onnx 32 | ``` 33 | 34 | After onnxslim, the output will look like this: 35 | 36 | ![../../image/after_cse.png](../../images/after_cse.png) 37 | 38 | and the summary is as follow: 39 | 40 | ![../../image/cse.png](../../images/cse.png) 41 | -------------------------------------------------------------------------------- /examples/common_subexpression_elimination/cse_demo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class Model(torch.nn.Module): 7 | def __init__(self): 8 | """Initializes the Model class with a single LayerNorm layer of embedding dimension 10.""" 9 | super().__init__() 10 | embedding_dim = 10 11 | self.layer_norm = nn.LayerNorm(embedding_dim) 12 | 13 | def forward(self, x): 14 | """Applies LayerNorm to the input tensor and adds it to an independently computed LayerNorm of the same 15 | tensor. 16 | """ 17 | return self.layer_norm(x) + F.layer_norm(x, [10]) 18 | 19 | 20 | layer_norm = Model() 21 | 22 | batch, sentence_length, embedding_dim = 20, 5, 10 23 | embedding = torch.randn(batch, sentence_length, embedding_dim) 24 | torch.onnx.export(layer_norm, embedding, "ln_cse.onnx", opset_version=13) 25 | -------------------------------------------------------------------------------- /examples/input_shape_modification/README.md: -------------------------------------------------------------------------------- 1 | # Input Shape Modification 2 | 3 | ## Introduction 4 | 5 | OnnxSlim includes an exploration of essential input shape modification techniques for ONNX models. 6 | 7 | This concise guide unveils techniques for seamlessly adjusting input tensor dimensions, ensuring optimal compatibility and performance within the dynamic landscape of neural network architectures. 8 | 9 | ## Running the example 10 | 11 | Change the input model by running: 12 | 13 | ```bash 14 | onnxslim UNetModel-fp16.onnx slim.onnx --input_shapes cc:1,1,768 15 | ``` 16 | 17 | The slimmed model will look like this: 18 | 19 | ![../../image/input_shape_modification.jpg](../../images/input_shape_modification.jpg) 20 | -------------------------------------------------------------------------------- /examples/model_inspect/README.md: -------------------------------------------------------------------------------- 1 | # Model Inspect 2 | 3 | ## Introduction 4 | 5 | Dive deep into the intricacies of your ONNX model using the powerful --inspect argument with OnnxSlim. This feature provides detailed insights into various aspects of your model, including input and output details, operator information, opset version, and more. 6 | 7 | ## Running the example 8 | 9 | Unveil the secrets of your ONNX model by executing the following command: 10 | 11 | ```bash 12 | onnxslim --inspect UNetModel-fp16.onnx 13 | ``` 14 | 15 | The output will look like this: 16 | 17 | ![../../image/model_inspect.jpg](../../images/model_inspect.jpg) 18 | -------------------------------------------------------------------------------- /examples/output_modification/README.md: -------------------------------------------------------------------------------- 1 | # Output Modification 2 | 3 | ## Introduction 4 | 5 | OnnxSlim provides capabilities for modifying the output specifications of ONNX models. 6 | 7 | This section explores techniques to customize the outputs, allowing for flexibility in handling diverse model requirements. 8 | 9 | ## Running the example 10 | 11 | Change the output of one model by running: 12 | 13 | ```bash 14 | onnxslim yolov5m.onnx slim.onnx --outputs 591 739 443 15 | ``` 16 | 17 | The slimmed model will look like this: 18 | 19 | ![../../image/output_modification.jpg](../../images/output_modification.jpg) 20 | -------------------------------------------------------------------------------- /format.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | set -ex 3 | ufmt format -- . 4 | autoflake --in-place --recursive . 5 | -------------------------------------------------------------------------------- /images/after_cse.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/inisis/OnnxSlim/4e820ced75ed909f96b5b92bc46cebf4ca3a2156/images/after_cse.png -------------------------------------------------------------------------------- /images/before_cse.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/inisis/OnnxSlim/4e820ced75ed909f96b5b92bc46cebf4ca3a2156/images/before_cse.png -------------------------------------------------------------------------------- /images/cse.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/inisis/OnnxSlim/4e820ced75ed909f96b5b92bc46cebf4ca3a2156/images/cse.png -------------------------------------------------------------------------------- /images/input_shape_modification.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/inisis/OnnxSlim/4e820ced75ed909f96b5b92bc46cebf4ca3a2156/images/input_shape_modification.jpg -------------------------------------------------------------------------------- /images/model_inspect.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/inisis/OnnxSlim/4e820ced75ed909f96b5b92bc46cebf4ca3a2156/images/model_inspect.jpg -------------------------------------------------------------------------------- /images/onnxslim.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/inisis/OnnxSlim/4e820ced75ed909f96b5b92bc46cebf4ca3a2156/images/onnxslim.gif -------------------------------------------------------------------------------- /images/output_modification.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/inisis/OnnxSlim/4e820ced75ed909f96b5b92bc46cebf4ca3a2156/images/output_modification.jpg -------------------------------------------------------------------------------- /onnxslim/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import warnings 3 | 4 | from onnxslim.cli import slim 5 | from onnxslim.core.optimization import OptimizationSettings 6 | from onnxslim.core.pattern.registry import ( 7 | DEFAULT_FUSION_PATTERNS, 8 | register_fusion_pattern, 9 | ) 10 | from onnxslim.version import __version__ 11 | 12 | if os.path.dirname(os.path.realpath(__file__)) == os.path.join(os.path.realpath(os.getcwd()), "onnxslim"): 13 | message = ( 14 | "You are importing onnxslim within its own root folder ({}). " 15 | "This is not expected to work and may give errors. Please exit the " 16 | "onnxslim project source and relaunch your python interpreter." 17 | ) 18 | warnings.warn(message.format(os.getcwd())) 19 | -------------------------------------------------------------------------------- /onnxslim/__main__.py: -------------------------------------------------------------------------------- 1 | from onnxslim.cli._main import main 2 | 3 | if __name__ == "__main__": 4 | main() 5 | -------------------------------------------------------------------------------- /onnxslim/argparser.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import dataclasses 3 | from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser 4 | from dataclasses import dataclass, field 5 | from typing import List, Optional, Type, Union, get_args, get_origin 6 | 7 | import onnxslim 8 | 9 | 10 | def _get_inner_type(arg_type): 11 | if get_origin(arg_type) is Union: 12 | return next((t for t in get_args(arg_type) if t is not type(None)), str) 13 | return arg_type 14 | 15 | 16 | @dataclass 17 | class ModelArguments: 18 | """ 19 | Args: 20 | model (Union[str, onnx.ModelProto]): The ONNX model to be slimmed. It can be either a file path or an `onnx.ModelProto` object. 21 | 22 | output_model (str, optional): File path to save the slimmed model. If None, the model will not be saved. 23 | """ 24 | 25 | input_model: str = field(metadata={"help": "input onnx model"}) 26 | output_model: Optional[str] = field(default=None, metadata={"help": "output onnx model"}) 27 | 28 | 29 | @dataclass 30 | class OptimizationArguments: 31 | """ 32 | Args: 33 | no_shape_infer (bool, optional): Flag indicating whether to perform shape inference. Default is False. 34 | 35 | no_constant_folding (bool, optional): Flag indicating whether to perform constant folding. Default is False. 36 | 37 | skip_fusion_patterns (str, optional): String representing fusion patterns to skip. Default is None. 38 | """ 39 | 40 | no_shape_infer: bool = field(default=False, metadata={"help": "whether to disable shape_infer, default false."}) 41 | skip_optimizations: Optional[List[str]] = field( 42 | default=None, 43 | metadata={ 44 | "help": "whether to skip some optimizations", 45 | "choices": list(onnxslim.OptimizationSettings.keys()), 46 | }, 47 | ) 48 | skip_fusion_patterns: Optional[List[str]] = field( 49 | default=None, 50 | metadata={ 51 | "help": "whether to skip the fusion of some patterns", 52 | "choices": list(onnxslim.DEFAULT_FUSION_PATTERNS.keys()), 53 | }, 54 | ) 55 | size_threshold: int = field( 56 | default=None, 57 | metadata={ 58 | "help": "size threshold in bytes, size larger than this value will not be folded, default None, which means fold all constants", 59 | }, 60 | ) 61 | 62 | 63 | @dataclass 64 | class ModificationArguments: 65 | """ 66 | Args: 67 | input_shapes (str, optional): String representing the input shapes. Default is None. 68 | 69 | outputs (str, optional): String representing the outputs. Default is None. 70 | 71 | dtype (str, optional): Data type. Default is None. 72 | 73 | save_as_external_data (bool, optional): Flag indicating whether to split onnx as model and weight. Default is False. 74 | """ 75 | 76 | input_shapes: Optional[List[str]] = field( 77 | default=None, 78 | metadata={ 79 | "help": "input shape of the model, INPUT_NAME:SHAPE, e.g. x:1,3,224,224 or x1:1,3,224,224 x2:1,3,224,224" 80 | }, 81 | ) 82 | inputs: Optional[List[str]] = field( 83 | default=None, 84 | metadata={ 85 | "help": "input of the model, INPUT_NAME:DTYPE, e.g. y:fp32 or y1:fp32 y2:fp32. If dtype is not specified, the dtype of the input will be the same as the original model if it has dtype, otherwise it will be fp32, available dtype: fp16, fp32, int32" 86 | }, 87 | ) 88 | outputs: Optional[List[str]] = field( 89 | default=None, 90 | metadata={ 91 | "help": "output of the model, OUTPUT_NAME:DTYPE, e.g. y:fp32 or y1:fp32 y2:fp32. If dtype is not specified, the dtype of the output will be the same as the original model if it has dtype, otherwise it will be fp32, available dtype: fp16, fp32, int32" 92 | }, 93 | ) 94 | dtype: Optional[str] = field( 95 | default=None, metadata={"help": "convert data format to fp16 or fp32.", "choices": ["fp16", "fp32"]} 96 | ) 97 | save_as_external_data: bool = field( 98 | default=False, metadata={"help": "split onnx as model and weight, default False."} 99 | ) 100 | 101 | 102 | @dataclass 103 | class CheckerArguments: 104 | """ 105 | Args: 106 | model_check (bool, optional): Flag indicating whether to perform model checking. Default is False. 107 | 108 | model_check_inputs (str, optional): The shape or tensor used for model check. Default is None. 109 | 110 | inspect (bool, optional): Flag indicating whether to inspect the model. Default is False. 111 | 112 | dump_to_disk (bool, optional): Flag indicating whether to dump the model detail to disk. Default is False. 113 | 114 | verbose (bool, optional): Flag indicating whether to print verbose logs. Default is False. 115 | """ 116 | 117 | model_check: bool = field(default=False, metadata={"help": "enable model check"}) 118 | model_check_inputs: Optional[List[str]] = field( 119 | default=None, 120 | metadata={ 121 | "help": "Works only when model_check is enabled, Input shape of the model or numpy data path, INPUT_NAME:SHAPE or INPUT_NAME:DATAPATH, e.g. x:1,3,224,224 or x1:1,3,224,224 x2:data.npy. Useful when input shapes are dynamic." 122 | }, 123 | ) 124 | inspect: bool = field(default=False, metadata={"help": "inspect model, default False."}) 125 | dump_to_disk: bool = field(default=False, metadata={"help": "dump model info to disk, default False."}) 126 | verbose: bool = field(default=False, metadata={"help": "verbose mode, default False."}) 127 | 128 | 129 | class OnnxSlimArgumentParser(ArgumentParser): 130 | def __init__(self, *argument_dataclasses: Type, **kwargs): 131 | if "formatter_class" not in kwargs: 132 | kwargs["formatter_class"] = ArgumentDefaultsHelpFormatter 133 | super().__init__(**kwargs) 134 | self.argument_dataclasses = argument_dataclasses 135 | self.parser = argparse.ArgumentParser( 136 | description="OnnxSlim: A Toolkit to Help Optimizer Onnx Model", 137 | formatter_class=argparse.RawDescriptionHelpFormatter, 138 | ) 139 | self._add_arguments() 140 | 141 | def _add_arguments(self): 142 | for dataclass_type in self.argument_dataclasses: 143 | if dataclass_type is ModelArguments: 144 | continue 145 | for field_name, field_def in dataclass_type.__dataclass_fields__.items(): 146 | arg_type = _get_inner_type(field_def.type) 147 | default_value = field_def.default if field_def.default is not field_def.default_factory else None 148 | help_text = field_def.metadata.get("help", "") 149 | nargs = "+" if get_origin(arg_type) == list else None 150 | choices = field_def.metadata.get("choices", None) 151 | if choices and default_value is not None and default_value not in choices: 152 | raise ValueError( 153 | f"Invalid default value '{default_value}' for argument '{field_name}'. Must be one of {choices}." 154 | ) 155 | arg_type = get_args(arg_type)[0] if get_args(arg_type) else arg_type 156 | if arg_type == bool: 157 | self.parser.add_argument( 158 | f"--{field_name.replace('_', '-')}", 159 | action="store_true", 160 | default=default_value, 161 | help=help_text, 162 | ) 163 | else: 164 | self.parser.add_argument( 165 | f"--{field_name.replace('_', '-')}", 166 | type=arg_type, 167 | default=default_value, 168 | nargs=nargs, 169 | choices=choices, 170 | help=help_text, 171 | ) 172 | 173 | # Add positional arguments separately for ModelArguments 174 | self.parser.add_argument("input_model", help="input onnx model") 175 | self.parser.add_argument("output_model", nargs="?", default=None, help="output onnx model") 176 | self.parser.add_argument("-v", "--version", action="version", version=onnxslim.__version__) 177 | 178 | def parse_args_into_dataclasses(self): 179 | # Pre-parse arguments to check for `--inspect` 180 | pre_parsed_args, _ = self.parser.parse_known_args() 181 | if pre_parsed_args.inspect: 182 | for action in self.parser._actions: 183 | if action.dest == "input_model": 184 | action.nargs = "+" 185 | break 186 | 187 | args = self.parser.parse_args() 188 | args_dict = vars(args) 189 | 190 | outputs = [] 191 | for dtype in self.argument_dataclasses: 192 | keys = {f.name for f in dataclasses.fields(dtype) if f.init} 193 | inputs = {k: v for k, v in args_dict.items() if k in keys} 194 | obj = dtype(**inputs) 195 | outputs.append(obj) 196 | 197 | return (*outputs,) 198 | -------------------------------------------------------------------------------- /onnxslim/cli/__init__.py: -------------------------------------------------------------------------------- 1 | from onnxslim.cli._main import main, slim 2 | -------------------------------------------------------------------------------- /onnxslim/cli/_main.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union 2 | 3 | import onnx 4 | 5 | 6 | def slim(model: Union[str, onnx.ModelProto, List[Union[str, onnx.ModelProto]]], *args, **kwargs): 7 | import os 8 | import time 9 | from pathlib import Path 10 | 11 | from onnxslim.core import ( 12 | OptimizationSettings, 13 | convert_data_format, 14 | freeze, 15 | input_modification, 16 | input_shape_modification, 17 | optimize, 18 | output_modification, 19 | shape_infer, 20 | ) 21 | from onnxslim.utils import ( 22 | TensorInfo, 23 | check_onnx, 24 | check_point, 25 | check_result, 26 | dump_model_info_to_disk, 27 | init_logging, 28 | onnxruntime_inference, 29 | print_model_info_as_table, 30 | save, 31 | summarize_model, 32 | update_outputs_dims, 33 | ) 34 | 35 | output_model = args[0] if len(args) > 0 else kwargs.get("output_model", None) 36 | model_check = kwargs.get("model_check", False) 37 | input_shapes = kwargs.get("input_shapes", None) 38 | inputs = kwargs.get("inputs", None) 39 | outputs = kwargs.get("outputs", None) 40 | no_shape_infer = kwargs.get("no_shape_infer", False) 41 | skip_optimizations = kwargs.get("skip_optimizations", None) 42 | dtype = kwargs.get("dtype", None) 43 | skip_fusion_patterns = kwargs.get("skip_fusion_patterns", None) 44 | size_threshold = kwargs.get("size_threshold", None) 45 | size_threshold = int(size_threshold) if size_threshold else None 46 | kwargs.get("inspect", False) 47 | dump_to_disk = kwargs.get("dump_to_disk", False) 48 | save_as_external_data = kwargs.get("save_as_external_data", False) 49 | model_check_inputs = kwargs.get("model_check_inputs", None) 50 | verbose = kwargs.get("verbose", False) 51 | 52 | logger = init_logging(verbose) 53 | 54 | MAX_ITER = int(os.getenv("ONNXSLIM_MAX_ITER")) if os.getenv("ONNXSLIM_MAX_ITER") else 10 55 | 56 | start_time = time.time() 57 | 58 | def get_info(model, inspect=False): 59 | if isinstance(model, str): 60 | model_name = Path(model).name 61 | model = onnx.load(model) 62 | else: 63 | model_name = "OnnxModel" 64 | 65 | freeze(model) 66 | 67 | if not inspect: 68 | return model_name, model 69 | 70 | model_info = summarize_model(model, model_name) 71 | 72 | return model_info 73 | 74 | if isinstance(model, list): 75 | model_info_list = [get_info(m, inspect=True) for m in model] 76 | 77 | if dump_to_disk: 78 | [dump_model_info_to_disk(info) for info in model_info_list] 79 | 80 | print_model_info_as_table(model_info_list) 81 | 82 | return 83 | else: 84 | model_name, model = get_info(model) 85 | if output_model: 86 | original_info = summarize_model(model, model_name) 87 | 88 | if inputs: 89 | model = input_modification(model, inputs) 90 | 91 | if input_shapes: 92 | model = input_shape_modification(model, input_shapes) 93 | 94 | if outputs: 95 | model = output_modification(model, outputs) 96 | 97 | if model_check: 98 | input_data_dict, raw_onnx_output, model = check_onnx(model, model_check_inputs) 99 | 100 | output_info = {TensorInfo(o).name: TensorInfo(o).shape for o in model.graph.output} 101 | 102 | if not no_shape_infer: 103 | model = shape_infer(model) 104 | 105 | OptimizationSettings.reset(skip_optimizations) 106 | if OptimizationSettings.enabled(): 107 | graph_check_point = check_point(model) 108 | while MAX_ITER > 0: 109 | logger.debug(f"iter: {MAX_ITER}") 110 | model = optimize(model, skip_fusion_patterns, size_threshold) 111 | if not no_shape_infer: 112 | model = shape_infer(model) 113 | graph = check_point(model) 114 | if graph == graph_check_point: 115 | logger.debug(f"converged at iter: {MAX_ITER}") 116 | break 117 | else: 118 | graph_check_point = graph 119 | 120 | MAX_ITER -= 1 121 | 122 | if dtype: 123 | model = convert_data_format(model, dtype) 124 | 125 | model = update_outputs_dims(model, output_dims=output_info) 126 | 127 | if model_check: 128 | slimmed_onnx_output, model = onnxruntime_inference(model, input_data_dict) 129 | if not check_result(raw_onnx_output, slimmed_onnx_output): 130 | return None 131 | 132 | if not output_model: 133 | return model 134 | 135 | slimmed_info = summarize_model(model, output_model) 136 | save(model, output_model, model_check, save_as_external_data, slimmed_info) 137 | 138 | end_time = time.time() 139 | elapsed_time = end_time - start_time 140 | print_model_info_as_table( 141 | [original_info, slimmed_info], 142 | elapsed_time, 143 | ) 144 | 145 | 146 | def main(): 147 | """Entry point for the OnnxSlim toolkit, processes command-line arguments and passes them to the slim function.""" 148 | from onnxslim.argparser import ( 149 | CheckerArguments, 150 | ModelArguments, 151 | ModificationArguments, 152 | OnnxSlimArgumentParser, 153 | OptimizationArguments, 154 | ) 155 | 156 | argument_parser = OnnxSlimArgumentParser( 157 | ModelArguments, OptimizationArguments, ModificationArguments, CheckerArguments 158 | ) 159 | model_args, optimization_args, modification_args, checker_args = argument_parser.parse_args_into_dataclasses() 160 | 161 | if not checker_args.inspect and checker_args.dump_to_disk: 162 | argument_parser.error("dump_to_disk can only be used with --inspect") 163 | 164 | if not optimization_args.no_shape_infer: 165 | from onnxslim.utils import check_onnx_compatibility, is_onnxruntime_available 166 | 167 | if is_onnxruntime_available(): 168 | check_onnx_compatibility() 169 | 170 | slim( 171 | model_args.input_model, 172 | model_args.output_model, 173 | **optimization_args.__dict__, 174 | **modification_args.__dict__, 175 | **checker_args.__dict__, 176 | ) 177 | 178 | return 0 179 | -------------------------------------------------------------------------------- /onnxslim/core/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import tempfile 4 | 5 | import numpy as np 6 | import onnx 7 | from onnx import checker 8 | 9 | import onnxslim.third_party.onnx_graphsurgeon as gs 10 | from onnxslim.core.optimization import OptimizationSettings, optimize_model 11 | from onnxslim.third_party.onnx_graphsurgeon.exporters.onnx_exporter import dtype_to_onnx 12 | from onnxslim.third_party.onnx_graphsurgeon.ir.tensor import Constant 13 | from onnxslim.third_party.symbolic_shape_infer import SymbolicShapeInference 14 | from onnxslim.utils import save 15 | 16 | logger = logging.getLogger("onnxslim") 17 | 18 | 19 | DEBUG = bool(os.getenv("ONNXSLIM_DEBUG")) 20 | AUTO_MERGE = True if os.getenv("ONNXSLIM_AUTO_MERGE") is None else bool(int(os.getenv("ONNXSLIM_AUTO_MERGE"))) 21 | 22 | 23 | def input_shape_modification(model: onnx.ModelProto, input_shapes: str) -> onnx.ModelProto: 24 | """Modifies input tensor shapes in the ONNX model according to the specified input_shapes string.""" 25 | if not input_shapes: 26 | return 27 | 28 | graph = gs.import_onnx(model) 29 | input_names = [input.name for input in graph.inputs] 30 | tensors = graph.tensors() 31 | 32 | for input_shape in input_shapes: 33 | key, values = input_shape.rsplit(":", 1) 34 | values_list = [int(value) for value in values.split(",")] 35 | if key not in input_names: 36 | raise Exception(f"Input name {key} not found in model, available keys: {' '.join(input_names)}") 37 | tensors[key].shape = values_list 38 | 39 | for tensor in tensors.values(): 40 | if tensor.name not in input_names: 41 | if isinstance(tensor, Constant): 42 | continue 43 | tensor.shape = None 44 | 45 | model = gs.export_onnx(graph) 46 | 47 | return model 48 | 49 | 50 | def output_modification(model: onnx.ModelProto, outputs: str) -> onnx.ModelProto: 51 | """Modifies the output layers of the ONNX model based on specified output names and data types.""" 52 | graph = gs.import_onnx(model) 53 | graph.outputs.clear() 54 | tensors = graph.tensors() 55 | for output in outputs: 56 | values = output.rsplit(":", 1) 57 | if len(values) == 1: 58 | key = values[0] 59 | if key not in tensors.keys(): 60 | raise Exception(f"Output name {key} not found in model, available keys: {' '.join(tensors.keys())}") 61 | dtype = tensors[key].dtype 62 | if dtype is None: 63 | dtype = np.float32 64 | logger.warning(f"Output layer {key} has no dtype, set to default {dtype}") 65 | else: 66 | key, dtype = values 67 | if dtype == "fp16": 68 | dtype = np.float16 69 | elif dtype == "fp32": 70 | dtype = np.float32 71 | elif dtype == "int32": 72 | dtype = np.int32 73 | elif dtype == "bool": 74 | dtype = bool 75 | else: 76 | raise Exception(f"Output layer {key} assigned unsupported dtype {dtype}") 77 | 78 | graph.outputs.append(tensors[key].to_variable(dtype=dtype, shape=tensors[key].shape)) 79 | 80 | graph.cleanup(remove_unused_graph_inputs=True).toposort() 81 | model = gs.export_onnx(graph) 82 | 83 | return model 84 | 85 | 86 | def input_modification(model: onnx.ModelProto, inputs: str) -> onnx.ModelProto: 87 | """Modifies the output layers of the ONNX model based on specified output names and data types.""" 88 | graph = gs.import_onnx(model) 89 | graph.inputs.clear() 90 | tensors = graph.tensors() 91 | for input in inputs: 92 | values = input.rsplit(":", 1) 93 | if len(values) == 1: 94 | key = values[0] 95 | if key not in tensors.keys(): 96 | raise Exception(f"Input name {key} not found in model, available keys: {' '.join(tensors.keys())}") 97 | dtype = tensors[key].dtype 98 | if dtype is None: 99 | dtype = np.float32 100 | logger.warning(f"Input layer {key} has no dtype, set to default {dtype}") 101 | else: 102 | key, dtype = values 103 | if dtype == "fp16": 104 | dtype = np.float16 105 | elif dtype == "fp32": 106 | dtype = np.float32 107 | elif dtype == "int32": 108 | dtype = np.int32 109 | elif dtype == "bool": 110 | dtype = bool 111 | else: 112 | raise Exception(f"Input layer {key} assigned unsupported dtype {dtype}") 113 | 114 | graph.inputs.append(tensors[key].to_variable(dtype=dtype, shape=tensors[key].shape)) 115 | 116 | graph.cleanup(remove_unused_graph_inputs=True).toposort() 117 | model = gs.export_onnx(graph) 118 | 119 | return model 120 | 121 | 122 | def shape_infer(model: onnx.ModelProto): 123 | """Infer tensor shapes in an ONNX model using symbolic and static shape inference techniques.""" 124 | logger.debug("Start shape inference.") 125 | try: 126 | logger.debug("try onnxruntime shape infer.") 127 | model = SymbolicShapeInference.infer_shapes(model, auto_merge=AUTO_MERGE) 128 | except Exception as err: 129 | logger.debug(f"onnxruntime shape infer failed, try onnx shape infer. {err}") 130 | if model.ByteSize() >= checker.MAXIMUM_PROTOBUF: 131 | tmp_dir = tempfile.TemporaryDirectory() 132 | tmp_path = os.path.join(tmp_dir.name, "tmp.onnx") 133 | tmp_infer_path = os.path.join(tmp_dir.name, "tmp_infer.onnx") 134 | save(model, tmp_path) 135 | onnx.shape_inference.infer_shapes_path(tmp_path, tmp_infer_path) 136 | model = onnx.load(tmp_infer_path) 137 | else: 138 | model = onnx.shape_inference.infer_shapes(model) 139 | if DEBUG: 140 | onnx.save(model, "debug_shape_infer.onnx") 141 | logger.debug("Finish shape inference.") 142 | return model 143 | 144 | 145 | def optimize(model: onnx.ModelProto, skip_fusion_patterns: str = None, size_threshold: int = None): 146 | """Optimize the given ONNX model with options to skip specific fusion patterns and return the optimized model.""" 147 | logger.debug("Start converting model to gs.") 148 | graph = gs.import_onnx(model).toposort() 149 | logger.debug("Finish converting model to gs.") 150 | if OptimizationSettings.constant_folding: 151 | logger.debug("Start constant folding.") 152 | graph.fold_constants(size_threshold=size_threshold).cleanup().toposort() 153 | logger.debug("Finish constant folding.") 154 | logger.debug("Start optimize model.") 155 | model = optimize_model(graph, skip_fusion_patterns) 156 | logger.debug("Finish optimize model.") 157 | if DEBUG: 158 | onnx.save(model, "debug_slim.onnx") 159 | 160 | return model 161 | 162 | 163 | def convert_data_format(model: onnx.ModelProto, dtype: str) -> onnx.ModelProto: 164 | """Convert ONNX model data format to specified dtype, supporting 'fp16' and 'fp32'.""" 165 | if dtype == "fp16": 166 | from onnxconverter_common import float16 167 | 168 | model = float16.convert_float_to_float16(model) 169 | elif dtype == "fp32": 170 | graph = gs.import_onnx(model).toposort() 171 | 172 | for node in graph.nodes: 173 | if node.op == "Cast": 174 | inp_dtype = [input.dtype for input in node.inputs][0] 175 | if inp_dtype in [np.float16, np.float32]: 176 | node.erase() 177 | else: 178 | outp_dtype = [output.dtype for output in node.outputs][0] 179 | if outp_dtype == np.float16: 180 | node.attrs["to"] = dtype_to_onnx(np.float32) 181 | node.outputs[0].dtype = np.float32 182 | elif node.op == "ConstantOfShape": 183 | if hasattr(node, "attrs") and "value" in node.attrs: 184 | if node.attrs["value"].dtype == np.float16: 185 | node.attrs["value"].values = node.attrs["value"].values.astype(np.float32) 186 | node.outputs[0].dtype = np.float32 187 | 188 | for tensor in graph.tensors().values(): 189 | if isinstance(tensor, gs.Variable) and tensor.dtype == np.float16: 190 | tensor.dtype = np.float32 191 | elif isinstance(tensor, gs.Constant) and tensor.dtype == np.float16: 192 | tensor.values = tensor.values.astype(np.float32) 193 | 194 | graph.cleanup(remove_unused_graph_inputs=True).toposort() 195 | model = gs.export_onnx(graph) 196 | 197 | return model 198 | 199 | 200 | def freeze(model: onnx.ModelProto): 201 | """Freeze the input layers of an ONNX model by removing the initializers from the input graph.""" 202 | inputs = model.graph.input 203 | name_to_input = {} 204 | for input in inputs: 205 | if input.name in name_to_input: 206 | logger.warning(f"Duplicate input name: {input.name}") 207 | name_to_input[input.name] = input 208 | 209 | for initializer in model.graph.initializer: 210 | if initializer.name in name_to_input: 211 | inputs.remove(name_to_input[initializer.name]) 212 | name_to_input.pop(initializer.name) 213 | -------------------------------------------------------------------------------- /onnxslim/core/optimization/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from collections import Counter 3 | from typing import List, Optional, Union 4 | 5 | import onnx 6 | 7 | import onnxslim.third_party.onnx_graphsurgeon as gs 8 | from onnxslim.core.pattern.registry import get_fusion_patterns 9 | from onnxslim.third_party.onnx_graphsurgeon.ir.graph import Graph 10 | 11 | logger = logging.getLogger("onnxslim") 12 | 13 | from .dead_node_elimination import dead_node_elimination 14 | from .subexpression_elimination import subexpression_elimination 15 | from .weight_tying import tie_weights 16 | 17 | 18 | class OptimizationSettings: 19 | constant_folding = True 20 | graph_fusion = True 21 | dead_node_elimination = True 22 | subexpression_elimination = True 23 | weight_tying = True 24 | 25 | @classmethod 26 | def keys(cls): 27 | return [ 28 | "constant_folding", 29 | "graph_fusion", 30 | "dead_node_elimination", 31 | "subexpression_elimination", 32 | "weight_tying", 33 | ] 34 | 35 | @classmethod 36 | def reset(cls, skip_optimizations: Optional[List[str]] = None): 37 | for key in cls.keys(): 38 | if skip_optimizations and key in skip_optimizations: 39 | setattr(cls, key, False) 40 | else: 41 | setattr(cls, key, True) 42 | 43 | @classmethod 44 | def stats(cls): 45 | return {key: getattr(cls, key) for key in cls.keys()} 46 | 47 | @classmethod 48 | def enabled(cls): 49 | return any([getattr(cls, key) for key in cls.keys()]) 50 | 51 | 52 | def optimize_model(model: Union[onnx.ModelProto, gs.Graph], skip_fusion_patterns: str = None) -> onnx.ModelProto: 53 | """Optimize and transform the given ONNX model using various fusion patterns and graph rewriting techniques.""" 54 | graph = model if isinstance(model, gs.Graph) else gs.import_onnx(model) 55 | if OptimizationSettings.graph_fusion: 56 | logger.debug("Start graph_fusion.") 57 | fusion_patterns = get_fusion_patterns(skip_fusion_patterns) 58 | fusion_pairs = find_matches(graph, fusion_patterns) 59 | for match in fusion_pairs.values(): 60 | graph.replace_custom_layer(**match) 61 | graph.cleanup(remove_unused_graph_inputs=True).toposort() 62 | logger.debug("Finish graph_fusion.") 63 | if OptimizationSettings.dead_node_elimination: 64 | logger.debug("Start dead_node_elimination.") 65 | dead_node_elimination(graph) 66 | graph.cleanup(remove_unused_graph_inputs=True).toposort() 67 | logger.debug("Finish dead_node_elimination.") 68 | if OptimizationSettings.subexpression_elimination: 69 | logger.debug("Start subexpression_elimination.") 70 | subexpression_elimination(graph) 71 | graph.cleanup(remove_unused_graph_inputs=True).toposort() 72 | logger.debug("Finish subexpression_elimination.") 73 | if OptimizationSettings.weight_tying: 74 | logger.debug("Start weight_tying.") 75 | tie_weights(graph) 76 | logger.debug("Finish weight_tying.") 77 | model = gs.export_onnx(graph) 78 | 79 | return model 80 | 81 | 82 | @gs.Graph.register() 83 | def replace_custom_layer( 84 | self, 85 | op: str, 86 | inputs, 87 | outputs: List[str], 88 | name: str, 89 | attrs: dict = None, 90 | domain: str = "ai.onnx.contrib", 91 | ): 92 | """Replace a custom layer in the computational graph with specified parameters and domain.""" 93 | return self.layer( 94 | op=op, 95 | inputs=inputs, 96 | outputs=outputs, 97 | name=name, 98 | attrs=attrs, 99 | domain=domain, 100 | ) 101 | 102 | 103 | def find_matches(graph: Graph, fusion_patterns: dict): 104 | """Find matching patterns in the graph based on provided fusion patterns.""" 105 | match_map = {} 106 | counter = Counter() 107 | for node in reversed(graph.nodes): 108 | if node.name not in match_map: 109 | for layer_type, pattern_matcher in fusion_patterns.items(): 110 | match = pattern_matcher.match(node) 111 | if match: 112 | match_case = pattern_matcher.rewrite(opset=graph.opset) 113 | logger.debug(f"matched pattern {layer_type}") 114 | for _, match in match_case.items(): 115 | if "op" not in match: 116 | match.update({"op": layer_type}) 117 | if "name" not in match: 118 | match.update({"name": f"{layer_type.lower()}_{counter[layer_type]}"}) 119 | counter.update([layer_type]) 120 | match_map.update(match_case) 121 | 122 | return match_map 123 | 124 | 125 | def get_previous_node_by_type(node, op_type, trajectory=None): 126 | """Recursively find and return the first preceding node of a specified type in the computation graph.""" 127 | if trajectory is None: 128 | trajectory = [] 129 | node_feeds = node.feeds 130 | for node_feed in node_feeds: 131 | trajectory.append(node_feed) 132 | if node_feed.op == op_type: 133 | return trajectory 134 | else: 135 | return get_previous_node_by_type(node_feed, op_type, trajectory) 136 | -------------------------------------------------------------------------------- /onnxslim/core/optimization/dead_node_elimination.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import numpy as np 4 | 5 | import onnxslim.third_party.onnx_graphsurgeon as gs 6 | from onnxslim.third_party.onnx_graphsurgeon.exporters.onnx_exporter import dtype_to_onnx 7 | from onnxslim.third_party.onnx_graphsurgeon.ir.tensor import Constant, Variable 8 | 9 | logger = logging.getLogger("onnxslim") 10 | 11 | 12 | def dead_node_elimination(graph, is_subgraph=False): 13 | """Perform in-place constant folding optimizations on the given computational graph by eliminating redundant 14 | nodes. 15 | """ 16 | for subgraph in graph.subgraphs(): 17 | dead_node_elimination(subgraph, is_subgraph=True) 18 | 19 | for node in graph.nodes: 20 | if node.op in {"Identity", "Dropout"}: 21 | if not is_subgraph: 22 | node.erase() 23 | logger.debug(f"removing {node.op} op: {node.name}") 24 | elif node.op == "Pad": 25 | if len(node.inputs) > 1 and isinstance(node.inputs[1], Constant): 26 | pad_value = node.inputs[1].values.tolist() 27 | pad_value = pad_value if isinstance(pad_value, list) else [pad_value] 28 | if all(value == 0 for value in pad_value): 29 | node.erase() 30 | logger.debug(f"removing {node.op} op: {node.name}") 31 | elif node.op == "Cast": 32 | inp_dtype = [dtype_to_onnx(input.dtype) for input in node.inputs][0] 33 | if inp_dtype == node.attrs["to"]: 34 | node.erase() 35 | logger.debug(f"removing {node.op} op: {node.name}") 36 | elif node.op == "Reshape": 37 | if (node.inputs[0].shape and len(node.inputs[0].shape) == 1) and ( 38 | node.outputs[0].shape and len(node.outputs[0].shape) == 1 39 | ): 40 | node.erase() 41 | logger.debug(f"removing {node.op} op: {node.name}") 42 | elif node.inputs[0].shape and node.outputs[0].shape and node.inputs[0].shape == node.outputs[0].shape: 43 | node.erase() 44 | logger.debug(f"removing {node.op} op: {node.name}") 45 | else: 46 | node_output_shape = node.outputs[0].shape 47 | if node_output_shape and check_shape(node_output_shape) and not isinstance(node.inputs[1], gs.Constant): 48 | shapes = [shape if isinstance(shape, int) else -1 for shape in node_output_shape] 49 | reshape_const = gs.Constant( 50 | f"{node.inputs[1].name}_", 51 | values=np.array(shapes, dtype=np.int64), 52 | ) 53 | node.inputs.pop(1) 54 | node.inputs.insert(1, reshape_const) 55 | logger.debug(f"replacing {node.op} op: {node.name}") 56 | elif node.op == "Mul": 57 | if (isinstance(node.inputs[1], Constant) and isinstance(node.inputs[0], Variable)) or ( 58 | isinstance(node.inputs[0], Constant) and isinstance(node.inputs[1], Variable) 59 | ): 60 | idx, constant_variable = get_constant_variable(node, return_idx=True) 61 | if np.all(constant_variable.values == 1): 62 | var_idx = 0 if idx == 1 else 1 63 | node.erase(var_idx, 0) 64 | logger.debug(f"removing {node.op} op: {node.name}") 65 | elif node.op == "Add": 66 | if (isinstance(node.inputs[1], Constant) and isinstance(node.inputs[0], Variable)) or ( 67 | isinstance(node.inputs[0], Constant) and isinstance(node.inputs[1], Variable) 68 | ): 69 | idx, constant_variable = get_constant_variable(node, return_idx=True) 70 | value = constant_variable.values 71 | var_idx = 0 if idx == 1 else 1 72 | if value.ndim == 0 and value == 0: 73 | node.erase(var_idx, 0) 74 | logger.debug(f"removing {node.op} op: {node.name}") 75 | elif np.all(value == 0) and (node.inputs[var_idx].shape == node.outputs[0].shape): 76 | node.erase(var_idx, 0) 77 | logger.debug(f"removing {node.op} op: {node.name}") 78 | elif node.op == "Expand": 79 | # tests/test_onnx_nets.py::TestTimmClass::test_timm[lambda_resnet26rpt_256] 80 | if len(node.inputs) > 1 and isinstance(node.inputs[1], Constant): 81 | constant_variable = node.inputs[1] 82 | value = constant_variable.values 83 | if node.inputs[0].shape == node.outputs[0].shape: 84 | node.erase() 85 | logger.debug(f"removing {node.op} op: {node.name}") 86 | elif value.ndim == 0 and value == 1: 87 | node.erase() 88 | logger.debug(f"removing {node.op} op: {node.name}") 89 | elif node.op == "Concat": 90 | if len(node.inputs) == 1: 91 | node.erase() 92 | logger.debug(f"removing {node.op} op: {node.name}") 93 | else: 94 | for input in node.inputs: 95 | if isinstance(input, Constant) and input.values.size == 0: 96 | node.inputs.remove(input) 97 | elif node.op == "Sub": 98 | if isinstance(node.inputs[1], Constant) and isinstance(node.inputs[0], Variable): 99 | constant_variable = node.inputs[1] 100 | value = constant_variable.values 101 | if value.ndim == 0 and value == 0: 102 | node.erase() 103 | logger.debug(f"removing {node.op} op: {node.name}") 104 | elif np.all(value == 0) and (node.inputs[0].shape == node.outputs[0].shape): 105 | node.erase() 106 | logger.debug(f"removing {node.op} op: {node.name}") 107 | elif node.op == "Div": 108 | if isinstance(node.inputs[1], Constant) and isinstance(node.inputs[0], Variable): 109 | constant_variable = node.inputs[1] 110 | value = constant_variable.values 111 | if value.ndim == 0 and value == 1: 112 | node.erase() 113 | logger.debug(f"removing {node.op} op: {node.name}") 114 | elif np.all(value == 1) and (node.inputs[0].shape == node.outputs[0].shape): 115 | node.erase() 116 | logger.debug(f"removing {node.op} op: {node.name}") 117 | elif node.op == "Split": 118 | if ( 119 | len(node.outputs) == 1 120 | and node.outputs[0].shape 121 | and node.inputs[0].shape 122 | and node.outputs[0].shape == node.inputs[0].shape 123 | ): 124 | node.erase() 125 | logger.debug(f"removing {node.op} op: {node.name}") 126 | elif node.op == "Resize": 127 | mode = node.attrs.get("mode") 128 | if mode is None: 129 | node.attrs["mode"] = "nearest" 130 | logger.debug(f"setting mode to nearest for {node.op} op: {node.name} since it is not set") 131 | 132 | 133 | def check_shape(shapes): 134 | """Verify that 'shapes' contains exactly one string and all other elements are positive integers.""" 135 | string_count = 0 136 | non_negative_int_count = 0 137 | 138 | for item in shapes: 139 | if isinstance(item, str): 140 | string_count += 1 141 | elif isinstance(item, int) and item > 0: 142 | non_negative_int_count += 1 143 | 144 | return (string_count == 1 and non_negative_int_count == len(shapes) - 1) or non_negative_int_count == len(shapes) 145 | 146 | 147 | def get_constant_variable(node, return_idx=False): 148 | """Return the first constant variable found in a node's inputs, optionally including the index.""" 149 | for idx, input in enumerate(list(node.inputs)): 150 | if isinstance(input, Constant): 151 | return (idx, input) if return_idx else input 152 | -------------------------------------------------------------------------------- /onnxslim/core/optimization/subexpression_elimination.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from onnxslim.third_party.onnx_graphsurgeon.ir.tensor import Variable 4 | 5 | logger = logging.getLogger("onnxslim") 6 | 7 | 8 | def find_and_remove_replaceable_nodes(nodes): 9 | """Find and remove duplicate or replaceable nodes in a given list of computational graph nodes.""" 10 | 11 | def get_node_key(node): 12 | input_names = [] 13 | for input_node in node.inputs: 14 | if isinstance(input_node, Variable): 15 | input_names.append(input_node.name) 16 | return "_".join(input_names) if input_names else None 17 | 18 | node_dict = {} 19 | for node in nodes: 20 | key = get_node_key(node) 21 | if key: 22 | if key in node_dict: 23 | node_dict[key].append(node) 24 | else: 25 | node_dict[key] = [node] 26 | 27 | for key, bucketed_nodes in node_dict.items(): 28 | if len(bucketed_nodes) > 1: 29 | keep_nodes = [True] * len(bucketed_nodes) 30 | for i, node in enumerate(bucketed_nodes): 31 | if keep_nodes[i]: 32 | for j in range(i + 1, len(bucketed_nodes)): 33 | if keep_nodes[j] and can_be_replaced(node, bucketed_nodes[j]): 34 | keep_nodes[j] = False 35 | existing_node = node 36 | to_be_removed_node = bucketed_nodes[j] 37 | to_be_removed_node.replace_all_uses_with(existing_node) 38 | logger.debug( 39 | f"Node {to_be_removed_node.name} Op {to_be_removed_node.op} can be replaced by {existing_node.name}" 40 | ) 41 | 42 | 43 | def sequences_equal(seq1, seq2): 44 | """Check if two sequences are equal by comparing their lengths and elements.""" 45 | length_match = len(seq1) == len(seq2) 46 | if not length_match: 47 | return False 48 | 49 | return all(elem1 == elem2 for elem1, elem2 in zip(seq1, seq2)) 50 | 51 | 52 | def can_be_replaced(node, other_node): 53 | """Check if two nodes can be replaced based on their operations, attributes, and inputs.""" 54 | attrs_match = node.op == other_node.op and node.attrs == other_node.attrs 55 | node_input = [input for input in node.inputs if not input.is_empty()] 56 | other_node_input = [input for input in other_node.inputs if not input.is_empty()] 57 | inputs_match = sequences_equal(node_input, other_node_input) 58 | 59 | return attrs_match and inputs_match 60 | 61 | 62 | def subexpression_elimination(graph): 63 | """Perform subexpression elimination on a computational graph to optimize node operations.""" 64 | nodes_by_op = {} 65 | 66 | for node in graph.nodes: 67 | op = node.op 68 | if op not in nodes_by_op: 69 | nodes_by_op[op] = [] 70 | nodes_by_op[op].append(node) 71 | 72 | for nodes in nodes_by_op.values(): 73 | find_and_remove_replaceable_nodes(nodes) 74 | -------------------------------------------------------------------------------- /onnxslim/core/optimization/weight_tying.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | logger = logging.getLogger("onnxslim") 4 | import onnxslim.third_party.onnx_graphsurgeon as gs 5 | 6 | 7 | def tie_weights(graph): 8 | """Tie weights in a computational graph to reduce the number of parameters.""" 9 | tensor_map = graph.tensors() 10 | constant_tensors = [tensor for tensor in tensor_map.values() if isinstance(tensor, gs.Constant)] 11 | 12 | sub_graphs = graph.subgraphs(recursive=True) 13 | sub_graphs_constant_tensors = [ 14 | [tensor for name, tensor in sub_graph.tensors().items() if isinstance(tensor, gs.Constant)] 15 | for sub_graph in sub_graphs 16 | ] 17 | 18 | constant_tensors.extend([tensor for tensors in sub_graphs_constant_tensors for tensor in tensors]) 19 | 20 | def replace_constant_references(existing_constant, to_be_removed_constant): 21 | users = list(to_be_removed_constant.outputs) 22 | 23 | for user in users: 24 | for idx, inp in enumerate(user.inputs): 25 | if (inp == to_be_removed_constant) and (inp.name == to_be_removed_constant.name): 26 | user.inputs.pop(idx) 27 | user.inputs.insert(idx, existing_constant) 28 | 29 | if len(constant_tensors) > 1: 30 | keep_constants = [True] * len(constant_tensors) 31 | for i, constant_tensor in enumerate(constant_tensors): 32 | if keep_constants[i]: 33 | for j in range(i + 1, len(constant_tensors)): 34 | if keep_constants[j]: 35 | if constant_tensor == constant_tensors[j]: 36 | keep_constants[j] = False 37 | replace_constant_references(constant_tensor, constant_tensors[j]) 38 | logger.debug( 39 | f"Constant {constant_tensors[j].name} can be replaced by {constant_tensor.name}" 40 | ) 41 | -------------------------------------------------------------------------------- /onnxslim/core/pattern/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import re 3 | from abc import abstractmethod 4 | 5 | import onnxslim.third_party.onnx_graphsurgeon as gs 6 | from onnxslim.third_party.onnx_graphsurgeon import Constant 7 | 8 | logger = logging.getLogger("onnxslim") 9 | 10 | 11 | def get_name(name): 12 | """Sanitizes the input string by replacing illegal characters with underscores and prefixing with an underscore if 13 | numeric. 14 | """ 15 | _illegal_char_regex = re.compile("[^0-9a-zA-Z_]+") 16 | sanitized_name = _illegal_char_regex.sub("_", name) 17 | if sanitized_name.isdigit(): 18 | sanitized_name = f"_{sanitized_name}" 19 | 20 | return sanitized_name 21 | 22 | 23 | class NodeDescriptor: 24 | """ 25 | case 0: input [1, 2, 3, 4, 5] output [0] Optype Name 5 1 i0 i1 i2 i3 i4 o0 26 | case 1: input [1, ...] output [0] Optype Name 1+ 1 i0 o0 27 | case 2: input [..., 1, ...] output [0] Optype Name 1* 1 i0 o0. 28 | """ 29 | 30 | def __init__(self, node_spec): 31 | """Initialize NodeDescriptor with node_spec list requiring at least 4 elements.""" 32 | if not isinstance(node_spec, list): 33 | raise ValueError("node_spec must be a list") 34 | if len(node_spec) < 4: 35 | raise ValueError(f"node_spec must have at least 4 elements {node_spec}") 36 | 37 | def get_input_info(io_spec): 38 | """Parses io_spec to return a tuple of (integer, boolean) indicating the presence of a plus sign in the 39 | input. 40 | """ 41 | if not io_spec.isdigit(): 42 | match = re.search(r"(\d+)([+*])", io_spec) 43 | if match: 44 | number = match.group(1) 45 | operator = match.group(2) 46 | 47 | if not number.isdigit(): 48 | raise ValueError(f"input_num and output_num must be integers {io_spec}") 49 | 50 | if operator == "+": 51 | return int(number), True, "append" 52 | elif operator == "*": 53 | return int(number), True, "free-match" 54 | else: 55 | raise ValueError(f"operator must be + or * {io_spec}") 56 | 57 | return int(io_spec), False, None 58 | 59 | self.op = node_spec[0] 60 | self.name = node_spec[1] 61 | self.input_num, self.coarse_input_num, self.input_mode = get_input_info(node_spec[2]) 62 | self.output_num, self.coarse_output_num, self.output_mode = get_input_info(node_spec[3]) 63 | self.input_names = node_spec[4 : 4 + self.input_num] 64 | self.output_names = node_spec[4 + self.input_num :] 65 | assert len(self.input_names) == self.input_num 66 | assert len(self.output_names) == self.output_num, f"{self.name} {len(self.output_names)} != {self.output_num}" 67 | 68 | def __repr__(self): 69 | """Return a string representation of the object, including its name, operation type, input/output counts, and 70 | input/output names. 71 | """ 72 | return f"name: {self.name}, type: {self.op}, input_num: {self.input_num}, output_num: {self.output_num}, input_names: {self.input_names}, output_names: {self.output_names}" 73 | 74 | def __dict__(self): 75 | """Returns a dictionary representation of the object, with 'name' as the key.""" 76 | return { 77 | "name": self, 78 | } 79 | 80 | 81 | class Pattern: 82 | def __init__(self, pattern): 83 | """Initialize the Pattern class with a given pattern and parse its nodes.""" 84 | self.pattern = pattern 85 | self.nodes = self.parse_nodes() 86 | 87 | def parse_nodes(self): 88 | """Parse pattern into a list of NodeDescriptor objects from non-empty, stripped, and split lines.""" 89 | nodes = self.pattern.split("\n") 90 | nodes = [line.strip().split() for line in nodes if line] 91 | nodes = [NodeDescriptor(node) for node in nodes if node] 92 | return nodes 93 | 94 | def match(self, node): 95 | """Match a node against a precompiled pattern.""" 96 | return self.pattern.match(node) 97 | 98 | def __repr__(self): 99 | """Return a string representation of the pattern attribute.""" 100 | return self.pattern 101 | 102 | 103 | class PatternMatcher: 104 | def __init__(self, pattern, priority): 105 | """Initialize the PatternMatcher with a given pattern and priority, and prepare node references and output 106 | names. 107 | """ 108 | self.pattern = pattern 109 | self.priority = priority 110 | self.pattern_dict = {node.name: node for node in pattern.nodes} 111 | self.output_names = [node.name for node in pattern.nodes if node.op == "output"] 112 | 113 | def get_match_point(self): 114 | """Retrieve the match point node from the pattern dictionary based on output node input names.""" 115 | return self.pattern_dict[self.pattern_dict[self.output_names[0]].input_names[0]] 116 | 117 | def match(self, node): 118 | """Match a given node to a pattern by comparing input names with the match point node from the pattern 119 | dictionary. 120 | """ 121 | match_point = self.get_match_point() 122 | 123 | def match_(node, pattern_node): 124 | """Match a given node to a pattern by comparing input names with the match point node from the pattern 125 | dictionary. 126 | """ 127 | if pattern_node.op == "input": 128 | return True 129 | 130 | # node is an input variable 131 | if not hasattr(node, "op"): 132 | return False 133 | 134 | if node.op == pattern_node.op: 135 | setattr(self, pattern_node.name, node) 136 | 137 | node_feeds = node.feeds 138 | if pattern_node.coarse_input_num: 139 | if len(node_feeds) < len(pattern_node.input_names): 140 | return False 141 | else: 142 | if len(node_feeds) != len(pattern_node.input_names): 143 | return False 144 | 145 | if pattern_node.input_mode == "append" or pattern_node.input_mode is None: 146 | pattern_nodes = [ 147 | self.pattern_dict[name] if name != "?" else None for name in pattern_node.input_names 148 | ] 149 | all_match = True 150 | for node_feed, pattern_node in zip(node_feeds, pattern_nodes): 151 | if pattern_node is not None: 152 | node_match = match_(node_feed, pattern_node) 153 | if not node_match: 154 | return False 155 | setattr(self, pattern_node.name, node_feed) 156 | 157 | return all_match 158 | elif pattern_node.input_mode == "free-match": 159 | pattern_nodes = [ 160 | self.pattern_dict[name] if name != "?" else None for name in pattern_node.input_names 161 | ] 162 | all_match = True 163 | for pattern_node in pattern_nodes: 164 | if pattern_node is not None: 165 | node_match = False 166 | for node_feed in node_feeds: 167 | node_match = match_(node_feed, pattern_node) 168 | if node_match: 169 | break 170 | if not node_match: 171 | return False 172 | setattr(self, pattern_node.name, node_feed) 173 | 174 | return all_match 175 | return False 176 | 177 | if match_(node, match_point): 178 | setattr(self, "output", node.outputs) 179 | if self.parameter_check(): 180 | return True 181 | 182 | return False 183 | 184 | @abstractmethod 185 | def rewrite(self, opset=11): 186 | """Abstract method to rewrite the graph based on matched patterns, to be implemented by subclasses.""" 187 | raise NotImplementedError("rewrite method must be implemented") 188 | 189 | def parameter_check(self): 190 | """Check and validate parameters, returning True if valid.""" 191 | return True 192 | 193 | 194 | class PatternGenerator: 195 | def __init__(self, onnx_model): 196 | """Initialize the PatternGenerator class with an ONNX model and process its graph.""" 197 | self.graph = gs.import_onnx(onnx_model) 198 | self.graph.fold_constants().cleanup().toposort() 199 | 200 | def generate(self): 201 | """Generate the inputs, outputs, and nodes from the graph of the initialized ONNX model.""" 202 | inputs = self.graph.inputs 203 | outputs = self.graph.outputs 204 | nodes = self.graph.nodes 205 | 206 | template = [] 207 | for input in inputs: 208 | name = get_name(input.name) 209 | template.append( 210 | " ".join( 211 | ["input", name, "0", str(len(input.outputs))] + [get_name(output.name) for output in input.outputs] 212 | ) 213 | ) 214 | 215 | for node in nodes: 216 | if node.op != "Constant": 217 | name = get_name(node.name) 218 | feeds = node.feeds 219 | users = node.users 220 | template.append( 221 | " ".join( 222 | [node.op, name, str(len(feeds)), str(len(users))] 223 | + ["?" if isinstance(feed, Constant) else get_name(feed.name) for feed in feeds] 224 | + ["?" if isinstance(user, Constant) else get_name(user.name) for user in users] 225 | ) 226 | ) 227 | 228 | for output in outputs: 229 | name = get_name(output.name) 230 | template.append( 231 | " ".join( 232 | ["output", name, str(len(output.inputs)), "0"] + [get_name(input.name) for input in output.inputs] 233 | ) 234 | ) 235 | 236 | return "\n".join(template) 237 | -------------------------------------------------------------------------------- /onnxslim/core/pattern/elimination/__init__.py: -------------------------------------------------------------------------------- 1 | from .concat import * 2 | from .reshape import * 3 | from .reshape_as import * 4 | from .slice import * 5 | from .unsqueeze import * 6 | -------------------------------------------------------------------------------- /onnxslim/core/pattern/elimination/concat.py: -------------------------------------------------------------------------------- 1 | from onnxslim.core.pattern import Pattern, PatternMatcher 2 | from onnxslim.core.pattern.registry import register_fusion_pattern 3 | 4 | 5 | class ConcatPatternMatcher(PatternMatcher): 6 | def __init__(self, priority): 7 | """Initializes the ConcatPatternMatcher with a specified priority using a predefined graph pattern.""" 8 | pattern = Pattern( 9 | """ 10 | input input 0 1 concat_0 11 | Concat concat_0 1+ 1 input concat_1 12 | Concat concat_1 1* 1 concat_0 output 13 | output output 1 0 concat_1 14 | """ 15 | ) 16 | super().__init__(pattern, priority) 17 | 18 | @property 19 | def name(self): 20 | """Returns the name of the elimination pattern, 'EliminationConcat'.""" 21 | return "EliminationConcat" 22 | 23 | def rewrite(self, opset=11): 24 | """Rewrites an elimination pattern for concat nodes by optimizing nested slice operations.""" 25 | match_case = {} 26 | 27 | node_concat_0 = self.concat_0 28 | users_node_concat_0 = node_concat_0.users 29 | node_concat_1 = self.concat_1 30 | node_concat_0_axis = node_concat_0.attrs.get("axis", 0) 31 | node_concat_1.attrs.get("axis", 0) 32 | 33 | if all(user.op == "Concat" and user.attrs.get("axis", 0) == node_concat_0_axis for user in users_node_concat_0): 34 | index = node_concat_1.inputs.index(node_concat_0.outputs[0]) 35 | node_concat_1.inputs.pop(index) 36 | for i, item in enumerate(node_concat_0.inputs): 37 | node_concat_1.inputs.insert(index + i, item) 38 | inputs = list(node_concat_1.inputs) 39 | outputs = list(node_concat_1.outputs) 40 | node_concat_1.inputs.clear() 41 | node_concat_1.outputs.clear() 42 | 43 | if len(users_node_concat_0) == 0: 44 | node_concat_0.inputs.clear() 45 | node_concat_0.outputs.clear() 46 | 47 | attrs = {"axis": node_concat_0_axis} 48 | 49 | match_case[node_concat_1.name] = { 50 | "op": "Concat", 51 | "inputs": inputs, 52 | "outputs": outputs, 53 | "name": node_concat_1.name, 54 | "attrs": attrs, 55 | "domain": None, 56 | } 57 | 58 | return match_case 59 | 60 | 61 | register_fusion_pattern(ConcatPatternMatcher(1)) 62 | -------------------------------------------------------------------------------- /onnxslim/core/pattern/elimination/reshape.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import onnxslim.third_party.onnx_graphsurgeon as gs 4 | from onnxslim.core.pattern import Pattern, PatternMatcher 5 | from onnxslim.core.pattern.registry import register_fusion_pattern 6 | 7 | 8 | class ReshapePatternMatcher(PatternMatcher): 9 | def __init__(self, priority): 10 | """Initializes the ReshapePatternMatcher with a priority and a specific pattern for detecting nested reshape 11 | operations. 12 | """ 13 | pattern = Pattern( 14 | """ 15 | input input 0 1 reshape_0 16 | Reshape reshape_0 2 1 input ? reshape_1 17 | Reshape reshape_1 2 1 reshape_0 ? output 18 | output output 1 0 reshape_1 19 | """ 20 | ) 21 | super().__init__(pattern, priority) 22 | 23 | @property 24 | def name(self): 25 | """Returns the name 'EliminationReshape'.""" 26 | return "EliminationReshape" 27 | 28 | def rewrite(self, opset=11): 29 | """Rewrite the computational graph by eliminating redundant reshape operations when certain conditions are 30 | met. 31 | """ 32 | match_case = {} 33 | node = self.reshape_1 34 | first_reshape_node = node.i(0) 35 | first_reshape_node_inputs = list(first_reshape_node.inputs) 36 | first_reshape_node_users = first_reshape_node.users 37 | if len(first_reshape_node_users) == 1: 38 | second_reshape_node = node 39 | 40 | def check_constant_mergeable(reshape_node): 41 | """Check if a reshape node's shape input, containing zero dimensions, can be merged with its input 42 | node's shape. 43 | """ 44 | if isinstance(reshape_node.inputs[1], gs.Constant): 45 | input_shape = reshape_node.inputs[0].shape 46 | reshape_shape = reshape_node.inputs[1].values.tolist() 47 | if input_shape is not None and np.any(np.array(reshape_shape) == 0): 48 | shape = [ 49 | input_shape[i] if dim_size == 0 else reshape_shape[i] 50 | for i, dim_size in enumerate(reshape_shape) 51 | ] 52 | if not all(isinstance(item, int) for item in shape): 53 | return False 54 | return True 55 | 56 | if check_constant_mergeable(first_reshape_node) and check_constant_mergeable(second_reshape_node): 57 | inputs = [] 58 | inputs.append(first_reshape_node_inputs[0]) 59 | inputs.append(second_reshape_node.inputs[1]) 60 | outputs = list(second_reshape_node.outputs) 61 | first_reshape_node.outputs.clear() 62 | second_reshape_node.inputs.clear() 63 | second_reshape_node.outputs.clear() 64 | 65 | match_case[first_reshape_node.name] = { 66 | "op": "Reshape", 67 | "inputs": inputs, 68 | "outputs": outputs, 69 | "name": first_reshape_node.name, 70 | "attrs": first_reshape_node.attrs, 71 | "domain": None, 72 | } 73 | 74 | return match_case 75 | 76 | 77 | register_fusion_pattern(ReshapePatternMatcher(1)) 78 | -------------------------------------------------------------------------------- /onnxslim/core/pattern/elimination/reshape_as.py: -------------------------------------------------------------------------------- 1 | import onnxslim.third_party.onnx_graphsurgeon as gs 2 | from onnxslim.core.pattern import Pattern, PatternMatcher 3 | from onnxslim.core.pattern.registry import register_fusion_pattern 4 | 5 | 6 | class ReshapeAsPatternMatcher(PatternMatcher): 7 | def __init__(self, priority): 8 | """Initializes the ReshapeAsPatternMatcher with a priority and a specific pattern for reshape as operations.""" 9 | pattern = Pattern( 10 | """ 11 | input input 0 1 shape 12 | Shape shape 1+ 1 input gather 13 | Gather gather 1+ 1 shape unsqueeze 14 | Unsqueeze unsqueeze 1+ 1 gather output 15 | Concat concat 1+ 1 unsqueeze output 16 | output output 1 0 concat 17 | """ 18 | ) 19 | super().__init__(pattern, priority) 20 | 21 | @property 22 | def name(self): 23 | """Returns the name 'EliminationReshapeAs'.""" 24 | return "EliminationReshapeAs" 25 | 26 | def parameter_check(self) -> bool: 27 | shape_node = self.shape 28 | if shape_node.outputs[0].shape is None: 29 | return False 30 | 31 | if len(shape_node.users) != shape_node.outputs[0].shape[0]: 32 | return False 33 | 34 | if not all([user.op == "Gather" for user in shape_node.users]): 35 | return False 36 | 37 | for idx, user in enumerate(shape_node.users): 38 | if not isinstance(user.inputs[1], gs.Constant): 39 | return False 40 | 41 | if user.inputs[1].values.shape != (): 42 | return False 43 | 44 | if user.inputs[1].values != idx: 45 | return False 46 | 47 | concat_node = self.concat 48 | if len(concat_node.inputs) != shape_node.users: 49 | return False 50 | 51 | return True 52 | 53 | def rewrite(self, opset=11): 54 | """Rewrites the pattern by replacing the Concat node with the Shape node.""" 55 | match_case = {} 56 | shape_node = self.shape 57 | concat_node = self.concat 58 | 59 | concat_node.replace_all_uses_with(shape_node) 60 | 61 | return match_case 62 | 63 | 64 | register_fusion_pattern(ReshapeAsPatternMatcher(1)) 65 | -------------------------------------------------------------------------------- /onnxslim/core/pattern/elimination/slice.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import onnxslim.third_party.onnx_graphsurgeon as gs 4 | from onnxslim.core.pattern import Pattern, PatternMatcher 5 | from onnxslim.core.pattern.registry import register_fusion_pattern 6 | 7 | 8 | class SlicePatternMatcher(PatternMatcher): 9 | def __init__(self, priority): 10 | """Initializes the SlicePatternMatcher with a specified priority using a predefined graph pattern.""" 11 | pattern = Pattern( 12 | """ 13 | input input 0 1 slice_0 14 | Slice slice_0 5 1 input ? ? ? ? slice_1 15 | Slice slice_1 5 1 slice_0 ? ? ? ? output 16 | output output 1 0 slice_1 17 | """ 18 | ) # to check here slice_0 19 | super().__init__(pattern, priority) 20 | 21 | @property 22 | def name(self): 23 | """Returns the name of the elimination pattern, 'EliminationSlice'.""" 24 | return "EliminationSlice" 25 | 26 | def rewrite(self, opset=11): 27 | """Rewrites an elimination pattern for slice nodes by optimizing nested slice operations.""" 28 | match_case = {} 29 | first_slice_node = self.slice_0 30 | first_slice_node_inputs = list(first_slice_node.inputs) 31 | if all(isinstance(input, gs.Constant) for input in first_slice_node_inputs[1:]): 32 | first_slice_node_users = first_slice_node.users 33 | if all( 34 | user.op == "Slice" and all(isinstance(input, gs.Constant) for input in list(user.inputs)[1:]) 35 | for user in first_slice_node_users 36 | ): 37 | first_slice_node_starts = first_slice_node_inputs[1].values.tolist() 38 | first_slice_node_ends = first_slice_node_inputs[2].values.tolist() 39 | first_slice_node_axes = first_slice_node_inputs[3].values.tolist() 40 | first_slice_node_steps = first_slice_node_inputs[4].values.tolist() 41 | 42 | for user_node in first_slice_node_users: 43 | second_slice_node = user_node 44 | second_slice_node_inputs = list(second_slice_node.inputs) 45 | second_slice_node_starts = second_slice_node_inputs[1].values.tolist() 46 | second_slice_node_ends = second_slice_node_inputs[2].values.tolist() 47 | second_slice_node_axes = second_slice_node_inputs[3].values.tolist() 48 | second_slice_node_steps = second_slice_node_inputs[4].values.tolist() 49 | 50 | new_starts = first_slice_node_starts + second_slice_node_starts 51 | new_ends = first_slice_node_ends + second_slice_node_ends 52 | new_axes = first_slice_node_axes + second_slice_node_axes 53 | new_steps = first_slice_node_steps + second_slice_node_steps 54 | 55 | if len(new_axes) != len(set(new_axes)): 56 | continue 57 | 58 | inputs = [] 59 | inputs.extend( 60 | ( 61 | list(first_slice_node.inputs)[0], 62 | gs.Constant( 63 | second_slice_node_inputs[1].name, 64 | values=np.array(new_starts, dtype=np.int64), 65 | ), 66 | gs.Constant( 67 | second_slice_node_inputs[2].name, 68 | values=np.array(new_ends, dtype=np.int64), 69 | ), 70 | gs.Constant( 71 | second_slice_node_inputs[3].name, 72 | values=np.array(new_axes, dtype=np.int64), 73 | ), 74 | gs.Constant( 75 | second_slice_node_inputs[4].name, 76 | values=np.array(new_steps, dtype=np.int64), 77 | ), 78 | ) 79 | ) 80 | outputs = list(second_slice_node.outputs) 81 | 82 | first_slice_node.outputs.clear() 83 | second_slice_node.inputs.clear() 84 | second_slice_node.outputs.clear() 85 | 86 | if len(first_slice_node_users) == 1: 87 | match_case[first_slice_node.name] = { 88 | "op": "Slice", 89 | "inputs": inputs, 90 | "outputs": outputs, 91 | "name": first_slice_node.name, 92 | "attrs": first_slice_node.attrs, 93 | "domain": None, 94 | } 95 | else: 96 | match_case[second_slice_node.name] = { 97 | "op": "Slice", 98 | "inputs": inputs, 99 | "outputs": outputs, 100 | "name": second_slice_node.name, 101 | "attrs": second_slice_node.attrs, 102 | "domain": None, 103 | } 104 | 105 | return match_case 106 | 107 | 108 | register_fusion_pattern(SlicePatternMatcher(1)) 109 | -------------------------------------------------------------------------------- /onnxslim/core/pattern/elimination/unsqueeze.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import onnxslim.third_party.onnx_graphsurgeon as gs 4 | from onnxslim.core.pattern import Pattern, PatternMatcher 5 | from onnxslim.core.pattern.registry import register_fusion_pattern 6 | 7 | 8 | class UnsqueezePatternMatcher(PatternMatcher): 9 | def __init__(self, priority): 10 | """Initializes the UnsqueezePatternMatcher with a specified priority using a predefined graph pattern.""" 11 | pattern = Pattern( 12 | """ 13 | input input 0 1 unsqueeze_0 14 | Unsqueeze unsqueeze_0 1+ 1 input unsqueeze_1 15 | Unsqueeze unsqueeze_1 1+ 1 unsqueeze_0 output 16 | output output 1 0 unsqueeze_1 17 | """ 18 | ) 19 | super().__init__(pattern, priority) 20 | 21 | @property 22 | def name(self): 23 | """Returns the name of the elimination pattern, 'EliminationUnsqueeze'.""" 24 | return "EliminationUnsqueeze" 25 | 26 | def rewrite(self, opset=11): 27 | """Rewrites an elimination pattern for unsqueeze nodes by optimizing nested slice operations.""" 28 | match_case = {} 29 | node_unsqueeze_0 = self.unsqueeze_0 30 | users_node_unsqueeze_0 = node_unsqueeze_0.users 31 | node_unsqueeze_1 = self.unsqueeze_1 32 | if len(users_node_unsqueeze_0) == 1 and node_unsqueeze_0.inputs[0].shape and node_unsqueeze_1.inputs[0].shape: 33 | if opset < 13 or ( 34 | isinstance(node_unsqueeze_0.inputs[1], gs.Constant) 35 | and isinstance(node_unsqueeze_1.inputs[1], gs.Constant) 36 | ): 37 | 38 | def get_unsqueeze_axes(unsqueeze_node, opset): 39 | dim = len(unsqueeze_node.inputs[0].shape) 40 | if opset < 13: 41 | axes = unsqueeze_node.attrs["axes"] 42 | else: 43 | axes = unsqueeze_node.inputs[1].values 44 | return [axis + dim + len(axes) if axis < 0 else axis for axis in axes] 45 | 46 | axes_node_unsqueeze_0 = get_unsqueeze_axes(node_unsqueeze_0, opset) 47 | axes_node_unsqueeze_1 = get_unsqueeze_axes(node_unsqueeze_1, opset) 48 | 49 | axes_node_unsqueeze_0 = [ 50 | axis + sum(1 for axis_ in axes_node_unsqueeze_1 if axis_ <= axis) for axis in axes_node_unsqueeze_0 51 | ] 52 | 53 | inputs = [node_unsqueeze_0.inputs[0]] 54 | outputs = list(node_unsqueeze_1.outputs) 55 | node_unsqueeze_0.inputs.clear() 56 | node_unsqueeze_0.outputs.clear() 57 | node_unsqueeze_1.inputs.clear() 58 | node_unsqueeze_1.outputs.clear() 59 | 60 | if opset < 13: 61 | attrs = {"axes": axes_node_unsqueeze_0 + axes_node_unsqueeze_1} 62 | else: 63 | attrs = None 64 | inputs.append( 65 | gs.Constant( 66 | name=f"{node_unsqueeze_0.name}_axes", 67 | values=np.array(axes_node_unsqueeze_0 + axes_node_unsqueeze_1, dtype=np.int64), 68 | ) 69 | ) 70 | 71 | match_case[node_unsqueeze_0.name] = { 72 | "op": "Unsqueeze", 73 | "inputs": inputs, 74 | "outputs": outputs, 75 | "name": node_unsqueeze_0.name, 76 | "attrs": attrs, 77 | "domain": None, 78 | } 79 | 80 | return match_case 81 | 82 | 83 | register_fusion_pattern(UnsqueezePatternMatcher(1)) 84 | -------------------------------------------------------------------------------- /onnxslim/core/pattern/fusion/__init__.py: -------------------------------------------------------------------------------- 1 | from .convadd import * 2 | from .convbn import * 3 | from .gelu import * 4 | from .gemm import * 5 | from .padconv import * 6 | from .reduce import * 7 | -------------------------------------------------------------------------------- /onnxslim/core/pattern/fusion/convadd.py: -------------------------------------------------------------------------------- 1 | import onnxslim.third_party.onnx_graphsurgeon as gs 2 | from onnxslim.core.pattern import Pattern, PatternMatcher 3 | from onnxslim.core.pattern.registry import register_fusion_pattern 4 | 5 | 6 | class ConvAddMatcher(PatternMatcher): 7 | def __init__(self, priority): 8 | """Initializes the ConvAddMatcher for fusing Conv and Add layers in an ONNX graph.""" 9 | pattern = Pattern( 10 | """ 11 | input input 0 1 conv_0 12 | Conv conv_0 1+ 1 input bn_0 13 | Add add_0 2 1 conv_0 ? output 14 | output output 1 0 add_0 15 | """ 16 | ) 17 | super().__init__(pattern, priority) 18 | 19 | @property 20 | def name(self): 21 | """Returns the name of the FusionConvAdd pattern.""" 22 | return "FusionConvAdd" 23 | 24 | def rewrite(self, opset=11): 25 | match_case = {} 26 | conv_node = self.conv_0 27 | conv_weight = list(conv_node.inputs)[1] 28 | conv_node_users = conv_node.users 29 | node = self.add_0 30 | if ( 31 | len(conv_node_users) == 1 32 | and isinstance(node.inputs[1], gs.Constant) 33 | and isinstance(conv_weight, gs.Constant) 34 | and node.inputs[1].values.squeeze().ndim == 1 35 | and node.inputs[1].values.squeeze().shape[0] == conv_weight.shape[0] 36 | ): 37 | add_node = node 38 | if len(conv_node.inputs) == 2: 39 | conv_bias = node.inputs[1].values.squeeze() 40 | else: 41 | conv_bias = conv_node.inputs[2].values + node.inputs[1].values.squeeze() 42 | 43 | inputs = [] 44 | inputs.append(list(conv_node.inputs)[0]) 45 | inputs.append(conv_weight) 46 | weight_name = list(conv_node.inputs)[1].name 47 | if weight_name.endswith("weight"): 48 | bias_name = f"{weight_name[:-6]}bias" 49 | else: 50 | bias_name = f"{weight_name}_bias" 51 | inputs.append(gs.Constant(bias_name, values=conv_bias)) 52 | outputs = list(add_node.outputs) 53 | 54 | conv_node.outputs.clear() 55 | add_node.inputs.clear() 56 | add_node.outputs.clear() 57 | 58 | match_case[conv_node.name] = { 59 | "op": conv_node.op, 60 | "inputs": inputs, 61 | "outputs": outputs, 62 | "name": conv_node.name, 63 | "attrs": conv_node.attrs, 64 | "domain": None, 65 | } 66 | 67 | return match_case 68 | 69 | 70 | register_fusion_pattern(ConvAddMatcher(1)) 71 | -------------------------------------------------------------------------------- /onnxslim/core/pattern/fusion/convbn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import onnxslim.third_party.onnx_graphsurgeon as gs 4 | from onnxslim.core.pattern import Pattern, PatternMatcher 5 | from onnxslim.core.pattern.registry import register_fusion_pattern 6 | 7 | 8 | class ConvBatchNormMatcher(PatternMatcher): 9 | def __init__(self, priority): 10 | """Initializes the ConvBatchNormMatcher for fusing Conv and BatchNormalization layers in an ONNX graph.""" 11 | pattern = Pattern( 12 | """ 13 | input input 0 1 conv_0 14 | Conv conv_0 1+ 1 input bn_0 15 | BatchNormalization bn_0 5 1 conv_0 ? ? ? ? output 16 | output output 1 0 bn_0 17 | """ 18 | ) 19 | super().__init__(pattern, priority) 20 | 21 | @property 22 | def name(self): 23 | """Returns the name of the FusionConvBN pattern.""" 24 | return "FusionConvBN" 25 | 26 | def rewrite(self, opset=11): 27 | """Rewrites the weights and biases of a BatchNormalization layer fused with a convolution layer.""" 28 | match_case = {} 29 | conv_transpose_node = self.conv_0 30 | conv_transpose_node_users = conv_transpose_node.users 31 | node = self.bn_0 32 | if len(conv_transpose_node_users) == 1 and isinstance(conv_transpose_node.inputs[1], gs.Constant): 33 | conv_transpose_weight = conv_transpose_node.inputs[1].values 34 | bn_node = node 35 | bn_scale = bn_node.inputs[1].values 36 | bn_bias = bn_node.inputs[2].values 37 | bn_running_mean = bn_node.inputs[3].values 38 | bn_running_var = bn_node.inputs[4].values 39 | bn_eps = bn_node.attrs["epsilon"] 40 | 41 | if len(conv_transpose_node.inputs) == 2: 42 | conv_transpose_bias = np.zeros_like(bn_running_mean) 43 | else: 44 | conv_transpose_bias = conv_transpose_node.inputs[2].values 45 | 46 | bn_var_rsqrt = 1.0 / np.sqrt(bn_running_var + bn_eps) 47 | shape = [1] * len(conv_transpose_weight.shape) 48 | if bn_node.i(0).op == "Conv": 49 | shape[0] = -1 50 | else: 51 | shape[1] = -1 52 | conv_w = conv_transpose_weight * (bn_scale * bn_var_rsqrt).reshape(shape) 53 | conv_b = (conv_transpose_bias - bn_running_mean) * bn_var_rsqrt * bn_scale + bn_bias 54 | 55 | inputs = [] 56 | inputs.append(list(conv_transpose_node.inputs)[0]) 57 | weight_name = list(conv_transpose_node.inputs)[1].name 58 | if weight_name.endswith("weight"): 59 | bias_name = f"{weight_name[:-6]}bias" 60 | else: 61 | bias_name = f"{weight_name}_bias" 62 | inputs.extend( 63 | ( 64 | gs.Constant(weight_name, values=conv_w), 65 | gs.Constant(bias_name, values=conv_b), 66 | ) 67 | ) 68 | outputs = list(bn_node.outputs) 69 | 70 | conv_transpose_node.outputs.clear() 71 | bn_node.inputs.clear() 72 | bn_node.outputs.clear() 73 | 74 | match_case[conv_transpose_node.name] = { 75 | "op": conv_transpose_node.op, 76 | "inputs": inputs, 77 | "outputs": outputs, 78 | "name": conv_transpose_node.name, 79 | "attrs": conv_transpose_node.attrs, 80 | "domain": None, 81 | } 82 | 83 | return match_case 84 | 85 | 86 | register_fusion_pattern(ConvBatchNormMatcher(1)) 87 | -------------------------------------------------------------------------------- /onnxslim/core/pattern/fusion/gelu.py: -------------------------------------------------------------------------------- 1 | from onnxslim.core.pattern import Pattern, PatternMatcher 2 | 3 | 4 | class GeluPatternMatcher(PatternMatcher): 5 | def __init__(self, priority): 6 | """Initializes a `GeluPatternMatcher` to identify and fuse GELU patterns in a computational graph.""" 7 | pattern = Pattern( 8 | """ 9 | input input 0 2 mul_0 div_0 10 | Div div_0 2 1 input ? erf_0 11 | Erf erf_0 1 1 div_0 add_0 12 | Add add_0 2 1 erf_0 ? mul_0 13 | Mul mul_0 2 1 input add_0 mul_1 14 | Mul mul_1 2 1 mul_0 ? output 15 | output output 1 0 mul_1 16 | """ 17 | ) 18 | super().__init__(pattern, priority) 19 | 20 | @property 21 | def name(self): 22 | """Returns the name of the fusion pattern, 'FusionGelu'.""" 23 | return "FusionGelu" 24 | 25 | def rewrite(self, opset=11): 26 | """Rewrite the computation graph pattern to fuse GELU operations.""" 27 | input_variable = self.div_0.inputs[0] 28 | mul_node = self.mul_0 29 | div_node = self.div_0 30 | 31 | input_variable.outputs.remove(mul_node) 32 | input_variable.outputs.remove(div_node) 33 | 34 | output_variable = self.mul_1.outputs[0] 35 | output_variable.inputs.clear() 36 | 37 | return { 38 | self.mul_1.name: { 39 | "op": "Gelu", 40 | "inputs": [input_variable], 41 | "outputs": [output_variable], 42 | "domain": None, 43 | } 44 | } 45 | 46 | 47 | # register_fusion_pattern(GeluPatternMatcher(1)) 48 | -------------------------------------------------------------------------------- /onnxslim/core/pattern/fusion/gemm.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import onnxslim.third_party.onnx_graphsurgeon as gs 4 | from onnxslim.core.optimization.dead_node_elimination import get_constant_variable 5 | from onnxslim.core.pattern import Pattern, PatternMatcher 6 | from onnxslim.core.pattern.registry import register_fusion_pattern 7 | 8 | 9 | class MatMulAddPatternMatcher(PatternMatcher): 10 | def __init__(self, priority): 11 | """Initializes a matcher for fusing MatMul and Add operations in ONNX graph optimization.""" 12 | pattern = Pattern( 13 | """ 14 | input input 0 1 matmul_0 15 | MatMul matmul_0 2 1 input ? add_0 16 | Add add_0 2 1 matmul_0 ? output 17 | output output 1 0 add_0 18 | """ 19 | ) 20 | super().__init__(pattern, priority) 21 | 22 | @property 23 | def name(self): 24 | """Returns the name of the fusion pattern as a string 'FusionGemm'.""" 25 | return "FusionGemm" 26 | 27 | def rewrite(self, opset=11): 28 | """Rewrites the graph for the fusion pattern 'FusionGemm' based on matching criteria and constant variables in 29 | matmul nodes. 30 | """ 31 | match_case = {} 32 | node = self.add_0 33 | matmul_node = self.matmul_0 34 | matmul_bias_variable = get_constant_variable(matmul_node) 35 | add_bias_variable = get_constant_variable(node) 36 | input_variable = ( 37 | matmul_node.inputs[0] if isinstance(matmul_node.inputs[1], gs.Constant) else matmul_node.inputs[1] 38 | ) 39 | users = matmul_node.users 40 | if len(users) == 1 and matmul_bias_variable and add_bias_variable and len(matmul_bias_variable.shape) == 2: 41 | if ( 42 | input_variable.shape 43 | and len(input_variable.shape) > 2 44 | and all([isinstance(value, int) for value in input_variable.shape]) 45 | ): 46 | pre_reshape_const = gs.Constant( 47 | f"{matmul_node.name}_pre_reshape_in", 48 | values=np.array([-1, matmul_bias_variable.values.shape[0]], dtype=np.int64), 49 | ) 50 | inputs = [] 51 | inputs.append(input_variable) 52 | inputs.append(pre_reshape_const) 53 | 54 | reshape_out_variable = gs.Variable( 55 | f"{matmul_node.name}_pre_reshape_out", 56 | dtype=input_variable.dtype, 57 | ) 58 | outputs = [reshape_out_variable] 59 | 60 | match_case.update( 61 | { 62 | f"{matmul_node.name}_pre_reshape": { 63 | "op": "Reshape", 64 | "inputs": inputs, 65 | "outputs": outputs, 66 | "name": f"{matmul_node.name}_pre_reshape", 67 | "domain": None, 68 | } 69 | } 70 | ) 71 | 72 | add_node = node 73 | add_bias_variable = get_constant_variable(add_node) 74 | 75 | output_variable = add_node.inputs[0] 76 | output_variable.outputs.remove(add_node) 77 | 78 | matmul_bias_transpose_constant = gs.Constant( 79 | matmul_bias_variable.name, values=matmul_bias_variable.values.T 80 | ) 81 | 82 | inputs = [] 83 | inputs.append(reshape_out_variable) 84 | inputs.append(matmul_bias_transpose_constant) 85 | inputs.append(add_bias_variable) 86 | 87 | gemm_out_variable = gs.Variable(f"{matmul_node.name}_gemm_out", dtype=output_variable.dtype) 88 | outputs = [gemm_out_variable] 89 | 90 | match_case.update( 91 | { 92 | matmul_node.name: { 93 | "op": "Gemm", 94 | "inputs": inputs, 95 | "outputs": outputs, 96 | "name": matmul_node.name, 97 | "attrs": { 98 | "alpha": 1.0, 99 | "beta": 1.0, 100 | "transA": 0, 101 | "transB": 1, 102 | }, 103 | "domain": None, 104 | } 105 | } 106 | ) 107 | 108 | values = input_variable.shape[:-1] + [matmul_bias_variable.values.shape[-1]] 109 | post_reshape_const = gs.Constant( 110 | f"{matmul_node.name}_post_reshape_in", 111 | values=np.array(values, dtype=np.int64), 112 | ) 113 | 114 | inputs = [] 115 | inputs.append(gemm_out_variable) 116 | inputs.append(post_reshape_const) 117 | outputs = list(add_node.outputs) 118 | 119 | matmul_node.outputs.clear() 120 | add_node.inputs.clear() 121 | add_node.outputs.clear() 122 | 123 | match_case.update( 124 | { 125 | f"{matmul_node.name}_post_reshape": { 126 | "op": "Reshape", 127 | "inputs": inputs, 128 | "outputs": outputs, 129 | "name": f"{matmul_node.name}_post_reshape", 130 | "domain": None, 131 | } 132 | } 133 | ) 134 | elif ( 135 | input_variable.shape 136 | and len(input_variable.shape) == 2 137 | and all([isinstance(value, int) for value in input_variable.shape]) 138 | ): 139 | add_node = node 140 | add_bias_variable = get_constant_variable(add_node) 141 | 142 | output_variable = add_node.inputs[0] 143 | output_variable.outputs.remove(add_node) 144 | 145 | matmul_bias_transpose_constant = gs.Constant( 146 | matmul_bias_variable.name, values=matmul_bias_variable.values.T 147 | ) 148 | 149 | inputs = [] 150 | inputs.append(input_variable) 151 | inputs.append(matmul_bias_transpose_constant) 152 | inputs.append(add_bias_variable) 153 | 154 | outputs = list(add_node.outputs) 155 | add_node.inputs.clear() 156 | add_node.outputs.clear() 157 | match_case.update( 158 | { 159 | matmul_node.name: { 160 | "op": "Gemm", 161 | "inputs": inputs, 162 | "outputs": outputs, 163 | "name": matmul_node.name, 164 | "attrs": { 165 | "alpha": 1.0, 166 | "beta": 1.0, 167 | "transA": 0, 168 | "transB": 1, 169 | }, 170 | "domain": None, 171 | } 172 | } 173 | ) 174 | return match_case 175 | 176 | 177 | register_fusion_pattern(MatMulAddPatternMatcher(1)) 178 | -------------------------------------------------------------------------------- /onnxslim/core/pattern/fusion/padconv.py: -------------------------------------------------------------------------------- 1 | import onnxslim.third_party.onnx_graphsurgeon as gs 2 | from onnxslim.core.pattern import Pattern, PatternMatcher 3 | from onnxslim.core.pattern.registry import register_fusion_pattern 4 | 5 | 6 | class PadConvMatcher(PatternMatcher): 7 | def __init__(self, priority): 8 | """Initializes the PadConvMatcher with a specified priority and defines its matching pattern.""" 9 | pattern = Pattern( 10 | """ 11 | input input 0 1 pad_0 12 | Pad pad_0 1+ 1 input conv_0 13 | Conv conv_0 1+ 1 pad_0 output 14 | output output 1 0 conv_0 15 | """ 16 | ) 17 | super().__init__(pattern, priority) 18 | 19 | @property 20 | def name(self): 21 | """Returns the name of the fusion pattern used.""" 22 | return "FusionPadConv" 23 | 24 | def parameter_check(self) -> bool: 25 | """Validates if the padding parameter for a convolutional node is a constant.""" 26 | pad_node = self.pad_0 27 | return isinstance(pad_node.inputs[1], gs.Constant) 28 | 29 | def rewrite(self, opset=11): 30 | """Rewrites the padding parameter for a convolutional node to use a constant if the current parameter is not a 31 | constant. 32 | """ 33 | match_case = {} 34 | conv_node = self.conv_0 35 | pad_node = self.pad_0 36 | pad_node_users = pad_node.users 37 | 38 | pad_inputs = len(pad_node.inputs) 39 | if pad_inputs < 3 or ( 40 | pad_inputs >= 3 and (isinstance(pad_node.inputs[2], gs.Constant) and pad_node.inputs[2].values == 0) 41 | ): 42 | if ( 43 | isinstance(pad_node.inputs[1], gs.Constant) 44 | and pad_node.attrs["mode"] == "constant" 45 | and conv_node.inputs[1].shape 46 | ): 47 | conv_weight_dim = len(conv_node.inputs[1].shape) 48 | pad_value = pad_node.inputs[1].values.tolist() 49 | 50 | if all(pad == 0 for pad in (pad_value[:2] + pad_value[conv_weight_dim : conv_weight_dim + 2])): 51 | conv_weight_dim - 2 52 | input_variable = self.pad_0.inputs[0] 53 | pad_variable = pad_node.outputs[0] # pad output variable 54 | index = conv_node.inputs.index(pad_variable) 55 | conv_node.inputs.pop(index) 56 | conv_node.inputs.insert(index, input_variable) 57 | 58 | inputs = list(conv_node.inputs) 59 | outputs = list(conv_node.outputs) 60 | attrs = conv_node.attrs 61 | 62 | conv_node.inputs.clear() 63 | conv_node.outputs.clear() 64 | # remove pad node if it has only one user 65 | if len(pad_node_users) == 0: 66 | input_variable.outputs.remove(pad_node) 67 | pad_node.inputs.clear() 68 | pad_node.outputs.clear() 69 | 70 | conv_pads = attrs["pads"] 71 | pads = pad_value[2:conv_weight_dim] + pad_value[conv_weight_dim + 2 :] 72 | pads = [pad + conv_pad for pad, conv_pad in zip(pads, conv_pads)] 73 | 74 | attrs["pads"] = pads 75 | match_case[conv_node.name] = { 76 | "op": "Conv", 77 | "inputs": inputs, 78 | "outputs": outputs, 79 | "name": conv_node.name, 80 | "attrs": conv_node.attrs, 81 | "domain": None, 82 | } 83 | 84 | return match_case 85 | 86 | 87 | register_fusion_pattern(PadConvMatcher(1)) 88 | -------------------------------------------------------------------------------- /onnxslim/core/pattern/fusion/reduce.py: -------------------------------------------------------------------------------- 1 | from onnxslim.core.pattern import Pattern, PatternMatcher 2 | from onnxslim.core.pattern.registry import register_fusion_pattern 3 | 4 | 5 | class ReducePatternMatcher(PatternMatcher): 6 | def __init__(self, priority): 7 | """Initializes the ReducePatternMatcher with a specified pattern matching priority level.""" 8 | pattern = Pattern( 9 | """ 10 | input input 0 1 reduce_0 11 | ReduceSum reduce_0 1 1 input unsqueeze_0 12 | Unsqueeze unsqueeze_0 1 1 reduce_0 output 13 | output output 1 0 unsqueeze_0 14 | """ 15 | ) 16 | super().__init__(pattern, priority) 17 | 18 | @property 19 | def name(self): 20 | """Returns the name of the fusion pattern 'FusionReduce'.""" 21 | return "FusionReduce" 22 | 23 | def rewrite(self, opset=11): 24 | """Rewrites the graph pattern based on opset version; reuses Reduce and Unsqueeze nodes if possible.""" 25 | match_case = {} 26 | node = self.unsqueeze_0 27 | reduce_node = self.reduce_0 28 | reduce_node_node_users = reduce_node.users 29 | if len(reduce_node_node_users) == 1: 30 | unsqueeze_node = node 31 | 32 | reduce_node_axes = reduce_node.attrs.get("axes", None) 33 | reduce_node_keepdims = reduce_node.attrs.get("keepdims", 1) 34 | unsqueeze_node_axes = unsqueeze_node.attrs.get("axes", None) 35 | 36 | if opset < 13 and reduce_node_axes == [-1] and unsqueeze_node_axes == [-1] and reduce_node_keepdims == 0: 37 | inputs = list(reduce_node.inputs) 38 | outputs = list(unsqueeze_node.outputs) 39 | attrs = reduce_node.attrs 40 | reduce_node.outputs.clear() 41 | unsqueeze_node.inputs.clear() 42 | unsqueeze_node.outputs.clear() 43 | attrs["keepdims"] = 1 44 | match_case[reduce_node.name] = { 45 | "op": reduce_node.op, 46 | "inputs": inputs, 47 | "outputs": outputs, 48 | "name": reduce_node.name, 49 | "attrs": attrs, 50 | "domain": None, 51 | } 52 | 53 | return match_case 54 | 55 | 56 | register_fusion_pattern(ReducePatternMatcher(1)) 57 | -------------------------------------------------------------------------------- /onnxslim/core/pattern/registry.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | DEFAULT_FUSION_PATTERNS = OrderedDict() 4 | 5 | 6 | def register_fusion_pattern(fusion_pattern): 7 | """Registers a fusion pattern function for a specified layer type in the DEFAULT_FUSION_PATTERNS dictionary.""" 8 | layer_type = fusion_pattern.name 9 | 10 | if layer_type in DEFAULT_FUSION_PATTERNS.keys(): 11 | raise 12 | DEFAULT_FUSION_PATTERNS[layer_type] = fusion_pattern 13 | 14 | 15 | def get_fusion_patterns(skip_fusion_patterns: str = None): 16 | """Returns a copy of the default fusion patterns, optionally excluding specific patterns.""" 17 | default_fusion_patterns = DEFAULT_FUSION_PATTERNS.copy() 18 | if skip_fusion_patterns: 19 | for pattern in skip_fusion_patterns: 20 | default_fusion_patterns.pop(pattern) 21 | 22 | return default_fusion_patterns 23 | 24 | 25 | from .elimination import * 26 | from .fusion import * 27 | -------------------------------------------------------------------------------- /onnxslim/misc/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/inisis/OnnxSlim/4e820ced75ed909f96b5b92bc46cebf4ca3a2156/onnxslim/misc/__init__.py -------------------------------------------------------------------------------- /onnxslim/misc/font.py: -------------------------------------------------------------------------------- 1 | RED = "\033[31m" # Red text 2 | WHITE = "\033[37m" # White text 3 | GREEN = "\033[32m" # Green text 4 | -------------------------------------------------------------------------------- /onnxslim/third_party/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/inisis/OnnxSlim/4e820ced75ed909f96b5b92bc46cebf4ca3a2156/onnxslim/third_party/__init__.py -------------------------------------------------------------------------------- /onnxslim/third_party/onnx_graphsurgeon/__init__.py: -------------------------------------------------------------------------------- 1 | from onnxslim.third_party.onnx_graphsurgeon.exporters.onnx_exporter import export_onnx 2 | from onnxslim.third_party.onnx_graphsurgeon.graph_pattern import ( 3 | GraphPattern, 4 | PatternMapping, 5 | ) 6 | from onnxslim.third_party.onnx_graphsurgeon.importers.onnx_importer import import_onnx 7 | from onnxslim.third_party.onnx_graphsurgeon.ir.function import Function 8 | from onnxslim.third_party.onnx_graphsurgeon.ir.graph import Graph 9 | from onnxslim.third_party.onnx_graphsurgeon.ir.node import Node 10 | from onnxslim.third_party.onnx_graphsurgeon.ir.tensor import Constant, Tensor, Variable 11 | from onnxslim.third_party.onnx_graphsurgeon.util.exception import ( 12 | OnnxGraphSurgeonException, 13 | ) 14 | 15 | __version__ = "0.5.1" 16 | -------------------------------------------------------------------------------- /onnxslim/third_party/onnx_graphsurgeon/exporters/__init__.py: -------------------------------------------------------------------------------- 1 | from onnxslim.third_party.onnx_graphsurgeon.exporters.base_exporter import BaseExporter 2 | -------------------------------------------------------------------------------- /onnxslim/third_party/onnx_graphsurgeon/exporters/base_exporter.py: -------------------------------------------------------------------------------- 1 | # 2 | # SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | # SPDX-License-Identifier: Apache-2.0 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | from onnxslim.third_party.onnx_graphsurgeon.ir.graph import Graph 19 | 20 | 21 | class BaseExporter: 22 | @staticmethod 23 | def export_graph(graph: Graph): 24 | """ 25 | Export a graph to some destination graph. 26 | 27 | Args: 28 | graph (Graph): The source graph to export. 29 | 30 | Returns: 31 | object: The exported graph. For example, this might be an onnx.GraphProto 32 | """ 33 | raise NotImplementedError("BaseExporter is an abstract class") 34 | -------------------------------------------------------------------------------- /onnxslim/third_party/onnx_graphsurgeon/exporters/onnx_exporter.py: -------------------------------------------------------------------------------- 1 | # 2 | # SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | # SPDX-License-Identifier: Apache-2.0 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | from collections import OrderedDict 18 | from typing import List, Sequence, Union 19 | 20 | import numpy as np 21 | import onnx 22 | import onnx.numpy_helper 23 | 24 | from onnxslim.third_party.onnx_graphsurgeon.exporters.base_exporter import BaseExporter 25 | from onnxslim.third_party.onnx_graphsurgeon.ir.function import Function 26 | from onnxslim.third_party.onnx_graphsurgeon.ir.graph import Graph 27 | from onnxslim.third_party.onnx_graphsurgeon.ir.node import Node 28 | from onnxslim.third_party.onnx_graphsurgeon.ir.tensor import ( 29 | Constant, 30 | LazyValues, 31 | SparseValues, 32 | Tensor, 33 | Variable, 34 | ) 35 | from onnxslim.third_party.onnx_graphsurgeon.logger import G_LOGGER 36 | from onnxslim.third_party.onnx_graphsurgeon.util import misc 37 | 38 | 39 | def dtype_to_onnx(dtype: Union[np.dtype, "onnx.TensorProto.DataType"]) -> int: 40 | """Converts a numpy dtype or ONNX data type to its integer representation.""" 41 | if isinstance(dtype, int): 42 | return dtype 43 | return onnx.helper.np_dtype_to_tensor_dtype(np.dtype(dtype)) 44 | 45 | 46 | def check_duplicate_node_names(nodes: Sequence[Node], level=G_LOGGER.WARNING): 47 | """Check if node names are unique and log any duplicates based on the specified severity level.""" 48 | # Note: 49 | # Empty string or None attribute values are not considered duplicates. 50 | name_map = {} 51 | for node in nodes: 52 | if not node.name: 53 | continue 54 | if node.name in name_map: 55 | msg = f"Found distinct Nodes that share the same name:\n[id: {id(name_map[node.name])}]:\n {name_map[node.name]}---\n[id: {id(node)}]:\n {node}\n" 56 | G_LOGGER.log(msg, level) 57 | else: 58 | name_map[node.name] = node 59 | 60 | 61 | def update_import_domains(graph): 62 | """Update the import_domains field of a graph to include its ONNX opset and other used non-ONNX domains.""" 63 | # as well as other non-ONNX domains which are used by this graph's nodes. 64 | # Returns the updated value of the import_domains field. 65 | 66 | # Add domain of the standard ONNX opset. 67 | if graph.import_domains is None: 68 | graph.import_domains = [onnx.helper.make_opsetid("", graph.opset)] 69 | 70 | # Crawl over all nodes in this graph and its subgraphs, and add the nodes' domains. 71 | all_used_domains = {node.domain for node in graph.nodes} 72 | for subgraph in graph.subgraphs(recursive=True): 73 | all_used_domains |= {n.domain for n in subgraph.nodes} 74 | all_used_domains.discard(None) 75 | 76 | # Update self.import_domains with any missing domains. 77 | current_domains = {opsetid.domain for opsetid in graph.import_domains} 78 | DEFAULT_CUSTOM_OPSET_VERSION = 1 79 | for used_domain in all_used_domains: 80 | if used_domain not in current_domains: 81 | graph.import_domains.append(onnx.helper.make_opsetid(used_domain, DEFAULT_CUSTOM_OPSET_VERSION)) 82 | current_domains.add(used_domain) 83 | return graph.import_domains 84 | 85 | 86 | # Converts a fp32 gs.Constant to a bf16 onnx.TensorProto 87 | def tensor_to_onnx_bf16(tensor: Constant): 88 | """Converts an fp32 gs.Constant tensor to a bf16 onnx.TensorProto.""" 89 | 90 | def np_float32_to_bf16_as_uint16(arr): 91 | new_arr = np.empty(arr.size, dtype=np.uint16) 92 | flatten = arr.flatten() 93 | for i in range(arr.size): 94 | new_arr[i] = onnx.helper.float32_to_bfloat16(flatten[i]) 95 | return new_arr.reshape(arr.shape) 96 | 97 | arr_bf16_as_uint16 = np_float32_to_bf16_as_uint16(tensor.values) 98 | 99 | onnx_tensor = onnx.TensorProto() 100 | onnx_tensor.data_type = onnx.TensorProto.BFLOAT16 101 | onnx_tensor.dims.extend(arr_bf16_as_uint16.shape) 102 | onnx_tensor.raw_data = arr_bf16_as_uint16.tobytes() 103 | 104 | return onnx_tensor 105 | 106 | 107 | class OnnxExporter(BaseExporter): 108 | @staticmethod 109 | def export_tensor_proto(tensor: Constant) -> onnx.TensorProto: 110 | # Do *not* load LazyValues into an intermediate numpy array - instead, use 111 | """Converts a gs.Constant tensor to an onnx.TensorProto with type and data location handling.""" 112 | # the original onnx.TensorProto directly. 113 | if isinstance(tensor._values, LazyValues): 114 | onnx_tensor = tensor._values.tensor 115 | else: 116 | if dtype_to_onnx(tensor.dtype) != dtype_to_onnx(tensor.export_dtype): 117 | assert tensor.dtype == np.float32, ( 118 | f"Cannot convert onnx dtype {dtype_to_onnx(tensor.dtype)} to {dtype_to_onnx(tensor.export_dtype)}." 119 | "Only float32 to bfloat16 is supported" 120 | ) 121 | assert tensor.export_dtype == onnx.TensorProto.BFLOAT16, ( 122 | f"Cannot convert onnx dtype {dtype_to_onnx(tensor.dtype)} to {dtype_to_onnx(tensor.export_dtype)}." 123 | "Only float32 to bfloat16 is supported" 124 | ) 125 | onnx_tensor = tensor_to_onnx_bf16(tensor) 126 | else: 127 | onnx_tensor = onnx.numpy_helper.from_array(tensor.values) 128 | 129 | if tensor.data_location is not None: 130 | onnx_tensor.data_location = tensor.data_location 131 | onnx_tensor.name = tensor.name 132 | return onnx_tensor 133 | 134 | @staticmethod 135 | def export_sparse_tensor_proto(tensor: Constant) -> onnx.SparseTensorProto: 136 | """Exports a given Constant tensor as an ONNX SparseTensorProto.""" 137 | return tensor._values.tensor 138 | 139 | @staticmethod 140 | def export_value_info_proto(tensor: Tensor, do_type_check: bool) -> onnx.ValueInfoProto: 141 | """Creates an ONNX ValueInfoProto from a Tensor, optionally checking for dtype information.""" 142 | if do_type_check and tensor.dtype is None: 143 | G_LOGGER.critical( 144 | f"Graph input and output tensors must include dtype information. Please set the dtype attribute for: {tensor}" 145 | ) 146 | 147 | if tensor.dtype is None: 148 | onnx_tensor = onnx.helper.make_empty_tensor_value_info(tensor.name) 149 | elif isinstance(tensor, Constant) or tensor.type == "tensor_type": 150 | onnx_tensor = onnx.helper.make_tensor_value_info(tensor.name, dtype_to_onnx(tensor.dtype), tensor.shape) 151 | elif tensor.type == "sequence_type": 152 | onnx_tensor = onnx.helper.make_tensor_sequence_value_info( 153 | tensor.name, dtype_to_onnx(tensor.dtype), tensor.shape 154 | ) 155 | elif tensor.type == "sparse_tensor_type": 156 | onnx_tensor = onnx.helper.make_sparse_tensor_value_info( 157 | tensor.name, dtype_to_onnx(tensor.dtype), tensor.shape 158 | ) 159 | return onnx_tensor 160 | 161 | @staticmethod 162 | def export_attributes(attrs: dict, subgraph_tensor_map=None) -> List[onnx.AttributeProto]: 163 | """Convert function attributes to ONNX AttributeProtos for model export.""" 164 | onnx_attrs: List[onnx.AttributeProto] = [] 165 | for key, val in attrs.items(): 166 | if isinstance(val, Tensor): 167 | val = OnnxExporter.export_tensor_proto(val) 168 | elif isinstance(val, Graph): 169 | # Subgraphs don't need to have types specified for their tensors. 170 | val = OnnxExporter.export_graph(val, subgraph_tensor_map=subgraph_tensor_map, do_type_check=False) 171 | elif isinstance(val, Node.AttributeRef): 172 | onnx_attr = onnx.AttributeProto() 173 | onnx_attr.name = key 174 | onnx_attr.type = misc.convert_to_onnx_attr_type(val.type) 175 | 176 | # Netron has a bug which makes it crash if a Tensor attribute has no tensor data. 177 | # So provide some meaningless tensor data for Netron to read. 178 | if val.type == Tensor: 179 | tensor_proto = OnnxExporter.export_tensor_proto(Constant("", np.array([0], dtype=np.float32))) 180 | onnx_attr.t.CopyFrom(tensor_proto) 181 | 182 | onnx_attr.ref_attr_name = val.name 183 | onnx_attrs.append(onnx_attr) 184 | continue 185 | elif isinstance(val, type): 186 | # May be a numpy type 187 | try: 188 | val = dtype_to_onnx(val) 189 | except TypeError: 190 | pass 191 | onnx_attrs.append(onnx.helper.make_attribute(key, val)) 192 | return onnx_attrs 193 | 194 | @staticmethod 195 | def export_node(node: Node, subgraph_tensor_map=None) -> onnx.NodeProto: 196 | # Cannot pass in attrs directly as make_node will change the order 197 | """Static method to convert an internal node to an ONNX node representation.""" 198 | onnx_node = onnx.helper.make_node( 199 | node.op, 200 | inputs=[t.name for t in node.inputs], 201 | outputs=[t.name for t in node.outputs], 202 | name=node.name, 203 | domain=node.domain, 204 | ) 205 | onnx_node.attribute.extend(OnnxExporter.export_attributes(node.attrs, subgraph_tensor_map)) 206 | return onnx_node 207 | 208 | @staticmethod 209 | def export_function(func: Function) -> onnx.FunctionProto: 210 | """ 211 | Export an onnx-graphsurgeon Function to an ONNX FunctionProto. 212 | 213 | Args: 214 | func (Function): The function to export. 215 | """ 216 | # Unlike onnx Graphs, onnx Functions don't have an 'initializer' field. 217 | # So we need to replace all Constant tensors with onnx Constant nodes which produce them. 218 | # We need to be careful to (a) preserve topological ordering and (b) not make the new nodes visible to the user. 219 | func_nodes = func.nodes.copy() 220 | new_const_nodes = [ 221 | Node("Constant", attrs={"value": tensor}, outputs=[tensor.copy()]) 222 | for tensor in func.tensors().values() 223 | if isinstance(tensor, Constant) 224 | ] 225 | # Const nodes have no inputs, so this maintains a topological ordering. 226 | func_nodes = new_const_nodes + func_nodes 227 | 228 | check_duplicate_node_names(func_nodes, level=G_LOGGER.WARNING) 229 | nodes = [OnnxExporter.export_node(node) for node in func_nodes] 230 | 231 | # Update the import_domains field to include all domains used by this function. 232 | opset_imports = update_import_domains(func) 233 | 234 | onnx_inputs = [inp.name for inp in func.inputs] 235 | onnx_outputs = [out.name for out in func.outputs] 236 | 237 | attributes = [] 238 | attribute_protos = {} 239 | for attr_name, default_val in func.attrs.items(): 240 | if default_val is None: 241 | attributes.append(attr_name) 242 | else: 243 | attribute_protos[attr_name] = default_val 244 | attribute_protos = OnnxExporter.export_attributes(attribute_protos) 245 | 246 | return onnx.helper.make_function( 247 | func.domain or "", 248 | func.name, 249 | onnx_inputs, 250 | onnx_outputs, 251 | nodes, 252 | opset_imports, 253 | attributes=attributes, 254 | attribute_protos=attribute_protos, 255 | doc_string=func.doc_string, 256 | ) 257 | 258 | @staticmethod 259 | def export_graph( 260 | graph: Graph, 261 | tensor_map: "OrderedDict[str, Tensor]" = None, 262 | subgraph_tensor_map: "OrderedDict[str, Tensor]" = None, 263 | do_type_check=True, 264 | ) -> onnx.GraphProto: 265 | """ 266 | Export an onnx-graphsurgeon Graph to an ONNX GraphProto. 267 | 268 | Args: 269 | graph (Graph): The graph to export. 270 | 271 | do_type_check (bool): Whether to check that input and output tensors have data types defined, and fail if not. 272 | Defaults to True. 273 | """ 274 | check_duplicate_node_names(graph.nodes, level=G_LOGGER.WARNING) 275 | nodes = [OnnxExporter.export_node(node, subgraph_tensor_map) for node in graph.nodes] 276 | inputs = [OnnxExporter.export_value_info_proto(inp, do_type_check) for inp in graph.inputs] 277 | outputs = [OnnxExporter.export_value_info_proto(out, do_type_check) for out in graph.outputs] 278 | if tensor_map is None: 279 | tensor_map = graph.tensors() 280 | tensor_map = misc.unique_dicts(tensor_map, subgraph_tensor_map) 281 | else: 282 | tensor_map = misc.combine_dicts(tensor_map, subgraph_tensor_map) 283 | initializer = [ 284 | OnnxExporter.export_tensor_proto(tensor) 285 | for tensor in tensor_map.values() 286 | if isinstance(tensor, Constant) and not isinstance(tensor._values, SparseValues) 287 | ] 288 | 289 | sparse_initializer = [ 290 | OnnxExporter.export_sparse_tensor_proto(tensor) 291 | for tensor in tensor_map.values() 292 | if isinstance(tensor, Constant) and isinstance(tensor._values, SparseValues) 293 | ] 294 | 295 | # Remove inputs and outputs to export ValueInfoProtos 296 | for tensor in graph.inputs + graph.outputs: 297 | if tensor.name in tensor_map: 298 | del tensor_map[tensor.name] 299 | 300 | # Omit tensors from value_info if we don't know their shape/dtype 301 | def has_value_info(tensor): 302 | """Check if a tensor is a Variable with either a defined dtype or shape.""" 303 | return isinstance(tensor, Variable) and (tensor.dtype is not None or tensor.shape is not None) 304 | 305 | value_info = [ 306 | OnnxExporter.export_value_info_proto(tensor, do_type_check) 307 | for tensor in tensor_map.values() 308 | if has_value_info(tensor) 309 | ] 310 | 311 | return onnx.helper.make_graph( 312 | nodes=nodes, 313 | name=graph.name, 314 | inputs=inputs, 315 | outputs=outputs, 316 | initializer=initializer, 317 | sparse_initializer=sparse_initializer, 318 | doc_string=graph.doc_string, 319 | value_info=value_info, 320 | ) 321 | 322 | 323 | def export_onnx(graph: Graph, do_type_check=True, **kwargs) -> "onnx.ModelProto": 324 | """ 325 | Exports an onnx-graphsurgeon Graph to an ONNX model. 326 | 327 | Args: 328 | graph (Graph): The graph to export 329 | 330 | do_type_check (bool): Whether to check that input and output tensors have data types defined, and fail if not. 331 | Defaults to True. 332 | kwargs: Additional arguments to onnx.helper.make_model 333 | 334 | Returns: 335 | onnx.ModelProto: A corresponding ONNX model. 336 | """ 337 | sub_graphs = graph.subgraphs(recursive=True) 338 | 339 | graph_constants_list = [ 340 | {name: tensor for name, tensor in sub_graph.tensors().items() if isinstance(tensor, Constant)} 341 | for sub_graph in sub_graphs 342 | ] 343 | 344 | if not graph_constants_list: 345 | intersection = None 346 | else: 347 | intersection = ( 348 | { 349 | key: graph_constants_list[0][key] 350 | for key in graph_constants_list[0] 351 | if all(key in d and graph_constants_list[0][key] == d[key] for d in graph_constants_list[1:]) 352 | } 353 | if graph_constants_list 354 | else None 355 | ) 356 | 357 | onnx_graph = OnnxExporter.export_graph( 358 | graph, tensor_map=graph.tensors(), subgraph_tensor_map=intersection, do_type_check=do_type_check 359 | ) 360 | onnx_functions = [OnnxExporter.export_function(func) for func in graph.functions] 361 | kwargs["functions"] = onnx_functions 362 | 363 | if "opset_imports" not in kwargs: 364 | kwargs["opset_imports"] = update_import_domains(graph) 365 | 366 | if "ir_version" not in kwargs and graph.ir_version is not None: 367 | kwargs["ir_version"] = graph.ir_version 368 | 369 | model = onnx.helper.make_model(onnx_graph, **kwargs) 370 | if graph.metadata_props is not None: 371 | model.metadata_props.extend(graph.metadata_props) 372 | model.producer_name = graph.producer_name 373 | model.producer_version = graph.producer_version 374 | return model 375 | -------------------------------------------------------------------------------- /onnxslim/third_party/onnx_graphsurgeon/graph_pattern/__init__.py: -------------------------------------------------------------------------------- 1 | from onnxslim.third_party.onnx_graphsurgeon.graph_pattern.graph_pattern import ( 2 | GraphPattern, 3 | PatternMapping, 4 | ) 5 | -------------------------------------------------------------------------------- /onnxslim/third_party/onnx_graphsurgeon/importers/__init__.py: -------------------------------------------------------------------------------- 1 | from onnxslim.third_party.onnx_graphsurgeon.importers.base_importer import BaseImporter 2 | -------------------------------------------------------------------------------- /onnxslim/third_party/onnx_graphsurgeon/importers/base_importer.py: -------------------------------------------------------------------------------- 1 | # 2 | # SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | # SPDX-License-Identifier: Apache-2.0 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | from onnxslim.third_party.onnx_graphsurgeon.ir.graph import Graph 19 | 20 | 21 | class BaseImporter: 22 | @staticmethod 23 | def import_graph(graph) -> Graph: 24 | """ 25 | Import a graph from some source graph. 26 | 27 | Args: 28 | graph (object): The source graph to import. For example, this might be an onnx.GraphProto. 29 | 30 | Returns: 31 | Graph: The equivalent onnx-graphsurgeon graph. 32 | """ 33 | raise NotImplementedError("BaseImporter is an abstract class") 34 | -------------------------------------------------------------------------------- /onnxslim/third_party/onnx_graphsurgeon/ir/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/inisis/OnnxSlim/4e820ced75ed909f96b5b92bc46cebf4ca3a2156/onnxslim/third_party/onnx_graphsurgeon/ir/__init__.py -------------------------------------------------------------------------------- /onnxslim/third_party/onnx_graphsurgeon/ir/function.py: -------------------------------------------------------------------------------- 1 | # 2 | # SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | # SPDX-License-Identifier: Apache-2.0 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | import copy 19 | from typing import List, Sequence 20 | 21 | from onnxslim.third_party.onnx_graphsurgeon.ir.graph import Graph 22 | from onnxslim.third_party.onnx_graphsurgeon.ir.node import Node 23 | from onnxslim.third_party.onnx_graphsurgeon.ir.tensor import Tensor, Variable 24 | from onnxslim.third_party.onnx_graphsurgeon.logger import G_LOGGER 25 | from onnxslim.third_party.onnx_graphsurgeon.util import misc 26 | 27 | 28 | class Function(Graph): 29 | """ 30 | Represents a local function, which is a default implementation of a Custom Op. This default implementation is 31 | represented as a Graph of other Ops. 32 | 33 | Functions are used in a model by creating a Node with the same name and domain as the function. This can be done 34 | using the __call__() method of a Function, which creates this new node and appends it to a Graph. A Function is not 35 | a subgraph of a Graph, and its Nodes, Tensors, and subgraphs are entirely separate from the main Graph. 36 | 37 | Functions can be composed of other functions, but cyclical or recursive definitions are not allowed in ONNX. 38 | """ 39 | 40 | DEFAULT_DOMAIN = "onnx_graphsurgeon" 41 | 42 | def __init__( 43 | self, 44 | name: str, 45 | domain: str = None, 46 | nodes: Sequence[Node] = None, 47 | inputs: Sequence[Tensor] = None, 48 | outputs: Sequence[Tensor] = None, 49 | doc_string: str = None, 50 | opset: int = None, 51 | import_domains: "Sequence[onnx.OperatorSetIdProto]" = None, 52 | functions: "Sequence[Function]" = None, 53 | attrs: dict = None, 54 | ): 55 | """ 56 | Args: 57 | name (str): The name of the function. 58 | domain (str): The domain/namespace of this function. 59 | nodes (Sequence[Node]): A list of the nodes in this function. 60 | inputs (Sequence[Tensor]): A list of graph input Tensors. 61 | outputs (Sequence[Tensor]): A list of graph output Tensors. 62 | doc_string (str): A doc_string for the function. Defaults to "". 63 | opset (int): The ONNX opset used by nodes in this function. 64 | import_domains (Sequence[onnx.OperatorSetIdProto]): The list of domains used by nodes in this function. 65 | functions (Sequence[Function]): The list of functions in this model. 66 | attrs (dict): A mapping of attribute names to their default values. 67 | Nodes within this function can have attributes which take on the values of the Function attributes. 68 | When a Function is instantiated into a Node, providing attributes to that Node will override the Function's 69 | default attribute values. A default value of `None` means that the instantiated Node must provide the value 70 | of that attribute (in other words, it is a required attribute). 71 | """ 72 | self.domain = misc.default_value(domain, Function.DEFAULT_DOMAIN) 73 | self.attrs = misc.default_value(attrs, {}) 74 | 75 | super().__init__( 76 | nodes, 77 | inputs, 78 | outputs, 79 | name=name, 80 | doc_string=doc_string, 81 | opset=opset, 82 | import_domains=import_domains, 83 | functions=functions, 84 | ) 85 | 86 | # Properties of Graph that Function doesn't have. 87 | del self.producer_name 88 | del self.producer_version 89 | 90 | @property 91 | def unique_id(self): 92 | """Returns a tuple which uniquely identifies this function.""" 93 | return (self.domain, self.name) 94 | 95 | def cleanup( 96 | self, 97 | remove_unused_node_outputs=False, 98 | recurse_subgraphs=True, 99 | remove_unused_graph_inputs=False, 100 | recurse_functions=False, 101 | ): 102 | """See Graph.cleanup() The only difference is that 'recurse_functions' defaults to False, so that only this 103 | Function is cleaned up. 104 | """ 105 | if recurse_functions: 106 | G_LOGGER.warning( 107 | "Function.cleanup() called with recurse_functions=True, meaning that other functions will also be cleaned up." 108 | ) 109 | return super().cleanup( 110 | remove_unused_node_outputs=remove_unused_node_outputs, 111 | recurse_subgraphs=recurse_subgraphs, 112 | remove_unused_graph_inputs=remove_unused_graph_inputs, 113 | recurse_functions=recurse_functions, 114 | ) 115 | 116 | def fold_constants(self, recurse_functions=False, **kwargs): 117 | """See Graph.fold_constants() The only difference is that 'recurse_functions' defaults to False, so that only 118 | this Function's constants are folded. 119 | """ 120 | if recurse_functions: 121 | G_LOGGER.warning( 122 | "Function.fold_constants() called with recurse_functions=True, meaning that other functions will also be const-folded." 123 | ) 124 | return super().fold_constants(recurse_functions=recurse_functions, **kwargs) 125 | 126 | def toposort( 127 | self, 128 | recurse_subgraphs=True, 129 | recurse_functions=False, 130 | mode="nodes", 131 | ): 132 | """See Graph.toposort() The only difference is that 'recurse_functions' defaults to False and mode defaults to 133 | "nodes", so that by default only this function's nodes will be sorted. 134 | """ 135 | if recurse_functions: 136 | G_LOGGER.warning( 137 | "Function.toposort() called with recurse_functions=True, meaning that other functions will be sorted." 138 | ) 139 | return super().toposort( 140 | recurse_subgraphs=recurse_subgraphs, 141 | recurse_functions=recurse_functions, 142 | mode=mode, 143 | ) 144 | 145 | def __call__(self, graph, inputs=None, outputs=None, *args, **kwargs) -> List[Tensor]: 146 | """ 147 | Creates a Node which is an instance of this function. The created node can be used in a Graph or another 148 | Function. 149 | 150 | The provided inputs are processed the same way as in Graph.layer(). 151 | If outputs are not provided, they are created based on the Function's outputs. 152 | 153 | Args: 154 | graph (Union[Graph, Function]): The Graph of Function to add the new node to. 155 | inputs (List[Union[Tensor, str, numpy.ndarray]]): The list of inputs. 156 | outputs (List[Union[Tensor, str, numpy.ndarray]]): The list of outputs. 157 | attrs (Dict[str, Any]): A list of attributes for the node. 158 | The attribute names should be a subset of this Function's attribute names. 159 | args/kwargs: These are passed directly to the constructor of Node. 160 | 161 | Returns: 162 | List[Tensor]: The output tensors of the node. 163 | """ 164 | if inputs is not None and len(inputs) != len(self.inputs): 165 | msg_template = "Function {} expects {} inputs, but was called with {} inputs." 166 | G_LOGGER.warning(msg_template.format(self.name, len(self.inputs), len(inputs))) 167 | 168 | new_output_indices = [] 169 | if outputs is None: 170 | # Graph.layer() will create Tensors and make sure the names do not conflict. 171 | outputs = [out.name for out in self.outputs] 172 | new_output_indices = list(range(len(outputs))) 173 | elif len(outputs) != len(self.outputs): 174 | msg_template = "Function {} expects {} outputs, but was called with {} outputs." 175 | G_LOGGER.warning(msg_template.format(self.name, len(self.outputs), len(outputs))) 176 | else: 177 | new_output_indices = [i for i in range(len(outputs)) if not isinstance(outputs[i], Tensor)] 178 | 179 | attrs = kwargs.get("attrs", None) 180 | if attrs is not None: 181 | for attr_name, default_val in self.attrs.items(): 182 | if default_val is None and attr_name not in attrs: 183 | msg_template = "Function {} called without required attribute: {}" 184 | G_LOGGER.warning(msg_template.format(self.name, attr_name)) 185 | 186 | inputs = misc.default_value(inputs, []) 187 | outputs = misc.default_value(outputs, []) 188 | outputs = graph.layer( 189 | *args, 190 | **kwargs, 191 | op=self.name, 192 | domain=self.domain, 193 | inputs=inputs, 194 | outputs=outputs, 195 | ) 196 | 197 | # For newly created output tensors, set their shape and dtype to match the Function definition. 198 | for i in new_output_indices: 199 | outputs[i].dtype = self.outputs[i].dtype 200 | outputs[i].shape = self.outputs[i].shape 201 | 202 | return outputs 203 | 204 | def copy(self): 205 | """ 206 | Copy the function. 207 | 208 | This makes copies of all nodes and tensors in the function, but will not 209 | do a deep-copy of weights or attributes (with the exception of ``Graph`` 210 | attributes, which will be copied using their ``copy`` method). 211 | 212 | Returns: 213 | Function: A copy of the function. 214 | """ 215 | local_tensor_copies = {n: t.copy() for n, t in self.tensors().items()} 216 | 217 | def get_tensor(name): 218 | """Retrieve a tensor by name from a deep-copied dictionary of tensors.""" 219 | return local_tensor_copies[name] if name else Variable.empty() 220 | 221 | # Next, copy nodes, and update inputs/outputs 222 | new_nodes = [] 223 | for node in self.nodes: 224 | new_node = node.copy( 225 | inputs=[get_tensor(inp.name) for inp in node.inputs], 226 | outputs=[get_tensor(out.name) for out in node.outputs], 227 | tensor_map=local_tensor_copies, 228 | ) 229 | new_nodes.append(new_node) 230 | new_func_inputs = [get_tensor(inp.name) for inp in self.inputs] 231 | new_func_outputs = [get_tensor(out.name) for out in self.outputs] 232 | 233 | new_attrs = {name: copy.copy(val) for name, val in self.attrs.items()} 234 | 235 | return Function( 236 | self.name, 237 | self.domain, 238 | nodes=new_nodes, 239 | inputs=new_func_inputs, 240 | outputs=new_func_outputs, 241 | doc_string=self.doc_string, 242 | opset=self.opset, 243 | import_domains=self.import_domains, 244 | functions=self.functions, 245 | attrs=new_attrs, 246 | ) 247 | 248 | def __eq__(self, other: "Function"): 249 | """Checks equality of self with another Function object based on their attributes.""" 250 | 251 | def sequences_equal(seq1, seq2): 252 | """Checks if two sequences are equal in length and elements.""" 253 | return len(seq1) == len(seq2) and all(elem1 == elem2 for elem1, elem2 in zip(seq1, seq2)) 254 | 255 | return ( 256 | self.unique_id == other.unique_id 257 | and self.opset == other.opset 258 | and self.import_domains == other.import_domains 259 | and sequences_equal(self.inputs, other.inputs) 260 | and sequences_equal(self.outputs, other.outputs) 261 | and sequences_equal(self.nodes, other.nodes) 262 | ) 263 | 264 | def __str__(self): 265 | """Returns a string representation of the function including its name, domain, opset, inputs, nodes, and 266 | outputs. 267 | """ 268 | nodes_str = "\n".join([str(node) for node in self.nodes]) 269 | out = f"Function {self.name}, Domain {self.domain}, Opset {self.opset}" 270 | out += f"\nInputs: {self.inputs}" 271 | out += f"\nNodes: {nodes_str}" 272 | out += f"\nOutputs: {self.outputs}" 273 | return out 274 | -------------------------------------------------------------------------------- /onnxslim/third_party/onnx_graphsurgeon/ir/node.py: -------------------------------------------------------------------------------- 1 | # 2 | # SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | # SPDX-License-Identifier: Apache-2.0 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | from collections import OrderedDict 19 | from dataclasses import dataclass 20 | from typing import Dict, List 21 | 22 | from onnxslim.third_party.onnx_graphsurgeon.ir.tensor import Constant, Tensor, Variable 23 | from onnxslim.third_party.onnx_graphsurgeon.logger import G_LOGGER 24 | from onnxslim.third_party.onnx_graphsurgeon.util import misc 25 | 26 | 27 | class Node: 28 | @dataclass 29 | class AttributeRef: 30 | """ 31 | An AttributeRef is an attribute value which references an attribute in the parent function. A node's attribute 32 | can only be an AttributeRef if the node lives inside a Function. 33 | 34 | Args: 35 | name (str): The name of the referenced attribute in the parent Function. 36 | type (type): The attribute's type. 37 | """ 38 | 39 | name: str 40 | type: type 41 | 42 | def __init__( 43 | self, 44 | op: str, 45 | name: str = None, 46 | attrs: Dict[str, object] = None, 47 | inputs: List["Tensor"] = None, 48 | outputs: List["Tensor"] = None, 49 | domain: str = None, 50 | ): 51 | """ 52 | A node represents an operation in a graph, and consumes zero or more Tensors, and produces zero or more Tensors. 53 | 54 | Args: 55 | op (str): The operation this node performs. 56 | 57 | name (str): The name of this node. 58 | attrs (Dict[str, object]): A dictionary that maps attribute names to their values. 59 | inputs (List[Tensor]): A list of zero or more input Tensors. 60 | outputs (List[Tensor]): A list of zero or more output Tensors. 61 | domain (str): The domain of this node, 62 | """ 63 | self.op = op 64 | self.name = misc.default_value(name, "") 65 | self.attrs = misc.default_value(attrs, OrderedDict()) 66 | self.inputs = misc.SynchronizedList(self, field_name="outputs", initial=misc.default_value(inputs, [])) 67 | self.outputs = misc.SynchronizedList(self, field_name="inputs", initial=misc.default_value(outputs, [])) 68 | self.domain = domain 69 | 70 | def i(self, tensor_idx=0, producer_idx=0): 71 | """ 72 | Convenience function to get a producer node of one of this node's input tensors. Note that the parameters are 73 | swapped compared to the o() function; this is because tensors are likely to have only a single producer. 74 | 75 | For example: 76 | :: 77 | 78 | assert node.i() == node.inputs[0].inputs[0] 79 | assert node.i(1, 2) == node.inputs[1].inputs[2] 80 | 81 | Args: 82 | tensor_idx (int): The index of the input tensor of this node. Defaults to 0. 83 | producer_idx (int): The index of the producer of the input tensor, if the tensor has multiple producers. Defaults to 0 84 | 85 | Returns: 86 | Node: The specified producer (input) node. 87 | """ 88 | return self.inputs[tensor_idx].inputs[producer_idx] 89 | 90 | def o(self, consumer_idx=0, tensor_idx=0): 91 | """ 92 | Convenience function to get a consumer node of one of this node's output tensors. 93 | 94 | For example: 95 | :: 96 | 97 | assert node.o() == node.outputs[0].outputs[0] 98 | assert node.o(2, 1) == node.outputs[1].outputs[2] 99 | 100 | Args: 101 | consumer_idx (int): The index of the consumer of the input tensor. Defaults to 0. 102 | tensor_idx (int): The index of the output tensor of this node, if the node has multiple outputs. Defaults to 0. 103 | 104 | Returns: 105 | Node: The specified consumer (output) node 106 | """ 107 | return self.outputs[tensor_idx].outputs[consumer_idx] 108 | 109 | def subgraphs(self, recursive=False): 110 | """ 111 | Convenience function to iterate over all subgraphs which are contained in this node. Node subgraphs are found in 112 | attributes of ONNX control flow nodes such as 'If' and 'Loop'. 113 | 114 | Args: 115 | recursive (bool): Whether to recurse into the subgraph nodes when looking for subgraphs. Defaults to False. 116 | 117 | Returns: 118 | A generator which iterates over this node's subgraphs. 119 | """ 120 | from onnxslim.third_party.onnx_graphsurgeon.ir.graph import Graph 121 | 122 | visit_queue = [self] 123 | 124 | # This prevents infinite recursion in the (illegal) case of cyclical graphs. 125 | visited = set() 126 | 127 | while visit_queue: 128 | node = visit_queue.pop() 129 | for attr in node.attrs.values(): 130 | if isinstance(attr, Graph) and id(attr) not in visited: 131 | visited.add(id(attr)) 132 | if recursive: 133 | visit_queue.extend(attr.nodes) 134 | yield attr 135 | 136 | def __setattr__(self, name, value): 137 | """Sets the attribute 'name' to 'value', handling special cases for 'inputs' and 'outputs' attributes.""" 138 | if name in {"inputs", "outputs"}: 139 | try: 140 | attr = getattr(self, name) 141 | if value is attr: 142 | # This can happen when using things like += 143 | # The __iadd__ is executed followed by an assignment 144 | return 145 | 146 | attr.clear() 147 | attr.extend(value) 148 | except AttributeError: 149 | super().__setattr__(name, value) 150 | else: 151 | super().__setattr__(name, value) 152 | 153 | def copy( 154 | self, 155 | inputs: List["Tensor"] = None, 156 | outputs: List["Tensor"] = None, 157 | tensor_map=None, 158 | ): 159 | """ 160 | Makes a shallow copy of this node, overriding input and output information. 161 | 162 | Note: Generally, you should only ever make a copy of a Graph. 163 | """ 164 | from onnxslim.third_party.onnx_graphsurgeon.ir.graph import Graph 165 | 166 | new_attrs = OrderedDict() 167 | for name, attr in self.attrs.items(): 168 | new_attrs[name] = attr.copy(tensor_map) if isinstance(attr, Graph) else attr 169 | return Node( 170 | self.op, 171 | self.name, 172 | new_attrs, 173 | inputs=inputs, 174 | outputs=outputs, 175 | domain=self.domain, 176 | ) 177 | 178 | def __str__(self): 179 | """Return a string representation of the object showing its name and operation.""" 180 | ret = f"{self.name} ({self.op})" 181 | 182 | def add_io(name, io): 183 | """Add the input or output operations and their names to the string representation of the object.""" 184 | nonlocal ret 185 | ret += f"\n\t{name}: [" 186 | for elem in io: 187 | ret += f"\n\t\t{elem}" 188 | ret += "\n\t]" 189 | 190 | add_io("Inputs", self.inputs) 191 | add_io("Outputs", self.outputs) 192 | 193 | if self.attrs: 194 | ret += f"\nAttributes: {self.attrs}" 195 | 196 | if self.domain: 197 | ret += f"\nDomain: {self.domain}" 198 | 199 | return ret 200 | 201 | def __repr__(self): 202 | """Return the string representation of the Ultralytics object.""" 203 | return self.__str__() 204 | 205 | def __eq__(self, other): 206 | """Check whether two nodes are equal by comparing name, attributes, op, inputs, and outputs.""" 207 | G_LOGGER.verbose(f"Comparing node: {self.name} with {other.name}") 208 | attrs_match = self.name == other.name and self.op == other.op and self.attrs == other.attrs 209 | if not attrs_match: 210 | return False 211 | 212 | inputs_match = misc.sequences_equal(self.inputs, other.inputs) 213 | if not inputs_match: 214 | return False 215 | 216 | outputs_match = misc.sequences_equal(self.outputs, other.outputs) 217 | return self.domain == other.domain if outputs_match else False 218 | 219 | @property 220 | def users(self): 221 | users = [] 222 | for output in self.outputs: # output is a Variable 223 | if output.is_output: 224 | users.append(output) 225 | users.extend(iter(output.outputs)) 226 | return users 227 | 228 | @property 229 | def feeds(self): 230 | """Retrieve the list of nodes that provide inputs to the given node.""" 231 | feeds = [] 232 | for input in self.inputs: 233 | if len(input.inputs) == 0 and not isinstance(input, Constant): 234 | feeds.append(input) 235 | elif isinstance(input, Constant): 236 | feeds.append(input) 237 | else: 238 | feeds.extend(input if feed.op == "Split" else feed for feed in input.inputs) 239 | return feeds 240 | 241 | def erase(self, input_var_idx=0, output_var_idx=0): 242 | if isinstance(self.inputs[input_var_idx], Variable): 243 | if self.inputs[input_var_idx].is_input: 244 | self.outputs[output_var_idx].replace_all_uses_with(self.inputs[input_var_idx]) 245 | self.inputs.clear() 246 | self.outputs.clear() 247 | else: 248 | self.inputs[input_var_idx].replace_all_uses_with(self.outputs[output_var_idx]) 249 | self.inputs.clear() 250 | self.outputs.clear() 251 | 252 | def replace_all_uses_with(self, node: "Node"): 253 | """Replace all uses of this node with the given node.""" 254 | for user in self.users: 255 | for inp in user.inputs: 256 | if inp in self.outputs: 257 | for i, input in enumerate(user.inputs): 258 | if input == inp: 259 | user.inputs[i] = node.outputs[self.outputs.index(inp)] 260 | 261 | self.inputs.clear() 262 | self.outputs.clear() 263 | -------------------------------------------------------------------------------- /onnxslim/third_party/onnx_graphsurgeon/logger/__init__.py: -------------------------------------------------------------------------------- 1 | from onnxslim.third_party.onnx_graphsurgeon.logger.logger import G_LOGGER, LogMode 2 | -------------------------------------------------------------------------------- /onnxslim/third_party/onnx_graphsurgeon/logger/logger.py: -------------------------------------------------------------------------------- 1 | # 2 | # SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | # SPDX-License-Identifier: Apache-2.0 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | import enum 19 | import inspect 20 | import os 21 | import sys 22 | import time 23 | 24 | from onnxslim.third_party.onnx_graphsurgeon.util.exception import ( 25 | OnnxGraphSurgeonException, 26 | ) 27 | 28 | 29 | # Context manager to apply indentation to messages 30 | class LoggerIndent: 31 | def __init__(self, logger, indent): 32 | """Initialize the LoggerIndent context manager with the specified logger and indentation level.""" 33 | self.logger = logger 34 | self.old_indent = self.logger.logging_indent 35 | self.indent = indent 36 | 37 | def __enter__(self): 38 | """Set logger indentation level on entering the context.""" 39 | self.logger.logging_indent = self.indent 40 | return self 41 | 42 | def __exit__(self, exc_type, exc_value, traceback): 43 | """Reset logger indentation level on exiting the context.""" 44 | self.logger.logging_indent = self.old_indent 45 | 46 | 47 | # Context manager to suppress messages 48 | class LoggerSuppress: 49 | def __init__(self, logger, severity): 50 | """Initialize a LoggerSuppress object with a logger and severity level.""" 51 | self.logger = logger 52 | self.old_severity = self.logger.severity 53 | self.severity = severity 54 | 55 | def __enter__(self): 56 | """Set logger severity to a specified level when entering the context.""" 57 | self.logger.severity = self.severity 58 | return self 59 | 60 | def __exit__(self, exc_type, exc_value, traceback): 61 | """Reset logger severity to its original level when exiting the context.""" 62 | self.logger.severity = self.old_severity 63 | 64 | 65 | class LogMode(enum.IntEnum): 66 | EACH = 0 # Log the message each time 67 | ONCE = 1 # Log the message only once. The same message will not be logged again. 68 | 69 | 70 | class Logger: 71 | ULTRA_VERBOSE = -10 72 | VERBOSE = 0 73 | DEBUG = 10 74 | INFO = 20 75 | WARNING = 30 76 | ERROR = 40 77 | CRITICAL = 50 78 | 79 | SEVERITY_LETTER_MAPPING = { 80 | ULTRA_VERBOSE: "[UV]", 81 | VERBOSE: "[V]", 82 | DEBUG: "[D]", 83 | INFO: "[I]", 84 | WARNING: "[W]", 85 | ERROR: "[E]", 86 | CRITICAL: "[C]", 87 | } 88 | 89 | SEVERITY_COLOR_MAPPING = { 90 | ULTRA_VERBOSE: "cyan", 91 | VERBOSE: "dark_gray", 92 | DEBUG: "light_gray", 93 | INFO: "light_green", 94 | WARNING: "light_yellow", 95 | ERROR: "red_1", 96 | CRITICAL: "red_1", 97 | } 98 | 99 | def __init__(self, severity=INFO, colors=True, letter=True, timestamp=False, line_info=False): 100 | """ 101 | Logger. 102 | 103 | Args: 104 | severity (Logger.Severity): Messages below this severity are ignored. 105 | colors (bool): Whether to use colored output. 106 | letter (bool): Whether to prepend each logging message with a letter indicating it's severity. Defaults to True. 107 | timestamp (bool): Whether to include a timestamp in the logging output. Defaults to False. 108 | line_info (bool): Whether to include file and line number information in the logging output. Defaults to False. 109 | """ 110 | self._severity = severity 111 | self.logging_indent = 0 112 | self.root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, os.pardir)) 113 | self.once_logged = set() 114 | self.colors = colors 115 | self.letter = letter 116 | self.timestamp = timestamp 117 | self.line_info = line_info 118 | self.logger_callbacks = [] 119 | 120 | @property 121 | def severity(self): 122 | """Returns the logging severity level.""" 123 | return self._severity 124 | 125 | @severity.setter 126 | def severity(self, value): 127 | """Returns or sets the logging severity level with callback updates.""" 128 | self._severity = value 129 | for callback in self.logger_callbacks: 130 | callback(self._severity) 131 | 132 | def register_callback(self, callback): 133 | """ 134 | Registers a callback with the logger, which will be invoked when the logging severity is modified. The callback 135 | is guaranteed to be called at least once in the register_callback function. 136 | 137 | Args: 138 | callback (Callable(Logger.Severity)): A callback that accepts the current logger severity. 139 | """ 140 | callback(self._severity) 141 | self.logger_callbacks.append(callback) 142 | 143 | def indent(self, level=1): 144 | """Returns a context manager that indents all strings logged by the specified amount.""" 145 | return LoggerIndent(self, level + self.logging_indent) 146 | 147 | def suppress(self, severity=CRITICAL): 148 | """ 149 | Returns a context manager that temporarily changes the severity of the logger for its duration. 150 | 151 | Args: 152 | severity (Logger.Severity): The severity to set the logger to. Defaults to Logger.CRITICAL, which will suppress all messages. 153 | """ 154 | return LoggerSuppress(self, severity) 155 | 156 | # If once is True, the logger will only log this message a single time. Useful in loops. 157 | # message may be a callable which returns a message. This way, only if the message needs to be logged is it ever generated. 158 | def log(self, message, severity, mode=LogMode.EACH, stack_depth=2): 159 | """Logs a message with a specified severity and mode, supporting both single and repeated logging based on 160 | conditions. 161 | """ 162 | 163 | def process_message(message, stack_depth): 164 | """Generates a log message prefix with file name and line number based on the specified stack depth.""" 165 | 166 | def get_prefix(): 167 | def get_line_info(): 168 | module = inspect.getmodule(sys._getframe(stack_depth + 3)) or inspect.getmodule( 169 | sys._getframe(stack_depth + 2) 170 | ) 171 | filename = module.__file__ 172 | filename = os.path.relpath(filename, self.root_dir) 173 | # If the file is not located in trt_smeagol, use its basename instead. 174 | if os.pardir in filename: 175 | filename = os.path.basename(filename) 176 | return f"[{filename}:{sys._getframe(stack_depth).f_lineno}] " 177 | 178 | prefix = "" 179 | if self.letter: 180 | prefix += f"{Logger.SEVERITY_LETTER_MAPPING[severity]} " 181 | if self.timestamp: 182 | prefix += "({:}) ".format(time.strftime("%X")) 183 | if self.line_info: 184 | prefix += get_line_info() 185 | return prefix 186 | 187 | def apply_indentation(message): 188 | """Indent each line in the message by the specified logging_indent level.""" 189 | message_lines = str(message).splitlines() 190 | return "\n".join(["\t" * self.logging_indent + line for line in message_lines]) 191 | 192 | def apply_color(message): 193 | """Apply color formatting to the message if color support is enabled.""" 194 | if self.colors: 195 | try: 196 | import colored 197 | 198 | color = Logger.SEVERITY_COLOR_MAPPING[severity] 199 | return colored.stylize(message, [colored.fg(color)]) 200 | except ImportError: 201 | self.colors = False 202 | self.warning( 203 | "colored module is not installed, will not use colors when logging. To enable colors, please install the colored module: python3 -m pip install colored" 204 | ) 205 | self.colors = True 206 | return message 207 | 208 | prefix = get_prefix() 209 | message = apply_indentation(message) 210 | return apply_color(f"{prefix}{message}") 211 | 212 | def should_log(message): 213 | """Determines if a message should be logged based on the severity level and logging mode.""" 214 | should = severity >= self._severity 215 | if mode == LogMode.ONCE: 216 | message_hash = hash(message) 217 | should &= message_hash not in self.once_logged 218 | self.once_logged.add(message_hash) 219 | return should 220 | 221 | if not should_log(message): 222 | return 223 | 224 | if callable(message): 225 | message = message() 226 | message = str(message) 227 | print(process_message(message, stack_depth=stack_depth)) 228 | 229 | def ultra_verbose(self, message, mode=LogMode.EACH): 230 | """Logs an ultra-verbose message with a specified mode and stack depth of 3.""" 231 | self.log(message, Logger.ULTRA_VERBOSE, mode=mode, stack_depth=3) 232 | 233 | def verbose(self, message, mode=LogMode.EACH): 234 | """Logs a verbose message with a specified mode and stack depth of 3.""" 235 | self.log(message, Logger.VERBOSE, mode=mode, stack_depth=3) 236 | 237 | def debug(self, message, mode=LogMode.EACH): 238 | """Logs a debug message with a specified mode and stack depth of 3.""" 239 | self.log(message, Logger.DEBUG, mode=mode, stack_depth=3) 240 | 241 | def info(self, message, mode=LogMode.EACH): 242 | """Logs an informational message with a specified mode and stack depth of 3.""" 243 | self.log(message, Logger.INFO, mode=mode, stack_depth=3) 244 | 245 | def warning(self, message, mode=LogMode.EACH): 246 | """Logs a warning message with a specified mode and stack depth of 3.""" 247 | self.log(message, Logger.WARNING, mode=mode, stack_depth=3) 248 | 249 | def error(self, message, mode=LogMode.EACH): 250 | """Logs an error message with a specified mode and stack depth of 3.""" 251 | self.log(message, Logger.ERROR, mode=mode, stack_depth=3) 252 | 253 | # Like error, but immediately exits. 254 | def critical(self, message): 255 | """Logs a critical message with a stack depth of 3 and raises an OnnxGraphSurgeonException.""" 256 | self.log(message, Logger.CRITICAL, stack_depth=3) 257 | raise OnnxGraphSurgeonException(message) from None # Erase exception chain 258 | 259 | 260 | global G_LOGGER 261 | G_LOGGER = Logger() 262 | -------------------------------------------------------------------------------- /onnxslim/third_party/onnx_graphsurgeon/util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/inisis/OnnxSlim/4e820ced75ed909f96b5b92bc46cebf4ca3a2156/onnxslim/third_party/onnx_graphsurgeon/util/__init__.py -------------------------------------------------------------------------------- /onnxslim/third_party/onnx_graphsurgeon/util/exception.py: -------------------------------------------------------------------------------- 1 | # 2 | # SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | # SPDX-License-Identifier: Apache-2.0 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | 19 | class OnnxGraphSurgeonException(Exception): 20 | """An exception raised by ONNX-GraphSurgeon.""" 21 | -------------------------------------------------------------------------------- /onnxslim/third_party/onnx_graphsurgeon/util/misc.py: -------------------------------------------------------------------------------- 1 | # 2 | # SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | # SPDX-License-Identifier: Apache-2.0 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | from collections import OrderedDict 19 | from typing import List, Sequence 20 | 21 | import numpy as np 22 | from onnx import AttributeProto 23 | 24 | from onnxslim.third_party.onnx_graphsurgeon.logger import G_LOGGER 25 | 26 | 27 | # default_value exists to solve issues that might result from Python's normal default argument behavior. 28 | # Specifically, consider the following class: 29 | # 30 | # class MyClass(object): 31 | # def __init__(self, value=[]): 32 | # self.value = value 33 | # 34 | # This leads to unwanted behavior when the default value is used: 35 | # 36 | # >>> x = MyClass() 37 | # >>> x.value.append("SHOULD NOT BE IN Y") 38 | # >>> y = MyClass() 39 | # >>> y.value 40 | # ['SHOULD NOT BE IN Y'] 41 | # 42 | # If we rewrite the class using default value: 43 | # 44 | # class MyClass(object): 45 | # def __init__(self, value=None): 46 | # self.value = default_value(value, []) 47 | # 48 | # Then we get the desired behavior: 49 | # 50 | # >>> x = MyClass() 51 | # >>> x.value.append("SHOULD NOT BE IN Y") 52 | # >>> y = MyClass() 53 | # >>> y.value 54 | # [] 55 | def default_value(value, default): 56 | """Return the value if not None, otherwise return the default value.""" 57 | return value if value is not None else default 58 | 59 | 60 | def combine_dicts(dict0, dict1): 61 | """ 62 | Combine two dictionaries. 63 | 64 | Values in the second will overwrite values in the first. 65 | """ 66 | if dict1 is None: 67 | return dict0 68 | combined = OrderedDict() 69 | combined.update(dict0) 70 | combined.update(dict1) 71 | return combined 72 | 73 | 74 | def unique_dicts(dict0, dict1): 75 | """ 76 | Subtract two dictionaries. 77 | 78 | Values in the second will be subtracted from the first. 79 | """ 80 | return {k: v for k, v in dict0.items() if k not in dict1} if dict1 else dict0 81 | 82 | 83 | def is_dynamic_dimension(dim): 84 | """Check if a dimension is dynamic (non-integer or negative).""" 85 | return not isinstance(dim, int) or dim < 0 86 | 87 | 88 | def is_dynamic_shape(shape): 89 | """Determine if any dimension in the given shape is dynamic (non-integer or negative).""" 90 | return any(is_dynamic_dimension(dim) for dim in shape) 91 | 92 | 93 | def volume(obj): 94 | """Calculate the volume by multiplying the elements of an iterable object.""" 95 | vol = 1 96 | for elem in obj: 97 | vol *= elem 98 | return vol 99 | 100 | 101 | _ONNX_ATTR_TYPE_TO_GS_TYPE = {} 102 | _GS_TYPE_TO_ONNX_ATTR_TYPE = {} 103 | 104 | 105 | # This method prevents circular import of Tensor and Graph 106 | def _init_dicts(): 107 | """Initialize mapping dictionaries to prevent circular imports of Tensor and Graph.""" 108 | global _ONNX_ATTR_TYPE_TO_GS_TYPE 109 | global _GS_TYPE_TO_ONNX_ATTR_TYPE 110 | if _ONNX_ATTR_TYPE_TO_GS_TYPE and _GS_TYPE_TO_ONNX_ATTR_TYPE: 111 | return 112 | 113 | from onnxslim.third_party.onnx_graphsurgeon.ir.graph import Graph 114 | from onnxslim.third_party.onnx_graphsurgeon.ir.tensor import Tensor 115 | 116 | _ONNX_ATTR_TYPE_TO_GS_TYPE = { 117 | AttributeProto.UNDEFINED: None, 118 | AttributeProto.FLOAT: float, 119 | AttributeProto.INT: int, 120 | AttributeProto.STRING: str, 121 | AttributeProto.TENSOR: Tensor, 122 | AttributeProto.GRAPH: Graph, 123 | AttributeProto.SPARSE_TENSOR: AttributeProto.SPARSE_TENSOR, 124 | AttributeProto.TYPE_PROTO: AttributeProto.TYPE_PROTO, 125 | AttributeProto.FLOATS: List[float], 126 | AttributeProto.INTS: List[int], 127 | AttributeProto.STRINGS: List[str], 128 | AttributeProto.TENSORS: List[Tensor], 129 | AttributeProto.GRAPHS: List[Graph], 130 | AttributeProto.SPARSE_TENSORS: AttributeProto.SPARSE_TENSORS, 131 | AttributeProto.TYPE_PROTOS: AttributeProto.TYPE_PROTOS, 132 | } 133 | _GS_TYPE_TO_ONNX_ATTR_TYPE = {v: k for k, v in _ONNX_ATTR_TYPE_TO_GS_TYPE.items()} 134 | 135 | 136 | def convert_from_onnx_attr_type(onnx_attr_type): 137 | """Converts an ONNX attribute type to its corresponding GS attribute type.""" 138 | _init_dicts() 139 | return _ONNX_ATTR_TYPE_TO_GS_TYPE[onnx_attr_type] 140 | 141 | 142 | def convert_to_onnx_attr_type(any_type): 143 | """Converts a given type to its corresponding ONNX attribute type.""" 144 | _init_dicts() 145 | if any_type in _GS_TYPE_TO_ONNX_ATTR_TYPE: 146 | return _GS_TYPE_TO_ONNX_ATTR_TYPE[any_type] 147 | if np.issubdtype(any_type, np.floating): 148 | return AttributeProto.FLOAT 149 | if np.issubdtype(any_type, np.integer): 150 | return AttributeProto.INT 151 | G_LOGGER.warning(f"Unable to convert {any_type} into an ONNX AttributeType") 152 | 153 | 154 | # Special type of list that synchronizes contents with another list. 155 | # Concrete example: Assume some node, n, contains an input tensor, t. If we remove t from n.inputs, 156 | # we also need to remove n from t.outputs. To avoid having to do this manually, we use SynchronizedList, 157 | # which takes an attribute name as a parameter, and then synchronizes to that attribute of each of its elements. 158 | # So, in the example above, we can make n.inputs a synchronized list whose field_name is set to "outputs". 159 | # See test_ir.TestNodeIO for functional tests 160 | class SynchronizedList(list): 161 | def __init__(self, parent_obj, field_name, initial): 162 | """Initialize a SynchronizedList with a parent object, a field name, and an initial set of elements.""" 163 | self.parent_obj = parent_obj 164 | self.field_name = field_name 165 | self.extend(initial) 166 | 167 | def _add_to_elem(self, elem): 168 | """Append the parent_obj to the list attribute defined by field_name in the provided elem object.""" 169 | list.append(getattr(elem, self.field_name), self.parent_obj) 170 | 171 | def _remove_from_elem(self, elem): 172 | """Remove the parent_obj from the list attribute defined by field_name in the provided elem object.""" 173 | list.remove(getattr(elem, self.field_name), self.parent_obj) 174 | 175 | def __delitem__(self, index): 176 | """Remove the element at the specified index and update the corresponding list attribute in the parent 177 | object. 178 | """ 179 | self._remove_from_elem(self[index]) 180 | super().__delitem__(index) 181 | 182 | def __setitem__(self, index, elem): 183 | """Update the element at the specified index and modify the corresponding list attribute in the parent 184 | object. 185 | """ 186 | self._remove_from_elem(self[index]) 187 | super().__setitem__(index, elem) 188 | self._add_to_elem(elem) 189 | 190 | def append(self, x): 191 | """Append an element to the list and update the parent object's corresponding list attribute.""" 192 | super().append(x) 193 | self._add_to_elem(x) 194 | 195 | def extend(self, iterable: Sequence[object]): 196 | """Extend the list with elements from an iterable and update the parent object's corresponding list 197 | attribute. 198 | """ 199 | super().extend(iterable) 200 | for elem in iterable: 201 | self._add_to_elem(elem) 202 | 203 | def insert(self, i, x): 204 | """Insert an element at a given position and update the parent object's corresponding list attribute.""" 205 | super().insert(i, x) 206 | self._add_to_elem(x) 207 | 208 | def remove(self, x): 209 | """Remove an element from the list and update the parent object's corresponding list attribute.""" 210 | super().remove(x) 211 | self._remove_from_elem(x) 212 | 213 | def pop(self, i=-1): 214 | """Remove and return the element at index i (default last) from the list and update the parent object's 215 | corresponding list attribute. 216 | """ 217 | elem = super().pop(i) 218 | self._remove_from_elem(elem) 219 | return elem 220 | 221 | def clear(self): 222 | """Clear all elements from the list and update the parent object's corresponding list attribute.""" 223 | for elem in self: 224 | self._remove_from_elem(elem) 225 | super().clear() 226 | 227 | def __add__(self, other_list: List[object]): 228 | """Concatenate the current list with another list and return the resulting list.""" 229 | return list(self) + list(other_list) 230 | 231 | def __iadd__(self, other_list: List[object]): 232 | """Append elements from another list to the current list and return the modified list.""" 233 | self.extend(other_list) 234 | return self 235 | 236 | def __copy__(self): 237 | """Return a shallow copy of the current list.""" 238 | return list(self) 239 | 240 | def __deepcopy__(self, memo): 241 | """Return a deep copy of the current list.""" 242 | return list(self) 243 | 244 | 245 | def sequences_equal(seq1, seq2): 246 | """Check if two sequences are equal by comparing their lengths and elements.""" 247 | length_match = len(seq1) == len(seq2) 248 | if not length_match: 249 | return False 250 | 251 | return all(elem1 == elem2 for elem1, elem2 in zip(seq1, seq2)) 252 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | with open("VERSION") as f: 4 | version = f.read().strip() 5 | 6 | with open("onnxslim/version.py", "w") as f: 7 | f.write(f'__version__ = "{version}"\n') 8 | 9 | setup( 10 | name="onnxslim", 11 | version=version, 12 | description="OnnxSlim: A Toolkit to Help Optimize Onnx Model", 13 | long_description=open("README.md", encoding="utf-8").read(), 14 | long_description_content_type="text/markdown", 15 | url="https://github.com/inisis/OnnxSlim", 16 | author="inisis", 17 | author_email="desmond.yao@buaa.edu.cn", 18 | project_urls={ 19 | "Bug Tracker": "https://github.com/inisis/OnnxSlim/issues", 20 | }, 21 | classifiers=[ 22 | "Programming Language :: Python :: 3", 23 | "License :: OSI Approved :: MIT License", 24 | "Intended Audience :: Developers", 25 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 26 | ], 27 | license="MIT", 28 | install_requires=["onnx", "sympy", "packaging"], 29 | packages=find_packages(exclude=("tests", "tests.*")), 30 | entry_points={"console_scripts": ["onnxslim=onnxslim.cli:main"]}, 31 | zip_safe=True, 32 | python_requires=">=3.6", 33 | ) 34 | -------------------------------------------------------------------------------- /tests/test_benchmark.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import tempfile 4 | 5 | import numpy as np 6 | import onnxruntime as ort 7 | import pytest 8 | 9 | ort.set_default_logger_severity(3) 10 | 11 | from onnxslim.utils import print_model_info_as_table, summarize_model 12 | 13 | MODELZOO_PATH = "/data/modelzoo" 14 | 15 | 16 | def bench_main(command): 17 | result = subprocess.run(command, shell=True, capture_output=True, text=True) 18 | return result 19 | 20 | 21 | def bench_onnxslim(input, output): 22 | command = f"onnxslim {input} {output}" 23 | result = bench_main(command) 24 | return result 25 | 26 | 27 | def bench_onnxsim(input, output): 28 | command = f"onnxsim {input} {output}" 29 | result = bench_main(command) 30 | return result 31 | 32 | 33 | def bench_polygraphy(input, output): 34 | command = f"polygraphy surgeon sanitize --fold-constants {input} -o {output}" 35 | result = bench_main(command) 36 | return result 37 | 38 | 39 | def bench_onnxruntime(input, output): 40 | try: 41 | import onnxruntime as rt 42 | 43 | sess_options = rt.SessionOptions() 44 | # Set graph optimization level 45 | sess_options.graph_optimization_level = rt.GraphOptimizationLevel.ORT_ENABLE_EXTENDED 46 | # To enable model serialization after graph optimization set this 47 | sess_options.optimized_model_filepath = output 48 | rt.InferenceSession(input, sess_options) 49 | return True 50 | 51 | except Exception as e: 52 | print(e) 53 | return None 54 | 55 | 56 | def bench_onnxruntime(input, output): 57 | try: 58 | import onnxruntime as rt 59 | 60 | sess_options = rt.SessionOptions() 61 | # Set graph optimization level 62 | sess_options.graph_optimization_level = rt.GraphOptimizationLevel.ORT_ENABLE_EXTENDED 63 | # To enable model serialization after graph optimization set this 64 | sess_options.optimized_model_filepath = output 65 | rt.InferenceSession(input, sess_options) 66 | return True 67 | 68 | except Exception as e: 69 | print(e) 70 | return None 71 | 72 | 73 | def bench_onnxruntime(input, output): 74 | try: 75 | import onnxruntime as rt 76 | 77 | sess_options = rt.SessionOptions() 78 | # Set graph optimization level 79 | sess_options.graph_optimization_level = rt.GraphOptimizationLevel.ORT_ENABLE_EXTENDED 80 | # To enable model serialization after graph optimization set this 81 | sess_options.optimized_model_filepath = output 82 | rt.InferenceSession(input, sess_options) 83 | return True 84 | 85 | except Exception as e: 86 | print(e) 87 | return None 88 | 89 | 90 | class TestModelZoo: 91 | def transform_and_check(self, name, filename, transformation_func, suffix, check_func): 92 | with tempfile.TemporaryDirectory() as tempdir: 93 | output_file = os.path.join(tempdir, f"{name}_{suffix}.onnx") 94 | result = transformation_func(filename, output_file) 95 | if result is None: 96 | return None 97 | if result is True or (hasattr(result, "returncode") and result.returncode == 0): 98 | if check_func: 99 | try: 100 | check_func(output_file) 101 | except: 102 | return None 103 | return summarize_model(output_file, suffix) 104 | return None 105 | 106 | def run_model_test(self, name, filename, check_func=None): 107 | summary_list = [summarize_model(filename)] 108 | summary_list.append(self.transform_and_check(name, filename, bench_onnxslim, "onnxslim", check_func)) 109 | summary_list.append(self.transform_and_check(name, filename, bench_onnxsim, "onnxsim", check_func)) 110 | summary_list.append(self.transform_and_check(name, filename, bench_polygraphy, "polygraphy", check_func)) 111 | summary_list.append(self.transform_and_check(name, filename, bench_onnxruntime, "onnxruntime", check_func)) 112 | 113 | summary_list = [summary for summary in summary_list if summary is not None] 114 | 115 | print() 116 | print_model_info_as_table(summary_list) 117 | 118 | def test_silero_vad(self, request): 119 | def check_model_inference(model_path): 120 | batch_size = 2 121 | input_data = np.zeros((batch_size, 256), dtype=np.float32) 122 | sr = np.array(16000) 123 | state = np.zeros((2, batch_size, 128), dtype=np.float32) 124 | 125 | ort_sess = ort.InferenceSession(model_path) 126 | outputs = ort_sess.run(None, {"input": input_data, "sr": sr, "state": state}) 127 | assert outputs is not None, "Inference failed on transformed model." 128 | 129 | name = request.node.originalname[len("test_") :] 130 | filename = f"{MODELZOO_PATH}/{name}/{name}.onnx" 131 | self.run_model_test(name, filename, check_model_inference) 132 | 133 | def test_decoder_with_past_model(self, request): 134 | def check_model_inference(model_path): 135 | batch_size = 2 136 | input_ids = np.ones((batch_size, 256), dtype=np.int64) 137 | encoder_hidden_states = np.zeros((batch_size, 128, 16), dtype=np.float32) 138 | 139 | ort_sess = ort.InferenceSession(model_path) 140 | ort_sess.run(None, {"input_ids": input_ids, "encoder_hidden_states": encoder_hidden_states}) 141 | 142 | name = request.node.originalname[len("test_") :] 143 | filename = f"{MODELZOO_PATH}/{name}/{name}.onnx" 144 | self.run_model_test(name, filename, check_model_inference) 145 | 146 | def test_tiny_en_decoder(self, request): 147 | name = request.node.originalname[len("test_") :] 148 | filename = f"{MODELZOO_PATH}/{name}/{name}.onnx" 149 | self.run_model_test(name, filename) 150 | 151 | def test_transformer_encoder(self, request): 152 | name = request.node.originalname[len("test_") :] 153 | filename = f"{MODELZOO_PATH}/{name}/{name}.onnx" 154 | self.run_model_test(name, filename) 155 | 156 | def test_uiex(self, request): 157 | name = request.node.originalname[len("test_") :] 158 | filename = f"{MODELZOO_PATH}/{name}/{name}.onnx" 159 | self.run_model_test(name, filename) 160 | 161 | def test_en_number_mobile_v2_0_rec_infer(self, request): 162 | name = request.node.originalname[len("test_") :] 163 | filename = f"{MODELZOO_PATH}/{name}/{name}.onnx" 164 | self.run_model_test(name, filename) 165 | 166 | def test_paddleocr(self, request): 167 | name = request.node.originalname[len("test_") :] 168 | filename = f"{MODELZOO_PATH}/{name}/{name}.onnx" 169 | self.run_model_test(name, filename) 170 | 171 | 172 | if __name__ == "__main__": 173 | import sys 174 | 175 | sys.exit( 176 | pytest.main( 177 | [ 178 | "-p", 179 | "no:warnings", 180 | "-sv", 181 | "tests/test_benchmark.py", 182 | ] 183 | ) 184 | ) 185 | -------------------------------------------------------------------------------- /tests/test_folder.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import os 4 | import subprocess 5 | 6 | import pytest 7 | 8 | 9 | def parse_arguments(): 10 | """Parses command-line arguments for specifying the ONNX model directory.""" 11 | parser = argparse.ArgumentParser(description="Test script for ONNX models") 12 | parser.add_argument( 13 | "--model-dir", 14 | type=str, 15 | required=True, 16 | help="Directory containing ONNX model files", 17 | ) 18 | return parser.parse_args() 19 | 20 | 21 | args = parse_arguments() 22 | 23 | 24 | @pytest.fixture(params=glob.glob(f"{args.model_dir}/*/*.onnx")) 25 | def model_file(request): 26 | """Yields ONNX model file paths from the specified directory for parameterized testing.""" 27 | yield request.param 28 | 29 | 30 | def test_model_file(model_file): 31 | """Tests the slimming of an ONNX model file using the onnxslim command, and validates the process by checking the 32 | command output. 33 | """ 34 | slim_model_file = model_file.replace(".onnx", "_slim.onnx") 35 | command = f"onnxslim {model_file} {slim_model_file}" 36 | result = subprocess.run(command, shell=True, capture_output=True, text=True) 37 | if result.returncode != 0: 38 | print(result.stderr) 39 | raise AssertionError("Failed to slim model") 40 | else: 41 | output = result.stdout 42 | print(f"\n{output}") 43 | os.remove(slim_model_file) 44 | slim_data_file = model_file.replace(".onnx", "_slim.onnx.data") 45 | if os.path.exists(slim_data_file): 46 | os.remove(slim_data_file) 47 | 48 | 49 | if __name__ == "__main__": 50 | import sys 51 | 52 | sys.exit( 53 | pytest.main( 54 | [ 55 | "-p", 56 | "no:warnings", 57 | "-sv", 58 | "tests/test_folder.py", 59 | ] 60 | ) 61 | ) 62 | -------------------------------------------------------------------------------- /tests/test_modelzoo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | 4 | import numpy as np 5 | import onnxruntime as ort 6 | import pytest 7 | 8 | from onnxslim import slim 9 | from onnxslim.utils import print_model_info_as_table, summarize_model 10 | 11 | MODELZOO_PATH = "/data/modelzoo" 12 | 13 | 14 | class TestModelZoo: 15 | def test_silero_vad(self, request): 16 | name = request.node.originalname[len("test_") :] 17 | filename = f"{MODELZOO_PATH}/{name}/{name}.onnx" 18 | 19 | with tempfile.TemporaryDirectory() as tempdir: 20 | slim(filename, os.path.join(tempdir, f"{name}_slim.onnx")) 21 | batch_size = 2 22 | input = np.zeros((batch_size, 256), dtype=np.float32) 23 | sr = np.array(16000) 24 | state = np.zeros((2, batch_size, 128), dtype=np.float32) 25 | 26 | ort_sess = ort.InferenceSession(os.path.join(tempdir, f"{name}_slim.onnx")) 27 | ort_sess.run(None, {"input": input, "sr": sr, "state": state}) 28 | 29 | def test_decoder_with_past_model(self, request): 30 | name = request.node.originalname[len("test_") :] 31 | filename = f"{MODELZOO_PATH}/{name}/{name}.onnx" 32 | 33 | with tempfile.TemporaryDirectory() as tempdir: 34 | slim(filename, os.path.join(tempdir, f"{name}_slim.onnx")) 35 | batch_size = 2 36 | input_ids = np.ones((batch_size, 256), dtype=np.int64) 37 | encoder_hidden_states = np.zeros((batch_size, 128, 16), dtype=np.float32) 38 | 39 | ort_sess = ort.InferenceSession(os.path.join(tempdir, f"{name}_slim.onnx")) 40 | ort_sess.run(None, {"input_ids": input_ids, "encoder_hidden_states": encoder_hidden_states}) 41 | 42 | def test_tiny_en_decoder(self, request): 43 | name = request.node.originalname[len("test_") :] 44 | filename = f"{MODELZOO_PATH}/{name}/{name}.onnx" 45 | 46 | with tempfile.TemporaryDirectory() as tempdir: 47 | slim(filename, os.path.join(tempdir, f"{name}_slim.onnx"), model_check=True) 48 | 49 | def test_transformer_encoder(self, request): 50 | name = request.node.originalname[len("test_") :] 51 | filename = f"{MODELZOO_PATH}/{name}/{name}.onnx" 52 | summary = summarize_model(slim(filename), tag=request.node.name) 53 | print_model_info_as_table(summary) 54 | assert summary.op_type_counts["Mul"] == 57 55 | assert summary.op_type_counts["Div"] == 53 56 | 57 | def test_uiex(self, request): 58 | name = request.node.originalname[len("test_") :] 59 | filename = f"{MODELZOO_PATH}/{name}/{name}.onnx" 60 | summary = summarize_model(slim(filename), tag=request.node.name) 61 | print_model_info_as_table(summary) 62 | assert summary.op_type_counts["Range"] == 0 63 | assert summary.op_type_counts["Floor"] == 0 64 | assert summary.op_type_counts["Concat"] == 54 65 | 66 | def test_qwen_vl_vision_encoder(self, request): 67 | name = request.node.originalname[len("test_") :] 68 | filename = f"{MODELZOO_PATH}/{name}/{name}.onnx" 69 | summary = summarize_model(slim(filename), tag=request.node.name) 70 | print_model_info_as_table(summary) 71 | with tempfile.TemporaryDirectory() as tempdir: 72 | slim(filename, os.path.join(tempdir, f"{name}_slim.onnx")) 73 | import numpy as np 74 | import onnxruntime as ort 75 | 76 | ort_sess = ort.InferenceSession(os.path.join(tempdir, f"{name}_slim.onnx")) 77 | outputs = ort_sess.run( 78 | None, 79 | {"pixel_values": np.random.rand(256, 1176).astype(np.float32), "grid_thw": np.array([[1, 16, 16]])}, 80 | ) 81 | print(f"{outputs[0].shape=}") # (64, 16) 82 | 83 | def test_layer_normalization_2d_axis0_expanded_ver18(self, request): 84 | name = request.node.originalname[len("test_") :] 85 | filename = f"{MODELZOO_PATH}/{name}/{name}.onnx" 86 | 87 | with tempfile.TemporaryDirectory() as tempdir: 88 | slim(filename, os.path.join(tempdir, f"{name}_slim.onnx"), model_check=True) 89 | summary = summarize_model(os.path.join(tempdir, f"{name}_slim.onnx"), tag=request.node.name) 90 | assert summary.op_type_counts["Reshape"] == 1 91 | 92 | def test_padconv(self, request): 93 | name = request.node.originalname[len("test_") :] 94 | filename = f"{MODELZOO_PATH}/{name}/{name}.onnx" 95 | 96 | with tempfile.TemporaryDirectory() as tempdir: 97 | slim( 98 | filename, 99 | os.path.join(tempdir, f"{name}_slim.onnx"), 100 | model_check=True, 101 | input_shapes=["/encoder/encoders0/encoders0.0/self_attn/Transpose_2_output_0:1,516,32"], 102 | ) 103 | 104 | # def test_wav2vec2_conformer(self, request): 105 | # name = request.node.originalname[len("test_") :] 106 | # filename = f"{MODELZOO_PATH}/{name}/{name}.onnx" 107 | 108 | # with tempfile.TemporaryDirectory() as tempdir: 109 | # slim(filename, os.path.join(tempdir, f"{name}_slim.onnx")) 110 | # batch_size = 2 111 | # input = np.zeros((batch_size, 256), dtype=np.float32) 112 | 113 | # ort_sess = ort.InferenceSession(os.path.join(tempdir, f"{name}_slim.onnx")) 114 | # ort_sess.run(None, {"input_values": input}) 115 | 116 | def test_yolo11n_pose(self, request): 117 | name = request.node.originalname[len("test_") :] 118 | filename = f"{MODELZOO_PATH}/{name}/{name}.onnx" 119 | 120 | with tempfile.TemporaryDirectory() as tempdir: 121 | slim(filename, os.path.join(tempdir, f"{name}_slim.onnx")) 122 | input = np.zeros((1, 3, 256, 256), dtype=np.float32) 123 | 124 | ort_sess = ort.InferenceSession(os.path.join(tempdir, f"{name}_slim.onnx")) 125 | ort_sess.run(None, {"images": input}) 126 | 127 | 128 | if __name__ == "__main__": 129 | import sys 130 | 131 | sys.exit( 132 | pytest.main( 133 | [ 134 | "-p", 135 | "no:warnings", 136 | "-sv", 137 | "tests/test_modelzoo.py", 138 | ] 139 | ) 140 | ) 141 | -------------------------------------------------------------------------------- /tests/test_onnx_nets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import subprocess 4 | import warnings 5 | 6 | import pytest 7 | import timm 8 | import torch 9 | import torchvision.models as models 10 | 11 | FUSE = True 12 | PRETRAINED = False 13 | MEMORY_LIMIT_GB = 0.75 # User's memory limit 14 | MEMORY_PER_PARAM = 4e-9 # Approximate memory required per parameter in GB 15 | 16 | os.makedirs("tmp", exist_ok=True) 17 | 18 | 19 | class TestTorchVisionClass: 20 | @pytest.mark.parametrize( 21 | "model", 22 | ( 23 | models.resnet18, 24 | models.alexnet, 25 | models.squeezenet1_0, 26 | models.googlenet, 27 | ), 28 | ) 29 | def test_torchvision(self, request, model, shape=(1, 3, 224, 224)): 30 | """Test various TorchVision models with random input tensors of a specified shape.""" 31 | model = model(pretrained=PRETRAINED) 32 | x = torch.rand(shape) 33 | directory = f"tmp/{request.node.name}" 34 | os.makedirs(directory, exist_ok=True) 35 | 36 | filename = f"{directory}/{request.node.name}.onnx" 37 | slim_filename = f"{directory}/{request.node.name}_slim.onnx" 38 | 39 | torch.onnx.export(model, x, filename) 40 | 41 | command = f"onnxslim {filename} {slim_filename}" 42 | result = subprocess.run(command, shell=True, capture_output=True, text=True) 43 | output = result.stderr.strip() 44 | # Assert the expected return code 45 | print(output) 46 | assert result.returncode == 0 47 | 48 | shutil.rmtree(directory, ignore_errors=True) 49 | 50 | 51 | class TestTimmClass: 52 | @pytest.fixture(params=timm.list_models()) 53 | def model_name(self, request): 54 | """Yields names of models available in TIMM (https://github.com/rwightman/pytorch-image-models) for pytest fixture parameterization.""" 55 | yield request.param 56 | 57 | skip_keywords = ["enormous", "giant", "huge", "xlarge"] 58 | 59 | def test_timm(self, request, model_name): 60 | """Tests a TIMM model's forward pass with a random input tensor of the appropriate size.""" 61 | if any(keyword in model_name.lower() for keyword in self.skip_keywords): 62 | pytest.skip(f"Skipping model due to size keyword in name: {model_name}") 63 | 64 | try: 65 | model = timm.create_model(model_name, pretrained=PRETRAINED) 66 | except RuntimeError as e: 67 | if "out of memory" in str(e): 68 | pytest.skip(f"Skipping model {model_name} due to memory error.") 69 | 70 | num_params = sum(p.numel() for p in model.parameters()) 71 | 72 | # Calculate estimated memory requirement 73 | estimated_memory = num_params * MEMORY_PER_PARAM 74 | 75 | if estimated_memory > MEMORY_LIMIT_GB: 76 | pytest.skip(f"Skipping model {model_name}: estimated memory {estimated_memory:.2f} GB exceeds limit.") 77 | 78 | input_size = model.default_cfg.get("input_size") 79 | x = torch.randn((1,) + input_size) 80 | directory = f"tmp/{request.node.name}" 81 | try: 82 | os.makedirs(directory, exist_ok=True) 83 | 84 | filename = f"{directory}/{request.node.name}.onnx" 85 | slim_filename = f"{directory}/{request.node.name}_slim.onnx" 86 | torch.onnx.export(model, x, filename) 87 | except Exception as e: 88 | print(f"An unexpected error occurred: {str(e)}") 89 | return 90 | if not os.path.exists(filename): 91 | return 92 | 93 | command = f"onnxslim {filename} {slim_filename}" 94 | result = subprocess.run(command, shell=True, capture_output=True, text=True) 95 | output = result.stderr.strip() 96 | # Assert the expected return code 97 | print(output) 98 | assert result.returncode == 0 99 | 100 | shutil.rmtree(directory, ignore_errors=True) 101 | 102 | 103 | if __name__ == "__main__": 104 | warnings.filterwarnings("ignore") 105 | import sys 106 | 107 | sys.exit(pytest.main(["-p", "no:warnings", "-v", "tests/test_onnx_nets.py"])) 108 | -------------------------------------------------------------------------------- /tests/test_onnxslim.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import tempfile 4 | 5 | import numpy as np 6 | import pytest 7 | 8 | from onnxslim import slim 9 | from onnxslim.utils import summarize_model 10 | 11 | MODELZOO_PATH = "/data/modelzoo" 12 | FILENAME = f"{MODELZOO_PATH}/resnet18/resnet18.onnx" 13 | 14 | 15 | class TestFunctional: 16 | def run_basic_test(self, in_model_path, out_model_path, **kwargs): 17 | check_func = kwargs.get("check_func", None) 18 | kwargs_api = kwargs.get("api", {}) 19 | kwargs_bash = kwargs.get("bash", "") 20 | summary = summarize_model(slim(in_model_path, **kwargs_api), in_model_path) 21 | if check_func: 22 | check_func(summary) 23 | 24 | slim(in_model_path, out_model_path, **kwargs_api) 25 | summary = summarize_model(out_model_path, out_model_path) 26 | if check_func: 27 | check_func(summary) 28 | 29 | command = f'onnxslim "{in_model_path}" "{out_model_path}" {kwargs_bash}' 30 | 31 | result = subprocess.run(command, shell=True, capture_output=True, text=True) 32 | output = result.stderr.strip() 33 | # Assert the expected return code 34 | print(output) 35 | assert result.returncode == 0 36 | 37 | summary = summarize_model(out_model_path, out_model_path) 38 | if check_func: 39 | check_func(summary) 40 | 41 | def test_basic(self, request): 42 | with tempfile.TemporaryDirectory() as tempdir: 43 | out_model_path = os.path.join(tempdir, "resnet18.onnx") 44 | self.run_basic_test(FILENAME, out_model_path) 45 | 46 | def test_input_shape_modification(self, request): 47 | def check_func(summary): 48 | assert summary.input_info[0].shape == (1, 3, 224, 224) 49 | 50 | with tempfile.TemporaryDirectory() as tempdir: 51 | out_model_path = os.path.join(tempdir, "resnet18.onnx") 52 | kwargs = {} 53 | kwargs["bash"] = "--input-shapes input:1,3,224,224" 54 | kwargs["api"] = {"input_shapes": ["input:1,3,224,224"]} 55 | kwargs["check_func"] = check_func 56 | self.run_basic_test(FILENAME, out_model_path, **kwargs) 57 | 58 | def test_input_modification(self, request): 59 | def check_func(summary): 60 | assert "/maxpool/MaxPool_output_0" in summary.input_maps 61 | assert "/layer1/layer1.0/relu/Relu_output_0" in summary.input_maps 62 | 63 | with tempfile.TemporaryDirectory() as tempdir: 64 | out_model_path = os.path.join(tempdir, "resnet18.onnx") 65 | kwargs = {} 66 | kwargs["bash"] = "--inputs /maxpool/MaxPool_output_0 /layer1/layer1.0/relu/Relu_output_0" 67 | kwargs["api"] = {"inputs": ["/maxpool/MaxPool_output_0", "/layer1/layer1.0/relu/Relu_output_0"]} 68 | kwargs["check_func"] = check_func 69 | self.run_basic_test(FILENAME, out_model_path, **kwargs) 70 | 71 | def test_output_modification(self, request): 72 | def check_func(summary): 73 | assert "/Flatten_output_0" in summary.output_maps 74 | 75 | with tempfile.TemporaryDirectory() as tempdir: 76 | out_model_path = os.path.join(tempdir, "resnet18.onnx") 77 | kwargs = {} 78 | kwargs["bash"] = "--outputs /Flatten_output_0" 79 | kwargs["api"] = {"outputs": ["/Flatten_output_0"]} 80 | kwargs["check_func"] = check_func 81 | self.run_basic_test(FILENAME, out_model_path, **kwargs) 82 | 83 | def test_dtype_conversion(self, request): 84 | def check_func_fp16(summary): 85 | assert summary.input_info[0].dtype == np.float16 86 | 87 | def check_func_fp32(summary): 88 | assert summary.input_info[0].dtype == np.float32 89 | 90 | with tempfile.TemporaryDirectory() as tempdir: 91 | out_fp16_model_path = os.path.join(tempdir, "resnet18_fp16.onnx") 92 | kwargs = {} 93 | kwargs["bash"] = "--dtype fp16" 94 | kwargs["api"] = {"dtype": "fp16"} 95 | kwargs["check_func"] = check_func_fp16 96 | self.run_basic_test(FILENAME, out_fp16_model_path, **kwargs) 97 | 98 | out_fp32_model_path = os.path.join(tempdir, "resnet18_fp32.onnx") 99 | kwargs = {} 100 | kwargs["bash"] = "--dtype fp32" 101 | kwargs["api"] = {"dtype": "fp32"} 102 | kwargs["check_func"] = check_func_fp32 103 | self.run_basic_test(out_fp16_model_path, out_fp32_model_path, **kwargs) 104 | 105 | def test_save_as_external_data(self, request): 106 | with tempfile.TemporaryDirectory() as tempdir: 107 | out_model_path = os.path.join(tempdir, "resnet18.onnx") 108 | kwargs = {} 109 | kwargs["bash"] = "--save-as-external-data" 110 | kwargs["api"] = {"save_as_external_data": True} 111 | self.run_basic_test(FILENAME, out_model_path, **kwargs) 112 | assert os.path.getsize(out_model_path) < 1e5 113 | 114 | def test_model_check(self, request): 115 | with tempfile.TemporaryDirectory() as tempdir: 116 | out_model_path = os.path.join(tempdir, "resnet18.onnx") 117 | input_data = os.path.join(tempdir, "input.npy") 118 | test_data = np.random.random((1, 3, 224, 224)).astype(np.float32) 119 | np.save(input_data, test_data) 120 | kwargs = {} 121 | kwargs["bash"] = f"--model-check --model-check-inputs input:{input_data}" 122 | kwargs["api"] = {"model_check": True, "model_check_inputs": [f"input:{input_data}"]} 123 | self.run_basic_test(FILENAME, out_model_path, **kwargs) 124 | 125 | def test_inspect(self, request): 126 | with tempfile.TemporaryDirectory(): 127 | kwargs = {} 128 | kwargs["bash"] = "--inspect" 129 | kwargs["api"] = {"inspect": True} 130 | self.run_basic_test(FILENAME, FILENAME, **kwargs) 131 | 132 | 133 | if __name__ == "__main__": 134 | import sys 135 | 136 | sys.exit( 137 | pytest.main( 138 | [ 139 | "-p", 140 | "no:warnings", 141 | "-v", 142 | "tests/test_onnxslim.py", 143 | ] 144 | ) 145 | ) 146 | -------------------------------------------------------------------------------- /tests/test_pattern_generator.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import onnx 4 | import pytest 5 | import torch 6 | import torch.nn as nn 7 | 8 | from onnxslim import register_fusion_pattern, slim 9 | from onnxslim.core.pattern import Pattern, PatternGenerator, PatternMatcher 10 | 11 | 12 | class TestPatternGenerator: 13 | def test_gelu(self, request): 14 | """Test the GELU activation function within the PatternModel class.""" 15 | 16 | class PatternModel(nn.Module): 17 | def __init__(self): 18 | super().__init__() 19 | self.gelu = nn.GELU() 20 | 21 | def forward(self, x): 22 | """Applies the GELU activation function to the input tensor.""" 23 | x = self.gelu(x) 24 | return x 25 | 26 | class Model(nn.Module): 27 | def __init__(self): 28 | """Initializes the Model class with ReLU and PatternModel components.""" 29 | super().__init__() 30 | self.relu0 = nn.ReLU() 31 | self.pattern = PatternModel() 32 | self.relu1 = nn.ReLU() 33 | 34 | def forward(self, x): 35 | """Applies the ReLU activation function, the PatternModel, and another ReLU activation sequentially to 36 | the input tensor. 37 | """ 38 | x = self.relu0(x) 39 | x = self.pattern(x) 40 | x = self.relu1(x) 41 | return x 42 | 43 | input = torch.randn(2) 44 | p = PatternModel() 45 | m = Model() 46 | directory = f"tmp/{request.node.name}" 47 | os.makedirs(directory, exist_ok=True) 48 | 49 | pattern_filename = f"{directory}/{request.node.name}.onnx" 50 | torch.onnx.export(p, input, pattern_filename) 51 | 52 | model_filename = f"{directory}/{request.node.name}.onnx" 53 | torch.onnx.export(m, input, model_filename) 54 | 55 | model = onnx.load(pattern_filename) 56 | pgen = PatternGenerator(model) 57 | template = pgen.generate() 58 | pattern = Pattern(template) 59 | 60 | class GeluMatcher(PatternMatcher): 61 | def __init__(self, pattern, priority): 62 | """Initialize a GeluMatcher with a given pattern and priority.""" 63 | super().__init__(pattern, priority) 64 | 65 | @property 66 | def name(self): 67 | """Return the name of the matcher as 'FusionGelu'.""" 68 | return "FusionGelu" 69 | 70 | def rewrite(self, opset=11): 71 | """Raise an exception indicating a pattern match in GeluMatcher.""" 72 | raise Exception("Pattern Matched") 73 | 74 | register_fusion_pattern(GeluMatcher(pattern, 1)) 75 | with pytest.raises(Exception) as excinfo: 76 | slim(model_filename, f"{directory}/{request.node.name}_slim.onnx") 77 | 78 | assert str(excinfo.value) == "Pattern Matched" 79 | 80 | 81 | if __name__ == "__main__": 82 | import sys 83 | 84 | sys.exit( 85 | pytest.main( 86 | [ 87 | "-p", 88 | "no:warnings", 89 | "-sv", 90 | "tests/test_pattern_generator.py", 91 | ] 92 | ) 93 | ) 94 | -------------------------------------------------------------------------------- /tests/test_pattern_matcher.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | import torch 5 | import torch.nn as nn 6 | 7 | from onnxslim import slim 8 | from onnxslim.utils import print_model_info_as_table, summarize_model 9 | 10 | 11 | class TestPatternMatcher: 12 | def test_gelu(self, request): 13 | """Test the GELU activation function in a neural network model using an instance of nn.Module.""" 14 | 15 | class Model(nn.Module): 16 | def __init__(self): 17 | super().__init__() 18 | self.relu0 = nn.ReLU() 19 | self.gelu = nn.GELU() 20 | self.relu1 = nn.ReLU() 21 | 22 | def forward(self, x): 23 | """Performs a forward pass through the model applying ReLU, GELU, and ReLU activations sequentially to 24 | the input tensor x. 25 | """ 26 | x = self.relu0(x) 27 | x = self.gelu(x) 28 | x = self.relu1(x) 29 | return x 30 | 31 | input = torch.randn(2) 32 | m = Model() 33 | directory = f"tmp/{request.node.name}" 34 | os.makedirs(directory, exist_ok=True) 35 | 36 | filename = f"{directory}/{request.node.name}.onnx" 37 | torch.onnx.export(m, input, filename) 38 | 39 | summary = summarize_model(slim(filename, model_check=True), request.node.name) 40 | print_model_info_as_table(summary) 41 | 42 | def test_pad_conv(self, request): 43 | """Test padding followed by 2D convolution within a neural network module.""" 44 | 45 | class Model(nn.Module): 46 | def __init__(self): 47 | super().__init__() 48 | self.pad_0 = nn.ConstantPad2d(3, 0) 49 | self.conv_0 = nn.Conv2d(1, 1, 3) 50 | 51 | self.pad_1 = nn.ConstantPad2d(3, 2) 52 | self.conv_1 = nn.Conv2d(1, 1, 3, bias=False) 53 | 54 | def forward(self, x): 55 | """Applies padding and convolutional layers to the input tensor x.""" 56 | x0 = self.pad_0(x) 57 | x0 = self.conv_0(x0) 58 | 59 | x1 = self.pad_1(x) 60 | x1 = self.conv_1(x1) 61 | 62 | return x0 + x1 63 | 64 | input = torch.randn(1, 1, 24, 24) 65 | m = Model() 66 | directory = f"tmp/{request.node.name}" 67 | os.makedirs(directory, exist_ok=True) 68 | 69 | filename = f"{directory}/{request.node.name}.onnx" 70 | torch.onnx.export(m, input, filename) 71 | 72 | summary = summarize_model(slim(filename, model_check=True), request.node.name) 73 | print_model_info_as_table(summary) 74 | 75 | assert summary.op_type_counts["Conv"] == 2 76 | assert summary.op_type_counts["Add"] == 1 77 | 78 | def test_conv_bn(self, request): 79 | """Test the convolutional layer followed by batch normalization export and re-import via ONNX.""" 80 | 81 | class Model(nn.Module): 82 | def __init__(self): 83 | super().__init__() 84 | self.conv = nn.Conv2d(1, 1, 3) 85 | self.bn = nn.BatchNorm2d(1) 86 | 87 | def forward(self, x): 88 | """Perform convolution followed by batch normalization on input tensor x.""" 89 | x = self.conv(x) 90 | x = self.bn(x) 91 | return x 92 | 93 | input = torch.randn(1, 1, 24, 24) 94 | m = Model() 95 | directory = f"tmp/{request.node.name}" 96 | os.makedirs(directory, exist_ok=True) 97 | 98 | filename = f"{directory}/{request.node.name}.onnx" 99 | torch.onnx.export(m, input, filename, do_constant_folding=False) 100 | 101 | summary = summarize_model(slim(filename, model_check=True), request.node.name) 102 | print_model_info_as_table(summary) 103 | assert summary.op_type_counts["Conv"] == 1 104 | 105 | def test_consecutive_slice(self, request): 106 | """Tests consecutive slicing operations on a model by exporting it to ONNX format and then slimming the ONNX 107 | file. 108 | """ 109 | 110 | class Model(nn.Module): 111 | def __init__(self): 112 | super().__init__() 113 | self.conv = nn.Conv2d(1, 1, 3) 114 | self.bn = nn.BatchNorm2d(1) 115 | 116 | def forward(self, x): 117 | """Performs slicing operation on the input tensor x by returning the section x[1:2, :2].""" 118 | return x[1:2, :2] 119 | 120 | input = torch.randn(3, 4) 121 | m = Model() 122 | directory = f"tmp/{request.node.name}" 123 | os.makedirs(directory, exist_ok=True) 124 | 125 | filename = f"{directory}/{request.node.name}.onnx" 126 | torch.onnx.export(m, input, filename) 127 | 128 | summary = summarize_model(slim(filename, model_check=True), request.node.name) 129 | print_model_info_as_table(summary) 130 | assert summary.op_type_counts["Slice"] == 1 131 | 132 | def test_consecutive_reshape(self, request): 133 | """Test the functionality of consecutive reshape operations in a model and export it to ONNX format.""" 134 | 135 | class Model(nn.Module): 136 | def __init__(self): 137 | super().__init__() 138 | 139 | def forward(self, x): 140 | """Reshape tensor sequentially to (2, 6) and then to (12, 1).""" 141 | return x.view(2, 6).view(12, 1) 142 | 143 | input = torch.randn(3, 4) 144 | m = Model() 145 | directory = f"tmp/{request.node.name}" 146 | os.makedirs(directory, exist_ok=True) 147 | 148 | filename = f"{directory}/{request.node.name}.onnx" 149 | torch.onnx.export(m, input, filename) 150 | 151 | summary = summarize_model(slim(filename, model_check=True), request.node.name) 152 | print_model_info_as_table(summary) 153 | assert summary.op_type_counts["Reshape"] == 1 154 | 155 | def test_matmul_add(self, request): 156 | """Tests matrix multiplication followed by an addition operation within a neural network model.""" 157 | 158 | class Model(nn.Module): 159 | def __init__(self): 160 | super().__init__() 161 | self.data = torch.randn(4, 3) 162 | 163 | def forward(self, x): 164 | """Performs matrix multiplication of input 'x' with pre-defined data, adds 1, and returns the result.""" 165 | x = torch.matmul(x, self.data) 166 | x += 1 167 | return x 168 | 169 | input = torch.randn(3, 4) 170 | m = Model() 171 | directory = f"tmp/{request.node.name}" 172 | os.makedirs(directory, exist_ok=True) 173 | 174 | filename = f"{directory}/{request.node.name}.onnx" 175 | torch.onnx.export(m, input, filename) 176 | 177 | summary = summarize_model(slim(filename, model_check=True), request.node.name) 178 | print_model_info_as_table(summary) 179 | assert summary.op_type_counts["Gemm"] == 1 180 | 181 | def test_reduce(self, request): 182 | """Tests model reduction by exporting a PyTorch model to ONNX format, slimming it, and saving to a specified 183 | directory. 184 | """ 185 | 186 | class Model(nn.Module): 187 | def __init__(self): 188 | super().__init__() 189 | 190 | def forward(self, x): 191 | """Performs a reduction summing over the last dimension of the input tensor and then unsqueezes the 192 | tensor along the same dimension. 193 | """ 194 | x = torch.sum(x, dim=[-1], keepdim=False) 195 | x = x.unsqueeze(-1) 196 | return x 197 | 198 | input = torch.randn(3, 4) 199 | m = Model() 200 | directory = f"tmp/{request.node.name}" 201 | os.makedirs(directory, exist_ok=True) 202 | 203 | filename = f"{directory}/{request.node.name}.onnx" 204 | torch.onnx.export(m, input, filename, opset_version=11) 205 | 206 | summary = summarize_model(slim(filename, model_check=True), request.node.name) 207 | print_model_info_as_table(summary) 208 | assert summary.op_type_counts["ReduceSum"] == 1 209 | 210 | @pytest.mark.parametrize( 211 | "opset", 212 | ( 213 | 11, 214 | 13, 215 | ), 216 | ) 217 | def test_consecutive_unsqueeze(self, request, opset): 218 | class Model(nn.Module): 219 | def __init__(self): 220 | super().__init__() 221 | 222 | def forward(self, x): 223 | x = x.unsqueeze(-1) 224 | x = x.unsqueeze(-1) 225 | x = x.unsqueeze(1) 226 | x = x.unsqueeze(0) 227 | return x 228 | 229 | input = torch.randn(3, 4) 230 | m = Model() 231 | directory = f"tmp/{request.node.name}" 232 | os.makedirs(directory, exist_ok=True) 233 | 234 | filename = f"{directory}/{request.node.name}.onnx" 235 | torch.onnx.export(m, input, filename, opset_version=opset) 236 | 237 | summary = summarize_model(slim(filename, model_check=True), request.node.name) 238 | print_model_info_as_table(summary) 239 | assert summary.op_type_counts["Unsqueeze"] == 1 240 | 241 | 242 | if __name__ == "__main__": 243 | import sys 244 | 245 | sys.exit( 246 | pytest.main( 247 | [ 248 | "-p", 249 | "no:warnings", 250 | "-sv", 251 | "tests/test_pattern_matcher.py", 252 | ] 253 | ) 254 | ) 255 | --------------------------------------------------------------------------------