├── .github
├── FUNDING.yml
└── workflows
│ ├── ci.yml
│ ├── format.yml
│ ├── nightly-test.yml
│ └── python-publish.yml
├── .gitignore
├── LICENSE
├── MANIFEST.in
├── README.md
├── VERSION
├── bin
└── onnxslim
├── docs
├── _static
│ └── style.css
├── _templates
│ └── layout.html
├── conf.py
├── index.rst
├── main
│ └── toc.rst
└── requirements.txt
├── examples
├── common_subexpression_elimination
│ ├── README.md
│ └── cse_demo.py
├── input_shape_modification
│ └── README.md
├── model_inspect
│ └── README.md
└── output_modification
│ └── README.md
├── format.sh
├── images
├── after_cse.png
├── before_cse.png
├── cse.png
├── input_shape_modification.jpg
├── model_inspect.jpg
├── onnxslim.gif
└── output_modification.jpg
├── onnxslim
├── __init__.py
├── __main__.py
├── argparser.py
├── cli
│ ├── __init__.py
│ └── _main.py
├── core
│ ├── __init__.py
│ ├── optimization
│ │ ├── __init__.py
│ │ ├── dead_node_elimination.py
│ │ ├── subexpression_elimination.py
│ │ └── weight_tying.py
│ └── pattern
│ │ ├── __init__.py
│ │ ├── elimination
│ │ ├── __init__.py
│ │ ├── concat.py
│ │ ├── reshape.py
│ │ ├── reshape_as.py
│ │ ├── slice.py
│ │ └── unsqueeze.py
│ │ ├── fusion
│ │ ├── __init__.py
│ │ ├── convadd.py
│ │ ├── convbn.py
│ │ ├── gelu.py
│ │ ├── gemm.py
│ │ ├── padconv.py
│ │ └── reduce.py
│ │ └── registry.py
├── misc
│ ├── __init__.py
│ ├── font.py
│ └── tabulate.py
├── third_party
│ ├── __init__.py
│ ├── onnx_graphsurgeon
│ │ ├── __init__.py
│ │ ├── exporters
│ │ │ ├── __init__.py
│ │ │ ├── base_exporter.py
│ │ │ └── onnx_exporter.py
│ │ ├── graph_pattern
│ │ │ ├── __init__.py
│ │ │ └── graph_pattern.py
│ │ ├── importers
│ │ │ ├── __init__.py
│ │ │ ├── base_importer.py
│ │ │ └── onnx_importer.py
│ │ ├── ir
│ │ │ ├── __init__.py
│ │ │ ├── function.py
│ │ │ ├── graph.py
│ │ │ ├── node.py
│ │ │ └── tensor.py
│ │ ├── logger
│ │ │ ├── __init__.py
│ │ │ └── logger.py
│ │ └── util
│ │ │ ├── __init__.py
│ │ │ ├── exception.py
│ │ │ └── misc.py
│ └── symbolic_shape_infer.py
└── utils.py
├── setup.py
└── tests
├── test_benchmark.py
├── test_folder.py
├── test_modelzoo.py
├── test_onnx_nets.py
├── test_onnxslim.py
├── test_pattern_generator.py
├── test_pattern_matcher.py
├── test_yolo.py
└── utils.py
/.github/FUNDING.yml:
--------------------------------------------------------------------------------
1 | # These are supported funding model platforms
2 |
3 | github: inisis
4 | patreon: # Replace with a single Patreon username
5 | open_collective: # Replace with a single Open Collective username
6 | ko_fi: # Replace with a single Ko-fi username
7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
9 | liberapay: # Replace with a single Liberapay username
10 | issuehunt: # Replace with a single IssueHunt username
11 | lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry
12 | polar: # Replace with a single Polar username
13 | buy_me_a_coffee: # Replace with a single Buy Me a Coffee username
14 | thanks_dev: # Replace with a single thanks.dev username
15 | custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']
16 |
--------------------------------------------------------------------------------
/.github/workflows/ci.yml:
--------------------------------------------------------------------------------
1 | name: CI
2 |
3 | on:
4 | push:
5 | branches: ["main"]
6 | pull_request:
7 | branches: ["main"]
8 |
9 | jobs:
10 | test:
11 | runs-on: self-hosted
12 | strategy:
13 | matrix:
14 | python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
15 |
16 | steps:
17 | - uses: actions/checkout@v3
18 |
19 | - uses: actions/setup-python@v4
20 | with:
21 | python-version: ${{ matrix.python-version }}
22 |
23 | - name: install dependency
24 | run: |
25 | python -m pip install --upgrade pip wheel setuptools
26 | pip install .
27 | pip install pytest onnxruntime
28 | pip install pytest pytest-xdist onnxruntime timm torchvision --no-cache-dir --extra-index-url https://download.pytorch.org/whl/cpu
29 | pip install coverage
30 |
31 | - name: yolo test
32 | if: matrix.python-version <= '3.11'
33 | run: |
34 | pip install ultralytics
35 | coverage run -m pytest tests/test_yolo.py -sv
36 |
37 | - name: onnxslim api and binary test
38 | run: |
39 | pip install onnxconverter_common
40 | coverage run -m pytest tests/test_onnxslim.py
41 |
42 | - name: model zoo test
43 | run: |
44 | coverage run -m pytest tests/test_modelzoo.py
45 |
46 | - name: pattern matcher test
47 | run: |
48 | coverage run -m pytest tests/test_pattern_matcher.py
49 |
50 | - name: pattern generator test
51 | run: |
52 | coverage run -m pytest tests/test_pattern_generator.py
53 |
54 | - name: Merge Coverage Reports
55 | run: |
56 | coverage xml -o coverage-ci.xml
57 |
58 | - name: Upload coverage reports to Codecov
59 | uses: codecov/codecov-action@v5
60 | with:
61 | token: ${{ secrets.CODECOV_TOKEN }}
62 |
--------------------------------------------------------------------------------
/.github/workflows/format.yml:
--------------------------------------------------------------------------------
1 | # Ultralytics 🚀 - AGPL-3.0 license
2 | # Ultralytics Actions https://github.com/ultralytics/actions
3 | # This workflow automatically formats code and documentation in PRs to official Ultralytics standards
4 |
5 | name: Ultralytics Actions
6 |
7 | on:
8 | push:
9 | branches: [main]
10 | pull_request_target:
11 | branches: [main]
12 | types: [opened, closed, synchronize]
13 |
14 | jobs:
15 | format:
16 | runs-on: ubuntu-latest
17 | steps:
18 | - name: Run Ultralytics Formatting
19 | uses: ultralytics/actions@main
20 | with:
21 | token: ${{ secrets.GITHUB_TOKEN }} # automatically generated, do not modify
22 | python: true # format Python code and docstrings
23 | markdown: true # format Markdown
24 | prettier: true # format YAML
25 | spelling: true # check spelling
26 | links: false # check broken links
27 |
--------------------------------------------------------------------------------
/.github/workflows/nightly-test.yml:
--------------------------------------------------------------------------------
1 | name: nightly-test
2 |
3 | on:
4 | schedule:
5 | - cron: "0 18 * * *" # Runs at 6:00 PM UTC every day, which is 2:00 AM Beijing Time the next day
6 |
7 | jobs:
8 | build:
9 | runs-on: self-hosted
10 |
11 | steps:
12 | - uses: actions/checkout@v3
13 |
14 | - uses: actions/setup-python@v4
15 | with:
16 | python-version: "3.10"
17 |
18 | - name: install dependency
19 | run: |
20 | python -m pip install --upgrade pip wheel setuptools
21 | pip install .
22 | pip install pytest onnxruntime
23 |
24 | - name: benchmark test
25 | run: |
26 | python tests/test_benchmark.py
27 |
28 | - name: model test
29 | run: |
30 | pip install .
31 | pip install pytest pytest-xdist onnxruntime timm torchvision --no-cache-dir --extra-index-url https://download.pytorch.org/whl/cpu
32 | python tests/test_onnx_nets.py
33 |
--------------------------------------------------------------------------------
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | # This workflow will upload a Python Package using Twine when a release is created
2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries
3 |
4 | # This workflow uses actions that are not certified by GitHub.
5 | # They are provided by a third-party and are governed by
6 | # separate terms of service, privacy policy, and support
7 | # documentation.
8 |
9 | name: Upload Python Package
10 |
11 | on:
12 | release:
13 | types: [published]
14 |
15 | permissions:
16 | contents: read
17 |
18 | jobs:
19 | deploy:
20 | runs-on: ubuntu-latest
21 |
22 | steps:
23 | - uses: actions/checkout@v3
24 | - name: Set up Python
25 | uses: actions/setup-python@v3
26 | with:
27 | python-version: "3.x"
28 | - name: Install dependencies
29 | run: |
30 | python -m pip install --upgrade pip wheel setuptools
31 | pip install build
32 | - name: Build package
33 | run: python -m build
34 | - name: Publish package
35 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
36 | with:
37 | user: __token__
38 | password: ${{ secrets.PYPI_API_TOKEN }}
39 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | tmp/
27 | # PyInstaller
28 | # Usually these files are written by a python script from a template
29 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
30 | *.manifest
31 | *.spec
32 |
33 | # Installer logs
34 | pip-log.txt
35 | pip-delete-this-directory.txt
36 |
37 | # Unit test / coverage reports
38 | htmlcov/
39 | .tox/
40 | .coverage
41 | .coverage.*
42 | .cache
43 | nosetests.xml
44 | coverage.xml
45 | *,cover
46 | .hypothesis/
47 |
48 | # Translations
49 | *.mo
50 | *.pot
51 |
52 | # Django stuff:
53 | *.log
54 | local_settings.py
55 |
56 | # Flask stuff:
57 | instance/
58 | .webassets-cache
59 |
60 | # Scrapy stuff:
61 | .scrapy
62 |
63 | # Sphinx documentation
64 | docs/_build/
65 |
66 | # PyBuilder
67 | target/
68 |
69 | # IPython Notebook
70 | .ipynb_checkpoints
71 |
72 | # pyenv
73 | .python-version
74 |
75 | # celery beat schedule file
76 | celerybeat-schedule
77 |
78 | # dotenv
79 | .env
80 |
81 | # virtualenv
82 | venv/
83 | ENV/
84 |
85 | # Spyder project settings
86 | .spyderproject
87 |
88 | # Rope project settings
89 | .ropeproject
90 |
91 | # atom remote-sync package
92 | .remote-sync.json
93 |
94 | # weights
95 | weights/
96 |
97 | #DS_Store
98 | .DS_Store
99 |
100 | # dev stuff
101 | eval/
102 | eval.ipynb
103 | dev.ipynb
104 | .vscode/
105 |
106 | # not ready
107 | videos/
108 | templates/
109 | data/ssd_dataloader.py
110 | data/datasets/
111 | doc/visualize.py
112 | read_results.py
113 | ssd300_120000/
114 | demos/live
115 | webdemo.py
116 | test_data_aug.py
117 |
118 | # attributes
119 |
120 | # pycharm
121 | .idea/
122 |
123 | # temp checkout soln
124 | data/datasets/
125 | data/ssd_dataloader.py
126 | tests/model
127 |
128 | # pylint
129 | .pylintrc
130 |
131 | # onnx
132 | *.onnx
133 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 inisis
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include VERSION
2 | recursive-exclude tests *
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # OnnxSlim
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 | OnnxSlim can help you slim your onnx model, with less operators, but same accuracy, better inference speed.
22 |
23 | - 🚀 2025/05/17: OnnxSlim is merged into [optimum](https://github.com/huggingface/optimum) 🤗🤗🤗
24 | - 🚀 2025/04/30: Rank 1st in the [AICAS 2025 LLM inference optimization challenge](https://tianchi.aliyun.com/competition/entrance/532289/customize588)
25 | - 🚀 2025/01/28: Achieved 1M downloads
26 | - 🚀 2024/06/23: OnnxSlim is merged into [transformers.js](https://github.com/huggingface/transformers.js) 🤗🤗🤗
27 | - 🚀 2024/06/02: OnnxSlim is merged into [ultralytics](https://github.com/ultralytics/ultralytics) ❤️❤️❤️
28 | - 🚀 2024/04/30: Rank 1st in the [AICAS 2024 LLM inference optimization challenge](https://tianchi.aliyun.com/competition/entrance/532170/customize440) held by Arm and T-head
29 | - 🚀 2024/01/25: OnnxSlim is merged to [mnn-llm](https://github.com/wangzhaode/mnn-llm), performance increased by 5%
30 |
31 | # Installation
32 |
33 | ## Using Prebuilt
34 |
35 | ```bash
36 | pip install onnxslim
37 | ```
38 |
39 | ## Install From Source
40 |
41 | ```bash
42 | pip install git+https://github.com/inisis/OnnxSlim@main
43 | ```
44 |
45 | ## Install From Local
46 |
47 | ```bash
48 | git clone https://github.com/inisis/OnnxSlim && cd OnnxSlim/
49 | pip install .
50 | ```
51 |
52 | # How to use
53 |
54 | ```
55 | onnxslim your_onnx_model slimmed_onnx_model
56 | ```
57 |
58 |
59 |
60 | For more usage, see onnxslim -h or refer to our [examples](./examples)
61 |
62 | # Projects using OnnxSlim
63 |
64 | -
[Mozilla/smart_autofill](https://github.com/mozilla/smart_autofill)
65 | -
[alibaba/MNN](https://github.com/alibaba/MNN)
66 | -
[PaddlePaddle/PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR)
67 | -
[huggingface/transformers.js](https://github.com/huggingface/transformers.js)
68 | -
[huggingface/optimum](https://github.com/huggingface/optimum)
69 | -
[THU-MIG/yolov10](https://github.com/THU-MIG/yolov10)
70 | -
[ultralytics/ultralytics](https://github.com/ultralytics/ultralytics)
71 | -
[ModelScope/FunASR](https://github.com/modelscope/FunASR)
72 | -
[alibaba/MNN-LLM](https://github.com/wangzhaode/mnn-llm)
73 | -
[deepghs/imgutils](https://github.com/deepghs/imgutils)
74 | -
[sunsmarterjie/yolov12](https://github.com/sunsmarterjie/yolov12)
75 | -
[nndeploy/nndeploy](https://github.com/nndeploy/nndeploy)
76 |
77 | # References
78 |
79 | > - [onnx-graphsurgeon](https://github.com/NVIDIA/TensorRT/tree/main/tools/onnx-graphsurgeon)
80 | > - [Polygraphy](https://github.com/NVIDIA/TensorRT/tree/main/tools/Polygraphy/polygraphy)
81 | > - [onnx-simplifier](https://github.com/daquexian/onnx-simplifier)
82 | > - [tabulate](https://github.com/astanin/python-tabulate)
83 | > - [onnxruntime](https://github.com/microsoft/onnxruntime)
84 |
85 | # Contact
86 |
87 | Discord: https://discord.gg/nRw2Fd3VUS QQ Group: `873569894`
88 |
--------------------------------------------------------------------------------
/VERSION:
--------------------------------------------------------------------------------
1 | 0.1.56
2 |
--------------------------------------------------------------------------------
/bin/onnxslim:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import os
4 | import sys
5 |
6 | G_SCRIPT_FILE = os.path.realpath(__file__)
7 | G_ROOT_DIR = os.path.join(os.path.dirname(G_SCRIPT_FILE), os.pardir)
8 | sys.path.insert(0, G_ROOT_DIR)
9 |
10 |
11 | from onnxslim import cli
12 |
13 | if __name__ == "__main__":
14 | sys.exit(cli.main())
--------------------------------------------------------------------------------
/docs/_static/style.css:
--------------------------------------------------------------------------------
1 | .wy-nav-content {
2 | max-width: 1100px !important;
3 | }
4 |
--------------------------------------------------------------------------------
/docs/_templates/layout.html:
--------------------------------------------------------------------------------
1 | {% extends '!layout.html' %} {% block extrahead %}
2 |
3 | {% endblock %} {% block footer %}
4 |
7 | {% endblock %}
8 |
--------------------------------------------------------------------------------
/docs/conf.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | ROOT_DIR = os.path.join(os.path.abspath(os.path.dirname(__file__)), os.path.pardir)
5 | sys.path.insert(0, ROOT_DIR)
6 |
7 |
8 | project = "OnnxSlim"
9 | copyright = "2022, WeLoveAI"
10 | author = "inisis"
11 |
12 | import onnxslim
13 |
14 | version = onnxslim.__version__
15 |
16 | extensions = [
17 | "sphinx.ext.autodoc",
18 | "sphinx.ext.intersphinx",
19 | "sphinx.ext.autosummary",
20 | "sphinx.ext.napoleon",
21 | "sphinx.ext.mathjax",
22 | ]
23 |
24 | intersphinx_mapping = {
25 | "rtd": ("https://docs.readthedocs.io/en/stable/", None),
26 | "python": ("https://docs.python.org/3/", None),
27 | "sphinx": ("https://www.sphinx-doc.org/en/master/", None),
28 | }
29 | intersphinx_disabled_domains = ["std"]
30 |
31 | templates_path = ["_templates"]
32 | epub_show_urls = "footnote"
33 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
34 | html_theme = "sphinx_rtd_theme"
35 | napoleon_preprocess_types = True
36 |
37 | html_static_path = ["_static"]
38 |
39 |
40 | def setup(app):
41 | """Configure the Sphinx app by adding a custom CSS file ('style.css')."""
42 | app.add_css_file("style.css")
43 |
--------------------------------------------------------------------------------
/docs/index.rst:
--------------------------------------------------------------------------------
1 | OnnxSlim
2 | ===================================
3 |
4 | Onnxslim is a toolkit to help optimize large onnx models
5 |
6 |
7 | .. toctree::
8 | :hidden:
9 |
10 | self
11 |
12 | .. toctree::
13 | :caption: API Reference
14 | :maxdepth: 2
15 |
16 | main/toc
17 |
--------------------------------------------------------------------------------
/docs/main/toc.rst:
--------------------------------------------------------------------------------
1 | main
2 | ============
3 |
4 | .. autoclass:: onnxslim.slim
5 |
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | sphinx
2 | sphinx_rtd_theme
3 |
--------------------------------------------------------------------------------
/examples/common_subexpression_elimination/README.md:
--------------------------------------------------------------------------------
1 | # Common SubExpression Elimination
2 |
3 | ## Introduction
4 |
5 | Common Subexpression Elimination (CSE) is a powerful optimization technique commonly employed in compilers to improve the efficiency of code execution. It targets redundant computations within a program by identifying and removing duplicate expressions, thus reducing both computational overhead and memory usage. By eliminating redundant computations, CSE enhances the overall performance of slimmed onnx model.
6 |
7 | ## How CSE Works
8 |
9 | In many programs, certain expressions are computed multiple times within a given scope, even though their results remain constant across these computations. Common subexpressions refer to these redundant expressions. CSE identifies such common subexpressions and replaces subsequent occurrences with references to the original computation result. This process effectively reduces the number of computations required during program execution.
10 |
11 | For example, consider the following code snippet:
12 |
13 | ```
14 | int a = b + c;
15 | int x = b + c;
16 | ```
17 |
18 | In this code, b + c is a common subexpression computed twice. With CSE, the redundant computation of b + c would be eliminated, and both occurrences of x would directly reference the computation result of a.
19 |
20 | ## Running the example
21 |
22 | ```bash
23 | python cse_demo.py # to generate onnx model for demo
24 | ```
25 |
26 | 
27 |
28 | There are two identical blocks that are doing the same things.
29 |
30 | ```bash
31 | onnxslim ln_cse.onnx slim.onnx
32 | ```
33 |
34 | After onnxslim, the output will look like this:
35 |
36 | 
37 |
38 | and the summary is as follow:
39 |
40 | 
41 |
--------------------------------------------------------------------------------
/examples/common_subexpression_elimination/cse_demo.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class Model(torch.nn.Module):
7 | def __init__(self):
8 | """Initializes the Model class with a single LayerNorm layer of embedding dimension 10."""
9 | super().__init__()
10 | embedding_dim = 10
11 | self.layer_norm = nn.LayerNorm(embedding_dim)
12 |
13 | def forward(self, x):
14 | """Applies LayerNorm to the input tensor and adds it to an independently computed LayerNorm of the same
15 | tensor.
16 | """
17 | return self.layer_norm(x) + F.layer_norm(x, [10])
18 |
19 |
20 | layer_norm = Model()
21 |
22 | batch, sentence_length, embedding_dim = 20, 5, 10
23 | embedding = torch.randn(batch, sentence_length, embedding_dim)
24 | torch.onnx.export(layer_norm, embedding, "ln_cse.onnx", opset_version=13)
25 |
--------------------------------------------------------------------------------
/examples/input_shape_modification/README.md:
--------------------------------------------------------------------------------
1 | # Input Shape Modification
2 |
3 | ## Introduction
4 |
5 | OnnxSlim includes an exploration of essential input shape modification techniques for ONNX models.
6 |
7 | This concise guide unveils techniques for seamlessly adjusting input tensor dimensions, ensuring optimal compatibility and performance within the dynamic landscape of neural network architectures.
8 |
9 | ## Running the example
10 |
11 | Change the input model by running:
12 |
13 | ```bash
14 | onnxslim UNetModel-fp16.onnx slim.onnx --input_shapes cc:1,1,768
15 | ```
16 |
17 | The slimmed model will look like this:
18 |
19 | 
20 |
--------------------------------------------------------------------------------
/examples/model_inspect/README.md:
--------------------------------------------------------------------------------
1 | # Model Inspect
2 |
3 | ## Introduction
4 |
5 | Dive deep into the intricacies of your ONNX model using the powerful --inspect argument with OnnxSlim. This feature provides detailed insights into various aspects of your model, including input and output details, operator information, opset version, and more.
6 |
7 | ## Running the example
8 |
9 | Unveil the secrets of your ONNX model by executing the following command:
10 |
11 | ```bash
12 | onnxslim --inspect UNetModel-fp16.onnx
13 | ```
14 |
15 | The output will look like this:
16 |
17 | 
18 |
--------------------------------------------------------------------------------
/examples/output_modification/README.md:
--------------------------------------------------------------------------------
1 | # Output Modification
2 |
3 | ## Introduction
4 |
5 | OnnxSlim provides capabilities for modifying the output specifications of ONNX models.
6 |
7 | This section explores techniques to customize the outputs, allowing for flexibility in handling diverse model requirements.
8 |
9 | ## Running the example
10 |
11 | Change the output of one model by running:
12 |
13 | ```bash
14 | onnxslim yolov5m.onnx slim.onnx --outputs 591 739 443
15 | ```
16 |
17 | The slimmed model will look like this:
18 |
19 | 
20 |
--------------------------------------------------------------------------------
/format.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 | set -ex
3 | ufmt format -- .
4 | autoflake --in-place --recursive .
5 |
--------------------------------------------------------------------------------
/images/after_cse.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/inisis/OnnxSlim/4e820ced75ed909f96b5b92bc46cebf4ca3a2156/images/after_cse.png
--------------------------------------------------------------------------------
/images/before_cse.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/inisis/OnnxSlim/4e820ced75ed909f96b5b92bc46cebf4ca3a2156/images/before_cse.png
--------------------------------------------------------------------------------
/images/cse.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/inisis/OnnxSlim/4e820ced75ed909f96b5b92bc46cebf4ca3a2156/images/cse.png
--------------------------------------------------------------------------------
/images/input_shape_modification.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/inisis/OnnxSlim/4e820ced75ed909f96b5b92bc46cebf4ca3a2156/images/input_shape_modification.jpg
--------------------------------------------------------------------------------
/images/model_inspect.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/inisis/OnnxSlim/4e820ced75ed909f96b5b92bc46cebf4ca3a2156/images/model_inspect.jpg
--------------------------------------------------------------------------------
/images/onnxslim.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/inisis/OnnxSlim/4e820ced75ed909f96b5b92bc46cebf4ca3a2156/images/onnxslim.gif
--------------------------------------------------------------------------------
/images/output_modification.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/inisis/OnnxSlim/4e820ced75ed909f96b5b92bc46cebf4ca3a2156/images/output_modification.jpg
--------------------------------------------------------------------------------
/onnxslim/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | import warnings
3 |
4 | from onnxslim.cli import slim
5 | from onnxslim.core.optimization import OptimizationSettings
6 | from onnxslim.core.pattern.registry import (
7 | DEFAULT_FUSION_PATTERNS,
8 | register_fusion_pattern,
9 | )
10 | from onnxslim.version import __version__
11 |
12 | if os.path.dirname(os.path.realpath(__file__)) == os.path.join(os.path.realpath(os.getcwd()), "onnxslim"):
13 | message = (
14 | "You are importing onnxslim within its own root folder ({}). "
15 | "This is not expected to work and may give errors. Please exit the "
16 | "onnxslim project source and relaunch your python interpreter."
17 | )
18 | warnings.warn(message.format(os.getcwd()))
19 |
--------------------------------------------------------------------------------
/onnxslim/__main__.py:
--------------------------------------------------------------------------------
1 | from onnxslim.cli._main import main
2 |
3 | if __name__ == "__main__":
4 | main()
5 |
--------------------------------------------------------------------------------
/onnxslim/argparser.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import dataclasses
3 | from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
4 | from dataclasses import dataclass, field
5 | from typing import List, Optional, Type, Union, get_args, get_origin
6 |
7 | import onnxslim
8 |
9 |
10 | def _get_inner_type(arg_type):
11 | if get_origin(arg_type) is Union:
12 | return next((t for t in get_args(arg_type) if t is not type(None)), str)
13 | return arg_type
14 |
15 |
16 | @dataclass
17 | class ModelArguments:
18 | """
19 | Args:
20 | model (Union[str, onnx.ModelProto]): The ONNX model to be slimmed. It can be either a file path or an `onnx.ModelProto` object.
21 |
22 | output_model (str, optional): File path to save the slimmed model. If None, the model will not be saved.
23 | """
24 |
25 | input_model: str = field(metadata={"help": "input onnx model"})
26 | output_model: Optional[str] = field(default=None, metadata={"help": "output onnx model"})
27 |
28 |
29 | @dataclass
30 | class OptimizationArguments:
31 | """
32 | Args:
33 | no_shape_infer (bool, optional): Flag indicating whether to perform shape inference. Default is False.
34 |
35 | no_constant_folding (bool, optional): Flag indicating whether to perform constant folding. Default is False.
36 |
37 | skip_fusion_patterns (str, optional): String representing fusion patterns to skip. Default is None.
38 | """
39 |
40 | no_shape_infer: bool = field(default=False, metadata={"help": "whether to disable shape_infer, default false."})
41 | skip_optimizations: Optional[List[str]] = field(
42 | default=None,
43 | metadata={
44 | "help": "whether to skip some optimizations",
45 | "choices": list(onnxslim.OptimizationSettings.keys()),
46 | },
47 | )
48 | skip_fusion_patterns: Optional[List[str]] = field(
49 | default=None,
50 | metadata={
51 | "help": "whether to skip the fusion of some patterns",
52 | "choices": list(onnxslim.DEFAULT_FUSION_PATTERNS.keys()),
53 | },
54 | )
55 | size_threshold: int = field(
56 | default=None,
57 | metadata={
58 | "help": "size threshold in bytes, size larger than this value will not be folded, default None, which means fold all constants",
59 | },
60 | )
61 |
62 |
63 | @dataclass
64 | class ModificationArguments:
65 | """
66 | Args:
67 | input_shapes (str, optional): String representing the input shapes. Default is None.
68 |
69 | outputs (str, optional): String representing the outputs. Default is None.
70 |
71 | dtype (str, optional): Data type. Default is None.
72 |
73 | save_as_external_data (bool, optional): Flag indicating whether to split onnx as model and weight. Default is False.
74 | """
75 |
76 | input_shapes: Optional[List[str]] = field(
77 | default=None,
78 | metadata={
79 | "help": "input shape of the model, INPUT_NAME:SHAPE, e.g. x:1,3,224,224 or x1:1,3,224,224 x2:1,3,224,224"
80 | },
81 | )
82 | inputs: Optional[List[str]] = field(
83 | default=None,
84 | metadata={
85 | "help": "input of the model, INPUT_NAME:DTYPE, e.g. y:fp32 or y1:fp32 y2:fp32. If dtype is not specified, the dtype of the input will be the same as the original model if it has dtype, otherwise it will be fp32, available dtype: fp16, fp32, int32"
86 | },
87 | )
88 | outputs: Optional[List[str]] = field(
89 | default=None,
90 | metadata={
91 | "help": "output of the model, OUTPUT_NAME:DTYPE, e.g. y:fp32 or y1:fp32 y2:fp32. If dtype is not specified, the dtype of the output will be the same as the original model if it has dtype, otherwise it will be fp32, available dtype: fp16, fp32, int32"
92 | },
93 | )
94 | dtype: Optional[str] = field(
95 | default=None, metadata={"help": "convert data format to fp16 or fp32.", "choices": ["fp16", "fp32"]}
96 | )
97 | save_as_external_data: bool = field(
98 | default=False, metadata={"help": "split onnx as model and weight, default False."}
99 | )
100 |
101 |
102 | @dataclass
103 | class CheckerArguments:
104 | """
105 | Args:
106 | model_check (bool, optional): Flag indicating whether to perform model checking. Default is False.
107 |
108 | model_check_inputs (str, optional): The shape or tensor used for model check. Default is None.
109 |
110 | inspect (bool, optional): Flag indicating whether to inspect the model. Default is False.
111 |
112 | dump_to_disk (bool, optional): Flag indicating whether to dump the model detail to disk. Default is False.
113 |
114 | verbose (bool, optional): Flag indicating whether to print verbose logs. Default is False.
115 | """
116 |
117 | model_check: bool = field(default=False, metadata={"help": "enable model check"})
118 | model_check_inputs: Optional[List[str]] = field(
119 | default=None,
120 | metadata={
121 | "help": "Works only when model_check is enabled, Input shape of the model or numpy data path, INPUT_NAME:SHAPE or INPUT_NAME:DATAPATH, e.g. x:1,3,224,224 or x1:1,3,224,224 x2:data.npy. Useful when input shapes are dynamic."
122 | },
123 | )
124 | inspect: bool = field(default=False, metadata={"help": "inspect model, default False."})
125 | dump_to_disk: bool = field(default=False, metadata={"help": "dump model info to disk, default False."})
126 | verbose: bool = field(default=False, metadata={"help": "verbose mode, default False."})
127 |
128 |
129 | class OnnxSlimArgumentParser(ArgumentParser):
130 | def __init__(self, *argument_dataclasses: Type, **kwargs):
131 | if "formatter_class" not in kwargs:
132 | kwargs["formatter_class"] = ArgumentDefaultsHelpFormatter
133 | super().__init__(**kwargs)
134 | self.argument_dataclasses = argument_dataclasses
135 | self.parser = argparse.ArgumentParser(
136 | description="OnnxSlim: A Toolkit to Help Optimizer Onnx Model",
137 | formatter_class=argparse.RawDescriptionHelpFormatter,
138 | )
139 | self._add_arguments()
140 |
141 | def _add_arguments(self):
142 | for dataclass_type in self.argument_dataclasses:
143 | if dataclass_type is ModelArguments:
144 | continue
145 | for field_name, field_def in dataclass_type.__dataclass_fields__.items():
146 | arg_type = _get_inner_type(field_def.type)
147 | default_value = field_def.default if field_def.default is not field_def.default_factory else None
148 | help_text = field_def.metadata.get("help", "")
149 | nargs = "+" if get_origin(arg_type) == list else None
150 | choices = field_def.metadata.get("choices", None)
151 | if choices and default_value is not None and default_value not in choices:
152 | raise ValueError(
153 | f"Invalid default value '{default_value}' for argument '{field_name}'. Must be one of {choices}."
154 | )
155 | arg_type = get_args(arg_type)[0] if get_args(arg_type) else arg_type
156 | if arg_type == bool:
157 | self.parser.add_argument(
158 | f"--{field_name.replace('_', '-')}",
159 | action="store_true",
160 | default=default_value,
161 | help=help_text,
162 | )
163 | else:
164 | self.parser.add_argument(
165 | f"--{field_name.replace('_', '-')}",
166 | type=arg_type,
167 | default=default_value,
168 | nargs=nargs,
169 | choices=choices,
170 | help=help_text,
171 | )
172 |
173 | # Add positional arguments separately for ModelArguments
174 | self.parser.add_argument("input_model", help="input onnx model")
175 | self.parser.add_argument("output_model", nargs="?", default=None, help="output onnx model")
176 | self.parser.add_argument("-v", "--version", action="version", version=onnxslim.__version__)
177 |
178 | def parse_args_into_dataclasses(self):
179 | # Pre-parse arguments to check for `--inspect`
180 | pre_parsed_args, _ = self.parser.parse_known_args()
181 | if pre_parsed_args.inspect:
182 | for action in self.parser._actions:
183 | if action.dest == "input_model":
184 | action.nargs = "+"
185 | break
186 |
187 | args = self.parser.parse_args()
188 | args_dict = vars(args)
189 |
190 | outputs = []
191 | for dtype in self.argument_dataclasses:
192 | keys = {f.name for f in dataclasses.fields(dtype) if f.init}
193 | inputs = {k: v for k, v in args_dict.items() if k in keys}
194 | obj = dtype(**inputs)
195 | outputs.append(obj)
196 |
197 | return (*outputs,)
198 |
--------------------------------------------------------------------------------
/onnxslim/cli/__init__.py:
--------------------------------------------------------------------------------
1 | from onnxslim.cli._main import main, slim
2 |
--------------------------------------------------------------------------------
/onnxslim/cli/_main.py:
--------------------------------------------------------------------------------
1 | from typing import List, Union
2 |
3 | import onnx
4 |
5 |
6 | def slim(model: Union[str, onnx.ModelProto, List[Union[str, onnx.ModelProto]]], *args, **kwargs):
7 | import os
8 | import time
9 | from pathlib import Path
10 |
11 | from onnxslim.core import (
12 | OptimizationSettings,
13 | convert_data_format,
14 | freeze,
15 | input_modification,
16 | input_shape_modification,
17 | optimize,
18 | output_modification,
19 | shape_infer,
20 | )
21 | from onnxslim.utils import (
22 | TensorInfo,
23 | check_onnx,
24 | check_point,
25 | check_result,
26 | dump_model_info_to_disk,
27 | init_logging,
28 | onnxruntime_inference,
29 | print_model_info_as_table,
30 | save,
31 | summarize_model,
32 | update_outputs_dims,
33 | )
34 |
35 | output_model = args[0] if len(args) > 0 else kwargs.get("output_model", None)
36 | model_check = kwargs.get("model_check", False)
37 | input_shapes = kwargs.get("input_shapes", None)
38 | inputs = kwargs.get("inputs", None)
39 | outputs = kwargs.get("outputs", None)
40 | no_shape_infer = kwargs.get("no_shape_infer", False)
41 | skip_optimizations = kwargs.get("skip_optimizations", None)
42 | dtype = kwargs.get("dtype", None)
43 | skip_fusion_patterns = kwargs.get("skip_fusion_patterns", None)
44 | size_threshold = kwargs.get("size_threshold", None)
45 | size_threshold = int(size_threshold) if size_threshold else None
46 | kwargs.get("inspect", False)
47 | dump_to_disk = kwargs.get("dump_to_disk", False)
48 | save_as_external_data = kwargs.get("save_as_external_data", False)
49 | model_check_inputs = kwargs.get("model_check_inputs", None)
50 | verbose = kwargs.get("verbose", False)
51 |
52 | logger = init_logging(verbose)
53 |
54 | MAX_ITER = int(os.getenv("ONNXSLIM_MAX_ITER")) if os.getenv("ONNXSLIM_MAX_ITER") else 10
55 |
56 | start_time = time.time()
57 |
58 | def get_info(model, inspect=False):
59 | if isinstance(model, str):
60 | model_name = Path(model).name
61 | model = onnx.load(model)
62 | else:
63 | model_name = "OnnxModel"
64 |
65 | freeze(model)
66 |
67 | if not inspect:
68 | return model_name, model
69 |
70 | model_info = summarize_model(model, model_name)
71 |
72 | return model_info
73 |
74 | if isinstance(model, list):
75 | model_info_list = [get_info(m, inspect=True) for m in model]
76 |
77 | if dump_to_disk:
78 | [dump_model_info_to_disk(info) for info in model_info_list]
79 |
80 | print_model_info_as_table(model_info_list)
81 |
82 | return
83 | else:
84 | model_name, model = get_info(model)
85 | if output_model:
86 | original_info = summarize_model(model, model_name)
87 |
88 | if inputs:
89 | model = input_modification(model, inputs)
90 |
91 | if input_shapes:
92 | model = input_shape_modification(model, input_shapes)
93 |
94 | if outputs:
95 | model = output_modification(model, outputs)
96 |
97 | if model_check:
98 | input_data_dict, raw_onnx_output, model = check_onnx(model, model_check_inputs)
99 |
100 | output_info = {TensorInfo(o).name: TensorInfo(o).shape for o in model.graph.output}
101 |
102 | if not no_shape_infer:
103 | model = shape_infer(model)
104 |
105 | OptimizationSettings.reset(skip_optimizations)
106 | if OptimizationSettings.enabled():
107 | graph_check_point = check_point(model)
108 | while MAX_ITER > 0:
109 | logger.debug(f"iter: {MAX_ITER}")
110 | model = optimize(model, skip_fusion_patterns, size_threshold)
111 | if not no_shape_infer:
112 | model = shape_infer(model)
113 | graph = check_point(model)
114 | if graph == graph_check_point:
115 | logger.debug(f"converged at iter: {MAX_ITER}")
116 | break
117 | else:
118 | graph_check_point = graph
119 |
120 | MAX_ITER -= 1
121 |
122 | if dtype:
123 | model = convert_data_format(model, dtype)
124 |
125 | model = update_outputs_dims(model, output_dims=output_info)
126 |
127 | if model_check:
128 | slimmed_onnx_output, model = onnxruntime_inference(model, input_data_dict)
129 | if not check_result(raw_onnx_output, slimmed_onnx_output):
130 | return None
131 |
132 | if not output_model:
133 | return model
134 |
135 | slimmed_info = summarize_model(model, output_model)
136 | save(model, output_model, model_check, save_as_external_data, slimmed_info)
137 |
138 | end_time = time.time()
139 | elapsed_time = end_time - start_time
140 | print_model_info_as_table(
141 | [original_info, slimmed_info],
142 | elapsed_time,
143 | )
144 |
145 |
146 | def main():
147 | """Entry point for the OnnxSlim toolkit, processes command-line arguments and passes them to the slim function."""
148 | from onnxslim.argparser import (
149 | CheckerArguments,
150 | ModelArguments,
151 | ModificationArguments,
152 | OnnxSlimArgumentParser,
153 | OptimizationArguments,
154 | )
155 |
156 | argument_parser = OnnxSlimArgumentParser(
157 | ModelArguments, OptimizationArguments, ModificationArguments, CheckerArguments
158 | )
159 | model_args, optimization_args, modification_args, checker_args = argument_parser.parse_args_into_dataclasses()
160 |
161 | if not checker_args.inspect and checker_args.dump_to_disk:
162 | argument_parser.error("dump_to_disk can only be used with --inspect")
163 |
164 | if not optimization_args.no_shape_infer:
165 | from onnxslim.utils import check_onnx_compatibility, is_onnxruntime_available
166 |
167 | if is_onnxruntime_available():
168 | check_onnx_compatibility()
169 |
170 | slim(
171 | model_args.input_model,
172 | model_args.output_model,
173 | **optimization_args.__dict__,
174 | **modification_args.__dict__,
175 | **checker_args.__dict__,
176 | )
177 |
178 | return 0
179 |
--------------------------------------------------------------------------------
/onnxslim/core/__init__.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import tempfile
4 |
5 | import numpy as np
6 | import onnx
7 | from onnx import checker
8 |
9 | import onnxslim.third_party.onnx_graphsurgeon as gs
10 | from onnxslim.core.optimization import OptimizationSettings, optimize_model
11 | from onnxslim.third_party.onnx_graphsurgeon.exporters.onnx_exporter import dtype_to_onnx
12 | from onnxslim.third_party.onnx_graphsurgeon.ir.tensor import Constant
13 | from onnxslim.third_party.symbolic_shape_infer import SymbolicShapeInference
14 | from onnxslim.utils import save
15 |
16 | logger = logging.getLogger("onnxslim")
17 |
18 |
19 | DEBUG = bool(os.getenv("ONNXSLIM_DEBUG"))
20 | AUTO_MERGE = True if os.getenv("ONNXSLIM_AUTO_MERGE") is None else bool(int(os.getenv("ONNXSLIM_AUTO_MERGE")))
21 |
22 |
23 | def input_shape_modification(model: onnx.ModelProto, input_shapes: str) -> onnx.ModelProto:
24 | """Modifies input tensor shapes in the ONNX model according to the specified input_shapes string."""
25 | if not input_shapes:
26 | return
27 |
28 | graph = gs.import_onnx(model)
29 | input_names = [input.name for input in graph.inputs]
30 | tensors = graph.tensors()
31 |
32 | for input_shape in input_shapes:
33 | key, values = input_shape.rsplit(":", 1)
34 | values_list = [int(value) for value in values.split(",")]
35 | if key not in input_names:
36 | raise Exception(f"Input name {key} not found in model, available keys: {' '.join(input_names)}")
37 | tensors[key].shape = values_list
38 |
39 | for tensor in tensors.values():
40 | if tensor.name not in input_names:
41 | if isinstance(tensor, Constant):
42 | continue
43 | tensor.shape = None
44 |
45 | model = gs.export_onnx(graph)
46 |
47 | return model
48 |
49 |
50 | def output_modification(model: onnx.ModelProto, outputs: str) -> onnx.ModelProto:
51 | """Modifies the output layers of the ONNX model based on specified output names and data types."""
52 | graph = gs.import_onnx(model)
53 | graph.outputs.clear()
54 | tensors = graph.tensors()
55 | for output in outputs:
56 | values = output.rsplit(":", 1)
57 | if len(values) == 1:
58 | key = values[0]
59 | if key not in tensors.keys():
60 | raise Exception(f"Output name {key} not found in model, available keys: {' '.join(tensors.keys())}")
61 | dtype = tensors[key].dtype
62 | if dtype is None:
63 | dtype = np.float32
64 | logger.warning(f"Output layer {key} has no dtype, set to default {dtype}")
65 | else:
66 | key, dtype = values
67 | if dtype == "fp16":
68 | dtype = np.float16
69 | elif dtype == "fp32":
70 | dtype = np.float32
71 | elif dtype == "int32":
72 | dtype = np.int32
73 | elif dtype == "bool":
74 | dtype = bool
75 | else:
76 | raise Exception(f"Output layer {key} assigned unsupported dtype {dtype}")
77 |
78 | graph.outputs.append(tensors[key].to_variable(dtype=dtype, shape=tensors[key].shape))
79 |
80 | graph.cleanup(remove_unused_graph_inputs=True).toposort()
81 | model = gs.export_onnx(graph)
82 |
83 | return model
84 |
85 |
86 | def input_modification(model: onnx.ModelProto, inputs: str) -> onnx.ModelProto:
87 | """Modifies the output layers of the ONNX model based on specified output names and data types."""
88 | graph = gs.import_onnx(model)
89 | graph.inputs.clear()
90 | tensors = graph.tensors()
91 | for input in inputs:
92 | values = input.rsplit(":", 1)
93 | if len(values) == 1:
94 | key = values[0]
95 | if key not in tensors.keys():
96 | raise Exception(f"Input name {key} not found in model, available keys: {' '.join(tensors.keys())}")
97 | dtype = tensors[key].dtype
98 | if dtype is None:
99 | dtype = np.float32
100 | logger.warning(f"Input layer {key} has no dtype, set to default {dtype}")
101 | else:
102 | key, dtype = values
103 | if dtype == "fp16":
104 | dtype = np.float16
105 | elif dtype == "fp32":
106 | dtype = np.float32
107 | elif dtype == "int32":
108 | dtype = np.int32
109 | elif dtype == "bool":
110 | dtype = bool
111 | else:
112 | raise Exception(f"Input layer {key} assigned unsupported dtype {dtype}")
113 |
114 | graph.inputs.append(tensors[key].to_variable(dtype=dtype, shape=tensors[key].shape))
115 |
116 | graph.cleanup(remove_unused_graph_inputs=True).toposort()
117 | model = gs.export_onnx(graph)
118 |
119 | return model
120 |
121 |
122 | def shape_infer(model: onnx.ModelProto):
123 | """Infer tensor shapes in an ONNX model using symbolic and static shape inference techniques."""
124 | logger.debug("Start shape inference.")
125 | try:
126 | logger.debug("try onnxruntime shape infer.")
127 | model = SymbolicShapeInference.infer_shapes(model, auto_merge=AUTO_MERGE)
128 | except Exception as err:
129 | logger.debug(f"onnxruntime shape infer failed, try onnx shape infer. {err}")
130 | if model.ByteSize() >= checker.MAXIMUM_PROTOBUF:
131 | tmp_dir = tempfile.TemporaryDirectory()
132 | tmp_path = os.path.join(tmp_dir.name, "tmp.onnx")
133 | tmp_infer_path = os.path.join(tmp_dir.name, "tmp_infer.onnx")
134 | save(model, tmp_path)
135 | onnx.shape_inference.infer_shapes_path(tmp_path, tmp_infer_path)
136 | model = onnx.load(tmp_infer_path)
137 | else:
138 | model = onnx.shape_inference.infer_shapes(model)
139 | if DEBUG:
140 | onnx.save(model, "debug_shape_infer.onnx")
141 | logger.debug("Finish shape inference.")
142 | return model
143 |
144 |
145 | def optimize(model: onnx.ModelProto, skip_fusion_patterns: str = None, size_threshold: int = None):
146 | """Optimize the given ONNX model with options to skip specific fusion patterns and return the optimized model."""
147 | logger.debug("Start converting model to gs.")
148 | graph = gs.import_onnx(model).toposort()
149 | logger.debug("Finish converting model to gs.")
150 | if OptimizationSettings.constant_folding:
151 | logger.debug("Start constant folding.")
152 | graph.fold_constants(size_threshold=size_threshold).cleanup().toposort()
153 | logger.debug("Finish constant folding.")
154 | logger.debug("Start optimize model.")
155 | model = optimize_model(graph, skip_fusion_patterns)
156 | logger.debug("Finish optimize model.")
157 | if DEBUG:
158 | onnx.save(model, "debug_slim.onnx")
159 |
160 | return model
161 |
162 |
163 | def convert_data_format(model: onnx.ModelProto, dtype: str) -> onnx.ModelProto:
164 | """Convert ONNX model data format to specified dtype, supporting 'fp16' and 'fp32'."""
165 | if dtype == "fp16":
166 | from onnxconverter_common import float16
167 |
168 | model = float16.convert_float_to_float16(model)
169 | elif dtype == "fp32":
170 | graph = gs.import_onnx(model).toposort()
171 |
172 | for node in graph.nodes:
173 | if node.op == "Cast":
174 | inp_dtype = [input.dtype for input in node.inputs][0]
175 | if inp_dtype in [np.float16, np.float32]:
176 | node.erase()
177 | else:
178 | outp_dtype = [output.dtype for output in node.outputs][0]
179 | if outp_dtype == np.float16:
180 | node.attrs["to"] = dtype_to_onnx(np.float32)
181 | node.outputs[0].dtype = np.float32
182 | elif node.op == "ConstantOfShape":
183 | if hasattr(node, "attrs") and "value" in node.attrs:
184 | if node.attrs["value"].dtype == np.float16:
185 | node.attrs["value"].values = node.attrs["value"].values.astype(np.float32)
186 | node.outputs[0].dtype = np.float32
187 |
188 | for tensor in graph.tensors().values():
189 | if isinstance(tensor, gs.Variable) and tensor.dtype == np.float16:
190 | tensor.dtype = np.float32
191 | elif isinstance(tensor, gs.Constant) and tensor.dtype == np.float16:
192 | tensor.values = tensor.values.astype(np.float32)
193 |
194 | graph.cleanup(remove_unused_graph_inputs=True).toposort()
195 | model = gs.export_onnx(graph)
196 |
197 | return model
198 |
199 |
200 | def freeze(model: onnx.ModelProto):
201 | """Freeze the input layers of an ONNX model by removing the initializers from the input graph."""
202 | inputs = model.graph.input
203 | name_to_input = {}
204 | for input in inputs:
205 | if input.name in name_to_input:
206 | logger.warning(f"Duplicate input name: {input.name}")
207 | name_to_input[input.name] = input
208 |
209 | for initializer in model.graph.initializer:
210 | if initializer.name in name_to_input:
211 | inputs.remove(name_to_input[initializer.name])
212 | name_to_input.pop(initializer.name)
213 |
--------------------------------------------------------------------------------
/onnxslim/core/optimization/__init__.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from collections import Counter
3 | from typing import List, Optional, Union
4 |
5 | import onnx
6 |
7 | import onnxslim.third_party.onnx_graphsurgeon as gs
8 | from onnxslim.core.pattern.registry import get_fusion_patterns
9 | from onnxslim.third_party.onnx_graphsurgeon.ir.graph import Graph
10 |
11 | logger = logging.getLogger("onnxslim")
12 |
13 | from .dead_node_elimination import dead_node_elimination
14 | from .subexpression_elimination import subexpression_elimination
15 | from .weight_tying import tie_weights
16 |
17 |
18 | class OptimizationSettings:
19 | constant_folding = True
20 | graph_fusion = True
21 | dead_node_elimination = True
22 | subexpression_elimination = True
23 | weight_tying = True
24 |
25 | @classmethod
26 | def keys(cls):
27 | return [
28 | "constant_folding",
29 | "graph_fusion",
30 | "dead_node_elimination",
31 | "subexpression_elimination",
32 | "weight_tying",
33 | ]
34 |
35 | @classmethod
36 | def reset(cls, skip_optimizations: Optional[List[str]] = None):
37 | for key in cls.keys():
38 | if skip_optimizations and key in skip_optimizations:
39 | setattr(cls, key, False)
40 | else:
41 | setattr(cls, key, True)
42 |
43 | @classmethod
44 | def stats(cls):
45 | return {key: getattr(cls, key) for key in cls.keys()}
46 |
47 | @classmethod
48 | def enabled(cls):
49 | return any([getattr(cls, key) for key in cls.keys()])
50 |
51 |
52 | def optimize_model(model: Union[onnx.ModelProto, gs.Graph], skip_fusion_patterns: str = None) -> onnx.ModelProto:
53 | """Optimize and transform the given ONNX model using various fusion patterns and graph rewriting techniques."""
54 | graph = model if isinstance(model, gs.Graph) else gs.import_onnx(model)
55 | if OptimizationSettings.graph_fusion:
56 | logger.debug("Start graph_fusion.")
57 | fusion_patterns = get_fusion_patterns(skip_fusion_patterns)
58 | fusion_pairs = find_matches(graph, fusion_patterns)
59 | for match in fusion_pairs.values():
60 | graph.replace_custom_layer(**match)
61 | graph.cleanup(remove_unused_graph_inputs=True).toposort()
62 | logger.debug("Finish graph_fusion.")
63 | if OptimizationSettings.dead_node_elimination:
64 | logger.debug("Start dead_node_elimination.")
65 | dead_node_elimination(graph)
66 | graph.cleanup(remove_unused_graph_inputs=True).toposort()
67 | logger.debug("Finish dead_node_elimination.")
68 | if OptimizationSettings.subexpression_elimination:
69 | logger.debug("Start subexpression_elimination.")
70 | subexpression_elimination(graph)
71 | graph.cleanup(remove_unused_graph_inputs=True).toposort()
72 | logger.debug("Finish subexpression_elimination.")
73 | if OptimizationSettings.weight_tying:
74 | logger.debug("Start weight_tying.")
75 | tie_weights(graph)
76 | logger.debug("Finish weight_tying.")
77 | model = gs.export_onnx(graph)
78 |
79 | return model
80 |
81 |
82 | @gs.Graph.register()
83 | def replace_custom_layer(
84 | self,
85 | op: str,
86 | inputs,
87 | outputs: List[str],
88 | name: str,
89 | attrs: dict = None,
90 | domain: str = "ai.onnx.contrib",
91 | ):
92 | """Replace a custom layer in the computational graph with specified parameters and domain."""
93 | return self.layer(
94 | op=op,
95 | inputs=inputs,
96 | outputs=outputs,
97 | name=name,
98 | attrs=attrs,
99 | domain=domain,
100 | )
101 |
102 |
103 | def find_matches(graph: Graph, fusion_patterns: dict):
104 | """Find matching patterns in the graph based on provided fusion patterns."""
105 | match_map = {}
106 | counter = Counter()
107 | for node in reversed(graph.nodes):
108 | if node.name not in match_map:
109 | for layer_type, pattern_matcher in fusion_patterns.items():
110 | match = pattern_matcher.match(node)
111 | if match:
112 | match_case = pattern_matcher.rewrite(opset=graph.opset)
113 | logger.debug(f"matched pattern {layer_type}")
114 | for _, match in match_case.items():
115 | if "op" not in match:
116 | match.update({"op": layer_type})
117 | if "name" not in match:
118 | match.update({"name": f"{layer_type.lower()}_{counter[layer_type]}"})
119 | counter.update([layer_type])
120 | match_map.update(match_case)
121 |
122 | return match_map
123 |
124 |
125 | def get_previous_node_by_type(node, op_type, trajectory=None):
126 | """Recursively find and return the first preceding node of a specified type in the computation graph."""
127 | if trajectory is None:
128 | trajectory = []
129 | node_feeds = node.feeds
130 | for node_feed in node_feeds:
131 | trajectory.append(node_feed)
132 | if node_feed.op == op_type:
133 | return trajectory
134 | else:
135 | return get_previous_node_by_type(node_feed, op_type, trajectory)
136 |
--------------------------------------------------------------------------------
/onnxslim/core/optimization/dead_node_elimination.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import numpy as np
4 |
5 | import onnxslim.third_party.onnx_graphsurgeon as gs
6 | from onnxslim.third_party.onnx_graphsurgeon.exporters.onnx_exporter import dtype_to_onnx
7 | from onnxslim.third_party.onnx_graphsurgeon.ir.tensor import Constant, Variable
8 |
9 | logger = logging.getLogger("onnxslim")
10 |
11 |
12 | def dead_node_elimination(graph, is_subgraph=False):
13 | """Perform in-place constant folding optimizations on the given computational graph by eliminating redundant
14 | nodes.
15 | """
16 | for subgraph in graph.subgraphs():
17 | dead_node_elimination(subgraph, is_subgraph=True)
18 |
19 | for node in graph.nodes:
20 | if node.op in {"Identity", "Dropout"}:
21 | if not is_subgraph:
22 | node.erase()
23 | logger.debug(f"removing {node.op} op: {node.name}")
24 | elif node.op == "Pad":
25 | if len(node.inputs) > 1 and isinstance(node.inputs[1], Constant):
26 | pad_value = node.inputs[1].values.tolist()
27 | pad_value = pad_value if isinstance(pad_value, list) else [pad_value]
28 | if all(value == 0 for value in pad_value):
29 | node.erase()
30 | logger.debug(f"removing {node.op} op: {node.name}")
31 | elif node.op == "Cast":
32 | inp_dtype = [dtype_to_onnx(input.dtype) for input in node.inputs][0]
33 | if inp_dtype == node.attrs["to"]:
34 | node.erase()
35 | logger.debug(f"removing {node.op} op: {node.name}")
36 | elif node.op == "Reshape":
37 | if (node.inputs[0].shape and len(node.inputs[0].shape) == 1) and (
38 | node.outputs[0].shape and len(node.outputs[0].shape) == 1
39 | ):
40 | node.erase()
41 | logger.debug(f"removing {node.op} op: {node.name}")
42 | elif node.inputs[0].shape and node.outputs[0].shape and node.inputs[0].shape == node.outputs[0].shape:
43 | node.erase()
44 | logger.debug(f"removing {node.op} op: {node.name}")
45 | else:
46 | node_output_shape = node.outputs[0].shape
47 | if node_output_shape and check_shape(node_output_shape) and not isinstance(node.inputs[1], gs.Constant):
48 | shapes = [shape if isinstance(shape, int) else -1 for shape in node_output_shape]
49 | reshape_const = gs.Constant(
50 | f"{node.inputs[1].name}_",
51 | values=np.array(shapes, dtype=np.int64),
52 | )
53 | node.inputs.pop(1)
54 | node.inputs.insert(1, reshape_const)
55 | logger.debug(f"replacing {node.op} op: {node.name}")
56 | elif node.op == "Mul":
57 | if (isinstance(node.inputs[1], Constant) and isinstance(node.inputs[0], Variable)) or (
58 | isinstance(node.inputs[0], Constant) and isinstance(node.inputs[1], Variable)
59 | ):
60 | idx, constant_variable = get_constant_variable(node, return_idx=True)
61 | if np.all(constant_variable.values == 1):
62 | var_idx = 0 if idx == 1 else 1
63 | node.erase(var_idx, 0)
64 | logger.debug(f"removing {node.op} op: {node.name}")
65 | elif node.op == "Add":
66 | if (isinstance(node.inputs[1], Constant) and isinstance(node.inputs[0], Variable)) or (
67 | isinstance(node.inputs[0], Constant) and isinstance(node.inputs[1], Variable)
68 | ):
69 | idx, constant_variable = get_constant_variable(node, return_idx=True)
70 | value = constant_variable.values
71 | var_idx = 0 if idx == 1 else 1
72 | if value.ndim == 0 and value == 0:
73 | node.erase(var_idx, 0)
74 | logger.debug(f"removing {node.op} op: {node.name}")
75 | elif np.all(value == 0) and (node.inputs[var_idx].shape == node.outputs[0].shape):
76 | node.erase(var_idx, 0)
77 | logger.debug(f"removing {node.op} op: {node.name}")
78 | elif node.op == "Expand":
79 | # tests/test_onnx_nets.py::TestTimmClass::test_timm[lambda_resnet26rpt_256]
80 | if len(node.inputs) > 1 and isinstance(node.inputs[1], Constant):
81 | constant_variable = node.inputs[1]
82 | value = constant_variable.values
83 | if node.inputs[0].shape == node.outputs[0].shape:
84 | node.erase()
85 | logger.debug(f"removing {node.op} op: {node.name}")
86 | elif value.ndim == 0 and value == 1:
87 | node.erase()
88 | logger.debug(f"removing {node.op} op: {node.name}")
89 | elif node.op == "Concat":
90 | if len(node.inputs) == 1:
91 | node.erase()
92 | logger.debug(f"removing {node.op} op: {node.name}")
93 | else:
94 | for input in node.inputs:
95 | if isinstance(input, Constant) and input.values.size == 0:
96 | node.inputs.remove(input)
97 | elif node.op == "Sub":
98 | if isinstance(node.inputs[1], Constant) and isinstance(node.inputs[0], Variable):
99 | constant_variable = node.inputs[1]
100 | value = constant_variable.values
101 | if value.ndim == 0 and value == 0:
102 | node.erase()
103 | logger.debug(f"removing {node.op} op: {node.name}")
104 | elif np.all(value == 0) and (node.inputs[0].shape == node.outputs[0].shape):
105 | node.erase()
106 | logger.debug(f"removing {node.op} op: {node.name}")
107 | elif node.op == "Div":
108 | if isinstance(node.inputs[1], Constant) and isinstance(node.inputs[0], Variable):
109 | constant_variable = node.inputs[1]
110 | value = constant_variable.values
111 | if value.ndim == 0 and value == 1:
112 | node.erase()
113 | logger.debug(f"removing {node.op} op: {node.name}")
114 | elif np.all(value == 1) and (node.inputs[0].shape == node.outputs[0].shape):
115 | node.erase()
116 | logger.debug(f"removing {node.op} op: {node.name}")
117 | elif node.op == "Split":
118 | if (
119 | len(node.outputs) == 1
120 | and node.outputs[0].shape
121 | and node.inputs[0].shape
122 | and node.outputs[0].shape == node.inputs[0].shape
123 | ):
124 | node.erase()
125 | logger.debug(f"removing {node.op} op: {node.name}")
126 | elif node.op == "Resize":
127 | mode = node.attrs.get("mode")
128 | if mode is None:
129 | node.attrs["mode"] = "nearest"
130 | logger.debug(f"setting mode to nearest for {node.op} op: {node.name} since it is not set")
131 |
132 |
133 | def check_shape(shapes):
134 | """Verify that 'shapes' contains exactly one string and all other elements are positive integers."""
135 | string_count = 0
136 | non_negative_int_count = 0
137 |
138 | for item in shapes:
139 | if isinstance(item, str):
140 | string_count += 1
141 | elif isinstance(item, int) and item > 0:
142 | non_negative_int_count += 1
143 |
144 | return (string_count == 1 and non_negative_int_count == len(shapes) - 1) or non_negative_int_count == len(shapes)
145 |
146 |
147 | def get_constant_variable(node, return_idx=False):
148 | """Return the first constant variable found in a node's inputs, optionally including the index."""
149 | for idx, input in enumerate(list(node.inputs)):
150 | if isinstance(input, Constant):
151 | return (idx, input) if return_idx else input
152 |
--------------------------------------------------------------------------------
/onnxslim/core/optimization/subexpression_elimination.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | from onnxslim.third_party.onnx_graphsurgeon.ir.tensor import Variable
4 |
5 | logger = logging.getLogger("onnxslim")
6 |
7 |
8 | def find_and_remove_replaceable_nodes(nodes):
9 | """Find and remove duplicate or replaceable nodes in a given list of computational graph nodes."""
10 |
11 | def get_node_key(node):
12 | input_names = []
13 | for input_node in node.inputs:
14 | if isinstance(input_node, Variable):
15 | input_names.append(input_node.name)
16 | return "_".join(input_names) if input_names else None
17 |
18 | node_dict = {}
19 | for node in nodes:
20 | key = get_node_key(node)
21 | if key:
22 | if key in node_dict:
23 | node_dict[key].append(node)
24 | else:
25 | node_dict[key] = [node]
26 |
27 | for key, bucketed_nodes in node_dict.items():
28 | if len(bucketed_nodes) > 1:
29 | keep_nodes = [True] * len(bucketed_nodes)
30 | for i, node in enumerate(bucketed_nodes):
31 | if keep_nodes[i]:
32 | for j in range(i + 1, len(bucketed_nodes)):
33 | if keep_nodes[j] and can_be_replaced(node, bucketed_nodes[j]):
34 | keep_nodes[j] = False
35 | existing_node = node
36 | to_be_removed_node = bucketed_nodes[j]
37 | to_be_removed_node.replace_all_uses_with(existing_node)
38 | logger.debug(
39 | f"Node {to_be_removed_node.name} Op {to_be_removed_node.op} can be replaced by {existing_node.name}"
40 | )
41 |
42 |
43 | def sequences_equal(seq1, seq2):
44 | """Check if two sequences are equal by comparing their lengths and elements."""
45 | length_match = len(seq1) == len(seq2)
46 | if not length_match:
47 | return False
48 |
49 | return all(elem1 == elem2 for elem1, elem2 in zip(seq1, seq2))
50 |
51 |
52 | def can_be_replaced(node, other_node):
53 | """Check if two nodes can be replaced based on their operations, attributes, and inputs."""
54 | attrs_match = node.op == other_node.op and node.attrs == other_node.attrs
55 | node_input = [input for input in node.inputs if not input.is_empty()]
56 | other_node_input = [input for input in other_node.inputs if not input.is_empty()]
57 | inputs_match = sequences_equal(node_input, other_node_input)
58 |
59 | return attrs_match and inputs_match
60 |
61 |
62 | def subexpression_elimination(graph):
63 | """Perform subexpression elimination on a computational graph to optimize node operations."""
64 | nodes_by_op = {}
65 |
66 | for node in graph.nodes:
67 | op = node.op
68 | if op not in nodes_by_op:
69 | nodes_by_op[op] = []
70 | nodes_by_op[op].append(node)
71 |
72 | for nodes in nodes_by_op.values():
73 | find_and_remove_replaceable_nodes(nodes)
74 |
--------------------------------------------------------------------------------
/onnxslim/core/optimization/weight_tying.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | logger = logging.getLogger("onnxslim")
4 | import onnxslim.third_party.onnx_graphsurgeon as gs
5 |
6 |
7 | def tie_weights(graph):
8 | """Tie weights in a computational graph to reduce the number of parameters."""
9 | tensor_map = graph.tensors()
10 | constant_tensors = [tensor for tensor in tensor_map.values() if isinstance(tensor, gs.Constant)]
11 |
12 | sub_graphs = graph.subgraphs(recursive=True)
13 | sub_graphs_constant_tensors = [
14 | [tensor for name, tensor in sub_graph.tensors().items() if isinstance(tensor, gs.Constant)]
15 | for sub_graph in sub_graphs
16 | ]
17 |
18 | constant_tensors.extend([tensor for tensors in sub_graphs_constant_tensors for tensor in tensors])
19 |
20 | def replace_constant_references(existing_constant, to_be_removed_constant):
21 | users = list(to_be_removed_constant.outputs)
22 |
23 | for user in users:
24 | for idx, inp in enumerate(user.inputs):
25 | if (inp == to_be_removed_constant) and (inp.name == to_be_removed_constant.name):
26 | user.inputs.pop(idx)
27 | user.inputs.insert(idx, existing_constant)
28 |
29 | if len(constant_tensors) > 1:
30 | keep_constants = [True] * len(constant_tensors)
31 | for i, constant_tensor in enumerate(constant_tensors):
32 | if keep_constants[i]:
33 | for j in range(i + 1, len(constant_tensors)):
34 | if keep_constants[j]:
35 | if constant_tensor == constant_tensors[j]:
36 | keep_constants[j] = False
37 | replace_constant_references(constant_tensor, constant_tensors[j])
38 | logger.debug(
39 | f"Constant {constant_tensors[j].name} can be replaced by {constant_tensor.name}"
40 | )
41 |
--------------------------------------------------------------------------------
/onnxslim/core/pattern/__init__.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import re
3 | from abc import abstractmethod
4 |
5 | import onnxslim.third_party.onnx_graphsurgeon as gs
6 | from onnxslim.third_party.onnx_graphsurgeon import Constant
7 |
8 | logger = logging.getLogger("onnxslim")
9 |
10 |
11 | def get_name(name):
12 | """Sanitizes the input string by replacing illegal characters with underscores and prefixing with an underscore if
13 | numeric.
14 | """
15 | _illegal_char_regex = re.compile("[^0-9a-zA-Z_]+")
16 | sanitized_name = _illegal_char_regex.sub("_", name)
17 | if sanitized_name.isdigit():
18 | sanitized_name = f"_{sanitized_name}"
19 |
20 | return sanitized_name
21 |
22 |
23 | class NodeDescriptor:
24 | """
25 | case 0: input [1, 2, 3, 4, 5] output [0] Optype Name 5 1 i0 i1 i2 i3 i4 o0
26 | case 1: input [1, ...] output [0] Optype Name 1+ 1 i0 o0
27 | case 2: input [..., 1, ...] output [0] Optype Name 1* 1 i0 o0.
28 | """
29 |
30 | def __init__(self, node_spec):
31 | """Initialize NodeDescriptor with node_spec list requiring at least 4 elements."""
32 | if not isinstance(node_spec, list):
33 | raise ValueError("node_spec must be a list")
34 | if len(node_spec) < 4:
35 | raise ValueError(f"node_spec must have at least 4 elements {node_spec}")
36 |
37 | def get_input_info(io_spec):
38 | """Parses io_spec to return a tuple of (integer, boolean) indicating the presence of a plus sign in the
39 | input.
40 | """
41 | if not io_spec.isdigit():
42 | match = re.search(r"(\d+)([+*])", io_spec)
43 | if match:
44 | number = match.group(1)
45 | operator = match.group(2)
46 |
47 | if not number.isdigit():
48 | raise ValueError(f"input_num and output_num must be integers {io_spec}")
49 |
50 | if operator == "+":
51 | return int(number), True, "append"
52 | elif operator == "*":
53 | return int(number), True, "free-match"
54 | else:
55 | raise ValueError(f"operator must be + or * {io_spec}")
56 |
57 | return int(io_spec), False, None
58 |
59 | self.op = node_spec[0]
60 | self.name = node_spec[1]
61 | self.input_num, self.coarse_input_num, self.input_mode = get_input_info(node_spec[2])
62 | self.output_num, self.coarse_output_num, self.output_mode = get_input_info(node_spec[3])
63 | self.input_names = node_spec[4 : 4 + self.input_num]
64 | self.output_names = node_spec[4 + self.input_num :]
65 | assert len(self.input_names) == self.input_num
66 | assert len(self.output_names) == self.output_num, f"{self.name} {len(self.output_names)} != {self.output_num}"
67 |
68 | def __repr__(self):
69 | """Return a string representation of the object, including its name, operation type, input/output counts, and
70 | input/output names.
71 | """
72 | return f"name: {self.name}, type: {self.op}, input_num: {self.input_num}, output_num: {self.output_num}, input_names: {self.input_names}, output_names: {self.output_names}"
73 |
74 | def __dict__(self):
75 | """Returns a dictionary representation of the object, with 'name' as the key."""
76 | return {
77 | "name": self,
78 | }
79 |
80 |
81 | class Pattern:
82 | def __init__(self, pattern):
83 | """Initialize the Pattern class with a given pattern and parse its nodes."""
84 | self.pattern = pattern
85 | self.nodes = self.parse_nodes()
86 |
87 | def parse_nodes(self):
88 | """Parse pattern into a list of NodeDescriptor objects from non-empty, stripped, and split lines."""
89 | nodes = self.pattern.split("\n")
90 | nodes = [line.strip().split() for line in nodes if line]
91 | nodes = [NodeDescriptor(node) for node in nodes if node]
92 | return nodes
93 |
94 | def match(self, node):
95 | """Match a node against a precompiled pattern."""
96 | return self.pattern.match(node)
97 |
98 | def __repr__(self):
99 | """Return a string representation of the pattern attribute."""
100 | return self.pattern
101 |
102 |
103 | class PatternMatcher:
104 | def __init__(self, pattern, priority):
105 | """Initialize the PatternMatcher with a given pattern and priority, and prepare node references and output
106 | names.
107 | """
108 | self.pattern = pattern
109 | self.priority = priority
110 | self.pattern_dict = {node.name: node for node in pattern.nodes}
111 | self.output_names = [node.name for node in pattern.nodes if node.op == "output"]
112 |
113 | def get_match_point(self):
114 | """Retrieve the match point node from the pattern dictionary based on output node input names."""
115 | return self.pattern_dict[self.pattern_dict[self.output_names[0]].input_names[0]]
116 |
117 | def match(self, node):
118 | """Match a given node to a pattern by comparing input names with the match point node from the pattern
119 | dictionary.
120 | """
121 | match_point = self.get_match_point()
122 |
123 | def match_(node, pattern_node):
124 | """Match a given node to a pattern by comparing input names with the match point node from the pattern
125 | dictionary.
126 | """
127 | if pattern_node.op == "input":
128 | return True
129 |
130 | # node is an input variable
131 | if not hasattr(node, "op"):
132 | return False
133 |
134 | if node.op == pattern_node.op:
135 | setattr(self, pattern_node.name, node)
136 |
137 | node_feeds = node.feeds
138 | if pattern_node.coarse_input_num:
139 | if len(node_feeds) < len(pattern_node.input_names):
140 | return False
141 | else:
142 | if len(node_feeds) != len(pattern_node.input_names):
143 | return False
144 |
145 | if pattern_node.input_mode == "append" or pattern_node.input_mode is None:
146 | pattern_nodes = [
147 | self.pattern_dict[name] if name != "?" else None for name in pattern_node.input_names
148 | ]
149 | all_match = True
150 | for node_feed, pattern_node in zip(node_feeds, pattern_nodes):
151 | if pattern_node is not None:
152 | node_match = match_(node_feed, pattern_node)
153 | if not node_match:
154 | return False
155 | setattr(self, pattern_node.name, node_feed)
156 |
157 | return all_match
158 | elif pattern_node.input_mode == "free-match":
159 | pattern_nodes = [
160 | self.pattern_dict[name] if name != "?" else None for name in pattern_node.input_names
161 | ]
162 | all_match = True
163 | for pattern_node in pattern_nodes:
164 | if pattern_node is not None:
165 | node_match = False
166 | for node_feed in node_feeds:
167 | node_match = match_(node_feed, pattern_node)
168 | if node_match:
169 | break
170 | if not node_match:
171 | return False
172 | setattr(self, pattern_node.name, node_feed)
173 |
174 | return all_match
175 | return False
176 |
177 | if match_(node, match_point):
178 | setattr(self, "output", node.outputs)
179 | if self.parameter_check():
180 | return True
181 |
182 | return False
183 |
184 | @abstractmethod
185 | def rewrite(self, opset=11):
186 | """Abstract method to rewrite the graph based on matched patterns, to be implemented by subclasses."""
187 | raise NotImplementedError("rewrite method must be implemented")
188 |
189 | def parameter_check(self):
190 | """Check and validate parameters, returning True if valid."""
191 | return True
192 |
193 |
194 | class PatternGenerator:
195 | def __init__(self, onnx_model):
196 | """Initialize the PatternGenerator class with an ONNX model and process its graph."""
197 | self.graph = gs.import_onnx(onnx_model)
198 | self.graph.fold_constants().cleanup().toposort()
199 |
200 | def generate(self):
201 | """Generate the inputs, outputs, and nodes from the graph of the initialized ONNX model."""
202 | inputs = self.graph.inputs
203 | outputs = self.graph.outputs
204 | nodes = self.graph.nodes
205 |
206 | template = []
207 | for input in inputs:
208 | name = get_name(input.name)
209 | template.append(
210 | " ".join(
211 | ["input", name, "0", str(len(input.outputs))] + [get_name(output.name) for output in input.outputs]
212 | )
213 | )
214 |
215 | for node in nodes:
216 | if node.op != "Constant":
217 | name = get_name(node.name)
218 | feeds = node.feeds
219 | users = node.users
220 | template.append(
221 | " ".join(
222 | [node.op, name, str(len(feeds)), str(len(users))]
223 | + ["?" if isinstance(feed, Constant) else get_name(feed.name) for feed in feeds]
224 | + ["?" if isinstance(user, Constant) else get_name(user.name) for user in users]
225 | )
226 | )
227 |
228 | for output in outputs:
229 | name = get_name(output.name)
230 | template.append(
231 | " ".join(
232 | ["output", name, str(len(output.inputs)), "0"] + [get_name(input.name) for input in output.inputs]
233 | )
234 | )
235 |
236 | return "\n".join(template)
237 |
--------------------------------------------------------------------------------
/onnxslim/core/pattern/elimination/__init__.py:
--------------------------------------------------------------------------------
1 | from .concat import *
2 | from .reshape import *
3 | from .reshape_as import *
4 | from .slice import *
5 | from .unsqueeze import *
6 |
--------------------------------------------------------------------------------
/onnxslim/core/pattern/elimination/concat.py:
--------------------------------------------------------------------------------
1 | from onnxslim.core.pattern import Pattern, PatternMatcher
2 | from onnxslim.core.pattern.registry import register_fusion_pattern
3 |
4 |
5 | class ConcatPatternMatcher(PatternMatcher):
6 | def __init__(self, priority):
7 | """Initializes the ConcatPatternMatcher with a specified priority using a predefined graph pattern."""
8 | pattern = Pattern(
9 | """
10 | input input 0 1 concat_0
11 | Concat concat_0 1+ 1 input concat_1
12 | Concat concat_1 1* 1 concat_0 output
13 | output output 1 0 concat_1
14 | """
15 | )
16 | super().__init__(pattern, priority)
17 |
18 | @property
19 | def name(self):
20 | """Returns the name of the elimination pattern, 'EliminationConcat'."""
21 | return "EliminationConcat"
22 |
23 | def rewrite(self, opset=11):
24 | """Rewrites an elimination pattern for concat nodes by optimizing nested slice operations."""
25 | match_case = {}
26 |
27 | node_concat_0 = self.concat_0
28 | users_node_concat_0 = node_concat_0.users
29 | node_concat_1 = self.concat_1
30 | node_concat_0_axis = node_concat_0.attrs.get("axis", 0)
31 | node_concat_1.attrs.get("axis", 0)
32 |
33 | if all(user.op == "Concat" and user.attrs.get("axis", 0) == node_concat_0_axis for user in users_node_concat_0):
34 | index = node_concat_1.inputs.index(node_concat_0.outputs[0])
35 | node_concat_1.inputs.pop(index)
36 | for i, item in enumerate(node_concat_0.inputs):
37 | node_concat_1.inputs.insert(index + i, item)
38 | inputs = list(node_concat_1.inputs)
39 | outputs = list(node_concat_1.outputs)
40 | node_concat_1.inputs.clear()
41 | node_concat_1.outputs.clear()
42 |
43 | if len(users_node_concat_0) == 0:
44 | node_concat_0.inputs.clear()
45 | node_concat_0.outputs.clear()
46 |
47 | attrs = {"axis": node_concat_0_axis}
48 |
49 | match_case[node_concat_1.name] = {
50 | "op": "Concat",
51 | "inputs": inputs,
52 | "outputs": outputs,
53 | "name": node_concat_1.name,
54 | "attrs": attrs,
55 | "domain": None,
56 | }
57 |
58 | return match_case
59 |
60 |
61 | register_fusion_pattern(ConcatPatternMatcher(1))
62 |
--------------------------------------------------------------------------------
/onnxslim/core/pattern/elimination/reshape.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import onnxslim.third_party.onnx_graphsurgeon as gs
4 | from onnxslim.core.pattern import Pattern, PatternMatcher
5 | from onnxslim.core.pattern.registry import register_fusion_pattern
6 |
7 |
8 | class ReshapePatternMatcher(PatternMatcher):
9 | def __init__(self, priority):
10 | """Initializes the ReshapePatternMatcher with a priority and a specific pattern for detecting nested reshape
11 | operations.
12 | """
13 | pattern = Pattern(
14 | """
15 | input input 0 1 reshape_0
16 | Reshape reshape_0 2 1 input ? reshape_1
17 | Reshape reshape_1 2 1 reshape_0 ? output
18 | output output 1 0 reshape_1
19 | """
20 | )
21 | super().__init__(pattern, priority)
22 |
23 | @property
24 | def name(self):
25 | """Returns the name 'EliminationReshape'."""
26 | return "EliminationReshape"
27 |
28 | def rewrite(self, opset=11):
29 | """Rewrite the computational graph by eliminating redundant reshape operations when certain conditions are
30 | met.
31 | """
32 | match_case = {}
33 | node = self.reshape_1
34 | first_reshape_node = node.i(0)
35 | first_reshape_node_inputs = list(first_reshape_node.inputs)
36 | first_reshape_node_users = first_reshape_node.users
37 | if len(first_reshape_node_users) == 1:
38 | second_reshape_node = node
39 |
40 | def check_constant_mergeable(reshape_node):
41 | """Check if a reshape node's shape input, containing zero dimensions, can be merged with its input
42 | node's shape.
43 | """
44 | if isinstance(reshape_node.inputs[1], gs.Constant):
45 | input_shape = reshape_node.inputs[0].shape
46 | reshape_shape = reshape_node.inputs[1].values.tolist()
47 | if input_shape is not None and np.any(np.array(reshape_shape) == 0):
48 | shape = [
49 | input_shape[i] if dim_size == 0 else reshape_shape[i]
50 | for i, dim_size in enumerate(reshape_shape)
51 | ]
52 | if not all(isinstance(item, int) for item in shape):
53 | return False
54 | return True
55 |
56 | if check_constant_mergeable(first_reshape_node) and check_constant_mergeable(second_reshape_node):
57 | inputs = []
58 | inputs.append(first_reshape_node_inputs[0])
59 | inputs.append(second_reshape_node.inputs[1])
60 | outputs = list(second_reshape_node.outputs)
61 | first_reshape_node.outputs.clear()
62 | second_reshape_node.inputs.clear()
63 | second_reshape_node.outputs.clear()
64 |
65 | match_case[first_reshape_node.name] = {
66 | "op": "Reshape",
67 | "inputs": inputs,
68 | "outputs": outputs,
69 | "name": first_reshape_node.name,
70 | "attrs": first_reshape_node.attrs,
71 | "domain": None,
72 | }
73 |
74 | return match_case
75 |
76 |
77 | register_fusion_pattern(ReshapePatternMatcher(1))
78 |
--------------------------------------------------------------------------------
/onnxslim/core/pattern/elimination/reshape_as.py:
--------------------------------------------------------------------------------
1 | import onnxslim.third_party.onnx_graphsurgeon as gs
2 | from onnxslim.core.pattern import Pattern, PatternMatcher
3 | from onnxslim.core.pattern.registry import register_fusion_pattern
4 |
5 |
6 | class ReshapeAsPatternMatcher(PatternMatcher):
7 | def __init__(self, priority):
8 | """Initializes the ReshapeAsPatternMatcher with a priority and a specific pattern for reshape as operations."""
9 | pattern = Pattern(
10 | """
11 | input input 0 1 shape
12 | Shape shape 1+ 1 input gather
13 | Gather gather 1+ 1 shape unsqueeze
14 | Unsqueeze unsqueeze 1+ 1 gather output
15 | Concat concat 1+ 1 unsqueeze output
16 | output output 1 0 concat
17 | """
18 | )
19 | super().__init__(pattern, priority)
20 |
21 | @property
22 | def name(self):
23 | """Returns the name 'EliminationReshapeAs'."""
24 | return "EliminationReshapeAs"
25 |
26 | def parameter_check(self) -> bool:
27 | shape_node = self.shape
28 | if shape_node.outputs[0].shape is None:
29 | return False
30 |
31 | if len(shape_node.users) != shape_node.outputs[0].shape[0]:
32 | return False
33 |
34 | if not all([user.op == "Gather" for user in shape_node.users]):
35 | return False
36 |
37 | for idx, user in enumerate(shape_node.users):
38 | if not isinstance(user.inputs[1], gs.Constant):
39 | return False
40 |
41 | if user.inputs[1].values.shape != ():
42 | return False
43 |
44 | if user.inputs[1].values != idx:
45 | return False
46 |
47 | concat_node = self.concat
48 | if len(concat_node.inputs) != shape_node.users:
49 | return False
50 |
51 | return True
52 |
53 | def rewrite(self, opset=11):
54 | """Rewrites the pattern by replacing the Concat node with the Shape node."""
55 | match_case = {}
56 | shape_node = self.shape
57 | concat_node = self.concat
58 |
59 | concat_node.replace_all_uses_with(shape_node)
60 |
61 | return match_case
62 |
63 |
64 | register_fusion_pattern(ReshapeAsPatternMatcher(1))
65 |
--------------------------------------------------------------------------------
/onnxslim/core/pattern/elimination/slice.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import onnxslim.third_party.onnx_graphsurgeon as gs
4 | from onnxslim.core.pattern import Pattern, PatternMatcher
5 | from onnxslim.core.pattern.registry import register_fusion_pattern
6 |
7 |
8 | class SlicePatternMatcher(PatternMatcher):
9 | def __init__(self, priority):
10 | """Initializes the SlicePatternMatcher with a specified priority using a predefined graph pattern."""
11 | pattern = Pattern(
12 | """
13 | input input 0 1 slice_0
14 | Slice slice_0 5 1 input ? ? ? ? slice_1
15 | Slice slice_1 5 1 slice_0 ? ? ? ? output
16 | output output 1 0 slice_1
17 | """
18 | ) # to check here slice_0
19 | super().__init__(pattern, priority)
20 |
21 | @property
22 | def name(self):
23 | """Returns the name of the elimination pattern, 'EliminationSlice'."""
24 | return "EliminationSlice"
25 |
26 | def rewrite(self, opset=11):
27 | """Rewrites an elimination pattern for slice nodes by optimizing nested slice operations."""
28 | match_case = {}
29 | first_slice_node = self.slice_0
30 | first_slice_node_inputs = list(first_slice_node.inputs)
31 | if all(isinstance(input, gs.Constant) for input in first_slice_node_inputs[1:]):
32 | first_slice_node_users = first_slice_node.users
33 | if all(
34 | user.op == "Slice" and all(isinstance(input, gs.Constant) for input in list(user.inputs)[1:])
35 | for user in first_slice_node_users
36 | ):
37 | first_slice_node_starts = first_slice_node_inputs[1].values.tolist()
38 | first_slice_node_ends = first_slice_node_inputs[2].values.tolist()
39 | first_slice_node_axes = first_slice_node_inputs[3].values.tolist()
40 | first_slice_node_steps = first_slice_node_inputs[4].values.tolist()
41 |
42 | for user_node in first_slice_node_users:
43 | second_slice_node = user_node
44 | second_slice_node_inputs = list(second_slice_node.inputs)
45 | second_slice_node_starts = second_slice_node_inputs[1].values.tolist()
46 | second_slice_node_ends = second_slice_node_inputs[2].values.tolist()
47 | second_slice_node_axes = second_slice_node_inputs[3].values.tolist()
48 | second_slice_node_steps = second_slice_node_inputs[4].values.tolist()
49 |
50 | new_starts = first_slice_node_starts + second_slice_node_starts
51 | new_ends = first_slice_node_ends + second_slice_node_ends
52 | new_axes = first_slice_node_axes + second_slice_node_axes
53 | new_steps = first_slice_node_steps + second_slice_node_steps
54 |
55 | if len(new_axes) != len(set(new_axes)):
56 | continue
57 |
58 | inputs = []
59 | inputs.extend(
60 | (
61 | list(first_slice_node.inputs)[0],
62 | gs.Constant(
63 | second_slice_node_inputs[1].name,
64 | values=np.array(new_starts, dtype=np.int64),
65 | ),
66 | gs.Constant(
67 | second_slice_node_inputs[2].name,
68 | values=np.array(new_ends, dtype=np.int64),
69 | ),
70 | gs.Constant(
71 | second_slice_node_inputs[3].name,
72 | values=np.array(new_axes, dtype=np.int64),
73 | ),
74 | gs.Constant(
75 | second_slice_node_inputs[4].name,
76 | values=np.array(new_steps, dtype=np.int64),
77 | ),
78 | )
79 | )
80 | outputs = list(second_slice_node.outputs)
81 |
82 | first_slice_node.outputs.clear()
83 | second_slice_node.inputs.clear()
84 | second_slice_node.outputs.clear()
85 |
86 | if len(first_slice_node_users) == 1:
87 | match_case[first_slice_node.name] = {
88 | "op": "Slice",
89 | "inputs": inputs,
90 | "outputs": outputs,
91 | "name": first_slice_node.name,
92 | "attrs": first_slice_node.attrs,
93 | "domain": None,
94 | }
95 | else:
96 | match_case[second_slice_node.name] = {
97 | "op": "Slice",
98 | "inputs": inputs,
99 | "outputs": outputs,
100 | "name": second_slice_node.name,
101 | "attrs": second_slice_node.attrs,
102 | "domain": None,
103 | }
104 |
105 | return match_case
106 |
107 |
108 | register_fusion_pattern(SlicePatternMatcher(1))
109 |
--------------------------------------------------------------------------------
/onnxslim/core/pattern/elimination/unsqueeze.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import onnxslim.third_party.onnx_graphsurgeon as gs
4 | from onnxslim.core.pattern import Pattern, PatternMatcher
5 | from onnxslim.core.pattern.registry import register_fusion_pattern
6 |
7 |
8 | class UnsqueezePatternMatcher(PatternMatcher):
9 | def __init__(self, priority):
10 | """Initializes the UnsqueezePatternMatcher with a specified priority using a predefined graph pattern."""
11 | pattern = Pattern(
12 | """
13 | input input 0 1 unsqueeze_0
14 | Unsqueeze unsqueeze_0 1+ 1 input unsqueeze_1
15 | Unsqueeze unsqueeze_1 1+ 1 unsqueeze_0 output
16 | output output 1 0 unsqueeze_1
17 | """
18 | )
19 | super().__init__(pattern, priority)
20 |
21 | @property
22 | def name(self):
23 | """Returns the name of the elimination pattern, 'EliminationUnsqueeze'."""
24 | return "EliminationUnsqueeze"
25 |
26 | def rewrite(self, opset=11):
27 | """Rewrites an elimination pattern for unsqueeze nodes by optimizing nested slice operations."""
28 | match_case = {}
29 | node_unsqueeze_0 = self.unsqueeze_0
30 | users_node_unsqueeze_0 = node_unsqueeze_0.users
31 | node_unsqueeze_1 = self.unsqueeze_1
32 | if len(users_node_unsqueeze_0) == 1 and node_unsqueeze_0.inputs[0].shape and node_unsqueeze_1.inputs[0].shape:
33 | if opset < 13 or (
34 | isinstance(node_unsqueeze_0.inputs[1], gs.Constant)
35 | and isinstance(node_unsqueeze_1.inputs[1], gs.Constant)
36 | ):
37 |
38 | def get_unsqueeze_axes(unsqueeze_node, opset):
39 | dim = len(unsqueeze_node.inputs[0].shape)
40 | if opset < 13:
41 | axes = unsqueeze_node.attrs["axes"]
42 | else:
43 | axes = unsqueeze_node.inputs[1].values
44 | return [axis + dim + len(axes) if axis < 0 else axis for axis in axes]
45 |
46 | axes_node_unsqueeze_0 = get_unsqueeze_axes(node_unsqueeze_0, opset)
47 | axes_node_unsqueeze_1 = get_unsqueeze_axes(node_unsqueeze_1, opset)
48 |
49 | axes_node_unsqueeze_0 = [
50 | axis + sum(1 for axis_ in axes_node_unsqueeze_1 if axis_ <= axis) for axis in axes_node_unsqueeze_0
51 | ]
52 |
53 | inputs = [node_unsqueeze_0.inputs[0]]
54 | outputs = list(node_unsqueeze_1.outputs)
55 | node_unsqueeze_0.inputs.clear()
56 | node_unsqueeze_0.outputs.clear()
57 | node_unsqueeze_1.inputs.clear()
58 | node_unsqueeze_1.outputs.clear()
59 |
60 | if opset < 13:
61 | attrs = {"axes": axes_node_unsqueeze_0 + axes_node_unsqueeze_1}
62 | else:
63 | attrs = None
64 | inputs.append(
65 | gs.Constant(
66 | name=f"{node_unsqueeze_0.name}_axes",
67 | values=np.array(axes_node_unsqueeze_0 + axes_node_unsqueeze_1, dtype=np.int64),
68 | )
69 | )
70 |
71 | match_case[node_unsqueeze_0.name] = {
72 | "op": "Unsqueeze",
73 | "inputs": inputs,
74 | "outputs": outputs,
75 | "name": node_unsqueeze_0.name,
76 | "attrs": attrs,
77 | "domain": None,
78 | }
79 |
80 | return match_case
81 |
82 |
83 | register_fusion_pattern(UnsqueezePatternMatcher(1))
84 |
--------------------------------------------------------------------------------
/onnxslim/core/pattern/fusion/__init__.py:
--------------------------------------------------------------------------------
1 | from .convadd import *
2 | from .convbn import *
3 | from .gelu import *
4 | from .gemm import *
5 | from .padconv import *
6 | from .reduce import *
7 |
--------------------------------------------------------------------------------
/onnxslim/core/pattern/fusion/convadd.py:
--------------------------------------------------------------------------------
1 | import onnxslim.third_party.onnx_graphsurgeon as gs
2 | from onnxslim.core.pattern import Pattern, PatternMatcher
3 | from onnxslim.core.pattern.registry import register_fusion_pattern
4 |
5 |
6 | class ConvAddMatcher(PatternMatcher):
7 | def __init__(self, priority):
8 | """Initializes the ConvAddMatcher for fusing Conv and Add layers in an ONNX graph."""
9 | pattern = Pattern(
10 | """
11 | input input 0 1 conv_0
12 | Conv conv_0 1+ 1 input bn_0
13 | Add add_0 2 1 conv_0 ? output
14 | output output 1 0 add_0
15 | """
16 | )
17 | super().__init__(pattern, priority)
18 |
19 | @property
20 | def name(self):
21 | """Returns the name of the FusionConvAdd pattern."""
22 | return "FusionConvAdd"
23 |
24 | def rewrite(self, opset=11):
25 | match_case = {}
26 | conv_node = self.conv_0
27 | conv_weight = list(conv_node.inputs)[1]
28 | conv_node_users = conv_node.users
29 | node = self.add_0
30 | if (
31 | len(conv_node_users) == 1
32 | and isinstance(node.inputs[1], gs.Constant)
33 | and isinstance(conv_weight, gs.Constant)
34 | and node.inputs[1].values.squeeze().ndim == 1
35 | and node.inputs[1].values.squeeze().shape[0] == conv_weight.shape[0]
36 | ):
37 | add_node = node
38 | if len(conv_node.inputs) == 2:
39 | conv_bias = node.inputs[1].values.squeeze()
40 | else:
41 | conv_bias = conv_node.inputs[2].values + node.inputs[1].values.squeeze()
42 |
43 | inputs = []
44 | inputs.append(list(conv_node.inputs)[0])
45 | inputs.append(conv_weight)
46 | weight_name = list(conv_node.inputs)[1].name
47 | if weight_name.endswith("weight"):
48 | bias_name = f"{weight_name[:-6]}bias"
49 | else:
50 | bias_name = f"{weight_name}_bias"
51 | inputs.append(gs.Constant(bias_name, values=conv_bias))
52 | outputs = list(add_node.outputs)
53 |
54 | conv_node.outputs.clear()
55 | add_node.inputs.clear()
56 | add_node.outputs.clear()
57 |
58 | match_case[conv_node.name] = {
59 | "op": conv_node.op,
60 | "inputs": inputs,
61 | "outputs": outputs,
62 | "name": conv_node.name,
63 | "attrs": conv_node.attrs,
64 | "domain": None,
65 | }
66 |
67 | return match_case
68 |
69 |
70 | register_fusion_pattern(ConvAddMatcher(1))
71 |
--------------------------------------------------------------------------------
/onnxslim/core/pattern/fusion/convbn.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import onnxslim.third_party.onnx_graphsurgeon as gs
4 | from onnxslim.core.pattern import Pattern, PatternMatcher
5 | from onnxslim.core.pattern.registry import register_fusion_pattern
6 |
7 |
8 | class ConvBatchNormMatcher(PatternMatcher):
9 | def __init__(self, priority):
10 | """Initializes the ConvBatchNormMatcher for fusing Conv and BatchNormalization layers in an ONNX graph."""
11 | pattern = Pattern(
12 | """
13 | input input 0 1 conv_0
14 | Conv conv_0 1+ 1 input bn_0
15 | BatchNormalization bn_0 5 1 conv_0 ? ? ? ? output
16 | output output 1 0 bn_0
17 | """
18 | )
19 | super().__init__(pattern, priority)
20 |
21 | @property
22 | def name(self):
23 | """Returns the name of the FusionConvBN pattern."""
24 | return "FusionConvBN"
25 |
26 | def rewrite(self, opset=11):
27 | """Rewrites the weights and biases of a BatchNormalization layer fused with a convolution layer."""
28 | match_case = {}
29 | conv_transpose_node = self.conv_0
30 | conv_transpose_node_users = conv_transpose_node.users
31 | node = self.bn_0
32 | if len(conv_transpose_node_users) == 1 and isinstance(conv_transpose_node.inputs[1], gs.Constant):
33 | conv_transpose_weight = conv_transpose_node.inputs[1].values
34 | bn_node = node
35 | bn_scale = bn_node.inputs[1].values
36 | bn_bias = bn_node.inputs[2].values
37 | bn_running_mean = bn_node.inputs[3].values
38 | bn_running_var = bn_node.inputs[4].values
39 | bn_eps = bn_node.attrs["epsilon"]
40 |
41 | if len(conv_transpose_node.inputs) == 2:
42 | conv_transpose_bias = np.zeros_like(bn_running_mean)
43 | else:
44 | conv_transpose_bias = conv_transpose_node.inputs[2].values
45 |
46 | bn_var_rsqrt = 1.0 / np.sqrt(bn_running_var + bn_eps)
47 | shape = [1] * len(conv_transpose_weight.shape)
48 | if bn_node.i(0).op == "Conv":
49 | shape[0] = -1
50 | else:
51 | shape[1] = -1
52 | conv_w = conv_transpose_weight * (bn_scale * bn_var_rsqrt).reshape(shape)
53 | conv_b = (conv_transpose_bias - bn_running_mean) * bn_var_rsqrt * bn_scale + bn_bias
54 |
55 | inputs = []
56 | inputs.append(list(conv_transpose_node.inputs)[0])
57 | weight_name = list(conv_transpose_node.inputs)[1].name
58 | if weight_name.endswith("weight"):
59 | bias_name = f"{weight_name[:-6]}bias"
60 | else:
61 | bias_name = f"{weight_name}_bias"
62 | inputs.extend(
63 | (
64 | gs.Constant(weight_name, values=conv_w),
65 | gs.Constant(bias_name, values=conv_b),
66 | )
67 | )
68 | outputs = list(bn_node.outputs)
69 |
70 | conv_transpose_node.outputs.clear()
71 | bn_node.inputs.clear()
72 | bn_node.outputs.clear()
73 |
74 | match_case[conv_transpose_node.name] = {
75 | "op": conv_transpose_node.op,
76 | "inputs": inputs,
77 | "outputs": outputs,
78 | "name": conv_transpose_node.name,
79 | "attrs": conv_transpose_node.attrs,
80 | "domain": None,
81 | }
82 |
83 | return match_case
84 |
85 |
86 | register_fusion_pattern(ConvBatchNormMatcher(1))
87 |
--------------------------------------------------------------------------------
/onnxslim/core/pattern/fusion/gelu.py:
--------------------------------------------------------------------------------
1 | from onnxslim.core.pattern import Pattern, PatternMatcher
2 |
3 |
4 | class GeluPatternMatcher(PatternMatcher):
5 | def __init__(self, priority):
6 | """Initializes a `GeluPatternMatcher` to identify and fuse GELU patterns in a computational graph."""
7 | pattern = Pattern(
8 | """
9 | input input 0 2 mul_0 div_0
10 | Div div_0 2 1 input ? erf_0
11 | Erf erf_0 1 1 div_0 add_0
12 | Add add_0 2 1 erf_0 ? mul_0
13 | Mul mul_0 2 1 input add_0 mul_1
14 | Mul mul_1 2 1 mul_0 ? output
15 | output output 1 0 mul_1
16 | """
17 | )
18 | super().__init__(pattern, priority)
19 |
20 | @property
21 | def name(self):
22 | """Returns the name of the fusion pattern, 'FusionGelu'."""
23 | return "FusionGelu"
24 |
25 | def rewrite(self, opset=11):
26 | """Rewrite the computation graph pattern to fuse GELU operations."""
27 | input_variable = self.div_0.inputs[0]
28 | mul_node = self.mul_0
29 | div_node = self.div_0
30 |
31 | input_variable.outputs.remove(mul_node)
32 | input_variable.outputs.remove(div_node)
33 |
34 | output_variable = self.mul_1.outputs[0]
35 | output_variable.inputs.clear()
36 |
37 | return {
38 | self.mul_1.name: {
39 | "op": "Gelu",
40 | "inputs": [input_variable],
41 | "outputs": [output_variable],
42 | "domain": None,
43 | }
44 | }
45 |
46 |
47 | # register_fusion_pattern(GeluPatternMatcher(1))
48 |
--------------------------------------------------------------------------------
/onnxslim/core/pattern/fusion/gemm.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import onnxslim.third_party.onnx_graphsurgeon as gs
4 | from onnxslim.core.optimization.dead_node_elimination import get_constant_variable
5 | from onnxslim.core.pattern import Pattern, PatternMatcher
6 | from onnxslim.core.pattern.registry import register_fusion_pattern
7 |
8 |
9 | class MatMulAddPatternMatcher(PatternMatcher):
10 | def __init__(self, priority):
11 | """Initializes a matcher for fusing MatMul and Add operations in ONNX graph optimization."""
12 | pattern = Pattern(
13 | """
14 | input input 0 1 matmul_0
15 | MatMul matmul_0 2 1 input ? add_0
16 | Add add_0 2 1 matmul_0 ? output
17 | output output 1 0 add_0
18 | """
19 | )
20 | super().__init__(pattern, priority)
21 |
22 | @property
23 | def name(self):
24 | """Returns the name of the fusion pattern as a string 'FusionGemm'."""
25 | return "FusionGemm"
26 |
27 | def rewrite(self, opset=11):
28 | """Rewrites the graph for the fusion pattern 'FusionGemm' based on matching criteria and constant variables in
29 | matmul nodes.
30 | """
31 | match_case = {}
32 | node = self.add_0
33 | matmul_node = self.matmul_0
34 | matmul_bias_variable = get_constant_variable(matmul_node)
35 | add_bias_variable = get_constant_variable(node)
36 | input_variable = (
37 | matmul_node.inputs[0] if isinstance(matmul_node.inputs[1], gs.Constant) else matmul_node.inputs[1]
38 | )
39 | users = matmul_node.users
40 | if len(users) == 1 and matmul_bias_variable and add_bias_variable and len(matmul_bias_variable.shape) == 2:
41 | if (
42 | input_variable.shape
43 | and len(input_variable.shape) > 2
44 | and all([isinstance(value, int) for value in input_variable.shape])
45 | ):
46 | pre_reshape_const = gs.Constant(
47 | f"{matmul_node.name}_pre_reshape_in",
48 | values=np.array([-1, matmul_bias_variable.values.shape[0]], dtype=np.int64),
49 | )
50 | inputs = []
51 | inputs.append(input_variable)
52 | inputs.append(pre_reshape_const)
53 |
54 | reshape_out_variable = gs.Variable(
55 | f"{matmul_node.name}_pre_reshape_out",
56 | dtype=input_variable.dtype,
57 | )
58 | outputs = [reshape_out_variable]
59 |
60 | match_case.update(
61 | {
62 | f"{matmul_node.name}_pre_reshape": {
63 | "op": "Reshape",
64 | "inputs": inputs,
65 | "outputs": outputs,
66 | "name": f"{matmul_node.name}_pre_reshape",
67 | "domain": None,
68 | }
69 | }
70 | )
71 |
72 | add_node = node
73 | add_bias_variable = get_constant_variable(add_node)
74 |
75 | output_variable = add_node.inputs[0]
76 | output_variable.outputs.remove(add_node)
77 |
78 | matmul_bias_transpose_constant = gs.Constant(
79 | matmul_bias_variable.name, values=matmul_bias_variable.values.T
80 | )
81 |
82 | inputs = []
83 | inputs.append(reshape_out_variable)
84 | inputs.append(matmul_bias_transpose_constant)
85 | inputs.append(add_bias_variable)
86 |
87 | gemm_out_variable = gs.Variable(f"{matmul_node.name}_gemm_out", dtype=output_variable.dtype)
88 | outputs = [gemm_out_variable]
89 |
90 | match_case.update(
91 | {
92 | matmul_node.name: {
93 | "op": "Gemm",
94 | "inputs": inputs,
95 | "outputs": outputs,
96 | "name": matmul_node.name,
97 | "attrs": {
98 | "alpha": 1.0,
99 | "beta": 1.0,
100 | "transA": 0,
101 | "transB": 1,
102 | },
103 | "domain": None,
104 | }
105 | }
106 | )
107 |
108 | values = input_variable.shape[:-1] + [matmul_bias_variable.values.shape[-1]]
109 | post_reshape_const = gs.Constant(
110 | f"{matmul_node.name}_post_reshape_in",
111 | values=np.array(values, dtype=np.int64),
112 | )
113 |
114 | inputs = []
115 | inputs.append(gemm_out_variable)
116 | inputs.append(post_reshape_const)
117 | outputs = list(add_node.outputs)
118 |
119 | matmul_node.outputs.clear()
120 | add_node.inputs.clear()
121 | add_node.outputs.clear()
122 |
123 | match_case.update(
124 | {
125 | f"{matmul_node.name}_post_reshape": {
126 | "op": "Reshape",
127 | "inputs": inputs,
128 | "outputs": outputs,
129 | "name": f"{matmul_node.name}_post_reshape",
130 | "domain": None,
131 | }
132 | }
133 | )
134 | elif (
135 | input_variable.shape
136 | and len(input_variable.shape) == 2
137 | and all([isinstance(value, int) for value in input_variable.shape])
138 | ):
139 | add_node = node
140 | add_bias_variable = get_constant_variable(add_node)
141 |
142 | output_variable = add_node.inputs[0]
143 | output_variable.outputs.remove(add_node)
144 |
145 | matmul_bias_transpose_constant = gs.Constant(
146 | matmul_bias_variable.name, values=matmul_bias_variable.values.T
147 | )
148 |
149 | inputs = []
150 | inputs.append(input_variable)
151 | inputs.append(matmul_bias_transpose_constant)
152 | inputs.append(add_bias_variable)
153 |
154 | outputs = list(add_node.outputs)
155 | add_node.inputs.clear()
156 | add_node.outputs.clear()
157 | match_case.update(
158 | {
159 | matmul_node.name: {
160 | "op": "Gemm",
161 | "inputs": inputs,
162 | "outputs": outputs,
163 | "name": matmul_node.name,
164 | "attrs": {
165 | "alpha": 1.0,
166 | "beta": 1.0,
167 | "transA": 0,
168 | "transB": 1,
169 | },
170 | "domain": None,
171 | }
172 | }
173 | )
174 | return match_case
175 |
176 |
177 | register_fusion_pattern(MatMulAddPatternMatcher(1))
178 |
--------------------------------------------------------------------------------
/onnxslim/core/pattern/fusion/padconv.py:
--------------------------------------------------------------------------------
1 | import onnxslim.third_party.onnx_graphsurgeon as gs
2 | from onnxslim.core.pattern import Pattern, PatternMatcher
3 | from onnxslim.core.pattern.registry import register_fusion_pattern
4 |
5 |
6 | class PadConvMatcher(PatternMatcher):
7 | def __init__(self, priority):
8 | """Initializes the PadConvMatcher with a specified priority and defines its matching pattern."""
9 | pattern = Pattern(
10 | """
11 | input input 0 1 pad_0
12 | Pad pad_0 1+ 1 input conv_0
13 | Conv conv_0 1+ 1 pad_0 output
14 | output output 1 0 conv_0
15 | """
16 | )
17 | super().__init__(pattern, priority)
18 |
19 | @property
20 | def name(self):
21 | """Returns the name of the fusion pattern used."""
22 | return "FusionPadConv"
23 |
24 | def parameter_check(self) -> bool:
25 | """Validates if the padding parameter for a convolutional node is a constant."""
26 | pad_node = self.pad_0
27 | return isinstance(pad_node.inputs[1], gs.Constant)
28 |
29 | def rewrite(self, opset=11):
30 | """Rewrites the padding parameter for a convolutional node to use a constant if the current parameter is not a
31 | constant.
32 | """
33 | match_case = {}
34 | conv_node = self.conv_0
35 | pad_node = self.pad_0
36 | pad_node_users = pad_node.users
37 |
38 | pad_inputs = len(pad_node.inputs)
39 | if pad_inputs < 3 or (
40 | pad_inputs >= 3 and (isinstance(pad_node.inputs[2], gs.Constant) and pad_node.inputs[2].values == 0)
41 | ):
42 | if (
43 | isinstance(pad_node.inputs[1], gs.Constant)
44 | and pad_node.attrs["mode"] == "constant"
45 | and conv_node.inputs[1].shape
46 | ):
47 | conv_weight_dim = len(conv_node.inputs[1].shape)
48 | pad_value = pad_node.inputs[1].values.tolist()
49 |
50 | if all(pad == 0 for pad in (pad_value[:2] + pad_value[conv_weight_dim : conv_weight_dim + 2])):
51 | conv_weight_dim - 2
52 | input_variable = self.pad_0.inputs[0]
53 | pad_variable = pad_node.outputs[0] # pad output variable
54 | index = conv_node.inputs.index(pad_variable)
55 | conv_node.inputs.pop(index)
56 | conv_node.inputs.insert(index, input_variable)
57 |
58 | inputs = list(conv_node.inputs)
59 | outputs = list(conv_node.outputs)
60 | attrs = conv_node.attrs
61 |
62 | conv_node.inputs.clear()
63 | conv_node.outputs.clear()
64 | # remove pad node if it has only one user
65 | if len(pad_node_users) == 0:
66 | input_variable.outputs.remove(pad_node)
67 | pad_node.inputs.clear()
68 | pad_node.outputs.clear()
69 |
70 | conv_pads = attrs["pads"]
71 | pads = pad_value[2:conv_weight_dim] + pad_value[conv_weight_dim + 2 :]
72 | pads = [pad + conv_pad for pad, conv_pad in zip(pads, conv_pads)]
73 |
74 | attrs["pads"] = pads
75 | match_case[conv_node.name] = {
76 | "op": "Conv",
77 | "inputs": inputs,
78 | "outputs": outputs,
79 | "name": conv_node.name,
80 | "attrs": conv_node.attrs,
81 | "domain": None,
82 | }
83 |
84 | return match_case
85 |
86 |
87 | register_fusion_pattern(PadConvMatcher(1))
88 |
--------------------------------------------------------------------------------
/onnxslim/core/pattern/fusion/reduce.py:
--------------------------------------------------------------------------------
1 | from onnxslim.core.pattern import Pattern, PatternMatcher
2 | from onnxslim.core.pattern.registry import register_fusion_pattern
3 |
4 |
5 | class ReducePatternMatcher(PatternMatcher):
6 | def __init__(self, priority):
7 | """Initializes the ReducePatternMatcher with a specified pattern matching priority level."""
8 | pattern = Pattern(
9 | """
10 | input input 0 1 reduce_0
11 | ReduceSum reduce_0 1 1 input unsqueeze_0
12 | Unsqueeze unsqueeze_0 1 1 reduce_0 output
13 | output output 1 0 unsqueeze_0
14 | """
15 | )
16 | super().__init__(pattern, priority)
17 |
18 | @property
19 | def name(self):
20 | """Returns the name of the fusion pattern 'FusionReduce'."""
21 | return "FusionReduce"
22 |
23 | def rewrite(self, opset=11):
24 | """Rewrites the graph pattern based on opset version; reuses Reduce and Unsqueeze nodes if possible."""
25 | match_case = {}
26 | node = self.unsqueeze_0
27 | reduce_node = self.reduce_0
28 | reduce_node_node_users = reduce_node.users
29 | if len(reduce_node_node_users) == 1:
30 | unsqueeze_node = node
31 |
32 | reduce_node_axes = reduce_node.attrs.get("axes", None)
33 | reduce_node_keepdims = reduce_node.attrs.get("keepdims", 1)
34 | unsqueeze_node_axes = unsqueeze_node.attrs.get("axes", None)
35 |
36 | if opset < 13 and reduce_node_axes == [-1] and unsqueeze_node_axes == [-1] and reduce_node_keepdims == 0:
37 | inputs = list(reduce_node.inputs)
38 | outputs = list(unsqueeze_node.outputs)
39 | attrs = reduce_node.attrs
40 | reduce_node.outputs.clear()
41 | unsqueeze_node.inputs.clear()
42 | unsqueeze_node.outputs.clear()
43 | attrs["keepdims"] = 1
44 | match_case[reduce_node.name] = {
45 | "op": reduce_node.op,
46 | "inputs": inputs,
47 | "outputs": outputs,
48 | "name": reduce_node.name,
49 | "attrs": attrs,
50 | "domain": None,
51 | }
52 |
53 | return match_case
54 |
55 |
56 | register_fusion_pattern(ReducePatternMatcher(1))
57 |
--------------------------------------------------------------------------------
/onnxslim/core/pattern/registry.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | DEFAULT_FUSION_PATTERNS = OrderedDict()
4 |
5 |
6 | def register_fusion_pattern(fusion_pattern):
7 | """Registers a fusion pattern function for a specified layer type in the DEFAULT_FUSION_PATTERNS dictionary."""
8 | layer_type = fusion_pattern.name
9 |
10 | if layer_type in DEFAULT_FUSION_PATTERNS.keys():
11 | raise
12 | DEFAULT_FUSION_PATTERNS[layer_type] = fusion_pattern
13 |
14 |
15 | def get_fusion_patterns(skip_fusion_patterns: str = None):
16 | """Returns a copy of the default fusion patterns, optionally excluding specific patterns."""
17 | default_fusion_patterns = DEFAULT_FUSION_PATTERNS.copy()
18 | if skip_fusion_patterns:
19 | for pattern in skip_fusion_patterns:
20 | default_fusion_patterns.pop(pattern)
21 |
22 | return default_fusion_patterns
23 |
24 |
25 | from .elimination import *
26 | from .fusion import *
27 |
--------------------------------------------------------------------------------
/onnxslim/misc/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/inisis/OnnxSlim/4e820ced75ed909f96b5b92bc46cebf4ca3a2156/onnxslim/misc/__init__.py
--------------------------------------------------------------------------------
/onnxslim/misc/font.py:
--------------------------------------------------------------------------------
1 | RED = "\033[31m" # Red text
2 | WHITE = "\033[37m" # White text
3 | GREEN = "\033[32m" # Green text
4 |
--------------------------------------------------------------------------------
/onnxslim/third_party/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/inisis/OnnxSlim/4e820ced75ed909f96b5b92bc46cebf4ca3a2156/onnxslim/third_party/__init__.py
--------------------------------------------------------------------------------
/onnxslim/third_party/onnx_graphsurgeon/__init__.py:
--------------------------------------------------------------------------------
1 | from onnxslim.third_party.onnx_graphsurgeon.exporters.onnx_exporter import export_onnx
2 | from onnxslim.third_party.onnx_graphsurgeon.graph_pattern import (
3 | GraphPattern,
4 | PatternMapping,
5 | )
6 | from onnxslim.third_party.onnx_graphsurgeon.importers.onnx_importer import import_onnx
7 | from onnxslim.third_party.onnx_graphsurgeon.ir.function import Function
8 | from onnxslim.third_party.onnx_graphsurgeon.ir.graph import Graph
9 | from onnxslim.third_party.onnx_graphsurgeon.ir.node import Node
10 | from onnxslim.third_party.onnx_graphsurgeon.ir.tensor import Constant, Tensor, Variable
11 | from onnxslim.third_party.onnx_graphsurgeon.util.exception import (
12 | OnnxGraphSurgeonException,
13 | )
14 |
15 | __version__ = "0.5.1"
16 |
--------------------------------------------------------------------------------
/onnxslim/third_party/onnx_graphsurgeon/exporters/__init__.py:
--------------------------------------------------------------------------------
1 | from onnxslim.third_party.onnx_graphsurgeon.exporters.base_exporter import BaseExporter
2 |
--------------------------------------------------------------------------------
/onnxslim/third_party/onnx_graphsurgeon/exporters/base_exporter.py:
--------------------------------------------------------------------------------
1 | #
2 | # SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | # SPDX-License-Identifier: Apache-2.0
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | #
17 |
18 | from onnxslim.third_party.onnx_graphsurgeon.ir.graph import Graph
19 |
20 |
21 | class BaseExporter:
22 | @staticmethod
23 | def export_graph(graph: Graph):
24 | """
25 | Export a graph to some destination graph.
26 |
27 | Args:
28 | graph (Graph): The source graph to export.
29 |
30 | Returns:
31 | object: The exported graph. For example, this might be an onnx.GraphProto
32 | """
33 | raise NotImplementedError("BaseExporter is an abstract class")
34 |
--------------------------------------------------------------------------------
/onnxslim/third_party/onnx_graphsurgeon/exporters/onnx_exporter.py:
--------------------------------------------------------------------------------
1 | #
2 | # SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | # SPDX-License-Identifier: Apache-2.0
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | #
17 | from collections import OrderedDict
18 | from typing import List, Sequence, Union
19 |
20 | import numpy as np
21 | import onnx
22 | import onnx.numpy_helper
23 |
24 | from onnxslim.third_party.onnx_graphsurgeon.exporters.base_exporter import BaseExporter
25 | from onnxslim.third_party.onnx_graphsurgeon.ir.function import Function
26 | from onnxslim.third_party.onnx_graphsurgeon.ir.graph import Graph
27 | from onnxslim.third_party.onnx_graphsurgeon.ir.node import Node
28 | from onnxslim.third_party.onnx_graphsurgeon.ir.tensor import (
29 | Constant,
30 | LazyValues,
31 | SparseValues,
32 | Tensor,
33 | Variable,
34 | )
35 | from onnxslim.third_party.onnx_graphsurgeon.logger import G_LOGGER
36 | from onnxslim.third_party.onnx_graphsurgeon.util import misc
37 |
38 |
39 | def dtype_to_onnx(dtype: Union[np.dtype, "onnx.TensorProto.DataType"]) -> int:
40 | """Converts a numpy dtype or ONNX data type to its integer representation."""
41 | if isinstance(dtype, int):
42 | return dtype
43 | return onnx.helper.np_dtype_to_tensor_dtype(np.dtype(dtype))
44 |
45 |
46 | def check_duplicate_node_names(nodes: Sequence[Node], level=G_LOGGER.WARNING):
47 | """Check if node names are unique and log any duplicates based on the specified severity level."""
48 | # Note:
49 | # Empty string or None attribute values are not considered duplicates.
50 | name_map = {}
51 | for node in nodes:
52 | if not node.name:
53 | continue
54 | if node.name in name_map:
55 | msg = f"Found distinct Nodes that share the same name:\n[id: {id(name_map[node.name])}]:\n {name_map[node.name]}---\n[id: {id(node)}]:\n {node}\n"
56 | G_LOGGER.log(msg, level)
57 | else:
58 | name_map[node.name] = node
59 |
60 |
61 | def update_import_domains(graph):
62 | """Update the import_domains field of a graph to include its ONNX opset and other used non-ONNX domains."""
63 | # as well as other non-ONNX domains which are used by this graph's nodes.
64 | # Returns the updated value of the import_domains field.
65 |
66 | # Add domain of the standard ONNX opset.
67 | if graph.import_domains is None:
68 | graph.import_domains = [onnx.helper.make_opsetid("", graph.opset)]
69 |
70 | # Crawl over all nodes in this graph and its subgraphs, and add the nodes' domains.
71 | all_used_domains = {node.domain for node in graph.nodes}
72 | for subgraph in graph.subgraphs(recursive=True):
73 | all_used_domains |= {n.domain for n in subgraph.nodes}
74 | all_used_domains.discard(None)
75 |
76 | # Update self.import_domains with any missing domains.
77 | current_domains = {opsetid.domain for opsetid in graph.import_domains}
78 | DEFAULT_CUSTOM_OPSET_VERSION = 1
79 | for used_domain in all_used_domains:
80 | if used_domain not in current_domains:
81 | graph.import_domains.append(onnx.helper.make_opsetid(used_domain, DEFAULT_CUSTOM_OPSET_VERSION))
82 | current_domains.add(used_domain)
83 | return graph.import_domains
84 |
85 |
86 | # Converts a fp32 gs.Constant to a bf16 onnx.TensorProto
87 | def tensor_to_onnx_bf16(tensor: Constant):
88 | """Converts an fp32 gs.Constant tensor to a bf16 onnx.TensorProto."""
89 |
90 | def np_float32_to_bf16_as_uint16(arr):
91 | new_arr = np.empty(arr.size, dtype=np.uint16)
92 | flatten = arr.flatten()
93 | for i in range(arr.size):
94 | new_arr[i] = onnx.helper.float32_to_bfloat16(flatten[i])
95 | return new_arr.reshape(arr.shape)
96 |
97 | arr_bf16_as_uint16 = np_float32_to_bf16_as_uint16(tensor.values)
98 |
99 | onnx_tensor = onnx.TensorProto()
100 | onnx_tensor.data_type = onnx.TensorProto.BFLOAT16
101 | onnx_tensor.dims.extend(arr_bf16_as_uint16.shape)
102 | onnx_tensor.raw_data = arr_bf16_as_uint16.tobytes()
103 |
104 | return onnx_tensor
105 |
106 |
107 | class OnnxExporter(BaseExporter):
108 | @staticmethod
109 | def export_tensor_proto(tensor: Constant) -> onnx.TensorProto:
110 | # Do *not* load LazyValues into an intermediate numpy array - instead, use
111 | """Converts a gs.Constant tensor to an onnx.TensorProto with type and data location handling."""
112 | # the original onnx.TensorProto directly.
113 | if isinstance(tensor._values, LazyValues):
114 | onnx_tensor = tensor._values.tensor
115 | else:
116 | if dtype_to_onnx(tensor.dtype) != dtype_to_onnx(tensor.export_dtype):
117 | assert tensor.dtype == np.float32, (
118 | f"Cannot convert onnx dtype {dtype_to_onnx(tensor.dtype)} to {dtype_to_onnx(tensor.export_dtype)}."
119 | "Only float32 to bfloat16 is supported"
120 | )
121 | assert tensor.export_dtype == onnx.TensorProto.BFLOAT16, (
122 | f"Cannot convert onnx dtype {dtype_to_onnx(tensor.dtype)} to {dtype_to_onnx(tensor.export_dtype)}."
123 | "Only float32 to bfloat16 is supported"
124 | )
125 | onnx_tensor = tensor_to_onnx_bf16(tensor)
126 | else:
127 | onnx_tensor = onnx.numpy_helper.from_array(tensor.values)
128 |
129 | if tensor.data_location is not None:
130 | onnx_tensor.data_location = tensor.data_location
131 | onnx_tensor.name = tensor.name
132 | return onnx_tensor
133 |
134 | @staticmethod
135 | def export_sparse_tensor_proto(tensor: Constant) -> onnx.SparseTensorProto:
136 | """Exports a given Constant tensor as an ONNX SparseTensorProto."""
137 | return tensor._values.tensor
138 |
139 | @staticmethod
140 | def export_value_info_proto(tensor: Tensor, do_type_check: bool) -> onnx.ValueInfoProto:
141 | """Creates an ONNX ValueInfoProto from a Tensor, optionally checking for dtype information."""
142 | if do_type_check and tensor.dtype is None:
143 | G_LOGGER.critical(
144 | f"Graph input and output tensors must include dtype information. Please set the dtype attribute for: {tensor}"
145 | )
146 |
147 | if tensor.dtype is None:
148 | onnx_tensor = onnx.helper.make_empty_tensor_value_info(tensor.name)
149 | elif isinstance(tensor, Constant) or tensor.type == "tensor_type":
150 | onnx_tensor = onnx.helper.make_tensor_value_info(tensor.name, dtype_to_onnx(tensor.dtype), tensor.shape)
151 | elif tensor.type == "sequence_type":
152 | onnx_tensor = onnx.helper.make_tensor_sequence_value_info(
153 | tensor.name, dtype_to_onnx(tensor.dtype), tensor.shape
154 | )
155 | elif tensor.type == "sparse_tensor_type":
156 | onnx_tensor = onnx.helper.make_sparse_tensor_value_info(
157 | tensor.name, dtype_to_onnx(tensor.dtype), tensor.shape
158 | )
159 | return onnx_tensor
160 |
161 | @staticmethod
162 | def export_attributes(attrs: dict, subgraph_tensor_map=None) -> List[onnx.AttributeProto]:
163 | """Convert function attributes to ONNX AttributeProtos for model export."""
164 | onnx_attrs: List[onnx.AttributeProto] = []
165 | for key, val in attrs.items():
166 | if isinstance(val, Tensor):
167 | val = OnnxExporter.export_tensor_proto(val)
168 | elif isinstance(val, Graph):
169 | # Subgraphs don't need to have types specified for their tensors.
170 | val = OnnxExporter.export_graph(val, subgraph_tensor_map=subgraph_tensor_map, do_type_check=False)
171 | elif isinstance(val, Node.AttributeRef):
172 | onnx_attr = onnx.AttributeProto()
173 | onnx_attr.name = key
174 | onnx_attr.type = misc.convert_to_onnx_attr_type(val.type)
175 |
176 | # Netron has a bug which makes it crash if a Tensor attribute has no tensor data.
177 | # So provide some meaningless tensor data for Netron to read.
178 | if val.type == Tensor:
179 | tensor_proto = OnnxExporter.export_tensor_proto(Constant("", np.array([0], dtype=np.float32)))
180 | onnx_attr.t.CopyFrom(tensor_proto)
181 |
182 | onnx_attr.ref_attr_name = val.name
183 | onnx_attrs.append(onnx_attr)
184 | continue
185 | elif isinstance(val, type):
186 | # May be a numpy type
187 | try:
188 | val = dtype_to_onnx(val)
189 | except TypeError:
190 | pass
191 | onnx_attrs.append(onnx.helper.make_attribute(key, val))
192 | return onnx_attrs
193 |
194 | @staticmethod
195 | def export_node(node: Node, subgraph_tensor_map=None) -> onnx.NodeProto:
196 | # Cannot pass in attrs directly as make_node will change the order
197 | """Static method to convert an internal node to an ONNX node representation."""
198 | onnx_node = onnx.helper.make_node(
199 | node.op,
200 | inputs=[t.name for t in node.inputs],
201 | outputs=[t.name for t in node.outputs],
202 | name=node.name,
203 | domain=node.domain,
204 | )
205 | onnx_node.attribute.extend(OnnxExporter.export_attributes(node.attrs, subgraph_tensor_map))
206 | return onnx_node
207 |
208 | @staticmethod
209 | def export_function(func: Function) -> onnx.FunctionProto:
210 | """
211 | Export an onnx-graphsurgeon Function to an ONNX FunctionProto.
212 |
213 | Args:
214 | func (Function): The function to export.
215 | """
216 | # Unlike onnx Graphs, onnx Functions don't have an 'initializer' field.
217 | # So we need to replace all Constant tensors with onnx Constant nodes which produce them.
218 | # We need to be careful to (a) preserve topological ordering and (b) not make the new nodes visible to the user.
219 | func_nodes = func.nodes.copy()
220 | new_const_nodes = [
221 | Node("Constant", attrs={"value": tensor}, outputs=[tensor.copy()])
222 | for tensor in func.tensors().values()
223 | if isinstance(tensor, Constant)
224 | ]
225 | # Const nodes have no inputs, so this maintains a topological ordering.
226 | func_nodes = new_const_nodes + func_nodes
227 |
228 | check_duplicate_node_names(func_nodes, level=G_LOGGER.WARNING)
229 | nodes = [OnnxExporter.export_node(node) for node in func_nodes]
230 |
231 | # Update the import_domains field to include all domains used by this function.
232 | opset_imports = update_import_domains(func)
233 |
234 | onnx_inputs = [inp.name for inp in func.inputs]
235 | onnx_outputs = [out.name for out in func.outputs]
236 |
237 | attributes = []
238 | attribute_protos = {}
239 | for attr_name, default_val in func.attrs.items():
240 | if default_val is None:
241 | attributes.append(attr_name)
242 | else:
243 | attribute_protos[attr_name] = default_val
244 | attribute_protos = OnnxExporter.export_attributes(attribute_protos)
245 |
246 | return onnx.helper.make_function(
247 | func.domain or "",
248 | func.name,
249 | onnx_inputs,
250 | onnx_outputs,
251 | nodes,
252 | opset_imports,
253 | attributes=attributes,
254 | attribute_protos=attribute_protos,
255 | doc_string=func.doc_string,
256 | )
257 |
258 | @staticmethod
259 | def export_graph(
260 | graph: Graph,
261 | tensor_map: "OrderedDict[str, Tensor]" = None,
262 | subgraph_tensor_map: "OrderedDict[str, Tensor]" = None,
263 | do_type_check=True,
264 | ) -> onnx.GraphProto:
265 | """
266 | Export an onnx-graphsurgeon Graph to an ONNX GraphProto.
267 |
268 | Args:
269 | graph (Graph): The graph to export.
270 |
271 | do_type_check (bool): Whether to check that input and output tensors have data types defined, and fail if not.
272 | Defaults to True.
273 | """
274 | check_duplicate_node_names(graph.nodes, level=G_LOGGER.WARNING)
275 | nodes = [OnnxExporter.export_node(node, subgraph_tensor_map) for node in graph.nodes]
276 | inputs = [OnnxExporter.export_value_info_proto(inp, do_type_check) for inp in graph.inputs]
277 | outputs = [OnnxExporter.export_value_info_proto(out, do_type_check) for out in graph.outputs]
278 | if tensor_map is None:
279 | tensor_map = graph.tensors()
280 | tensor_map = misc.unique_dicts(tensor_map, subgraph_tensor_map)
281 | else:
282 | tensor_map = misc.combine_dicts(tensor_map, subgraph_tensor_map)
283 | initializer = [
284 | OnnxExporter.export_tensor_proto(tensor)
285 | for tensor in tensor_map.values()
286 | if isinstance(tensor, Constant) and not isinstance(tensor._values, SparseValues)
287 | ]
288 |
289 | sparse_initializer = [
290 | OnnxExporter.export_sparse_tensor_proto(tensor)
291 | for tensor in tensor_map.values()
292 | if isinstance(tensor, Constant) and isinstance(tensor._values, SparseValues)
293 | ]
294 |
295 | # Remove inputs and outputs to export ValueInfoProtos
296 | for tensor in graph.inputs + graph.outputs:
297 | if tensor.name in tensor_map:
298 | del tensor_map[tensor.name]
299 |
300 | # Omit tensors from value_info if we don't know their shape/dtype
301 | def has_value_info(tensor):
302 | """Check if a tensor is a Variable with either a defined dtype or shape."""
303 | return isinstance(tensor, Variable) and (tensor.dtype is not None or tensor.shape is not None)
304 |
305 | value_info = [
306 | OnnxExporter.export_value_info_proto(tensor, do_type_check)
307 | for tensor in tensor_map.values()
308 | if has_value_info(tensor)
309 | ]
310 |
311 | return onnx.helper.make_graph(
312 | nodes=nodes,
313 | name=graph.name,
314 | inputs=inputs,
315 | outputs=outputs,
316 | initializer=initializer,
317 | sparse_initializer=sparse_initializer,
318 | doc_string=graph.doc_string,
319 | value_info=value_info,
320 | )
321 |
322 |
323 | def export_onnx(graph: Graph, do_type_check=True, **kwargs) -> "onnx.ModelProto":
324 | """
325 | Exports an onnx-graphsurgeon Graph to an ONNX model.
326 |
327 | Args:
328 | graph (Graph): The graph to export
329 |
330 | do_type_check (bool): Whether to check that input and output tensors have data types defined, and fail if not.
331 | Defaults to True.
332 | kwargs: Additional arguments to onnx.helper.make_model
333 |
334 | Returns:
335 | onnx.ModelProto: A corresponding ONNX model.
336 | """
337 | sub_graphs = graph.subgraphs(recursive=True)
338 |
339 | graph_constants_list = [
340 | {name: tensor for name, tensor in sub_graph.tensors().items() if isinstance(tensor, Constant)}
341 | for sub_graph in sub_graphs
342 | ]
343 |
344 | if not graph_constants_list:
345 | intersection = None
346 | else:
347 | intersection = (
348 | {
349 | key: graph_constants_list[0][key]
350 | for key in graph_constants_list[0]
351 | if all(key in d and graph_constants_list[0][key] == d[key] for d in graph_constants_list[1:])
352 | }
353 | if graph_constants_list
354 | else None
355 | )
356 |
357 | onnx_graph = OnnxExporter.export_graph(
358 | graph, tensor_map=graph.tensors(), subgraph_tensor_map=intersection, do_type_check=do_type_check
359 | )
360 | onnx_functions = [OnnxExporter.export_function(func) for func in graph.functions]
361 | kwargs["functions"] = onnx_functions
362 |
363 | if "opset_imports" not in kwargs:
364 | kwargs["opset_imports"] = update_import_domains(graph)
365 |
366 | if "ir_version" not in kwargs and graph.ir_version is not None:
367 | kwargs["ir_version"] = graph.ir_version
368 |
369 | model = onnx.helper.make_model(onnx_graph, **kwargs)
370 | if graph.metadata_props is not None:
371 | model.metadata_props.extend(graph.metadata_props)
372 | model.producer_name = graph.producer_name
373 | model.producer_version = graph.producer_version
374 | return model
375 |
--------------------------------------------------------------------------------
/onnxslim/third_party/onnx_graphsurgeon/graph_pattern/__init__.py:
--------------------------------------------------------------------------------
1 | from onnxslim.third_party.onnx_graphsurgeon.graph_pattern.graph_pattern import (
2 | GraphPattern,
3 | PatternMapping,
4 | )
5 |
--------------------------------------------------------------------------------
/onnxslim/third_party/onnx_graphsurgeon/importers/__init__.py:
--------------------------------------------------------------------------------
1 | from onnxslim.third_party.onnx_graphsurgeon.importers.base_importer import BaseImporter
2 |
--------------------------------------------------------------------------------
/onnxslim/third_party/onnx_graphsurgeon/importers/base_importer.py:
--------------------------------------------------------------------------------
1 | #
2 | # SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | # SPDX-License-Identifier: Apache-2.0
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | #
17 |
18 | from onnxslim.third_party.onnx_graphsurgeon.ir.graph import Graph
19 |
20 |
21 | class BaseImporter:
22 | @staticmethod
23 | def import_graph(graph) -> Graph:
24 | """
25 | Import a graph from some source graph.
26 |
27 | Args:
28 | graph (object): The source graph to import. For example, this might be an onnx.GraphProto.
29 |
30 | Returns:
31 | Graph: The equivalent onnx-graphsurgeon graph.
32 | """
33 | raise NotImplementedError("BaseImporter is an abstract class")
34 |
--------------------------------------------------------------------------------
/onnxslim/third_party/onnx_graphsurgeon/ir/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/inisis/OnnxSlim/4e820ced75ed909f96b5b92bc46cebf4ca3a2156/onnxslim/third_party/onnx_graphsurgeon/ir/__init__.py
--------------------------------------------------------------------------------
/onnxslim/third_party/onnx_graphsurgeon/ir/function.py:
--------------------------------------------------------------------------------
1 | #
2 | # SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | # SPDX-License-Identifier: Apache-2.0
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | #
17 |
18 | import copy
19 | from typing import List, Sequence
20 |
21 | from onnxslim.third_party.onnx_graphsurgeon.ir.graph import Graph
22 | from onnxslim.third_party.onnx_graphsurgeon.ir.node import Node
23 | from onnxslim.third_party.onnx_graphsurgeon.ir.tensor import Tensor, Variable
24 | from onnxslim.third_party.onnx_graphsurgeon.logger import G_LOGGER
25 | from onnxslim.third_party.onnx_graphsurgeon.util import misc
26 |
27 |
28 | class Function(Graph):
29 | """
30 | Represents a local function, which is a default implementation of a Custom Op. This default implementation is
31 | represented as a Graph of other Ops.
32 |
33 | Functions are used in a model by creating a Node with the same name and domain as the function. This can be done
34 | using the __call__() method of a Function, which creates this new node and appends it to a Graph. A Function is not
35 | a subgraph of a Graph, and its Nodes, Tensors, and subgraphs are entirely separate from the main Graph.
36 |
37 | Functions can be composed of other functions, but cyclical or recursive definitions are not allowed in ONNX.
38 | """
39 |
40 | DEFAULT_DOMAIN = "onnx_graphsurgeon"
41 |
42 | def __init__(
43 | self,
44 | name: str,
45 | domain: str = None,
46 | nodes: Sequence[Node] = None,
47 | inputs: Sequence[Tensor] = None,
48 | outputs: Sequence[Tensor] = None,
49 | doc_string: str = None,
50 | opset: int = None,
51 | import_domains: "Sequence[onnx.OperatorSetIdProto]" = None,
52 | functions: "Sequence[Function]" = None,
53 | attrs: dict = None,
54 | ):
55 | """
56 | Args:
57 | name (str): The name of the function.
58 | domain (str): The domain/namespace of this function.
59 | nodes (Sequence[Node]): A list of the nodes in this function.
60 | inputs (Sequence[Tensor]): A list of graph input Tensors.
61 | outputs (Sequence[Tensor]): A list of graph output Tensors.
62 | doc_string (str): A doc_string for the function. Defaults to "".
63 | opset (int): The ONNX opset used by nodes in this function.
64 | import_domains (Sequence[onnx.OperatorSetIdProto]): The list of domains used by nodes in this function.
65 | functions (Sequence[Function]): The list of functions in this model.
66 | attrs (dict): A mapping of attribute names to their default values.
67 | Nodes within this function can have attributes which take on the values of the Function attributes.
68 | When a Function is instantiated into a Node, providing attributes to that Node will override the Function's
69 | default attribute values. A default value of `None` means that the instantiated Node must provide the value
70 | of that attribute (in other words, it is a required attribute).
71 | """
72 | self.domain = misc.default_value(domain, Function.DEFAULT_DOMAIN)
73 | self.attrs = misc.default_value(attrs, {})
74 |
75 | super().__init__(
76 | nodes,
77 | inputs,
78 | outputs,
79 | name=name,
80 | doc_string=doc_string,
81 | opset=opset,
82 | import_domains=import_domains,
83 | functions=functions,
84 | )
85 |
86 | # Properties of Graph that Function doesn't have.
87 | del self.producer_name
88 | del self.producer_version
89 |
90 | @property
91 | def unique_id(self):
92 | """Returns a tuple which uniquely identifies this function."""
93 | return (self.domain, self.name)
94 |
95 | def cleanup(
96 | self,
97 | remove_unused_node_outputs=False,
98 | recurse_subgraphs=True,
99 | remove_unused_graph_inputs=False,
100 | recurse_functions=False,
101 | ):
102 | """See Graph.cleanup() The only difference is that 'recurse_functions' defaults to False, so that only this
103 | Function is cleaned up.
104 | """
105 | if recurse_functions:
106 | G_LOGGER.warning(
107 | "Function.cleanup() called with recurse_functions=True, meaning that other functions will also be cleaned up."
108 | )
109 | return super().cleanup(
110 | remove_unused_node_outputs=remove_unused_node_outputs,
111 | recurse_subgraphs=recurse_subgraphs,
112 | remove_unused_graph_inputs=remove_unused_graph_inputs,
113 | recurse_functions=recurse_functions,
114 | )
115 |
116 | def fold_constants(self, recurse_functions=False, **kwargs):
117 | """See Graph.fold_constants() The only difference is that 'recurse_functions' defaults to False, so that only
118 | this Function's constants are folded.
119 | """
120 | if recurse_functions:
121 | G_LOGGER.warning(
122 | "Function.fold_constants() called with recurse_functions=True, meaning that other functions will also be const-folded."
123 | )
124 | return super().fold_constants(recurse_functions=recurse_functions, **kwargs)
125 |
126 | def toposort(
127 | self,
128 | recurse_subgraphs=True,
129 | recurse_functions=False,
130 | mode="nodes",
131 | ):
132 | """See Graph.toposort() The only difference is that 'recurse_functions' defaults to False and mode defaults to
133 | "nodes", so that by default only this function's nodes will be sorted.
134 | """
135 | if recurse_functions:
136 | G_LOGGER.warning(
137 | "Function.toposort() called with recurse_functions=True, meaning that other functions will be sorted."
138 | )
139 | return super().toposort(
140 | recurse_subgraphs=recurse_subgraphs,
141 | recurse_functions=recurse_functions,
142 | mode=mode,
143 | )
144 |
145 | def __call__(self, graph, inputs=None, outputs=None, *args, **kwargs) -> List[Tensor]:
146 | """
147 | Creates a Node which is an instance of this function. The created node can be used in a Graph or another
148 | Function.
149 |
150 | The provided inputs are processed the same way as in Graph.layer().
151 | If outputs are not provided, they are created based on the Function's outputs.
152 |
153 | Args:
154 | graph (Union[Graph, Function]): The Graph of Function to add the new node to.
155 | inputs (List[Union[Tensor, str, numpy.ndarray]]): The list of inputs.
156 | outputs (List[Union[Tensor, str, numpy.ndarray]]): The list of outputs.
157 | attrs (Dict[str, Any]): A list of attributes for the node.
158 | The attribute names should be a subset of this Function's attribute names.
159 | args/kwargs: These are passed directly to the constructor of Node.
160 |
161 | Returns:
162 | List[Tensor]: The output tensors of the node.
163 | """
164 | if inputs is not None and len(inputs) != len(self.inputs):
165 | msg_template = "Function {} expects {} inputs, but was called with {} inputs."
166 | G_LOGGER.warning(msg_template.format(self.name, len(self.inputs), len(inputs)))
167 |
168 | new_output_indices = []
169 | if outputs is None:
170 | # Graph.layer() will create Tensors and make sure the names do not conflict.
171 | outputs = [out.name for out in self.outputs]
172 | new_output_indices = list(range(len(outputs)))
173 | elif len(outputs) != len(self.outputs):
174 | msg_template = "Function {} expects {} outputs, but was called with {} outputs."
175 | G_LOGGER.warning(msg_template.format(self.name, len(self.outputs), len(outputs)))
176 | else:
177 | new_output_indices = [i for i in range(len(outputs)) if not isinstance(outputs[i], Tensor)]
178 |
179 | attrs = kwargs.get("attrs", None)
180 | if attrs is not None:
181 | for attr_name, default_val in self.attrs.items():
182 | if default_val is None and attr_name not in attrs:
183 | msg_template = "Function {} called without required attribute: {}"
184 | G_LOGGER.warning(msg_template.format(self.name, attr_name))
185 |
186 | inputs = misc.default_value(inputs, [])
187 | outputs = misc.default_value(outputs, [])
188 | outputs = graph.layer(
189 | *args,
190 | **kwargs,
191 | op=self.name,
192 | domain=self.domain,
193 | inputs=inputs,
194 | outputs=outputs,
195 | )
196 |
197 | # For newly created output tensors, set their shape and dtype to match the Function definition.
198 | for i in new_output_indices:
199 | outputs[i].dtype = self.outputs[i].dtype
200 | outputs[i].shape = self.outputs[i].shape
201 |
202 | return outputs
203 |
204 | def copy(self):
205 | """
206 | Copy the function.
207 |
208 | This makes copies of all nodes and tensors in the function, but will not
209 | do a deep-copy of weights or attributes (with the exception of ``Graph``
210 | attributes, which will be copied using their ``copy`` method).
211 |
212 | Returns:
213 | Function: A copy of the function.
214 | """
215 | local_tensor_copies = {n: t.copy() for n, t in self.tensors().items()}
216 |
217 | def get_tensor(name):
218 | """Retrieve a tensor by name from a deep-copied dictionary of tensors."""
219 | return local_tensor_copies[name] if name else Variable.empty()
220 |
221 | # Next, copy nodes, and update inputs/outputs
222 | new_nodes = []
223 | for node in self.nodes:
224 | new_node = node.copy(
225 | inputs=[get_tensor(inp.name) for inp in node.inputs],
226 | outputs=[get_tensor(out.name) for out in node.outputs],
227 | tensor_map=local_tensor_copies,
228 | )
229 | new_nodes.append(new_node)
230 | new_func_inputs = [get_tensor(inp.name) for inp in self.inputs]
231 | new_func_outputs = [get_tensor(out.name) for out in self.outputs]
232 |
233 | new_attrs = {name: copy.copy(val) for name, val in self.attrs.items()}
234 |
235 | return Function(
236 | self.name,
237 | self.domain,
238 | nodes=new_nodes,
239 | inputs=new_func_inputs,
240 | outputs=new_func_outputs,
241 | doc_string=self.doc_string,
242 | opset=self.opset,
243 | import_domains=self.import_domains,
244 | functions=self.functions,
245 | attrs=new_attrs,
246 | )
247 |
248 | def __eq__(self, other: "Function"):
249 | """Checks equality of self with another Function object based on their attributes."""
250 |
251 | def sequences_equal(seq1, seq2):
252 | """Checks if two sequences are equal in length and elements."""
253 | return len(seq1) == len(seq2) and all(elem1 == elem2 for elem1, elem2 in zip(seq1, seq2))
254 |
255 | return (
256 | self.unique_id == other.unique_id
257 | and self.opset == other.opset
258 | and self.import_domains == other.import_domains
259 | and sequences_equal(self.inputs, other.inputs)
260 | and sequences_equal(self.outputs, other.outputs)
261 | and sequences_equal(self.nodes, other.nodes)
262 | )
263 |
264 | def __str__(self):
265 | """Returns a string representation of the function including its name, domain, opset, inputs, nodes, and
266 | outputs.
267 | """
268 | nodes_str = "\n".join([str(node) for node in self.nodes])
269 | out = f"Function {self.name}, Domain {self.domain}, Opset {self.opset}"
270 | out += f"\nInputs: {self.inputs}"
271 | out += f"\nNodes: {nodes_str}"
272 | out += f"\nOutputs: {self.outputs}"
273 | return out
274 |
--------------------------------------------------------------------------------
/onnxslim/third_party/onnx_graphsurgeon/ir/node.py:
--------------------------------------------------------------------------------
1 | #
2 | # SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | # SPDX-License-Identifier: Apache-2.0
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | #
17 |
18 | from collections import OrderedDict
19 | from dataclasses import dataclass
20 | from typing import Dict, List
21 |
22 | from onnxslim.third_party.onnx_graphsurgeon.ir.tensor import Constant, Tensor, Variable
23 | from onnxslim.third_party.onnx_graphsurgeon.logger import G_LOGGER
24 | from onnxslim.third_party.onnx_graphsurgeon.util import misc
25 |
26 |
27 | class Node:
28 | @dataclass
29 | class AttributeRef:
30 | """
31 | An AttributeRef is an attribute value which references an attribute in the parent function. A node's attribute
32 | can only be an AttributeRef if the node lives inside a Function.
33 |
34 | Args:
35 | name (str): The name of the referenced attribute in the parent Function.
36 | type (type): The attribute's type.
37 | """
38 |
39 | name: str
40 | type: type
41 |
42 | def __init__(
43 | self,
44 | op: str,
45 | name: str = None,
46 | attrs: Dict[str, object] = None,
47 | inputs: List["Tensor"] = None,
48 | outputs: List["Tensor"] = None,
49 | domain: str = None,
50 | ):
51 | """
52 | A node represents an operation in a graph, and consumes zero or more Tensors, and produces zero or more Tensors.
53 |
54 | Args:
55 | op (str): The operation this node performs.
56 |
57 | name (str): The name of this node.
58 | attrs (Dict[str, object]): A dictionary that maps attribute names to their values.
59 | inputs (List[Tensor]): A list of zero or more input Tensors.
60 | outputs (List[Tensor]): A list of zero or more output Tensors.
61 | domain (str): The domain of this node,
62 | """
63 | self.op = op
64 | self.name = misc.default_value(name, "")
65 | self.attrs = misc.default_value(attrs, OrderedDict())
66 | self.inputs = misc.SynchronizedList(self, field_name="outputs", initial=misc.default_value(inputs, []))
67 | self.outputs = misc.SynchronizedList(self, field_name="inputs", initial=misc.default_value(outputs, []))
68 | self.domain = domain
69 |
70 | def i(self, tensor_idx=0, producer_idx=0):
71 | """
72 | Convenience function to get a producer node of one of this node's input tensors. Note that the parameters are
73 | swapped compared to the o() function; this is because tensors are likely to have only a single producer.
74 |
75 | For example:
76 | ::
77 |
78 | assert node.i() == node.inputs[0].inputs[0]
79 | assert node.i(1, 2) == node.inputs[1].inputs[2]
80 |
81 | Args:
82 | tensor_idx (int): The index of the input tensor of this node. Defaults to 0.
83 | producer_idx (int): The index of the producer of the input tensor, if the tensor has multiple producers. Defaults to 0
84 |
85 | Returns:
86 | Node: The specified producer (input) node.
87 | """
88 | return self.inputs[tensor_idx].inputs[producer_idx]
89 |
90 | def o(self, consumer_idx=0, tensor_idx=0):
91 | """
92 | Convenience function to get a consumer node of one of this node's output tensors.
93 |
94 | For example:
95 | ::
96 |
97 | assert node.o() == node.outputs[0].outputs[0]
98 | assert node.o(2, 1) == node.outputs[1].outputs[2]
99 |
100 | Args:
101 | consumer_idx (int): The index of the consumer of the input tensor. Defaults to 0.
102 | tensor_idx (int): The index of the output tensor of this node, if the node has multiple outputs. Defaults to 0.
103 |
104 | Returns:
105 | Node: The specified consumer (output) node
106 | """
107 | return self.outputs[tensor_idx].outputs[consumer_idx]
108 |
109 | def subgraphs(self, recursive=False):
110 | """
111 | Convenience function to iterate over all subgraphs which are contained in this node. Node subgraphs are found in
112 | attributes of ONNX control flow nodes such as 'If' and 'Loop'.
113 |
114 | Args:
115 | recursive (bool): Whether to recurse into the subgraph nodes when looking for subgraphs. Defaults to False.
116 |
117 | Returns:
118 | A generator which iterates over this node's subgraphs.
119 | """
120 | from onnxslim.third_party.onnx_graphsurgeon.ir.graph import Graph
121 |
122 | visit_queue = [self]
123 |
124 | # This prevents infinite recursion in the (illegal) case of cyclical graphs.
125 | visited = set()
126 |
127 | while visit_queue:
128 | node = visit_queue.pop()
129 | for attr in node.attrs.values():
130 | if isinstance(attr, Graph) and id(attr) not in visited:
131 | visited.add(id(attr))
132 | if recursive:
133 | visit_queue.extend(attr.nodes)
134 | yield attr
135 |
136 | def __setattr__(self, name, value):
137 | """Sets the attribute 'name' to 'value', handling special cases for 'inputs' and 'outputs' attributes."""
138 | if name in {"inputs", "outputs"}:
139 | try:
140 | attr = getattr(self, name)
141 | if value is attr:
142 | # This can happen when using things like +=
143 | # The __iadd__ is executed followed by an assignment
144 | return
145 |
146 | attr.clear()
147 | attr.extend(value)
148 | except AttributeError:
149 | super().__setattr__(name, value)
150 | else:
151 | super().__setattr__(name, value)
152 |
153 | def copy(
154 | self,
155 | inputs: List["Tensor"] = None,
156 | outputs: List["Tensor"] = None,
157 | tensor_map=None,
158 | ):
159 | """
160 | Makes a shallow copy of this node, overriding input and output information.
161 |
162 | Note: Generally, you should only ever make a copy of a Graph.
163 | """
164 | from onnxslim.third_party.onnx_graphsurgeon.ir.graph import Graph
165 |
166 | new_attrs = OrderedDict()
167 | for name, attr in self.attrs.items():
168 | new_attrs[name] = attr.copy(tensor_map) if isinstance(attr, Graph) else attr
169 | return Node(
170 | self.op,
171 | self.name,
172 | new_attrs,
173 | inputs=inputs,
174 | outputs=outputs,
175 | domain=self.domain,
176 | )
177 |
178 | def __str__(self):
179 | """Return a string representation of the object showing its name and operation."""
180 | ret = f"{self.name} ({self.op})"
181 |
182 | def add_io(name, io):
183 | """Add the input or output operations and their names to the string representation of the object."""
184 | nonlocal ret
185 | ret += f"\n\t{name}: ["
186 | for elem in io:
187 | ret += f"\n\t\t{elem}"
188 | ret += "\n\t]"
189 |
190 | add_io("Inputs", self.inputs)
191 | add_io("Outputs", self.outputs)
192 |
193 | if self.attrs:
194 | ret += f"\nAttributes: {self.attrs}"
195 |
196 | if self.domain:
197 | ret += f"\nDomain: {self.domain}"
198 |
199 | return ret
200 |
201 | def __repr__(self):
202 | """Return the string representation of the Ultralytics object."""
203 | return self.__str__()
204 |
205 | def __eq__(self, other):
206 | """Check whether two nodes are equal by comparing name, attributes, op, inputs, and outputs."""
207 | G_LOGGER.verbose(f"Comparing node: {self.name} with {other.name}")
208 | attrs_match = self.name == other.name and self.op == other.op and self.attrs == other.attrs
209 | if not attrs_match:
210 | return False
211 |
212 | inputs_match = misc.sequences_equal(self.inputs, other.inputs)
213 | if not inputs_match:
214 | return False
215 |
216 | outputs_match = misc.sequences_equal(self.outputs, other.outputs)
217 | return self.domain == other.domain if outputs_match else False
218 |
219 | @property
220 | def users(self):
221 | users = []
222 | for output in self.outputs: # output is a Variable
223 | if output.is_output:
224 | users.append(output)
225 | users.extend(iter(output.outputs))
226 | return users
227 |
228 | @property
229 | def feeds(self):
230 | """Retrieve the list of nodes that provide inputs to the given node."""
231 | feeds = []
232 | for input in self.inputs:
233 | if len(input.inputs) == 0 and not isinstance(input, Constant):
234 | feeds.append(input)
235 | elif isinstance(input, Constant):
236 | feeds.append(input)
237 | else:
238 | feeds.extend(input if feed.op == "Split" else feed for feed in input.inputs)
239 | return feeds
240 |
241 | def erase(self, input_var_idx=0, output_var_idx=0):
242 | if isinstance(self.inputs[input_var_idx], Variable):
243 | if self.inputs[input_var_idx].is_input:
244 | self.outputs[output_var_idx].replace_all_uses_with(self.inputs[input_var_idx])
245 | self.inputs.clear()
246 | self.outputs.clear()
247 | else:
248 | self.inputs[input_var_idx].replace_all_uses_with(self.outputs[output_var_idx])
249 | self.inputs.clear()
250 | self.outputs.clear()
251 |
252 | def replace_all_uses_with(self, node: "Node"):
253 | """Replace all uses of this node with the given node."""
254 | for user in self.users:
255 | for inp in user.inputs:
256 | if inp in self.outputs:
257 | for i, input in enumerate(user.inputs):
258 | if input == inp:
259 | user.inputs[i] = node.outputs[self.outputs.index(inp)]
260 |
261 | self.inputs.clear()
262 | self.outputs.clear()
263 |
--------------------------------------------------------------------------------
/onnxslim/third_party/onnx_graphsurgeon/logger/__init__.py:
--------------------------------------------------------------------------------
1 | from onnxslim.third_party.onnx_graphsurgeon.logger.logger import G_LOGGER, LogMode
2 |
--------------------------------------------------------------------------------
/onnxslim/third_party/onnx_graphsurgeon/logger/logger.py:
--------------------------------------------------------------------------------
1 | #
2 | # SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | # SPDX-License-Identifier: Apache-2.0
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | #
17 |
18 | import enum
19 | import inspect
20 | import os
21 | import sys
22 | import time
23 |
24 | from onnxslim.third_party.onnx_graphsurgeon.util.exception import (
25 | OnnxGraphSurgeonException,
26 | )
27 |
28 |
29 | # Context manager to apply indentation to messages
30 | class LoggerIndent:
31 | def __init__(self, logger, indent):
32 | """Initialize the LoggerIndent context manager with the specified logger and indentation level."""
33 | self.logger = logger
34 | self.old_indent = self.logger.logging_indent
35 | self.indent = indent
36 |
37 | def __enter__(self):
38 | """Set logger indentation level on entering the context."""
39 | self.logger.logging_indent = self.indent
40 | return self
41 |
42 | def __exit__(self, exc_type, exc_value, traceback):
43 | """Reset logger indentation level on exiting the context."""
44 | self.logger.logging_indent = self.old_indent
45 |
46 |
47 | # Context manager to suppress messages
48 | class LoggerSuppress:
49 | def __init__(self, logger, severity):
50 | """Initialize a LoggerSuppress object with a logger and severity level."""
51 | self.logger = logger
52 | self.old_severity = self.logger.severity
53 | self.severity = severity
54 |
55 | def __enter__(self):
56 | """Set logger severity to a specified level when entering the context."""
57 | self.logger.severity = self.severity
58 | return self
59 |
60 | def __exit__(self, exc_type, exc_value, traceback):
61 | """Reset logger severity to its original level when exiting the context."""
62 | self.logger.severity = self.old_severity
63 |
64 |
65 | class LogMode(enum.IntEnum):
66 | EACH = 0 # Log the message each time
67 | ONCE = 1 # Log the message only once. The same message will not be logged again.
68 |
69 |
70 | class Logger:
71 | ULTRA_VERBOSE = -10
72 | VERBOSE = 0
73 | DEBUG = 10
74 | INFO = 20
75 | WARNING = 30
76 | ERROR = 40
77 | CRITICAL = 50
78 |
79 | SEVERITY_LETTER_MAPPING = {
80 | ULTRA_VERBOSE: "[UV]",
81 | VERBOSE: "[V]",
82 | DEBUG: "[D]",
83 | INFO: "[I]",
84 | WARNING: "[W]",
85 | ERROR: "[E]",
86 | CRITICAL: "[C]",
87 | }
88 |
89 | SEVERITY_COLOR_MAPPING = {
90 | ULTRA_VERBOSE: "cyan",
91 | VERBOSE: "dark_gray",
92 | DEBUG: "light_gray",
93 | INFO: "light_green",
94 | WARNING: "light_yellow",
95 | ERROR: "red_1",
96 | CRITICAL: "red_1",
97 | }
98 |
99 | def __init__(self, severity=INFO, colors=True, letter=True, timestamp=False, line_info=False):
100 | """
101 | Logger.
102 |
103 | Args:
104 | severity (Logger.Severity): Messages below this severity are ignored.
105 | colors (bool): Whether to use colored output.
106 | letter (bool): Whether to prepend each logging message with a letter indicating it's severity. Defaults to True.
107 | timestamp (bool): Whether to include a timestamp in the logging output. Defaults to False.
108 | line_info (bool): Whether to include file and line number information in the logging output. Defaults to False.
109 | """
110 | self._severity = severity
111 | self.logging_indent = 0
112 | self.root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, os.pardir))
113 | self.once_logged = set()
114 | self.colors = colors
115 | self.letter = letter
116 | self.timestamp = timestamp
117 | self.line_info = line_info
118 | self.logger_callbacks = []
119 |
120 | @property
121 | def severity(self):
122 | """Returns the logging severity level."""
123 | return self._severity
124 |
125 | @severity.setter
126 | def severity(self, value):
127 | """Returns or sets the logging severity level with callback updates."""
128 | self._severity = value
129 | for callback in self.logger_callbacks:
130 | callback(self._severity)
131 |
132 | def register_callback(self, callback):
133 | """
134 | Registers a callback with the logger, which will be invoked when the logging severity is modified. The callback
135 | is guaranteed to be called at least once in the register_callback function.
136 |
137 | Args:
138 | callback (Callable(Logger.Severity)): A callback that accepts the current logger severity.
139 | """
140 | callback(self._severity)
141 | self.logger_callbacks.append(callback)
142 |
143 | def indent(self, level=1):
144 | """Returns a context manager that indents all strings logged by the specified amount."""
145 | return LoggerIndent(self, level + self.logging_indent)
146 |
147 | def suppress(self, severity=CRITICAL):
148 | """
149 | Returns a context manager that temporarily changes the severity of the logger for its duration.
150 |
151 | Args:
152 | severity (Logger.Severity): The severity to set the logger to. Defaults to Logger.CRITICAL, which will suppress all messages.
153 | """
154 | return LoggerSuppress(self, severity)
155 |
156 | # If once is True, the logger will only log this message a single time. Useful in loops.
157 | # message may be a callable which returns a message. This way, only if the message needs to be logged is it ever generated.
158 | def log(self, message, severity, mode=LogMode.EACH, stack_depth=2):
159 | """Logs a message with a specified severity and mode, supporting both single and repeated logging based on
160 | conditions.
161 | """
162 |
163 | def process_message(message, stack_depth):
164 | """Generates a log message prefix with file name and line number based on the specified stack depth."""
165 |
166 | def get_prefix():
167 | def get_line_info():
168 | module = inspect.getmodule(sys._getframe(stack_depth + 3)) or inspect.getmodule(
169 | sys._getframe(stack_depth + 2)
170 | )
171 | filename = module.__file__
172 | filename = os.path.relpath(filename, self.root_dir)
173 | # If the file is not located in trt_smeagol, use its basename instead.
174 | if os.pardir in filename:
175 | filename = os.path.basename(filename)
176 | return f"[{filename}:{sys._getframe(stack_depth).f_lineno}] "
177 |
178 | prefix = ""
179 | if self.letter:
180 | prefix += f"{Logger.SEVERITY_LETTER_MAPPING[severity]} "
181 | if self.timestamp:
182 | prefix += "({:}) ".format(time.strftime("%X"))
183 | if self.line_info:
184 | prefix += get_line_info()
185 | return prefix
186 |
187 | def apply_indentation(message):
188 | """Indent each line in the message by the specified logging_indent level."""
189 | message_lines = str(message).splitlines()
190 | return "\n".join(["\t" * self.logging_indent + line for line in message_lines])
191 |
192 | def apply_color(message):
193 | """Apply color formatting to the message if color support is enabled."""
194 | if self.colors:
195 | try:
196 | import colored
197 |
198 | color = Logger.SEVERITY_COLOR_MAPPING[severity]
199 | return colored.stylize(message, [colored.fg(color)])
200 | except ImportError:
201 | self.colors = False
202 | self.warning(
203 | "colored module is not installed, will not use colors when logging. To enable colors, please install the colored module: python3 -m pip install colored"
204 | )
205 | self.colors = True
206 | return message
207 |
208 | prefix = get_prefix()
209 | message = apply_indentation(message)
210 | return apply_color(f"{prefix}{message}")
211 |
212 | def should_log(message):
213 | """Determines if a message should be logged based on the severity level and logging mode."""
214 | should = severity >= self._severity
215 | if mode == LogMode.ONCE:
216 | message_hash = hash(message)
217 | should &= message_hash not in self.once_logged
218 | self.once_logged.add(message_hash)
219 | return should
220 |
221 | if not should_log(message):
222 | return
223 |
224 | if callable(message):
225 | message = message()
226 | message = str(message)
227 | print(process_message(message, stack_depth=stack_depth))
228 |
229 | def ultra_verbose(self, message, mode=LogMode.EACH):
230 | """Logs an ultra-verbose message with a specified mode and stack depth of 3."""
231 | self.log(message, Logger.ULTRA_VERBOSE, mode=mode, stack_depth=3)
232 |
233 | def verbose(self, message, mode=LogMode.EACH):
234 | """Logs a verbose message with a specified mode and stack depth of 3."""
235 | self.log(message, Logger.VERBOSE, mode=mode, stack_depth=3)
236 |
237 | def debug(self, message, mode=LogMode.EACH):
238 | """Logs a debug message with a specified mode and stack depth of 3."""
239 | self.log(message, Logger.DEBUG, mode=mode, stack_depth=3)
240 |
241 | def info(self, message, mode=LogMode.EACH):
242 | """Logs an informational message with a specified mode and stack depth of 3."""
243 | self.log(message, Logger.INFO, mode=mode, stack_depth=3)
244 |
245 | def warning(self, message, mode=LogMode.EACH):
246 | """Logs a warning message with a specified mode and stack depth of 3."""
247 | self.log(message, Logger.WARNING, mode=mode, stack_depth=3)
248 |
249 | def error(self, message, mode=LogMode.EACH):
250 | """Logs an error message with a specified mode and stack depth of 3."""
251 | self.log(message, Logger.ERROR, mode=mode, stack_depth=3)
252 |
253 | # Like error, but immediately exits.
254 | def critical(self, message):
255 | """Logs a critical message with a stack depth of 3 and raises an OnnxGraphSurgeonException."""
256 | self.log(message, Logger.CRITICAL, stack_depth=3)
257 | raise OnnxGraphSurgeonException(message) from None # Erase exception chain
258 |
259 |
260 | global G_LOGGER
261 | G_LOGGER = Logger()
262 |
--------------------------------------------------------------------------------
/onnxslim/third_party/onnx_graphsurgeon/util/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/inisis/OnnxSlim/4e820ced75ed909f96b5b92bc46cebf4ca3a2156/onnxslim/third_party/onnx_graphsurgeon/util/__init__.py
--------------------------------------------------------------------------------
/onnxslim/third_party/onnx_graphsurgeon/util/exception.py:
--------------------------------------------------------------------------------
1 | #
2 | # SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | # SPDX-License-Identifier: Apache-2.0
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | #
17 |
18 |
19 | class OnnxGraphSurgeonException(Exception):
20 | """An exception raised by ONNX-GraphSurgeon."""
21 |
--------------------------------------------------------------------------------
/onnxslim/third_party/onnx_graphsurgeon/util/misc.py:
--------------------------------------------------------------------------------
1 | #
2 | # SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | # SPDX-License-Identifier: Apache-2.0
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | #
17 |
18 | from collections import OrderedDict
19 | from typing import List, Sequence
20 |
21 | import numpy as np
22 | from onnx import AttributeProto
23 |
24 | from onnxslim.third_party.onnx_graphsurgeon.logger import G_LOGGER
25 |
26 |
27 | # default_value exists to solve issues that might result from Python's normal default argument behavior.
28 | # Specifically, consider the following class:
29 | #
30 | # class MyClass(object):
31 | # def __init__(self, value=[]):
32 | # self.value = value
33 | #
34 | # This leads to unwanted behavior when the default value is used:
35 | #
36 | # >>> x = MyClass()
37 | # >>> x.value.append("SHOULD NOT BE IN Y")
38 | # >>> y = MyClass()
39 | # >>> y.value
40 | # ['SHOULD NOT BE IN Y']
41 | #
42 | # If we rewrite the class using default value:
43 | #
44 | # class MyClass(object):
45 | # def __init__(self, value=None):
46 | # self.value = default_value(value, [])
47 | #
48 | # Then we get the desired behavior:
49 | #
50 | # >>> x = MyClass()
51 | # >>> x.value.append("SHOULD NOT BE IN Y")
52 | # >>> y = MyClass()
53 | # >>> y.value
54 | # []
55 | def default_value(value, default):
56 | """Return the value if not None, otherwise return the default value."""
57 | return value if value is not None else default
58 |
59 |
60 | def combine_dicts(dict0, dict1):
61 | """
62 | Combine two dictionaries.
63 |
64 | Values in the second will overwrite values in the first.
65 | """
66 | if dict1 is None:
67 | return dict0
68 | combined = OrderedDict()
69 | combined.update(dict0)
70 | combined.update(dict1)
71 | return combined
72 |
73 |
74 | def unique_dicts(dict0, dict1):
75 | """
76 | Subtract two dictionaries.
77 |
78 | Values in the second will be subtracted from the first.
79 | """
80 | return {k: v for k, v in dict0.items() if k not in dict1} if dict1 else dict0
81 |
82 |
83 | def is_dynamic_dimension(dim):
84 | """Check if a dimension is dynamic (non-integer or negative)."""
85 | return not isinstance(dim, int) or dim < 0
86 |
87 |
88 | def is_dynamic_shape(shape):
89 | """Determine if any dimension in the given shape is dynamic (non-integer or negative)."""
90 | return any(is_dynamic_dimension(dim) for dim in shape)
91 |
92 |
93 | def volume(obj):
94 | """Calculate the volume by multiplying the elements of an iterable object."""
95 | vol = 1
96 | for elem in obj:
97 | vol *= elem
98 | return vol
99 |
100 |
101 | _ONNX_ATTR_TYPE_TO_GS_TYPE = {}
102 | _GS_TYPE_TO_ONNX_ATTR_TYPE = {}
103 |
104 |
105 | # This method prevents circular import of Tensor and Graph
106 | def _init_dicts():
107 | """Initialize mapping dictionaries to prevent circular imports of Tensor and Graph."""
108 | global _ONNX_ATTR_TYPE_TO_GS_TYPE
109 | global _GS_TYPE_TO_ONNX_ATTR_TYPE
110 | if _ONNX_ATTR_TYPE_TO_GS_TYPE and _GS_TYPE_TO_ONNX_ATTR_TYPE:
111 | return
112 |
113 | from onnxslim.third_party.onnx_graphsurgeon.ir.graph import Graph
114 | from onnxslim.third_party.onnx_graphsurgeon.ir.tensor import Tensor
115 |
116 | _ONNX_ATTR_TYPE_TO_GS_TYPE = {
117 | AttributeProto.UNDEFINED: None,
118 | AttributeProto.FLOAT: float,
119 | AttributeProto.INT: int,
120 | AttributeProto.STRING: str,
121 | AttributeProto.TENSOR: Tensor,
122 | AttributeProto.GRAPH: Graph,
123 | AttributeProto.SPARSE_TENSOR: AttributeProto.SPARSE_TENSOR,
124 | AttributeProto.TYPE_PROTO: AttributeProto.TYPE_PROTO,
125 | AttributeProto.FLOATS: List[float],
126 | AttributeProto.INTS: List[int],
127 | AttributeProto.STRINGS: List[str],
128 | AttributeProto.TENSORS: List[Tensor],
129 | AttributeProto.GRAPHS: List[Graph],
130 | AttributeProto.SPARSE_TENSORS: AttributeProto.SPARSE_TENSORS,
131 | AttributeProto.TYPE_PROTOS: AttributeProto.TYPE_PROTOS,
132 | }
133 | _GS_TYPE_TO_ONNX_ATTR_TYPE = {v: k for k, v in _ONNX_ATTR_TYPE_TO_GS_TYPE.items()}
134 |
135 |
136 | def convert_from_onnx_attr_type(onnx_attr_type):
137 | """Converts an ONNX attribute type to its corresponding GS attribute type."""
138 | _init_dicts()
139 | return _ONNX_ATTR_TYPE_TO_GS_TYPE[onnx_attr_type]
140 |
141 |
142 | def convert_to_onnx_attr_type(any_type):
143 | """Converts a given type to its corresponding ONNX attribute type."""
144 | _init_dicts()
145 | if any_type in _GS_TYPE_TO_ONNX_ATTR_TYPE:
146 | return _GS_TYPE_TO_ONNX_ATTR_TYPE[any_type]
147 | if np.issubdtype(any_type, np.floating):
148 | return AttributeProto.FLOAT
149 | if np.issubdtype(any_type, np.integer):
150 | return AttributeProto.INT
151 | G_LOGGER.warning(f"Unable to convert {any_type} into an ONNX AttributeType")
152 |
153 |
154 | # Special type of list that synchronizes contents with another list.
155 | # Concrete example: Assume some node, n, contains an input tensor, t. If we remove t from n.inputs,
156 | # we also need to remove n from t.outputs. To avoid having to do this manually, we use SynchronizedList,
157 | # which takes an attribute name as a parameter, and then synchronizes to that attribute of each of its elements.
158 | # So, in the example above, we can make n.inputs a synchronized list whose field_name is set to "outputs".
159 | # See test_ir.TestNodeIO for functional tests
160 | class SynchronizedList(list):
161 | def __init__(self, parent_obj, field_name, initial):
162 | """Initialize a SynchronizedList with a parent object, a field name, and an initial set of elements."""
163 | self.parent_obj = parent_obj
164 | self.field_name = field_name
165 | self.extend(initial)
166 |
167 | def _add_to_elem(self, elem):
168 | """Append the parent_obj to the list attribute defined by field_name in the provided elem object."""
169 | list.append(getattr(elem, self.field_name), self.parent_obj)
170 |
171 | def _remove_from_elem(self, elem):
172 | """Remove the parent_obj from the list attribute defined by field_name in the provided elem object."""
173 | list.remove(getattr(elem, self.field_name), self.parent_obj)
174 |
175 | def __delitem__(self, index):
176 | """Remove the element at the specified index and update the corresponding list attribute in the parent
177 | object.
178 | """
179 | self._remove_from_elem(self[index])
180 | super().__delitem__(index)
181 |
182 | def __setitem__(self, index, elem):
183 | """Update the element at the specified index and modify the corresponding list attribute in the parent
184 | object.
185 | """
186 | self._remove_from_elem(self[index])
187 | super().__setitem__(index, elem)
188 | self._add_to_elem(elem)
189 |
190 | def append(self, x):
191 | """Append an element to the list and update the parent object's corresponding list attribute."""
192 | super().append(x)
193 | self._add_to_elem(x)
194 |
195 | def extend(self, iterable: Sequence[object]):
196 | """Extend the list with elements from an iterable and update the parent object's corresponding list
197 | attribute.
198 | """
199 | super().extend(iterable)
200 | for elem in iterable:
201 | self._add_to_elem(elem)
202 |
203 | def insert(self, i, x):
204 | """Insert an element at a given position and update the parent object's corresponding list attribute."""
205 | super().insert(i, x)
206 | self._add_to_elem(x)
207 |
208 | def remove(self, x):
209 | """Remove an element from the list and update the parent object's corresponding list attribute."""
210 | super().remove(x)
211 | self._remove_from_elem(x)
212 |
213 | def pop(self, i=-1):
214 | """Remove and return the element at index i (default last) from the list and update the parent object's
215 | corresponding list attribute.
216 | """
217 | elem = super().pop(i)
218 | self._remove_from_elem(elem)
219 | return elem
220 |
221 | def clear(self):
222 | """Clear all elements from the list and update the parent object's corresponding list attribute."""
223 | for elem in self:
224 | self._remove_from_elem(elem)
225 | super().clear()
226 |
227 | def __add__(self, other_list: List[object]):
228 | """Concatenate the current list with another list and return the resulting list."""
229 | return list(self) + list(other_list)
230 |
231 | def __iadd__(self, other_list: List[object]):
232 | """Append elements from another list to the current list and return the modified list."""
233 | self.extend(other_list)
234 | return self
235 |
236 | def __copy__(self):
237 | """Return a shallow copy of the current list."""
238 | return list(self)
239 |
240 | def __deepcopy__(self, memo):
241 | """Return a deep copy of the current list."""
242 | return list(self)
243 |
244 |
245 | def sequences_equal(seq1, seq2):
246 | """Check if two sequences are equal by comparing their lengths and elements."""
247 | length_match = len(seq1) == len(seq2)
248 | if not length_match:
249 | return False
250 |
251 | return all(elem1 == elem2 for elem1, elem2 in zip(seq1, seq2))
252 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import find_packages, setup
2 |
3 | with open("VERSION") as f:
4 | version = f.read().strip()
5 |
6 | with open("onnxslim/version.py", "w") as f:
7 | f.write(f'__version__ = "{version}"\n')
8 |
9 | setup(
10 | name="onnxslim",
11 | version=version,
12 | description="OnnxSlim: A Toolkit to Help Optimize Onnx Model",
13 | long_description=open("README.md", encoding="utf-8").read(),
14 | long_description_content_type="text/markdown",
15 | url="https://github.com/inisis/OnnxSlim",
16 | author="inisis",
17 | author_email="desmond.yao@buaa.edu.cn",
18 | project_urls={
19 | "Bug Tracker": "https://github.com/inisis/OnnxSlim/issues",
20 | },
21 | classifiers=[
22 | "Programming Language :: Python :: 3",
23 | "License :: OSI Approved :: MIT License",
24 | "Intended Audience :: Developers",
25 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
26 | ],
27 | license="MIT",
28 | install_requires=["onnx", "sympy", "packaging"],
29 | packages=find_packages(exclude=("tests", "tests.*")),
30 | entry_points={"console_scripts": ["onnxslim=onnxslim.cli:main"]},
31 | zip_safe=True,
32 | python_requires=">=3.6",
33 | )
34 |
--------------------------------------------------------------------------------
/tests/test_benchmark.py:
--------------------------------------------------------------------------------
1 | import os
2 | import subprocess
3 | import tempfile
4 |
5 | import numpy as np
6 | import onnxruntime as ort
7 | import pytest
8 |
9 | ort.set_default_logger_severity(3)
10 |
11 | from onnxslim.utils import print_model_info_as_table, summarize_model
12 |
13 | MODELZOO_PATH = "/data/modelzoo"
14 |
15 |
16 | def bench_main(command):
17 | result = subprocess.run(command, shell=True, capture_output=True, text=True)
18 | return result
19 |
20 |
21 | def bench_onnxslim(input, output):
22 | command = f"onnxslim {input} {output}"
23 | result = bench_main(command)
24 | return result
25 |
26 |
27 | def bench_onnxsim(input, output):
28 | command = f"onnxsim {input} {output}"
29 | result = bench_main(command)
30 | return result
31 |
32 |
33 | def bench_polygraphy(input, output):
34 | command = f"polygraphy surgeon sanitize --fold-constants {input} -o {output}"
35 | result = bench_main(command)
36 | return result
37 |
38 |
39 | def bench_onnxruntime(input, output):
40 | try:
41 | import onnxruntime as rt
42 |
43 | sess_options = rt.SessionOptions()
44 | # Set graph optimization level
45 | sess_options.graph_optimization_level = rt.GraphOptimizationLevel.ORT_ENABLE_EXTENDED
46 | # To enable model serialization after graph optimization set this
47 | sess_options.optimized_model_filepath = output
48 | rt.InferenceSession(input, sess_options)
49 | return True
50 |
51 | except Exception as e:
52 | print(e)
53 | return None
54 |
55 |
56 | def bench_onnxruntime(input, output):
57 | try:
58 | import onnxruntime as rt
59 |
60 | sess_options = rt.SessionOptions()
61 | # Set graph optimization level
62 | sess_options.graph_optimization_level = rt.GraphOptimizationLevel.ORT_ENABLE_EXTENDED
63 | # To enable model serialization after graph optimization set this
64 | sess_options.optimized_model_filepath = output
65 | rt.InferenceSession(input, sess_options)
66 | return True
67 |
68 | except Exception as e:
69 | print(e)
70 | return None
71 |
72 |
73 | def bench_onnxruntime(input, output):
74 | try:
75 | import onnxruntime as rt
76 |
77 | sess_options = rt.SessionOptions()
78 | # Set graph optimization level
79 | sess_options.graph_optimization_level = rt.GraphOptimizationLevel.ORT_ENABLE_EXTENDED
80 | # To enable model serialization after graph optimization set this
81 | sess_options.optimized_model_filepath = output
82 | rt.InferenceSession(input, sess_options)
83 | return True
84 |
85 | except Exception as e:
86 | print(e)
87 | return None
88 |
89 |
90 | class TestModelZoo:
91 | def transform_and_check(self, name, filename, transformation_func, suffix, check_func):
92 | with tempfile.TemporaryDirectory() as tempdir:
93 | output_file = os.path.join(tempdir, f"{name}_{suffix}.onnx")
94 | result = transformation_func(filename, output_file)
95 | if result is None:
96 | return None
97 | if result is True or (hasattr(result, "returncode") and result.returncode == 0):
98 | if check_func:
99 | try:
100 | check_func(output_file)
101 | except:
102 | return None
103 | return summarize_model(output_file, suffix)
104 | return None
105 |
106 | def run_model_test(self, name, filename, check_func=None):
107 | summary_list = [summarize_model(filename)]
108 | summary_list.append(self.transform_and_check(name, filename, bench_onnxslim, "onnxslim", check_func))
109 | summary_list.append(self.transform_and_check(name, filename, bench_onnxsim, "onnxsim", check_func))
110 | summary_list.append(self.transform_and_check(name, filename, bench_polygraphy, "polygraphy", check_func))
111 | summary_list.append(self.transform_and_check(name, filename, bench_onnxruntime, "onnxruntime", check_func))
112 |
113 | summary_list = [summary for summary in summary_list if summary is not None]
114 |
115 | print()
116 | print_model_info_as_table(summary_list)
117 |
118 | def test_silero_vad(self, request):
119 | def check_model_inference(model_path):
120 | batch_size = 2
121 | input_data = np.zeros((batch_size, 256), dtype=np.float32)
122 | sr = np.array(16000)
123 | state = np.zeros((2, batch_size, 128), dtype=np.float32)
124 |
125 | ort_sess = ort.InferenceSession(model_path)
126 | outputs = ort_sess.run(None, {"input": input_data, "sr": sr, "state": state})
127 | assert outputs is not None, "Inference failed on transformed model."
128 |
129 | name = request.node.originalname[len("test_") :]
130 | filename = f"{MODELZOO_PATH}/{name}/{name}.onnx"
131 | self.run_model_test(name, filename, check_model_inference)
132 |
133 | def test_decoder_with_past_model(self, request):
134 | def check_model_inference(model_path):
135 | batch_size = 2
136 | input_ids = np.ones((batch_size, 256), dtype=np.int64)
137 | encoder_hidden_states = np.zeros((batch_size, 128, 16), dtype=np.float32)
138 |
139 | ort_sess = ort.InferenceSession(model_path)
140 | ort_sess.run(None, {"input_ids": input_ids, "encoder_hidden_states": encoder_hidden_states})
141 |
142 | name = request.node.originalname[len("test_") :]
143 | filename = f"{MODELZOO_PATH}/{name}/{name}.onnx"
144 | self.run_model_test(name, filename, check_model_inference)
145 |
146 | def test_tiny_en_decoder(self, request):
147 | name = request.node.originalname[len("test_") :]
148 | filename = f"{MODELZOO_PATH}/{name}/{name}.onnx"
149 | self.run_model_test(name, filename)
150 |
151 | def test_transformer_encoder(self, request):
152 | name = request.node.originalname[len("test_") :]
153 | filename = f"{MODELZOO_PATH}/{name}/{name}.onnx"
154 | self.run_model_test(name, filename)
155 |
156 | def test_uiex(self, request):
157 | name = request.node.originalname[len("test_") :]
158 | filename = f"{MODELZOO_PATH}/{name}/{name}.onnx"
159 | self.run_model_test(name, filename)
160 |
161 | def test_en_number_mobile_v2_0_rec_infer(self, request):
162 | name = request.node.originalname[len("test_") :]
163 | filename = f"{MODELZOO_PATH}/{name}/{name}.onnx"
164 | self.run_model_test(name, filename)
165 |
166 | def test_paddleocr(self, request):
167 | name = request.node.originalname[len("test_") :]
168 | filename = f"{MODELZOO_PATH}/{name}/{name}.onnx"
169 | self.run_model_test(name, filename)
170 |
171 |
172 | if __name__ == "__main__":
173 | import sys
174 |
175 | sys.exit(
176 | pytest.main(
177 | [
178 | "-p",
179 | "no:warnings",
180 | "-sv",
181 | "tests/test_benchmark.py",
182 | ]
183 | )
184 | )
185 |
--------------------------------------------------------------------------------
/tests/test_folder.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import glob
3 | import os
4 | import subprocess
5 |
6 | import pytest
7 |
8 |
9 | def parse_arguments():
10 | """Parses command-line arguments for specifying the ONNX model directory."""
11 | parser = argparse.ArgumentParser(description="Test script for ONNX models")
12 | parser.add_argument(
13 | "--model-dir",
14 | type=str,
15 | required=True,
16 | help="Directory containing ONNX model files",
17 | )
18 | return parser.parse_args()
19 |
20 |
21 | args = parse_arguments()
22 |
23 |
24 | @pytest.fixture(params=glob.glob(f"{args.model_dir}/*/*.onnx"))
25 | def model_file(request):
26 | """Yields ONNX model file paths from the specified directory for parameterized testing."""
27 | yield request.param
28 |
29 |
30 | def test_model_file(model_file):
31 | """Tests the slimming of an ONNX model file using the onnxslim command, and validates the process by checking the
32 | command output.
33 | """
34 | slim_model_file = model_file.replace(".onnx", "_slim.onnx")
35 | command = f"onnxslim {model_file} {slim_model_file}"
36 | result = subprocess.run(command, shell=True, capture_output=True, text=True)
37 | if result.returncode != 0:
38 | print(result.stderr)
39 | raise AssertionError("Failed to slim model")
40 | else:
41 | output = result.stdout
42 | print(f"\n{output}")
43 | os.remove(slim_model_file)
44 | slim_data_file = model_file.replace(".onnx", "_slim.onnx.data")
45 | if os.path.exists(slim_data_file):
46 | os.remove(slim_data_file)
47 |
48 |
49 | if __name__ == "__main__":
50 | import sys
51 |
52 | sys.exit(
53 | pytest.main(
54 | [
55 | "-p",
56 | "no:warnings",
57 | "-sv",
58 | "tests/test_folder.py",
59 | ]
60 | )
61 | )
62 |
--------------------------------------------------------------------------------
/tests/test_modelzoo.py:
--------------------------------------------------------------------------------
1 | import os
2 | import tempfile
3 |
4 | import numpy as np
5 | import onnxruntime as ort
6 | import pytest
7 |
8 | from onnxslim import slim
9 | from onnxslim.utils import print_model_info_as_table, summarize_model
10 |
11 | MODELZOO_PATH = "/data/modelzoo"
12 |
13 |
14 | class TestModelZoo:
15 | def test_silero_vad(self, request):
16 | name = request.node.originalname[len("test_") :]
17 | filename = f"{MODELZOO_PATH}/{name}/{name}.onnx"
18 |
19 | with tempfile.TemporaryDirectory() as tempdir:
20 | slim(filename, os.path.join(tempdir, f"{name}_slim.onnx"))
21 | batch_size = 2
22 | input = np.zeros((batch_size, 256), dtype=np.float32)
23 | sr = np.array(16000)
24 | state = np.zeros((2, batch_size, 128), dtype=np.float32)
25 |
26 | ort_sess = ort.InferenceSession(os.path.join(tempdir, f"{name}_slim.onnx"))
27 | ort_sess.run(None, {"input": input, "sr": sr, "state": state})
28 |
29 | def test_decoder_with_past_model(self, request):
30 | name = request.node.originalname[len("test_") :]
31 | filename = f"{MODELZOO_PATH}/{name}/{name}.onnx"
32 |
33 | with tempfile.TemporaryDirectory() as tempdir:
34 | slim(filename, os.path.join(tempdir, f"{name}_slim.onnx"))
35 | batch_size = 2
36 | input_ids = np.ones((batch_size, 256), dtype=np.int64)
37 | encoder_hidden_states = np.zeros((batch_size, 128, 16), dtype=np.float32)
38 |
39 | ort_sess = ort.InferenceSession(os.path.join(tempdir, f"{name}_slim.onnx"))
40 | ort_sess.run(None, {"input_ids": input_ids, "encoder_hidden_states": encoder_hidden_states})
41 |
42 | def test_tiny_en_decoder(self, request):
43 | name = request.node.originalname[len("test_") :]
44 | filename = f"{MODELZOO_PATH}/{name}/{name}.onnx"
45 |
46 | with tempfile.TemporaryDirectory() as tempdir:
47 | slim(filename, os.path.join(tempdir, f"{name}_slim.onnx"), model_check=True)
48 |
49 | def test_transformer_encoder(self, request):
50 | name = request.node.originalname[len("test_") :]
51 | filename = f"{MODELZOO_PATH}/{name}/{name}.onnx"
52 | summary = summarize_model(slim(filename), tag=request.node.name)
53 | print_model_info_as_table(summary)
54 | assert summary.op_type_counts["Mul"] == 57
55 | assert summary.op_type_counts["Div"] == 53
56 |
57 | def test_uiex(self, request):
58 | name = request.node.originalname[len("test_") :]
59 | filename = f"{MODELZOO_PATH}/{name}/{name}.onnx"
60 | summary = summarize_model(slim(filename), tag=request.node.name)
61 | print_model_info_as_table(summary)
62 | assert summary.op_type_counts["Range"] == 0
63 | assert summary.op_type_counts["Floor"] == 0
64 | assert summary.op_type_counts["Concat"] == 54
65 |
66 | def test_qwen_vl_vision_encoder(self, request):
67 | name = request.node.originalname[len("test_") :]
68 | filename = f"{MODELZOO_PATH}/{name}/{name}.onnx"
69 | summary = summarize_model(slim(filename), tag=request.node.name)
70 | print_model_info_as_table(summary)
71 | with tempfile.TemporaryDirectory() as tempdir:
72 | slim(filename, os.path.join(tempdir, f"{name}_slim.onnx"))
73 | import numpy as np
74 | import onnxruntime as ort
75 |
76 | ort_sess = ort.InferenceSession(os.path.join(tempdir, f"{name}_slim.onnx"))
77 | outputs = ort_sess.run(
78 | None,
79 | {"pixel_values": np.random.rand(256, 1176).astype(np.float32), "grid_thw": np.array([[1, 16, 16]])},
80 | )
81 | print(f"{outputs[0].shape=}") # (64, 16)
82 |
83 | def test_layer_normalization_2d_axis0_expanded_ver18(self, request):
84 | name = request.node.originalname[len("test_") :]
85 | filename = f"{MODELZOO_PATH}/{name}/{name}.onnx"
86 |
87 | with tempfile.TemporaryDirectory() as tempdir:
88 | slim(filename, os.path.join(tempdir, f"{name}_slim.onnx"), model_check=True)
89 | summary = summarize_model(os.path.join(tempdir, f"{name}_slim.onnx"), tag=request.node.name)
90 | assert summary.op_type_counts["Reshape"] == 1
91 |
92 | def test_padconv(self, request):
93 | name = request.node.originalname[len("test_") :]
94 | filename = f"{MODELZOO_PATH}/{name}/{name}.onnx"
95 |
96 | with tempfile.TemporaryDirectory() as tempdir:
97 | slim(
98 | filename,
99 | os.path.join(tempdir, f"{name}_slim.onnx"),
100 | model_check=True,
101 | input_shapes=["/encoder/encoders0/encoders0.0/self_attn/Transpose_2_output_0:1,516,32"],
102 | )
103 |
104 | # def test_wav2vec2_conformer(self, request):
105 | # name = request.node.originalname[len("test_") :]
106 | # filename = f"{MODELZOO_PATH}/{name}/{name}.onnx"
107 |
108 | # with tempfile.TemporaryDirectory() as tempdir:
109 | # slim(filename, os.path.join(tempdir, f"{name}_slim.onnx"))
110 | # batch_size = 2
111 | # input = np.zeros((batch_size, 256), dtype=np.float32)
112 |
113 | # ort_sess = ort.InferenceSession(os.path.join(tempdir, f"{name}_slim.onnx"))
114 | # ort_sess.run(None, {"input_values": input})
115 |
116 | def test_yolo11n_pose(self, request):
117 | name = request.node.originalname[len("test_") :]
118 | filename = f"{MODELZOO_PATH}/{name}/{name}.onnx"
119 |
120 | with tempfile.TemporaryDirectory() as tempdir:
121 | slim(filename, os.path.join(tempdir, f"{name}_slim.onnx"))
122 | input = np.zeros((1, 3, 256, 256), dtype=np.float32)
123 |
124 | ort_sess = ort.InferenceSession(os.path.join(tempdir, f"{name}_slim.onnx"))
125 | ort_sess.run(None, {"images": input})
126 |
127 |
128 | if __name__ == "__main__":
129 | import sys
130 |
131 | sys.exit(
132 | pytest.main(
133 | [
134 | "-p",
135 | "no:warnings",
136 | "-sv",
137 | "tests/test_modelzoo.py",
138 | ]
139 | )
140 | )
141 |
--------------------------------------------------------------------------------
/tests/test_onnx_nets.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import subprocess
4 | import warnings
5 |
6 | import pytest
7 | import timm
8 | import torch
9 | import torchvision.models as models
10 |
11 | FUSE = True
12 | PRETRAINED = False
13 | MEMORY_LIMIT_GB = 0.75 # User's memory limit
14 | MEMORY_PER_PARAM = 4e-9 # Approximate memory required per parameter in GB
15 |
16 | os.makedirs("tmp", exist_ok=True)
17 |
18 |
19 | class TestTorchVisionClass:
20 | @pytest.mark.parametrize(
21 | "model",
22 | (
23 | models.resnet18,
24 | models.alexnet,
25 | models.squeezenet1_0,
26 | models.googlenet,
27 | ),
28 | )
29 | def test_torchvision(self, request, model, shape=(1, 3, 224, 224)):
30 | """Test various TorchVision models with random input tensors of a specified shape."""
31 | model = model(pretrained=PRETRAINED)
32 | x = torch.rand(shape)
33 | directory = f"tmp/{request.node.name}"
34 | os.makedirs(directory, exist_ok=True)
35 |
36 | filename = f"{directory}/{request.node.name}.onnx"
37 | slim_filename = f"{directory}/{request.node.name}_slim.onnx"
38 |
39 | torch.onnx.export(model, x, filename)
40 |
41 | command = f"onnxslim {filename} {slim_filename}"
42 | result = subprocess.run(command, shell=True, capture_output=True, text=True)
43 | output = result.stderr.strip()
44 | # Assert the expected return code
45 | print(output)
46 | assert result.returncode == 0
47 |
48 | shutil.rmtree(directory, ignore_errors=True)
49 |
50 |
51 | class TestTimmClass:
52 | @pytest.fixture(params=timm.list_models())
53 | def model_name(self, request):
54 | """Yields names of models available in TIMM (https://github.com/rwightman/pytorch-image-models) for pytest fixture parameterization."""
55 | yield request.param
56 |
57 | skip_keywords = ["enormous", "giant", "huge", "xlarge"]
58 |
59 | def test_timm(self, request, model_name):
60 | """Tests a TIMM model's forward pass with a random input tensor of the appropriate size."""
61 | if any(keyword in model_name.lower() for keyword in self.skip_keywords):
62 | pytest.skip(f"Skipping model due to size keyword in name: {model_name}")
63 |
64 | try:
65 | model = timm.create_model(model_name, pretrained=PRETRAINED)
66 | except RuntimeError as e:
67 | if "out of memory" in str(e):
68 | pytest.skip(f"Skipping model {model_name} due to memory error.")
69 |
70 | num_params = sum(p.numel() for p in model.parameters())
71 |
72 | # Calculate estimated memory requirement
73 | estimated_memory = num_params * MEMORY_PER_PARAM
74 |
75 | if estimated_memory > MEMORY_LIMIT_GB:
76 | pytest.skip(f"Skipping model {model_name}: estimated memory {estimated_memory:.2f} GB exceeds limit.")
77 |
78 | input_size = model.default_cfg.get("input_size")
79 | x = torch.randn((1,) + input_size)
80 | directory = f"tmp/{request.node.name}"
81 | try:
82 | os.makedirs(directory, exist_ok=True)
83 |
84 | filename = f"{directory}/{request.node.name}.onnx"
85 | slim_filename = f"{directory}/{request.node.name}_slim.onnx"
86 | torch.onnx.export(model, x, filename)
87 | except Exception as e:
88 | print(f"An unexpected error occurred: {str(e)}")
89 | return
90 | if not os.path.exists(filename):
91 | return
92 |
93 | command = f"onnxslim {filename} {slim_filename}"
94 | result = subprocess.run(command, shell=True, capture_output=True, text=True)
95 | output = result.stderr.strip()
96 | # Assert the expected return code
97 | print(output)
98 | assert result.returncode == 0
99 |
100 | shutil.rmtree(directory, ignore_errors=True)
101 |
102 |
103 | if __name__ == "__main__":
104 | warnings.filterwarnings("ignore")
105 | import sys
106 |
107 | sys.exit(pytest.main(["-p", "no:warnings", "-v", "tests/test_onnx_nets.py"]))
108 |
--------------------------------------------------------------------------------
/tests/test_onnxslim.py:
--------------------------------------------------------------------------------
1 | import os
2 | import subprocess
3 | import tempfile
4 |
5 | import numpy as np
6 | import pytest
7 |
8 | from onnxslim import slim
9 | from onnxslim.utils import summarize_model
10 |
11 | MODELZOO_PATH = "/data/modelzoo"
12 | FILENAME = f"{MODELZOO_PATH}/resnet18/resnet18.onnx"
13 |
14 |
15 | class TestFunctional:
16 | def run_basic_test(self, in_model_path, out_model_path, **kwargs):
17 | check_func = kwargs.get("check_func", None)
18 | kwargs_api = kwargs.get("api", {})
19 | kwargs_bash = kwargs.get("bash", "")
20 | summary = summarize_model(slim(in_model_path, **kwargs_api), in_model_path)
21 | if check_func:
22 | check_func(summary)
23 |
24 | slim(in_model_path, out_model_path, **kwargs_api)
25 | summary = summarize_model(out_model_path, out_model_path)
26 | if check_func:
27 | check_func(summary)
28 |
29 | command = f'onnxslim "{in_model_path}" "{out_model_path}" {kwargs_bash}'
30 |
31 | result = subprocess.run(command, shell=True, capture_output=True, text=True)
32 | output = result.stderr.strip()
33 | # Assert the expected return code
34 | print(output)
35 | assert result.returncode == 0
36 |
37 | summary = summarize_model(out_model_path, out_model_path)
38 | if check_func:
39 | check_func(summary)
40 |
41 | def test_basic(self, request):
42 | with tempfile.TemporaryDirectory() as tempdir:
43 | out_model_path = os.path.join(tempdir, "resnet18.onnx")
44 | self.run_basic_test(FILENAME, out_model_path)
45 |
46 | def test_input_shape_modification(self, request):
47 | def check_func(summary):
48 | assert summary.input_info[0].shape == (1, 3, 224, 224)
49 |
50 | with tempfile.TemporaryDirectory() as tempdir:
51 | out_model_path = os.path.join(tempdir, "resnet18.onnx")
52 | kwargs = {}
53 | kwargs["bash"] = "--input-shapes input:1,3,224,224"
54 | kwargs["api"] = {"input_shapes": ["input:1,3,224,224"]}
55 | kwargs["check_func"] = check_func
56 | self.run_basic_test(FILENAME, out_model_path, **kwargs)
57 |
58 | def test_input_modification(self, request):
59 | def check_func(summary):
60 | assert "/maxpool/MaxPool_output_0" in summary.input_maps
61 | assert "/layer1/layer1.0/relu/Relu_output_0" in summary.input_maps
62 |
63 | with tempfile.TemporaryDirectory() as tempdir:
64 | out_model_path = os.path.join(tempdir, "resnet18.onnx")
65 | kwargs = {}
66 | kwargs["bash"] = "--inputs /maxpool/MaxPool_output_0 /layer1/layer1.0/relu/Relu_output_0"
67 | kwargs["api"] = {"inputs": ["/maxpool/MaxPool_output_0", "/layer1/layer1.0/relu/Relu_output_0"]}
68 | kwargs["check_func"] = check_func
69 | self.run_basic_test(FILENAME, out_model_path, **kwargs)
70 |
71 | def test_output_modification(self, request):
72 | def check_func(summary):
73 | assert "/Flatten_output_0" in summary.output_maps
74 |
75 | with tempfile.TemporaryDirectory() as tempdir:
76 | out_model_path = os.path.join(tempdir, "resnet18.onnx")
77 | kwargs = {}
78 | kwargs["bash"] = "--outputs /Flatten_output_0"
79 | kwargs["api"] = {"outputs": ["/Flatten_output_0"]}
80 | kwargs["check_func"] = check_func
81 | self.run_basic_test(FILENAME, out_model_path, **kwargs)
82 |
83 | def test_dtype_conversion(self, request):
84 | def check_func_fp16(summary):
85 | assert summary.input_info[0].dtype == np.float16
86 |
87 | def check_func_fp32(summary):
88 | assert summary.input_info[0].dtype == np.float32
89 |
90 | with tempfile.TemporaryDirectory() as tempdir:
91 | out_fp16_model_path = os.path.join(tempdir, "resnet18_fp16.onnx")
92 | kwargs = {}
93 | kwargs["bash"] = "--dtype fp16"
94 | kwargs["api"] = {"dtype": "fp16"}
95 | kwargs["check_func"] = check_func_fp16
96 | self.run_basic_test(FILENAME, out_fp16_model_path, **kwargs)
97 |
98 | out_fp32_model_path = os.path.join(tempdir, "resnet18_fp32.onnx")
99 | kwargs = {}
100 | kwargs["bash"] = "--dtype fp32"
101 | kwargs["api"] = {"dtype": "fp32"}
102 | kwargs["check_func"] = check_func_fp32
103 | self.run_basic_test(out_fp16_model_path, out_fp32_model_path, **kwargs)
104 |
105 | def test_save_as_external_data(self, request):
106 | with tempfile.TemporaryDirectory() as tempdir:
107 | out_model_path = os.path.join(tempdir, "resnet18.onnx")
108 | kwargs = {}
109 | kwargs["bash"] = "--save-as-external-data"
110 | kwargs["api"] = {"save_as_external_data": True}
111 | self.run_basic_test(FILENAME, out_model_path, **kwargs)
112 | assert os.path.getsize(out_model_path) < 1e5
113 |
114 | def test_model_check(self, request):
115 | with tempfile.TemporaryDirectory() as tempdir:
116 | out_model_path = os.path.join(tempdir, "resnet18.onnx")
117 | input_data = os.path.join(tempdir, "input.npy")
118 | test_data = np.random.random((1, 3, 224, 224)).astype(np.float32)
119 | np.save(input_data, test_data)
120 | kwargs = {}
121 | kwargs["bash"] = f"--model-check --model-check-inputs input:{input_data}"
122 | kwargs["api"] = {"model_check": True, "model_check_inputs": [f"input:{input_data}"]}
123 | self.run_basic_test(FILENAME, out_model_path, **kwargs)
124 |
125 | def test_inspect(self, request):
126 | with tempfile.TemporaryDirectory():
127 | kwargs = {}
128 | kwargs["bash"] = "--inspect"
129 | kwargs["api"] = {"inspect": True}
130 | self.run_basic_test(FILENAME, FILENAME, **kwargs)
131 |
132 |
133 | if __name__ == "__main__":
134 | import sys
135 |
136 | sys.exit(
137 | pytest.main(
138 | [
139 | "-p",
140 | "no:warnings",
141 | "-v",
142 | "tests/test_onnxslim.py",
143 | ]
144 | )
145 | )
146 |
--------------------------------------------------------------------------------
/tests/test_pattern_generator.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import onnx
4 | import pytest
5 | import torch
6 | import torch.nn as nn
7 |
8 | from onnxslim import register_fusion_pattern, slim
9 | from onnxslim.core.pattern import Pattern, PatternGenerator, PatternMatcher
10 |
11 |
12 | class TestPatternGenerator:
13 | def test_gelu(self, request):
14 | """Test the GELU activation function within the PatternModel class."""
15 |
16 | class PatternModel(nn.Module):
17 | def __init__(self):
18 | super().__init__()
19 | self.gelu = nn.GELU()
20 |
21 | def forward(self, x):
22 | """Applies the GELU activation function to the input tensor."""
23 | x = self.gelu(x)
24 | return x
25 |
26 | class Model(nn.Module):
27 | def __init__(self):
28 | """Initializes the Model class with ReLU and PatternModel components."""
29 | super().__init__()
30 | self.relu0 = nn.ReLU()
31 | self.pattern = PatternModel()
32 | self.relu1 = nn.ReLU()
33 |
34 | def forward(self, x):
35 | """Applies the ReLU activation function, the PatternModel, and another ReLU activation sequentially to
36 | the input tensor.
37 | """
38 | x = self.relu0(x)
39 | x = self.pattern(x)
40 | x = self.relu1(x)
41 | return x
42 |
43 | input = torch.randn(2)
44 | p = PatternModel()
45 | m = Model()
46 | directory = f"tmp/{request.node.name}"
47 | os.makedirs(directory, exist_ok=True)
48 |
49 | pattern_filename = f"{directory}/{request.node.name}.onnx"
50 | torch.onnx.export(p, input, pattern_filename)
51 |
52 | model_filename = f"{directory}/{request.node.name}.onnx"
53 | torch.onnx.export(m, input, model_filename)
54 |
55 | model = onnx.load(pattern_filename)
56 | pgen = PatternGenerator(model)
57 | template = pgen.generate()
58 | pattern = Pattern(template)
59 |
60 | class GeluMatcher(PatternMatcher):
61 | def __init__(self, pattern, priority):
62 | """Initialize a GeluMatcher with a given pattern and priority."""
63 | super().__init__(pattern, priority)
64 |
65 | @property
66 | def name(self):
67 | """Return the name of the matcher as 'FusionGelu'."""
68 | return "FusionGelu"
69 |
70 | def rewrite(self, opset=11):
71 | """Raise an exception indicating a pattern match in GeluMatcher."""
72 | raise Exception("Pattern Matched")
73 |
74 | register_fusion_pattern(GeluMatcher(pattern, 1))
75 | with pytest.raises(Exception) as excinfo:
76 | slim(model_filename, f"{directory}/{request.node.name}_slim.onnx")
77 |
78 | assert str(excinfo.value) == "Pattern Matched"
79 |
80 |
81 | if __name__ == "__main__":
82 | import sys
83 |
84 | sys.exit(
85 | pytest.main(
86 | [
87 | "-p",
88 | "no:warnings",
89 | "-sv",
90 | "tests/test_pattern_generator.py",
91 | ]
92 | )
93 | )
94 |
--------------------------------------------------------------------------------
/tests/test_pattern_matcher.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import pytest
4 | import torch
5 | import torch.nn as nn
6 |
7 | from onnxslim import slim
8 | from onnxslim.utils import print_model_info_as_table, summarize_model
9 |
10 |
11 | class TestPatternMatcher:
12 | def test_gelu(self, request):
13 | """Test the GELU activation function in a neural network model using an instance of nn.Module."""
14 |
15 | class Model(nn.Module):
16 | def __init__(self):
17 | super().__init__()
18 | self.relu0 = nn.ReLU()
19 | self.gelu = nn.GELU()
20 | self.relu1 = nn.ReLU()
21 |
22 | def forward(self, x):
23 | """Performs a forward pass through the model applying ReLU, GELU, and ReLU activations sequentially to
24 | the input tensor x.
25 | """
26 | x = self.relu0(x)
27 | x = self.gelu(x)
28 | x = self.relu1(x)
29 | return x
30 |
31 | input = torch.randn(2)
32 | m = Model()
33 | directory = f"tmp/{request.node.name}"
34 | os.makedirs(directory, exist_ok=True)
35 |
36 | filename = f"{directory}/{request.node.name}.onnx"
37 | torch.onnx.export(m, input, filename)
38 |
39 | summary = summarize_model(slim(filename, model_check=True), request.node.name)
40 | print_model_info_as_table(summary)
41 |
42 | def test_pad_conv(self, request):
43 | """Test padding followed by 2D convolution within a neural network module."""
44 |
45 | class Model(nn.Module):
46 | def __init__(self):
47 | super().__init__()
48 | self.pad_0 = nn.ConstantPad2d(3, 0)
49 | self.conv_0 = nn.Conv2d(1, 1, 3)
50 |
51 | self.pad_1 = nn.ConstantPad2d(3, 2)
52 | self.conv_1 = nn.Conv2d(1, 1, 3, bias=False)
53 |
54 | def forward(self, x):
55 | """Applies padding and convolutional layers to the input tensor x."""
56 | x0 = self.pad_0(x)
57 | x0 = self.conv_0(x0)
58 |
59 | x1 = self.pad_1(x)
60 | x1 = self.conv_1(x1)
61 |
62 | return x0 + x1
63 |
64 | input = torch.randn(1, 1, 24, 24)
65 | m = Model()
66 | directory = f"tmp/{request.node.name}"
67 | os.makedirs(directory, exist_ok=True)
68 |
69 | filename = f"{directory}/{request.node.name}.onnx"
70 | torch.onnx.export(m, input, filename)
71 |
72 | summary = summarize_model(slim(filename, model_check=True), request.node.name)
73 | print_model_info_as_table(summary)
74 |
75 | assert summary.op_type_counts["Conv"] == 2
76 | assert summary.op_type_counts["Add"] == 1
77 |
78 | def test_conv_bn(self, request):
79 | """Test the convolutional layer followed by batch normalization export and re-import via ONNX."""
80 |
81 | class Model(nn.Module):
82 | def __init__(self):
83 | super().__init__()
84 | self.conv = nn.Conv2d(1, 1, 3)
85 | self.bn = nn.BatchNorm2d(1)
86 |
87 | def forward(self, x):
88 | """Perform convolution followed by batch normalization on input tensor x."""
89 | x = self.conv(x)
90 | x = self.bn(x)
91 | return x
92 |
93 | input = torch.randn(1, 1, 24, 24)
94 | m = Model()
95 | directory = f"tmp/{request.node.name}"
96 | os.makedirs(directory, exist_ok=True)
97 |
98 | filename = f"{directory}/{request.node.name}.onnx"
99 | torch.onnx.export(m, input, filename, do_constant_folding=False)
100 |
101 | summary = summarize_model(slim(filename, model_check=True), request.node.name)
102 | print_model_info_as_table(summary)
103 | assert summary.op_type_counts["Conv"] == 1
104 |
105 | def test_consecutive_slice(self, request):
106 | """Tests consecutive slicing operations on a model by exporting it to ONNX format and then slimming the ONNX
107 | file.
108 | """
109 |
110 | class Model(nn.Module):
111 | def __init__(self):
112 | super().__init__()
113 | self.conv = nn.Conv2d(1, 1, 3)
114 | self.bn = nn.BatchNorm2d(1)
115 |
116 | def forward(self, x):
117 | """Performs slicing operation on the input tensor x by returning the section x[1:2, :2]."""
118 | return x[1:2, :2]
119 |
120 | input = torch.randn(3, 4)
121 | m = Model()
122 | directory = f"tmp/{request.node.name}"
123 | os.makedirs(directory, exist_ok=True)
124 |
125 | filename = f"{directory}/{request.node.name}.onnx"
126 | torch.onnx.export(m, input, filename)
127 |
128 | summary = summarize_model(slim(filename, model_check=True), request.node.name)
129 | print_model_info_as_table(summary)
130 | assert summary.op_type_counts["Slice"] == 1
131 |
132 | def test_consecutive_reshape(self, request):
133 | """Test the functionality of consecutive reshape operations in a model and export it to ONNX format."""
134 |
135 | class Model(nn.Module):
136 | def __init__(self):
137 | super().__init__()
138 |
139 | def forward(self, x):
140 | """Reshape tensor sequentially to (2, 6) and then to (12, 1)."""
141 | return x.view(2, 6).view(12, 1)
142 |
143 | input = torch.randn(3, 4)
144 | m = Model()
145 | directory = f"tmp/{request.node.name}"
146 | os.makedirs(directory, exist_ok=True)
147 |
148 | filename = f"{directory}/{request.node.name}.onnx"
149 | torch.onnx.export(m, input, filename)
150 |
151 | summary = summarize_model(slim(filename, model_check=True), request.node.name)
152 | print_model_info_as_table(summary)
153 | assert summary.op_type_counts["Reshape"] == 1
154 |
155 | def test_matmul_add(self, request):
156 | """Tests matrix multiplication followed by an addition operation within a neural network model."""
157 |
158 | class Model(nn.Module):
159 | def __init__(self):
160 | super().__init__()
161 | self.data = torch.randn(4, 3)
162 |
163 | def forward(self, x):
164 | """Performs matrix multiplication of input 'x' with pre-defined data, adds 1, and returns the result."""
165 | x = torch.matmul(x, self.data)
166 | x += 1
167 | return x
168 |
169 | input = torch.randn(3, 4)
170 | m = Model()
171 | directory = f"tmp/{request.node.name}"
172 | os.makedirs(directory, exist_ok=True)
173 |
174 | filename = f"{directory}/{request.node.name}.onnx"
175 | torch.onnx.export(m, input, filename)
176 |
177 | summary = summarize_model(slim(filename, model_check=True), request.node.name)
178 | print_model_info_as_table(summary)
179 | assert summary.op_type_counts["Gemm"] == 1
180 |
181 | def test_reduce(self, request):
182 | """Tests model reduction by exporting a PyTorch model to ONNX format, slimming it, and saving to a specified
183 | directory.
184 | """
185 |
186 | class Model(nn.Module):
187 | def __init__(self):
188 | super().__init__()
189 |
190 | def forward(self, x):
191 | """Performs a reduction summing over the last dimension of the input tensor and then unsqueezes the
192 | tensor along the same dimension.
193 | """
194 | x = torch.sum(x, dim=[-1], keepdim=False)
195 | x = x.unsqueeze(-1)
196 | return x
197 |
198 | input = torch.randn(3, 4)
199 | m = Model()
200 | directory = f"tmp/{request.node.name}"
201 | os.makedirs(directory, exist_ok=True)
202 |
203 | filename = f"{directory}/{request.node.name}.onnx"
204 | torch.onnx.export(m, input, filename, opset_version=11)
205 |
206 | summary = summarize_model(slim(filename, model_check=True), request.node.name)
207 | print_model_info_as_table(summary)
208 | assert summary.op_type_counts["ReduceSum"] == 1
209 |
210 | @pytest.mark.parametrize(
211 | "opset",
212 | (
213 | 11,
214 | 13,
215 | ),
216 | )
217 | def test_consecutive_unsqueeze(self, request, opset):
218 | class Model(nn.Module):
219 | def __init__(self):
220 | super().__init__()
221 |
222 | def forward(self, x):
223 | x = x.unsqueeze(-1)
224 | x = x.unsqueeze(-1)
225 | x = x.unsqueeze(1)
226 | x = x.unsqueeze(0)
227 | return x
228 |
229 | input = torch.randn(3, 4)
230 | m = Model()
231 | directory = f"tmp/{request.node.name}"
232 | os.makedirs(directory, exist_ok=True)
233 |
234 | filename = f"{directory}/{request.node.name}.onnx"
235 | torch.onnx.export(m, input, filename, opset_version=opset)
236 |
237 | summary = summarize_model(slim(filename, model_check=True), request.node.name)
238 | print_model_info_as_table(summary)
239 | assert summary.op_type_counts["Unsqueeze"] == 1
240 |
241 |
242 | if __name__ == "__main__":
243 | import sys
244 |
245 | sys.exit(
246 | pytest.main(
247 | [
248 | "-p",
249 | "no:warnings",
250 | "-sv",
251 | "tests/test_pattern_matcher.py",
252 | ]
253 | )
254 | )
255 |
--------------------------------------------------------------------------------